add auth for reload api

This commit is contained in:
vashstorm 2016-12-20 18:32:17 +08:00
parent b8a28e945c
commit 5eb5fec761
2 changed files with 59 additions and 4 deletions

View File

@ -15,6 +15,7 @@
package main package main
import ( import (
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
@ -38,7 +39,7 @@ var usage string = `frps is the server of frp
Usage: Usage:
frps [-c config_file] [-L log_file] [--log-level=<log_level>] [--addr=<bind_addr>] frps [-c config_file] [-L log_file] [--log-level=<log_level>] [--addr=<bind_addr>]
frps --reload frps [-c config_file] --reload
frps -h | --help frps -h | --help
frps -v | --version frps -v | --version
@ -68,7 +69,18 @@ func main() {
// reload check // reload check
if args["--reload"] != nil { if args["--reload"] != nil {
if args["--reload"].(bool) { if args["--reload"].(bool) {
resp, err := http.Get("http://" + server.BindAddr + ":" + fmt.Sprintf("%d", server.DashboardPort) + "/api/reload") req, err := http.NewRequest("GET", "http://"+server.BindAddr+":"+fmt.Sprintf("%d", server.DashboardPort)+"/api/reload", nil)
if err != nil {
fmt.Printf("frps reload error: %v\n", err)
os.Exit(1)
}
authStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(server.DashboardUsername+":"+server.DashboardPassword))
req.Header.Add("Authorization", authStr)
defaultClient := &http.Client{}
resp, err := defaultClient.Do(req)
if err != nil { if err != nil {
fmt.Printf("frps reload error: %v\n", err) fmt.Printf("frps reload error: %v\n", err)
os.Exit(1) os.Exit(1)

View File

@ -15,9 +15,11 @@
package server package server
import ( import (
"encoding/base64"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/fatedier/frp/src/assets" "github.com/fatedier/frp/src/assets"
@ -32,13 +34,14 @@ func RunDashboardServer(addr string, port int64) (err error) {
// url router // url router
mux := http.NewServeMux() mux := http.NewServeMux()
// api, see dashboard_api.go // api, see dashboard_api.go
mux.HandleFunc("/api/reload", apiReload) // mux.HandleFunc("/api/reload", apiReload)
mux.HandleFunc("/api/reload", use(apiReload, basicAuth))
mux.HandleFunc("/api/proxies", apiProxies) mux.HandleFunc("/api/proxies", apiProxies)
// view, see dashboard_view.go // view, see dashboard_view.go
mux.Handle("/favicon.ico", http.FileServer(assets.FileSystem)) mux.Handle("/favicon.ico", http.FileServer(assets.FileSystem))
mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(assets.FileSystem))) mux.Handle("/static/", http.StripPrefix("/static/", http.FileServer(assets.FileSystem)))
mux.HandleFunc("/", viewDashboard) mux.HandleFunc("/", use(viewDashboard, basicAuth))
address := fmt.Sprintf("%s:%d", addr, port) address := fmt.Sprintf("%s:%d", addr, port)
server := &http.Server{ server := &http.Server{
@ -58,3 +61,43 @@ func RunDashboardServer(addr string, port int64) (err error) {
go server.Serve(ln) go server.Serve(ln)
return return
} }
func use(h http.HandlerFunc, middleware ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
for _, m := range middleware {
h = m(h)
}
return h
}
func basicAuth(h http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`)
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
if len(s) != 2 {
http.Error(w, "Not authorized", 401)
return
}
b, err := base64.StdEncoding.DecodeString(s[1])
if err != nil {
http.Error(w, err.Error(), 401)
return
}
pair := strings.SplitN(string(b), ":", 2)
if len(pair) != 2 {
http.Error(w, "Not authorized", 401)
return
}
if pair[0] != DashboardUsername || pair[1] != DashboardPassword {
http.Error(w, "Not authorized", 401)
return
}
h.ServeHTTP(w, r)
}
}