Improve ECH support

This commit is contained in:
世界
2023-08-29 13:43:42 +08:00
parent be61ca64d4
commit 5b343d4c72
34 changed files with 434 additions and 84 deletions

View File

@@ -13,29 +13,29 @@ import (
aTLS "github.com/sagernet/sing/common/tls"
)
func NewDialerFromOptions(router adapter.Router, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
func NewDialerFromOptions(ctx context.Context, router adapter.Router, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
if !options.Enabled {
return dialer, nil
}
config, err := NewClient(router, serverAddress, options)
config, err := NewClient(ctx, serverAddress, options)
if err != nil {
return nil, err
}
return NewDialer(dialer, config), nil
}
func NewClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
func NewClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
if !options.Enabled {
return nil, nil
}
if options.ECH != nil && options.ECH.Enabled {
return NewECHClient(router, serverAddress, options)
return NewECHClient(ctx, serverAddress, options)
} else if options.Reality != nil && options.Reality.Enabled {
return NewRealityClient(router, serverAddress, options)
return NewRealityClient(ctx, serverAddress, options)
} else if options.UTLS != nil && options.UTLS.Enabled {
return NewUTLSClient(router, serverAddress, options)
return NewUTLSClient(ctx, serverAddress, options)
} else {
return NewSTDClient(router, serverAddress, options)
return NewSTDClient(ctx, serverAddress, options)
}
}

View File

