Refactor tls receive pointer scheme

This commit is contained in:
Sergio R. Caprile 2024-07-23 17:02:29 -03:00
parent 4d6b126b9d
commit fe77075996
2 changed files with 98 additions and 80 deletions

View File

@ -9504,9 +9504,9 @@ struct tls_data {
enum mg_tls_hs_state state; // keep track of connection handshake progress
struct mg_iobuf send; // For the receive path, we're reusing c->rtls
struct mg_iobuf recv; // While c->rtls contains full records, recv reuses
// the same underlying buffer but points at individual
// decrypted messages
size_t recv_offset; // While c->rtls contains full records, reuse that
size_t recv_len; // buffer but point at individual decrypted messages
uint8_t content_type; // Last received record content type
mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake
@ -9653,16 +9653,17 @@ static void mg_tls_drop_record(struct mg_connection *c) {
static void mg_tls_drop_message(struct mg_connection *c) {
uint32_t len;
struct tls_data *tls = (struct tls_data *) c->tls;
if (tls->recv.len == 0) return;
len = MG_LOAD_BE24(tls->recv.buf + 1) + TLS_MSGHDR_SIZE;
if (tls->recv.len < len) {
unsigned char *recv_buf = &c->rtls.buf[tls->recv_offset];
if (tls->recv_len == 0) return;
len = MG_LOAD_BE24(recv_buf + 1) + TLS_MSGHDR_SIZE;
if (tls->recv_len < len) {
mg_error(c, "wrong size");
return;
}
mg_sha256_update(&tls->sha256, tls->recv.buf, len);
tls->recv.buf += len;
tls->recv.len -= len;
if (tls->recv.len == 0) {
mg_sha256_update(&tls->sha256, recv_buf, len);
tls->recv_offset += len;
tls->recv_len -= len;
if (tls->recv_len == 0) {
mg_tls_drop_record(c);
}
}
@ -9877,7 +9878,7 @@ static int mg_tls_recv_record(struct mg_connection *c) {
uint8_t *iv =
c->is_client ? tls->enc.server_write_iv : tls->enc.client_write_iv;
if (tls->recv.len > 0) {
if (tls->recv_len > 0) {
return 0; /* some data from previous record is still present */
}
for (;;) {
@ -9930,8 +9931,8 @@ static int mg_tls_recv_record(struct mg_connection *c) {
#endif
r = msgsz - 16 - 1;
tls->content_type = msg[msgsz - 16 - 1];
tls->recv.buf = msg;
tls->recv.size = tls->recv.len = msgsz - 16 - 1;
tls->recv_offset = (size_t) msg - (size_t) rio->buf;
tls->recv_len = msgsz - 16 - 1;
c->is_client ? tls->enc.sseq++ : tls->enc.cseq++;
return r;
}
@ -10193,6 +10194,7 @@ static void mg_tls_server_send_finish(struct mg_connection *c) {
static int mg_tls_server_recv_finish(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
// we have to backup sha256 value to restore it later, since Finished record
// is exceptional and is not supposed to be added to the rolling hash
// calculation.
@ -10200,8 +10202,9 @@ static int mg_tls_server_recv_finish(struct mg_connection *c) {
if (mg_tls_recv_record(c) < 0) {
return -1;
}
if (tls->recv.buf[0] != MG_TLS_FINISHED) {
mg_error(c, "expected Finish but got msg 0x%02x", tls->recv.buf[0]);
recv_buf = &c->rtls.buf[tls->recv_offset];
if (recv_buf[0] != MG_TLS_FINISHED) {
mg_error(c, "expected Finish but got msg 0x%02x", recv_buf[0]);
return -1;
}
mg_tls_drop_message(c);
@ -10377,12 +10380,13 @@ fail:
static int mg_tls_client_recv_ext(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
if (mg_tls_recv_record(c) < 0) {
return -1;
}
if (tls->recv.buf[0] != MG_TLS_ENCRYPTED_EXTENSIONS) {
mg_error(c, "expected server extensions but got msg 0x%02x",
tls->recv.buf[0]);
recv_buf = &c->rtls.buf[tls->recv_offset];
if (recv_buf[0] != MG_TLS_ENCRYPTED_EXTENSIONS) {
mg_error(c, "expected server extensions but got msg 0x%02x", recv_buf[0]);
return -1;
}
mg_tls_drop_message(c);
@ -10395,18 +10399,19 @@ static int mg_tls_client_recv_cert(struct mg_connection *c) {
struct mg_der_tlv oid, pubkey, seq, subj;
int subj_match = 0;
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
if (mg_tls_recv_record(c) < 0) {
return -1;
}
if (tls->recv.buf[0] == MG_TLS_CERTIFICATE_REQUEST) {
recv_buf = &c->rtls.buf[tls->recv_offset];
if (recv_buf[0] == MG_TLS_CERTIFICATE_REQUEST) {
MG_VERBOSE(("got certificate request"));
mg_tls_drop_message(c);
tls->cert_requested = 1;
return -1;
}
if (tls->recv.buf[0] != MG_TLS_CERTIFICATE) {
mg_error(c, "expected server certificate but got msg 0x%02x",
tls->recv.buf[0]);
if (recv_buf[0] != MG_TLS_CERTIFICATE) {
mg_error(c, "expected server certificate but got msg 0x%02x", recv_buf[0]);
return -1;
}
if (tls->skip_verification) {
@ -10414,15 +10419,15 @@ static int mg_tls_client_recv_cert(struct mg_connection *c) {
return 0;
}
if (tls->recv.len < 11) {
if (tls->recv_len < 11) {
mg_error(c, "certificate list too short");
return -1;
}
cert = tls->recv.buf + 11;
certsz = MG_LOAD_BE24(tls->recv.buf + 8);
if (certsz > tls->recv.len - 11) {
mg_error(c, "certificate too long: %d vs %d", certsz, tls->recv.len - 11);
cert = recv_buf + 11;
certsz = MG_LOAD_BE24(recv_buf + 8);
if (certsz > tls->recv_len - 11) {
mg_error(c, "certificate too long: %d vs %d", certsz, tls->recv_len - 11);
return -1;
}
@ -10496,12 +10501,13 @@ static int mg_tls_client_recv_cert(struct mg_connection *c) {
static int mg_tls_client_recv_cert_verify(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
if (mg_tls_recv_record(c) < 0) {
return -1;
}
if (tls->recv.buf[0] != MG_TLS_CERTIFICATE_VERIFY) {
mg_error(c, "expected server certificate verify but got msg 0x%02x",
tls->recv.buf[0]);
recv_buf = &c->rtls.buf[tls->recv_offset];
if (recv_buf[0] != MG_TLS_CERTIFICATE_VERIFY) {
mg_error(c, "expected server certificate verify but got msg 0x%02x", recv_buf[0]);
return -1;
}
// Ignore CertificateVerify is strict checks are not required
@ -10514,7 +10520,7 @@ static int mg_tls_client_recv_cert_verify(struct mg_connection *c) {
do {
uint8_t sig[64];
struct mg_der_tlv seq, a, b;
if (mg_der_to_tlv(tls->recv.buf + 8, tls->recv.len - 8, &seq) < 0) {
if (mg_der_to_tlv(recv_buf + 8, tls->recv_len - 8, &seq) < 0) {
mg_error(c, "verification message is not an ASN.1 DER sequence");
return -1;
}
@ -10552,12 +10558,13 @@ static int mg_tls_client_recv_cert_verify(struct mg_connection *c) {
static int mg_tls_client_recv_finish(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
if (mg_tls_recv_record(c) < 0) {
return -1;
}
if (tls->recv.buf[0] != MG_TLS_FINISHED) {
mg_error(c, "expected server finished but got msg 0x%02x",
tls->recv.buf[0]);
recv_buf = &c->rtls.buf[tls->recv_offset];
if (recv_buf[0] != MG_TLS_FINISHED) {
mg_error(c, "expected server finished but got msg 0x%02x", recv_buf[0]);
return -1;
}
mg_tls_drop_message(c);
@ -10803,23 +10810,25 @@ long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) {
long mg_tls_recv(struct mg_connection *c, void *buf, size_t len) {
int r = 0;
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
size_t minlen;
r = mg_tls_recv_record(c);
if (r < 0) {
return r;
}
recv_buf = &c->rtls.buf[tls->recv_offset];
if (tls->content_type != MG_TLS_APP_DATA) {
tls->recv.len = 0;
tls->recv_len = 0;
mg_tls_drop_record(c);
return MG_IO_WAIT;
}
minlen = len < tls->recv.len ? len : tls->recv.len;
memmove(buf, tls->recv.buf, minlen);
tls->recv.buf += minlen;
tls->recv.len -= minlen;
if (tls->recv.len == 0) {
minlen = len < tls->recv_len ? len : tls->recv_len;
memmove(buf, recv_buf, minlen);
tls->recv_offset += minlen;
tls->recv_len -= minlen;
if (tls->recv_len == 0) {
mg_tls_drop_record(c);
}
return (long) minlen;

View File

@ -65,9 +65,9 @@ struct tls_data {
enum mg_tls_hs_state state; // keep track of connection handshake progress
struct mg_iobuf send; // For the receive path, we're reusing c->rtls
struct mg_iobuf recv; // While c->rtls contains full records, recv reuses
// the same underlying buffer but points at individual
// decrypted messages
size_t recv_offset; // While c->rtls contains full records, reuse that
size_t recv_len; // buffer but point at individual decrypted messages
uint8_t content_type; // Last received record content type
mg_sha256_ctx sha256; // incremental SHA-256 hash for TLS handshake
@ -214,16 +214,17 @@ static void mg_tls_drop_record(struct mg_connection *c) {
static void mg_tls_drop_message(struct mg_connection *c) {
uint32_t len;
struct tls_data *tls = (struct tls_data *) c->tls;
if (tls->recv.len == 0) return;
len = MG_LOAD_BE24(tls->recv.buf + 1) + TLS_MSGHDR_SIZE;
if (tls->recv.len < len) {
unsigned char *recv_buf = &c->rtls.buf[tls->recv_offset];
if (tls->recv_len == 0) return;
len = MG_LOAD_BE24(recv_buf + 1) + TLS_MSGHDR_SIZE;
if (tls->recv_len < len) {
mg_error(c, "wrong size");
return;
}
mg_sha256_update(&tls->sha256, tls->recv.buf, len);
tls->recv.buf += len;
tls->recv.len -= len;
if (tls->recv.len == 0) {
mg_sha256_update(&tls->sha256, recv_buf, len);
tls->recv_offset += len;
tls->recv_len -= len;
if (tls->recv_len == 0) {
mg_tls_drop_record(c);
}
}
@ -438,7 +439,7 @@ static int mg_tls_recv_record(struct mg_connection *c) {
uint8_t *iv =
c->is_client ? tls->enc.server_write_iv : tls->enc.client_write_iv;
if (tls->recv.len > 0) {
if (tls->recv_len > 0) {
return 0; /* some data from previous record is still present */
}
for (;;) {
@ -491,8 +492,8 @@ static int mg_tls_recv_record(struct mg_connection *c) {
#endif
r = msgsz - 16 - 1;
tls->content_type = msg[msgsz - 16 - 1];
tls->recv.buf = msg;
tls->recv.size = tls->recv.len = msgsz - 16 - 1;
tls->recv_offset = (size_t) msg - (size_t) rio->buf;
tls->recv_len = msgsz - 16 - 1;
c->is_client ? tls->enc.sseq++ : tls->enc.cseq++;
return r;
}
@ -754,6 +755,7 @@ static void mg_tls_server_send_finish(struct mg_connection *c) {
static int mg_tls_server_recv_finish(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
// we have to backup sha256 value to restore it later, since Finished record
// is exceptional and is not supposed to be added to the rolling hash
// calculation.
@ -761,8 +763,9 @@ static int mg_tls_server_recv_finish(struct mg_connection *c) {
if (mg_tls_recv_record(c) < 0) {
return -1;
}
if (tls->recv.buf[0] != MG_TLS_FINISHED) {
mg_error(c, "expected Finish but got msg 0x%02x", tls->recv.buf[0]);
recv_buf = &c->rtls.buf[tls->recv_offset];
if (recv_buf[0] != MG_TLS_FINISHED) {
mg_error(c, "expected Finish but got msg 0x%02x", recv_buf[0]);
return -1;
}
mg_tls_drop_message(c);
@ -938,12 +941,13 @@ fail:
static int mg_tls_client_recv_ext(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
if (mg_tls_recv_record(c) < 0) {
return -1;
}
if (tls->recv.buf[0] != MG_TLS_ENCRYPTED_EXTENSIONS) {
mg_error(c, "expected server extensions but got msg 0x%02x",
tls->recv.buf[0]);
recv_buf = &c->rtls.buf[tls->recv_offset];
if (recv_buf[0] != MG_TLS_ENCRYPTED_EXTENSIONS) {
mg_error(c, "expected server extensions but got msg 0x%02x", recv_buf[0]);
return -1;
}
mg_tls_drop_message(c);
@ -956,18 +960,19 @@ static int mg_tls_client_recv_cert(struct mg_connection *c) {
struct mg_der_tlv oid, pubkey, seq, subj;
int subj_match = 0;
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
if (mg_tls_recv_record(c) < 0) {
return -1;
}
if (tls->recv.buf[0] == MG_TLS_CERTIFICATE_REQUEST) {
recv_buf = &c->rtls.buf[tls->recv_offset];
if (recv_buf[0] == MG_TLS_CERTIFICATE_REQUEST) {
MG_VERBOSE(("got certificate request"));
mg_tls_drop_message(c);
tls->cert_requested = 1;
return -1;
}
if (tls->recv.buf[0] != MG_TLS_CERTIFICATE) {
mg_error(c, "expected server certificate but got msg 0x%02x",
tls->recv.buf[0]);
if (recv_buf[0] != MG_TLS_CERTIFICATE) {
mg_error(c, "expected server certificate but got msg 0x%02x", recv_buf[0]);
return -1;
}
if (tls->skip_verification) {
@ -975,15 +980,15 @@ static int mg_tls_client_recv_cert(struct mg_connection *c) {
return 0;
}
if (tls->recv.len < 11) {
if (tls->recv_len < 11) {
mg_error(c, "certificate list too short");
return -1;
}
cert = tls->recv.buf + 11;
certsz = MG_LOAD_BE24(tls->recv.buf + 8);
if (certsz > tls->recv.len - 11) {
mg_error(c, "certificate too long: %d vs %d", certsz, tls->recv.len - 11);
cert = recv_buf + 11;
certsz = MG_LOAD_BE24(recv_buf + 8);
if (certsz > tls->recv_len - 11) {
mg_error(c, "certificate too long: %d vs %d", certsz, tls->recv_len - 11);
return -1;
}
@ -1057,12 +1062,13 @@ static int mg_tls_client_recv_cert(struct mg_connection *c) {
static int mg_tls_client_recv_cert_verify(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
if (mg_tls_recv_record(c) < 0) {
return -1;
}
if (tls->recv.buf[0] != MG_TLS_CERTIFICATE_VERIFY) {
mg_error(c, "expected server certificate verify but got msg 0x%02x",
tls->recv.buf[0]);
recv_buf = &c->rtls.buf[tls->recv_offset];
if (recv_buf[0] != MG_TLS_CERTIFICATE_VERIFY) {
mg_error(c, "expected server certificate verify but got msg 0x%02x", recv_buf[0]);
return -1;
}
// Ignore CertificateVerify is strict checks are not required
@ -1075,7 +1081,7 @@ static int mg_tls_client_recv_cert_verify(struct mg_connection *c) {
do {
uint8_t sig[64];
struct mg_der_tlv seq, a, b;
if (mg_der_to_tlv(tls->recv.buf + 8, tls->recv.len - 8, &seq) < 0) {
if (mg_der_to_tlv(recv_buf + 8, tls->recv_len - 8, &seq) < 0) {
mg_error(c, "verification message is not an ASN.1 DER sequence");
return -1;
}
@ -1113,12 +1119,13 @@ static int mg_tls_client_recv_cert_verify(struct mg_connection *c) {
static int mg_tls_client_recv_finish(struct mg_connection *c) {
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
if (mg_tls_recv_record(c) < 0) {
return -1;
}
if (tls->recv.buf[0] != MG_TLS_FINISHED) {
mg_error(c, "expected server finished but got msg 0x%02x",
tls->recv.buf[0]);
recv_buf = &c->rtls.buf[tls->recv_offset];
if (recv_buf[0] != MG_TLS_FINISHED) {
mg_error(c, "expected server finished but got msg 0x%02x", recv_buf[0]);
return -1;
}
mg_tls_drop_message(c);
@ -1364,23 +1371,25 @@ long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) {
long mg_tls_recv(struct mg_connection *c, void *buf, size_t len) {
int r = 0;
struct tls_data *tls = (struct tls_data *) c->tls;
unsigned char *recv_buf;
size_t minlen;
r = mg_tls_recv_record(c);
if (r < 0) {
return r;
}
recv_buf = &c->rtls.buf[tls->recv_offset];
if (tls->content_type != MG_TLS_APP_DATA) {
tls->recv.len = 0;
tls->recv_len = 0;
mg_tls_drop_record(c);
return MG_IO_WAIT;
}
minlen = len < tls->recv.len ? len : tls->recv.len;
memmove(buf, tls->recv.buf, minlen);
tls->recv.buf += minlen;
tls->recv.len -= minlen;
if (tls->recv.len == 0) {
minlen = len < tls->recv_len ? len : tls->recv_len;
memmove(buf, recv_buf, minlen);
tls->recv_offset += minlen;
tls->recv_len -= minlen;
if (tls->recv_len == 0) {
mg_tls_drop_record(c);
}
return (long) minlen;