mirror of
https://github.com/fatedier/frp.git
synced 2024-11-24 11:09:21 +08:00
newhttp: support websocket
This commit is contained in:
parent
3f64d73ea9
commit
cf9193a429
@ -79,6 +79,11 @@ func NewHttpReverseProxy() *HttpReverseProxy {
|
||||
return rp.CreateConnection(host, url)
|
||||
},
|
||||
},
|
||||
WebSocketDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
url := ctx.Value("url").(string)
|
||||
host := getHostFromAddr(ctx.Value("host").(string))
|
||||
return rp.CreateConnection(host, url)
|
||||
},
|
||||
BufferPool: newWrapPool(),
|
||||
ErrorLog: log.New(newWrapLogger(), "", 0),
|
||||
}
|
||||
|
@ -16,6 +16,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
frpIo "github.com/fatedier/frp/utils/io"
|
||||
)
|
||||
|
||||
// onExitFlushLoop is a callback set by tests to detect the state of the
|
||||
@ -59,6 +61,8 @@ type ReverseProxy struct {
|
||||
// modifies the Response from the backend.
|
||||
// If it returns an error, the proxy returns a StatusBadGateway error.
|
||||
ModifyResponse func(*http.Response) error
|
||||
|
||||
WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// A BufferPool is an interface for getting and returning temporary
|
||||
@ -139,6 +143,48 @@ var hopHeaders = []string{
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
if IsWebsocketRequest(req) {
|
||||
p.serveWebSocket(rw, req)
|
||||
} else {
|
||||
p.serveHTTP(rw, req)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) {
|
||||
if p.WebSocketDialContext == nil {
|
||||
rw.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
|
||||
req = req.WithContext(context.WithValue(req.Context(), "url", req.URL.Path))
|
||||
req = req.WithContext(context.WithValue(req.Context(), "host", req.Host))
|
||||
|
||||
targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "")
|
||||
if err != nil {
|
||||
rw.WriteHeader(501)
|
||||
return
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
p.Director(req)
|
||||
|
||||
hijacker, ok := rw.(http.Hijacker)
|
||||
if !ok {
|
||||
rw.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
conn, _, errHijack := hijacker.Hijack()
|
||||
if errHijack != nil {
|
||||
rw.WriteHeader(500)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
req.Write(targetConn)
|
||||
frpIo.Join(conn, targetConn)
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
transport := p.Transport
|
||||
if transport == nil {
|
||||
transport = http.DefaultTransport
|
||||
@ -368,3 +414,16 @@ func (m *maxLatencyWriter) flushLoop() {
|
||||
}
|
||||
|
||||
func (m *maxLatencyWriter) stop() { m.done <- true }
|
||||
|
||||
func IsWebsocketRequest(req *http.Request) bool {
|
||||
containsHeader := func(name, value string) bool {
|
||||
items := strings.Split(req.Header.Get(name), ",")
|
||||
for _, item := range items {
|
||||
if value == strings.ToLower(strings.TrimSpace(item)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket")
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user