mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-02 17:27:32 +03:00
Refactor TLS
This commit is contained in:
77
common/tls/acme.go
Normal file
77
common/tls/acme.go
Normal file
@@ -0,0 +1,77 @@
|
||||
//go:build with_acme
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/certmagic"
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/mholt/acmez/acme"
|
||||
)
|
||||
|
||||
type acmeWrapper struct {
|
||||
ctx context.Context
|
||||
cfg *certmagic.Config
|
||||
domain []string
|
||||
}
|
||||
|
||||
func (w *acmeWrapper) Start() error {
|
||||
return w.cfg.ManageSync(w.ctx, w.domain)
|
||||
}
|
||||
|
||||
func (w *acmeWrapper) Close() error {
|
||||
w.cfg.Unmanage(w.domain)
|
||||
return nil
|
||||
}
|
||||
|
||||
func startACME(ctx context.Context, options option.InboundACMEOptions) (*tls.Config, adapter.Service, error) {
|
||||
var acmeServer string
|
||||
switch options.Provider {
|
||||
case "", "letsencrypt":
|
||||
acmeServer = certmagic.LetsEncryptProductionCA
|
||||
case "zerossl":
|
||||
acmeServer = certmagic.ZeroSSLProductionCA
|
||||
default:
|
||||
if !strings.HasPrefix(options.Provider, "https://") {
|
||||
return nil, nil, E.New("unsupported acme provider: " + options.Provider)
|
||||
}
|
||||
acmeServer = options.Provider
|
||||
}
|
||||
var storage certmagic.Storage
|
||||
if options.DataDirectory != "" {
|
||||
storage = &certmagic.FileStorage{
|
||||
Path: options.DataDirectory,
|
||||
}
|
||||
} else {
|
||||
storage = certmagic.Default.Storage
|
||||
}
|
||||
config := &certmagic.Config{
|
||||
DefaultServerName: options.DefaultServerName,
|
||||
Storage: storage,
|
||||
}
|
||||
acmeConfig := certmagic.ACMEIssuer{
|
||||
CA: acmeServer,
|
||||
Email: options.Email,
|
||||
Agreed: true,
|
||||
DisableHTTPChallenge: options.DisableHTTPChallenge,
|
||||
DisableTLSALPNChallenge: options.DisableTLSALPNChallenge,
|
||||
AltHTTPPort: int(options.AlternativeHTTPPort),
|
||||
AltTLSALPNPort: int(options.AlternativeTLSPort),
|
||||
}
|
||||
if options.ExternalAccount != nil {
|
||||
acmeConfig.ExternalAccount = (*acme.EAB)(options.ExternalAccount)
|
||||
}
|
||||
config.Issuers = []certmagic.Issuer{certmagic.NewACMEIssuer(config, acmeConfig)}
|
||||
config = certmagic.New(certmagic.NewCache(certmagic.CacheOptions{
|
||||
GetConfigForCert: func(certificate certmagic.Certificate) (*certmagic.Config, error) {
|
||||
return config, nil
|
||||
},
|
||||
}), *config)
|
||||
return config.TLSConfig(), &acmeWrapper{ctx, config, options.Domain}, nil
|
||||
}
|
||||
16
common/tls/acme_stub.go
Normal file
16
common/tls/acme_stub.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build !with_acme
|
||||
|
||||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func startACME(ctx context.Context, options option.InboundACMEOptions) (*tls.Config, adapter.Service, error) {
|
||||
return nil, nil, E.New(`ACME is not included in this build, rebuild with -tags with_acme`)
|
||||
}
|
||||
57
common/tls/client.go
Normal file
57
common/tls/client.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func NewDialerFromOptions(router adapter.Router, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
|
||||
config, err := NewClient(router, serverAddress, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewDialer(dialer, config), nil
|
||||
}
|
||||
|
||||
func NewClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
return newStdClient(serverAddress, options)
|
||||
}
|
||||
|
||||
func ClientHandshake(ctx context.Context, conn net.Conn, config Config) (net.Conn, error) {
|
||||
tlsConn := config.Client(conn)
|
||||
ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout)
|
||||
defer cancel()
|
||||
err := tlsConn.HandshakeContext(ctx)
|
||||
return tlsConn, err
|
||||
}
|
||||
|
||||
type Dialer struct {
|
||||
dialer N.Dialer
|
||||
config Config
|
||||
}
|
||||
|
||||
func NewDialer(dialer N.Dialer, config Config) N.Dialer {
|
||||
return &Dialer{dialer, config}
|
||||
}
|
||||
|
||||
func (d *Dialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if network != N.NetworkTCP {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
conn, err := d.dialer.DialContext(ctx, network, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ClientHandshake(ctx, conn, d.config)
|
||||
}
|
||||
|
||||
func (d *Dialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
12
common/tls/common.go
Normal file
12
common/tls/common.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package tls
|
||||
|
||||
const (
|
||||
VersionTLS10 = 0x0301
|
||||
VersionTLS11 = 0x0302
|
||||
VersionTLS12 = 0x0303
|
||||
VersionTLS13 = 0x0304
|
||||
|
||||
// Deprecated: SSLv3 is cryptographically broken, and is no longer
|
||||
// supported by this package. See golang.org/issue/32716.
|
||||
VersionSSL30 = 0x0300
|
||||
)
|
||||
46
common/tls/config.go
Normal file
46
common/tls/config.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type (
|
||||
STDConfig = tls.Config
|
||||
STDConn = tls.Conn
|
||||
)
|
||||
|
||||
type Config interface {
|
||||
Config() (*STDConfig, error)
|
||||
Client(conn net.Conn) Conn
|
||||
}
|
||||
|
||||
type ServerConfig interface {
|
||||
Config
|
||||
adapter.Service
|
||||
Server(conn net.Conn) Conn
|
||||
}
|
||||
|
||||
type Conn interface {
|
||||
net.Conn
|
||||
HandshakeContext(ctx context.Context) error
|
||||
}
|
||||
|
||||
func ParseTLSVersion(version string) (uint16, error) {
|
||||
switch version {
|
||||
case "1.0":
|
||||
return tls.VersionTLS10, nil
|
||||
case "1.1":
|
||||
return tls.VersionTLS11, nil
|
||||
case "1.2":
|
||||
return tls.VersionTLS12, nil
|
||||
case "1.3":
|
||||
return tls.VersionTLS13, nil
|
||||
default:
|
||||
return 0, E.New("unknown tls version:", version)
|
||||
}
|
||||
}
|
||||
@@ -1,34 +1,26 @@
|
||||
package dialer
|
||||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
type TLSDialer struct {
|
||||
dialer N.Dialer
|
||||
type stdClientConfig struct {
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
func TLSConfig(serverAddress string, options option.OutboundTLSOptions) (*tls.Config, error) {
|
||||
if !options.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
func newStdClient(serverAddress string, options option.OutboundTLSOptions) (Config, error) {
|
||||
var serverName string
|
||||
if options.ServerName != "" {
|
||||
serverName = options.ServerName
|
||||
} else if serverAddress != "" {
|
||||
if _, err := netip.ParseAddr(serverName); err == nil {
|
||||
if _, err := netip.ParseAddr(serverName); err != nil {
|
||||
serverName = serverAddress
|
||||
}
|
||||
}
|
||||
@@ -62,14 +54,14 @@ func TLSConfig(serverAddress string, options option.OutboundTLSOptions) (*tls.Co
|
||||
tlsConfig.NextProtos = options.ALPN
|
||||
}
|
||||
if options.MinVersion != "" {
|
||||
minVersion, err := option.ParseTLSVersion(options.MinVersion)
|
||||
minVersion, err := ParseTLSVersion(options.MinVersion)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse min_version")
|
||||
}
|
||||
tlsConfig.MinVersion = minVersion
|
||||
}
|
||||
if options.MaxVersion != "" {
|
||||
maxVersion, err := option.ParseTLSVersion(options.MaxVersion)
|
||||
maxVersion, err := ParseTLSVersion(options.MaxVersion)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse max_version")
|
||||
}
|
||||
@@ -104,42 +96,13 @@ func TLSConfig(serverAddress string, options option.OutboundTLSOptions) (*tls.Co
|
||||
}
|
||||
tlsConfig.RootCAs = certPool
|
||||
}
|
||||
return &tlsConfig, nil
|
||||
return &stdClientConfig{&tlsConfig}, nil
|
||||
}
|
||||
|
||||
func NewTLS(dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) {
|
||||
if !options.Enabled {
|
||||
return dialer, nil
|
||||
}
|
||||
tlsConfig, err := TLSConfig(serverAddress, options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &TLSDialer{
|
||||
dialer: dialer,
|
||||
config: tlsConfig,
|
||||
}, nil
|
||||
func (s *stdClientConfig) Config() (*STDConfig, error) {
|
||||
return s.config, nil
|
||||
}
|
||||
|
||||
func (d *TLSDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
if network != N.NetworkTCP {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
conn, err := d.dialer.DialContext(ctx, network, destination)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return TLSClient(ctx, conn, d.config)
|
||||
}
|
||||
|
||||
func (d *TLSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
|
||||
func TLSClient(ctx context.Context, conn net.Conn, tlsConfig *tls.Config) (*tls.Conn, error) {
|
||||
tlsConn := tls.Client(conn, tlsConfig)
|
||||
ctx, cancel := context.WithTimeout(ctx, C.TCPTimeout)
|
||||
defer cancel()
|
||||
err := tlsConn.HandshakeContext(ctx)
|
||||
return tlsConn, err
|
||||
func (s *stdClientConfig) Client(conn net.Conn) Conn {
|
||||
return tls.Client(conn, s.config)
|
||||
}
|
||||
224
common/tls/std_server.go
Normal file
224
common/tls/std_server.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package tls
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
type STDServerConfig struct {
|
||||
config *tls.Config
|
||||
logger log.Logger
|
||||
acmeService adapter.Service
|
||||
certificate []byte
|
||||
key []byte
|
||||
certificatePath string
|
||||
keyPath string
|
||||
watcher *fsnotify.Watcher
|
||||
}
|
||||
|
||||
func NewServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) {
|
||||
if !options.Enabled {
|
||||
return nil, nil
|
||||
}
|
||||
var tlsConfig *tls.Config
|
||||
var acmeService adapter.Service
|
||||
var err error
|
||||
if options.ACME != nil && len(options.ACME.Domain) > 0 {
|
||||
tlsConfig, acmeService, err = startACME(ctx, common.PtrValueOrDefault(options.ACME))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
tlsConfig = &tls.Config{}
|
||||
}
|
||||
if options.ServerName != "" {
|
||||
tlsConfig.ServerName = options.ServerName
|
||||
}
|
||||
if len(options.ALPN) > 0 {
|
||||
tlsConfig.NextProtos = append(tlsConfig.NextProtos, options.ALPN...)
|
||||
}
|
||||
if options.MinVersion != "" {
|
||||
minVersion, err := ParseTLSVersion(options.MinVersion)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse min_version")
|
||||
}
|
||||
tlsConfig.MinVersion = minVersion
|
||||
}
|
||||
if options.MaxVersion != "" {
|
||||
maxVersion, err := ParseTLSVersion(options.MaxVersion)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse max_version")
|
||||
}
|
||||
tlsConfig.MaxVersion = maxVersion
|
||||
}
|
||||
if options.CipherSuites != nil {
|
||||
find:
|
||||
for _, cipherSuite := range options.CipherSuites {
|
||||
for _, tlsCipherSuite := range tls.CipherSuites() {
|
||||
if cipherSuite == tlsCipherSuite.Name {
|
||||
tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID)
|
||||
continue find
|
||||
}
|
||||
}
|
||||
return nil, E.New("unknown cipher_suite: ", cipherSuite)
|
||||
}
|
||||
}
|
||||
var certificate []byte
|
||||
var key []byte
|
||||
if acmeService == nil {
|
||||
if options.Certificate != "" {
|
||||
certificate = []byte(options.Certificate)
|
||||
} else if options.CertificatePath != "" {
|
||||
content, err := os.ReadFile(options.CertificatePath)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read certificate")
|
||||
}
|
||||
certificate = content
|
||||
}
|
||||
if options.Key != "" {
|
||||
key = []byte(options.Key)
|
||||
} else if options.KeyPath != "" {
|
||||
content, err := os.ReadFile(options.KeyPath)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read key")
|
||||
}
|
||||
key = content
|
||||
}
|
||||
if certificate == nil {
|
||||
return nil, E.New("missing certificate")
|
||||
}
|
||||
if key == nil {
|
||||
return nil, E.New("missing key")
|
||||
}
|
||||
keyPair, err := tls.X509KeyPair(certificate, key)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "parse x509 key pair")
|
||||
}
|
||||
tlsConfig.Certificates = []tls.Certificate{keyPair}
|
||||
}
|
||||
return &STDServerConfig{
|
||||
config: tlsConfig,
|
||||
logger: logger,
|
||||
acmeService: acmeService,
|
||||
certificate: certificate,
|
||||
key: key,
|
||||
certificatePath: options.CertificatePath,
|
||||
keyPath: options.KeyPath,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Config() (*STDConfig, error) {
|
||||
return c.config, nil
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Client(conn net.Conn) Conn {
|
||||
return tls.Client(conn, c.config)
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Server(conn net.Conn) Conn {
|
||||
return tls.Server(conn, c.config)
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Start() error {
|
||||
if c.acmeService != nil {
|
||||
return c.acmeService.Start()
|
||||
} else {
|
||||
if c.certificatePath == "" && c.keyPath == "" {
|
||||
return nil
|
||||
}
|
||||
err := c.startWatcher()
|
||||
if err != nil {
|
||||
c.logger.Warn("create fsnotify watcher: ", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) startWatcher() error {
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c.certificatePath != "" {
|
||||
err = watcher.Add(c.certificatePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if c.keyPath != "" {
|
||||
err = watcher.Add(c.keyPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
c.watcher = watcher
|
||||
go c.loopUpdate()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) loopUpdate() {
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-c.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if event.Op&fsnotify.Write != fsnotify.Write {
|
||||
continue
|
||||
}
|
||||
err := c.reloadKeyPair()
|
||||
if err != nil {
|
||||
c.logger.Error(E.Cause(err, "reload TLS key pair"))
|
||||
}
|
||||
case err, ok := <-c.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
c.logger.Error(E.Cause(err, "fsnotify error"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) reloadKeyPair() error {
|
||||
if c.certificatePath != "" {
|
||||
certificate, err := os.ReadFile(c.certificatePath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "reload certificate from ", c.certificatePath)
|
||||
}
|
||||
c.certificate = certificate
|
||||
}
|
||||
if c.keyPath != "" {
|
||||
key, err := os.ReadFile(c.keyPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "reload key from ", c.keyPath)
|
||||
}
|
||||
c.key = key
|
||||
}
|
||||
keyPair, err := tls.X509KeyPair(c.certificate, c.key)
|
||||
if err != nil {
|
||||
return E.Cause(err, "reload key pair")
|
||||
}
|
||||
c.config.Certificates = []tls.Certificate{keyPair}
|
||||
c.logger.Info("reloaded TLS certificate")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *STDServerConfig) Close() error {
|
||||
if c.acmeService != nil {
|
||||
return c.acmeService.Close()
|
||||
}
|
||||
if c.watcher != nil {
|
||||
return c.watcher.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user