Misc TLS cleanups

This commit is contained in:
Sergio R. Caprile 2024-07-01 13:00:40 -03:00
parent 8c4cfc8101
commit 5a8c56e784
4 changed files with 76 additions and 28 deletions

View File

@ -8145,7 +8145,9 @@ bool mg_match(struct mg_str s, struct mg_str p, struct mg_str *caps) {
size_t i = 0, j = 0, ni = 0, nj = 0;
if (caps) caps->buf = NULL, caps->len = 0;
while (i < p.len || j < s.len) {
if (i < p.len && j < s.len && (p.buf[i] == '?' || s.buf[j] == p.buf[i])) {
if (i < p.len && j < s.len &&
(p.buf[i] == '?' ||
(p.buf[i] != '*' && p.buf[i] != '#' && s.buf[j] == p.buf[i]))) {
if (caps == NULL) {
} else if (p.buf[i] == '?') {
caps->buf = &s.buf[j], caps->len = 1; // Finalize `?` cap
@ -8188,10 +8190,10 @@ bool mg_span(struct mg_str s, struct mg_str *a, struct mg_str *b, char sep) {
bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
size_t i = 0, ndigits = 0;
uint64_t max = val_len == sizeof(uint8_t) ? 0xFF
uint64_t max = val_len == sizeof(uint8_t) ? 0xFF
: val_len == sizeof(uint16_t) ? 0xFFFF
: val_len == sizeof(uint32_t) ? 0xFFFFFFFF
: (uint64_t) ~0;
: (uint64_t) ~0;
uint64_t result = 0;
if (max == (uint64_t) ~0 && val_len != sizeof(uint64_t)) return false;
if (base == 0 && str.len >= 2) {
@ -8207,7 +8209,7 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
case 2:
while (i < str.len && (str.buf[i] == '0' || str.buf[i] == '1')) {
uint64_t digit = (uint64_t) (str.buf[i] - '0');
if (result > max/2) return false; // Overflow
if (result > max / 2) return false; // Overflow
result *= 2;
if (result > max - digit) return false; // Overflow
result += digit;
@ -8217,12 +8219,12 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
case 10:
while (i < str.len && str.buf[i] >= '0' && str.buf[i] <= '9') {
uint64_t digit = (uint64_t) (str.buf[i] - '0');
if (result > max/10) return false; // Overflow
if (result > max / 10) return false; // Overflow
result *= 10;
if (result > max - digit) return false; // Overflow
result += digit;
i++, ndigits++;
}
}
break;
case 16:
while (i < str.len) {
@ -8232,7 +8234,7 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
: (c >= 'a' && c <= 'f') ? (uint64_t) (c - 'W')
: (uint64_t) ~0;
if (digit == (uint64_t) ~0) break;
if (result > max/16) return false; // Overflow
if (result > max / 16) return false; // Overflow
result *= 16;
if (result > max - digit) return false; // Overflow
result += digit;
@ -9651,13 +9653,15 @@ static void mg_tls_drop_record(struct mg_connection *c) {
static void mg_tls_drop_message(struct mg_connection *c) {
uint32_t len;
struct tls_data *tls = (struct tls_data *) c->tls;
if (tls->recv.len == 0) {
if (tls->recv.len == 0) return;
len = MG_LOAD_BE24(tls->recv.buf + 1) + TLS_MSGHDR_SIZE;
if (tls->recv.len < len) {
mg_error(c, "wrong size");
return;
}
len = MG_LOAD_BE24(tls->recv.buf + 1);
mg_sha256_update(&tls->sha256, tls->recv.buf, len + TLS_MSGHDR_SIZE);
tls->recv.buf += len + TLS_MSGHDR_SIZE;
tls->recv.len -= len + TLS_MSGHDR_SIZE;
mg_sha256_update(&tls->sha256, tls->recv.buf, len);
tls->recv.buf += len;
tls->recv.len -= len;
if (tls->recv.len == 0) {
mg_tls_drop_record(c);
}
@ -9918,6 +9922,10 @@ static int mg_tls_recv_record(struct mg_connection *c) {
free(dec);
}
#else
if (msgsz < 16) {
mg_error(c, "wrong size");
return -1;
}
mg_aes_gcm_decrypt(msg, msg, msgsz - 16, key, 16, nonce, sizeof(nonce));
#endif
r = msgsz - 16 - 1;
@ -9981,8 +9989,10 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
MG_INFO(("bad session id len"));
}
cipher_suites_len = MG_LOAD_BE16(rio->buf + 44 + session_id_len);
if (cipher_suites_len > (rio->len - 46 - session_id_len)) goto fail;
ext_len = MG_LOAD_BE16(rio->buf + 48 + session_id_len + cipher_suites_len);
ext = rio->buf + 50 + session_id_len + cipher_suites_len;
if (ext_len > (rio->len - 52 - session_id_len - cipher_suites_len)) goto fail;
for (j = 0; j < ext_len;) {
uint16_t k;
uint16_t key_exchange_len;
@ -9993,10 +10003,14 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
j += (uint16_t) (n + 4);
continue;
}
key_exchange_len = MG_LOAD_BE16(ext + j + 5);
key_exchange_len = MG_LOAD_BE16(ext + j + 4);
key_exchange = ext + j + 6;
if (key_exchange_len >
rio->len - (uint16_t) ((size_t) key_exchange - (size_t) rio->buf) - 2)
goto fail;
for (k = 0; k < key_exchange_len;) {
uint16_t m = MG_LOAD_BE16(key_exchange + k + 2);
if (m > (key_exchange_len - k - 4)) goto fail;
if (m == 32 && key_exchange[k] == 0x00 && key_exchange[k + 1] == 0x1d) {
memmove(tls->x25519_cli, key_exchange + k + 4, m);
mg_tls_drop_record(c);
@ -10006,6 +10020,7 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
}
j += (uint16_t) (n + 4);
}
fail:
mg_error(c, "bad client hello");
return -1;
}
@ -10324,6 +10339,7 @@ static int mg_tls_client_recv_hello(struct mg_connection *c) {
ext_len = MG_LOAD_BE16(rio->buf + 5 + 39 + 32 + 3);
ext = rio->buf + 5 + 39 + 32 + 3 + 2;
if (ext_len > (rio->len - (5 + 39 + 32 + 3 + 2))) goto fail;
for (j = 0; j < ext_len;) {
uint16_t ext_type = MG_LOAD_BE16(ext + j);
@ -10331,6 +10347,7 @@ static int mg_tls_client_recv_hello(struct mg_connection *c) {
uint16_t group;
uint8_t *key_exchange;
uint16_t key_exchange_len;
if (ext_len2 > (ext_len - j - 4)) goto fail;
if (ext_type != 0x0033) { // not a key share extension, ignore
j += (uint16_t) (ext_len2 + 4);
continue;
@ -10353,6 +10370,7 @@ static int mg_tls_client_recv_hello(struct mg_connection *c) {
mg_tls_generate_handshake_keys(c);
return 0;
}
fail:
mg_error(c, "bad client hello");
return -1;
}
@ -10663,7 +10681,7 @@ static int mg_parse_pem(const struct mg_str pem, const struct mg_str label,
size_t n = 0, m = 0;
char *s;
const char *c;
struct mg_str caps[5];
struct mg_str caps[6]; // number of wildcards + 1
if (!mg_match(pem, mg_str("#-----BEGIN #-----#-----END #-----#"), caps)) {
*der = mg_strdup(pem);
return 0;
@ -10713,6 +10731,7 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
if (opts->name.len > 0) {
if (opts->name.len >= sizeof(tls->hostname) - 1) {
mg_error(c, "hostname too long");
return;
}
strncpy((char *) tls->hostname, opts->name.buf, sizeof(tls->hostname) - 1);
tls->hostname[opts->name.len] = 0;

View File

@ -69,7 +69,9 @@ bool mg_match(struct mg_str s, struct mg_str p, struct mg_str *caps) {
size_t i = 0, j = 0, ni = 0, nj = 0;
if (caps) caps->buf = NULL, caps->len = 0;
while (i < p.len || j < s.len) {
if (i < p.len && j < s.len && (p.buf[i] == '?' || s.buf[j] == p.buf[i])) {
if (i < p.len && j < s.len &&
(p.buf[i] == '?' ||
(p.buf[i] != '*' && p.buf[i] != '#' && s.buf[j] == p.buf[i]))) {
if (caps == NULL) {
} else if (p.buf[i] == '?') {
caps->buf = &s.buf[j], caps->len = 1; // Finalize `?` cap
@ -112,10 +114,10 @@ bool mg_span(struct mg_str s, struct mg_str *a, struct mg_str *b, char sep) {
bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
size_t i = 0, ndigits = 0;
uint64_t max = val_len == sizeof(uint8_t) ? 0xFF
uint64_t max = val_len == sizeof(uint8_t) ? 0xFF
: val_len == sizeof(uint16_t) ? 0xFFFF
: val_len == sizeof(uint32_t) ? 0xFFFFFFFF
: (uint64_t) ~0;
: (uint64_t) ~0;
uint64_t result = 0;
if (max == (uint64_t) ~0 && val_len != sizeof(uint64_t)) return false;
if (base == 0 && str.len >= 2) {
@ -131,7 +133,7 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
case 2:
while (i < str.len && (str.buf[i] == '0' || str.buf[i] == '1')) {
uint64_t digit = (uint64_t) (str.buf[i] - '0');
if (result > max/2) return false; // Overflow
if (result > max / 2) return false; // Overflow
result *= 2;
if (result > max - digit) return false; // Overflow
result += digit;
@ -141,12 +143,12 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
case 10:
while (i < str.len && str.buf[i] >= '0' && str.buf[i] <= '9') {
uint64_t digit = (uint64_t) (str.buf[i] - '0');
if (result > max/10) return false; // Overflow
if (result > max / 10) return false; // Overflow
result *= 10;
if (result > max - digit) return false; // Overflow
result += digit;
i++, ndigits++;
}
}
break;
case 16:
while (i < str.len) {
@ -156,7 +158,7 @@ bool mg_str_to_num(struct mg_str str, int base, void *val, size_t val_len) {
: (c >= 'a' && c <= 'f') ? (uint64_t) (c - 'W')
: (uint64_t) ~0;
if (digit == (uint64_t) ~0) break;
if (result > max/16) return false; // Overflow
if (result > max / 16) return false; // Overflow
result *= 16;
if (result > max - digit) return false; // Overflow
result += digit;

View File

@ -214,13 +214,15 @@ static void mg_tls_drop_record(struct mg_connection *c) {
static void mg_tls_drop_message(struct mg_connection *c) {
uint32_t len;
struct tls_data *tls = (struct tls_data *) c->tls;
if (tls->recv.len == 0) {
if (tls->recv.len == 0) return;
len = MG_LOAD_BE24(tls->recv.buf + 1) + TLS_MSGHDR_SIZE;
if (tls->recv.len < len) {
mg_error(c, "wrong size");
return;
}
len = MG_LOAD_BE24(tls->recv.buf + 1);
mg_sha256_update(&tls->sha256, tls->recv.buf, len + TLS_MSGHDR_SIZE);
tls->recv.buf += len + TLS_MSGHDR_SIZE;
tls->recv.len -= len + TLS_MSGHDR_SIZE;
mg_sha256_update(&tls->sha256, tls->recv.buf, len);
tls->recv.buf += len;
tls->recv.len -= len;
if (tls->recv.len == 0) {
mg_tls_drop_record(c);
}
@ -481,6 +483,10 @@ static int mg_tls_recv_record(struct mg_connection *c) {
free(dec);
}
#else
if (msgsz < 16) {
mg_error(c, "wrong size");
return -1;
}
mg_aes_gcm_decrypt(msg, msg, msgsz - 16, key, 16, nonce, sizeof(nonce));
#endif
r = msgsz - 16 - 1;
@ -544,8 +550,10 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
MG_INFO(("bad session id len"));
}
cipher_suites_len = MG_LOAD_BE16(rio->buf + 44 + session_id_len);
if (cipher_suites_len > (rio->len - 46 - session_id_len)) goto fail;
ext_len = MG_LOAD_BE16(rio->buf + 48 + session_id_len + cipher_suites_len);
ext = rio->buf + 50 + session_id_len + cipher_suites_len;
if (ext_len > (rio->len - 52 - session_id_len - cipher_suites_len)) goto fail;
for (j = 0; j < ext_len;) {
uint16_t k;
uint16_t key_exchange_len;
@ -556,10 +564,14 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
j += (uint16_t) (n + 4);
continue;
}
key_exchange_len = MG_LOAD_BE16(ext + j + 5);
key_exchange_len = MG_LOAD_BE16(ext + j + 4);
key_exchange = ext + j + 6;
if (key_exchange_len >
rio->len - (uint16_t) ((size_t) key_exchange - (size_t) rio->buf) - 2)
goto fail;
for (k = 0; k < key_exchange_len;) {
uint16_t m = MG_LOAD_BE16(key_exchange + k + 2);
if (m > (key_exchange_len - k - 4)) goto fail;
if (m == 32 && key_exchange[k] == 0x00 && key_exchange[k + 1] == 0x1d) {
memmove(tls->x25519_cli, key_exchange + k + 4, m);
mg_tls_drop_record(c);
@ -569,6 +581,7 @@ static int mg_tls_server_recv_hello(struct mg_connection *c) {
}
j += (uint16_t) (n + 4);
}
fail:
mg_error(c, "bad client hello");
return -1;
}
@ -887,6 +900,7 @@ static int mg_tls_client_recv_hello(struct mg_connection *c) {
ext_len = MG_LOAD_BE16(rio->buf + 5 + 39 + 32 + 3);
ext = rio->buf + 5 + 39 + 32 + 3 + 2;
if (ext_len > (rio->len - (5 + 39 + 32 + 3 + 2))) goto fail;
for (j = 0; j < ext_len;) {
uint16_t ext_type = MG_LOAD_BE16(ext + j);
@ -894,6 +908,7 @@ static int mg_tls_client_recv_hello(struct mg_connection *c) {
uint16_t group;
uint8_t *key_exchange;
uint16_t key_exchange_len;
if (ext_len2 > (ext_len - j - 4)) goto fail;
if (ext_type != 0x0033) { // not a key share extension, ignore
j += (uint16_t) (ext_len2 + 4);
continue;
@ -916,6 +931,7 @@ static int mg_tls_client_recv_hello(struct mg_connection *c) {
mg_tls_generate_handshake_keys(c);
return 0;
}
fail:
mg_error(c, "bad client hello");
return -1;
}
@ -1226,7 +1242,7 @@ static int mg_parse_pem(const struct mg_str pem, const struct mg_str label,
size_t n = 0, m = 0;
char *s;
const char *c;
struct mg_str caps[5];
struct mg_str caps[6]; // number of wildcards + 1
if (!mg_match(pem, mg_str("#-----BEGIN #-----#-----END #-----#"), caps)) {
*der = mg_strdup(pem);
return 0;
@ -1276,6 +1292,7 @@ void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
if (opts->name.len > 0) {
if (opts->name.len >= sizeof(tls->hostname) - 1) {
mg_error(c, "hostname too long");
return;
}
strncpy((char *) tls->hostname, opts->name.buf, sizeof(tls->hostname) - 1);
tls->hostname[opts->name.len] = 0;

View File

@ -95,6 +95,16 @@ static void test_match(void) {
ASSERT(mg_strcmp(caps[0], mg_str("a")) == 0);
ASSERT(mg_strcmp(caps[1], mg_str("bc")) == 0);
ASSERT(mg_strcmp(caps[2], mg_str("")) == 0);
ASSERT(mg_match(mg_str("a#c"), mg_str("?#"), caps) == true);
ASSERT(mg_strcmp(caps[0], mg_str("a")) == 0);
ASSERT(mg_strcmp(caps[1], mg_str("#c")) == 0);
ASSERT(mg_strcmp(caps[2], mg_str("")) == 0);
ASSERT(mg_match(mg_str("a*c"), mg_str("?*"), caps) == true);
ASSERT(mg_strcmp(caps[0], mg_str("a")) == 0);
ASSERT(mg_strcmp(caps[1], mg_str("*c")) == 0);
ASSERT(mg_strcmp(caps[2], mg_str("")) == 0);
}
}