diff --git a/client/admin_api.go b/client/admin_api.go index c5548fb5..49c2abfd 100644 --- a/client/admin_api.go +++ b/client/admin_api.go @@ -33,36 +33,19 @@ type GeneralResponse struct { } // GET api/reload - func (svr *Service) apiReload(w http.ResponseWriter, r *http.Request) { res := GeneralResponse{Code: 200} - log.Info("Http request [/api/reload]") + log.Info("api request [/api/reload]") defer func() { - log.Info("Http response [/api/reload], code [%d]", res.Code) + log.Info("api response [/api/reload], code [%d]", res.Code) w.WriteHeader(res.Code) if len(res.Msg) > 0 { w.Write([]byte(res.Msg)) } }() - content, err := config.GetRenderedConfFromFile(svr.cfgFile) - if err != nil { - res.Code = 400 - res.Msg = err.Error() - log.Warn("reload frpc config file error: %s", res.Msg) - return - } - - newCommonCfg, err := config.UnmarshalClientConfFromIni(content) - if err != nil { - res.Code = 400 - res.Msg = err.Error() - log.Warn("reload frpc common section error: %s", res.Msg) - return - } - - pxyCfgs, visitorCfgs, err := config.LoadAllProxyConfsFromIni(svr.cfg.User, content, newCommonCfg.Start) + _, pxyCfgs, visitorCfgs, err := config.ParseClientConfig(svr.cfgFile) if err != nil { res.Code = 400 res.Msg = err.Error() @@ -70,8 +53,7 @@ func (svr *Service) apiReload(w http.ResponseWriter, r *http.Request) { return } - err = svr.ReloadConf(pxyCfgs, visitorCfgs) - if err != nil { + if err = svr.ReloadConf(pxyCfgs, visitorCfgs); err != nil { res.Code = 500 res.Msg = err.Error() log.Warn("reload frpc proxy config error: %s", res.Msg) diff --git a/cmd/frpc/sub/http.go b/cmd/frpc/sub/http.go index 03593f1a..2e19fce4 100644 --- a/cmd/frpc/sub/http.go +++ b/cmd/frpc/sub/http.go @@ -47,7 +47,7 @@ var httpCmd = &cobra.Command{ Use: "http", Short: "Run frpc with a single http proxy", RunE: func(cmd *cobra.Command, args []string) error { - clientCfg, err := parseClientCommonCfg(CfgFileTypeCmd, nil) + clientCfg, err := parseClientCommonCfgFromCmd() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/https.go b/cmd/frpc/sub/https.go index d636f426..8a14d39d 100644 --- a/cmd/frpc/sub/https.go +++ b/cmd/frpc/sub/https.go @@ -43,7 +43,7 @@ var httpsCmd = &cobra.Command{ Use: "https", Short: "Run frpc with a single https proxy", RunE: func(cmd *cobra.Command, args []string) error { - clientCfg, err := parseClientCommonCfg(CfgFileTypeCmd, nil) + clientCfg, err := parseClientCommonCfgFromCmd() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/reload.go b/cmd/frpc/sub/reload.go index 44b16770..35f160d3 100644 --- a/cmd/frpc/sub/reload.go +++ b/cmd/frpc/sub/reload.go @@ -35,19 +35,13 @@ var reloadCmd = &cobra.Command{ Use: "reload", Short: "Hot-Reload frpc configuration", RunE: func(cmd *cobra.Command, args []string) error { - iniContent, err := config.GetRenderedConfFromFile(cfgFile) + cfg, _, _, err := config.ParseClientConfig(cfgFile) if err != nil { fmt.Println(err) os.Exit(1) } - clientCfg, err := parseClientCommonCfg(CfgFileTypeIni, iniContent) - if err != nil { - fmt.Println(err) - os.Exit(1) - } - - err = reload(clientCfg) + err = reload(cfg) if err != nil { fmt.Printf("frpc reload error: %v\n", err) os.Exit(1) diff --git a/cmd/frpc/sub/root.go b/cmd/frpc/sub/root.go index a8012491..2c1b3e34 100644 --- a/cmd/frpc/sub/root.go +++ b/cmd/frpc/sub/root.go @@ -15,14 +15,11 @@ package sub import ( - "bytes" "context" "fmt" - "io/ioutil" "net" "os" "os/signal" - "path/filepath" "strconv" "strings" "syscall" @@ -132,25 +129,6 @@ func handleSignal(svr *client.Service) { close(kcpDoneCh) } -func parseClientCommonCfg(fileType int, source []byte) (cfg config.ClientCommonConf, err error) { - if fileType == CfgFileTypeIni { - cfg, err = config.UnmarshalClientConfFromIni(source) - } else if fileType == CfgFileTypeCmd { - cfg, err = parseClientCommonCfgFromCmd() - } - if err != nil { - return - } - - cfg.Complete() - err = cfg.Validate() - if err != nil { - err = fmt.Errorf("Parse config error: %v", err) - return - } - return -} - func parseClientCommonCfgFromCmd() (cfg config.ClientCommonConf, err error) { cfg = config.GetDefaultClientConf() @@ -179,89 +157,22 @@ func parseClientCommonCfgFromCmd() (cfg config.ClientCommonConf, err error) { cfg.Token = token cfg.TLSEnable = tlsEnable + cfg.Complete() + if err = cfg.Validate(); err != nil { + err = fmt.Errorf("Parse config error: %v", err) + return + } return } func runClient(cfgFilePath string) error { - cfg, pxyCfgs, visitorCfgs, err := parseConfig(cfgFilePath) + cfg, pxyCfgs, visitorCfgs, err := config.ParseClientConfig(cfgFilePath) if err != nil { return err } return startService(cfg, pxyCfgs, visitorCfgs, cfgFilePath) } -func parseConfig(cfgFilePath string) ( - cfg config.ClientCommonConf, - pxyCfgs map[string]config.ProxyConf, - visitorCfgs map[string]config.VisitorConf, - err error, -) { - var content []byte - content, err = config.GetRenderedConfFromFile(cfgFilePath) - if err != nil { - return - } - configBuffer := bytes.NewBuffer(nil) - configBuffer.Write(content) - - // Parse common section. - cfg, err = parseClientCommonCfg(CfgFileTypeIni, content) - if err != nil { - return - } - - // Aggregate proxy configs from include files. - var buf []byte - buf, err = getIncludeContents(cfg.IncludeConfigFiles) - if err != nil { - err = fmt.Errorf("getIncludeContents error: %v", err) - return - } - configBuffer.WriteString("\n") - configBuffer.Write(buf) - - // Parse all proxy and visitor configs. - pxyCfgs, visitorCfgs, err = config.LoadAllProxyConfsFromIni(cfg.User, configBuffer.Bytes(), cfg.Start) - if err != nil { - return - } - return -} - -// getIncludeContents renders all configs from paths. -// files format can be a single file path or directory or regex path. -func getIncludeContents(paths []string) ([]byte, error) { - out := bytes.NewBuffer(nil) - for _, path := range paths { - absDir, err := filepath.Abs(filepath.Dir(path)) - if err != nil { - return nil, err - } - if _, err := os.Stat(absDir); os.IsNotExist(err) { - return nil, err - } - files, err := ioutil.ReadDir(absDir) - if err != nil { - return nil, err - } - for _, fi := range files { - if fi.IsDir() { - continue - } - absFile := filepath.Join(absDir, fi.Name()) - if matched, _ := filepath.Match(filepath.Join(absDir, filepath.Base(path)), absFile); matched { - tmpContent, err := config.GetRenderedConfFromFile(absFile) - if err != nil { - return nil, fmt.Errorf("render extra config %s error: %v", absFile, err) - } - out.Write(tmpContent) - out.WriteString("\n") - } - } - } - return out.Bytes(), nil -} - func startService( cfg config.ClientCommonConf, pxyCfgs map[string]config.ProxyConf, diff --git a/cmd/frpc/sub/status.go b/cmd/frpc/sub/status.go index 774de53f..c71b4b9d 100644 --- a/cmd/frpc/sub/status.go +++ b/cmd/frpc/sub/status.go @@ -38,20 +38,13 @@ var statusCmd = &cobra.Command{ Use: "status", Short: "Overview of all proxies status", RunE: func(cmd *cobra.Command, args []string) error { - iniContent, err := config.GetRenderedConfFromFile(cfgFile) + cfg, _, _, err := config.ParseClientConfig(cfgFile) if err != nil { fmt.Println(err) os.Exit(1) } - clientCfg, err := parseClientCommonCfg(CfgFileTypeIni, iniContent) - if err != nil { - fmt.Println(err) - os.Exit(1) - } - - err = status(clientCfg) - if err != nil { + if err = status(cfg); err != nil { fmt.Printf("frpc get status error: %v\n", err) os.Exit(1) } diff --git a/cmd/frpc/sub/stcp.go b/cmd/frpc/sub/stcp.go index 673a268e..45f01e59 100644 --- a/cmd/frpc/sub/stcp.go +++ b/cmd/frpc/sub/stcp.go @@ -45,7 +45,7 @@ var stcpCmd = &cobra.Command{ Use: "stcp", Short: "Run frpc with a single stcp proxy", RunE: func(cmd *cobra.Command, args []string) error { - clientCfg, err := parseClientCommonCfg(CfgFileTypeCmd, nil) + clientCfg, err := parseClientCommonCfgFromCmd() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/sudp.go b/cmd/frpc/sub/sudp.go index 3c3d5a8b..45c5ad61 100644 --- a/cmd/frpc/sub/sudp.go +++ b/cmd/frpc/sub/sudp.go @@ -45,7 +45,7 @@ var sudpCmd = &cobra.Command{ Use: "sudp", Short: "Run frpc with a single sudp proxy", RunE: func(cmd *cobra.Command, args []string) error { - clientCfg, err := parseClientCommonCfg(CfgFileTypeCmd, nil) + clientCfg, err := parseClientCommonCfgFromCmd() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/tcp.go b/cmd/frpc/sub/tcp.go index b62cb74a..7e867345 100644 --- a/cmd/frpc/sub/tcp.go +++ b/cmd/frpc/sub/tcp.go @@ -41,7 +41,7 @@ var tcpCmd = &cobra.Command{ Use: "tcp", Short: "Run frpc with a single tcp proxy", RunE: func(cmd *cobra.Command, args []string) error { - clientCfg, err := parseClientCommonCfg(CfgFileTypeCmd, nil) + clientCfg, err := parseClientCommonCfgFromCmd() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/tcpmux.go b/cmd/frpc/sub/tcpmux.go index 6f46cf76..cef845d6 100644 --- a/cmd/frpc/sub/tcpmux.go +++ b/cmd/frpc/sub/tcpmux.go @@ -44,7 +44,7 @@ var tcpMuxCmd = &cobra.Command{ Use: "tcpmux", Short: "Run frpc with a single tcpmux proxy", RunE: func(cmd *cobra.Command, args []string) error { - clientCfg, err := parseClientCommonCfg(CfgFileTypeCmd, nil) + clientCfg, err := parseClientCommonCfgFromCmd() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/udp.go b/cmd/frpc/sub/udp.go index 7f6dd3f0..2ce4327e 100644 --- a/cmd/frpc/sub/udp.go +++ b/cmd/frpc/sub/udp.go @@ -41,7 +41,7 @@ var udpCmd = &cobra.Command{ Use: "udp", Short: "Run frpc with a single udp proxy", RunE: func(cmd *cobra.Command, args []string) error { - clientCfg, err := parseClientCommonCfg(CfgFileTypeCmd, nil) + clientCfg, err := parseClientCommonCfgFromCmd() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/verify.go b/cmd/frpc/sub/verify.go index 4e76d0f3..76872b90 100644 --- a/cmd/frpc/sub/verify.go +++ b/cmd/frpc/sub/verify.go @@ -18,6 +18,8 @@ import ( "fmt" "os" + "github.com/fatedier/frp/pkg/config" + "github.com/spf13/cobra" ) @@ -29,7 +31,7 @@ var verifyCmd = &cobra.Command{ Use: "verify", Short: "Verify that the configures is valid", RunE: func(cmd *cobra.Command, args []string) error { - _, _, _, err := parseConfig(cfgFile) + _, _, _, err := config.ParseClientConfig(cfgFile) if err != nil { fmt.Println(err) os.Exit(1) diff --git a/cmd/frpc/sub/xtcp.go b/cmd/frpc/sub/xtcp.go index 1eb096f7..6c6c7d84 100644 --- a/cmd/frpc/sub/xtcp.go +++ b/cmd/frpc/sub/xtcp.go @@ -45,7 +45,7 @@ var xtcpCmd = &cobra.Command{ Use: "xtcp", Short: "Run frpc with a single xtcp proxy", RunE: func(cmd *cobra.Command, args []string) error { - clientCfg, err := parseClientCommonCfg(CfgFileTypeCmd, nil) + clientCfg, err := parseClientCommonCfgFromCmd() if err != nil { fmt.Println(err) os.Exit(1) diff --git a/pkg/config/parse.go b/pkg/config/parse.go new file mode 100644 index 00000000..cf994c96 --- /dev/null +++ b/pkg/config/parse.go @@ -0,0 +1,100 @@ +// Copyright 2021 The frp Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "bytes" + "fmt" + "io/ioutil" + "os" + "path/filepath" +) + +func ParseClientConfig(filePath string) ( + cfg ClientCommonConf, + pxyCfgs map[string]ProxyConf, + visitorCfgs map[string]VisitorConf, + err error, +) { + var content []byte + content, err = GetRenderedConfFromFile(filePath) + if err != nil { + return + } + configBuffer := bytes.NewBuffer(nil) + configBuffer.Write(content) + + // Parse common section. + cfg, err = UnmarshalClientConfFromIni(content) + if err != nil { + return + } + cfg.Complete() + if err = cfg.Validate(); err != nil { + err = fmt.Errorf("Parse config error: %v", err) + return + } + + // Aggregate proxy configs from include files. + var buf []byte + buf, err = getIncludeContents(cfg.IncludeConfigFiles) + if err != nil { + err = fmt.Errorf("getIncludeContents error: %v", err) + return + } + configBuffer.WriteString("\n") + configBuffer.Write(buf) + + // Parse all proxy and visitor configs. + pxyCfgs, visitorCfgs, err = LoadAllProxyConfsFromIni(cfg.User, configBuffer.Bytes(), cfg.Start) + if err != nil { + return + } + return +} + +// getIncludeContents renders all configs from paths. +// files format can be a single file path or directory or regex path. +func getIncludeContents(paths []string) ([]byte, error) { + out := bytes.NewBuffer(nil) + for _, path := range paths { + absDir, err := filepath.Abs(filepath.Dir(path)) + if err != nil { + return nil, err + } + if _, err := os.Stat(absDir); os.IsNotExist(err) { + return nil, err + } + files, err := ioutil.ReadDir(absDir) + if err != nil { + return nil, err + } + for _, fi := range files { + if fi.IsDir() { + continue + } + absFile := filepath.Join(absDir, fi.Name()) + if matched, _ := filepath.Match(filepath.Join(absDir, filepath.Base(path)), absFile); matched { + tmpContent, err := GetRenderedConfFromFile(absFile) + if err != nil { + return nil, fmt.Errorf("render extra config %s error: %v", absFile, err) + } + out.Write(tmpContent) + out.WriteString("\n") + } + } + } + return out.Bytes(), nil +} diff --git a/pkg/config/proxy.go b/pkg/config/proxy.go index 8aacece0..c000bb30 100644 --- a/pkg/config/proxy.go +++ b/pkg/config/proxy.go @@ -143,7 +143,6 @@ type BaseProxyConf struct { // meta info for each proxy Metas map[string]string `ini:"-" json:"metas"` - // TODO: LocalSvrConf => LocalAppConf LocalSvrConf `ini:",extends"` HealthCheckConf `ini:",extends"` } diff --git a/pkg/config/server.go b/pkg/config/server.go index 92ca7baa..a5af0faf 100644 --- a/pkg/config/server.go +++ b/pkg/config/server.go @@ -223,7 +223,6 @@ func UnmarshalServerConfFromIni(source interface{}) (ServerCommonConf, error) { s, err := f.GetSection("common") if err != nil { - // TODO: add error info return ServerCommonConf{}, err } diff --git a/server/control.go b/server/control.go index bec64aca..7632bbed 100644 --- a/server/control.go +++ b/server/control.go @@ -248,12 +248,10 @@ func (ctl *Control) GetWorkConn() (workConn net.Conn, err error) { xl.Debug("get work connection from pool") default: // no work connections available in the poll, send message to frpc to get more - err = errors.PanicToError(func() { + if err = errors.PanicToError(func() { ctl.sendCh <- &msg.ReqWorkConn{} - }) - if err != nil { - xl.Error("%v", err) - return + }); err != nil { + return nil, fmt.Errorf("control is already closed") } select { @@ -357,15 +355,15 @@ func (ctl *Control) stoper() { ctl.allShutdown.WaitStart() + ctl.conn.Close() + ctl.readerShutdown.WaitDone() + close(ctl.readCh) ctl.managerShutdown.WaitDone() close(ctl.sendCh) ctl.writerShutdown.WaitDone() - ctl.conn.Close() - ctl.readerShutdown.WaitDone() - ctl.mu.Lock() defer ctl.mu.Unlock() diff --git a/test/e2e/basic/basic.go b/test/e2e/basic/basic.go index 5e5f432b..3c18b941 100644 --- a/test/e2e/basic/basic.go +++ b/test/e2e/basic/basic.go @@ -6,6 +6,7 @@ import ( "github.com/fatedier/frp/test/e2e/framework" "github.com/fatedier/frp/test/e2e/framework/consts" + "github.com/fatedier/frp/test/e2e/mock/server" "github.com/fatedier/frp/test/e2e/pkg/port" "github.com/fatedier/frp/test/e2e/pkg/request" @@ -80,7 +81,7 @@ var _ = Describe("[Feature: Basic]", func() { for _, test := range tests { framework.NewRequestExpect(f). - Request(framework.SetRequestProtocol(protocol)). + RequestModify(framework.SetRequestProtocol(protocol)). PortName(test.portName). Explain(test.proxyName). Ensure() @@ -185,7 +186,7 @@ var _ = Describe("[Feature: Basic]", func() { for _, test := range tests { framework.NewRequestExpect(f). - Request(framework.SetRequestProtocol(protocol)). + RequestModify(framework.SetRequestProtocol(protocol)). PortName(test.bindPortName). Explain(test.proxyName). ExpectError(test.expectError). @@ -213,12 +214,11 @@ var _ = Describe("[Feature: Basic]", func() { multiplexer = httpconnect local_port = {{ .%s }} custom_domains = %s - `+extra, proxyName, framework.TCPEchoServerPort, proxyName) + `+extra, proxyName, port.GenName(proxyName), proxyName) } tests := []struct { proxyName string - portName string extraConfig string }{ { @@ -244,7 +244,11 @@ var _ = Describe("[Feature: Basic]", func() { // build all client config for _, test := range tests { clientConf += getProxyConf(test.proxyName, test.extraConfig) + "\n" + + localServer := server.New(server.TCP, server.WithBindPort(f.AllocPort()), server.WithRespContent([]byte(test.proxyName))) + f.RunServer(port.GenName(test.proxyName), localServer) } + // run frps and frpc f.RunProcesses([]string{serverConf}, []string{clientConf}) @@ -257,15 +261,15 @@ var _ = Describe("[Feature: Basic]", func() { proxyURL := fmt.Sprintf("http://127.0.0.1:%d", f.PortByName(tcpmuxHTTPConnectPortName)) // Request with incorrect connect hostname - framework.NewRequestExpect(f).Request(func(r *request.Request) { + framework.NewRequestExpect(f).RequestModify(func(r *request.Request) { r.Proxy(proxyURL, "invalid") }).ExpectError(true).Explain("request without HTTP connect expect error").Ensure() // Request with correct connect hostname for _, test := range tests { - framework.NewRequestExpect(f).Request(func(r *request.Request) { + framework.NewRequestExpect(f).RequestModify(func(r *request.Request) { r.Proxy(proxyURL, test.proxyName) - }).Explain(test.proxyName).Ensure() + }).ExpectResp([]byte(test.proxyName)).Explain(test.proxyName).Ensure() } }) }) diff --git a/test/e2e/basic/client_server.go b/test/e2e/basic/client_server.go index d070ff8f..97b97fc6 100644 --- a/test/e2e/basic/client_server.go +++ b/test/e2e/basic/client_server.go @@ -49,7 +49,7 @@ func defineClientServerTest(desc string, f *framework.Framework, configures *gen f.RunProcesses([]string{serverConf}, []string{clientConf}) framework.NewRequestExpect(f).PortName(tcpPortName).ExpectError(configures.expectError).Explain("tcp proxy").Ensure() - framework.NewRequestExpect(f).Request(framework.SetRequestProtocol("udp")). + framework.NewRequestExpect(f).RequestModify(framework.SetRequestProtocol("udp")). PortName(udpPortName).ExpectError(configures.expectError).Explain("udp proxy").Ensure() }) } diff --git a/test/e2e/basic/server.go b/test/e2e/basic/server.go index 2e2b6be8..ed979c57 100644 --- a/test/e2e/basic/server.go +++ b/test/e2e/basic/server.go @@ -62,17 +62,17 @@ var _ = Describe("[Feature: Server Manager]", func() { framework.NewRequestExpect(f).PortName(tcpPortName).Ensure() // Not Allowed - framework.NewRequestExpect(f).Request(framework.SetRequestPort(20001)).ExpectError(true).Ensure() + framework.NewRequestExpect(f).RequestModify(framework.SetRequestPort(20001)).ExpectError(true).Ensure() // Unavailable, already bind by frps framework.NewRequestExpect(f).PortName(consts.PortServerName).ExpectError(true).Ensure() // UDP // Allowed in range - framework.NewRequestExpect(f).Request(framework.SetRequestProtocol("udp")).PortName(udpPortName).Ensure() + framework.NewRequestExpect(f).RequestModify(framework.SetRequestProtocol("udp")).PortName(udpPortName).Ensure() // Not Allowed - framework.NewRequestExpect(f).Request(func(r *request.Request) { + framework.NewRequestExpect(f).RequestModify(func(r *request.Request) { r.UDP().Port(20003) }).ExpectError(true).Ensure() }) diff --git a/test/e2e/e2e.go b/test/e2e/e2e.go index ce830fdc..b392954b 100644 --- a/test/e2e/e2e.go +++ b/test/e2e/e2e.go @@ -49,7 +49,6 @@ func RunE2ETests(t *testing.T) { // accepting the byte array. func setupSuite() { // Run only on Ginkgo node 1 - // TODO } // setupSuitePerGinkgoNode is the boilerplate that can be used to setup ginkgo test suites, on the SynchronizedBeforeSuite step. diff --git a/test/e2e/framework/framework.go b/test/e2e/framework/framework.go index 19564d75..ed1bfe62 100644 --- a/test/e2e/framework/framework.go +++ b/test/e2e/framework/framework.go @@ -9,6 +9,7 @@ import ( "strings" "text/template" + "github.com/fatedier/frp/test/e2e/mock/server" "github.com/fatedier/frp/test/e2e/pkg/port" "github.com/fatedier/frp/test/e2e/pkg/process" @@ -32,7 +33,7 @@ type Framework struct { // portAllocator to alloc port for this test case. portAllocator *port.Allocator - // Multiple mock servers used for e2e testing. + // Multiple default mock servers used for e2e testing. mockServers *MockServers // To make sure that this framework cleans up after itself, no matter what, @@ -47,6 +48,9 @@ type Framework struct { serverProcesses []*process.Process clientConfPaths []string clientProcesses []*process.Process + + // Manual registered mock servers. + servers []*server.Server } func NewDefaultFramework() *Framework { @@ -62,6 +66,7 @@ func NewDefaultFramework() *Framework { func NewFramework(opt Options) *Framework { f := &Framework{ portAllocator: port.NewAllocator(opt.FromPortIndex, opt.ToPortIndex, opt.TotalParallelNode, opt.CurrentNodeIndex-1), + usedPorts: make(map[string]int), } ginkgo.BeforeEach(f.BeforeEach) @@ -110,9 +115,14 @@ func (f *Framework) AfterEach() { f.serverProcesses = nil f.clientProcesses = nil - // close mock servers + // close default mock servers f.mockServers.Close() + // close manual registered mock servers + for _, s := range f.servers { + s.Close() + } + // clean directory os.RemoveAll(f.TempDirectory) f.TempDirectory = "" @@ -123,7 +133,7 @@ func (f *Framework) AfterEach() { for _, port := range f.usedPorts { f.portAllocator.Release(port) } - f.usedPorts = nil + f.usedPorts = make(map[string]int) } var portRegex = regexp.MustCompile(`{{ \.Port.*? }}`) @@ -161,7 +171,6 @@ func (f *Framework) genPortsFromTemplates(templates []string) (ports map[string] ports[name] = port } return - } // RenderTemplates alloc all ports for port names placeholder. @@ -176,6 +185,10 @@ func (f *Framework) RenderTemplates(templates []string) (outs []string, ports ma params[name] = port } + for name, port := range f.usedPorts { + params[name] = port + } + for _, t := range templates { tmpl, err := template.New("").Parse(t) if err != nil { @@ -193,3 +206,22 @@ func (f *Framework) RenderTemplates(templates []string) (outs []string, ports ma func (f *Framework) PortByName(name string) int { return f.usedPorts[name] } + +func (f *Framework) AllocPort() int { + port := f.portAllocator.Get() + ExpectTrue(port > 0, "alloc port failed") + return port +} + +func (f *Framework) ReleasePort(port int) { + f.portAllocator.Release(port) +} + +func (f *Framework) RunServer(portName string, s *server.Server) { + f.servers = append(f.servers, s) + if s.BindPort() > 0 { + f.usedPorts[portName] = s.BindPort() + } + err := s.Run() + ExpectNoError(err, portName) +} diff --git a/test/e2e/framework/mockservers.go b/test/e2e/framework/mockservers.go index 3598aac1..1935a2b1 100644 --- a/test/e2e/framework/mockservers.go +++ b/test/e2e/framework/mockservers.go @@ -4,7 +4,7 @@ import ( "fmt" "os" - "github.com/fatedier/frp/test/e2e/mock/echoserver" + "github.com/fatedier/frp/test/e2e/mock/server" "github.com/fatedier/frp/test/e2e/pkg/port" ) @@ -15,36 +15,22 @@ const ( ) type MockServers struct { - tcpEchoServer *echoserver.Server - udpEchoServer *echoserver.Server - udsEchoServer *echoserver.Server + tcpEchoServer *server.Server + udpEchoServer *server.Server + udsEchoServer *server.Server } func NewMockServers(portAllocator *port.Allocator) *MockServers { s := &MockServers{} tcpPort := portAllocator.Get() udpPort := portAllocator.Get() - s.tcpEchoServer = echoserver.New(echoserver.Options{ - Type: echoserver.TCP, - BindAddr: "127.0.0.1", - BindPort: int32(tcpPort), - RepeatNum: 1, - }) - s.udpEchoServer = echoserver.New(echoserver.Options{ - Type: echoserver.UDP, - BindAddr: "127.0.0.1", - BindPort: int32(udpPort), - RepeatNum: 1, - }) + s.tcpEchoServer = server.New(server.TCP, server.WithBindPort(tcpPort), server.WithEchoMode(true)) + s.udpEchoServer = server.New(server.UDP, server.WithBindPort(udpPort), server.WithEchoMode(true)) udsIndex := portAllocator.Get() udsAddr := fmt.Sprintf("%s/frp_echo_server_%d.sock", os.TempDir(), udsIndex) os.Remove(udsAddr) - s.udsEchoServer = echoserver.New(echoserver.Options{ - Type: echoserver.Unix, - BindAddr: udsAddr, - RepeatNum: 1, - }) + s.udsEchoServer = server.New(server.Unix, server.WithBindAddr(udsAddr), server.WithEchoMode(true)) return s } @@ -65,14 +51,14 @@ func (m *MockServers) Close() { m.tcpEchoServer.Close() m.udpEchoServer.Close() m.udsEchoServer.Close() - os.Remove(m.udsEchoServer.GetOptions().BindAddr) + os.Remove(m.udsEchoServer.BindAddr()) } func (m *MockServers) GetTemplateParams() map[string]interface{} { ret := make(map[string]interface{}) - ret[TCPEchoServerPort] = m.tcpEchoServer.GetOptions().BindPort - ret[UDPEchoServerPort] = m.udpEchoServer.GetOptions().BindPort - ret[UDSEchoServerAddr] = m.udsEchoServer.GetOptions().BindAddr + ret[TCPEchoServerPort] = m.tcpEchoServer.BindPort() + ret[UDPEchoServerPort] = m.udpEchoServer.BindPort() + ret[UDSEchoServerAddr] = m.udsEchoServer.BindAddr() return ret } diff --git a/test/e2e/framework/process.go b/test/e2e/framework/process.go index 2d200cf0..40df9c9e 100644 --- a/test/e2e/framework/process.go +++ b/test/e2e/framework/process.go @@ -28,7 +28,9 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str ExpectNoError(err) ExpectTrue(len(templates) > 0) - f.usedPorts = ports + for name, port := range ports { + f.usedPorts[name] = port + } for i := range serverTemplates { path := filepath.Join(f.TempDirectory, fmt.Sprintf("frp-e2e-server-%d", i)) diff --git a/test/e2e/framework/request.go b/test/e2e/framework/request.go index decf6bf3..847f1aa7 100644 --- a/test/e2e/framework/request.go +++ b/test/e2e/framework/request.go @@ -54,7 +54,7 @@ func NewRequestExpect(f *Framework) *RequestExpect { } } -func (e *RequestExpect) Request(f func(r *request.Request)) *RequestExpect { +func (e *RequestExpect) RequestModify(f func(r *request.Request)) *RequestExpect { f(e.req) return e } @@ -66,6 +66,11 @@ func (e *RequestExpect) PortName(name string) *RequestExpect { return e } +func (e *RequestExpect) ExpectResp(resp []byte) *RequestExpect { + e.expectResp = resp + return e +} + func (e *RequestExpect) ExpectError(expectErr bool) *RequestExpect { e.expectError = expectErr return e diff --git a/test/e2e/mock/echoserver/echoserver.go b/test/e2e/mock/echoserver/echoserver.go deleted file mode 100644 index 09a20954..00000000 --- a/test/e2e/mock/echoserver/echoserver.go +++ /dev/null @@ -1,111 +0,0 @@ -package echoserver - -import ( - "fmt" - "net" - "strings" - - fnet "github.com/fatedier/frp/pkg/util/net" -) - -type ServerType string - -const ( - TCP ServerType = "tcp" - UDP ServerType = "udp" - Unix ServerType = "unix" -) - -type Options struct { - Type ServerType - BindAddr string - BindPort int32 - RepeatNum int - SpecifiedResponse string -} - -type Server struct { - opt Options - - l net.Listener -} - -func New(opt Options) *Server { - if opt.Type == "" { - opt.Type = TCP - } - if opt.BindAddr == "" { - opt.BindAddr = "127.0.0.1" - } - if opt.RepeatNum <= 0 { - opt.RepeatNum = 1 - } - return &Server{ - opt: opt, - } -} - -func (s *Server) GetOptions() Options { - return s.opt -} - -func (s *Server) Run() error { - if err := s.initListener(); err != nil { - return err - } - - go func() { - for { - c, err := s.l.Accept() - if err != nil { - return - } - go s.handle(c) - } - }() - return nil -} - -func (s *Server) Close() error { - if s.l != nil { - return s.l.Close() - } - return nil -} - -func (s *Server) initListener() (err error) { - switch s.opt.Type { - case TCP: - s.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.opt.BindAddr, s.opt.BindPort)) - case UDP: - s.l, err = fnet.ListenUDP(s.opt.BindAddr, int(s.opt.BindPort)) - case Unix: - s.l, err = net.Listen("unix", s.opt.BindAddr) - default: - return fmt.Errorf("unknown server type: %s", s.opt.Type) - } - if err != nil { - return - } - return nil -} - -func (s *Server) handle(c net.Conn) { - defer c.Close() - - buf := make([]byte, 2048) - for { - n, err := c.Read(buf) - if err != nil { - return - } - - var response string - if len(s.opt.SpecifiedResponse) > 0 { - response = s.opt.SpecifiedResponse - } else { - response = strings.Repeat(string(buf[:n]), s.opt.RepeatNum) - } - c.Write([]byte(response)) - } -} diff --git a/test/e2e/mock/server/server.go b/test/e2e/mock/server/server.go new file mode 100644 index 00000000..5ce7307e --- /dev/null +++ b/test/e2e/mock/server/server.go @@ -0,0 +1,142 @@ +package server + +import ( + "fmt" + "net" + + libnet "github.com/fatedier/frp/pkg/util/net" +) + +type ServerType string + +const ( + TCP ServerType = "tcp" + UDP ServerType = "udp" + Unix ServerType = "unix" +) + +type Server struct { + netType ServerType + bindAddr string + bindPort int + respContent []byte + bufSize int64 + + echoMode bool + + l net.Listener +} + +type Option func(*Server) *Server + +func New(netType ServerType, options ...Option) *Server { + s := &Server{ + netType: netType, + bindAddr: "127.0.0.1", + bufSize: 2048, + } + + for _, option := range options { + s = option(s) + } + return s +} + +func WithBindAddr(addr string) Option { + return func(s *Server) *Server { + s.bindAddr = addr + return s + } +} + +func WithBindPort(port int) Option { + return func(s *Server) *Server { + s.bindPort = port + return s + } +} + +func WithRespContent(content []byte) Option { + return func(s *Server) *Server { + s.respContent = content + return s + } +} + +func WithBufSize(bufSize int64) Option { + return func(s *Server) *Server { + s.bufSize = bufSize + return s + } +} + +func WithEchoMode(echoMode bool) Option { + return func(s *Server) *Server { + s.echoMode = echoMode + return s + } +} + +func (s *Server) Run() error { + if err := s.initListener(); err != nil { + return err + } + + go func() { + for { + c, err := s.l.Accept() + if err != nil { + return + } + go s.handle(c) + } + }() + return nil +} + +func (s *Server) Close() error { + if s.l != nil { + return s.l.Close() + } + return nil +} + +func (s *Server) initListener() (err error) { + switch s.netType { + case TCP: + s.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", s.bindAddr, s.bindPort)) + case UDP: + s.l, err = libnet.ListenUDP(s.bindAddr, s.bindPort) + case Unix: + s.l, err = net.Listen("unix", s.bindAddr) + default: + return fmt.Errorf("unknown server type: %s", s.netType) + } + return err +} + +func (s *Server) handle(c net.Conn) { + defer c.Close() + + buf := make([]byte, s.bufSize) + for { + n, err := c.Read(buf) + if err != nil { + return + } + + if s.echoMode { + c.Write(buf[:n]) + } else { + c.Write(s.respContent) + } + } +} + +func (s *Server) BindAddr() string { + return s.bindAddr +} + +func (s *Server) BindPort() int { + return s.bindPort +} diff --git a/test/e2e/pkg/port/port.go b/test/e2e/pkg/port/port.go index 296cb18b..1812e906 100644 --- a/test/e2e/pkg/port/port.go +++ b/test/e2e/pkg/port/port.go @@ -56,7 +56,6 @@ func (pa *Allocator) GetByName(portName string) int { return 0 } - // TODO: Distinguish between TCP and UDP l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) if err != nil { // Maybe not controlled by us, mark it used. diff --git a/test/e2e/pkg/port/util.go b/test/e2e/pkg/port/util.go index 74e39593..9cf1204f 100644 --- a/test/e2e/pkg/port/util.go +++ b/test/e2e/pkg/port/util.go @@ -59,9 +59,11 @@ func WithRangePorts(from, to int) NameOption { } func GenName(name string, options ...NameOption) string { + name = strings.ReplaceAll(name, "-", "") + name = strings.ReplaceAll(name, "_", "") builder := &nameBuilder{name: name} for _, option := range options { - option(builder) + builder = option(builder) } return builder.String() } diff --git a/test/e2e/suites.go b/test/e2e/suites.go index 268dcb45..1201bd3b 100644 --- a/test/e2e/suites.go +++ b/test/e2e/suites.go @@ -6,11 +6,9 @@ package e2e // and then the function that only runs on the first Ginkgo node. func CleanupSuite() { // Run on all Ginkgo nodes - // TODO } // AfterSuiteActions are actions that are run on ginkgo's SynchronizedAfterSuite func AfterSuiteActions() { // Run only Ginkgo on node 1 - // TODO }