diff --git a/mongoose.c b/mongoose.c index 39615f00..49716cf7 100644 --- a/mongoose.c +++ b/mongoose.c @@ -47,7 +47,7 @@ #define _WIN32_WINNT 0x0400 // To make it link in VS2005 #include -#ifndef PATH_MAX +#ifndef PATH_MAX #define PATH_MAX MAX_PATH #endif @@ -210,11 +210,13 @@ typedef int SOCKET; #if defined(DEBUG) #define DEBUG_TRACE(x) do { \ + flockfile(stdout); \ printf("*** [%lu] thread %p: %s: ", \ (unsigned long) time(NULL), (void *) pthread_self(), __func__); \ printf x; \ putchar('\n'); \ fflush(stdout); \ + funlockfile(stdout); \ } while (0) #else #define DEBUG_TRACE(x) @@ -360,7 +362,7 @@ struct vec { // Structure used by mg_stat() function. Uses 64 bit file length. struct mgstat { - int is_directory; // Directory marker + int is_directory; // Directory marker int64_t size; // File size time_t mtime; // Modification time }; @@ -432,6 +434,7 @@ struct mg_context { }; struct mg_connection { + struct mg_connection *peer; // Remote target in proxy mode struct mg_request_info request_info; struct mg_context *ctx; SSL *ssl; // SSL descriptor @@ -719,9 +722,12 @@ static int match_extension(const char *path, const char *ext_list) { #endif // !NO_CGI // HTTP 1.1 assumes keep alive if "Connection:" header is not set +// This function must tolerate situations when connection info is not +// set up, for example if request parsing failed. static int should_keep_alive(const struct mg_connection *conn) { + const char *http_version = conn->request_info.http_version; const char *header = mg_get_header(conn, "Connection"); - return (header == NULL && !strcmp(conn->request_info.http_version, "1.1")) || + return (header == NULL && http_version && !strcmp(http_version, "1.1")) || (header != NULL && !strcmp(header, "keep-alive")); } @@ -937,7 +943,7 @@ static size_t strftime(char *dst, size_t dst_size, const char *fmt, const struct tm *tm) { (void) snprintf(dst, dst_size, "implement strftime() for WinCE"); return 0; -} +} #endif static int mg_rename(const char* oldname, const char* newname) { @@ -1584,6 +1590,7 @@ static int get_request_len(const char *buf, int buflen) { const char *s, *e; int len = 0; + DEBUG_TRACE(("buf: %p, len: %d", buf, buflen)); for (s = buf, e = s + buflen - 1; len <= 0 && s < e; s++) // Control characters are not allowed but >=128 is. if (!isprint(* (unsigned char *) s) && *s != '\r' && @@ -2533,11 +2540,9 @@ static void parse_http_headers(char **buf, struct mg_request_info *ri) { } static int is_valid_http_method(const char *method) { - return !strcmp(method, "GET") || - !strcmp(method, "POST") || - !strcmp(method, "HEAD") || - !strcmp(method, "PUT") || - !strcmp(method, "DELETE"); + return !strcmp(method, "GET") || !strcmp(method, "POST") || + !strcmp(method, "HEAD") || !strcmp(method, "CONNECT") || + !strcmp(method, "PUT") || !strcmp(method, "DELETE"); } // Parse HTTP request, fill in mg_request_info structure. @@ -2554,7 +2559,6 @@ static int parse_http_request(char *buf, struct mg_request_info *ri) { ri->http_version = skip(&buf, "\r\n"); if (is_valid_http_method(ri->request_method) && - ri->uri[0] == '/' && strncmp(ri->http_version, "HTTP/", 5) == 0) { ri->http_version += 5; /* Skip "HTTP/" */ parse_http_headers(&buf, ri); @@ -2641,7 +2645,8 @@ static int is_not_modified(const struct mg_connection *conn, return ims != NULL && stp->mtime <= parse_date_string(ims); } -static int handle_request_body(struct mg_connection *conn, FILE *fp) { +static int forward_body_data(struct mg_connection *conn, FILE *fp, + SOCKET sock, SSL *ssl) { const char *expect, *buffered; char buf[BUFSIZ]; int to_read, nread, buffered_len, success = 0; @@ -2667,7 +2672,7 @@ static int handle_request_body(struct mg_connection *conn, FILE *fp) { if ((int64_t) buffered_len > conn->content_len) { buffered_len = (int) conn->content_len; } - push(fp, INVALID_SOCKET, NULL, buffered, (int64_t) buffered_len); + push(fp, sock, ssl, buffered, (int64_t) buffered_len); conn->consumed_content += buffered_len; } @@ -2677,7 +2682,7 @@ static int handle_request_body(struct mg_connection *conn, FILE *fp) { to_read = (int) (conn->content_len - conn->consumed_content); } nread = pull(NULL, conn->client.sock, conn->ssl, buf, to_read); - if (nread <= 0 || push(fp, INVALID_SOCKET, NULL, buf, nread) != nread) { + if (nread <= 0 || push(fp, sock, ssl, buf, nread) != nread) { break; } conn->consumed_content += nread; @@ -2890,7 +2895,7 @@ static void handle_cgi_request(struct mg_connection *conn, const char *prog) { // Send POST data to the CGI process if needed if (!strcmp(conn->request_info.request_method, "POST") && - !handle_request_body(conn, in)) { + !forward_body_data(conn, in, INVALID_SOCKET, NULL)) { goto done; } @@ -3010,7 +3015,7 @@ static void put_file(struct mg_connection *conn, const char *path) { // TODO(lsm): handle seek error (void) fseeko(fp, (off_t) r1, SEEK_SET); } - if (handle_request_body(conn, fp)) + if (forward_body_data(conn, fp, INVALID_SOCKET, NULL)) (void) mg_printf(conn, "HTTP/1.1 %d OK\r\n\r\n", conn->request_info.status_code); (void) fclose(fp); @@ -3242,7 +3247,7 @@ static void close_all_listening_sockets(struct mg_context *ctx) { } } -// Valid listening port specification is: [ip_address:]port[s[p]] +// Valid listening port specification is: [ip_address:]port[s|p] // Examples: 80, 443s, 127.0.0.1:3128p, 1.2.3.4:8080sp static int parse_port_string(const struct vec *vec, struct socket *so) { struct usa *usa = &so->lsa; @@ -3260,17 +3265,14 @@ static int parse_port_string(const struct vec *vec, struct socket *so) { } else { return 0; } - assert(len > 0 && len <= (int) vec->len); - so->is_ssl = vec->ptr[len] == 's'; - so->is_proxy = vec->ptr[len] == 'p' || - (vec->ptr[len] == 's' && vec->ptr[len + 1] == 'p'); - if (vec->ptr[len + so->is_ssl + so->is_proxy] != '\0' && - vec->ptr[len + so->is_ssl + so->is_proxy] != ',') { + if (strchr("sp,", vec->ptr[len]) == NULL) { return 0; } + so->is_ssl = vec->ptr[len] == 's'; + so->is_proxy = vec->ptr[len] == 'p'; usa->len = sizeof(usa->u.sin); usa->u.sin.sin_family = AF_INET; usa->u.sin.sin_port = htons((uint16_t) port); @@ -3288,7 +3290,7 @@ static int set_ports_option(struct mg_context *ctx) { while (success && (list = next_option(list, &vec, NULL)) != NULL) { if (!parse_port_string(&vec, &so)) { cry(fc(ctx), "%s: %.*s: invalid port spec. Expecting list of: %s", - __func__, vec.len, vec.ptr, "[IP_ADDRESS:]PORT[s[p]]"); + __func__, vec.len, vec.ptr, "[IP_ADDRESS:]PORT[s|p]"); success = 0; } else if (so.is_ssl && ctx->ssl_ctx == NULL) { cry(fc(ctx), "Cannot add SSL socket, is -ssl_cert option set?"); @@ -3586,11 +3588,16 @@ static int set_acl_option(struct mg_context *ctx) { } static void reset_per_request_attributes(struct mg_connection *conn) { - if (conn->request_info.remote_user != NULL) { - free((void *) conn->request_info.remote_user); - conn->request_info.remote_user = NULL; + struct mg_request_info *ri = &conn->request_info; + + // Reset request info attributes. DO NOT TOUCH is_ssl, remote_ip, remote_port + if (ri->remote_user != NULL) { + free((void *) ri->remote_user); } - conn->request_info.status_code = -1; + ri->remote_user = ri->request_method = ri->uri = ri->http_version = NULL; + ri->num_headers = 0; + ri->status_code = -1; + conn->num_bytes_sent = conn->consumed_content = 0; conn->content_len = -1; conn->request_len = conn->data_len = 0; @@ -3648,6 +3655,78 @@ static void discard_current_request_from_buffer(struct mg_connection *conn) { memmove(conn->buf, conn->buf + conn->request_len + body_len, conn->data_len); } +static int parse_url(const char *url, char *host, int *port) { + int len; + + if (url == NULL) { + return 0; + }; + + if (!strncmp(url, "http://", 7)) { + url += 7; + } + + if (sscanf(url, "%1024[^:]:%d/%n", host, port, &len) == 2) { + } else { + sscanf(url, "%1024[^/]/%n", host, &len); + *port = 80; + } + DEBUG_TRACE(("Host:%s, port:%d", host, *port)); + + return len > 0 && url[len - 1] == '/' ? len - 1 : len; +} + +static void handle_proxy_request(struct mg_connection *conn) { + struct mg_request_info *ri = &conn->request_info; + char host[1025], buf[BUFSIZ]; + int port, is_ssl, len, i, n; + + DEBUG_TRACE(("URL: %s", ri->uri)); + if (conn->request_info.uri[0] == '/' || + (len = parse_url(ri->uri, host, &port)) == 0) { + return; + } + + if (conn->peer == NULL) { + is_ssl = !strcmp(ri->request_method, "CONNECT"); + if ((conn->peer = mg_connect(conn, host, port, is_ssl)) == NULL) { + return; + } + conn->peer->client.is_ssl = is_ssl; + } + + // Forward client's request to the target + mg_printf(conn->peer, "%s %s HTTP/%s\r\n", ri->request_method, ri->uri + len, + ri->http_version); + + // And also all headers. TODO(lsm): anonymize! + for (i = 0; i < ri->num_headers; i++) { + mg_printf(conn->peer, "%s: %s\r\n", ri->http_headers[i].name, + ri->http_headers[i].value); + } + // End of headers, final newline + mg_write(conn->peer, "\r\n", 2); + + // Read and forward body data if any + if (!strcmp(ri->request_method, "POST")) { + forward_body_data(conn, NULL, conn->peer->client.sock, conn->peer->ssl); + } + + // Read data from the target and forward it to the client + while ((n = pull(NULL, conn->peer->client.sock, conn->peer->ssl, + buf, sizeof(buf))) > 0) { + if (mg_write(conn, buf, n) != n) { + break; + } + } + + if (!conn->peer->client.is_ssl) { + close_connection(conn->peer); + free(conn->peer); + conn->peer = NULL; + } +} + static void process_new_connection(struct mg_connection *conn) { struct mg_request_info *ri = &conn->request_info; int keep_alive_enabled; @@ -3664,13 +3743,17 @@ static void process_new_connection(struct mg_connection *conn) { conn->buf, sizeof(conn->buf), &conn->data_len); } assert(conn->data_len >= conn->request_len); - if (conn->request_len <= 0) { + if (conn->request_len == 0 && conn->data_len == sizeof(conn->buf)) { + send_http_error(conn, 413, "Request Too Large", ""); + return; + } if (conn->request_len <= 0) { return; // Remote end closed the connection } // Nul-terminate the request cause parse_http_request() uses sscanf conn->buf[conn->request_len - 1] = '\0'; - if (!parse_http_request(conn->buf, ri)) { + if (!parse_http_request(conn->buf, ri) || + (!conn->client.is_proxy && ri->uri[0] != '/')) { // Do not put garbage in the access log, just send it back to the client send_http_error(conn, 400, "Bad Request", "Cannot parse HTTP request: [%.*s]", conn->data_len, conn->buf); @@ -3684,11 +3767,16 @@ static void process_new_connection(struct mg_connection *conn) { cl = get_header(ri, "Content-Length"); conn->content_len = cl == NULL ? -1 : strtoll(cl, NULL, 10); conn->birth_time = time(NULL); - handle_request(conn); + if (conn->client.is_proxy) { + handle_proxy_request(conn); + } else { + handle_request(conn); + } log_access(conn); discard_current_request_from_buffer(conn); } - } while (keep_alive_enabled && should_keep_alive(conn)); + // conn->peer is not NULL only for SSL-ed proxy connections + } while (conn->peer || (keep_alive_enabled && should_keep_alive(conn))); } // Worker threads take accepted socket from the queue @@ -3796,6 +3884,7 @@ static void accept_new_connection(const struct socket *listener, // Put accepted socket structure into the queue DEBUG_TRACE(("accepted socket %d", accepted.sock)); accepted.is_ssl = listener->is_ssl; + accepted.is_proxy = listener->is_proxy; produce_socket(ctx, &accepted); } else { cry(fc(ctx), "%s: %s is not allowed to connect",