@@ -7,15 +7,18 @@ import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"net"
"net/netip"
"os"
"strings"
cftls "github.com/sagernet/cloudflare-tls"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
mDNS "github.com/miekg/dns"
)
@@ -80,7 +83,7 @@ func (c *echConnWrapper) Upstream() any {
return c.Conn
}
func NewECHClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
func NewECHClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
var serverName string
if options.ServerName != "" {
serverName = options.ServerName
@@ -94,7 +97,7 @@ func NewECHClient(router adapter.Router, serverAddress string, options option.Ou
}
var tlsConfig cftls.Config
tlsConfig.Time = router.TimeFunc()
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
if options.DisableSNI {
tlsConfig.ServerName = "127.0.0.1"
} else {
@@ -168,24 +171,24 @@ func NewECHClient(router adapter.Router, serverAddress string, options option.Ou
tlsConfig.ECHEnabled = true
tlsConfig.PQSignatureSchemesEnabled = options.ECH.PQSignatureSchemesEnabled
tlsConfig.DynamicRecordSizingDisabled = options.ECH.DynamicRecordSizingDisabled
if options.ECH.Config != "" {
clientConfigContent, err := base64.StdEncoding.DecodeString(options.ECH.Config)
if err != nil {
return nil, err
if len(options.ECH.Config) > 0 {
block, rest := pem.Decode([]byte(strings.Join(options.ECH.Config, "\n")))
if block == nil || block.Type != "ECH CONFIGS" || len(rest) > 0 {
return nil, E.New("invalid ECH configs pem")
}
clientConfig, err := cftls.UnmarshalECHConfigs(clientConfigContent)
echConfigs, err := cftls.UnmarshalECHConfigs(block.Bytes)
if err != nil {
return nil, err
return nil, E.Cause(err, "parse ECH configs")
}
tlsConfig.ClientECHConfigs = clientConfig
tlsConfig.ClientECHConfigs = echConfigs
} else {
tlsConfig.GetClientECHConfigs = fetchECHClientConfig(router)
tlsConfig.GetClientECHConfigs = fetchECHClientConfig(ctx)
}
return &ECHClientConfig{&tlsConfig}, nil
}
func fetchECHClientConfig(router adapter.Router) func(ctx context.Context, serverName string) ([]cftls.ECHConfig, error) {
return func(ctx context.Context, serverName string) ([]cftls.ECHConfig, error) {
func fetchECHClientConfig(ctx context.Context) func(_ context.Context, serverName string) ([]cftls.ECHConfig, error) {
return func(_ context.Context, serverName string) ([]cftls.ECHConfig, error) {
message := &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
RecursionDesired: true,
@@ -198,7 +201,7 @@ func fetchECHClientConfig(router adapter.Router) func(ctx context.Context, serve
},
},
}
response, err := router.Exchange(ctx, message)
response, err := adapter.RouterFromContext(ctx).Exchange(ctx, message)
if err != nil {
return nil, err
}

330
common/tls/ech_server.go Normal file
View File

@@ -0,0 +1,330 @@
//go:build with_ech
package tls
import (
"context"
"crypto/tls"
"encoding/pem"
"net"
"os"
"strings"
cftls "github.com/sagernet/cloudflare-tls"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
"github.com/fsnotify/fsnotify"
)
type echServerConfig struct {
config *cftls.Config
logger log.Logger
certificate []byte
key []byte
certificatePath string
keyPath string
watcher *fsnotify.Watcher
echKeyPath string
echWatcher *fsnotify.Watcher
}
func (c *echServerConfig) ServerName() string {
return c.config.ServerName
}
func (c *echServerConfig) SetServerName(serverName string) {
c.config.ServerName = serverName
}
func (c *echServerConfig) NextProtos() []string {
return c.config.NextProtos
}
func (c *echServerConfig) SetNextProtos(nextProto []string) {
c.config.NextProtos = nextProto
}
func (c *echServerConfig) Config() (*STDConfig, error) {
return nil, E.New("unsupported usage for ECH")
}
func (c *echServerConfig) Client(conn net.Conn) (Conn, error) {
return &echConnWrapper{cftls.Client(conn, c.config)}, nil
}
func (c *echServerConfig) Server(conn net.Conn) (Conn, error) {
return &echConnWrapper{cftls.Server(conn, c.config)}, nil
}
func (c *echServerConfig) Clone() Config {
return &echServerConfig{
config: c.config.Clone(),
}
}
func (c *echServerConfig) Start() error {
if c.certificatePath != "" && c.keyPath != "" {
err := c.startWatcher()
if err != nil {
c.logger.Warn("create fsnotify watcher: ", err)
}
}
if c.echKeyPath != "" {
err := c.startECHWatcher()
if err != nil {
c.logger.Warn("create fsnotify watcher: ", err)
}
}
return nil
}
func (c *echServerConfig) startWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
if c.certificatePath != "" {
err = watcher.Add(c.certificatePath)
if err != nil {
return err
}
}
if c.keyPath != "" {
err = watcher.Add(c.keyPath)
if err != nil {
return err
}
}
c.watcher = watcher
go c.loopUpdate()
return nil
}
func (c *echServerConfig) loopUpdate() {
for {
select {
case event, ok := <-c.watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := c.reloadKeyPair()
if err != nil {
c.logger.Error(E.Cause(err, "reload TLS key pair"))
}
case err, ok := <-c.watcher.Errors:
if !ok {
return
}
c.logger.Error(E.Cause(err, "fsnotify error"))
}
}
}
func (c *echServerConfig) reloadKeyPair() error {
if c.certificatePath != "" {
certificate, err := os.ReadFile(c.certificatePath)
if err != nil {
return E.Cause(err, "reload certificate from ", c.certificatePath)
}
c.certificate = certificate
}
if c.keyPath != "" {
key, err := os.ReadFile(c.keyPath)
if err != nil {
return E.Cause(err, "reload key from ", c.keyPath)
}
c.key = key
}
keyPair, err := cftls.X509KeyPair(c.certificate, c.key)
if err != nil {
return E.Cause(err, "reload key pair")
}
c.config.Certificates = []cftls.Certificate{keyPair}
c.logger.Info("reloaded TLS certificate")
return nil
}
func (c *echServerConfig) startECHWatcher() error {
watcher, err := fsnotify.NewWatcher()
if err != nil {
return err
}
err = watcher.Add(c.echKeyPath)
if err != nil {
return err
}
c.watcher = watcher
go c.loopECHUpdate()
return nil
}
func (c *echServerConfig) loopECHUpdate() {
for {
select {
case event, ok := <-c.echWatcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write != fsnotify.Write {
continue
}
err := c.reloadECHKey()
if err != nil {
c.logger.Error(E.Cause(err, "reload ECH key"))
}
case err, ok := <-c.watcher.Errors:
if !ok {
return
}
c.logger.Error(E.Cause(err, "fsnotify error"))
}
}
}
func (c *echServerConfig) reloadECHKey() error {
echKeyContent, err := os.ReadFile(c.echKeyPath)
if err != nil {
return err
}
block, rest := pem.Decode(echKeyContent)
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return E.New("invalid ECH keys pem")
}
echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes)
if err != nil {
return E.Cause(err, "parse ECH keys")
}
echKeySet, err := cftls.EXP_NewECHKeySet(echKeys)
if err != nil {
return E.Cause(err, "create ECH key set")
}
c.config.ServerECHProvider = echKeySet
c.logger.Info("reloaded ECH keys")
return nil
}
func (c *echServerConfig) Close() error {
var err error
if c.watcher != nil {
err = E.Append(err, c.watcher.Close(), func(err error) error {
return E.Cause(err, "close certificate watcher")
})
}
if c.echWatcher != nil {
err = E.Append(err, c.echWatcher.Close(), func(err error) error {
return E.Cause(err, "close ECH key watcher")
})
}
return err
}
func NewECHServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
if !options.Enabled {
return nil, nil
}
var tlsConfig cftls.Config
if options.ACME != nil && len(options.ACME.Domain) > 0 {
return nil, E.New("acme is unavailable in ech")
}
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
if options.ServerName != "" {
tlsConfig.ServerName = options.ServerName
}
if len(options.ALPN) > 0 {
tlsConfig.NextProtos = append(options.ALPN, tlsConfig.NextProtos...)
}
if options.MinVersion != "" {
minVersion, err := ParseTLSVersion(options.MinVersion)
if err != nil {
return nil, E.Cause(err, "parse min_version")
}
tlsConfig.MinVersion = minVersion
}
if options.MaxVersion != "" {
maxVersion, err := ParseTLSVersion(options.MaxVersion)
if err != nil {
return nil, E.Cause(err, "parse max_version")
}
tlsConfig.MaxVersion = maxVersion
}
if options.CipherSuites != nil {
find:
for _, cipherSuite := range options.CipherSuites {
for _, tlsCipherSuite := range tls.CipherSuites() {
if cipherSuite == tlsCipherSuite.Name {
tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
continue find
}
}
return nil, E.New("unknown cipher_suite: ", cipherSuite)
}
}
var certificate []byte
var key []byte
if len(options.Certificate) > 0 {
certificate = []byte(strings.Join(options.Certificate, "\n"))
} else if options.CertificatePath != "" {
content, err := os.ReadFile(options.CertificatePath)
if err != nil {
return nil, E.Cause(err, "read certificate")
}
certificate = content
}
if len(options.Key) > 0 {
key = []byte(strings.Join(options.Key, ""))
} else if options.KeyPath != "" {
content, err := os.ReadFile(options.KeyPath)
if err != nil {
return nil, E.Cause(err, "read key")
}
key = content
}
if certificate == nil {
return nil, E.New("missing certificate")
} else if key == nil {
return nil, E.New("missing key")
}
keyPair, err := cftls.X509KeyPair(certificate, key)
if err != nil {
return nil, E.Cause(err, "parse x509 key pair")
}
tlsConfig.Certificates = []cftls.Certificate{keyPair}
block, rest := pem.Decode([]byte(strings.Join(options.ECH.Key, "\n")))
if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 {
return nil, E.New("invalid ECH keys pem")
}
echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes)
if err != nil {
return nil, E.Cause(err, "parse ECH keys")
}
echKeySet, err := cftls.EXP_NewECHKeySet(echKeys)
if err != nil {
return nil, E.Cause(err, "create ECH key set")
}
tlsConfig.ECHEnabled = true
tlsConfig.PQSignatureSchemesEnabled = options.ECH.PQSignatureSchemesEnabled
tlsConfig.DynamicRecordSizingDisabled = options.ECH.DynamicRecordSizingDisabled
tlsConfig.ServerECHProvider = echKeySet
return &echServerConfig{
config: &tlsConfig,
logger: logger,
certificate: certificate,
key: key,
certificatePath: options.CertificatePath,
keyPath: options.KeyPath,
echKeyPath: options.ECH.KeyPath,
}, nil
}

View File

@@ -3,11 +3,17 @@
package tls
import (
"github.com/sagernet/sing-box/adapter"
"context"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func NewECHClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
func NewECHServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
return nil, E.New(`ECH is not included in this build, rebuild with -tags with_ech`)
}
func NewECHClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
return nil, E.New(`ECH is not included in this build, rebuild with -tags with_ech`)
}

