/* * Copyright (c) 2014 Cesanta Software Limited * All rights reserved */ #include "../../mongoose.h" #include #define MAX_IDLE_CONNS 5 #define CONN_IDLE_TIMEOUT 30 struct http_backend; struct be_conn { struct http_backend *be; struct mg_connection *nc; time_t idle_deadline; STAILQ_ENTRY(be_conn) conns; }; STAILQ_HEAD(be_conn_list_head, be_conn); struct http_backend { const char *vhost; /* NULL if any host */ const char *uri_prefix; /* URI prefix, e.g. "/api/v1/", "/static/" */ const char *uri_prefix_replacement; /* if not NULL, will replace uri_prefix in requests to backends */ const char *host_port; /* Backend address */ int redirect; /* if true redirect instead of proxy */ int usage_counter; /* Number of times this backend was chosen */ struct be_conn_list_head conns; int num_conns; }; struct peer { struct mg_connection *nc; int64_t body_len; /* Size of the HTTP body to forward */ int64_t body_sent; /* Number of bytes already forwarded */ struct { /* Headers have been sent, no more headers. */ unsigned int headers_sent : 1; unsigned int keep_alive : 1; } flags; }; struct conn_data { struct be_conn *be_conn; /* Chosen backend */ struct peer client; /* Client peer */ struct peer backend; /* Backend peer */ time_t last_activity; }; static const char *s_error_500 = "HTTP/1.1 500 Failed\r\n"; static const char *s_content_len_0 = "Content-Length: 0\r\n"; static const char *s_connection_close = "Connection: close\r\n"; static const char *s_http_port = "8000"; static struct http_backend s_vhost_backends[100], s_default_backends[100]; static int s_num_vhost_backends = 0, s_num_default_backends = 0; static int s_sig_num = 0; static int s_backend_keepalive = 0; static FILE *s_log_file = NULL; #if MG_ENABLE_SSL const char *s_ssl_cert = NULL; #endif static void ev_handler(struct mg_connection *nc, int ev, void *ev_data); static void write_log(const char *fmt, ...); static void signal_handler(int sig_num) { signal(sig_num, signal_handler); s_sig_num = sig_num; } static void send_http_err(struct mg_connection *nc, const char *err_line) { mg_printf(nc, "%s%s%s\r\n", err_line, s_content_len_0, s_connection_close); } static void respond_with_error(struct conn_data *conn, const char *err_line) { struct mg_connection *nc = conn->client.nc; int headers_sent = conn->client.flags.headers_sent; #ifdef DEBUG write_log("conn=%p nc=%p respond_with_error %d\n", conn, nc, headers_sent); #endif if (nc == NULL) return; if (!headers_sent) { send_http_err(nc, err_line); conn->client.flags.headers_sent = 1; } nc->flags |= MG_F_SEND_AND_CLOSE; } static int has_prefix(const struct mg_str *uri, const char *prefix) { size_t prefix_len = strlen(prefix); return uri->len >= prefix_len && memcmp(uri->p, prefix, prefix_len) == 0; } static int matches_vhost(const struct mg_str *host, const char *vhost) { size_t vhost_len; if (vhost == NULL) { return 1; } vhost_len = strlen(vhost); return host->len == vhost_len && memcmp(host->p, vhost, vhost_len) == 0; } static void write_log(const char *fmt, ...) { va_list ap; if (s_log_file != NULL) { va_start(ap, fmt); vfprintf(s_log_file, fmt, ap); fflush(s_log_file); va_end(ap); } } static struct http_backend *choose_backend_from_list( struct http_message *hm, struct http_backend *backends, int num_backends) { int i; struct mg_str vhost = {"", 0}; const struct mg_str *host = mg_get_http_header(hm, "host"); if (host != NULL) vhost = *host; const char *vhost_end = vhost.p; while (vhost_end < vhost.p + vhost.len && *vhost_end != ':') { vhost_end++; } vhost.len = vhost_end - vhost.p; struct http_backend *chosen = NULL; for (i = 0; i < num_backends; i++) { struct http_backend *be = &backends[i]; if (has_prefix(&hm->uri, be->uri_prefix) && matches_vhost(&vhost, be->vhost) && (chosen == NULL || /* Prefer most specific URI prefixes */ strlen(be->uri_prefix) > strlen(chosen->uri_prefix) || /* Among prefixes of the same length chose the least used. */ (strlen(be->uri_prefix) == strlen(chosen->uri_prefix) && be->usage_counter < chosen->usage_counter))) { chosen = be; } } return chosen; } static struct http_backend *choose_backend(struct http_message *hm) { struct http_backend *chosen = choose_backend_from_list(hm, s_vhost_backends, s_num_vhost_backends); /* Nothing was chosen for this vhost, look for vhost == NULL backends. */ if (chosen == NULL) { chosen = choose_backend_from_list(hm, s_default_backends, s_num_default_backends); } if (chosen != NULL) chosen->usage_counter++; return chosen; } static void forward_body(struct peer *src, struct peer *dst) { struct mbuf *src_io = &src->nc->recv_mbuf; if (src->body_sent < src->body_len) { size_t to_send = src->body_len - src->body_sent; if (src_io->len < to_send) { to_send = src_io->len; } mg_send(dst->nc, src_io->buf, to_send); src->body_sent += to_send; mbuf_remove(src_io, to_send); } #ifdef DEBUG write_log("forward_body %p (ka=%d) -> %p sent %d of %d\n", src->nc, src->flags.keep_alive, dst->nc, src->body_sent, src->body_len); #endif } static void forward(struct conn_data *conn, struct http_message *hm, struct peer *src_peer, struct peer *dst_peer) { struct mg_connection *src = src_peer->nc; struct mg_connection *dst = dst_peer->nc; struct mbuf *io = &src->recv_mbuf; int i; int is_request = (src_peer == &conn->client); src_peer->body_len = hm->body.len; struct http_backend *be = conn->be_conn->be; if (is_request) { /* Write rewritten request line. */ size_t trim_len = strlen(be->uri_prefix); mg_printf(dst, "%.*s%s%.*s\r\n", (int) (hm->uri.p - io->buf), io->buf, be->uri_prefix_replacement, (int) (hm->proto.p + hm->proto.len - (hm->uri.p + trim_len)), hm->uri.p + trim_len); } else { /* Reply line goes without modification */ mg_printf(dst, "%.*s %d %.*s\r\n", (int) hm->proto.len, hm->proto.p, (int) hm->resp_code, (int) hm->resp_status_msg.len, hm->resp_status_msg.p); } /* Headers. */ for (i = 0; i < MG_MAX_HTTP_HEADERS && hm->header_names[i].len > 0; i++) { struct mg_str hn = hm->header_names[i]; struct mg_str hv = hm->header_values[i]; #if MG_ENABLE_SSL /* * If we terminate SSL and backend redirects to local HTTP port, * strip protocol to let client use HTTPS. * TODO(lsm): web page content may also contain local HTTP references, * they need to be rewritten too. */ if (mg_vcasecmp(&hn, "Location") == 0 && s_ssl_cert != NULL) { size_t hlen = strlen(be->host_port); const char *hp = be->host_port, *p = memchr(hp, ':', hlen); if (p == NULL) { p = hp + hlen; } if (mg_ncasecmp(hv.p, "http://", 7) == 0 && mg_ncasecmp(hv.p + 7, hp, (p - hp)) == 0) { mg_printf(dst, "Location: %.*s\r\n", (int) (hv.len - (7 + (p - hp))), hv.p + 7 + (p - hp)); continue; } } #endif /* We always rewrite the connection header depending on the settings. */ if (mg_vcasecmp(&hn, "Connection") == 0) continue; /* Don't pass chunked transfer encoding to the client */ if (mg_vcasecmp(&hn, "Transfer-encoding") == 0 && mg_vcasecmp(&hv, "chunked") == 0) { continue; } mg_printf(dst, "%.*s: %.*s\r\n", (int) hn.len, hn.p, (int) hv.len, hv.p); } /* Emit the connection header. */ const char *connection_mode = "close"; if (dst_peer == &conn->backend) { if (s_backend_keepalive) connection_mode = "keep-alive"; } else { if (conn->client.flags.keep_alive) connection_mode = "keep-alive"; } mg_printf(dst, "Connection: %s\r\n", connection_mode); mg_printf(dst, "%s", "\r\n"); mbuf_remove(io, hm->body.p - hm->message.p); /* We've forwarded headers */ dst_peer->flags.headers_sent = 1; forward_body(src_peer, dst_peer); } struct be_conn *get_conn(struct http_backend *be) { if (STAILQ_EMPTY(&be->conns)) return NULL; struct be_conn *result = STAILQ_FIRST(&be->conns); STAILQ_REMOVE_HEAD(&be->conns, conns); be->num_conns--; return result; } /* * choose_backend parses incoming HTTP request and routes it to the appropriate * backend. It assumes that clients don't do HTTP pipelining, handling only * one request request for each connection. To give a hint to backend about * this it inserts "Connection: close" header into each forwarded request. */ static int connect_backend(struct conn_data *conn, struct http_message *hm) { struct mg_connection *nc = conn->client.nc; struct http_backend *be = choose_backend(hm); write_log("%ld %.*s %.*s backend=%s\n", (long) time(NULL), (int) hm->method.len, hm->method.p, (int) hm->uri.len, hm->uri.p, be ? be->host_port : "not defined"); if (be == NULL) return 0; if (be->redirect != 0) { mg_printf(nc, "HTTP/1.1 302 Found\r\nLocation: %s\r\n\r\n", be->host_port); return 1; } struct be_conn *bec = get_conn(be); if (bec != NULL) { bec->nc->handler = ev_handler; #ifdef DEBUG write_log("conn=%p to %p (%s) reusing bec=%p\n", conn, be, be->host_port, bec); #endif } else { bec = malloc(sizeof(*conn->be_conn)); memset(bec, 0, sizeof(*bec)); bec->nc = mg_connect(nc->mgr, be->host_port, ev_handler); #ifdef DEBUG write_log("conn=%p new conn to %p (%s) bec=%p\n", conn, be, be->host_port, bec); #endif if (bec->nc == NULL) { free(bec); write_log("Connection to [%s] failed\n", be->host_port); return 0; } } bec->be = be; conn->be_conn = bec; conn->backend.nc = bec->nc; conn->backend.nc->user_data = conn; mg_set_protocol_http_websocket(conn->backend.nc); return 1; } static int is_keep_alive(struct http_message *hm) { const struct mg_str *connection_header = mg_get_http_header(hm, "Connection"); if (connection_header == NULL) { /* HTTP/1.1 connections are keep-alive by default. */ if (mg_vcasecmp(&hm->proto, "HTTP/1.1") != 0) return 0; } else if (mg_vcasecmp(connection_header, "keep-alive") != 0) { return 0; } // We must also have Content-Length. return mg_get_http_header(hm, "Content-Length") != NULL; } static void idle_backend_handler(struct mg_connection *nc, int ev, void *ev_data) { (void) ev_data; /* Unused. */ struct be_conn *bec = nc->user_data; const time_t now = time(NULL); #ifdef DEBUG write_log("%d idle bec=%p nc=%p ev=%d deadline=%d\n", now, bec, nc, ev, bec->idle_deadline); #endif switch (ev) { case MG_EV_POLL: { if (bec->idle_deadline > 0 && now > bec->idle_deadline) { #ifdef DEBUG write_log("bec=%p nc=%p closing due to idleness\n", bec, bec->nc); #endif bec->nc->flags |= MG_F_CLOSE_IMMEDIATELY; } break; } case MG_EV_CLOSE: { #ifdef DEBUG write_log("bec=%p closed\n", bec); #endif if (bec->idle_deadline > 0) { STAILQ_REMOVE(&bec->be->conns, bec, be_conn, conns); } free(bec); break; } } } void release_backend(struct conn_data *conn) { /* Disassociate the backend, put back on the pool. */ struct be_conn *bec = conn->be_conn; conn->be_conn = NULL; if (bec->nc == NULL) { free(bec); memset(&conn->backend, 0, sizeof(conn->backend)); return; } struct http_backend *be = bec->be; bec->nc->user_data = bec; bec->nc->handler = idle_backend_handler; if (conn->backend.flags.keep_alive) { bec->idle_deadline = time(NULL) + CONN_IDLE_TIMEOUT; STAILQ_INSERT_TAIL(&be->conns, bec, conns); #ifdef DEBUG write_log("bec=%p becoming idle\n", bec); #endif be->num_conns++; while (be->num_conns > MAX_IDLE_CONNS) { bec = STAILQ_FIRST(&be->conns); STAILQ_REMOVE_HEAD(&be->conns, conns); be->num_conns--; bec->idle_deadline = 0; bec->nc->flags = MG_F_CLOSE_IMMEDIATELY; #ifdef DEBUG write_log("bec=%p evicted\n", bec); #endif } } else { bec->idle_deadline = 0; bec->nc->flags |= MG_F_CLOSE_IMMEDIATELY; } memset(&conn->backend, 0, sizeof(conn->backend)); } static void ev_handler(struct mg_connection *nc, int ev, void *ev_data) { struct conn_data *conn = (struct conn_data *) nc->user_data; const time_t now = time(NULL); #ifdef DEBUG write_log("%d conn=%p nc=%p ev=%d ev_data=%p bec=%p bec_nc=%p\n", now, conn, nc, ev, ev_data, conn != NULL ? conn->be_conn : NULL, conn != NULL && conn->be_conn != NULL ? conn->be_conn->nc : NULL); #endif if (conn == NULL) { if (ev == MG_EV_ACCEPT) { conn = calloc(1, sizeof(*conn)); if (conn == NULL) { send_http_err(nc, s_error_500); } else { memset(conn, 0, sizeof(*conn)); nc->user_data = conn; conn->client.nc = nc; conn->client.body_len = -1; conn->backend.body_len = -1; conn->last_activity = now; } return; } else { if (ev != MG_EV_POLL) { nc->flags |= MG_F_CLOSE_IMMEDIATELY; } return; } } if (ev != MG_EV_POLL) conn->last_activity = now; switch (ev) { case MG_EV_HTTP_REQUEST: { /* From client */ assert(conn != NULL); assert(conn->be_conn == NULL); struct http_message *hm = (struct http_message *) ev_data; conn->client.flags.keep_alive = is_keep_alive(hm); if (!connect_backend(conn, hm)) { respond_with_error(conn, s_error_500); break; } if (conn->backend.nc == NULL) { /* This is a redirect, we're done. */ conn->client.nc->flags |= MG_F_SEND_AND_CLOSE; break; } forward(conn, hm, &conn->client, &conn->backend); break; } case MG_EV_CONNECT: { /* To backend */ assert(conn != NULL); assert(conn->be_conn != NULL); int status = *(int *) ev_data; if (status != 0) { write_log("Error connecting to %s: %d (%s)\n", conn->be_conn->be->host_port, status, strerror(status)); /* TODO(lsm): mark backend as defunct, try it later on */ respond_with_error(conn, s_error_500); conn->be_conn->nc = NULL; release_backend(conn); break; } break; } case MG_EV_HTTP_REPLY: { /* From backend */ assert(conn != NULL); struct http_message *hm = (struct http_message *) ev_data; conn->backend.flags.keep_alive = s_backend_keepalive && is_keep_alive(hm); forward(conn, hm, &conn->backend, &conn->client); release_backend(conn); if (!conn->client.flags.keep_alive) { conn->client.nc->flags |= MG_F_SEND_AND_CLOSE; } else { #ifdef DEBUG write_log("conn=%p remains open\n", conn); #endif } break; } case MG_EV_POLL: { assert(conn != NULL); if (now - conn->last_activity > CONN_IDLE_TIMEOUT && conn->backend.nc == NULL /* not waiting for backend */) { #ifdef DEBUG write_log("conn=%p has been idle for too long\n", conn); conn->client.nc->flags |= MG_F_SEND_AND_CLOSE; #endif } break; } case MG_EV_CLOSE: { assert(conn != NULL); if (nc == conn->client.nc) { #ifdef DEBUG write_log("conn=%p nc=%p client closed, body_sent=%d\n", conn, nc, conn->backend.body_sent); #endif conn->client.nc = NULL; if (conn->backend.nc != NULL) { conn->backend.nc->flags |= MG_F_CLOSE_IMMEDIATELY; } } else if (nc == conn->backend.nc) { #ifdef DEBUG write_log("conn=%p nc=%p backend closed\n", conn, nc); #endif conn->backend.nc = NULL; if (conn->client.nc != NULL && (conn->backend.body_len < 0 || conn->backend.body_sent < conn->backend.body_len)) { write_log("Backend %s disconnected.\n", conn->be_conn->be->host_port); respond_with_error(conn, s_error_500); } } if (conn->client.nc == NULL && conn->backend.nc == NULL) { free(conn); } break; } } } static void print_usage_and_exit(const char *prog_name) { fprintf(stderr, "Usage: %s [-D debug_dump_file] [-p http_port] [-l log] [-k]" #if MG_ENABLE_SSL "[-s ssl_cert] " #endif "<[-r] [-v vhost] -b uri_prefix[=replacement] host_port> ... \n", prog_name); exit(EXIT_FAILURE); } int main(int argc, char *argv[]) { struct mg_mgr mgr; struct mg_connection *nc; int i, redirect = 0; const char *vhost = NULL; mg_mgr_init(&mgr, NULL); /* Parse command line arguments */ for (i = 1; i < argc; i++) { if (strcmp(argv[i], "-D") == 0) { mgr.hexdump_file = argv[i + 1]; i++; } else if (strcmp(argv[i], "-k") == 0) { s_backend_keepalive = 1; } else if (strcmp(argv[i], "-l") == 0 && i + 1 < argc) { if (strcmp(argv[i + 1], "-") == 0) { s_log_file = stdout; } else { s_log_file = fopen(argv[i + 1], "a"); if (s_log_file == NULL) { perror("fopen"); exit(EXIT_FAILURE); } } i++; } else if (strcmp(argv[i], "-p") == 0) { s_http_port = argv[i + 1]; i++; } else if (strcmp(argv[i], "-r") == 0 && i + 1 < argc) { redirect = 1; } else if (strcmp(argv[i], "-v") == 0 && i + 1 < argc) { if (strcmp(argv[i + 1], "") == 0) { vhost = NULL; } else { vhost = argv[i + 1]; } i++; } else if (strcmp(argv[i], "-b") == 0 && i + 2 < argc) { struct http_backend *be = vhost != NULL ? &s_vhost_backends[s_num_vhost_backends++] : &s_default_backends[s_num_default_backends++]; STAILQ_INIT(&be->conns); char *r = NULL; be->vhost = vhost; be->uri_prefix = argv[i + 1]; be->host_port = argv[i + 2]; be->redirect = redirect; be->uri_prefix_replacement = be->uri_prefix; if ((r = strchr(be->uri_prefix, '=')) != NULL) { *r = '\0'; be->uri_prefix_replacement = r + 1; } printf( "Adding backend for %s%s : %s " "[redirect=%d,prefix_replacement=%s]\n", be->vhost == NULL ? "" : be->vhost, be->uri_prefix, be->host_port, be->redirect, be->uri_prefix_replacement); vhost = NULL; redirect = 0; i += 2; #if MG_ENABLE_SSL } else if (strcmp(argv[i], "-s") == 0 && i + 1 < argc) { s_ssl_cert = argv[++i]; #endif } else { print_usage_and_exit(argv[0]); } } /* Open listening socket */ if ((nc = mg_bind(&mgr, s_http_port, ev_handler)) == NULL) { fprintf(stderr, "mg_bind(%s) failed\n", s_http_port); exit(EXIT_FAILURE); } #if MG_ENABLE_SSL if (s_ssl_cert != NULL) { const char *err_str = mg_set_ssl(nc, s_ssl_cert, NULL); if (err_str != NULL) { fprintf(stderr, "Error loading SSL cert: %s\n", err_str); exit(1); } } #endif mg_set_protocol_http_websocket(nc); if (s_num_vhost_backends + s_num_default_backends == 0) { print_usage_and_exit(argv[0]); } signal(SIGINT, signal_handler); signal(SIGTERM, signal_handler); /* Run event loop until signal is received */ printf("Starting LB on port %s\n", s_http_port); while (s_sig_num == 0) { mg_mgr_poll(&mgr, 1000); } /* Cleanup */ mg_mgr_free(&mgr); printf("Exiting on signal %d\n", s_sig_num); return EXIT_SUCCESS; }