newhttp: support websocket

This commit is contained in:
fatedier 2018-01-23 01:29:52 +08:00
parent 3f64d73ea9
commit cf9193a429
2 changed files with 64 additions and 0 deletions

View File

@ -79,6 +79,11 @@ func NewHttpReverseProxy() *HttpReverseProxy {
return rp.CreateConnection(host, url) return rp.CreateConnection(host, url)
}, },
}, },
WebSocketDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
url := ctx.Value("url").(string)
host := getHostFromAddr(ctx.Value("host").(string))
return rp.CreateConnection(host, url)
},
BufferPool: newWrapPool(), BufferPool: newWrapPool(),
ErrorLog: log.New(newWrapLogger(), "", 0), ErrorLog: log.New(newWrapLogger(), "", 0),
} }

View File

@ -16,6 +16,8 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
frpIo "github.com/fatedier/frp/utils/io"
) )
// onExitFlushLoop is a callback set by tests to detect the state of the // onExitFlushLoop is a callback set by tests to detect the state of the
@ -59,6 +61,8 @@ type ReverseProxy struct {
// modifies the Response from the backend. // modifies the Response from the backend.
// If it returns an error, the proxy returns a StatusBadGateway error. // If it returns an error, the proxy returns a StatusBadGateway error.
ModifyResponse func(*http.Response) error ModifyResponse func(*http.Response) error
WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
} }
// A BufferPool is an interface for getting and returning temporary // A BufferPool is an interface for getting and returning temporary
@ -139,6 +143,48 @@ var hopHeaders = []string{
} }
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if IsWebsocketRequest(req) {
p.serveWebSocket(rw, req)
} else {
p.serveHTTP(rw, req)
}
}
func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) {
if p.WebSocketDialContext == nil {
rw.WriteHeader(500)
return
}
req = req.WithContext(context.WithValue(req.Context(), "url", req.URL.Path))
req = req.WithContext(context.WithValue(req.Context(), "host", req.Host))
targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "")
if err != nil {
rw.WriteHeader(501)
return
}
defer targetConn.Close()
p.Director(req)
hijacker, ok := rw.(http.Hijacker)
if !ok {
rw.WriteHeader(500)
return
}
conn, _, errHijack := hijacker.Hijack()
if errHijack != nil {
rw.WriteHeader(500)
return
}
defer conn.Close()
req.Write(targetConn)
frpIo.Join(conn, targetConn)
}
func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport transport := p.Transport
if transport == nil { if transport == nil {
transport = http.DefaultTransport transport = http.DefaultTransport
@ -368,3 +414,16 @@ func (m *maxLatencyWriter) flushLoop() {
} }
func (m *maxLatencyWriter) stop() { m.done <- true } func (m *maxLatencyWriter) stop() { m.done <- true }
func IsWebsocketRequest(req *http.Request) bool {
containsHeader := func(name, value string) bool {
items := strings.Split(req.Header.Get(name), ",")
for _, item := range items {
if value == strings.ToLower(strings.TrimSpace(item)) {
return true
}
}
return false
}
return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket")
}