This commit is contained in:
wingdeans 2024-05-04 03:51:26 +02:00 committed by GitHub
commit db51f2dd19
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 289 additions and 6 deletions

View file

@ -26,8 +26,8 @@ static const char kUtf8Dispatch[] = {
1, 1, 1, 1, 1, 1, 1, 1, // 0320
1, 1, 1, 1, 1, 1, 1, 1, // 0330
2, 3, 3, 3, 3, 3, 3, 3, // 0340 utf8-3
3, 3, 3, 3, 3, 3, 3, 3, // 0350
4, 5, 5, 5, 5, 0, 0, 0, // 0360 utf8-4
3, 3, 3, 3, 3, 4, 3, 3, // 0350
5, 6, 6, 6, 7, 0, 0, 0, // 0360 utf8-4
0, 0, 0, 0, 0, 0, 0, 0, // 0370
};
@ -95,6 +95,7 @@ bool32 isutf8(const void *data, size_t size) {
}
// fallthrough
case 3:
case3:
if (p + 2 <= e && //
(p[0] & 0300) == 0200 && //
(p[1] & 0300) == 0200) { //
@ -104,11 +105,17 @@ bool32 isutf8(const void *data, size_t size) {
return false; // missing cont
}
case 4:
if (p < e && (*p & 040)) {
return false; // utf-16 surrogate
}
goto case3;
case 5:
if (p < e && (*p & 0377) < 0220) {
return false; // overlong
}
// fallthrough
case 5:
case 6:
case6:
if (p + 3 <= e && //
(((uint32_t)(p[+2] & 0377) << 030 | //
(uint32_t)(p[+1] & 0377) << 020 | //
@ -120,6 +127,11 @@ bool32 isutf8(const void *data, size_t size) {
} else {
return false; // missing cont
}
case 7:
if (p < e && (*p & 0x3F) > 0xF) {
return false; // over limit
}
goto case6;
default:
__builtin_unreachable();
}

View file

@ -104,3 +104,4 @@ CF-Visitor, kHttpCfVisitor
CF-Connecting-IP, kHttpCfConnectingIp
CF-IPCountry, kHttpCfIpcountry
CDN-Loop, kHttpCdnLoop
Sec-WebSocket-Key, kHttpWebsocketKey

View file

@ -39,7 +39,7 @@
#line 12 "gethttpheader.gperf"
struct thatispacked HttpHeaderSlot { char *name; char code; };
#define TOTAL_KEYWORDS 93
#define TOTAL_KEYWORDS 94
#define MIN_WORD_LENGTH 2
#define MAX_WORD_LENGTH 32
#define MIN_HASH_VALUE 3
@ -387,7 +387,10 @@ LookupHttpHeader (register const char *str, register size_t len)
#line 87 "gethttpheader.gperf"
{"Strict-Transport-Security", kHttpStrictTransportSecurity},
{""}, {""}, {""}, {""}, {""}, {""}, {""}, {""}, {""},
{""}, {""}, {""}, {""}, {""},
{""}, {""},
#line 107 "gethttpheader.gperf"
{"Sec-WebSocket-Key", kHttpWebsocketKey},
{""}, {""},
#line 22 "gethttpheader.gperf"
{"X-Forwarded-For", kHttpXForwardedFor},
{""},

View file

@ -206,6 +206,8 @@ const char *GetHttpHeaderName(int h) {
return "CDN-Loop";
case kHttpSecChUaPlatform:
return "Sec-CH-UA-Platform";
case kHttpWebsocketKey:
return "Sec-WebSocket-Key";
default:
return NULL;
}

View file

@ -138,7 +138,8 @@
#define kHttpCfIpcountry 90
#define kHttpSecChUaPlatform 91
#define kHttpCdnLoop 92
#define kHttpHeadersMax 93
#define kHttpWebsocketKey 93
#define kHttpHeadersMax 94
COSMOPOLITAN_C_START_

View file

@ -40,6 +40,7 @@
#include "libc/intrin/atomic.h"
#include "libc/intrin/bsr.h"
#include "libc/intrin/likely.h"
#include "libc/intrin/newbie.h"
#include "libc/intrin/nomultics.internal.h"
#include "libc/intrin/safemacros.internal.h"
#include "libc/log/appendresourcereport.internal.h"
@ -121,6 +122,7 @@
#include "third_party/mbedtls/net_sockets.h"
#include "third_party/mbedtls/oid.h"
#include "third_party/mbedtls/san.h"
#include "third_party/mbedtls/sha1.h"
#include "third_party/mbedtls/ssl.h"
#include "third_party/mbedtls/ssl_ticket.h"
#include "third_party/mbedtls/x509.h"
@ -406,6 +408,7 @@ struct ClearedPerMessage {
bool hascontenttype;
bool gotcachecontrol;
bool gotxcontenttypeoptions;
char wstype;
int frags;
int statuscode;
int isyielding;
@ -487,6 +490,8 @@ static uint8_t *zmap;
static uint8_t *zcdir;
static size_t hdrsize;
static size_t amtread;
static size_t wsfragread;
static char wsfragtype;
static reader_f reader;
static writer_f writer;
static char *extrahdrs;
@ -5141,6 +5146,195 @@ static bool LuaRunAsset(const char *path, bool mandatory) {
return !!a;
}
static int LuaWSUpgrade(lua_State *L) {
size_t i;
char *p, *q;
bool haskey;
mbedtls_sha1_context ctx;
unsigned char hash[20];
if (cpm.generator)
luaL_error(L, "Cannot upgrade to websocket after yielding normally");
if (!HasHeader(kHttpWebsocketKey))
luaL_error(L, "No Sec-WebSocket-Key header");
mbedtls_sha1_init(&ctx);
mbedtls_sha1_starts_ret(&ctx);
mbedtls_sha1_update_ret(&ctx, (unsigned char*)
HeaderData(kHttpWebsocketKey),
HeaderLength(kHttpWebsocketKey));
p = SetStatus(101, "Switching Protocols");
while (p - hdrbuf.p + (20 + 21 + (20 + 28 + 4)) + 512 > hdrbuf.n) {
hdrbuf.n += hdrbuf.n >> 1;
q = xrealloc(hdrbuf.p, hdrbuf.n);
cpm.luaheaderp = p = q + (p - hdrbuf.p);
hdrbuf.p = q;
}
mbedtls_sha1_update_ret(
&ctx, (unsigned char *)"258EAFA5-E914-47DA-95CA-C5AB0DC85B11", 36);
mbedtls_sha1_finish_ret(&ctx, hash);
char *accept = EncodeBase64((char *)hash, 20, NULL);
p = stpcpy(p, "Upgrade: websocket\r\n");
p = stpcpy(p, "Connection: upgrade\r\n");
p = AppendHeader(p, "Sec-WebSocket-Accept", accept);
cpm.luaheaderp = p;
cpm.wstype = 1;
return 0;
}
static int LuaWSRead(lua_State *L) {
ssize_t rc;
size_t i, got, amt, bufsize;
unsigned char wshdr[10], wshdrlen, *extlen, *mask, op;
char *bufstart;
uint64_t len;
struct iovec iov[2];
OnlyCallDuringRequest(L, "ws.Read");
got = 0;
do {
if ((rc = reader(client, wshdr + got, 2 - got)) == -1)
luaL_error(L, "Could not read WS header");
} while ((got += rc) < 2);
op = wshdr[0] & 0xF;
if (wshdr[0] & 0x70) goto close; // reserved bit set
if (!(wshdr[1] | (1 << 7))) goto close; // unmasked
if ((wshdr[0] & 0x7) >= 0x3) goto close; // reserved opcode
if (!wsfragtype && !op) goto close; // not in continuation
len = wshdr[1] & ~(1 << 7);
if (wshdr[0] & 0x8) { // control frame
if (!(wshdr[0] & 0x80) || len >= 126) goto close; // fragmented or too long
} else {
if (op && wsfragtype) goto close; // during fragmented seq
}
wshdrlen = 6;
if (len == 126) {
wshdrlen = 8;
} else if (len == 127) {
wshdrlen = 14;
}
while (got < wshdrlen) {
if ((rc = reader(client, wshdr + got, wshdrlen - got)) == -1)
luaL_error(L, "Could not read WS extended length");
got += rc;
}
extlen = &wshdr[2];
mask = &wshdr[wshdrlen - 4];
if (len == 126) {
len = be16toh(*(uint16_t *)extlen);
} else if (len == 127) {
len = be64toh(*(uint64_t *)extlen);
}
if (len >= inbuf.n - wsfragread)
luaL_error(L, "Required %d bytes to read WS frame, %d bytes available", len,
inbuf.n - wsfragread);
for (got = 0, amt = wsfragread; got < len; got += rc, amt += rc) {
if ((rc = reader(client, inbuf.p + amt, len - got)) == -1)
luaL_error(L, "Could not read WS data");
}
for (i = 0, amt = wsfragread; i < got; ++i, ++amt)
inbuf.p[amt] ^= mask[i & 0x3];
if (op == 0x9) {
wshdr[0] = (wshdr[0] & ~0xF) | 0xA;
wshdr[1] = wshdr[1] & ~0x80;
iov[0].iov_base = wshdr;
iov[0].iov_len = wshdrlen - 4;
iov[1].iov_base = inbuf.p + wsfragread;
iov[1].iov_len = got;
Send(iov, 2);
}
if (wshdr[0] & 0x80) {
if (op) {
bufstart = inbuf.p + wsfragread;
bufsize = got;
if (op == 0x1 && !isutf8(bufstart, bufsize)) goto close;
lua_pushlstring(L, bufstart, bufsize);
lua_pushinteger(L, op);
} else {
bufstart = inbuf.p + amtread;
bufsize = (wsfragread - amtread) + got;
if (wsfragtype == 0x1 && !isutf8(bufstart, bufsize)) goto close;
lua_pushlstring(L, bufstart, bufsize);
lua_pushinteger(L, wsfragtype);
wsfragread = amtread;
wsfragtype = 0;
}
} else {
lua_pushnil(L);
lua_pushinteger(L, 0);
if (!wsfragtype) wsfragtype = op;
wsfragread += got;
}
return 2;
close:
lua_pushnil(L);
lua_pushinteger(L, 0x08);
return 2;
}
static int LuaWSWrite(lua_State *L) {
int type;
size_t size;
const char *data;
OnlyCallDuringRequest(L, "ws.Write");
if (!cpm.wstype)
LuaWSUpgrade(L);
type = luaL_optinteger(L, 2, -1);
if (type == 1 || type == 2) {
cpm.wstype = type;
} else if (type != -1) {
luaL_error(L, "Invalid WS type");
}
if (!lua_isnil(L, 1)) {
data = luaL_checklstring(L, 1, &size);
appendd(&cpm.outbuf, data, size);
}
return 0;
}
static const luaL_Reg kLuaWS[] = {
{"Read", LuaWSRead}, //
{"Write", LuaWSWrite}, //
{0} //
};
int LuaWS(lua_State *L) {
luaL_newlib(L, kLuaWS);
lua_pushinteger(L, 0); lua_setfield(L, -2, "CONT");
lua_pushinteger(L, 1); lua_setfield(L, -2, "TEXT");
lua_pushinteger(L, 2); lua_setfield(L, -2, "BIN");
lua_pushinteger(L, 8); lua_setfield(L, -2, "CLOSE");
lua_pushinteger(L, 9); lua_setfield(L, -2, "PING");
lua_pushinteger(L, 10); lua_setfield(L, -2, "PONG");
return 1;
}
// <SORTED>
// list of functions that can't be run from the repl
static const char *const kDontAutoComplete[] = {
@ -5407,6 +5601,7 @@ static const luaL_Reg kLuaLibs[] = {
{"path", LuaPath}, //
{"re", LuaRe}, //
{"unix", LuaUnix}, //
{"ws", LuaWS} //
};
static void LuaSetArgv(lua_State *L) {
@ -6453,6 +6648,73 @@ static bool StreamResponse(char *p) {
return true;
}
static bool StreamWS(char *p) {
ssize_t rc;
struct iovec iov[2];
char *s, wshdr[10], *extlen;
int nresults, status;
p = AppendCrlf(p);
CHECK_LE(p - hdrbuf.p, hdrbuf.n);
if (logmessages) {
LogMessage("sending", hdrbuf.p, p - hdrbuf.p);
}
iov[0].iov_base = hdrbuf.p;
iov[0].iov_len = p - hdrbuf.p;
Send(iov, 1);
bzero(iov, sizeof(iov));
iov[0].iov_base = wshdr;
extlen = &wshdr[2];
wsfragread = amtread;
wsfragtype = 0;
for (;;) {
if (!YL || lua_status(YL) != LUA_YIELD) break; // done yielding
cpm.contentlength = 0;
status = lua_resume(YL, NULL, 0, &nresults);
if (status == LUA_OK) {
lua_pop(YL, nresults);
break;
} else if (status != LUA_YIELD) {
LogLuaError("resume", lua_tostring(YL, -1));
lua_pop(YL, 1);
break;
}
lua_pop(YL, nresults);
if (!cpm.contentlength) UseOutput();
DEBUGF("(lua) ws yielded with %ld bytes generated", cpm.contentlength);
iov[1].iov_base = cpm.content;
iov[1].iov_len = rc = cpm.contentlength;
if (rc < 126) {
wshdr[1] = rc;
iov[0].iov_len = 2;
} else if (rc <= 0xFFFF) {
wshdr[1] = 126;
*(uint16_t *)extlen = htobe16(rc);
iov[0].iov_len = 4;
} else {
wshdr[1] = 127;
*(uint64_t *)extlen = htobe64(rc);
iov[0].iov_len = 10;
}
wshdr[0] = cpm.wstype | (1 << 7);
if (Send(iov, 2) == -1) break;
}
wshdr[0] = 0x8 | (1 << 7);
wshdr[1] = 0;
iov[0].iov_len = 2;
Send(iov, 1);
connectionclose = true;
return true;
}
static bool HandleMessageActual(void) {
int rc;
long reqtime, contime;
@ -6515,6 +6777,8 @@ static bool HandleMessageActual(void) {
}
if (!cpm.generator) {
return TransmitResponse(p);
} else if (cpm.wstype) {
return StreamWS(p);
} else {
return StreamResponse(p);
}