Implement client support for TLS-PSK

For both OpenSSL and mbedTLS

PUBLISHED_FROM=0bfd5f128b4c4c062cb6f0ca0da9b30790aa8bf8
This commit is contained in:
Deomid Ryabkov 2017-02-15 08:05:37 +00:00 committed by Cesanta Bot
parent d6d956b9d8
commit d4b23f08b6
3 changed files with 167 additions and 11 deletions

View File

@ -44,6 +44,15 @@ signature: |
* name verification. * name verification.
*/ */
const char *ssl_server_name; const char *ssl_server_name;
/*
* PSK identity and key. Identity is a NUL-terminated string and key is a hex
* string. Key must be either 16 or 32 bytes (32 or 64 hex digits) for AES-128
* or AES-256 respectively.
* Note: Default list of cipher suites does not include PSK suites, if you
* want to use PSK you will need to set ssl_cipher_suites as well.
*/
const char *ssl_psk_identity;
const char *ssl_psk_key;
#endif #endif
}; };
--- ---

View File

@ -2678,7 +2678,8 @@ struct mg_connection *mg_connect_opt(struct mg_mgr *mgr, const char *address,
(opts.ssl_key ? opts.ssl_key : "-"), (opts.ssl_key ? opts.ssl_key : "-"),
(opts.ssl_ca_cert ? opts.ssl_ca_cert : "-"))); (opts.ssl_ca_cert ? opts.ssl_ca_cert : "-")));
if (opts.ssl_cert != NULL || opts.ssl_ca_cert != NULL) { if (opts.ssl_cert != NULL || opts.ssl_ca_cert != NULL ||
opts.ssl_psk_identity != NULL) {
const char *err_msg = NULL; const char *err_msg = NULL;
struct mg_ssl_if_conn_params params; struct mg_ssl_if_conn_params params;
if (nc->flags & MG_F_UDP) { if (nc->flags & MG_F_UDP) {
@ -2691,6 +2692,8 @@ struct mg_connection *mg_connect_opt(struct mg_mgr *mgr, const char *address,
params.key = opts.ssl_key; params.key = opts.ssl_key;
params.ca_cert = opts.ssl_ca_cert; params.ca_cert = opts.ssl_ca_cert;
params.cipher_suites = opts.ssl_cipher_suites; params.cipher_suites = opts.ssl_cipher_suites;
params.psk_identity = opts.ssl_psk_identity;
params.psk_key = opts.ssl_psk_key;
if (opts.ssl_ca_cert != NULL) { if (opts.ssl_ca_cert != NULL) {
if (opts.ssl_server_name != NULL) { if (opts.ssl_server_name != NULL) {
if (strcmp(opts.ssl_server_name, "*") != 0) { if (strcmp(opts.ssl_server_name, "*") != 0) {
@ -3983,6 +3986,8 @@ struct mg_iface_vtable mg_tun_iface_vtable = MG_TUN_IFACE_VTABLE;
struct mg_ssl_if_ctx { struct mg_ssl_if_ctx {
SSL *ssl; SSL *ssl;
SSL_CTX *ssl_ctx; SSL_CTX *ssl_ctx;
struct mbuf psk;
size_t identity_len;
}; };
void mg_ssl_if_init() { void mg_ssl_if_init() {
@ -4007,6 +4012,9 @@ static enum mg_ssl_if_result mg_use_cert(SSL_CTX *ctx, const char *cert,
const char *key, const char **err_msg); const char *key, const char **err_msg);
static enum mg_ssl_if_result mg_use_ca_cert(SSL_CTX *ctx, const char *cert); static enum mg_ssl_if_result mg_use_ca_cert(SSL_CTX *ctx, const char *cert);
static enum mg_ssl_if_result mg_set_cipher_list(SSL_CTX *ctx, const char *cl); static enum mg_ssl_if_result mg_set_cipher_list(SSL_CTX *ctx, const char *cl);
static enum mg_ssl_if_result mg_ssl_if_ossl_set_psk(struct mg_ssl_if_ctx *ctx,
const char *identity,
const char *key_str);
enum mg_ssl_if_result mg_ssl_if_conn_init( enum mg_ssl_if_result mg_ssl_if_conn_init(
struct mg_connection *nc, const struct mg_ssl_if_conn_params *params, struct mg_connection *nc, const struct mg_ssl_if_conn_params *params,
@ -4056,6 +4064,13 @@ enum mg_ssl_if_result mg_ssl_if_conn_init(
return MG_SSL_ERROR; return MG_SSL_ERROR;
} }
mbuf_init(&ctx->psk, 0);
if (mg_ssl_if_ossl_set_psk(ctx, params->psk_identity, params->psk_key) !=
MG_SSL_OK) {
MG_SET_PTRPTR(err_msg, "Invalid PSK settings");
return MG_SSL_ERROR;
}
if (!(nc->flags & MG_F_LISTENING) && if (!(nc->flags & MG_F_LISTENING) &&
(ctx->ssl = SSL_new(ctx->ssl_ctx)) == NULL) { (ctx->ssl = SSL_new(ctx->ssl_ctx)) == NULL) {
MG_SET_PTRPTR(err_msg, "Failed to create SSL session"); MG_SET_PTRPTR(err_msg, "Failed to create SSL session");
@ -4114,6 +4129,7 @@ void mg_ssl_if_conn_free(struct mg_connection *nc) {
nc->ssl_if_data = NULL; nc->ssl_if_data = NULL;
if (ctx->ssl != NULL) SSL_free(ctx->ssl); if (ctx->ssl != NULL) SSL_free(ctx->ssl);
if (ctx->ssl_ctx != NULL && nc->listener == NULL) SSL_CTX_free(ctx->ssl_ctx); if (ctx->ssl_ctx != NULL && nc->listener == NULL) SSL_CTX_free(ctx->ssl_ctx);
mbuf_free(&ctx->psk);
memset(ctx, 0, sizeof(*ctx)); memset(ctx, 0, sizeof(*ctx));
MG_FREE(ctx); MG_FREE(ctx);
} }
@ -4242,6 +4258,78 @@ static enum mg_ssl_if_result mg_set_cipher_list(SSL_CTX *ctx, const char *cl) {
: MG_SSL_ERROR); : MG_SSL_ERROR);
} }
#ifndef KR_VERSION
static unsigned int mg_ssl_if_ossl_psk_cb(SSL *ssl, const char *hint,
char *identity,
unsigned int max_identity_len,
unsigned char *psk,
unsigned int max_psk_len) {
struct mg_ssl_if_ctx *ctx =
(struct mg_ssl_if_ctx *) ssl->ctx->msg_callback_arg;
size_t key_len = ctx->psk.len - ctx->identity_len - 1;
DBG(("hint: '%s'", (hint ? hint : "")));
if (ctx->identity_len + 1 > max_identity_len) {
DBG(("identity too long"));
return 0;
}
if (key_len > max_psk_len) {
DBG(("key too long"));
return 0;
}
memcpy(identity, ctx->psk.buf, ctx->identity_len + 1);
memcpy(psk, ctx->psk.buf + ctx->identity_len + 1, key_len);
(void) ssl;
return key_len;
}
static enum mg_ssl_if_result mg_ssl_if_ossl_set_psk(struct mg_ssl_if_ctx *ctx,
const char *identity,
const char *key_str) {
unsigned char key[32];
size_t key_len;
size_t i = 0;
if (identity == NULL && key_str == NULL) return MG_SSL_OK;
if (identity == NULL || key_str == NULL) return MG_SSL_ERROR;
key_len = strlen(key_str);
if (key_len != 32 && key_len != 64) return MG_SSL_ERROR;
memset(key, 0, sizeof(key));
key_len = 0;
for (i = 0; key_str[i] != '\0'; i++) {
unsigned char c;
char hc = tolower((int) key_str[i]);
if (hc >= '0' && hc <= '9') {
c = hc - '0';
} else if (hc >= 'a' && hc <= 'f') {
c = hc - 'a' + 0xa;
} else {
return MG_SSL_ERROR;
}
key_len = i / 2;
key[key_len] <<= 4;
key[key_len] |= c;
}
key_len++;
DBG(("identity = '%s', key = (%u)", identity, (unsigned int) key_len));
ctx->identity_len = strlen(identity);
mbuf_append(&ctx->psk, identity, ctx->identity_len + 1);
mbuf_append(&ctx->psk, key, key_len);
SSL_CTX_set_psk_client_callback(ctx->ssl_ctx, mg_ssl_if_ossl_psk_cb);
/* Hack: there is no field for us to keep this, so we use msg_callback_arg */
ctx->ssl_ctx->msg_callback_arg = ctx;
return MG_SSL_OK;
}
#else
static enum mg_ssl_if_result mg_ssl_if_ossl_set_psk(struct mg_ssl_if_ctx *ctx,
const char *identity,
const char *key_str) {
(void) ctx;
(void) identity;
(void) key_str;
/* Krypton does not support PSK. */
return MG_SSL_ERROR;
}
#endif /* defined(KR_VERSION) */
const char *mg_set_ssl(struct mg_connection *nc, const char *cert, const char *mg_set_ssl(struct mg_connection *nc, const char *cert,
const char *ca_cert) { const char *ca_cert) {
const char *err_msg = NULL; const char *err_msg = NULL;
@ -4314,7 +4402,7 @@ enum mg_ssl_if_result mg_ssl_if_conn_accept(struct mg_connection *nc,
struct mg_ssl_if_ctx *lc_ctx = (struct mg_ssl_if_ctx *) lc->ssl_if_data; struct mg_ssl_if_ctx *lc_ctx = (struct mg_ssl_if_ctx *) lc->ssl_if_data;
nc->ssl_if_data = ctx; nc->ssl_if_data = ctx;
if (ctx == NULL || lc_ctx == NULL) return MG_SSL_ERROR; if (ctx == NULL || lc_ctx == NULL) return MG_SSL_ERROR;
ctx->ssl = MG_CALLOC(1, sizeof(*ctx->ssl)); ctx->ssl = (mbedtls_ssl_context *) MG_CALLOC(1, sizeof(*ctx->ssl));
if (mbedtls_ssl_setup(ctx->ssl, lc_ctx->conf) != 0) { if (mbedtls_ssl_setup(ctx->ssl, lc_ctx->conf) != 0) {
return MG_SSL_ERROR; return MG_SSL_ERROR;
} }
@ -4328,6 +4416,9 @@ static enum mg_ssl_if_result mg_use_ca_cert(struct mg_ssl_if_ctx *ctx,
const char *cert); const char *cert);
static enum mg_ssl_if_result mg_set_cipher_list(struct mg_ssl_if_ctx *ctx, static enum mg_ssl_if_result mg_set_cipher_list(struct mg_ssl_if_ctx *ctx,
const char *ciphers); const char *ciphers);
static enum mg_ssl_if_result mg_ssl_if_mbed_set_psk(struct mg_ssl_if_ctx *ctx,
const char *identity,
const char *key);
enum mg_ssl_if_result mg_ssl_if_conn_init( enum mg_ssl_if_result mg_ssl_if_conn_init(
struct mg_connection *nc, const struct mg_ssl_if_conn_params *params, struct mg_connection *nc, const struct mg_ssl_if_conn_params *params,
@ -4343,7 +4434,7 @@ enum mg_ssl_if_result mg_ssl_if_conn_init(
return MG_SSL_ERROR; return MG_SSL_ERROR;
} }
nc->ssl_if_data = ctx; nc->ssl_if_data = ctx;
ctx->conf = MG_CALLOC(1, sizeof(*ctx->conf)); ctx->conf = (mbedtls_ssl_config *) MG_CALLOC(1, sizeof(*ctx->conf));
mbuf_init(&ctx->cipher_suites, 0); mbuf_init(&ctx->cipher_suites, 0);
mbedtls_ssl_config_init(ctx->conf); mbedtls_ssl_config_init(ctx->conf);
mbedtls_ssl_conf_dbg(ctx->conf, mg_ssl_mbed_log, nc); mbedtls_ssl_conf_dbg(ctx->conf, mg_ssl_mbed_log, nc);
@ -4354,6 +4445,7 @@ enum mg_ssl_if_result mg_ssl_if_conn_init(
MG_SET_PTRPTR(err_msg, "Failed to init SSL config"); MG_SET_PTRPTR(err_msg, "Failed to init SSL config");
return MG_SSL_ERROR; return MG_SSL_ERROR;
} }
/* TLS 1.2 and up */ /* TLS 1.2 and up */
mbedtls_ssl_conf_min_version(ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, mbedtls_ssl_conf_min_version(ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3,
MBEDTLS_SSL_MINOR_VERSION_3); MBEDTLS_SSL_MINOR_VERSION_3);
@ -4375,8 +4467,14 @@ enum mg_ssl_if_result mg_ssl_if_conn_init(
return MG_SSL_ERROR; return MG_SSL_ERROR;
} }
if (mg_ssl_if_mbed_set_psk(ctx, params->psk_identity, params->psk_key) !=
MG_SSL_OK) {
MG_SET_PTRPTR(err_msg, "Invalid PSK settings");
return MG_SSL_ERROR;
}
if (!(nc->flags & MG_F_LISTENING)) { if (!(nc->flags & MG_F_LISTENING)) {
ctx->ssl = MG_CALLOC(1, sizeof(*ctx->ssl)); ctx->ssl = (mbedtls_ssl_context *) MG_CALLOC(1, sizeof(*ctx->ssl));
mbedtls_ssl_init(ctx->ssl); mbedtls_ssl_init(ctx->ssl);
if (mbedtls_ssl_setup(ctx->ssl, ctx->conf) != 0) { if (mbedtls_ssl_setup(ctx->ssl, ctx->conf) != 0) {
MG_SET_PTRPTR(err_msg, "Failed to create SSL session"); MG_SET_PTRPTR(err_msg, "Failed to create SSL session");
@ -4497,7 +4595,7 @@ enum mg_ssl_if_result mg_ssl_if_handshake(struct mg_connection *nc) {
int mg_ssl_if_read(struct mg_connection *nc, void *buf, size_t buf_size) { int mg_ssl_if_read(struct mg_connection *nc, void *buf, size_t buf_size) {
struct mg_ssl_if_ctx *ctx = (struct mg_ssl_if_ctx *) nc->ssl_if_data; struct mg_ssl_if_ctx *ctx = (struct mg_ssl_if_ctx *) nc->ssl_if_data;
int n = mbedtls_ssl_read(ctx->ssl, buf, buf_size); int n = mbedtls_ssl_read(ctx->ssl, (unsigned char *) buf, buf_size);
DBG(("%p %d -> %d", nc, (int) buf_size, n)); DBG(("%p %d -> %d", nc, (int) buf_size, n));
if (n < 0) return mg_ssl_if_mbed_err(nc, n); if (n < 0) return mg_ssl_if_mbed_err(nc, n);
if (n == 0) nc->flags |= MG_F_CLOSE_IMMEDIATELY; if (n == 0) nc->flags |= MG_F_CLOSE_IMMEDIATELY;
@ -4506,7 +4604,7 @@ int mg_ssl_if_read(struct mg_connection *nc, void *buf, size_t buf_size) {
int mg_ssl_if_write(struct mg_connection *nc, const void *data, size_t len) { int mg_ssl_if_write(struct mg_connection *nc, const void *data, size_t len) {
struct mg_ssl_if_ctx *ctx = (struct mg_ssl_if_ctx *) nc->ssl_if_data; struct mg_ssl_if_ctx *ctx = (struct mg_ssl_if_ctx *) nc->ssl_if_data;
int n = mbedtls_ssl_write(ctx->ssl, data, len); int n = mbedtls_ssl_write(ctx->ssl, (const unsigned char *) data, len);
DBG(("%p %d -> %d", nc, (int) len, n)); DBG(("%p %d -> %d", nc, (int) len, n));
if (n < 0) return mg_ssl_if_mbed_err(nc, n); if (n < 0) return mg_ssl_if_mbed_err(nc, n);
return n; return n;
@ -4535,7 +4633,7 @@ static enum mg_ssl_if_result mg_use_ca_cert(struct mg_ssl_if_ctx *ctx,
if (ca_cert == NULL || strcmp(ca_cert, "*") == 0) { if (ca_cert == NULL || strcmp(ca_cert, "*") == 0) {
return MG_SSL_OK; return MG_SSL_OK;
} }
ctx->ca_cert = MG_CALLOC(1, sizeof(*ctx->ca_cert)); ctx->ca_cert = (mbedtls_x509_crt *) MG_CALLOC(1, sizeof(*ctx->ca_cert));
mbedtls_x509_crt_init(ctx->ca_cert); mbedtls_x509_crt_init(ctx->ca_cert);
if (mbedtls_x509_crt_parse_file(ctx->ca_cert, ca_cert) != 0) { if (mbedtls_x509_crt_parse_file(ctx->ca_cert, ca_cert) != 0) {
return MG_SSL_ERROR; return MG_SSL_ERROR;
@ -4552,9 +4650,9 @@ static enum mg_ssl_if_result mg_use_cert(struct mg_ssl_if_ctx *ctx,
if (cert == NULL || cert[0] == '\0' || key == NULL || key[0] == '\0') { if (cert == NULL || cert[0] == '\0' || key == NULL || key[0] == '\0') {
return MG_SSL_OK; return MG_SSL_OK;
} }
ctx->cert = MG_CALLOC(1, sizeof(*ctx->cert)); ctx->cert = (mbedtls_x509_crt *) MG_CALLOC(1, sizeof(*ctx->cert));
mbedtls_x509_crt_init(ctx->cert); mbedtls_x509_crt_init(ctx->cert);
ctx->key = MG_CALLOC(1, sizeof(*ctx->key)); ctx->key = (mbedtls_pk_context *) MG_CALLOC(1, sizeof(*ctx->key));
mbedtls_pk_init(ctx->key); mbedtls_pk_init(ctx->key);
if (mbedtls_x509_crt_parse_file(ctx->cert, cert) != 0) { if (mbedtls_x509_crt_parse_file(ctx->cert, cert) != 0) {
MG_SET_PTRPTR(err_msg, "Invalid SSL cert"); MG_SET_PTRPTR(err_msg, "Invalid SSL cert");
@ -4596,8 +4694,8 @@ static enum mg_ssl_if_result mg_set_cipher_list(struct mg_ssl_if_ctx *ctx,
const char *ciphers) { const char *ciphers) {
if (ciphers != NULL) { if (ciphers != NULL) {
int l, id; int l, id;
const char *s = ciphers; const char *s = ciphers, *e;
char *e, tmp[50]; char tmp[50];
while (s != NULL) { while (s != NULL) {
e = strchr(s, ':'); e = strchr(s, ':');
l = (e != NULL ? (e - s) : (int) strlen(s)); l = (e != NULL ? (e - s) : (int) strlen(s));
@ -4613,6 +4711,7 @@ static enum mg_ssl_if_result mg_set_cipher_list(struct mg_ssl_if_ctx *ctx,
if (ctx->cipher_suites.len == 0) return MG_SSL_ERROR; if (ctx->cipher_suites.len == 0) return MG_SSL_ERROR;
id = 0; id = 0;
mbuf_append(&ctx->cipher_suites, &id, sizeof(id)); mbuf_append(&ctx->cipher_suites, &id, sizeof(id));
mbuf_trim(&ctx->cipher_suites);
mbedtls_ssl_conf_ciphersuites(ctx->conf, mbedtls_ssl_conf_ciphersuites(ctx->conf,
(const int *) ctx->cipher_suites.buf); (const int *) ctx->cipher_suites.buf);
} else { } else {
@ -4621,6 +4720,43 @@ static enum mg_ssl_if_result mg_set_cipher_list(struct mg_ssl_if_ctx *ctx,
return MG_SSL_OK; return MG_SSL_OK;
} }
static enum mg_ssl_if_result mg_ssl_if_mbed_set_psk(struct mg_ssl_if_ctx *ctx,
const char *identity,
const char *key_str) {
unsigned char key[32];
size_t key_len;
if (identity == NULL && key_str == NULL) return MG_SSL_OK;
if (identity == NULL || key_str == NULL) return MG_SSL_ERROR;
key_len = strlen(key_str);
if (key_len != 32 && key_len != 64) return MG_SSL_ERROR;
size_t i = 0;
memset(key, 0, sizeof(key));
key_len = 0;
for (i = 0; key_str[i] != '\0'; i++) {
unsigned char c;
char hc = tolower((int) key_str[i]);
if (hc >= '0' && hc <= '9') {
c = hc - '0';
} else if (hc >= 'a' && hc <= 'f') {
c = hc - 'a' + 0xa;
} else {
return MG_SSL_ERROR;
}
key_len = i / 2;
key[key_len] <<= 4;
key[key_len] |= c;
}
key_len++;
DBG(("identity = '%s', key = (%u)", identity, (unsigned int) key_len));
/* mbedTLS makes copies of psk and identity. */
if (mbedtls_ssl_conf_psk(ctx->conf, (const unsigned char *) key, key_len,
(const unsigned char *) identity,
strlen(identity)) != 0) {
return MG_SSL_ERROR;
}
return MG_SSL_OK;
}
const char *mg_set_ssl(struct mg_connection *nc, const char *cert, const char *mg_set_ssl(struct mg_connection *nc, const char *cert,
const char *ca_cert) { const char *ca_cert) {
const char *err_msg = NULL; const char *err_msg = NULL;

View File

@ -3127,6 +3127,8 @@ struct mg_ssl_if_conn_params {
const char *ca_cert; const char *ca_cert;
const char *server_name; const char *server_name;
const char *cipher_suites; const char *cipher_suites;
const char *psk_identity;
const char *psk_key;
}; };
enum mg_ssl_if_result mg_ssl_if_conn_init( enum mg_ssl_if_result mg_ssl_if_conn_init(
@ -3544,6 +3546,15 @@ struct mg_connect_opts {
* name verification. * name verification.
*/ */
const char *ssl_server_name; const char *ssl_server_name;
/*
* PSK identity and key. Identity is a NUL-terminated string and key is a hex
* string. Key must be either 16 or 32 bytes (32 or 64 hex digits) for AES-128
* or AES-256 respectively.
* Note: Default list of cipher suites does not include PSK suites, if you
* want to use PSK you will need to set ssl_cipher_suites as well.
*/
const char *ssl_psk_identity;
const char *ssl_psk_key;
#endif #endif
}; };