diff --git a/.goreleaser.yaml b/.goreleaser.yaml index 953e7325..dcb55e4a 100644 --- a/.goreleaser.yaml +++ b/.goreleaser.yaml @@ -20,6 +20,8 @@ builds: - with_acme - with_clash_api - with_tailscale + - with_masque + - with_mtproxy env: - CGO_ENABLED=0 - GOTOOLCHAIN=local @@ -61,6 +63,8 @@ builds: - with_acme - with_clash_api - with_tailscale + - with_masque + - with_mtproxy - with_manager - with_admin_panel env: @@ -97,6 +101,8 @@ builds: - with_acme - with_clash_api - with_tailscale + - with_masque + - with_mtproxy env: - CGO_ENABLED=0 targets: diff --git a/README.md b/README.md index cb6be2f2..bf1c2ff5 100644 --- a/README.md +++ b/README.md @@ -4,34 +4,37 @@ Sing-box with extended features. ## 🔥 Features -### 🌐 Outbounds -- **WARP** — Cloudflare WARP integration through WireGuard -- **Tunnel** — Protocol for creating tunnels across nodes -- **Bond** — Link aggregation for increased throughput -- **Mieru** — Secure, hard to classify, hard to probe network protocol -- **Failover** — Automatic outbound switching for high availability +### 🌐 Protocols +- **WARP** +- **Masque** +- **MTProxy** +- **Mieru** +- **VPN** +- **Bond** +- **Fallback** ### 🚦 Limiters -- **Bandwidth Limiter** — Upload / download rate limiting -- **Connection Limiter** — Concurrent connection control +- **Bandwidth Limiter** +- **Connection Limiter** ### 🛡 Encryption & Obfuscation -- **Amnezia 1.5** — WireGuard traffic obfuscation -- **VLESS encryption** — XRAY encryption for VLESS protocol +- **Amnezia 2.0** +- **VLESS encryption** ### 🔄 Transports -- **mKCP** — Reliable UDP-based transport -- **XHTTP** — Modern XRAY transport +- **mKCP** +- **XHTTP** ### 🛠 Services -- **Admin Panel** — Web-based management interface -- **Manager** — Management service for configuring squads, nodes, users, limiters -- **Node Manager** — Service for connecting nodes to remote manager +- **Admin Panel** +- **Manager** +- **Node Manager** ### ⚙ Miscellaneous -- **SDNS (DNSCrypt)** — Encrypted DNS queries for enhanced privacy -- **Extended WireGuard options** — Advanced configuration capabilities -- **Unified Delay** — Unified latency measurement +- **Link parser** +- **SDNS (DNSCrypt)** +- **Extended WireGuard options** +- **Unified Delay** ## 📚 Examples diff --git a/adapter/dns.go b/adapter/dns.go index 8f065e2e..23fbc9de 100644 --- a/adapter/dns.go +++ b/adapter/dns.go @@ -68,6 +68,8 @@ type DNSTransport interface { Type() string Tag() string Dependencies() []string + // Reset closes the transport's existing connections so later requests use fresh connections. + // Exchanges that are currently using those connections may fail. Reset() Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) } diff --git a/adapter/experimental.go b/adapter/experimental.go index 5409e163..d4b904ed 100644 --- a/adapter/experimental.go +++ b/adapter/experimental.go @@ -48,6 +48,7 @@ type CacheFile interface { RDRCStore StoreWARPConfig() bool + StoreMASQUEConfig() bool LoadMode() string StoreMode(mode string) error @@ -59,6 +60,10 @@ type CacheFile interface { SaveRuleSet(tag string, set *SavedBinary) error LoadWARPConfig(tag string) *SavedBinary SaveWARPConfig(tag string, set *SavedBinary) error + LoadMASQUEConfig(tag string) *SavedBinary + SaveMASQUEConfig(tag string, set *SavedBinary) error + LoadSubscription(tag string) *SavedBinary + SaveSubscription(tag string, sub *SavedBinary) error } type SavedBinary struct { diff --git a/adapter/inbound.go b/adapter/inbound.go index 73bc98cf..4ffdcc58 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -42,16 +42,15 @@ type InboundManager interface { } type InboundContext struct { - Inbound string - InboundType string - IPVersion uint8 - Network string - Source M.Socksaddr - Destination M.Socksaddr - TunnelSource string - TunnelDestination string - User string - Outbound string + Inbound string + InboundType string + IPVersion uint8 + Network string + Source M.Socksaddr + Destination M.Socksaddr + Gateway *netip.Addr + User string + Outbound string // sniffer diff --git a/adapter/platform.go b/adapter/platform.go index fa4cbc2e..df1f4471 100644 --- a/adapter/platform.go +++ b/adapter/platform.go @@ -1,6 +1,8 @@ package adapter import ( + "net/netip" + "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/logger" @@ -36,6 +38,8 @@ type PlatformInterface interface { UsePlatformNotification() bool SendNotification(notification *Notification) error + + MyInterfaceAddress() []netip.Addr } type FindConnectionOwnerRequest struct { diff --git a/adapter/provider.go b/adapter/provider.go new file mode 100644 index 00000000..0bb88860 --- /dev/null +++ b/adapter/provider.go @@ -0,0 +1,51 @@ +package adapter + +import ( + "context" + "time" + + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/x/list" +) + +type Provider interface { + Type() string + Tag() string + Outbounds() []Outbound + Outbound(tag string) (Outbound, bool) + UpdatedAt() time.Time + HealthCheck(ctx context.Context) (map[string]uint16, error) + RegisterCallback(callback ProviderUpdateCallback) *list.Element[ProviderUpdateCallback] + UnregisterCallback(element *list.Element[ProviderUpdateCallback]) +} + +type ProviderUpdater interface { + Update() error +} + +type ProviderSubscriptionInfo interface { + SubscriptionInfo() SubscriptionInfo +} + +type ProviderRegistry interface { + option.ProviderOptionsRegistry + CreateProvider(ctx context.Context, router Router, logFactory log.Factory, tag string, providerType string, options any) (Provider, error) +} + +type ProviderManager interface { + Lifecycle + Providers() []Provider + Get(tag string) (Provider, bool) + Remove(tag string) error + Create(ctx context.Context, router Router, logFactory log.Factory, tag string, providerType string, options any) error +} + +type SubscriptionInfo struct { + Upload int64 + Download int64 + Total int64 + Expire int64 +} + +type ProviderUpdateCallback = func(tag string) error diff --git a/adapter/provider/adapter.go b/adapter/provider/adapter.go new file mode 100644 index 00000000..3c55783e --- /dev/null +++ b/adapter/provider/adapter.go @@ -0,0 +1,267 @@ +package provider + +import ( + "context" + "reflect" + "sync" + "sync/atomic" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/urltest" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/batch" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/x/list" + "github.com/sagernet/sing/service" +) + +type Adapter struct { + ctx context.Context + outbound adapter.OutboundManager + router adapter.Router + logFactory log.Factory + logger log.ContextLogger + providerType string + providerTag string + outbounds []adapter.Outbound + outboundsByTag map[string]adapter.Outbound + ticker *time.Ticker + checking atomic.Bool + history adapter.URLTestHistoryStorage + callbackAccess sync.Mutex + callbacks list.List[adapter.ProviderUpdateCallback] + + link string + enabled bool + timeout time.Duration + interval time.Duration +} + +func NewAdapter(ctx context.Context, router adapter.Router, outbound adapter.OutboundManager, logFactory log.Factory, logger log.ContextLogger, providerTag string, providerType string, options option.ProviderHealthCheckOptions) Adapter { + timeout := time.Duration(options.Timeout) + if timeout == 0 { + timeout = 3 * time.Second + } + interval := time.Duration(options.Interval) + if interval == 0 { + interval = 10 * time.Minute + } + if interval < time.Minute { + interval = time.Minute + } + return Adapter{ + ctx: ctx, + outbound: outbound, + router: router, + logFactory: logFactory, + logger: logger, + providerType: providerType, + providerTag: providerTag, + + enabled: options.Enabled, + link: options.URL, + timeout: timeout, + interval: interval, + } +} + +func (a *Adapter) Start() error { + a.history = service.FromContext[adapter.URLTestHistoryStorage](a.ctx) + if a.history == nil { + if clashServer := service.FromContext[adapter.ClashServer](a.ctx); clashServer != nil { + a.history = clashServer.HistoryStorage() + } else { + a.history = urltest.NewHistoryStorage() + } + } + go a.loopCheck() + return nil +} + +func (a *Adapter) Type() string { + return a.providerType +} + +func (a *Adapter) Tag() string { + return a.providerTag +} + +func (a *Adapter) Outbounds() []adapter.Outbound { + return a.outbounds +} + +func (a *Adapter) Outbound(tag string) (adapter.Outbound, bool) { + if a.outboundsByTag == nil { + return nil, false + } + detour, ok := a.outboundsByTag[tag] + return detour, ok +} + +func (a *Adapter) UpdateOutbounds(oldOpts []option.Outbound, newOpts []option.Outbound) { + a.removeUseless(newOpts) + var ( + oldOptByTag = make(map[string]option.Outbound) + outbounds = make([]adapter.Outbound, 0, len(newOpts)) + outboundsByTag = make(map[string]adapter.Outbound) + ) + for _, opt := range oldOpts { + oldOptByTag[opt.Tag] = opt + } + for i, opt := range newOpts { + var tag string + if opt.Tag != "" { + tag = F.ToString(a.providerTag, "/", opt.Tag) + } else { + tag = F.ToString(a.providerTag, "/", i) + } + outbound, exist := a.outbound.Outbound(tag) + if !exist || !reflect.DeepEqual(opt, oldOptByTag[opt.Tag]) { + err := a.outbound.Create( + adapter.WithContext(a.ctx, &adapter.InboundContext{ + Outbound: tag, + }), + a.router, + a.logFactory.NewLogger(F.ToString("outbound/", opt.Type, "[", tag, "]")), + tag, + opt.Type, + opt.Options, + ) + if err != nil { + a.logger.Warn(err, " in ", tag, ", skip create this outbound") + continue + } + outbound, _ = a.outbound.Outbound(tag) + } + outbounds = append(outbounds, outbound) + outboundsByTag[tag] = outbound + } + if a.enabled && a.history != nil { + go a.HealthCheck(a.ctx) + } + a.outbounds = outbounds + a.outboundsByTag = outboundsByTag +} + +func (a *Adapter) HealthCheck(ctx context.Context) (map[string]uint16, error) { + if a.ticker != nil { + a.ticker.Reset(a.interval) + } + return a.healthcheck(ctx) +} + +func (a *Adapter) RegisterCallback(callback adapter.ProviderUpdateCallback) *list.Element[adapter.ProviderUpdateCallback] { + a.callbackAccess.Lock() + defer a.callbackAccess.Unlock() + return a.callbacks.PushBack(callback) +} + +func (a *Adapter) UnregisterCallback(element *list.Element[adapter.ProviderUpdateCallback]) { + a.callbackAccess.Lock() + defer a.callbackAccess.Unlock() + a.callbacks.Remove(element) +} + +func (a *Adapter) UpdateGroups() { + for element := a.callbacks.Front(); element != nil; element = element.Next() { + element.Value(a.providerTag) + } +} + +func (a *Adapter) Close() error { + if a.ticker != nil { + a.ticker.Stop() + } + outbounds := a.outbounds + a.outbounds = nil + var err error + for _, ob := range outbounds { + if err2 := a.outbound.Remove(ob.Tag()); err2 != nil { + err = E.Append(err, err2, func(err error) error { + return E.Cause(err, "close outbound [", ob.Tag(), "]") + }) + } + } + return err +} + +func (a *Adapter) loopCheck() { + if !a.enabled { + return + } + a.ticker = time.NewTicker(a.interval) + a.healthcheck(a.ctx) + for { + select { + case <-a.ctx.Done(): + return + case <-a.ticker.C: + a.healthcheck(a.ctx) + } + } +} + +func (a *Adapter) healthcheck(ctx context.Context) (map[string]uint16, error) { + result := make(map[string]uint16) + if a.checking.Swap(true) { + return result, nil + } + defer a.checking.Store(false) + b, _ := batch.New(ctx, batch.WithConcurrencyNum[any](10)) + var resultAccess sync.Mutex + checked := make(map[string]bool) + for _, detour := range a.outbounds { + tag := detour.Tag() + if checked[tag] { + continue + } + checked[tag] = true + b.Go(tag, func() (any, error) { + ctx, cancel := context.WithTimeout(a.ctx, a.timeout) + defer cancel() + t, err := urltest.URLTest(ctx, a.link, detour) + if err != nil { + a.logger.Debug("outbound ", tag, " unavailable: ", err) + a.history.DeleteURLTestHistory(tag) + } else { + a.logger.Debug("outbound ", tag, " available: ", t, "ms") + a.history.StoreURLTestHistory(tag, &adapter.URLTestHistory{ + Time: time.Now(), + Delay: t, + }) + resultAccess.Lock() + result[tag] = t + resultAccess.Unlock() + } + return nil, nil + }) + } + b.Wait() + return result, nil +} + +func (a *Adapter) removeUseless(newOpts []option.Outbound) { + if len(a.outbounds) == 0 { + return + } + exists := make(map[string]bool) + for i, opt := range newOpts { + var tag string + if opt.Tag != "" { + tag = F.ToString(a.providerTag, "/", opt.Tag) + } else { + tag = F.ToString(a.providerTag, "/", i) + } + exists[tag] = true + } + for _, opt := range a.outbounds { + if !exists[opt.Tag()] { + if err := a.outbound.Remove(opt.Tag()); err != nil { + a.logger.Error(err, "close outbound [", opt.Tag(), "]") + } + } + } +} diff --git a/adapter/provider/manager.go b/adapter/provider/manager.go new file mode 100644 index 00000000..563df8da --- /dev/null +++ b/adapter/provider/manager.go @@ -0,0 +1,157 @@ +package provider + +import ( + "context" + "io" + "os" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/taskmonitor" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" +) + +var _ adapter.ProviderManager = (*Manager)(nil) + +type Manager struct { + logger log.ContextLogger + registry adapter.ProviderRegistry + access sync.Mutex + started bool + stage adapter.StartStage + providers []adapter.Provider + providerByTag map[string]adapter.Provider + wg sync.WaitGroup +} + +func NewManager(logger logger.ContextLogger, registry adapter.ProviderRegistry) *Manager { + return &Manager{ + logger: logger, + registry: registry, + providerByTag: make(map[string]adapter.Provider), + } +} + +func (m *Manager) Initialize() { +} + +func (m *Manager) Start(stage adapter.StartStage) error { + m.access.Lock() + if m.started && m.stage >= stage { + panic("already started") + } + m.started = true + m.stage = stage + providers := m.providers + m.access.Unlock() + for _, provider := range providers { + err := adapter.LegacyStart(provider, stage) + if err != nil { + return E.Cause(err, stage, " provider/", provider.Type(), "[", provider.Tag(), "]") + } + } + return nil +} + +func (m *Manager) Close() error { + monitor := taskmonitor.New(m.logger, C.StopTimeout) + m.access.Lock() + if !m.started { + m.access.Unlock() + return nil + } + m.started = false + providers := m.providers + m.providers = nil + m.access.Unlock() + var err error + for _, provider := range providers { + if closer, isCloser := provider.(io.Closer); isCloser { + monitor.Start("close provider/", provider.Type(), "[", provider.Tag(), "]") + err = E.Append(err, closer.Close(), func(err error) error { + return E.Cause(err, "close provider/", provider.Type(), "[", provider.Tag(), "]") + }) + monitor.Finish() + } + } + return nil +} + +func (m *Manager) Providers() []adapter.Provider { + m.access.Lock() + defer m.access.Unlock() + return m.providers +} + +func (m *Manager) Get(tag string) (adapter.Provider, bool) { + m.access.Lock() + provider, found := m.providerByTag[tag] + m.access.Unlock() + return provider, found +} + +func (m *Manager) Remove(tag string) error { + m.access.Lock() + provider, found := m.providerByTag[tag] + if !found { + m.access.Unlock() + return os.ErrInvalid + } + delete(m.providerByTag, tag) + index := common.Index(m.providers, func(it adapter.Provider) bool { + return it == provider + }) + if index == -1 { + panic("invalid provider index") + } + m.providers = append(m.providers[:index], m.providers[index+1:]...) + started := m.started + m.access.Unlock() + if started { + return common.Close(provider) + } + return nil +} + +func (m *Manager) Create(ctx context.Context, router adapter.Router, logFactory log.Factory, tag string, providerType string, options any) error { + if tag == "" { + return os.ErrInvalid + } + + provider, err := m.registry.CreateProvider(ctx, router, logFactory, tag, providerType, options) + if err != nil { + return err + } + m.access.Lock() + defer m.access.Unlock() + if m.started { + for _, stage := range adapter.ListStartStages { + err = adapter.LegacyStart(provider, stage) + if err != nil { + return E.Cause(err, stage, " provider/", provider.Type(), "[", provider.Tag(), "]") + } + } + } + if existsProvider, loaded := m.providerByTag[tag]; loaded { + if m.started { + err = common.Close(existsProvider) + if err != nil { + return E.Cause(err, "close provider", provider.Type(), "[", existsProvider.Tag(), "]") + } + } + existsIndex := common.Index(m.providers, func(it adapter.Provider) bool { + return it == existsProvider + }) + if existsIndex == -1 { + panic("invalid provider index") + } + m.providers = append(m.providers[:existsIndex], m.providers[existsIndex+1:]...) + } + m.providers = append(m.providers, provider) + m.providerByTag[tag] = provider + return nil +} diff --git a/adapter/provider/registry.go b/adapter/provider/registry.go new file mode 100644 index 00000000..5a484754 --- /dev/null +++ b/adapter/provider/registry.go @@ -0,0 +1,72 @@ +package provider + +import ( + "context" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +type ConstructorFunc[T any] func(ctx context.Context, router adapter.Router, logFactory log.Factory, tag string, options T) (adapter.Provider, error) + +func Register[Options any](registry *Registry, providerType string, constructor ConstructorFunc[Options]) { + registry.register(providerType, func() any { + return new(Options) + }, func(ctx context.Context, router adapter.Router, logFactory log.Factory, tag string, rawOptions any) (adapter.Provider, error) { + var options *Options + if rawOptions != nil { + options = rawOptions.(*Options) + } + return constructor(ctx, router, logFactory, tag, common.PtrValueOrDefault(options)) + }) +} + +var _ adapter.ProviderRegistry = (*Registry)(nil) + +type ( + optionsConstructorFunc func() any + constructorFunc func(ctx context.Context, router adapter.Router, logFactory log.Factory, tag string, options any) (adapter.Provider, error) +) + +type Registry struct { + access sync.Mutex + optionsType map[string]optionsConstructorFunc + constructors map[string]constructorFunc +} + +func NewRegistry() *Registry { + return &Registry{ + optionsType: make(map[string]optionsConstructorFunc), + constructors: make(map[string]constructorFunc), + } +} + +func (r *Registry) CreateOptions(providerType string) (any, bool) { + r.access.Lock() + defer r.access.Unlock() + optionsConstructor, loaded := r.optionsType[providerType] + if !loaded { + return nil, false + } + return optionsConstructor(), true +} + +func (r *Registry) CreateProvider(ctx context.Context, router adapter.Router, logFactory log.Factory, tag string, providerType string, options any) (adapter.Provider, error) { + r.access.Lock() + defer r.access.Unlock() + constructor, loaded := r.constructors[providerType] + if !loaded { + return nil, E.New("provider type not found: '" + providerType + "'") + } + return constructor(ctx, router, logFactory, tag, options) +} + +func (r *Registry) register(providerType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) { + r.access.Lock() + defer r.access.Unlock() + r.optionsType[providerType] = optionsConstructor + r.constructors[providerType] = constructor +} diff --git a/box.go b/box.go index 789b8b11..f99dbdb2 100644 --- a/box.go +++ b/box.go @@ -12,6 +12,7 @@ import ( "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/adapter/provider" boxService "github.com/sagernet/sing-box/adapter/service" "github.com/sagernet/sing-box/common/certificate" "github.com/sagernet/sing-box/common/dialer" @@ -44,6 +45,7 @@ type Box struct { endpoint *endpoint.Manager inbound *inbound.Manager outbound *outbound.Manager + provider *provider.Manager service *boxService.Manager dnsTransport *dns.TransportManager dnsRouter *dns.Router @@ -64,6 +66,7 @@ func Context( inboundRegistry adapter.InboundRegistry, outboundRegistry adapter.OutboundRegistry, endpointRegistry adapter.EndpointRegistry, + providerRegistry adapter.ProviderRegistry, dnsTransportRegistry adapter.DNSTransportRegistry, serviceRegistry adapter.ServiceRegistry, ) context.Context { @@ -82,6 +85,11 @@ func Context( ctx = service.ContextWith[option.EndpointOptionsRegistry](ctx, endpointRegistry) ctx = service.ContextWith[adapter.EndpointRegistry](ctx, endpointRegistry) } + if service.FromContext[option.ProviderOptionsRegistry](ctx) == nil || + service.FromContext[adapter.ProviderRegistry](ctx) == nil { + ctx = service.ContextWith[option.ProviderOptionsRegistry](ctx, providerRegistry) + ctx = service.ContextWith[adapter.ProviderRegistry](ctx, providerRegistry) + } if service.FromContext[adapter.DNSTransportRegistry](ctx) == nil { ctx = service.ContextWith[option.DNSTransportOptionsRegistry](ctx, dnsTransportRegistry) ctx = service.ContextWith[adapter.DNSTransportRegistry](ctx, dnsTransportRegistry) @@ -104,6 +112,7 @@ func New(options Options) (*Box, error) { endpointRegistry := service.FromContext[adapter.EndpointRegistry](ctx) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) + providerRegistry := service.FromContext[adapter.ProviderRegistry](ctx) dnsTransportRegistry := service.FromContext[adapter.DNSTransportRegistry](ctx) serviceRegistry := service.FromContext[adapter.ServiceRegistry](ctx) @@ -116,6 +125,9 @@ func New(options Options) (*Box, error) { if outboundRegistry == nil { return nil, E.New("missing outbound registry in context") } + if providerRegistry == nil { + return nil, E.New("missing provider registry in context") + } if dnsTransportRegistry == nil { return nil, E.New("missing DNS transport registry in context") } @@ -181,11 +193,13 @@ func New(options Options) (*Box, error) { endpointManager := endpoint.NewManager(logFactory.NewLogger("endpoint"), endpointRegistry) inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry, endpointManager) outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, endpointManager, routeOptions.Final) + providerManager := provider.NewManager(logFactory.NewLogger("provider"), providerRegistry) dnsTransportManager := dns.NewTransportManager(logFactory.NewLogger("dns/transport"), dnsTransportRegistry, outboundManager, dnsOptions.Final) serviceManager := boxService.NewManager(logFactory.NewLogger("service"), serviceRegistry) service.MustRegister[adapter.EndpointManager](ctx, endpointManager) service.MustRegister[adapter.InboundManager](ctx, inboundManager) service.MustRegister[adapter.OutboundManager](ctx, outboundManager) + service.MustRegister[adapter.ProviderManager](ctx, providerManager) service.MustRegister[adapter.DNSTransportManager](ctx, dnsTransportManager) service.MustRegister[adapter.ServiceManager](ctx, serviceManager) dnsRouter := dns.NewRouter(ctx, logFactory, dnsOptions) @@ -276,6 +290,10 @@ func New(options Options) (*Box, error) { return nil, E.Cause(err, "initialize inbound[", i, "]") } } + options.Outbounds = append(options.Outbounds, option.Outbound{ + Tag: "Compatible", + Type: C.TypeDirect, + }) for i, outboundOptions := range options.Outbounds { var tag string if outboundOptions.Tag != "" { @@ -302,6 +320,25 @@ func New(options Options) (*Box, error) { return nil, E.Cause(err, "initialize outbound[", i, "]") } } + for i, providerOptions := range options.Providers { + var tag string + if providerOptions.Tag != "" { + tag = providerOptions.Tag + } else { + tag = F.ToString(i) + } + err = providerManager.Create( + ctx, + router, + logFactory, + tag, + providerOptions.Type, + providerOptions.Options, + ) + if err != nil { + return nil, E.Cause(err, "initialize provider[", i, "]") + } + } for i, serviceOptions := range options.Services { var tag string if serviceOptions.Tag != "" { @@ -392,6 +429,7 @@ func New(options Options) (*Box, error) { endpoint: endpointManager, inbound: inboundManager, outbound: outboundManager, + provider: providerManager, dnsTransport: dnsTransportManager, service: serviceManager, dnsRouter: dnsRouter, @@ -455,11 +493,11 @@ func (s *Box) preStart() error { if err != nil { return err } - err = adapter.Start(s.logger, adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service) + err = adapter.Start(s.logger, adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.provider, s.service) if err != nil { return err } - err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router) + err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.provider, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router) if err != nil { return err } @@ -479,7 +517,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(s.logger, adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint, s.service) + err = adapter.Start(s.logger, adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint, s.provider, s.service) if err != nil { return err } @@ -487,7 +525,7 @@ func (s *Box) start() error { if err != nil { return err } - err = adapter.Start(s.logger, adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.service) + err = adapter.Start(s.logger, adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint, s.provider, s.service) if err != nil { return err } @@ -513,6 +551,7 @@ func (s *Box) Close() error { {"service", s.service}, {"endpoint", s.endpoint}, {"inbound", s.inbound}, + {"provider", s.provider}, {"outbound", s.outbound}, {"router", s.router}, {"connection", s.connection}, diff --git a/common/cloudflare/api.go b/common/cloudflare/api.go index 2bd63343..85d33252 100644 --- a/common/cloudflare/api.go +++ b/common/cloudflare/api.go @@ -1,15 +1,12 @@ package cloudflare import ( + "bytes" "context" "encoding/json" "fmt" - "io" "net/http" - "strings" "time" - - "github.com/tidwall/gjson" ) type CloudflareApi struct { @@ -25,50 +22,93 @@ func NewCloudflareApi(opts ...CloudflareApiOption) *CloudflareApi { } func (api *CloudflareApi) CreateProfile(ctx context.Context, publicKey string) (*CloudflareProfile, error) { - request, err := http.NewRequest("POST", "https://api.cloudflareclient.com/v0i1909051800/reg", strings.NewReader( - fmt.Sprintf( - "{\"install_id\":\"\",\"tos\":\"%s\",\"key\":\"%s\",\"fcm_token\":\"\",\"type\":\"ios\",\"locale\":\"en_US\"}", - time.Now().Format("2006-01-02T15:04:05.000Z"), - publicKey, - ), - )) + serial, err := GenerateRandomAndroidSerial() + if err != nil { + return nil, fmt.Errorf("failed to generate serial: %v", err) + } + data := Registration{ + Key: publicKey, + InstallID: "", + FcmToken: "", + Tos: TimeAsCfString(time.Now()), + Model: "PC", + Serial: serial, + OsVersion: "", + KeyType: KeyTypeWg, + TunType: TunTypeWg, + Locale: "en-US", + } + jsonData, err := json.Marshal(data) + if err != nil { + return nil, fmt.Errorf("failed to marshal json: %v", err) + } + request, err := http.NewRequest("POST", ApiUrl+"/"+ApiVersion+"/reg", bytes.NewBuffer(jsonData)) if err != nil { return nil, err } + for k, v := range Headers { + request.Header.Set(k, v) + } response, err := api.client.Do(request.WithContext(ctx)) if err != nil { return nil, err } defer response.Body.Close() - if response.StatusCode != 200 { - return nil, fmt.Errorf("status code is not 200") - } - content, err := io.ReadAll(response.Body) - if err != nil { - return nil, err + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to register: %v", response.StatusCode) } profile := new(CloudflareProfile) - return profile, json.NewDecoder(strings.NewReader(gjson.Get(string(content), "result").Raw)).Decode(profile) + return profile, json.NewDecoder(response.Body).Decode(profile) } -func (api *CloudflareApi) GetProfile(ctx context.Context, authToken string, id string) (*CloudflareProfile, error) { - request, err := http.NewRequest("GET", "https://api.cloudflareclient.com/v0i1909051800/reg/"+id, nil) +func (api *CloudflareApi) EnrollKey(ctx context.Context, authToken string, id string, keyType, tunType, publicKey string) (*CloudflareProfile, error) { + deviceUpdate := DeviceUpdate{ + Name: "PC", + Key: publicKey, + KeyType: keyType, + TunType: tunType, + } + jsonData, err := json.Marshal(deviceUpdate) + if err != nil { + return nil, fmt.Errorf("failed to marshal json: %v", err) + } + request, err := http.NewRequest("PATCH", ApiUrl+"/"+ApiVersion+"/reg/"+id, bytes.NewBuffer(jsonData)) if err != nil { return nil, err } + for k, v := range Headers { + request.Header.Set(k, v) + } request.Header.Set("Authorization", "Bearer "+authToken) response, err := api.client.Do(request.WithContext(ctx)) if err != nil { return nil, err } defer response.Body.Close() - if response.StatusCode != 200 { - return nil, fmt.Errorf("status code is not 200") + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to enroll key: %v", response.StatusCode) } - content, err := io.ReadAll(response.Body) + profile := new(CloudflareProfile) + return profile, json.NewDecoder(response.Body).Decode(profile) +} + +func (api *CloudflareApi) GetProfile(ctx context.Context, authToken string, id string) (*CloudflareProfile, error) { + request, err := http.NewRequest("GET", ApiUrl+"/"+ApiVersion+"/reg/"+id, nil) if err != nil { return nil, err } + for k, v := range Headers { + request.Header.Set(k, v) + } + request.Header.Set("Authorization", "Bearer "+authToken) + response, err := api.client.Do(request.WithContext(ctx)) + if err != nil { + return nil, err + } + defer response.Body.Close() + if response.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to get profile: %v", response.StatusCode) + } profile := new(CloudflareProfile) - return profile, json.NewDecoder(strings.NewReader(gjson.Get(string(content), "result").Raw)).Decode(profile) + return profile, json.NewDecoder(response.Body).Decode(profile) } diff --git a/common/cloudflare/constant.go b/common/cloudflare/constant.go new file mode 100644 index 00000000..e5108b07 --- /dev/null +++ b/common/cloudflare/constant.go @@ -0,0 +1,25 @@ +package cloudflare + +const ( + ApiUrl = "https://api.cloudflareclient.com" + ApiVersion = "v0a4471" + ConnectSNI = "consumer-masque.cloudflareclient.com" + // unused for now + ZeroTierSNI = "zt-masque.cloudflareclient.com" + ConnectURI = "https://cloudflareaccess.com" + DefaultModel = "PC" + KeyTypeWg = "curve25519" + TunTypeWg = "wireguard" + KeyTypeMasque = "secp256r1" + TunTypeMasque = "masque" + DefaultLocale = "en_US" + DefaultEndpointH2V4 = "162.159.198.2" + DefaultEndpointH2V6 = "" +) + +var Headers = map[string]string{ + "User-Agent": "WARP for Android", + "CF-Client-Version": "a-6.35-4471", + "Content-Type": "application/json; charset=UTF-8", + "Connection": "Keep-Alive", +} diff --git a/common/cloudflare/models.go b/common/cloudflare/models.go new file mode 100644 index 00000000..591b0338 --- /dev/null +++ b/common/cloudflare/models.go @@ -0,0 +1,132 @@ +package cloudflare + +import ( + "strings" +) + +type Registration struct { + Key string `json:"key"` + InstallID string `json:"install_id"` + FcmToken string `json:"fcm_token"` + Tos string `json:"tos"` + Model string `json:"model"` + Serial string `json:"serial_number"` + OsVersion string `json:"os_version"` + KeyType string `json:"key_type"` + TunType string `json:"tunnel_type"` + Locale string `json:"locale"` +} + +type CloudflareProfile struct { + ID string `json:"id"` + Type string `json:"type"` + Model string `json:"model"` + Name string `json:"name"` + Key string `json:"key"` + KeyType string `json:"key_type"` + TunType string `json:"tunnel_type"` + Account Account `json:"account"` + Config Config `json:"config"` + // WarpEnabled not set for ZeroTier + WarpEnabled bool `json:"warp_enabled,omitempty"` + // Waitlist not set for ZeroTier + Waitlist bool `json:"waitlist_enabled,omitempty"` + Created string `json:"created"` + Updated string `json:"updated"` + // Tos not set for ZeroTier + Tos string `json:"tos,omitempty"` + // Place not set for ZeroTier + Place int `json:"place,omitempty"` + Locale string `json:"locale"` + // Enabled not set for ZeroTier + Enabled bool `json:"enabled,omitempty"` + InstallID string `json:"install_id"` + // Token only set for /reg call + Token string `json:"token,omitempty"` + FcmToken string `json:"fcm_token"` + // SerialNumber not set for ZeroTier + SerialNumber string `json:"serial_number,omitempty"` + Policy Policy `json:"policy"` +} + +type Account struct { + ID string `json:"id"` + AccountType string `json:"account_type"` + // Created not set for ZeroTier + Created string `json:"created,omitempty"` + // Updated not set for ZeroTier + Updated string `json:"updated,omitempty"` + // Managed only set for ZeroTier + Managed string `json:"managed,omitempty"` + // Organization only set for ZeroTier + Organization string `json:"organization,omitempty"` + // PremiumData not set for ZeroTier + PremiumData int `json:"premium_data,omitempty"` + // Quota not set for ZeroTier + Quota int `json:"quota,omitempty"` + // WarpPlus not set for ZeroTier + WarpPlus bool `json:"warp_plus,omitempty"` + // ReferralCode not set for ZeroTier + ReferralCount int `json:"referral_count,omitempty"` + // ReferralRenewalCount not set for ZeroTier + ReferralRenewalCount int `json:"referral_renewal_countdown,omitempty"` + // Role not set for ZeroTier + Role string `json:"role,omitempty"` + // License not set for ZeroTier + License string `json:"license,omitempty"` +} + +type Config struct { + ClientID string `json:"client_id"` + Peers []Peer `json:"peers"` + Interface struct { + Addresses struct { + V4 string `json:"v4"` + V6 string `json:"v6"` + } `json:"addresses"` + } `json:"interface"` + Services struct { + HTTPProxy string `json:"http_proxy"` + } `json:"services"` +} + +type Peer struct { + PublicKey string `json:"public_key"` + Endpoint struct { + V4 string `json:"v4"` + V6 string `json:"v6"` + Host string `json:"host"` + Ports []int `json:"ports"` + } `json:"endpoint"` +} + +type Policy struct { + TunnelProtocol string `json:"tunnel_protocol"` +} + +type DeviceUpdate struct { + Key string `json:"key"` + KeyType string `json:"key_type"` + TunType string `json:"tunnel_type"` + Name string `json:"name,omitempty"` +} + +type APIError struct { + Result interface{} `json:"result"` + Success bool `json:"success"` + Errors []ErrorInfo `json:"errors"` + Messages []string `json:"messages"` +} + +type ErrorInfo struct { + Code int `json:"code"` + Message string `json:"message"` +} + +func (e *APIError) Error() string { + errors := make([]string, len(e.Errors)) + for i, err := range e.Errors { + errors[i] = err.Message + } + return strings.Join(errors, ",") +} diff --git a/common/cloudflare/option.go b/common/cloudflare/option.go index 929e91b7..1ba4f829 100644 --- a/common/cloudflare/option.go +++ b/common/cloudflare/option.go @@ -4,12 +4,14 @@ import ( "context" "net" "net/http" + "time" ) type CloudflareApiOption func(api *CloudflareApi) func WithDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) CloudflareApiOption { return func(api *CloudflareApi) { + api.client.Timeout = 30 * time.Second api.client.Transport = &http.Transport{ DialContext: dialContext, } diff --git a/common/cloudflare/profile.go b/common/cloudflare/profile.go deleted file mode 100644 index bc011e06..00000000 --- a/common/cloudflare/profile.go +++ /dev/null @@ -1,64 +0,0 @@ -package cloudflare - -import "time" - -type CloudflareProfile struct { - ID string `json:"id"` - Type string `json:"type"` - Name string `json:"name"` - Key string `json:"key"` - Account struct { - ID string `json:"id"` - AccountType string `json:"account_type"` - Created time.Time `json:"created"` - Updated time.Time `json:"updated"` - PremiumData int `json:"premium_data"` - Quota int `json:"quota"` - Usage int `json:"usage"` - WARPPlus bool `json:"warp_plus"` - ReferralCount int `json:"referral_count"` - ReferralRenewalCountdown int `json:"referral_renewal_countdown"` - Role string `json:"role"` - License string `json:"license"` - TTL time.Time `json:"ttl"` - } `json:"account"` - Config struct { - ClientID string `json:"client_id"` - Interface struct { - Addresses struct { - V4 string `json:"v4"` - V6 string `json:"v6"` - } `json:"addresses"` - } `json:"interface"` - Peers []struct { - PublicKey string `json:"public_key"` - Endpoint struct { - V4 string `json:"v4"` - V6 string `json:"v6"` - Host string `json:"host"` - Ports []int `json:"ports"` - } `json:"endpoint"` - } `json:"peers"` - Services struct { - HTTPProxy string `json:"http_proxy"` - } `json:"services"` - Metrics struct { - Ping int `json:"ping"` - Report int `json:"report"` - } `json:"metrics"` - } `json:"config"` - Token string `json:"token"` - WARPEnabled bool `json:"warp_enabled"` - WaitlistEnabled bool `json:"waitlist_enabled"` - Created time.Time `json:"created"` - Updated time.Time `json:"updated"` - Tos time.Time `json:"tos"` - Place int `json:"place"` - Locale string `json:"locale"` - Enabled bool `json:"enabled"` - InstallID string `json:"install_id"` - FcmToken string `json:"fcm_token"` - Policy struct { - TunnelProtocol string `json:"tunnel_protocol"` - } `json:"policy"` -} diff --git a/common/cloudflare/utils.go b/common/cloudflare/utils.go new file mode 100644 index 00000000..45aa1d41 --- /dev/null +++ b/common/cloudflare/utils.go @@ -0,0 +1,19 @@ +package cloudflare + +import ( + "crypto/rand" + "encoding/hex" + "time" +) + +func GenerateRandomAndroidSerial() (string, error) { + serial := make([]byte, 8) + if _, err := rand.Read(serial); err != nil { + return "", err + } + return hex.EncodeToString(serial), nil +} + +func TimeAsCfString(t time.Time) string { + return t.Format("2006-01-02T15:04:05.000-07:00") +} diff --git a/common/interrupt/context.go b/common/interrupt/context.go index 44726b2d..ba91601a 100644 --- a/common/interrupt/context.go +++ b/common/interrupt/context.go @@ -11,3 +11,13 @@ func ContextWithIsExternalConnection(ctx context.Context) context.Context { func IsExternalConnectionFromContext(ctx context.Context) bool { return ctx.Value(contextKeyIsExternalConnection{}) != nil } + +type contextKeyIsProviderConnection struct{} + +func ContextWithIsProviderConnection(ctx context.Context) context.Context { + return context.WithValue(ctx, contextKeyIsProviderConnection{}, true) +} + +func IsProviderConnectionFromContext(ctx context.Context) bool { + return ctx.Value(contextKeyIsProviderConnection{}) != nil +} diff --git a/common/interrupt/group.go b/common/interrupt/group.go index bd3fbb0a..ae9095f8 100644 --- a/common/interrupt/group.go +++ b/common/interrupt/group.go @@ -17,30 +17,31 @@ type Group struct { type groupConnItem struct { conn io.Closer isExternal bool + isProvider bool } func NewGroup() *Group { return &Group{} } -func (g *Group) NewConn(conn net.Conn, isExternal bool) net.Conn { +func (g *Group) NewConn(conn net.Conn, isExternal bool, isProvider bool) net.Conn { g.access.Lock() defer g.access.Unlock() - item := g.connections.PushBack(&groupConnItem{conn, isExternal}) + item := g.connections.PushBack(&groupConnItem{conn, isExternal, isProvider}) return &Conn{Conn: conn, group: g, element: item} } -func (g *Group) NewPacketConn(conn net.PacketConn, isExternal bool) net.PacketConn { +func (g *Group) NewPacketConn(conn net.PacketConn, isExternal bool, isProvider bool) net.PacketConn { g.access.Lock() defer g.access.Unlock() - item := g.connections.PushBack(&groupConnItem{conn, isExternal}) + item := g.connections.PushBack(&groupConnItem{conn, isExternal, isProvider}) return &PacketConn{PacketConn: conn, group: g, element: item} } -func (g *Group) NewSingPacketConn(conn N.PacketConn, isExternal bool) N.PacketConn { +func (g *Group) NewSingPacketConn(conn N.PacketConn, isExternal bool, isProvider bool) N.PacketConn { g.access.Lock() defer g.access.Unlock() - item := g.connections.PushBack(&groupConnItem{conn, isExternal}) + item := g.connections.PushBack(&groupConnItem{conn, isExternal, isProvider}) return &SingPacketConn{PacketConn: conn, group: g, element: item} } diff --git a/common/kmutex/mutex.go b/common/kmutex/mutex.go index 9767959f..6e2e4ec7 100644 --- a/common/kmutex/mutex.go +++ b/common/kmutex/mutex.go @@ -12,7 +12,6 @@ type klock struct { ref uint64 } -// Create new Kmutex func New[T comparable]() *Kmutex[T] { l := sync.Mutex{} return &Kmutex[T]{ @@ -21,7 +20,6 @@ func New[T comparable]() *Kmutex[T] { } } -// Unlock Kmutex by unique ID func (km *Kmutex[T]) Unlock(key T) { km.l.Lock() defer km.l.Unlock() @@ -37,7 +35,6 @@ func (km *Kmutex[T]) Unlock(key T) { kl.cond.Signal() } -// Lock Kmutex by unique ID func (km *Kmutex[T]) Lock(key T) { km.l.Lock() defer km.l.Unlock() diff --git a/common/tls/masque_client.go b/common/tls/masque_client.go new file mode 100644 index 00000000..d4e23940 --- /dev/null +++ b/common/tls/masque_client.go @@ -0,0 +1,74 @@ +package tls + +import ( + "context" + "crypto/ecdsa" + "crypto/tls" + "crypto/x509" + "time" + + "github.com/sagernet/quic-go/http3" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" +) + +func NewMASQUEClient(ctx context.Context, logger logger.ContextLogger, serverName string, cert [][]byte, privateKey *ecdsa.PrivateKey, peerPublicKey *ecdsa.PublicKey, options option.MASQUEOutboundTLSOptions) (Config, error) { + var tlsConfig tls.Config + tlsConfig.ServerName = serverName + tlsConfig.InsecureSkipVerify = true + tlsConfig.NextProtos = []string{http3.NextProtoH3} + tlsConfig.Certificates = []tls.Certificate{ + { + Certificate: cert, + PrivateKey: privateKey, + }, + } + if options.CipherSuites != nil { + find: + for _, cipherSuite := range options.CipherSuites { + for _, tlsCipherSuite := range tls.CipherSuites() { + if cipherSuite == tlsCipherSuite.Name { + tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID) + continue find + } + } + return nil, E.New("unknown cipher_suite: ", cipherSuite) + } + } + for _, curve := range options.CurvePreferences { + tlsConfig.CurvePreferences = append(tlsConfig.CurvePreferences, tls.CurveID(curve)) + } + if !options.Insecure { + tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { + if len(rawCerts) == 0 { + return nil + } + cert, err := x509.ParseCertificate(rawCerts[0]) + if err != nil { + return err + } + if _, ok := cert.PublicKey.(*ecdsa.PublicKey); !ok { + return x509.ErrUnsupportedAlgorithm + } + if !cert.PublicKey.(*ecdsa.PublicKey).Equal(peerPublicKey) { + return x509.CertificateInvalidError{Cert: cert, Reason: 10, Detail: "remote endpoint has a different public key than what we trust in config.json"} + } + return nil + } + } + var config Config = &STDClientConfig{ctx, &tlsConfig, options.Fragment, time.Duration(options.FragmentFallbackDelay), options.RecordFragment} + if options.KernelRx || options.KernelTx { + if !C.IsLinux { + return nil, E.New("kTLS is only supported on Linux") + } + config = &KTLSClientConfig{ + Config: config, + logger: logger, + kernelTx: options.KernelTx, + kernelRx: options.KernelRx, + } + } + return config, nil +} diff --git a/common/utils.go b/common/utils.go new file mode 100644 index 00000000..9b8a6ae6 --- /dev/null +++ b/common/utils.go @@ -0,0 +1,68 @@ +package common + +import ( + "encoding/base64" + "reflect" + "regexp" + "strconv" + "strings" + "time" + + "github.com/sagernet/sing/common/json/badoption" +) + +func StringToType[T any](str string) T { + var value T + v := reflect.ValueOf(&value).Elem() + switch any(value).(type) { + case badoption.Duration: + d, err := time.ParseDuration(str) + if err != nil { + v.SetInt(StringToType[int64](str)) + } else { + v.Set(reflect.ValueOf(d)) + } + return value + case badoption.HTTPHeader: + headers := badoption.HTTPHeader{} + reg := regexp.MustCompile(`^[ \t]*?(\S+?):[ \t]*?(\S+?)[ \t]*?$`) + for _, header := range strings.Split(str, "\n") { + result := reg.FindStringSubmatch(header) + if result != nil { + key := result[1] + headers[key] = strings.Split(result[2], ",") + } + } + v.Set(reflect.ValueOf(headers)) + return value + } + switch v.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, _ := strconv.ParseInt(str, 10, 64) + v.SetInt(i) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i, _ := strconv.ParseUint(str, 10, 64) + v.SetUint(i) + case reflect.Float32, reflect.Float64: + f, _ := strconv.ParseFloat(str, 64) + v.SetFloat(f) + case reflect.Bool: + b, _ := strconv.ParseBool(str) + v.SetBool(b) + default: + panic("unsupported type") + } + return value +} + +func DecodeBase64URLSafe(content string) (string, error) { + s := strings.ReplaceAll(content, " ", "-") + s = strings.ReplaceAll(s, "/", "_") + s = strings.ReplaceAll(s, "+", "-") + s = strings.ReplaceAll(s, "=", "") + result, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return content, nil + } + return string(result), nil +} diff --git a/constant/provider.go b/constant/provider.go new file mode 100644 index 00000000..252b1af5 --- /dev/null +++ b/constant/provider.go @@ -0,0 +1,20 @@ +package constant + +const ( + ProviderTypeInline = "inline" + ProviderTypeLocal = "local" + ProviderTypeRemote = "remote" +) + +func ProviderDisplayName(providerType string) string { + switch providerType { + case ProviderTypeInline: + return "Inline" + case ProviderTypeLocal: + return "Local" + case ProviderTypeRemote: + return "Remote" + default: + return "Unknown" + } +} diff --git a/constant/proxy.go b/constant/proxy.go index 527dbfb4..d8212baa 100644 --- a/constant/proxy.go +++ b/constant/proxy.go @@ -16,6 +16,9 @@ const ( TypeNaive = "naive" TypeWireGuard = "wireguard" TypeWARP = "warp" + TypeMASQUE = "masque" + TypeMTProxy = "mtproxy" + TypeParser = "parser" TypeHysteria = "hysteria" TypeTor = "tor" TypeSSH = "ssh" @@ -27,8 +30,8 @@ const ( TypeTUIC = "tuic" TypeHysteria2 = "hysteria2" TypeBond = "bond" - TypeTunnelServer = "tunnel-server" - TypeTunnelClient = "tunnel-client" + TypeVPNServer = "vpn-server" + TypeVPNClient = "vpn-client" TypeTailscale = "tailscale" TypeConnectionLimiter = "connection-limiter" TypeBandwidthLimiter = "bandwidth-limiter" @@ -47,7 +50,7 @@ const ( ) const ( - TypeFailover = "failover" + TypeFallback = "fallback" TypeSelector = "selector" TypeURLTest = "urltest" ) @@ -84,6 +87,12 @@ func ProxyDisplayName(proxyType string) string { return "WireGuard" case TypeWARP: return "WARP" + case TypeMASQUE: + return "MASQUE" + case TypeMTProxy: + return "MTProxy" + case TypeParser: + return "Parser" case TypeHysteria: return "Hysteria" case TypeTor: @@ -106,18 +115,18 @@ func ProxyDisplayName(proxyType string) string { return "Mieru" case TypeAnyTLS: return "AnyTLS" - case TypeFailover: - return "Failover" + case TypeFallback: + return "Fallback" case TypeTailscale: return "Tailscale" case TypeSelector: return "Selector" case TypeURLTest: return "URLTest" - case TypeTunnelClient: - return "Tunnel client" - case TypeTunnelServer: - return "Tunnel server" + case TypeVPNClient: + return "VPN Client" + case TypeVPNServer: + return "VPN Server" default: return "Unknown" } diff --git a/constant/warp.go b/constant/warp.go deleted file mode 100644 index 038ce346..00000000 --- a/constant/warp.go +++ /dev/null @@ -1,20 +0,0 @@ -package constant - -type WARPConfig struct { - PrivateKey string `json:"private_key"` - Interface struct { - Addresses struct { - V4 string `json:"v4"` - V6 string `json:"v6"` - } `json:"addresses"` - } `json:"interface"` - Peers []struct { - PublicKey string `json:"public_key"` - Endpoint struct { - V4 string `json:"v4"` - V6 string `json:"v6"` - Host string `json:"host"` - Ports []int `json:"ports"` - } `json:"endpoint"` - } `json:"peers"` -} diff --git a/dns/transport/base.go b/dns/transport/base.go deleted file mode 100644 index 06e41fd0..00000000 --- a/dns/transport/base.go +++ /dev/null @@ -1,145 +0,0 @@ -package transport - -import ( - "context" - "os" - "sync" - - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/dns" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" -) - -type TransportState int - -const ( - StateNew TransportState = iota - StateStarted - StateClosing - StateClosed -) - -var ( - ErrTransportClosed = os.ErrClosed - ErrConnectionReset = E.New("connection reset") -) - -type BaseTransport struct { - dns.TransportAdapter - Logger logger.ContextLogger - - mutex sync.Mutex - state TransportState - inFlight int32 - queriesComplete chan struct{} - closeCtx context.Context - closeCancel context.CancelFunc -} - -func NewBaseTransport(adapter dns.TransportAdapter, logger logger.ContextLogger) *BaseTransport { - ctx, cancel := context.WithCancel(context.Background()) - return &BaseTransport{ - TransportAdapter: adapter, - Logger: logger, - state: StateNew, - closeCtx: ctx, - closeCancel: cancel, - } -} - -func (t *BaseTransport) State() TransportState { - t.mutex.Lock() - defer t.mutex.Unlock() - return t.state -} - -func (t *BaseTransport) SetStarted() error { - t.mutex.Lock() - defer t.mutex.Unlock() - switch t.state { - case StateNew: - t.state = StateStarted - return nil - case StateStarted: - return nil - default: - return ErrTransportClosed - } -} - -func (t *BaseTransport) BeginQuery() bool { - t.mutex.Lock() - defer t.mutex.Unlock() - if t.state != StateStarted { - return false - } - t.inFlight++ - return true -} - -func (t *BaseTransport) EndQuery() { - t.mutex.Lock() - if t.inFlight > 0 { - t.inFlight-- - } - if t.inFlight == 0 && t.queriesComplete != nil { - close(t.queriesComplete) - t.queriesComplete = nil - } - t.mutex.Unlock() -} - -func (t *BaseTransport) CloseContext() context.Context { - return t.closeCtx -} - -func (t *BaseTransport) Shutdown(ctx context.Context) error { - t.mutex.Lock() - - if t.state >= StateClosing { - t.mutex.Unlock() - return nil - } - - if t.state == StateNew { - t.state = StateClosed - t.mutex.Unlock() - t.closeCancel() - return nil - } - - t.state = StateClosing - - if t.inFlight == 0 { - t.state = StateClosed - t.mutex.Unlock() - t.closeCancel() - return nil - } - - t.queriesComplete = make(chan struct{}) - queriesComplete := t.queriesComplete - t.mutex.Unlock() - - t.closeCancel() - - select { - case <-queriesComplete: - t.mutex.Lock() - t.state = StateClosed - t.mutex.Unlock() - return nil - case <-ctx.Done(): - t.mutex.Lock() - t.state = StateClosed - t.mutex.Unlock() - return ctx.Err() - } -} - -func (t *BaseTransport) Close() error { - ctx, cancel := context.WithTimeout(context.Background(), C.TCPTimeout) - defer cancel() - return t.Shutdown(ctx) -} diff --git a/dns/transport/conn_pool.go b/dns/transport/conn_pool.go new file mode 100644 index 00000000..6161e9bd --- /dev/null +++ b/dns/transport/conn_pool.go @@ -0,0 +1,547 @@ +package transport + +import ( + "context" + "net" + "sync" + "time" + + "github.com/sagernet/sing/common/x/list" +) + +type ConnPoolMode int + +const ( + ConnPoolSingle ConnPoolMode = iota + ConnPoolOrdered +) + +type ConnPoolOptions[T comparable] struct { + Mode ConnPoolMode + IsAlive func(T) bool + Close func(T, error) +} + +type ConnPool[T comparable] struct { + options ConnPoolOptions[T] + + access sync.Mutex + closed bool + state *connPoolState[T] +} + +type connPoolState[T comparable] struct { + ctx context.Context + cancel context.CancelCauseFunc + + all map[T]struct{} + + idle list.List[T] + idleElements map[T]*list.Element[T] + + shared T + hasShared bool + sharedClaimed bool + sharedCtx context.Context + sharedCancel context.CancelCauseFunc + + connecting *connPoolConnect[T] +} + +type connPoolConnect[T comparable] struct { + done chan struct{} + err error +} + +type connPoolDialContext struct { + context.Context + parent context.Context +} + +func (c connPoolDialContext) Deadline() (time.Time, bool) { + return c.parent.Deadline() +} + +func (c connPoolDialContext) Value(key any) any { + return c.parent.Value(key) +} + +func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] { + return &ConnPool[T]{ + options: options, + state: newConnPoolState[T](options.Mode), + } +} + +func newConnPoolState[T comparable](mode ConnPoolMode) *connPoolState[T] { + ctx, cancel := context.WithCancelCause(context.Background()) + state := &connPoolState[T]{ + ctx: ctx, + cancel: cancel, + all: make(map[T]struct{}), + } + if mode == ConnPoolOrdered { + state.idleElements = make(map[T]*list.Element[T]) + } + return state +} + +func (p *ConnPool[T]) Acquire(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) { + switch p.options.Mode { + case ConnPoolSingle: + conn, _, created, err := p.acquireShared(ctx, dial) + return conn, created, err + case ConnPoolOrdered: + return p.acquireOrdered(ctx, dial) + default: + var zero T + return zero, false, net.ErrClosed + } +} + +func (p *ConnPool[T]) AcquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) { + if p.options.Mode != ConnPoolSingle { + var zero T + return zero, nil, false, net.ErrClosed + } + return p.acquireShared(ctx, dial) +} + +func (p *ConnPool[T]) Release(conn T, reuse bool) { + var ( + closeConn bool + closeErr error + ) + + p.access.Lock() + if p.closed || p.state == nil { + closeConn = true + closeErr = net.ErrClosed + p.access.Unlock() + if closeConn { + p.options.Close(conn, closeErr) + } + return + } + + currentState := p.state + _, tracked := currentState.all[conn] + if !tracked { + closeConn = true + closeErr = p.closeCause(currentState) + p.access.Unlock() + if closeConn { + p.options.Close(conn, closeErr) + } + return + } + + if !reuse || !p.options.IsAlive(conn) { + delete(currentState.all, conn) + switch p.options.Mode { + case ConnPoolSingle: + if currentState.hasShared && currentState.shared == conn { + var zero T + currentState.shared = zero + currentState.hasShared = false + currentState.sharedClaimed = false + currentState.sharedCtx = nil + if currentState.sharedCancel != nil { + currentState.sharedCancel(net.ErrClosed) + currentState.sharedCancel = nil + } + } + case ConnPoolOrdered: + if element, loaded := currentState.idleElements[conn]; loaded { + currentState.idle.Remove(element) + delete(currentState.idleElements, conn) + } + } + closeConn = true + closeErr = net.ErrClosed + p.access.Unlock() + if closeConn { + p.options.Close(conn, closeErr) + } + return + } + + if p.options.Mode == ConnPoolOrdered { + if _, loaded := currentState.idleElements[conn]; !loaded { + currentState.idleElements[conn] = currentState.idle.PushBack(conn) + } + } + p.access.Unlock() +} + +func (p *ConnPool[T]) Invalidate(conn T, cause error) { + p.access.Lock() + if p.closed || p.state == nil { + p.access.Unlock() + p.options.Close(conn, cause) + return + } + + currentState := p.state + _, tracked := currentState.all[conn] + if !tracked { + p.access.Unlock() + return + } + + delete(currentState.all, conn) + switch p.options.Mode { + case ConnPoolSingle: + if currentState.hasShared && currentState.shared == conn { + var zero T + currentState.shared = zero + currentState.hasShared = false + currentState.sharedClaimed = false + currentState.sharedCtx = nil + if currentState.sharedCancel != nil { + currentState.sharedCancel(cause) + currentState.sharedCancel = nil + } + } + case ConnPoolOrdered: + if element, loaded := currentState.idleElements[conn]; loaded { + currentState.idle.Remove(element) + delete(currentState.idleElements, conn) + } + } + p.access.Unlock() + + p.options.Close(conn, cause) +} + +func (p *ConnPool[T]) Reset() { + p.access.Lock() + if p.closed { + p.access.Unlock() + return + } + + oldState := p.state + p.state = newConnPoolState[T](p.options.Mode) + p.access.Unlock() + + p.closeState(oldState, net.ErrClosed) +} + +func (p *ConnPool[T]) Close() error { + p.access.Lock() + if p.closed { + p.access.Unlock() + return nil + } + + p.closed = true + oldState := p.state + p.state = nil + p.access.Unlock() + + p.closeState(oldState, net.ErrClosed) + return nil +} + +func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) { + var zero T + for { + var ( + staleConn T + hasStale bool + ) + + p.access.Lock() + if p.closed { + p.access.Unlock() + return zero, false, net.ErrClosed + } + + currentState := p.state + if element := currentState.idle.Front(); element != nil { + conn := currentState.idle.Remove(element) + delete(currentState.idleElements, conn) + if p.options.IsAlive(conn) { + p.access.Unlock() + return conn, false, nil + } + delete(currentState.all, conn) + staleConn = conn + hasStale = true + } + p.access.Unlock() + + if hasStale { + p.options.Close(staleConn, net.ErrClosed) + continue + } + + conn, err := p.dial(ctx, currentState, dial) + if err != nil { + return zero, false, err + } + + p.access.Lock() + if p.closed { + p.access.Unlock() + p.options.Close(conn, net.ErrClosed) + return zero, false, net.ErrClosed + } + if p.state != currentState { + cause := p.closeCause(currentState) + p.access.Unlock() + p.options.Close(conn, cause) + return zero, false, cause + } + currentState.all[conn] = struct{}{} + p.access.Unlock() + return conn, true, nil + } +} + +func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) { + var zero T + for { + var ( + staleConn T + hasStale bool + state *connPoolConnect[T] + current *connPoolState[T] + startDial bool + ) + + p.access.Lock() + if p.closed { + p.access.Unlock() + return zero, nil, false, net.ErrClosed + } + + current = p.state + if current.hasShared { + conn := current.shared + if p.options.IsAlive(conn) { + created := !current.sharedClaimed + current.sharedClaimed = true + connCtx := current.sharedCtx + p.access.Unlock() + return conn, connCtx, created, nil + } + delete(current.all, conn) + var zeroConn T + current.shared = zeroConn + current.hasShared = false + current.sharedClaimed = false + current.sharedCtx = nil + if current.sharedCancel != nil { + current.sharedCancel(net.ErrClosed) + current.sharedCancel = nil + } + staleConn = conn + hasStale = true + p.access.Unlock() + p.options.Close(staleConn, net.ErrClosed) + continue + } + + if current.connecting == nil { + current.connecting = &connPoolConnect[T]{ + done: make(chan struct{}), + } + startDial = true + } + state = current.connecting + p.access.Unlock() + + if hasStale { + continue + } + if startDial { + go p.connectSingle(current, state, ctx, dial) + } + + select { + case <-state.done: + conn, connCtx, created, retry, err := p.collectShared(current, state, startDial) + if retry { + continue + } + return conn, connCtx, created, err + case <-ctx.Done(): + return zero, nil, false, ctx.Err() + case <-current.ctx.Done(): + p.access.Lock() + closed := p.closed + p.access.Unlock() + if closed { + return zero, nil, false, net.ErrClosed + } + } + } +} + +func (p *ConnPool[T]) connectSingle(current *connPoolState[T], state *connPoolConnect[T], ctx context.Context, dial func(context.Context) (T, error)) { + conn, err := p.dial(ctx, current, dial) + if err != nil { + p.access.Lock() + if current.connecting == state { + current.connecting = nil + } + state.err = err + p.access.Unlock() + close(state.done) + return + } + + var closeErr error + + p.access.Lock() + if current.connecting == state { + current.connecting = nil + } + if p.closed { + closeErr = net.ErrClosed + state.err = closeErr + } else if p.state != current { + closeErr = p.closeCause(current) + state.err = closeErr + } else { + sharedCtx, sharedCancel := context.WithCancelCause(current.ctx) + current.shared = conn + current.hasShared = true + current.sharedClaimed = false + current.sharedCtx = sharedCtx + current.sharedCancel = sharedCancel + current.all[conn] = struct{}{} + } + p.access.Unlock() + + if closeErr != nil { + p.options.Close(conn, closeErr) + } + close(state.done) +} + +func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolConnect[T], startDial bool) (T, context.Context, bool, bool, error) { + var zero T + + p.access.Lock() + if state.err != nil { + err := state.err + p.access.Unlock() + if startDial { + return zero, nil, false, false, err + } + return zero, nil, false, true, nil + } + if p.closed { + p.access.Unlock() + return zero, nil, false, false, net.ErrClosed + } + if p.state != current { + cause := p.closeCause(current) + p.access.Unlock() + return zero, nil, false, false, cause + } + if !current.hasShared { + p.access.Unlock() + return zero, nil, false, true, nil + } + + conn := current.shared + if !p.options.IsAlive(conn) { + delete(current.all, conn) + var zeroConn T + current.shared = zeroConn + current.hasShared = false + current.sharedClaimed = false + current.sharedCtx = nil + if current.sharedCancel != nil { + current.sharedCancel(net.ErrClosed) + current.sharedCancel = nil + } + p.access.Unlock() + p.options.Close(conn, net.ErrClosed) + return zero, nil, false, true, nil + } + + created := !current.sharedClaimed + current.sharedClaimed = true + connCtx := current.sharedCtx + p.access.Unlock() + return conn, connCtx, created, false, nil +} + +func (p *ConnPool[T]) dial(ctx context.Context, current *connPoolState[T], dial func(context.Context) (T, error)) (T, error) { + var zero T + + if err := ctx.Err(); err != nil { + return zero, err + } + if cause := context.Cause(current.ctx); cause != nil { + return zero, cause + } + + dialCtx, cancel := context.WithCancelCause(current.ctx) + var ( + stateAccess sync.Mutex + dialComplete bool + ) + stopCancel := context.AfterFunc(ctx, func() { + stateAccess.Lock() + if !dialComplete { + cancel(context.Cause(ctx)) + } + stateAccess.Unlock() + }) + + select { + case <-ctx.Done(): + stateAccess.Lock() + dialComplete = true + stateAccess.Unlock() + stopCancel() + cancel(context.Cause(ctx)) + return zero, ctx.Err() + default: + } + + conn, err := dial(connPoolDialContext{ + Context: dialCtx, + parent: ctx, + }) + stateAccess.Lock() + dialComplete = true + stateAccess.Unlock() + stopCancel() + if err != nil { + if cause := context.Cause(dialCtx); cause != nil { + return zero, cause + } + return zero, err + } + if cause := context.Cause(dialCtx); cause != nil { + p.options.Close(conn, cause) + return zero, cause + } + return conn, nil +} + +func (p *ConnPool[T]) closeState(state *connPoolState[T], cause error) { + if state == nil { + return + } + + state.cancel(cause) + if state.sharedCancel != nil { + state.sharedCancel(cause) + } + for conn := range state.all { + p.options.Close(conn, cause) + } +} + +func (p *ConnPool[T]) closeCause(state *connPoolState[T]) error { + _ = state + return net.ErrClosed +} diff --git a/dns/transport/connector.go b/dns/transport/connector.go deleted file mode 100644 index 3a87456d..00000000 --- a/dns/transport/connector.go +++ /dev/null @@ -1,321 +0,0 @@ -package transport - -import ( - "context" - "net" - "sync" - "time" - - E "github.com/sagernet/sing/common/exceptions" -) - -type ConnectorCallbacks[T any] struct { - IsClosed func(connection T) bool - Close func(connection T) - Reset func(connection T) -} - -type Connector[T any] struct { - dial func(ctx context.Context) (T, error) - callbacks ConnectorCallbacks[T] - - access sync.Mutex - connection T - hasConnection bool - connectionCancel context.CancelFunc - connecting chan struct{} - - closeCtx context.Context - closed bool -} - -func NewConnector[T any](closeCtx context.Context, dial func(context.Context) (T, error), callbacks ConnectorCallbacks[T]) *Connector[T] { - return &Connector[T]{ - dial: dial, - callbacks: callbacks, - closeCtx: closeCtx, - } -} - -func NewSingleflightConnector(closeCtx context.Context, dial func(context.Context) (*Connection, error)) *Connector[*Connection] { - return NewConnector(closeCtx, dial, ConnectorCallbacks[*Connection]{ - IsClosed: func(connection *Connection) bool { - return connection.IsClosed() - }, - Close: func(connection *Connection) { - connection.CloseWithError(ErrTransportClosed) - }, - Reset: func(connection *Connection) { - connection.CloseWithError(ErrConnectionReset) - }, - }) -} - -type contextKeyConnecting struct{} - -var errRecursiveConnectorDial = E.New("recursive connector dial") - -type connectorDialResult[T any] struct { - connection T - cancel context.CancelFunc - err error -} - -func (c *Connector[T]) Get(ctx context.Context) (T, error) { - var zero T - for { - c.access.Lock() - - if c.closed { - c.access.Unlock() - return zero, ErrTransportClosed - } - - if c.hasConnection && !c.callbacks.IsClosed(c.connection) { - connection := c.connection - c.access.Unlock() - return connection, nil - } - - c.hasConnection = false - if c.connectionCancel != nil { - c.connectionCancel() - c.connectionCancel = nil - } - if isRecursiveConnectorDial(ctx, c) { - c.access.Unlock() - return zero, errRecursiveConnectorDial - } - - if c.connecting != nil { - connecting := c.connecting - c.access.Unlock() - - select { - case <-connecting: - continue - case <-ctx.Done(): - return zero, ctx.Err() - case <-c.closeCtx.Done(): - return zero, ErrTransportClosed - } - } - - if err := ctx.Err(); err != nil { - c.access.Unlock() - return zero, err - } - - connecting := make(chan struct{}) - c.connecting = connecting - dialContext := context.WithValue(ctx, contextKeyConnecting{}, c) - dialResult := make(chan connectorDialResult[T], 1) - c.access.Unlock() - - go func() { - connection, cancel, err := c.dialWithCancellation(dialContext) - dialResult <- connectorDialResult[T]{ - connection: connection, - cancel: cancel, - err: err, - } - }() - - select { - case result := <-dialResult: - return c.completeDial(ctx, connecting, result) - case <-ctx.Done(): - go func() { - result := <-dialResult - _, _ = c.completeDial(ctx, connecting, result) - }() - return zero, ctx.Err() - case <-c.closeCtx.Done(): - go func() { - result := <-dialResult - _, _ = c.completeDial(ctx, connecting, result) - }() - return zero, ErrTransportClosed - } - } -} - -func isRecursiveConnectorDial[T any](ctx context.Context, connector *Connector[T]) bool { - dialConnector, loaded := ctx.Value(contextKeyConnecting{}).(*Connector[T]) - return loaded && dialConnector == connector -} - -func (c *Connector[T]) completeDial(ctx context.Context, connecting chan struct{}, result connectorDialResult[T]) (T, error) { - var zero T - - c.access.Lock() - defer c.access.Unlock() - defer func() { - if c.connecting == connecting { - c.connecting = nil - } - close(connecting) - }() - - if result.err != nil { - return zero, result.err - } - if c.closed || c.closeCtx.Err() != nil { - result.cancel() - c.callbacks.Close(result.connection) - return zero, ErrTransportClosed - } - if err := ctx.Err(); err != nil { - result.cancel() - c.callbacks.Close(result.connection) - return zero, err - } - - c.connection = result.connection - c.hasConnection = true - c.connectionCancel = result.cancel - return c.connection, nil -} - -func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) { - var zero T - if err := ctx.Err(); err != nil { - return zero, nil, err - } - connCtx, cancel := context.WithCancel(c.closeCtx) - - var ( - stateAccess sync.Mutex - dialComplete bool - ) - stopCancel := context.AfterFunc(ctx, func() { - stateAccess.Lock() - if !dialComplete { - cancel() - } - stateAccess.Unlock() - }) - select { - case <-ctx.Done(): - stateAccess.Lock() - dialComplete = true - stateAccess.Unlock() - stopCancel() - cancel() - return zero, nil, ctx.Err() - default: - } - - connection, err := c.dial(valueContext{connCtx, ctx}) - stateAccess.Lock() - dialComplete = true - stateAccess.Unlock() - stopCancel() - if err != nil { - cancel() - return zero, nil, err - } - return connection, cancel, nil -} - -type valueContext struct { - context.Context - parent context.Context -} - -func (v valueContext) Value(key any) any { - return v.parent.Value(key) -} - -func (v valueContext) Deadline() (time.Time, bool) { - return v.parent.Deadline() -} - -func (c *Connector[T]) Close() error { - c.access.Lock() - defer c.access.Unlock() - - if c.closed { - return nil - } - c.closed = true - - if c.connectionCancel != nil { - c.connectionCancel() - c.connectionCancel = nil - } - if c.hasConnection { - c.callbacks.Close(c.connection) - c.hasConnection = false - } - - return nil -} - -func (c *Connector[T]) Reset() { - c.access.Lock() - defer c.access.Unlock() - - if c.connectionCancel != nil { - c.connectionCancel() - c.connectionCancel = nil - } - if c.hasConnection { - c.callbacks.Reset(c.connection) - c.hasConnection = false - } -} - -type Connection struct { - net.Conn - - closeOnce sync.Once - done chan struct{} - closeError error -} - -func WrapConnection(conn net.Conn) *Connection { - return &Connection{ - Conn: conn, - done: make(chan struct{}), - } -} - -func (c *Connection) Done() <-chan struct{} { - return c.done -} - -func (c *Connection) IsClosed() bool { - select { - case <-c.done: - return true - default: - return false - } -} - -func (c *Connection) CloseError() error { - select { - case <-c.done: - if c.closeError != nil { - return c.closeError - } - return ErrTransportClosed - default: - return nil - } -} - -func (c *Connection) Close() error { - return c.CloseWithError(ErrTransportClosed) -} - -func (c *Connection) CloseWithError(err error) error { - var returnError error - c.closeOnce.Do(func() { - c.closeError = err - returnError = c.Conn.Close() - close(c.done) - }) - return returnError -} diff --git a/dns/transport/connector_test.go b/dns/transport/connector_test.go deleted file mode 100644 index 309b28c8..00000000 --- a/dns/transport/connector_test.go +++ /dev/null @@ -1,407 +0,0 @@ -package transport - -import ( - "context" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -type testConnectorConnection struct{} - -func TestConnectorRecursiveGetFailsFast(t *testing.T) { - t.Parallel() - - var ( - dialCount atomic.Int32 - closeCount atomic.Int32 - connector *Connector[*testConnectorConnection] - ) - - dial := func(ctx context.Context) (*testConnectorConnection, error) { - dialCount.Add(1) - _, err := connector.Get(ctx) - if err != nil { - return nil, err - } - return &testConnectorConnection{}, nil - } - - connector = NewConnector(context.Background(), dial, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) { - closeCount.Add(1) - }, - Reset: func(connection *testConnectorConnection) { - closeCount.Add(1) - }, - }) - - _, err := connector.Get(context.Background()) - require.ErrorIs(t, err, errRecursiveConnectorDial) - require.EqualValues(t, 1, dialCount.Load()) - require.EqualValues(t, 0, closeCount.Load()) -} - -func TestConnectorRecursiveGetAcrossConnectorsAllowed(t *testing.T) { - t.Parallel() - - var ( - outerDialCount atomic.Int32 - innerDialCount atomic.Int32 - outerConnector *Connector[*testConnectorConnection] - innerConnector *Connector[*testConnectorConnection] - ) - - innerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - innerDialCount.Add(1) - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - outerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - outerDialCount.Add(1) - _, err := innerConnector.Get(ctx) - if err != nil { - return nil, err - } - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - _, err := outerConnector.Get(context.Background()) - require.NoError(t, err) - require.EqualValues(t, 1, outerDialCount.Load()) - require.EqualValues(t, 1, innerDialCount.Load()) -} - -func TestConnectorDialContextPreservesValueAndDeadline(t *testing.T) { - t.Parallel() - - type contextKey struct{} - - var ( - dialValue any - dialDeadline time.Time - dialHasDeadline bool - ) - - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialValue = ctx.Value(contextKey{}) - dialDeadline, dialHasDeadline = ctx.Deadline() - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - deadline := time.Now().Add(time.Minute) - requestContext, cancel := context.WithDeadline(context.WithValue(context.Background(), contextKey{}, "test-value"), deadline) - defer cancel() - - _, err := connector.Get(requestContext) - require.NoError(t, err) - require.Equal(t, "test-value", dialValue) - require.True(t, dialHasDeadline) - require.WithinDuration(t, deadline, dialDeadline, time.Second) -} - -func TestConnectorDialSkipsCanceledRequest(t *testing.T) { - t.Parallel() - - var dialCount atomic.Int32 - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialCount.Add(1) - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - requestContext, cancel := context.WithCancel(context.Background()) - cancel() - - _, err := connector.Get(requestContext) - require.ErrorIs(t, err, context.Canceled) - require.EqualValues(t, 0, dialCount.Load()) -} - -func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) { - t.Parallel() - - var ( - dialCount atomic.Int32 - closeCount atomic.Int32 - ) - dialStarted := make(chan struct{}, 1) - releaseDial := make(chan struct{}) - - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialCount.Add(1) - select { - case dialStarted <- struct{}{}: - default: - } - <-releaseDial - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) { - closeCount.Add(1) - }, - Reset: func(connection *testConnectorConnection) {}, - }) - - requestContext, cancel := context.WithCancel(context.Background()) - result := make(chan error, 1) - go func() { - _, err := connector.Get(requestContext) - result <- err - }() - - <-dialStarted - cancel() - close(releaseDial) - - err := <-result - require.ErrorIs(t, err, context.Canceled) - require.EqualValues(t, 1, dialCount.Load()) - require.Eventually(t, func() bool { - return closeCount.Load() == 1 - }, time.Second, 10*time.Millisecond) - - _, err = connector.Get(context.Background()) - require.NoError(t, err) - require.EqualValues(t, 2, dialCount.Load()) -} - -func TestConnectorCanceledRequestReturnsBeforeIgnoredDialCompletes(t *testing.T) { - t.Parallel() - - var ( - dialCount atomic.Int32 - closeCount atomic.Int32 - ) - dialStarted := make(chan struct{}, 1) - releaseDial := make(chan struct{}) - - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialCount.Add(1) - select { - case dialStarted <- struct{}{}: - default: - } - <-releaseDial - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) { - closeCount.Add(1) - }, - Reset: func(connection *testConnectorConnection) {}, - }) - - requestContext, cancel := context.WithCancel(context.Background()) - result := make(chan error, 1) - go func() { - _, err := connector.Get(requestContext) - result <- err - }() - - <-dialStarted - cancel() - - select { - case err := <-result: - require.ErrorIs(t, err, context.Canceled) - case <-time.After(time.Second): - t.Fatal("Get did not return after request cancel") - } - - require.EqualValues(t, 1, dialCount.Load()) - require.EqualValues(t, 0, closeCount.Load()) - - close(releaseDial) - - require.Eventually(t, func() bool { - return closeCount.Load() == 1 - }, time.Second, 10*time.Millisecond) - - _, err := connector.Get(context.Background()) - require.NoError(t, err) - require.EqualValues(t, 2, dialCount.Load()) -} - -func TestConnectorWaiterDoesNotStartNewDialBeforeCanceledDialCompletes(t *testing.T) { - t.Parallel() - - var ( - dialCount atomic.Int32 - closeCount atomic.Int32 - ) - firstDialStarted := make(chan struct{}, 1) - secondDialStarted := make(chan struct{}, 1) - releaseFirstDial := make(chan struct{}) - - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - attempt := dialCount.Add(1) - switch attempt { - case 1: - select { - case firstDialStarted <- struct{}{}: - default: - } - <-releaseFirstDial - case 2: - select { - case secondDialStarted <- struct{}{}: - default: - } - } - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) { - closeCount.Add(1) - }, - Reset: func(connection *testConnectorConnection) {}, - }) - - requestContext, cancel := context.WithCancel(context.Background()) - firstResult := make(chan error, 1) - go func() { - _, err := connector.Get(requestContext) - firstResult <- err - }() - - <-firstDialStarted - cancel() - - secondResult := make(chan error, 1) - go func() { - _, err := connector.Get(context.Background()) - secondResult <- err - }() - - select { - case <-secondDialStarted: - t.Fatal("second dial started before first dial completed") - case <-time.After(100 * time.Millisecond): - } - - select { - case err := <-firstResult: - require.ErrorIs(t, err, context.Canceled) - case <-time.After(time.Second): - t.Fatal("first Get did not return after request cancel") - } - - close(releaseFirstDial) - - require.Eventually(t, func() bool { - return closeCount.Load() == 1 - }, time.Second, 10*time.Millisecond) - - select { - case <-secondDialStarted: - case <-time.After(time.Second): - t.Fatal("second dial did not start after first dial completed") - } - - err := <-secondResult - require.NoError(t, err) - require.EqualValues(t, 2, dialCount.Load()) -} - -func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) { - t.Parallel() - - var dialContext context.Context - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialContext = ctx - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - requestContext, cancel := context.WithCancel(context.Background()) - _, err := connector.Get(requestContext) - require.NoError(t, err) - require.NotNil(t, dialContext) - - cancel() - - select { - case <-dialContext.Done(): - t.Fatal("dial context canceled by request context after successful dial") - case <-time.After(100 * time.Millisecond): - } - - err = connector.Close() - require.NoError(t, err) -} - -func TestConnectorDialContextCanceledOnClose(t *testing.T) { - t.Parallel() - - var dialContext context.Context - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialContext = ctx - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - _, err := connector.Get(context.Background()) - require.NoError(t, err) - require.NotNil(t, dialContext) - - select { - case <-dialContext.Done(): - t.Fatal("dial context canceled before connector close") - default: - } - - err = connector.Close() - require.NoError(t, err) - - select { - case <-dialContext.Done(): - case <-time.After(time.Second): - t.Fatal("dial context not canceled after connector close") - } -} diff --git a/dns/transport/quic/quic.go b/dns/transport/quic/quic.go index 26461006..3a7b6163 100644 --- a/dns/transport/quic/quic.go +++ b/dns/transport/quic/quic.go @@ -31,14 +31,13 @@ func RegisterTransport(registry *dns.TransportRegistry) { } type Transport struct { - *transport.BaseTransport + dns.TransportAdapter - ctx context.Context dialer N.Dialer serverAddr M.Socksaddr tlsConfig tls.Config - connector *transport.Connector[*quic.Conn] + connection *transport.ConnPool[*quic.Conn] } func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) { @@ -63,93 +62,76 @@ func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options return nil, E.New("invalid server address: ", serverAddr) } - t := &Transport{ - BaseTransport: transport.NewBaseTransport( - dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions), - logger, - ), - ctx: ctx, - dialer: transportDialer, - serverAddr: serverAddr, - tlsConfig: tlsConfig, - } - - t.connector = transport.NewConnector(t.CloseContext(), t.dial, transport.ConnectorCallbacks[*quic.Conn]{ - IsClosed: func(connection *quic.Conn) bool { - return common.Done(connection.Context()) - }, - Close: func(connection *quic.Conn) { - connection.CloseWithError(0, "") - }, - Reset: func(connection *quic.Conn) { - connection.CloseWithError(0, "") - }, - }) - - return t, nil -} - -func (t *Transport) dial(ctx context.Context) (*quic.Conn, error) { - conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) - if err != nil { - return nil, E.Cause(err, "dial UDP connection") - } - earlyConnection, err := sQUIC.DialEarly( - ctx, - bufio.NewUnbindPacketConn(conn), - t.serverAddr.UDPAddr(), - t.tlsConfig, - nil, - ) - if err != nil { - conn.Close() - return nil, E.Cause(err, "establish QUIC connection") - } - return earlyConnection, nil + return &Transport{ + TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions), + dialer: transportDialer, + serverAddr: serverAddr, + tlsConfig: tlsConfig, + connection: transport.NewConnPool(transport.ConnPoolOptions[*quic.Conn]{ + Mode: transport.ConnPoolSingle, + IsAlive: func(conn *quic.Conn) bool { + return conn != nil && !common.Done(conn.Context()) + }, + Close: func(conn *quic.Conn, _ error) { + conn.CloseWithError(0, "") + }, + }), + }, nil } func (t *Transport) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } - err := t.SetStarted() - if err != nil { - return err - } return dialer.InitializeDetour(t.dialer) } func (t *Transport) Close() error { - return E.Errors(t.BaseTransport.Close(), t.connector.Close()) + return t.connection.Close() } func (t *Transport) Reset() { - t.connector.Reset() + t.connection.Reset() } func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - if !t.BeginQuery() { - return nil, transport.ErrTransportClosed - } - defer t.EndQuery() - var ( conn *quic.Conn err error response *mDNS.Msg ) for i := 0; i < 2; i++ { - conn, err = t.connector.Get(ctx) + conn, _, err = t.connection.Acquire(ctx, func(ctx context.Context) (*quic.Conn, error) { + rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) + if err != nil { + return nil, E.Cause(err, "dial UDP connection") + } + earlyConnection, err := sQUIC.DialEarly( + ctx, + bufio.NewUnbindPacketConn(rawConn), + t.serverAddr.UDPAddr(), + t.tlsConfig, + nil, + ) + if err != nil { + rawConn.Close() + return nil, E.Cause(err, "establish QUIC connection") + } + return earlyConnection, nil + }) if err != nil { return nil, err } response, err = t.exchange(ctx, message, conn) if err == nil { + t.connection.Release(conn, true) return response, nil } else if !isQUICRetryError(err) { + t.connection.Release(conn, true) return nil, err } else { - t.connector.Reset() + t.connection.Release(conn, true) + t.Reset() continue } } diff --git a/dns/transport/tls.go b/dns/transport/tls.go index 4d463296..43978b6f 100644 --- a/dns/transport/tls.go +++ b/dns/transport/tls.go @@ -2,7 +2,6 @@ package transport import ( "context" - "sync" "time" "github.com/sagernet/sing-box/adapter" @@ -17,7 +16,6 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/x/list" mDNS "github.com/miekg/dns" ) @@ -29,13 +27,13 @@ func RegisterTLS(registry *dns.TransportRegistry) { } type TLSTransport struct { - *BaseTransport + dns.TransportAdapter + logger logger.ContextLogger dialer tls.Dialer serverAddr M.Socksaddr tlsConfig tls.Config - access sync.Mutex - connections list.List[*tlsDNSConn] + connections *ConnPool[*tlsDNSConn] } type tlsDNSConn struct { @@ -66,10 +64,20 @@ func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options o func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport { return &TLSTransport{ - BaseTransport: NewBaseTransport(adapter, logger), - dialer: tls.NewDialer(dialer, tlsConfig), - serverAddr: serverAddr, - tlsConfig: tlsConfig, + TransportAdapter: adapter, + logger: logger, + dialer: tls.NewDialer(dialer, tlsConfig), + serverAddr: serverAddr, + tlsConfig: tlsConfig, + connections: NewConnPool(ConnPoolOptions[*tlsDNSConn]{ + Mode: ConnPoolOrdered, + IsAlive: func(conn *tlsDNSConn) bool { + return conn != nil + }, + Close: func(conn *tlsDNSConn, _ error) { + conn.Close() + }, + }), } } @@ -77,53 +85,43 @@ func (t *TLSTransport) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } - err := t.SetStarted() - if err != nil { - return err - } return dialer.InitializeDetour(t.dialer) } func (t *TLSTransport) Close() error { - t.access.Lock() - for connection := t.connections.Front(); connection != nil; connection = connection.Next() { - connection.Value.Close() - } - t.connections.Init() - t.access.Unlock() - return t.BaseTransport.Close() + return t.connections.Close() } func (t *TLSTransport) Reset() { - t.access.Lock() - defer t.access.Unlock() - for connection := t.connections.Front(); connection != nil; connection = connection.Next() { - connection.Value.Close() - } - t.connections.Init() + t.connections.Reset() } func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - if !t.BeginQuery() { - return nil, ErrTransportClosed - } - defer t.EndQuery() - - t.access.Lock() - conn := t.connections.PopFront() - t.access.Unlock() - if conn != nil { + var lastErr error + for attempt := 0; attempt < 2; attempt++ { + conn, created, err := t.connections.Acquire(ctx, func(ctx context.Context) (*tlsDNSConn, error) { + tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr) + if err != nil { + return nil, E.Cause(err, "dial TLS connection") + } + return &tlsDNSConn{Conn: tlsConn}, nil + }) + if err != nil { + return nil, err + } response, err := t.exchange(ctx, message, conn) if err == nil { + t.connections.Release(conn, true) return response, nil } - t.Logger.DebugContext(ctx, "discarded pooled connection: ", err) + lastErr = err + t.logger.DebugContext(ctx, "discarded pooled connection: ", err) + t.connections.Release(conn, false) + if created { + return nil, err + } } - tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr) - if err != nil { - return nil, E.Cause(err, "dial TLS connection") - } - return t.exchange(ctx, message, &tlsDNSConn{Conn: tlsConn}) + return nil, lastErr } func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) { @@ -133,22 +131,12 @@ func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tl conn.queryId++ err := WriteMessage(conn, conn.queryId, message) if err != nil { - conn.Close() return nil, E.Cause(err, "write request") } response, err := ReadMessage(conn) if err != nil { - conn.Close() return nil, E.Cause(err, "read response") } - t.access.Lock() - if t.State() >= StateClosing { - t.access.Unlock() - conn.Close() - return response, nil - } conn.SetDeadline(time.Time{}) - t.connections.PushBack(conn) - t.access.Unlock() return response, nil } diff --git a/dns/transport/udp.go b/dns/transport/udp.go index a7272545..c9f520e3 100644 --- a/dns/transport/udp.go +++ b/dns/transport/udp.go @@ -2,6 +2,7 @@ package transport import ( "context" + "net" "sync" "sync/atomic" @@ -27,13 +28,14 @@ func RegisterUDP(registry *dns.TransportRegistry) { } type UDPTransport struct { - *BaseTransport + dns.TransportAdapter + logger logger.ContextLogger dialer N.Dialer serverAddr M.Socksaddr udpSize atomic.Int32 - connector *Connector[*Connection] + connection *ConnPool[net.Conn] callbackAccess sync.RWMutex queryId uint16 @@ -63,43 +65,38 @@ func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options o func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialerInstance N.Dialer, serverAddr M.Socksaddr) *UDPTransport { t := &UDPTransport{ - BaseTransport: NewBaseTransport(adapter, logger), - dialer: dialerInstance, - serverAddr: serverAddr, - callbacks: make(map[uint16]*udpCallback), + TransportAdapter: adapter, + logger: logger, + dialer: dialerInstance, + serverAddr: serverAddr, + callbacks: make(map[uint16]*udpCallback), + connection: NewConnPool(ConnPoolOptions[net.Conn]{ + Mode: ConnPoolSingle, + IsAlive: func(conn net.Conn) bool { + return conn != nil + }, + Close: func(conn net.Conn, cause error) { + conn.Close() + }, + }), } t.udpSize.Store(2048) - t.connector = NewSingleflightConnector(t.CloseContext(), t.dial) return t } -func (t *UDPTransport) dial(ctx context.Context) (*Connection, error) { - rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) - if err != nil { - return nil, E.Cause(err, "dial UDP connection") - } - conn := WrapConnection(rawConn) - go t.recvLoop(conn) - return conn, nil -} - func (t *UDPTransport) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } - err := t.SetStarted() - if err != nil { - return err - } return dialer.InitializeDetour(t.dialer) } func (t *UDPTransport) Close() error { - return E.Errors(t.BaseTransport.Close(), t.connector.Close()) + return t.connection.Close() } func (t *UDPTransport) Reset() { - t.connector.Reset() + t.connection.Reset() } func (t *UDPTransport) nextAvailableQueryId() (uint16, error) { @@ -116,17 +113,12 @@ func (t *UDPTransport) nextAvailableQueryId() (uint16, error) { } func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - if !t.BeginQuery() { - return nil, ErrTransportClosed - } - defer t.EndQuery() - response, err := t.exchange(ctx, message) if err != nil { return nil, err } if response.Truncated { - t.Logger.InfoContext(ctx, "response truncated, retrying with TCP") + t.logger.InfoContext(ctx, "response truncated, retrying with TCP") return t.exchangeTCP(ctx, message) } return response, nil @@ -158,16 +150,25 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M break } if t.udpSize.CompareAndSwap(current, udpSize) { - t.connector.Reset() + t.Reset() break } } } - conn, err := t.connector.Get(ctx) + conn, connCtx, created, err := t.connection.AcquireShared(ctx, func(ctx context.Context) (net.Conn, error) { + rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) + if err != nil { + return nil, E.Cause(err, "dial UDP connection") + } + return rawConn, nil + }) if err != nil { return nil, err } + if created { + go t.recvLoop(conn) + } callback := &udpCallback{ done: make(chan struct{}), @@ -177,6 +178,7 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M queryId, err := t.nextAvailableQueryId() if err != nil { t.callbackAccess.Unlock() + t.connection.Release(conn, true) return nil, err } t.callbacks[queryId] = callback @@ -203,30 +205,30 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M _, err = conn.Write(rawMessage) if err != nil { - conn.CloseWithError(err) + t.connection.Invalidate(conn, err) return nil, E.Cause(err, "write request") } select { case <-callback.done: + t.connection.Release(conn, true) callback.response.Id = originalId return callback.response, nil - case <-conn.Done(): - return nil, conn.CloseError() - case <-t.CloseContext().Done(): - return nil, ErrTransportClosed + case <-connCtx.Done(): + return nil, context.Cause(connCtx) case <-ctx.Done(): + t.connection.Release(conn, true) return nil, ctx.Err() } } -func (t *UDPTransport) recvLoop(conn *Connection) { +func (t *UDPTransport) recvLoop(conn net.Conn) { for { buffer := buf.NewSize(int(t.udpSize.Load())) _, err := buffer.ReadOnceFrom(conn) if err != nil { buffer.Release() - conn.CloseWithError(err) + t.connection.Invalidate(conn, err) return } @@ -234,7 +236,7 @@ func (t *UDPTransport) recvLoop(conn *Connection) { err = message.Unpack(buffer.Bytes()) buffer.Release() if err != nil { - t.Logger.Debug("discarded malformed UDP response: ", err) + t.logger.Debug("discarded malformed UDP response: ", err) continue } diff --git a/docs/changelog.md b/docs/changelog.md index f9384cb2..3b30e932 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,11 @@ icon: material/alert-decagram --- +#### 1.13.11 + +* Fix process searcher failure introduced in 1.13.9 +* Fixes and improvements + #### 1.13.10 * Fix process searcher failure introduced in 1.13.9 diff --git a/examples/failover/client.json b/examples/fallback/client.json similarity index 92% rename from examples/failover/client.json rename to examples/fallback/client.json index b4998f13..f67e3951 100644 --- a/examples/failover/client.json +++ b/examples/fallback/client.json @@ -44,8 +44,8 @@ "uuid": "257f20d0-294a-4f07-9f2c-9efee9a37400" }, { - "type": "failover", - "tag": "failover-out", + "type": "fallback", + "tag": "fallback-out", "outbounds": [ "vless-1-out", "vless-2-out", @@ -54,7 +54,7 @@ } ], "route": { - "final": "failover-out", + "final": "fallback-out", "default_domain_resolver": "default", "auto_detect_interface": true } diff --git a/examples/manager/manager.json b/examples/manager/manager.json index d9320755..f13553f0 100644 --- a/examples/manager/manager.json +++ b/examples/manager/manager.json @@ -15,22 +15,14 @@ { "type": "direct", "tag": "direct-out" - }, - { - "type": "dns", - "tag": "dns-out" } ], "route": { "rules": [ { "protocol": "dns", - "outbound": "dns-out" - }, - { - "port": 53, - "outbound": "dns-out" - }, + "action": "hijack-dns" + } ], "final": "direct-out" }, diff --git a/examples/manager/node.json b/examples/manager/node.json index 0b330b72..6b487c81 100644 --- a/examples/manager/node.json +++ b/examples/manager/node.json @@ -26,10 +26,6 @@ "type": "direct", "tag": "direct-out" }, - { - "type": "dns", - "tag": "dns-out" - }, { "type": "bandwidth-limiter", "tag": "bandwidth-limiter", @@ -51,11 +47,7 @@ "rules": [ { "protocol": "dns", - "outbound": "dns-out" - }, - { - "port": 53, - "outbound": "dns-out" + "action": "hijack-dns" } ], "final": "connection-limiter" diff --git a/examples/masque/client.json b/examples/masque/client.json new file mode 100644 index 00000000..cd248287 --- /dev/null +++ b/examples/masque/client.json @@ -0,0 +1,58 @@ +{ + "log": { + "level": "error" + }, + "dns": { + "servers": [ + { + "type": "local", + "tag": "default" + } + ] + }, + "inbounds": [ + { + "type": "mixed", + "tag": "mixed-in", + "listen_port": 7897 + } + ], + "outbounds": [ + { + "type": "direct", + "tag": "direct" + }, + { + "type": "masque", + "tag": "masque-out", + "use_http2": false, + "use_ipv6": false, + "profile": { + "detour": "direct", + // For getting existing MASQUE device profile, else sing-box will create new profile + "id": "", + "auth_token": "" + }, + "udp_timeout": "5m0s", + "udp_keepalive_period": "30s", + "udp_initial_packet_size": 0, + "reconnect_delay": "5s", + "tls": { // https://sing-box.sagernet.org/configuration/shared/tls/#fields + "insecure": false, + "cipher_suites": [], + "curve_preferences": [], + "fragment": false, + "fragment_fallback_delay": "", + "record_fragment": false, + "kernel_tx": false, + "kernel_rx": false, + } + // Dial Fields + } + ], + "route": { + "final": "masque-out", + "default_domain_resolver": "default", + "auto_detect_interface": true + } +} \ No newline at end of file diff --git a/examples/mtproxy/server.json b/examples/mtproxy/server.json new file mode 100644 index 00000000..85d4e9bd --- /dev/null +++ b/examples/mtproxy/server.json @@ -0,0 +1,83 @@ +{ + "log": { + "level": "error" + }, + "dns": { + "servers": [ + { + "type": "local", + "tag": "default" + } + ] + }, + "inbounds": [ + { + "type": "mtproxy", + // https://sing-box.sagernet.org/configuration/shared/listen/ + "listen": "0.0.0.0", + "listen_port": 3128, + "users": [ + { + "name": "user1", + "secret": "7hBO-dCS4EBzenlKbdLFxyNnb29nbGUuY29t" + } + ], + // concurrency is a size of the worker pool for connection management. + "concurrency": 8192, + // domain_fronting_port is a port we use to connect to a fronting domain. + "domain_fronting_port": 443, + // domain_fronting_ip is an IP address to use when connecting to the fronting + // domain instead of resolving the hostname from the secret via DNS. + "domain_fronting_ip": "", + // domain_fronting_proxy_protocol is used if communication between upstream + // endpoint and sing-box supports proxy protocol. + "domain_fronting_proxy_protocol": false, + // prefer_ip defines an IP connectivity preference. Valid values are: + // 'prefer-ipv4', 'prefer-ipv6', 'only-ipv4', 'only-ipv6'. + "prefer_ip": "prefer-ipv4", + // auto_update defines if it is required to auto update proxy list from + // Telegram instead of relying on a hardcoded list. + "auto_update": false, + // allow_fallback_on_unknown_dc defines how proxy behaves if unknown DC was + // requested. If this setting is set to false, then such connection will be + // rejected. Otherwise, proxy will chose any DC. + "allow_fallback_on_unknown_dc": false, + // tolerate_time_skewness is a time boundary that defines a time range where + // faketls timestamp is acceptable. + "tolerate_time_skewness": "", + // idle_timeout is a timeout for relay when we have to break a stream. + "idle_timeout": "5m", + // handshake_timeout is a timeout during which all handshake ceremonies must + // be completed, otherwise this process will be aborted + "handshake_timeout": "10s", + // doppelganger_urls is a list of URLs that should be crawled by + // sing-box to calculate parameters for statistical distribution of a + // traffic for fronting domains. + "doppelganger_urls": [], + // doppelganger_per_raid defines how many time each URL from + // doppelganger_urls list should be crawled per raid. + "doppelganger_per_raid": 10, + // doppelganger_each defines a time period between each raid. We recommend + // to use hours here. + "doppelganger_each": "6h", + // doppelganger_drs defines if TLS Dynamic Record Sizing is active. + "doppelganger_drs": false, + // throttle_max_connections is the total connection limit. + "throttle_max_connections": 0, + // throttle_check_interval is how often the throttle recomputes per-user + // caps. + "throttle_check_interval": "5s" + } + ], + "outbounds": [ + { + "type": "direct", + "tag": "direct" + } + ], + "route": { + "final": "direct", + "default_domain_resolver": "default", + "auto_detect_interface": true + } +} diff --git a/examples/parser/client.json b/examples/parser/client.json new file mode 100644 index 00000000..5af54473 --- /dev/null +++ b/examples/parser/client.json @@ -0,0 +1,37 @@ +{ + "log": { + "level": "error" + }, + "dns": { + "servers": [ + { + "type": "local", + "tag": "default" + } + ] + }, + "inbounds": [ + { + "type": "mixed", + "tag": "mixed-in", + "listen_port": 7897 + } + ], + "outbounds": [ + { + "type": "direct", + "tag": "direct" + }, + { + "type": "parser", + "tag": "vless-out", + // Supported protocols: hysteria, hysteria2, shadowsocks, trojan, tuic, vless, vmess + "link": "vless://b5e41c8c-c437-4689-b863-76208a3efb4b@0.0.0.0:443?..." + } + ], + "route": { + "final": "vless-out", + "default_domain_resolver": "default", + "auto_detect_interface": true + } +} \ No newline at end of file diff --git a/examples/tunnel/client-server/server.json b/examples/tunnel/client-server/server.json deleted file mode 100644 index 282a2f37..00000000 --- a/examples/tunnel/client-server/server.json +++ /dev/null @@ -1,49 +0,0 @@ -{ - "log": { - "level": "info" - }, - "dns": { - "servers": [ - { - "type": "local", - "tag": "default" - } - ] - }, - "endpoints": [ - { - "type": "tunnel-server", - "tag": "tunnel", - "uuid": "f79f7678-55e7-432d-a15f-6e8ab2b7fe13", - "users": [ - { - "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937", - "key": "1c9b2ccf-b0c0-4c26-868d-a55a4edad3fe" - } - ], - "inbound": { - "type": "vless", - "tag": "vless-in", - "listen": "0.0.0.0", - "listen_port": 8000, - "users": [ - { - "name": "vless", - "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937" - } - ] - } - } - ], - "outbounds": [ - { - "type": "direct", - "tag": "direct-out" - } - ], - "route": { - "final": "direct-out", - "default_domain_resolver": "default", - "auto_detect_interface": true - } -} \ No newline at end of file diff --git a/examples/tunnel/client1-server-client2/server.json b/examples/tunnel/client1-server-client2/server.json deleted file mode 100644 index 0a83bd8f..00000000 --- a/examples/tunnel/client1-server-client2/server.json +++ /dev/null @@ -1,66 +0,0 @@ -{ - "log": { - "level": "error" - }, - "dns": { - "servers": [ - { - "type": "local", - "tag": "default" - } - ] - }, - "endpoints": [ - { - "type": "tunnel-server", - "tag": "tunnel", - "uuid": "f79f7678-55e7-432d-a15f-6e8ab2b7fe13", - "users": [ - { - "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937", - "key": "1c9b2ccf-b0c0-4c26-868d-a55a4edad3fe" - }, - { - "uuid": "487f6073-3300-4819-a07d-39652e45fb4d", - "key": "3d74d616-2502-4c17-9cc3-92c366550f4f" - } - ], - "inbound": { - "type": "vless", - "tag": "vless-in", - "listen": "0.0.0.0", - "listen_port": 8000, - "users": [ - { - "name": "vless", - "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937" - } - ] - } - } - ], - "outbounds": [ - { - "type": "direct", - "tag": "direct-out" - } - ], - "route": { - "rules": [ - { - "tunnel_source": [ - "9b65b7e1-04c8-4717-8f45-2aa61fd25937", - "487f6073-3300-4819-a07d-39652e45fb4d" - ], - "tunnel_destination": [ - "9b65b7e1-04c8-4717-8f45-2aa61fd25937", - "487f6073-3300-4819-a07d-39652e45fb4d" - ], - "outbound": "tunnel" - } - ], - "final": "direct-out", - "default_domain_resolver": "default", - "auto_detect_interface": true - } -} \ No newline at end of file diff --git a/examples/tunnel/client-server/client.json b/examples/vpn/client-server/client.json similarity index 60% rename from examples/tunnel/client-server/client.json rename to examples/vpn/client-server/client.json index 91941a82..81508ebf 100644 --- a/examples/tunnel/client-server/client.json +++ b/examples/vpn/client-server/client.json @@ -1,6 +1,6 @@ { "log": { - "level": "info" + "level": "error" }, "dns": { "servers": [ @@ -12,9 +12,9 @@ }, "endpoints": [ { - "type": "tunnel-client", - "tag": "tunnel", - "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937", + "type": "vpn-client", + "tag": "vpn", + "address": "10.0.0.2", "key": "1c9b2ccf-b0c0-4c26-868d-a55a4edad3fe", "outbound": { "type": "vless", @@ -30,33 +30,26 @@ { "type": "mixed", "tag": "mixed-in", - "listen_port": 10000 + "listen_port": 7897 } ], "outbounds": [ { "type": "direct", "tag": "direct-out" - }, - { - "type": "dns", - "tag": "dns-out" - }, - { - "type": "failover", - "tag": "f", - "outbounds": ["tunnel", "direct-out"], - "interrupt_exist_connections": false, } ], "route": { "rules": [ { - "outbound": "f", - "override_tunnel_destination": "f79f7678-55e7-432d-a15f-6e8ab2b7fe13" + "protocol": "dns", + "action": "hijack-dns" + }, + { + "outbound": "vpn", } ], - "final": "f", + "final": "direct-out", "default_domain_resolver": "default", "auto_detect_interface": true } diff --git a/examples/vpn/client-server/server.json b/examples/vpn/client-server/server.json new file mode 100644 index 00000000..1c68ed46 --- /dev/null +++ b/examples/vpn/client-server/server.json @@ -0,0 +1,51 @@ +{ + "log": { + "level": "error" + }, + "dns": { + "servers": [ + { + "type": "local", + "tag": "default" + } + ] + }, + "endpoints": [ + { + "type": "vpn-server", + "tag": "vpn", + "address": "10.0.0.1", + "users": [ + { + "address": "10.0.0.2", + "key": "1c9b2ccf-b0c0-4c26-868d-a55a4edad3fe" + } + ], + "inbounds": [ + { + "type": "vless", + "tag": "vless-in", + "listen": "0.0.0.0", + "listen_port": 8000, + "users": [ + { + "name": "vless", + "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937" + } + ] + } + ] + } + ], + "outbounds": [ + { + "type": "direct", + "tag": "direct-out" + } + ], + "route": { + "final": "direct-out", + "default_domain_resolver": "default", + "auto_detect_interface": true + } +} \ No newline at end of file diff --git a/examples/tunnel/client1-server-client2/client1.json b/examples/vpn/client1-server-client2/client1.json similarity index 77% rename from examples/tunnel/client1-server-client2/client1.json rename to examples/vpn/client1-server-client2/client1.json index 29b72784..70600225 100644 --- a/examples/tunnel/client1-server-client2/client1.json +++ b/examples/vpn/client1-server-client2/client1.json @@ -12,9 +12,9 @@ }, "endpoints": [ { - "type": "tunnel-client", - "tag": "tunnel", - "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937", + "type": "vpn-client", + "tag": "vpn", + "address": "10.0.0.2", "key": "1c9b2ccf-b0c0-4c26-868d-a55a4edad3fe", "outbound": { "type": "vless", @@ -42,8 +42,8 @@ "route": { "rules": [ { - "outbound": "tunnel", - "override_tunnel_destination": "487f6073-3300-4819-a07d-39652e45fb4d" + "outbound": "vpn", + "override_gateway": "10.0.0.3" } ], "final": "direct-out", diff --git a/examples/tunnel/client1-server-client2/client2.json b/examples/vpn/client1-server-client2/client2.json similarity index 85% rename from examples/tunnel/client1-server-client2/client2.json rename to examples/vpn/client1-server-client2/client2.json index ef13a2c1..907f638e 100644 --- a/examples/tunnel/client1-server-client2/client2.json +++ b/examples/vpn/client1-server-client2/client2.json @@ -12,9 +12,9 @@ }, "endpoints": [ { - "type": "tunnel-client", - "tag": "tunnel", - "uuid": "487f6073-3300-4819-a07d-39652e45fb4d", + "type": "vpn-client", + "tag": "vpn", + "address": "10.0.0.3", "key": "3d74d616-2502-4c17-9cc3-92c366550f4f", "outbound": { "type": "vless", diff --git a/examples/vpn/client1-server-client2/server.json b/examples/vpn/client1-server-client2/server.json new file mode 100644 index 00000000..bc6aafb5 --- /dev/null +++ b/examples/vpn/client1-server-client2/server.json @@ -0,0 +1,61 @@ +{ + "log": { + "level": "error" + }, + "dns": { + "servers": [ + { + "type": "local", + "tag": "default" + } + ] + }, + "endpoints": [ + { + "type": "vpn-server", + "tag": "vpn", + "address": "10.0.0.1", + "users": [ + { + "address": "10.0.0.2", + "key": "1c9b2ccf-b0c0-4c26-868d-a55a4edad3fe" + }, + { + "address": "10.0.0.3", + "key": "3d74d616-2502-4c17-9cc3-92c366550f4f" + } + ], + "inbounds": [ + { + "type": "vless", + "tag": "vless-in", + "listen": "0.0.0.0", + "listen_port": 8000, + "users": [ + { + "name": "vless", + "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937" + } + ] + } + ] + } + ], + "outbounds": [ + { + "type": "direct", + "tag": "direct-out" + } + ], + "route": { + "rules": [ + { + "source_ip_cidr": "10.0.0.0/24", + "outbound": "vpn" + } + ], + "final": "direct-out", + "default_domain_resolver": "default", + "auto_detect_interface": true + } +} \ No newline at end of file diff --git a/examples/tunnel/proxy_client-server-tunnel_client/proxy_client.json b/examples/vpn/proxy_client-server-tunnel_client/proxy_client.json similarity index 91% rename from examples/tunnel/proxy_client-server-tunnel_client/proxy_client.json rename to examples/vpn/proxy_client-server-tunnel_client/proxy_client.json index 390b73b3..eabb3192 100644 --- a/examples/tunnel/proxy_client-server-tunnel_client/proxy_client.json +++ b/examples/vpn/proxy_client-server-tunnel_client/proxy_client.json @@ -29,10 +29,6 @@ "server_port": 8000, "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937", "network": "tcp" - }, - { - "type": "dns", - "tag": "dns-out" } ], "route": { diff --git a/examples/tunnel/proxy_client-server-tunnel_client/server.json b/examples/vpn/proxy_client-server-tunnel_client/server.json similarity index 75% rename from examples/tunnel/proxy_client-server-tunnel_client/server.json rename to examples/vpn/proxy_client-server-tunnel_client/server.json index 6efc6efb..629eb072 100644 --- a/examples/tunnel/proxy_client-server-tunnel_client/server.json +++ b/examples/vpn/proxy_client-server-tunnel_client/server.json @@ -12,12 +12,12 @@ }, "endpoints": [ { - "type": "tunnel-server", - "tag": "tunnel", - "uuid": "f79f7678-55e7-432d-a15f-6e8ab2b7fe13", + "type": "vpn-server", + "tag": "vpn", + "address": "10.0.0.1", "users": [ { - "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937", + "address": "10.0.0.2", "key": "1c9b2ccf-b0c0-4c26-868d-a55a4edad3fe" } ], @@ -45,8 +45,8 @@ "rules": [ { "inbound": "vless-in", - "outbound": "tunnel", - "override_tunnel_destination": "9b65b7e1-04c8-4717-8f45-2aa61fd25937" + "outbound": "vpn", + "override_gateway": "10.0.0.2" } ], "final": "direct-out", diff --git a/examples/tunnel/proxy_client-server-tunnel_client/tunnel_client.json b/examples/vpn/proxy_client-server-tunnel_client/tunnel_client.json similarity index 85% rename from examples/tunnel/proxy_client-server-tunnel_client/tunnel_client.json rename to examples/vpn/proxy_client-server-tunnel_client/tunnel_client.json index d3d9d7d5..6df5d666 100644 --- a/examples/tunnel/proxy_client-server-tunnel_client/tunnel_client.json +++ b/examples/vpn/proxy_client-server-tunnel_client/tunnel_client.json @@ -12,9 +12,9 @@ }, "endpoints": [ { - "type": "tunnel-client", - "tag": "tunnel", - "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937", + "type": "vpn-client", + "tag": "vpn", + "address": "10.0.0.2", "key": "1c9b2ccf-b0c0-4c26-868d-a55a4edad3fe", "outbound": { "type": "vless", diff --git a/examples/tunnel/server-client/client.json b/examples/vpn/server-client/client.json similarity index 85% rename from examples/tunnel/server-client/client.json rename to examples/vpn/server-client/client.json index 95b146f8..95f369fc 100644 --- a/examples/tunnel/server-client/client.json +++ b/examples/vpn/server-client/client.json @@ -12,9 +12,9 @@ }, "endpoints": [ { - "type": "tunnel-client", - "tag": "tunnel", - "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937", + "type": "vpn-client", + "tag": "vpn", + "address": "10.0.0.2", "key": "1c9b2ccf-b0c0-4c26-868d-a55a4edad3fe", "outbound": { "type": "vless", diff --git a/examples/tunnel/server-client/server.json b/examples/vpn/server-client/server.json similarity index 76% rename from examples/tunnel/server-client/server.json rename to examples/vpn/server-client/server.json index 52a26613..c02f480a 100644 --- a/examples/tunnel/server-client/server.json +++ b/examples/vpn/server-client/server.json @@ -12,12 +12,12 @@ }, "endpoints": [ { - "type": "tunnel-server", - "tag": "tunnel", - "uuid": "f79f7678-55e7-432d-a15f-6e8ab2b7fe13", + "type": "vpn-server", + "tag": "vpn", + "address": "10.0.0.1", "users": [ { - "uuid": "9b65b7e1-04c8-4717-8f45-2aa61fd25937", + "address": "10.0.0.2", "key": "1c9b2ccf-b0c0-4c26-868d-a55a4edad3fe" } ], @@ -51,8 +51,8 @@ "route": { "rules": [ { - "outbound": "tunnel", - "override_tunnel_destination": "9b65b7e1-04c8-4717-8f45-2aa61fd25937" + "outbound": "vpn", + "override_gateway": "10.0.0.2" } ], "final": "direct-out", diff --git a/experimental/cachefile/cache.go b/experimental/cachefile/cache.go index 03ef055f..24eb5112 100644 --- a/experimental/cachefile/cache.go +++ b/experimental/cachefile/cache.go @@ -44,6 +44,7 @@ type CacheFile struct { storeFakeIP bool storeRDRC bool storeWARPConfig bool + storeMASQUEConfig bool rdrcTimeout time.Duration DB *bbolt.DB resetAccess sync.Mutex @@ -82,17 +83,18 @@ func New(ctx context.Context, options option.CacheFileOptions) *CacheFile { } } return &CacheFile{ - ctx: ctx, - path: filemanager.BasePath(ctx, path), - cacheID: cacheIDBytes, - storeFakeIP: options.StoreFakeIP, - storeRDRC: options.StoreRDRC, - storeWARPConfig: options.StoreWARPConfig, - rdrcTimeout: rdrcTimeout, - saveDomain: make(map[netip.Addr]string), - saveAddress4: make(map[string]netip.Addr), - saveAddress6: make(map[string]netip.Addr), - saveRDRC: make(map[saveRDRCCacheKey]bool), + ctx: ctx, + path: filemanager.BasePath(ctx, path), + cacheID: cacheIDBytes, + storeFakeIP: options.StoreFakeIP, + storeRDRC: options.StoreRDRC, + storeWARPConfig: options.StoreWARPConfig, + storeMASQUEConfig: options.StoreMASQUEConfig, + rdrcTimeout: rdrcTimeout, + saveDomain: make(map[netip.Addr]string), + saveAddress4: make(map[string]netip.Addr), + saveAddress6: make(map[string]netip.Addr), + saveRDRC: make(map[saveRDRCCacheKey]bool), } } @@ -366,6 +368,10 @@ func (c *CacheFile) StoreWARPConfig() bool { return c.storeWARPConfig } +func (c *CacheFile) StoreMASQUEConfig() bool { + return c.storeMASQUEConfig +} + func (c *CacheFile) LoadWARPConfig(tag string) *adapter.SavedBinary { var savedConfig adapter.SavedBinary err := c.DB.View(func(t *bbolt.Tx) error { @@ -398,3 +404,69 @@ func (c *CacheFile) SaveWARPConfig(tag string, set *adapter.SavedBinary) error { return bucket.Put([]byte(tag), configBinary) }) } + +func (c *CacheFile) LoadMASQUEConfig(tag string) *adapter.SavedBinary { + var savedConfig adapter.SavedBinary + err := c.DB.View(func(t *bbolt.Tx) error { + bucket := c.bucket(t, bucketRuleSet) + if bucket == nil { + return os.ErrNotExist + } + configBinary := bucket.Get([]byte(tag)) + if len(configBinary) == 0 { + return os.ErrInvalid + } + return savedConfig.UnmarshalBinary(configBinary) + }) + if err != nil { + return nil + } + return &savedConfig +} + +func (c *CacheFile) SaveMASQUEConfig(tag string, set *adapter.SavedBinary) error { + return c.DB.Batch(func(t *bbolt.Tx) error { + bucket, err := c.createBucket(t, bucketRuleSet) + if err != nil { + return err + } + configBinary, err := set.MarshalBinary() + if err != nil { + return err + } + return bucket.Put([]byte(tag), configBinary) + }) +} + +func (c *CacheFile) LoadSubscription(tag string) *adapter.SavedBinary { + var savedSet adapter.SavedBinary + err := c.DB.View(func(t *bbolt.Tx) error { + bucket := c.bucket(t, bucketRuleSet) + if bucket == nil { + return os.ErrNotExist + } + setBinary := bucket.Get([]byte(tag)) + if len(setBinary) == 0 { + return os.ErrInvalid + } + return savedSet.UnmarshalBinary(setBinary) + }) + if err != nil { + return nil + } + return &savedSet +} + +func (c *CacheFile) SaveSubscription(tag string, sub *adapter.SavedBinary) error { + return c.DB.Batch(func(t *bbolt.Tx) error { + bucket, err := c.createBucket(t, bucketRuleSet) + if err != nil { + return err + } + setBinary, err := sub.MarshalBinary() + if err != nil { + return err + } + return bucket.Put([]byte(tag), setBinary) + }) +} diff --git a/experimental/clashapi/provider.go b/experimental/clashapi/provider.go index 352b2894..f2487e49 100644 --- a/experimental/clashapi/provider.go +++ b/experimental/clashapi/provider.go @@ -4,48 +4,78 @@ import ( "context" "net/http" + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common/json/badjson" + "github.com/go-chi/chi/v5" "github.com/go-chi/render" ) -func proxyProviderRouter() http.Handler { +func proxyProviderRouter(server *Server) http.Handler { r := chi.NewRouter() - r.Get("/", getProviders) + r.Get("/", getProviders(server)) r.Route("/{name}", func(r chi.Router) { - r.Use(parseProviderName, findProviderByName) - r.Get("/", getProvider) + r.Use(parseProviderName, findProviderByName(server)) + r.Get("/", getProvider(server)) r.Put("/", updateProvider) r.Get("/healthcheck", healthCheckProvider) }) return r } -func getProviders(w http.ResponseWriter, r *http.Request) { - render.JSON(w, r, render.M{ - "providers": render.M{}, - }) +func getProviders(server *Server) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + providerMap := make(render.M) + for _, provider := range server.provider.Providers() { + providerMap[provider.Tag()] = providerInfo(server, provider) + } + render.JSON(w, r, render.M{ + "providers": providerMap, + }) + } } -func getProvider(w http.ResponseWriter, r *http.Request) { - /*provider := r.Context().Value(CtxKeyProvider).(provider.ProxyProvider) - render.JSON(w, r, provider)*/ - render.NoContent(w, r) +func getProvider(server *Server) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + provider := r.Context().Value(CtxKeyProvider).(adapter.Provider) + render.JSON(w, r, providerInfo(server, provider)) + } +} + +func providerInfo(server *Server, p adapter.Provider) *badjson.JSONObject { + var info badjson.JSONObject + proxies := make([]*badjson.JSONObject, 0) + for _, detour := range p.Outbounds() { + proxies = append(proxies, proxyInfo(server, detour)) + } + info.Put("type", "Proxy") // Proxy, Rule + info.Put("vehicleType", C.ProviderDisplayName(p.Type())) // HTTP, File, Compatible + info.Put("name", p.Tag()) + info.Put("proxies", proxies) + info.Put("updatedAt", p.UpdatedAt()) + if p, ok := p.(adapter.ProviderSubscriptionInfo); ok { + info.Put("subscriptionInfo", p.SubscriptionInfo()) + } + return &info } func updateProvider(w http.ResponseWriter, r *http.Request) { - /*provider := r.Context().Value(CtxKeyProvider).(provider.ProxyProvider) - if err := provider.Update(); err != nil { - render.Status(r, http.StatusServiceUnavailable) - render.JSON(w, r, newError(err.Error())) - return - }*/ + provider := r.Context().Value(CtxKeyProvider).(adapter.Provider) + if provider, isUpdater := provider.(adapter.ProviderUpdater); isUpdater { + if err := provider.Update(); err != nil { + render.Status(r, http.StatusServiceUnavailable) + render.JSON(w, r, newError(err.Error())) + return + } + } render.NoContent(w, r) } func healthCheckProvider(w http.ResponseWriter, r *http.Request) { - /*provider := r.Context().Value(CtxKeyProvider).(provider.ProxyProvider) - provider.HealthCheck()*/ + provider := r.Context().Value(CtxKeyProvider).(adapter.Provider) + provider.HealthCheck(r.Context()) render.NoContent(w, r) } @@ -57,18 +87,19 @@ func parseProviderName(next http.Handler) http.Handler { }) } -func findProviderByName(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - /*name := r.Context().Value(CtxKeyProviderName).(string) - providers := tunnel.ProxyProviders() - provider, exist := providers[name] - if !exist {*/ - render.Status(r, http.StatusNotFound) - render.JSON(w, r, ErrNotFound) - //return - //} +func findProviderByName(server *Server) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + name := r.Context().Value(CtxKeyProviderName).(string) + provider, exist := server.provider.Get(name) + if !exist { + render.Status(r, http.StatusNotFound) + render.JSON(w, r, ErrNotFound) + return + } - // ctx := context.WithValue(r.Context(), CtxKeyProvider, provider) - // next.ServeHTTP(w, r.WithContext(ctx)) - }) + ctx := context.WithValue(r.Context(), CtxKeyProvider, provider) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } } diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index ec40a95f..c5255314 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -46,6 +46,7 @@ type Server struct { dnsRouter adapter.DNSRouter outbound adapter.OutboundManager endpoint adapter.EndpointManager + provider adapter.ProviderManager logger log.Logger httpServer *http.Server trafficManager *trafficontrol.Manager @@ -71,6 +72,7 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op dnsRouter: service.FromContext[adapter.DNSRouter](ctx), outbound: service.FromContext[adapter.OutboundManager](ctx), endpoint: service.FromContext[adapter.EndpointManager](ctx), + provider: service.FromContext[adapter.ProviderManager](ctx), logger: logFactory.NewLogger("clash-api"), httpServer: &http.Server{ Addr: options.ExternalController, @@ -122,7 +124,7 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op r.Mount("/proxies", proxyRouter(s, s.router)) r.Mount("/rules", ruleRouter(s.router)) r.Mount("/connections", connectionRouter(s.ctx, s.router, trafficManager)) - r.Mount("/providers/proxies", proxyProviderRouter()) + r.Mount("/providers/proxies", proxyProviderRouter(s)) r.Mount("/providers/rules", ruleProviderRouter()) r.Mount("/script", scriptRouter()) r.Mount("/profile", profileRouter()) diff --git a/experimental/libbox/config.go b/experimental/libbox/config.go index 122425d2..45156f77 100644 --- a/experimental/libbox/config.go +++ b/experimental/libbox/config.go @@ -3,6 +3,7 @@ package libbox import ( "bytes" "context" + "net/netip" "os" box "github.com/sagernet/sing-box" @@ -33,7 +34,7 @@ func baseContext(platformInterface PlatformInterface) context.Context { } ctx := context.Background() ctx = filemanager.WithDefault(ctx, sWorkingPath, sTempPath, sUserID, sGroupID) - return box.Context(ctx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry(), dnsRegistry, include.ServiceRegistry()) + return box.Context(ctx, include.InboundRegistry(), include.OutboundRegistry(), include.EndpointRegistry(), include.ProviderRegistry(), dnsRegistry, include.ServiceRegistry()) } func parseConfig(ctx context.Context, configContent string) (option.Options, error) { @@ -144,6 +145,10 @@ func (s *platformInterfaceStub) SendNotification(notification *adapter.Notificat return nil } +func (s *platformInterfaceStub) MyInterfaceAddress() []netip.Addr { + return nil +} + func (s *platformInterfaceStub) UsePlatformLocalDNSTransport() bool { return false } diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index 7d0b3004..37fd56c9 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -29,6 +29,7 @@ type platformInterfaceWrapper struct { useProcFS bool networkManager adapter.NetworkManager myTunName string + myTunAddress []netip.Addr defaultInterfaceAccess sync.Mutex defaultInterface *control.Interface isExpensive bool @@ -78,9 +79,25 @@ func (w *platformInterfaceWrapper) OpenInterface(options *tun.Options, platformO } options.FileDescriptor = dupFd w.myTunName = options.Name + w.myTunAddress = myTunAddress(options) return tun.New(*options) } +func myTunAddress(options *tun.Options) []netip.Addr { + addresses := make([]netip.Addr, 0, len(options.Inet4Address)+len(options.Inet6Address)) + for _, prefix := range options.Inet4Address { + addresses = append(addresses, prefix.Addr()) + } + for _, prefix := range options.Inet6Address { + addresses = append(addresses, prefix.Addr()) + } + return addresses +} + +func (w *platformInterfaceWrapper) MyInterfaceAddress() []netip.Addr { + return w.myTunAddress +} + func (w *platformInterfaceWrapper) UsePlatformDefaultInterfaceMonitor() bool { return true } diff --git a/go.mod b/go.mod index 1b92c07b..f01b57e3 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,9 @@ module github.com/sagernet/sing-box -go 1.25.5 +go 1.26.1 require ( + github.com/Diniboy1123/connect-ip-go v0.0.0-20260409225322-8d7bb0a858a2 github.com/GoAdminGroup/go-admin v1.2.26 github.com/GoAdminGroup/themes v0.0.48 github.com/anthropics/anthropic-sdk-go v1.26.0 @@ -33,7 +34,7 @@ require ( github.com/miekg/dns v1.1.72 github.com/openai/openai-go/v3 v3.26.0 github.com/oschwald/maxminddb-golang v1.13.1 - github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/patrickmn/go-cache/v2 v2.0.0-00010101000000-000000000000 github.com/sagernet/asc-go v0.0.0-20241217030726-d563060fe4e1 github.com/sagernet/bbolt v0.0.0-20231014093535-ea5cb2fe9f0a github.com/sagernet/cors v1.2.1 @@ -55,9 +56,11 @@ require ( github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.7 github.com/sagernet/wireguard-go v0.0.2-beta.1.0.20260224074747-506b7631853c github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 + github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 github.com/spf13/cobra v1.10.2 github.com/stretchr/testify v1.11.1 github.com/vishvananda/netns v0.0.5 + github.com/yosida95/uritemplate/v3 v3.0.2 go.uber.org/zap v1.27.1 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba golang.org/x/crypto v0.49.0 @@ -72,7 +75,12 @@ require ( ) require ( - github.com/kr/pretty v0.3.1 // indirect + github.com/OneOfOne/xxhash v1.2.8 // indirect + github.com/dunglas/httpsfv v1.1.0 // indirect + github.com/panjf2000/ants/v2 v2.12.0 // indirect + github.com/quic-go/quic-go v0.59.0 // indirect + github.com/tylertreat/BoomFilters v0.0.0-20251117164519-53813c36cc1b // indirect + github.com/yl2chen/cidranger v1.0.2 // indirect gvisor.dev/gvisor v0.0.0-20260408064518-65a410b0d584 // indirect ) @@ -95,6 +103,7 @@ require ( github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 // indirect + github.com/dolonet/mtg-multi v1.8.0 github.com/ebitengine/purego v0.9.1 // indirect github.com/florianl/go-nfqueue/v2 v2.0.2 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect @@ -127,7 +136,7 @@ require ( github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jsimonetti/rtnetlink v1.4.0 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/klauspost/compress v1.18.0 // indirect + github.com/klauspost/compress v1.18.3 // indirect github.com/klauspost/cpuid/v2 v2.3.0 github.com/leodido/go-urn v1.4.0 // indirect github.com/libdns/libdns v1.1.1 // indirect @@ -141,7 +150,7 @@ require ( github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/nxadm/tail v1.4.11 // indirect github.com/pierrec/lz4/v4 v4.1.21 // indirect - github.com/pires/go-proxyproto v0.8.1 // indirect + github.com/pires/go-proxyproto v0.11.0 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus-community/pro-bing v0.4.0 // indirect github.com/quic-go/qpack v0.6.0 // indirect @@ -186,7 +195,7 @@ require ( github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 // indirect github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc // indirect github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 // indirect - github.com/tidwall/gjson v1.18.0 + github.com/tidwall/gjson v1.18.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect @@ -210,13 +219,13 @@ require ( gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect + gopkg.in/yaml.v3 v3.0.1 lukechampine.com/blake3 v1.4.1 xorm.io/builder v0.3.7 // indirect xorm.io/xorm v1.0.2 // indirect ) -replace github.com/sagernet/wireguard-go => github.com/shtorm-7/wireguard-go v0.0.2-beta.1-extended-1.3.4 +replace github.com/sagernet/wireguard-go => github.com/shtorm-7/wireguard-go v0.0.2-beta.1-extended-1.4.0 replace github.com/sagernet/tailscale => github.com/shtorm-7/tailscale v1.92.4-sing-box-1.13-mod.7-extended-1.0.2 @@ -225,3 +234,9 @@ replace github.com/sagernet/sing-mux => github.com/shtorm-7/sing-mux v0.3.4-exte replace github.com/ameshkov/dnscrypt/v2 => github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0 replace github.com/sagernet/sing-vmess => github.com/starifly/sing-vmess v0.2.7-mod.9 + +replace github.com/patrickmn/go-cache/v2 => github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.0.2 + +replace github.com/dolonet/mtg-multi => github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.0 + +replace github.com/Diniboy1123/connect-ip-go => github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 diff --git a/go.sum b/go.sum index d2f3f3c6..4ecaa45b 100644 --- a/go.sum +++ b/go.sum @@ -24,6 +24,8 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/NebulousLabs/fastrand v0.0.0-20181203155948-6fb6489aac4e h1:n+DcnTNkQnHlwpsrHoQtkrJIO7CBx029fw6oR4vIob4= github.com/NebulousLabs/fastrand v0.0.0-20181203155948-6fb6489aac4e/go.mod h1:Bdzq+51GR4/0DIhaICZEOm+OHvXGwwB2trKZ8B4Y6eQ= +github.com/OneOfOne/xxhash v1.2.8 h1:31czK/TI9sNkxIKfaUfGlU47BAxQ0ztGgd9vPyqimf8= +github.com/OneOfOne/xxhash v1.2.8/go.mod h1:eZbhyaAYD41SGSSsnmcpxVoRiQ/MPUTjUdIIOT9Um7Q= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= @@ -43,6 +45,8 @@ github.com/anthropics/anthropic-sdk-go v1.26.0/go.mod h1:qUKmaW+uuPB64iy1l+4kOSv github.com/anytls/sing-anytls v0.0.11 h1:w8e9Uj1oP3m4zxkyZDewPk0EcQbvVxb7Nn+rapEx4fc= github.com/anytls/sing-anytls v0.0.11/go.mod h1:7rjN6IukwysmdusYsrV51Fgu1uW6vsrdd6ctjnEAln8= github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= +github.com/babolivier/go-doh-client v0.0.0-20201028162107-a76cff4cb8b6 h1:4NNbNM2Iq/k57qEu7WfL67UrbPq1uFWxW4qODCohi+0= +github.com/babolivier/go-doh-client v0.0.0-20201028162107-a76cff4cb8b6/go.mod h1:J29hk+f9lJrblVIfiJOtTFk+OblBawmib4uz/VdKzlg= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/caddyserver/certmagic v0.25.2 h1:D7xcS7ggX/WEY54x0czj7ioTkmDWKIgxtIi2OcQclUc= github.com/caddyserver/certmagic v0.25.2/go.mod h1:llW/CvsNmza8S6hmsuggsZeiX+uS27dkqY27wDIuBWg= @@ -64,9 +68,10 @@ github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmC github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/cretz/bine v0.2.0 h1:8GiDRGlTgz+o8H9DSnsl+5MeBK4HsExxgl6WgzOCuZo= github.com/cretz/bine v0.2.0/go.mod h1:WU4o9QR9wWp8AVKtTM1XD5vUHkEqnf2vVSo6dBqbetI= +github.com/d4l3k/messagediff v1.2.1 h1:ZcAIMYsUg0EAp9X+tt8/enBE/Q8Yd5kzPynLyKptt9U= +github.com/d4l3k/messagediff v1.2.1/go.mod h1:Oozbb1TVXFac9FtSIxHBMnBCq2qeH/2KkEQxENCrlLo= github.com/database64128/netx-go v0.1.1 h1:dT5LG7Gs7zFZBthFBbzWE6K8wAHjSNAaK7wCYZT7NzM= github.com/database64128/netx-go v0.1.1/go.mod h1:LNlYVipaYkQArRFDNNJ02VkNV+My9A5XR/IGS7sIBQc= github.com/database64128/tfo-go/v2 v2.3.2 h1:UhZMKiMq3swZGUiETkLBDzQnZBPSAeBMClpJGlnJ5Fw= @@ -94,6 +99,8 @@ github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dunglas/httpsfv v1.1.0 h1:Jw76nAyKWKZKFrpMMcL76y35tOpYHqQPzHQiwDvpe54= +github.com/dunglas/httpsfv v1.1.0/go.mod h1:zID2mqw9mFsnt7YC3vYQ9/cjq30q41W+1AnDwH8TiMg= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= @@ -224,6 +231,8 @@ github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jarcoal/httpmock v1.0.8 h1:8kI16SoO6LQKgPE7PvQuV+YuD/inwHd7fOOe2zMbo4k= +github.com/jarcoal/httpmock v1.0.8/go.mod h1:ATjnClrvW/3tijVmpL/va5Z3aAyGvqU3gCT8nX0Txik= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jsimonetti/rtnetlink v1.4.0 h1:Z1BF0fRgcETPEa0Kt0MRk3yV5+kF1FWTni6KUFKrq2I= github.com/jsimonetti/rtnetlink v1.4.0/go.mod h1:5W1jDvWdnthFJ7fxYX1GMK07BUpI4oskfOqvPteYS6E= @@ -234,8 +243,8 @@ github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7V github.com/keybase/go-keychain v0.0.1 h1:way+bWYa6lDppZoZcgMbYsvC7GxljxrskdNInRtuthU= github.com/keybase/go-keychain v0.0.1/go.mod h1:PdEILRW3i9D8JcdM+FmY6RwkHGnhHxXwkPPMeUgOK1k= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= -github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= +github.com/klauspost/compress v1.18.3 h1:9PJRvfbmTabkOX8moIpXPbMMbYN60bWImDDU7L+/6zw= +github.com/klauspost/compress v1.18.3/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -268,6 +277,8 @@ github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczG github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= github.com/magiconair/properties v1.8.6 h1:5ibWZ6iY0NctNGWo87LalDlEZ6R41TqbbDamhfG/Qzo= github.com/magiconair/properties v1.8.6/go.mod h1:y3VJvCyxH9uVvJTWEGAELF3aiYNyPKd5NZ3oSwXrF60= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= @@ -320,14 +331,15 @@ github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2sz github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw= github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= +github.com/panjf2000/ants/v2 v2.12.0 h1:u9JhESo83i/GkZnhfTNuFMMWcNt7mnV1bGJ6FT4wXH8= +github.com/panjf2000/ants/v2 v2.12.0/go.mod h1:tSQuaNQ6r6NRhPt+IZVUevvDyFMTs+eS4ztZc52uJTY= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= -github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= -github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pires/go-proxyproto v0.11.0 h1:gUQpS85X/VJMdUsYyEgyn59uLJvGqPhJV5YvG68wXH4= +github.com/pires/go-proxyproto v0.11.0/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -345,10 +357,13 @@ github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/quic-go/qpack v0.6.0 h1:g7W+BMYynC1LbYLSqRt8PBg5Tgwxn214ZZR34VIOjz8= github.com/quic-go/qpack v0.6.0/go.mod h1:lUpLKChi8njB4ty2bFLX2x4gzDqXwUpaO1DP9qMDZII= +github.com/quic-go/quic-go v0.59.0 h1:OLJkp1Mlm/aS7dpKgTc6cnpynnD2Xg7C1pwL6vy/SAw= +github.com/quic-go/quic-go v0.59.0/go.mod h1:upnsH4Ju1YkqpLXC305eW3yDZ4NfnNbmQRCMWS58IKU= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/zerolog v1.35.0 h1:VD0ykx7HMiMJytqINBsKcbLS+BJ4WYjz+05us+LRTdI= +github.com/rs/zerolog v1.35.0/go.mod h1:EjML9kdfa/RMA7h/6z6pYmq1ykOuA8/mjWaEvGI+jcw= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/safchain/ethtool v0.3.0 h1:gimQJpsI6sc1yIqP/y8GYgiXn/NjgvpM0RNoWLVVmP0= github.com/safchain/ethtool v0.3.0/go.mod h1:SA9BwrgyAqNo7M+uaL6IYbxpm5wk3L7Mm6ocLW+CJUs= @@ -448,15 +463,23 @@ github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1h github.com/sagernet/smux v1.5.50-sing-box-mod.1/go.mod h1:NjhsCEWedJm7eFLyhuBgIEzwfhRmytrUoiLluxs5Sk8= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854 h1:6uUiZcDRnZSAegryaUGwPC/Fj13JSHwiTftrXhMmYOc= github.com/sagernet/ws v0.0.0-20231204124109-acfe8907c854/go.mod h1:LtfoSK3+NG57tvnVEHgcuBW9ujgE8enPSgzgwStwCAA= +github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0 h1:ws7BIsYLd31Wjifq88BYCHRVlgO+07iwil39s6ERba8= +github.com/shtorm-7/connect-ip-go v1.0.0-extended-1.0.0/go.mod h1:mRwx4w32qQxsWB2kThuHpbo7iNjJiq1jYWubgqEPjHA= github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0 h1:e5s7RKBd2rIPR0StbvZ2vTVtJ5jDTsTk5wtIIapZTRg= github.com/shtorm-7/dnscrypt/v2 v2.4.0-extended-1.0.0/go.mod h1:WpEFV2uhebXb8Jhes/5/fSdpmhGV8TL22RDaeWwV6hI= +github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.0.2 h1:+1tb8QNU0n2p/8Ct0A3/uHYImYXFhnN4lHOJoIdAV2s= +github.com/shtorm-7/go-cache/v2 v2.1.0-extended-1.0.2/go.mod h1:Ek4yz5OK6stwhLKgLsRRYDI+FA+ZWvRJiWLjsi/vMM4= +github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.0 h1:Gr0oINXDOAuQ+eoenfT53UWm1Y47QA7A4PLzgbVFNWo= +github.com/shtorm-7/mtg-multi v1.8.0-extended-1.0.0/go.mod h1:3rvdhwdPABkwKBdvgMt3VwMn9uSq8hpoHRezZ5jRJU0= github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0 h1:a5OoXr3e2ACbM6vDIaaGL44IdHQ6wPjcSoU13vfC0Sw= github.com/shtorm-7/sing-mux v0.3.4-extended-1.0.0/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk= github.com/shtorm-7/tailscale v1.92.4-sing-box-1.13-mod.7-extended-1.0.2 h1:hSMjh97OszszOd8HrzpaYUQH9dWRRBluJCbwQyz8ZOk= github.com/shtorm-7/tailscale v1.92.4-sing-box-1.13-mod.7-extended-1.0.2/go.mod h1:TYIIqO5sZpWq873rLIeO2usszSMUpR3h6WdqVVs65ug= -github.com/shtorm-7/wireguard-go v0.0.2-beta.1-extended-1.3.4 h1:t/2ZxRo8cwvydImFaKuUSDrcZYhX753JiXGe7411krI= -github.com/shtorm-7/wireguard-go v0.0.2-beta.1-extended-1.3.4/go.mod h1:Me2JlCDYHxnd0mnuX7L5LXAeDHCltI7vSKq3eTE6SVE= +github.com/shtorm-7/wireguard-go v0.0.2-beta.1-extended-1.4.0 h1:z25EapzvkpyLgaq2T0o7eeoshBR3U4AhqMOBq1gRtrA= +github.com/shtorm-7/wireguard-go v0.0.2-beta.1-extended-1.4.0/go.mod h1:Me2JlCDYHxnd0mnuX7L5LXAeDHCltI7vSKq3eTE6SVE= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8 h1:TG/diQgUe0pntT/2D9tmUCz4VNwm9MfrtPr0SU2qSX8= +github.com/songgao/water v0.0.0-20200317203138-2b4b6d7c09d8/go.mod h1:P5HUIBuIWKbyjl083/loAegFkfbFNx5i2qEP4CNbm7E= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= @@ -466,6 +489,8 @@ github.com/starifly/sing-vmess v0.2.7-mod.9 h1:xobAmejSbBQ0A3f/EtJ9cJd3m6gK7dDPc github.com/starifly/sing-vmess v0.2.7-mod.9/go.mod h1:5aYoOtYksAyS0NXDm0qKeTYW1yoE1bJVcv+XLcVoyJs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.2.3-0.20181224173747-660f15d67dbb/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= @@ -501,6 +526,12 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/txthinking/runnergroup v0.0.0-20250224021307-5864ffeb65ae h1:ArVM1jICfm7g4E4dBet+KHUFMLuxmj1Nxdp/tr3ByCU= +github.com/txthinking/runnergroup v0.0.0-20250224021307-5864ffeb65ae/go.mod h1:cldYm15/XHcGt7ndItnEWHwFZo7dinU+2QoyjfErhsI= +github.com/txthinking/socks5 v0.0.0-20251011041537-5c31f201a10e h1:xA7GVlbz6teIF4FdvuqwbX6C4tiqNk2PH7FRPIDerao= +github.com/txthinking/socks5 v0.0.0-20251011041537-5c31f201a10e/go.mod h1:ntmMHL/xPq1WLeKiw8p/eRATaae6PiVRNipHFJxI8PM= +github.com/tylertreat/BoomFilters v0.0.0-20251117164519-53813c36cc1b h1:p+bJ3v5uUdEVMCoeFUs+BNJPsqt+Y6BLbDaPfTcbcH8= +github.com/tylertreat/BoomFilters v0.0.0-20251117164519-53813c36cc1b/go.mod h1:OYRfF6eb5wY9VRFkXJH8FFBi3plw2v+giaIu7P054pM= github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= @@ -510,6 +541,10 @@ github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zd github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/yl2chen/cidranger v1.0.2 h1:lbOWZVCG1tCRX4u24kuM1Tb4nHqWkDxwLdoS+SevawU= +github.com/yl2chen/cidranger v1.0.2/go.mod h1:9U1yz7WPYDwf0vpNWFaeRh0bjwz5RVgRy/9UEQfHl0g= +github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= +github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/blake3 v0.2.4 h1:KYQPkhpRtcqh0ssGYcKLG1JYvddkEA8QwCM/yBqhaZI= @@ -534,6 +569,8 @@ go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6 go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.5.2 h1:LbtPTcP8A5k9WPXj54PPPbjcI4Y6lhyOZXn+VS7wNko= +go.uber.org/mock v0.5.2/go.mod h1:wLlUxC2vVTPTaE3UD51E0BGOAElKrILxhVSDYQLld5o= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= diff --git a/include/masque.go b/include/masque.go new file mode 100644 index 00000000..fdd75d98 --- /dev/null +++ b/include/masque.go @@ -0,0 +1,12 @@ +//go:build with_masque + +package include + +import ( + "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/protocol/masque" +) + +func registerMASQUEOutbound(registry *outbound.Registry) { + masque.RegisterOutbound(registry) +} diff --git a/include/masque_stub.go b/include/masque_stub.go new file mode 100644 index 00000000..fc31da68 --- /dev/null +++ b/include/masque_stub.go @@ -0,0 +1,20 @@ +//go:build !with_masque + +package include + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/outbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func registerMASQUEOutbound(registry *outbound.Registry) { + outbound.Register[option.MASQUEOutboundOptions](registry, C.TypeMASQUE, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MASQUEOutboundOptions) (adapter.Outbound, error) { + return nil, E.New(`MASQUE outbound is not included in this build, rebuild with -tags with_masque`) + }) +} diff --git a/include/mtproxy.go b/include/mtproxy.go new file mode 100644 index 00000000..2fba9693 --- /dev/null +++ b/include/mtproxy.go @@ -0,0 +1,12 @@ +//go:build with_mtproxy + +package include + +import ( + "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/protocol/mtproxy" +) + +func registerMTProxyInbound(registry *inbound.Registry) { + mtproxy.RegisterInbound(registry) +} diff --git a/include/mtproxy_stub.go b/include/mtproxy_stub.go new file mode 100644 index 00000000..c06a98fc --- /dev/null +++ b/include/mtproxy_stub.go @@ -0,0 +1,20 @@ +//go:build !with_mtproxy + +package include + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func registerMTProxyInbound(registry *inbound.Registry) { + inbound.Register[option.MTProxyInboundOptions](registry, C.TypeMTProxy, func(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MTProxyInboundOptions) (adapter.Inbound, error) { + return nil, E.New(`MTProxy is not included in this build, rebuild with -tags with_mtproxy`) + }) +} diff --git a/include/registry.go b/include/registry.go index d2b6bbe2..a34d8075 100644 --- a/include/registry.go +++ b/include/registry.go @@ -8,6 +8,7 @@ import ( "github.com/sagernet/sing-box/adapter/endpoint" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/adapter/provider" "github.com/sagernet/sing-box/adapter/service" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" @@ -28,6 +29,7 @@ import ( "github.com/sagernet/sing-box/protocol/mieru" "github.com/sagernet/sing-box/protocol/mixed" "github.com/sagernet/sing-box/protocol/naive" + "github.com/sagernet/sing-box/protocol/parser" "github.com/sagernet/sing-box/protocol/redirect" "github.com/sagernet/sing-box/protocol/shadowsocks" "github.com/sagernet/sing-box/protocol/shadowtls" @@ -36,9 +38,11 @@ import ( "github.com/sagernet/sing-box/protocol/tor" "github.com/sagernet/sing-box/protocol/trojan" "github.com/sagernet/sing-box/protocol/tun" - "github.com/sagernet/sing-box/protocol/tunnel" "github.com/sagernet/sing-box/protocol/vless" "github.com/sagernet/sing-box/protocol/vmess" + "github.com/sagernet/sing-box/protocol/vpn" + localProvider "github.com/sagernet/sing-box/provider/local" + remoteProvider "github.com/sagernet/sing-box/provider/remote" "github.com/sagernet/sing-box/service/admin_panel" "github.com/sagernet/sing-box/service/manager" "github.com/sagernet/sing-box/service/node" @@ -50,7 +54,7 @@ import ( ) func Context(ctx context.Context) context.Context { - return box.Context(ctx, InboundRegistry(), OutboundRegistry(), EndpointRegistry(), DNSTransportRegistry(), ServiceRegistry()) + return box.Context(ctx, InboundRegistry(), OutboundRegistry(), EndpointRegistry(), ProviderRegistry(), DNSTransportRegistry(), ServiceRegistry()) } func InboundRegistry() *inbound.Registry { @@ -77,6 +81,7 @@ func InboundRegistry() *inbound.Registry { registerQUICInbounds(registry) registerStubForRemovedInbounds(registry) + registerMTProxyInbound(registry) return registry } @@ -88,7 +93,7 @@ func OutboundRegistry() *outbound.Registry { block.RegisterOutbound(registry) - group.RegisterFailover(registry) + group.RegisterFallback(registry) group.RegisterSelector(registry) group.RegisterURLTest(registry) @@ -104,12 +109,15 @@ func OutboundRegistry() *outbound.Registry { vless.RegisterOutbound(registry) mieru.RegisterOutbound(registry) anytls.RegisterOutbound(registry) + registerMASQUEOutbound(registry) bond.RegisterOutbound(registry) bandwidth.RegisterOutbound(registry) connection.RegisterOutbound(registry) + parser.RegisterOutbound(registry) + registerQUICOutbounds(registry) registerStubForRemovedOutbounds(registry) @@ -119,8 +127,8 @@ func OutboundRegistry() *outbound.Registry { func EndpointRegistry() *endpoint.Registry { registry := endpoint.NewRegistry() - tunnel.RegisterServerEndpoint(registry) - tunnel.RegisterClientEndpoint(registry) + vpn.RegisterServerEndpoint(registry) + vpn.RegisterClientEndpoint(registry) registerWireGuardEndpoint(registry) registerTailscaleEndpoint(registry) @@ -128,6 +136,16 @@ func EndpointRegistry() *endpoint.Registry { return registry } +func ProviderRegistry() *provider.Registry { + registry := provider.NewRegistry() + + localProvider.RegisterProviderInline(registry) + localProvider.RegisterProviderLocal(registry) + remoteProvider.RegisterProvider(registry) + + return registry +} + func DNSTransportRegistry() *dns.TransportRegistry { registry := dns.NewTransportRegistry() diff --git a/include/wireguard.go b/include/wireguard.go index 40f881d1..43b10200 100644 --- a/include/wireguard.go +++ b/include/wireguard.go @@ -4,10 +4,11 @@ package include import ( "github.com/sagernet/sing-box/adapter/endpoint" + "github.com/sagernet/sing-box/protocol/warp" "github.com/sagernet/sing-box/protocol/wireguard" ) func registerWireGuardEndpoint(registry *endpoint.Registry) { wireguard.RegisterEndpoint(registry) - wireguard.RegisterWARPEndpoint(registry) + warp.RegisterEndpoint(registry) } diff --git a/option/cloudflare.go b/option/cloudflare.go new file mode 100644 index 00000000..6fbd0805 --- /dev/null +++ b/option/cloudflare.go @@ -0,0 +1,9 @@ +package option + +type CloudflareProfile struct { + ID string `json:"id,omitempty"` + AuthToken string `json:"auth_token,omitempty"` + PrivateKey string `json:"private_key,omitempty"` + Recreate bool `json:"recreate,omitempty"` + Detour string `json:"detour,omitempty"` +} diff --git a/option/experimental.go b/option/experimental.go index 0487881b..031a8cbc 100644 --- a/option/experimental.go +++ b/option/experimental.go @@ -11,13 +11,14 @@ type ExperimentalOptions struct { } type CacheFileOptions struct { - Enabled bool `json:"enabled,omitempty"` - Path string `json:"path,omitempty"` - CacheID string `json:"cache_id,omitempty"` - StoreFakeIP bool `json:"store_fakeip,omitempty"` - StoreRDRC bool `json:"store_rdrc,omitempty"` - StoreWARPConfig bool `json:"store_warp_config,omitempty"` - RDRCTimeout badoption.Duration `json:"rdrc_timeout,omitempty"` + Enabled bool `json:"enabled,omitempty"` + Path string `json:"path,omitempty"` + CacheID string `json:"cache_id,omitempty"` + StoreFakeIP bool `json:"store_fakeip,omitempty"` + StoreRDRC bool `json:"store_rdrc,omitempty"` + StoreWARPConfig bool `json:"store_warp_config,omitempty"` + StoreMASQUEConfig bool `json:"store_masque_config,omitempty"` + RDRCTimeout badoption.Duration `json:"rdrc_timeout,omitempty"` } type ClashAPIOptions struct { diff --git a/option/failover.go b/option/failover.go new file mode 100644 index 00000000..73bdeb34 --- /dev/null +++ b/option/failover.go @@ -0,0 +1,9 @@ +package option + +type FailoverInboundOptions struct { + Inbounds []Inbound `json:"inbounds"` +} + +type FailoverOutboundOptions struct { + Outbounds []Outbound `json:"outbounds"` +} diff --git a/option/group.go b/option/group.go index d550b233..2fb8e65b 100644 --- a/option/group.go +++ b/option/group.go @@ -3,13 +3,13 @@ package option import "github.com/sagernet/sing/common/json/badoption" type SelectorOutboundOptions struct { - Outbounds []string `json:"outbounds"` - Default string `json:"default,omitempty"` - InterruptExistConnections bool `json:"interrupt_exist_connections,omitempty"` + GroupCommonOption + Default string `json:"default,omitempty"` + InterruptExistConnections bool `json:"interrupt_exist_connections,omitempty"` } type URLTestOutboundOptions struct { - Outbounds []string `json:"outbounds"` + GroupCommonOption URL string `json:"url,omitempty"` Interval badoption.Duration `json:"interval,omitempty"` Tolerance uint16 `json:"tolerance,omitempty"` @@ -17,6 +17,14 @@ type URLTestOutboundOptions struct { InterruptExistConnections bool `json:"interrupt_exist_connections,omitempty"` } -type FailoverOutboundOptions struct { +type FallbackOutboundOptions struct { Outbounds []string `json:"outbounds"` } + +type GroupCommonOption struct { + Outbounds []string `json:"outbounds"` + Providers []string `json:"providers"` + Exclude *badoption.Regexp `json:"exclude,omitempty"` + Include *badoption.Regexp `json:"include,omitempty"` + UseAllProviders bool `json:"use_all_providers,omitempty"` +} diff --git a/option/masque.go b/option/masque.go new file mode 100644 index 00000000..83fa849a --- /dev/null +++ b/option/masque.go @@ -0,0 +1,32 @@ +package option + +import ( + "github.com/sagernet/sing/common/json/badoption" +) + +type MASQUEOutboundOptions struct { + UseHTTP2 bool `json:"use_http2,omitempty"` + UseIPv6 bool `json:"use_ipv6,omitempty"` + Profile CloudflareProfile `json:"profile,omitempty"` + UDPTimeout badoption.Duration `json:"udp_timeout,omitempty"` + UDPKeepalivePeriod badoption.Duration `json:"udp_keepalive_period,omitempty"` + UDPInitialPacketSize uint16 `json:"udp_initial_packet_size,omitempty"` + ReconnectDelay badoption.Duration `json:"reconnect_delay,omitempty"` + MASQUEOutboundTLSOptions + DialerOptions +} + +type MASQUEOutboundTLSOptions struct { + Insecure bool `json:"insecure,omitempty"` + CipherSuites badoption.Listable[string] `json:"cipher_suites,omitempty"` + CurvePreferences badoption.Listable[CurvePreference] `json:"curve_preferences,omitempty"` + Fragment bool `json:"fragment,omitempty"` + FragmentFallbackDelay badoption.Duration `json:"fragment_fallback_delay,omitempty"` + RecordFragment bool `json:"record_fragment,omitempty"` + KernelTx bool `json:"kernel_tx,omitempty"` + KernelRx bool `json:"kernel_rx,omitempty"` +} + +type MASQUEOutboundTLSOptionsContainer struct { + TLS *OutboundTLSOptions `json:"tls,omitempty"` +} diff --git a/option/mtproxy.go b/option/mtproxy.go new file mode 100644 index 00000000..4e903a4a --- /dev/null +++ b/option/mtproxy.go @@ -0,0 +1,89 @@ +package option + +import ( + "time" + + "github.com/sagernet/sing/common/json/badoption" +) + +type MTProxyInboundOptions struct { + ListenOptions + Users []MTProxyUser `json:"users,omitempty"` + Concurrency uint `json:"concurrency,omitempty"` + DomainFrontingPort uint `json:"domain_fronting_port,omitempty"` + DomainFrontingIP string `json:"domain_fronting_ip,omitempty"` + DomainFrontingProxyProtocol bool `json:"domain_fronting_proxy_protocol,omitempty"` + PreferIP string `json:"prefer_ip,omitempty"` + AutoUpdate bool `json:"auto_update,omitempty"` + AllowFallbackOnUnknownDC bool `json:"allow_fallback_on_unknown_dc,omitempty"` + TolerateTimeSkewness badoption.Duration `json:"tolerate_time_skewness,omitempty"` + IdleTimeout badoption.Duration `json:"idle_timeout,omitempty"` + HandshakeTimeout badoption.Duration `json:"handshake_timeout,omitempty"` + DoppelGangerURLs []string `json:"doppelganger_urls,omitempty"` + DoppelGangerPerRaid uint `json:"doppelganger_per_raid,omitempty"` + DoppelGangerEach badoption.Duration `json:"doppelganger_each,omitempty"` + DoppelGangerDRS bool `json:"doppelganger_drs,omitempty"` + ThrottleMaxConnections uint `json:"throttle_max_connections,omitempty"` + ThrottleCheckInterval badoption.Duration `json:"throttle_check_interval,omitempty"` +} + +func (o *MTProxyInboundOptions) GetConcurrency() uint { + if o.Concurrency == 0 { + return 8192 + } + return o.Concurrency +} + +func (o *MTProxyInboundOptions) GetDomainFrontingPort() uint { + if o.DomainFrontingPort == 0 { + return 443 + } + return o.DomainFrontingPort +} + +func (o *MTProxyInboundOptions) GetPreferIP() string { + if o.PreferIP == "" { + return "prefer-ipv4" + } + return o.PreferIP +} + +func (o *MTProxyInboundOptions) GetIdleTimeout() time.Duration { + if o.IdleTimeout == 0 { + return 5 * time.Minute + } + return o.IdleTimeout.Build() +} + +func (o *MTProxyInboundOptions) GetHandshakeTimeout() time.Duration { + if o.HandshakeTimeout == 0 { + return 10 * time.Second + } + return o.HandshakeTimeout.Build() +} + +func (o *MTProxyInboundOptions) GetDoppelGangerPerRaid() uint { + if o.DoppelGangerPerRaid == 0 { + return 10 + } + return o.DoppelGangerPerRaid +} + +func (o *MTProxyInboundOptions) GetDoppelGangerEach() time.Duration { + if o.HandshakeTimeout == 0 { + return 6 * time.Hour + } + return o.DoppelGangerEach.Build() +} + +func (o *MTProxyInboundOptions) GetThrottleCheckInterval() time.Duration { + if o.ThrottleCheckInterval == 0 { + return 5 * time.Second + } + return o.ThrottleCheckInterval.Build() +} + +type MTProxyUser struct { + Name string `json:"name"` + Secret string `json:"secret"` +} diff --git a/option/options.go b/option/options.go index 8bebd48f..fcca94c3 100644 --- a/option/options.go +++ b/option/options.go @@ -19,6 +19,7 @@ type _Options struct { Endpoints []Endpoint `json:"endpoints,omitempty"` Inbounds []Inbound `json:"inbounds,omitempty"` Outbounds []Outbound `json:"outbounds,omitempty"` + Providers []Provider `json:"providers,omitempty"` Route *RouteOptions `json:"route,omitempty"` Services []Service `json:"services,omitempty"` Experimental *ExperimentalOptions `json:"experimental,omitempty"` diff --git a/option/parser.go b/option/parser.go new file mode 100644 index 00000000..db916c8d --- /dev/null +++ b/option/parser.go @@ -0,0 +1,6 @@ +package option + +type ParserOutboundOptions struct { + DialerOptions + Link string `json:"link"` +} diff --git a/option/provider.go b/option/provider.go new file mode 100644 index 00000000..656036e4 --- /dev/null +++ b/option/provider.go @@ -0,0 +1,75 @@ +package option + +import ( + "context" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" + "github.com/sagernet/sing/common/json/badoption" + "github.com/sagernet/sing/service" +) + +type ProviderOptionsRegistry interface { + CreateOptions(providerType string) (any, bool) +} +type _Provider struct { + Type string `json:"type"` + Tag string `json:"tag,omitempty"` + Options any `json:"-"` +} + +type Provider _Provider + +func (h *Provider) MarshalJSONContext(ctx context.Context) ([]byte, error) { + return badjson.MarshallObjectsContext(ctx, (*_Provider)(h), h.Options) +} + +func (h *Provider) UnmarshalJSONContext(ctx context.Context, content []byte) error { + err := json.UnmarshalContext(ctx, content, (*_Provider)(h)) + if err != nil { + return err + } + registry := service.FromContext[ProviderOptionsRegistry](ctx) + if registry == nil { + return E.New("missing provider options registry in context") + } + options, loaded := registry.CreateOptions(h.Type) + if !loaded { + return E.New("unknown provider type: ", h.Type) + } + err = badjson.UnmarshallExcludedContext(ctx, content, (*_Provider)(h), options) + if err != nil { + return err + } + h.Options = options + return nil +} + +type ProviderLocalOptions struct { + Path string `json:"path"` + HealthCheck ProviderHealthCheckOptions `json:"health_check,omitempty"` +} + +type ProviderRemoteOptions struct { + URL string `json:"url"` + UserAgent string `json:"user_agent,omitempty"` + DownloadDetour string `json:"download_detour,omitempty"` + UpdateInterval badoption.Duration `json:"update_interval,omitempty"` + + Exclude *badoption.Regexp `json:"exclude,omitempty"` + Include *badoption.Regexp `json:"include,omitempty"` + HealthCheck ProviderHealthCheckOptions `json:"health_check,omitempty"` +} + +type ProviderInlineOptions struct { + Outbounds []Outbound `json:"outbounds,omitempty"` + HealthCheck ProviderHealthCheckOptions `json:"health_check,omitempty"` +} + +type ProviderHealthCheckOptions struct { + Enabled bool `json:"enabled,omitempty"` + URL string `json:"url,omitempty"` + Interval badoption.Duration `json:"interval,omitempty"` + Timeout badoption.Duration `json:"timeout,omitempty"` +} diff --git a/option/rule.go b/option/rule.go index ba732616..3e7fd877 100644 --- a/option/rule.go +++ b/option/rule.go @@ -88,8 +88,6 @@ type RawDefaultRule struct { SourcePortRange badoption.Listable[string] `json:"source_port_range,omitempty"` Port badoption.Listable[uint16] `json:"port,omitempty"` PortRange badoption.Listable[string] `json:"port_range,omitempty"` - TunnelSource badoption.Listable[string] `json:"tunnel_source,omitempty"` - TunnelDestination badoption.Listable[string] `json:"tunnel_destination,omitempty"` ProcessName badoption.Listable[string] `json:"process_name,omitempty"` ProcessPath badoption.Listable[string] `json:"process_path,omitempty"` ProcessPathRegex badoption.Listable[string] `json:"process_path_regex,omitempty"` diff --git a/option/rule_action.go b/option/rule_action.go index bfe12625..8ecb0dda 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -155,9 +155,10 @@ type RouteActionOptions struct { } type RawRouteOptionsActionOptions struct { - OverrideAddress string `json:"override_address,omitempty"` - OverridePort uint16 `json:"override_port,omitempty"` - OverrideTunnelDestination string `json:"override_tunnel_destination,omitempty"` + OverrideAddress string `json:"override_address,omitempty"` + OverridePort uint16 `json:"override_port,omitempty"` + + OverrideGateway string `json:"override_gateway,omitempty"` NetworkStrategy *NetworkStrategy `json:"network_strategy,omitempty"` FallbackDelay uint32 `json:"fallback_delay,omitempty"` diff --git a/option/rule_dns.go b/option/rule_dns.go index d34cba23..dbc16578 100644 --- a/option/rule_dns.go +++ b/option/rule_dns.go @@ -90,8 +90,6 @@ type RawDefaultDNSRule struct { SourcePortRange badoption.Listable[string] `json:"source_port_range,omitempty"` Port badoption.Listable[uint16] `json:"port,omitempty"` PortRange badoption.Listable[string] `json:"port_range,omitempty"` - TunnelSource badoption.Listable[string] `json:"tunnel_source,omitempty"` - TunnelDestination badoption.Listable[string] `json:"tunnel_destination,omitempty"` ProcessName badoption.Listable[string] `json:"process_name,omitempty"` ProcessPath badoption.Listable[string] `json:"process_path,omitempty"` ProcessPathRegex badoption.Listable[string] `json:"process_path_regex,omitempty"` diff --git a/option/rule_set.go b/option/rule_set.go index 8155055f..b0634228 100644 --- a/option/rule_set.go +++ b/option/rule_set.go @@ -194,8 +194,6 @@ type DefaultHeadlessRule struct { SourcePortRange badoption.Listable[string] `json:"source_port_range,omitempty"` Port badoption.Listable[uint16] `json:"port,omitempty"` PortRange badoption.Listable[string] `json:"port_range,omitempty"` - TunnelSource badoption.Listable[string] `json:"tunnel_source,omitempty"` - TunnelDestination badoption.Listable[string] `json:"tunnel_destination,omitempty"` ProcessName badoption.Listable[string] `json:"process_name,omitempty"` ProcessPath badoption.Listable[string] `json:"process_path,omitempty"` ProcessPathRegex badoption.Listable[string] `json:"process_path_regex,omitempty"` diff --git a/option/tunnel.go b/option/tunnel.go deleted file mode 100644 index cc1df36b..00000000 --- a/option/tunnel.go +++ /dev/null @@ -1,21 +0,0 @@ -package option - -import "github.com/sagernet/sing/common/json/badoption" - -type TunnelClientEndpointOptions struct { - UUID string `json:"uuid"` - Key string `json:"key"` - Outbound Outbound `json:"outbound"` -} - -type TunnelServerEndpointOptions struct { - UUID string `json:"uuid"` - Users []TunnelUser `json:"users"` - Inbound Inbound `json:"inbound"` - ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` -} - -type TunnelUser struct { - UUID string `json:"uuid"` - Key string `json:"key"` -} diff --git a/option/vpn.go b/option/vpn.go new file mode 100644 index 00000000..49139b9f --- /dev/null +++ b/option/vpn.go @@ -0,0 +1,25 @@ +package option + +import ( + "net/netip" + + "github.com/sagernet/sing/common/json/badoption" +) + +type VPNClientEndpointOptions struct { + Address netip.Addr `json:"address"` + Key string `json:"key"` + Outbound Outbound `json:"outbound"` +} + +type VPNServerEndpointOptions struct { + Address netip.Addr `json:"address"` + Users []VPNUser `json:"users"` + Inbounds []Inbound `json:"inbounds"` + ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` +} + +type VPNUser struct { + Address netip.Addr `json:"address"` + Key string `json:"key"` +} diff --git a/option/warp.go b/option/warp.go new file mode 100644 index 00000000..f1ede310 --- /dev/null +++ b/option/warp.go @@ -0,0 +1,18 @@ +package option + +import "github.com/sagernet/sing/common/json/badoption" + +type WARPEndpointOptions struct { + System bool `json:"system,omitempty"` + Name string `json:"name,omitempty"` + ListenPort uint16 `json:"listen_port,omitempty"` + UDPTimeout badoption.Duration `json:"udp_timeout,omitempty"` + PersistentKeepaliveInterval uint16 `json:"persistent_keepalive_interval,omitempty"` + Reserved []uint8 `json:"reserved,omitempty"` + Workers int `json:"workers,omitempty"` + PreallocatedBuffersPerPool uint32 `json:"preallocated_buffers_per_pool,omitempty"` + DisablePauses bool `json:"disable_pauses,omitempty"` + Amnezia *WireGuardAmnezia `json:"amnezia,omitempty"` + Profile CloudflareProfile `json:"profile,omitempty"` + DialerOptions +} diff --git a/option/wireguard.go b/option/wireguard.go index 0a1af69b..04f8b039 100644 --- a/option/wireguard.go +++ b/option/wireguard.go @@ -33,29 +33,6 @@ type WireGuardPeer struct { Reserved []uint8 `json:"reserved,omitempty"` } -type WireGuardWARPEndpointOptions struct { - System bool `json:"system,omitempty"` - Name string `json:"name,omitempty"` - ListenPort uint16 `json:"listen_port,omitempty"` - UDPTimeout badoption.Duration `json:"udp_timeout,omitempty"` - PersistentKeepaliveInterval uint16 `json:"persistent_keepalive_interval,omitempty"` - Reserved []uint8 `json:"reserved,omitempty"` - Workers int `json:"workers,omitempty"` - PreallocatedBuffersPerPool uint32 `json:"preallocated_buffers_per_pool,omitempty"` - DisablePauses bool `json:"disable_pauses,omitempty"` - Amnezia *WireGuardAmnezia `json:"amnezia,omitempty"` - Profile WARPProfile `json:"profile,omitempty"` - DialerOptions -} - -type WARPProfile struct { - ID string `json:"id,omitempty"` - PrivateKey string `json:"private_key,omitempty"` - AuthToken string `json:"auth_token,omitempty"` - Recreate bool `json:"recreate,omitempty"` - Detour string `json:"detour,omitempty"` -} - type WireGuardAmnezia struct { JC int `json:"jc,omitempty"` JMin int `json:"jmin,omitempty"` diff --git a/parser/clash/anytls.go b/parser/clash/anytls.go new file mode 100644 index 00000000..5f6356bf --- /dev/null +++ b/parser/clash/anytls.go @@ -0,0 +1,30 @@ +package clash + +import ( + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json/badoption" +) + +type AnyTLSOption struct { + DialerOptions `yaml:",inline"` + ServerOptions `yaml:",inline"` + TLSOptions `yaml:",inline"` + Password string `yaml:"password"` + UDP bool `yaml:"udp,omitempty"` + IdleSessionCheckInterval int `yaml:"idle-session-check-interval,omitempty"` + IdleSessionTimeout int `yaml:"idle-session-timeout,omitempty"` + MinIdleSession int `yaml:"min-idle-session,omitempty"` +} + +func (a *AnyTLSOption) Build() any { + a.TLS = true + return &option.AnyTLSOutboundOptions{ + DialerOptions: a.DialerOptions.Build(), + ServerOptions: a.ServerOptions.Build(), + OutboundTLSOptionsContainer: clashTLSOptions(a.Server, &a.TLSOptions), + Password: a.Password, + IdleSessionCheckInterval: badoption.Duration(a.IdleSessionCheckInterval), + IdleSessionTimeout: badoption.Duration(a.IdleSessionTimeout), + MinIdleSession: a.MinIdleSession, + } +} diff --git a/parser/clash/base.go b/parser/clash/base.go new file mode 100644 index 00000000..cc275c3b --- /dev/null +++ b/parser/clash/base.go @@ -0,0 +1,181 @@ +package clash + +import ( + "encoding/base64" + "strings" + + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json/badoption" +) + +type HTTPOptions struct { + Method string `yaml:"method,omitempty"` + Path []string `yaml:"path,omitempty"` + Headers badoption.HTTPHeader `yaml:"headers,omitempty"` +} + +type HTTP2Options struct { + Host []string `yaml:"host,omitempty"` + Path string `yaml:"path,omitempty"` +} + +type GrpcOptions struct { + GrpcServiceName string `yaml:"grpc-service-name,omitempty"` +} + +type WSOptions struct { + Path string `yaml:"path,omitempty"` + Headers map[string]string `yaml:"headers,omitempty"` + MaxEarlyData int `yaml:"max-early-data,omitempty"` + EarlyDataHeaderName string `yaml:"early-data-header-name,omitempty"` + V2rayHttpUpgrade bool `yaml:"v2ray-http-upgrade,omitempty"` +} + +type MuxOptions struct { + Enabled bool `yaml:"enabled,omitempty"` + Protocol string `yaml:"protocol,omitempty"` + MaxConnections int `yaml:"max-connections,omitempty"` + MinStreams int `yaml:"min-streams,omitempty"` + MaxStreams int `yaml:"max-streams,omitempty"` + Padding bool `yaml:"padding,omitempty"` + BrutalOpts *BrutalOptions `yaml:"brutal-opts,omitempty"` +} + +func (s *MuxOptions) Build() *option.OutboundMultiplexOptions { + if s == nil { + return nil + } + return &option.OutboundMultiplexOptions{ + Enabled: s.Enabled, + Protocol: s.Protocol, + MaxConnections: s.MaxConnections, + MinStreams: s.MinStreams, + MaxStreams: s.MaxStreams, + Padding: s.Padding, + Brutal: s.BrutalOpts.Build(), + } +} + +type BrutalOptions struct { + Enabled bool `yaml:"enabled,omitempty"` + Up string `yaml:"up,omitempty"` + Down string `yaml:"down,omitempty"` +} + +func (b *BrutalOptions) Build() *option.BrutalOptions { + if b == nil { + return nil + } + return &option.BrutalOptions{ + Enabled: b.Enabled, + UpMbps: clashSpeedToIntMbps(b.Up), + DownMbps: clashSpeedToIntMbps(b.Down), + } +} + +type RealityOptions struct { + PublicKey string `yaml:"public-key"` + ShortID string `yaml:"short-id"` +} + +func (r *RealityOptions) Build() *option.OutboundRealityOptions { + if r == nil { + return nil + } + return &option.OutboundRealityOptions{ + Enabled: true, + PublicKey: r.PublicKey, + ShortID: r.ShortID, + } +} + +type ECHOptions struct { + Enable bool `yaml:"enable,omitempty"` + Config string `yaml:"config,omitempty"` +} + +func (e *ECHOptions) Build() *option.OutboundECHOptions { + if e == nil { + return nil + } + list, err := base64.StdEncoding.DecodeString(e.Config) + if err != nil { + return nil + } + return &option.OutboundECHOptions{ + Enabled: e.Enable, + Config: trimStringArray(strings.Split(string(list), "\n")), + } +} + +type TLSOptions struct { + TLS bool `yaml:"tls,omitempty"` + SNI string `yaml:"sni,omitempty"` + SkipCertVerify bool `yaml:"skip-cert-verify,omitempty"` + ALPN []string `yaml:"alpn,omitempty"` + ClientFingerprint string `yaml:"client-fingerprint,omitempty"` + CustomCA string `yaml:"ca,omitempty"` + CustomCAString string `yaml:"ca-str,omitempty"` + Certificate string `yaml:"certificate,omitempty"` + PrivateKey string `yaml:"private-key,omitempty"` + ECHOpts *ECHOptions `yaml:"ech-opts,omitempty"` + RealityOpts *RealityOptions `yaml:"reality-opts,omitempty"` +} + +func (t *TLSOptions) Build() *option.OutboundTLSOptions { + if t == nil { + return nil + } + options := &option.OutboundTLSOptions{ + Enabled: t.TLS, + ServerName: t.SNI, + Insecure: t.SkipCertVerify, + ALPN: t.ALPN, + UTLS: clashClientFingerprint(t.ClientFingerprint), + Certificate: trimStringArray(strings.Split(t.CustomCAString, "\n")), + CertificatePath: t.CustomCA, + ECH: t.ECHOpts.Build(), + Reality: t.RealityOpts.Build(), + } + if strings.HasPrefix(t.Certificate, "-----BEGIN ") { + options.ClientCertificate = trimStringArray(strings.Split(t.Certificate, "\n")) + } else { + options.ClientCertificatePath = t.Certificate + } + if strings.HasPrefix(t.PrivateKey, "-----BEGIN ") { + options.ClientKey = trimStringArray(strings.Split(t.PrivateKey, "\n")) + } else { + options.ClientKeyPath = t.PrivateKey + } + return options +} + +type DialerOptions struct { + TFO bool `yaml:"tfo,omitempty"` + MPTCP bool `yaml:"mptcp,omitempty"` + Interface string `yaml:"interface-name,omitempty"` + RoutingMark int `yaml:"routing-mark,omitempty"` + DialerProxy string `yaml:"dialer-proxy,omitempty"` +} + +func (b *DialerOptions) Build() option.DialerOptions { + return option.DialerOptions{ + Detour: b.DialerProxy, + BindInterface: b.Interface, + TCPFastOpen: b.TFO, + TCPMultiPath: b.MPTCP, + RoutingMark: option.FwMark(b.RoutingMark), + } +} + +type ServerOptions struct { + Server string `yaml:"server"` + Port int `yaml:"port"` +} + +func (s *ServerOptions) Build() option.ServerOptions { + return option.ServerOptions{ + Server: s.Server, + ServerPort: uint16(s.Port), + } +} diff --git a/parser/clash/http.go b/parser/clash/http.go new file mode 100644 index 00000000..f176c3ed --- /dev/null +++ b/parser/clash/http.go @@ -0,0 +1,23 @@ +package clash + +import "github.com/sagernet/sing-box/option" + +type HttpOption struct { + DialerOptions `yaml:",inline"` + ServerOptions `yaml:",inline"` + *TLSOptions `yaml:",inline"` + UserName string `yaml:"username,omitempty"` + Password string `yaml:"password,omitempty"` + Headers map[string]string `yaml:"headers,omitempty"` +} + +func (h *HttpOption) Build() any { + return &option.HTTPOutboundOptions{ + DialerOptions: h.DialerOptions.Build(), + ServerOptions: h.ServerOptions.Build(), + Username: h.UserName, + Password: h.Password, + OutboundTLSOptionsContainer: clashTLSOptions(h.Server, h.TLSOptions), + Headers: clashHeaders(h.Headers), + } +} diff --git a/parser/clash/hysteria.go b/parser/clash/hysteria.go new file mode 100644 index 00000000..9d35ccdd --- /dev/null +++ b/parser/clash/hysteria.go @@ -0,0 +1,47 @@ +package clash + +import ( + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json/badoption" +) + +type HysteriaOption struct { + DialerOptions `yaml:",inline"` + ServerOptions `yaml:",inline"` + TLSOptions `yaml:",inline"` + Ports string `yaml:"ports,omitempty"` + Up string `yaml:"up"` + UpSpeed int `yaml:"up-speed,omitempty"` // compatible with Stash + Down string `yaml:"down"` + DownSpeed int `yaml:"down-speed,omitempty"` // compatible with Stash + Auth string `yaml:"auth,omitempty"` + AuthString string `yaml:"auth-str,omitempty"` + Obfs string `yaml:"obfs,omitempty"` + ReceiveWindowConn int `yaml:"recv-window-conn,omitempty"` + ReceiveWindow int `yaml:"recv-window,omitempty"` + DisableMTUDiscovery bool `yaml:"disable-mtu-discovery,omitempty"` + FastOpen bool `yaml:"fast-open,omitempty"` + HopInterval int `yaml:"hop-interval,omitempty"` +} + +func (h *HysteriaOption) Build() any { + h.TLS = true + h.TFO = h.FastOpen + return &option.HysteriaOutboundOptions{ + DialerOptions: h.DialerOptions.Build(), + ServerOptions: h.ServerOptions.Build(), + ServerPorts: clashPorts(h.Ports), + HopInterval: badoption.Duration(h.HopInterval), + Up: clashSpeedToNetworkBytes(h.Up), + UpMbps: h.UpSpeed, + Down: clashSpeedToNetworkBytes(h.Down), + DownMbps: h.DownSpeed, + Obfs: h.Obfs, + Auth: []byte(h.Auth), + AuthString: h.AuthString, + ReceiveWindowConn: uint64(h.ReceiveWindowConn), + ReceiveWindow: uint64(h.ReceiveWindow), + DisableMTUDiscovery: h.DisableMTUDiscovery, + OutboundTLSOptionsContainer: clashTLSOptions(h.Server, &h.TLSOptions), + } +} diff --git a/parser/clash/hysteria2.go b/parser/clash/hysteria2.go new file mode 100644 index 00000000..e2f37ba2 --- /dev/null +++ b/parser/clash/hysteria2.go @@ -0,0 +1,34 @@ +package clash + +import ( + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json/badoption" +) + +type Hysteria2Option struct { + DialerOptions `yaml:",inline"` + ServerOptions `yaml:",inline"` + TLSOptions `yaml:",inline"` + Ports string `yaml:"ports,omitempty"` + HopInterval int `yaml:"hop-interval,omitempty"` + Up string `yaml:"up,omitempty"` + Down string `yaml:"down,omitempty"` + Password string `yaml:"password,omitempty"` + Obfs string `yaml:"obfs,omitempty"` + ObfsPassword string `yaml:"obfs-password,omitempty"` +} + +func (h *Hysteria2Option) Build() any { + h.TLS = true + return &option.Hysteria2OutboundOptions{ + DialerOptions: h.DialerOptions.Build(), + ServerOptions: h.ServerOptions.Build(), + ServerPorts: clashPorts(h.Ports), + HopInterval: badoption.Duration(h.HopInterval), + UpMbps: clashSpeedToIntMbps(h.Up), + DownMbps: clashSpeedToIntMbps(h.Down), + Obfs: clashHysteria2Obfs(h.Obfs, h.ObfsPassword), + Password: h.Password, + OutboundTLSOptionsContainer: clashTLSOptions(h.Server, &h.TLSOptions), + } +} diff --git a/parser/clash/parser.go b/parser/clash/parser.go new file mode 100644 index 00000000..61efa865 --- /dev/null +++ b/parser/clash/parser.go @@ -0,0 +1,106 @@ +package clash + +import ( + "context" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + + "gopkg.in/yaml.v3" +) + +type ClashConfig struct { + Proxies []ClashProxy `yaml:"proxies"` +} + +type _ClashProxy struct { + Name string `yaml:"name"` + Type string `yaml:"type"` + Options Proxy `yaml:"-"` + + SingType string `yaml:"-"` +} +type ClashProxy _ClashProxy + +type Proxy interface { + Build() any +} + +func (c *ClashProxy) UnmarshalYAML(value *yaml.Node) error { + err := value.Decode((*_ClashProxy)(c)) + if err != nil { + return err + } + var options Proxy + switch c.Type { + case "ss": + c.SingType = C.TypeShadowsocks + options = &ShadowSocksOption{} + case "tuic": + c.SingType = C.TypeTUIC + options = &TuicOption{} + case "vmess": + c.SingType = C.TypeVMess + options = &VmessOption{} + case "vless": + c.SingType = C.TypeVLESS + options = &VlessOption{} + case "socks5": + c.SingType = C.TypeSOCKS + options = &Socks5Option{} + case "http": + c.SingType = C.TypeHTTP + options = &HttpOption{} + case "trojan": + c.SingType = C.TypeTrojan + options = &TrojanOption{} + case "hysteria": + c.SingType = C.TypeHysteria + options = &HysteriaOption{} + case "hysteria2": + c.SingType = C.TypeHysteria2 + options = &Hysteria2Option{} + case "ssh": + c.SingType = C.TypeSSH + options = &SSHOption{} + case "anytls": + c.SingType = C.TypeAnyTLS + options = &AnyTLSOption{} + default: + return nil + } + err = value.Decode(options) + if err != nil { + return err + } + c.Options = options + return nil +} + +func (c *ClashProxy) Build() option.Outbound { + outbound := option.Outbound{ + Tag: c.Name, + Type: c.SingType, + } + if c.Options != nil { + outbound.Options = c.Options.Build() + } + return outbound +} + +func ParseClashSubscription(_ context.Context, content string) ([]option.Outbound, error) { + config := &ClashConfig{} + err := yaml.Unmarshal([]byte(content), &config) + if err != nil { + return nil, E.Cause(err, "parse clash config") + } + outbounds := common.FilterIsInstance(config.Proxies, func(proxy ClashProxy) (option.Outbound, bool) { + if proxy.SingType == "" { + return option.Outbound{}, false + } + return proxy.Build(), true + }) + return outbounds, nil +} diff --git a/parser/clash/shadowsocks.go b/parser/clash/shadowsocks.go new file mode 100644 index 00000000..f45d59cf --- /dev/null +++ b/parser/clash/shadowsocks.go @@ -0,0 +1,51 @@ +package clash + +import ( + "strings" + + "github.com/sagernet/sing-box/option" + F "github.com/sagernet/sing/common/format" +) + +type ShadowSocksOption struct { + DialerOptions `yaml:",inline"` + ServerOptions `yaml:",inline"` + Password string `yaml:"password"` + Cipher string `yaml:"cipher"` + UDP bool `yaml:"udp,omitempty"` + Plugin string `yaml:"plugin,omitempty"` + PluginOpts map[string]any `yaml:"plugin-opts,omitempty"` + UDPOverTCP bool `yaml:"udp-over-tcp,omitempty"` + UDPOverTCPVersion int `yaml:"udp-over-tcp-version,omitempty"` + MuxOpts *MuxOptions `yaml:"smux,omitempty"` +} + +func (s *ShadowSocksOption) Build() any { + return &option.ShadowsocksOutboundOptions{ + DialerOptions: s.DialerOptions.Build(), + ServerOptions: s.ServerOptions.Build(), + Password: s.Password, + Method: clashShadowsocksCipher(s.Cipher), + Plugin: clashPluginName(s.Plugin), + PluginOptions: clashPluginOptions(s.Plugin, s.PluginOpts), + Network: clashNetworks(s.UDP), + UDPOverTCP: &option.UDPOverTCPOptions{ + Enabled: s.UDPOverTCP, + Version: uint8(s.UDPOverTCPVersion), + }, + Multiplex: s.MuxOpts.Build(), + } +} + +type shadowsocksPluginOptionsBuilder map[string]any + +func (o shadowsocksPluginOptionsBuilder) Build() string { + var opts []string + for key, value := range o { + if value == nil { + continue + } + opts = append(opts, F.ToString(key, "=", value)) + } + return strings.Join(opts, ";") +} diff --git a/parser/clash/socks5.go b/parser/clash/socks5.go new file mode 100644 index 00000000..7c1dd390 --- /dev/null +++ b/parser/clash/socks5.go @@ -0,0 +1,21 @@ +package clash + +import "github.com/sagernet/sing-box/option" + +type Socks5Option struct { + DialerOptions `yaml:",inline"` + ServerOptions `yaml:",inline"` + UserName string `yaml:"username,omitempty"` + Password string `yaml:"password,omitempty"` + UDP bool `yaml:"udp,omitempty"` +} + +func (s *Socks5Option) Build() any { + return &option.SOCKSOutboundOptions{ + DialerOptions: s.DialerOptions.Build(), + ServerOptions: s.ServerOptions.Build(), + Username: s.UserName, + Password: s.Password, + Network: clashNetworks(s.UDP), + } +} diff --git a/parser/clash/ssh.go b/parser/clash/ssh.go new file mode 100644 index 00000000..7010634b --- /dev/null +++ b/parser/clash/ssh.go @@ -0,0 +1,36 @@ +package clash + +import ( + "strings" + + "github.com/sagernet/sing-box/option" +) + +type SSHOption struct { + DialerOptions `yaml:",inline"` + ServerOptions `yaml:",inline"` + UserName string `yaml:"username"` + Password string `yaml:"password,omitempty"` + PrivateKey string `yaml:"private-key,omitempty"` + PrivateKeyPassphrase string `yaml:"private-key-passphrase,omitempty"` + HostKey []string `yaml:"host-key,omitempty"` + HostKeyAlgorithms []string `yaml:"host-key-algorithms,omitempty"` +} + +func (s *SSHOption) Build() any { + options := &option.SSHOutboundOptions{ + DialerOptions: s.DialerOptions.Build(), + ServerOptions: s.ServerOptions.Build(), + User: s.UserName, + Password: s.Password, + PrivateKeyPassphrase: s.PrivateKeyPassphrase, + HostKey: s.HostKey, + HostKeyAlgorithms: s.HostKeyAlgorithms, + } + if strings.Contains(s.PrivateKey, "PRIVATE KEY") { + options.PrivateKey = trimStringArray(strings.Split(s.PrivateKey, "\n")) + } else { + options.PrivateKeyPath = s.PrivateKey + } + return options +} diff --git a/parser/clash/trojan.go b/parser/clash/trojan.go new file mode 100644 index 00000000..5d391c7a --- /dev/null +++ b/parser/clash/trojan.go @@ -0,0 +1,28 @@ +package clash + +import "github.com/sagernet/sing-box/option" + +type TrojanOption struct { + DialerOptions `yaml:",inline"` + ServerOptions `yaml:",inline"` + TLSOptions `yaml:",inline"` + Password string `yaml:"password"` + UDP bool `yaml:"udp,omitempty"` + Network string `yaml:"network,omitempty"` + GrpcOpts GrpcOptions `yaml:"grpc-opts,omitempty"` + WSOpts WSOptions `yaml:"ws-opts,omitempty"` + MuxOpts *MuxOptions `yaml:"smux,omitempty"` +} + +func (t *TrojanOption) Build() any { + t.TLS = true + return &option.TrojanOutboundOptions{ + DialerOptions: t.DialerOptions.Build(), + ServerOptions: t.ServerOptions.Build(), + Password: t.Password, + Network: clashNetworks(t.UDP), + OutboundTLSOptionsContainer: clashTLSOptions(t.Server, &t.TLSOptions), + Multiplex: t.MuxOpts.Build(), + Transport: clashTransport(t.Network, HTTPOptions{}, HTTP2Options{}, t.GrpcOpts, t.WSOpts), + } +} diff --git a/parser/clash/tuic.go b/parser/clash/tuic.go new file mode 100644 index 00000000..fb2ebccf --- /dev/null +++ b/parser/clash/tuic.go @@ -0,0 +1,47 @@ +package clash + +import ( + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/json/badoption" +) + +type TuicOption struct { + DialerOptions `yaml:",inline"` + ServerOptions `yaml:",inline"` + TLSOptions `yaml:",inline"` + UUID string `yaml:"uuid,omitempty"` + Password string `yaml:"password,omitempty"` + Ip string `yaml:"ip,omitempty"` + HeartbeatInterval int `yaml:"heartbeat-interval,omitempty"` + DisableSni bool `yaml:"disable-sni,omitempty"` + ReduceRtt bool `yaml:"reduce-rtt,omitempty"` + UdpRelayMode string `yaml:"udp-relay-mode,omitempty"` + CongestionController string `yaml:"congestion-controller,omitempty"` + FastOpen bool `yaml:"fast-open,omitempty"` + DisableMTUDiscovery bool `yaml:"disable-mtu-discovery,omitempty"` + UDPOverStream bool `yaml:"udp-over-stream,omitempty"` +} + +func (t *TuicOption) Build() any { + t.TLS = true + t.TFO = t.FastOpen + options := &option.TUICOutboundOptions{ + DialerOptions: t.DialerOptions.Build(), + ServerOptions: t.ServerOptions.Build(), + UUID: t.UUID, + Password: t.Password, + CongestionControl: t.CongestionController, + UDPRelayMode: t.UdpRelayMode, + UDPOverStream: t.UDPOverStream, + ZeroRTTHandshake: t.ReduceRtt, + Heartbeat: badoption.Duration(t.HeartbeatInterval), + OutboundTLSOptionsContainer: clashTLSOptions(t.Server, &t.TLSOptions), + } + if t.Ip != "" { + options.Server = t.Ip + } + if t.DisableSni { + options.TLS.DisableSNI = true + } + return options +} diff --git a/parser/clash/utils.go b/parser/clash/utils.go new file mode 100644 index 00000000..9055532c --- /dev/null +++ b/parser/clash/utils.go @@ -0,0 +1,205 @@ +package clash + +import ( + "strconv" + "strings" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/byteformats" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/json/badoption" + N "github.com/sagernet/sing/common/network" +) + +func clashClientFingerprint(clientFingerprint string) *option.OutboundUTLSOptions { + if clientFingerprint == "" { + return nil + } + return &option.OutboundUTLSOptions{ + Enabled: true, + Fingerprint: clientFingerprint, + } +} + +func clashHeaders(headers map[string]string) map[string]badoption.Listable[string] { + if headers == nil { + return nil + } + result := make(map[string]badoption.Listable[string]) + for key, value := range headers { + result[key] = []string{value} + } + return result +} + +func clashHysteria2Obfs(obfs string, password string) *option.Hysteria2Obfs { + if obfs == "" { + return nil + } + return &option.Hysteria2Obfs{ + Type: obfs, + Password: password, + } +} + +func clashNetworks(udpEnabled bool) option.NetworkList { + if !udpEnabled { + return N.NetworkTCP + } + return "" +} + +func clashPluginName(plugin string) string { + switch plugin { + case "obfs": + return "obfs-local" + } + return plugin +} + +func clashPluginOptions(plugin string, opts map[string]any) string { + options := make(shadowsocksPluginOptionsBuilder) + switch plugin { + case "obfs": + options["obfs"] = opts["mode"] + options["obfs-host"] = opts["host"] + case "v2ray-plugin": + options["mode"] = opts["mode"] + options["tls"] = opts["tls"] + options["host"] = opts["host"] + options["path"] = opts["path"] + } + return options.Build() +} + +func clashPorts(ports string) badoption.Listable[string] { + if ports == "" { + return nil + } + serverPorts := badoption.Listable[string]{} + ports = strings.ReplaceAll(ports, "/", ",") + for _, port := range strings.Split(ports, ",") { + if port == "" { + continue + } + port = strings.Replace(port, "-", ":", 1) + serverPorts = append(serverPorts, port) + } + return serverPorts +} + +func clashShadowsocksCipher(cipher string) string { + switch cipher { + case "dummy": + return "none" + } + return cipher +} + +func clashStringList(list []string) string { + if len(list) > 0 { + return list[0] + } + return "" +} + +func clashSpeedToIntMbps(speed string) int { + if speed == "" { + return 0 + } + if num, err := strconv.Atoi(speed); err == nil { + return num + } + networkBytes := byteformats.NetworkBytesCompat{} + if err := networkBytes.UnmarshalJSON([]byte(speed)); err != nil { + return 0 + } + return int(networkBytes.Value() / byteformats.MByte * 8) +} + +func clashSpeedToNetworkBytes(speed string) *byteformats.NetworkBytesCompat { + if speed == "" { + return nil + } + networkBytes := &byteformats.NetworkBytesCompat{} + if num, err := strconv.Atoi(speed); err == nil { + speed = F.ToString(num, "Mbps") + } + if err := networkBytes.UnmarshalJSON([]byte(speed)); err != nil { + return nil + } + return networkBytes +} + +func clashTransport(network string, httpOpts HTTPOptions, h2Opts HTTP2Options, grpcOpts GrpcOptions, wsOpts WSOptions) *option.V2RayTransportOptions { + switch network { + case "http": + return &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeHTTP, + HTTPOptions: option.V2RayHTTPOptions{ + Method: httpOpts.Method, + Path: clashStringList(httpOpts.Path), + Headers: httpOpts.Headers, + }, + } + case "h2": + return &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeHTTP, + HTTPOptions: option.V2RayHTTPOptions{ + Path: h2Opts.Path, + Host: h2Opts.Host, + }, + } + case "grpc": + return &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeGRPC, + GRPCOptions: option.V2RayGRPCOptions{ + ServiceName: grpcOpts.GrpcServiceName, + }, + } + case "ws": + headers := clashHeaders(wsOpts.Headers) + if wsOpts.V2rayHttpUpgrade { + var host string + if headers != nil && headers["Host"] != nil { + host = headers["Host"][0] + } + return &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeHTTPUpgrade, + HTTPUpgradeOptions: option.V2RayHTTPUpgradeOptions{ + Host: host, + Path: wsOpts.Path, + Headers: headers, + }, + } + } + return &option.V2RayTransportOptions{ + Type: C.V2RayTransportTypeWebsocket, + WebsocketOptions: option.V2RayWebsocketOptions{ + Path: wsOpts.Path, + Headers: headers, + MaxEarlyData: uint32(wsOpts.MaxEarlyData), + EarlyDataHeaderName: wsOpts.EarlyDataHeaderName, + }, + } + default: + return nil + } +} + +func clashTLSOptions(server string, tlsOptions *TLSOptions) option.OutboundTLSOptionsContainer { + if tlsOptions != nil && tlsOptions.SNI == "" { + tlsOptions.SNI = server + } + return option.OutboundTLSOptionsContainer{ + TLS: tlsOptions.Build(), + } +} + +func trimStringArray(array []string) []string { + return common.Filter(array, func(it string) bool { + return strings.TrimSpace(it) != "" + }) +} diff --git a/parser/clash/vless.go b/parser/clash/vless.go new file mode 100644 index 00000000..d5494562 --- /dev/null +++ b/parser/clash/vless.go @@ -0,0 +1,49 @@ +package clash + +import "github.com/sagernet/sing-box/option" + +type VlessOption struct { + DialerOptions `yaml:",inline"` + ServerOptions `yaml:",inline"` + *TLSOptions `yaml:",inline"` + UUID string `yaml:"uuid"` + Flow string `yaml:"flow,omitempty"` + UDP bool `yaml:"udp,omitempty"` + PacketAddr bool `yaml:"packet-addr,omitempty"` + XUDP bool `yaml:"xudp,omitempty"` + PacketEncoding string `yaml:"packet-encoding,omitempty"` + Network string `yaml:"network,omitempty"` + ServerName string `yaml:"servername,omitempty"` + HTTPOpts HTTPOptions `yaml:"http-opts,omitempty"` + HTTP2Opts HTTP2Options `yaml:"h2-opts,omitempty"` + GrpcOpts GrpcOptions `yaml:"grpc-opts,omitempty"` + WSOpts WSOptions `yaml:"ws-opts,omitempty"` + MuxOpts *MuxOptions `yaml:"smux,omitempty"` +} + +func (v *VlessOption) Build() any { + if v.TLSOptions != nil { + v.SNI = v.ServerName + } + switch v.PacketEncoding { + case "": + if v.PacketAddr { + v.PacketEncoding = "packetaddr" + } else { + v.PacketEncoding = "xudp" + } + case "packet": + v.PacketEncoding = "packetaddr" + } + return &option.VLESSOutboundOptions{ + DialerOptions: v.DialerOptions.Build(), + ServerOptions: v.ServerOptions.Build(), + UUID: v.UUID, + Flow: v.Flow, + Network: clashNetworks(v.UDP), + OutboundTLSOptionsContainer: clashTLSOptions(v.Server, v.TLSOptions), + Multiplex: v.MuxOpts.Build(), + Transport: clashTransport(v.Network, v.HTTPOpts, v.HTTP2Opts, v.GrpcOpts, v.WSOpts), + PacketEncoding: &v.PacketEncoding, + } +} diff --git a/parser/clash/vmess.go b/parser/clash/vmess.go new file mode 100644 index 00000000..fccea09a --- /dev/null +++ b/parser/clash/vmess.go @@ -0,0 +1,55 @@ +package clash + +import "github.com/sagernet/sing-box/option" + +type VmessOption struct { + DialerOptions `yaml:",inline"` + ServerOptions `yaml:",inline"` + *TLSOptions `yaml:",inline"` + UUID string `yaml:"uuid"` + AlterID int `yaml:"alterId"` + Cipher string `yaml:"cipher"` + UDP bool `yaml:"udp,omitempty"` + Network string `yaml:"network,omitempty"` + ServerName string `yaml:"servername,omitempty"` + HTTPOpts HTTPOptions `yaml:"http-opts,omitempty"` + HTTP2Opts HTTP2Options `yaml:"h2-opts,omitempty"` + GrpcOpts GrpcOptions `yaml:"grpc-opts,omitempty"` + WSOpts WSOptions `yaml:"ws-opts,omitempty"` + PacketAddr bool `yaml:"packet-addr,omitempty"` + XUDP bool `yaml:"xudp,omitempty"` + PacketEncoding string `yaml:"packet-encoding,omitempty"` + GlobalPadding bool `yaml:"global-padding,omitempty"` + AuthenticatedLength bool `yaml:"authenticated-length,omitempty"` + MuxOpts *MuxOptions `yaml:"smux,omitempty"` +} + +func (v *VmessOption) Build() any { + if v.TLSOptions != nil { + v.SNI = v.ServerName + } + switch v.PacketEncoding { + case "": + if v.XUDP { + v.PacketEncoding = "xudp" + } else if v.PacketAddr { + v.PacketEncoding = "packetaddr" + } + case "packet": + v.PacketEncoding = "packetaddr" + } + return &option.VMessOutboundOptions{ + DialerOptions: v.DialerOptions.Build(), + ServerOptions: v.ServerOptions.Build(), + UUID: v.UUID, + Security: v.Cipher, + AlterId: v.AlterID, + GlobalPadding: v.GlobalPadding, + AuthenticatedLength: v.AuthenticatedLength, + Network: clashNetworks(v.UDP), + OutboundTLSOptionsContainer: clashTLSOptions(v.Server, v.TLSOptions), + PacketEncoding: v.PacketEncoding, + Multiplex: v.MuxOpts.Build(), + Transport: clashTransport(v.Network, v.HTTPOpts, v.HTTP2Opts, v.GrpcOpts, v.WSOpts), + } +} diff --git a/parser/link/hysteria.go b/parser/link/hysteria.go new file mode 100644 index 00000000..4e756af1 --- /dev/null +++ b/parser/link/hysteria.go @@ -0,0 +1,71 @@ +package link + +import ( + "net/url" + "strconv" + "strings" + + "github.com/sagernet/sing-box/common" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/byteformats" +) + +func parseHysteriaLink(link string) (option.Outbound, error) { + linkURL, err := url.Parse(link) + if err != nil { + return option.Outbound{}, err + } + var options option.HysteriaOutboundOptions + TLSOptions := option.OutboundTLSOptions{ + Enabled: true, + ECH: &option.OutboundECHOptions{}, + UTLS: &option.OutboundUTLSOptions{}, + Reality: &option.OutboundRealityOptions{}, + } + options.Server = linkURL.Hostname() + TLSOptions.ServerName = linkURL.Hostname() + options.ServerPort = common.StringToType[uint16](linkURL.Port()) + for key, values := range linkURL.Query() { + value := values[0] + switch key { + case "auth": + options.AuthString = value + case "peer", "sni": + TLSOptions.ServerName = value + case "alpn": + TLSOptions.ALPN = strings.Split(value, ",") + case "ca": + TLSOptions.CertificatePath = value + case "ca_str": + TLSOptions.Certificate = strings.Split(value, "\n") + case "up": + options.Up = &byteformats.NetworkBytesCompat{} + options.Up.UnmarshalJSON([]byte(value)) + case "up_mbps": + options.UpMbps, _ = strconv.Atoi(value) + case "down": + options.Down = &byteformats.NetworkBytesCompat{} + options.Down.UnmarshalJSON([]byte(value)) + case "down_mbps": + options.DownMbps, _ = strconv.Atoi(value) + case "obfs", "obfsParam": + options.Obfs = value + case "insecure", "skip-cert-verify": + if value == "1" || value == "true" { + TLSOptions.Insecure = true + } + case "tfo", "tcp-fast-open", "tcp_fast_open": + if value == "1" || value == "true" { + options.TCPFastOpen = true + } + } + } + outbound := option.Outbound{ + Type: C.TypeHysteria, + Tag: linkURL.Fragment, + } + options.TLS = &TLSOptions + outbound.Options = &options + return outbound, nil +} diff --git a/parser/link/hysteria2.go b/parser/link/hysteria2.go new file mode 100644 index 00000000..d9486064 --- /dev/null +++ b/parser/link/hysteria2.go @@ -0,0 +1,61 @@ +package link + +import ( + "net/url" + "strconv" + + "github.com/sagernet/sing-box/common" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" +) + +func parseHysteria2Link(link string) (option.Outbound, error) { + linkURL, err := url.Parse(link) + if err != nil { + return option.Outbound{}, err + } + var options option.Hysteria2OutboundOptions + TLSOptions := option.OutboundTLSOptions{ + Enabled: true, + ECH: &option.OutboundECHOptions{}, + UTLS: &option.OutboundUTLSOptions{}, + Reality: &option.OutboundRealityOptions{}, + } + Obfs := &option.Hysteria2Obfs{} + options.ServerPort = uint16(443) + options.Server = linkURL.Hostname() + TLSOptions.ServerName = linkURL.Hostname() + if linkURL.User != nil { + options.Password = linkURL.User.Username() + } + if linkURL.Port() != "" { + options.ServerPort = common.StringToType[uint16](linkURL.Port()) + } + for key, values := range linkURL.Query() { + value := values[0] + switch key { + case "up": + options.UpMbps, _ = strconv.Atoi(value) + case "down": + options.DownMbps, _ = strconv.Atoi(value) + case "obfs": + if value == "salamander" { + Obfs.Type = "salamander" + options.Obfs = Obfs + } + case "obfs-password": + Obfs.Password = value + case "insecure", "skip-cert-verify": + if value == "1" || value == "true" { + TLSOptions.Insecure = true + } + } + } + outbound := option.Outbound{ + Type: C.TypeHysteria2, + Tag: linkURL.Fragment, + } + options.TLS = &TLSOptions + outbound.Options = &options + return outbound, nil +} diff --git a/parser/link/parser.go b/parser/link/parser.go new file mode 100644 index 00000000..72e35ab6 --- /dev/null +++ b/parser/link/parser.go @@ -0,0 +1,42 @@ +package link + +import ( + "regexp" + "strings" + + "github.com/sagernet/sing-box/common" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func ParseSubscriptionLink(link string) (option.Outbound, error) { + reg := regexp.MustCompile(`^(.*?)(://)(.*?)([@?#].*)?$`) + result := reg.FindStringSubmatch(link) + if result == nil { + return option.Outbound{}, E.New("invalid link") + } + + scheme := result[1] + switch scheme { + case "tuic": + return parseTuicLink(link) + case "trojan": + return parseTrojanLink(link) + case "vless": + return parseVLESSLink(link) + case "hysteria": + return parseHysteriaLink(link) + case "hy2", "hysteria2": + return parseHysteria2Link(link) + } + result[3], _ = common.DecodeBase64URLSafe(result[3]) + link = strings.Join(result[1:], "") + switch scheme { + case "ss": + return parseShadowsocksLink(link) + case "vmess": + return parseVMessLink(link) + default: + return option.Outbound{}, E.New("unsupported scheme: ", scheme) + } +} diff --git a/parser/link/shadowsocks.go b/parser/link/shadowsocks.go new file mode 100644 index 00000000..d22df3f1 --- /dev/null +++ b/parser/link/shadowsocks.go @@ -0,0 +1,39 @@ +package link + +import ( + "net/url" + + "github.com/sagernet/sing-box/common" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" +) + +func parseShadowsocksLink(link string) (option.Outbound, error) { + linkURL, err := url.Parse(link) + if err != nil { + return option.Outbound{}, err + } + if linkURL.User == nil || linkURL.User.Username() == "" { + return option.Outbound{}, E.New("missing user info") + } + var options option.ShadowsocksOutboundOptions + options.ServerOptions.Server = linkURL.Hostname() + options.ServerOptions.ServerPort = common.StringToType[uint16](linkURL.Port()) + password, _ := linkURL.User.Password() + if password == "" { + return option.Outbound{}, E.New("bad user info") + } + options.Method = linkURL.User.Username() + options.Password = password + plugin := linkURL.Query().Get("plugin") + options.Plugin = shadowsocksPluginName(plugin) + options.PluginOptions = shadowsocksPluginOptions(plugin) + + outbound := option.Outbound{ + Type: C.TypeShadowsocks, + Tag: linkURL.Fragment, + } + outbound.Options = &options + return outbound, nil +} diff --git a/parser/link/trojan.go b/parser/link/trojan.go new file mode 100644 index 00000000..bca7ddf3 --- /dev/null +++ b/parser/link/trojan.go @@ -0,0 +1,89 @@ +package link + +import ( + "net/url" + "strings" + + "github.com/sagernet/sing-box/common" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json/badoption" +) + +func parseTrojanLink(link string) (option.Outbound, error) { + linkURL, err := url.Parse(link) + if err != nil { + return option.Outbound{}, err + } + if linkURL.User == nil || linkURL.User.Username() == "" { + return option.Outbound{}, E.New("missing password") + } + var options option.TrojanOutboundOptions + TLSOptions := option.OutboundTLSOptions{ + Enabled: true, + ECH: &option.OutboundECHOptions{}, + UTLS: &option.OutboundUTLSOptions{}, + Reality: &option.OutboundRealityOptions{}, + } + options.Server = linkURL.Hostname() + TLSOptions.ServerName = linkURL.Hostname() + options.ServerPort = common.StringToType[uint16](linkURL.Port()) + options.Password = linkURL.User.Username() + proxy := map[string]string{} + for key, values := range linkURL.Query() { + value := values[0] + proxy[key] = value + } + for key, value := range proxy { + switch key { + case "insecure", "allowInsecure", "skip-cert-verify": + if value == "1" || value == "true" { + TLSOptions.Insecure = true + } + case "serviceName", "sni", "peer": + TLSOptions.ServerName = value + case "alpn": + TLSOptions.ALPN = strings.Split(value, ",") + case "fp": + TLSOptions.UTLS.Enabled = true + TLSOptions.UTLS.Fingerprint = value + case "type": + Transport := option.V2RayTransportOptions{ + Type: "", + WebsocketOptions: option.V2RayWebsocketOptions{ + Headers: map[string]badoption.Listable[string]{}, + }, + HTTPOptions: option.V2RayHTTPOptions{ + Host: badoption.Listable[string]{}, + Headers: map[string]badoption.Listable[string]{}, + }, + GRPCOptions: option.V2RayGRPCOptions{}, + } + switch value { + case "ws": + Transport.Type = C.V2RayTransportTypeWebsocket + Transport.WebsocketOptions = v2rayTransportWs(proxy["host"], proxy["path"]) + case "grpc": + Transport.Type = C.V2RayTransportTypeGRPC + if serviceName, exists := proxy["grpc-service-name"]; exists && serviceName != "" { + Transport.GRPCOptions.ServiceName = serviceName + } + default: + continue + } + options.Transport = &Transport + case "tfo", "tcp-fast-open", "tcp_fast_open": + if value == "1" || value == "true" { + options.TCPFastOpen = true + } + } + } + outbound := option.Outbound{ + Type: C.TypeTrojan, + Tag: linkURL.Fragment, + } + options.TLS = &TLSOptions + outbound.Options = &options + return outbound, nil +} diff --git a/parser/link/tuic.go b/parser/link/tuic.go new file mode 100644 index 00000000..1cf1e29c --- /dev/null +++ b/parser/link/tuic.go @@ -0,0 +1,81 @@ +package link + +import ( + "net/url" + "strings" + + "github.com/sagernet/sing-box/common" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json/badoption" +) + +func parseTuicLink(link string) (option.Outbound, error) { + linkURL, err := url.Parse(link) + if err != nil { + return option.Outbound{}, err + } + if linkURL.User == nil || linkURL.User.Username() == "" { + return option.Outbound{}, E.New("missing uuid") + } + var options option.TUICOutboundOptions + TLSOptions := option.OutboundTLSOptions{ + Enabled: true, + ECH: &option.OutboundECHOptions{}, + UTLS: &option.OutboundUTLSOptions{}, + Reality: &option.OutboundRealityOptions{}, + } + options.UUID = linkURL.User.Username() + options.Password, _ = linkURL.User.Password() + options.ServerOptions.Server = linkURL.Hostname() + TLSOptions.ServerName = linkURL.Hostname() + options.ServerOptions.ServerPort = common.StringToType[uint16](linkURL.Port()) + for key, values := range linkURL.Query() { + value := values[0] + switch key { + case "congestion_control": + if value != "cubic" { + options.CongestionControl = value + } + case "udp_relay_mode": + options.UDPRelayMode = value + case "udp_over_stream": + if value == "true" || value == "1" { + options.UDPOverStream = true + } + case "zero_rtt_handshake", "reduce_rtt": + if value == "true" || value == "1" { + options.ZeroRTTHandshake = true + } + case "heartbeat_interval": + options.Heartbeat = common.StringToType[badoption.Duration](value) + case "sni": + TLSOptions.ServerName = value + case "insecure", "skip-cert-verify", "allow_insecure": + if value == "1" || value == "true" { + TLSOptions.Insecure = true + } + case "disable_sni": + if value == "1" || value == "true" { + TLSOptions.DisableSNI = true + } + case "tfo", "tcp-fast-open", "tcp_fast_open": + if value == "1" || value == "true" { + options.TCPFastOpen = true + } + case "alpn": + TLSOptions.ALPN = strings.Split(value, ",") + } + } + if options.UDPOverStream { + options.UDPRelayMode = "" + } + outbound := option.Outbound{ + Type: C.TypeTUIC, + Tag: linkURL.Fragment, + } + options.TLS = &TLSOptions + outbound.Options = &options + return outbound, nil +} diff --git a/parser/link/utils.go b/parser/link/utils.go new file mode 100644 index 00000000..bb72cd4d --- /dev/null +++ b/parser/link/utils.go @@ -0,0 +1,46 @@ +package link + +import ( + "regexp" + "strings" + + "github.com/sagernet/sing-box/common" + "github.com/sagernet/sing-box/option" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/json/badoption" +) + +func shadowsocksPluginName(plugin string) string { + if index := strings.Index(plugin, ";"); index != -1 { + return plugin[:index] + } + return plugin +} + +func shadowsocksPluginOptions(plugin string) string { + if index := strings.Index(plugin, ";"); index != -1 { + return plugin[index+1:] + } + return "" +} + +func v2rayTransportWsPath(WebsocketOptions *option.V2RayWebsocketOptions, path string) { + reg := regexp.MustCompile(`^(.*?)(?:\?ed=(\d*))?$`) + result := reg.FindStringSubmatch(path) + WebsocketOptions.Path = result[1] + if result[2] != "" { + WebsocketOptions.EarlyDataHeaderName = "Sec-WebSocket-Protocol" + WebsocketOptions.MaxEarlyData = common.StringToType[uint32](result[2]) + } +} + +func v2rayTransportWs(host string, path string) option.V2RayWebsocketOptions { + var WebsocketOptions option.V2RayWebsocketOptions + if host != "" { + WebsocketOptions.Headers = common.StringToType[badoption.HTTPHeader](F.ToString("Host: ", host)) + } + if path != "" { + v2rayTransportWsPath(&WebsocketOptions, path) + } + return WebsocketOptions +} diff --git a/parser/link/vless.go b/parser/link/vless.go new file mode 100644 index 00000000..80442e1e --- /dev/null +++ b/parser/link/vless.go @@ -0,0 +1,114 @@ +package link + +import ( + "net/url" + "strings" + + "github.com/sagernet/sing-box/common" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json/badoption" +) + +func parseVLESSLink(link string) (option.Outbound, error) { + linkURL, err := url.Parse(link) + if err != nil { + return option.Outbound{}, err + } + if linkURL.User == nil || linkURL.User.Username() == "" { + return option.Outbound{}, E.New("missing uuid") + } + var options option.VLESSOutboundOptions + TLSOptions := option.OutboundTLSOptions{ + ECH: &option.OutboundECHOptions{}, + UTLS: &option.OutboundUTLSOptions{}, + Reality: &option.OutboundRealityOptions{}, + } + options.UUID = linkURL.User.Username() + options.Server = linkURL.Hostname() + TLSOptions.ServerName = linkURL.Hostname() + options.ServerPort = common.StringToType[uint16](linkURL.Port()) + proxy := map[string]string{} + for key, values := range linkURL.Query() { + value := values[0] + switch key { + case "key", "alpn", "seed", "path", "host": + proxy[key] = value + default: + proxy[key] = value + } + } + for key, value := range proxy { + switch key { + case "type": + Transport := option.V2RayTransportOptions{ + HTTPOptions: option.V2RayHTTPOptions{ + Host: badoption.Listable[string]{}, + Headers: badoption.HTTPHeader{}, + }, + GRPCOptions: option.V2RayGRPCOptions{}, + } + switch value { + case "ws": + Transport.Type = C.V2RayTransportTypeWebsocket + Transport.WebsocketOptions = v2rayTransportWs(proxy["host"], proxy["path"]) + case "http": + Transport.Type = C.V2RayTransportTypeHTTP + if host, exists := proxy["host"]; exists && host != "" { + Transport.HTTPOptions.Host = strings.Split(host, ",") + } + if path, exists := proxy["path"]; exists && path != "" { + Transport.HTTPOptions.Path = path + } + case "grpc": + Transport.Type = C.V2RayTransportTypeGRPC + if serviceName, exists := proxy["serviceName"]; exists && serviceName != "" { + Transport.GRPCOptions.ServiceName = serviceName + } + default: + continue + } + options.Transport = &Transport + case "security": + if value == "tls" { + TLSOptions.Enabled = true + } else if value == "reality" { + TLSOptions.Enabled = true + TLSOptions.Reality.Enabled = true + } + case "insecure", "skip-cert-verify": + if value == "1" || value == "true" { + TLSOptions.Insecure = true + } + case "serviceName", "sni", "peer": + TLSOptions.ServerName = value + case "alpn": + TLSOptions.ALPN = strings.Split(value, ",") + case "fp": + TLSOptions.UTLS.Enabled = true + TLSOptions.UTLS.Fingerprint = value + case "flow": + if value == "xtls-rprx-vision" { + options.Flow = "xtls-rprx-vision" + } + case "pbk": + TLSOptions.Reality.PublicKey = value + case "sid": + TLSOptions.Reality.ShortID = value + case "tfo", "tcp-fast-open", "tcp_fast_open": + if value == "1" || value == "true" { + options.TCPFastOpen = true + } + } + } + outbound := option.Outbound{ + Type: C.TypeVLESS, + Tag: linkURL.Fragment, + } + if TLSOptions.Enabled { + options.TLS = &TLSOptions + } + outbound.Options = &options + return outbound, nil +} diff --git a/parser/link/vmess.go b/parser/link/vmess.go new file mode 100644 index 00000000..d3a09711 --- /dev/null +++ b/parser/link/vmess.go @@ -0,0 +1,160 @@ +package link + +import ( + "encoding/json" + "net/url" + "regexp" + "strconv" + + "github.com/sagernet/sing-box/common" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json/badoption" +) + +func parseVMessLink(link string) (option.Outbound, error) { + var proxy map[string]string + reg := regexp.MustCompile(`(\"[^:,]+?\"[ \t]*:[ \t]*)(\d+|true|false)`) + s := reg.ReplaceAllString(link, `$1"$2"`) + err := json.Unmarshal([]byte(s[8:]), &proxy) + if err != nil { + proxy = make(map[string]string) + linkURL, err := url.Parse(link) + if err != nil { + return option.Outbound{}, err + } + if linkURL.User == nil || linkURL.User.Username() == "" { + return option.Outbound{}, E.New("missing uuid") + } + proxy["id"] = linkURL.User.Username() + proxy["add"] = linkURL.Hostname() + proxy["port"] = linkURL.Port() + proxy["ps"] = linkURL.Fragment + for key, values := range linkURL.Query() { + value := values[0] + switch key { + case "type": + if value == "http" { + proxy["net"] = "tcp" + proxy["type"] = "http" + } + case "encryption": + proxy["scy"] = value + case "alterId": + proxy["aid"] = value + case "key", "alpn", "seed", "path", "host": + proxy[key] = value + default: + proxy[key] = value + } + } + } + outbound := option.Outbound{ + Type: C.TypeVMess, + } + options := option.VMessOutboundOptions{ + Security: "auto", + } + TLSOptions := option.OutboundTLSOptions{ + ECH: &option.OutboundECHOptions{}, + UTLS: &option.OutboundUTLSOptions{}, + Reality: &option.OutboundRealityOptions{}, + } + for key, value := range proxy { + switch key { + case "ps": + outbound.Tag = value + case "add": + options.Server = value + TLSOptions.ServerName = value + case "port": + options.ServerPort = common.StringToType[uint16](value) + case "id": + options.UUID = value + case "scy": + options.Security = value + case "aid": + options.AlterId, _ = strconv.Atoi(value) + case "packet_encoding": + options.PacketEncoding = value + case "xudp": + if value == "1" || value == "true" { + options.PacketEncoding = "xudp" + } + case "tls": + if value == "1" || value == "true" || value == "tls" { + TLSOptions.Enabled = true + } + case "insecure", "skip-cert-verify": + if value == "1" || value == "true" { + TLSOptions.Insecure = true + } + case "fp": + TLSOptions.UTLS.Enabled = true + TLSOptions.UTLS.Fingerprint = value + case "net": + Transport := option.V2RayTransportOptions{ + Type: "", + WebsocketOptions: option.V2RayWebsocketOptions{ + Headers: badoption.HTTPHeader{}, + }, + HTTPOptions: option.V2RayHTTPOptions{ + Host: badoption.Listable[string]{}, + Headers: map[string]badoption.Listable[string]{}, + }, + GRPCOptions: option.V2RayGRPCOptions{}, + } + switch value { + case "ws": + Transport.Type = C.V2RayTransportTypeWebsocket + Transport.WebsocketOptions = v2rayTransportWs(proxy["host"], proxy["path"]) + case "h2": + Transport.Type = C.V2RayTransportTypeHTTP + TLSOptions.Enabled = true + if host, exists := proxy["host"]; exists && host != "" { + Transport.HTTPOptions.Host = []string{host} + } + if path, exists := proxy["path"]; exists && path != "" { + Transport.HTTPOptions.Path = path + } + case "tcp": + if tType, exists := proxy["type"]; exists { + if tType != "http" { + continue + } + Transport.Type = C.V2RayTransportTypeHTTP + if method, exists := proxy["method"]; exists { + Transport.HTTPOptions.Method = method + } + if host, exists := proxy["host"]; exists && host != "" { + Transport.HTTPOptions.Host = []string{host} + } + if path, exists := proxy["path"]; exists && path != "" { + Transport.HTTPOptions.Path = path + } + if headers, exists := proxy["headers"]; exists { + Transport.HTTPOptions.Headers = common.StringToType[badoption.HTTPHeader](headers) + } + } + case "grpc": + Transport.Type = C.V2RayTransportTypeGRPC + if host, exists := proxy["host"]; exists && host != "" { + Transport.GRPCOptions.ServiceName = host + } + default: + continue + } + options.Transport = &Transport + case "tfo", "tcp-fast-open", "tcp_fast_open": + if value == "1" || value == "true" { + options.TCPFastOpen = true + } + } + } + if TLSOptions.Enabled { + options.TLS = &TLSOptions + } + outbound.Options = &options + return outbound, nil +} diff --git a/parser/parser.go b/parser/parser.go new file mode 100644 index 00000000..32da0064 --- /dev/null +++ b/parser/parser.go @@ -0,0 +1,31 @@ +package parser + +import ( + "context" + + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/parser/clash" + "github.com/sagernet/sing-box/parser/raw" + "github.com/sagernet/sing-box/parser/singbox" + "github.com/sagernet/sing-box/parser/sip008" + E "github.com/sagernet/sing/common/exceptions" +) + +var subscriptionParsers = []func(ctx context.Context, content string) ([]option.Outbound, error){ + singbox.ParseBoxSubscription, + clash.ParseClashSubscription, + sip008.ParseSIP008Subscription, + raw.ParseRawSubscription, +} + +func ParseSubscription(ctx context.Context, content string) ([]option.Outbound, error) { + var pErr error + for _, parser := range subscriptionParsers { + servers, err := parser(ctx, content) + if len(servers) > 0 { + return servers, nil + } + pErr = E.Errors(pErr, err) + } + return nil, E.Cause(pErr, "no servers found") +} diff --git a/parser/raw/parser.go b/parser/raw/parser.go new file mode 100644 index 00000000..4c8fe2d8 --- /dev/null +++ b/parser/raw/parser.go @@ -0,0 +1,50 @@ +package raw + +import ( + "context" + "encoding/base64" + "strings" + + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/parser/link" + E "github.com/sagernet/sing/common/exceptions" +) + +func ParseRawSubscription(ctx context.Context, content string) ([]option.Outbound, error) { + if base64Content, err := DecodeBase64URLSafe(content); err == nil { + servers, _ := parseRawSubscription(base64Content) + if len(servers) > 0 { + return servers, err + } + } + return parseRawSubscription(content) +} + +func parseRawSubscription(content string) ([]option.Outbound, error) { + var servers []option.Outbound + content = strings.ReplaceAll(content, "\r\n", "\n") + linkList := strings.Split(content, "\n") + for _, linkLine := range linkList { + server, err := link.ParseSubscriptionLink(linkLine) + if err != nil { + continue + } + servers = append(servers, server) + } + if len(servers) == 0 { + return nil, E.New("no servers found") + } + return servers, nil +} + +func DecodeBase64URLSafe(content string) (string, error) { + s := strings.ReplaceAll(content, " ", "-") + s = strings.ReplaceAll(s, "/", "_") + s = strings.ReplaceAll(s, "+", "-") + s = strings.ReplaceAll(s, "=", "") + result, err := base64.RawURLEncoding.DecodeString(s) + if err != nil { + return content, nil + } + return string(result), nil +} diff --git a/parser/singbox/parser.go b/parser/singbox/parser.go new file mode 100644 index 00000000..f5d3ed5a --- /dev/null +++ b/parser/singbox/parser.go @@ -0,0 +1,58 @@ +package singbox + +import ( + "context" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" + "github.com/sagernet/sing/common/json/badjson" +) + +type _SingBoxDocument struct { + Outbounds []option.Outbound `json:"outbounds"` +} +type SingBoxDocument _SingBoxDocument + +func (o *SingBoxDocument) UnmarshalJSONContext(ctx context.Context, inputContent []byte) error { + var content badjson.JSONObject + err := content.UnmarshalJSONContext(ctx, inputContent) + if err != nil { + return err + } + outbounds, ok := content.Get("outbounds") + if !ok { + return E.New("missing outbounds in sing-box configuration") + } + var outs badjson.JSONArray + for i, outbound := range outbounds.(badjson.JSONArray) { + typeVal, loaded := outbound.(*badjson.JSONObject).Get("type") + if !loaded { + return E.New("missing type in outbound[", i, "]") + } + switch typeVal.(string) { + case C.TypeDirect, C.TypeBlock, C.TypeDNS, C.TypeSelector, C.TypeURLTest: + continue + default: + outs = append(outs, outbound) + } + } + content.Put("outbounds", outs) + inputContent, err = content.MarshalJSONContext(ctx) + if err != nil { + return err + } + return json.UnmarshalContext(ctx, inputContent, (*_SingBoxDocument)(o)) +} + +func ParseBoxSubscription(ctx context.Context, content string) ([]option.Outbound, error) { + options, err := json.UnmarshalExtendedContext[SingBoxDocument](ctx, []byte(content)) + if err != nil { + return nil, err + } + if len(options.Outbounds) == 0 { + return nil, E.New("no servers found") + } + return options.Outbounds, nil +} diff --git a/parser/sip008/parser.go b/parser/sip008/parser.go new file mode 100644 index 00000000..ad55d062 --- /dev/null +++ b/parser/sip008/parser.go @@ -0,0 +1,53 @@ +package sip008 + +import ( + "context" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/json" +) + +type ShadowsocksDocument struct { + Version int `json:"version"` + Servers []ShadowsocksServerDocument `json:"servers"` +} + +type ShadowsocksServerDocument struct { + ID string `json:"id"` + Remarks string `json:"remarks"` + Server string `json:"server"` + ServerPort int `json:"server_port"` + Password string `json:"password"` + Method string `json:"method"` + Plugin string `json:"plugin"` + PluginOpts string `json:"plugin_opts"` +} + +func ParseSIP008Subscription(_ context.Context, content string) ([]option.Outbound, error) { + var document ShadowsocksDocument + err := json.Unmarshal([]byte(content), &document) + if err != nil { + return nil, E.Cause(err, "parse SIP008 document") + } + + var servers []option.Outbound + for _, server := range document.Servers { + servers = append(servers, option.Outbound{ + Type: C.TypeShadowsocks, + Tag: server.Remarks, + Options: &option.ShadowsocksOutboundOptions{ + ServerOptions: option.ServerOptions{ + Server: server.Server, + ServerPort: uint16(server.ServerPort), + }, + Password: server.Password, + Method: server.Method, + Plugin: server.Plugin, + PluginOptions: server.PluginOpts, + }, + }) + } + return servers, nil +} diff --git a/protocol/bond/conn.go b/protocol/bond/conn.go index 5ddeeda9..992e5187 100644 --- a/protocol/bond/conn.go +++ b/protocol/bond/conn.go @@ -1,6 +1,7 @@ package bond import ( + "bytes" "encoding/binary" "errors" "io" @@ -13,9 +14,7 @@ type bondedConn struct { downloadRatios []uint8 uploadRatios []uint8 - readBuffer []byte - readOffset int - readSize int + readBuffer *bytes.Buffer } func NewBondedConn(conns []net.Conn, downloadRatios, uploadRatios []uint8) *bondedConn { @@ -23,12 +22,13 @@ func NewBondedConn(conns []net.Conn, downloadRatios, uploadRatios []uint8) *bond conns: conns, downloadRatios: downloadRatios, uploadRatios: uploadRatios, - readBuffer: make([]byte, 65535), + readBuffer: bytes.NewBuffer(make([]byte, 0, 65536)), } } func (c *bondedConn) Read(b []byte) (n int, err error) { - if c.readOffset == c.readSize { + if c.readBuffer.Len() == 0 { + c.readBuffer.Reset() var header [2]byte _, err := io.ReadFull(c.conns[0], header[:]) if err != nil { @@ -41,19 +41,14 @@ func (c *bondedConn) Read(b []byte) (n int, err error) { if chunkLen == 0 { continue } - chunk := c.readBuffer[total : total+chunkLen] - n, err := io.ReadFull(c.conns[i], chunk) - total += n + n, err := io.CopyN(c.readBuffer, c.conns[i], int64(chunkLen)) + total += int(n) if err != nil { return total, err } } - c.readOffset = 0 - c.readSize = size } - n = copy(b, c.readBuffer[c.readOffset:c.readSize]) - c.readOffset += n - return n, nil + return c.readBuffer.Read(b) } func (c *bondedConn) Write(b []byte) (n int, err error) { diff --git a/protocol/bond/inbound.go b/protocol/bond/inbound.go index 6eac51c6..89796c4b 100644 --- a/protocol/bond/inbound.go +++ b/protocol/bond/inbound.go @@ -6,7 +6,8 @@ import ( "net" "time" - "github.com/patrickmn/go-cache" + "github.com/gofrs/uuid/v5" + "github.com/patrickmn/go-cache/v2" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter/inbound" "github.com/sagernet/sing-box/common/kmutex" @@ -29,9 +30,9 @@ type Inbound struct { logger logger.ContextLogger router adapter.ConnectionRouterEx inbounds []adapter.Inbound - conns *cache.Cache + conns *cache.Cache[uuid.UUID, map[uint8]*ratioConn] - mtx *kmutex.Kmutex[string] + mtx *kmutex.Kmutex[uuid.UUID] } func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.BondInboundOptions) (adapter.Inbound, error) { @@ -42,23 +43,23 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo Adapter: inbound.NewAdapter(C.TypeBond, tag), logger: logger, router: uot.NewRouter(router, logger), - conns: cache.New(C.TCPConnectTimeout, time.Second), - mtx: kmutex.New[string](), + conns: cache.New[uuid.UUID, map[uint8]*ratioConn](C.TCPConnectTimeout, time.Second), + mtx: kmutex.New[uuid.UUID](), } + router = NewRouter(router, logger, inbound.connHandler) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) inbounds := make([]adapter.Inbound, len(options.Inbounds)) for i, inboundOptions := range options.Inbounds { - inbound, err := inboundRegistry.UnsafeCreate(ctx, NewRouter(router, logger, inbound.connHandler), logger, inboundOptions.Tag, inboundOptions.Type, inboundOptions.Options) + inbound, err := inboundRegistry.UnsafeCreate(ctx, router, logger, inboundOptions.Tag, inboundOptions.Type, inboundOptions.Options) if err != nil { return nil, err } inbounds[i] = inbound } inbound.inbounds = inbounds - inbound.conns.OnEvicted(func(s string, i interface{}) { + inbound.conns.OnEvicted(func(s uuid.UUID, ratioConns map[uint8]*ratioConn) { inbound.mtx.Lock(s) defer inbound.mtx.Unlock(s) - ratioConns := i.(map[uint8]*ratioConn) for _, ratioConn := range ratioConns { if ratioConn != nil { ratioConn.conn.Close() @@ -93,21 +94,15 @@ func (h *Inbound) Close() error { } func (h *Inbound) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { - if metadata.Destination != Destination { - h.router.RouteConnectionEx(ctx, conn, metadata, onClose) - return nil - } request, err := ReadRequest(conn) if err != nil { return err } - requestUUID := request.UUID.String() + requestUUID := request.UUID h.mtx.Lock(requestUUID) var ratioConns map[uint8]*ratioConn - rawRatioConns, ok := h.conns.Get(requestUUID) - if ok { - ratioConns = rawRatioConns.(map[uint8]*ratioConn) - } else { + ratioConns, ok := h.conns.Get(requestUUID) + if !ok { ratioConns = make(map[uint8]*ratioConn, request.Count) h.conns.SetDefault(requestUUID, ratioConns) } diff --git a/protocol/bond/router.go b/protocol/bond/router.go index 04ea5a7d..f2ba1283 100644 --- a/protocol/bond/router.go +++ b/protocol/bond/router.go @@ -21,14 +21,24 @@ func NewRouter(router adapter.Router, logger logger.ContextLogger, handler func( } func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + if metadata.Destination != Destination { + return r.Router.RouteConnection(ctx, conn, metadata) + } return r.handler(ctx, conn, metadata, func(error) {}) } func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + if metadata.Destination != Destination { + return r.Router.RoutePacketConnection(ctx, conn, metadata) + } return os.ErrInvalid } func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + if metadata.Destination != Destination { + r.Router.RouteConnectionEx(ctx, conn, metadata, onClose) + return + } if err := r.handler(ctx, conn, metadata, onClose); err != nil { r.logger.ErrorContext(ctx, err) N.CloseOnHandshakeFailure(conn, onClose, err) @@ -36,6 +46,10 @@ func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata } func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + if metadata.Destination != Destination { + r.Router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) + return + } r.logger.ErrorContext(ctx, os.ErrInvalid) N.CloseOnHandshakeFailure(conn, onClose, os.ErrInvalid) } diff --git a/protocol/group/failover.go b/protocol/group/fallback.go similarity index 77% rename from protocol/group/failover.go rename to protocol/group/fallback.go index c0362163..57353362 100644 --- a/protocol/group/failover.go +++ b/protocol/group/fallback.go @@ -17,15 +17,15 @@ import ( "github.com/sagernet/sing/service" ) -func RegisterFailover(registry *outbound.Registry) { - outbound.Register[option.FailoverOutboundOptions](registry, C.TypeFailover, NewFailover) +func RegisterFallback(registry *outbound.Registry) { + outbound.Register[option.FallbackOutboundOptions](registry, C.TypeFallback, NewFallback) } var ( - _ adapter.OutboundGroup = (*Failover)(nil) + _ adapter.OutboundGroup = (*Fallback)(nil) ) -type Failover struct { +type Fallback struct { outbound.Adapter ctx context.Context outbound adapter.OutboundManager @@ -37,12 +37,12 @@ type Failover struct { mtx sync.Mutex } -func NewFailover(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FailoverOutboundOptions) (adapter.Outbound, error) { +func NewFallback(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.FallbackOutboundOptions) (adapter.Outbound, error) { if len(options.Outbounds) == 0 { return nil, E.New("missing tags") } - outbound := &Failover{ - Adapter: outbound.NewAdapter(C.TypeFailover, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds), + outbound := &Fallback{ + Adapter: outbound.NewAdapter(C.TypeFallback, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds), ctx: ctx, outbound: service.FromContext[adapter.OutboundManager](ctx), logger: logger, @@ -53,7 +53,7 @@ func NewFailover(ctx context.Context, router adapter.Router, logger log.ContextL return outbound, nil } -func (s *Failover) Start() error { +func (s *Fallback) Start() error { for i, tag := range s.tags { outbound, loaded := s.outbound.Outbound(tag) if !loaded { @@ -64,17 +64,17 @@ func (s *Failover) Start() error { return nil } -func (s *Failover) Now() string { +func (s *Fallback) Now() string { s.mtx.Lock() defer s.mtx.Unlock() return s.lastUsedOutbound } -func (s *Failover) All() []string { +func (s *Fallback) All() []string { return s.tags } -func (s *Failover) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { +func (s *Fallback) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { var conn net.Conn var err error for _, outbound := range s.outbounds { @@ -91,7 +91,7 @@ func (s *Failover) DialContext(ctx context.Context, network string, destination return nil, err } -func (s *Failover) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { +func (s *Fallback) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { var conn net.PacketConn var err error for _, outbound := range s.outbounds { diff --git a/protocol/group/selector.go b/protocol/group/selector.go index 8a686e5b..29d56560 100644 --- a/protocol/group/selector.go +++ b/protocol/group/selector.go @@ -3,6 +3,7 @@ package group import ( "context" "net" + "regexp" "time" "github.com/sagernet/sing-box/adapter" @@ -42,11 +43,20 @@ type Selector struct { selected common.TypedValue[adapter.Outbound] interruptGroup *interrupt.Group interruptExternalConnections bool + + provider adapter.ProviderManager + providers map[string]adapter.Provider + outboundsCache map[string][]adapter.Outbound + + providerTags []string + exclude *regexp.Regexp + include *regexp.Regexp + useAllProviders bool } func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (adapter.Outbound, error) { outbound := &Selector{ - Adapter: outbound.NewAdapter(C.TypeSelector, tag, nil, options.Outbounds), + Adapter: outbound.NewAdapter(C.TypeSelector, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.Outbounds), ctx: ctx, outbound: service.FromContext[adapter.OutboundManager](ctx), connection: service.FromContext[adapter.ConnectionManager](ctx), @@ -56,9 +66,15 @@ func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextL outbounds: make(map[string]adapter.Outbound), interruptGroup: interrupt.NewGroup(), interruptExternalConnections: options.InterruptExistConnections, - } - if len(outbound.tags) == 0 { - return nil, E.New("missing tags") + + provider: service.FromContext[adapter.ProviderManager](ctx), + providers: make(map[string]adapter.Provider), + outboundsCache: make(map[string][]adapter.Outbound), + + providerTags: options.Providers, + exclude: (*regexp.Regexp)(options.Exclude), + include: (*regexp.Regexp)(options.Include), + useAllProviders: options.UseAllProviders, } return outbound, nil } @@ -72,6 +88,28 @@ func (s *Selector) Network() []string { } func (s *Selector) Start() error { + if s.useAllProviders { + var providerTags []string + for _, provider := range s.provider.Providers() { + providerTags = append(providerTags, provider.Tag()) + s.providers[provider.Tag()] = provider + provider.RegisterCallback(s.onProviderUpdated) + } + s.providerTags = providerTags + } else { + for i, tag := range s.providerTags { + provider, loaded := s.provider.Get(tag) + if !loaded { + return E.New("outbound provider ", i, " not found: ", tag) + } + s.providers[tag] = provider + provider.RegisterCallback(s.onProviderUpdated) + } + } + if len(s.tags)+len(s.providerTags) == 0 { + return E.New("missing outbound and provider tags") + } + for i, tag := range s.tags { detour, loaded := s.outbound.Outbound(tag) if !loaded { @@ -79,31 +117,16 @@ func (s *Selector) Start() error { } s.outbounds[tag] = detour } - - if s.Tag() != "" { - cacheFile := service.FromContext[adapter.CacheFile](s.ctx) - if cacheFile != nil { - selected := cacheFile.LoadSelected(s.Tag()) - if selected != "" { - detour, loaded := s.outbounds[selected] - if loaded { - s.selected.Store(detour) - return nil - } - } - } + if len(s.tags) == 0 { + detour, _ := s.outbound.Outbound("Compatible") + s.tags = append(s.tags, detour.Tag()) + s.outbounds[detour.Tag()] = detour } - - if s.defaultTag != "" { - detour, loaded := s.outbounds[s.defaultTag] - if !loaded { - return E.New("default outbound not found: ", s.defaultTag) - } - s.selected.Store(detour) - return nil + outbound, err := s.outboundSelect() + if err != nil { + return err } - - s.selected.Store(s.outbounds[s.tags[0]]) + s.selected.Store(outbound) return nil } @@ -145,7 +168,7 @@ func (s *Selector) DialContext(ctx context.Context, network string, destination if err != nil { return nil, err } - return s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil + return s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil } func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { @@ -153,13 +176,13 @@ func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (n if err != nil { return nil, err } - return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil + return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil } func (s *Selector) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) selected := s.selected.Load() - conn = s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)) + conn = s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)) if outboundHandler, isHandler := selected.(adapter.ConnectionHandlerEx); isHandler { outboundHandler.NewConnectionEx(ctx, conn, metadata, onClose) } else { @@ -170,7 +193,7 @@ func (s *Selector) NewConnectionEx(ctx context.Context, conn net.Conn, metadata func (s *Selector) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) selected := s.selected.Load() - conn = s.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)) + conn = s.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)) if outboundHandler, isHandler := selected.(adapter.PacketConnectionHandlerEx); isHandler { outboundHandler.NewPacketConnectionEx(ctx, conn, metadata, onClose) } else { @@ -192,3 +215,77 @@ func RealTag(detour adapter.Outbound) string { } return detour.Tag() } + +func (s *Selector) onProviderUpdated(tag string) error { + _, loaded := s.providers[tag] + if !loaded { + return E.New(s.Tag(), ": ", "outbound provider not found: ", tag) + } + var ( + tags = s.Dependencies() + outboundByTag = make(map[string]adapter.Outbound) + ) + for _, tag := range tags { + outboundByTag[tag] = s.outbounds[tag] + } + for _, providerTag := range s.providerTags { + if providerTag != tag && s.outboundsCache[providerTag] != nil { + for _, detour := range s.outboundsCache[providerTag] { + tags = append(tags, detour.Tag()) + outboundByTag[detour.Tag()] = detour + } + continue + } + provider := s.providers[providerTag] + var cache []adapter.Outbound + for _, detour := range provider.Outbounds() { + tag := detour.Tag() + if s.exclude != nil && s.exclude.MatchString(tag) { + continue + } + if s.include != nil && !s.include.MatchString(tag) { + continue + } + tags = append(tags, tag) + cache = append(cache, detour) + outboundByTag[tag] = detour + } + s.outboundsCache[providerTag] = cache + } + if len(tags) == 0 { + detour, _ := s.outbound.Outbound("Compatible") + tags = append(tags, detour.Tag()) + outboundByTag[detour.Tag()] = detour + } + s.tags, s.outbounds = tags, outboundByTag + detour, _ := s.outboundSelect() + if s.selected.Swap(detour) != detour { + s.interruptGroup.Interrupt(s.interruptExternalConnections) + } + return nil +} + +func (s *Selector) outboundSelect() (adapter.Outbound, error) { + if s.Tag() != "" { + cacheFile := service.FromContext[adapter.CacheFile](s.ctx) + if cacheFile != nil { + selected := cacheFile.LoadSelected(s.Tag()) + if selected != "" { + detour, loaded := s.outbounds[selected] + if loaded { + return detour, nil + } + } + } + } + + if s.defaultTag != "" { + detour, loaded := s.outbounds[s.defaultTag] + if !loaded { + return nil, E.New("default outbound not found: ", s.defaultTag) + } + return detour, nil + } + + return s.outbounds[s.tags[0]], nil +} diff --git a/protocol/group/urltest.go b/protocol/group/urltest.go index 91964aa0..4b20c629 100644 --- a/protocol/group/urltest.go +++ b/protocol/group/urltest.go @@ -3,6 +3,7 @@ package group import ( "context" "net" + "regexp" "sync" "sync/atomic" "time" @@ -45,6 +46,16 @@ type URLTest struct { idleTimeout time.Duration group *URLTestGroup interruptExternalConnections bool + + provider adapter.ProviderManager + providers map[string]adapter.Provider + outboundsCache map[string][]adapter.Outbound + cancel context.CancelFunc + + providerTags []string + exclude *regexp.Regexp + include *regexp.Regexp + useAllProviders bool } func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (adapter.Outbound, error) { @@ -61,14 +72,42 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo tolerance: options.Tolerance, idleTimeout: time.Duration(options.IdleTimeout), interruptExternalConnections: options.InterruptExistConnections, - } - if len(outbound.tags) == 0 { - return nil, E.New("missing tags") + + provider: service.FromContext[adapter.ProviderManager](ctx), + providers: make(map[string]adapter.Provider), + outboundsCache: make(map[string][]adapter.Outbound), + + providerTags: options.Providers, + exclude: (*regexp.Regexp)(options.Exclude), + include: (*regexp.Regexp)(options.Include), + useAllProviders: options.UseAllProviders, } return outbound, nil } func (s *URLTest) Start() error { + if s.useAllProviders { + var providerTags []string + for _, provider := range s.provider.Providers() { + providerTags = append(providerTags, provider.Tag()) + s.providers[provider.Tag()] = provider + provider.RegisterCallback(s.onProviderUpdated) + } + s.providerTags = providerTags + } else { + for i, tag := range s.providerTags { + provider, loaded := s.provider.Get(tag) + if !loaded { + return E.New("outbound provider ", i, " not found: ", tag) + } + s.providers[tag] = provider + provider.RegisterCallback(s.onProviderUpdated) + } + } + if len(s.tags)+len(s.providerTags) == 0 { + return E.New("missing outbound and provider tags") + } + outbounds := make([]adapter.Outbound, 0, len(s.tags)) for i, tag := range s.tags { detour, loaded := s.outbound.Outbound(tag) @@ -77,6 +116,11 @@ func (s *URLTest) Start() error { } outbounds = append(outbounds, detour) } + if len(s.tags) == 0 { + detour, _ := s.outbound.Outbound("Compatible") + s.tags = append(s.tags, detour.Tag()) + outbounds = append(outbounds, detour) + } group, err := NewURLTestGroup(s.ctx, s.outbound, s.logger, outbounds, s.link, s.interval, s.tolerance, s.idleTimeout, s.interruptExternalConnections) if err != nil { return err @@ -136,7 +180,7 @@ func (s *URLTest) DialContext(ctx context.Context, network string, destination M } conn, err := outbound.DialContext(ctx, network, destination) if err == nil { - return s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil + return s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil } s.logger.ErrorContext(ctx, err) s.group.history.DeleteURLTestHistory(outbound.Tag()) @@ -154,7 +198,7 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne } conn, err := outbound.ListenPacket(ctx, destination) if err == nil { - return s.group.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil + return s.group.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)), nil } s.logger.ErrorContext(ctx, err) s.group.history.DeleteURLTestHistory(outbound.Tag()) @@ -163,13 +207,13 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne func (s *URLTest) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) - conn = s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)) + conn = s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)) s.connection.NewConnection(ctx, s, conn, metadata, onClose) } func (s *URLTest) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) - conn = s.group.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)) + conn = s.group.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx), interrupt.IsProviderConnectionFromContext(ctx)) s.connection.NewPacketConnection(ctx, s, conn, metadata, onClose) } @@ -188,6 +232,63 @@ func (s *URLTest) NewDirectRouteConnection(metadata adapter.InboundContext, rout return selected.(adapter.DirectRouteOutbound).NewDirectRouteConnection(metadata, routeContext, timeout) } +func (s *URLTest) onProviderUpdated(tag string) error { + _, loaded := s.providers[tag] + if !loaded { + return E.New("outbound provider not found: ", tag) + } + var ( + tags = s.Dependencies() + outbounds []adapter.Outbound + ) + for _, tag := range tags { + detour, _ := s.outbound.Outbound(tag) + outbounds = append(outbounds, detour) + } + for _, providerTag := range s.providerTags { + if providerTag != tag && s.outboundsCache[providerTag] != nil { + for _, detour := range s.outboundsCache[providerTag] { + tags = append(tags, detour.Tag()) + outbounds = append(outbounds, detour) + } + continue + } + provider := s.providers[providerTag] + var cache []adapter.Outbound + for _, detour := range provider.Outbounds() { + tag := detour.Tag() + if s.exclude != nil && s.exclude.MatchString(tag) { + continue + } + if s.include != nil && !s.include.MatchString(tag) { + continue + } + tags = append(tags, tag) + cache = append(cache, detour) + } + outbounds = append(outbounds, cache...) + s.outboundsCache[providerTag] = cache + } + if len(tags) == 0 { + detour, _ := s.outbound.Outbound("Compatible") + tags = append(tags, detour.Tag()) + outbounds = append(outbounds, detour) + } + s.tags, s.group.outbounds = tags, outbounds + s.group.access.Lock() + if s.group.ticker != nil { + s.group.ticker.Reset(s.group.interval) + } + s.group.access.Unlock() + ctx, cancel := context.WithCancel(s.ctx) + if s.cancel != nil { + s.cancel() + } + s.cancel = cancel + s.URLTest(ctx) + return nil +} + type URLTestGroup struct { ctx context.Context router adapter.Router @@ -407,7 +508,11 @@ func (g *URLTestGroup) urlTest(ctx context.Context, force bool) (map[string]uint }) } b.Wait() - g.performUpdateCheck() + select { + case <-ctx.Done(): + default: + g.performUpdateCheck() + } return result, nil } diff --git a/protocol/limiter/bandwidth/conn.go b/protocol/limiter/bandwidth/conn.go new file mode 100644 index 00000000..06796e27 --- /dev/null +++ b/protocol/limiter/bandwidth/conn.go @@ -0,0 +1,119 @@ +package bandwidth + +import ( + "context" + "net" + + "golang.org/x/time/rate" +) + +type connWithDownloadBandwidthLimiter struct { + net.Conn + ctx context.Context + limiter Limiter +} + +func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithDownloadBandwidthLimiter { + return &connWithDownloadBandwidthLimiter{conn, ctx, limiter} +} + +func (conn *connWithDownloadBandwidthLimiter) Write(p []byte) (n int, err error) { + err = conn.limiter.WaitN(conn.ctx, len(p)) + if err != nil { + return + } + return conn.Conn.Write(p) +} + +type connWithUploadBandwidthLimiter struct { + net.Conn + ctx context.Context + limiter *rate.Limiter + burst int +} + +func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithUploadBandwidthLimiter { + return &connWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()} +} + +func (conn *connWithUploadBandwidthLimiter) Read(p []byte) (n int, err error) { + n, err = conn.Conn.Read(p) + if err != nil { + return + } + err = conn.limiter.WaitN(conn.ctx, n) + if err != nil { + return + } + return n, err +} + +type connWithCloseHandler struct { + net.Conn + onClose CloseHandlerFunc +} + +func NewConnWithCloseHandler(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandler { + return &connWithCloseHandler{conn, onClose} +} + +func (conn *connWithCloseHandler) Close() error { + conn.onClose() + return conn.Conn.Close() +} + +type packetConnWithDownloadBandwidthLimiter struct { + net.PacketConn + ctx context.Context + limiter *rate.Limiter + burst int +} + +func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithDownloadBandwidthLimiter { + return &packetConnWithDownloadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()} +} + +func (conn *packetConnWithDownloadBandwidthLimiter) WriteTo(p []byte, addr net.Addr) (n int, err error) { + err = conn.limiter.WaitN(conn.ctx, len(p)) + if err != nil { + return + } + return conn.PacketConn.WriteTo(p, addr) +} + +type packetConnWithUploadBandwidthLimiter struct { + net.PacketConn + ctx context.Context + limiter Limiter + burst int +} + +func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithUploadBandwidthLimiter { + return &packetConnWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()} +} + +func (conn *packetConnWithUploadBandwidthLimiter) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, addr, err = conn.PacketConn.ReadFrom(p) + if err != nil { + return + } + err = conn.limiter.WaitN(conn.ctx, n) + if err != nil { + return + } + return +} + +type packetConnWithCloseHandler struct { + net.PacketConn + onClose CloseHandlerFunc +} + +func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandler { + return &packetConnWithCloseHandler{conn, onClose} +} + +func (conn *packetConnWithCloseHandler) Close() error { + conn.onClose() + return conn.PacketConn.Close() +} diff --git a/protocol/limiter/bandwidth/limiter.go b/protocol/limiter/bandwidth/limiter.go index 404b9c15..95655b3c 100644 --- a/protocol/limiter/bandwidth/limiter.go +++ b/protocol/limiter/bandwidth/limiter.go @@ -2,157 +2,8 @@ package bandwidth import ( "context" - "net" - - "golang.org/x/time/rate" ) -type connWithDownloadBandwidthLimiter struct { - net.Conn - ctx context.Context - limiter *rate.Limiter - burst int -} - -func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithDownloadBandwidthLimiter { - return &connWithDownloadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()} -} - -func (conn *connWithDownloadBandwidthLimiter) Write(p []byte) (n int, err error) { - var nn int - for { - end := len(p) - if end == 0 { - break - } - if conn.burst < len(p) { - end = conn.burst - } - err = conn.limiter.WaitN(conn.ctx, end) - if err != nil { - return - } - nn, err = conn.Conn.Write(p[:end]) - n += nn - if err != nil { - return - } - p = p[end:] - } - return -} - -type connWithUploadBandwidthLimiter struct { - net.Conn - ctx context.Context - limiter *rate.Limiter - burst int -} - -func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter *rate.Limiter) *connWithUploadBandwidthLimiter { - return &connWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()} -} - -func (conn *connWithUploadBandwidthLimiter) Read(p []byte) (n int, err error) { - if conn.burst < len(p) { - p = p[:conn.burst] - } - n, err = conn.Conn.Read(p) - if err != nil { - return - } - err = conn.limiter.WaitN(conn.ctx, n) - if err != nil { - return - } - return -} - -type connWithCloseHandler struct { - net.Conn - onClose CloseHandlerFunc -} - -func NewConnWithCloseHandler(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandler { - return &connWithCloseHandler{conn, onClose} -} - -func (conn *connWithCloseHandler) Close() error { - conn.onClose() - return conn.Conn.Close() -} - -type packetConnWithDownloadBandwidthLimiter struct { - net.PacketConn - ctx context.Context - limiter *rate.Limiter - burst int -} - -func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithDownloadBandwidthLimiter { - return &packetConnWithDownloadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()} -} - -func (conn *packetConnWithDownloadBandwidthLimiter) WriteTo(p []byte, addr net.Addr) (n int, err error) { - var nn int - for { - end := len(p) - if end == 0 { - break - } - if conn.burst < len(p) { - end = conn.burst - } - err = conn.limiter.WaitN(conn.ctx, end) - if err != nil { - return - } - nn, err = conn.PacketConn.WriteTo(p[:end], addr) - n += nn - if err != nil { - return - } - p = p[end:] - } - return -} - -type packetConnWithUploadBandwidthLimiter struct { - net.PacketConn - ctx context.Context - limiter *rate.Limiter - burst int -} - -func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter *rate.Limiter) *packetConnWithUploadBandwidthLimiter { - return &packetConnWithUploadBandwidthLimiter{conn, ctx, limiter, limiter.Burst()} -} - -func (conn *packetConnWithUploadBandwidthLimiter) ReadFrom(p []byte) (n int, addr net.Addr, err error) { - if conn.burst < len(p) { - p = p[:conn.burst] - } - n, addr, err = conn.PacketConn.ReadFrom(p) - if err != nil { - return - } - err = conn.limiter.WaitN(conn.ctx, n) - if err != nil { - return - } - return -} - -type packetConnWithCloseHandler struct { - net.PacketConn - onClose CloseHandlerFunc -} - -func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandler { - return &packetConnWithCloseHandler{conn, onClose} -} - -func (conn *packetConnWithCloseHandler) Close() error { - conn.onClose() - return conn.PacketConn.Close() +type Limiter interface { + WaitN(ctx context.Context, n int) (err error) } diff --git a/protocol/limiter/bandwidth/strategy.go b/protocol/limiter/bandwidth/strategy.go index bc46ee21..92db1f64 100644 --- a/protocol/limiter/bandwidth/strategy.go +++ b/protocol/limiter/bandwidth/strategy.go @@ -226,7 +226,7 @@ func CreateStrategy(strategy string, mode string, connectionType string, speed u } func createSpeedLimiter(speed uint64) *rate.Limiter { - return rate.NewLimiter(rate.Limit(float64(speed)), 10000) + return rate.NewLimiter(rate.Limit(float64(speed)), 65536) } func connWithDownloadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter *rate.Limiter, reverse bool) net.Conn { diff --git a/protocol/masque/config.go b/protocol/masque/config.go new file mode 100644 index 00000000..11aa52da --- /dev/null +++ b/protocol/masque/config.go @@ -0,0 +1,89 @@ +package masque + +import ( + "crypto/ecdsa" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "fmt" + "net" +) + +type Config struct { + PrivateKey string `json:"private_key"` // Base64-encoded ECDSA private key + EndpointV4 string `json:"endpoint_v4"` // IPv4 address of the endpoint + EndpointV6 string `json:"endpoint_v6"` // IPv6 address of the endpoint + EndpointH2V4 string `json:"endpoint_h2_v4"` // IPv4 address used in HTTP/2 mode + EndpointH2V6 string `json:"endpoint_h2_v6"` // IPv6 address used in HTTP/2 mode + EndpointPubKey string `json:"endpoint_pub_key"` // PEM-encoded ECDSA public key of the endpoint to verify against + License string `json:"license"` // Application license key + ID string `json:"id"` // Device unique identifier + AccessToken string `json:"access_token"` // Authentication token for API access + IPv4 string `json:"ipv4"` // Assigned IPv4 address + IPv6 string `json:"ipv6"` // Assigned IPv6 address +} + +func (c *Config) GetEcPrivateKey() (*ecdsa.PrivateKey, error) { + privKeyB64, err := base64.StdEncoding.DecodeString(c.PrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to decode private key: %v", err) + } + privKey, err := x509.ParseECPrivateKey(privKeyB64) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %v", err) + } + return privKey, nil +} + +func (c *Config) GetEcEndpointPublicKey() (*ecdsa.PublicKey, error) { + endpointPubKeyB64, _ := pem.Decode([]byte(c.EndpointPubKey)) + if endpointPubKeyB64 == nil { + return nil, fmt.Errorf("failed to decode endpoint public key") + } + + pubKey, err := x509.ParsePKIXPublicKey(endpointPubKeyB64.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %v", err) + } + + ecPubKey, ok := pubKey.(*ecdsa.PublicKey) + if !ok { + return nil, fmt.Errorf("failed to assert public key as ECDSA") + } + + return ecPubKey, nil +} + +func (c *Config) SelectEndpointFromConfig(useHTTP2 bool, useIPv6 bool, port int) (net.Addr, error) { + if useHTTP2 { + if useIPv6 { + if c.EndpointH2V6 == "" { + return nil, fmt.Errorf("--http2 with --ipv6 requires config endpoint_h2_v6 to be set") + } + ip := net.ParseIP(c.EndpointH2V6) + if ip == nil { + return nil, fmt.Errorf("invalid endpoint_h2_v6 value %q", c.EndpointH2V6) + } + + return &net.TCPAddr{IP: ip, Port: port}, nil + } + v4 := c.EndpointH2V4 + ip := net.ParseIP(v4) + if ip == nil { + return nil, fmt.Errorf("invalid endpoint_h2_v4 value %q") + } + return &net.TCPAddr{IP: ip, Port: port}, nil + } + if useIPv6 { + ip := net.ParseIP(c.EndpointV6) + if ip == nil { + return nil, fmt.Errorf("invalid endpoint_v6 value %q", c.EndpointV6) + } + return &net.UDPAddr{IP: ip, Port: port}, nil + } + ip := net.ParseIP(c.EndpointV4) + if ip == nil { + return nil, fmt.Errorf("invalid endpoint_v4 value %q", c.EndpointV4) + } + return &net.UDPAddr{IP: ip, Port: port}, nil +} diff --git a/protocol/masque/outbound.go b/protocol/masque/outbound.go new file mode 100644 index 00000000..7d64f24d --- /dev/null +++ b/protocol/masque/outbound.go @@ -0,0 +1,300 @@ +package masque + +import ( + "context" + "encoding/base64" + "encoding/json" + "net" + "net/netip" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/outbound" + "github.com/sagernet/sing-box/common/cloudflare" + "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/common/tls" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/transport/masque" + "github.com/sagernet/sing/common/bufio" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +func RegisterOutbound(registry *outbound.Registry) { + outbound.Register[option.MASQUEOutboundOptions](registry, C.TypeMASQUE, NewOutbound) +} + +type Outbound struct { + outbound.Adapter + ctx context.Context + dnsRouter adapter.DNSRouter + logger logger.ContextLogger + options option.MASQUEOutboundOptions + tunnel *masque.Tunnel + startHandler func() + + await chan struct{} +} + +func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MASQUEOutboundOptions) (adapter.Outbound, error) { + outbound := &Outbound{ + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeMASQUE, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMP}, options.DialerOptions), + ctx: ctx, + dnsRouter: service.FromContext[adapter.DNSRouter](ctx), + logger: logger, + options: options, + await: make(chan struct{}), + } + outbound.startHandler = func() { + defer close(outbound.await) + cacheFile := service.FromContext[adapter.CacheFile](ctx) + var appConfig *Config + var err error + if !options.Profile.Recreate && cacheFile != nil && cacheFile.StoreMASQUEConfig() { + savedProfile := cacheFile.LoadMASQUEConfig(tag) + if savedProfile != nil { + if err = json.Unmarshal(savedProfile.Content, &appConfig); err != nil { + logger.ErrorContext(ctx, err) + return + } + } + } + if appConfig == nil { + appConfig, err = outbound.createConfig() + if err != nil { + logger.ErrorContext(ctx, err) + return + } + if cacheFile != nil && cacheFile.StoreMASQUEConfig() { + content, err := json.Marshal(appConfig) + if err != nil { + logger.ErrorContext(ctx, err) + return + } + cacheFile.SaveMASQUEConfig(tag, &adapter.SavedBinary{ + LastUpdated: time.Now(), + Content: content, + LastEtag: "", + }) + } + } + privKey, err := appConfig.GetEcPrivateKey() + if err != nil { + logger.ErrorContext(ctx, E.New("failed to get private key: ", err)) + return + } + peerPubKey, err := appConfig.GetEcEndpointPublicKey() + if err != nil { + logger.ErrorContext(ctx, E.New("failed to get public key: ", err)) + return + } + cert, err := masque.GenerateCert(privKey, &privKey.PublicKey) + if err != nil { + logger.ErrorContext(ctx, E.New("failed to generate cert: ", err)) + return + } + tlsConfig, err := tls.NewMASQUEClient(ctx, logger, "consumer-masque.cloudflareclient.com", cert, privKey, peerPubKey, options.MASQUEOutboundTLSOptions) + if err != nil { + logger.ErrorContext(ctx, E.New("failed to prepare TLS config: ", err)) + return + } + endpoint, err := appConfig.SelectEndpointFromConfig(options.UseHTTP2, options.UseIPv6, 443) + if err != nil { + logger.ErrorContext(ctx, E.New("failed to select endpoint: ", err)) + return + } + var udpTimeout time.Duration + if options.UDPTimeout != 0 { + udpTimeout = time.Duration(options.UDPTimeout) + } else { + udpTimeout = C.UDPTimeout + } + var udpKeepalivePeriod time.Duration + if options.UDPKeepalivePeriod != 0 { + udpKeepalivePeriod = time.Duration(options.UDPKeepalivePeriod) + } else { + udpKeepalivePeriod = time.Second * 30 + } + outboundDialer, err := dialer.NewWithOptions(dialer.Options{ + Context: ctx, + Options: options.DialerOptions, + RemoteIsDomain: false, + ResolverOnDetour: true, + }) + if err != nil { + logger.ErrorContext(ctx, err) + return + } + tunnel, err := masque.NewTunnel( + ctx, + logger, + masque.TunnelOptions{ + Dialer: outboundDialer, + Address: []netip.Prefix{ + netip.MustParsePrefix(appConfig.IPv4 + "/32"), + netip.MustParsePrefix(appConfig.IPv6 + "/128"), + }, + Endpoint: endpoint, + TLSConfig: tlsConfig, + UseHTTP2: options.UseHTTP2, + UDPTimeout: udpTimeout, + UDPKeepalivePeriod: udpKeepalivePeriod, + UDPInitialPacketSize: options.UDPInitialPacketSize, + ReconnectDelay: options.ReconnectDelay.Build(), + }) + if err != nil { + logger.ErrorContext(ctx, err) + return + } + outbound.tunnel = tunnel + if err = outbound.tunnel.Start(false); err != nil { + logger.ErrorContext(ctx, err) + return + } + if err = outbound.tunnel.Start(true); err != nil { + logger.ErrorContext(ctx, err) + return + } + } + return outbound, nil +} + +func (w *Outbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStatePostStart { + return nil + } + go w.startHandler() + return nil +} + +func (w *Outbound) Close() error { + if err := w.isTunnelInitialized(w.ctx); err != nil { + return err + } + return w.tunnel.Close() +} + +func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if err := w.isTunnelInitialized(ctx); err != nil { + return nil, err + } + switch network { + case N.NetworkTCP: + w.logger.InfoContext(ctx, "outbound connection to ", destination) + case N.NetworkUDP: + w.logger.InfoContext(ctx, "outbound packet connection to ", destination) + } + if destination.IsDomain() { + destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) + if err != nil { + return nil, err + } + return N.DialSerial(ctx, w.tunnel, network, destination, destinationAddresses) + } else if !destination.Addr.IsValid() { + return nil, E.New("invalid destination: ", destination) + } + return w.tunnel.DialContext(ctx, network, destination) +} + +func (w *Outbound) ListenPacketWithDestination(ctx context.Context, destination M.Socksaddr) (net.PacketConn, netip.Addr, error) { + if err := w.isTunnelInitialized(ctx); err != nil { + return nil, netip.Addr{}, err + } + w.logger.InfoContext(ctx, "outbound packet connection to ", destination) + if destination.IsDomain() { + destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{}) + if err != nil { + return nil, netip.Addr{}, err + } + return N.ListenSerial(ctx, w.tunnel, destination, destinationAddresses) + } + packetConn, err := w.tunnel.ListenPacket(ctx, destination) + if err != nil { + return nil, netip.Addr{}, err + } + if destination.IsIP() { + return packetConn, destination.Addr, nil + } + return packetConn, netip.Addr{}, nil +} + +func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + packetConn, destinationAddress, err := w.ListenPacketWithDestination(ctx, destination) + if err != nil { + return nil, err + } + if destinationAddress.IsValid() && destination != M.SocksaddrFrom(destinationAddress, destination.Port) { + return bufio.NewNATPacketConn(bufio.NewPacketConn(packetConn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil + } + return packetConn, nil +} + +func (w *Outbound) isTunnelInitialized(ctx context.Context) error { + select { + case <-w.await: + case <-ctx.Done(): + return ctx.Err() + } + if w.tunnel == nil { + return E.New("tunnel not initialized") + } + return nil +} + +func (w *Outbound) createConfig() (*Config, error) { + opts := make([]cloudflare.CloudflareApiOption, 0, 1) + if w.options.Profile.Detour != "" { + detour, ok := service.FromContext[adapter.OutboundManager](w.ctx).Outbound(w.options.Profile.Detour) + if !ok { + return nil, E.New("outbound detour not found: ", w.options.Profile.Detour) + } + opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) { + return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) + })) + } + api := cloudflare.NewCloudflareApi(opts...) + var profile *cloudflare.CloudflareProfile + var err error + if w.options.Profile.AuthToken != "" && w.options.Profile.ID != "" { + profile, err = api.GetProfile(w.ctx, w.options.Profile.AuthToken, w.options.Profile.ID) + if err != nil { + return nil, err + } + } else { + wgPrivateKey, err := wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + profile, err = api.CreateProfile(w.ctx, wgPrivateKey.PublicKey().String()) + if err != nil { + return nil, err + } + } + privateKey, publicKey, err := masque.GenerateEcKeyPair() + if err != nil { + return nil, E.New("failed to generate key pair: ", err) + } + updatedProfile, err := api.EnrollKey(w.ctx, profile.Token, profile.ID, cloudflare.KeyTypeMasque, cloudflare.TunTypeMasque, base64.StdEncoding.EncodeToString(publicKey)) + if err != nil { + return nil, err + } + return &Config{ + PrivateKey: base64.StdEncoding.EncodeToString(privateKey), + EndpointV4: updatedProfile.Config.Peers[0].Endpoint.V4[:len(updatedProfile.Config.Peers[0].Endpoint.V4)-2], + EndpointV6: updatedProfile.Config.Peers[0].Endpoint.V6[1 : len(updatedProfile.Config.Peers[0].Endpoint.V6)-3], + EndpointH2V4: cloudflare.DefaultEndpointH2V4, + EndpointH2V6: cloudflare.DefaultEndpointH2V6, + EndpointPubKey: updatedProfile.Config.Peers[0].PublicKey, + License: updatedProfile.Account.License, + ID: updatedProfile.ID, + AccessToken: profile.Token, + IPv4: updatedProfile.Config.Interface.Addresses.V4, + IPv6: updatedProfile.Config.Interface.Addresses.V6, + }, nil +} diff --git a/protocol/mtproxy/dialer.go b/protocol/mtproxy/dialer.go new file mode 100644 index 00000000..aba39f63 --- /dev/null +++ b/protocol/mtproxy/dialer.go @@ -0,0 +1,38 @@ +package mtproxy + +import ( + "context" + "net" + + "github.com/sagernet/sing-box/adapter" + M "github.com/sagernet/sing/common/metadata" +) + +type Dialer struct { + handler adapter.ConnectionHandlerFuncEx +} + +func NewDialer(handler adapter.ConnectionHandlerFuncEx) *Dialer { + return &Dialer{handler} +} + +func (d *Dialer) Dial(network, address string) (net.Conn, error) { + return d.DialContext(context.Background(), network, address) +} + +func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + inConn, outConn := net.Pipe() + var metadata adapter.InboundContext + if streamContext, ok := ctx.(streamContext); ok { + metadata.Source = M.SocksaddrFromNet(streamContext.ClientAddr()) + metadata.User = streamContext.SecretName() + } + metadata.Destination = M.ParseSocksaddr(address) + d.handler(ctx, inConn, metadata, func(error) {}) + return outConn, nil +} + +type streamContext interface { + ClientAddr() net.Addr + SecretName() string +} diff --git a/protocol/mtproxy/inbound.go b/protocol/mtproxy/inbound.go new file mode 100644 index 00000000..48f829a5 --- /dev/null +++ b/protocol/mtproxy/inbound.go @@ -0,0 +1,132 @@ +package mtproxy + +import ( + "context" + "net" + + "github.com/dolonet/mtg-multi/antireplay" + "github.com/dolonet/mtg-multi/events" + "github.com/dolonet/mtg-multi/ipblocklist" + "github.com/dolonet/mtg-multi/mtglib" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/common/listener" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + N "github.com/sagernet/sing/common/network" +) + +func RegisterInbound(registry *inbound.Registry) { + inbound.Register[option.MTProxyInboundOptions](registry, C.TypeMTProxy, NewInbound) +} + +type Inbound struct { + inbound.Adapter + ctx context.Context + router adapter.ConnectionRouterEx + logger logger.ContextLogger + listener *listener.Listener + proxy *mtglib.Proxy +} + +func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.MTProxyInboundOptions) (adapter.Inbound, error) { + inbound := &Inbound{ + Adapter: inbound.NewAdapter(C.TypeMTProxy, tag), + ctx: ctx, + router: router, + logger: logger, + listener: listener.New(listener.Options{ + Context: ctx, + Logger: logger, + Listen: options.ListenOptions, + }), + } + mtgLogger := NewLoggerAdapter(logger) + secrets := make(map[string]mtglib.Secret, len(options.Users)) + for _, user := range options.Users { + secret := mtglib.Secret{} + err := secret.Set(user.Secret) + if err != nil { + return nil, err + } + secrets[user.Name] = secret + } + opts := mtglib.ProxyOpts{ + Logger: mtgLogger, + Network: NewNetworkAdapter(ctx, NewDialer(inbound.newConnection)), + AntiReplayCache: antireplay.NewNoop(), + IPBlocklist: ipblocklist.NewNoop(), + IPAllowlist: ipblocklist.NewNoop(), + EventStream: events.NewNoopStream(), + + Secrets: secrets, + Concurrency: options.GetConcurrency(), + DomainFrontingPort: options.GetDomainFrontingPort(), + DomainFrontingIP: options.DomainFrontingIP, + DomainFrontingProxyProtocol: options.DomainFrontingProxyProtocol, + PreferIP: options.GetPreferIP(), + AutoUpdate: options.AutoUpdate, + + AllowFallbackOnUnknownDC: options.AllowFallbackOnUnknownDC, + TolerateTimeSkewness: options.TolerateTimeSkewness.Build(), + IdleTimeout: options.GetIdleTimeout(), + HandshakeTimeout: options.GetHandshakeTimeout(), + + DoppelGangerURLs: options.DoppelGangerURLs, + DoppelGangerPerRaid: options.GetDoppelGangerPerRaid(), + DoppelGangerEach: options.GetDoppelGangerEach(), + DoppelGangerDRS: options.DoppelGangerDRS, + + ThrottleMaxConnections: options.ThrottleMaxConnections, + ThrottleCheckInterval: options.GetThrottleCheckInterval(), + } + proxy, err := mtglib.NewProxy(opts) + if err != nil { + return nil, E.New("cannot create a proxy: ", err) + } + inbound.proxy = proxy + return inbound, nil +} + +func (n *Inbound) Start(stage adapter.StartStage) error { + if stage != adapter.StartStateStart { + return nil + } + listener, err := n.listener.ListenTCP() + if err != nil { + return err + } + go n.proxy.Serve(listener) + return nil +} + +func (n *Inbound) Close() error { + n.proxy.Shutdown() + return common.Close( + &n.listener, + ) +} + +func (h *Inbound) UpdateUsers(users []option.MTProxyUser) { + secrets := make(map[string]mtglib.Secret, len(users)) + for _, user := range users { + secret := mtglib.Secret{} + err := secret.Set(user.Secret) + if err != nil { + return + } + secrets[user.Name] = secret + } + h.proxy.UpdateUsers(secrets) +} + +func (h *Inbound) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + metadata.Inbound = h.Tag() + metadata.InboundType = h.Type() + h.logger.InfoContext(ctx, "[", metadata.User, "] inbound connection to ", metadata.Destination) + h.router.RouteConnectionEx(ctx, conn, metadata, onClose) +} diff --git a/protocol/mtproxy/logger.go b/protocol/mtproxy/logger.go new file mode 100644 index 00000000..351fc6a5 --- /dev/null +++ b/protocol/mtproxy/logger.go @@ -0,0 +1,60 @@ +package mtproxy + +import ( + "fmt" + + "github.com/dolonet/mtg-multi/mtglib" + "github.com/sagernet/sing/common/logger" +) + +type LoggerAdapter struct { + logger logger.Logger +} + +func NewLoggerAdapter(logger logger.Logger) *LoggerAdapter { + return &LoggerAdapter{logger} +} + +func (l *LoggerAdapter) Named(name string) mtglib.Logger { + return l +} + +func (l *LoggerAdapter) BindInt(name string, value int) mtglib.Logger { + return l +} + +func (l *LoggerAdapter) BindStr(name, value string) mtglib.Logger { + return l +} + +func (l *LoggerAdapter) BindJSON(name, value string) mtglib.Logger { + return l +} + +func (l *LoggerAdapter) Printf(format string, args ...any) { + l.logger.Info(fmt.Sprintf(format, args...)) +} + +func (l *LoggerAdapter) Info(msg string) { + l.logger.Info(msg) +} + +func (l *LoggerAdapter) InfoError(msg string, err error) { + l.logger.Error(msg, err) +} + +func (l *LoggerAdapter) Warning(msg string) { + l.logger.Warn(msg) +} + +func (l *LoggerAdapter) WarningError(msg string, err error) { + l.logger.Warn(msg, err) +} + +func (l *LoggerAdapter) Debug(msg string) { + l.logger.Debug(msg) +} + +func (l *LoggerAdapter) DebugError(msg string, err error) { + l.logger.Debug(msg, err) +} diff --git a/protocol/mtproxy/network.go b/protocol/mtproxy/network.go new file mode 100644 index 00000000..a14f8f42 --- /dev/null +++ b/protocol/mtproxy/network.go @@ -0,0 +1,43 @@ +package mtproxy + +import ( + "context" + "net" + "net/http" + + "github.com/dolonet/mtg-multi/essentials" +) + +type NetworkAdapter struct { + ctx context.Context + dialer essentials.Dialer +} + +func NewNetworkAdapter(ctx context.Context, dialer essentials.Dialer) *NetworkAdapter { + return &NetworkAdapter{ctx, dialer} +} + +func (a *NetworkAdapter) Dial(network, address string) (essentials.Conn, error) { + return a.DialContext(a.ctx, network, address) +} + +func (a *NetworkAdapter) DialContext(ctx context.Context, network, address string) (essentials.Conn, error) { + conn, err := a.dialer.DialContext(ctx, network, address) + if err != nil { + return nil, err + } + return essentials.WrapNetConn(conn), nil +} + +func (a *NetworkAdapter) MakeHTTPClient(func(ctx context.Context, network, address string) (essentials.Conn, error)) *http.Client { + return &http.Client{ + Timeout: 10, + Transport: &http.Transport{DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return a.DialContext(ctx, network, addr) + }}, + } +} + +func (a *NetworkAdapter) NativeDialer() essentials.Dialer { + return a.dialer +} diff --git a/protocol/parser/outbound.go b/protocol/parser/outbound.go new file mode 100644 index 00000000..2f16f277 --- /dev/null +++ b/protocol/parser/outbound.go @@ -0,0 +1,38 @@ +package parser + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/outbound" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/parser/link" + + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/service" +) + +func RegisterOutbound(registry *outbound.Registry) { + outbound.Register[option.ParserOutboundOptions](registry, C.TypeParser, NewOutbound) +} + +func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.ParserOutboundOptions) (adapter.Outbound, error) { + if options.Link == "" { + return nil, E.New("missing link") + } + outboundOptions, err := link.ParseSubscriptionLink(options.Link) + if err != nil { + return nil, err + } + if dialerOptions, ok := outboundOptions.Options.(option.DialerOptionsWrapper); ok { + dialerOptions.ReplaceDialerOptions(options.DialerOptions) + } + outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) + outbound, err := outboundRegistry.UnsafeCreateOutbound(ctx, router, logger, tag, outboundOptions.Type, outboundOptions.Options) + if err != nil { + return nil, err + } + return outbound, nil +} diff --git a/protocol/relay/outbound.go b/protocol/relay/outbound.go new file mode 100644 index 00000000..e69de29b diff --git a/protocol/tailscale/dns_transport.go b/protocol/tailscale/dns_transport.go index 3a92a66b..4195235c 100644 --- a/protocol/tailscale/dns_transport.go +++ b/protocol/tailscale/dns_transport.go @@ -49,6 +49,7 @@ type DNSTransport struct { dnsRouter adapter.DNSRouter endpointManager adapter.EndpointManager endpoint *Endpoint + access sync.RWMutex routePrefixes []netip.Prefix routes map[string][]adapter.DNSTransport hosts map[string][]netip.Addr @@ -91,6 +92,12 @@ func (t *DNSTransport) Start(stage adapter.StartStage) error { } func (t *DNSTransport) Reset() { + t.access.RLock() + transports := t.collectResolversLocked() + t.access.RUnlock() + for _, transport := range transports { + transport.Reset() + } } func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *nDNS.Config) { @@ -101,7 +108,7 @@ func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, d } func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *nDNS.Config) error { - t.routePrefixes = buildRoutePrefixes(routeConfig) + routePrefixes := buildRoutePrefixes(routeConfig) directDialerOnce := sync.OnceValue(func() N.Dialer { directDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{})) return &DNSDialer{transport: t, fallbackDialer: directDialer} @@ -130,9 +137,19 @@ func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *n } defaultResolvers = append(defaultResolvers, myResolver) } + + t.access.Lock() + oldResolvers := t.collectResolversLocked() + t.routePrefixes = routePrefixes t.routes = routes t.hosts = hosts t.defaultResolvers = defaultResolvers + t.access.Unlock() + + for _, transport := range oldResolvers { + transport.Close() + } + if len(defaultResolvers) > 0 { t.logger.Info("updated ", len(routes), " routes, ", len(hosts), " hosts, default resolvers: ", strings.Join(common.Map(dnsConfig.DefaultResolvers, func(it *dnstype.Resolver) string { return it.Addr }), " ")) @@ -207,7 +224,22 @@ func buildRoutePrefixes(routeConfig *router.Config) []netip.Prefix { } func (t *DNSTransport) Close() error { - return nil + t.access.Lock() + transports := t.collectResolversLocked() + t.routePrefixes = nil + t.routes = nil + t.hosts = nil + t.defaultResolvers = nil + t.access.Unlock() + + var err error + for _, transport := range transports { + name := "resolver/" + transport.Type() + "[" + transport.Tag() + "]" + err = E.Append(err, transport.Close(), func(err error) error { + return E.Cause(err, "close ", name) + }) + } + return err } func (t *DNSTransport) Raw() bool { @@ -219,7 +251,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M return nil, os.ErrInvalid } question := message.Question[0] - addresses, hostsLoaded := t.hosts[question.Name] + + t.access.RLock() + hosts := t.hosts + routes := t.routes + defaultResolvers := t.defaultResolvers + acceptDefaultResolvers := t.acceptDefaultResolvers + t.access.RUnlock() + + addresses, hostsLoaded := hosts[question.Name] if hostsLoaded { switch question.Qtype { case mDNS.TypeA: @@ -238,7 +278,7 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M } } } - for domainSuffix, transports := range t.routes { + for domainSuffix, transports := range routes { if strings.HasSuffix(question.Name, domainSuffix) { if len(transports) == 0 { return &mDNS.Msg{ @@ -262,10 +302,10 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M return nil, lastErr } } - if t.acceptDefaultResolvers { - if len(t.defaultResolvers) > 0 { + if acceptDefaultResolvers { + if len(defaultResolvers) > 0 { var lastErr error - for _, resolver := range t.defaultResolvers { + for _, resolver := range defaultResolvers { response, err := resolver.Exchange(ctx, message) if err != nil { lastErr = err @@ -281,6 +321,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M return nil, dns.RcodeNameError } +func (t *DNSTransport) collectResolversLocked() []adapter.DNSTransport { + var transports []adapter.DNSTransport + for _, resolvers := range t.routes { + transports = append(transports, resolvers...) + } + transports = append(transports, t.defaultResolvers...) + return transports +} + type DNSDialer struct { transport *DNSTransport fallbackDialer N.Dialer @@ -290,7 +339,8 @@ func (d *DNSDialer) DialContext(ctx context.Context, network string, destination if destination.IsDomain() { panic("invalid request here") } - for _, prefix := range d.transport.routePrefixes { + routePrefixes := d.transport.routePrefixesSnapshot() + for _, prefix := range routePrefixes { if prefix.Contains(destination.Addr) { return d.transport.endpoint.DialContext(ctx, network, destination) } @@ -302,10 +352,17 @@ func (d *DNSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) ( if destination.IsDomain() { panic("invalid request here") } - for _, prefix := range d.transport.routePrefixes { + routePrefixes := d.transport.routePrefixesSnapshot() + for _, prefix := range routePrefixes { if prefix.Contains(destination.Addr) { return d.transport.endpoint.ListenPacket(ctx, destination) } } return d.fallbackDialer.ListenPacket(ctx, destination) } + +func (t *DNSTransport) routePrefixesSnapshot() []netip.Prefix { + t.access.RLock() + defer t.access.RUnlock() + return append([]netip.Prefix(nil), t.routePrefixes...) +} diff --git a/protocol/tunnel/protocol.go b/protocol/tunnel/protocol.go deleted file mode 100644 index 19e6f1cd..00000000 --- a/protocol/tunnel/protocol.go +++ /dev/null @@ -1,91 +0,0 @@ -package tunnel - -import ( - "encoding/binary" - "io" - - "github.com/gofrs/uuid/v5" - "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/buf" - E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" -) - -const ( - Version = 0 -) - -const ( - CommandInbound = 1 - CommandTCP = 2 -) - -var Destination = M.Socksaddr{ - Fqdn: "sp.tunnel.sing-box.arpa", - Port: 444, -} - -var AddressSerializer = M.NewSerializer( - M.AddressFamilyByte(0x01, M.AddressFamilyIPv4), - M.AddressFamilyByte(0x03, M.AddressFamilyIPv6), - M.AddressFamilyByte(0x02, M.AddressFamilyFqdn), - M.PortThenAddress(), -) - -type Request struct { - UUID uuid.UUID - Command byte - DestinationUUID uuid.UUID - Destination M.Socksaddr -} - -func ReadRequest(reader io.Reader) (*Request, error) { - var request Request - var version uint8 - err := binary.Read(reader, binary.BigEndian, &version) - if err != nil { - return nil, err - } - if version != Version { - return nil, E.New("unknown version: ", version) - } - _, err = io.ReadFull(reader, request.UUID[:]) - if err != nil { - return nil, err - } - err = binary.Read(reader, binary.BigEndian, &request.Command) - if err != nil { - return nil, err - } - _, err = io.ReadFull(reader, request.DestinationUUID[:]) - if err != nil { - return nil, err - } - request.Destination, err = AddressSerializer.ReadAddrPort(reader) - if err != nil { - return nil, err - } - return &request, nil -} - -func WriteRequest(writer io.Writer, request *Request) error { - var requestLen int - requestLen += 1 // version - requestLen += 16 // UUID - requestLen += 16 // destinationUUID - requestLen += 1 // command - requestLen += AddressSerializer.AddrPortLen(request.Destination) - buffer := buf.NewSize(requestLen) - defer buffer.Release() - common.Must( - buffer.WriteByte(Version), - common.Error(buffer.Write(request.UUID[:])), - buffer.WriteByte(request.Command), - common.Error(buffer.Write(request.DestinationUUID[:])), - ) - err := AddressSerializer.WriteAddrPort(buffer, request.Destination) - if err != nil { - return err - } - return common.Error(writer.Write(buffer.Bytes())) -} diff --git a/protocol/tunnel/server.go b/protocol/tunnel/server.go deleted file mode 100644 index a43254b7..00000000 --- a/protocol/tunnel/server.go +++ /dev/null @@ -1,201 +0,0 @@ -package tunnel - -import ( - "context" - "net" - "time" - - "github.com/gofrs/uuid/v5" - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/adapter/endpoint" - "github.com/sagernet/sing-box/adapter/outbound" - sbUot "github.com/sagernet/sing-box/common/uot" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/log" - "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/uot" - "github.com/sagernet/sing/service" -) - -func RegisterServerEndpoint(registry *endpoint.Registry) { - endpoint.Register[option.TunnelServerEndpointOptions](registry, C.TypeTunnelServer, NewServerEndpoint) -} - -type ServerEndpoint struct { - outbound.Adapter - logger logger.ContextLogger - inbound adapter.Inbound - router adapter.ConnectionRouterEx - uuid uuid.UUID - users map[uuid.UUID]uuid.UUID - keys map[uuid.UUID]uuid.UUID - conns map[uuid.UUID]chan net.Conn - timeout time.Duration - uotClient *uot.Client -} - -func NewServerEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunnelServerEndpointOptions) (adapter.Endpoint, error) { - serverUUID, err := uuid.FromString(options.UUID) - if err != nil { - return nil, err - } - server := &ServerEndpoint{ - Adapter: outbound.NewAdapter(C.TypeTunnelServer, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}), - logger: logger, - router: sbUot.NewRouter(router, logger), - uuid: serverUUID, - } - inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) - inbound, err := inboundRegistry.Create(ctx, NewRouter(router, logger, server.connHandler), logger, options.Inbound.Tag, options.Inbound.Type, options.Inbound.Options) - if err != nil { - return nil, err - } - server.inbound = inbound - server.users = make(map[uuid.UUID]uuid.UUID, len(options.Users)) - server.keys = make(map[uuid.UUID]uuid.UUID, len(options.Users)) - server.conns = make(map[uuid.UUID]chan net.Conn) - for _, user := range options.Users { - key, err := uuid.FromString(user.Key) - if err != nil { - return nil, err - } - uuid, err := uuid.FromString(user.UUID) - if err != nil { - return nil, err - } - server.users[key] = uuid - server.keys[uuid] = key - server.conns[uuid] = make(chan net.Conn, 10) - } - if options.ConnectTimeout != 0 { - server.timeout = time.Duration(options.ConnectTimeout) - } else { - server.timeout = C.TCPConnectTimeout - } - server.uotClient = &uot.Client{ - Dialer: server, - Version: uot.Version, - } - return server, nil -} - -func (s *ServerEndpoint) Start(stage adapter.StartStage) error { - return s.inbound.Start(stage) -} - -func (s *ServerEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - if N.NetworkName(network) == N.NetworkUDP { - return s.uotClient.DialContext(ctx, network, destination) - } - var sourceUUID *uuid.UUID - var ch chan net.Conn - if metadata := adapter.ContextFrom(ctx); metadata != nil { - if metadata.TunnelDestination != "" { - tunnelDestination, err := uuid.FromString(metadata.TunnelDestination) - if err != nil { - return nil, err - } - var ok bool - ch, ok = s.conns[tunnelDestination] - if !ok { - return nil, E.New("user ", metadata.TunnelDestination, " not found") - } - } - if metadata.TunnelSource != "" { - tunnelSource, err := uuid.FromString(metadata.TunnelSource) - if err != nil { - return nil, err - } - sourceUUID = &tunnelSource - } - } - if ch == nil { - return nil, E.New("tunnel destination not set") - } - if sourceUUID == nil { - sourceUUID = &s.uuid - } - ctx, cancel := context.WithTimeout(ctx, s.timeout) - defer cancel() - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - } - select { - case conn := <-ch: - err := WriteRequest(conn, &Request{UUID: *sourceUUID, Command: CommandTCP, Destination: destination}) - if err != nil { - conn.Close() - s.logger.ErrorContext(ctx, err) - continue - } - return conn, nil - case <-ctx.Done(): - return nil, ctx.Err() - } - } -} - -func (s *ServerEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - return s.uotClient.ListenPacket(ctx, destination) -} - -func (s *ServerEndpoint) Close() error { - return common.Close(s.inbound) -} - -func (s *ServerEndpoint) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { - if metadata.Destination != Destination { - s.router.RouteConnectionEx(ctx, conn, metadata, onClose) - return nil - } - request, err := ReadRequest(conn) - if err != nil { - return err - } - if request.Command == CommandInbound { - uuid, ok := s.users[request.UUID] - if !ok { - return E.New("key ", request.UUID.String(), " not found") - } - ch := s.conns[uuid] - select { - case ch <- conn: - default: - oldConn := <-ch - oldConn.Close() - ch <- conn - } - return nil - } - if request.Command == CommandTCP { - sourceUUID, ok := s.users[request.UUID] - if !ok { - return E.New("key ", request.UUID, " not found") - } - if sourceUUID == request.DestinationUUID { - return E.New("routing loop on ", sourceUUID) - } - if request.DestinationUUID != s.uuid { - _, ok = s.keys[request.DestinationUUID] - if !ok { - return E.New("user ", request.DestinationUUID, " not found") - } - } - metadata.Inbound = s.Tag() - metadata.InboundType = C.TypeTunnelServer - metadata.Destination = request.Destination - metadata.TunnelSource = sourceUUID.String() - metadata.TunnelDestination = request.DestinationUUID.String() - s.router.RouteConnectionEx(ctx, conn, metadata, onClose) - return nil - } - return E.New("command ", request.Command, " not found") -} diff --git a/protocol/tunnel/client.go b/protocol/vpn/client.go similarity index 65% rename from protocol/tunnel/client.go rename to protocol/vpn/client.go index d00cdcbf..15007081 100644 --- a/protocol/tunnel/client.go +++ b/protocol/vpn/client.go @@ -1,8 +1,9 @@ -package tunnel +package vpn import ( "context" "net" + "net/netip" "time" "github.com/gofrs/uuid/v5" @@ -23,7 +24,7 @@ import ( ) func RegisterClientEndpoint(registry *endpoint.Registry) { - endpoint.Register[option.TunnelClientEndpointOptions](registry, C.TypeTunnelClient, NewClientEndpoint) + endpoint.Register[option.VPNClientEndpointOptions](registry, C.TypeVPNClient, NewClientEndpoint) } type ClientEndpoint struct { @@ -32,27 +33,27 @@ type ClientEndpoint struct { outbound adapter.Outbound router adapter.ConnectionRouterEx logger logger.ContextLogger - uuid uuid.UUID + address IPv4 key uuid.UUID uotClient *uot.Client } -func NewClientEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunnelClientEndpointOptions) (adapter.Endpoint, error) { - clientUUID, err := uuid.FromString(options.UUID) - if err != nil { - return nil, err +func NewClientEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VPNClientEndpointOptions) (adapter.Endpoint, error) { + address := options.Address + if !address.Is4() { + return nil, E.New("invalid address: ", address) } - clientKey, err := uuid.FromString(options.Key) + key, err := uuid.FromString(options.Key) if err != nil { return nil, err } client := &ClientEndpoint{ - Adapter: outbound.NewAdapter(C.TypeTunnelClient, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}), + Adapter: outbound.NewAdapter(C.TypeVPNClient, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}), ctx: ctx, router: sbUot.NewRouter(router, logger), logger: logger, - uuid: clientUUID, - key: clientKey, + address: address.As4(), + key: key, } outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx) outbound, err := outboundRegistry.CreateOutbound(ctx, router, logger, options.Outbound.Tag, options.Outbound.Type, options.Outbound.Options) @@ -94,27 +95,31 @@ func (c *ClientEndpoint) DialContext(ctx context.Context, network string, destin if N.NetworkName(network) == N.NetworkUDP { return c.uotClient.DialContext(ctx, network, destination) } - var destinationUUID *uuid.UUID - if metadata := adapter.ContextFrom(ctx); metadata != nil { - if metadata.TunnelDestination != "" { - uuid, err := uuid.FromString(metadata.TunnelDestination) - if err != nil { - return nil, err - } - destinationUUID = &uuid - } - } - if destinationUUID == nil { - return nil, E.New("tunnel destination not set") - } - if *destinationUUID == c.uuid { - return nil, E.New("routing loop") + if destination.Addr.Is4() && destination.Addr.As4() == c.address { + return nil, E.New("routing loop on ", destination.Addr) } conn, err := c.outbound.DialContext(ctx, N.NetworkTCP, Destination) if err != nil { return nil, err } - err = WriteRequest(conn, &Request{UUID: c.key, Command: CommandTCP, DestinationUUID: *destinationUUID, Destination: destination}) + gateway := Loopback.As4() + if metadata := adapter.ContextFrom(ctx); metadata != nil { + if metadata.Gateway != nil { + gateway = metadata.Gateway.As4() + if gateway == c.address { + return nil, E.New("routing loop on ", destination.Addr) + } + } + } + err = WriteClientRequest( + conn, + &ClientRequest{ + Key: c.key, + Command: CommandTCP, + Gateway: gateway, + Destination: destination, + }, + ) if err != nil { return nil, err } @@ -134,29 +139,22 @@ func (c *ClientEndpoint) startInboundConn() error { if err != nil { return err } - err = WriteRequest(conn, &Request{UUID: c.key, Command: CommandInbound, Destination: Destination}) + err = WriteClientRequest(conn, &ClientRequest{Key: c.key, Command: CommandInbound, Destination: Destination}) if err != nil { return err } - request, err := ReadRequest(conn) + request, err := ReadServerRequest(conn) if err != nil { return err } - go c.connHandler(conn, request) - return nil -} - -func (c *ClientEndpoint) connHandler(conn net.Conn, request *Request) { + if request.Source == c.address { + return E.New("routing loop") + } metadata := adapter.InboundContext{ Inbound: c.Tag(), - Source: M.ParseSocksaddr(conn.RemoteAddr().String()), + Source: M.Socksaddr{Addr: netip.AddrFrom4(request.Source)}, Destination: request.Destination, } - if request.UUID == c.uuid { - c.logger.ErrorContext(c.ctx, "routing loop") - conn.Close() - return - } - metadata.TunnelSource = request.UUID.String() - c.router.RouteConnectionEx(c.ctx, conn, metadata, func(it error) {}) + go c.router.RouteConnectionEx(c.ctx, conn, metadata, func(it error) {}) + return nil } diff --git a/protocol/vpn/protocol.go b/protocol/vpn/protocol.go new file mode 100644 index 00000000..f6b3e210 --- /dev/null +++ b/protocol/vpn/protocol.go @@ -0,0 +1,124 @@ +package vpn + +import ( + "encoding/binary" + "io" + "net/netip" + + "github.com/gofrs/uuid/v5" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" +) + +const ( + Version = 0 +) + +const ( + CommandInbound = 1 + CommandTCP = 2 +) + +type IPv4 [4]byte + +var Destination = M.Socksaddr{ + Fqdn: "sp.vpn.sing-box.arpa", + Port: 444, +} + +var Loopback = netip.AddrFrom4([4]byte{127, 0, 0, 1}) + +var AddressSerializer = M.NewSerializer( + M.AddressFamilyByte(0x01, M.AddressFamilyIPv4), + M.AddressFamilyByte(0x03, M.AddressFamilyIPv6), + M.AddressFamilyByte(0x02, M.AddressFamilyFqdn), + M.PortThenAddress(), +) + +type ClientRequest struct { + Key uuid.UUID + Command byte + Gateway IPv4 + Destination M.Socksaddr +} + +func ReadClientRequest(reader io.Reader) (*ClientRequest, error) { + var request ClientRequest + var version uint8 + err := binary.Read(reader, binary.BigEndian, &version) + if err != nil { + return nil, err + } + if version != Version { + return nil, E.New("unknown version: ", version) + } + _, err = io.ReadFull(reader, request.Key[:]) + if err != nil { + return nil, err + } + err = binary.Read(reader, binary.BigEndian, &request.Command) + if err != nil { + return nil, err + } + _, err = io.ReadFull(reader, request.Gateway[:]) + if err != nil { + return nil, err + } + request.Destination, err = AddressSerializer.ReadAddrPort(reader) + if err != nil { + return nil, err + } + return &request, nil +} + +func WriteClientRequest(writer io.Writer, request *ClientRequest) error { + var requestLen int + requestLen += 1 // version + requestLen += 16 // key + requestLen += 1 // command + requestLen += 4 // gateway + requestLen += AddressSerializer.AddrPortLen(request.Destination) + buffer := buf.NewSize(requestLen) + defer buffer.Release() + common.Must( + buffer.WriteByte(Version), + common.Error(buffer.Write(request.Key[:])), + buffer.WriteByte(request.Command), + common.Error(buffer.Write(request.Gateway[:])), + AddressSerializer.WriteAddrPort(buffer, request.Destination), + ) + return common.Error(writer.Write(buffer.Bytes())) +} + +type ServerRequest struct { + Source IPv4 + Destination M.Socksaddr +} + +func ReadServerRequest(reader io.Reader) (*ServerRequest, error) { + var request ServerRequest + _, err := io.ReadFull(reader, request.Source[:]) + if err != nil { + return nil, err + } + request.Destination, err = AddressSerializer.ReadAddrPort(reader) + if err != nil { + return nil, err + } + return &request, nil +} + +func WriteServerRequest(writer io.Writer, request *ServerRequest) error { + var requestLen int + requestLen += 4 // source + requestLen += AddressSerializer.AddrPortLen(request.Destination) + buffer := buf.NewSize(requestLen) + defer buffer.Release() + common.Must( + common.Error(buffer.Write(request.Source[:])), + AddressSerializer.WriteAddrPort(buffer, request.Destination), + ) + return common.Error(writer.Write(buffer.Bytes())) +} diff --git a/protocol/tunnel/router.go b/protocol/vpn/router.go similarity index 75% rename from protocol/tunnel/router.go rename to protocol/vpn/router.go index e2e708b5..5e0d2a87 100644 --- a/protocol/tunnel/router.go +++ b/protocol/vpn/router.go @@ -1,4 +1,4 @@ -package tunnel +package vpn import ( "context" @@ -21,14 +21,24 @@ func NewRouter(router adapter.Router, logger logger.ContextLogger, handler func( } func (r *Router) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + if metadata.Destination != Destination { + return r.Router.RouteConnection(ctx, conn, metadata) + } return r.handler(ctx, conn, metadata, func(error) {}) } func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + if metadata.Destination != Destination { + return r.Router.RoutePacketConnection(ctx, conn, metadata) + } return os.ErrInvalid } func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + if metadata.Destination != Destination { + r.Router.RouteConnectionEx(ctx, conn, metadata, onClose) + return + } if err := r.handler(ctx, conn, metadata, onClose); err != nil { r.logger.ErrorContext(ctx, err) N.CloseOnHandshakeFailure(conn, onClose, err) @@ -36,6 +46,10 @@ func (r *Router) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata } func (r *Router) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { + if metadata.Destination != Destination { + r.Router.RoutePacketConnectionEx(ctx, conn, metadata, onClose) + return + } r.logger.ErrorContext(ctx, os.ErrInvalid) N.CloseOnHandshakeFailure(conn, onClose, os.ErrInvalid) } diff --git a/protocol/vpn/server.go b/protocol/vpn/server.go new file mode 100644 index 00000000..872dd83e --- /dev/null +++ b/protocol/vpn/server.go @@ -0,0 +1,235 @@ +package vpn + +import ( + "context" + "errors" + "net" + "net/netip" + "time" + + "github.com/gofrs/uuid/v5" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/endpoint" + "github.com/sagernet/sing-box/adapter/outbound" + sbUot "github.com/sagernet/sing-box/common/uot" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/uot" + "github.com/sagernet/sing/service" +) + +func RegisterServerEndpoint(registry *endpoint.Registry) { + endpoint.Register[option.VPNServerEndpointOptions](registry, C.TypeVPNServer, NewServerEndpoint) +} + +type ServerEndpoint struct { + outbound.Adapter + logger logger.ContextLogger + inbounds []adapter.Inbound + router adapter.ConnectionRouterEx + address IPv4 + addresses map[uuid.UUID]IPv4 + keys map[IPv4]uuid.UUID + conns map[IPv4]chan net.Conn + timeout time.Duration + uotClient *uot.Client +} + +func NewServerEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VPNServerEndpointOptions) (adapter.Endpoint, error) { + address := options.Address + if !address.Is4() { + return nil, E.New("invalid address: ", address) + } + server := &ServerEndpoint{ + Adapter: outbound.NewAdapter(C.TypeVPNServer, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}), + logger: logger, + router: sbUot.NewRouter(router, logger), + address: address.As4(), + } + router = NewRouter(router, logger, server.connHandler) + inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) + inbounds := make([]adapter.Inbound, len(options.Inbounds)) + for i, inboundOptions := range options.Inbounds { + inbound, err := inboundRegistry.Create(ctx, router, logger, inboundOptions.Tag, inboundOptions.Type, inboundOptions.Options) + if err != nil { + return nil, err + } + inbounds[i] = inbound + } + server.inbounds = inbounds + server.addresses = make(map[uuid.UUID]IPv4, len(options.Users)) + server.keys = make(map[IPv4]uuid.UUID, len(options.Users)) + server.conns = make(map[IPv4]chan net.Conn) + for _, user := range options.Users { + key, err := uuid.FromString(user.Key) + if err != nil { + return nil, err + } + if !user.Address.Is4() { + return nil, E.New("invalid address: ", user.Address) + } + address := user.Address.As4() + server.addresses[key] = address + server.keys[address] = key + server.conns[address] = make(chan net.Conn, 10) + } + if options.ConnectTimeout != 0 { + server.timeout = time.Duration(options.ConnectTimeout) + } else { + server.timeout = C.TCPConnectTimeout + } + server.uotClient = &uot.Client{ + Dialer: server, + Version: uot.Version, + } + return server, nil +} + +func (s *ServerEndpoint) Start(stage adapter.StartStage) error { + for _, inbound := range s.inbounds { + err := inbound.Start(stage) + if err != nil { + return err + } + } + return nil +} + +func (s *ServerEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if N.NetworkName(network) == N.NetworkUDP { + return s.uotClient.DialContext(ctx, network, destination) + } + source := s.address + var gateway *netip.Addr + if metadata := adapter.ContextFrom(ctx); metadata != nil { + if metadata.Source.IsIPv4() { + address := metadata.Source.Addr.As4() + if _, ok := s.conns[address]; ok { + source = address + } + } + if metadata.Gateway != nil { + gateway = metadata.Gateway + } + } + if gateway == nil { + if destination.IsIPv4() { + gateway = &destination.Addr + destination = M.Socksaddr{ + Addr: Loopback, + Port: destination.Port, + } + } else { + return nil, E.New("missing gateway") + } + } else if destination.Addr.Compare(*gateway) == 0 { + destination = M.Socksaddr{ + Addr: Loopback, + Port: destination.Port, + } + } + if gateway.Compare(Loopback) == 0 { + return nil, E.New("invalid gateway") + } + ch, ok := s.conns[gateway.As4()] + if !ok { + return nil, E.New("user with address ", gateway, " not found") + } + ctx, cancel := context.WithTimeout(ctx, s.timeout) + defer cancel() + for { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + select { + case conn := <-ch: + err := WriteServerRequest(conn, &ServerRequest{Source: source, Destination: destination}) + if err != nil { + conn.Close() + s.logger.ErrorContext(ctx, err) + continue + } + return conn, nil + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func (s *ServerEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + return s.uotClient.ListenPacket(ctx, destination) +} + +func (s *ServerEndpoint) Close() error { + errs := make([]error, 0) + for _, inbound := range s.inbounds { + err := inbound.Close() + if err != nil { + errs = append(errs, err) + } + } + if len(errs) != 0 { + return errors.Join(errs...) + } + return nil +} + +func (s *ServerEndpoint) connHandler(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) error { + if metadata.Destination != Destination { + s.router.RouteConnectionEx(ctx, conn, metadata, onClose) + return nil + } + request, err := ReadClientRequest(conn) + if err != nil { + return err + } + if request.Command == CommandInbound { + address, ok := s.addresses[request.Key] + if !ok { + return E.New("key ", request.Key.String(), " not found") + } + ch := s.conns[address] + select { + case ch <- conn: + default: + oldConn := <-ch + oldConn.Close() + ch <- conn + } + return nil + } + if request.Command == CommandTCP { + source, ok := s.addresses[request.Key] + if !ok { + return E.New("key ", request.Key, " not found") + } + if request.Destination.Addr.Is4() && source == request.Destination.Addr.As4() { + return E.New("routing loop on ", request.Destination) + } + metadata.Inbound = s.Tag() + metadata.InboundType = C.TypeVPNServer + metadata.Source = M.Socksaddr{Addr: netip.AddrFrom4(source)} + if request.Destination.Addr.Is4() && request.Destination.Addr.As4() == s.address { + metadata.Destination = M.Socksaddr{ + Addr: Loopback, + Port: request.Destination.Port, + } + } else { + metadata.Destination = request.Destination + if request.Gateway != s.address && request.Gateway != Loopback.As4() { + addr := netip.AddrFrom4(request.Gateway) + metadata.Gateway = &addr + } + } + s.router.RouteConnectionEx(ctx, conn, metadata, onClose) + return nil + } + return E.New("command ", request.Command, " not found") +} diff --git a/protocol/warp/config.go b/protocol/warp/config.go new file mode 100644 index 00000000..eadd9824 --- /dev/null +++ b/protocol/warp/config.go @@ -0,0 +1,14 @@ +package warp + +import "github.com/sagernet/sing-box/common/cloudflare" + +type Config struct { + PrivateKey string `json:"private_key"` + Interface struct { + Addresses struct { + V4 string `json:"v4"` + V6 string `json:"v6"` + } `json:"addresses"` + } `json:"interface"` + Peers []cloudflare.Peer +} diff --git a/protocol/wireguard/endpoint_warp.go b/protocol/warp/endpoint.go similarity index 53% rename from protocol/wireguard/endpoint_warp.go rename to protocol/warp/endpoint.go index d44e3f34..8b417813 100644 --- a/protocol/wireguard/endpoint_warp.go +++ b/protocol/warp/endpoint.go @@ -1,4 +1,4 @@ -package wireguard +package warp import ( "context" @@ -7,7 +7,6 @@ import ( "net" "net/netip" "strings" - "sync" "time" "github.com/sagernet/sing-box/adapter" @@ -16,6 +15,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/protocol/wireguard" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json/badoption" @@ -25,19 +25,21 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -func RegisterWARPEndpoint(registry *endpoint.Registry) { - endpoint.Register[option.WireGuardWARPEndpointOptions](registry, C.TypeWARP, NewWARPEndpoint) +func RegisterEndpoint(registry *endpoint.Registry) { + endpoint.Register[option.WARPEndpointOptions](registry, C.TypeWARP, NewEndpoint) } -type WARPEndpoint struct { +type Endpoint struct { endpoint.Adapter + ctx context.Context + options option.WARPEndpointOptions endpoint adapter.Endpoint startHandler func() - mtx sync.Mutex + await chan struct{} } -func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardWARPEndpointOptions) (adapter.Endpoint, error) { +func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WARPEndpointOptions) (adapter.Endpoint, error) { var dependencies []string if options.Detour != "" { dependencies = append(dependencies, options.Detour) @@ -45,14 +47,16 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont if options.Profile.Detour != "" { dependencies = append(dependencies, options.Profile.Detour) } - warpEndpoint := &WARPEndpoint{ + endpoint := &Endpoint{ Adapter: endpoint.NewAdapter(C.TypeWARP, tag, []string{N.NetworkTCP, N.NetworkUDP}, dependencies), + ctx: ctx, + options: options, + await: make(chan struct{}), } - warpEndpoint.mtx.Lock() - warpEndpoint.startHandler = func() { - defer warpEndpoint.mtx.Unlock() + endpoint.startHandler = func() { + defer close(endpoint.await) cacheFile := service.FromContext[adapter.CacheFile](ctx) - var config *C.WARPConfig + var config *Config var err error if !options.Profile.Recreate && cacheFile != nil && cacheFile.StoreWARPConfig() { savedProfile := cacheFile.LoadWARPConfig(tag) @@ -64,50 +68,10 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont } } if config == nil { - var privateKey wgtypes.Key - if options.Profile.PrivateKey != "" { - privateKey, err = wgtypes.ParseKey(options.Profile.PrivateKey) - if err != nil { - logger.ErrorContext(ctx, err) - return - } - } else { - privateKey, err = wgtypes.GeneratePrivateKey() - if err != nil { - logger.ErrorContext(ctx, err) - return - } - } - opts := make([]cloudflare.CloudflareApiOption, 0, 1) - if options.Profile.Detour != "" { - detour, ok := service.FromContext[adapter.OutboundManager](ctx).Outbound(options.Profile.Detour) - if !ok { - logger.ErrorContext(ctx, E.New("outbound detour not found: ", options.Profile.Detour)) - return - } - opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) { - return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) - })) - } - api := cloudflare.NewCloudflareApi(opts...) - var profile *cloudflare.CloudflareProfile - if options.Profile.AuthToken != "" && options.Profile.ID != "" { - profile, err = api.GetProfile(ctx, options.Profile.AuthToken, options.Profile.ID) - if err != nil { - logger.ErrorContext(ctx, err) - return - } - } else { - profile, err = api.CreateProfile(ctx, privateKey.PublicKey().String()) - if err != nil { - logger.ErrorContext(ctx, err) - return - } - } - config = &C.WARPConfig{ - PrivateKey: privateKey.String(), - Interface: profile.Config.Interface, - Peers: profile.Config.Peers, + config, err := endpoint.createConfig() + if err != nil { + logger.ErrorContext(ctx, err) + return } if cacheFile != nil && cacheFile.StoreWARPConfig() { content, err := json.Marshal(config) @@ -124,7 +88,7 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont } peer := config.Peers[0] hostParts := strings.Split(peer.Endpoint.Host, ":") - warpEndpoint.endpoint, err = NewEndpoint( + endpoint.endpoint, err = wireguard.NewEndpoint( ctx, router, logger, @@ -165,19 +129,19 @@ func NewWARPEndpoint(ctx context.Context, router adapter.Router, logger log.Cont logger.ErrorContext(ctx, err) return } - if err = warpEndpoint.endpoint.Start(adapter.StartStateStart); err != nil { + if err = endpoint.endpoint.Start(adapter.StartStateStart); err != nil { logger.ErrorContext(ctx, err) return } - if err = warpEndpoint.endpoint.Start(adapter.StartStatePostStart); err != nil { + if err = endpoint.endpoint.Start(adapter.StartStatePostStart); err != nil { logger.ErrorContext(ctx, err) return } } - return warpEndpoint, nil + return endpoint, nil } -func (w *WARPEndpoint) Start(stage adapter.StartStage) error { +func (w *Endpoint) Start(stage adapter.StartStage) error { if stage != adapter.StartStatePostStart { return nil } @@ -185,26 +149,79 @@ func (w *WARPEndpoint) Start(stage adapter.StartStage) error { return nil } -func (w *WARPEndpoint) Close() error { +func (w *Endpoint) Close() error { + if err := w.isEndpointInitialized(w.ctx); err != nil { + return err + } return common.Close(w.endpoint) } -func (w *WARPEndpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - if ok := w.isEndpointInitialized(); !ok { - return nil, E.New("endpoint not initialized") +func (w *Endpoint) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if err := w.isEndpointInitialized(ctx); err != nil { + return nil, err } return w.endpoint.DialContext(ctx, network, destination) } -func (w *WARPEndpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - if ok := w.isEndpointInitialized(); !ok { - return nil, E.New("endpoint not initialized") +func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + if err := w.isEndpointInitialized(ctx); err != nil { + return nil, err } return w.endpoint.ListenPacket(ctx, destination) } -func (w *WARPEndpoint) isEndpointInitialized() bool { - w.mtx.Lock() - defer w.mtx.Unlock() - return w.endpoint != nil +func (w *Endpoint) isEndpointInitialized(ctx context.Context) error { + select { + case <-w.await: + case <-ctx.Done(): + return ctx.Err() + } + if w.endpoint == nil { + return E.New("endpoint not initialized") + } + return nil +} + +func (w *Endpoint) createConfig() (*Config, error) { + var privateKey wgtypes.Key + var err error + if w.options.Profile.PrivateKey != "" { + privateKey, err = wgtypes.ParseKey(w.options.Profile.PrivateKey) + if err != nil { + return nil, err + } + } else { + privateKey, err = wgtypes.GeneratePrivateKey() + if err != nil { + return nil, err + } + } + opts := make([]cloudflare.CloudflareApiOption, 0, 1) + if w.options.Profile.Detour != "" { + detour, ok := service.FromContext[adapter.OutboundManager](w.ctx).Outbound(w.options.Profile.Detour) + if !ok { + return nil, E.New("outbound detour not found: ", w.options.Profile.Detour) + } + opts = append(opts, cloudflare.WithDialContext(func(ctx context.Context, network, addr string) (net.Conn, error) { + return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) + })) + } + api := cloudflare.NewCloudflareApi(opts...) + var profile *cloudflare.CloudflareProfile + if w.options.Profile.AuthToken != "" && w.options.Profile.ID != "" { + profile, err = api.GetProfile(w.ctx, w.options.Profile.AuthToken, w.options.Profile.ID) + if err != nil { + return nil, err + } + } else { + profile, err = api.CreateProfile(w.ctx, privateKey.PublicKey().String()) + if err != nil { + return nil, err + } + } + return &Config{ + PrivateKey: privateKey.String(), + Interface: profile.Config.Interface, + Peers: profile.Config.Peers, + }, nil } diff --git a/provider/local/provider.go b/provider/local/provider.go new file mode 100644 index 00000000..a8a9f405 --- /dev/null +++ b/provider/local/provider.go @@ -0,0 +1,129 @@ +package provider + +import ( + "context" + "os" + "path/filepath" + "time" + + "github.com/sagernet/fswatch" + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/provider" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/parser" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/service" + "github.com/sagernet/sing/service/filemanager" +) + +func RegisterProviderLocal(registry *provider.Registry) { + provider.Register[option.ProviderLocalOptions](registry, C.ProviderTypeLocal, NewProviderLocal) +} + +func RegisterProviderInline(registry *provider.Registry) { + provider.Register[option.ProviderInlineOptions](registry, C.ProviderTypeInline, NewProviderInline) +} + +var _ adapter.Provider = (*ProviderLocal)(nil) + +type ProviderLocal struct { + provider.Adapter + ctx context.Context + logger log.ContextLogger + provider adapter.ProviderManager + path string + lastOutOpts []option.Outbound + lastUpdated time.Time + watcher *fswatch.Watcher +} + +func NewProviderInline(ctx context.Context, router adapter.Router, logFactory log.Factory, tag string, options option.ProviderInlineOptions) (adapter.Provider, error) { + var ( + outbound = service.FromContext[adapter.OutboundManager](ctx) + logger = logFactory.NewLogger(F.ToString("provider/inline", "[", tag, "]")) + ) + provider := &ProviderLocal{ + Adapter: provider.NewAdapter(ctx, router, outbound, logFactory, logger, tag, C.ProviderTypeInline, options.HealthCheck), + ctx: ctx, + logger: logger, + } + provider.UpdateOutbounds(nil, options.Outbounds) + return provider, nil +} + +func NewProviderLocal(ctx context.Context, router adapter.Router, logFactory log.Factory, tag string, options option.ProviderLocalOptions) (adapter.Provider, error) { + if options.Path == "" { + return nil, E.New("provider path is required") + } + var ( + outbound = service.FromContext[adapter.OutboundManager](ctx) + logger = logFactory.NewLogger(F.ToString("provider/local", "[", tag, "]")) + ) + provider := &ProviderLocal{ + Adapter: provider.NewAdapter(ctx, router, outbound, logFactory, logger, tag, C.ProviderTypeLocal, options.HealthCheck), + ctx: ctx, + logger: logger, + provider: service.FromContext[adapter.ProviderManager](ctx), + } + filePath := filemanager.BasePath(ctx, options.Path) + provider.path, _ = filepath.Abs(filePath) + watcher, err := fswatch.NewWatcher(fswatch.Options{ + Path: []string{filePath}, + Callback: func(path string) { + uErr := provider.reloadFile(path) + if uErr != nil { + logger.Error(E.Cause(uErr, "reload provider ", tag)) + } + provider.UpdateGroups() + }, + }) + if err != nil { + return nil, err + } + provider.watcher = watcher + return provider, nil +} + +func (s *ProviderLocal) Start() error { + err := s.reloadFile(s.path) + if err != nil { + return err + } + s.UpdateGroups() + if s.watcher != nil { + err := s.watcher.Start() + if err != nil { + s.logger.Error(E.Cause(err, "watch provider file")) + } + } + return s.Adapter.Start() +} + +func (s *ProviderLocal) UpdatedAt() time.Time { + return s.lastUpdated +} + +func (s *ProviderLocal) reloadFile(path string) error { + if fileInfo, err := os.Stat(path); err == nil { + s.lastUpdated = fileInfo.ModTime() + } + content, err := os.ReadFile(path) + if err != nil { + return err + } + outboundOpts, err := parser.ParseSubscription(s.ctx, string(content)) + if err != nil { + return err + } + s.UpdateOutbounds(s.lastOutOpts, outboundOpts) + s.lastOutOpts = outboundOpts + return nil +} + +func (s *ProviderLocal) Close() error { + return common.Close(&s.Adapter, common.PtrOrNil(s.watcher)) +} diff --git a/provider/remote/provider.go b/provider/remote/provider.go new file mode 100644 index 00000000..906c333e --- /dev/null +++ b/provider/remote/provider.go @@ -0,0 +1,338 @@ +package provider + +import ( + "bytes" + "context" + "crypto/tls" + "io" + "net" + "net/http" + "regexp" + "runtime" + "strings" + "sync/atomic" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/provider" + boxCommon "github.com/sagernet/sing-box/common" + "github.com/sagernet/sing-box/common/interrupt" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/parser" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/common/json" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" + "github.com/sagernet/sing/service" +) + +func RegisterProvider(registry *provider.Registry) { + provider.Register[option.ProviderRemoteOptions](registry, C.ProviderTypeRemote, NewProviderRemote) +} + +var _ adapter.Provider = (*ProviderRemote)(nil) + +type ProviderRemote struct { + provider.Adapter + ctx context.Context + cancel context.CancelFunc + logger log.ContextLogger + outbound adapter.OutboundManager + provider adapter.ProviderManager + cacheFile adapter.CacheFile + dialer N.Dialer + lastEtag string + lastOutOpts []option.Outbound + lastUpdated time.Time + subscriptionInfo adapter.SubscriptionInfo + ticker *time.Ticker + updating atomic.Bool + + url string + userAgent string + downloadDetour string + updateInterval time.Duration + exclude *regexp.Regexp + include *regexp.Regexp +} + +func NewProviderRemote(ctx context.Context, router adapter.Router, logFactory log.Factory, tag string, options option.ProviderRemoteOptions) (adapter.Provider, error) { + if options.URL == "" { + return nil, E.New("provider URL is required") + } + updateInterval := time.Duration(options.UpdateInterval) + if updateInterval <= 0 { + updateInterval = 24 * time.Hour + } + if updateInterval < time.Minute { + updateInterval = time.Minute + } + var userAgent string + if options.UserAgent == "" { + userAgent = "sing-box " + C.Version + } else { + userAgent = options.UserAgent + } + ctx, cancel := context.WithCancel(ctx) + outbound := service.FromContext[adapter.OutboundManager](ctx) + logger := logFactory.NewLogger(F.ToString("provider/remote", "[", tag, "]")) + updateChan := make(chan struct{}) + close(updateChan) + return &ProviderRemote{ + Adapter: provider.NewAdapter(ctx, router, outbound, logFactory, logger, tag, C.ProviderTypeRemote, options.HealthCheck), + ctx: ctx, + cancel: cancel, + logger: logger, + outbound: outbound, + provider: service.FromContext[adapter.ProviderManager](ctx), + + url: options.URL, + userAgent: userAgent, + downloadDetour: options.DownloadDetour, + updateInterval: updateInterval, + exclude: (*regexp.Regexp)(options.Exclude), + include: (*regexp.Regexp)(options.Include), + }, nil +} + +func (s *ProviderRemote) Start() error { + s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx) + if s.cacheFile != nil { + if saveSub := s.cacheFile.LoadSubscription(s.Tag()); saveSub != nil { + content, _ := boxCommon.DecodeBase64URLSafe(string(saveSub.Content)) + firstLine, others := getFirstLine(content) + if info, ok := parseInfo(firstLine); ok { + s.subscriptionInfo = info + content, _ = boxCommon.DecodeBase64URLSafe(others) + } + if err := s.updateProviderFromContent(content); err != nil { + return E.Cause(err, "restore cached outbound provider") + } + s.UpdateGroups() + s.lastUpdated, s.lastEtag = saveSub.LastUpdated, saveSub.LastEtag + } + } + if s.downloadDetour != "" { + outbound, loaded := s.outbound.Outbound(s.downloadDetour) + if !loaded { + return E.New("detour outbound not found: ", s.downloadDetour) + } + s.dialer = outbound + } else { + s.dialer = s.outbound.Default() + } + + go s.loopUpdate() + return s.Adapter.Start() +} + +func (s *ProviderRemote) Update() error { + if s.ticker != nil { + s.ticker.Reset(s.updateInterval) + } + ctx := interrupt.ContextWithIsProviderConnection(s.ctx) + return s.fetch(ctx) +} + +func (s *ProviderRemote) UpdatedAt() time.Time { + return s.lastUpdated +} + +func (s *ProviderRemote) SubscriptionInfo() adapter.SubscriptionInfo { + return s.subscriptionInfo +} + +func (s *ProviderRemote) Close() error { + s.cancel() + if s.ticker != nil { + s.ticker.Stop() + } + return common.Close(&s.Adapter) +} + +func (s *ProviderRemote) updateOnce() { + ctx := interrupt.ContextWithIsProviderConnection(s.ctx) + if err := s.fetch(ctx); err != nil { + s.logger.Error("update outbound provider: ", err) + } +} + +func (s *ProviderRemote) fetch(ctx context.Context) error { + if s.updating.Swap(true) { + return E.New("provider is updating") + } + defer s.updating.Store(false) + s.logger.Debug("updating outbound provider ", s.Tag(), " from URL: ", s.url) + client := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: C.TCPTimeout, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + TLSClientConfig: &tls.Config{ + Time: ntp.TimeFuncFromContext(ctx), + RootCAs: adapter.RootPoolFromContext(ctx), + }, + }, + } + req, err := http.NewRequest(http.MethodGet, s.url, nil) + if err != nil { + return err + } + if s.lastEtag != "" { + req.Header.Set("If-None-Match", s.lastEtag) + } + req.Header.Set("User-Agent", s.userAgent) + resp, err := client.Do(req.WithContext(ctx)) + if err != nil { + return err + } + infoStr := resp.Header.Get("subscription-userinfo") + info, hasInfo := parseInfo(infoStr) + switch resp.StatusCode { + case http.StatusOK: + case http.StatusNotModified: + s.subscriptionInfo = info + s.lastUpdated = time.Now() + if s.cacheFile != nil { + saveSub := s.cacheFile.LoadSubscription(s.Tag()) + if saveSub != nil { + if hasInfo { + index := bytes.IndexByte(saveSub.Content, '\n') + if index != -1 { + saveSub.Content = append([]byte(infoStr+"\n"), saveSub.Content[index+1:]...) + } + } + saveSub.LastUpdated = s.lastUpdated + err := s.cacheFile.SaveSubscription(s.Tag(), saveSub) + if err != nil { + s.logger.Error("save outbound provider cache file: ", err) + } + } + } + s.logger.Info("update outbound provider ", s.Tag(), ": not modified") + return nil + default: + return E.New("unexpected status: ", resp.Status) + } + defer resp.Body.Close() + contentRaw, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + eTagHeader := resp.Header.Get("Etag") + if eTagHeader != "" { + s.lastEtag = eTagHeader + } + content, _ := boxCommon.DecodeBase64URLSafe(string(contentRaw)) + if !hasInfo { + firstLine, others := getFirstLine(content) + if info, hasInfo = parseInfo(firstLine); hasInfo { + infoStr = firstLine + content, _ = boxCommon.DecodeBase64URLSafe(others) + } + } + if err := s.updateProviderFromContent(content); err != nil { + return err + } + s.UpdateGroups() + s.subscriptionInfo = info + s.lastUpdated = time.Now() + if s.cacheFile != nil { + content, _ := json.Marshal(option.Options{ + Outbounds: s.lastOutOpts, + }) + if hasInfo { + content = append([]byte(infoStr+"\n"), content...) + } + err = s.cacheFile.SaveSubscription(s.Tag(), &adapter.SavedBinary{ + LastUpdated: s.lastUpdated, + Content: content, + LastEtag: s.lastEtag, + }) + if err != nil { + s.logger.Error("save outbound provider cache file: ", err) + } + } + s.logger.Info("updated outbound provider ", s.Tag()) + return nil +} + +func (s *ProviderRemote) loopUpdate() { + if time.Since(s.lastUpdated) < s.updateInterval { + select { + case <-s.ctx.Done(): + return + case <-time.After(time.Until(s.lastUpdated.Add(s.updateInterval))): + s.updateOnce() + } + } else { + s.updateOnce() + } + s.ticker = time.NewTicker(s.updateInterval) + for { + runtime.GC() + select { + case <-s.ctx.Done(): + return + case <-s.ticker.C: + s.updateOnce() + } + } +} + +func (s *ProviderRemote) updateProviderFromContent(content string) error { + outboundOpts, err := parser.ParseSubscription(s.ctx, content) + if err != nil { + return err + } + outboundOpts = common.Filter(outboundOpts, func(it option.Outbound) bool { + return (s.exclude == nil || !s.exclude.MatchString(it.Tag)) && (s.include == nil || s.include.MatchString(it.Tag)) + }) + s.UpdateOutbounds(s.lastOutOpts, outboundOpts) + s.lastOutOpts = outboundOpts + return nil +} + +func getFirstLine(content string) (string, string) { + lines := strings.Split(content, "\n") + if len(lines) == 1 { + return lines[0], "" + } + others := strings.Join(lines[1:], "\n") + return lines[0], others +} + +func parseInfo(infoStr string) (adapter.SubscriptionInfo, bool) { + info := adapter.SubscriptionInfo{} + if infoStr == "" { + return info, false + } + reg := regexp.MustCompile(`(upload|download|total|expire)[\s\t]*=[\s\t]*(-?\d*);?`) + matches := reg.FindAllStringSubmatch(infoStr, 4) + if len(matches) == 0 { + return info, false + } + for _, match := range matches { + key, value := match[1], match[2] + switch key { + case "upload": + info.Upload = boxCommon.StringToType[int64](value) + case "download": + info.Download = boxCommon.StringToType[int64](value) + case "total": + info.Total = boxCommon.StringToType[int64](value) + case "expire": + info.Expire = boxCommon.StringToType[int64](value) + default: + return info, false + } + } + return info, true +} diff --git a/route/process_cache.go b/route/process_cache.go index 01b477c4..44ee3fcf 100644 --- a/route/process_cache.go +++ b/route/process_cache.go @@ -74,16 +74,19 @@ func (r *Router) searchProcessInfo(ctx context.Context, metadata *adapter.Inboun } func (r *Router) isLocalSource(source netip.Addr) bool { - if !source.IsValid() { - return false - } - source = source.Unmap() if source.IsLoopback() { return true } + if r.platformInterface != nil { + for _, addr := range r.platformInterface.MyInterfaceAddress() { + if addr == source { + return true + } + } + } for _, netInterface := range r.network.InterfaceFinder().Interfaces() { for _, prefix := range netInterface.Addresses { - if prefix.Addr().Unmap() == source { + if prefix.Addr() == source { return true } } diff --git a/route/route.go b/route/route.go index 4be15d79..67027337 100644 --- a/route/route.go +++ b/route/route.go @@ -485,8 +485,8 @@ match: Fqdn: metadata.Destination.Fqdn, } } - if routeOptions.OverrideTunnelDestination != "" { - metadata.TunnelDestination = routeOptions.OverrideTunnelDestination + if routeOptions.OverrideGateway.IsValid() { + metadata.Gateway = routeOptions.OverrideGateway } if routeOptions.NetworkStrategy != nil { metadata.NetworkStrategy = routeOptions.NetworkStrategy diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index c671f367..fb60a4d7 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -29,12 +29,13 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti case "": return nil, nil case C.RuleActionTypeRoute: + overrideGateway := M.ParseAddr(action.RouteOptions.OverrideGateway) return &RuleActionRoute{ Outbound: action.RouteOptions.Outbound, RuleActionRouteOptions: RuleActionRouteOptions{ OverrideAddress: M.ParseSocksaddrHostPort(action.RouteOptions.OverrideAddress, 0), OverridePort: action.RouteOptions.OverridePort, - OverrideTunnelDestination: action.RouteOptions.OverrideTunnelDestination, + OverrideGateway: &overrideGateway, NetworkStrategy: (*C.NetworkStrategy)(action.RouteOptions.NetworkStrategy), FallbackDelay: time.Duration(action.RouteOptions.FallbackDelay), UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping, @@ -196,7 +197,7 @@ func (r *RuleActionBypass) String() string { type RuleActionRouteOptions struct { OverrideAddress M.Socksaddr OverridePort uint16 - OverrideTunnelDestination string + OverrideGateway *netip.Addr NetworkStrategy *C.NetworkStrategy NetworkType []C.InterfaceType FallbackNetworkType []C.InterfaceType @@ -225,8 +226,8 @@ func (r *RuleActionRouteOptions) Descriptions() []string { if r.OverridePort > 0 { descriptions = append(descriptions, F.ToString("override-port=", r.OverridePort)) } - if r.OverrideTunnelDestination != "" { - descriptions = append(descriptions, F.ToString("override-tunnel-destination=", r.OverrideTunnelDestination)) + if r.OverrideGateway != nil { + descriptions = append(descriptions, F.ToString("override-gateway=", r.OverrideGateway.String())) } if r.NetworkStrategy != nil { descriptions = append(descriptions, F.ToString("network-strategy=", r.NetworkStrategy)) diff --git a/route/rule/rule_default.go b/route/rule/rule_default.go index 1eef862d..b921c8b2 100644 --- a/route/rule/rule_default.go +++ b/route/rule/rule_default.go @@ -186,16 +186,6 @@ func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options optio rule.destinationPortItems = append(rule.destinationPortItems, item) rule.allItems = append(rule.allItems, item) } - if len(options.TunnelSource) > 0 { - item := NewTunnelSourceItem(options.TunnelSource) - rule.items = append(rule.items, item) - rule.allItems = append(rule.allItems, item) - } - if len(options.TunnelDestination) > 0 { - item := NewTunnelDestinationItem(options.TunnelDestination) - rule.items = append(rule.items, item) - rule.allItems = append(rule.allItems, item) - } if len(options.ProcessName) > 0 { item := NewProcessItem(options.ProcessName) rule.items = append(rule.items, item) diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index dad49503..04f0f236 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -182,16 +182,6 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op rule.destinationPortItems = append(rule.destinationPortItems, item) rule.allItems = append(rule.allItems, item) } - if len(options.TunnelSource) > 0 { - item := NewTunnelSourceItem(options.TunnelSource) - rule.items = append(rule.items, item) - rule.allItems = append(rule.allItems, item) - } - if len(options.TunnelDestination) > 0 { - item := NewTunnelDestinationItem(options.TunnelDestination) - rule.items = append(rule.items, item) - rule.allItems = append(rule.allItems, item) - } if len(options.ProcessName) > 0 { item := NewProcessItem(options.ProcessName) rule.items = append(rule.items, item) diff --git a/route/rule/rule_headless.go b/route/rule/rule_headless.go index f11d1126..c5146318 100644 --- a/route/rule/rule_headless.go +++ b/route/rule/rule_headless.go @@ -130,16 +130,6 @@ func NewDefaultHeadlessRule(ctx context.Context, options option.DefaultHeadlessR rule.destinationPortItems = append(rule.destinationPortItems, item) rule.allItems = append(rule.allItems, item) } - if len(options.TunnelSource) > 0 { - item := NewTunnelSourceItem(options.TunnelSource) - rule.items = append(rule.items, item) - rule.allItems = append(rule.allItems, item) - } - if len(options.TunnelDestination) > 0 { - item := NewTunnelDestinationItem(options.TunnelDestination) - rule.items = append(rule.items, item) - rule.allItems = append(rule.allItems, item) - } if len(options.ProcessName) > 0 { item := NewProcessItem(options.ProcessName) rule.items = append(rule.items, item) diff --git a/route/rule/rule_item_tunnel_destination.go b/route/rule/rule_item_tunnel_destination.go deleted file mode 100644 index 34f711d6..00000000 --- a/route/rule/rule_item_tunnel_destination.go +++ /dev/null @@ -1,35 +0,0 @@ -package rule - -import ( - "strings" - - "github.com/sagernet/sing-box/adapter" - F "github.com/sagernet/sing/common/format" -) - -var _ RuleItem = (*TunnelDestinationItem)(nil) - -type TunnelDestinationItem struct { - destinations []string - destinationMap map[string]bool -} - -func NewTunnelDestinationItem(destinations []string) *TunnelDestinationItem { - rule := &TunnelDestinationItem{destinations, make(map[string]bool)} - for _, destination := range destinations { - rule.destinationMap[destination] = true - } - return rule -} - -func (r *TunnelDestinationItem) Match(metadata *adapter.InboundContext) bool { - return r.destinationMap[metadata.TunnelDestination] -} - -func (r *TunnelDestinationItem) String() string { - if len(r.destinations) == 1 { - return F.ToString("tunnel_destination=", r.destinations[0]) - } else { - return F.ToString("tunnel_destination=[", strings.Join(r.destinations, " "), "]") - } -} diff --git a/route/rule/rule_item_tunnel_source.go b/route/rule/rule_item_tunnel_source.go deleted file mode 100644 index 6a2f01cb..00000000 --- a/route/rule/rule_item_tunnel_source.go +++ /dev/null @@ -1,35 +0,0 @@ -package rule - -import ( - "strings" - - "github.com/sagernet/sing-box/adapter" - F "github.com/sagernet/sing/common/format" -) - -var _ RuleItem = (*TunnelSourceItem)(nil) - -type TunnelSourceItem struct { - sources []string - sourceMap map[string]bool -} - -func NewTunnelSourceItem(sources []string) *TunnelSourceItem { - rule := &TunnelSourceItem{sources, make(map[string]bool)} - for _, source := range sources { - rule.sourceMap[source] = true - } - return rule -} - -func (r *TunnelSourceItem) Match(metadata *adapter.InboundContext) bool { - return r.sourceMap[metadata.TunnelSource] -} - -func (r *TunnelSourceItem) String() string { - if len(r.sources) == 1 { - return F.ToString("tunnel_source=", r.sources[0]) - } else { - return F.ToString("tunnel_source=[", strings.Join(r.sources, " "), "]") - } -} diff --git a/service/admin_panel/tables/user.go b/service/admin_panel/tables/user.go index 23087c4f..6bc36767 100644 --- a/service/admin_panel/tables/user.go +++ b/service/admin_panel/tables/user.go @@ -77,6 +77,7 @@ func UserTableFactory(manager CM.Manager, logger log.Logger) func(ctx *context.C Options: types.FieldOptions{ {Text: "Hysteria", Value: "hysteria"}, {Text: "Hysteria2", Value: "hysteria2"}, + {Text: "MTProxy", Value: "mtproxy"}, {Text: "Trojan", Value: "trojan"}, {Text: "TUIC", Value: "tuic"}, {Text: "VLESS", Value: "vless"}, @@ -183,16 +184,18 @@ func UserTableFactory(manager CM.Manager, logger log.Logger) func(ctx *context.C FieldOptions(types.FieldOptions{ {Text: "Hysteria", Value: "hysteria"}, {Text: "Hysteria2", Value: "hysteria2"}, + {Text: "MTProxy", Value: "mtproxy"}, {Text: "Trojan", Value: "trojan"}, {Text: "TUIC", Value: "tuic"}, {Text: "VLESS", Value: "vless"}, {Text: "VMess", Value: "vmess"}, }). FieldOnChooseOptionsHide([]string{""}, "inbound"). - FieldOnChooseOptionsHide([]string{"", "hysteria", "hysteria2", "shadowsocks", "trojan", "tuic"}, "uuid"). - FieldOnChooseOptionsHide([]string{"", "vless", "vmess"}, "password"). - FieldOnChooseOptionsHide([]string{"", "hysteria", "hysteria2", "shadowsocks", "trojan", "tuic", "vmess"}, "flow"). - FieldOnChooseOptionsHide([]string{"", "hysteria", "hysteria2", "shadowsocks", "trojan", "tuic", "vless"}, "alter_id") + FieldOnChooseOptionsHide([]string{"", "hysteria", "hysteria2", "mtproxy", "shadowsocks", "trojan", "tuic"}, "uuid"). + FieldOnChooseOptionsHide([]string{"", "mtproxy", "vless", "vmess"}, "password"). + FieldOnChooseOptionsHide([]string{"", "hysteria", "hysteria2", "shadowsocks", "trojan", "tuic", "vless", "vmess"}, "secret"). + FieldOnChooseOptionsHide([]string{"", "hysteria", "hysteria2", "mtproxy", "shadowsocks", "trojan", "tuic", "vmess"}, "flow"). + FieldOnChooseOptionsHide([]string{"", "hysteria", "hysteria2", "mtproxy", "shadowsocks", "trojan", "tuic", "vless"}, "alter_id") formList.AddField("Inbound", "inbound", db.Varchar, form.Text). FieldMust(). FieldDisplayButCanNotEditWhenUpdate(). @@ -203,6 +206,7 @@ func UserTableFactory(manager CM.Manager, logger log.Logger) func(ctx *context.C }) formList.AddField("UUID", "uuid", db.Varchar, form.Text) formList.AddField("Password", "password", db.Varchar, form.Text) + formList.AddField("Secret", "secret", db.Varchar, form.Text) formList.AddField("Flow", "flow", db.Varchar, form.SelectSingle). FieldOptions(types.FieldOptions{ {Text: "xtls-rprx-vision", Value: "xtls-rprx-vision"}, @@ -233,6 +237,7 @@ func UserTableFactory(manager CM.Manager, logger log.Logger) func(ctx *context.C Inbound: values.Get("inbound"), UUID: values.Get("uuid"), Password: values.Get("password"), + Secret: values.Get("secret"), Flow: values.Get("flow"), AlterID: alterId, }) @@ -269,6 +274,7 @@ func UserTableFactory(manager CM.Manager, logger log.Logger) func(ctx *context.C _, err = manager.UpdateUser(id, CM.UserUpdate{ UUID: values.Get("uuid"), Password: values.Get("password"), + Secret: values.Get("secret"), Flow: values.Get("flow"), AlterID: alterId, }) diff --git a/service/manager/constant/dto.go b/service/manager/constant/dto.go index 0aae5914..c8988c73 100644 --- a/service/manager/constant/dto.go +++ b/service/manager/constant/dto.go @@ -48,6 +48,7 @@ type User struct { Inbound string `json:"inbound" validate:"required"` UUID string `json:"uuid" validate:"required"` Password string `json:"password" validate:"required"` + Secret string `json:"secret" validate:"required"` Flow string `json:"flow" validate:"required"` AlterID int `json:"alter_id" validate:"required"` CreatedAt time.Time `json:"created_at" validate:"required"` @@ -57,10 +58,11 @@ type User struct { type UserCreate struct { SquadIDs []int `json:"squad_ids" validate:"required"` Username string `json:"username" validate:"required"` - Type string `json:"type" validate:"required,oneof=hysteria hysteria2 trojan tuic vless vmess"` + Type string `json:"type" validate:"required,oneof=hysteria hysteria2 mtproxy trojan tuic vless vmess"` Inbound string `json:"inbound" validate:"required"` UUID string `json:"uuid" validate:"omitempty,uuid4"` Password string `json:"password" validate:"omitempty"` + Secret string `json:"secret" validate:"omitempty"` Flow string `json:"flow" validate:"omitempty"` AlterID int `json:"alter_id" validate:"omitempty"` } @@ -68,6 +70,7 @@ type UserCreate struct { type UserUpdate struct { UUID string `json:"uuid" validate:"omitempty,uuid4"` Password string `json:"password" validate:"omitempty"` + Secret string `json:"secret" validate:"omitempty"` Flow string `json:"flow" validate:"omitempty"` AlterID int `json:"alter_id" validate:"omitempty"` } @@ -75,6 +78,7 @@ type UserUpdate struct { type BaseUser struct { UUID string `json:"uuid" validate:"omitempty,uuid4"` Password string `json:"password" validate:"omitempty"` + Secret string `json:"secret" validate:"omitempty"` Flow string `json:"flow" validate:"omitempty"` AlterID int `json:"alter_id" validate:"omitempty"` } diff --git a/service/manager/repository/postgresql/migration.go b/service/manager/repository/postgresql/migration.go index 4a426e38..2c08baaa 100644 --- a/service/manager/repository/postgresql/migration.go +++ b/service/manager/repository/postgresql/migration.go @@ -40,6 +40,7 @@ var migrations = map[string]string{ inbound TEXT NOT NULL, uuid TEXT NOT NULL, password TEXT NOT NULL, + secret TEXT NOT NULL, flow TEXT NOT NULL, alter_id INTEGER NOT NULL, created_at TIMESTAMP NOT NULL, diff --git a/service/manager/repository/postgresql/repository.go b/service/manager/repository/postgresql/repository.go index 29a1a526..c75eff8e 100644 --- a/service/manager/repository/postgresql/repository.go +++ b/service/manager/repository/postgresql/repository.go @@ -391,12 +391,13 @@ func (r *PostgreSQLRepository) CreateUser(user constant.UserCreate) (constant.Us inbound, uuid, password, + secret, flow, alter_id, created_at, updated_at ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, username, @@ -404,6 +405,7 @@ func (r *PostgreSQLRepository) CreateUser(user constant.UserCreate) (constant.Us inbound, uuid, password, + secret, flow, alter_id, created_at, @@ -414,6 +416,7 @@ func (r *PostgreSQLRepository) CreateUser(user constant.UserCreate) (constant.Us user.Inbound, user.UUID, user.Password, + user.Secret, user.Flow, user.AlterID, now, @@ -425,11 +428,15 @@ func (r *PostgreSQLRepository) CreateUser(user constant.UserCreate) (constant.Us &u.Inbound, &u.UUID, &u.Password, + &u.Secret, &u.Flow, &u.AlterID, &u.CreatedAt, &u.UpdatedAt, ) + if err != nil { + return u, err + } rows := make([][]any, len(user.SquadIDs)) for i, squadID := range user.SquadIDs { rows[i] = []any{u.ID, squadID} @@ -465,6 +472,7 @@ func (r *PostgreSQLRepository) GetUsers(filters map[string][]string) ([]constant "inbound", "uuid", "password", + "secret", "flow", "alter_id", "created_at", @@ -495,6 +503,7 @@ func (r *PostgreSQLRepository) GetUsers(filters map[string][]string) ([]constant &u.Inbound, &u.UUID, &u.Password, + &u.Secret, &u.Flow, &u.AlterID, &u.CreatedAt, @@ -539,6 +548,7 @@ func (r *PostgreSQLRepository) GetUser(id int) (constant.User, error) { inbound, uuid, password, + secret, flow, alter_id, created_at, @@ -553,6 +563,7 @@ func (r *PostgreSQLRepository) GetUser(id int) (constant.User, error) { &u.Inbound, &u.UUID, &u.Password, + &u.Secret, &u.Flow, &u.AlterID, &u.CreatedAt, @@ -568,10 +579,11 @@ func (r *PostgreSQLRepository) UpdateUser(id int, user constant.UserUpdate) (con SET uuid = $1, password = $2, - flow = $3, - alter_id = $4, - updated_at = $5 - WHERE id = $6 + secret = $3, + flow = $4, + alter_id = $5, + updated_at = $6 + WHERE id = $7 RETURNING id, ARRAY( @@ -584,6 +596,7 @@ func (r *PostgreSQLRepository) UpdateUser(id int, user constant.UserUpdate) (con inbound, uuid, password, + secret, flow, alter_id, created_at, @@ -591,6 +604,7 @@ func (r *PostgreSQLRepository) UpdateUser(id int, user constant.UserUpdate) (con `, user.UUID, user.Password, + user.Secret, user.Flow, user.AlterID, time.Now(), @@ -603,6 +617,7 @@ func (r *PostgreSQLRepository) UpdateUser(id int, user constant.UserUpdate) (con &u.Inbound, &u.UUID, &u.Password, + &u.Secret, &u.Flow, &u.AlterID, &u.CreatedAt, @@ -628,6 +643,7 @@ func (r *PostgreSQLRepository) DeleteUser(id int) (constant.User, error) { inbound, uuid, password, + secret, flow, alter_id, created_at, @@ -640,6 +656,7 @@ func (r *PostgreSQLRepository) DeleteUser(id int) (constant.User, error) { &u.Inbound, &u.UUID, &u.Password, + &u.Secret, &u.Flow, &u.AlterID, &u.CreatedAt, diff --git a/service/manager/service.go b/service/manager/service.go index b73c75b0..a8cb3408 100644 --- a/service/manager/service.go +++ b/service/manager/service.go @@ -10,7 +10,7 @@ import ( "github.com/go-playground/validator/v10" "github.com/gofrs/uuid/v5" - "github.com/patrickmn/go-cache" + "github.com/patrickmn/go-cache/v2" "github.com/sagernet/sing-box/adapter" boxService "github.com/sagernet/sing-box/adapter/service" C "github.com/sagernet/sing-box/constant" @@ -32,7 +32,7 @@ type Service struct { repository constant.Repository nodes map[string]constant.ConnectedNode - limiterLocks map[int]map[string]*cache.Cache + limiterLocks map[int]map[string]*cache.Cache[string, struct{}] userValidator *validator.Validate defaultValidator *validator.Validate @@ -79,6 +79,10 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio if user.Password == "" { sl.ReportError(user.Password, "password", "Password", "required", "") } + case "mtproxy": + if user.Secret == "" { + sl.ReportError(user.Secret, "secret", "Secret", "required", "") + } } }, constant.UserCreate{}) return &Service{ @@ -87,7 +91,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio logger: logger, repository: repository, nodes: make(map[string]constant.ConnectedNode, 0), - limiterLocks: make(map[int]map[string]*cache.Cache), + limiterLocks: make(map[int]map[string]*cache.Cache[string, struct{}]), userValidator: userValidator, defaultValidator: validator.New(), }, nil @@ -519,7 +523,7 @@ func (s *Service) AcquireLock(limiterId int, id string) (string, error) { } locks, ok := s.limiterLocks[limiterId] if !ok { - locks = make(map[string]*cache.Cache) + locks = make(map[string]*cache.Cache[string, struct{}]) s.limiterLocks[limiter.ID] = locks } lock, ok := locks[id] @@ -527,8 +531,8 @@ func (s *Service) AcquireLock(limiterId int, id string) (string, error) { if len(locks) == int(limiter.Count) { return "", E.New("not enough free locks") } - lock = cache.New(time.Second*30, time.Second) - lock.OnEvicted(func(_ string, _ interface{}) { + lock = cache.New[string, struct{}](time.Second*30, time.Second) + lock.OnEvicted(func(_ string, _ struct{}) { s.connLockMtx.Lock() defer s.connLockMtx.Unlock() if lock.ItemCount() == 0 { @@ -541,7 +545,7 @@ func (s *Service) AcquireLock(limiterId int, id string) (string, error) { if err != nil { return "", err } - lock.SetDefault(handleID.String(), new(struct{})) + lock.SetDefault(handleID.String(), struct{}{}) return handleID.String(), nil } @@ -556,7 +560,7 @@ func (s *Service) RefreshLock(limiterId int, id string, handleId string) error { if !ok { return E.New("lock not found") } - err := lock.Replace(handleId, new(struct{}), time.Second*30) + err := lock.Replace(handleId, struct{}{}, time.Second*30) return err } diff --git a/service/node/inbound/mtproxy.go b/service/node/inbound/mtproxy.go new file mode 100644 index 00000000..7aa24c41 --- /dev/null +++ b/service/node/inbound/mtproxy.go @@ -0,0 +1,88 @@ +package inbound + +import ( + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-box/protocol/mtproxy" + CM "github.com/sagernet/sing-box/service/manager/constant" + "github.com/sagernet/sing-box/service/node/constant" +) + +type MTProxyManager struct { + access sync.Mutex + inbounds map[string]*MTProxyUserManager +} + +func NewMTProxyManager() *MTProxyManager { + return &MTProxyManager{ + inbounds: make(map[string]*MTProxyUserManager), + } +} + +func (m *MTProxyManager) AddUserManager(inbound adapter.Inbound) error { + m.access.Lock() + defer m.access.Unlock() + m.inbounds[inbound.Tag()] = &MTProxyUserManager{ + inbound: inbound.(*mtproxy.Inbound), + usersMap: make(map[string]option.MTProxyUser), + } + return nil +} + +func (m *MTProxyManager) GetUserManager(tag string) (constant.UserManager, bool) { + m.access.Lock() + defer m.access.Unlock() + inbound, ok := m.inbounds[tag] + return inbound, ok +} + +func (m *MTProxyManager) GetUserManagerTags() []string { + m.access.Lock() + defer m.access.Unlock() + tags := make([]string, 0, len(m.inbounds)) + for tag, _ := range m.inbounds { + tags = append(tags, tag) + } + return tags +} + +type MTProxyUserManager struct { + inbound *mtproxy.Inbound + usersMap map[string]option.MTProxyUser + + mtx sync.Mutex +} + +func (i *MTProxyUserManager) postUpdate() { + users := make([]option.MTProxyUser, 0, len(i.usersMap)) + for _, user := range i.usersMap { + users = append(users, user) + } + i.inbound.UpdateUsers(users) +} + +func (i *MTProxyUserManager) UpdateUser(user CM.User) { + i.mtx.Lock() + defer i.mtx.Unlock() + i.usersMap[user.Username] = option.MTProxyUser{Name: user.Username, Secret: user.Secret} + i.postUpdate() +} + +func (i *MTProxyUserManager) UpdateUsers(users []CM.User) { + i.mtx.Lock() + defer i.mtx.Unlock() + clear(i.usersMap) + for _, user := range users { + i.usersMap[user.Username] = option.MTProxyUser{Name: user.Username, Secret: user.Secret} + } + i.postUpdate() +} + +func (i *MTProxyUserManager) DeleteUser(username string) { + i.mtx.Lock() + defer i.mtx.Unlock() + delete(i.usersMap, username) + i.postUpdate() +} diff --git a/transport/masque/adapter.go b/transport/masque/adapter.go new file mode 100644 index 00000000..f72748f7 --- /dev/null +++ b/transport/masque/adapter.go @@ -0,0 +1,82 @@ +package masque + +import ( + "sync" + + "github.com/sagernet/wireguard-go/tun" + "github.com/songgao/water" +) + +type NetstackAdapter struct { + dev tun.Device + tunnelBufPool sync.Pool + tunnelSizesPool sync.Pool +} + +func (n *NetstackAdapter) ReadPacket(buf []byte) (int, error) { + packetBufsPtr := n.tunnelBufPool.Get().(*[][]byte) + sizesPtr := n.tunnelSizesPool.Get().(*[]int) + + defer func() { + (*packetBufsPtr)[0] = nil + n.tunnelBufPool.Put(packetBufsPtr) + n.tunnelSizesPool.Put(sizesPtr) + }() + + (*packetBufsPtr)[0] = buf + (*sizesPtr)[0] = 0 + + _, err := n.dev.Read(*packetBufsPtr, *sizesPtr, 0) + if err != nil { + return 0, err + } + + return (*sizesPtr)[0], nil +} + +func (n *NetstackAdapter) WritePacket(pkt []byte) error { + // Write expects a slice of packet buffers. + _, err := n.dev.Write([][]byte{pkt}, 0) + return err +} + +// NewNetstackAdapter creates a new NetstackAdapter. +func NewNetstackAdapter(dev tun.Device) TunnelDevice { + return &NetstackAdapter{ + dev: dev, + tunnelBufPool: sync.Pool{ + New: func() interface{} { + buf := make([][]byte, 1) + return &buf + }, + }, + tunnelSizesPool: sync.Pool{ + New: func() interface{} { + sizes := make([]int, 1) + return &sizes + }, + }, + } +} + +type WaterAdapter struct { + iface *water.Interface +} + +func (w *WaterAdapter) ReadPacket(buf []byte) (int, error) { + n, err := w.iface.Read(buf) + if err != nil { + return 0, err + } + + return n, nil +} + +func (w *WaterAdapter) WritePacket(pkt []byte) error { + _, err := w.iface.Write(pkt) + return err +} + +func NewWaterAdapter(iface *water.Interface) TunnelDevice { + return &WaterAdapter{iface: iface} +} diff --git a/transport/masque/buffer.go b/transport/masque/buffer.go new file mode 100644 index 00000000..267f494f --- /dev/null +++ b/transport/masque/buffer.go @@ -0,0 +1,34 @@ +package masque + +import "sync" + +type NetBuffer struct { + capacity uint32 + buf sync.Pool +} + +func (n *NetBuffer) Get() []byte { + return *(n.buf.Get().(*[]byte)) +} + +func (n *NetBuffer) Put(buf []byte) { + if cap(buf) != int(n.capacity) { + return + } + n.buf.Put(&buf) +} + +func NewNetBuffer(capacity uint32) *NetBuffer { + if capacity == 0 { + panic("capacity must be greater than 0") + } + return &NetBuffer{ + capacity: capacity, + buf: sync.Pool{ + New: func() interface{} { + b := make([]byte, capacity) + return &b + }, + }, + } +} diff --git a/transport/masque/device.go b/transport/masque/device.go new file mode 100644 index 00000000..d6f71597 --- /dev/null +++ b/transport/masque/device.go @@ -0,0 +1,33 @@ +package masque + +import ( + "context" + "net/netip" + "time" + + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/logger" + N "github.com/sagernet/sing/common/network" + wgTun "github.com/sagernet/wireguard-go/tun" +) + +type Device interface { + wgTun.Device + N.Dialer + Start() error +} + +type DeviceOptions struct { + Context context.Context + Logger logger.ContextLogger + Handler tun.Handler + UDPTimeout time.Duration + CreateDialer func(interfaceName string) N.Dialer + Name string + MTU uint32 + Address []netip.Prefix +} + +func NewDevice(options DeviceOptions) (Device, error) { + return newStackDevice(options) +} diff --git a/transport/masque/device_stack.go b/transport/masque/device_stack.go new file mode 100644 index 00000000..a25115c0 --- /dev/null +++ b/transport/masque/device_stack.go @@ -0,0 +1,307 @@ +package masque + +import ( + "context" + "net" + "net/netip" + "os" + + "github.com/sagernet/gvisor/pkg/buffer" + "github.com/sagernet/gvisor/pkg/tcpip" + "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" + "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/network/ipv4" + "github.com/sagernet/gvisor/pkg/tcpip/network/ipv6" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/icmp" + "github.com/sagernet/gvisor/pkg/tcpip/transport/tcp" + "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/transport/wireguard" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/buf" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + wgTun "github.com/sagernet/wireguard-go/tun" +) + +type stackDevice struct { + ctx context.Context + logger log.ContextLogger + stack *stack.Stack + mtu uint32 + events chan wgTun.Event + wgTun.Device + outbound chan *stack.PacketBuffer + packetOutbound chan *buf.Buffer + done chan struct{} + dispatcher stack.NetworkDispatcher + inet4Address netip.Addr + inet6Address netip.Addr +} + +func newStackDevice(options DeviceOptions) (*stackDevice, error) { + tunDevice := &stackDevice{ + ctx: options.Context, + logger: options.Logger, + mtu: options.MTU, + events: make(chan wgTun.Event, 1), + outbound: make(chan *stack.PacketBuffer, 256), + packetOutbound: make(chan *buf.Buffer, 256), + done: make(chan struct{}), + } + ipStack, err := tun.NewGVisorStackWithOptions((*wireEndpoint)(tunDevice), stack.NICOptions{}, true) + if err != nil { + return nil, err + } + var ( + inet4Address netip.Addr + inet6Address netip.Addr + ) + for _, prefix := range options.Address { + addr := tun.AddressFromAddr(prefix.Addr()) + protoAddr := tcpip.ProtocolAddress{ + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: addr, + PrefixLen: prefix.Bits(), + }, + } + if prefix.Addr().Is4() { + inet4Address = prefix.Addr() + tunDevice.inet4Address = inet4Address + protoAddr.Protocol = ipv4.ProtocolNumber + } else { + inet6Address = prefix.Addr() + tunDevice.inet6Address = inet6Address + protoAddr.Protocol = ipv6.ProtocolNumber + } + gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{}) + if gErr != nil { + return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String()) + } + } + tunDevice.stack = ipStack + if options.Handler != nil { + ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket) + icmpForwarder := tun.NewICMPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber4, icmpForwarder.HandlePacket) + ipStack.SetTransportProtocolHandler(icmp.ProtocolNumber6, icmpForwarder.HandlePacket) + } + return tunDevice, nil +} + +func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + addr := tcpip.FullAddress{ + NIC: tun.DefaultNIC, + Port: destination.Port, + Addr: tun.AddressFromAddr(destination.Addr), + } + bind := tcpip.FullAddress{ + NIC: tun.DefaultNIC, + } + var networkProtocol tcpip.NetworkProtocolNumber + if destination.IsIPv4() { + if !w.inet4Address.IsValid() { + return nil, E.New("missing IPv4 local address") + } + networkProtocol = header.IPv4ProtocolNumber + bind.Addr = tun.AddressFromAddr(w.inet4Address) + } else { + if !w.inet6Address.IsValid() { + return nil, E.New("missing IPv6 local address") + } + networkProtocol = header.IPv6ProtocolNumber + bind.Addr = tun.AddressFromAddr(w.inet6Address) + } + switch N.NetworkName(network) { + case N.NetworkTCP: + tcpConn, err := wireguard.DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol) + if err != nil { + return nil, err + } + return tcpConn, nil + case N.NetworkUDP: + udpConn, err := gonet.DialUDP(w.stack, &bind, &addr, networkProtocol) + if err != nil { + return nil, err + } + return udpConn, nil + default: + return nil, E.Extend(N.ErrUnknownNetwork, network) + } +} + +func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + bind := tcpip.FullAddress{ + NIC: tun.DefaultNIC, + } + var networkProtocol tcpip.NetworkProtocolNumber + if destination.IsIPv4() { + networkProtocol = header.IPv4ProtocolNumber + bind.Addr = tun.AddressFromAddr(w.inet4Address) + } else { + networkProtocol = header.IPv6ProtocolNumber + bind.Addr = tun.AddressFromAddr(w.inet4Address) + } + udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol) + if err != nil { + return nil, err + } + return udpConn, nil +} + +func (w *stackDevice) Start() error { + w.events <- wgTun.EventUp + return nil +} + +func (w *stackDevice) File() *os.File { + return nil +} + +func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) { + select { + case packet, ok := <-w.outbound: + if !ok { + return 0, os.ErrClosed + } + defer packet.DecRef() + var copyN int + /*rangeIterate(packet.Data().AsRange(), func(view *buffer.View) { + copyN += copy(bufs[0][offset+copyN:], view.AsSlice()) + })*/ + for _, view := range packet.AsSlices() { + copyN += copy(bufs[0][offset+copyN:], view) + } + sizes[0] = copyN + return 1, nil + case packet := <-w.packetOutbound: + defer packet.Release() + sizes[0] = copy(bufs[0][offset:], packet.Bytes()) + return 1, nil + case <-w.done: + return 0, os.ErrClosed + } +} + +func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) { + for _, b := range bufs { + b = b[offset:] + if len(b) == 0 { + continue + } + var networkProtocol tcpip.NetworkProtocolNumber + switch header.IPVersion(b) { + case header.IPv4Version: + networkProtocol = header.IPv4ProtocolNumber + case header.IPv6Version: + networkProtocol = header.IPv6ProtocolNumber + } + packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(b), + }) + w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer) + packetBuffer.DecRef() + count++ + } + return +} + +func (w *stackDevice) Flush() error { + return nil +} + +func (w *stackDevice) MTU() (int, error) { + return int(w.mtu), nil +} + +func (w *stackDevice) Name() (string, error) { + return "sing-box", nil +} + +func (w *stackDevice) Events() <-chan wgTun.Event { + return w.events +} + +func (w *stackDevice) Close() error { + close(w.done) + close(w.events) + w.stack.Close() + for _, endpoint := range w.stack.CleanupEndpoints() { + endpoint.Abort() + } + w.stack.Wait() + return nil +} + +func (w *stackDevice) BatchSize() int { + return 1 +} + +var _ stack.LinkEndpoint = (*wireEndpoint)(nil) + +type wireEndpoint stackDevice + +func (ep *wireEndpoint) MTU() uint32 { + return ep.mtu +} + +func (ep *wireEndpoint) SetMTU(mtu uint32) { +} + +func (ep *wireEndpoint) MaxHeaderLength() uint16 { + return 0 +} + +func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +func (ep *wireEndpoint) SetLinkAddress(addr tcpip.LinkAddress) { +} + +func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityRXChecksumOffload +} + +func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) { + ep.dispatcher = dispatcher +} + +func (ep *wireEndpoint) IsAttached() bool { + return ep.dispatcher != nil +} + +func (ep *wireEndpoint) Wait() { +} + +func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) { +} + +func (ep *wireEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool { + return true +} + +func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) { + for _, packetBuffer := range list.AsSlice() { + packetBuffer.IncRef() + select { + case <-ep.done: + return 0, &tcpip.ErrClosedForSend{} + case ep.outbound <- packetBuffer: + } + } + return list.Len(), nil +} + +func (ep *wireEndpoint) Close() { +} + +func (ep *wireEndpoint) SetOnCloseAction(f func()) { +} diff --git a/transport/masque/masque.go b/transport/masque/masque.go new file mode 100644 index 00000000..62b90fd7 --- /dev/null +++ b/transport/masque/masque.go @@ -0,0 +1,166 @@ +package masque + +import ( + "context" + "crypto/tls" + "errors" + "fmt" + "net" + "net/http" + "net/netip" + "net/url" + "strings" + + connectip "github.com/Diniboy1123/connect-ip-go" + "github.com/sagernet/quic-go" + "github.com/sagernet/quic-go/http3" + qtls "github.com/sagernet/sing-quic" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + aTLS "github.com/sagernet/sing/common/tls" + "github.com/yosida95/uritemplate/v3" + "golang.org/x/net/http2" +) + +type ( + DialContext func(ctx context.Context, network, address string) (net.Conn, error) + ListenPacket func(network string, address string) (net.PacketConn, error) +) + +func ConnectTunnel(ctx context.Context, dialer N.Dialer, tlsConfig aTLS.Config, quicConfig *quic.Config, connectUri string, endpoint net.Addr, useHTTP2 bool) (net.PacketConn, *http3.Transport, *connectip.Conn, *http.Response, error) { + template := uritemplate.MustNew(connectUri) + additionalHeaders := http.Header{ + "User-Agent": []string{""}, + } + if useHTTP2 { + h2Endpoint, ok := endpoint.(*net.TCPAddr) + if !ok || h2Endpoint == nil { + return nil, nil, nil, nil, errors.New("missing HTTP/2 TCP endpoint") + } + h2Headers := additionalHeaders.Clone() + h2Headers.Set("cf-connect-proto", "cf-connect-ip") + h2Headers.Set("pq-enabled", "false") + h2Client, err := newHTTP2Client(dialer, tlsConfig, h2Endpoint, connectUri) + if err != nil { + return nil, nil, nil, nil, fmt.Errorf("failed to create HTTP/2 client: %w", err) + } + ipConn, rsp, err := connectip.DialH2(ctx, h2Client, template, h2Headers) + if err != nil { + if strings.Contains(err.Error(), "tls: access denied") { + return nil, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service") + } + return nil, nil, nil, nil, fmt.Errorf("failed to dial connect-ip over HTTP/2: %w", err) + } + return nil, nil, ipConn, rsp, nil + } + quicEndpoint, ok := endpoint.(*net.UDPAddr) + if !ok || quicEndpoint == nil { + return nil, nil, nil, nil, errors.New("missing HTTP/3 UDP endpoint") + } + udpConn, err := dialer.ListenPacket(ctx, M.SocksaddrFromNetIP(quicEndpoint.AddrPort())) + if err != nil { + return nil, nil, nil, nil, err + } + conn, err := qtls.Dial( + ctx, + udpConn, + quicEndpoint, + tlsConfig, + quicConfig, + ) + if err != nil { + return nil, nil, nil, nil, err + } + tr := &http3.Transport{ + EnableDatagrams: true, + AdditionalSettings: map[uint64]uint64{ + // official client still sends this out as well, even though + // it's deprecated, see https://datatracker.ietf.org/doc/draft-ietf-masque-h3-datagram/00/ + // SETTINGS_H3_DATAGRAM_00 = 0x0000000000000276 + // https://github.com/cloudflare/quiche/blob/7c66757dbc55b8d0c3653d4b345c6785a181f0b7/quiche/src/h3/frame.rs#L46 + 0x276: 1, + }, + DisableCompression: true, + } + hconn := tr.NewClientConn(conn) + ipConn, rsp, err := connectip.Dial(ctx, hconn, template, "cf-connect-ip", additionalHeaders, true) + if err != nil { + if err.Error() == "CRYPTO_ERROR 0x131 (remote): tls: access denied" { + return udpConn, nil, nil, nil, errors.New("login failed! Please double-check if your tls key and cert is enrolled in the Cloudflare Access service") + } + return udpConn, nil, nil, nil, fmt.Errorf("failed to dial connect-ip: %w", err) + } + err = ipConn.AdvertiseRoute(ctx, []connectip.IPRoute{ + { + IPProtocol: 0, + StartIP: netip.AddrFrom4([4]byte{}), + EndIP: netip.AddrFrom4([4]byte{255, 255, 255, 255}), + }, + { + IPProtocol: 0, + StartIP: netip.AddrFrom16([16]byte{}), + EndIP: netip.AddrFrom16([16]byte{ + 255, 255, 255, 255, + 255, 255, 255, 255, + 255, 255, 255, 255, + 255, 255, 255, 255, + }), + }, + }) + if err != nil { + return udpConn, nil, nil, nil, err + } + return udpConn, tr, ipConn, rsp, nil +} + +func newHTTP2Client(dialer N.Dialer, baseTLSConfig aTLS.Config, endpoint *net.TCPAddr, connectURI string) (*http.Client, error) { + if endpoint == nil { + return nil, errors.New("missing HTTP/2 endpoint") + } + tlsConfig := baseTLSConfig.Clone() + tlsConfig.SetNextProtos([]string{"h2"}) + return &http.Client{ + Transport: &http2.Transport{ + DialTLSContext: func(ctx context.Context, network, _ string, _ *tls.Config) (net.Conn, error) { + conn, err := dialer.DialContext(ctx, network, M.SocksaddrFromNetIP(endpoint.AddrPort())) + if err != nil { + return nil, err + } + tlsConn, err := tlsConfig.Client(conn) + if err != nil { + return nil, err + } + if err := tlsConn.HandshakeContext(ctx); err != nil { + _ = conn.Close() + return nil, err + } + return tlsConn, nil + }, + }, + }, nil +} + +func authorityWithDefaultPort(u *url.URL, defaultPort string) string { + if u == nil { + return "" + } + + host := u.Hostname() + if host == "" { + return u.Host + } + + port := u.Port() + if port == "" { + port = defaultPort + } + + return net.JoinHostPort(host, port) +} + +func proxyDefaultPort(u *url.URL) string { + if u != nil && u.Scheme == "https" { + return "443" + } + return "80" +} diff --git a/transport/masque/options.go b/transport/masque/options.go new file mode 100644 index 00000000..b2722436 --- /dev/null +++ b/transport/masque/options.go @@ -0,0 +1,24 @@ +package masque + +import ( + "net" + "net/netip" + "time" + + tun "github.com/sagernet/sing-tun" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/tls" +) + +type TunnelOptions struct { + Handler tun.Handler + Dialer N.Dialer + Address []netip.Prefix + Endpoint net.Addr + TLSConfig tls.Config + UseHTTP2 bool + UDPTimeout time.Duration + UDPKeepalivePeriod time.Duration + UDPInitialPacketSize uint16 + ReconnectDelay time.Duration +} diff --git a/transport/masque/tunnel.go b/transport/masque/tunnel.go new file mode 100644 index 00000000..c5f65443 --- /dev/null +++ b/transport/masque/tunnel.go @@ -0,0 +1,200 @@ +package masque + +import ( + "context" + "errors" + "fmt" + "net" + "os" + "time" + + connectip "github.com/Diniboy1123/connect-ip-go" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" +) + +type TunnelDevice interface { + ReadPacket(buf []byte) (int, error) + WritePacket(pkt []byte) error +} + +type Tunnel struct { + ctx context.Context + logger logger.ContextLogger + options TunnelOptions + tunDevice Device + tunnelDevice TunnelDevice +} + +func NewTunnel(ctx context.Context, logger logger.ContextLogger, options TunnelOptions) (*Tunnel, error) { + deviceOptions := DeviceOptions{ + Context: ctx, + Logger: logger, + Handler: options.Handler, + UDPTimeout: options.UDPTimeout, + MTU: 1280, + Address: options.Address, + } + tunDevice, err := NewDevice(deviceOptions) + if err != nil { + return nil, E.Cause(err, "create MASQUE device") + } + return &Tunnel{ + ctx: ctx, + logger: logger, + options: options, + tunDevice: tunDevice, + tunnelDevice: NewNetstackAdapter(tunDevice), + }, nil +} + +func (e *Tunnel) Start(resolve bool) error { + if resolve { + err := e.tunDevice.Start() + if err != nil { + return err + } + go e.MaintainTunnel() + } + return nil +} + +func (e *Tunnel) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + if !destination.Addr.IsValid() { + return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination") + } + return e.tunDevice.DialContext(ctx, network, destination) +} + +func (e *Tunnel) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + if !destination.Addr.IsValid() { + return nil, E.Cause(os.ErrInvalid, "invalid non-IP destination") + } + return e.tunDevice.ListenPacket(ctx, destination) +} + +func (e *Tunnel) Close() error { + return e.tunDevice.Close() +} + +func (e *Tunnel) MaintainTunnel() { + packetBufferPool := NewNetBuffer(1280) + timer := time.NewTimer(0) + defer timer.Stop() + for { + select { + case <-e.ctx.Done(): + return + default: + } + e.logger.InfoContext(e.ctx, fmt.Errorf("Establishing MASQUE connection to %s", e.options.Endpoint)) + udpConn, tr, ipConn, rsp, err := ConnectTunnel( + e.ctx, + e.options.Dialer, + e.options.TLSConfig, + DefaultQuicConfig(e.options.UDPKeepalivePeriod, e.options.UDPInitialPacketSize), + "https://cloudflareaccess.com", + e.options.Endpoint, + e.options.UseHTTP2, + ) + if err != nil { + e.logger.InfoContext(e.ctx, fmt.Errorf("Failed to connect tunnel: %v", err)) + timer.Reset(e.options.ReconnectDelay) + select { + case <-e.ctx.Done(): + return + case <-timer.C: + } + continue + } + if rsp.StatusCode != 200 { + e.logger.InfoContext(e.ctx, fmt.Errorf("Tunnel connection failed: %s", rsp.Status)) + ipConn.Close() + if udpConn != nil { + udpConn.Close() + } + if tr != nil { + tr.Close() + } + timer.Reset(e.options.ReconnectDelay) + select { + case <-e.ctx.Done(): + return + case <-timer.C: + } + continue + } + e.logger.InfoContext(e.ctx, "Connected to MASQUE server") + errChan := make(chan error, 2) + go func() { + for { + buf := packetBufferPool.Get() + n, err := e.tunnelDevice.ReadPacket(buf) + if err != nil { + packetBufferPool.Put(buf) + errChan <- fmt.Errorf("failed to read from TUN device: %w", err) + return + } + icmp, err := ipConn.WritePacket(buf[:n]) + if err != nil { + packetBufferPool.Put(buf) + if errors.As(err, new(*connectip.CloseError)) { + errChan <- fmt.Errorf("connection closed while writing to IP connection: %w", err) + return + } + e.logger.InfoContext(e.ctx, fmt.Errorf("Error writing to IP connection: %v, continuing...", err)) + continue + } + packetBufferPool.Put(buf) + if len(icmp) > 0 { + if err := e.tunnelDevice.WritePacket(icmp); err != nil { + if errors.As(err, new(*connectip.CloseError)) { + errChan <- fmt.Errorf("connection closed while writing ICMP to TUN device: %w", err) + return + } + e.logger.InfoContext(e.ctx, fmt.Errorf("Error writing ICMP to TUN device: %v, continuing...", err)) + } + } + } + }() + go func() { + buf := packetBufferPool.Get() + defer packetBufferPool.Put(buf) + for { + n, err := ipConn.ReadPacket(buf, true) + if err != nil { + if e.options.UseHTTP2 { + errChan <- fmt.Errorf("connection closed while reading from IP connection: %w", err) + return + } + if errors.As(err, new(*connectip.CloseError)) { + errChan <- fmt.Errorf("connection closed while reading from IP connection: %w", err) + return + } + e.logger.InfoContext(e.ctx, fmt.Errorf("Error reading from IP connection: %v, continuing...", err)) + continue + } + if err := e.tunnelDevice.WritePacket(buf[:n]); err != nil { + errChan <- fmt.Errorf("failed to write to TUN device: %w", err) + return + } + } + }() + err = <-errChan + e.logger.InfoContext(e.ctx, fmt.Errorf("Tunnel connection lost: %v. Reconnecting...", err)) + ipConn.Close() + if udpConn != nil { + udpConn.Close() + } + if tr != nil { + tr.Close() + } + timer.Reset(e.options.ReconnectDelay) + select { + case <-e.ctx.Done(): + return + case <-timer.C: + } + } +} diff --git a/transport/masque/utils.go b/transport/masque/utils.go new file mode 100644 index 00000000..b99b4459 --- /dev/null +++ b/transport/masque/utils.go @@ -0,0 +1,326 @@ +package masque + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/base64" + "encoding/hex" + "errors" + "log" + "math/big" + "net" + "strconv" + "strings" + "time" + + "github.com/sagernet/quic-go" +) + +// PortMapping represents a network port forwarding rule. +type PortMapping struct { + BindAddress string // The address to bind the local port. + LocalPort int // The local port number. + RemoteIP string // The remote destination IP address. + RemotePort int // The remote destination port number. +} + +// GenerateRandomAndroidSerial generates a random 8-byte Android-like device identifier +// and returns it as a hexadecimal string. +// +// Returns: +// - string: A randomly generated 16-character hexadecimal serial number. +// - error: An error if random data generation fails. +func GenerateRandomAndroidSerial() (string, error) { + serial := make([]byte, 8) + if _, err := rand.Read(serial); err != nil { + return "", err + } + return hex.EncodeToString(serial), nil +} + +// GenerateRandomWgPubkey generates a random 32-byte WireGuard like public key +// and returns it as a base64-encoded string. +// +// Returns: +// - string: A randomly generated WireGuard like public key in base64 format. +// - error: An error if random data generation fails. +func GenerateRandomWgPubkey() (string, error) { + publicKey := make([]byte, 32) + if _, err := rand.Read(publicKey); err != nil { + return "", err + } + return base64.StdEncoding.EncodeToString(publicKey), nil +} + +// TimeAsCfString formats a given time.Time into a Cloudflare-compatible string format. +// +// The format follows the standard: "YYYY-MM-DDTHH:MM:SS.sss-07:00". +// +// Parameters: +// - t: time.Time to format. +// +// Returns: +// - string: The formatted time string. +func TimeAsCfString(t time.Time) string { + return t.Format("2006-01-02T15:04:05.000-07:00") +} + +// GenerateEcKeyPair generates a new ECDSA key pair using the P-256 curve. +// +// Returns: +// - []byte: The marshalled private key in ASN.1 DER format. +// - []byte: The marshalled public key in PKIX format. +// - error: An error if key generation or marshalling fails. +func GenerateEcKeyPair() ([]byte, []byte, error) { + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, nil, err + } + + marshalledPrivKey, err := x509.MarshalECPrivateKey(privKey) + if err != nil { + return nil, nil, err + } + + marshalledPubKey, err := x509.MarshalPKIXPublicKey(&privKey.PublicKey) + if err != nil { + return nil, nil, err + } + + return marshalledPrivKey, marshalledPubKey, nil +} + +// GenerateCert creates a self-signed certificate using the provided ECDSA private and public keys. +// +// The certificate is valid for 24 hours. +// +// Parameters: +// - privKey: *ecdsa.PrivateKey - The private key to sign the certificate. +// - pubKey: *ecdsa.PublicKey - The public key to include in the certificate. +// +// Returns: +// - [][]byte: A slice containing the certificate in DER format. +// - error: An error if certificate generation fails. +func GenerateCert(privKey *ecdsa.PrivateKey, pubKey *ecdsa.PublicKey) ([][]byte, error) { + cert, err := x509.CreateCertificate(rand.Reader, &x509.Certificate{ + SerialNumber: big.NewInt(0), + NotBefore: time.Now(), + NotAfter: time.Now().Add(1 * 24 * time.Hour), + }, &x509.Certificate{}, &privKey.PublicKey, privKey) + if err != nil { + return nil, err + } + + return [][]byte{cert}, nil +} + +// DefaultQuicConfig returns a MASQUE-compatible default QUIC configuration. +// +// When initialPacketSize is 0, Path MTU Discovery remains enabled. +// +// Parameters: +// - keepalivePeriod: time.Duration - The duration for sending QUIC keep-alive packets. +// - initialPacketSize: uint16 - The custom initial size of QUIC packets (0 = auto with PMTU discovery). +// +// Returns: +// - *quic.Config: A pointer to a configured QUIC configuration object. +func DefaultQuicConfig(keepalivePeriod time.Duration, initialPacketSize uint16) *quic.Config { + cfg := &quic.Config{ + EnableDatagrams: true, + KeepAlivePeriod: keepalivePeriod, + } + + if initialPacketSize > 0 { + cfg.InitialPacketSize = initialPacketSize + cfg.DisablePathMTUDiscovery = true + } + + return cfg +} + +// parsePortMapping is an internal helper function that parses a port mapping string into its components. +// +// It handles IPv6 addresses enclosed in brackets and various format edge cases. +// +// Parameters: +// - port: string - The port mapping string. +// +// Returns: +// - string: The bind address. +// - int: The local port. +// - string: The remote hostname/IP. +// - int: The remote port. +// - error: An error if parsing fails. +func parsePortMapping(port string) (bindAddress string, localPort int, remoteHost string, remotePort int, err error) { + parts := strings.Split(port, ":") + + // Handle IPv6 addresses (which are enclosed in brackets) + if len(parts) >= 4 && strings.HasPrefix(parts[0], "[") && strings.Contains(parts[0], "]") { + bindAddress = parts[0] + parts = parts[1:] // Shift parts forward + } else if len(parts) == 3 { + bindAddress = "localhost" // Default to localhost + } else if len(parts) == 4 { + bindAddress = parts[0] + parts = parts[1:] // Shift forward + } else { + return "", 0, "", 0, errors.New("invalid port mapping format (expected format: [bind_address:]local_port:remote_host:remote_port)") + } + + // Parse local port + localPort, err = strconv.Atoi(parts[0]) + if err != nil || localPort <= 0 || localPort > 65535 { + return "", 0, "", 0, errors.New("invalid local port") + } + + // Validate remote host (allow both hostnames and IPs) + remoteHost = parts[1] + if net.ParseIP(remoteHost) == nil && !isValidHostname(remoteHost) { + return "", 0, "", 0, errors.New("invalid remote hostname/IP") + } + + // Parse remote port + remotePort, err = strconv.Atoi(parts[2]) + if err != nil || remotePort <= 0 || remotePort > 65535 { + return "", 0, "", 0, errors.New("invalid remote port") + } + + // If bindAddress is an IPv6 address, remove brackets for proper binding + if strings.HasPrefix(bindAddress, "[") && strings.HasSuffix(bindAddress, "]") { + bindAddress = strings.Trim(bindAddress, "[]") + } + + // Convert "localhost" or hostnames to actual addresses + if bindAddress == "*" { + bindAddress = "0.0.0.0" // Allow all interfaces + } + + // Validate bind address (support both IPs and hostnames) + bindAddress, err = resolveBindAddress(bindAddress) + if err != nil { + return "", 0, "", 0, errors.New("invalid local address: " + err.Error()) + } + + remoteHost, err = resolveBindAddress(remoteHost) + if err != nil { + return "", 0, "", 0, errors.New("invalid remote address: " + err.Error()) + } + + return bindAddress, localPort, remoteHost, remotePort, nil +} + +// ParsePortMapping parses a port mapping string into a structured PortMapping. +// +// The expected format is: `[bind_address:]local_port:remote_host:remote_port`. +// +// Parameters: +// - port: string - The port mapping string. +// +// Returns: +// - PortMapping: A structured representation of the parsed port mapping. +// - error: An error if the parsing fails. +func ParsePortMapping(port string) (PortMapping, error) { + bindAddress, localPort, remoteHost, remotePort, err := parsePortMapping(port) + if err != nil { + return PortMapping{}, err + } + + return PortMapping{ + BindAddress: bindAddress, + LocalPort: localPort, + RemoteIP: remoteHost, + RemotePort: remotePort, + }, nil +} + +// resolveBindAddress resolves a hostname or IP to its string representation. +// +// Parameters: +// - addr: string - The hostname or IP. +// +// Returns: +// - string: The resolved IP address. +// - error: An error if resolution fails. +func resolveBindAddress(addr string) (string, error) { + tcpAddr, err := net.ResolveTCPAddr("tcp", addr+":0") // Resolve the address + if err != nil { + return "", err + } + return tcpAddr.IP.String(), nil // Return resolved IP +} + +// isValidHostname checks if a given hostname is valid. +// Pretty ugly for now, needs to be refactored. +// +// Parameters: +// - hostname: string - The hostname to validate. +// +// Returns: +// - bool: True if valid, false otherwise. +func isValidHostname(hostname string) bool { + // Must contain at least one dot (.) unless it's "localhost" + if hostname == "localhost" { + return true + } + return strings.Contains(hostname, ".") +} + +// LoginToBase64 encodes a username and password into a base64-encoded string in "username:password" format. +// This is commonly used for HTTP Basic Authentication. +// +// Parameters: +// - username: string - The username to encode. +// - password: string - The password to encode. +// +// Returns: +// - string: The base64-encoded "username:password" string. +func LoginToBase64(username, password string) string { + return base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) +} + +// CheckIfname validates a network interface name according to the following rules: +// - Must not be empty. +// - Should not exceed 15 characters (warning if it does). +// - Should not contain non-ASCII characters (warning if it does). +// - Should not contain invalid characters: '/', whitespace, or control characters. +// +// Parameters: +// - name: string - The interface name to validate. +// +// Returns: +// - error: An error if the name is invalid, or nil if valid. +func CheckIfname(name string) error { + if name == "" { + return errors.New("interface name cannot be empty") + } + + if len(name) >= 16 { + log.Printf("Warning: interface name '%s' is longer than %d characters", name, 16-1) + } + + var invalidChar bool + var hasWhitespace bool + + for _, r := range name { + if r > 127 { + invalidChar = true + break + } + if r == '/' || r == ' ' || strings.ContainsRune("\t\n\v\f\r", r) { + hasWhitespace = true + break + } + } + + if invalidChar { + log.Printf("Warning: interface name contains non-ASCII character") + } + + if hasWhitespace { + return errors.New("interface name contains invalid character: '/' or whitespace") + } + + return nil +}