diff --git a/docs/README.md b/docs/README.md index cfdbb463..c5644bc0 100644 --- a/docs/README.md +++ b/docs/README.md @@ -806,10 +806,13 @@ Create client Websocket connection. ### mg\_ws\_upgrade() ```c -void mg_ws_upgrade(struct mg_connection *, struct mg_http_message *); +void mg_ws_upgrade(struct mg_connection *, struct mg_http_message *, + const char *fmt, ...); ``` -Upgrade given HTTP connection to Websocket. +Upgrade given HTTP connection to Websocket. The `fmt` is a printf-like +format string for the extra HTTP headers returned to the client in a +Websocket handshake. Set `fmt` to `NULL` if no extra headers needs to be passed. ### mg\_ws\_send() diff --git a/examples/timers/main.c b/examples/timers/main.c index ae0104ce..f00bf200 100644 --- a/examples/timers/main.c +++ b/examples/timers/main.c @@ -18,7 +18,7 @@ static const char *s_listen_on = "http://localhost:8000"; static void fn(struct mg_connection *c, int ev, void *ev_data, void *fn_data) { if (ev == MG_EV_HTTP_MSG) { struct mg_http_message *hm = (struct mg_http_message *) ev_data; - mg_ws_upgrade(c, hm); + mg_ws_upgrade(c, hm, NULL); } else if (ev == MG_EV_WS_MSG) { // Got websocket frame. Received data is wm->data. Echo it back! struct mg_ws_message *wm = (struct mg_ws_message *) ev_data; @@ -42,9 +42,9 @@ static void timer_fn(void *arg) { } int main(void) { - struct mg_mgr mgr; // Event manager - struct mg_timer t1; // Timer - mg_mgr_init(&mgr); // Initialise event manager + struct mg_mgr mgr; // Event manager + struct mg_timer t1; // Timer + mg_mgr_init(&mgr); // Initialise event manager mg_timer_init(&t1, 300, MG_TIMER_REPEAT, timer_fn, &mgr); // Init timer mg_http_listen(&mgr, s_listen_on, fn, NULL); // Create HTTP listener for (;;) mg_mgr_poll(&mgr, 1000); // Infinite event loop diff --git a/examples/websocket-server/main.c b/examples/websocket-server/main.c index d9df7cc0..80c279d1 100644 --- a/examples/websocket-server/main.c +++ b/examples/websocket-server/main.c @@ -21,7 +21,7 @@ static void fn(struct mg_connection *c, int ev, void *ev_data, void *fn_data) { if (mg_http_match_uri(hm, "/websocket")) { // Upgrade to websocket. From now on, a connection is a full-duplex // Websocket connection, which will receive MG_EV_WS_MSG events. - mg_ws_upgrade(c, hm); + mg_ws_upgrade(c, hm, NULL); } else if (mg_http_match_uri(hm, "/rest")) { // Serve REST response mg_http_reply(c, 200, "", "{\"result\": %d}\n", 123); diff --git a/mongoose.c b/mongoose.c index 5fae3525..ec4cdf48 100644 --- a/mongoose.c +++ b/mongoose.c @@ -3967,25 +3967,27 @@ struct ws_msg { }; static void ws_handshake(struct mg_connection *c, const char *key, - size_t key_len) { + size_t key_len, const char *fmt, va_list ap) { const char *magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; unsigned char sha[20], b64_sha[30]; - // mem[256], *buf = mem; - // int len = 0; + char mem[128], *buf = mem; + mg_sha1_ctx sha_ctx; mg_sha1_init(&sha_ctx); mg_sha1_update(&sha_ctx, (unsigned char *) key, key_len); mg_sha1_update(&sha_ctx, (unsigned char *) magic, 36); mg_sha1_final(sha, &sha_ctx); mg_base64_encode(sha, sizeof(sha), (char *) b64_sha); + buf[0] = '\0'; + if (fmt != NULL) mg_vasprintf(&buf, sizeof(mem), fmt, ap); mg_printf(c, "HTTP/1.1 101 Switching Protocols\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept: %s\r\n\r\n", - b64_sha); - // mg_send(c, buf, len); - // if (buf != mem) free(buf); + "Sec-WebSocket-Accept: %s\r\n" + "%s\r\n", + b64_sha, buf); + if (buf != mem) free(buf); } static size_t ws_process(uint8_t *buf, size_t len, struct ws_msg *msg) { @@ -4145,9 +4147,15 @@ struct mg_connection *mg_ws_connect(struct mg_mgr *mgr, const char *url, return c; } -void mg_ws_upgrade(struct mg_connection *c, struct mg_http_message *hm) { +void mg_ws_upgrade(struct mg_connection *c, struct mg_http_message *hm, + const char *fmt, ...) { struct mg_str *wskey = mg_http_get_header(hm, "Sec-WebSocket-Key"); c->pfn = mg_ws_cb; - if (wskey != NULL) ws_handshake(c, wskey->ptr, wskey->len); + if (wskey != NULL) { + va_list ap; + va_start(ap, fmt); + ws_handshake(c, wskey->ptr, wskey->len, fmt, ap); + va_end(ap); + } c->is_websocket = 1; } diff --git a/mongoose.h b/mongoose.h index 8da88991..7fa8f33a 100644 --- a/mongoose.h +++ b/mongoose.h @@ -740,7 +740,8 @@ struct mg_ws_message { struct mg_connection *mg_ws_connect(struct mg_mgr *, const char *url, mg_event_handler_t fn, void *fn_data, const char *fmt, ...); -void mg_ws_upgrade(struct mg_connection *, struct mg_http_message *); +void mg_ws_upgrade(struct mg_connection *, struct mg_http_message *, + const char *fmt, ...); size_t mg_ws_send(struct mg_connection *, const char *buf, size_t len, int op); diff --git a/src/ws.c b/src/ws.c index f2e2fd2a..4cb883ac 100644 --- a/src/ws.c +++ b/src/ws.c @@ -15,25 +15,27 @@ struct ws_msg { }; static void ws_handshake(struct mg_connection *c, const char *key, - size_t key_len) { + size_t key_len, const char *fmt, va_list ap) { const char *magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; unsigned char sha[20], b64_sha[30]; - // mem[256], *buf = mem; - // int len = 0; + char mem[128], *buf = mem; + mg_sha1_ctx sha_ctx; mg_sha1_init(&sha_ctx); mg_sha1_update(&sha_ctx, (unsigned char *) key, key_len); mg_sha1_update(&sha_ctx, (unsigned char *) magic, 36); mg_sha1_final(sha, &sha_ctx); mg_base64_encode(sha, sizeof(sha), (char *) b64_sha); + buf[0] = '\0'; + if (fmt != NULL) mg_vasprintf(&buf, sizeof(mem), fmt, ap); mg_printf(c, "HTTP/1.1 101 Switching Protocols\r\n" "Upgrade: websocket\r\n" "Connection: Upgrade\r\n" - "Sec-WebSocket-Accept: %s\r\n\r\n", - b64_sha); - // mg_send(c, buf, len); - // if (buf != mem) free(buf); + "Sec-WebSocket-Accept: %s\r\n" + "%s\r\n", + b64_sha, buf); + if (buf != mem) free(buf); } static size_t ws_process(uint8_t *buf, size_t len, struct ws_msg *msg) { @@ -193,9 +195,15 @@ struct mg_connection *mg_ws_connect(struct mg_mgr *mgr, const char *url, return c; } -void mg_ws_upgrade(struct mg_connection *c, struct mg_http_message *hm) { +void mg_ws_upgrade(struct mg_connection *c, struct mg_http_message *hm, + const char *fmt, ...) { struct mg_str *wskey = mg_http_get_header(hm, "Sec-WebSocket-Key"); c->pfn = mg_ws_cb; - if (wskey != NULL) ws_handshake(c, wskey->ptr, wskey->len); + if (wskey != NULL) { + va_list ap; + va_start(ap, fmt); + ws_handshake(c, wskey->ptr, wskey->len, fmt, ap); + va_end(ap); + } c->is_websocket = 1; } diff --git a/src/ws.h b/src/ws.h index 37729c96..6413c79c 100644 --- a/src/ws.h +++ b/src/ws.h @@ -20,5 +20,6 @@ struct mg_ws_message { struct mg_connection *mg_ws_connect(struct mg_mgr *, const char *url, mg_event_handler_t fn, void *fn_data, const char *fmt, ...); -void mg_ws_upgrade(struct mg_connection *, struct mg_http_message *); +void mg_ws_upgrade(struct mg_connection *, struct mg_http_message *, + const char *fmt, ...); size_t mg_ws_send(struct mg_connection *, const char *buf, size_t len, int op); diff --git a/test/unit_test.c b/test/unit_test.c index 4f2168a6..6d21fa1c 100644 --- a/test/unit_test.c +++ b/test/unit_test.c @@ -331,7 +331,7 @@ static void eh1(struct mg_connection *c, int ev, void *ev_data, void *fn_data) { if (mg_http_match_uri(hm, "/foo/*")) { mg_http_reply(c, 200, "", "uri: %.*s", hm->uri.len - 5, hm->uri.ptr + 5); } else if (mg_http_match_uri(hm, "/ws")) { - mg_ws_upgrade(c, hm); + mg_ws_upgrade(c, hm, NULL); } else if (mg_http_match_uri(hm, "/body")) { mg_http_reply(c, 200, "", "%.*s", (int) hm->body.len, hm->body.ptr); } else if (mg_http_match_uri(hm, "/bar")) {