View File

@@ -26,7 +26,6 @@ import (
"time"
"unsafe"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
@@ -45,12 +44,12 @@ type RealityClientConfig struct {
shortID [8]byte
}
func NewRealityClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (*RealityClientConfig, error) {
func NewRealityClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (*RealityClientConfig, error) {
if options.UTLS == nil || !options.UTLS.Enabled {
return nil, E.New("uTLS is required by reality client")
}
uClient, err := NewUTLSClient(router, serverAddress, options)
uClient, err := NewUTLSClient(ctx, serverAddress, options)
if err != nil {
return nil, err
}

View File

@@ -19,6 +19,7 @@ import (
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
)
var _ ServerConfigCompat = (*RealityServerConfig)(nil)
@@ -27,13 +28,13 @@ type RealityServerConfig struct {
config *reality.Config
}
func NewRealityServer(ctx context.Context, router adapter.Router, logger log.Logger, options option.InboundTLSOptions) (*RealityServerConfig, error) {
func NewRealityServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (*RealityServerConfig, error) {
var tlsConfig reality.Config
if options.ACME != nil && len(options.ACME.Domain) > 0 {
return nil, E.New("acme is unavailable in reality")
}
tlsConfig.Time = router.TimeFunc()
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
if options.ServerName != "" {
tlsConfig.ServerName = options.ServerName
}
@@ -101,7 +102,7 @@ func NewRealityServer(ctx context.Context, router adapter.Router, logger log.Log
tlsConfig.ShortIds[shortID] = true
}
handshakeDialer, err := dialer.New(router, options.Reality.Handshake.DialerOptions)
handshakeDialer, err := dialer.New(adapter.RouterFromContext(ctx), options.Reality.Handshake.DialerOptions)
if err != nil {
return nil, err
}

View File

@@ -5,12 +5,11 @@ package tls
import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func NewRealityServer(ctx context.Context, router adapter.Router, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
func NewRealityServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
return nil, E.New(`reality server is not included in this build, rebuild with -tags with_reality_server`)
}

View File

@@ -4,21 +4,22 @@ import (
"context"
"net"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
aTLS "github.com/sagernet/sing/common/tls"
)
func NewServer(ctx context.Context, router adapter.Router, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
func NewServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
if !options.Enabled {
return nil, nil
}
if options.Reality != nil && options.Reality.Enabled {
return NewRealityServer(ctx, router, logger, options)
if options.ECH != nil && options.ECH.Enabled {
return NewECHServer(ctx, logger, options)
} else if options.Reality != nil && options.Reality.Enabled {
return NewRealityServer(ctx, logger, options)
} else {
return NewSTDServer(ctx, router, logger, options)
return NewSTDServer(ctx, logger, options)
}
}

View File

@@ -1,15 +1,16 @@
package tls
import (
"context"
"crypto/tls"
"crypto/x509"
"net"
"net/netip"
"os"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
)
type STDClientConfig struct {
@@ -44,7 +45,7 @@ func (s *STDClientConfig) Clone() Config {
return &STDClientConfig{s.config.Clone()}
}
func NewSTDClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
func NewSTDClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
var serverName string
if options.ServerName != "" {
serverName = options.ServerName
@@ -58,7 +59,7 @@ func NewSTDClient(router adapter.Router, serverAddress string, options option.Ou
}
var tlsConfig tls.Config
tlsConfig.Time = router.TimeFunc()
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
if options.DisableSNI {
tlsConfig.ServerName = "127.0.0.1"
} else {

View File

@@ -11,6 +11,7 @@ import (
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
"github.com/fsnotify/fsnotify"
)
@@ -156,7 +157,7 @@ func (c *STDServerConfig) Close() error {
return nil
}
func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
func NewSTDServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
if !options.Enabled {
return nil, nil
}
@@ -175,7 +176,7 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger,
} else {
tlsConfig = &tls.Config{}
}
tlsConfig.Time = router.TimeFunc()
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
if options.ServerName != "" {
tlsConfig.ServerName = options.ServerName
}
@@ -231,7 +232,7 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger,
}
if certificate == nil && key == nil && options.Insecure {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return GenerateKeyPair(router.TimeFunc(), info.ServerName)
return GenerateKeyPair(ntp.TimeFuncFromContext(ctx), info.ServerName)
}
} else {
if certificate == nil {

View File

@@ -11,9 +11,9 @@ import (
"net/netip"
"os"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp"
utls "github.com/sagernet/utls"
"golang.org/x/net/http2"
@@ -113,7 +113,7 @@ func (c *utlsALPNWrapper) HandshakeContext(ctx context.Context) error {
return c.UConn.HandshakeContext(ctx)
}
func NewUTLSClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (*UTLSClientConfig, error) {
func NewUTLSClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (*UTLSClientConfig, error) {
var serverName string
if options.ServerName != "" {
serverName = options.ServerName
@@ -127,7 +127,7 @@ func NewUTLSClient(router adapter.Router, serverAddress string, options option.O
}
var tlsConfig utls.Config
tlsConfig.Time = router.TimeFunc()
tlsConfig.Time = ntp.TimeFuncFromContext(ctx)
if options.DisableSNI {
tlsConfig.ServerName = "127.0.0.1"
} else {

View File

@@ -3,15 +3,16 @@
package tls
import (
"github.com/sagernet/sing-box/adapter"
"context"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
)
func NewUTLSClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
func NewUTLSClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
return nil, E.New(`uTLS is not included in this build, rebuild with -tags with_utls`)
}
func NewRealityClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
func NewRealityClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
return nil, E.New(`uTLS, which is required by reality client is not included in this build, rebuild with -tags with_utls`)
}