mirror of
https://github.com/fatedier/frp.git
synced 2024-11-24 02:59:19 +08:00
458 lines
12 KiB
Go
458 lines
12 KiB
Go
// Copyright 2017 fatedier, fatedier@gmail.com
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package visitor
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
libio "github.com/fatedier/golib/io"
|
|
fmux "github.com/hashicorp/yamux"
|
|
quic "github.com/quic-go/quic-go"
|
|
"golang.org/x/time/rate"
|
|
|
|
v1 "github.com/fatedier/frp/pkg/config/v1"
|
|
"github.com/fatedier/frp/pkg/msg"
|
|
"github.com/fatedier/frp/pkg/nathole"
|
|
"github.com/fatedier/frp/pkg/transport"
|
|
netpkg "github.com/fatedier/frp/pkg/util/net"
|
|
"github.com/fatedier/frp/pkg/util/util"
|
|
"github.com/fatedier/frp/pkg/util/xlog"
|
|
)
|
|
|
|
var ErrNoTunnelSession = errors.New("no tunnel session")
|
|
|
|
type XTCPVisitor struct {
|
|
*BaseVisitor
|
|
session TunnelSession
|
|
startTunnelCh chan struct{}
|
|
retryLimiter *rate.Limiter
|
|
cancel context.CancelFunc
|
|
|
|
cfg *v1.XTCPVisitorConfig
|
|
}
|
|
|
|
func (sv *XTCPVisitor) Run() (err error) {
|
|
sv.ctx, sv.cancel = context.WithCancel(sv.ctx)
|
|
|
|
if sv.cfg.Protocol == "kcp" {
|
|
sv.session = NewKCPTunnelSession()
|
|
} else {
|
|
sv.session = NewQUICTunnelSession(sv.clientCfg)
|
|
}
|
|
|
|
if sv.cfg.BindPort > 0 {
|
|
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
|
|
if err != nil {
|
|
return
|
|
}
|
|
go sv.worker()
|
|
}
|
|
|
|
go sv.internalConnWorker()
|
|
go sv.processTunnelStartEvents()
|
|
if sv.cfg.KeepTunnelOpen {
|
|
sv.retryLimiter = rate.NewLimiter(rate.Every(time.Hour/time.Duration(sv.cfg.MaxRetriesAnHour)), sv.cfg.MaxRetriesAnHour)
|
|
go sv.keepTunnelOpenWorker()
|
|
}
|
|
return
|
|
}
|
|
|
|
func (sv *XTCPVisitor) Close() {
|
|
sv.mu.Lock()
|
|
defer sv.mu.Unlock()
|
|
sv.BaseVisitor.Close()
|
|
if sv.cancel != nil {
|
|
sv.cancel()
|
|
}
|
|
if sv.session != nil {
|
|
sv.session.Close()
|
|
}
|
|
}
|
|
|
|
func (sv *XTCPVisitor) worker() {
|
|
xl := xlog.FromContextSafe(sv.ctx)
|
|
for {
|
|
conn, err := sv.l.Accept()
|
|
if err != nil {
|
|
xl.Warnf("xtcp local listener closed")
|
|
return
|
|
}
|
|
go sv.handleConn(conn)
|
|
}
|
|
}
|
|
|
|
func (sv *XTCPVisitor) internalConnWorker() {
|
|
xl := xlog.FromContextSafe(sv.ctx)
|
|
for {
|
|
conn, err := sv.internalLn.Accept()
|
|
if err != nil {
|
|
xl.Warnf("xtcp internal listener closed")
|
|
return
|
|
}
|
|
go sv.handleConn(conn)
|
|
}
|
|
}
|
|
|
|
func (sv *XTCPVisitor) processTunnelStartEvents() {
|
|
for {
|
|
select {
|
|
case <-sv.ctx.Done():
|
|
return
|
|
case <-sv.startTunnelCh:
|
|
start := time.Now()
|
|
sv.makeNatHole()
|
|
duration := time.Since(start)
|
|
// avoid too frequently
|
|
if duration < 10*time.Second {
|
|
time.Sleep(10*time.Second - duration)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sv *XTCPVisitor) keepTunnelOpenWorker() {
|
|
xl := xlog.FromContextSafe(sv.ctx)
|
|
ticker := time.NewTicker(time.Duration(sv.cfg.MinRetryInterval) * time.Second)
|
|
defer ticker.Stop()
|
|
|
|
sv.startTunnelCh <- struct{}{}
|
|
for {
|
|
select {
|
|
case <-sv.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
xl.Debugf("keepTunnelOpenWorker try to check tunnel...")
|
|
conn, err := sv.getTunnelConn()
|
|
if err != nil {
|
|
xl.Warnf("keepTunnelOpenWorker get tunnel connection error: %v", err)
|
|
_ = sv.retryLimiter.Wait(sv.ctx)
|
|
continue
|
|
}
|
|
xl.Debugf("keepTunnelOpenWorker check success")
|
|
if conn != nil {
|
|
conn.Close()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
|
|
xl := xlog.FromContextSafe(sv.ctx)
|
|
isConnTrasfered := false
|
|
defer func() {
|
|
if !isConnTrasfered {
|
|
userConn.Close()
|
|
}
|
|
}()
|
|
|
|
xl.Debugf("get a new xtcp user connection")
|
|
|
|
// Open a tunnel connection to the server. If there is already a successful hole-punching connection,
|
|
// it will be reused. Otherwise, it will block and wait for a successful hole-punching connection until timeout.
|
|
ctx := context.Background()
|
|
if sv.cfg.FallbackTo != "" {
|
|
timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(sv.cfg.FallbackTimeoutMs)*time.Millisecond)
|
|
defer cancel()
|
|
ctx = timeoutCtx
|
|
}
|
|
tunnelConn, err := sv.openTunnel(ctx)
|
|
if err != nil {
|
|
xl.Errorf("open tunnel error: %v", err)
|
|
// no fallback, just return
|
|
if sv.cfg.FallbackTo == "" {
|
|
return
|
|
}
|
|
|
|
xl.Debugf("try to transfer connection to visitor: %s", sv.cfg.FallbackTo)
|
|
if err := sv.helper.TransferConn(sv.cfg.FallbackTo, userConn); err != nil {
|
|
xl.Errorf("transfer connection to visitor %s error: %v", sv.cfg.FallbackTo, err)
|
|
return
|
|
}
|
|
isConnTrasfered = true
|
|
return
|
|
}
|
|
|
|
var muxConnRWCloser io.ReadWriteCloser = tunnelConn
|
|
if sv.cfg.Transport.UseEncryption {
|
|
muxConnRWCloser, err = libio.WithEncryption(muxConnRWCloser, []byte(sv.cfg.SecretKey))
|
|
if err != nil {
|
|
xl.Errorf("create encryption stream error: %v", err)
|
|
return
|
|
}
|
|
}
|
|
if sv.cfg.Transport.UseCompression {
|
|
var recycleFn func()
|
|
muxConnRWCloser, recycleFn = libio.WithCompressionFromPool(muxConnRWCloser)
|
|
defer recycleFn()
|
|
}
|
|
|
|
_, _, errs := libio.Join(userConn, muxConnRWCloser)
|
|
xl.Debugf("join connections closed")
|
|
if len(errs) > 0 {
|
|
xl.Tracef("join connections errors: %v", errs)
|
|
}
|
|
}
|
|
|
|
// openTunnel will open a tunnel connection to the target server.
|
|
func (sv *XTCPVisitor) openTunnel(ctx context.Context) (conn net.Conn, err error) {
|
|
xl := xlog.FromContextSafe(sv.ctx)
|
|
ticker := time.NewTicker(500 * time.Millisecond)
|
|
defer ticker.Stop()
|
|
|
|
timeoutC := time.After(20 * time.Second)
|
|
immediateTrigger := make(chan struct{}, 1)
|
|
defer close(immediateTrigger)
|
|
immediateTrigger <- struct{}{}
|
|
|
|
for {
|
|
select {
|
|
case <-sv.ctx.Done():
|
|
return nil, sv.ctx.Err()
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case <-immediateTrigger:
|
|
conn, err = sv.getTunnelConn()
|
|
case <-ticker.C:
|
|
conn, err = sv.getTunnelConn()
|
|
case <-timeoutC:
|
|
return nil, fmt.Errorf("open tunnel timeout")
|
|
}
|
|
|
|
if err != nil {
|
|
if err != ErrNoTunnelSession {
|
|
xl.Warnf("get tunnel connection error: %v", err)
|
|
}
|
|
continue
|
|
}
|
|
return conn, nil
|
|
}
|
|
}
|
|
|
|
func (sv *XTCPVisitor) getTunnelConn() (net.Conn, error) {
|
|
conn, err := sv.session.OpenConn(sv.ctx)
|
|
if err == nil {
|
|
return conn, nil
|
|
}
|
|
sv.session.Close()
|
|
|
|
select {
|
|
case sv.startTunnelCh <- struct{}{}:
|
|
default:
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// 0. PreCheck
|
|
// 1. Prepare
|
|
// 2. ExchangeInfo
|
|
// 3. MakeNATHole
|
|
// 4. Create a tunnel session using an underlying UDP connection.
|
|
func (sv *XTCPVisitor) makeNatHole() {
|
|
xl := xlog.FromContextSafe(sv.ctx)
|
|
xl.Tracef("makeNatHole start")
|
|
if err := nathole.PreCheck(sv.ctx, sv.helper.MsgTransporter(), sv.cfg.ServerName, 5*time.Second); err != nil {
|
|
xl.Warnf("nathole precheck error: %v", err)
|
|
return
|
|
}
|
|
|
|
xl.Tracef("nathole prepare start")
|
|
prepareResult, err := nathole.Prepare([]string{sv.clientCfg.NatHoleSTUNServer})
|
|
if err != nil {
|
|
xl.Warnf("nathole prepare error: %v", err)
|
|
return
|
|
}
|
|
xl.Infof("nathole prepare success, nat type: %s, behavior: %s, addresses: %v, assistedAddresses: %v",
|
|
prepareResult.NatType, prepareResult.Behavior, prepareResult.Addrs, prepareResult.AssistedAddrs)
|
|
|
|
listenConn := prepareResult.ListenConn
|
|
|
|
// send NatHoleVisitor to server
|
|
now := time.Now().Unix()
|
|
transactionID := nathole.NewTransactionID()
|
|
natHoleVisitorMsg := &msg.NatHoleVisitor{
|
|
TransactionID: transactionID,
|
|
ProxyName: sv.cfg.ServerName,
|
|
Protocol: sv.cfg.Protocol,
|
|
SignKey: util.GetAuthKey(sv.cfg.SecretKey, now),
|
|
Timestamp: now,
|
|
MappedAddrs: prepareResult.Addrs,
|
|
AssistedAddrs: prepareResult.AssistedAddrs,
|
|
}
|
|
|
|
xl.Tracef("nathole exchange info start")
|
|
natHoleRespMsg, err := nathole.ExchangeInfo(sv.ctx, sv.helper.MsgTransporter(), transactionID, natHoleVisitorMsg, 5*time.Second)
|
|
if err != nil {
|
|
listenConn.Close()
|
|
xl.Warnf("nathole exchange info error: %v", err)
|
|
return
|
|
}
|
|
|
|
xl.Infof("get natHoleRespMsg, sid [%s], protocol [%s], candidate address %v, assisted address %v, detectBehavior: %+v",
|
|
natHoleRespMsg.Sid, natHoleRespMsg.Protocol, natHoleRespMsg.CandidateAddrs,
|
|
natHoleRespMsg.AssistedAddrs, natHoleRespMsg.DetectBehavior)
|
|
|
|
newListenConn, raddr, err := nathole.MakeHole(sv.ctx, listenConn, natHoleRespMsg, []byte(sv.cfg.SecretKey))
|
|
if err != nil {
|
|
listenConn.Close()
|
|
xl.Warnf("make hole error: %v", err)
|
|
return
|
|
}
|
|
listenConn = newListenConn
|
|
xl.Infof("establishing nat hole connection successful, sid [%s], remoteAddr [%s]", natHoleRespMsg.Sid, raddr)
|
|
|
|
if err := sv.session.Init(listenConn, raddr); err != nil {
|
|
listenConn.Close()
|
|
xl.Warnf("init tunnel session error: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
type TunnelSession interface {
|
|
Init(listenConn *net.UDPConn, raddr *net.UDPAddr) error
|
|
OpenConn(context.Context) (net.Conn, error)
|
|
Close()
|
|
}
|
|
|
|
type KCPTunnelSession struct {
|
|
session *fmux.Session
|
|
lConn *net.UDPConn
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
func NewKCPTunnelSession() TunnelSession {
|
|
return &KCPTunnelSession{}
|
|
}
|
|
|
|
func (ks *KCPTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) error {
|
|
listenConn.Close()
|
|
laddr, _ := net.ResolveUDPAddr("udp", listenConn.LocalAddr().String())
|
|
lConn, err := net.DialUDP("udp", laddr, raddr)
|
|
if err != nil {
|
|
return fmt.Errorf("dial udp error: %v", err)
|
|
}
|
|
remote, err := netpkg.NewKCPConnFromUDP(lConn, true, raddr.String())
|
|
if err != nil {
|
|
return fmt.Errorf("create kcp connection from udp connection error: %v", err)
|
|
}
|
|
|
|
fmuxCfg := fmux.DefaultConfig()
|
|
fmuxCfg.KeepAliveInterval = 10 * time.Second
|
|
fmuxCfg.MaxStreamWindowSize = 6 * 1024 * 1024
|
|
fmuxCfg.LogOutput = io.Discard
|
|
session, err := fmux.Client(remote, fmuxCfg)
|
|
if err != nil {
|
|
remote.Close()
|
|
return fmt.Errorf("initial client session error: %v", err)
|
|
}
|
|
ks.mu.Lock()
|
|
ks.session = session
|
|
ks.lConn = lConn
|
|
ks.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (ks *KCPTunnelSession) OpenConn(_ context.Context) (net.Conn, error) {
|
|
ks.mu.RLock()
|
|
defer ks.mu.RUnlock()
|
|
session := ks.session
|
|
if session == nil {
|
|
return nil, ErrNoTunnelSession
|
|
}
|
|
return session.Open()
|
|
}
|
|
|
|
func (ks *KCPTunnelSession) Close() {
|
|
ks.mu.Lock()
|
|
defer ks.mu.Unlock()
|
|
if ks.session != nil {
|
|
_ = ks.session.Close()
|
|
ks.session = nil
|
|
}
|
|
if ks.lConn != nil {
|
|
_ = ks.lConn.Close()
|
|
ks.lConn = nil
|
|
}
|
|
}
|
|
|
|
type QUICTunnelSession struct {
|
|
session quic.Connection
|
|
listenConn *net.UDPConn
|
|
mu sync.RWMutex
|
|
|
|
clientCfg *v1.ClientCommonConfig
|
|
}
|
|
|
|
func NewQUICTunnelSession(clientCfg *v1.ClientCommonConfig) TunnelSession {
|
|
return &QUICTunnelSession{
|
|
clientCfg: clientCfg,
|
|
}
|
|
}
|
|
|
|
func (qs *QUICTunnelSession) Init(listenConn *net.UDPConn, raddr *net.UDPAddr) error {
|
|
tlsConfig, err := transport.NewClientTLSConfig("", "", "", raddr.String())
|
|
if err != nil {
|
|
return fmt.Errorf("create tls config error: %v", err)
|
|
}
|
|
tlsConfig.NextProtos = []string{"frp"}
|
|
quicConn, err := quic.Dial(context.Background(), listenConn, raddr, tlsConfig,
|
|
&quic.Config{
|
|
MaxIdleTimeout: time.Duration(qs.clientCfg.Transport.QUIC.MaxIdleTimeout) * time.Second,
|
|
MaxIncomingStreams: int64(qs.clientCfg.Transport.QUIC.MaxIncomingStreams),
|
|
KeepAlivePeriod: time.Duration(qs.clientCfg.Transport.QUIC.KeepalivePeriod) * time.Second,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("dial quic error: %v", err)
|
|
}
|
|
qs.mu.Lock()
|
|
qs.session = quicConn
|
|
qs.listenConn = listenConn
|
|
qs.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
func (qs *QUICTunnelSession) OpenConn(ctx context.Context) (net.Conn, error) {
|
|
qs.mu.RLock()
|
|
defer qs.mu.RUnlock()
|
|
session := qs.session
|
|
if session == nil {
|
|
return nil, ErrNoTunnelSession
|
|
}
|
|
stream, err := session.OpenStreamSync(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return netpkg.QuicStreamToNetConn(stream, session), nil
|
|
}
|
|
|
|
func (qs *QUICTunnelSession) Close() {
|
|
qs.mu.Lock()
|
|
defer qs.mu.Unlock()
|
|
if qs.session != nil {
|
|
_ = qs.session.CloseWithError(0, "")
|
|
qs.session = nil
|
|
}
|
|
if qs.listenConn != nil {
|
|
_ = qs.listenConn.Close()
|
|
qs.listenConn = nil
|
|
}
|
|
}
|