frp/utils/net/websocket.go
2018-08-05 12:55:31 +08:00

128 lines
2.4 KiB
Go

package net
import (
"fmt"
"net"
"net/http"
"net/url"
"sync/atomic"
"time"
"github.com/fatedier/frp/utils/log"
"golang.org/x/net/websocket"
)
type WebsocketListener struct {
log.Logger
server *http.Server
httpMutex *http.ServeMux
connChan chan *WebsocketConn
closeFlag bool
}
func NewWebsocketListener(ln net.Listener,
filter func(w http.ResponseWriter, r *http.Request) bool) (l *WebsocketListener, err error) {
l = &WebsocketListener{
httpMutex: http.NewServeMux(),
connChan: make(chan *WebsocketConn),
Logger: log.NewPrefixLogger(""),
}
l.httpMutex.Handle("/", websocket.Handler(func(c *websocket.Conn) {
conn := NewWebScoketConn(c)
l.connChan <- conn
conn.waitClose()
}))
l.server = &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if filter != nil && !filter(w, r) {
return
}
l.httpMutex.ServeHTTP(w, r)
}),
}
ch := make(chan struct{})
go func() {
close(ch)
err = l.server.Serve(ln)
}()
<-ch
<-time.After(time.Millisecond)
return
}
func ListenWebsocket(bindAddr string, bindPort int) (l *WebsocketListener, err error) {
ln, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
if err != nil {
return
}
l, err = NewWebsocketListener(ln, nil)
return
}
func (p *WebsocketListener) Accept() (Conn, error) {
c := <-p.connChan
return c, nil
}
func (p *WebsocketListener) Close() error {
if !p.closeFlag {
p.closeFlag = true
p.server.Close()
}
return nil
}
type WebsocketConn struct {
net.Conn
log.Logger
closed int32
wait chan struct{}
}
func NewWebScoketConn(conn net.Conn) (c *WebsocketConn) {
c = &WebsocketConn{
Conn: conn,
Logger: log.NewPrefixLogger(""),
wait: make(chan struct{}),
}
return
}
func (p *WebsocketConn) Close() error {
if atomic.SwapInt32(&p.closed, 1) == 1 {
return nil
}
close(p.wait)
return p.Conn.Close()
}
func (p *WebsocketConn) waitClose() {
<-p.wait
}
// ConnectWebsocketServer :
// addr: ws://domain:port
func ConnectWebsocketServer(addr string) (c Conn, err error) {
addr = "ws://" + addr
uri, err := url.Parse(addr)
if err != nil {
return
}
origin := "http://" + uri.Host
cfg, err := websocket.NewConfig(addr, origin)
if err != nil {
return
}
cfg.Dialer = &net.Dialer{
Timeout: time.Second * 10,
}
conn, err := websocket.DialConfig(cfg)
if err != nil {
return
}
c = NewWebScoketConn(conn)
return
}