diff --git a/weed/command/s3.go b/weed/command/s3.go index 0575cfa58..f4c7166c3 100644 --- a/weed/command/s3.go +++ b/weed/command/s3.go @@ -328,6 +328,10 @@ func (s3opt *S3Options) startS3Server() bool { ClientAuth: clientAuth, ClientCAs: caCertPool, } + err = security.FixTlsConfig(util.GetViper(), httpS.TLSConfig) + if err != nil { + glog.Fatalf("error with tls config: %v", err) + } if *s3opt.portHttps == 0 { glog.V(0).Infof("Start Seaweed S3 API Server %s at https port %d", util.Version(), *s3opt.port) if s3ApiLocalListener != nil { diff --git a/weed/security/tls.go b/weed/security/tls.go index 977234ee0..1a9dfacb5 100644 --- a/weed/security/tls.go +++ b/weed/security/tls.go @@ -4,16 +4,17 @@ import ( "crypto/tls" "crypto/x509" "fmt" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/credentials/tls/certprovider/pemfile" - "google.golang.org/grpc/security/advancedtls" "os" + "slices" "strings" "time" "github.com/seaweedfs/seaweedfs/weed/glog" "github.com/seaweedfs/seaweedfs/weed/util" "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/credentials/tls/certprovider/pemfile" + "google.golang.org/grpc/security/advancedtls" ) const CredRefreshingInterval = time.Duration(5) * time.Hour @@ -62,7 +63,22 @@ func LoadServerTLS(config *util.ViperProxy, component string) (grpc.ServerOption RootProvider: serverRootProvider, }, RequireClientCert: true, - VerificationType: advancedtls.CertVerification, + VerificationType: advancedtls.CertVerification, + } + options.MinTLSVersion, err = TlsVersionByName(config.GetString("tls.min_version")) + if err != nil { + glog.Warningf("tls min version parse failed, %v", err) + return nil, nil + } + options.MaxTLSVersion, err = TlsVersionByName(config.GetString("tls.max_version")) + if err != nil { + glog.Warningf("tls max version parse failed, %v", err) + return nil, nil + } + options.CipherSuites, err = TlsCipherSuiteByNames(config.GetString("tls.cipher_suites")) + if err != nil { + glog.Warningf("tls cipher suite parse failed, %v", err) + return nil, nil } allowedCommonNames := config.GetString(component + ".allowed_commonNames") allowedWildcardDomain := config.GetString("grpc.allowed_wildcard_domain") @@ -123,8 +139,8 @@ func LoadClientTLS(config *util.ViperProxy, component string) grpc.DialOption { IdentityProvider: clientProvider, }, AdditionalPeerVerification: func(params *advancedtls.HandshakeVerificationInfo) (*advancedtls.PostHandshakeVerificationResults, error) { - return &advancedtls.PostHandshakeVerificationResults{}, nil - }, + return &advancedtls.PostHandshakeVerificationResults{}, nil + }, RootOptions: advancedtls.RootCertificateOptions{ RootProvider: clientRootProvider, }, @@ -166,3 +182,57 @@ func (a Authenticator) Authenticate(params *advancedtls.HandshakeVerificationInf glog.Error(err) return nil, err } + +func FixTlsConfig(viper *util.ViperProxy, config *tls.Config) error { + var err error + config.MinVersion, err = TlsVersionByName(viper.GetString("tls.min_version")) + if err != nil { + return err + } + config.MaxVersion, err = TlsVersionByName(viper.GetString("tls.max_version")) + if err != nil { + return err + } + config.CipherSuites, err = TlsCipherSuiteByNames(viper.GetString("tls.cipher_suites")) + return err +} + +func TlsVersionByName(name string) (uint16, error) { + switch name { + case "": + return 0, nil + case "SSLv3": + return tls.VersionSSL30, nil + case "TLS 1.0": + return tls.VersionTLS10, nil + case "TLS 1.1": + return tls.VersionTLS11, nil + case "TLS 1.2": + return tls.VersionTLS12, nil + case "TLS 1.3": + return tls.VersionTLS13, nil + default: + return 0, fmt.Errorf("invalid tls version %s", name) + } +} + +func TlsCipherSuiteByNames(cipherSuiteNames string) ([]uint16, error) { + cipherSuiteNames = strings.TrimSpace(cipherSuiteNames) + if cipherSuiteNames == "" { + return nil, nil + } + names := strings.Split(cipherSuiteNames, ",") + cipherSuites := tls.CipherSuites() + cipherIds := make([]uint16, 0, len(names)) + for _, name := range names { + name = strings.TrimSpace(name) + index := slices.IndexFunc(cipherSuites, func(suite *tls.CipherSuite) bool { + return name == suite.Name + }) + if index == -1 { + return nil, fmt.Errorf("invalid tls cipher suite name %s", name) + } + cipherIds = append(cipherIds, cipherSuites[index].ID) + } + return cipherIds, nil +}