Merge pull request #1364 from fatedier/dev

bump version to v0.28.1 and remove support for go1.11
This commit is contained in:
fatedier 2019-08-08 17:32:57 +08:00 committed by GitHub
commit ae08811636
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
597 changed files with 269990 additions and 5727 deletions

View File

@ -2,7 +2,6 @@ sudo: false
language: go language: go
go: go:
- 1.11.x
- 1.12.x - 1.12.x
install: install:

View File

@ -86,8 +86,6 @@ func (svr *Service) Run() error {
if g.GlbClientCfg.LoginFailExit { if g.GlbClientCfg.LoginFailExit {
return err return err
} else { } else {
conn.Close()
session.Close()
time.Sleep(10 * time.Second) time.Sleep(10 * time.Second)
} }
} else { } else {
@ -169,6 +167,9 @@ func (svr *Service) login() (conn frpNet.Conn, session *fmux.Session, err error)
defer func() { defer func() {
if err != nil { if err != nil {
conn.Close() conn.Close()
if session != nil {
session.Close()
}
} }
}() }()

16
go.mod
View File

@ -4,14 +4,12 @@ go 1.12
require ( require (
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/fatedier/beego v0.0.0-20171024143340-6c6a4f5bd5eb github.com/fatedier/beego v0.0.0-20171024143340-6c6a4f5bd5eb
github.com/fatedier/golib v0.0.0-20181107124048-ff8cd814b049 github.com/fatedier/golib v0.0.0-20181107124048-ff8cd814b049
github.com/fatedier/kcp-go v2.0.4-0.20190317085623-2063a803e6fe+incompatible github.com/fatedier/kcp-go v2.0.4-0.20190803094908-fe8645b0a904+incompatible
github.com/golang/snappy v0.0.0-20170215233205-553a64147049 // indirect github.com/golang/snappy v0.0.0-20170215233205-553a64147049 // indirect
github.com/gorilla/context v1.1.1 // indirect github.com/gorilla/mux v1.7.3
github.com/gorilla/mux v1.6.2 github.com/gorilla/websocket v1.4.0
github.com/gorilla/websocket v1.2.0
github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d
github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.0.0 // indirect
github.com/klauspost/cpuid v1.2.0 // indirect github.com/klauspost/cpuid v1.2.0 // indirect
@ -19,16 +17,16 @@ require (
github.com/mattn/go-runewidth v0.0.4 // indirect github.com/mattn/go-runewidth v0.0.4 // indirect
github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc github.com/pires/go-proxyproto v0.0.0-20190111085350-4d51b51e3bfc
github.com/pkg/errors v0.8.0 // indirect github.com/pkg/errors v0.8.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rakyll/statik v0.1.1 github.com/rakyll/statik v0.1.1
github.com/rodaine/table v1.0.0 github.com/rodaine/table v1.0.0
github.com/spf13/cobra v0.0.3 github.com/spf13/cobra v0.0.3
github.com/spf13/pflag v1.0.1 // indirect github.com/spf13/pflag v1.0.1 // indirect
github.com/stretchr/testify v1.2.1 github.com/stretchr/testify v1.3.0
github.com/templexxx/cpufeat v0.0.0-20170927014610-3794dfbfb047 // indirect github.com/templexxx/cpufeat v0.0.0-20170927014610-3794dfbfb047 // indirect
github.com/templexxx/xor v0.0.0-20170926022130-0af8e873c554 // indirect github.com/templexxx/xor v0.0.0-20170926022130-0af8e873c554 // indirect
github.com/tjfoc/gmsm v0.0.0-20171124023159-98aa888b79d8 // indirect github.com/tjfoc/gmsm v0.0.0-20171124023159-98aa888b79d8 // indirect
github.com/vaughan0/go-ini v0.0.0-20130923145212-a98ad7ee00ec github.com/vaughan0/go-ini v0.0.0-20130923145212-a98ad7ee00ec
golang.org/x/crypto v0.0.0-20180505025534-4ec37c66abab // indirect github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae // indirect
golang.org/x/net v0.0.0-20180524181706-dfa909b99c79 golang.org/x/net v0.0.0-20190724013045-ca1201d0de80
golang.org/x/text v0.3.2 // indirect
) )

33
go.sum
View File

@ -5,15 +5,14 @@ github.com/fatedier/beego v0.0.0-20171024143340-6c6a4f5bd5eb h1:wCrNShQidLmvVWn/
github.com/fatedier/beego v0.0.0-20171024143340-6c6a4f5bd5eb/go.mod h1:wx3gB6dbIfBRcucp94PI9Bt3I0F2c/MyNEWuhzpWiwk= github.com/fatedier/beego v0.0.0-20171024143340-6c6a4f5bd5eb/go.mod h1:wx3gB6dbIfBRcucp94PI9Bt3I0F2c/MyNEWuhzpWiwk=
github.com/fatedier/golib v0.0.0-20181107124048-ff8cd814b049 h1:teH578mf2ii42NHhIp3PhgvjU5bv+NFMq9fSQR8NaG8= github.com/fatedier/golib v0.0.0-20181107124048-ff8cd814b049 h1:teH578mf2ii42NHhIp3PhgvjU5bv+NFMq9fSQR8NaG8=
github.com/fatedier/golib v0.0.0-20181107124048-ff8cd814b049/go.mod h1:DqIrnl0rp3Zybg9zbJmozTy1n8fYJoX+QoAj9slIkKM= github.com/fatedier/golib v0.0.0-20181107124048-ff8cd814b049/go.mod h1:DqIrnl0rp3Zybg9zbJmozTy1n8fYJoX+QoAj9slIkKM=
github.com/fatedier/kcp-go v2.0.4-0.20190317085623-2063a803e6fe+incompatible h1:pNNeBKz1jtMDupiwvtEGFTujA3J86xoEXGSkwVeYFsw= github.com/fatedier/kcp-go v2.0.4-0.20190803094908-fe8645b0a904+incompatible h1:ssXat9YXFvigNge/IkkZvFMn8yeYKFX+uI6wn2mLJ74=
github.com/fatedier/kcp-go v2.0.4-0.20190317085623-2063a803e6fe+incompatible/go.mod h1:YpCOaxj7vvMThhIQ9AfTOPW2sfztQR5WDfs7AflSy4s= github.com/fatedier/kcp-go v2.0.4-0.20190803094908-fe8645b0a904+incompatible/go.mod h1:YpCOaxj7vvMThhIQ9AfTOPW2sfztQR5WDfs7AflSy4s=
github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk= github.com/golang/snappy v0.0.0-20170215233205-553a64147049 h1:K9KHZbXKpGydfDN0aZrsoHpLJlZsBrGMFWbgLDGnPZk=
github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.0-20170215233205-553a64147049/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk= github.com/gorilla/websocket v1.4.0 h1:WDFjx/TMzVgy9VdMMQi2K2Emtwi2QcUQsztZ/zLaH/Q=
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.2.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d h1:kJCB4vdITiW1eC1vq2e6IsrXKrZit1bv/TDYFGMp4BQ= github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d h1:kJCB4vdITiW1eC1vq2e6IsrXKrZit1bv/TDYFGMp4BQ=
github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM= github.com/hashicorp/yamux v0.0.0-20181012175058-2f1d1f20f75d/go.mod h1:+NfK9FKeTrX5uv1uIXGdwYDTeHna2qgaIlx54MXqjAM=
github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM=
@ -37,7 +36,9 @@ github.com/spf13/cobra v0.0.3 h1:ZlrZ4XsMRm04Fr5pSFxBgfND2EBVa1nLpiy1stUsX/8=
github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ= github.com/spf13/cobra v0.0.3/go.mod h1:1l0Ry5zgKvJasoi3XT1TypsSe7PqH0Sj9dhYf7v3XqQ=
github.com/spf13/pflag v1.0.1 h1:aCvUg6QPl3ibpQUxyLkrEkCHtPqYJL4x9AuhqVqFis4= github.com/spf13/pflag v1.0.1 h1:aCvUg6QPl3ibpQUxyLkrEkCHtPqYJL4x9AuhqVqFis4=
github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= github.com/spf13/pflag v1.0.1/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4=
github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/templexxx/cpufeat v0.0.0-20170927014610-3794dfbfb047 h1:K+jtWCOuZgCra7eXZ/VWn2FbJmrA/D058mTXhh2rq+8= github.com/templexxx/cpufeat v0.0.0-20170927014610-3794dfbfb047 h1:K+jtWCOuZgCra7eXZ/VWn2FbJmrA/D058mTXhh2rq+8=
github.com/templexxx/cpufeat v0.0.0-20170927014610-3794dfbfb047/go.mod h1:wM7WEvslTq+iOEAMDLSzhVuOt5BRZ05WirO+b09GHQU= github.com/templexxx/cpufeat v0.0.0-20170927014610-3794dfbfb047/go.mod h1:wM7WEvslTq+iOEAMDLSzhVuOt5BRZ05WirO+b09GHQU=
github.com/templexxx/xor v0.0.0-20170926022130-0af8e873c554 h1:pexgSe+JCFuxG+uoMZLO+ce8KHtdHGhst4cs6rw3gmk= github.com/templexxx/xor v0.0.0-20170926022130-0af8e873c554 h1:pexgSe+JCFuxG+uoMZLO+ce8KHtdHGhst4cs6rw3gmk=
@ -46,7 +47,15 @@ github.com/tjfoc/gmsm v0.0.0-20171124023159-98aa888b79d8 h1:6CNSDqI1wiE+JqyOy5Qt
github.com/tjfoc/gmsm v0.0.0-20171124023159-98aa888b79d8/go.mod h1:XxO4hdhhrzAd+G4CjDqaOkd0hUzmtPR/d3EiBBMn/wc= github.com/tjfoc/gmsm v0.0.0-20171124023159-98aa888b79d8/go.mod h1:XxO4hdhhrzAd+G4CjDqaOkd0hUzmtPR/d3EiBBMn/wc=
github.com/vaughan0/go-ini v0.0.0-20130923145212-a98ad7ee00ec h1:DGmKwyZwEB8dI7tbLt/I/gQuP559o/0FrAkHKlQM/Ks= github.com/vaughan0/go-ini v0.0.0-20130923145212-a98ad7ee00ec h1:DGmKwyZwEB8dI7tbLt/I/gQuP559o/0FrAkHKlQM/Ks=
github.com/vaughan0/go-ini v0.0.0-20130923145212-a98ad7ee00ec/go.mod h1:owBmyHYMLkxyrugmfwE/DLJyW8Ro9mkphwuVErQ0iUw= github.com/vaughan0/go-ini v0.0.0-20130923145212-a98ad7ee00ec/go.mod h1:owBmyHYMLkxyrugmfwE/DLJyW8Ro9mkphwuVErQ0iUw=
golang.org/x/crypto v0.0.0-20180505025534-4ec37c66abab h1:w4c/LoOA2vE8SYwh8wEEQVRUwpph7TtcjH7AtZvOjy0= github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae h1:J0GxkO96kL4WF+AIT3M4mfUVinOCPgf2uUWYFUzN0sM=
golang.org/x/crypto v0.0.0-20180505025534-4ec37c66abab/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae/go.mod h1:gXtu8J62kEgmN++bm9BVICuT/e8yiLI2KFobd/TRFsE=
golang.org/x/net v0.0.0-20180524181706-dfa909b99c79 h1:1FDlG4HI84rVePw1/0E/crL5tt2N+1blLJpY6UZ6krs= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
golang.org/x/net v0.0.0-20180524181706-dfa909b99c79/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80 h1:Ao/3l156eZf2AW5wK8a7/smtodRU+gha3+BeqJ69lRk=
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@ -19,7 +19,7 @@ import (
"strings" "strings"
) )
var version string = "0.28.0" var version string = "0.28.1"
func Full() string { func Full() string {
return version return version

View File

@ -89,14 +89,13 @@ func NewHttpReverseProxy(option HttpReverseProxyOptions, vhostRouter *VhostRoute
return rp.CreateConnection(host, url, remote) return rp.CreateConnection(host, url, remote)
}, },
}, },
WebSocketDialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
url := ctx.Value("url").(string)
host := getHostFromAddr(ctx.Value("host").(string))
remote := ctx.Value("remote").(string)
return rp.CreateConnection(host, url, remote)
},
BufferPool: newWrapPool(), BufferPool: newWrapPool(),
ErrorLog: log.New(newWrapLogger(), "", 0), ErrorLog: log.New(newWrapLogger(), "", 0),
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
frpLog.Warn("do http proxy request error: %v", err)
rw.WriteHeader(http.StatusNotFound)
rw.Write(getNotFoundPageContent())
},
} }
rp.proxy = proxy rp.proxy = proxy
return rp return rp

View File

@ -8,6 +8,7 @@ package vhost
import ( import (
"context" "context"
"fmt"
"io" "io"
"log" "log"
"net" "net"
@ -17,13 +18,9 @@ import (
"sync" "sync"
"time" "time"
frpIo "github.com/fatedier/golib/io" "golang.org/x/net/http/httpguts"
) )
// onExitFlushLoop is a callback set by tests to detect the state of the
// flushLoop() goroutine.
var onExitFlushLoop func()
// ReverseProxy is an HTTP Handler that takes an incoming request and // ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the // sends it to another server, proxying the response back to the
// client. // client.
@ -44,12 +41,17 @@ type ReverseProxy struct {
// to flush to the client while copying the // to flush to the client while copying the
// response body. // response body.
// If zero, no periodic flushing is done. // If zero, no periodic flushing is done.
// A negative value means to flush immediately
// after each write to the client.
// The FlushInterval is ignored when ReverseProxy
// recognizes a response as a streaming response;
// for such responses, writes are flushed to the client
// immediately.
FlushInterval time.Duration FlushInterval time.Duration
// ErrorLog specifies an optional logger for errors // ErrorLog specifies an optional logger for errors
// that occur when attempting to proxy the request. // that occur when attempting to proxy the request.
// If nil, logging goes to os.Stderr via the log package's // If nil, logging is done via the log package's standard logger.
// standard logger.
ErrorLog *log.Logger ErrorLog *log.Logger
// BufferPool optionally specifies a buffer pool to // BufferPool optionally specifies a buffer pool to
@ -57,12 +59,23 @@ type ReverseProxy struct {
// copying HTTP response bodies. // copying HTTP response bodies.
BufferPool BufferPool BufferPool BufferPool
// ModifyResponse is an optional function that // ModifyResponse is an optional function that modifies the
// modifies the Response from the backend. // Response from the backend. It is called if the backend
// If it returns an error, the proxy returns a StatusBadGateway error. // returns a response at all, with any HTTP status code.
// If the backend is unreachable, the optional ErrorHandler is
// called without any call to ModifyResponse.
//
// If ModifyResponse returns an error, ErrorHandler is called
// with its error value. If ErrorHandler is nil, its default
// implementation is used.
ModifyResponse func(*http.Response) error ModifyResponse func(*http.Response) error
WebSocketDialContext func(ctx context.Context, network, addr string) (net.Conn, error) // ErrorHandler is an optional function that handles errors
// reaching the backend or errors from ModifyResponse.
//
// If nil, the default is to log the provided error and return
// a 502 Status Bad Gateway response.
ErrorHandler func(http.ResponseWriter, *http.Request, error)
} }
// A BufferPool is an interface for getting and returning temporary // A BufferPool is an interface for getting and returning temporary
@ -118,18 +131,11 @@ func copyHeader(dst, src http.Header) {
} }
} }
func cloneHeader(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
h2[k] = vv2
}
return h2
}
// Hop-by-hop headers. These are removed when sent to the backend. // Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html // As of RFC 7230, hop-by-hop headers are required to appear in the
// Connection header field. These are the headers defined by the
// obsoleted RFC 2616 (section 13.5.1) and are used for backward
// compatibility.
var hopHeaders = []string{ var hopHeaders = []string{
"Connection", "Connection",
"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
@ -137,55 +143,38 @@ var hopHeaders = []string{
"Proxy-Authenticate", "Proxy-Authenticate",
"Proxy-Authorization", "Proxy-Authorization",
"Te", // canonicalized version of "TE" "Te", // canonicalized version of "TE"
"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522 "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522
"Transfer-Encoding", "Transfer-Encoding",
"Upgrade", "Upgrade",
} }
func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
p.logf("http: proxy error: %v", err)
rw.WriteHeader(http.StatusBadGateway)
}
func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
if p.ErrorHandler != nil {
return p.ErrorHandler
}
return p.defaultErrorHandler
}
// modifyResponse conditionally runs the optional ModifyResponse hook
// and reports whether the request should proceed.
func (p *ReverseProxy) modifyResponse(rw http.ResponseWriter, res *http.Response, req *http.Request) bool {
if p.ModifyResponse == nil {
return true
}
if err := p.ModifyResponse(res); err != nil {
res.Body.Close()
p.getErrorHandler()(rw, req, err)
return false
}
return true
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
if IsWebsocketRequest(req) {
p.serveWebSocket(rw, req)
} else {
p.serveHTTP(rw, req)
}
}
func (p *ReverseProxy) serveWebSocket(rw http.ResponseWriter, req *http.Request) {
if p.WebSocketDialContext == nil {
rw.WriteHeader(500)
return
}
req = req.WithContext(context.WithValue(req.Context(), "url", req.URL.Path))
req = req.WithContext(context.WithValue(req.Context(), "host", req.Host))
req = req.WithContext(context.WithValue(req.Context(), "remote", req.RemoteAddr))
targetConn, err := p.WebSocketDialContext(req.Context(), "tcp", "")
if err != nil {
rw.WriteHeader(501)
return
}
defer targetConn.Close()
p.Director(req)
hijacker, ok := rw.(http.Hijacker)
if !ok {
rw.WriteHeader(500)
return
}
conn, _, errHijack := hijacker.Hijack()
if errHijack != nil {
rw.WriteHeader(500)
return
}
defer conn.Close()
req.Write(targetConn)
frpIo.Join(conn, targetConn)
}
func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport transport := p.Transport
if transport == nil { if transport == nil {
transport = http.DefaultTransport transport = http.DefaultTransport
@ -206,38 +195,49 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
}() }()
} }
outreq := req.WithContext(ctx) // includes shallow copies of maps, but okay outreq := req.WithContext(ctx)
if req.ContentLength == 0 { if req.ContentLength == 0 {
outreq.Body = nil // Issue 16036: nil Body for http.Transport retries outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
} }
outreq.Header = cloneHeader(req.Header) // =============================
// Modified for frp
// Modify for frp
outreq = outreq.WithContext(context.WithValue(outreq.Context(), "url", req.URL.Path)) outreq = outreq.WithContext(context.WithValue(outreq.Context(), "url", req.URL.Path))
outreq = outreq.WithContext(context.WithValue(outreq.Context(), "host", req.Host)) outreq = outreq.WithContext(context.WithValue(outreq.Context(), "host", req.Host))
outreq = outreq.WithContext(context.WithValue(outreq.Context(), "remote", req.RemoteAddr)) outreq = outreq.WithContext(context.WithValue(outreq.Context(), "remote", req.RemoteAddr))
// =============================
p.Director(outreq) p.Director(outreq)
outreq.Close = false outreq.Close = false
// Remove hop-by-hop headers listed in the "Connection" header. reqUpType := upgradeType(outreq.Header)
// See RFC 2616, section 14.10. removeConnectionHeaders(outreq.Header)
if c := outreq.Header.Get("Connection"); c != "" {
for _, f := range strings.Split(c, ",") {
if f = strings.TrimSpace(f); f != "" {
outreq.Header.Del(f)
}
}
}
// Remove hop-by-hop headers to the backend. Especially // Remove hop-by-hop headers to the backend. Especially
// important is "Connection" because we want a persistent // important is "Connection" because we want a persistent
// connection, regardless of what the client sent to us. // connection, regardless of what the client sent to us.
for _, h := range hopHeaders { for _, h := range hopHeaders {
if outreq.Header.Get(h) != "" { hv := outreq.Header.Get(h)
outreq.Header.Del(h) if hv == "" {
continue
} }
if h == "Te" && hv == "trailers" {
// Issue 21096: tell backend applications that
// care about trailer support that we support
// trailers. (We do, but we don't go out of
// our way to advertise that unless the
// incoming client request thought it was
// worth mentioning)
continue
}
outreq.Header.Del(h)
}
// After stripping all the hop-by-hop connection headers above, add back any
// necessary for protocol upgrades, such as for websockets.
if reqUpType != "" {
outreq.Header.Set("Connection", "Upgrade")
outreq.Header.Set("Upgrade", reqUpType)
} }
if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { if clientIP, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
@ -252,33 +252,27 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
res, err := transport.RoundTrip(outreq) res, err := transport.RoundTrip(outreq)
if err != nil { if err != nil {
p.logf("http: proxy error: %v", err) p.getErrorHandler()(rw, outreq, err)
rw.WriteHeader(http.StatusNotFound)
rw.Write(getNotFoundPageContent())
return return
} }
// Remove hop-by-hop headers listed in the // Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
// "Connection" header of the response. if res.StatusCode == http.StatusSwitchingProtocols {
if c := res.Header.Get("Connection"); c != "" { if !p.modifyResponse(rw, res, outreq) {
for _, f := range strings.Split(c, ",") { return
if f = strings.TrimSpace(f); f != "" {
res.Header.Del(f)
}
} }
p.handleUpgradeResponse(rw, outreq, res)
return
} }
removeConnectionHeaders(res.Header)
for _, h := range hopHeaders { for _, h := range hopHeaders {
res.Header.Del(h) res.Header.Del(h)
} }
if p.ModifyResponse != nil { if !p.modifyResponse(rw, res, outreq) {
if err := p.ModifyResponse(res); err != nil { return
p.logf("http: proxy error: %v", err)
rw.WriteHeader(http.StatusBadGateway)
return
}
} }
copyHeader(rw.Header(), res.Header) copyHeader(rw.Header(), res.Header)
@ -295,6 +289,21 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
} }
rw.WriteHeader(res.StatusCode) rw.WriteHeader(res.StatusCode)
err = p.copyResponse(rw, res.Body, p.flushInterval(req, res))
if err != nil {
defer res.Body.Close()
// Since we're streaming the response, if we run into an error all we can do
// is abort the request. Issue 23643: ReverseProxy should use ErrAbortHandler
// on read error while copying body.
if !shouldPanicOnCopyError(req) {
p.logf("suppressing panic for copyResponse error in test; copy error: %v", err)
return
}
panic(http.ErrAbortHandler)
}
res.Body.Close() // close now, instead of defer, to populate res.Trailer
if len(res.Trailer) > 0 { if len(res.Trailer) > 0 {
// Force chunking if we saw a response trailer. // Force chunking if we saw a response trailer.
// This prevents net/http from calculating the length for short // This prevents net/http from calculating the length for short
@ -303,8 +312,6 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
fl.Flush() fl.Flush()
} }
} }
p.copyResponse(rw, res.Body)
res.Body.Close() // close now, instead of defer, to populate res.Trailer
if len(res.Trailer) == announcedTrailers { if len(res.Trailer) == announcedTrailers {
copyHeader(rw.Header(), res.Trailer) copyHeader(rw.Header(), res.Trailer)
@ -319,16 +326,68 @@ func (p *ReverseProxy) serveHTTP(rw http.ResponseWriter, req *http.Request) {
} }
} }
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { var inOurTests bool // whether we're in our own tests
if p.FlushInterval != 0 {
// shouldPanicOnCopyError reports whether the reverse proxy should
// panic with http.ErrAbortHandler. This is the right thing to do by
// default, but Go 1.10 and earlier did not, so existing unit tests
// weren't expecting panics. Only panic in our own tests, or when
// running under the HTTP server.
func shouldPanicOnCopyError(req *http.Request) bool {
if inOurTests {
// Our tests know to handle this panic.
return true
}
if req.Context().Value(http.ServerContextKey) != nil {
// We seem to be running under an HTTP server, so
// it'll recover the panic.
return true
}
// Otherwise act like Go 1.10 and earlier to not break
// existing tests.
return false
}
// removeConnectionHeaders removes hop-by-hop headers listed in the "Connection" header of h.
// See RFC 7230, section 6.1
func removeConnectionHeaders(h http.Header) {
for _, f := range h["Connection"] {
for _, sf := range strings.Split(f, ",") {
if sf = strings.TrimSpace(sf); sf != "" {
h.Del(sf)
}
}
}
}
// flushInterval returns the p.FlushInterval value, conditionally
// overriding its value for a specific request/response.
func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration {
resCT := res.Header.Get("Content-Type")
// For Server-Sent Events responses, flush immediately.
// The MIME type is defined in https://www.w3.org/TR/eventsource/#text-event-stream
if resCT == "text/event-stream" {
return -1 // negative means immediately
}
// TODO: more specific cases? e.g. res.ContentLength == -1?
return p.FlushInterval
}
func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader, flushInterval time.Duration) error {
if flushInterval != 0 {
if wf, ok := dst.(writeFlusher); ok { if wf, ok := dst.(writeFlusher); ok {
mlw := &maxLatencyWriter{ mlw := &maxLatencyWriter{
dst: wf, dst: wf,
latency: p.FlushInterval, latency: flushInterval,
done: make(chan bool),
} }
go mlw.flushLoop()
defer mlw.stop() defer mlw.stop()
// set up initial timer so headers get flushed even if body writes are delayed
mlw.flushPending = true
mlw.t = time.AfterFunc(flushInterval, mlw.delayedFlush)
dst = mlw dst = mlw
} }
} }
@ -336,13 +395,14 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
var buf []byte var buf []byte
if p.BufferPool != nil { if p.BufferPool != nil {
buf = p.BufferPool.Get() buf = p.BufferPool.Get()
defer p.BufferPool.Put(buf)
} }
p.copyBuffer(dst, src, buf) _, err := p.copyBuffer(dst, src, buf)
if p.BufferPool != nil { return err
p.BufferPool.Put(buf)
}
} }
// copyBuffer returns any write errors or non-EOF read errors, and the amount
// of bytes written.
func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) {
if len(buf) == 0 { if len(buf) == 0 {
buf = make([]byte, 32*1024) buf = make([]byte, 32*1024)
@ -366,6 +426,9 @@ func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int
} }
} }
if rerr != nil { if rerr != nil {
if rerr == io.EOF {
rerr = nil
}
return written, rerr return written, rerr
} }
} }
@ -386,47 +449,115 @@ type writeFlusher interface {
type maxLatencyWriter struct { type maxLatencyWriter struct {
dst writeFlusher dst writeFlusher
latency time.Duration latency time.Duration // non-zero; negative means to flush immediately
mu sync.Mutex // protects Write + Flush mu sync.Mutex // protects t, flushPending, and dst.Flush
done chan bool t *time.Timer
flushPending bool
} }
func (m *maxLatencyWriter) Write(p []byte) (int, error) { func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
return m.dst.Write(p) n, err = m.dst.Write(p)
if m.latency < 0 {
m.dst.Flush()
return
}
if m.flushPending {
return
}
if m.t == nil {
m.t = time.AfterFunc(m.latency, m.delayedFlush)
} else {
m.t.Reset(m.latency)
}
m.flushPending = true
return
} }
func (m *maxLatencyWriter) flushLoop() { func (m *maxLatencyWriter) delayedFlush() {
t := time.NewTicker(m.latency) m.mu.Lock()
defer t.Stop() defer m.mu.Unlock()
for { if !m.flushPending { // if stop was called but AfterFunc already started this goroutine
select { return
case <-m.done: }
if onExitFlushLoop != nil { m.dst.Flush()
onExitFlushLoop() m.flushPending = false
} }
return
case <-t.C: func (m *maxLatencyWriter) stop() {
m.mu.Lock() m.mu.Lock()
m.dst.Flush() defer m.mu.Unlock()
m.mu.Unlock() m.flushPending = false
} if m.t != nil {
m.t.Stop()
} }
} }
func (m *maxLatencyWriter) stop() { m.done <- true } func upgradeType(h http.Header) string {
if !httpguts.HeaderValuesContainsToken(h["Connection"], "Upgrade") {
func IsWebsocketRequest(req *http.Request) bool { return ""
containsHeader := func(name, value string) bool {
items := strings.Split(req.Header.Get(name), ",")
for _, item := range items {
if value == strings.ToLower(strings.TrimSpace(item)) {
return true
}
}
return false
} }
return containsHeader("Connection", "upgrade") && containsHeader("Upgrade", "websocket") return strings.ToLower(h.Get("Upgrade"))
}
func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) {
reqUpType := upgradeType(req.Header)
resUpType := upgradeType(res.Header)
if reqUpType != resUpType {
p.getErrorHandler()(rw, req, fmt.Errorf("backend tried to switch protocol %q when %q was requested", resUpType, reqUpType))
return
}
copyHeader(res.Header, rw.Header())
hj, ok := rw.(http.Hijacker)
if !ok {
p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw))
return
}
backConn, ok := res.Body.(io.ReadWriteCloser)
if !ok {
p.getErrorHandler()(rw, req, fmt.Errorf("internal error: 101 switching protocols response with non-writable body"))
return
}
defer backConn.Close()
conn, brw, err := hj.Hijack()
if err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("Hijack failed on protocol switch: %v", err))
return
}
defer conn.Close()
res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
if err := res.Write(brw); err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err))
return
}
if err := brw.Flush(); err != nil {
p.getErrorHandler()(rw, req, fmt.Errorf("response flush: %v", err))
return
}
errc := make(chan error, 1)
spc := switchProtocolCopier{user: conn, backend: backConn}
go spc.copyToBackend(errc)
go spc.copyFromBackend(errc)
<-errc
return
}
// switchProtocolCopier exists so goroutines proxying data back and
// forth have nice names in stacks.
type switchProtocolCopier struct {
user, backend io.ReadWriter
}
func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
_, err := io.Copy(c.user, c.backend)
errc <- err
}
func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
_, err := io.Copy(c.backend, c.user)
errc <- err
} }

View File

@ -34,6 +34,7 @@ This library intents to provide a **smooth, resilient, ordered, error-checked an
1. Packet level encryption support with [AES](https://en.wikipedia.org/wiki/Advanced_Encryption_Standard), [TEA](https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm), [3DES](https://en.wikipedia.org/wiki/Triple_DES), [Blowfish](https://en.wikipedia.org/wiki/Blowfish_(cipher)), [Cast5](https://en.wikipedia.org/wiki/CAST-128), [Salsa20]( https://en.wikipedia.org/wiki/Salsa20), etc. in [CFB](https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Cipher_Feedback_.28CFB.29) mode, which generates completely anonymous packet. 1. Packet level encryption support with [AES](https://en.wikipedia.org/wiki/Advanced_Encryption_Standard), [TEA](https://en.wikipedia.org/wiki/Tiny_Encryption_Algorithm), [3DES](https://en.wikipedia.org/wiki/Triple_DES), [Blowfish](https://en.wikipedia.org/wiki/Blowfish_(cipher)), [Cast5](https://en.wikipedia.org/wiki/CAST-128), [Salsa20]( https://en.wikipedia.org/wiki/Salsa20), etc. in [CFB](https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Cipher_Feedback_.28CFB.29) mode, which generates completely anonymous packet.
1. Only **A fixed number of goroutines** will be created for the entire server application, costs in **context switch** between goroutines have been taken into consideration. 1. Only **A fixed number of goroutines** will be created for the entire server application, costs in **context switch** between goroutines have been taken into consideration.
1. Compatible with [skywind3000's](https://github.com/skywind3000) C version with various improvements. 1. Compatible with [skywind3000's](https://github.com/skywind3000) C version with various improvements.
1. Platform-dependent optimizations: [sendmmsg](http://man7.org/linux/man-pages/man2/sendmmsg.2.html) and [recvmmsg](http://man7.org/linux/man-pages/man2/recvmmsg.2.html) were expoloited for linux.
## Documentation ## Documentation
@ -43,6 +44,24 @@ For complete documentation, see the associated [Godoc](https://godoc.org/github.
<img src="frame.png" alt="Frame Format" height="109px" /> <img src="frame.png" alt="Frame Format" height="109px" />
```
NONCE:
16bytes cryptographically secure random number, nonce changes for every packet.
CRC32:
CRC-32 checksum of data using the IEEE polynomial
FEC TYPE:
typeData = 0xF1
typeParity = 0xF2
FEC SEQID:
monotonically increasing in range: [0, (0xffffffff/shardSize) * shardSize - 1]
SIZE:
The size of KCP frame plus 2
```
``` ```
+-----------------+ +-----------------+
| SESSION | | SESSION |
@ -65,16 +84,11 @@ For complete documentation, see the associated [Godoc](https://godoc.org/github.
``` ```
## Usage ## Examples
Client: [full demo](https://github.com/xtaci/kcptun/blob/master/client/main.go) 1. [simple examples](https://github.com/xtaci/kcp-go/tree/master/examples)
```go 2. [kcptun client](https://github.com/xtaci/kcptun/blob/master/client/main.go)
kcpconn, err := kcp.DialWithOptions("192.168.0.1:10000", nil, 10, 3) 3. [kcptun server](https://github.com/xtaci/kcptun/blob/master/server/main.go)
```
Server: [full demo](https://github.com/xtaci/kcptun/blob/master/server/main.go)
```go
lis, err := kcp.ListenWithOptions(":10000", nil, 10, 3)
```
## Benchmark ## Benchmark
``` ```
@ -128,6 +142,10 @@ PASS
ok github.com/xtaci/kcp-go 50.349s ok github.com/xtaci/kcp-go 50.349s
``` ```
## Typical Flame Graph
![Flame Graph in kcptun](flame.png)
## Key Design Considerations ## Key Design Considerations
1. slice vs. container/list 1. slice vs. container/list
@ -159,6 +177,18 @@ BenchmarkNow-4 100000000 15.6 ns/op
In kcp-go, after each `kcp.output()` function call, current clock time will be updated upon return, and for a single `kcp.flush()` operation, current time will be queried from system once. For most of the time, 5000 connections costs 5000 * 15.6ns = 78us(a fixed cost while no packet needs to be sent), as for 10MB/s data transfering with 1400 MTU, `kcp.output()` will be called around 7500 times and costs 117us for `time.Now()` in **every second**. In kcp-go, after each `kcp.output()` function call, current clock time will be updated upon return, and for a single `kcp.flush()` operation, current time will be queried from system once. For most of the time, 5000 connections costs 5000 * 15.6ns = 78us(a fixed cost while no packet needs to be sent), as for 10MB/s data transfering with 1400 MTU, `kcp.output()` will be called around 7500 times and costs 117us for `time.Now()` in **every second**.
3. Memory management
Primary memory allocation are done from a global buffer pool xmit.Buf, in kcp-go, when we need to allocate some bytes, we can get from that pool, and a fixed-capacity 1500 bytes(mtuLimit) will be returned, the rx queue, tx queue and fec queue all receive bytes from there, and they will return the bytes to the pool after using to prevent unnecessary zer0ing of bytes. The pool mechanism maintained a high watermark for slice objects, these in-flight objects from the pool will survive from the perodical garbage collection, meanwhile the pool kept the ability to return the memory to runtime if in idle.
4. Information security
kcp-go is shipped with builtin packet encryption powered by various block encryption algorithms and works in [Cipher Feedback Mode](https://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Cipher_Feedback_(CFB)), for each packet to be sent, the encryption process will start from encrypting a [nonce](https://en.wikipedia.org/wiki/Cryptographic_nonce) from the [system entropy](https://en.wikipedia.org/wiki//dev/random), so encryption to same plaintexts never leads to a same ciphertexts thereafter.
The contents of the packets are completely anonymous with encryption, including the headers(FEC,KCP), checksums and contents. Note that, no matter which encryption method you choose on you upper layer, if you disable encryption, the transmit will be insecure somehow, since the header is ***PLAINTEXT*** to everyone it would be susceptible to header tampering, such as jamming the *sliding window size*, *round-trip time*, *FEC property* and *checksums*. ```AES-128``` is suggested for minimal encryption since modern CPUs are shipped with [AES-NI](https://en.wikipedia.org/wiki/AES_instruction_set) instructions and performs even better than `salsa20`(check the table above).
Other possible attacks to kcp-go includes: a) [traffic analysis](https://en.wikipedia.org/wiki/Traffic_analysis), dataflow on specific websites may have pattern while interchanging data, but this type of eavesdropping has been mitigated by adapting [smux](https://github.com/xtaci/smux) to mix data streams so as to introduce noises, perfect solution to this has not appeared yet, theroretically by shuffling/mixing messages on larger scale network may mitigate this problem. b) [replay attack](https://en.wikipedia.org/wiki/Replay_attack), since the asymmetrical encryption has not been introduced into kcp-go for some reason, capturing the packets and replay them on a different machine is possible, (notice: hijacking the session and decrypting the contents is still *impossible*), so upper layers should contain a asymmetrical encryption system to guarantee the authenticity of each message(to process message exactly once), such as HTTPS/OpenSSL/LibreSSL, only by signing the requests with private keys can eliminate this type of attack.
## Connection Termination ## Connection Termination
Control messages like **SYN/FIN/RST** in TCP **are not defined** in KCP, you need some **keepalive/heartbeat mechanism** in the application-level. A real world example is to use some **multiplexing** protocol over session, such as [smux](https://github.com/xtaci/smux)(with embedded keepalive mechanism), see [kcptun](https://github.com/xtaci/kcptun) for example. Control messages like **SYN/FIN/RST** in TCP **are not defined** in KCP, you need some **keepalive/heartbeat mechanism** in the application-level. A real world example is to use some **multiplexing** protocol over session, such as [smux](https://github.com/xtaci/smux)(with embedded keepalive mechanism), see [kcptun](https://github.com/xtaci/kcptun) for example.
@ -169,6 +199,14 @@ Q: I'm handling >5K connections on my server, the CPU utilization is so high.
A: A standalone `agent` or `gate` server for running kcp-go is suggested, not only for CPU utilization, but also important to the **precision** of RTT measurements(timing) which indirectly affects retransmission. By increasing update `interval` with `SetNoDelay` like `conn.SetNoDelay(1, 40, 1, 1)` will dramatically reduce system load, but lower the performance. A: A standalone `agent` or `gate` server for running kcp-go is suggested, not only for CPU utilization, but also important to the **precision** of RTT measurements(timing) which indirectly affects retransmission. By increasing update `interval` with `SetNoDelay` like `conn.SetNoDelay(1, 40, 1, 1)` will dramatically reduce system load, but lower the performance.
Q: When should I enable FEC?
A: Forward error correction is critical to long-distance transmission, because a packet loss will lead to a huge penalty in time. And for the complicated packet routing network in modern world, round-trip time based loss check will not always be efficient, the big deviation of RTT samples in the long way usually leads to a larger RTO value in typical rtt estimator, which in other words, slows down the transmission.
Q: Should I enable encryption?
A: Yes, for the safety of protocol, even if the upper layer has encrypted.
## Who is using this? ## Who is using this?
1. https://github.com/xtaci/kcptun -- A Secure Tunnel Based On KCP over UDP. 1. https://github.com/xtaci/kcptun -- A Secure Tunnel Based On KCP over UDP.

12
vendor/github.com/fatedier/kcp-go/batchconn.go generated vendored Normal file
View File

@ -0,0 +1,12 @@
package kcp
import "golang.org/x/net/ipv4"
const (
batchSize = 16
)
type batchConn interface {
WriteBatch(ms []ipv4.Message, flags int) (int, error)
ReadBatch(ms []ipv4.Message, flags int) (int, error)
}

View File

@ -11,36 +11,34 @@ const (
fecHeaderSize = 6 fecHeaderSize = 6
fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size
typeData = 0xf1 typeData = 0xf1
typeFEC = 0xf2 typeParity = 0xf2
) )
type ( // fecPacket is a decoded FEC packet
// fecPacket is a decoded FEC packet type fecPacket []byte
fecPacket struct {
seqid uint32
flag uint16
data []byte
}
// fecDecoder for decoding incoming packets func (bts fecPacket) seqid() uint32 { return binary.LittleEndian.Uint32(bts) }
fecDecoder struct { func (bts fecPacket) flag() uint16 { return binary.LittleEndian.Uint16(bts[4:]) }
rxlimit int // queue size limit func (bts fecPacket) data() []byte { return bts[6:] }
dataShards int
parityShards int
shardSize int
rx []fecPacket // ordered receive queue
// caches // fecDecoder for decoding incoming packets
decodeCache [][]byte type fecDecoder struct {
flagCache []bool rxlimit int // queue size limit
dataShards int
parityShards int
shardSize int
rx []fecPacket // ordered receive queue
// zeros // caches
zeros []byte decodeCache [][]byte
flagCache []bool
// RS decoder // zeros
codec reedsolomon.Encoder zeros []byte
}
) // RS decoder
codec reedsolomon.Encoder
}
func newFECDecoder(rxlimit, dataShards, parityShards int) *fecDecoder { func newFECDecoder(rxlimit, dataShards, parityShards int) *fecDecoder {
if dataShards <= 0 || parityShards <= 0 { if dataShards <= 0 || parityShards <= 0 {
@ -66,33 +64,24 @@ func newFECDecoder(rxlimit, dataShards, parityShards int) *fecDecoder {
return dec return dec
} }
// decodeBytes a fec packet
func (dec *fecDecoder) decodeBytes(data []byte) fecPacket {
var pkt fecPacket
pkt.seqid = binary.LittleEndian.Uint32(data)
pkt.flag = binary.LittleEndian.Uint16(data[4:])
// allocate memory & copy
buf := xmitBuf.Get().([]byte)[:len(data)-6]
copy(buf, data[6:])
pkt.data = buf
return pkt
}
// decode a fec packet // decode a fec packet
func (dec *fecDecoder) decode(pkt fecPacket) (recovered [][]byte) { func (dec *fecDecoder) decode(in fecPacket) (recovered [][]byte) {
// insertion // insertion
n := len(dec.rx) - 1 n := len(dec.rx) - 1
insertIdx := 0 insertIdx := 0
for i := n; i >= 0; i-- { for i := n; i >= 0; i-- {
if pkt.seqid == dec.rx[i].seqid { // de-duplicate if in.seqid() == dec.rx[i].seqid() { // de-duplicate
xmitBuf.Put(pkt.data)
return nil return nil
} else if _itimediff(pkt.seqid, dec.rx[i].seqid) > 0 { // insertion } else if _itimediff(in.seqid(), dec.rx[i].seqid()) > 0 { // insertion
insertIdx = i + 1 insertIdx = i + 1
break break
} }
} }
// make a copy
pkt := fecPacket(xmitBuf.Get().([]byte)[:len(in)])
copy(pkt, in)
// insert into ordered rx queue // insert into ordered rx queue
if insertIdx == n+1 { if insertIdx == n+1 {
dec.rx = append(dec.rx, pkt) dec.rx = append(dec.rx, pkt)
@ -103,11 +92,11 @@ func (dec *fecDecoder) decode(pkt fecPacket) (recovered [][]byte) {
} }
// shard range for current packet // shard range for current packet
shardBegin := pkt.seqid - pkt.seqid%uint32(dec.shardSize) shardBegin := pkt.seqid() - pkt.seqid()%uint32(dec.shardSize)
shardEnd := shardBegin + uint32(dec.shardSize) - 1 shardEnd := shardBegin + uint32(dec.shardSize) - 1
// max search range in ordered queue for current shard // max search range in ordered queue for current shard
searchBegin := insertIdx - int(pkt.seqid%uint32(dec.shardSize)) searchBegin := insertIdx - int(pkt.seqid()%uint32(dec.shardSize))
if searchBegin < 0 { if searchBegin < 0 {
searchBegin = 0 searchBegin = 0
} }
@ -130,21 +119,21 @@ func (dec *fecDecoder) decode(pkt fecPacket) (recovered [][]byte) {
// shard assembly // shard assembly
for i := searchBegin; i <= searchEnd; i++ { for i := searchBegin; i <= searchEnd; i++ {
seqid := dec.rx[i].seqid seqid := dec.rx[i].seqid()
if _itimediff(seqid, shardEnd) > 0 { if _itimediff(seqid, shardEnd) > 0 {
break break
} else if _itimediff(seqid, shardBegin) >= 0 { } else if _itimediff(seqid, shardBegin) >= 0 {
shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data()
shardsflag[seqid%uint32(dec.shardSize)] = true shardsflag[seqid%uint32(dec.shardSize)] = true
numshard++ numshard++
if dec.rx[i].flag == typeData { if dec.rx[i].flag() == typeData {
numDataShard++ numDataShard++
} }
if numshard == 1 { if numshard == 1 {
first = i first = i
} }
if len(dec.rx[i].data) > maxlen { if len(dec.rx[i].data()) > maxlen {
maxlen = len(dec.rx[i].data) maxlen = len(dec.rx[i].data())
} }
} }
} }
@ -159,11 +148,14 @@ func (dec *fecDecoder) decode(pkt fecPacket) (recovered [][]byte) {
dlen := len(shards[k]) dlen := len(shards[k])
shards[k] = shards[k][:maxlen] shards[k] = shards[k][:maxlen]
copy(shards[k][dlen:], dec.zeros) copy(shards[k][dlen:], dec.zeros)
} else {
shards[k] = xmitBuf.Get().([]byte)[:0]
} }
} }
if err := dec.codec.ReconstructData(shards); err == nil { if err := dec.codec.ReconstructData(shards); err == nil {
for k := range shards[:dec.dataShards] { for k := range shards[:dec.dataShards] {
if !shardsflag[k] { if !shardsflag[k] {
// recovered data should be recycled
recovered = append(recovered, shards[k]) recovered = append(recovered, shards[k])
} }
} }
@ -174,7 +166,7 @@ func (dec *fecDecoder) decode(pkt fecPacket) (recovered [][]byte) {
// keep rxlimit // keep rxlimit
if len(dec.rx) > dec.rxlimit { if len(dec.rx) > dec.rxlimit {
if dec.rx[0].flag == typeData { // track the unrecoverable data if dec.rx[0].flag() == typeData { // track the unrecoverable data
atomic.AddUint64(&DefaultSnmp.FECShortShards, 1) atomic.AddUint64(&DefaultSnmp.FECShortShards, 1)
} }
dec.rx = dec.freeRange(0, 1, dec.rx) dec.rx = dec.freeRange(0, 1, dec.rx)
@ -182,15 +174,16 @@ func (dec *fecDecoder) decode(pkt fecPacket) (recovered [][]byte) {
return return
} }
// free a range of fecPacket, and zero for GC recycling // free a range of fecPacket
func (dec *fecDecoder) freeRange(first, n int, q []fecPacket) []fecPacket { func (dec *fecDecoder) freeRange(first, n int, q []fecPacket) []fecPacket {
for i := first; i < first+n; i++ { // recycle buffer for i := first; i < first+n; i++ { // recycle buffer
xmitBuf.Put(q[i].data) xmitBuf.Put([]byte(q[i]))
}
if first == 0 && n < cap(q)/2 {
return q[n:]
} }
copy(q[first:], q[first+n:]) copy(q[first:], q[first+n:])
for i := 0; i < n; i++ { // dereference data
q[len(q)-1-i].data = nil
}
return q[:len(q)-n] return q[:len(q)-n]
} }
@ -229,7 +222,7 @@ func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder {
enc.dataShards = dataShards enc.dataShards = dataShards
enc.parityShards = parityShards enc.parityShards = parityShards
enc.shardSize = dataShards + parityShards enc.shardSize = dataShards + parityShards
enc.paws = (0xffffffff/uint32(enc.shardSize) - 1) * uint32(enc.shardSize) enc.paws = 0xffffffff / uint32(enc.shardSize) * uint32(enc.shardSize)
enc.headerOffset = offset enc.headerOffset = offset
enc.payloadOffset = enc.headerOffset + fecHeaderSize enc.payloadOffset = enc.headerOffset + fecHeaderSize
@ -252,13 +245,16 @@ func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder {
// encodes the packet, outputs parity shards if we have collected quorum datashards // encodes the packet, outputs parity shards if we have collected quorum datashards
// notice: the contents of 'ps' will be re-written in successive calling // notice: the contents of 'ps' will be re-written in successive calling
func (enc *fecEncoder) encode(b []byte) (ps [][]byte) { func (enc *fecEncoder) encode(b []byte) (ps [][]byte) {
// The header format:
// | FEC SEQID(4B) | FEC TYPE(2B) | SIZE (2B) | PAYLOAD(SIZE-2) |
// |<-headerOffset |<-payloadOffset
enc.markData(b[enc.headerOffset:]) enc.markData(b[enc.headerOffset:])
binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:]))) binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:])))
// copy data to fec datashards // copy data from payloadOffset to fec shard cache
sz := len(b) sz := len(b)
enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz] enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz]
copy(enc.shardCache[enc.shardCount], b) copy(enc.shardCache[enc.shardCount][enc.payloadOffset:], b[enc.payloadOffset:])
enc.shardCount++ enc.shardCount++
// track max datashard length // track max datashard length
@ -285,7 +281,7 @@ func (enc *fecEncoder) encode(b []byte) (ps [][]byte) {
if err := enc.codec.Encode(cache); err == nil { if err := enc.codec.Encode(cache); err == nil {
ps = enc.shardCache[enc.dataShards:] ps = enc.shardCache[enc.dataShards:]
for k := range ps { for k := range ps {
enc.markFEC(ps[k][enc.headerOffset:]) enc.markParity(ps[k][enc.headerOffset:])
ps[k] = ps[k][:enc.maxSize] ps[k] = ps[k][:enc.maxSize]
} }
} }
@ -304,8 +300,9 @@ func (enc *fecEncoder) markData(data []byte) {
enc.next++ enc.next++
} }
func (enc *fecEncoder) markFEC(data []byte) { func (enc *fecEncoder) markParity(data []byte) {
binary.LittleEndian.PutUint32(data, enc.next) binary.LittleEndian.PutUint32(data, enc.next)
binary.LittleEndian.PutUint16(data[4:], typeFEC) binary.LittleEndian.PutUint16(data[4:], typeParity)
// sequence wrap will only happen at parity shard
enc.next = (enc.next + 1) % enc.paws enc.next = (enc.next + 1) % enc.paws
} }

BIN
vendor/github.com/fatedier/kcp-go/flame.png generated vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

View File

@ -1,9 +1,9 @@
// Package kcp - A Fast and Reliable ARQ Protocol
package kcp package kcp
import ( import (
"encoding/binary" "encoding/binary"
"sync/atomic" "sync/atomic"
"time"
) )
const ( const (
@ -30,6 +30,12 @@ const (
IKCP_PROBE_LIMIT = 120000 // up to 120 secs to probe window IKCP_PROBE_LIMIT = 120000 // up to 120 secs to probe window
) )
// monotonic reference time point
var refTime time.Time = time.Now()
// currentMs returns current elasped monotonic milliseconds since program startup
func currentMs() uint32 { return uint32(time.Now().Sub(refTime) / time.Millisecond) }
// output_callback is a prototype which ought capture conn and call conn.Write // output_callback is a prototype which ought capture conn and call conn.Write
type output_callback func(buf []byte, size int) type output_callback func(buf []byte, size int)
@ -145,8 +151,9 @@ type KCP struct {
acklist []ackItem acklist []ackItem
buffer []byte buffer []byte
output output_callback reserved int
output output_callback
} }
type ackItem struct { type ackItem struct {
@ -154,8 +161,11 @@ type ackItem struct {
ts uint32 ts uint32
} }
// NewKCP create a new kcp control object, 'conv' must equal in two endpoint // NewKCP create a new kcp state machine
// from the same connection. //
// 'conv' must be equal in the connection peers, or else data will be silently rejected.
//
// 'output' function will be called whenever these is data to be sent on wire.
func NewKCP(conv uint32, output output_callback) *KCP { func NewKCP(conv uint32, output output_callback) *KCP {
kcp := new(KCP) kcp := new(KCP)
kcp.conv = conv kcp.conv = conv
@ -164,7 +174,7 @@ func NewKCP(conv uint32, output output_callback) *KCP {
kcp.rmt_wnd = IKCP_WND_RCV kcp.rmt_wnd = IKCP_WND_RCV
kcp.mtu = IKCP_MTU_DEF kcp.mtu = IKCP_MTU_DEF
kcp.mss = kcp.mtu - IKCP_OVERHEAD kcp.mss = kcp.mtu - IKCP_OVERHEAD
kcp.buffer = make([]byte, (kcp.mtu+IKCP_OVERHEAD)*3) kcp.buffer = make([]byte, kcp.mtu)
kcp.rx_rto = IKCP_RTO_DEF kcp.rx_rto = IKCP_RTO_DEF
kcp.rx_minrto = IKCP_RTO_MIN kcp.rx_minrto = IKCP_RTO_MIN
kcp.interval = IKCP_INTERVAL kcp.interval = IKCP_INTERVAL
@ -189,6 +199,19 @@ func (kcp *KCP) delSegment(seg *segment) {
} }
} }
// ReserveBytes keeps n bytes untouched from the beginning of the buffer,
// the output_callback function should be aware of this.
//
// Return false if n >= mss
func (kcp *KCP) ReserveBytes(n int) bool {
if n >= int(kcp.mtu-IKCP_OVERHEAD) || n < 0 {
return false
}
kcp.reserved = n
kcp.mss = kcp.mtu - IKCP_OVERHEAD - uint32(n)
return true
}
// PeekSize checks the size of next message in the recv queue // PeekSize checks the size of next message in the recv queue
func (kcp *KCP) PeekSize() (length int) { func (kcp *KCP) PeekSize() (length int) {
if len(kcp.rcv_queue) == 0 { if len(kcp.rcv_queue) == 0 {
@ -214,19 +237,21 @@ func (kcp *KCP) PeekSize() (length int) {
return return
} }
// Recv is user/upper level recv: returns size, returns below zero for EAGAIN // Receive data from kcp state machine
//
// Return number of bytes read.
//
// Return -1 when there is no readable data.
//
// Return -2 if len(buffer) is smaller than kcp.PeekSize().
func (kcp *KCP) Recv(buffer []byte) (n int) { func (kcp *KCP) Recv(buffer []byte) (n int) {
if len(kcp.rcv_queue) == 0 { peeksize := kcp.PeekSize()
if peeksize < 0 {
return -1 return -1
} }
peeksize := kcp.PeekSize()
if peeksize < 0 {
return -2
}
if peeksize > len(buffer) { if peeksize > len(buffer) {
return -3 return -2
} }
var fast_recover bool var fast_recover bool
@ -255,7 +280,7 @@ func (kcp *KCP) Recv(buffer []byte) (n int) {
count = 0 count = 0
for k := range kcp.rcv_buf { for k := range kcp.rcv_buf {
seg := &kcp.rcv_buf[k] seg := &kcp.rcv_buf[k]
if seg.sn == kcp.rcv_nxt && len(kcp.rcv_queue) < int(kcp.rcv_wnd) { if seg.sn == kcp.rcv_nxt && len(kcp.rcv_queue)+count < int(kcp.rcv_wnd) {
kcp.rcv_nxt++ kcp.rcv_nxt++
count++ count++
} else { } else {
@ -386,6 +411,10 @@ func (kcp *KCP) parse_ack(sn uint32) {
for k := range kcp.snd_buf { for k := range kcp.snd_buf {
seg := &kcp.snd_buf[k] seg := &kcp.snd_buf[k]
if sn == seg.sn { if sn == seg.sn {
// mark and free space, but leave the segment here,
// and wait until `una` to delete this, then we don't
// have to shift the segments behind forward,
// which is an expensive operation for large window
seg.acked = 1 seg.acked = 1
kcp.delSegment(seg) kcp.delSegment(seg)
break break
@ -474,7 +503,7 @@ func (kcp *KCP) parse_data(newseg segment) bool {
count := 0 count := 0
for k := range kcp.rcv_buf { for k := range kcp.rcv_buf {
seg := &kcp.rcv_buf[k] seg := &kcp.rcv_buf[k]
if seg.sn == kcp.rcv_nxt && len(kcp.rcv_queue) < int(kcp.rcv_wnd) { if seg.sn == kcp.rcv_nxt && len(kcp.rcv_queue)+count < int(kcp.rcv_wnd) {
kcp.rcv_nxt++ kcp.rcv_nxt++
count++ count++
} else { } else {
@ -489,8 +518,12 @@ func (kcp *KCP) parse_data(newseg segment) bool {
return repeat return repeat
} }
// Input when you received a low level packet (eg. UDP packet), call it // Input a packet into kcp state machine.
// regular indicates a regular packet has received(not from FEC) //
// 'regular' indicates it's a real data packet from remote, and it means it's not generated from ReedSolomon
// codecs.
//
// 'ackNoDelay' will trigger immediate ACK, but surely it will not be efficient in bandwidth
func (kcp *KCP) Input(data []byte, regular, ackNoDelay bool) int { func (kcp *KCP) Input(data []byte, regular, ackNoDelay bool) int {
snd_una := kcp.snd_una snd_una := kcp.snd_una
if len(data) < IKCP_OVERHEAD { if len(data) < IKCP_OVERHEAD {
@ -634,14 +667,28 @@ func (kcp *KCP) flush(ackOnly bool) uint32 {
seg.una = kcp.rcv_nxt seg.una = kcp.rcv_nxt
buffer := kcp.buffer buffer := kcp.buffer
// flush acknowledges ptr := buffer[kcp.reserved:] // keep n bytes untouched
ptr := buffer
for i, ack := range kcp.acklist { // makeSpace makes room for writing
makeSpace := func(space int) {
size := len(buffer) - len(ptr) size := len(buffer) - len(ptr)
if size+IKCP_OVERHEAD > int(kcp.mtu) { if size+space > int(kcp.mtu) {
kcp.output(buffer, size) kcp.output(buffer, size)
ptr = buffer ptr = buffer[kcp.reserved:]
} }
}
// flush bytes in buffer if there is any
flushBuffer := func() {
size := len(buffer) - len(ptr)
if size > kcp.reserved {
kcp.output(buffer, size)
}
}
// flush acknowledges
for i, ack := range kcp.acklist {
makeSpace(IKCP_OVERHEAD)
// filter jitters caused by bufferbloat // filter jitters caused by bufferbloat
if ack.sn >= kcp.rcv_nxt || len(kcp.acklist)-1 == i { if ack.sn >= kcp.rcv_nxt || len(kcp.acklist)-1 == i {
seg.sn, seg.ts = ack.sn, ack.ts seg.sn, seg.ts = ack.sn, ack.ts
@ -651,10 +698,7 @@ func (kcp *KCP) flush(ackOnly bool) uint32 {
kcp.acklist = kcp.acklist[0:0] kcp.acklist = kcp.acklist[0:0]
if ackOnly { // flash remain ack segments if ackOnly { // flash remain ack segments
size := len(buffer) - len(ptr) flushBuffer()
if size > 0 {
kcp.output(buffer, size)
}
return kcp.interval return kcp.interval
} }
@ -685,22 +729,14 @@ func (kcp *KCP) flush(ackOnly bool) uint32 {
// flush window probing commands // flush window probing commands
if (kcp.probe & IKCP_ASK_SEND) != 0 { if (kcp.probe & IKCP_ASK_SEND) != 0 {
seg.cmd = IKCP_CMD_WASK seg.cmd = IKCP_CMD_WASK
size := len(buffer) - len(ptr) makeSpace(IKCP_OVERHEAD)
if size+IKCP_OVERHEAD > int(kcp.mtu) {
kcp.output(buffer, size)
ptr = buffer
}
ptr = seg.encode(ptr) ptr = seg.encode(ptr)
} }
// flush window probing commands // flush window probing commands
if (kcp.probe & IKCP_ASK_TELL) != 0 { if (kcp.probe & IKCP_ASK_TELL) != 0 {
seg.cmd = IKCP_CMD_WINS seg.cmd = IKCP_CMD_WINS
size := len(buffer) - len(ptr) makeSpace(IKCP_OVERHEAD)
if size+IKCP_OVERHEAD > int(kcp.mtu) {
kcp.output(buffer, size)
ptr = buffer
}
ptr = seg.encode(ptr) ptr = seg.encode(ptr)
} }
@ -779,20 +815,14 @@ func (kcp *KCP) flush(ackOnly bool) uint32 {
} }
if needsend { if needsend {
current = currentMs() // time update for a blocking call current = currentMs()
segment.xmit++ segment.xmit++
segment.ts = current segment.ts = current
segment.wnd = seg.wnd segment.wnd = seg.wnd
segment.una = seg.una segment.una = seg.una
size := len(buffer) - len(ptr)
need := IKCP_OVERHEAD + len(segment.data) need := IKCP_OVERHEAD + len(segment.data)
makeSpace(need)
if size+need > int(kcp.mtu) {
kcp.output(buffer, size)
ptr = buffer
}
ptr = segment.encode(ptr) ptr = segment.encode(ptr)
copy(ptr, segment.data) copy(ptr, segment.data)
ptr = ptr[len(segment.data):] ptr = ptr[len(segment.data):]
@ -809,10 +839,7 @@ func (kcp *KCP) flush(ackOnly bool) uint32 {
} }
// flash remain segments // flash remain segments
size := len(buffer) - len(ptr) flushBuffer()
if size > 0 {
kcp.output(buffer, size)
}
// counter updates // counter updates
sum := lostSegs sum := lostSegs
@ -864,6 +891,8 @@ func (kcp *KCP) flush(ackOnly bool) uint32 {
return uint32(minrto) return uint32(minrto)
} }
// (deprecated)
//
// Update updates state (call it repeatedly, every 10ms-100ms), or you can ask // Update updates state (call it repeatedly, every 10ms-100ms), or you can ask
// ikcp_check when to call it again (without ikcp_input/_send calling). // ikcp_check when to call it again (without ikcp_input/_send calling).
// 'current' - current timestamp in millisec. // 'current' - current timestamp in millisec.
@ -892,6 +921,8 @@ func (kcp *KCP) Update() {
} }
} }
// (deprecated)
//
// Check determines when should you invoke ikcp_update: // Check determines when should you invoke ikcp_update:
// returns when you should invoke ikcp_update in millisec, if there // returns when you should invoke ikcp_update in millisec, if there
// is no ikcp_input/_send calling. you can call ikcp_update in that // is no ikcp_input/_send calling. you can call ikcp_update in that
@ -947,12 +978,16 @@ func (kcp *KCP) SetMtu(mtu int) int {
if mtu < 50 || mtu < IKCP_OVERHEAD { if mtu < 50 || mtu < IKCP_OVERHEAD {
return -1 return -1
} }
buffer := make([]byte, (mtu+IKCP_OVERHEAD)*3) if kcp.reserved >= int(kcp.mtu-IKCP_OVERHEAD) || kcp.reserved < 0 {
return -1
}
buffer := make([]byte, mtu)
if buffer == nil { if buffer == nil {
return -2 return -2
} }
kcp.mtu = uint32(mtu) kcp.mtu = uint32(mtu)
kcp.mss = kcp.mtu - IKCP_OVERHEAD kcp.mss = kcp.mtu - IKCP_OVERHEAD - uint32(kcp.reserved)
kcp.buffer = buffer kcp.buffer = buffer
return 0 return 0
} }
@ -1006,7 +1041,13 @@ func (kcp *KCP) WaitSnd() int {
} }
// remove front n elements from queue // remove front n elements from queue
// if the number of elements to remove is more than half of the size.
// just shift the rear elements to front, otherwise just reslice q to q[n:]
// then the cost of runtime.growslice can always be less than n/2
func (kcp *KCP) remove_front(q []segment, n int) []segment { func (kcp *KCP) remove_front(q []segment, n int) []segment {
newn := copy(q, q[n:]) if n > cap(q)/2 {
return q[:newn] newn := copy(q, q[n:])
return q[:newn]
}
return q[n:]
} }

48
vendor/github.com/fatedier/kcp-go/readloop.go generated vendored Normal file
View File

@ -0,0 +1,48 @@
package kcp
import (
"sync/atomic"
"github.com/pkg/errors"
)
func (s *UDPSession) defaultReadLoop() {
buf := make([]byte, mtuLimit)
var src string
for {
if n, addr, err := s.conn.ReadFrom(buf); err == nil {
// make sure the packet is from the same source
if src == "" { // set source address
src = addr.String()
} else if addr.String() != src {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
continue
}
if n >= s.headerSize+IKCP_OVERHEAD {
s.packetInput(buf[:n])
} else {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
}
} else {
s.notifyReadError(errors.WithStack(err))
return
}
}
}
func (l *Listener) defaultMonitor() {
buf := make([]byte, mtuLimit)
for {
if n, from, err := l.conn.ReadFrom(buf); err == nil {
if n >= l.headerSize+IKCP_OVERHEAD {
l.packetInput(buf[:n], from)
} else {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
}
} else {
l.notifyReadError(errors.WithStack(err))
return
}
}
}

11
vendor/github.com/fatedier/kcp-go/readloop_generic.go generated vendored Normal file
View File

@ -0,0 +1,11 @@
// +build !linux
package kcp
func (s *UDPSession) readLoop() {
s.defaultReadLoop()
}
func (l *Listener) monitor() {
l.defaultMonitor()
}

120
vendor/github.com/fatedier/kcp-go/readloop_linux.go generated vendored Normal file
View File

@ -0,0 +1,120 @@
// +build linux
package kcp
import (
"net"
"os"
"sync/atomic"
"github.com/pkg/errors"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// the read loop for a client session
func (s *UDPSession) readLoop() {
// default version
if s.xconn == nil {
s.defaultReadLoop()
return
}
// x/net version
var src string
msgs := make([]ipv4.Message, batchSize)
for k := range msgs {
msgs[k].Buffers = [][]byte{make([]byte, mtuLimit)}
}
for {
if count, err := s.xconn.ReadBatch(msgs, 0); err == nil {
for i := 0; i < count; i++ {
msg := &msgs[i]
// make sure the packet is from the same source
if src == "" { // set source address if nil
src = msg.Addr.String()
} else if msg.Addr.String() != src {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
continue
}
if msg.N < s.headerSize+IKCP_OVERHEAD {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
continue
}
// source and size has validated
s.packetInput(msg.Buffers[0][:msg.N])
}
} else {
// compatibility issue:
// for linux kernel<=2.6.32, support for sendmmsg is not available
// an error of type os.SyscallError will be returned
if operr, ok := err.(*net.OpError); ok {
if se, ok := operr.Err.(*os.SyscallError); ok {
if se.Syscall == "recvmmsg" {
s.defaultReadLoop()
return
}
}
}
s.notifyReadError(errors.WithStack(err))
return
}
}
}
// monitor incoming data for all connections of server
func (l *Listener) monitor() {
var xconn batchConn
if _, ok := l.conn.(*net.UDPConn); ok {
addr, err := net.ResolveUDPAddr("udp", l.conn.LocalAddr().String())
if err == nil {
if addr.IP.To4() != nil {
xconn = ipv4.NewPacketConn(l.conn)
} else {
xconn = ipv6.NewPacketConn(l.conn)
}
}
}
// default version
if xconn == nil {
l.defaultMonitor()
return
}
// x/net version
msgs := make([]ipv4.Message, batchSize)
for k := range msgs {
msgs[k].Buffers = [][]byte{make([]byte, mtuLimit)}
}
for {
if count, err := xconn.ReadBatch(msgs, 0); err == nil {
for i := 0; i < count; i++ {
msg := &msgs[i]
if msg.N >= l.headerSize+IKCP_OVERHEAD {
l.packetInput(msg.Buffers[0][:msg.N], msg.Addr)
} else {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
}
}
} else {
// compatibility issue:
// for linux kernel<=2.6.32, support for sendmmsg is not available
// an error of type os.SyscallError will be returned
if operr, ok := err.(*net.OpError); ok {
if se, ok := operr.Err.(*os.SyscallError); ok {
if se.Syscall == "recvmmsg" {
l.defaultMonitor()
return
}
}
}
l.notifyReadError(errors.WithStack(err))
return
}
}
}

View File

@ -1,9 +1,17 @@
// Package kcp-go is a Reliable-UDP library for golang.
//
// This library intents to provide a smooth, resilient, ordered,
// error-checked and anonymous delivery of streams over UDP packets.
//
// The interfaces of this package aims to be compatible with
// net.Conn in standard library, but offers powerful features for advanced users.
package kcp package kcp
import ( import (
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"hash/crc32" "hash/crc32"
"io"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -14,14 +22,6 @@ import (
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
) )
type errTimeout struct {
error
}
func (errTimeout) Timeout() bool { return true }
func (errTimeout) Temporary() bool { return true }
func (errTimeout) Error() string { return "i/o timeout" }
const ( const (
// 16-bytes nonce for each packet // 16-bytes nonce for each packet
nonceSize = 16 nonceSize = 16
@ -42,9 +42,9 @@ const (
acceptBacklog = 128 acceptBacklog = 128
) )
const ( var (
errBrokenPipe = "broken pipe" errInvalidOperation = errors.New("invalid operation")
errInvalidOperation = "invalid operation" errTimeout = errors.New("timeout")
) )
var ( var (
@ -72,8 +72,6 @@ type (
// recvbuf turns packets into stream // recvbuf turns packets into stream
recvbuf []byte recvbuf []byte
bufptr []byte bufptr []byte
// header extended output buffer, if has header
ext []byte
// FEC codec // FEC codec
fecDecoder *fecDecoder fecDecoder *fecDecoder
@ -90,16 +88,27 @@ type (
// notifications // notifications
die chan struct{} // notify current session has Closed die chan struct{} // notify current session has Closed
dieOnce sync.Once
chReadEvent chan struct{} // notify Read() can be called without blocking chReadEvent chan struct{} // notify Read() can be called without blocking
chWriteEvent chan struct{} // notify Write() can be called without blocking chWriteEvent chan struct{} // notify Write() can be called without blocking
chReadError chan error // notify PacketConn.Read() have an error
chWriteError chan error // notify PacketConn.Write() have an error // socket error handling
socketReadError atomic.Value
socketWriteError atomic.Value
chSocketReadError chan struct{}
chSocketWriteError chan struct{}
socketReadErrorOnce sync.Once
socketWriteErrorOnce sync.Once
// nonce generator // nonce generator
nonce Entropy nonce Entropy
isClosed bool // flag the session has Closed // packets waiting to be sent on wire
mu sync.Mutex txqueue []ipv4.Message
xconn batchConn // for x/net
xconnWriteError error
mu sync.Mutex
} }
setReadBuffer interface { setReadBuffer interface {
@ -119,14 +128,26 @@ func newUDPSession(conv uint32, dataShards, parityShards int, l *Listener, conn
sess.nonce.Init() sess.nonce.Init()
sess.chReadEvent = make(chan struct{}, 1) sess.chReadEvent = make(chan struct{}, 1)
sess.chWriteEvent = make(chan struct{}, 1) sess.chWriteEvent = make(chan struct{}, 1)
sess.chReadError = make(chan error, 1) sess.chSocketReadError = make(chan struct{})
sess.chWriteError = make(chan error, 1) sess.chSocketWriteError = make(chan struct{})
sess.remote = remote sess.remote = remote
sess.conn = conn sess.conn = conn
sess.l = l sess.l = l
sess.block = block sess.block = block
sess.recvbuf = make([]byte, mtuLimit) sess.recvbuf = make([]byte, mtuLimit)
// cast to writebatch conn
if _, ok := conn.(*net.UDPConn); ok {
addr, err := net.ResolveUDPAddr("udp", conn.LocalAddr().String())
if err == nil {
if addr.IP.To4() != nil {
sess.xconn = ipv4.NewPacketConn(conn)
} else {
sess.xconn = ipv6.NewPacketConn(conn)
}
}
}
// FEC codec initialization // FEC codec initialization
sess.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards) sess.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards)
if sess.block != nil { if sess.block != nil {
@ -143,17 +164,12 @@ func newUDPSession(conv uint32, dataShards, parityShards int, l *Listener, conn
sess.headerSize += fecHeaderSizePlus2 sess.headerSize += fecHeaderSizePlus2
} }
// we only need to allocate extended packet buffer if we have the additional header
if sess.headerSize > 0 {
sess.ext = make([]byte, mtuLimit)
}
sess.kcp = NewKCP(conv, func(buf []byte, size int) { sess.kcp = NewKCP(conv, func(buf []byte, size int) {
if size >= IKCP_OVERHEAD { if size >= IKCP_OVERHEAD+sess.headerSize {
sess.output(buf[:size]) sess.output(buf[:size])
} }
}) })
sess.kcp.SetMtu(IKCP_MTU_DEF - sess.headerSize) sess.kcp.ReserveBytes(sess.headerSize)
// register current session to the global updater, // register current session to the global updater,
// which call sess.update() periodically. // which call sess.update() periodically.
@ -165,6 +181,7 @@ func newUDPSession(conv uint32, dataShards, parityShards int, l *Listener, conn
} else { } else {
atomic.AddUint64(&DefaultSnmp.PassiveOpens, 1) atomic.AddUint64(&DefaultSnmp.PassiveOpens, 1)
} }
currestab := atomic.AddUint64(&DefaultSnmp.CurrEstab, 1) currestab := atomic.AddUint64(&DefaultSnmp.CurrEstab, 1)
maxconn := atomic.LoadUint64(&DefaultSnmp.MaxConn) maxconn := atomic.LoadUint64(&DefaultSnmp.MaxConn)
if currestab > maxconn { if currestab > maxconn {
@ -186,11 +203,6 @@ func (s *UDPSession) Read(b []byte) (n int, err error) {
return n, nil return n, nil
} }
if s.isClosed {
s.mu.Unlock()
return 0, errors.New(errBrokenPipe)
}
if size := s.kcp.PeekSize(); size > 0 { // peek data size from kcp if size := s.kcp.PeekSize(); size > 0 { // peek data size from kcp
if len(b) >= size { // receive data into 'b' directly if len(b) >= size { // receive data into 'b' directly
s.kcp.Recv(b) s.kcp.Recv(b)
@ -220,7 +232,7 @@ func (s *UDPSession) Read(b []byte) (n int, err error) {
if !s.rd.IsZero() { if !s.rd.IsZero() {
if time.Now().After(s.rd) { if time.Now().After(s.rd) {
s.mu.Unlock() s.mu.Unlock()
return 0, errTimeout{} return 0, errors.WithStack(errTimeout)
} }
delay := s.rd.Sub(time.Now()) delay := s.rd.Sub(time.Now())
@ -229,63 +241,66 @@ func (s *UDPSession) Read(b []byte) (n int, err error) {
} }
s.mu.Unlock() s.mu.Unlock()
// wait for read event or timeout // wait for read event or timeout or error
select { select {
case <-s.chReadEvent: case <-s.chReadEvent:
case <-c:
case <-s.die:
case err = <-s.chReadError:
if timeout != nil { if timeout != nil {
timeout.Stop() timeout.Stop()
} }
return n, err case <-c:
} return 0, errors.WithStack(errTimeout)
case <-s.chSocketReadError:
if timeout != nil { return 0, s.socketReadError.Load().(error)
timeout.Stop() case <-s.die:
return 0, errors.WithStack(io.ErrClosedPipe)
} }
} }
} }
// Write implements net.Conn // Write implements net.Conn
func (s *UDPSession) Write(b []byte) (n int, err error) { func (s *UDPSession) Write(b []byte) (n int, err error) { return s.WriteBuffers([][]byte{b}) }
// WriteBuffers write a vector of byte slices to the underlying connection
func (s *UDPSession) WriteBuffers(v [][]byte) (n int, err error) {
for { for {
s.mu.Lock() select {
if s.isClosed { case <-s.chSocketWriteError:
s.mu.Unlock() return 0, s.socketWriteError.Load().(error)
return 0, errors.New(errBrokenPipe) case <-s.die:
return 0, errors.WithStack(io.ErrClosedPipe)
default:
} }
// controls how much data will be sent to kcp core s.mu.Lock()
// to prevent the memory from exhuasting
if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) { if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) {
n = len(b) for _, b := range v {
for { n += len(b)
if len(b) <= int(s.kcp.mss) { for {
s.kcp.Send(b) if len(b) <= int(s.kcp.mss) {
break s.kcp.Send(b)
} else { break
s.kcp.Send(b[:s.kcp.mss]) } else {
b = b[s.kcp.mss:] s.kcp.Send(b[:s.kcp.mss])
b = b[s.kcp.mss:]
}
} }
} }
// flush immediately if the queue is full
if s.kcp.WaitSnd() >= int(s.kcp.snd_wnd) || !s.writeDelay { if s.kcp.WaitSnd() >= int(s.kcp.snd_wnd) || !s.writeDelay {
s.kcp.flush(false) s.kcp.flush(false)
s.uncork()
} }
s.mu.Unlock() s.mu.Unlock()
atomic.AddUint64(&DefaultSnmp.BytesSent, uint64(n)) atomic.AddUint64(&DefaultSnmp.BytesSent, uint64(n))
return n, nil return n, nil
} }
// deadline for current writing operation
var timeout *time.Timer var timeout *time.Timer
var c <-chan time.Time var c <-chan time.Time
if !s.wd.IsZero() { if !s.wd.IsZero() {
if time.Now().After(s.wd) { if time.Now().After(s.wd) {
s.mu.Unlock() s.mu.Unlock()
return 0, errTimeout{} return 0, errors.WithStack(errTimeout)
} }
delay := s.wd.Sub(time.Now()) delay := s.wd.Sub(time.Now())
timeout = time.NewTimer(delay) timeout = time.NewTimer(delay)
@ -293,44 +308,52 @@ func (s *UDPSession) Write(b []byte) (n int, err error) {
} }
s.mu.Unlock() s.mu.Unlock()
// wait for write event or timeout
select { select {
case <-s.chWriteEvent: case <-s.chWriteEvent:
case <-c:
case <-s.die:
case err = <-s.chWriteError:
if timeout != nil { if timeout != nil {
timeout.Stop() timeout.Stop()
} }
return n, err case <-c:
} return 0, errors.WithStack(errTimeout)
case <-s.chSocketWriteError:
if timeout != nil { return 0, s.socketWriteError.Load().(error)
timeout.Stop() case <-s.die:
return 0, errors.WithStack(io.ErrClosedPipe)
} }
} }
} }
// uncork sends data in txqueue if there is any
func (s *UDPSession) uncork() {
if len(s.txqueue) > 0 {
s.tx(s.txqueue)
s.txqueue = s.txqueue[:0]
}
return
}
// Close closes the connection. // Close closes the connection.
func (s *UDPSession) Close() error { func (s *UDPSession) Close() error {
// remove current session from updater & listener(if necessary) var once bool
updater.removeSession(s) s.dieOnce.Do(func() {
if s.l != nil { // notify listener close(s.die)
s.l.closeSession(s.remote) once = true
} })
s.mu.Lock() if once {
defer s.mu.Unlock() // remove from updater
if s.isClosed { updater.removeSession(s)
return errors.New(errBrokenPipe) atomic.AddUint64(&DefaultSnmp.CurrEstab, ^uint64(0))
if s.l != nil { // belongs to listener
s.l.closeSession(s.remote)
return nil
} else { // client socket close
return s.conn.Close()
}
} else {
return errors.WithStack(io.ErrClosedPipe)
} }
close(s.die)
s.isClosed = true
atomic.AddUint64(&DefaultSnmp.CurrEstab, ^uint64(0))
if s.l == nil { // client socket close
return s.conn.Close()
}
return nil
} }
// LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
@ -390,7 +413,7 @@ func (s *UDPSession) SetMtu(mtu int) bool {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.kcp.SetMtu(mtu - s.headerSize) s.kcp.SetMtu(mtu)
return true return true
} }
@ -412,7 +435,9 @@ func (s *UDPSession) SetACKNoDelay(nodelay bool) {
s.ackNoDelay = nodelay s.ackNoDelay = nodelay
} }
// SetDUP duplicates udp packets for kcp output, for testing purpose only // (deprecated)
//
// SetDUP duplicates udp packets for kcp output.
func (s *UDPSession) SetDUP(dup int) { func (s *UDPSession) SetDUP(dup int) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@ -427,19 +452,29 @@ func (s *UDPSession) SetNoDelay(nodelay, interval, resend, nc int) {
s.kcp.NoDelay(nodelay, interval, resend, nc) s.kcp.NoDelay(nodelay, interval, resend, nc)
} }
// SetDSCP sets the 6bit DSCP field of IP header, no effect if it's accepted from Listener // SetDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header.
//
// It has no effect if it's accepted from Listener.
func (s *UDPSession) SetDSCP(dscp int) error { func (s *UDPSession) SetDSCP(dscp int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.l == nil { if s.l != nil {
if nc, ok := s.conn.(net.Conn); ok { return errInvalidOperation
if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err != nil { }
return ipv6.NewConn(nc).SetTrafficClass(dscp) if nc, ok := s.conn.(net.Conn); ok {
} var succeed bool
if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err == nil {
succeed = true
}
if err := ipv6.NewConn(nc).SetTrafficClass(dscp); err == nil {
succeed = true
}
if succeed {
return nil return nil
} }
} }
return errors.New(errInvalidOperation) return errInvalidOperation
} }
// SetReadBuffer sets the socket read buffer, no effect if it's accepted from Listener // SetReadBuffer sets the socket read buffer, no effect if it's accepted from Listener
@ -451,7 +486,7 @@ func (s *UDPSession) SetReadBuffer(bytes int) error {
return nc.SetReadBuffer(bytes) return nc.SetReadBuffer(bytes)
} }
} }
return errors.New(errInvalidOperation) return errInvalidOperation
} }
// SetWriteBuffer sets the socket write buffer, no effect if it's accepted from Listener // SetWriteBuffer sets the socket write buffer, no effect if it's accepted from Listener
@ -463,37 +498,29 @@ func (s *UDPSession) SetWriteBuffer(bytes int) error {
return nc.SetWriteBuffer(bytes) return nc.SetWriteBuffer(bytes)
} }
} }
return errors.New(errInvalidOperation) return errInvalidOperation
} }
// post-processing for sending a packet from kcp core // post-processing for sending a packet from kcp core
// steps: // steps:
// 0. Header extending
// 1. FEC packet generation // 1. FEC packet generation
// 2. CRC32 integrity // 2. CRC32 integrity
// 3. Encryption // 3. Encryption
// 4. WriteTo kernel // 4. TxQueue
func (s *UDPSession) output(buf []byte) { func (s *UDPSession) output(buf []byte) {
var ecc [][]byte var ecc [][]byte
// 0. extend buf's header space(if necessary)
ext := buf
if s.headerSize > 0 {
ext = s.ext[:s.headerSize+len(buf)]
copy(ext[s.headerSize:], buf)
}
// 1. FEC encoding // 1. FEC encoding
if s.fecEncoder != nil { if s.fecEncoder != nil {
ecc = s.fecEncoder.encode(ext) ecc = s.fecEncoder.encode(buf)
} }
// 2&3. crc32 & encryption // 2&3. crc32 & encryption
if s.block != nil { if s.block != nil {
s.nonce.Fill(ext[:nonceSize]) s.nonce.Fill(buf[:nonceSize])
checksum := crc32.ChecksumIEEE(ext[cryptHeaderSize:]) checksum := crc32.ChecksumIEEE(buf[cryptHeaderSize:])
binary.LittleEndian.PutUint32(ext[nonceSize:], checksum) binary.LittleEndian.PutUint32(buf[nonceSize:], checksum)
s.block.Encrypt(ext, ext) s.block.Encrypt(buf, buf)
for k := range ecc { for k := range ecc {
s.nonce.Fill(ecc[k][:nonceSize]) s.nonce.Fill(ecc[k][:nonceSize])
@ -503,28 +530,23 @@ func (s *UDPSession) output(buf []byte) {
} }
} }
// 4. WriteTo kernel // 4. TxQueue
nbytes := 0 var msg ipv4.Message
npkts := 0
for i := 0; i < s.dup+1; i++ { for i := 0; i < s.dup+1; i++ {
if n, err := s.conn.WriteTo(ext, s.remote); err == nil { bts := xmitBuf.Get().([]byte)[:len(buf)]
nbytes += n copy(bts, buf)
npkts++ msg.Buffers = [][]byte{bts}
} else { msg.Addr = s.remote
s.notifyWriteError(err) s.txqueue = append(s.txqueue, msg)
}
} }
for k := range ecc { for k := range ecc {
if n, err := s.conn.WriteTo(ecc[k], s.remote); err == nil { bts := xmitBuf.Get().([]byte)[:len(ecc[k])]
nbytes += n copy(bts, ecc[k])
npkts++ msg.Buffers = [][]byte{bts}
} else { msg.Addr = s.remote
s.notifyWriteError(err) s.txqueue = append(s.txqueue, msg)
}
} }
atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts))
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes))
} }
// kcp update, returns interval for next calling // kcp update, returns interval for next calling
@ -535,6 +557,7 @@ func (s *UDPSession) update() (interval time.Duration) {
if s.kcp.WaitSnd() < waitsnd { if s.kcp.WaitSnd() < waitsnd {
s.notifyWriteEvent() s.notifyWriteEvent()
} }
s.uncork()
s.mu.Unlock() s.mu.Unlock()
return return
} }
@ -556,10 +579,39 @@ func (s *UDPSession) notifyWriteEvent() {
} }
} }
func (s *UDPSession) notifyReadError(err error) {
s.socketReadErrorOnce.Do(func() {
s.socketReadError.Store(err)
close(s.chSocketReadError)
})
}
func (s *UDPSession) notifyWriteError(err error) { func (s *UDPSession) notifyWriteError(err error) {
select { s.socketWriteErrorOnce.Do(func() {
case s.chWriteError <- err: s.socketWriteError.Store(err)
default: close(s.chSocketWriteError)
})
}
// packet input stage
func (s *UDPSession) packetInput(data []byte) {
dataValid := false
if s.block != nil {
s.block.Decrypt(data, data)
data = data[nonceSize:]
checksum := crc32.ChecksumIEEE(data[crcSize:])
if checksum == binary.LittleEndian.Uint32(data) {
data = data[crcSize:]
dataValid = true
} else {
atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1)
}
} else if s.block == nil {
dataValid = true
}
if dataValid {
s.kcpInput(data)
} }
} }
@ -568,16 +620,16 @@ func (s *UDPSession) kcpInput(data []byte) {
if s.fecDecoder != nil { if s.fecDecoder != nil {
if len(data) > fecHeaderSize { // must be larger than fec header size if len(data) > fecHeaderSize { // must be larger than fec header size
f := s.fecDecoder.decodeBytes(data) f := fecPacket(data)
if f.flag == typeData || f.flag == typeFEC { // header check if f.flag() == typeData || f.flag() == typeParity { // header check
if f.flag == typeFEC { if f.flag() == typeParity {
fecParityShards++ fecParityShards++
} }
recovers := s.fecDecoder.decode(f) recovers := s.fecDecoder.decode(f)
s.mu.Lock() s.mu.Lock()
waitsnd := s.kcp.WaitSnd() waitsnd := s.kcp.WaitSnd()
if f.flag == typeData { if f.flag() == typeData {
if ret := s.kcp.Input(data[fecHeaderSizePlus2:], true, s.ackNoDelay); ret != 0 { if ret := s.kcp.Input(data[fecHeaderSizePlus2:], true, s.ackNoDelay); ret != 0 {
kcpInErrors++ kcpInErrors++
} }
@ -598,6 +650,8 @@ func (s *UDPSession) kcpInput(data []byte) {
} else { } else {
fecErrs++ fecErrs++
} }
// recycle the recovers
xmitBuf.Put(r)
} }
// to notify the readers to receive the data // to notify the readers to receive the data
@ -608,6 +662,7 @@ func (s *UDPSession) kcpInput(data []byte) {
if s.kcp.WaitSnd() < waitsnd { if s.kcp.WaitSnd() < waitsnd {
s.notifyWriteEvent() s.notifyWriteEvent()
} }
s.uncork()
s.mu.Unlock() s.mu.Unlock()
} else { } else {
atomic.AddUint64(&DefaultSnmp.InErrs, 1) atomic.AddUint64(&DefaultSnmp.InErrs, 1)
@ -627,6 +682,7 @@ func (s *UDPSession) kcpInput(data []byte) {
if s.kcp.WaitSnd() < waitsnd { if s.kcp.WaitSnd() < waitsnd {
s.notifyWriteEvent() s.notifyWriteEvent()
} }
s.uncork()
s.mu.Unlock() s.mu.Unlock()
} }
@ -644,50 +700,7 @@ func (s *UDPSession) kcpInput(data []byte) {
if fecRecovered > 0 { if fecRecovered > 0 {
atomic.AddUint64(&DefaultSnmp.FECRecovered, fecRecovered) atomic.AddUint64(&DefaultSnmp.FECRecovered, fecRecovered)
} }
}
// the read loop for a client session
func (s *UDPSession) readLoop() {
buf := make([]byte, mtuLimit)
var src string
for {
if n, addr, err := s.conn.ReadFrom(buf); err == nil {
// make sure the packet is from the same source
if src == "" { // set source address
src = addr.String()
} else if addr.String() != src {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
continue
}
if n >= s.headerSize+IKCP_OVERHEAD {
data := buf[:n]
dataValid := false
if s.block != nil {
s.block.Decrypt(data, data)
data = data[nonceSize:]
checksum := crc32.ChecksumIEEE(data[crcSize:])
if checksum == binary.LittleEndian.Uint32(data) {
data = data[crcSize:]
dataValid = true
} else {
atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1)
}
} else if s.block == nil {
dataValid = true
}
if dataValid {
s.kcpInput(data)
}
} else {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
}
} else {
s.chReadError <- err
return
}
}
} }
type ( type (
@ -704,98 +717,91 @@ type (
chAccepts chan *UDPSession // Listen() backlog chAccepts chan *UDPSession // Listen() backlog
chSessionClosed chan net.Addr // session close queue chSessionClosed chan net.Addr // session close queue
headerSize int // the additional header to a KCP frame headerSize int // the additional header to a KCP frame
die chan struct{} // notify the listener has closed
rd atomic.Value // read deadline for Accept() die chan struct{} // notify the listener has closed
wd atomic.Value dieOnce sync.Once
// socket error handling
socketReadError atomic.Value
chSocketReadError chan struct{}
socketReadErrorOnce sync.Once
rd atomic.Value // read deadline for Accept()
} }
) )
// monitor incoming data for all connections of server // packet input stage
func (l *Listener) monitor() { func (l *Listener) packetInput(data []byte, addr net.Addr) {
// a cache for session object last used dataValid := false
var lastAddr string if l.block != nil {
var lastSession *UDPSession l.block.Decrypt(data, data)
buf := make([]byte, mtuLimit) data = data[nonceSize:]
for { checksum := crc32.ChecksumIEEE(data[crcSize:])
if n, from, err := l.conn.ReadFrom(buf); err == nil { if checksum == binary.LittleEndian.Uint32(data) {
if n >= l.headerSize+IKCP_OVERHEAD { data = data[crcSize:]
data := buf[:n] dataValid = true
dataValid := false } else {
if l.block != nil { atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1)
l.block.Decrypt(data, data) }
data = data[nonceSize:] } else if l.block == nil {
checksum := crc32.ChecksumIEEE(data[crcSize:]) dataValid = true
if checksum == binary.LittleEndian.Uint32(data) { }
data = data[crcSize:]
dataValid = true if dataValid {
} else { l.sessionLock.Lock()
atomic.AddUint64(&DefaultSnmp.InCsumErrors, 1) s, ok := l.sessions[addr.String()]
l.sessionLock.Unlock()
if !ok { // new address:port
if len(l.chAccepts) < cap(l.chAccepts) { // do not let the new sessions overwhelm accept queue
var conv uint32
convValid := false
if l.fecDecoder != nil {
isfec := binary.LittleEndian.Uint16(data[4:])
if isfec == typeData {
conv = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2:])
convValid = true
} }
} else if l.block == nil { } else {
dataValid = true conv = binary.LittleEndian.Uint32(data)
convValid = true
} }
if dataValid { if convValid { // creates a new session only if the 'conv' field in kcp is accessible
addr := from.String() s := newUDPSession(conv, l.dataShards, l.parityShards, l, l.conn, addr, l.block)
var s *UDPSession s.kcpInput(data)
var ok bool l.sessionLock.Lock()
l.sessions[addr.String()] = s
// the packets received from an address always come in batch, l.sessionLock.Unlock()
// cache the session for next packet, without querying map. l.chAccepts <- s
if addr == lastAddr {
s, ok = lastSession, true
} else {
l.sessionLock.Lock()
if s, ok = l.sessions[addr]; ok {
lastSession = s
lastAddr = addr
}
l.sessionLock.Unlock()
}
if !ok { // new session
if len(l.chAccepts) < cap(l.chAccepts) { // do not let the new sessions overwhelm accept queue
var conv uint32
convValid := false
if l.fecDecoder != nil {
isfec := binary.LittleEndian.Uint16(data[4:])
if isfec == typeData {
conv = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2:])
convValid = true
}
} else {
conv = binary.LittleEndian.Uint32(data)
convValid = true
}
if convValid { // creates a new session only if the 'conv' field in kcp is accessible
s := newUDPSession(conv, l.dataShards, l.parityShards, l, l.conn, from, l.block)
s.kcpInput(data)
l.sessionLock.Lock()
l.sessions[addr] = s
l.sessionLock.Unlock()
l.chAccepts <- s
}
}
} else {
s.kcpInput(data)
}
} }
} else {
atomic.AddUint64(&DefaultSnmp.InErrs, 1)
} }
} else { } else {
return s.kcpInput(data)
} }
} }
} }
func (l *Listener) notifyReadError(err error) {
l.socketReadErrorOnce.Do(func() {
l.socketReadError.Store(err)
close(l.chSocketReadError)
// propagate read error to all sessions
l.sessionLock.Lock()
for _, s := range l.sessions {
s.notifyReadError(err)
}
l.sessionLock.Unlock()
})
}
// SetReadBuffer sets the socket read buffer for the Listener // SetReadBuffer sets the socket read buffer for the Listener
func (l *Listener) SetReadBuffer(bytes int) error { func (l *Listener) SetReadBuffer(bytes int) error {
if nc, ok := l.conn.(setReadBuffer); ok { if nc, ok := l.conn.(setReadBuffer); ok {
return nc.SetReadBuffer(bytes) return nc.SetReadBuffer(bytes)
} }
return errors.New(errInvalidOperation) return errInvalidOperation
} }
// SetWriteBuffer sets the socket write buffer for the Listener // SetWriteBuffer sets the socket write buffer for the Listener
@ -803,18 +809,25 @@ func (l *Listener) SetWriteBuffer(bytes int) error {
if nc, ok := l.conn.(setWriteBuffer); ok { if nc, ok := l.conn.(setWriteBuffer); ok {
return nc.SetWriteBuffer(bytes) return nc.SetWriteBuffer(bytes)
} }
return errors.New(errInvalidOperation) return errInvalidOperation
} }
// SetDSCP sets the 6bit DSCP field of IP header // SetDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header.
func (l *Listener) SetDSCP(dscp int) error { func (l *Listener) SetDSCP(dscp int) error {
if nc, ok := l.conn.(net.Conn); ok { if nc, ok := l.conn.(net.Conn); ok {
if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err != nil { var succeed bool
return ipv6.NewConn(nc).SetTrafficClass(dscp) if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err == nil {
succeed = true
}
if err := ipv6.NewConn(nc).SetTrafficClass(dscp); err == nil {
succeed = true
}
if succeed {
return nil
} }
return nil
} }
return errors.New(errInvalidOperation) return errInvalidOperation
} }
// Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn. // Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn.
@ -831,11 +844,13 @@ func (l *Listener) AcceptKCP() (*UDPSession, error) {
select { select {
case <-timeout: case <-timeout:
return nil, &errTimeout{} return nil, errors.WithStack(errTimeout)
case c := <-l.chAccepts: case c := <-l.chAccepts:
return c, nil return c, nil
case <-l.chSocketReadError:
return nil, l.socketReadError.Load().(error)
case <-l.die: case <-l.die:
return nil, errors.New(errBrokenPipe) return nil, errors.WithStack(io.ErrClosedPipe)
} }
} }
@ -853,15 +868,21 @@ func (l *Listener) SetReadDeadline(t time.Time) error {
} }
// SetWriteDeadline implements the Conn SetWriteDeadline method. // SetWriteDeadline implements the Conn SetWriteDeadline method.
func (l *Listener) SetWriteDeadline(t time.Time) error { func (l *Listener) SetWriteDeadline(t time.Time) error { return errInvalidOperation }
l.wd.Store(t)
return nil
}
// Close stops listening on the UDP address. Already Accepted connections are not closed. // Close stops listening on the UDP address, and closes the socket
func (l *Listener) Close() error { func (l *Listener) Close() error {
close(l.die) var once bool
return l.conn.Close() l.dieOnce.Do(func() {
close(l.die)
once = true
})
if once {
return l.conn.Close()
} else {
return errors.WithStack(io.ErrClosedPipe)
}
} }
// closeSession notify the listener that a session has closed // closeSession notify the listener that a session has closed
@ -881,16 +902,21 @@ func (l *Listener) Addr() net.Addr { return l.conn.LocalAddr() }
// Listen listens for incoming KCP packets addressed to the local address laddr on the network "udp", // Listen listens for incoming KCP packets addressed to the local address laddr on the network "udp",
func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr, nil, 0, 0) } func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr, nil, 0, 0) }
// ListenWithOptions listens for incoming KCP packets addressed to the local address laddr on the network "udp" with packet encryption, // ListenWithOptions listens for incoming KCP packets addressed to the local address laddr on the network "udp" with packet encryption.
// dataShards, parityShards defines Reed-Solomon Erasure Coding parameters //
// 'block' is the block encryption algorithm to encrypt packets.
//
// 'dataShards', 'parityShards' specifiy how many parity packets will be generated following the data packets.
//
// Check https://github.com/klauspost/reedsolomon for details
func ListenWithOptions(laddr string, block BlockCrypt, dataShards, parityShards int) (*Listener, error) { func ListenWithOptions(laddr string, block BlockCrypt, dataShards, parityShards int) (*Listener, error) {
udpaddr, err := net.ResolveUDPAddr("udp", laddr) udpaddr, err := net.ResolveUDPAddr("udp", laddr)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "net.ResolveUDPAddr") return nil, errors.WithStack(err)
} }
conn, err := net.ListenUDP("udp", udpaddr) conn, err := net.ListenUDP("udp", udpaddr)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "net.ListenUDP") return nil, errors.WithStack(err)
} }
return ServeConn(block, dataShards, parityShards, conn) return ServeConn(block, dataShards, parityShards, conn)
@ -908,6 +934,7 @@ func ServeConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketCo
l.parityShards = parityShards l.parityShards = parityShards
l.block = block l.block = block
l.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards) l.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards)
l.chSocketReadError = make(chan struct{})
// calculate header size // calculate header size
if l.block != nil { if l.block != nil {
@ -921,15 +948,21 @@ func ServeConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketCo
return l, nil return l, nil
} }
// Dial connects to the remote address "raddr" on the network "udp" // Dial connects to the remote address "raddr" on the network "udp" without encryption and FEC
func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr, nil, 0, 0) } func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr, nil, 0, 0) }
// DialWithOptions connects to the remote address "raddr" on the network "udp" with packet encryption // DialWithOptions connects to the remote address "raddr" on the network "udp" with packet encryption
//
// 'block' is the block encryption algorithm to encrypt packets.
//
// 'dataShards', 'parityShards' specifiy how many parity packets will be generated following the data packets.
//
// Check https://github.com/klauspost/reedsolomon for details
func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards int) (*UDPSession, error) { func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards int) (*UDPSession, error) {
// network type detection // network type detection
udpaddr, err := net.ResolveUDPAddr("udp", raddr) udpaddr, err := net.ResolveUDPAddr("udp", raddr)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "net.ResolveUDPAddr") return nil, errors.WithStack(err)
} }
network := "udp4" network := "udp4"
if udpaddr.IP.To4() == nil { if udpaddr.IP.To4() == nil {
@ -938,30 +971,33 @@ func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards in
conn, err := net.ListenUDP(network, nil) conn, err := net.ListenUDP(network, nil)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "net.DialUDP") return nil, errors.WithStack(err)
} }
return NewConn(raddr, block, dataShards, parityShards, conn) return NewConn(raddr, block, dataShards, parityShards, conn)
} }
// NewConn3 establishes a session and talks KCP protocol over a packet connection.
func NewConn3(convid uint32, raddr net.Addr, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) {
return newUDPSession(convid, dataShards, parityShards, nil, conn, raddr, block), nil
}
// NewConn2 establishes a session and talks KCP protocol over a packet connection.
func NewConn2(raddr net.Addr, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) {
var convid uint32
binary.Read(rand.Reader, binary.LittleEndian, &convid)
return NewConn3(convid, raddr, block, dataShards, parityShards, conn)
}
// NewConn establishes a session and talks KCP protocol over a packet connection. // NewConn establishes a session and talks KCP protocol over a packet connection.
func NewConn(raddr string, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) { func NewConn(raddr string, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) {
udpaddr, err := net.ResolveUDPAddr("udp", raddr) udpaddr, err := net.ResolveUDPAddr("udp", raddr)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "net.ResolveUDPAddr") return nil, errors.WithStack(err)
} }
return NewConn2(udpaddr, block, dataShards, parityShards, conn)
var convid uint32
binary.Read(rand.Reader, binary.LittleEndian, &convid)
return newUDPSession(convid, dataShards, parityShards, nil, conn, udpaddr, block), nil
} }
// monotonic reference time point
var refTime time.Time = time.Now()
// currentMs returns current elasped monotonic milliseconds since program startup
func currentMs() uint32 { return uint32(time.Now().Sub(refTime) / time.Millisecond) }
func NewConnEx(convid uint32, connected bool, raddr string, block BlockCrypt, dataShards, parityShards int, conn *net.UDPConn) (*UDPSession, error) { func NewConnEx(convid uint32, connected bool, raddr string, block BlockCrypt, dataShards, parityShards int, conn *net.UDPConn) (*UDPSession, error) {
udpaddr, err := net.ResolveUDPAddr("udp", raddr) udpaddr, err := net.ResolveUDPAddr("udp", raddr)
if err != nil { if err != nil {

25
vendor/github.com/fatedier/kcp-go/tx.go generated vendored Normal file
View File

@ -0,0 +1,25 @@
package kcp
import (
"sync/atomic"
"github.com/pkg/errors"
"golang.org/x/net/ipv4"
)
func (s *UDPSession) defaultTx(txqueue []ipv4.Message) {
nbytes := 0
npkts := 0
for k := range txqueue {
if n, err := s.conn.WriteTo(txqueue[k].Buffers[0], txqueue[k].Addr); err == nil {
nbytes += n
npkts++
xmitBuf.Put(txqueue[k].Buffers[0])
} else {
s.notifyWriteError(errors.WithStack(err))
break
}
}
atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts))
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes))
}

11
vendor/github.com/fatedier/kcp-go/tx_generic.go generated vendored Normal file
View File

@ -0,0 +1,11 @@
// +build !linux
package kcp
import (
"golang.org/x/net/ipv4"
)
func (s *UDPSession) tx(txqueue []ipv4.Message) {
s.defaultTx(txqueue)
}

52
vendor/github.com/fatedier/kcp-go/tx_linux.go generated vendored Normal file
View File

@ -0,0 +1,52 @@
// +build linux
package kcp
import (
"net"
"os"
"sync/atomic"
"github.com/pkg/errors"
"golang.org/x/net/ipv4"
)
func (s *UDPSession) tx(txqueue []ipv4.Message) {
// default version
if s.xconn == nil || s.xconnWriteError != nil {
s.defaultTx(txqueue)
return
}
// x/net version
nbytes := 0
npkts := 0
for len(txqueue) > 0 {
if n, err := s.xconn.WriteBatch(txqueue, 0); err == nil {
for k := range txqueue[:n] {
nbytes += len(txqueue[k].Buffers[0])
xmitBuf.Put(txqueue[k].Buffers[0])
}
npkts += n
txqueue = txqueue[n:]
} else {
// compatibility issue:
// for linux kernel<=2.6.32, support for sendmmsg is not available
// an error of type os.SyscallError will be returned
if operr, ok := err.(*net.OpError); ok {
if se, ok := operr.Err.(*os.SyscallError); ok {
if se.Syscall == "sendmmsg" {
s.xconnWriteError = se
s.defaultTx(txqueue)
return
}
}
}
s.notifyWriteError(errors.WithStack(err))
break
}
}
atomic.AddUint64(&DefaultSnmp.OutPkts, uint64(npkts))
atomic.AddUint64(&DefaultSnmp.OutBytes, uint64(nbytes))
}

View File

@ -76,10 +76,10 @@ func (h *updateHeap) wakeup() {
} }
func (h *updateHeap) updateTask() { func (h *updateHeap) updateTask() {
var timer <-chan time.Time timer := time.NewTimer(0)
for { for {
select { select {
case <-timer: case <-timer.C:
case <-h.chWakeUp: case <-h.chWakeUp:
} }
@ -87,7 +87,7 @@ func (h *updateHeap) updateTask() {
hlen := h.Len() hlen := h.Len()
for i := 0; i < hlen; i++ { for i := 0; i < hlen; i++ {
entry := &h.entries[0] entry := &h.entries[0]
if time.Now().After(entry.ts) { if !time.Now().Before(entry.ts) {
interval := entry.s.update() interval := entry.s.update()
entry.ts = time.Now().Add(interval) entry.ts = time.Now().Add(interval)
heap.Fix(h, 0) heap.Fix(h, 0)
@ -97,7 +97,7 @@ func (h *updateHeap) updateTask() {
} }
if hlen > 0 { if hlen > 0 {
timer = time.After(h.entries[0].ts.Sub(time.Now())) timer.Reset(h.entries[0].ts.Sub(time.Now()))
} }
h.mu.Unlock() h.mu.Unlock()
} }

View File

@ -1,19 +0,0 @@
language: go
sudo: false
matrix:
include:
- go: 1.3
- go: 1.4
- go: 1.5
- go: 1.6
- go: 1.7
- go: tip
allow_failures:
- go: tip
script:
- go get -t -v ./...
- diff -u <(echo -n) <(gofmt -d .)
- go vet $(go list ./... | grep -v /vendor/)
- go test -v -race ./...

View File

@ -1,10 +0,0 @@
context
=======
[![Build Status](https://travis-ci.org/gorilla/context.png?branch=master)](https://travis-ci.org/gorilla/context)
gorilla/context is a general purpose registry for global request variables.
> Note: gorilla/context, having been born well before `context.Context` existed, does not play well
> with the shallow copying of the request that [`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext) (added to net/http Go 1.7 onwards) performs. You should either use *just* gorilla/context, or moving forward, the new `http.Request.Context()`.
Read the full documentation here: http://www.gorillatoolkit.org/pkg/context

View File

@ -1,143 +0,0 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package context
import (
"net/http"
"sync"
"time"
)
var (
mutex sync.RWMutex
data = make(map[*http.Request]map[interface{}]interface{})
datat = make(map[*http.Request]int64)
)
// Set stores a value for a given key in a given request.
func Set(r *http.Request, key, val interface{}) {
mutex.Lock()
if data[r] == nil {
data[r] = make(map[interface{}]interface{})
datat[r] = time.Now().Unix()
}
data[r][key] = val
mutex.Unlock()
}
// Get returns a value stored for a given key in a given request.
func Get(r *http.Request, key interface{}) interface{} {
mutex.RLock()
if ctx := data[r]; ctx != nil {
value := ctx[key]
mutex.RUnlock()
return value
}
mutex.RUnlock()
return nil
}
// GetOk returns stored value and presence state like multi-value return of map access.
func GetOk(r *http.Request, key interface{}) (interface{}, bool) {
mutex.RLock()
if _, ok := data[r]; ok {
value, ok := data[r][key]
mutex.RUnlock()
return value, ok
}
mutex.RUnlock()
return nil, false
}
// GetAll returns all stored values for the request as a map. Nil is returned for invalid requests.
func GetAll(r *http.Request) map[interface{}]interface{} {
mutex.RLock()
if context, ok := data[r]; ok {
result := make(map[interface{}]interface{}, len(context))
for k, v := range context {
result[k] = v
}
mutex.RUnlock()
return result
}
mutex.RUnlock()
return nil
}
// GetAllOk returns all stored values for the request as a map and a boolean value that indicates if
// the request was registered.
func GetAllOk(r *http.Request) (map[interface{}]interface{}, bool) {
mutex.RLock()
context, ok := data[r]
result := make(map[interface{}]interface{}, len(context))
for k, v := range context {
result[k] = v
}
mutex.RUnlock()
return result, ok
}
// Delete removes a value stored for a given key in a given request.
func Delete(r *http.Request, key interface{}) {
mutex.Lock()
if data[r] != nil {
delete(data[r], key)
}
mutex.Unlock()
}
// Clear removes all values stored for a given request.
//
// This is usually called by a handler wrapper to clean up request
// variables at the end of a request lifetime. See ClearHandler().
func Clear(r *http.Request) {
mutex.Lock()
clear(r)
mutex.Unlock()
}
// clear is Clear without the lock.
func clear(r *http.Request) {
delete(data, r)
delete(datat, r)
}
// Purge removes request data stored for longer than maxAge, in seconds.
// It returns the amount of requests removed.
//
// If maxAge <= 0, all request data is removed.
//
// This is only used for sanity check: in case context cleaning was not
// properly set some request data can be kept forever, consuming an increasing
// amount of memory. In case this is detected, Purge() must be called
// periodically until the problem is fixed.
func Purge(maxAge int) int {
mutex.Lock()
count := 0
if maxAge <= 0 {
count = len(data)
data = make(map[*http.Request]map[interface{}]interface{})
datat = make(map[*http.Request]int64)
} else {
min := time.Now().Unix() - int64(maxAge)
for r := range data {
if datat[r] < min {
clear(r)
count++
}
}
}
mutex.Unlock()
return count
}
// ClearHandler wraps an http.Handler and clears request values at the end
// of a request lifetime.
func ClearHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer Clear(r)
h.ServeHTTP(w, r)
})
}

View File

@ -1,88 +0,0 @@
// Copyright 2012 The Gorilla Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package context stores values shared during a request lifetime.
Note: gorilla/context, having been born well before `context.Context` existed,
does not play well > with the shallow copying of the request that
[`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext)
(added to net/http Go 1.7 onwards) performs. You should either use *just*
gorilla/context, or moving forward, the new `http.Request.Context()`.
For example, a router can set variables extracted from the URL and later
application handlers can access those values, or it can be used to store
sessions values to be saved at the end of a request. There are several
others common uses.
The idea was posted by Brad Fitzpatrick to the go-nuts mailing list:
http://groups.google.com/group/golang-nuts/msg/e2d679d303aa5d53
Here's the basic usage: first define the keys that you will need. The key
type is interface{} so a key can be of any type that supports equality.
Here we define a key using a custom int type to avoid name collisions:
package foo
import (
"github.com/gorilla/context"
)
type key int
const MyKey key = 0
Then set a variable. Variables are bound to an http.Request object, so you
need a request instance to set a value:
context.Set(r, MyKey, "bar")
The application can later access the variable using the same key you provided:
func MyHandler(w http.ResponseWriter, r *http.Request) {
// val is "bar".
val := context.Get(r, foo.MyKey)
// returns ("bar", true)
val, ok := context.GetOk(r, foo.MyKey)
// ...
}
And that's all about the basic usage. We discuss some other ideas below.
Any type can be stored in the context. To enforce a given type, make the key
private and wrap Get() and Set() to accept and return values of a specific
type:
type key int
const mykey key = 0
// GetMyKey returns a value for this package from the request values.
func GetMyKey(r *http.Request) SomeType {
if rv := context.Get(r, mykey); rv != nil {
return rv.(SomeType)
}
return nil
}
// SetMyKey sets a value for this package in the request values.
func SetMyKey(r *http.Request, val SomeType) {
context.Set(r, mykey, val)
}
Variables must be cleared at the end of a request, to remove all values
that were stored. This can be done in an http.Handler, after a request was
served. Just call Clear() passing the request:
context.Clear(r)
...or use ClearHandler(), which conveniently wraps an http.Handler to clear
variables at the end of a request lifetime.
The Routers from the packages gorilla/mux and gorilla/pat call Clear()
so if you are using either of them you don't need to clear the context manually.
*/
package context

View File

@ -1,23 +0,0 @@
language: go
sudo: false
matrix:
include:
- go: 1.5.x
- go: 1.6.x
- go: 1.7.x
- go: 1.8.x
- go: 1.9.x
- go: 1.10.x
- go: tip
allow_failures:
- go: tip
install:
- # Skip
script:
- go get -t -v ./...
- diff -u <(echo -n) <(gofmt -d .)
- go tool vet .
- go test -v -race ./...

8
vendor/github.com/gorilla/mux/AUTHORS generated vendored Normal file
View File

@ -0,0 +1,8 @@
# This is the official list of gorilla/mux authors for copyright purposes.
#
# Please keep the list sorted.
Google LLC (https://opensource.google.com/)
Kamil Kisielk <kamil@kamilkisiel.net>
Matt Silverlock <matt@eatsleeprepeat.net>
Rodrigo Moraes (https://github.com/moraes)

View File

@ -1,11 +0,0 @@
**What version of Go are you running?** (Paste the output of `go version`)
**What version of gorilla/mux are you at?** (Paste the output of `git rev-parse HEAD` inside `$GOPATH/src/github.com/gorilla/mux`)
**Describe your problem** (and what you have tried so far)
**Paste a minimal, runnable, reproduction of your issue below** (use backticks to format it)

View File

@ -1,4 +1,4 @@
Copyright (c) 2012 Rodrigo Moraes. All rights reserved. Copyright (c) 2012-2018 The Gorilla Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are modification, are permitted provided that the following conditions are

View File

@ -2,11 +2,12 @@
[![GoDoc](https://godoc.org/github.com/gorilla/mux?status.svg)](https://godoc.org/github.com/gorilla/mux) [![GoDoc](https://godoc.org/github.com/gorilla/mux?status.svg)](https://godoc.org/github.com/gorilla/mux)
[![Build Status](https://travis-ci.org/gorilla/mux.svg?branch=master)](https://travis-ci.org/gorilla/mux) [![Build Status](https://travis-ci.org/gorilla/mux.svg?branch=master)](https://travis-ci.org/gorilla/mux)
[![CircleCI](https://circleci.com/gh/gorilla/mux.svg?style=svg)](https://circleci.com/gh/gorilla/mux)
[![Sourcegraph](https://sourcegraph.com/github.com/gorilla/mux/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/mux?badge) [![Sourcegraph](https://sourcegraph.com/github.com/gorilla/mux/-/badge.svg)](https://sourcegraph.com/github.com/gorilla/mux?badge)
![Gorilla Logo](http://www.gorillatoolkit.org/static/images/gorilla-icon-64.png) ![Gorilla Logo](http://www.gorillatoolkit.org/static/images/gorilla-icon-64.png)
http://www.gorillatoolkit.org/pkg/mux https://www.gorillatoolkit.org/pkg/mux
Package `gorilla/mux` implements a request router and dispatcher for matching incoming requests to Package `gorilla/mux` implements a request router and dispatcher for matching incoming requests to
their respective handler. their respective handler.
@ -29,6 +30,7 @@ The name mux stands for "HTTP request multiplexer". Like the standard `http.Serv
* [Walking Routes](#walking-routes) * [Walking Routes](#walking-routes)
* [Graceful Shutdown](#graceful-shutdown) * [Graceful Shutdown](#graceful-shutdown)
* [Middleware](#middleware) * [Middleware](#middleware)
* [Handling CORS Requests](#handling-cors-requests)
* [Testing Handlers](#testing-handlers) * [Testing Handlers](#testing-handlers)
* [Full Example](#full-example) * [Full Example](#full-example)
@ -88,7 +90,7 @@ r := mux.NewRouter()
// Only matches if domain is "www.example.com". // Only matches if domain is "www.example.com".
r.Host("www.example.com") r.Host("www.example.com")
// Matches a dynamic subdomain. // Matches a dynamic subdomain.
r.Host("{subdomain:[a-z]+}.domain.com") r.Host("{subdomain:[a-z]+}.example.com")
``` ```
There are several other matchers that can be added. To match path prefixes: There are several other matchers that can be added. To match path prefixes:
@ -238,13 +240,13 @@ This also works for host and query value variables:
```go ```go
r := mux.NewRouter() r := mux.NewRouter()
r.Host("{subdomain}.domain.com"). r.Host("{subdomain}.example.com").
Path("/articles/{category}/{id:[0-9]+}"). Path("/articles/{category}/{id:[0-9]+}").
Queries("filter", "{filter}"). Queries("filter", "{filter}").
HandlerFunc(ArticleHandler). HandlerFunc(ArticleHandler).
Name("article") Name("article")
// url.String() will be "http://news.domain.com/articles/technology/42?filter=gorilla" // url.String() will be "http://news.example.com/articles/technology/42?filter=gorilla"
url, err := r.Get("article").URL("subdomain", "news", url, err := r.Get("article").URL("subdomain", "news",
"category", "technology", "category", "technology",
"id", "42", "id", "42",
@ -264,7 +266,7 @@ r.HeadersRegexp("Content-Type", "application/(text|json)")
There's also a way to build only the URL host or path for a route: use the methods `URLHost()` or `URLPath()` instead. For the previous route, we would do: There's also a way to build only the URL host or path for a route: use the methods `URLHost()` or `URLPath()` instead. For the previous route, we would do:
```go ```go
// "http://news.domain.com/" // "http://news.example.com/"
host, err := r.Get("article").URLHost("subdomain", "news") host, err := r.Get("article").URLHost("subdomain", "news")
// "/articles/technology/42" // "/articles/technology/42"
@ -275,12 +277,12 @@ And if you use subrouters, host and path defined separately can be built as well
```go ```go
r := mux.NewRouter() r := mux.NewRouter()
s := r.Host("{subdomain}.domain.com").Subrouter() s := r.Host("{subdomain}.example.com").Subrouter()
s.Path("/articles/{category}/{id:[0-9]+}"). s.Path("/articles/{category}/{id:[0-9]+}").
HandlerFunc(ArticleHandler). HandlerFunc(ArticleHandler).
Name("article") Name("article")
// "http://news.domain.com/articles/technology/42" // "http://news.example.com/articles/technology/42"
url, err := r.Get("article").URL("subdomain", "news", url, err := r.Get("article").URL("subdomain", "news",
"category", "technology", "category", "technology",
"id", "42") "id", "42")
@ -491,6 +493,73 @@ r.Use(amw.Middleware)
Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it. Note: The handler chain will be stopped if your middleware doesn't call `next.ServeHTTP()` with the corresponding parameters. This can be used to abort a request if the middleware writer wants to. Middlewares _should_ write to `ResponseWriter` if they _are_ going to terminate the request, and they _should not_ write to `ResponseWriter` if they _are not_ going to terminate it.
### Handling CORS Requests
[CORSMethodMiddleware](https://godoc.org/github.com/gorilla/mux#CORSMethodMiddleware) intends to make it easier to strictly set the `Access-Control-Allow-Methods` response header.
* You will still need to use your own CORS handler to set the other CORS headers such as `Access-Control-Allow-Origin`
* The middleware will set the `Access-Control-Allow-Methods` header to all the method matchers (e.g. `r.Methods(http.MethodGet, http.MethodPut, http.MethodOptions)` -> `Access-Control-Allow-Methods: GET,PUT,OPTIONS`) on a route
* If you do not specify any methods, then:
> _Important_: there must be an `OPTIONS` method matcher for the middleware to set the headers.
Here is an example of using `CORSMethodMiddleware` along with a custom `OPTIONS` handler to set all the required CORS headers:
```go
package main
import (
"net/http"
"github.com/gorilla/mux"
)
func main() {
r := mux.NewRouter()
// IMPORTANT: you must specify an OPTIONS method matcher for the middleware to set CORS headers
r.HandleFunc("/foo", fooHandler).Methods(http.MethodGet, http.MethodPut, http.MethodPatch, http.MethodOptions)
r.Use(mux.CORSMethodMiddleware(r))
http.ListenAndServe(":8080", r)
}
func fooHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
if r.Method == http.MethodOptions {
return
}
w.Write([]byte("foo"))
}
```
And an request to `/foo` using something like:
```bash
curl localhost:8080/foo -v
```
Would look like:
```bash
* Trying ::1...
* TCP_NODELAY set
* Connected to localhost (::1) port 8080 (#0)
> GET /foo HTTP/1.1
> Host: localhost:8080
> User-Agent: curl/7.59.0
> Accept: */*
>
< HTTP/1.1 200 OK
< Access-Control-Allow-Methods: GET,PUT,PATCH,OPTIONS
< Access-Control-Allow-Origin: *
< Date: Fri, 28 Jun 2019 20:13:30 GMT
< Content-Length: 3
< Content-Type: text/plain; charset=utf-8
<
* Connection #0 to host localhost left intact
foo
```
### Testing Handlers ### Testing Handlers
Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_. Testing handlers in a Go web application is straightforward, and _mux_ doesn't complicate this any further. Given two files: `endpoints.go` and `endpoints_test.go`, here's how we'd test an application using _mux_.
@ -503,8 +572,8 @@ package main
func HealthCheckHandler(w http.ResponseWriter, r *http.Request) { func HealthCheckHandler(w http.ResponseWriter, r *http.Request) {
// A very simple health check. // A very simple health check.
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
// In the future we could report back on the status of our DB, or our cache // In the future we could report back on the status of our DB, or our cache
// (e.g. Redis) by performing a simple PING, and include them in the response. // (e.g. Redis) by performing a simple PING, and include them in the response.

View File

@ -1,5 +1,3 @@
// +build go1.7
package mux package mux
import ( import (
@ -18,7 +16,3 @@ func contextSet(r *http.Request, key, val interface{}) *http.Request {
return r.WithContext(context.WithValue(r.Context(), key, val)) return r.WithContext(context.WithValue(r.Context(), key, val))
} }
func contextClear(r *http.Request) {
return
}

View File

@ -1,26 +0,0 @@
// +build !go1.7
package mux
import (
"net/http"
"github.com/gorilla/context"
)
func contextGet(r *http.Request, key interface{}) interface{} {
return context.Get(r, key)
}
func contextSet(r *http.Request, key, val interface{}) *http.Request {
if val == nil {
return r
}
context.Set(r, key, val)
return r
}
func contextClear(r *http.Request) {
context.Clear(r)
}

View File

@ -295,7 +295,7 @@ A more complex authentication middleware, which maps session token to users, cou
r := mux.NewRouter() r := mux.NewRouter()
r.HandleFunc("/", handler) r.HandleFunc("/", handler)
amw := authenticationMiddleware{} amw := authenticationMiddleware{tokenUsers: make(map[string]string)}
amw.Populate() amw.Populate()
r.Use(amw.Middleware) r.Use(amw.Middleware)

1
vendor/github.com/gorilla/mux/go.mod generated vendored Normal file
View File

@ -0,0 +1 @@
module github.com/gorilla/mux

View File

@ -32,37 +32,19 @@ func (r *Router) useInterface(mw middleware) {
r.middlewares = append(r.middlewares, mw) r.middlewares = append(r.middlewares, mw)
} }
// CORSMethodMiddleware sets the Access-Control-Allow-Methods response header // CORSMethodMiddleware automatically sets the Access-Control-Allow-Methods response header
// on a request, by matching routes based only on paths. It also handles // on requests for routes that have an OPTIONS method matcher to all the method matchers on
// OPTIONS requests, by settings Access-Control-Allow-Methods, and then // the route. Routes that do not explicitly handle OPTIONS requests will not be processed
// returning without calling the next http handler. // by the middleware. See examples for usage.
func CORSMethodMiddleware(r *Router) MiddlewareFunc { func CORSMethodMiddleware(r *Router) MiddlewareFunc {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
var allMethods []string allMethods, err := getAllMethodsForRoute(r, req)
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
for _, m := range route.matchers {
if _, ok := m.(*routeRegexp); ok {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods()
if err != nil {
return err
}
allMethods = append(allMethods, methods...)
}
break
}
}
return nil
})
if err == nil { if err == nil {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(append(allMethods, "OPTIONS"), ",")) for _, v := range allMethods {
if v == http.MethodOptions {
if req.Method == "OPTIONS" { w.Header().Set("Access-Control-Allow-Methods", strings.Join(allMethods, ","))
return }
} }
} }
@ -70,3 +52,28 @@ func CORSMethodMiddleware(r *Router) MiddlewareFunc {
}) })
} }
} }
// getAllMethodsForRoute returns all the methods from method matchers matching a given
// request.
func getAllMethodsForRoute(r *Router, req *http.Request) ([]string, error) {
var allMethods []string
err := r.Walk(func(route *Route, _ *Router, _ []*Route) error {
for _, m := range route.matchers {
if _, ok := m.(*routeRegexp); ok {
if m.Match(req, &RouteMatch{}) {
methods, err := route.GetMethods()
if err != nil {
return err
}
allMethods = append(allMethods, methods...)
}
break
}
}
return nil
})
return allMethods, err
}

129
vendor/github.com/gorilla/mux/mux.go generated vendored
View File

@ -22,7 +22,7 @@ var (
// NewRouter returns a new router instance. // NewRouter returns a new router instance.
func NewRouter() *Router { func NewRouter() *Router {
return &Router{namedRoutes: make(map[string]*Route), KeepContext: false} return &Router{namedRoutes: make(map[string]*Route)}
} }
// Router registers routes to be matched and dispatches a handler. // Router registers routes to be matched and dispatches a handler.
@ -50,24 +50,78 @@ type Router struct {
// Configurable Handler to be used when the request method does not match the route. // Configurable Handler to be used when the request method does not match the route.
MethodNotAllowedHandler http.Handler MethodNotAllowedHandler http.Handler
// Parent route, if this is a subrouter.
parent parentRoute
// Routes to be matched, in order. // Routes to be matched, in order.
routes []*Route routes []*Route
// Routes by name for URL building. // Routes by name for URL building.
namedRoutes map[string]*Route namedRoutes map[string]*Route
// See Router.StrictSlash(). This defines the flag for new routes.
strictSlash bool
// See Router.SkipClean(). This defines the flag for new routes.
skipClean bool
// If true, do not clear the request context after handling the request. // If true, do not clear the request context after handling the request.
// This has no effect when go1.7+ is used, since the context is stored //
// Deprecated: No effect when go1.7+ is used, since the context is stored
// on the request itself. // on the request itself.
KeepContext bool KeepContext bool
// see Router.UseEncodedPath(). This defines a flag for all routes.
useEncodedPath bool
// Slice of middlewares to be called after a match is found // Slice of middlewares to be called after a match is found
middlewares []middleware middlewares []middleware
// configuration shared with `Route`
routeConf
}
// common route configuration shared between `Router` and `Route`
type routeConf struct {
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to"
useEncodedPath bool
// If true, when the path pattern is "/path/", accessing "/path" will
// redirect to the former and vice versa.
strictSlash bool
// If true, when the path pattern is "/path//to", accessing "/path//to"
// will not redirect
skipClean bool
// Manager for the variables from host and path.
regexp routeRegexpGroup
// List of matchers.
matchers []matcher
// The scheme used when building URLs.
buildScheme string
buildVarsFunc BuildVarsFunc
}
// returns an effective deep copy of `routeConf`
func copyRouteConf(r routeConf) routeConf {
c := r
if r.regexp.path != nil {
c.regexp.path = copyRouteRegexp(r.regexp.path)
}
if r.regexp.host != nil {
c.regexp.host = copyRouteRegexp(r.regexp.host)
}
c.regexp.queries = make([]*routeRegexp, 0, len(r.regexp.queries))
for _, q := range r.regexp.queries {
c.regexp.queries = append(c.regexp.queries, copyRouteRegexp(q))
}
c.matchers = make([]matcher, 0, len(r.matchers))
for _, m := range r.matchers {
c.matchers = append(c.matchers, m)
}
return c
}
func copyRouteRegexp(r *routeRegexp) *routeRegexp {
c := *r
return &c
} }
// Match attempts to match the given request against the router's registered routes. // Match attempts to match the given request against the router's registered routes.
@ -155,22 +209,18 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
handler = http.NotFoundHandler() handler = http.NotFoundHandler()
} }
if !r.KeepContext {
defer contextClear(req)
}
handler.ServeHTTP(w, req) handler.ServeHTTP(w, req)
} }
// Get returns a route registered with the given name. // Get returns a route registered with the given name.
func (r *Router) Get(name string) *Route { func (r *Router) Get(name string) *Route {
return r.getNamedRoutes()[name] return r.namedRoutes[name]
} }
// GetRoute returns a route registered with the given name. This method // GetRoute returns a route registered with the given name. This method
// was renamed to Get() and remains here for backwards compatibility. // was renamed to Get() and remains here for backwards compatibility.
func (r *Router) GetRoute(name string) *Route { func (r *Router) GetRoute(name string) *Route {
return r.getNamedRoutes()[name] return r.namedRoutes[name]
} }
// StrictSlash defines the trailing slash behavior for new routes. The initial // StrictSlash defines the trailing slash behavior for new routes. The initial
@ -221,55 +271,24 @@ func (r *Router) UseEncodedPath() *Router {
return r return r
} }
// ----------------------------------------------------------------------------
// parentRoute
// ----------------------------------------------------------------------------
func (r *Router) getBuildScheme() string {
if r.parent != nil {
return r.parent.getBuildScheme()
}
return ""
}
// getNamedRoutes returns the map where named routes are registered.
func (r *Router) getNamedRoutes() map[string]*Route {
if r.namedRoutes == nil {
if r.parent != nil {
r.namedRoutes = r.parent.getNamedRoutes()
} else {
r.namedRoutes = make(map[string]*Route)
}
}
return r.namedRoutes
}
// getRegexpGroup returns regexp definitions from the parent route, if any.
func (r *Router) getRegexpGroup() *routeRegexpGroup {
if r.parent != nil {
return r.parent.getRegexpGroup()
}
return nil
}
func (r *Router) buildVars(m map[string]string) map[string]string {
if r.parent != nil {
m = r.parent.buildVars(m)
}
return m
}
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Route factories // Route factories
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// NewRoute registers an empty route. // NewRoute registers an empty route.
func (r *Router) NewRoute() *Route { func (r *Router) NewRoute() *Route {
route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath} // initialize a route with a copy of the parent router's configuration
route := &Route{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes}
r.routes = append(r.routes, route) r.routes = append(r.routes, route)
return route return route
} }
// Name registers a new route with a name.
// See Route.Name().
func (r *Router) Name(name string) *Route {
return r.NewRoute().Name(name)
}
// Handle registers a new route with a matcher for the URL path. // Handle registers a new route with a matcher for the URL path.
// See Route.Path() and Route.Handler(). // See Route.Path() and Route.Handler().
func (r *Router) Handle(path string, handler http.Handler) *Route { func (r *Router) Handle(path string, handler http.Handler) *Route {

View File

@ -113,6 +113,13 @@ func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*ro
if typ != regexpTypePrefix { if typ != regexpTypePrefix {
pattern.WriteByte('$') pattern.WriteByte('$')
} }
var wildcardHostPort bool
if typ == regexpTypeHost {
if !strings.Contains(pattern.String(), ":") {
wildcardHostPort = true
}
}
reverse.WriteString(raw) reverse.WriteString(raw)
if endSlash { if endSlash {
reverse.WriteByte('/') reverse.WriteByte('/')
@ -131,13 +138,14 @@ func newRouteRegexp(tpl string, typ regexpType, options routeRegexpOptions) (*ro
// Done! // Done!
return &routeRegexp{ return &routeRegexp{
template: template, template: template,
regexpType: typ, regexpType: typ,
options: options, options: options,
regexp: reg, regexp: reg,
reverse: reverse.String(), reverse: reverse.String(),
varsN: varsN, varsN: varsN,
varsR: varsR, varsR: varsR,
wildcardHostPort: wildcardHostPort,
}, nil }, nil
} }
@ -158,11 +166,22 @@ type routeRegexp struct {
varsN []string varsN []string
// Variable regexps (validators). // Variable regexps (validators).
varsR []*regexp.Regexp varsR []*regexp.Regexp
// Wildcard host-port (no strict port match in hostname)
wildcardHostPort bool
} }
// Match matches the regexp against the URL host or path. // Match matches the regexp against the URL host or path.
func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool { func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
if r.regexpType != regexpTypeHost { if r.regexpType == regexpTypeHost {
host := getHost(req)
if r.wildcardHostPort {
// Don't be strict on the port match
if i := strings.Index(host, ":"); i != -1 {
host = host[:i]
}
}
return r.regexp.MatchString(host)
} else {
if r.regexpType == regexpTypeQuery { if r.regexpType == regexpTypeQuery {
return r.matchQueryString(req) return r.matchQueryString(req)
} }
@ -172,8 +191,6 @@ func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
} }
return r.regexp.MatchString(path) return r.regexp.MatchString(path)
} }
return r.regexp.MatchString(getHost(req))
} }
// url builds a URL part using the given values. // url builds a URL part using the given values.
@ -267,7 +284,7 @@ type routeRegexpGroup struct {
} }
// setMatch extracts the variables from the URL once a route matches. // setMatch extracts the variables from the URL once a route matches.
func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) { func (v routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
// Store host variables. // Store host variables.
if v.host != nil { if v.host != nil {
host := getHost(req) host := getHost(req)
@ -296,7 +313,7 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route)
} else { } else {
u.Path += "/" u.Path += "/"
} }
m.Handler = http.RedirectHandler(u.String(), 301) m.Handler = http.RedirectHandler(u.String(), http.StatusMovedPermanently)
} }
} }
} }
@ -312,17 +329,13 @@ func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route)
} }
// getHost tries its best to return the request host. // getHost tries its best to return the request host.
// According to section 14.23 of RFC 2616 the Host header
// can include the port number if the default value of 80 is not used.
func getHost(r *http.Request) string { func getHost(r *http.Request) string {
if r.URL.IsAbs() { if r.URL.IsAbs() {
return r.URL.Host return r.URL.Host
} }
host := r.Host return r.Host
// Slice off any port information.
if i := strings.Index(host, ":"); i != -1 {
host = host[:i]
}
return host
} }
func extractVars(input string, matches []int, names []string, output map[string]string) { func extractVars(input string, matches []int, names []string, output map[string]string) {

View File

@ -15,24 +15,8 @@ import (
// Route stores information to match a request and build URLs. // Route stores information to match a request and build URLs.
type Route struct { type Route struct {
// Parent where the route was registered (a Router).
parent parentRoute
// Request handler for the route. // Request handler for the route.
handler http.Handler handler http.Handler
// List of matchers.
matchers []matcher
// Manager for the variables from host and path.
regexp *routeRegexpGroup
// If true, when the path pattern is "/path/", accessing "/path" will
// redirect to the former and vice versa.
strictSlash bool
// If true, when the path pattern is "/path//to", accessing "/path//to"
// will not redirect
skipClean bool
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to"
useEncodedPath bool
// The scheme used when building URLs.
buildScheme string
// If true, this route never matches: it is only used to build URLs. // If true, this route never matches: it is only used to build URLs.
buildOnly bool buildOnly bool
// The name used to build URLs. // The name used to build URLs.
@ -40,7 +24,11 @@ type Route struct {
// Error resulted from building a route. // Error resulted from building a route.
err error err error
buildVarsFunc BuildVarsFunc // "global" reference to all named routes
namedRoutes map[string]*Route
// config possibly passed in from `Router`
routeConf
} }
// SkipClean reports whether path cleaning is enabled for this route via // SkipClean reports whether path cleaning is enabled for this route via
@ -64,6 +52,18 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
matchErr = ErrMethodMismatch matchErr = ErrMethodMismatch
continue continue
} }
// Ignore ErrNotFound errors. These errors arise from match call
// to Subrouters.
//
// This prevents subsequent matching subrouters from failing to
// run middleware. If not ignored, the middleware would see a
// non-nil MatchErr and be skipped, even when there was a
// matching route.
if match.MatchErr == ErrNotFound {
match.MatchErr = nil
}
matchErr = nil matchErr = nil
return false return false
} }
@ -93,9 +93,7 @@ func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
} }
// Set variables. // Set variables.
if r.regexp != nil { r.regexp.setMatch(req, match, r)
r.regexp.setMatch(req, match, r)
}
return true return true
} }
@ -137,7 +135,7 @@ func (r *Route) GetHandler() http.Handler {
// Name ----------------------------------------------------------------------- // Name -----------------------------------------------------------------------
// Name sets the name for the route, used to build URLs. // Name sets the name for the route, used to build URLs.
// If the name was registered already it will be overwritten. // It is an error to call Name more than once on a route.
func (r *Route) Name(name string) *Route { func (r *Route) Name(name string) *Route {
if r.name != "" { if r.name != "" {
r.err = fmt.Errorf("mux: route already has name %q, can't set %q", r.err = fmt.Errorf("mux: route already has name %q, can't set %q",
@ -145,7 +143,7 @@ func (r *Route) Name(name string) *Route {
} }
if r.err == nil { if r.err == nil {
r.name = name r.name = name
r.getNamedRoutes()[name] = r r.namedRoutes[name] = r
} }
return r return r
} }
@ -177,7 +175,6 @@ func (r *Route) addRegexpMatcher(tpl string, typ regexpType) error {
if r.err != nil { if r.err != nil {
return r.err return r.err
} }
r.regexp = r.getRegexpGroup()
if typ == regexpTypePath || typ == regexpTypePrefix { if typ == regexpTypePath || typ == regexpTypePrefix {
if len(tpl) > 0 && tpl[0] != '/' { if len(tpl) > 0 && tpl[0] != '/' {
return fmt.Errorf("mux: path must start with a slash, got %q", tpl) return fmt.Errorf("mux: path must start with a slash, got %q", tpl)
@ -386,7 +383,7 @@ func (r *Route) PathPrefix(tpl string) *Route {
// The above route will only match if the URL contains the defined queries // The above route will only match if the URL contains the defined queries
// values, e.g.: ?foo=bar&id=42. // values, e.g.: ?foo=bar&id=42.
// //
// It the value is an empty string, it will match any value if the key is set. // If the value is an empty string, it will match any value if the key is set.
// //
// Variables can define an optional regexp pattern to be matched: // Variables can define an optional regexp pattern to be matched:
// //
@ -424,7 +421,7 @@ func (r *Route) Schemes(schemes ...string) *Route {
for k, v := range schemes { for k, v := range schemes {
schemes[k] = strings.ToLower(v) schemes[k] = strings.ToLower(v)
} }
if r.buildScheme == "" && len(schemes) > 0 { if len(schemes) > 0 {
r.buildScheme = schemes[0] r.buildScheme = schemes[0]
} }
return r.addMatcher(schemeMatcher(schemes)) return r.addMatcher(schemeMatcher(schemes))
@ -439,7 +436,15 @@ type BuildVarsFunc func(map[string]string) map[string]string
// BuildVarsFunc adds a custom function to be used to modify build variables // BuildVarsFunc adds a custom function to be used to modify build variables
// before a route's URL is built. // before a route's URL is built.
func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route { func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route {
r.buildVarsFunc = f if r.buildVarsFunc != nil {
// compose the old and new functions
old := r.buildVarsFunc
r.buildVarsFunc = func(m map[string]string) map[string]string {
return f(old(m))
}
} else {
r.buildVarsFunc = f
}
return r return r
} }
@ -458,7 +463,8 @@ func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route {
// Here, the routes registered in the subrouter won't be tested if the host // Here, the routes registered in the subrouter won't be tested if the host
// doesn't match. // doesn't match.
func (r *Route) Subrouter() *Router { func (r *Route) Subrouter() *Router {
router := &Router{parent: r, strictSlash: r.strictSlash} // initialize a subrouter with a copy of the parent route's configuration
router := &Router{routeConf: copyRouteConf(r.routeConf), namedRoutes: r.namedRoutes}
r.addMatcher(router) r.addMatcher(router)
return router return router
} }
@ -502,9 +508,6 @@ func (r *Route) URL(pairs ...string) (*url.URL, error) {
if r.err != nil { if r.err != nil {
return nil, r.err return nil, r.err
} }
if r.regexp == nil {
return nil, errors.New("mux: route doesn't have a host or path")
}
values, err := r.prepareVars(pairs...) values, err := r.prepareVars(pairs...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -516,8 +519,8 @@ func (r *Route) URL(pairs ...string) (*url.URL, error) {
return nil, err return nil, err
} }
scheme = "http" scheme = "http"
if s := r.getBuildScheme(); s != "" { if r.buildScheme != "" {
scheme = s scheme = r.buildScheme
} }
} }
if r.regexp.path != nil { if r.regexp.path != nil {
@ -547,7 +550,7 @@ func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
if r.err != nil { if r.err != nil {
return nil, r.err return nil, r.err
} }
if r.regexp == nil || r.regexp.host == nil { if r.regexp.host == nil {
return nil, errors.New("mux: route doesn't have a host") return nil, errors.New("mux: route doesn't have a host")
} }
values, err := r.prepareVars(pairs...) values, err := r.prepareVars(pairs...)
@ -562,8 +565,8 @@ func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
Scheme: "http", Scheme: "http",
Host: host, Host: host,
} }
if s := r.getBuildScheme(); s != "" { if r.buildScheme != "" {
u.Scheme = s u.Scheme = r.buildScheme
} }
return u, nil return u, nil
} }
@ -575,7 +578,7 @@ func (r *Route) URLPath(pairs ...string) (*url.URL, error) {
if r.err != nil { if r.err != nil {
return nil, r.err return nil, r.err
} }
if r.regexp == nil || r.regexp.path == nil { if r.regexp.path == nil {
return nil, errors.New("mux: route doesn't have a path") return nil, errors.New("mux: route doesn't have a path")
} }
values, err := r.prepareVars(pairs...) values, err := r.prepareVars(pairs...)
@ -600,7 +603,7 @@ func (r *Route) GetPathTemplate() (string, error) {
if r.err != nil { if r.err != nil {
return "", r.err return "", r.err
} }
if r.regexp == nil || r.regexp.path == nil { if r.regexp.path == nil {
return "", errors.New("mux: route doesn't have a path") return "", errors.New("mux: route doesn't have a path")
} }
return r.regexp.path.template, nil return r.regexp.path.template, nil
@ -614,7 +617,7 @@ func (r *Route) GetPathRegexp() (string, error) {
if r.err != nil { if r.err != nil {
return "", r.err return "", r.err
} }
if r.regexp == nil || r.regexp.path == nil { if r.regexp.path == nil {
return "", errors.New("mux: route does not have a path") return "", errors.New("mux: route does not have a path")
} }
return r.regexp.path.regexp.String(), nil return r.regexp.path.regexp.String(), nil
@ -629,7 +632,7 @@ func (r *Route) GetQueriesRegexp() ([]string, error) {
if r.err != nil { if r.err != nil {
return nil, r.err return nil, r.err
} }
if r.regexp == nil || r.regexp.queries == nil { if r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries") return nil, errors.New("mux: route doesn't have queries")
} }
var queries []string var queries []string
@ -648,7 +651,7 @@ func (r *Route) GetQueriesTemplates() ([]string, error) {
if r.err != nil { if r.err != nil {
return nil, r.err return nil, r.err
} }
if r.regexp == nil || r.regexp.queries == nil { if r.regexp.queries == nil {
return nil, errors.New("mux: route doesn't have queries") return nil, errors.New("mux: route doesn't have queries")
} }
var queries []string var queries []string
@ -683,7 +686,7 @@ func (r *Route) GetHostTemplate() (string, error) {
if r.err != nil { if r.err != nil {
return "", r.err return "", r.err
} }
if r.regexp == nil || r.regexp.host == nil { if r.regexp.host == nil {
return "", errors.New("mux: route doesn't have a host") return "", errors.New("mux: route doesn't have a host")
} }
return r.regexp.host.template, nil return r.regexp.host.template, nil
@ -700,64 +703,8 @@ func (r *Route) prepareVars(pairs ...string) (map[string]string, error) {
} }
func (r *Route) buildVars(m map[string]string) map[string]string { func (r *Route) buildVars(m map[string]string) map[string]string {
if r.parent != nil {
m = r.parent.buildVars(m)
}
if r.buildVarsFunc != nil { if r.buildVarsFunc != nil {
m = r.buildVarsFunc(m) m = r.buildVarsFunc(m)
} }
return m return m
} }
// ----------------------------------------------------------------------------
// parentRoute
// ----------------------------------------------------------------------------
// parentRoute allows routes to know about parent host and path definitions.
type parentRoute interface {
getBuildScheme() string
getNamedRoutes() map[string]*Route
getRegexpGroup() *routeRegexpGroup
buildVars(map[string]string) map[string]string
}
func (r *Route) getBuildScheme() string {
if r.buildScheme != "" {
return r.buildScheme
}
if r.parent != nil {
return r.parent.getBuildScheme()
}
return ""
}
// getNamedRoutes returns the map where named routes are registered.
func (r *Route) getNamedRoutes() map[string]*Route {
if r.parent == nil {
// During tests router is not always set.
r.parent = NewRouter()
}
return r.parent.getNamedRoutes()
}
// getRegexpGroup returns regexp definitions from this route.
func (r *Route) getRegexpGroup() *routeRegexpGroup {
if r.regexp == nil {
if r.parent == nil {
// During tests router is not always set.
r.parent = NewRouter()
}
regexp := r.parent.getRegexpGroup()
if regexp == nil {
r.regexp = new(routeRegexpGroup)
} else {
// Copy.
r.regexp = &routeRegexpGroup{
host: regexp.host,
path: regexp.path,
queries: regexp.queries,
}
}
}
return r.regexp
}

View File

@ -3,11 +3,11 @@ sudo: false
matrix: matrix:
include: include:
- go: 1.4 - go: 1.7.x
- go: 1.5 - go: 1.8.x
- go: 1.6 - go: 1.9.x
- go: 1.7 - go: 1.10.x
- go: 1.8 - go: 1.11.x
- go: tip - go: tip
allow_failures: allow_failures:
- go: tip - go: tip

View File

@ -4,5 +4,6 @@
# Please keep the list sorted. # Please keep the list sorted.
Gary Burd <gary@beagledreams.com> Gary Burd <gary@beagledreams.com>
Google LLC (https://opensource.google.com/)
Joachim Bauch <mail@joachim-bauch.de> Joachim Bauch <mail@joachim-bauch.de>

View File

@ -5,15 +5,15 @@
package websocket package websocket
import ( import (
"bufio"
"bytes" "bytes"
"context"
"crypto/tls" "crypto/tls"
"encoding/base64"
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
"net/http" "net/http"
"net/http/httptrace"
"net/url" "net/url"
"strings" "strings"
"time" "time"
@ -53,6 +53,10 @@ type Dialer struct {
// NetDial is nil, net.Dial is used. // NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error) NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, net.DialContext is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given // Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the // Request. If the function returns a non-nil error, the
// request is aborted with the provided error. // request is aborted with the provided error.
@ -71,6 +75,17 @@ type Dialer struct {
// do not limit the size of the messages that can be sent or received. // do not limit the size of the messages that can be sent or received.
ReadBufferSize, WriteBufferSize int ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the client's requested subprotocols. // Subprotocols specifies the client's requested subprotocols.
Subprotocols []string Subprotocols []string
@ -86,52 +101,13 @@ type Dialer struct {
Jar http.CookieJar Jar http.CookieJar
} }
var errMalformedURL = errors.New("malformed ws or wss URL") // Dial creates a new client connection by calling DialContext with a background context.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
// parseURL parses the URL. return d.DialContext(context.Background(), urlStr, requestHeader)
//
// This function is a replacement for the standard library url.Parse function.
// In Go 1.4 and earlier, url.Parse loses information from the path.
func parseURL(s string) (*url.URL, error) {
// From the RFC:
//
// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
var u url.URL
switch {
case strings.HasPrefix(s, "ws://"):
u.Scheme = "ws"
s = s[len("ws://"):]
case strings.HasPrefix(s, "wss://"):
u.Scheme = "wss"
s = s[len("wss://"):]
default:
return nil, errMalformedURL
}
if i := strings.Index(s, "?"); i >= 0 {
u.RawQuery = s[i+1:]
s = s[:i]
}
if i := strings.Index(s, "/"); i >= 0 {
u.Opaque = s[i:]
s = s[:i]
} else {
u.Opaque = "/"
}
u.Host = s
if strings.Contains(u.Host, "@") {
// Don't bother parsing user information because user information is
// not allowed in websocket URIs.
return nil, errMalformedURL
}
return &u, nil
} }
var errMalformedURL = errors.New("malformed ws or wss URL")
func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
hostPort = u.Host hostPort = u.Host
hostNoPort = u.Host hostNoPort = u.Host
@ -150,26 +126,29 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
return hostPort, hostNoPort return hostPort, hostNoPort
} }
// DefaultDialer is a dialer with all fields set to the default zero values. // DefaultDialer is a dialer with all fields set to the default values.
var DefaultDialer = &Dialer{ var DefaultDialer = &Dialer{
Proxy: http.ProxyFromEnvironment, Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: 45 * time.Second,
} }
// Dial creates a new client connection. Use requestHeader to specify the // nilDialer is dialer to use when receiver is nil.
var nilDialer = *DefaultDialer
// DialContext creates a new client connection. Use requestHeader to specify the
// origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
// Use the response.Header to get the selected subprotocol // Use the response.Header to get the selected subprotocol
// (Sec-WebSocket-Protocol) and cookies (Set-Cookie). // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
// //
// The context will be used in the request and in the Dialer
//
// If the WebSocket handshake fails, ErrBadHandshake is returned along with a // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
// non-nil *http.Response so that callers can handle redirects, authentication, // non-nil *http.Response so that callers can handle redirects, authentication,
// etcetera. The response body may not contain the entire response and does not // etcetera. The response body may not contain the entire response and does not
// need to be closed by the application. // need to be closed by the application.
func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
if d == nil { if d == nil {
d = &Dialer{ d = &nilDialer
Proxy: http.ProxyFromEnvironment,
}
} }
challengeKey, err := generateChallengeKey() challengeKey, err := generateChallengeKey()
@ -177,7 +156,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
return nil, nil, err return nil, nil, err
} }
u, err := parseURL(urlStr) u, err := url.Parse(urlStr)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -205,6 +184,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
Header: make(http.Header), Header: make(http.Header),
Host: u.Host, Host: u.Host,
} }
req = req.WithContext(ctx)
// Set the cookies present in the cookie jar of the dialer // Set the cookies present in the cookie jar of the dialer
if d.Jar != nil { if d.Jar != nil {
@ -237,45 +217,83 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
k == "Sec-Websocket-Extensions" || k == "Sec-Websocket-Extensions" ||
(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
case k == "Sec-Websocket-Protocol":
req.Header["Sec-WebSocket-Protocol"] = vs
default: default:
req.Header[k] = vs req.Header[k] = vs
} }
} }
if d.EnableCompression { if d.EnableCompression {
req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") req.Header["Sec-WebSocket-Extensions"] = []string{"permessage-deflate; server_no_context_takeover; client_no_context_takeover"}
}
if d.HandshakeTimeout != 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, d.HandshakeTimeout)
defer cancel()
}
// Get network dial function.
var netDial func(network, add string) (net.Conn, error)
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
} else {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
}
}
// If needed, wrap the dial function to set the connection deadline.
if deadline, ok := ctx.Deadline(); ok {
forwardDial := netDial
netDial = func(network, addr string) (net.Conn, error) {
c, err := forwardDial(network, addr)
if err != nil {
return nil, err
}
err = c.SetDeadline(deadline)
if err != nil {
c.Close()
return nil, err
}
return c, nil
}
}
// If needed, wrap the dial function to connect through a proxy.
if d.Proxy != nil {
proxyURL, err := d.Proxy(req)
if err != nil {
return nil, nil, err
}
if proxyURL != nil {
dialer, err := proxy_FromURL(proxyURL, netDialerFunc(netDial))
if err != nil {
return nil, nil, err
}
netDial = dialer.Dial
}
} }
hostPort, hostNoPort := hostPortNoPort(u) hostPort, hostNoPort := hostPortNoPort(u)
trace := httptrace.ContextClientTrace(ctx)
var proxyURL *url.URL if trace != nil && trace.GetConn != nil {
// Check wether the proxy method has been configured trace.GetConn(hostPort)
if d.Proxy != nil {
proxyURL, err = d.Proxy(req)
}
if err != nil {
return nil, nil, err
} }
var targetHostPort string netConn, err := netDial("tcp", hostPort)
if proxyURL != nil { if trace != nil && trace.GotConn != nil {
targetHostPort, _ = hostPortNoPort(proxyURL) trace.GotConn(httptrace.GotConnInfo{
} else { Conn: netConn,
targetHostPort = hostPort })
} }
var deadline time.Time
if d.HandshakeTimeout != 0 {
deadline = time.Now().Add(d.HandshakeTimeout)
}
netDial := d.NetDial
if netDial == nil {
netDialer := &net.Dialer{Deadline: deadline}
netDial = netDialer.Dial
}
netConn, err := netDial("tcp", targetHostPort)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -286,42 +304,6 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
} }
}() }()
if err := netConn.SetDeadline(deadline); err != nil {
return nil, nil, err
}
if proxyURL != nil {
connectHeader := make(http.Header)
if user := proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}
connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: hostPort},
Host: hostPort,
Header: connectHeader,
}
connectReq.Write(netConn)
// Read response.
// Okay to use and discard buffered reader here, because
// TLS server will not speak until spoken to.
br := bufio.NewReader(netConn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
return nil, nil, err
}
if resp.StatusCode != 200 {
f := strings.SplitN(resp.Status, " ", 2)
return nil, nil, errors.New(f[1])
}
}
if u.Scheme == "https" { if u.Scheme == "https" {
cfg := cloneTLSConfig(d.TLSClientConfig) cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" { if cfg.ServerName == "" {
@ -329,22 +311,31 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
} }
tlsConn := tls.Client(netConn, cfg) tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn netConn = tlsConn
if err := tlsConn.Handshake(); err != nil {
return nil, nil, err var err error
if trace != nil {
err = doHandshakeWithTrace(trace, tlsConn, cfg)
} else {
err = doHandshake(tlsConn, cfg)
} }
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil { if err != nil {
return nil, nil, err return nil, nil, err
}
} }
} }
conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize, d.WriteBufferPool, nil, nil)
if err := req.Write(netConn); err != nil { if err := req.Write(netConn); err != nil {
return nil, nil, err return nil, nil, err
} }
if trace != nil && trace.GotFirstResponseByte != nil {
if peek, err := conn.br.Peek(1); err == nil && len(peek) == 1 {
trace.GotFirstResponseByte()
}
}
resp, err := http.ReadResponse(conn.br, req) resp, err := http.ReadResponse(conn.br, req)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -390,3 +381,15 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re
netConn = nil // to avoid close in defer. netConn = nil // to avoid close in defer.
return conn, resp, nil return conn, resp, nil
} }
func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View File

@ -76,7 +76,7 @@ const (
// is UTF-8 encoded text. // is UTF-8 encoded text.
PingMessage = 9 PingMessage = 9
// PongMessage denotes a ping control message. The optional message payload // PongMessage denotes a pong control message. The optional message payload
// is UTF-8 encoded text. // is UTF-8 encoded text.
PongMessage = 10 PongMessage = 10
) )
@ -100,9 +100,8 @@ func (e *netError) Error() string { return e.msg }
func (e *netError) Temporary() bool { return e.temporary } func (e *netError) Temporary() bool { return e.temporary }
func (e *netError) Timeout() bool { return e.timeout } func (e *netError) Timeout() bool { return e.timeout }
// CloseError represents close frame. // CloseError represents a close message.
type CloseError struct { type CloseError struct {
// Code is defined in RFC 6455, section 11.7. // Code is defined in RFC 6455, section 11.7.
Code int Code int
@ -224,6 +223,20 @@ func isValidReceivedCloseCode(code int) bool {
return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
} }
// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this
// interface. The type of the value stored in a pool is not specified.
type BufferPool interface {
// Get gets a value from the pool or returns nil if the pool is empty.
Get() interface{}
// Put adds a value to the pool.
Put(interface{})
}
// writePoolData is the type added to the write buffer pool. This wrapper is
// used to prevent applications from peeking at and depending on the values
// added to the pool.
type writePoolData struct{ buf []byte }
// The Conn type represents a WebSocket connection. // The Conn type represents a WebSocket connection.
type Conn struct { type Conn struct {
conn net.Conn conn net.Conn
@ -233,6 +246,8 @@ type Conn struct {
// Write fields // Write fields
mu chan bool // used as mutex to protect write to conn mu chan bool // used as mutex to protect write to conn
writeBuf []byte // frame is constructed in this buffer. writeBuf []byte // frame is constructed in this buffer.
writePool BufferPool
writeBufSize int
writeDeadline time.Time writeDeadline time.Time
writer io.WriteCloser // the current writer returned to the application writer io.WriteCloser // the current writer returned to the application
isWriting bool // for best-effort concurrent write detection isWriting bool // for best-effort concurrent write detection
@ -264,64 +279,29 @@ type Conn struct {
newDecompressionReader func(io.Reader) io.ReadCloser newDecompressionReader func(io.Reader) io.ReadCloser
} }
func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int) *Conn { func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn {
return newConnBRW(conn, isServer, readBufferSize, writeBufferSize, nil)
}
type writeHook struct {
p []byte
}
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, brw *bufio.ReadWriter) *Conn {
mu := make(chan bool, 1)
mu <- true
var br *bufio.Reader
if readBufferSize == 0 && brw != nil && brw.Reader != nil {
// Reuse the supplied bufio.Reader if the buffer has a useful size.
// This code assumes that peek on a reader returns
// bufio.Reader.buf[:0].
brw.Reader.Reset(conn)
if p, err := brw.Reader.Peek(0); err == nil && cap(p) >= 256 {
br = brw.Reader
}
}
if br == nil { if br == nil {
if readBufferSize == 0 { if readBufferSize == 0 {
readBufferSize = defaultReadBufferSize readBufferSize = defaultReadBufferSize
} } else if readBufferSize < maxControlFramePayloadSize {
if readBufferSize < maxControlFramePayloadSize { // must be large enough for control frame
readBufferSize = maxControlFramePayloadSize readBufferSize = maxControlFramePayloadSize
} }
br = bufio.NewReaderSize(conn, readBufferSize) br = bufio.NewReaderSize(conn, readBufferSize)
} }
var writeBuf []byte if writeBufferSize <= 0 {
if writeBufferSize == 0 && brw != nil && brw.Writer != nil { writeBufferSize = defaultWriteBufferSize
// Use the bufio.Writer's buffer if the buffer has a useful size. This }
// code assumes that bufio.Writer.buf[:1] is passed to the writeBufferSize += maxFrameHeaderSize
// bufio.Writer's underlying writer.
var wh writeHook if writeBuf == nil && writeBufferPool == nil {
brw.Writer.Reset(&wh) writeBuf = make([]byte, writeBufferSize)
brw.Writer.WriteByte(0)
brw.Flush()
if cap(wh.p) >= maxFrameHeaderSize+256 {
writeBuf = wh.p[:cap(wh.p)]
}
}
if writeBuf == nil {
if writeBufferSize == 0 {
writeBufferSize = defaultWriteBufferSize
}
writeBuf = make([]byte, writeBufferSize+maxFrameHeaderSize)
} }
mu := make(chan bool, 1)
mu <- true
c := &Conn{ c := &Conn{
isServer: isServer, isServer: isServer,
br: br, br: br,
@ -329,6 +309,8 @@ func newConnBRW(conn net.Conn, isServer bool, readBufferSize, writeBufferSize in
mu: mu, mu: mu,
readFinal: true, readFinal: true,
writeBuf: writeBuf, writeBuf: writeBuf,
writePool: writeBufferPool,
writeBufSize: writeBufferSize,
enableWriteCompression: true, enableWriteCompression: true,
compressionLevel: defaultCompressionLevel, compressionLevel: defaultCompressionLevel,
} }
@ -343,7 +325,8 @@ func (c *Conn) Subprotocol() string {
return c.subprotocol return c.subprotocol
} }
// Close closes the underlying network connection without sending or waiting for a close frame. // Close closes the underlying network connection without sending or waiting
// for a close message.
func (c *Conn) Close() error { func (c *Conn) Close() error {
return c.conn.Close() return c.conn.Close()
} }
@ -370,7 +353,16 @@ func (c *Conn) writeFatal(err error) error {
return err return err
} }
func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error { func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
c.br.Discard(len(p))
return p, err
}
func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error {
<-c.mu <-c.mu
defer func() { c.mu <- true }() defer func() { c.mu <- true }()
@ -382,15 +374,14 @@ func (c *Conn) write(frameType int, deadline time.Time, bufs ...[]byte) error {
} }
c.conn.SetWriteDeadline(deadline) c.conn.SetWriteDeadline(deadline)
for _, buf := range bufs { if len(buf1) == 0 {
if len(buf) > 0 { _, err = c.conn.Write(buf0)
_, err := c.conn.Write(buf) } else {
if err != nil { err = c.writeBufs(buf0, buf1)
return c.writeFatal(err) }
} if err != nil {
} return c.writeFatal(err)
} }
if frameType == CloseMessage { if frameType == CloseMessage {
c.writeFatal(ErrCloseSent) c.writeFatal(ErrCloseSent)
} }
@ -476,7 +467,19 @@ func (c *Conn) prepWrite(messageType int) error {
c.writeErrMu.Lock() c.writeErrMu.Lock()
err := c.writeErr err := c.writeErr
c.writeErrMu.Unlock() c.writeErrMu.Unlock()
return err if err != nil {
return err
}
if c.writeBuf == nil {
wpd, ok := c.writePool.Get().(writePoolData)
if ok {
c.writeBuf = wpd.buf
} else {
c.writeBuf = make([]byte, c.writeBufSize)
}
}
return nil
} }
// NextWriter returns a writer for the next message to send. The writer's Close // NextWriter returns a writer for the next message to send. The writer's Close
@ -484,6 +487,9 @@ func (c *Conn) prepWrite(messageType int) error {
// //
// There can be at most one open writer on a connection. NextWriter closes the // There can be at most one open writer on a connection. NextWriter closes the
// previous writer if the application has not already done so. // previous writer if the application has not already done so.
//
// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and
// PongMessage) are supported.
func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) {
if err := c.prepWrite(messageType); err != nil { if err := c.prepWrite(messageType); err != nil {
return nil, err return nil, err
@ -599,6 +605,10 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error {
if final { if final {
c.writer = nil c.writer = nil
if c.writePool != nil {
c.writePool.Put(writePoolData{buf: c.writeBuf})
c.writeBuf = nil
}
return nil return nil
} }
@ -764,7 +774,6 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
// Read methods // Read methods
func (c *Conn) advanceFrame() (int, error) { func (c *Conn) advanceFrame() (int, error) {
// 1. Skip remainder of previous frame. // 1. Skip remainder of previous frame.
if c.readRemaining > 0 { if c.readRemaining > 0 {
@ -1033,7 +1042,7 @@ func (c *Conn) SetReadDeadline(t time.Time) error {
} }
// SetReadLimit sets the maximum size for a message read from the peer. If a // SetReadLimit sets the maximum size for a message read from the peer. If a
// message exceeds the limit, the connection sends a close frame to the peer // message exceeds the limit, the connection sends a close message to the peer
// and returns ErrReadLimit to the application. // and returns ErrReadLimit to the application.
func (c *Conn) SetReadLimit(limit int64) { func (c *Conn) SetReadLimit(limit int64) {
c.readLimit = limit c.readLimit = limit
@ -1046,24 +1055,22 @@ func (c *Conn) CloseHandler() func(code int, text string) error {
// SetCloseHandler sets the handler for close messages received from the peer. // SetCloseHandler sets the handler for close messages received from the peer.
// The code argument to h is the received close code or CloseNoStatusReceived // The code argument to h is the received close code or CloseNoStatusReceived
// if the close message is empty. The default close handler sends a close frame // if the close message is empty. The default close handler sends a close
// back to the peer. // message back to the peer.
// //
// The application must read the connection to process close messages as // The handler function is called from the NextReader, ReadMessage and message
// described in the section on Control Frames above. // reader Read methods. The application must read the connection to process
// close messages as described in the section on Control Messages above.
// //
// The connection read methods return a CloseError when a close frame is // The connection read methods return a CloseError when a close message is
// received. Most applications should handle close messages as part of their // received. Most applications should handle close messages as part of their
// normal error handling. Applications should only set a close handler when the // normal error handling. Applications should only set a close handler when the
// application must perform some action before sending a close frame back to // application must perform some action before sending a close message back to
// the peer. // the peer.
func (c *Conn) SetCloseHandler(h func(code int, text string) error) { func (c *Conn) SetCloseHandler(h func(code int, text string) error) {
if h == nil { if h == nil {
h = func(code int, text string) error { h = func(code int, text string) error {
message := []byte{} message := FormatCloseMessage(code, "")
if code != CloseNoStatusReceived {
message = FormatCloseMessage(code, "")
}
c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) c.WriteControl(CloseMessage, message, time.Now().Add(writeWait))
return nil return nil
} }
@ -1077,11 +1084,12 @@ func (c *Conn) PingHandler() func(appData string) error {
} }
// SetPingHandler sets the handler for ping messages received from the peer. // SetPingHandler sets the handler for ping messages received from the peer.
// The appData argument to h is the PING frame application data. The default // The appData argument to h is the PING message application data. The default
// ping handler sends a pong to the peer. // ping handler sends a pong to the peer.
// //
// The application must read the connection to process ping messages as // The handler function is called from the NextReader, ReadMessage and message
// described in the section on Control Frames above. // reader Read methods. The application must read the connection to process
// ping messages as described in the section on Control Messages above.
func (c *Conn) SetPingHandler(h func(appData string) error) { func (c *Conn) SetPingHandler(h func(appData string) error) {
if h == nil { if h == nil {
h = func(message string) error { h = func(message string) error {
@ -1103,11 +1111,12 @@ func (c *Conn) PongHandler() func(appData string) error {
} }
// SetPongHandler sets the handler for pong messages received from the peer. // SetPongHandler sets the handler for pong messages received from the peer.
// The appData argument to h is the PONG frame application data. The default // The appData argument to h is the PONG message application data. The default
// pong handler does nothing. // pong handler does nothing.
// //
// The application must read the connection to process ping messages as // The handler function is called from the NextReader, ReadMessage and message
// described in the section on Control Frames above. // reader Read methods. The application must read the connection to process
// pong messages as described in the section on Control Messages above.
func (c *Conn) SetPongHandler(h func(appData string) error) { func (c *Conn) SetPongHandler(h func(appData string) error) {
if h == nil { if h == nil {
h = func(string) error { return nil } h = func(string) error { return nil }
@ -1141,7 +1150,14 @@ func (c *Conn) SetCompressionLevel(level int) error {
} }
// FormatCloseMessage formats closeCode and text as a WebSocket close message. // FormatCloseMessage formats closeCode and text as a WebSocket close message.
// An empty message is returned for code CloseNoStatusReceived.
func FormatCloseMessage(closeCode int, text string) []byte { func FormatCloseMessage(closeCode int, text string) []byte {
if closeCode == CloseNoStatusReceived {
// Return empty message because it's illegal to send
// CloseNoStatusReceived. Return non-nil value in case application
// checks for nil.
return []byte{}
}
buf := make([]byte, 2+len(text)) buf := make([]byte, 2+len(text))
binary.BigEndian.PutUint16(buf, uint16(closeCode)) binary.BigEndian.PutUint16(buf, uint16(closeCode))
copy(buf[2:], text) copy(buf[2:], text)

View File

@ -1,21 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.5
package websocket
import "io"
func (c *Conn) read(n int) ([]byte, error) {
p, err := c.br.Peek(n)
if err == io.EOF {
err = errUnexpectedEOF
}
if len(p) > 0 {
// advance over the bytes just read
io.ReadFull(c.br, p)
}
return p, err
}

View File

@ -2,17 +2,14 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build go1.5 // +build go1.8
package websocket package websocket
import "io" import "net"
func (c *Conn) read(n int) ([]byte, error) { func (c *Conn) writeBufs(bufs ...[]byte) error {
p, err := c.br.Peek(n) b := net.Buffers(bufs)
if err == io.EOF { _, err := b.WriteTo(c.conn)
err = errUnexpectedEOF return err
}
c.br.Discard(len(p))
return p, err
} }

View File

@ -0,0 +1,18 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
func (c *Conn) writeBufs(bufs ...[]byte) error {
for _, buf := range bufs {
if len(buf) > 0 {
if _, err := c.conn.Write(buf); err != nil {
return err
}
}
}
return nil
}

View File

@ -6,9 +6,8 @@
// //
// Overview // Overview
// //
// The Conn type represents a WebSocket connection. A server application uses // The Conn type represents a WebSocket connection. A server application calls
// the Upgrade function from an Upgrader object with a HTTP request handler // the Upgrader.Upgrade method from an HTTP request handler to get a *Conn:
// to get a pointer to a Conn:
// //
// var upgrader = websocket.Upgrader{ // var upgrader = websocket.Upgrader{
// ReadBufferSize: 1024, // ReadBufferSize: 1024,
@ -31,10 +30,12 @@
// for { // for {
// messageType, p, err := conn.ReadMessage() // messageType, p, err := conn.ReadMessage()
// if err != nil { // if err != nil {
// log.Println(err)
// return // return
// } // }
// if err = conn.WriteMessage(messageType, p); err != nil { // if err := conn.WriteMessage(messageType, p); err != nil {
// return err // log.Println(err)
// return
// } // }
// } // }
// //
@ -85,20 +86,26 @@
// and pong. Call the connection WriteControl, WriteMessage or NextWriter // and pong. Call the connection WriteControl, WriteMessage or NextWriter
// methods to send a control message to the peer. // methods to send a control message to the peer.
// //
// Connections handle received close messages by sending a close message to the // Connections handle received close messages by calling the handler function
// peer and returning a *CloseError from the the NextReader, ReadMessage or the // set with the SetCloseHandler method and by returning a *CloseError from the
// message Read method. // NextReader, ReadMessage or the message Read method. The default close
// handler sends a close message to the peer.
// //
// Connections handle received ping and pong messages by invoking callback // Connections handle received ping messages by calling the handler function
// functions set with SetPingHandler and SetPongHandler methods. The callback // set with the SetPingHandler method. The default ping handler sends a pong
// functions are called from the NextReader, ReadMessage and the message Read // message to the peer.
// methods.
// //
// The default ping handler sends a pong to the peer. The application's reading // Connections handle received pong messages by calling the handler function
// goroutine can block for a short time while the handler writes the pong data // set with the SetPongHandler method. The default pong handler does nothing.
// to the connection. // If an application sends ping messages, then the application should set a
// pong handler to receive the corresponding pong.
// //
// The application must read the connection to process ping, pong and close // The control message handler functions are called from the NextReader,
// ReadMessage and message reader Read methods. The default close and ping
// handlers can block these methods for a short time when the handler writes to
// the connection.
//
// The application must read the connection to process close, ping and pong
// messages sent from the peer. If the application is not otherwise interested // messages sent from the peer. If the application is not otherwise interested
// in messages from the peer, then the application should start a goroutine to // in messages from the peer, then the application should start a goroutine to
// read and discard messages from the peer. A simple example is: // read and discard messages from the peer. A simple example is:
@ -137,19 +144,12 @@
// method fails the WebSocket handshake with HTTP status 403. // method fails the WebSocket handshake with HTTP status 403.
// //
// If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail // If the CheckOrigin field is nil, then the Upgrader uses a safe default: fail
// the handshake if the Origin request header is present and not equal to the // the handshake if the Origin request header is present and the Origin host is
// Host request header. // not equal to the Host request header.
// //
// An application can allow connections from any origin by specifying a // The deprecated package-level Upgrade function does not perform origin
// function that always returns true: // checking. The application is responsible for checking the Origin header
// // before calling the Upgrade function.
// var upgrader = websocket.Upgrader{
// CheckOrigin: func(r *http.Request) bool { return true },
// }
//
// The deprecated Upgrade function does not enforce an origin policy. It's the
// application's responsibility to check the Origin header before calling
// Upgrade.
// //
// Compression EXPERIMENTAL // Compression EXPERIMENTAL
// //

View File

@ -9,12 +9,14 @@ import (
"io" "io"
) )
// WriteJSON is deprecated, use c.WriteJSON instead. // WriteJSON writes the JSON encoding of v as a message.
//
// Deprecated: Use c.WriteJSON instead.
func WriteJSON(c *Conn, v interface{}) error { func WriteJSON(c *Conn, v interface{}) error {
return c.WriteJSON(v) return c.WriteJSON(v)
} }
// WriteJSON writes the JSON encoding of v to the connection. // WriteJSON writes the JSON encoding of v as a message.
// //
// See the documentation for encoding/json Marshal for details about the // See the documentation for encoding/json Marshal for details about the
// conversion of Go values to JSON. // conversion of Go values to JSON.
@ -31,7 +33,10 @@ func (c *Conn) WriteJSON(v interface{}) error {
return err2 return err2
} }
// ReadJSON is deprecated, use c.ReadJSON instead. // ReadJSON reads the next JSON-encoded message from the connection and stores
// it in the value pointed to by v.
//
// Deprecated: Use c.ReadJSON instead.
func ReadJSON(c *Conn, v interface{}) error { func ReadJSON(c *Conn, v interface{}) error {
return c.ReadJSON(v) return c.ReadJSON(v)
} }

View File

@ -11,7 +11,6 @@ import "unsafe"
const wordSize = int(unsafe.Sizeof(uintptr(0))) const wordSize = int(unsafe.Sizeof(uintptr(0)))
func maskBytes(key [4]byte, pos int, b []byte) int { func maskBytes(key [4]byte, pos int, b []byte) int {
// Mask one byte at a time for small buffers. // Mask one byte at a time for small buffers.
if len(b) < 2*wordSize { if len(b) < 2*wordSize {
for i := range b { for i := range b {

View File

@ -19,7 +19,6 @@ import (
type PreparedMessage struct { type PreparedMessage struct {
messageType int messageType int
data []byte data []byte
err error
mu sync.Mutex mu sync.Mutex
frames map[prepareKey]*preparedFrame frames map[prepareKey]*preparedFrame
} }

77
vendor/github.com/gorilla/websocket/proxy.go generated vendored Normal file
View File

@ -0,0 +1,77 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package websocket
import (
"bufio"
"encoding/base64"
"errors"
"net"
"net/http"
"net/url"
"strings"
)
type netDialerFunc func(network, addr string) (net.Conn, error)
func (fn netDialerFunc) Dial(network, addr string) (net.Conn, error) {
return fn(network, addr)
}
func init() {
proxy_RegisterDialerType("http", func(proxyURL *url.URL, forwardDialer proxy_Dialer) (proxy_Dialer, error) {
return &httpProxyDialer{proxyURL: proxyURL, fowardDial: forwardDialer.Dial}, nil
})
}
type httpProxyDialer struct {
proxyURL *url.URL
fowardDial func(network, addr string) (net.Conn, error)
}
func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error) {
hostPort, _ := hostPortNoPort(hpd.proxyURL)
conn, err := hpd.fowardDial(network, hostPort)
if err != nil {
return nil, err
}
connectHeader := make(http.Header)
if user := hpd.proxyURL.User; user != nil {
proxyUser := user.Username()
if proxyPassword, passwordSet := user.Password(); passwordSet {
credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
connectHeader.Set("Proxy-Authorization", "Basic "+credential)
}
}
connectReq := &http.Request{
Method: "CONNECT",
URL: &url.URL{Opaque: addr},
Host: addr,
Header: connectHeader,
}
if err := connectReq.Write(conn); err != nil {
conn.Close()
return nil, err
}
// Read response. It's OK to use and discard buffered reader here becaue
// the remote server does not speak until spoken to.
br := bufio.NewReader(conn)
resp, err := http.ReadResponse(br, connectReq)
if err != nil {
conn.Close()
return nil, err
}
if resp.StatusCode != 200 {
conn.Close()
f := strings.SplitN(resp.Status, " ", 2)
return nil, errors.New(f[1])
}
return conn, nil
}

View File

@ -7,7 +7,7 @@ package websocket
import ( import (
"bufio" "bufio"
"errors" "errors"
"net" "io"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
@ -33,10 +33,23 @@ type Upgrader struct {
// or received. // or received.
ReadBufferSize, WriteBufferSize int ReadBufferSize, WriteBufferSize int
// WriteBufferPool is a pool of buffers for write operations. If the value
// is not set, then write buffers are allocated to the connection for the
// lifetime of the connection.
//
// A pool is most useful when the application has a modest volume of writes
// across a large number of connections.
//
// Applications should use a single pool for each unique value of
// WriteBufferSize.
WriteBufferPool BufferPool
// Subprotocols specifies the server's supported protocols in order of // Subprotocols specifies the server's supported protocols in order of
// preference. If this field is set, then the Upgrade method negotiates a // preference. If this field is not nil, then the Upgrade method negotiates a
// subprotocol by selecting the first match in this list with a protocol // subprotocol by selecting the first match in this list with a protocol
// requested by the client. // requested by the client. If there's no match, then no protocol is
// negotiated (the Sec-Websocket-Protocol header is not included in the
// handshake response).
Subprotocols []string Subprotocols []string
// Error specifies the function for generating HTTP error responses. If Error // Error specifies the function for generating HTTP error responses. If Error
@ -44,8 +57,12 @@ type Upgrader struct {
Error func(w http.ResponseWriter, r *http.Request, status int, reason error) Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
// CheckOrigin returns true if the request Origin header is acceptable. If // CheckOrigin returns true if the request Origin header is acceptable. If
// CheckOrigin is nil, the host in the Origin header must not be set or // CheckOrigin is nil, then a safe default is used: return false if the
// must match the host of the request. // Origin request header is present and the origin host is not equal to
// request Host header.
//
// A CheckOrigin function should carefully validate the request origin to
// prevent cross-site request forgery.
CheckOrigin func(r *http.Request) bool CheckOrigin func(r *http.Request) bool
// EnableCompression specify if the server should attempt to negotiate per // EnableCompression specify if the server should attempt to negotiate per
@ -76,7 +93,7 @@ func checkSameOrigin(r *http.Request) bool {
if err != nil { if err != nil {
return false return false
} }
return u.Host == r.Host return equalASCIIFold(u.Host, r.Host)
} }
func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string { func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
@ -99,42 +116,44 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
// //
// The responseHeader is included in the response to the client's upgrade // The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the // request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-Websocket-Protocol). // application negotiated subprotocol (Sec-WebSocket-Protocol).
// //
// If the upgrade fails, then Upgrade replies to the client with an HTTP error // If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response. // response.
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
if r.Method != "GET" { const badHandshake = "websocket: the client is not using the websocket protocol: "
return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: not a websocket handshake: request method is not GET")
}
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-Websocket-Extensions' headers are unsupported")
}
if !tokenListContainsValue(r.Header, "Connection", "upgrade") { if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'upgrade' token not found in 'Connection' header") return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
} }
if !tokenListContainsValue(r.Header, "Upgrade", "websocket") { if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'websocket' token not found in 'Upgrade' header") return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
}
if r.Method != "GET" {
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
} }
if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") { if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
} }
if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
}
checkOrigin := u.CheckOrigin checkOrigin := u.CheckOrigin
if checkOrigin == nil { if checkOrigin == nil {
checkOrigin = checkSameOrigin checkOrigin = checkSameOrigin
} }
if !checkOrigin(r) { if !checkOrigin(r) {
return u.returnError(w, r, http.StatusForbidden, "websocket: 'Origin' header value not allowed") return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
} }
challengeKey := r.Header.Get("Sec-Websocket-Key") challengeKey := r.Header.Get("Sec-Websocket-Key")
if challengeKey == "" { if challengeKey == "" {
return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-Websocket-Key' header is missing or blank") return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank")
} }
subprotocol := u.selectSubprotocol(r, responseHeader) subprotocol := u.selectSubprotocol(r, responseHeader)
@ -151,17 +170,12 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
} }
} }
var (
netConn net.Conn
err error
)
h, ok := w.(http.Hijacker) h, ok := w.(http.Hijacker)
if !ok { if !ok {
return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker") return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
} }
var brw *bufio.ReadWriter var brw *bufio.ReadWriter
netConn, brw, err = h.Hijack() netConn, brw, err := h.Hijack()
if err != nil { if err != nil {
return u.returnError(w, r, http.StatusInternalServerError, err.Error()) return u.returnError(w, r, http.StatusInternalServerError, err.Error())
} }
@ -171,7 +185,21 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return nil, errors.New("websocket: client sent data before handshake is complete") return nil, errors.New("websocket: client sent data before handshake is complete")
} }
c := newConnBRW(netConn, true, u.ReadBufferSize, u.WriteBufferSize, brw) var br *bufio.Reader
if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
// Reuse hijacked buffered reader as connection reader.
br = brw.Reader
}
buf := bufioWriterBuffer(netConn, brw.Writer)
var writeBuf []byte
if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
// Reuse hijacked write buffer as connection buffer.
writeBuf = buf
}
c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
c.subprotocol = subprotocol c.subprotocol = subprotocol
if compress { if compress {
@ -179,17 +207,23 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
c.newDecompressionReader = decompressNoContextTakeover c.newDecompressionReader = decompressNoContextTakeover
} }
p := c.writeBuf[:0] // Use larger of hijacked buffer and connection write buffer for header.
p := buf
if len(c.writeBuf) > len(p) {
p = c.writeBuf
}
p = p[:0]
p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
p = append(p, computeAcceptKey(challengeKey)...) p = append(p, computeAcceptKey(challengeKey)...)
p = append(p, "\r\n"...) p = append(p, "\r\n"...)
if c.subprotocol != "" { if c.subprotocol != "" {
p = append(p, "Sec-Websocket-Protocol: "...) p = append(p, "Sec-WebSocket-Protocol: "...)
p = append(p, c.subprotocol...) p = append(p, c.subprotocol...)
p = append(p, "\r\n"...) p = append(p, "\r\n"...)
} }
if compress { if compress {
p = append(p, "Sec-Websocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...) p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
} }
for k, vs := range responseHeader { for k, vs := range responseHeader {
if k == "Sec-Websocket-Protocol" { if k == "Sec-Websocket-Protocol" {
@ -230,13 +264,14 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
// Upgrade upgrades the HTTP server connection to the WebSocket protocol. // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
// //
// This function is deprecated, use websocket.Upgrader instead. // Deprecated: Use websocket.Upgrader instead.
// //
// The application is responsible for checking the request origin before // Upgrade does not perform origin checking. The application is responsible for
// calling Upgrade. An example implementation of the same origin policy is: // checking the Origin header before calling Upgrade. An example implementation
// of the same origin policy check is:
// //
// if req.Header.Get("Origin") != "http://"+req.Host { // if req.Header.Get("Origin") != "http://"+req.Host {
// http.Error(w, "Origin not allowed", 403) // http.Error(w, "Origin not allowed", http.StatusForbidden)
// return // return
// } // }
// //
@ -289,3 +324,40 @@ func IsWebSocketUpgrade(r *http.Request) bool {
return tokenListContainsValue(r.Header, "Connection", "upgrade") && return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
tokenListContainsValue(r.Header, "Upgrade", "websocket") tokenListContainsValue(r.Header, "Upgrade", "websocket")
} }
// bufioReaderSize size returns the size of a bufio.Reader.
func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
// This code assumes that peek on a reset reader returns
// bufio.Reader.buf[:0].
// TODO: Use bufio.Reader.Size() after Go 1.10
br.Reset(originalReader)
if p, err := br.Peek(0); err == nil {
return cap(p)
}
return 0
}
// writeHook is an io.Writer that records the last slice passed to it vio
// io.Writer.Write.
type writeHook struct {
p []byte
}
func (wh *writeHook) Write(p []byte) (int, error) {
wh.p = p
return len(p), nil
}
// bufioWriterBuffer grabs the buffer from a bufio.Writer.
func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
// This code assumes that bufio.Writer.buf[:1] is passed to the
// bufio.Writer's underlying writer.
var wh writeHook
bw.Reset(&wh)
bw.WriteByte(0)
bw.Flush()
bw.Reset(originalWriter)
return wh.p[:cap(wh.p)]
}

19
vendor/github.com/gorilla/websocket/trace.go generated vendored Normal file
View File

@ -0,0 +1,19 @@
// +build go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
if trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(tlsConn, cfg)
if trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
return err
}

12
vendor/github.com/gorilla/websocket/trace_17.go generated vendored Normal file
View File

@ -0,0 +1,12 @@
// +build !go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
return doHandshake(tlsConn, cfg)
}

View File

@ -11,6 +11,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"unicode/utf8"
) )
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
@ -111,14 +112,14 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
case escape: case escape:
escape = false escape = false
p[j] = b p[j] = b
j += 1 j++
case b == '\\': case b == '\\':
escape = true escape = true
case b == '"': case b == '"':
return string(p[:j]), s[i+1:] return string(p[:j]), s[i+1:]
default: default:
p[j] = b p[j] = b
j += 1 j++
} }
} }
return "", "" return "", ""
@ -127,8 +128,31 @@ func nextTokenOrQuoted(s string) (value string, rest string) {
return "", "" return "", ""
} }
// equalASCIIFold returns true if s is equal to t with ASCII case folding.
func equalASCIIFold(s, t string) bool {
for s != "" && t != "" {
sr, size := utf8.DecodeRuneInString(s)
s = s[size:]
tr, size := utf8.DecodeRuneInString(t)
t = t[size:]
if sr == tr {
continue
}
if 'A' <= sr && sr <= 'Z' {
sr = sr + 'a' - 'A'
}
if 'A' <= tr && tr <= 'Z' {
tr = tr + 'a' - 'A'
}
if sr != tr {
return false
}
}
return s == t
}
// tokenListContainsValue returns true if the 1#token header with the given // tokenListContainsValue returns true if the 1#token header with the given
// name contains token. // name contains a token equal to value with ASCII case folding.
func tokenListContainsValue(header http.Header, name string, value string) bool { func tokenListContainsValue(header http.Header, name string, value string) bool {
headers: headers:
for _, s := range header[name] { for _, s := range header[name] {
@ -142,7 +166,7 @@ headers:
if s != "" && s[0] != ',' { if s != "" && s[0] != ',' {
continue headers continue headers
} }
if strings.EqualFold(t, value) { if equalASCIIFold(t, value) {
return true return true
} }
if s == "" { if s == "" {
@ -154,9 +178,8 @@ headers:
return false return false
} }
// parseExtensiosn parses WebSocket extensions from a header. // parseExtensions parses WebSocket extensions from a header.
func parseExtensions(header http.Header) []map[string]string { func parseExtensions(header http.Header) []map[string]string {
// From RFC 6455: // From RFC 6455:
// //
// Sec-WebSocket-Extensions = extension-list // Sec-WebSocket-Extensions = extension-list

473
vendor/github.com/gorilla/websocket/x_net_proxy.go generated vendored Normal file
View File

@ -0,0 +1,473 @@
// Code generated by golang.org/x/tools/cmd/bundle. DO NOT EDIT.
//go:generate bundle -o x_net_proxy.go golang.org/x/net/proxy
// Package proxy provides support for a variety of protocols to proxy network
// data.
//
package websocket
import (
"errors"
"io"
"net"
"net/url"
"os"
"strconv"
"strings"
"sync"
)
type proxy_direct struct{}
// Direct is a direct proxy: one that makes network connections directly.
var proxy_Direct = proxy_direct{}
func (proxy_direct) Dial(network, addr string) (net.Conn, error) {
return net.Dial(network, addr)
}
// A PerHost directs connections to a default Dialer unless the host name
// requested matches one of a number of exceptions.
type proxy_PerHost struct {
def, bypass proxy_Dialer
bypassNetworks []*net.IPNet
bypassIPs []net.IP
bypassZones []string
bypassHosts []string
}
// NewPerHost returns a PerHost Dialer that directs connections to either
// defaultDialer or bypass, depending on whether the connection matches one of
// the configured rules.
func proxy_NewPerHost(defaultDialer, bypass proxy_Dialer) *proxy_PerHost {
return &proxy_PerHost{
def: defaultDialer,
bypass: bypass,
}
}
// Dial connects to the address addr on the given network through either
// defaultDialer or bypass.
func (p *proxy_PerHost) Dial(network, addr string) (c net.Conn, err error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
return p.dialerForRequest(host).Dial(network, addr)
}
func (p *proxy_PerHost) dialerForRequest(host string) proxy_Dialer {
if ip := net.ParseIP(host); ip != nil {
for _, net := range p.bypassNetworks {
if net.Contains(ip) {
return p.bypass
}
}
for _, bypassIP := range p.bypassIPs {
if bypassIP.Equal(ip) {
return p.bypass
}
}
return p.def
}
for _, zone := range p.bypassZones {
if strings.HasSuffix(host, zone) {
return p.bypass
}
if host == zone[1:] {
// For a zone ".example.com", we match "example.com"
// too.
return p.bypass
}
}
for _, bypassHost := range p.bypassHosts {
if bypassHost == host {
return p.bypass
}
}
return p.def
}
// AddFromString parses a string that contains comma-separated values
// specifying hosts that should use the bypass proxy. Each value is either an
// IP address, a CIDR range, a zone (*.example.com) or a host name
// (localhost). A best effort is made to parse the string and errors are
// ignored.
func (p *proxy_PerHost) AddFromString(s string) {
hosts := strings.Split(s, ",")
for _, host := range hosts {
host = strings.TrimSpace(host)
if len(host) == 0 {
continue
}
if strings.Contains(host, "/") {
// We assume that it's a CIDR address like 127.0.0.0/8
if _, net, err := net.ParseCIDR(host); err == nil {
p.AddNetwork(net)
}
continue
}
if ip := net.ParseIP(host); ip != nil {
p.AddIP(ip)
continue
}
if strings.HasPrefix(host, "*.") {
p.AddZone(host[1:])
continue
}
p.AddHost(host)
}
}
// AddIP specifies an IP address that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match an IP.
func (p *proxy_PerHost) AddIP(ip net.IP) {
p.bypassIPs = append(p.bypassIPs, ip)
}
// AddNetwork specifies an IP range that will use the bypass proxy. Note that
// this will only take effect if a literal IP address is dialed. A connection
// to a named host will never match.
func (p *proxy_PerHost) AddNetwork(net *net.IPNet) {
p.bypassNetworks = append(p.bypassNetworks, net)
}
// AddZone specifies a DNS suffix that will use the bypass proxy. A zone of
// "example.com" matches "example.com" and all of its subdomains.
func (p *proxy_PerHost) AddZone(zone string) {
if strings.HasSuffix(zone, ".") {
zone = zone[:len(zone)-1]
}
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
p.bypassZones = append(p.bypassZones, zone)
}
// AddHost specifies a host name that will use the bypass proxy.
func (p *proxy_PerHost) AddHost(host string) {
if strings.HasSuffix(host, ".") {
host = host[:len(host)-1]
}
p.bypassHosts = append(p.bypassHosts, host)
}
// A Dialer is a means to establish a connection.
type proxy_Dialer interface {
// Dial connects to the given address via the proxy.
Dial(network, addr string) (c net.Conn, err error)
}
// Auth contains authentication parameters that specific Dialers may require.
type proxy_Auth struct {
User, Password string
}
// FromEnvironment returns the dialer specified by the proxy related variables in
// the environment.
func proxy_FromEnvironment() proxy_Dialer {
allProxy := proxy_allProxyEnv.Get()
if len(allProxy) == 0 {
return proxy_Direct
}
proxyURL, err := url.Parse(allProxy)
if err != nil {
return proxy_Direct
}
proxy, err := proxy_FromURL(proxyURL, proxy_Direct)
if err != nil {
return proxy_Direct
}
noProxy := proxy_noProxyEnv.Get()
if len(noProxy) == 0 {
return proxy
}
perHost := proxy_NewPerHost(proxy, proxy_Direct)
perHost.AddFromString(noProxy)
return perHost
}
// proxySchemes is a map from URL schemes to a function that creates a Dialer
// from a URL with such a scheme.
var proxy_proxySchemes map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error)
// RegisterDialerType takes a URL scheme and a function to generate Dialers from
// a URL with that scheme and a forwarding Dialer. Registered schemes are used
// by FromURL.
func proxy_RegisterDialerType(scheme string, f func(*url.URL, proxy_Dialer) (proxy_Dialer, error)) {
if proxy_proxySchemes == nil {
proxy_proxySchemes = make(map[string]func(*url.URL, proxy_Dialer) (proxy_Dialer, error))
}
proxy_proxySchemes[scheme] = f
}
// FromURL returns a Dialer given a URL specification and an underlying
// Dialer for it to make network requests.
func proxy_FromURL(u *url.URL, forward proxy_Dialer) (proxy_Dialer, error) {
var auth *proxy_Auth
if u.User != nil {
auth = new(proxy_Auth)
auth.User = u.User.Username()
if p, ok := u.User.Password(); ok {
auth.Password = p
}
}
switch u.Scheme {
case "socks5":
return proxy_SOCKS5("tcp", u.Host, auth, forward)
}
// If the scheme doesn't match any of the built-in schemes, see if it
// was registered by another package.
if proxy_proxySchemes != nil {
if f, ok := proxy_proxySchemes[u.Scheme]; ok {
return f(u, forward)
}
}
return nil, errors.New("proxy: unknown scheme: " + u.Scheme)
}
var (
proxy_allProxyEnv = &proxy_envOnce{
names: []string{"ALL_PROXY", "all_proxy"},
}
proxy_noProxyEnv = &proxy_envOnce{
names: []string{"NO_PROXY", "no_proxy"},
}
)
// envOnce looks up an environment variable (optionally by multiple
// names) once. It mitigates expensive lookups on some platforms
// (e.g. Windows).
// (Borrowed from net/http/transport.go)
type proxy_envOnce struct {
names []string
once sync.Once
val string
}
func (e *proxy_envOnce) Get() string {
e.once.Do(e.init)
return e.val
}
func (e *proxy_envOnce) init() {
for _, n := range e.names {
e.val = os.Getenv(n)
if e.val != "" {
return
}
}
}
// SOCKS5 returns a Dialer that makes SOCKSv5 connections to the given address
// with an optional username and password. See RFC 1928 and RFC 1929.
func proxy_SOCKS5(network, addr string, auth *proxy_Auth, forward proxy_Dialer) (proxy_Dialer, error) {
s := &proxy_socks5{
network: network,
addr: addr,
forward: forward,
}
if auth != nil {
s.user = auth.User
s.password = auth.Password
}
return s, nil
}
type proxy_socks5 struct {
user, password string
network, addr string
forward proxy_Dialer
}
const proxy_socks5Version = 5
const (
proxy_socks5AuthNone = 0
proxy_socks5AuthPassword = 2
)
const proxy_socks5Connect = 1
const (
proxy_socks5IP4 = 1
proxy_socks5Domain = 3
proxy_socks5IP6 = 4
)
var proxy_socks5Errors = []string{
"",
"general failure",
"connection forbidden",
"network unreachable",
"host unreachable",
"connection refused",
"TTL expired",
"command not supported",
"address type not supported",
}
// Dial connects to the address addr on the given network via the SOCKS5 proxy.
func (s *proxy_socks5) Dial(network, addr string) (net.Conn, error) {
switch network {
case "tcp", "tcp6", "tcp4":
default:
return nil, errors.New("proxy: no support for SOCKS5 proxy connections of type " + network)
}
conn, err := s.forward.Dial(s.network, s.addr)
if err != nil {
return nil, err
}
if err := s.connect(conn, addr); err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
// connect takes an existing connection to a socks5 proxy server,
// and commands the server to extend that connection to target,
// which must be a canonical address with a host and port.
func (s *proxy_socks5) connect(conn net.Conn, target string) error {
host, portStr, err := net.SplitHostPort(target)
if err != nil {
return err
}
port, err := strconv.Atoi(portStr)
if err != nil {
return errors.New("proxy: failed to parse port number: " + portStr)
}
if port < 1 || port > 0xffff {
return errors.New("proxy: port number out of range: " + portStr)
}
// the size here is just an estimate
buf := make([]byte, 0, 6+len(host))
buf = append(buf, proxy_socks5Version)
if len(s.user) > 0 && len(s.user) < 256 && len(s.password) < 256 {
buf = append(buf, 2 /* num auth methods */, proxy_socks5AuthNone, proxy_socks5AuthPassword)
} else {
buf = append(buf, 1 /* num auth methods */, proxy_socks5AuthNone)
}
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write greeting to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read greeting from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[0] != 5 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " has unexpected version " + strconv.Itoa(int(buf[0])))
}
if buf[1] == 0xff {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " requires authentication")
}
// See RFC 1929
if buf[1] == proxy_socks5AuthPassword {
buf = buf[:0]
buf = append(buf, 1 /* password protocol version */)
buf = append(buf, uint8(len(s.user)))
buf = append(buf, s.user...)
buf = append(buf, uint8(len(s.password)))
buf = append(buf, s.password...)
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write authentication request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read authentication reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if buf[1] != 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " rejected username/password")
}
}
buf = buf[:0]
buf = append(buf, proxy_socks5Version, proxy_socks5Connect, 0 /* reserved */)
if ip := net.ParseIP(host); ip != nil {
if ip4 := ip.To4(); ip4 != nil {
buf = append(buf, proxy_socks5IP4)
ip = ip4
} else {
buf = append(buf, proxy_socks5IP6)
}
buf = append(buf, ip...)
} else {
if len(host) > 255 {
return errors.New("proxy: destination host name too long: " + host)
}
buf = append(buf, proxy_socks5Domain)
buf = append(buf, byte(len(host)))
buf = append(buf, host...)
}
buf = append(buf, byte(port>>8), byte(port))
if _, err := conn.Write(buf); err != nil {
return errors.New("proxy: failed to write connect request to SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
if _, err := io.ReadFull(conn, buf[:4]); err != nil {
return errors.New("proxy: failed to read connect reply from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
failure := "unknown error"
if int(buf[1]) < len(proxy_socks5Errors) {
failure = proxy_socks5Errors[buf[1]]
}
if len(failure) > 0 {
return errors.New("proxy: SOCKS5 proxy at " + s.addr + " failed to connect: " + failure)
}
bytesToDiscard := 0
switch buf[3] {
case proxy_socks5IP4:
bytesToDiscard = net.IPv4len
case proxy_socks5IP6:
bytesToDiscard = net.IPv6len
case proxy_socks5Domain:
_, err := io.ReadFull(conn, buf[:1])
if err != nil {
return errors.New("proxy: failed to read domain length from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
bytesToDiscard = int(buf[0])
default:
return errors.New("proxy: got unknown address type " + strconv.Itoa(int(buf[3])) + " from SOCKS5 proxy at " + s.addr)
}
if cap(buf) < bytesToDiscard {
buf = make([]byte, bytesToDiscard)
} else {
buf = buf[:bytesToDiscard]
}
if _, err := io.ReadFull(conn, buf); err != nil {
return errors.New("proxy: failed to read address from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
// Also need to discard the port number
if _, err := io.ReadFull(conn, buf[:2]); err != nil {
return errors.New("proxy: failed to read port from SOCKS5 proxy at " + s.addr + ": " + err.Error())
}
return nil
}

View File

@ -1,476 +0,0 @@
// +build ignore
package main
import (
"bytes"
"fmt"
"go/ast"
"go/parser"
"go/printer"
"go/token"
"io"
"io/ioutil"
"log"
"os"
"reflect"
"strings"
"unicode"
"unicode/utf8"
)
var inFiles = []string{"cpuid.go", "cpuid_test.go"}
var copyFiles = []string{"cpuid_amd64.s", "cpuid_386.s", "detect_ref.go", "detect_intel.go"}
var fileSet = token.NewFileSet()
var reWrites = []rewrite{
initRewrite("CPUInfo -> cpuInfo"),
initRewrite("Vendor -> vendor"),
initRewrite("Flags -> flags"),
initRewrite("Detect -> detect"),
initRewrite("CPU -> cpu"),
}
var excludeNames = map[string]bool{"string": true, "join": true, "trim": true,
// cpuid_test.go
"t": true, "println": true, "logf": true, "log": true, "fatalf": true, "fatal": true,
}
var excludePrefixes = []string{"test", "benchmark"}
func main() {
Package := "private"
parserMode := parser.ParseComments
exported := make(map[string]rewrite)
for _, file := range inFiles {
in, err := os.Open(file)
if err != nil {
log.Fatalf("opening input", err)
}
src, err := ioutil.ReadAll(in)
if err != nil {
log.Fatalf("reading input", err)
}
astfile, err := parser.ParseFile(fileSet, file, src, parserMode)
if err != nil {
log.Fatalf("parsing input", err)
}
for _, rw := range reWrites {
astfile = rw(astfile)
}
// Inspect the AST and print all identifiers and literals.
var startDecl token.Pos
var endDecl token.Pos
ast.Inspect(astfile, func(n ast.Node) bool {
var s string
switch x := n.(type) {
case *ast.Ident:
if x.IsExported() {
t := strings.ToLower(x.Name)
for _, pre := range excludePrefixes {
if strings.HasPrefix(t, pre) {
return true
}
}
if excludeNames[t] != true {
//if x.Pos() > startDecl && x.Pos() < endDecl {
exported[x.Name] = initRewrite(x.Name + " -> " + t)
}
}
case *ast.GenDecl:
if x.Tok == token.CONST && x.Lparen > 0 {
startDecl = x.Lparen
endDecl = x.Rparen
// fmt.Printf("Decl:%s -> %s\n", fileSet.Position(startDecl), fileSet.Position(endDecl))
}
}
if s != "" {
fmt.Printf("%s:\t%s\n", fileSet.Position(n.Pos()), s)
}
return true
})
for _, rw := range exported {
astfile = rw(astfile)
}
var buf bytes.Buffer
printer.Fprint(&buf, fileSet, astfile)
// Remove package documentation and insert information
s := buf.String()
ind := strings.Index(buf.String(), "\npackage cpuid")
s = s[ind:]
s = "// Generated, DO NOT EDIT,\n" +
"// but copy it to your own project and rename the package.\n" +
"// See more at http://github.com/klauspost/cpuid\n" +
s
outputName := Package + string(os.PathSeparator) + file
err = ioutil.WriteFile(outputName, []byte(s), 0644)
if err != nil {
log.Fatalf("writing output: %s", err)
}
log.Println("Generated", outputName)
}
for _, file := range copyFiles {
dst := ""
if strings.HasPrefix(file, "cpuid") {
dst = Package + string(os.PathSeparator) + file
} else {
dst = Package + string(os.PathSeparator) + "cpuid_" + file
}
err := copyFile(file, dst)
if err != nil {
log.Fatalf("copying file: %s", err)
}
log.Println("Copied", dst)
}
}
// CopyFile copies a file from src to dst. If src and dst files exist, and are
// the same, then return success. Copy the file contents from src to dst.
func copyFile(src, dst string) (err error) {
sfi, err := os.Stat(src)
if err != nil {
return
}
if !sfi.Mode().IsRegular() {
// cannot copy non-regular files (e.g., directories,
// symlinks, devices, etc.)
return fmt.Errorf("CopyFile: non-regular source file %s (%q)", sfi.Name(), sfi.Mode().String())
}
dfi, err := os.Stat(dst)
if err != nil {
if !os.IsNotExist(err) {
return
}
} else {
if !(dfi.Mode().IsRegular()) {
return fmt.Errorf("CopyFile: non-regular destination file %s (%q)", dfi.Name(), dfi.Mode().String())
}
if os.SameFile(sfi, dfi) {
return
}
}
err = copyFileContents(src, dst)
return
}
// copyFileContents copies the contents of the file named src to the file named
// by dst. The file will be created if it does not already exist. If the
// destination file exists, all it's contents will be replaced by the contents
// of the source file.
func copyFileContents(src, dst string) (err error) {
in, err := os.Open(src)
if err != nil {
return
}
defer in.Close()
out, err := os.Create(dst)
if err != nil {
return
}
defer func() {
cerr := out.Close()
if err == nil {
err = cerr
}
}()
if _, err = io.Copy(out, in); err != nil {
return
}
err = out.Sync()
return
}
type rewrite func(*ast.File) *ast.File
// Mostly copied from gofmt
func initRewrite(rewriteRule string) rewrite {
f := strings.Split(rewriteRule, "->")
if len(f) != 2 {
fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n")
os.Exit(2)
}
pattern := parseExpr(f[0], "pattern")
replace := parseExpr(f[1], "replacement")
return func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) }
}
// parseExpr parses s as an expression.
// It might make sense to expand this to allow statement patterns,
// but there are problems with preserving formatting and also
// with what a wildcard for a statement looks like.
func parseExpr(s, what string) ast.Expr {
x, err := parser.ParseExpr(s)
if err != nil {
fmt.Fprintf(os.Stderr, "parsing %s %s at %s\n", what, s, err)
os.Exit(2)
}
return x
}
// Keep this function for debugging.
/*
func dump(msg string, val reflect.Value) {
fmt.Printf("%s:\n", msg)
ast.Print(fileSet, val.Interface())
fmt.Println()
}
*/
// rewriteFile applies the rewrite rule 'pattern -> replace' to an entire file.
func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File {
cmap := ast.NewCommentMap(fileSet, p, p.Comments)
m := make(map[string]reflect.Value)
pat := reflect.ValueOf(pattern)
repl := reflect.ValueOf(replace)
var rewriteVal func(val reflect.Value) reflect.Value
rewriteVal = func(val reflect.Value) reflect.Value {
// don't bother if val is invalid to start with
if !val.IsValid() {
return reflect.Value{}
}
for k := range m {
delete(m, k)
}
val = apply(rewriteVal, val)
if match(m, pat, val) {
val = subst(m, repl, reflect.ValueOf(val.Interface().(ast.Node).Pos()))
}
return val
}
r := apply(rewriteVal, reflect.ValueOf(p)).Interface().(*ast.File)
r.Comments = cmap.Filter(r).Comments() // recreate comments list
return r
}
// set is a wrapper for x.Set(y); it protects the caller from panics if x cannot be changed to y.
func set(x, y reflect.Value) {
// don't bother if x cannot be set or y is invalid
if !x.CanSet() || !y.IsValid() {
return
}
defer func() {
if x := recover(); x != nil {
if s, ok := x.(string); ok &&
(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
// x cannot be set to y - ignore this rewrite
return
}
panic(x)
}
}()
x.Set(y)
}
// Values/types for special cases.
var (
objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
identType = reflect.TypeOf((*ast.Ident)(nil))
objectPtrType = reflect.TypeOf((*ast.Object)(nil))
positionType = reflect.TypeOf(token.NoPos)
callExprType = reflect.TypeOf((*ast.CallExpr)(nil))
scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
)
// apply replaces each AST field x in val with f(x), returning val.
// To avoid extra conversions, f operates on the reflect.Value form.
func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
if !val.IsValid() {
return reflect.Value{}
}
// *ast.Objects introduce cycles and are likely incorrect after
// rewrite; don't follow them but replace with nil instead
if val.Type() == objectPtrType {
return objectPtrNil
}
// similarly for scopes: they are likely incorrect after a rewrite;
// replace them with nil
if val.Type() == scopePtrType {
return scopePtrNil
}
switch v := reflect.Indirect(val); v.Kind() {
case reflect.Slice:
for i := 0; i < v.Len(); i++ {
e := v.Index(i)
set(e, f(e))
}
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
e := v.Field(i)
set(e, f(e))
}
case reflect.Interface:
e := v.Elem()
set(v, f(e))
}
return val
}
func isWildcard(s string) bool {
rune, size := utf8.DecodeRuneInString(s)
return size == len(s) && unicode.IsLower(rune)
}
// match returns true if pattern matches val,
// recording wildcard submatches in m.
// If m == nil, match checks whether pattern == val.
func match(m map[string]reflect.Value, pattern, val reflect.Value) bool {
// Wildcard matches any expression. If it appears multiple
// times in the pattern, it must match the same expression
// each time.
if m != nil && pattern.IsValid() && pattern.Type() == identType {
name := pattern.Interface().(*ast.Ident).Name
if isWildcard(name) && val.IsValid() {
// wildcards only match valid (non-nil) expressions.
if _, ok := val.Interface().(ast.Expr); ok && !val.IsNil() {
if old, ok := m[name]; ok {
return match(nil, old, val)
}
m[name] = val
return true
}
}
}
// Otherwise, pattern and val must match recursively.
if !pattern.IsValid() || !val.IsValid() {
return !pattern.IsValid() && !val.IsValid()
}
if pattern.Type() != val.Type() {
return false
}
// Special cases.
switch pattern.Type() {
case identType:
// For identifiers, only the names need to match
// (and none of the other *ast.Object information).
// This is a common case, handle it all here instead
// of recursing down any further via reflection.
p := pattern.Interface().(*ast.Ident)
v := val.Interface().(*ast.Ident)
return p == nil && v == nil || p != nil && v != nil && p.Name == v.Name
case objectPtrType, positionType:
// object pointers and token positions always match
return true
case callExprType:
// For calls, the Ellipsis fields (token.Position) must
// match since that is how f(x) and f(x...) are different.
// Check them here but fall through for the remaining fields.
p := pattern.Interface().(*ast.CallExpr)
v := val.Interface().(*ast.CallExpr)
if p.Ellipsis.IsValid() != v.Ellipsis.IsValid() {
return false
}
}
p := reflect.Indirect(pattern)
v := reflect.Indirect(val)
if !p.IsValid() || !v.IsValid() {
return !p.IsValid() && !v.IsValid()
}
switch p.Kind() {
case reflect.Slice:
if p.Len() != v.Len() {
return false
}
for i := 0; i < p.Len(); i++ {
if !match(m, p.Index(i), v.Index(i)) {
return false
}
}
return true
case reflect.Struct:
for i := 0; i < p.NumField(); i++ {
if !match(m, p.Field(i), v.Field(i)) {
return false
}
}
return true
case reflect.Interface:
return match(m, p.Elem(), v.Elem())
}
// Handle token integers, etc.
return p.Interface() == v.Interface()
}
// subst returns a copy of pattern with values from m substituted in place
// of wildcards and pos used as the position of tokens from the pattern.
// if m == nil, subst returns a copy of pattern and doesn't change the line
// number information.
func subst(m map[string]reflect.Value, pattern reflect.Value, pos reflect.Value) reflect.Value {
if !pattern.IsValid() {
return reflect.Value{}
}
// Wildcard gets replaced with map value.
if m != nil && pattern.Type() == identType {
name := pattern.Interface().(*ast.Ident).Name
if isWildcard(name) {
if old, ok := m[name]; ok {
return subst(nil, old, reflect.Value{})
}
}
}
if pos.IsValid() && pattern.Type() == positionType {
// use new position only if old position was valid in the first place
if old := pattern.Interface().(token.Pos); !old.IsValid() {
return pattern
}
return pos
}
// Otherwise copy.
switch p := pattern; p.Kind() {
case reflect.Slice:
v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
for i := 0; i < p.Len(); i++ {
v.Index(i).Set(subst(m, p.Index(i), pos))
}
return v
case reflect.Struct:
v := reflect.New(p.Type()).Elem()
for i := 0; i < p.NumField(); i++ {
v.Field(i).Set(subst(m, p.Field(i), pos))
}
return v
case reflect.Ptr:
v := reflect.New(p.Type()).Elem()
if elem := p.Elem(); elem.IsValid() {
v.Set(subst(m, elem, pos).Addr())
}
return v
case reflect.Interface:
v := reflect.New(p.Type()).Elem()
if elem := p.Elem(); elem.IsValid() {
v.Set(subst(m, elem, pos))
}
return v
}
return pattern
}

View File

@ -1,132 +0,0 @@
//+build ignore
package main
import (
"fmt"
)
var logTable = [fieldSize]int16{
-1, 0, 1, 25, 2, 50, 26, 198,
3, 223, 51, 238, 27, 104, 199, 75,
4, 100, 224, 14, 52, 141, 239, 129,
28, 193, 105, 248, 200, 8, 76, 113,
5, 138, 101, 47, 225, 36, 15, 33,
53, 147, 142, 218, 240, 18, 130, 69,
29, 181, 194, 125, 106, 39, 249, 185,
201, 154, 9, 120, 77, 228, 114, 166,
6, 191, 139, 98, 102, 221, 48, 253,
226, 152, 37, 179, 16, 145, 34, 136,
54, 208, 148, 206, 143, 150, 219, 189,
241, 210, 19, 92, 131, 56, 70, 64,
30, 66, 182, 163, 195, 72, 126, 110,
107, 58, 40, 84, 250, 133, 186, 61,
202, 94, 155, 159, 10, 21, 121, 43,
78, 212, 229, 172, 115, 243, 167, 87,
7, 112, 192, 247, 140, 128, 99, 13,
103, 74, 222, 237, 49, 197, 254, 24,
227, 165, 153, 119, 38, 184, 180, 124,
17, 68, 146, 217, 35, 32, 137, 46,
55, 63, 209, 91, 149, 188, 207, 205,
144, 135, 151, 178, 220, 252, 190, 97,
242, 86, 211, 171, 20, 42, 93, 158,
132, 60, 57, 83, 71, 109, 65, 162,
31, 45, 67, 216, 183, 123, 164, 118,
196, 23, 73, 236, 127, 12, 111, 246,
108, 161, 59, 82, 41, 157, 85, 170,
251, 96, 134, 177, 187, 204, 62, 90,
203, 89, 95, 176, 156, 169, 160, 81,
11, 245, 22, 235, 122, 117, 44, 215,
79, 174, 213, 233, 230, 231, 173, 232,
116, 214, 244, 234, 168, 80, 88, 175,
}
const (
// The number of elements in the field.
fieldSize = 256
// The polynomial used to generate the logarithm table.
//
// There are a number of polynomials that work to generate
// a Galois field of 256 elements. The choice is arbitrary,
// and we just use the first one.
//
// The possibilities are: 29, 43, 45, 77, 95, 99, 101, 105,
//* 113, 135, 141, 169, 195, 207, 231, and 245.
generatingPolynomial = 29
)
func main() {
t := generateExpTable()
fmt.Printf("var expTable = %#v\n", t)
//t2 := generateMulTableSplit(t)
//fmt.Printf("var mulTable = %#v\n", t2)
low, high := generateMulTableHalf(t)
fmt.Printf("var mulTableLow = %#v\n", low)
fmt.Printf("var mulTableHigh = %#v\n", high)
}
/**
* Generates the inverse log table.
*/
func generateExpTable() []byte {
result := make([]byte, fieldSize*2-2)
for i := 1; i < fieldSize; i++ {
log := logTable[i]
result[log] = byte(i)
result[log+fieldSize-1] = byte(i)
}
return result
}
func generateMulTable(expTable []byte) []byte {
result := make([]byte, 256*256)
for v := range result {
a := byte(v & 0xff)
b := byte(v >> 8)
if a == 0 || b == 0 {
result[v] = 0
continue
}
logA := int(logTable[a])
logB := int(logTable[b])
result[v] = expTable[logA+logB]
}
return result
}
func generateMulTableSplit(expTable []byte) [256][256]byte {
var result [256][256]byte
for a := range result {
for b := range result[a] {
if a == 0 || b == 0 {
result[a][b] = 0
continue
}
logA := int(logTable[a])
logB := int(logTable[b])
result[a][b] = expTable[logA+logB]
}
}
return result
}
func generateMulTableHalf(expTable []byte) (low [256][16]byte, high [256][16]byte) {
for a := range low {
for b := range low {
result := 0
if !(a == 0 || b == 0) {
logA := int(logTable[a])
logB := int(logTable[b])
result = int(expTable[logA+logB])
}
if (b & 0xf) == b {
low[a][b] = byte(result)
}
if (b & 0xf0) == b {
high[a][b>>4] = byte(result)
}
}
}
return
}

View File

@ -1,22 +1,21 @@
Copyright (c) 2012 - 2013 Mat Ryer and Tyler Bunnell MIT License
Please consider promoting this project if you find it useful. Copyright (c) 2012-2018 Mat Ryer and Tyler Bunnell
Permission is hereby granted, free of charge, to any person Permission is hereby granted, free of charge, to any person obtaining a copy
obtaining a copy of this software and associated documentation of this software and associated documentation files (the "Software"), to deal
files (the "Software"), to deal in the Software without restriction, in the Software without restriction, including without limitation the rights
including without limitation the rights to use, copy, modify, merge, to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
publish, distribute, sublicense, and/or sell copies of the Software, copies of the Software, and to permit persons to whom the Software is
and to permit persons to whom the Software is furnished to do so, furnished to do so, subject to the following conditions:
subject to the following conditions:
The above copyright notice and this permission notice shall be included The above copyright notice and this permission notice shall be included in all
in all copies or substantial portions of the Software. copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. SOFTWARE.

View File

@ -13,6 +13,9 @@ import (
// Conditionf uses a Comparison to assert a complex condition. // Conditionf uses a Comparison to assert a complex condition.
func Conditionf(t TestingT, comp Comparison, msg string, args ...interface{}) bool { func Conditionf(t TestingT, comp Comparison, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Condition(t, comp, append([]interface{}{msg}, args...)...) return Condition(t, comp, append([]interface{}{msg}, args...)...)
} }
@ -23,11 +26,17 @@ func Conditionf(t TestingT, comp Comparison, msg string, args ...interface{}) bo
// assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted") // assert.Containsf(t, ["Hello", "World"], "World", "error message %s", "formatted")
// assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted") // assert.Containsf(t, {"Hello": "World"}, "Hello", "error message %s", "formatted")
func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool { func Containsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Contains(t, s, contains, append([]interface{}{msg}, args...)...) return Contains(t, s, contains, append([]interface{}{msg}, args...)...)
} }
// DirExistsf checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists. // DirExistsf checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists.
func DirExistsf(t TestingT, path string, msg string, args ...interface{}) bool { func DirExistsf(t TestingT, path string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return DirExists(t, path, append([]interface{}{msg}, args...)...) return DirExists(t, path, append([]interface{}{msg}, args...)...)
} }
@ -37,6 +46,9 @@ func DirExistsf(t TestingT, path string, msg string, args ...interface{}) bool {
// //
// assert.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted") // assert.ElementsMatchf(t, [1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted")
func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool { func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...) return ElementsMatch(t, listA, listB, append([]interface{}{msg}, args...)...)
} }
@ -45,6 +57,9 @@ func ElementsMatchf(t TestingT, listA interface{}, listB interface{}, msg string
// //
// assert.Emptyf(t, obj, "error message %s", "formatted") // assert.Emptyf(t, obj, "error message %s", "formatted")
func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool { func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Empty(t, object, append([]interface{}{msg}, args...)...) return Empty(t, object, append([]interface{}{msg}, args...)...)
} }
@ -56,6 +71,9 @@ func Emptyf(t TestingT, object interface{}, msg string, args ...interface{}) boo
// referenced values (as opposed to the memory addresses). Function equality // referenced values (as opposed to the memory addresses). Function equality
// cannot be determined and will always fail. // cannot be determined and will always fail.
func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Equal(t, expected, actual, append([]interface{}{msg}, args...)...) return Equal(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
@ -65,6 +83,9 @@ func Equalf(t TestingT, expected interface{}, actual interface{}, msg string, ar
// actualObj, err := SomeFunction() // actualObj, err := SomeFunction()
// assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted") // assert.EqualErrorf(t, err, expectedErrorString, "error message %s", "formatted")
func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) bool { func EqualErrorf(t TestingT, theError error, errString string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return EqualError(t, theError, errString, append([]interface{}{msg}, args...)...) return EqualError(t, theError, errString, append([]interface{}{msg}, args...)...)
} }
@ -73,6 +94,9 @@ func EqualErrorf(t TestingT, theError error, errString string, msg string, args
// //
// assert.EqualValuesf(t, uint32(123, "error message %s", "formatted"), int32(123)) // assert.EqualValuesf(t, uint32(123, "error message %s", "formatted"), int32(123))
func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return EqualValues(t, expected, actual, append([]interface{}{msg}, args...)...) return EqualValues(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
@ -83,6 +107,9 @@ func EqualValuesf(t TestingT, expected interface{}, actual interface{}, msg stri
// assert.Equal(t, expectedErrorf, err) // assert.Equal(t, expectedErrorf, err)
// } // }
func Errorf(t TestingT, err error, msg string, args ...interface{}) bool { func Errorf(t TestingT, err error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Error(t, err, append([]interface{}{msg}, args...)...) return Error(t, err, append([]interface{}{msg}, args...)...)
} }
@ -90,16 +117,25 @@ func Errorf(t TestingT, err error, msg string, args ...interface{}) bool {
// //
// assert.Exactlyf(t, int32(123, "error message %s", "formatted"), int64(123)) // assert.Exactlyf(t, int32(123, "error message %s", "formatted"), int64(123))
func Exactlyf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { func Exactlyf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Exactly(t, expected, actual, append([]interface{}{msg}, args...)...) return Exactly(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// Failf reports a failure through // Failf reports a failure through
func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) bool { func Failf(t TestingT, failureMessage string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Fail(t, failureMessage, append([]interface{}{msg}, args...)...) return Fail(t, failureMessage, append([]interface{}{msg}, args...)...)
} }
// FailNowf fails test // FailNowf fails test
func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}) bool { func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return FailNow(t, failureMessage, append([]interface{}{msg}, args...)...) return FailNow(t, failureMessage, append([]interface{}{msg}, args...)...)
} }
@ -107,31 +143,43 @@ func FailNowf(t TestingT, failureMessage string, msg string, args ...interface{}
// //
// assert.Falsef(t, myBool, "error message %s", "formatted") // assert.Falsef(t, myBool, "error message %s", "formatted")
func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool { func Falsef(t TestingT, value bool, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return False(t, value, append([]interface{}{msg}, args...)...) return False(t, value, append([]interface{}{msg}, args...)...)
} }
// FileExistsf checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file. // FileExistsf checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file.
func FileExistsf(t TestingT, path string, msg string, args ...interface{}) bool { func FileExistsf(t TestingT, path string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return FileExists(t, path, append([]interface{}{msg}, args...)...) return FileExists(t, path, append([]interface{}{msg}, args...)...)
} }
// HTTPBodyContainsf asserts that a specified handler returns a // HTTPBodyContainsf asserts that a specified handler returns a
// body that contains a string. // body that contains a string.
// //
// assert.HTTPBodyContainsf(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // assert.HTTPBodyContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { func HTTPBodyContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return HTTPBodyContains(t, handler, method, url, values, str, append([]interface{}{msg}, args...)...) return HTTPBodyContains(t, handler, method, url, values, str, append([]interface{}{msg}, args...)...)
} }
// HTTPBodyNotContainsf asserts that a specified handler returns a // HTTPBodyNotContainsf asserts that a specified handler returns a
// body that does not contain a string. // body that does not contain a string.
// //
// assert.HTTPBodyNotContainsf(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // assert.HTTPBodyNotContainsf(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return HTTPBodyNotContains(t, handler, method, url, values, str, append([]interface{}{msg}, args...)...) return HTTPBodyNotContains(t, handler, method, url, values, str, append([]interface{}{msg}, args...)...)
} }
@ -141,6 +189,9 @@ func HTTPBodyNotContainsf(t TestingT, handler http.HandlerFunc, method string, u
// //
// Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false).
func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return HTTPError(t, handler, method, url, values, append([]interface{}{msg}, args...)...) return HTTPError(t, handler, method, url, values, append([]interface{}{msg}, args...)...)
} }
@ -150,6 +201,9 @@ func HTTPErrorf(t TestingT, handler http.HandlerFunc, method string, url string,
// //
// Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false).
func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return HTTPRedirect(t, handler, method, url, values, append([]interface{}{msg}, args...)...) return HTTPRedirect(t, handler, method, url, values, append([]interface{}{msg}, args...)...)
} }
@ -159,6 +213,9 @@ func HTTPRedirectf(t TestingT, handler http.HandlerFunc, method string, url stri
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return HTTPSuccess(t, handler, method, url, values, append([]interface{}{msg}, args...)...) return HTTPSuccess(t, handler, method, url, values, append([]interface{}{msg}, args...)...)
} }
@ -166,6 +223,9 @@ func HTTPSuccessf(t TestingT, handler http.HandlerFunc, method string, url strin
// //
// assert.Implementsf(t, (*MyInterface, "error message %s", "formatted")(nil), new(MyObject)) // assert.Implementsf(t, (*MyInterface, "error message %s", "formatted")(nil), new(MyObject))
func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool { func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Implements(t, interfaceObject, object, append([]interface{}{msg}, args...)...) return Implements(t, interfaceObject, object, append([]interface{}{msg}, args...)...)
} }
@ -173,31 +233,49 @@ func Implementsf(t TestingT, interfaceObject interface{}, object interface{}, ms
// //
// assert.InDeltaf(t, math.Pi, (22 / 7.0, "error message %s", "formatted"), 0.01) // assert.InDeltaf(t, math.Pi, (22 / 7.0, "error message %s", "formatted"), 0.01)
func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { func InDeltaf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return InDelta(t, expected, actual, delta, append([]interface{}{msg}, args...)...) return InDelta(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
} }
// InDeltaMapValuesf is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys. // InDeltaMapValuesf is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys.
func InDeltaMapValuesf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { func InDeltaMapValuesf(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return InDeltaMapValues(t, expected, actual, delta, append([]interface{}{msg}, args...)...) return InDeltaMapValues(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
} }
// InDeltaSlicef is the same as InDelta, except it compares two slices. // InDeltaSlicef is the same as InDelta, except it compares two slices.
func InDeltaSlicef(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { func InDeltaSlicef(t TestingT, expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return InDeltaSlice(t, expected, actual, delta, append([]interface{}{msg}, args...)...) return InDeltaSlice(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
} }
// InEpsilonf asserts that expected and actual have a relative error less than epsilon // InEpsilonf asserts that expected and actual have a relative error less than epsilon
func InEpsilonf(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { func InEpsilonf(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return InEpsilon(t, expected, actual, epsilon, append([]interface{}{msg}, args...)...) return InEpsilon(t, expected, actual, epsilon, append([]interface{}{msg}, args...)...)
} }
// InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices. // InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices.
func InEpsilonSlicef(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { func InEpsilonSlicef(t TestingT, expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return InEpsilonSlice(t, expected, actual, epsilon, append([]interface{}{msg}, args...)...) return InEpsilonSlice(t, expected, actual, epsilon, append([]interface{}{msg}, args...)...)
} }
// IsTypef asserts that the specified objects are of the same type. // IsTypef asserts that the specified objects are of the same type.
func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool { func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return IsType(t, expectedType, object, append([]interface{}{msg}, args...)...) return IsType(t, expectedType, object, append([]interface{}{msg}, args...)...)
} }
@ -205,6 +283,9 @@ func IsTypef(t TestingT, expectedType interface{}, object interface{}, msg strin
// //
// assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") // assert.JSONEqf(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted")
func JSONEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool { func JSONEqf(t TestingT, expected string, actual string, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return JSONEq(t, expected, actual, append([]interface{}{msg}, args...)...) return JSONEq(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
@ -213,6 +294,9 @@ func JSONEqf(t TestingT, expected string, actual string, msg string, args ...int
// //
// assert.Lenf(t, mySlice, 3, "error message %s", "formatted") // assert.Lenf(t, mySlice, 3, "error message %s", "formatted")
func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) bool { func Lenf(t TestingT, object interface{}, length int, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Len(t, object, length, append([]interface{}{msg}, args...)...) return Len(t, object, length, append([]interface{}{msg}, args...)...)
} }
@ -220,6 +304,9 @@ func Lenf(t TestingT, object interface{}, length int, msg string, args ...interf
// //
// assert.Nilf(t, err, "error message %s", "formatted") // assert.Nilf(t, err, "error message %s", "formatted")
func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool { func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Nil(t, object, append([]interface{}{msg}, args...)...) return Nil(t, object, append([]interface{}{msg}, args...)...)
} }
@ -230,6 +317,9 @@ func Nilf(t TestingT, object interface{}, msg string, args ...interface{}) bool
// assert.Equal(t, expectedObj, actualObj) // assert.Equal(t, expectedObj, actualObj)
// } // }
func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool { func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NoError(t, err, append([]interface{}{msg}, args...)...) return NoError(t, err, append([]interface{}{msg}, args...)...)
} }
@ -240,6 +330,9 @@ func NoErrorf(t TestingT, err error, msg string, args ...interface{}) bool {
// assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted") // assert.NotContainsf(t, ["Hello", "World"], "Earth", "error message %s", "formatted")
// assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted") // assert.NotContainsf(t, {"Hello": "World"}, "Earth", "error message %s", "formatted")
func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool { func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotContains(t, s, contains, append([]interface{}{msg}, args...)...) return NotContains(t, s, contains, append([]interface{}{msg}, args...)...)
} }
@ -250,6 +343,9 @@ func NotContainsf(t TestingT, s interface{}, contains interface{}, msg string, a
// assert.Equal(t, "two", obj[1]) // assert.Equal(t, "two", obj[1])
// } // }
func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool { func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotEmpty(t, object, append([]interface{}{msg}, args...)...) return NotEmpty(t, object, append([]interface{}{msg}, args...)...)
} }
@ -260,6 +356,9 @@ func NotEmptyf(t TestingT, object interface{}, msg string, args ...interface{})
// Pointer variable equality is determined based on the equality of the // Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses). // referenced values (as opposed to the memory addresses).
func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool { func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotEqual(t, expected, actual, append([]interface{}{msg}, args...)...) return NotEqual(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
@ -267,6 +366,9 @@ func NotEqualf(t TestingT, expected interface{}, actual interface{}, msg string,
// //
// assert.NotNilf(t, err, "error message %s", "formatted") // assert.NotNilf(t, err, "error message %s", "formatted")
func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bool { func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotNil(t, object, append([]interface{}{msg}, args...)...) return NotNil(t, object, append([]interface{}{msg}, args...)...)
} }
@ -274,6 +376,9 @@ func NotNilf(t TestingT, object interface{}, msg string, args ...interface{}) bo
// //
// assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted") // assert.NotPanicsf(t, func(){ RemainCalm() }, "error message %s", "formatted")
func NotPanicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool { func NotPanicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotPanics(t, f, append([]interface{}{msg}, args...)...) return NotPanics(t, f, append([]interface{}{msg}, args...)...)
} }
@ -282,6 +387,9 @@ func NotPanicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bo
// assert.NotRegexpf(t, regexp.MustCompile("starts", "error message %s", "formatted"), "it's starting") // assert.NotRegexpf(t, regexp.MustCompile("starts", "error message %s", "formatted"), "it's starting")
// assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted") // assert.NotRegexpf(t, "^start", "it's not starting", "error message %s", "formatted")
func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool { func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotRegexp(t, rx, str, append([]interface{}{msg}, args...)...) return NotRegexp(t, rx, str, append([]interface{}{msg}, args...)...)
} }
@ -290,11 +398,17 @@ func NotRegexpf(t TestingT, rx interface{}, str interface{}, msg string, args ..
// //
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted")
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotSubset(t, list, subset, append([]interface{}{msg}, args...)...) return NotSubset(t, list, subset, append([]interface{}{msg}, args...)...)
} }
// NotZerof asserts that i is not the zero value for its type. // NotZerof asserts that i is not the zero value for its type.
func NotZerof(t TestingT, i interface{}, msg string, args ...interface{}) bool { func NotZerof(t TestingT, i interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotZero(t, i, append([]interface{}{msg}, args...)...) return NotZero(t, i, append([]interface{}{msg}, args...)...)
} }
@ -302,6 +416,9 @@ func NotZerof(t TestingT, i interface{}, msg string, args ...interface{}) bool {
// //
// assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted") // assert.Panicsf(t, func(){ GoCrazy() }, "error message %s", "formatted")
func Panicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool { func Panicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Panics(t, f, append([]interface{}{msg}, args...)...) return Panics(t, f, append([]interface{}{msg}, args...)...)
} }
@ -310,6 +427,9 @@ func Panicsf(t TestingT, f PanicTestFunc, msg string, args ...interface{}) bool
// //
// assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted") // assert.PanicsWithValuef(t, "crazy error", func(){ GoCrazy() }, "error message %s", "formatted")
func PanicsWithValuef(t TestingT, expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool { func PanicsWithValuef(t TestingT, expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return PanicsWithValue(t, expected, f, append([]interface{}{msg}, args...)...) return PanicsWithValue(t, expected, f, append([]interface{}{msg}, args...)...)
} }
@ -318,6 +438,9 @@ func PanicsWithValuef(t TestingT, expected interface{}, f PanicTestFunc, msg str
// assert.Regexpf(t, regexp.MustCompile("start", "error message %s", "formatted"), "it's starting") // assert.Regexpf(t, regexp.MustCompile("start", "error message %s", "formatted"), "it's starting")
// assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted") // assert.Regexpf(t, "start...$", "it's not starting", "error message %s", "formatted")
func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool { func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Regexp(t, rx, str, append([]interface{}{msg}, args...)...) return Regexp(t, rx, str, append([]interface{}{msg}, args...)...)
} }
@ -326,6 +449,9 @@ func Regexpf(t TestingT, rx interface{}, str interface{}, msg string, args ...in
// //
// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted")
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Subset(t, list, subset, append([]interface{}{msg}, args...)...) return Subset(t, list, subset, append([]interface{}{msg}, args...)...)
} }
@ -333,6 +459,9 @@ func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args
// //
// assert.Truef(t, myBool, "error message %s", "formatted") // assert.Truef(t, myBool, "error message %s", "formatted")
func Truef(t TestingT, value bool, msg string, args ...interface{}) bool { func Truef(t TestingT, value bool, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return True(t, value, append([]interface{}{msg}, args...)...) return True(t, value, append([]interface{}{msg}, args...)...)
} }
@ -340,10 +469,16 @@ func Truef(t TestingT, value bool, msg string, args ...interface{}) bool {
// //
// assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") // assert.WithinDurationf(t, time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted")
func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool { func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...) return WithinDuration(t, expected, actual, delta, append([]interface{}{msg}, args...)...)
} }
// Zerof asserts that i is the zero value for its type. // Zerof asserts that i is the zero value for its type.
func Zerof(t TestingT, i interface{}, msg string, args ...interface{}) bool { func Zerof(t TestingT, i interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return Zero(t, i, append([]interface{}{msg}, args...)...) return Zero(t, i, append([]interface{}{msg}, args...)...)
} }

View File

@ -1,4 +1,5 @@
{{.CommentFormat}} {{.CommentFormat}}
func {{.DocInfo.Name}}f(t TestingT, {{.ParamsFormat}}) bool { func {{.DocInfo.Name}}f(t TestingT, {{.ParamsFormat}}) bool {
if h, ok := t.(tHelper); ok { h.Helper() }
return {{.DocInfo.Name}}(t, {{.ForwardedParamsFormat}}) return {{.DocInfo.Name}}(t, {{.ForwardedParamsFormat}})
} }

View File

@ -13,11 +13,17 @@ import (
// Condition uses a Comparison to assert a complex condition. // Condition uses a Comparison to assert a complex condition.
func (a *Assertions) Condition(comp Comparison, msgAndArgs ...interface{}) bool { func (a *Assertions) Condition(comp Comparison, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Condition(a.t, comp, msgAndArgs...) return Condition(a.t, comp, msgAndArgs...)
} }
// Conditionf uses a Comparison to assert a complex condition. // Conditionf uses a Comparison to assert a complex condition.
func (a *Assertions) Conditionf(comp Comparison, msg string, args ...interface{}) bool { func (a *Assertions) Conditionf(comp Comparison, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Conditionf(a.t, comp, msg, args...) return Conditionf(a.t, comp, msg, args...)
} }
@ -28,6 +34,9 @@ func (a *Assertions) Conditionf(comp Comparison, msg string, args ...interface{}
// a.Contains(["Hello", "World"], "World") // a.Contains(["Hello", "World"], "World")
// a.Contains({"Hello": "World"}, "Hello") // a.Contains({"Hello": "World"}, "Hello")
func (a *Assertions) Contains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Contains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Contains(a.t, s, contains, msgAndArgs...) return Contains(a.t, s, contains, msgAndArgs...)
} }
@ -38,16 +47,25 @@ func (a *Assertions) Contains(s interface{}, contains interface{}, msgAndArgs ..
// a.Containsf(["Hello", "World"], "World", "error message %s", "formatted") // a.Containsf(["Hello", "World"], "World", "error message %s", "formatted")
// a.Containsf({"Hello": "World"}, "Hello", "error message %s", "formatted") // a.Containsf({"Hello": "World"}, "Hello", "error message %s", "formatted")
func (a *Assertions) Containsf(s interface{}, contains interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Containsf(s interface{}, contains interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Containsf(a.t, s, contains, msg, args...) return Containsf(a.t, s, contains, msg, args...)
} }
// DirExists checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists. // DirExists checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists.
func (a *Assertions) DirExists(path string, msgAndArgs ...interface{}) bool { func (a *Assertions) DirExists(path string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return DirExists(a.t, path, msgAndArgs...) return DirExists(a.t, path, msgAndArgs...)
} }
// DirExistsf checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists. // DirExistsf checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists.
func (a *Assertions) DirExistsf(path string, msg string, args ...interface{}) bool { func (a *Assertions) DirExistsf(path string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return DirExistsf(a.t, path, msg, args...) return DirExistsf(a.t, path, msg, args...)
} }
@ -57,6 +75,9 @@ func (a *Assertions) DirExistsf(path string, msg string, args ...interface{}) bo
// //
// a.ElementsMatch([1, 3, 2, 3], [1, 3, 3, 2]) // a.ElementsMatch([1, 3, 2, 3], [1, 3, 3, 2])
func (a *Assertions) ElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) ElementsMatch(listA interface{}, listB interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ElementsMatch(a.t, listA, listB, msgAndArgs...) return ElementsMatch(a.t, listA, listB, msgAndArgs...)
} }
@ -66,6 +87,9 @@ func (a *Assertions) ElementsMatch(listA interface{}, listB interface{}, msgAndA
// //
// a.ElementsMatchf([1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted") // a.ElementsMatchf([1, 3, 2, 3], [1, 3, 3, 2], "error message %s", "formatted")
func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) bool { func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return ElementsMatchf(a.t, listA, listB, msg, args...) return ElementsMatchf(a.t, listA, listB, msg, args...)
} }
@ -74,6 +98,9 @@ func (a *Assertions) ElementsMatchf(listA interface{}, listB interface{}, msg st
// //
// a.Empty(obj) // a.Empty(obj)
func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Empty(a.t, object, msgAndArgs...) return Empty(a.t, object, msgAndArgs...)
} }
@ -82,6 +109,9 @@ func (a *Assertions) Empty(object interface{}, msgAndArgs ...interface{}) bool {
// //
// a.Emptyf(obj, "error message %s", "formatted") // a.Emptyf(obj, "error message %s", "formatted")
func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Emptyf(a.t, object, msg, args...) return Emptyf(a.t, object, msg, args...)
} }
@ -93,6 +123,9 @@ func (a *Assertions) Emptyf(object interface{}, msg string, args ...interface{})
// referenced values (as opposed to the memory addresses). Function equality // referenced values (as opposed to the memory addresses). Function equality
// cannot be determined and will always fail. // cannot be determined and will always fail.
func (a *Assertions) Equal(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Equal(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Equal(a.t, expected, actual, msgAndArgs...) return Equal(a.t, expected, actual, msgAndArgs...)
} }
@ -102,6 +135,9 @@ func (a *Assertions) Equal(expected interface{}, actual interface{}, msgAndArgs
// actualObj, err := SomeFunction() // actualObj, err := SomeFunction()
// a.EqualError(err, expectedErrorString) // a.EqualError(err, expectedErrorString)
func (a *Assertions) EqualError(theError error, errString string, msgAndArgs ...interface{}) bool { func (a *Assertions) EqualError(theError error, errString string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return EqualError(a.t, theError, errString, msgAndArgs...) return EqualError(a.t, theError, errString, msgAndArgs...)
} }
@ -111,6 +147,9 @@ func (a *Assertions) EqualError(theError error, errString string, msgAndArgs ...
// actualObj, err := SomeFunction() // actualObj, err := SomeFunction()
// a.EqualErrorf(err, expectedErrorString, "error message %s", "formatted") // a.EqualErrorf(err, expectedErrorString, "error message %s", "formatted")
func (a *Assertions) EqualErrorf(theError error, errString string, msg string, args ...interface{}) bool { func (a *Assertions) EqualErrorf(theError error, errString string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return EqualErrorf(a.t, theError, errString, msg, args...) return EqualErrorf(a.t, theError, errString, msg, args...)
} }
@ -119,6 +158,9 @@ func (a *Assertions) EqualErrorf(theError error, errString string, msg string, a
// //
// a.EqualValues(uint32(123), int32(123)) // a.EqualValues(uint32(123), int32(123))
func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return EqualValues(a.t, expected, actual, msgAndArgs...) return EqualValues(a.t, expected, actual, msgAndArgs...)
} }
@ -127,6 +169,9 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
// //
// a.EqualValuesf(uint32(123, "error message %s", "formatted"), int32(123)) // a.EqualValuesf(uint32(123, "error message %s", "formatted"), int32(123))
func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return EqualValuesf(a.t, expected, actual, msg, args...) return EqualValuesf(a.t, expected, actual, msg, args...)
} }
@ -138,6 +183,9 @@ func (a *Assertions) EqualValuesf(expected interface{}, actual interface{}, msg
// referenced values (as opposed to the memory addresses). Function equality // referenced values (as opposed to the memory addresses). Function equality
// cannot be determined and will always fail. // cannot be determined and will always fail.
func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Equalf(a.t, expected, actual, msg, args...) return Equalf(a.t, expected, actual, msg, args...)
} }
@ -148,6 +196,9 @@ func (a *Assertions) Equalf(expected interface{}, actual interface{}, msg string
// assert.Equal(t, expectedError, err) // assert.Equal(t, expectedError, err)
// } // }
func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool { func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Error(a.t, err, msgAndArgs...) return Error(a.t, err, msgAndArgs...)
} }
@ -158,6 +209,9 @@ func (a *Assertions) Error(err error, msgAndArgs ...interface{}) bool {
// assert.Equal(t, expectedErrorf, err) // assert.Equal(t, expectedErrorf, err)
// } // }
func (a *Assertions) Errorf(err error, msg string, args ...interface{}) bool { func (a *Assertions) Errorf(err error, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Errorf(a.t, err, msg, args...) return Errorf(a.t, err, msg, args...)
} }
@ -165,6 +219,9 @@ func (a *Assertions) Errorf(err error, msg string, args ...interface{}) bool {
// //
// a.Exactly(int32(123), int64(123)) // a.Exactly(int32(123), int64(123))
func (a *Assertions) Exactly(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Exactly(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Exactly(a.t, expected, actual, msgAndArgs...) return Exactly(a.t, expected, actual, msgAndArgs...)
} }
@ -172,26 +229,41 @@ func (a *Assertions) Exactly(expected interface{}, actual interface{}, msgAndArg
// //
// a.Exactlyf(int32(123, "error message %s", "formatted"), int64(123)) // a.Exactlyf(int32(123, "error message %s", "formatted"), int64(123))
func (a *Assertions) Exactlyf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Exactlyf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Exactlyf(a.t, expected, actual, msg, args...) return Exactlyf(a.t, expected, actual, msg, args...)
} }
// Fail reports a failure through // Fail reports a failure through
func (a *Assertions) Fail(failureMessage string, msgAndArgs ...interface{}) bool { func (a *Assertions) Fail(failureMessage string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Fail(a.t, failureMessage, msgAndArgs...) return Fail(a.t, failureMessage, msgAndArgs...)
} }
// FailNow fails test // FailNow fails test
func (a *Assertions) FailNow(failureMessage string, msgAndArgs ...interface{}) bool { func (a *Assertions) FailNow(failureMessage string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return FailNow(a.t, failureMessage, msgAndArgs...) return FailNow(a.t, failureMessage, msgAndArgs...)
} }
// FailNowf fails test // FailNowf fails test
func (a *Assertions) FailNowf(failureMessage string, msg string, args ...interface{}) bool { func (a *Assertions) FailNowf(failureMessage string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return FailNowf(a.t, failureMessage, msg, args...) return FailNowf(a.t, failureMessage, msg, args...)
} }
// Failf reports a failure through // Failf reports a failure through
func (a *Assertions) Failf(failureMessage string, msg string, args ...interface{}) bool { func (a *Assertions) Failf(failureMessage string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Failf(a.t, failureMessage, msg, args...) return Failf(a.t, failureMessage, msg, args...)
} }
@ -199,6 +271,9 @@ func (a *Assertions) Failf(failureMessage string, msg string, args ...interface{
// //
// a.False(myBool) // a.False(myBool)
func (a *Assertions) False(value bool, msgAndArgs ...interface{}) bool { func (a *Assertions) False(value bool, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return False(a.t, value, msgAndArgs...) return False(a.t, value, msgAndArgs...)
} }
@ -206,56 +281,77 @@ func (a *Assertions) False(value bool, msgAndArgs ...interface{}) bool {
// //
// a.Falsef(myBool, "error message %s", "formatted") // a.Falsef(myBool, "error message %s", "formatted")
func (a *Assertions) Falsef(value bool, msg string, args ...interface{}) bool { func (a *Assertions) Falsef(value bool, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Falsef(a.t, value, msg, args...) return Falsef(a.t, value, msg, args...)
} }
// FileExists checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file. // FileExists checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file.
func (a *Assertions) FileExists(path string, msgAndArgs ...interface{}) bool { func (a *Assertions) FileExists(path string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return FileExists(a.t, path, msgAndArgs...) return FileExists(a.t, path, msgAndArgs...)
} }
// FileExistsf checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file. // FileExistsf checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file.
func (a *Assertions) FileExistsf(path string, msg string, args ...interface{}) bool { func (a *Assertions) FileExistsf(path string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return FileExistsf(a.t, path, msg, args...) return FileExistsf(a.t, path, msg, args...)
} }
// HTTPBodyContains asserts that a specified handler returns a // HTTPBodyContains asserts that a specified handler returns a
// body that contains a string. // body that contains a string.
// //
// a.HTTPBodyContains(myHandler, "www.google.com", nil, "I'm Feeling Lucky") // a.HTTPBodyContains(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPBodyContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) HTTPBodyContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPBodyContains(a.t, handler, method, url, values, str, msgAndArgs...) return HTTPBodyContains(a.t, handler, method, url, values, str, msgAndArgs...)
} }
// HTTPBodyContainsf asserts that a specified handler returns a // HTTPBodyContainsf asserts that a specified handler returns a
// body that contains a string. // body that contains a string.
// //
// a.HTTPBodyContainsf(myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // a.HTTPBodyContainsf(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPBodyContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { func (a *Assertions) HTTPBodyContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPBodyContainsf(a.t, handler, method, url, values, str, msg, args...) return HTTPBodyContainsf(a.t, handler, method, url, values, str, msg, args...)
} }
// HTTPBodyNotContains asserts that a specified handler returns a // HTTPBodyNotContains asserts that a specified handler returns a
// body that does not contain a string. // body that does not contain a string.
// //
// a.HTTPBodyNotContains(myHandler, "www.google.com", nil, "I'm Feeling Lucky") // a.HTTPBodyNotContains(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPBodyNotContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) HTTPBodyNotContains(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPBodyNotContains(a.t, handler, method, url, values, str, msgAndArgs...) return HTTPBodyNotContains(a.t, handler, method, url, values, str, msgAndArgs...)
} }
// HTTPBodyNotContainsf asserts that a specified handler returns a // HTTPBodyNotContainsf asserts that a specified handler returns a
// body that does not contain a string. // body that does not contain a string.
// //
// a.HTTPBodyNotContainsf(myHandler, "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted") // a.HTTPBodyNotContainsf(myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky", "error message %s", "formatted")
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPBodyNotContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool { func (a *Assertions) HTTPBodyNotContainsf(handler http.HandlerFunc, method string, url string, values url.Values, str interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPBodyNotContainsf(a.t, handler, method, url, values, str, msg, args...) return HTTPBodyNotContainsf(a.t, handler, method, url, values, str, msg, args...)
} }
@ -265,6 +361,9 @@ func (a *Assertions) HTTPBodyNotContainsf(handler http.HandlerFunc, method strin
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPError(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool { func (a *Assertions) HTTPError(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPError(a.t, handler, method, url, values, msgAndArgs...) return HTTPError(a.t, handler, method, url, values, msgAndArgs...)
} }
@ -274,6 +373,9 @@ func (a *Assertions) HTTPError(handler http.HandlerFunc, method string, url stri
// //
// Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false).
func (a *Assertions) HTTPErrorf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { func (a *Assertions) HTTPErrorf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPErrorf(a.t, handler, method, url, values, msg, args...) return HTTPErrorf(a.t, handler, method, url, values, msg, args...)
} }
@ -283,6 +385,9 @@ func (a *Assertions) HTTPErrorf(handler http.HandlerFunc, method string, url str
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPRedirect(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool { func (a *Assertions) HTTPRedirect(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPRedirect(a.t, handler, method, url, values, msgAndArgs...) return HTTPRedirect(a.t, handler, method, url, values, msgAndArgs...)
} }
@ -292,6 +397,9 @@ func (a *Assertions) HTTPRedirect(handler http.HandlerFunc, method string, url s
// //
// Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false). // Returns whether the assertion was successful (true, "error message %s", "formatted") or not (false).
func (a *Assertions) HTTPRedirectf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { func (a *Assertions) HTTPRedirectf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPRedirectf(a.t, handler, method, url, values, msg, args...) return HTTPRedirectf(a.t, handler, method, url, values, msg, args...)
} }
@ -301,6 +409,9 @@ func (a *Assertions) HTTPRedirectf(handler http.HandlerFunc, method string, url
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPSuccess(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool { func (a *Assertions) HTTPSuccess(handler http.HandlerFunc, method string, url string, values url.Values, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPSuccess(a.t, handler, method, url, values, msgAndArgs...) return HTTPSuccess(a.t, handler, method, url, values, msgAndArgs...)
} }
@ -310,6 +421,9 @@ func (a *Assertions) HTTPSuccess(handler http.HandlerFunc, method string, url st
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func (a *Assertions) HTTPSuccessf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool { func (a *Assertions) HTTPSuccessf(handler http.HandlerFunc, method string, url string, values url.Values, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return HTTPSuccessf(a.t, handler, method, url, values, msg, args...) return HTTPSuccessf(a.t, handler, method, url, values, msg, args...)
} }
@ -317,6 +431,9 @@ func (a *Assertions) HTTPSuccessf(handler http.HandlerFunc, method string, url s
// //
// a.Implements((*MyInterface)(nil), new(MyObject)) // a.Implements((*MyInterface)(nil), new(MyObject))
func (a *Assertions) Implements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Implements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Implements(a.t, interfaceObject, object, msgAndArgs...) return Implements(a.t, interfaceObject, object, msgAndArgs...)
} }
@ -324,6 +441,9 @@ func (a *Assertions) Implements(interfaceObject interface{}, object interface{},
// //
// a.Implementsf((*MyInterface, "error message %s", "formatted")(nil), new(MyObject)) // a.Implementsf((*MyInterface, "error message %s", "formatted")(nil), new(MyObject))
func (a *Assertions) Implementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Implementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Implementsf(a.t, interfaceObject, object, msg, args...) return Implementsf(a.t, interfaceObject, object, msg, args...)
} }
@ -331,26 +451,41 @@ func (a *Assertions) Implementsf(interfaceObject interface{}, object interface{}
// //
// a.InDelta(math.Pi, (22 / 7.0), 0.01) // a.InDelta(math.Pi, (22 / 7.0), 0.01)
func (a *Assertions) InDelta(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { func (a *Assertions) InDelta(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return InDelta(a.t, expected, actual, delta, msgAndArgs...) return InDelta(a.t, expected, actual, delta, msgAndArgs...)
} }
// InDeltaMapValues is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys. // InDeltaMapValues is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys.
func (a *Assertions) InDeltaMapValues(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { func (a *Assertions) InDeltaMapValues(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return InDeltaMapValues(a.t, expected, actual, delta, msgAndArgs...) return InDeltaMapValues(a.t, expected, actual, delta, msgAndArgs...)
} }
// InDeltaMapValuesf is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys. // InDeltaMapValuesf is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys.
func (a *Assertions) InDeltaMapValuesf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { func (a *Assertions) InDeltaMapValuesf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return InDeltaMapValuesf(a.t, expected, actual, delta, msg, args...) return InDeltaMapValuesf(a.t, expected, actual, delta, msg, args...)
} }
// InDeltaSlice is the same as InDelta, except it compares two slices. // InDeltaSlice is the same as InDelta, except it compares two slices.
func (a *Assertions) InDeltaSlice(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { func (a *Assertions) InDeltaSlice(expected interface{}, actual interface{}, delta float64, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return InDeltaSlice(a.t, expected, actual, delta, msgAndArgs...) return InDeltaSlice(a.t, expected, actual, delta, msgAndArgs...)
} }
// InDeltaSlicef is the same as InDelta, except it compares two slices. // InDeltaSlicef is the same as InDelta, except it compares two slices.
func (a *Assertions) InDeltaSlicef(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { func (a *Assertions) InDeltaSlicef(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return InDeltaSlicef(a.t, expected, actual, delta, msg, args...) return InDeltaSlicef(a.t, expected, actual, delta, msg, args...)
} }
@ -358,36 +493,57 @@ func (a *Assertions) InDeltaSlicef(expected interface{}, actual interface{}, del
// //
// a.InDeltaf(math.Pi, (22 / 7.0, "error message %s", "formatted"), 0.01) // a.InDeltaf(math.Pi, (22 / 7.0, "error message %s", "formatted"), 0.01)
func (a *Assertions) InDeltaf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool { func (a *Assertions) InDeltaf(expected interface{}, actual interface{}, delta float64, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return InDeltaf(a.t, expected, actual, delta, msg, args...) return InDeltaf(a.t, expected, actual, delta, msg, args...)
} }
// InEpsilon asserts that expected and actual have a relative error less than epsilon // InEpsilon asserts that expected and actual have a relative error less than epsilon
func (a *Assertions) InEpsilon(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { func (a *Assertions) InEpsilon(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return InEpsilon(a.t, expected, actual, epsilon, msgAndArgs...) return InEpsilon(a.t, expected, actual, epsilon, msgAndArgs...)
} }
// InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices. // InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices.
func (a *Assertions) InEpsilonSlice(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { func (a *Assertions) InEpsilonSlice(expected interface{}, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return InEpsilonSlice(a.t, expected, actual, epsilon, msgAndArgs...) return InEpsilonSlice(a.t, expected, actual, epsilon, msgAndArgs...)
} }
// InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices. // InEpsilonSlicef is the same as InEpsilon, except it compares each value from two slices.
func (a *Assertions) InEpsilonSlicef(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { func (a *Assertions) InEpsilonSlicef(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return InEpsilonSlicef(a.t, expected, actual, epsilon, msg, args...) return InEpsilonSlicef(a.t, expected, actual, epsilon, msg, args...)
} }
// InEpsilonf asserts that expected and actual have a relative error less than epsilon // InEpsilonf asserts that expected and actual have a relative error less than epsilon
func (a *Assertions) InEpsilonf(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool { func (a *Assertions) InEpsilonf(expected interface{}, actual interface{}, epsilon float64, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return InEpsilonf(a.t, expected, actual, epsilon, msg, args...) return InEpsilonf(a.t, expected, actual, epsilon, msg, args...)
} }
// IsType asserts that the specified objects are of the same type. // IsType asserts that the specified objects are of the same type.
func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) IsType(expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsType(a.t, expectedType, object, msgAndArgs...) return IsType(a.t, expectedType, object, msgAndArgs...)
} }
// IsTypef asserts that the specified objects are of the same type. // IsTypef asserts that the specified objects are of the same type.
func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) bool { func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return IsTypef(a.t, expectedType, object, msg, args...) return IsTypef(a.t, expectedType, object, msg, args...)
} }
@ -395,6 +551,9 @@ func (a *Assertions) IsTypef(expectedType interface{}, object interface{}, msg s
// //
// a.JSONEq(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) // a.JSONEq(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`)
func (a *Assertions) JSONEq(expected string, actual string, msgAndArgs ...interface{}) bool { func (a *Assertions) JSONEq(expected string, actual string, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return JSONEq(a.t, expected, actual, msgAndArgs...) return JSONEq(a.t, expected, actual, msgAndArgs...)
} }
@ -402,6 +561,9 @@ func (a *Assertions) JSONEq(expected string, actual string, msgAndArgs ...interf
// //
// a.JSONEqf(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted") // a.JSONEqf(`{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`, "error message %s", "formatted")
func (a *Assertions) JSONEqf(expected string, actual string, msg string, args ...interface{}) bool { func (a *Assertions) JSONEqf(expected string, actual string, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return JSONEqf(a.t, expected, actual, msg, args...) return JSONEqf(a.t, expected, actual, msg, args...)
} }
@ -410,6 +572,9 @@ func (a *Assertions) JSONEqf(expected string, actual string, msg string, args ..
// //
// a.Len(mySlice, 3) // a.Len(mySlice, 3)
func (a *Assertions) Len(object interface{}, length int, msgAndArgs ...interface{}) bool { func (a *Assertions) Len(object interface{}, length int, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Len(a.t, object, length, msgAndArgs...) return Len(a.t, object, length, msgAndArgs...)
} }
@ -418,6 +583,9 @@ func (a *Assertions) Len(object interface{}, length int, msgAndArgs ...interface
// //
// a.Lenf(mySlice, 3, "error message %s", "formatted") // a.Lenf(mySlice, 3, "error message %s", "formatted")
func (a *Assertions) Lenf(object interface{}, length int, msg string, args ...interface{}) bool { func (a *Assertions) Lenf(object interface{}, length int, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Lenf(a.t, object, length, msg, args...) return Lenf(a.t, object, length, msg, args...)
} }
@ -425,6 +593,9 @@ func (a *Assertions) Lenf(object interface{}, length int, msg string, args ...in
// //
// a.Nil(err) // a.Nil(err)
func (a *Assertions) Nil(object interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Nil(object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Nil(a.t, object, msgAndArgs...) return Nil(a.t, object, msgAndArgs...)
} }
@ -432,6 +603,9 @@ func (a *Assertions) Nil(object interface{}, msgAndArgs ...interface{}) bool {
// //
// a.Nilf(err, "error message %s", "formatted") // a.Nilf(err, "error message %s", "formatted")
func (a *Assertions) Nilf(object interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Nilf(object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Nilf(a.t, object, msg, args...) return Nilf(a.t, object, msg, args...)
} }
@ -442,6 +616,9 @@ func (a *Assertions) Nilf(object interface{}, msg string, args ...interface{}) b
// assert.Equal(t, expectedObj, actualObj) // assert.Equal(t, expectedObj, actualObj)
// } // }
func (a *Assertions) NoError(err error, msgAndArgs ...interface{}) bool { func (a *Assertions) NoError(err error, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NoError(a.t, err, msgAndArgs...) return NoError(a.t, err, msgAndArgs...)
} }
@ -452,6 +629,9 @@ func (a *Assertions) NoError(err error, msgAndArgs ...interface{}) bool {
// assert.Equal(t, expectedObj, actualObj) // assert.Equal(t, expectedObj, actualObj)
// } // }
func (a *Assertions) NoErrorf(err error, msg string, args ...interface{}) bool { func (a *Assertions) NoErrorf(err error, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NoErrorf(a.t, err, msg, args...) return NoErrorf(a.t, err, msg, args...)
} }
@ -462,6 +642,9 @@ func (a *Assertions) NoErrorf(err error, msg string, args ...interface{}) bool {
// a.NotContains(["Hello", "World"], "Earth") // a.NotContains(["Hello", "World"], "Earth")
// a.NotContains({"Hello": "World"}, "Earth") // a.NotContains({"Hello": "World"}, "Earth")
func (a *Assertions) NotContains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) NotContains(s interface{}, contains interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotContains(a.t, s, contains, msgAndArgs...) return NotContains(a.t, s, contains, msgAndArgs...)
} }
@ -472,6 +655,9 @@ func (a *Assertions) NotContains(s interface{}, contains interface{}, msgAndArgs
// a.NotContainsf(["Hello", "World"], "Earth", "error message %s", "formatted") // a.NotContainsf(["Hello", "World"], "Earth", "error message %s", "formatted")
// a.NotContainsf({"Hello": "World"}, "Earth", "error message %s", "formatted") // a.NotContainsf({"Hello": "World"}, "Earth", "error message %s", "formatted")
func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg string, args ...interface{}) bool { func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotContainsf(a.t, s, contains, msg, args...) return NotContainsf(a.t, s, contains, msg, args...)
} }
@ -482,6 +668,9 @@ func (a *Assertions) NotContainsf(s interface{}, contains interface{}, msg strin
// assert.Equal(t, "two", obj[1]) // assert.Equal(t, "two", obj[1])
// } // }
func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotEmpty(a.t, object, msgAndArgs...) return NotEmpty(a.t, object, msgAndArgs...)
} }
@ -492,6 +681,9 @@ func (a *Assertions) NotEmpty(object interface{}, msgAndArgs ...interface{}) boo
// assert.Equal(t, "two", obj[1]) // assert.Equal(t, "two", obj[1])
// } // }
func (a *Assertions) NotEmptyf(object interface{}, msg string, args ...interface{}) bool { func (a *Assertions) NotEmptyf(object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotEmptyf(a.t, object, msg, args...) return NotEmptyf(a.t, object, msg, args...)
} }
@ -502,6 +694,9 @@ func (a *Assertions) NotEmptyf(object interface{}, msg string, args ...interface
// Pointer variable equality is determined based on the equality of the // Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses). // referenced values (as opposed to the memory addresses).
func (a *Assertions) NotEqual(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) NotEqual(expected interface{}, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotEqual(a.t, expected, actual, msgAndArgs...) return NotEqual(a.t, expected, actual, msgAndArgs...)
} }
@ -512,6 +707,9 @@ func (a *Assertions) NotEqual(expected interface{}, actual interface{}, msgAndAr
// Pointer variable equality is determined based on the equality of the // Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses). // referenced values (as opposed to the memory addresses).
func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool { func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotEqualf(a.t, expected, actual, msg, args...) return NotEqualf(a.t, expected, actual, msg, args...)
} }
@ -519,6 +717,9 @@ func (a *Assertions) NotEqualf(expected interface{}, actual interface{}, msg str
// //
// a.NotNil(err) // a.NotNil(err)
func (a *Assertions) NotNil(object interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) NotNil(object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotNil(a.t, object, msgAndArgs...) return NotNil(a.t, object, msgAndArgs...)
} }
@ -526,6 +727,9 @@ func (a *Assertions) NotNil(object interface{}, msgAndArgs ...interface{}) bool
// //
// a.NotNilf(err, "error message %s", "formatted") // a.NotNilf(err, "error message %s", "formatted")
func (a *Assertions) NotNilf(object interface{}, msg string, args ...interface{}) bool { func (a *Assertions) NotNilf(object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotNilf(a.t, object, msg, args...) return NotNilf(a.t, object, msg, args...)
} }
@ -533,6 +737,9 @@ func (a *Assertions) NotNilf(object interface{}, msg string, args ...interface{}
// //
// a.NotPanics(func(){ RemainCalm() }) // a.NotPanics(func(){ RemainCalm() })
func (a *Assertions) NotPanics(f PanicTestFunc, msgAndArgs ...interface{}) bool { func (a *Assertions) NotPanics(f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotPanics(a.t, f, msgAndArgs...) return NotPanics(a.t, f, msgAndArgs...)
} }
@ -540,6 +747,9 @@ func (a *Assertions) NotPanics(f PanicTestFunc, msgAndArgs ...interface{}) bool
// //
// a.NotPanicsf(func(){ RemainCalm() }, "error message %s", "formatted") // a.NotPanicsf(func(){ RemainCalm() }, "error message %s", "formatted")
func (a *Assertions) NotPanicsf(f PanicTestFunc, msg string, args ...interface{}) bool { func (a *Assertions) NotPanicsf(f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotPanicsf(a.t, f, msg, args...) return NotPanicsf(a.t, f, msg, args...)
} }
@ -548,6 +758,9 @@ func (a *Assertions) NotPanicsf(f PanicTestFunc, msg string, args ...interface{}
// a.NotRegexp(regexp.MustCompile("starts"), "it's starting") // a.NotRegexp(regexp.MustCompile("starts"), "it's starting")
// a.NotRegexp("^start", "it's not starting") // a.NotRegexp("^start", "it's not starting")
func (a *Assertions) NotRegexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) NotRegexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotRegexp(a.t, rx, str, msgAndArgs...) return NotRegexp(a.t, rx, str, msgAndArgs...)
} }
@ -556,6 +769,9 @@ func (a *Assertions) NotRegexp(rx interface{}, str interface{}, msgAndArgs ...in
// a.NotRegexpf(regexp.MustCompile("starts", "error message %s", "formatted"), "it's starting") // a.NotRegexpf(regexp.MustCompile("starts", "error message %s", "formatted"), "it's starting")
// a.NotRegexpf("^start", "it's not starting", "error message %s", "formatted") // a.NotRegexpf("^start", "it's not starting", "error message %s", "formatted")
func (a *Assertions) NotRegexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool { func (a *Assertions) NotRegexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotRegexpf(a.t, rx, str, msg, args...) return NotRegexpf(a.t, rx, str, msg, args...)
} }
@ -564,6 +780,9 @@ func (a *Assertions) NotRegexpf(rx interface{}, str interface{}, msg string, arg
// //
// a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]")
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotSubset(a.t, list, subset, msgAndArgs...) return NotSubset(a.t, list, subset, msgAndArgs...)
} }
@ -572,16 +791,25 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
// //
// a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotSubsetf(a.t, list, subset, msg, args...) return NotSubsetf(a.t, list, subset, msg, args...)
} }
// NotZero asserts that i is not the zero value for its type. // NotZero asserts that i is not the zero value for its type.
func (a *Assertions) NotZero(i interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) NotZero(i interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotZero(a.t, i, msgAndArgs...) return NotZero(a.t, i, msgAndArgs...)
} }
// NotZerof asserts that i is not the zero value for its type. // NotZerof asserts that i is not the zero value for its type.
func (a *Assertions) NotZerof(i interface{}, msg string, args ...interface{}) bool { func (a *Assertions) NotZerof(i interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotZerof(a.t, i, msg, args...) return NotZerof(a.t, i, msg, args...)
} }
@ -589,6 +817,9 @@ func (a *Assertions) NotZerof(i interface{}, msg string, args ...interface{}) bo
// //
// a.Panics(func(){ GoCrazy() }) // a.Panics(func(){ GoCrazy() })
func (a *Assertions) Panics(f PanicTestFunc, msgAndArgs ...interface{}) bool { func (a *Assertions) Panics(f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Panics(a.t, f, msgAndArgs...) return Panics(a.t, f, msgAndArgs...)
} }
@ -597,6 +828,9 @@ func (a *Assertions) Panics(f PanicTestFunc, msgAndArgs ...interface{}) bool {
// //
// a.PanicsWithValue("crazy error", func(){ GoCrazy() }) // a.PanicsWithValue("crazy error", func(){ GoCrazy() })
func (a *Assertions) PanicsWithValue(expected interface{}, f PanicTestFunc, msgAndArgs ...interface{}) bool { func (a *Assertions) PanicsWithValue(expected interface{}, f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return PanicsWithValue(a.t, expected, f, msgAndArgs...) return PanicsWithValue(a.t, expected, f, msgAndArgs...)
} }
@ -605,6 +839,9 @@ func (a *Assertions) PanicsWithValue(expected interface{}, f PanicTestFunc, msgA
// //
// a.PanicsWithValuef("crazy error", func(){ GoCrazy() }, "error message %s", "formatted") // a.PanicsWithValuef("crazy error", func(){ GoCrazy() }, "error message %s", "formatted")
func (a *Assertions) PanicsWithValuef(expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool { func (a *Assertions) PanicsWithValuef(expected interface{}, f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return PanicsWithValuef(a.t, expected, f, msg, args...) return PanicsWithValuef(a.t, expected, f, msg, args...)
} }
@ -612,6 +849,9 @@ func (a *Assertions) PanicsWithValuef(expected interface{}, f PanicTestFunc, msg
// //
// a.Panicsf(func(){ GoCrazy() }, "error message %s", "formatted") // a.Panicsf(func(){ GoCrazy() }, "error message %s", "formatted")
func (a *Assertions) Panicsf(f PanicTestFunc, msg string, args ...interface{}) bool { func (a *Assertions) Panicsf(f PanicTestFunc, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Panicsf(a.t, f, msg, args...) return Panicsf(a.t, f, msg, args...)
} }
@ -620,6 +860,9 @@ func (a *Assertions) Panicsf(f PanicTestFunc, msg string, args ...interface{}) b
// a.Regexp(regexp.MustCompile("start"), "it's starting") // a.Regexp(regexp.MustCompile("start"), "it's starting")
// a.Regexp("start...$", "it's not starting") // a.Regexp("start...$", "it's not starting")
func (a *Assertions) Regexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Regexp(rx interface{}, str interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Regexp(a.t, rx, str, msgAndArgs...) return Regexp(a.t, rx, str, msgAndArgs...)
} }
@ -628,6 +871,9 @@ func (a *Assertions) Regexp(rx interface{}, str interface{}, msgAndArgs ...inter
// a.Regexpf(regexp.MustCompile("start", "error message %s", "formatted"), "it's starting") // a.Regexpf(regexp.MustCompile("start", "error message %s", "formatted"), "it's starting")
// a.Regexpf("start...$", "it's not starting", "error message %s", "formatted") // a.Regexpf("start...$", "it's not starting", "error message %s", "formatted")
func (a *Assertions) Regexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Regexpf(rx interface{}, str interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Regexpf(a.t, rx, str, msg, args...) return Regexpf(a.t, rx, str, msg, args...)
} }
@ -636,6 +882,9 @@ func (a *Assertions) Regexpf(rx interface{}, str interface{}, msg string, args .
// //
// a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]")
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Subset(a.t, list, subset, msgAndArgs...) return Subset(a.t, list, subset, msgAndArgs...)
} }
@ -644,6 +893,9 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
// //
// a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Subsetf(a.t, list, subset, msg, args...) return Subsetf(a.t, list, subset, msg, args...)
} }
@ -651,6 +903,9 @@ func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, a
// //
// a.True(myBool) // a.True(myBool)
func (a *Assertions) True(value bool, msgAndArgs ...interface{}) bool { func (a *Assertions) True(value bool, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return True(a.t, value, msgAndArgs...) return True(a.t, value, msgAndArgs...)
} }
@ -658,6 +913,9 @@ func (a *Assertions) True(value bool, msgAndArgs ...interface{}) bool {
// //
// a.Truef(myBool, "error message %s", "formatted") // a.Truef(myBool, "error message %s", "formatted")
func (a *Assertions) Truef(value bool, msg string, args ...interface{}) bool { func (a *Assertions) Truef(value bool, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Truef(a.t, value, msg, args...) return Truef(a.t, value, msg, args...)
} }
@ -665,6 +923,9 @@ func (a *Assertions) Truef(value bool, msg string, args ...interface{}) bool {
// //
// a.WithinDuration(time.Now(), time.Now(), 10*time.Second) // a.WithinDuration(time.Now(), time.Now(), 10*time.Second)
func (a *Assertions) WithinDuration(expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool { func (a *Assertions) WithinDuration(expected time.Time, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return WithinDuration(a.t, expected, actual, delta, msgAndArgs...) return WithinDuration(a.t, expected, actual, delta, msgAndArgs...)
} }
@ -672,15 +933,24 @@ func (a *Assertions) WithinDuration(expected time.Time, actual time.Time, delta
// //
// a.WithinDurationf(time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted") // a.WithinDurationf(time.Now(), time.Now(), 10*time.Second, "error message %s", "formatted")
func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool { func (a *Assertions) WithinDurationf(expected time.Time, actual time.Time, delta time.Duration, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return WithinDurationf(a.t, expected, actual, delta, msg, args...) return WithinDurationf(a.t, expected, actual, delta, msg, args...)
} }
// Zero asserts that i is the zero value for its type. // Zero asserts that i is the zero value for its type.
func (a *Assertions) Zero(i interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Zero(i interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Zero(a.t, i, msgAndArgs...) return Zero(a.t, i, msgAndArgs...)
} }
// Zerof asserts that i is the zero value for its type. // Zerof asserts that i is the zero value for its type.
func (a *Assertions) Zerof(i interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Zerof(i interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return Zerof(a.t, i, msg, args...) return Zerof(a.t, i, msg, args...)
} }

View File

@ -1,4 +1,5 @@
{{.CommentWithoutT "a"}} {{.CommentWithoutT "a"}}
func (a *Assertions) {{.DocInfo.Name}}({{.Params}}) bool { func (a *Assertions) {{.DocInfo.Name}}({{.Params}}) bool {
if h, ok := a.t.(tHelper); ok { h.Helper() }
return {{.DocInfo.Name}}(a.t, {{.ForwardedParams}}) return {{.DocInfo.Name}}(a.t, {{.ForwardedParams}})
} }

View File

@ -27,6 +27,22 @@ type TestingT interface {
Errorf(format string, args ...interface{}) Errorf(format string, args ...interface{})
} }
// ComparisonAssertionFunc is a common function prototype when comparing two values. Can be useful
// for table driven tests.
type ComparisonAssertionFunc func(TestingT, interface{}, interface{}, ...interface{}) bool
// ValueAssertionFunc is a common function prototype when validating a single value. Can be useful
// for table driven tests.
type ValueAssertionFunc func(TestingT, interface{}, ...interface{}) bool
// BoolAssertionFunc is a common function prototype when validating a bool value. Can be useful
// for table driven tests.
type BoolAssertionFunc func(TestingT, bool, ...interface{}) bool
// ErrorAssertionFunc is a common function prototype when validating an error value. Can be useful
// for table driven tests.
type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool
// Comparison a custom function that returns true on success and false on failure // Comparison a custom function that returns true on success and false on failure
type Comparison func() (success bool) type Comparison func() (success bool)
@ -38,21 +54,23 @@ type Comparison func() (success bool)
// //
// This function does no assertion of any kind. // This function does no assertion of any kind.
func ObjectsAreEqual(expected, actual interface{}) bool { func ObjectsAreEqual(expected, actual interface{}) bool {
if expected == nil || actual == nil { if expected == nil || actual == nil {
return expected == actual return expected == actual
} }
if exp, ok := expected.([]byte); ok {
act, ok := actual.([]byte)
if !ok {
return false
} else if exp == nil || act == nil {
return exp == nil && act == nil
}
return bytes.Equal(exp, act)
}
return reflect.DeepEqual(expected, actual)
exp, ok := expected.([]byte)
if !ok {
return reflect.DeepEqual(expected, actual)
}
act, ok := actual.([]byte)
if !ok {
return false
}
if exp == nil || act == nil {
return exp == nil && act == nil
}
return bytes.Equal(exp, act)
} }
// ObjectsAreEqualValues gets whether two objects are equal, or if their // ObjectsAreEqualValues gets whether two objects are equal, or if their
@ -156,27 +174,16 @@ func isTest(name, prefix string) bool {
return !unicode.IsLower(rune) return !unicode.IsLower(rune)
} }
// getWhitespaceString returns a string that is long enough to overwrite the default
// output from the go testing framework.
func getWhitespaceString() string {
_, file, line, ok := runtime.Caller(1)
if !ok {
return ""
}
parts := strings.Split(file, "/")
file = parts[len(parts)-1]
return strings.Repeat(" ", len(fmt.Sprintf("%s:%d: ", file, line)))
}
func messageFromMsgAndArgs(msgAndArgs ...interface{}) string { func messageFromMsgAndArgs(msgAndArgs ...interface{}) string {
if len(msgAndArgs) == 0 || msgAndArgs == nil { if len(msgAndArgs) == 0 || msgAndArgs == nil {
return "" return ""
} }
if len(msgAndArgs) == 1 { if len(msgAndArgs) == 1 {
return msgAndArgs[0].(string) msg := msgAndArgs[0]
if msgAsStr, ok := msg.(string); ok {
return msgAsStr
}
return fmt.Sprintf("%+v", msg)
} }
if len(msgAndArgs) > 1 { if len(msgAndArgs) > 1 {
return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...)
@ -195,7 +202,7 @@ func indentMessageLines(message string, longestLabelLen int) string {
// no need to align first line because it starts at the correct location (after the label) // no need to align first line because it starts at the correct location (after the label)
if i != 0 { if i != 0 {
// append alignLen+1 spaces to align with "{{longestLabel}}:" before adding tab // append alignLen+1 spaces to align with "{{longestLabel}}:" before adding tab
outBuf.WriteString("\n\r\t" + strings.Repeat(" ", longestLabelLen+1) + "\t") outBuf.WriteString("\n\t" + strings.Repeat(" ", longestLabelLen+1) + "\t")
} }
outBuf.WriteString(scanner.Text()) outBuf.WriteString(scanner.Text())
} }
@ -209,6 +216,9 @@ type failNower interface {
// FailNow fails test // FailNow fails test
func FailNow(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool { func FailNow(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
Fail(t, failureMessage, msgAndArgs...) Fail(t, failureMessage, msgAndArgs...)
// We cannot extend TestingT with FailNow() and // We cannot extend TestingT with FailNow() and
@ -227,8 +237,11 @@ func FailNow(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool
// Fail reports a failure through // Fail reports a failure through
func Fail(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool { func Fail(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
content := []labeledContent{ content := []labeledContent{
{"Error Trace", strings.Join(CallerInfo(), "\n\r\t\t\t")}, {"Error Trace", strings.Join(CallerInfo(), "\n\t\t\t")},
{"Error", failureMessage}, {"Error", failureMessage},
} }
@ -244,7 +257,7 @@ func Fail(t TestingT, failureMessage string, msgAndArgs ...interface{}) bool {
content = append(content, labeledContent{"Messages", message}) content = append(content, labeledContent{"Messages", message})
} }
t.Errorf("%s", "\r"+getWhitespaceString()+labeledOutput(content...)) t.Errorf("\n%s", ""+labeledOutput(content...))
return false return false
} }
@ -256,7 +269,7 @@ type labeledContent struct {
// labeledOutput returns a string consisting of the provided labeledContent. Each labeled output is appended in the following manner: // labeledOutput returns a string consisting of the provided labeledContent. Each labeled output is appended in the following manner:
// //
// \r\t{{label}}:{{align_spaces}}\t{{content}}\n // \t{{label}}:{{align_spaces}}\t{{content}}\n
// //
// The initial carriage return is required to undo/erase any padding added by testing.T.Errorf. The "\t{{label}}:" is for the label. // The initial carriage return is required to undo/erase any padding added by testing.T.Errorf. The "\t{{label}}:" is for the label.
// If a label is shorter than the longest label provided, padding spaces are added to make all the labels match in length. Once this // If a label is shorter than the longest label provided, padding spaces are added to make all the labels match in length. Once this
@ -272,7 +285,7 @@ func labeledOutput(content ...labeledContent) string {
} }
var output string var output string
for _, v := range content { for _, v := range content {
output += "\r\t" + v.label + ":" + strings.Repeat(" ", longestLabel-len(v.label)) + "\t" + indentMessageLines(v.content, longestLabel) + "\n" output += "\t" + v.label + ":" + strings.Repeat(" ", longestLabel-len(v.label)) + "\t" + indentMessageLines(v.content, longestLabel) + "\n"
} }
return output return output
} }
@ -281,6 +294,9 @@ func labeledOutput(content ...labeledContent) string {
// //
// assert.Implements(t, (*MyInterface)(nil), new(MyObject)) // assert.Implements(t, (*MyInterface)(nil), new(MyObject))
func Implements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool { func Implements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
interfaceType := reflect.TypeOf(interfaceObject).Elem() interfaceType := reflect.TypeOf(interfaceObject).Elem()
if object == nil { if object == nil {
@ -295,6 +311,9 @@ func Implements(t TestingT, interfaceObject interface{}, object interface{}, msg
// IsType asserts that the specified objects are of the same type. // IsType asserts that the specified objects are of the same type.
func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) { if !ObjectsAreEqual(reflect.TypeOf(object), reflect.TypeOf(expectedType)) {
return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...) return Fail(t, fmt.Sprintf("Object expected to be of type %v, but was %v", reflect.TypeOf(expectedType), reflect.TypeOf(object)), msgAndArgs...)
@ -311,6 +330,9 @@ func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs
// referenced values (as opposed to the memory addresses). Function equality // referenced values (as opposed to the memory addresses). Function equality
// cannot be determined and will always fail. // cannot be determined and will always fail.
func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { func Equal(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if err := validateEqualArgs(expected, actual); err != nil { if err := validateEqualArgs(expected, actual); err != nil {
return Fail(t, fmt.Sprintf("Invalid operation: %#v == %#v (%s)", return Fail(t, fmt.Sprintf("Invalid operation: %#v == %#v (%s)",
expected, actual, err), msgAndArgs...) expected, actual, err), msgAndArgs...)
@ -349,6 +371,9 @@ func formatUnequalValues(expected, actual interface{}) (e string, a string) {
// //
// assert.EqualValues(t, uint32(123), int32(123)) // assert.EqualValues(t, uint32(123), int32(123))
func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !ObjectsAreEqualValues(expected, actual) { if !ObjectsAreEqualValues(expected, actual) {
diff := diff(expected, actual) diff := diff(expected, actual)
@ -366,12 +391,15 @@ func EqualValues(t TestingT, expected, actual interface{}, msgAndArgs ...interfa
// //
// assert.Exactly(t, int32(123), int64(123)) // assert.Exactly(t, int32(123), int64(123))
func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
aType := reflect.TypeOf(expected) aType := reflect.TypeOf(expected)
bType := reflect.TypeOf(actual) bType := reflect.TypeOf(actual)
if aType != bType { if aType != bType {
return Fail(t, fmt.Sprintf("Types expected to match exactly\n\r\t%v != %v", aType, bType), msgAndArgs...) return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...)
} }
return Equal(t, expected, actual, msgAndArgs...) return Equal(t, expected, actual, msgAndArgs...)
@ -382,12 +410,26 @@ func Exactly(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}
// //
// assert.NotNil(t, err) // assert.NotNil(t, err)
func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !isNil(object) { if !isNil(object) {
return true return true
} }
return Fail(t, "Expected value not to be nil.", msgAndArgs...) return Fail(t, "Expected value not to be nil.", msgAndArgs...)
} }
// containsKind checks if a specified kind in the slice of kinds.
func containsKind(kinds []reflect.Kind, kind reflect.Kind) bool {
for i := 0; i < len(kinds); i++ {
if kind == kinds[i] {
return true
}
}
return false
}
// isNil checks if a specified object is nil or not, without Failing. // isNil checks if a specified object is nil or not, without Failing.
func isNil(object interface{}) bool { func isNil(object interface{}) bool {
if object == nil { if object == nil {
@ -396,7 +438,14 @@ func isNil(object interface{}) bool {
value := reflect.ValueOf(object) value := reflect.ValueOf(object)
kind := value.Kind() kind := value.Kind()
if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() { isNilableKind := containsKind(
[]reflect.Kind{
reflect.Chan, reflect.Func,
reflect.Interface, reflect.Map,
reflect.Ptr, reflect.Slice},
kind)
if isNilableKind && value.IsNil() {
return true return true
} }
@ -407,6 +456,9 @@ func isNil(object interface{}) bool {
// //
// assert.Nil(t, err) // assert.Nil(t, err)
func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func Nil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if isNil(object) { if isNil(object) {
return true return true
} }
@ -446,6 +498,9 @@ func isEmpty(object interface{}) bool {
// //
// assert.Empty(t, obj) // assert.Empty(t, obj)
func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
pass := isEmpty(object) pass := isEmpty(object)
if !pass { if !pass {
@ -463,6 +518,9 @@ func Empty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
// assert.Equal(t, "two", obj[1]) // assert.Equal(t, "two", obj[1])
// } // }
func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
pass := !isEmpty(object) pass := !isEmpty(object)
if !pass { if !pass {
@ -490,6 +548,9 @@ func getLen(x interface{}) (ok bool, length int) {
// //
// assert.Len(t, mySlice, 3) // assert.Len(t, mySlice, 3)
func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) bool { func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
ok, l := getLen(object) ok, l := getLen(object)
if !ok { if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", object), msgAndArgs...) return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", object), msgAndArgs...)
@ -505,6 +566,14 @@ func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{})
// //
// assert.True(t, myBool) // assert.True(t, myBool)
func True(t TestingT, value bool, msgAndArgs ...interface{}) bool { func True(t TestingT, value bool, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if h, ok := t.(interface {
Helper()
}); ok {
h.Helper()
}
if value != true { if value != true {
return Fail(t, "Should be true", msgAndArgs...) return Fail(t, "Should be true", msgAndArgs...)
@ -518,6 +587,9 @@ func True(t TestingT, value bool, msgAndArgs ...interface{}) bool {
// //
// assert.False(t, myBool) // assert.False(t, myBool)
func False(t TestingT, value bool, msgAndArgs ...interface{}) bool { func False(t TestingT, value bool, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if value != false { if value != false {
return Fail(t, "Should be false", msgAndArgs...) return Fail(t, "Should be false", msgAndArgs...)
@ -534,6 +606,9 @@ func False(t TestingT, value bool, msgAndArgs ...interface{}) bool {
// Pointer variable equality is determined based on the equality of the // Pointer variable equality is determined based on the equality of the
// referenced values (as opposed to the memory addresses). // referenced values (as opposed to the memory addresses).
func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool { func NotEqual(t TestingT, expected, actual interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if err := validateEqualArgs(expected, actual); err != nil { if err := validateEqualArgs(expected, actual); err != nil {
return Fail(t, fmt.Sprintf("Invalid operation: %#v != %#v (%s)", return Fail(t, fmt.Sprintf("Invalid operation: %#v != %#v (%s)",
expected, actual, err), msgAndArgs...) expected, actual, err), msgAndArgs...)
@ -592,6 +667,9 @@ func includeElement(list interface{}, element interface{}) (ok, found bool) {
// assert.Contains(t, ["Hello", "World"], "World") // assert.Contains(t, ["Hello", "World"], "World")
// assert.Contains(t, {"Hello": "World"}, "Hello") // assert.Contains(t, {"Hello": "World"}, "Hello")
func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
ok, found := includeElement(s, contains) ok, found := includeElement(s, contains)
if !ok { if !ok {
@ -612,6 +690,9 @@ func Contains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bo
// assert.NotContains(t, ["Hello", "World"], "Earth") // assert.NotContains(t, ["Hello", "World"], "Earth")
// assert.NotContains(t, {"Hello": "World"}, "Earth") // assert.NotContains(t, {"Hello": "World"}, "Earth")
func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool { func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
ok, found := includeElement(s, contains) ok, found := includeElement(s, contains)
if !ok { if !ok {
@ -630,6 +711,9 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
// //
// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]")
func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if subset == nil { if subset == nil {
return true // we consider nil to be equal to the nil set return true // we consider nil to be equal to the nil set
} }
@ -671,6 +755,9 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
// //
// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]")
func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if subset == nil { if subset == nil {
return Fail(t, fmt.Sprintf("nil is the empty set which is a subset of every set"), msgAndArgs...) return Fail(t, fmt.Sprintf("nil is the empty set which is a subset of every set"), msgAndArgs...)
} }
@ -713,6 +800,9 @@ func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{})
// //
// assert.ElementsMatch(t, [1, 3, 2, 3], [1, 3, 3, 2]) // assert.ElementsMatch(t, [1, 3, 2, 3], [1, 3, 3, 2])
func ElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) { func ElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if isEmpty(listA) && isEmpty(listB) { if isEmpty(listA) && isEmpty(listB) {
return true return true
} }
@ -763,6 +853,9 @@ func ElementsMatch(t TestingT, listA, listB interface{}, msgAndArgs ...interface
// Condition uses a Comparison to assert a complex condition. // Condition uses a Comparison to assert a complex condition.
func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool { func Condition(t TestingT, comp Comparison, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
result := comp() result := comp()
if !result { if !result {
Fail(t, "Condition failed!", msgAndArgs...) Fail(t, "Condition failed!", msgAndArgs...)
@ -800,9 +893,12 @@ func didPanic(f PanicTestFunc) (bool, interface{}) {
// //
// assert.Panics(t, func(){ GoCrazy() }) // assert.Panics(t, func(){ GoCrazy() })
func Panics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { func Panics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if funcDidPanic, panicValue := didPanic(f); !funcDidPanic { if funcDidPanic, panicValue := didPanic(f); !funcDidPanic {
return Fail(t, fmt.Sprintf("func %#v should panic\n\r\tPanic value:\t%v", f, panicValue), msgAndArgs...) return Fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", f, panicValue), msgAndArgs...)
} }
return true return true
@ -813,13 +909,16 @@ func Panics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool {
// //
// assert.PanicsWithValue(t, "crazy error", func(){ GoCrazy() }) // assert.PanicsWithValue(t, "crazy error", func(){ GoCrazy() })
func PanicsWithValue(t TestingT, expected interface{}, f PanicTestFunc, msgAndArgs ...interface{}) bool { func PanicsWithValue(t TestingT, expected interface{}, f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
funcDidPanic, panicValue := didPanic(f) funcDidPanic, panicValue := didPanic(f)
if !funcDidPanic { if !funcDidPanic {
return Fail(t, fmt.Sprintf("func %#v should panic\n\r\tPanic value:\t%v", f, panicValue), msgAndArgs...) return Fail(t, fmt.Sprintf("func %#v should panic\n\tPanic value:\t%#v", f, panicValue), msgAndArgs...)
} }
if panicValue != expected { if panicValue != expected {
return Fail(t, fmt.Sprintf("func %#v should panic with value:\t%v\n\r\tPanic value:\t%v", f, expected, panicValue), msgAndArgs...) return Fail(t, fmt.Sprintf("func %#v should panic with value:\t%#v\n\tPanic value:\t%#v", f, expected, panicValue), msgAndArgs...)
} }
return true return true
@ -829,9 +928,12 @@ func PanicsWithValue(t TestingT, expected interface{}, f PanicTestFunc, msgAndAr
// //
// assert.NotPanics(t, func(){ RemainCalm() }) // assert.NotPanics(t, func(){ RemainCalm() })
func NotPanics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool { func NotPanics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if funcDidPanic, panicValue := didPanic(f); funcDidPanic { if funcDidPanic, panicValue := didPanic(f); funcDidPanic {
return Fail(t, fmt.Sprintf("func %#v should not panic\n\r\tPanic value:\t%v", f, panicValue), msgAndArgs...) return Fail(t, fmt.Sprintf("func %#v should not panic\n\tPanic value:\t%v", f, panicValue), msgAndArgs...)
} }
return true return true
@ -841,6 +943,9 @@ func NotPanics(t TestingT, f PanicTestFunc, msgAndArgs ...interface{}) bool {
// //
// assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second) // assert.WithinDuration(t, time.Now(), time.Now(), 10*time.Second)
func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool { func WithinDuration(t TestingT, expected, actual time.Time, delta time.Duration, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
dt := expected.Sub(actual) dt := expected.Sub(actual)
if dt < -delta || dt > delta { if dt < -delta || dt > delta {
@ -890,6 +995,9 @@ func toFloat(x interface{}) (float64, bool) {
// //
// assert.InDelta(t, math.Pi, (22 / 7.0), 0.01) // assert.InDelta(t, math.Pi, (22 / 7.0), 0.01)
func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
af, aok := toFloat(expected) af, aok := toFloat(expected)
bf, bok := toFloat(actual) bf, bok := toFloat(actual)
@ -916,6 +1024,9 @@ func InDelta(t TestingT, expected, actual interface{}, delta float64, msgAndArgs
// InDeltaSlice is the same as InDelta, except it compares two slices. // InDeltaSlice is the same as InDelta, except it compares two slices.
func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if expected == nil || actual == nil || if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice || reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice { reflect.TypeOf(expected).Kind() != reflect.Slice {
@ -937,6 +1048,9 @@ func InDeltaSlice(t TestingT, expected, actual interface{}, delta float64, msgAn
// InDeltaMapValues is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys. // InDeltaMapValues is the same as InDelta, but it compares all values between two maps. Both maps must have exactly the same keys.
func InDeltaMapValues(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool { func InDeltaMapValues(t TestingT, expected, actual interface{}, delta float64, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if expected == nil || actual == nil || if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Map || reflect.TypeOf(actual).Kind() != reflect.Map ||
reflect.TypeOf(expected).Kind() != reflect.Map { reflect.TypeOf(expected).Kind() != reflect.Map {
@ -994,6 +1108,9 @@ func calcRelativeError(expected, actual interface{}) (float64, error) {
// InEpsilon asserts that expected and actual have a relative error less than epsilon // InEpsilon asserts that expected and actual have a relative error less than epsilon
func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
actualEpsilon, err := calcRelativeError(expected, actual) actualEpsilon, err := calcRelativeError(expected, actual)
if err != nil { if err != nil {
return Fail(t, err.Error(), msgAndArgs...) return Fail(t, err.Error(), msgAndArgs...)
@ -1008,6 +1125,9 @@ func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAnd
// InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices. // InEpsilonSlice is the same as InEpsilon, except it compares each value from two slices.
func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool { func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if expected == nil || actual == nil || if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice || reflect.TypeOf(actual).Kind() != reflect.Slice ||
reflect.TypeOf(expected).Kind() != reflect.Slice { reflect.TypeOf(expected).Kind() != reflect.Slice {
@ -1038,6 +1158,9 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
// assert.Equal(t, expectedObj, actualObj) // assert.Equal(t, expectedObj, actualObj)
// } // }
func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool { func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if err != nil { if err != nil {
return Fail(t, fmt.Sprintf("Received unexpected error:\n%+v", err), msgAndArgs...) return Fail(t, fmt.Sprintf("Received unexpected error:\n%+v", err), msgAndArgs...)
} }
@ -1052,6 +1175,9 @@ func NoError(t TestingT, err error, msgAndArgs ...interface{}) bool {
// assert.Equal(t, expectedError, err) // assert.Equal(t, expectedError, err)
// } // }
func Error(t TestingT, err error, msgAndArgs ...interface{}) bool { func Error(t TestingT, err error, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if err == nil { if err == nil {
return Fail(t, "An error is expected but got nil.", msgAndArgs...) return Fail(t, "An error is expected but got nil.", msgAndArgs...)
@ -1066,6 +1192,9 @@ func Error(t TestingT, err error, msgAndArgs ...interface{}) bool {
// actualObj, err := SomeFunction() // actualObj, err := SomeFunction()
// assert.EqualError(t, err, expectedErrorString) // assert.EqualError(t, err, expectedErrorString)
func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) bool { func EqualError(t TestingT, theError error, errString string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if !Error(t, theError, msgAndArgs...) { if !Error(t, theError, msgAndArgs...) {
return false return false
} }
@ -1099,6 +1228,9 @@ func matchRegexp(rx interface{}, str interface{}) bool {
// assert.Regexp(t, regexp.MustCompile("start"), "it's starting") // assert.Regexp(t, regexp.MustCompile("start"), "it's starting")
// assert.Regexp(t, "start...$", "it's not starting") // assert.Regexp(t, "start...$", "it's not starting")
func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
match := matchRegexp(rx, str) match := matchRegexp(rx, str)
@ -1114,6 +1246,9 @@ func Regexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface
// assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting") // assert.NotRegexp(t, regexp.MustCompile("starts"), "it's starting")
// assert.NotRegexp(t, "^start", "it's not starting") // assert.NotRegexp(t, "^start", "it's not starting")
func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool { func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
match := matchRegexp(rx, str) match := matchRegexp(rx, str)
if match { if match {
@ -1126,6 +1261,9 @@ func NotRegexp(t TestingT, rx interface{}, str interface{}, msgAndArgs ...interf
// Zero asserts that i is the zero value for its type. // Zero asserts that i is the zero value for its type.
func Zero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool { func Zero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if i != nil && !reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { if i != nil && !reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
return Fail(t, fmt.Sprintf("Should be zero, but was %v", i), msgAndArgs...) return Fail(t, fmt.Sprintf("Should be zero, but was %v", i), msgAndArgs...)
} }
@ -1134,6 +1272,9 @@ func Zero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool {
// NotZero asserts that i is not the zero value for its type. // NotZero asserts that i is not the zero value for its type.
func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool { func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
if i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) { if i == nil || reflect.DeepEqual(i, reflect.Zero(reflect.TypeOf(i)).Interface()) {
return Fail(t, fmt.Sprintf("Should not be zero, but was %v", i), msgAndArgs...) return Fail(t, fmt.Sprintf("Should not be zero, but was %v", i), msgAndArgs...)
} }
@ -1142,6 +1283,9 @@ func NotZero(t TestingT, i interface{}, msgAndArgs ...interface{}) bool {
// FileExists checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file. // FileExists checks whether a file exists in the given path. It also fails if the path points to a directory or there is an error when trying to check the file.
func FileExists(t TestingT, path string, msgAndArgs ...interface{}) bool { func FileExists(t TestingT, path string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
info, err := os.Lstat(path) info, err := os.Lstat(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -1157,6 +1301,9 @@ func FileExists(t TestingT, path string, msgAndArgs ...interface{}) bool {
// DirExists checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists. // DirExists checks whether a directory exists in the given path. It also fails if the path is a file rather a directory or there is an error checking whether it exists.
func DirExists(t TestingT, path string, msgAndArgs ...interface{}) bool { func DirExists(t TestingT, path string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
info, err := os.Lstat(path) info, err := os.Lstat(path)
if err != nil { if err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {
@ -1174,6 +1321,9 @@ func DirExists(t TestingT, path string, msgAndArgs ...interface{}) bool {
// //
// assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`) // assert.JSONEq(t, `{"hello": "world", "foo": "bar"}`, `{"foo": "bar", "hello": "world"}`)
func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) bool { func JSONEq(t TestingT, expected string, actual string, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
var expectedJSONAsInterface, actualJSONAsInterface interface{} var expectedJSONAsInterface, actualJSONAsInterface interface{}
if err := json.Unmarshal([]byte(expected), &expectedJSONAsInterface); err != nil { if err := json.Unmarshal([]byte(expected), &expectedJSONAsInterface); err != nil {
@ -1199,7 +1349,7 @@ func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
} }
// diff returns a diff of both values as long as both are of the same type and // diff returns a diff of both values as long as both are of the same type and
// are a struct, map, slice or array. Otherwise it returns an empty string. // are a struct, map, slice, array or string. Otherwise it returns an empty string.
func diff(expected interface{}, actual interface{}) string { func diff(expected interface{}, actual interface{}) string {
if expected == nil || actual == nil { if expected == nil || actual == nil {
return "" return ""
@ -1212,12 +1362,18 @@ func diff(expected interface{}, actual interface{}) string {
return "" return ""
} }
if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array { if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array && ek != reflect.String {
return "" return ""
} }
e := spewConfig.Sdump(expected) var e, a string
a := spewConfig.Sdump(actual) if et != reflect.TypeOf("") {
e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual)
} else {
e = expected.(string)
a = actual.(string)
}
diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{ diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
A: difflib.SplitLines(e), A: difflib.SplitLines(e),
@ -1254,3 +1410,7 @@ var spewConfig = spew.ConfigState{
DisableCapacities: true, DisableCapacities: true,
SortKeys: true, SortKeys: true,
} }
type tHelper interface {
Helper()
}

View File

@ -12,10 +12,11 @@ import (
// an error if building a new request fails. // an error if building a new request fails.
func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) { func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, err := http.NewRequest(method, url+"?"+values.Encode(), nil) req, err := http.NewRequest(method, url, nil)
if err != nil { if err != nil {
return -1, err return -1, err
} }
req.URL.RawQuery = values.Encode()
handler(w, req) handler(w, req)
return w.Code, nil return w.Code, nil
} }
@ -26,6 +27,9 @@ func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool { func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err))
@ -46,6 +50,9 @@ func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, value
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool { func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err))
@ -66,6 +73,9 @@ func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, valu
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool { func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err))
@ -95,10 +105,13 @@ func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) s
// HTTPBodyContains asserts that a specified handler returns a // HTTPBodyContains asserts that a specified handler returns a
// body that contains a string. // body that contains a string.
// //
// assert.HTTPBodyContains(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky") // assert.HTTPBodyContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
body := HTTPBody(handler, method, url, values) body := HTTPBody(handler, method, url, values)
contains := strings.Contains(body, fmt.Sprint(str)) contains := strings.Contains(body, fmt.Sprint(str))
@ -112,10 +125,13 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string,
// HTTPBodyNotContains asserts that a specified handler returns a // HTTPBodyNotContains asserts that a specified handler returns a
// body that does not contain a string. // body that does not contain a string.
// //
// assert.HTTPBodyNotContains(t, myHandler, "www.google.com", nil, "I'm Feeling Lucky") // assert.HTTPBodyNotContains(t, myHandler, "GET", "www.google.com", nil, "I'm Feeling Lucky")
// //
// Returns whether the assertion was successful (true) or not (false). // Returns whether the assertion was successful (true) or not (false).
func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool { func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url string, values url.Values, str interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
body := HTTPBody(handler, method, url, values) body := HTTPBody(handler, method, url, values)
contains := strings.Contains(body, fmt.Sprint(str)) contains := strings.Contains(body, fmt.Sprint(str))

View File

@ -3,6 +3,14 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package blowfish implements Bruce Schneier's Blowfish encryption algorithm. // Package blowfish implements Bruce Schneier's Blowfish encryption algorithm.
//
// Blowfish is a legacy cipher and its short block size makes it vulnerable to
// birthday bound attacks (see https://sweet32.info). It should only be used
// where compatibility with legacy systems, not security, is the goal.
//
// Deprecated: any new system should use AES (from crypto/aes, if necessary in
// an AEAD mode like crypto/cipher.NewGCM) or XChaCha20-Poly1305 (from
// golang.org/x/crypto/chacha20poly1305).
package blowfish // import "golang.org/x/crypto/blowfish" package blowfish // import "golang.org/x/crypto/blowfish"
// The code is a port of Bruce Schneier's C implementation. // The code is a port of Bruce Schneier's C implementation.

View File

@ -2,8 +2,15 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package cast5 implements CAST5, as defined in RFC 2144. CAST5 is a common // Package cast5 implements CAST5, as defined in RFC 2144.
// OpenPGP cipher. //
// CAST5 is a legacy cipher and its short block size makes it vulnerable to
// birthday bound attacks (see https://sweet32.info). It should only be used
// where compatibility with legacy systems, not security, is the goal.
//
// Deprecated: any new system should use AES (from crypto/aes, if necessary in
// an AEAD mode like crypto/cipher.NewGCM) or XChaCha20-Poly1305 (from
// golang.org/x/crypto/chacha20poly1305).
package cast5 // import "golang.org/x/crypto/cast5" package cast5 // import "golang.org/x/crypto/cast5"
import "errors" import "errors"

32
vendor/golang.org/x/crypto/internal/subtle/aliasing.go generated vendored Normal file
View File

@ -0,0 +1,32 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !appengine
// Package subtle implements functions that are often useful in cryptographic
// code but require careful thought to use correctly.
package subtle // import "golang.org/x/crypto/internal/subtle"
import "unsafe"
// AnyOverlap reports whether x and y share memory at any (not necessarily
// corresponding) index. The memory beyond the slice length is ignored.
func AnyOverlap(x, y []byte) bool {
return len(x) > 0 && len(y) > 0 &&
uintptr(unsafe.Pointer(&x[0])) <= uintptr(unsafe.Pointer(&y[len(y)-1])) &&
uintptr(unsafe.Pointer(&y[0])) <= uintptr(unsafe.Pointer(&x[len(x)-1]))
}
// InexactOverlap reports whether x and y share memory at any non-corresponding
// index. The memory beyond the slice length is ignored. Note that x and y can
// have different lengths and still not have any inexact overlap.
//
// InexactOverlap can be used to implement the requirements of the crypto/cipher
// AEAD, Block, BlockMode and Stream interfaces.
func InexactOverlap(x, y []byte) bool {
if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] {
return false
}
return AnyOverlap(x, y)
}

View File

@ -0,0 +1,35 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build appengine
// Package subtle implements functions that are often useful in cryptographic
// code but require careful thought to use correctly.
package subtle // import "golang.org/x/crypto/internal/subtle"
// This is the Google App Engine standard variant based on reflect
// because the unsafe package and cgo are disallowed.
import "reflect"
// AnyOverlap reports whether x and y share memory at any (not necessarily
// corresponding) index. The memory beyond the slice length is ignored.
func AnyOverlap(x, y []byte) bool {
return len(x) > 0 && len(y) > 0 &&
reflect.ValueOf(&x[0]).Pointer() <= reflect.ValueOf(&y[len(y)-1]).Pointer() &&
reflect.ValueOf(&y[0]).Pointer() <= reflect.ValueOf(&x[len(x)-1]).Pointer()
}
// InexactOverlap reports whether x and y share memory at any non-corresponding
// index. The memory beyond the slice length is ignored. Note that x and y can
// have different lengths and still not have any inexact overlap.
//
// InexactOverlap can be used to implement the requirements of the crypto/cipher
// AEAD, Block, BlockMode and Stream interfaces.
func InexactOverlap(x, y []byte) bool {
if len(x) == 0 || len(y) == 0 || &x[0] == &y[0] {
return false
}
return AnyOverlap(x, y)
}

View File

@ -24,6 +24,7 @@ package salsa20 // import "golang.org/x/crypto/salsa20"
// TODO(agl): implement XORKeyStream12 and XORKeyStream8 - the reduced round variants of Salsa20. // TODO(agl): implement XORKeyStream12 and XORKeyStream8 - the reduced round variants of Salsa20.
import ( import (
"golang.org/x/crypto/internal/subtle"
"golang.org/x/crypto/salsa20/salsa" "golang.org/x/crypto/salsa20/salsa"
) )
@ -32,7 +33,10 @@ import (
// be either 8 or 24 bytes long. // be either 8 or 24 bytes long.
func XORKeyStream(out, in []byte, nonce []byte, key *[32]byte) { func XORKeyStream(out, in []byte, nonce []byte, key *[32]byte) {
if len(out) < len(in) { if len(out) < len(in) {
in = in[:len(out)] panic("salsa20: output smaller than input")
}
if subtle.InexactOverlap(out[:len(in)], in) {
panic("salsa20: invalid buffer overlap")
} }
var subNonce [16]byte var subNonce [16]byte

View File

@ -5,6 +5,14 @@
// Package tea implements the TEA algorithm, as defined in Needham and // Package tea implements the TEA algorithm, as defined in Needham and
// Wheeler's 1994 technical report, “TEA, a Tiny Encryption Algorithm”. See // Wheeler's 1994 technical report, “TEA, a Tiny Encryption Algorithm”. See
// http://www.cix.co.uk/~klockstone/tea.pdf for details. // http://www.cix.co.uk/~klockstone/tea.pdf for details.
//
// TEA is a legacy cipher and its short block size makes it vulnerable to
// birthday bound attacks (see https://sweet32.info). It should only be used
// where compatibility with legacy systems, not security, is the goal.
//
// Deprecated: any new system should use AES (from crypto/aes, if necessary in
// an AEAD mode like crypto/cipher.NewGCM) or XChaCha20-Poly1305 (from
// golang.org/x/crypto/chacha20poly1305).
package tea package tea
import ( import (

View File

@ -3,6 +3,12 @@
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Package twofish implements Bruce Schneier's Twofish encryption algorithm. // Package twofish implements Bruce Schneier's Twofish encryption algorithm.
//
// Deprecated: Twofish is a legacy cipher and should not be used for new
// applications. Also, this package does not and will not provide an optimized
// implementation. Instead, use AES (from crypto/aes, if necessary in an AEAD
// mode like crypto/cipher.NewGCM) or XChaCha20-Poly1305 (from
// golang.org/x/crypto/chacha20poly1305).
package twofish // import "golang.org/x/crypto/twofish" package twofish // import "golang.org/x/crypto/twofish"
// Twofish is defined in https://www.schneier.com/paper-twofish-paper.pdf [TWOFISH] // Twofish is defined in https://www.schneier.com/paper-twofish-paper.pdf [TWOFISH]

View File

@ -4,6 +4,14 @@
// Package xtea implements XTEA encryption, as defined in Needham and Wheeler's // Package xtea implements XTEA encryption, as defined in Needham and Wheeler's
// 1997 technical report, "Tea extensions." // 1997 technical report, "Tea extensions."
//
// XTEA is a legacy cipher and its short block size makes it vulnerable to
// birthday bound attacks (see https://sweet32.info). It should only be used
// where compatibility with legacy systems, not security, is the goal.
//
// Deprecated: any new system should use AES (from crypto/aes, if necessary in
// an AEAD mode like crypto/cipher.NewGCM) or XChaCha20-Poly1305 (from
// golang.org/x/crypto/chacha20poly1305).
package xtea // import "golang.org/x/crypto/xtea" package xtea // import "golang.org/x/crypto/xtea"
// For details, see http://www.cix.co.uk/~klockstone/xtea.pdf // For details, see http://www.cix.co.uk/~klockstone/xtea.pdf

View File

@ -38,6 +38,7 @@ const (
type JumpTest uint16 type JumpTest uint16
// Supported operators for conditional jumps. // Supported operators for conditional jumps.
// K can be RegX for JumpIfX
const ( const (
// K == A // K == A
JumpEqual JumpTest = iota JumpEqual JumpTest = iota
@ -134,12 +135,9 @@ const (
opMaskLoadDest = 0x01 opMaskLoadDest = 0x01
opMaskLoadWidth = 0x18 opMaskLoadWidth = 0x18
opMaskLoadMode = 0xe0 opMaskLoadMode = 0xe0
// opClsALU // opClsALU & opClsJump
opMaskOperandSrc = 0x08 opMaskOperand = 0x08
opMaskOperator = 0xf0 opMaskOperator = 0xf0
// opClsJump
opMaskJumpConst = 0x0f
opMaskJumpCond = 0xf0
) )
const ( const (
@ -192,15 +190,21 @@ const (
opLoadWidth1 opLoadWidth1
) )
// Operator defined by ALUOp* // Operand for ALU and Jump instructions
type opOperand uint16
// Supported operand sources.
const ( const (
opALUSrcConstant uint16 = iota << 3 opOperandConstant opOperand = iota << 3
opALUSrcX opOperandX
) )
// An jumpOp is a conditional jump condition.
type jumpOp uint16
// Supported jump conditions.
const ( const (
opJumpAlways = iota << 4 opJumpAlways jumpOp = iota << 4
opJumpEqual opJumpEqual
opJumpGT opJumpGT
opJumpGE opJumpGE

View File

@ -89,10 +89,14 @@ func (ri RawInstruction) Disassemble() Instruction {
case opClsALU: case opClsALU:
switch op := ALUOp(ri.Op & opMaskOperator); op { switch op := ALUOp(ri.Op & opMaskOperator); op {
case ALUOpAdd, ALUOpSub, ALUOpMul, ALUOpDiv, ALUOpOr, ALUOpAnd, ALUOpShiftLeft, ALUOpShiftRight, ALUOpMod, ALUOpXor: case ALUOpAdd, ALUOpSub, ALUOpMul, ALUOpDiv, ALUOpOr, ALUOpAnd, ALUOpShiftLeft, ALUOpShiftRight, ALUOpMod, ALUOpXor:
if ri.Op&opMaskOperandSrc != 0 { switch operand := opOperand(ri.Op & opMaskOperand); operand {
case opOperandX:
return ALUOpX{Op: op} return ALUOpX{Op: op}
case opOperandConstant:
return ALUOpConstant{Op: op, Val: ri.K}
default:
return ri
} }
return ALUOpConstant{Op: op, Val: ri.K}
case aluOpNeg: case aluOpNeg:
return NegateA{} return NegateA{}
default: default:
@ -100,63 +104,18 @@ func (ri RawInstruction) Disassemble() Instruction {
} }
case opClsJump: case opClsJump:
if ri.Op&opMaskJumpConst != opClsJump { switch op := jumpOp(ri.Op & opMaskOperator); op {
return ri
}
switch ri.Op & opMaskJumpCond {
case opJumpAlways: case opJumpAlways:
return Jump{Skip: ri.K} return Jump{Skip: ri.K}
case opJumpEqual: case opJumpEqual, opJumpGT, opJumpGE, opJumpSet:
if ri.Jt == 0 { cond, skipTrue, skipFalse := jumpOpToTest(op, ri.Jt, ri.Jf)
return JumpIf{ switch operand := opOperand(ri.Op & opMaskOperand); operand {
Cond: JumpNotEqual, case opOperandX:
Val: ri.K, return JumpIfX{Cond: cond, SkipTrue: skipTrue, SkipFalse: skipFalse}
SkipTrue: ri.Jf, case opOperandConstant:
SkipFalse: 0, return JumpIf{Cond: cond, Val: ri.K, SkipTrue: skipTrue, SkipFalse: skipFalse}
} default:
} return ri
return JumpIf{
Cond: JumpEqual,
Val: ri.K,
SkipTrue: ri.Jt,
SkipFalse: ri.Jf,
}
case opJumpGT:
if ri.Jt == 0 {
return JumpIf{
Cond: JumpLessOrEqual,
Val: ri.K,
SkipTrue: ri.Jf,
SkipFalse: 0,
}
}
return JumpIf{
Cond: JumpGreaterThan,
Val: ri.K,
SkipTrue: ri.Jt,
SkipFalse: ri.Jf,
}
case opJumpGE:
if ri.Jt == 0 {
return JumpIf{
Cond: JumpLessThan,
Val: ri.K,
SkipTrue: ri.Jf,
SkipFalse: 0,
}
}
return JumpIf{
Cond: JumpGreaterOrEqual,
Val: ri.K,
SkipTrue: ri.Jt,
SkipFalse: ri.Jf,
}
case opJumpSet:
return JumpIf{
Cond: JumpBitsSet,
Val: ri.K,
SkipTrue: ri.Jt,
SkipFalse: ri.Jf,
} }
default: default:
return ri return ri
@ -187,6 +146,41 @@ func (ri RawInstruction) Disassemble() Instruction {
} }
} }
func jumpOpToTest(op jumpOp, skipTrue uint8, skipFalse uint8) (JumpTest, uint8, uint8) {
var test JumpTest
// Decode "fake" jump conditions that don't appear in machine code
// Ensures the Assemble -> Disassemble stage recreates the same instructions
// See https://github.com/golang/go/issues/18470
if skipTrue == 0 {
switch op {
case opJumpEqual:
test = JumpNotEqual
case opJumpGT:
test = JumpLessOrEqual
case opJumpGE:
test = JumpLessThan
case opJumpSet:
test = JumpBitsNotSet
}
return test, skipFalse, 0
}
switch op {
case opJumpEqual:
test = JumpEqual
case opJumpGT:
test = JumpGreaterThan
case opJumpGE:
test = JumpGreaterOrEqual
case opJumpSet:
test = JumpBitsSet
}
return test, skipTrue, skipFalse
}
// LoadConstant loads Val into register Dst. // LoadConstant loads Val into register Dst.
type LoadConstant struct { type LoadConstant struct {
Dst Register Dst Register
@ -413,7 +407,7 @@ type ALUOpConstant struct {
// Assemble implements the Instruction Assemble method. // Assemble implements the Instruction Assemble method.
func (a ALUOpConstant) Assemble() (RawInstruction, error) { func (a ALUOpConstant) Assemble() (RawInstruction, error) {
return RawInstruction{ return RawInstruction{
Op: opClsALU | opALUSrcConstant | uint16(a.Op), Op: opClsALU | uint16(opOperandConstant) | uint16(a.Op),
K: a.Val, K: a.Val,
}, nil }, nil
} }
@ -454,7 +448,7 @@ type ALUOpX struct {
// Assemble implements the Instruction Assemble method. // Assemble implements the Instruction Assemble method.
func (a ALUOpX) Assemble() (RawInstruction, error) { func (a ALUOpX) Assemble() (RawInstruction, error) {
return RawInstruction{ return RawInstruction{
Op: opClsALU | opALUSrcX | uint16(a.Op), Op: opClsALU | uint16(opOperandX) | uint16(a.Op),
}, nil }, nil
} }
@ -509,7 +503,7 @@ type Jump struct {
// Assemble implements the Instruction Assemble method. // Assemble implements the Instruction Assemble method.
func (a Jump) Assemble() (RawInstruction, error) { func (a Jump) Assemble() (RawInstruction, error) {
return RawInstruction{ return RawInstruction{
Op: opClsJump | opJumpAlways, Op: opClsJump | uint16(opJumpAlways),
K: a.Skip, K: a.Skip,
}, nil }, nil
} }
@ -530,11 +524,39 @@ type JumpIf struct {
// Assemble implements the Instruction Assemble method. // Assemble implements the Instruction Assemble method.
func (a JumpIf) Assemble() (RawInstruction, error) { func (a JumpIf) Assemble() (RawInstruction, error) {
return jumpToRaw(a.Cond, opOperandConstant, a.Val, a.SkipTrue, a.SkipFalse)
}
// String returns the instruction in assembler notation.
func (a JumpIf) String() string {
return jumpToString(a.Cond, fmt.Sprintf("#%d", a.Val), a.SkipTrue, a.SkipFalse)
}
// JumpIfX skips the following Skip instructions in the program if A
// <Cond> X is true.
type JumpIfX struct {
Cond JumpTest
SkipTrue uint8
SkipFalse uint8
}
// Assemble implements the Instruction Assemble method.
func (a JumpIfX) Assemble() (RawInstruction, error) {
return jumpToRaw(a.Cond, opOperandX, 0, a.SkipTrue, a.SkipFalse)
}
// String returns the instruction in assembler notation.
func (a JumpIfX) String() string {
return jumpToString(a.Cond, "x", a.SkipTrue, a.SkipFalse)
}
// jumpToRaw assembles a jump instruction into a RawInstruction
func jumpToRaw(test JumpTest, operand opOperand, k uint32, skipTrue, skipFalse uint8) (RawInstruction, error) {
var ( var (
cond uint16 cond jumpOp
flip bool flip bool
) )
switch a.Cond { switch test {
case JumpEqual: case JumpEqual:
cond = opJumpEqual cond = opJumpEqual
case JumpNotEqual: case JumpNotEqual:
@ -552,63 +574,63 @@ func (a JumpIf) Assemble() (RawInstruction, error) {
case JumpBitsNotSet: case JumpBitsNotSet:
cond, flip = opJumpSet, true cond, flip = opJumpSet, true
default: default:
return RawInstruction{}, fmt.Errorf("unknown JumpTest %v", a.Cond) return RawInstruction{}, fmt.Errorf("unknown JumpTest %v", test)
} }
jt, jf := a.SkipTrue, a.SkipFalse jt, jf := skipTrue, skipFalse
if flip { if flip {
jt, jf = jf, jt jt, jf = jf, jt
} }
return RawInstruction{ return RawInstruction{
Op: opClsJump | cond, Op: opClsJump | uint16(cond) | uint16(operand),
Jt: jt, Jt: jt,
Jf: jf, Jf: jf,
K: a.Val, K: k,
}, nil }, nil
} }
// String returns the instruction in assembler notation. // jumpToString converts a jump instruction to assembler notation
func (a JumpIf) String() string { func jumpToString(cond JumpTest, operand string, skipTrue, skipFalse uint8) string {
switch a.Cond { switch cond {
// K == A // K == A
case JumpEqual: case JumpEqual:
return conditionalJump(a, "jeq", "jneq") return conditionalJump(operand, skipTrue, skipFalse, "jeq", "jneq")
// K != A // K != A
case JumpNotEqual: case JumpNotEqual:
return fmt.Sprintf("jneq #%d,%d", a.Val, a.SkipTrue) return fmt.Sprintf("jneq %s,%d", operand, skipTrue)
// K > A // K > A
case JumpGreaterThan: case JumpGreaterThan:
return conditionalJump(a, "jgt", "jle") return conditionalJump(operand, skipTrue, skipFalse, "jgt", "jle")
// K < A // K < A
case JumpLessThan: case JumpLessThan:
return fmt.Sprintf("jlt #%d,%d", a.Val, a.SkipTrue) return fmt.Sprintf("jlt %s,%d", operand, skipTrue)
// K >= A // K >= A
case JumpGreaterOrEqual: case JumpGreaterOrEqual:
return conditionalJump(a, "jge", "jlt") return conditionalJump(operand, skipTrue, skipFalse, "jge", "jlt")
// K <= A // K <= A
case JumpLessOrEqual: case JumpLessOrEqual:
return fmt.Sprintf("jle #%d,%d", a.Val, a.SkipTrue) return fmt.Sprintf("jle %s,%d", operand, skipTrue)
// K & A != 0 // K & A != 0
case JumpBitsSet: case JumpBitsSet:
if a.SkipFalse > 0 { if skipFalse > 0 {
return fmt.Sprintf("jset #%d,%d,%d", a.Val, a.SkipTrue, a.SkipFalse) return fmt.Sprintf("jset %s,%d,%d", operand, skipTrue, skipFalse)
} }
return fmt.Sprintf("jset #%d,%d", a.Val, a.SkipTrue) return fmt.Sprintf("jset %s,%d", operand, skipTrue)
// K & A == 0, there is no assembler instruction for JumpBitNotSet, use JumpBitSet and invert skips // K & A == 0, there is no assembler instruction for JumpBitNotSet, use JumpBitSet and invert skips
case JumpBitsNotSet: case JumpBitsNotSet:
return JumpIf{Cond: JumpBitsSet, SkipTrue: a.SkipFalse, SkipFalse: a.SkipTrue, Val: a.Val}.String() return jumpToString(JumpBitsSet, operand, skipFalse, skipTrue)
default: default:
return fmt.Sprintf("unknown instruction: %#v", a) return fmt.Sprintf("unknown JumpTest %#v", cond)
} }
} }
func conditionalJump(inst JumpIf, positiveJump, negativeJump string) string { func conditionalJump(operand string, skipTrue, skipFalse uint8, positiveJump, negativeJump string) string {
if inst.SkipTrue > 0 { if skipTrue > 0 {
if inst.SkipFalse > 0 { if skipFalse > 0 {
return fmt.Sprintf("%s #%d,%d,%d", positiveJump, inst.Val, inst.SkipTrue, inst.SkipFalse) return fmt.Sprintf("%s %s,%d,%d", positiveJump, operand, skipTrue, skipFalse)
} }
return fmt.Sprintf("%s #%d,%d", positiveJump, inst.Val, inst.SkipTrue) return fmt.Sprintf("%s %s,%d", positiveJump, operand, skipTrue)
} }
return fmt.Sprintf("%s #%d,%d", negativeJump, inst.Val, inst.SkipFalse) return fmt.Sprintf("%s %s,%d", negativeJump, operand, skipFalse)
} }
// RetA exits the BPF program, returning the value of register A. // RetA exits the BPF program, returning the value of register A.

10
vendor/golang.org/x/net/bpf/vm.go generated vendored
View File

@ -35,6 +35,13 @@ func NewVM(filter []Instruction) (*VM, error) {
if check <= int(ins.SkipFalse) { if check <= int(ins.SkipFalse) {
return nil, fmt.Errorf("cannot jump %d instructions in false case; jumping past program bounds", ins.SkipFalse) return nil, fmt.Errorf("cannot jump %d instructions in false case; jumping past program bounds", ins.SkipFalse)
} }
case JumpIfX:
if check <= int(ins.SkipTrue) {
return nil, fmt.Errorf("cannot jump %d instructions in true case; jumping past program bounds", ins.SkipTrue)
}
if check <= int(ins.SkipFalse) {
return nil, fmt.Errorf("cannot jump %d instructions in false case; jumping past program bounds", ins.SkipFalse)
}
// Check for division or modulus by zero // Check for division or modulus by zero
case ALUOpConstant: case ALUOpConstant:
if ins.Val != 0 { if ins.Val != 0 {
@ -109,6 +116,9 @@ func (v *VM) Run(in []byte) (int, error) {
case JumpIf: case JumpIf:
jump := jumpIf(ins, regA) jump := jumpIf(ins, regA)
i += jump i += jump
case JumpIfX:
jump := jumpIfX(ins, regA, regX)
i += jump
case LoadAbsolute: case LoadAbsolute:
regA, ok = loadAbsolute(ins, in) regA, ok = loadAbsolute(ins, in)
case LoadConstant: case LoadConstant:

View File

@ -55,34 +55,41 @@ func aluOpCommon(op ALUOp, regA uint32, value uint32) uint32 {
} }
} }
func jumpIf(ins JumpIf, value uint32) int { func jumpIf(ins JumpIf, regA uint32) int {
var ok bool return jumpIfCommon(ins.Cond, ins.SkipTrue, ins.SkipFalse, regA, ins.Val)
inV := uint32(ins.Val) }
switch ins.Cond { func jumpIfX(ins JumpIfX, regA uint32, regX uint32) int {
return jumpIfCommon(ins.Cond, ins.SkipTrue, ins.SkipFalse, regA, regX)
}
func jumpIfCommon(cond JumpTest, skipTrue, skipFalse uint8, regA uint32, value uint32) int {
var ok bool
switch cond {
case JumpEqual: case JumpEqual:
ok = value == inV ok = regA == value
case JumpNotEqual: case JumpNotEqual:
ok = value != inV ok = regA != value
case JumpGreaterThan: case JumpGreaterThan:
ok = value > inV ok = regA > value
case JumpLessThan: case JumpLessThan:
ok = value < inV ok = regA < value
case JumpGreaterOrEqual: case JumpGreaterOrEqual:
ok = value >= inV ok = regA >= value
case JumpLessOrEqual: case JumpLessOrEqual:
ok = value <= inV ok = regA <= value
case JumpBitsSet: case JumpBitsSet:
ok = (value & inV) != 0 ok = (regA & value) != 0
case JumpBitsNotSet: case JumpBitsNotSet:
ok = (value & inV) == 0 ok = (regA & value) == 0
} }
if ok { if ok {
return int(ins.SkipTrue) return int(skipTrue)
} }
return int(ins.SkipFalse) return int(skipFalse)
} }
func loadAbsolute(ins LoadAbsolute, in []byte) (uint32, bool) { func loadAbsolute(ins LoadAbsolute, in []byte) (uint32, bool) {
@ -122,7 +129,8 @@ func loadIndirect(ins LoadIndirect, in []byte, regX uint32) (uint32, bool) {
func loadMemShift(ins LoadMemShift, in []byte) (uint32, bool) { func loadMemShift(ins LoadMemShift, in []byte) (uint32, bool) {
offset := int(ins.Off) offset := int(ins.Off)
if !inBounds(len(in), offset, 0) { // Size of LoadMemShift is always 1 byte
if !inBounds(len(in), offset, 1) {
return 0, false return 0, false
} }

50
vendor/golang.org/x/net/http/httpguts/guts.go generated vendored Normal file
View File

@ -0,0 +1,50 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package httpguts provides functions implementing various details
// of the HTTP specification.
//
// This package is shared by the standard library (which vendors it)
// and x/net/http2. It comes with no API stability promise.
package httpguts
import (
"net/textproto"
"strings"
)
// ValidTrailerHeader reports whether name is a valid header field name to appear
// in trailers.
// See RFC 7230, Section 4.1.2
func ValidTrailerHeader(name string) bool {
name = textproto.CanonicalMIMEHeaderKey(name)
if strings.HasPrefix(name, "If-") || badTrailer[name] {
return false
}
return true
}
var badTrailer = map[string]bool{
"Authorization": true,
"Cache-Control": true,
"Connection": true,
"Content-Encoding": true,
"Content-Length": true,
"Content-Range": true,
"Content-Type": true,
"Expect": true,
"Host": true,
"Keep-Alive": true,
"Max-Forwards": true,
"Pragma": true,
"Proxy-Authenticate": true,
"Proxy-Authorization": true,
"Proxy-Connection": true,
"Range": true,
"Realm": true,
"Te": true,
"Trailer": true,
"Transfer-Encoding": true,
"Www-Authenticate": true,
}

346
vendor/golang.org/x/net/http/httpguts/httplex.go generated vendored Normal file
View File

@ -0,0 +1,346 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package httpguts
import (
"net"
"strings"
"unicode/utf8"
"golang.org/x/net/idna"
)
var isTokenTable = [127]bool{
'!': true,
'#': true,
'$': true,
'%': true,
'&': true,
'\'': true,
'*': true,
'+': true,
'-': true,
'.': true,
'0': true,
'1': true,
'2': true,
'3': true,
'4': true,
'5': true,
'6': true,
'7': true,
'8': true,
'9': true,
'A': true,
'B': true,
'C': true,
'D': true,
'E': true,
'F': true,
'G': true,
'H': true,
'I': true,
'J': true,
'K': true,
'L': true,
'M': true,
'N': true,
'O': true,
'P': true,
'Q': true,
'R': true,
'S': true,
'T': true,
'U': true,
'W': true,
'V': true,
'X': true,
'Y': true,
'Z': true,
'^': true,
'_': true,
'`': true,
'a': true,
'b': true,
'c': true,
'd': true,
'e': true,
'f': true,
'g': true,
'h': true,
'i': true,
'j': true,
'k': true,
'l': true,
'm': true,
'n': true,
'o': true,
'p': true,
'q': true,
'r': true,
's': true,
't': true,
'u': true,
'v': true,
'w': true,
'x': true,
'y': true,
'z': true,
'|': true,
'~': true,
}
func IsTokenRune(r rune) bool {
i := int(r)
return i < len(isTokenTable) && isTokenTable[i]
}
func isNotToken(r rune) bool {
return !IsTokenRune(r)
}
// HeaderValuesContainsToken reports whether any string in values
// contains the provided token, ASCII case-insensitively.
func HeaderValuesContainsToken(values []string, token string) bool {
for _, v := range values {
if headerValueContainsToken(v, token) {
return true
}
}
return false
}
// isOWS reports whether b is an optional whitespace byte, as defined
// by RFC 7230 section 3.2.3.
func isOWS(b byte) bool { return b == ' ' || b == '\t' }
// trimOWS returns x with all optional whitespace removes from the
// beginning and end.
func trimOWS(x string) string {
// TODO: consider using strings.Trim(x, " \t") instead,
// if and when it's fast enough. See issue 10292.
// But this ASCII-only code will probably always beat UTF-8
// aware code.
for len(x) > 0 && isOWS(x[0]) {
x = x[1:]
}
for len(x) > 0 && isOWS(x[len(x)-1]) {
x = x[:len(x)-1]
}
return x
}
// headerValueContainsToken reports whether v (assumed to be a
// 0#element, in the ABNF extension described in RFC 7230 section 7)
// contains token amongst its comma-separated tokens, ASCII
// case-insensitively.
func headerValueContainsToken(v string, token string) bool {
v = trimOWS(v)
if comma := strings.IndexByte(v, ','); comma != -1 {
return tokenEqual(trimOWS(v[:comma]), token) || headerValueContainsToken(v[comma+1:], token)
}
return tokenEqual(v, token)
}
// lowerASCII returns the ASCII lowercase version of b.
func lowerASCII(b byte) byte {
if 'A' <= b && b <= 'Z' {
return b + ('a' - 'A')
}
return b
}
// tokenEqual reports whether t1 and t2 are equal, ASCII case-insensitively.
func tokenEqual(t1, t2 string) bool {
if len(t1) != len(t2) {
return false
}
for i, b := range t1 {
if b >= utf8.RuneSelf {
// No UTF-8 or non-ASCII allowed in tokens.
return false
}
if lowerASCII(byte(b)) != lowerASCII(t2[i]) {
return false
}
}
return true
}
// isLWS reports whether b is linear white space, according
// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2
// LWS = [CRLF] 1*( SP | HT )
func isLWS(b byte) bool { return b == ' ' || b == '\t' }
// isCTL reports whether b is a control byte, according
// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2
// CTL = <any US-ASCII control character
// (octets 0 - 31) and DEL (127)>
func isCTL(b byte) bool {
const del = 0x7f // a CTL
return b < ' ' || b == del
}
// ValidHeaderFieldName reports whether v is a valid HTTP/1.x header name.
// HTTP/2 imposes the additional restriction that uppercase ASCII
// letters are not allowed.
//
// RFC 7230 says:
// header-field = field-name ":" OWS field-value OWS
// field-name = token
// token = 1*tchar
// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
// "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA
func ValidHeaderFieldName(v string) bool {
if len(v) == 0 {
return false
}
for _, r := range v {
if !IsTokenRune(r) {
return false
}
}
return true
}
// ValidHostHeader reports whether h is a valid host header.
func ValidHostHeader(h string) bool {
// The latest spec is actually this:
//
// http://tools.ietf.org/html/rfc7230#section-5.4
// Host = uri-host [ ":" port ]
//
// Where uri-host is:
// http://tools.ietf.org/html/rfc3986#section-3.2.2
//
// But we're going to be much more lenient for now and just
// search for any byte that's not a valid byte in any of those
// expressions.
for i := 0; i < len(h); i++ {
if !validHostByte[h[i]] {
return false
}
}
return true
}
// See the validHostHeader comment.
var validHostByte = [256]bool{
'0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true,
'8': true, '9': true,
'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true,
'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true,
'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true,
'y': true, 'z': true,
'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true,
'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true,
'Y': true, 'Z': true,
'!': true, // sub-delims
'$': true, // sub-delims
'%': true, // pct-encoded (and used in IPv6 zones)
'&': true, // sub-delims
'(': true, // sub-delims
')': true, // sub-delims
'*': true, // sub-delims
'+': true, // sub-delims
',': true, // sub-delims
'-': true, // unreserved
'.': true, // unreserved
':': true, // IPv6address + Host expression's optional port
';': true, // sub-delims
'=': true, // sub-delims
'[': true,
'\'': true, // sub-delims
']': true,
'_': true, // unreserved
'~': true, // unreserved
}
// ValidHeaderFieldValue reports whether v is a valid "field-value" according to
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 :
//
// message-header = field-name ":" [ field-value ]
// field-value = *( field-content | LWS )
// field-content = <the OCTETs making up the field-value
// and consisting of either *TEXT or combinations
// of token, separators, and quoted-string>
//
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 :
//
// TEXT = <any OCTET except CTLs,
// but including LWS>
// LWS = [CRLF] 1*( SP | HT )
// CTL = <any US-ASCII control character
// (octets 0 - 31) and DEL (127)>
//
// RFC 7230 says:
// field-value = *( field-content / obs-fold )
// obj-fold = N/A to http2, and deprecated
// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ]
// field-vchar = VCHAR / obs-text
// obs-text = %x80-FF
// VCHAR = "any visible [USASCII] character"
//
// http2 further says: "Similarly, HTTP/2 allows header field values
// that are not valid. While most of the values that can be encoded
// will not alter header field parsing, carriage return (CR, ASCII
// 0xd), line feed (LF, ASCII 0xa), and the zero character (NUL, ASCII
// 0x0) might be exploited by an attacker if they are translated
// verbatim. Any request or response that contains a character not
// permitted in a header field value MUST be treated as malformed
// (Section 8.1.2.6). Valid characters are defined by the
// field-content ABNF rule in Section 3.2 of [RFC7230]."
//
// This function does not (yet?) properly handle the rejection of
// strings that begin or end with SP or HTAB.
func ValidHeaderFieldValue(v string) bool {
for i := 0; i < len(v); i++ {
b := v[i]
if isCTL(b) && !isLWS(b) {
return false
}
}
return true
}
func isASCII(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] >= utf8.RuneSelf {
return false
}
}
return true
}
// PunycodeHostPort returns the IDNA Punycode version
// of the provided "host" or "host:port" string.
func PunycodeHostPort(v string) (string, error) {
if isASCII(v) {
return v, nil
}
host, port, err := net.SplitHostPort(v)
if err != nil {
// The input 'v' argument was just a "host" argument,
// without a port. This error should not be returned
// to the caller.
host = v
port = ""
}
host, err = idna.ToASCII(host)
if err != nil {
// Non-UTF-8? Not representable in Punycode, in any
// case.
return "", err
}
if port == "" {
return host, nil
}
return net.JoinHostPort(host, port), nil
}

734
vendor/golang.org/x/net/idna/idna10.0.0.go generated vendored Normal file
View File

@ -0,0 +1,734 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.10
// Package idna implements IDNA2008 using the compatibility processing
// defined by UTS (Unicode Technical Standard) #46, which defines a standard to
// deal with the transition from IDNA2003.
//
// IDNA2008 (Internationalized Domain Names for Applications), is defined in RFC
// 5890, RFC 5891, RFC 5892, RFC 5893 and RFC 5894.
// UTS #46 is defined in https://www.unicode.org/reports/tr46.
// See https://unicode.org/cldr/utility/idna.jsp for a visualization of the
// differences between these two standards.
package idna // import "golang.org/x/net/idna"
import (
"fmt"
"strings"
"unicode/utf8"
"golang.org/x/text/secure/bidirule"
"golang.org/x/text/unicode/bidi"
"golang.org/x/text/unicode/norm"
)
// NOTE: Unlike common practice in Go APIs, the functions will return a
// sanitized domain name in case of errors. Browsers sometimes use a partially
// evaluated string as lookup.
// TODO: the current error handling is, in my opinion, the least opinionated.
// Other strategies are also viable, though:
// Option 1) Return an empty string in case of error, but allow the user to
// specify explicitly which errors to ignore.
// Option 2) Return the partially evaluated string if it is itself a valid
// string, otherwise return the empty string in case of error.
// Option 3) Option 1 and 2.
// Option 4) Always return an empty string for now and implement Option 1 as
// needed, and document that the return string may not be empty in case of
// error in the future.
// I think Option 1 is best, but it is quite opinionated.
// ToASCII is a wrapper for Punycode.ToASCII.
func ToASCII(s string) (string, error) {
return Punycode.process(s, true)
}
// ToUnicode is a wrapper for Punycode.ToUnicode.
func ToUnicode(s string) (string, error) {
return Punycode.process(s, false)
}
// An Option configures a Profile at creation time.
type Option func(*options)
// Transitional sets a Profile to use the Transitional mapping as defined in UTS
// #46. This will cause, for example, "ß" to be mapped to "ss". Using the
// transitional mapping provides a compromise between IDNA2003 and IDNA2008
// compatibility. It is used by most browsers when resolving domain names. This
// option is only meaningful if combined with MapForLookup.
func Transitional(transitional bool) Option {
return func(o *options) { o.transitional = true }
}
// VerifyDNSLength sets whether a Profile should fail if any of the IDN parts
// are longer than allowed by the RFC.
func VerifyDNSLength(verify bool) Option {
return func(o *options) { o.verifyDNSLength = verify }
}
// RemoveLeadingDots removes leading label separators. Leading runes that map to
// dots, such as U+3002 IDEOGRAPHIC FULL STOP, are removed as well.
//
// This is the behavior suggested by the UTS #46 and is adopted by some
// browsers.
func RemoveLeadingDots(remove bool) Option {
return func(o *options) { o.removeLeadingDots = remove }
}
// ValidateLabels sets whether to check the mandatory label validation criteria
// as defined in Section 5.4 of RFC 5891. This includes testing for correct use
// of hyphens ('-'), normalization, validity of runes, and the context rules.
func ValidateLabels(enable bool) Option {
return func(o *options) {
// Don't override existing mappings, but set one that at least checks
// normalization if it is not set.
if o.mapping == nil && enable {
o.mapping = normalize
}
o.trie = trie
o.validateLabels = enable
o.fromPuny = validateFromPunycode
}
}
// StrictDomainName limits the set of permissible ASCII characters to those
// allowed in domain names as defined in RFC 1034 (A-Z, a-z, 0-9 and the
// hyphen). This is set by default for MapForLookup and ValidateForRegistration.
//
// This option is useful, for instance, for browsers that allow characters
// outside this range, for example a '_' (U+005F LOW LINE). See
// http://www.rfc-editor.org/std/std3.txt for more details This option
// corresponds to the UseSTD3ASCIIRules option in UTS #46.
func StrictDomainName(use bool) Option {
return func(o *options) {
o.trie = trie
o.useSTD3Rules = use
o.fromPuny = validateFromPunycode
}
}
// NOTE: the following options pull in tables. The tables should not be linked
// in as long as the options are not used.
// BidiRule enables the Bidi rule as defined in RFC 5893. Any application
// that relies on proper validation of labels should include this rule.
func BidiRule() Option {
return func(o *options) { o.bidirule = bidirule.ValidString }
}
// ValidateForRegistration sets validation options to verify that a given IDN is
// properly formatted for registration as defined by Section 4 of RFC 5891.
func ValidateForRegistration() Option {
return func(o *options) {
o.mapping = validateRegistration
StrictDomainName(true)(o)
ValidateLabels(true)(o)
VerifyDNSLength(true)(o)
BidiRule()(o)
}
}
// MapForLookup sets validation and mapping options such that a given IDN is
// transformed for domain name lookup according to the requirements set out in
// Section 5 of RFC 5891. The mappings follow the recommendations of RFC 5894,
// RFC 5895 and UTS 46. It does not add the Bidi Rule. Use the BidiRule option
// to add this check.
//
// The mappings include normalization and mapping case, width and other
// compatibility mappings.
func MapForLookup() Option {
return func(o *options) {
o.mapping = validateAndMap
StrictDomainName(true)(o)
ValidateLabels(true)(o)
}
}
type options struct {
transitional bool
useSTD3Rules bool
validateLabels bool
verifyDNSLength bool
removeLeadingDots bool
trie *idnaTrie
// fromPuny calls validation rules when converting A-labels to U-labels.
fromPuny func(p *Profile, s string) error
// mapping implements a validation and mapping step as defined in RFC 5895
// or UTS 46, tailored to, for example, domain registration or lookup.
mapping func(p *Profile, s string) (mapped string, isBidi bool, err error)
// bidirule, if specified, checks whether s conforms to the Bidi Rule
// defined in RFC 5893.
bidirule func(s string) bool
}
// A Profile defines the configuration of an IDNA mapper.
type Profile struct {
options
}
func apply(o *options, opts []Option) {
for _, f := range opts {
f(o)
}
}
// New creates a new Profile.
//
// With no options, the returned Profile is the most permissive and equals the
// Punycode Profile. Options can be passed to further restrict the Profile. The
// MapForLookup and ValidateForRegistration options set a collection of options,
// for lookup and registration purposes respectively, which can be tailored by
// adding more fine-grained options, where later options override earlier
// options.
func New(o ...Option) *Profile {
p := &Profile{}
apply(&p.options, o)
return p
}
// ToASCII converts a domain or domain label to its ASCII form. For example,
// ToASCII("bücher.example.com") is "xn--bcher-kva.example.com", and
// ToASCII("golang") is "golang". If an error is encountered it will return
// an error and a (partially) processed result.
func (p *Profile) ToASCII(s string) (string, error) {
return p.process(s, true)
}
// ToUnicode converts a domain or domain label to its Unicode form. For example,
// ToUnicode("xn--bcher-kva.example.com") is "bücher.example.com", and
// ToUnicode("golang") is "golang". If an error is encountered it will return
// an error and a (partially) processed result.
func (p *Profile) ToUnicode(s string) (string, error) {
pp := *p
pp.transitional = false
return pp.process(s, false)
}
// String reports a string with a description of the profile for debugging
// purposes. The string format may change with different versions.
func (p *Profile) String() string {
s := ""
if p.transitional {
s = "Transitional"
} else {
s = "NonTransitional"
}
if p.useSTD3Rules {
s += ":UseSTD3Rules"
}
if p.validateLabels {
s += ":ValidateLabels"
}
if p.verifyDNSLength {
s += ":VerifyDNSLength"
}
return s
}
var (
// Punycode is a Profile that does raw punycode processing with a minimum
// of validation.
Punycode *Profile = punycode
// Lookup is the recommended profile for looking up domain names, according
// to Section 5 of RFC 5891. The exact configuration of this profile may
// change over time.
Lookup *Profile = lookup
// Display is the recommended profile for displaying domain names.
// The configuration of this profile may change over time.
Display *Profile = display
// Registration is the recommended profile for checking whether a given
// IDN is valid for registration, according to Section 4 of RFC 5891.
Registration *Profile = registration
punycode = &Profile{}
lookup = &Profile{options{
transitional: true,
useSTD3Rules: true,
validateLabels: true,
trie: trie,
fromPuny: validateFromPunycode,
mapping: validateAndMap,
bidirule: bidirule.ValidString,
}}
display = &Profile{options{
useSTD3Rules: true,
validateLabels: true,
trie: trie,
fromPuny: validateFromPunycode,
mapping: validateAndMap,
bidirule: bidirule.ValidString,
}}
registration = &Profile{options{
useSTD3Rules: true,
validateLabels: true,
verifyDNSLength: true,
trie: trie,
fromPuny: validateFromPunycode,
mapping: validateRegistration,
bidirule: bidirule.ValidString,
}}
// TODO: profiles
// Register: recommended for approving domain names: don't do any mappings
// but rather reject on invalid input. Bundle or block deviation characters.
)
type labelError struct{ label, code_ string }
func (e labelError) code() string { return e.code_ }
func (e labelError) Error() string {
return fmt.Sprintf("idna: invalid label %q", e.label)
}
type runeError rune
func (e runeError) code() string { return "P1" }
func (e runeError) Error() string {
return fmt.Sprintf("idna: disallowed rune %U", e)
}
// process implements the algorithm described in section 4 of UTS #46,
// see https://www.unicode.org/reports/tr46.
func (p *Profile) process(s string, toASCII bool) (string, error) {
var err error
var isBidi bool
if p.mapping != nil {
s, isBidi, err = p.mapping(p, s)
}
// Remove leading empty labels.
if p.removeLeadingDots {
for ; len(s) > 0 && s[0] == '.'; s = s[1:] {
}
}
// TODO: allow for a quick check of the tables data.
// It seems like we should only create this error on ToASCII, but the
// UTS 46 conformance tests suggests we should always check this.
if err == nil && p.verifyDNSLength && s == "" {
err = &labelError{s, "A4"}
}
labels := labelIter{orig: s}
for ; !labels.done(); labels.next() {
label := labels.label()
if label == "" {
// Empty labels are not okay. The label iterator skips the last
// label if it is empty.
if err == nil && p.verifyDNSLength {
err = &labelError{s, "A4"}
}
continue
}
if strings.HasPrefix(label, acePrefix) {
u, err2 := decode(label[len(acePrefix):])
if err2 != nil {
if err == nil {
err = err2
}
// Spec says keep the old label.
continue
}
isBidi = isBidi || bidirule.DirectionString(u) != bidi.LeftToRight
labels.set(u)
if err == nil && p.validateLabels {
err = p.fromPuny(p, u)
}
if err == nil {
// This should be called on NonTransitional, according to the
// spec, but that currently does not have any effect. Use the
// original profile to preserve options.
err = p.validateLabel(u)
}
} else if err == nil {
err = p.validateLabel(label)
}
}
if isBidi && p.bidirule != nil && err == nil {
for labels.reset(); !labels.done(); labels.next() {
if !p.bidirule(labels.label()) {
err = &labelError{s, "B"}
break
}
}
}
if toASCII {
for labels.reset(); !labels.done(); labels.next() {
label := labels.label()
if !ascii(label) {
a, err2 := encode(acePrefix, label)
if err == nil {
err = err2
}
label = a
labels.set(a)
}
n := len(label)
if p.verifyDNSLength && err == nil && (n == 0 || n > 63) {
err = &labelError{label, "A4"}
}
}
}
s = labels.result()
if toASCII && p.verifyDNSLength && err == nil {
// Compute the length of the domain name minus the root label and its dot.
n := len(s)
if n > 0 && s[n-1] == '.' {
n--
}
if len(s) < 1 || n > 253 {
err = &labelError{s, "A4"}
}
}
return s, err
}
func normalize(p *Profile, s string) (mapped string, isBidi bool, err error) {
// TODO: consider first doing a quick check to see if any of these checks
// need to be done. This will make it slower in the general case, but
// faster in the common case.
mapped = norm.NFC.String(s)
isBidi = bidirule.DirectionString(mapped) == bidi.RightToLeft
return mapped, isBidi, nil
}
func validateRegistration(p *Profile, s string) (idem string, bidi bool, err error) {
// TODO: filter need for normalization in loop below.
if !norm.NFC.IsNormalString(s) {
return s, false, &labelError{s, "V1"}
}
for i := 0; i < len(s); {
v, sz := trie.lookupString(s[i:])
if sz == 0 {
return s, bidi, runeError(utf8.RuneError)
}
bidi = bidi || info(v).isBidi(s[i:])
// Copy bytes not copied so far.
switch p.simplify(info(v).category()) {
// TODO: handle the NV8 defined in the Unicode idna data set to allow
// for strict conformance to IDNA2008.
case valid, deviation:
case disallowed, mapped, unknown, ignored:
r, _ := utf8.DecodeRuneInString(s[i:])
return s, bidi, runeError(r)
}
i += sz
}
return s, bidi, nil
}
func (c info) isBidi(s string) bool {
if !c.isMapped() {
return c&attributesMask == rtl
}
// TODO: also store bidi info for mapped data. This is possible, but a bit
// cumbersome and not for the common case.
p, _ := bidi.LookupString(s)
switch p.Class() {
case bidi.R, bidi.AL, bidi.AN:
return true
}
return false
}
func validateAndMap(p *Profile, s string) (vm string, bidi bool, err error) {
var (
b []byte
k int
)
// combinedInfoBits contains the or-ed bits of all runes. We use this
// to derive the mayNeedNorm bit later. This may trigger normalization
// overeagerly, but it will not do so in the common case. The end result
// is another 10% saving on BenchmarkProfile for the common case.
var combinedInfoBits info
for i := 0; i < len(s); {
v, sz := trie.lookupString(s[i:])
if sz == 0 {
b = append(b, s[k:i]...)
b = append(b, "\ufffd"...)
k = len(s)
if err == nil {
err = runeError(utf8.RuneError)
}
break
}
combinedInfoBits |= info(v)
bidi = bidi || info(v).isBidi(s[i:])
start := i
i += sz
// Copy bytes not copied so far.
switch p.simplify(info(v).category()) {
case valid:
continue
case disallowed:
if err == nil {
r, _ := utf8.DecodeRuneInString(s[start:])
err = runeError(r)
}
continue
case mapped, deviation:
b = append(b, s[k:start]...)
b = info(v).appendMapping(b, s[start:i])
case ignored:
b = append(b, s[k:start]...)
// drop the rune
case unknown:
b = append(b, s[k:start]...)
b = append(b, "\ufffd"...)
}
k = i
}
if k == 0 {
// No changes so far.
if combinedInfoBits&mayNeedNorm != 0 {
s = norm.NFC.String(s)
}
} else {
b = append(b, s[k:]...)
if norm.NFC.QuickSpan(b) != len(b) {
b = norm.NFC.Bytes(b)
}
// TODO: the punycode converters require strings as input.
s = string(b)
}
return s, bidi, err
}
// A labelIter allows iterating over domain name labels.
type labelIter struct {
orig string
slice []string
curStart int
curEnd int
i int
}
func (l *labelIter) reset() {
l.curStart = 0
l.curEnd = 0
l.i = 0
}
func (l *labelIter) done() bool {
return l.curStart >= len(l.orig)
}
func (l *labelIter) result() string {
if l.slice != nil {
return strings.Join(l.slice, ".")
}
return l.orig
}
func (l *labelIter) label() string {
if l.slice != nil {
return l.slice[l.i]
}
p := strings.IndexByte(l.orig[l.curStart:], '.')
l.curEnd = l.curStart + p
if p == -1 {
l.curEnd = len(l.orig)
}
return l.orig[l.curStart:l.curEnd]
}
// next sets the value to the next label. It skips the last label if it is empty.
func (l *labelIter) next() {
l.i++
if l.slice != nil {
if l.i >= len(l.slice) || l.i == len(l.slice)-1 && l.slice[l.i] == "" {
l.curStart = len(l.orig)
}
} else {
l.curStart = l.curEnd + 1
if l.curStart == len(l.orig)-1 && l.orig[l.curStart] == '.' {
l.curStart = len(l.orig)
}
}
}
func (l *labelIter) set(s string) {
if l.slice == nil {
l.slice = strings.Split(l.orig, ".")
}
l.slice[l.i] = s
}
// acePrefix is the ASCII Compatible Encoding prefix.
const acePrefix = "xn--"
func (p *Profile) simplify(cat category) category {
switch cat {
case disallowedSTD3Mapped:
if p.useSTD3Rules {
cat = disallowed
} else {
cat = mapped
}
case disallowedSTD3Valid:
if p.useSTD3Rules {
cat = disallowed
} else {
cat = valid
}
case deviation:
if !p.transitional {
cat = valid
}
case validNV8, validXV8:
// TODO: handle V2008
cat = valid
}
return cat
}
func validateFromPunycode(p *Profile, s string) error {
if !norm.NFC.IsNormalString(s) {
return &labelError{s, "V1"}
}
// TODO: detect whether string may have to be normalized in the following
// loop.
for i := 0; i < len(s); {
v, sz := trie.lookupString(s[i:])
if sz == 0 {
return runeError(utf8.RuneError)
}
if c := p.simplify(info(v).category()); c != valid && c != deviation {
return &labelError{s, "V6"}
}
i += sz
}
return nil
}
const (
zwnj = "\u200c"
zwj = "\u200d"
)
type joinState int8
const (
stateStart joinState = iota
stateVirama
stateBefore
stateBeforeVirama
stateAfter
stateFAIL
)
var joinStates = [][numJoinTypes]joinState{
stateStart: {
joiningL: stateBefore,
joiningD: stateBefore,
joinZWNJ: stateFAIL,
joinZWJ: stateFAIL,
joinVirama: stateVirama,
},
stateVirama: {
joiningL: stateBefore,
joiningD: stateBefore,
},
stateBefore: {
joiningL: stateBefore,
joiningD: stateBefore,
joiningT: stateBefore,
joinZWNJ: stateAfter,
joinZWJ: stateFAIL,
joinVirama: stateBeforeVirama,
},
stateBeforeVirama: {
joiningL: stateBefore,
joiningD: stateBefore,
joiningT: stateBefore,
},
stateAfter: {
joiningL: stateFAIL,
joiningD: stateBefore,
joiningT: stateAfter,
joiningR: stateStart,
joinZWNJ: stateFAIL,
joinZWJ: stateFAIL,
joinVirama: stateAfter, // no-op as we can't accept joiners here
},
stateFAIL: {
0: stateFAIL,
joiningL: stateFAIL,
joiningD: stateFAIL,
joiningT: stateFAIL,
joiningR: stateFAIL,
joinZWNJ: stateFAIL,
joinZWJ: stateFAIL,
joinVirama: stateFAIL,
},
}
// validateLabel validates the criteria from Section 4.1. Item 1, 4, and 6 are
// already implicitly satisfied by the overall implementation.
func (p *Profile) validateLabel(s string) (err error) {
if s == "" {
if p.verifyDNSLength {
return &labelError{s, "A4"}
}
return nil
}
if !p.validateLabels {
return nil
}
trie := p.trie // p.validateLabels is only set if trie is set.
if len(s) > 4 && s[2] == '-' && s[3] == '-' {
return &labelError{s, "V2"}
}
if s[0] == '-' || s[len(s)-1] == '-' {
return &labelError{s, "V3"}
}
// TODO: merge the use of this in the trie.
v, sz := trie.lookupString(s)
x := info(v)
if x.isModifier() {
return &labelError{s, "V5"}
}
// Quickly return in the absence of zero-width (non) joiners.
if strings.Index(s, zwj) == -1 && strings.Index(s, zwnj) == -1 {
return nil
}
st := stateStart
for i := 0; ; {
jt := x.joinType()
if s[i:i+sz] == zwj {
jt = joinZWJ
} else if s[i:i+sz] == zwnj {
jt = joinZWNJ
}
st = joinStates[st][jt]
if x.isViramaModifier() {
st = joinStates[st][joinVirama]
}
if i += sz; i == len(s) {
break
}
v, sz = trie.lookupString(s[i:])
x = info(v)
}
if st == stateFAIL || st == stateAfter {
return &labelError{s, "C"}
}
return nil
}
func ascii(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] >= utf8.RuneSelf {
return false
}
}
return true
}

682
vendor/golang.org/x/net/idna/idna9.0.0.go generated vendored Normal file
View File

@ -0,0 +1,682 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.10
// Package idna implements IDNA2008 using the compatibility processing
// defined by UTS (Unicode Technical Standard) #46, which defines a standard to
// deal with the transition from IDNA2003.
//
// IDNA2008 (Internationalized Domain Names for Applications), is defined in RFC
// 5890, RFC 5891, RFC 5892, RFC 5893 and RFC 5894.
// UTS #46 is defined in https://www.unicode.org/reports/tr46.
// See https://unicode.org/cldr/utility/idna.jsp for a visualization of the
// differences between these two standards.
package idna // import "golang.org/x/net/idna"
import (
"fmt"
"strings"
"unicode/utf8"
"golang.org/x/text/secure/bidirule"
"golang.org/x/text/unicode/norm"
)
// NOTE: Unlike common practice in Go APIs, the functions will return a
// sanitized domain name in case of errors. Browsers sometimes use a partially
// evaluated string as lookup.
// TODO: the current error handling is, in my opinion, the least opinionated.
// Other strategies are also viable, though:
// Option 1) Return an empty string in case of error, but allow the user to
// specify explicitly which errors to ignore.
// Option 2) Return the partially evaluated string if it is itself a valid
// string, otherwise return the empty string in case of error.
// Option 3) Option 1 and 2.
// Option 4) Always return an empty string for now and implement Option 1 as
// needed, and document that the return string may not be empty in case of
// error in the future.
// I think Option 1 is best, but it is quite opinionated.
// ToASCII is a wrapper for Punycode.ToASCII.
func ToASCII(s string) (string, error) {
return Punycode.process(s, true)
}
// ToUnicode is a wrapper for Punycode.ToUnicode.
func ToUnicode(s string) (string, error) {
return Punycode.process(s, false)
}
// An Option configures a Profile at creation time.
type Option func(*options)
// Transitional sets a Profile to use the Transitional mapping as defined in UTS
// #46. This will cause, for example, "ß" to be mapped to "ss". Using the
// transitional mapping provides a compromise between IDNA2003 and IDNA2008
// compatibility. It is used by most browsers when resolving domain names. This
// option is only meaningful if combined with MapForLookup.
func Transitional(transitional bool) Option {
return func(o *options) { o.transitional = true }
}
// VerifyDNSLength sets whether a Profile should fail if any of the IDN parts
// are longer than allowed by the RFC.
func VerifyDNSLength(verify bool) Option {
return func(o *options) { o.verifyDNSLength = verify }
}
// RemoveLeadingDots removes leading label separators. Leading runes that map to
// dots, such as U+3002 IDEOGRAPHIC FULL STOP, are removed as well.
//
// This is the behavior suggested by the UTS #46 and is adopted by some
// browsers.
func RemoveLeadingDots(remove bool) Option {
return func(o *options) { o.removeLeadingDots = remove }
}
// ValidateLabels sets whether to check the mandatory label validation criteria
// as defined in Section 5.4 of RFC 5891. This includes testing for correct use
// of hyphens ('-'), normalization, validity of runes, and the context rules.
func ValidateLabels(enable bool) Option {
return func(o *options) {
// Don't override existing mappings, but set one that at least checks
// normalization if it is not set.
if o.mapping == nil && enable {
o.mapping = normalize
}
o.trie = trie
o.validateLabels = enable
o.fromPuny = validateFromPunycode
}
}
// StrictDomainName limits the set of permissable ASCII characters to those
// allowed in domain names as defined in RFC 1034 (A-Z, a-z, 0-9 and the
// hyphen). This is set by default for MapForLookup and ValidateForRegistration.
//
// This option is useful, for instance, for browsers that allow characters
// outside this range, for example a '_' (U+005F LOW LINE). See
// http://www.rfc-editor.org/std/std3.txt for more details This option
// corresponds to the UseSTD3ASCIIRules option in UTS #46.
func StrictDomainName(use bool) Option {
return func(o *options) {
o.trie = trie
o.useSTD3Rules = use
o.fromPuny = validateFromPunycode
}
}
// NOTE: the following options pull in tables. The tables should not be linked
// in as long as the options are not used.
// BidiRule enables the Bidi rule as defined in RFC 5893. Any application
// that relies on proper validation of labels should include this rule.
func BidiRule() Option {
return func(o *options) { o.bidirule = bidirule.ValidString }
}
// ValidateForRegistration sets validation options to verify that a given IDN is
// properly formatted for registration as defined by Section 4 of RFC 5891.
func ValidateForRegistration() Option {
return func(o *options) {
o.mapping = validateRegistration
StrictDomainName(true)(o)
ValidateLabels(true)(o)
VerifyDNSLength(true)(o)
BidiRule()(o)
}
}
// MapForLookup sets validation and mapping options such that a given IDN is
// transformed for domain name lookup according to the requirements set out in
// Section 5 of RFC 5891. The mappings follow the recommendations of RFC 5894,
// RFC 5895 and UTS 46. It does not add the Bidi Rule. Use the BidiRule option
// to add this check.
//
// The mappings include normalization and mapping case, width and other
// compatibility mappings.
func MapForLookup() Option {
return func(o *options) {
o.mapping = validateAndMap
StrictDomainName(true)(o)
ValidateLabels(true)(o)
RemoveLeadingDots(true)(o)
}
}
type options struct {
transitional bool
useSTD3Rules bool
validateLabels bool
verifyDNSLength bool
removeLeadingDots bool
trie *idnaTrie
// fromPuny calls validation rules when converting A-labels to U-labels.
fromPuny func(p *Profile, s string) error
// mapping implements a validation and mapping step as defined in RFC 5895
// or UTS 46, tailored to, for example, domain registration or lookup.
mapping func(p *Profile, s string) (string, error)
// bidirule, if specified, checks whether s conforms to the Bidi Rule
// defined in RFC 5893.
bidirule func(s string) bool
}
// A Profile defines the configuration of a IDNA mapper.
type Profile struct {
options
}
func apply(o *options, opts []Option) {
for _, f := range opts {
f(o)
}
}
// New creates a new Profile.
//
// With no options, the returned Profile is the most permissive and equals the
// Punycode Profile. Options can be passed to further restrict the Profile. The
// MapForLookup and ValidateForRegistration options set a collection of options,
// for lookup and registration purposes respectively, which can be tailored by
// adding more fine-grained options, where later options override earlier
// options.
func New(o ...Option) *Profile {
p := &Profile{}
apply(&p.options, o)
return p
}
// ToASCII converts a domain or domain label to its ASCII form. For example,
// ToASCII("bücher.example.com") is "xn--bcher-kva.example.com", and
// ToASCII("golang") is "golang". If an error is encountered it will return
// an error and a (partially) processed result.
func (p *Profile) ToASCII(s string) (string, error) {
return p.process(s, true)
}
// ToUnicode converts a domain or domain label to its Unicode form. For example,
// ToUnicode("xn--bcher-kva.example.com") is "bücher.example.com", and
// ToUnicode("golang") is "golang". If an error is encountered it will return
// an error and a (partially) processed result.
func (p *Profile) ToUnicode(s string) (string, error) {
pp := *p
pp.transitional = false
return pp.process(s, false)
}
// String reports a string with a description of the profile for debugging
// purposes. The string format may change with different versions.
func (p *Profile) String() string {
s := ""
if p.transitional {
s = "Transitional"
} else {
s = "NonTransitional"
}
if p.useSTD3Rules {
s += ":UseSTD3Rules"
}
if p.validateLabels {
s += ":ValidateLabels"
}
if p.verifyDNSLength {
s += ":VerifyDNSLength"
}
return s
}
var (
// Punycode is a Profile that does raw punycode processing with a minimum
// of validation.
Punycode *Profile = punycode
// Lookup is the recommended profile for looking up domain names, according
// to Section 5 of RFC 5891. The exact configuration of this profile may
// change over time.
Lookup *Profile = lookup
// Display is the recommended profile for displaying domain names.
// The configuration of this profile may change over time.
Display *Profile = display
// Registration is the recommended profile for checking whether a given
// IDN is valid for registration, according to Section 4 of RFC 5891.
Registration *Profile = registration
punycode = &Profile{}
lookup = &Profile{options{
transitional: true,
useSTD3Rules: true,
validateLabels: true,
removeLeadingDots: true,
trie: trie,
fromPuny: validateFromPunycode,
mapping: validateAndMap,
bidirule: bidirule.ValidString,
}}
display = &Profile{options{
useSTD3Rules: true,
validateLabels: true,
removeLeadingDots: true,
trie: trie,
fromPuny: validateFromPunycode,
mapping: validateAndMap,
bidirule: bidirule.ValidString,
}}
registration = &Profile{options{
useSTD3Rules: true,
validateLabels: true,
verifyDNSLength: true,
trie: trie,
fromPuny: validateFromPunycode,
mapping: validateRegistration,
bidirule: bidirule.ValidString,
}}
// TODO: profiles
// Register: recommended for approving domain names: don't do any mappings
// but rather reject on invalid input. Bundle or block deviation characters.
)
type labelError struct{ label, code_ string }
func (e labelError) code() string { return e.code_ }
func (e labelError) Error() string {
return fmt.Sprintf("idna: invalid label %q", e.label)
}
type runeError rune
func (e runeError) code() string { return "P1" }
func (e runeError) Error() string {
return fmt.Sprintf("idna: disallowed rune %U", e)
}
// process implements the algorithm described in section 4 of UTS #46,
// see https://www.unicode.org/reports/tr46.
func (p *Profile) process(s string, toASCII bool) (string, error) {
var err error
if p.mapping != nil {
s, err = p.mapping(p, s)
}
// Remove leading empty labels.
if p.removeLeadingDots {
for ; len(s) > 0 && s[0] == '.'; s = s[1:] {
}
}
// It seems like we should only create this error on ToASCII, but the
// UTS 46 conformance tests suggests we should always check this.
if err == nil && p.verifyDNSLength && s == "" {
err = &labelError{s, "A4"}
}
labels := labelIter{orig: s}
for ; !labels.done(); labels.next() {
label := labels.label()
if label == "" {
// Empty labels are not okay. The label iterator skips the last
// label if it is empty.
if err == nil && p.verifyDNSLength {
err = &labelError{s, "A4"}
}
continue
}
if strings.HasPrefix(label, acePrefix) {
u, err2 := decode(label[len(acePrefix):])
if err2 != nil {
if err == nil {
err = err2
}
// Spec says keep the old label.
continue
}
labels.set(u)
if err == nil && p.validateLabels {
err = p.fromPuny(p, u)
}
if err == nil {
// This should be called on NonTransitional, according to the
// spec, but that currently does not have any effect. Use the
// original profile to preserve options.
err = p.validateLabel(u)
}
} else if err == nil {
err = p.validateLabel(label)
}
}
if toASCII {
for labels.reset(); !labels.done(); labels.next() {
label := labels.label()
if !ascii(label) {
a, err2 := encode(acePrefix, label)
if err == nil {
err = err2
}
label = a
labels.set(a)
}
n := len(label)
if p.verifyDNSLength && err == nil && (n == 0 || n > 63) {
err = &labelError{label, "A4"}
}
}
}
s = labels.result()
if toASCII && p.verifyDNSLength && err == nil {
// Compute the length of the domain name minus the root label and its dot.
n := len(s)
if n > 0 && s[n-1] == '.' {
n--
}
if len(s) < 1 || n > 253 {
err = &labelError{s, "A4"}
}
}
return s, err
}
func normalize(p *Profile, s string) (string, error) {
return norm.NFC.String(s), nil
}
func validateRegistration(p *Profile, s string) (string, error) {
if !norm.NFC.IsNormalString(s) {
return s, &labelError{s, "V1"}
}
for i := 0; i < len(s); {
v, sz := trie.lookupString(s[i:])
// Copy bytes not copied so far.
switch p.simplify(info(v).category()) {
// TODO: handle the NV8 defined in the Unicode idna data set to allow
// for strict conformance to IDNA2008.
case valid, deviation:
case disallowed, mapped, unknown, ignored:
r, _ := utf8.DecodeRuneInString(s[i:])
return s, runeError(r)
}
i += sz
}
return s, nil
}
func validateAndMap(p *Profile, s string) (string, error) {
var (
err error
b []byte
k int
)
for i := 0; i < len(s); {
v, sz := trie.lookupString(s[i:])
start := i
i += sz
// Copy bytes not copied so far.
switch p.simplify(info(v).category()) {
case valid:
continue
case disallowed:
if err == nil {
r, _ := utf8.DecodeRuneInString(s[start:])
err = runeError(r)
}
continue
case mapped, deviation:
b = append(b, s[k:start]...)
b = info(v).appendMapping(b, s[start:i])
case ignored:
b = append(b, s[k:start]...)
// drop the rune
case unknown:
b = append(b, s[k:start]...)
b = append(b, "\ufffd"...)
}
k = i
}
if k == 0 {
// No changes so far.
s = norm.NFC.String(s)
} else {
b = append(b, s[k:]...)
if norm.NFC.QuickSpan(b) != len(b) {
b = norm.NFC.Bytes(b)
}
// TODO: the punycode converters require strings as input.
s = string(b)
}
return s, err
}
// A labelIter allows iterating over domain name labels.
type labelIter struct {
orig string
slice []string
curStart int
curEnd int
i int
}
func (l *labelIter) reset() {
l.curStart = 0
l.curEnd = 0
l.i = 0
}
func (l *labelIter) done() bool {
return l.curStart >= len(l.orig)
}
func (l *labelIter) result() string {
if l.slice != nil {
return strings.Join(l.slice, ".")
}
return l.orig
}
func (l *labelIter) label() string {
if l.slice != nil {
return l.slice[l.i]
}
p := strings.IndexByte(l.orig[l.curStart:], '.')
l.curEnd = l.curStart + p
if p == -1 {
l.curEnd = len(l.orig)
}
return l.orig[l.curStart:l.curEnd]
}
// next sets the value to the next label. It skips the last label if it is empty.
func (l *labelIter) next() {
l.i++
if l.slice != nil {
if l.i >= len(l.slice) || l.i == len(l.slice)-1 && l.slice[l.i] == "" {
l.curStart = len(l.orig)
}
} else {
l.curStart = l.curEnd + 1
if l.curStart == len(l.orig)-1 && l.orig[l.curStart] == '.' {
l.curStart = len(l.orig)
}
}
}
func (l *labelIter) set(s string) {
if l.slice == nil {
l.slice = strings.Split(l.orig, ".")
}
l.slice[l.i] = s
}
// acePrefix is the ASCII Compatible Encoding prefix.
const acePrefix = "xn--"
func (p *Profile) simplify(cat category) category {
switch cat {
case disallowedSTD3Mapped:
if p.useSTD3Rules {
cat = disallowed
} else {
cat = mapped
}
case disallowedSTD3Valid:
if p.useSTD3Rules {
cat = disallowed
} else {
cat = valid
}
case deviation:
if !p.transitional {
cat = valid
}
case validNV8, validXV8:
// TODO: handle V2008
cat = valid
}
return cat
}
func validateFromPunycode(p *Profile, s string) error {
if !norm.NFC.IsNormalString(s) {
return &labelError{s, "V1"}
}
for i := 0; i < len(s); {
v, sz := trie.lookupString(s[i:])
if c := p.simplify(info(v).category()); c != valid && c != deviation {
return &labelError{s, "V6"}
}
i += sz
}
return nil
}
const (
zwnj = "\u200c"
zwj = "\u200d"
)
type joinState int8
const (
stateStart joinState = iota
stateVirama
stateBefore
stateBeforeVirama
stateAfter
stateFAIL
)
var joinStates = [][numJoinTypes]joinState{
stateStart: {
joiningL: stateBefore,
joiningD: stateBefore,
joinZWNJ: stateFAIL,
joinZWJ: stateFAIL,
joinVirama: stateVirama,
},
stateVirama: {
joiningL: stateBefore,
joiningD: stateBefore,
},
stateBefore: {
joiningL: stateBefore,
joiningD: stateBefore,
joiningT: stateBefore,
joinZWNJ: stateAfter,
joinZWJ: stateFAIL,
joinVirama: stateBeforeVirama,
},
stateBeforeVirama: {
joiningL: stateBefore,
joiningD: stateBefore,
joiningT: stateBefore,
},
stateAfter: {
joiningL: stateFAIL,
joiningD: stateBefore,
joiningT: stateAfter,
joiningR: stateStart,
joinZWNJ: stateFAIL,
joinZWJ: stateFAIL,
joinVirama: stateAfter, // no-op as we can't accept joiners here
},
stateFAIL: {
0: stateFAIL,
joiningL: stateFAIL,
joiningD: stateFAIL,
joiningT: stateFAIL,
joiningR: stateFAIL,
joinZWNJ: stateFAIL,
joinZWJ: stateFAIL,
joinVirama: stateFAIL,
},
}
// validateLabel validates the criteria from Section 4.1. Item 1, 4, and 6 are
// already implicitly satisfied by the overall implementation.
func (p *Profile) validateLabel(s string) error {
if s == "" {
if p.verifyDNSLength {
return &labelError{s, "A4"}
}
return nil
}
if p.bidirule != nil && !p.bidirule(s) {
return &labelError{s, "B"}
}
if !p.validateLabels {
return nil
}
trie := p.trie // p.validateLabels is only set if trie is set.
if len(s) > 4 && s[2] == '-' && s[3] == '-' {
return &labelError{s, "V2"}
}
if s[0] == '-' || s[len(s)-1] == '-' {
return &labelError{s, "V3"}
}
// TODO: merge the use of this in the trie.
v, sz := trie.lookupString(s)
x := info(v)
if x.isModifier() {
return &labelError{s, "V5"}
}
// Quickly return in the absence of zero-width (non) joiners.
if strings.Index(s, zwj) == -1 && strings.Index(s, zwnj) == -1 {
return nil
}
st := stateStart
for i := 0; ; {
jt := x.joinType()
if s[i:i+sz] == zwj {
jt = joinZWJ
} else if s[i:i+sz] == zwnj {
jt = joinZWNJ
}
st = joinStates[st][jt]
if x.isViramaModifier() {
st = joinStates[st][joinVirama]
}
if i += sz; i == len(s) {
break
}
v, sz = trie.lookupString(s[i:])
x = info(v)
}
if st == stateFAIL || st == stateAfter {
return &labelError{s, "C"}
}
return nil
}
func ascii(s string) bool {
for i := 0; i < len(s); i++ {
if s[i] >= utf8.RuneSelf {
return false
}
}
return true
}

203
vendor/golang.org/x/net/idna/punycode.go generated vendored Normal file
View File

@ -0,0 +1,203 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package idna
// This file implements the Punycode algorithm from RFC 3492.
import (
"math"
"strings"
"unicode/utf8"
)
// These parameter values are specified in section 5.
//
// All computation is done with int32s, so that overflow behavior is identical
// regardless of whether int is 32-bit or 64-bit.
const (
base int32 = 36
damp int32 = 700
initialBias int32 = 72
initialN int32 = 128
skew int32 = 38
tmax int32 = 26
tmin int32 = 1
)
func punyError(s string) error { return &labelError{s, "A3"} }
// decode decodes a string as specified in section 6.2.
func decode(encoded string) (string, error) {
if encoded == "" {
return "", nil
}
pos := 1 + strings.LastIndex(encoded, "-")
if pos == 1 {
return "", punyError(encoded)
}
if pos == len(encoded) {
return encoded[:len(encoded)-1], nil
}
output := make([]rune, 0, len(encoded))
if pos != 0 {
for _, r := range encoded[:pos-1] {
output = append(output, r)
}
}
i, n, bias := int32(0), initialN, initialBias
for pos < len(encoded) {
oldI, w := i, int32(1)
for k := base; ; k += base {
if pos == len(encoded) {
return "", punyError(encoded)
}
digit, ok := decodeDigit(encoded[pos])
if !ok {
return "", punyError(encoded)
}
pos++
i += digit * w
if i < 0 {
return "", punyError(encoded)
}
t := k - bias
if t < tmin {
t = tmin
} else if t > tmax {
t = tmax
}
if digit < t {
break
}
w *= base - t
if w >= math.MaxInt32/base {
return "", punyError(encoded)
}
}
x := int32(len(output) + 1)
bias = adapt(i-oldI, x, oldI == 0)
n += i / x
i %= x
if n > utf8.MaxRune || len(output) >= 1024 {
return "", punyError(encoded)
}
output = append(output, 0)
copy(output[i+1:], output[i:])
output[i] = n
i++
}
return string(output), nil
}
// encode encodes a string as specified in section 6.3 and prepends prefix to
// the result.
//
// The "while h < length(input)" line in the specification becomes "for
// remaining != 0" in the Go code, because len(s) in Go is in bytes, not runes.
func encode(prefix, s string) (string, error) {
output := make([]byte, len(prefix), len(prefix)+1+2*len(s))
copy(output, prefix)
delta, n, bias := int32(0), initialN, initialBias
b, remaining := int32(0), int32(0)
for _, r := range s {
if r < 0x80 {
b++
output = append(output, byte(r))
} else {
remaining++
}
}
h := b
if b > 0 {
output = append(output, '-')
}
for remaining != 0 {
m := int32(0x7fffffff)
for _, r := range s {
if m > r && r >= n {
m = r
}
}
delta += (m - n) * (h + 1)
if delta < 0 {
return "", punyError(s)
}
n = m
for _, r := range s {
if r < n {
delta++
if delta < 0 {
return "", punyError(s)
}
continue
}
if r > n {
continue
}
q := delta
for k := base; ; k += base {
t := k - bias
if t < tmin {
t = tmin
} else if t > tmax {
t = tmax
}
if q < t {
break
}
output = append(output, encodeDigit(t+(q-t)%(base-t)))
q = (q - t) / (base - t)
}
output = append(output, encodeDigit(q))
bias = adapt(delta, h+1, h == b)
delta = 0
h++
remaining--
}
delta++
n++
}
return string(output), nil
}
func decodeDigit(x byte) (digit int32, ok bool) {
switch {
case '0' <= x && x <= '9':
return int32(x - ('0' - 26)), true
case 'A' <= x && x <= 'Z':
return int32(x - 'A'), true
case 'a' <= x && x <= 'z':
return int32(x - 'a'), true
}
return 0, false
}
func encodeDigit(digit int32) byte {
switch {
case 0 <= digit && digit < 26:
return byte(digit + 'a')
case 26 <= digit && digit < 36:
return byte(digit + ('0' - 26))
}
panic("idna: internal error in punycode encoding")
}
// adapt is the bias adaptation function specified in section 6.1.
func adapt(delta, numPoints int32, firstTime bool) int32 {
if firstTime {
delta /= damp
} else {
delta /= 2
}
delta += delta / numPoints
k := int32(0)
for delta > ((base-tmin)*tmax)/2 {
delta /= base - tmin
k += base
}
return k + (base-tmin+1)*delta/(delta+skew)
}

4559
vendor/golang.org/x/net/idna/tables10.0.0.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

4653
vendor/golang.org/x/net/idna/tables11.0.0.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

4486
vendor/golang.org/x/net/idna/tables9.0.0.go generated vendored Normal file

File diff suppressed because it is too large Load Diff

72
vendor/golang.org/x/net/idna/trie.go generated vendored Normal file
View File

@ -0,0 +1,72 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package idna
// appendMapping appends the mapping for the respective rune. isMapped must be
// true. A mapping is a categorization of a rune as defined in UTS #46.
func (c info) appendMapping(b []byte, s string) []byte {
index := int(c >> indexShift)
if c&xorBit == 0 {
s := mappings[index:]
return append(b, s[1:s[0]+1]...)
}
b = append(b, s...)
if c&inlineXOR == inlineXOR {
// TODO: support and handle two-byte inline masks
b[len(b)-1] ^= byte(index)
} else {
for p := len(b) - int(xorData[index]); p < len(b); p++ {
index++
b[p] ^= xorData[index]
}
}
return b
}
// Sparse block handling code.
type valueRange struct {
value uint16 // header: value:stride
lo, hi byte // header: lo:n
}
type sparseBlocks struct {
values []valueRange
offset []uint16
}
var idnaSparse = sparseBlocks{
values: idnaSparseValues[:],
offset: idnaSparseOffset[:],
}
// Don't use newIdnaTrie to avoid unconditional linking in of the table.
var trie = &idnaTrie{}
// lookup determines the type of block n and looks up the value for b.
// For n < t.cutoff, the block is a simple lookup table. Otherwise, the block
// is a list of ranges with an accompanying value. Given a matching range r,
// the value for b is by r.value + (b - r.lo) * stride.
func (t *sparseBlocks) lookup(n uint32, b byte) uint16 {
offset := t.offset[n]
header := t.values[offset]
lo := offset + 1
hi := lo + uint16(header.lo)
for lo < hi {
m := lo + (hi-lo)/2
r := t.values[m]
if r.lo <= b && b <= r.hi {
return r.value + uint16(b-r.lo)*header.value
}
if b < r.lo {
hi = m
} else {
lo = m + 1
}
}
return 0
}

119
vendor/golang.org/x/net/idna/trieval.go generated vendored Normal file
View File

@ -0,0 +1,119 @@
// Code generated by running "go generate" in golang.org/x/text. DO NOT EDIT.
package idna
// This file contains definitions for interpreting the trie value of the idna
// trie generated by "go run gen*.go". It is shared by both the generator
// program and the resultant package. Sharing is achieved by the generator
// copying gen_trieval.go to trieval.go and changing what's above this comment.
// info holds information from the IDNA mapping table for a single rune. It is
// the value returned by a trie lookup. In most cases, all information fits in
// a 16-bit value. For mappings, this value may contain an index into a slice
// with the mapped string. Such mappings can consist of the actual mapped value
// or an XOR pattern to be applied to the bytes of the UTF8 encoding of the
// input rune. This technique is used by the cases packages and reduces the
// table size significantly.
//
// The per-rune values have the following format:
//
// if mapped {
// if inlinedXOR {
// 15..13 inline XOR marker
// 12..11 unused
// 10..3 inline XOR mask
// } else {
// 15..3 index into xor or mapping table
// }
// } else {
// 15..14 unused
// 13 mayNeedNorm
// 12..11 attributes
// 10..8 joining type
// 7..3 category type
// }
// 2 use xor pattern
// 1..0 mapped category
//
// See the definitions below for a more detailed description of the various
// bits.
type info uint16
const (
catSmallMask = 0x3
catBigMask = 0xF8
indexShift = 3
xorBit = 0x4 // interpret the index as an xor pattern
inlineXOR = 0xE000 // These bits are set if the XOR pattern is inlined.
joinShift = 8
joinMask = 0x07
// Attributes
attributesMask = 0x1800
viramaModifier = 0x1800
modifier = 0x1000
rtl = 0x0800
mayNeedNorm = 0x2000
)
// A category corresponds to a category defined in the IDNA mapping table.
type category uint16
const (
unknown category = 0 // not currently defined in unicode.
mapped category = 1
disallowedSTD3Mapped category = 2
deviation category = 3
)
const (
valid category = 0x08
validNV8 category = 0x18
validXV8 category = 0x28
disallowed category = 0x40
disallowedSTD3Valid category = 0x80
ignored category = 0xC0
)
// join types and additional rune information
const (
joiningL = (iota + 1)
joiningD
joiningT
joiningR
//the following types are derived during processing
joinZWJ
joinZWNJ
joinVirama
numJoinTypes
)
func (c info) isMapped() bool {
return c&0x3 != 0
}
func (c info) category() category {
small := c & catSmallMask
if small != 0 {
return category(small)
}
return category(c & catBigMask)
}
func (c info) joinType() info {
if c.isMapped() {
return 0
}
return (c >> joinShift) & joinMask
}
func (c info) isModifier() bool {
return c&(modifier|catSmallMask) == modifier
}
func (c info) isViramaModifier() bool {
return c&(attributesMask|catSmallMask) == viramaModifier
}

View File

@ -4,38 +4,34 @@
// Package iana provides protocol number resources managed by the Internet Assigned Numbers Authority (IANA). // Package iana provides protocol number resources managed by the Internet Assigned Numbers Authority (IANA).
package iana // import "golang.org/x/net/internal/iana" package iana // import "golang.org/x/net/internal/iana"
// Differentiated Services Field Codepoints (DSCP), Updated: 2017-05-12 // Differentiated Services Field Codepoints (DSCP), Updated: 2018-05-04
const ( const (
DiffServCS0 = 0x0 // CS0 DiffServCS0 = 0x00 // CS0
DiffServCS1 = 0x20 // CS1 DiffServCS1 = 0x20 // CS1
DiffServCS2 = 0x40 // CS2 DiffServCS2 = 0x40 // CS2
DiffServCS3 = 0x60 // CS3 DiffServCS3 = 0x60 // CS3
DiffServCS4 = 0x80 // CS4 DiffServCS4 = 0x80 // CS4
DiffServCS5 = 0xa0 // CS5 DiffServCS5 = 0xa0 // CS5
DiffServCS6 = 0xc0 // CS6 DiffServCS6 = 0xc0 // CS6
DiffServCS7 = 0xe0 // CS7 DiffServCS7 = 0xe0 // CS7
DiffServAF11 = 0x28 // AF11 DiffServAF11 = 0x28 // AF11
DiffServAF12 = 0x30 // AF12 DiffServAF12 = 0x30 // AF12
DiffServAF13 = 0x38 // AF13 DiffServAF13 = 0x38 // AF13
DiffServAF21 = 0x48 // AF21 DiffServAF21 = 0x48 // AF21
DiffServAF22 = 0x50 // AF22 DiffServAF22 = 0x50 // AF22
DiffServAF23 = 0x58 // AF23 DiffServAF23 = 0x58 // AF23
DiffServAF31 = 0x68 // AF31 DiffServAF31 = 0x68 // AF31
DiffServAF32 = 0x70 // AF32 DiffServAF32 = 0x70 // AF32
DiffServAF33 = 0x78 // AF33 DiffServAF33 = 0x78 // AF33
DiffServAF41 = 0x88 // AF41 DiffServAF41 = 0x88 // AF41
DiffServAF42 = 0x90 // AF42 DiffServAF42 = 0x90 // AF42
DiffServAF43 = 0x98 // AF43 DiffServAF43 = 0x98 // AF43
DiffServEF = 0xb8 // EF DiffServEF = 0xb8 // EF
DiffServVOICEADMIT = 0xb0 // VOICE-ADMIT DiffServVOICEADMIT = 0xb0 // VOICE-ADMIT
) NotECNTransport = 0x00 // Not-ECT (Not ECN-Capable Transport)
ECNTransport1 = 0x01 // ECT(1) (ECN-Capable Transport(1))
// IPv4 TOS Byte and IPv6 Traffic Class Octet, Updated: 2001-09-06 ECNTransport0 = 0x02 // ECT(0) (ECN-Capable Transport(0))
const ( CongestionExperienced = 0x03 // CE (Congestion Experienced)
NotECNTransport = 0x0 // Not-ECT (Not ECN-Capable Transport)
ECNTransport1 = 0x1 // ECT(1) (ECN-Capable Transport(1))
ECNTransport0 = 0x2 // ECT(0) (ECN-Capable Transport(0))
CongestionExperienced = 0x3 // CE (Congestion Experienced)
) )
// Protocol Numbers, Updated: 2017-10-13 // Protocol Numbers, Updated: 2017-10-13
@ -179,7 +175,7 @@ const (
ProtocolReserved = 255 // Reserved ProtocolReserved = 255 // Reserved
) )
// Address Family Numbers, Updated: 2016-10-25 // Address Family Numbers, Updated: 2018-04-02
const ( const (
AddrFamilyIPv4 = 1 // IP (IP version 4) AddrFamilyIPv4 = 1 // IP (IP version 4)
AddrFamilyIPv6 = 2 // IP6 (IP version 6) AddrFamilyIPv6 = 2 // IP6 (IP version 6)

View File

@ -1,387 +0,0 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
//go:generate go run gen.go
// This program generates internet protocol constants and tables by
// reading IANA protocol registries.
package main
import (
"bytes"
"encoding/xml"
"fmt"
"go/format"
"io"
"io/ioutil"
"net/http"
"os"
"strconv"
"strings"
)
var registries = []struct {
url string
parse func(io.Writer, io.Reader) error
}{
{
"https://www.iana.org/assignments/dscp-registry/dscp-registry.xml",
parseDSCPRegistry,
},
{
"https://www.iana.org/assignments/ipv4-tos-byte/ipv4-tos-byte.xml",
parseTOSTCByte,
},
{
"https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xml",
parseProtocolNumbers,
},
{
"http://www.iana.org/assignments/address-family-numbers/address-family-numbers.xml",
parseAddrFamilyNumbers,
},
}
func main() {
var bb bytes.Buffer
fmt.Fprintf(&bb, "// go generate gen.go\n")
fmt.Fprintf(&bb, "// Code generated by the command above; DO NOT EDIT.\n\n")
fmt.Fprintf(&bb, "// Package iana provides protocol number resources managed by the Internet Assigned Numbers Authority (IANA).\n")
fmt.Fprintf(&bb, `package iana // import "golang.org/x/net/internal/iana"`+"\n\n")
for _, r := range registries {
resp, err := http.Get(r.url)
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
fmt.Fprintf(os.Stderr, "got HTTP status code %v for %v\n", resp.StatusCode, r.url)
os.Exit(1)
}
if err := r.parse(&bb, resp.Body); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
fmt.Fprintf(&bb, "\n")
}
b, err := format.Source(bb.Bytes())
if err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
if err := ioutil.WriteFile("const.go", b, 0644); err != nil {
fmt.Fprintln(os.Stderr, err)
os.Exit(1)
}
}
func parseDSCPRegistry(w io.Writer, r io.Reader) error {
dec := xml.NewDecoder(r)
var dr dscpRegistry
if err := dec.Decode(&dr); err != nil {
return err
}
drs := dr.escape()
fmt.Fprintf(w, "// %s, Updated: %s\n", dr.Title, dr.Updated)
fmt.Fprintf(w, "const (\n")
for _, dr := range drs {
fmt.Fprintf(w, "DiffServ%s = %#x", dr.Name, dr.Value)
fmt.Fprintf(w, "// %s\n", dr.OrigName)
}
fmt.Fprintf(w, ")\n")
return nil
}
type dscpRegistry struct {
XMLName xml.Name `xml:"registry"`
Title string `xml:"title"`
Updated string `xml:"updated"`
Note string `xml:"note"`
RegTitle string `xml:"registry>title"`
PoolRecords []struct {
Name string `xml:"name"`
Space string `xml:"space"`
} `xml:"registry>record"`
Records []struct {
Name string `xml:"name"`
Space string `xml:"space"`
} `xml:"registry>registry>record"`
}
type canonDSCPRecord struct {
OrigName string
Name string
Value int
}
func (drr *dscpRegistry) escape() []canonDSCPRecord {
drs := make([]canonDSCPRecord, len(drr.Records))
sr := strings.NewReplacer(
"+", "",
"-", "",
"/", "",
".", "",
" ", "",
)
for i, dr := range drr.Records {
s := strings.TrimSpace(dr.Name)
drs[i].OrigName = s
drs[i].Name = sr.Replace(s)
n, err := strconv.ParseUint(dr.Space, 2, 8)
if err != nil {
continue
}
drs[i].Value = int(n) << 2
}
return drs
}
func parseTOSTCByte(w io.Writer, r io.Reader) error {
dec := xml.NewDecoder(r)
var ttb tosTCByte
if err := dec.Decode(&ttb); err != nil {
return err
}
trs := ttb.escape()
fmt.Fprintf(w, "// %s, Updated: %s\n", ttb.Title, ttb.Updated)
fmt.Fprintf(w, "const (\n")
for _, tr := range trs {
fmt.Fprintf(w, "%s = %#x", tr.Keyword, tr.Value)
fmt.Fprintf(w, "// %s\n", tr.OrigKeyword)
}
fmt.Fprintf(w, ")\n")
return nil
}
type tosTCByte struct {
XMLName xml.Name `xml:"registry"`
Title string `xml:"title"`
Updated string `xml:"updated"`
Note string `xml:"note"`
RegTitle string `xml:"registry>title"`
Records []struct {
Binary string `xml:"binary"`
Keyword string `xml:"keyword"`
} `xml:"registry>record"`
}
type canonTOSTCByteRecord struct {
OrigKeyword string
Keyword string
Value int
}
func (ttb *tosTCByte) escape() []canonTOSTCByteRecord {
trs := make([]canonTOSTCByteRecord, len(ttb.Records))
sr := strings.NewReplacer(
"Capable", "",
"(", "",
")", "",
"+", "",
"-", "",
"/", "",
".", "",
" ", "",
)
for i, tr := range ttb.Records {
s := strings.TrimSpace(tr.Keyword)
trs[i].OrigKeyword = s
ss := strings.Split(s, " ")
if len(ss) > 1 {
trs[i].Keyword = strings.Join(ss[1:], " ")
} else {
trs[i].Keyword = ss[0]
}
trs[i].Keyword = sr.Replace(trs[i].Keyword)
n, err := strconv.ParseUint(tr.Binary, 2, 8)
if err != nil {
continue
}
trs[i].Value = int(n)
}
return trs
}
func parseProtocolNumbers(w io.Writer, r io.Reader) error {
dec := xml.NewDecoder(r)
var pn protocolNumbers
if err := dec.Decode(&pn); err != nil {
return err
}
prs := pn.escape()
prs = append([]canonProtocolRecord{{
Name: "IP",
Descr: "IPv4 encapsulation, pseudo protocol number",
Value: 0,
}}, prs...)
fmt.Fprintf(w, "// %s, Updated: %s\n", pn.Title, pn.Updated)
fmt.Fprintf(w, "const (\n")
for _, pr := range prs {
if pr.Name == "" {
continue
}
fmt.Fprintf(w, "Protocol%s = %d", pr.Name, pr.Value)
s := pr.Descr
if s == "" {
s = pr.OrigName
}
fmt.Fprintf(w, "// %s\n", s)
}
fmt.Fprintf(w, ")\n")
return nil
}
type protocolNumbers struct {
XMLName xml.Name `xml:"registry"`
Title string `xml:"title"`
Updated string `xml:"updated"`
RegTitle string `xml:"registry>title"`
Note string `xml:"registry>note"`
Records []struct {
Value string `xml:"value"`
Name string `xml:"name"`
Descr string `xml:"description"`
} `xml:"registry>record"`
}
type canonProtocolRecord struct {
OrigName string
Name string
Descr string
Value int
}
func (pn *protocolNumbers) escape() []canonProtocolRecord {
prs := make([]canonProtocolRecord, len(pn.Records))
sr := strings.NewReplacer(
"-in-", "in",
"-within-", "within",
"-over-", "over",
"+", "P",
"-", "",
"/", "",
".", "",
" ", "",
)
for i, pr := range pn.Records {
if strings.Contains(pr.Name, "Deprecated") ||
strings.Contains(pr.Name, "deprecated") {
continue
}
prs[i].OrigName = pr.Name
s := strings.TrimSpace(pr.Name)
switch pr.Name {
case "ISIS over IPv4":
prs[i].Name = "ISIS"
case "manet":
prs[i].Name = "MANET"
default:
prs[i].Name = sr.Replace(s)
}
ss := strings.Split(pr.Descr, "\n")
for i := range ss {
ss[i] = strings.TrimSpace(ss[i])
}
if len(ss) > 1 {
prs[i].Descr = strings.Join(ss, " ")
} else {
prs[i].Descr = ss[0]
}
prs[i].Value, _ = strconv.Atoi(pr.Value)
}
return prs
}
func parseAddrFamilyNumbers(w io.Writer, r io.Reader) error {
dec := xml.NewDecoder(r)
var afn addrFamilylNumbers
if err := dec.Decode(&afn); err != nil {
return err
}
afrs := afn.escape()
fmt.Fprintf(w, "// %s, Updated: %s\n", afn.Title, afn.Updated)
fmt.Fprintf(w, "const (\n")
for _, afr := range afrs {
if afr.Name == "" {
continue
}
fmt.Fprintf(w, "AddrFamily%s = %d", afr.Name, afr.Value)
fmt.Fprintf(w, "// %s\n", afr.Descr)
}
fmt.Fprintf(w, ")\n")
return nil
}
type addrFamilylNumbers struct {
XMLName xml.Name `xml:"registry"`
Title string `xml:"title"`
Updated string `xml:"updated"`
RegTitle string `xml:"registry>title"`
Note string `xml:"registry>note"`
Records []struct {
Value string `xml:"value"`
Descr string `xml:"description"`
} `xml:"registry>record"`
}
type canonAddrFamilyRecord struct {
Name string
Descr string
Value int
}
func (afn *addrFamilylNumbers) escape() []canonAddrFamilyRecord {
afrs := make([]canonAddrFamilyRecord, len(afn.Records))
sr := strings.NewReplacer(
"IP version 4", "IPv4",
"IP version 6", "IPv6",
"Identifier", "ID",
"-", "",
"-", "",
"/", "",
".", "",
" ", "",
)
for i, afr := range afn.Records {
if strings.Contains(afr.Descr, "Unassigned") ||
strings.Contains(afr.Descr, "Reserved") {
continue
}
afrs[i].Descr = afr.Descr
s := strings.TrimSpace(afr.Descr)
switch s {
case "IP (IP version 4)":
afrs[i].Name = "IPv4"
case "IP6 (IP version 6)":
afrs[i].Name = "IPv6"
case "AFI for L2VPN information":
afrs[i].Name = "L2VPN"
case "E.164 with NSAP format subaddress":
afrs[i].Name = "E164withSubaddress"
case "MT IP: Multi-Topology IP version 4":
afrs[i].Name = "MTIPv4"
case "MAC/24":
afrs[i].Name = "MACFinal24bits"
case "MAC/40":
afrs[i].Name = "MACFinal40bits"
case "IPv6/64":
afrs[i].Name = "IPv6Initial64bits"
default:
n := strings.Index(s, "(")
if n > 0 {
s = s[:n]
}
n = strings.Index(s, ":")
if n > 0 {
s = s[:n]
}
afrs[i].Name = sr.Replace(s)
}
afrs[i].Value, _ = strconv.Atoi(afr.Value)
}
return afrs
}

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd solaris // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package socket package socket

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd netbsd openbsd // +build aix darwin dragonfly freebsd netbsd openbsd
package socket package socket

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build arm64 amd64 ppc64 ppc64le mips64 mips64le s390x // +build arm64 amd64 ppc64 ppc64le mips64 mips64le riscv64 s390x
// +build linux // +build linux
package socket package socket

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build !darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris // +build !aix,!darwin,!dragonfly,!freebsd,!linux,!netbsd,!openbsd,!solaris
package socket package socket

View File

@ -1,44 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
// +godefs map struct_in6_addr [16]byte /* in6_addr */
package socket
/*
#include <sys/socket.h>
#include <netinet/in.h>
*/
import "C"
const (
sysAF_UNSPEC = C.AF_UNSPEC
sysAF_INET = C.AF_INET
sysAF_INET6 = C.AF_INET6
sysSOCK_RAW = C.SOCK_RAW
)
type iovec C.struct_iovec
type msghdr C.struct_msghdr
type cmsghdr C.struct_cmsghdr
type sockaddrInet C.struct_sockaddr_in
type sockaddrInet6 C.struct_sockaddr_in6
const (
sizeofIovec = C.sizeof_struct_iovec
sizeofMsghdr = C.sizeof_struct_msghdr
sizeofCmsghdr = C.sizeof_struct_cmsghdr
sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)

View File

@ -1,44 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
// +godefs map struct_in6_addr [16]byte /* in6_addr */
package socket
/*
#include <sys/socket.h>
#include <netinet/in.h>
*/
import "C"
const (
sysAF_UNSPEC = C.AF_UNSPEC
sysAF_INET = C.AF_INET
sysAF_INET6 = C.AF_INET6
sysSOCK_RAW = C.SOCK_RAW
)
type iovec C.struct_iovec
type msghdr C.struct_msghdr
type cmsghdr C.struct_cmsghdr
type sockaddrInet C.struct_sockaddr_in
type sockaddrInet6 C.struct_sockaddr_in6
const (
sizeofIovec = C.sizeof_struct_iovec
sizeofMsghdr = C.sizeof_struct_msghdr
sizeofCmsghdr = C.sizeof_struct_cmsghdr
sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)

View File

@ -1,44 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
// +godefs map struct_in6_addr [16]byte /* in6_addr */
package socket
/*
#include <sys/socket.h>
#include <netinet/in.h>
*/
import "C"
const (
sysAF_UNSPEC = C.AF_UNSPEC
sysAF_INET = C.AF_INET
sysAF_INET6 = C.AF_INET6
sysSOCK_RAW = C.SOCK_RAW
)
type iovec C.struct_iovec
type msghdr C.struct_msghdr
type cmsghdr C.struct_cmsghdr
type sockaddrInet C.struct_sockaddr_in
type sockaddrInet6 C.struct_sockaddr_in6
const (
sizeofIovec = C.sizeof_struct_iovec
sizeofMsghdr = C.sizeof_struct_msghdr
sizeofCmsghdr = C.sizeof_struct_cmsghdr
sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)

View File

@ -1,49 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
// +godefs map struct_in6_addr [16]byte /* in6_addr */
package socket
/*
#include <linux/in.h>
#include <linux/in6.h>
#define _GNU_SOURCE
#include <sys/socket.h>
*/
import "C"
const (
sysAF_UNSPEC = C.AF_UNSPEC
sysAF_INET = C.AF_INET
sysAF_INET6 = C.AF_INET6
sysSOCK_RAW = C.SOCK_RAW
)
type iovec C.struct_iovec
type msghdr C.struct_msghdr
type mmsghdr C.struct_mmsghdr
type cmsghdr C.struct_cmsghdr
type sockaddrInet C.struct_sockaddr_in
type sockaddrInet6 C.struct_sockaddr_in6
const (
sizeofIovec = C.sizeof_struct_iovec
sizeofMsghdr = C.sizeof_struct_msghdr
sizeofMmsghdr = C.sizeof_struct_mmsghdr
sizeofCmsghdr = C.sizeof_struct_cmsghdr
sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)

View File

@ -1,47 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
// +godefs map struct_in6_addr [16]byte /* in6_addr */
package socket
/*
#include <sys/socket.h>
#include <netinet/in.h>
*/
import "C"
const (
sysAF_UNSPEC = C.AF_UNSPEC
sysAF_INET = C.AF_INET
sysAF_INET6 = C.AF_INET6
sysSOCK_RAW = C.SOCK_RAW
)
type iovec C.struct_iovec
type msghdr C.struct_msghdr
type mmsghdr C.struct_mmsghdr
type cmsghdr C.struct_cmsghdr
type sockaddrInet C.struct_sockaddr_in
type sockaddrInet6 C.struct_sockaddr_in6
const (
sizeofIovec = C.sizeof_struct_iovec
sizeofMsghdr = C.sizeof_struct_msghdr
sizeofMmsghdr = C.sizeof_struct_mmsghdr
sizeofCmsghdr = C.sizeof_struct_cmsghdr
sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)

View File

@ -1,44 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
// +godefs map struct_in6_addr [16]byte /* in6_addr */
package socket
/*
#include <sys/socket.h>
#include <netinet/in.h>
*/
import "C"
const (
sysAF_UNSPEC = C.AF_UNSPEC
sysAF_INET = C.AF_INET
sysAF_INET6 = C.AF_INET6
sysSOCK_RAW = C.SOCK_RAW
)
type iovec C.struct_iovec
type msghdr C.struct_msghdr
type cmsghdr C.struct_cmsghdr
type sockaddrInet C.struct_sockaddr_in
type sockaddrInet6 C.struct_sockaddr_in6
const (
sizeofIovec = C.sizeof_struct_iovec
sizeofMsghdr = C.sizeof_struct_msghdr
sizeofCmsghdr = C.sizeof_struct_cmsghdr
sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)

View File

@ -1,44 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
// +godefs map struct_in_addr [4]byte /* in_addr */
// +godefs map struct_in6_addr [16]byte /* in6_addr */
package socket
/*
#include <sys/socket.h>
#include <netinet/in.h>
*/
import "C"
const (
sysAF_UNSPEC = C.AF_UNSPEC
sysAF_INET = C.AF_INET
sysAF_INET6 = C.AF_INET6
sysSOCK_RAW = C.SOCK_RAW
)
type iovec C.struct_iovec
type msghdr C.struct_msghdr
type cmsghdr C.struct_cmsghdr
type sockaddrInet C.struct_sockaddr_in
type sockaddrInet6 C.struct_sockaddr_in6
const (
sizeofIovec = C.sizeof_struct_iovec
sizeofMsghdr = C.sizeof_struct_msghdr
sizeofCmsghdr = C.sizeof_struct_cmsghdr
sizeofSockaddrInet = C.sizeof_struct_sockaddr_in
sizeofSockaddrInet6 = C.sizeof_struct_sockaddr_in6
)

7
vendor/golang.org/x/net/internal/socket/empty.s generated vendored Normal file
View File

@ -0,0 +1,7 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin,go1.12
// This exists solely so we can linkname in symbols from syscall.

View File

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd solaris // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris
package socket package socket

Some files were not shown because too many files have changed in this diff Show More