Compare commits

..

31 Commits

Author SHA1 Message Date
世界
003423745e mitm: Refactor & Add url 2025-02-04 15:00:59 +08:00
世界
fb3007fa80 mitm: Minor fixes 2025-02-03 11:32:37 +08:00
世界
5361d2acec mitm: Add /mitm/mobileconfig and /mitm/certificate clash api endpoints 2025-02-03 09:09:41 +08:00
世界
1fe983a81b mitm: Fix HTTP2 support & Add print 2025-02-03 08:20:26 +08:00
世界
b01fe5d364 Fix override address 2025-02-02 23:17:31 +08:00
世界
74920b44ac mitm: Add HTTP2 support 2025-02-02 21:36:09 +08:00
世界
5e28a80e63 Add Surge MITM and scripts 2025-02-02 17:27:29 +08:00
世界
b55bfca7de documentation: Bump version 2025-02-02 07:21:34 +08:00
世界
a0dc394c8f Fix WireGuard panic 2025-02-02 07:21:33 +08:00
世界
87f3423d7e Fix query options leakage 2025-02-02 07:21:33 +08:00
世界
244243f206 Fix domain resolver for DNS server 2025-02-02 07:21:33 +08:00
世界
89855ff548 documentation: Bump version 2025-02-02 07:21:27 +08:00
世界
8a388e9c90 documentation: Fix fakeip example 2025-02-02 07:21:27 +08:00
世界
6bf39156ec release: Skip testflight when another build in review 2025-02-02 07:21:27 +08:00
世界
ab021bee74 Update quic-go to v0.49.0-beta.1 2025-02-02 07:21:27 +08:00
世界
6fd95e157a Fix fakeip not started 2025-02-02 07:21:27 +08:00
世界
6c0e2bf526 Fix missing default domain resolver in route 2025-02-02 07:21:15 +08:00
世界
4f61fc20e0 Fix missing default store value 2025-02-02 07:21:15 +08:00
世界
438405bbf1 documentation: Bump version 2025-02-02 07:21:14 +08:00
世界
b0a2ed9f7e documentation: Remove outdated icons 2025-02-02 07:21:14 +08:00
世界
f4d5823bb2 documentation: Certificate store 2025-02-02 07:21:14 +08:00
世界
e8b43a97f7 documentation: TLS fragment 2025-02-02 07:21:13 +08:00
世界
00460f15f6 documentation: Outbound domain resolver 2025-02-02 07:21:13 +08:00
世界
ca33246c9e documentation: Refactor DNS 2025-02-02 07:21:12 +08:00
世界
e04a4181f4 Add certificate store 2025-02-02 07:21:12 +08:00
世界
ff4f455f25 Add TLS fragment support 2025-02-02 07:21:11 +08:00
世界
9a7c0d9136 refactor: Outbound domain resolver 2025-02-02 07:21:11 +08:00
世界
57de88b9c9 refactor: DNS 2025-02-02 07:21:10 +08:00
世界
7f79458b4f Minor updates 2025-02-01 19:49:33 +08:00
世界
9b4c11ba95 Fix rule-set not closed 2025-02-01 19:49:33 +08:00
世界
27c31eac5d Fix local rule-set not updated 2025-02-01 19:42:21 +08:00
90 changed files with 7325 additions and 16002 deletions

View File

@@ -50,17 +50,21 @@ type CacheFile interface {
StoreSelected(group string, selected string) error StoreSelected(group string, selected string) error
LoadGroupExpand(group string) (isExpand bool, loaded bool) LoadGroupExpand(group string) (isExpand bool, loaded bool)
StoreGroupExpand(group string, expand bool) error StoreGroupExpand(group string, expand bool) error
LoadRuleSet(tag string) *SavedRuleSet LoadRuleSet(tag string) *SavedBinary
SaveRuleSet(tag string, set *SavedRuleSet) error SaveRuleSet(tag string, set *SavedBinary) error
LoadScript(tag string) *SavedBinary
SaveScript(tag string, script *SavedBinary) error
SurgePersistentStoreRead(key string) string
SurgePersistentStoreWrite(key string, value string) error
} }
type SavedRuleSet struct { type SavedBinary struct {
Content []byte Content []byte
LastUpdated time.Time LastUpdated time.Time
LastEtag string LastEtag string
} }
func (s *SavedRuleSet) MarshalBinary() ([]byte, error) { func (s *SavedBinary) MarshalBinary() ([]byte, error) {
var buffer bytes.Buffer var buffer bytes.Buffer
err := binary.Write(&buffer, binary.BigEndian, uint8(1)) err := binary.Write(&buffer, binary.BigEndian, uint8(1))
if err != nil { if err != nil {
@@ -81,7 +85,7 @@ func (s *SavedRuleSet) MarshalBinary() ([]byte, error) {
return buffer.Bytes(), nil return buffer.Bytes(), nil
} }
func (s *SavedRuleSet) UnmarshalBinary(data []byte) error { func (s *SavedBinary) UnmarshalBinary(data []byte) error {
reader := bytes.NewReader(data) reader := bytes.NewReader(data)
var version uint8 var version uint8
err := binary.Read(reader, binary.BigEndian, &version) err := binary.Read(reader, binary.BigEndian, &version)

View File

@@ -2,6 +2,8 @@ package adapter
import ( import (
"context" "context"
"crypto/tls"
"net/http"
"net/netip" "net/netip"
"time" "time"
@@ -57,6 +59,8 @@ type InboundContext struct {
Domain string Domain string
Client string Client string
SniffContext any SniffContext any
HTTPRequest *http.Request
ClientHello *tls.ClientHelloInfo
// cache // cache
@@ -73,6 +77,7 @@ type InboundContext struct {
UDPTimeout time.Duration UDPTimeout time.Duration
TLSFragment bool TLSFragment bool
TLSFragmentFallbackDelay time.Duration TLSFragmentFallbackDelay time.Duration
MITM *option.MITMRouteOptions
NetworkStrategy *C.NetworkStrategy NetworkStrategy *C.NetworkStrategy
NetworkType []C.InterfaceType NetworkType []C.InterfaceType

View File

@@ -1,6 +1,8 @@
package adapter package adapter
import E "github.com/sagernet/sing/common/exceptions" import (
E "github.com/sagernet/sing/common/exceptions"
)
type StartStage uint8 type StartStage uint8
@@ -45,6 +47,9 @@ type LifecycleService interface {
func Start(stage StartStage, services ...Lifecycle) error { func Start(stage StartStage, services ...Lifecycle) error {
for _, service := range services { for _, service := range services {
if service == nil {
continue
}
err := service.Start(stage) err := service.Start(stage)
if err != nil { if err != nil {
return err return err

15
adapter/mitm.go Normal file
View File

@@ -0,0 +1,15 @@
package adapter
import (
"context"
"crypto/x509"
"net"
N "github.com/sagernet/sing/common/network"
)
type MITMEngine interface {
Lifecycle
ExportCertificate() *x509.Certificate
NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata InboundContext, onClose N.CloseHandlerFunc)
}

54
adapter/script.go Normal file
View File

@@ -0,0 +1,54 @@
package adapter
import (
"context"
"net/http"
"sync"
"time"
)
type ScriptManager interface {
Lifecycle
Scripts() []Script
Script(name string) (Script, bool)
SurgeCache() *SurgeInMemoryCache
}
type SurgeInMemoryCache struct {
sync.RWMutex
Data map[string]string
}
type Script interface {
Type() string
Tag() string
StartContext(ctx context.Context, startContext *HTTPStartContext) error
PostStart() error
Close() error
}
type SurgeScript interface {
Script
ExecuteGeneric(ctx context.Context, scriptType string, timeout time.Duration, arguments []string) error
ExecuteHTTPRequest(ctx context.Context, timeout time.Duration, request *http.Request, body []byte, binaryBody bool, arguments []string) (*HTTPRequestScriptResult, error)
ExecuteHTTPResponse(ctx context.Context, timeout time.Duration, request *http.Request, response *http.Response, body []byte, binaryBody bool, arguments []string) (*HTTPResponseScriptResult, error)
}
type HTTPRequestScriptResult struct {
URL string
Headers http.Header
Body []byte
Response *HTTPRequestScriptResponse
}
type HTTPRequestScriptResponse struct {
Status int
Headers http.Header
Body []byte
}
type HTTPResponseScriptResult struct {
Status int
Headers http.Header
Body []byte
}

39
box.go
View File

@@ -23,9 +23,11 @@ import (
"github.com/sagernet/sing-box/experimental/cachefile" "github.com/sagernet/sing-box/experimental/cachefile"
"github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/mitm"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/protocol/direct" "github.com/sagernet/sing-box/protocol/direct"
"github.com/sagernet/sing-box/route" "github.com/sagernet/sing-box/route"
"github.com/sagernet/sing-box/script"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
@@ -48,6 +50,8 @@ type Box struct {
dnsRouter *dns.Router dnsRouter *dns.Router
connection *route.ConnectionManager connection *route.ConnectionManager
router *route.Router router *route.Router
script *script.Manager
mitm adapter.MITMEngine //*mitm.Engine
services []adapter.LifecycleService services []adapter.LifecycleService
done chan struct{} done chan struct{}
} }
@@ -173,7 +177,7 @@ func New(options Options) (*Box, error) {
return nil, E.Cause(err, "initialize network manager") return nil, E.Cause(err, "initialize network manager")
} }
service.MustRegister[adapter.NetworkManager](ctx, networkManager) service.MustRegister[adapter.NetworkManager](ctx, networkManager)
connectionManager := route.NewConnectionManager(logFactory.NewLogger("connection")) connectionManager := route.NewConnectionManager(ctx, logFactory.NewLogger("connection"))
service.MustRegister[adapter.ConnectionManager](ctx, connectionManager) service.MustRegister[adapter.ConnectionManager](ctx, connectionManager)
router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions) router := route.NewRouter(ctx, logFactory, routeOptions, dnsOptions)
service.MustRegister[adapter.Router](ctx, router) service.MustRegister[adapter.Router](ctx, router)
@@ -181,8 +185,8 @@ func New(options Options) (*Box, error) {
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize router") return nil, E.Cause(err, "initialize router")
} }
ntpOptions := common.PtrValueOrDefault(options.NTP)
var timeService *tls.TimeServiceWrapper var timeService *tls.TimeServiceWrapper
ntpOptions := common.PtrValueOrDefault(options.NTP)
if ntpOptions.Enabled { if ntpOptions.Enabled {
timeService = new(tls.TimeServiceWrapper) timeService = new(tls.TimeServiceWrapper)
service.MustRegister[ntp.TimeService](ctx, timeService) service.MustRegister[ntp.TimeService](ctx, timeService)
@@ -202,7 +206,7 @@ func New(options Options) (*Box, error) {
transportOptions.Options, transportOptions.Options,
) )
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize inbound[", i, "]") return nil, E.Cause(err, "initialize DNS server[", i, "]")
} }
} }
err = dnsRouter.Initialize(dnsOptions.Rules) err = dnsRouter.Initialize(dnsOptions.Rules)
@@ -225,7 +229,7 @@ func New(options Options) (*Box, error) {
endpointOptions.Options, endpointOptions.Options,
) )
if err != nil { if err != nil {
return nil, E.Cause(err, "initialize inbound[", i, "]") return nil, E.Cause(err, "initialize endpoint[", i, "]")
} }
} }
for i, inboundOptions := range options.Inbounds { for i, inboundOptions := range options.Inbounds {
@@ -289,6 +293,11 @@ func New(options Options) (*Box, error) {
"local", "local",
option.LocalDNSServerOptions{}, option.LocalDNSServerOptions{},
))) )))
scriptManager, err := script.NewManager(ctx, logFactory, options.Scripts)
if err != nil {
return nil, E.Cause(err, "initialize script manager")
}
service.MustRegister[adapter.ScriptManager](ctx, scriptManager)
if platformInterface != nil { if platformInterface != nil {
err = platformInterface.Initialize(networkManager) err = platformInterface.Initialize(networkManager)
if err != nil { if err != nil {
@@ -338,6 +347,16 @@ func New(options Options) (*Box, error) {
timeService.TimeService = ntpService timeService.TimeService = ntpService
services = append(services, adapter.NewLifecycleService(ntpService, "ntp service")) services = append(services, adapter.NewLifecycleService(ntpService, "ntp service"))
} }
mitmOptions := common.PtrValueOrDefault(options.MITM)
var mitmEngine adapter.MITMEngine
if mitmOptions.Enabled {
engine, err := mitm.NewEngine(ctx, logFactory.NewLogger("mitm"), mitmOptions)
if err != nil {
return nil, E.Cause(err, "create MITM engine")
}
service.MustRegister[adapter.MITMEngine](ctx, engine)
mitmEngine = engine
}
return &Box{ return &Box{
network: networkManager, network: networkManager,
endpoint: endpointManager, endpoint: endpointManager,
@@ -347,6 +366,8 @@ func New(options Options) (*Box, error) {
dnsRouter: dnsRouter, dnsRouter: dnsRouter,
connection: connectionManager, connection: connectionManager,
router: router, router: router,
script: scriptManager,
mitm: mitmEngine,
createdAt: createdAt, createdAt: createdAt,
logFactory: logFactory, logFactory: logFactory,
logger: logFactory.Logger(), logger: logFactory.Logger(),
@@ -405,11 +426,11 @@ func (s *Box) preStart() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint) err = adapter.Start(adapter.StartStateInitialize, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.script, s.mitm, s.outbound, s.inbound, s.endpoint)
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router) err = adapter.Start(adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router, s.script, s.mitm)
if err != nil { if err != nil {
return err return err
} }
@@ -433,7 +454,7 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.inbound, s.endpoint) err = adapter.Start(adapter.StartStatePostStart, s.outbound, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.script, s.mitm, s.inbound, s.endpoint)
if err != nil { if err != nil {
return err return err
} }
@@ -441,7 +462,7 @@ func (s *Box) start() error {
if err != nil { if err != nil {
return err return err
} }
err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.outbound, s.inbound, s.endpoint) err = adapter.Start(adapter.StartStateStarted, s.network, s.dnsTransport, s.dnsRouter, s.connection, s.router, s.script, s.mitm, s.outbound, s.inbound, s.endpoint)
if err != nil { if err != nil {
return err return err
} }
@@ -460,7 +481,7 @@ func (s *Box) Close() error {
close(s.done) close(s.done)
} }
err := common.Close( err := common.Close(
s.inbound, s.outbound, s.endpoint, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network, s.inbound, s.outbound, s.endpoint, s.mitm, s.script, s.router, s.connection, s.dnsRouter, s.dnsTransport, s.network,
) )
for _, lifecycleService := range s.services { for _, lifecycleService := range s.services {
err = E.Append(err, lifecycleService.Close(), func(err error) error { err = E.Append(err, lifecycleService.Close(), func(err error) error {

View File

@@ -30,7 +30,7 @@ func init() {
} }
func generateTLSKeyPair(serverName string) error { func generateTLSKeyPair(serverName string) error {
privateKeyPem, publicKeyPem, err := tls.GenerateKeyPair(time.Now, serverName, time.Now().AddDate(0, flagGenerateTLSKeyPairMonths, 0)) privateKeyPem, publicKeyPem, err := tls.GenerateCertificate(nil, nil, time.Now, serverName, time.Now().AddDate(0, flagGenerateTLSKeyPairMonths, 0))
if err != nil { if err != nil {
return err return err
} }

View File

@@ -35,7 +35,6 @@ type DefaultDialer struct {
udpListener net.ListenConfig udpListener net.ListenConfig
udpAddr4 string udpAddr4 string
udpAddr6 string udpAddr6 string
isWireGuardListener bool
networkManager adapter.NetworkManager networkManager adapter.NetworkManager
networkStrategy *C.NetworkStrategy networkStrategy *C.NetworkStrategy
defaultNetworkStrategy bool defaultNetworkStrategy bool
@@ -183,11 +182,6 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
} }
setMultiPathTCP(&dialer4) setMultiPathTCP(&dialer4)
} }
if options.IsWireGuardListener {
for _, controlFn := range WgControlFns {
listener.Control = control.Append(listener.Control, controlFn)
}
}
tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen) tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -204,7 +198,6 @@ func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDial
udpListener: listener, udpListener: listener,
udpAddr4: udpAddr4, udpAddr4: udpAddr4,
udpAddr6: udpAddr6, udpAddr6: udpAddr6,
isWireGuardListener: options.IsWireGuardListener,
networkManager: networkManager, networkManager: networkManager,
networkStrategy: networkStrategy, networkStrategy: networkStrategy,
defaultNetworkStrategy: defaultNetworkStrategy, defaultNetworkStrategy: defaultNetworkStrategy,

View File

@@ -29,16 +29,18 @@ func (d *DetourDialer) Start() error {
} }
func (d *DetourDialer) Dialer() (N.Dialer, error) { func (d *DetourDialer) Dialer() (N.Dialer, error) {
d.initOnce.Do(func() { d.initOnce.Do(d.init)
var loaded bool
d.dialer, loaded = d.outboundManager.Outbound(d.detour)
if !loaded {
d.initErr = E.New("outbound detour not found: ", d.detour)
}
})
return d.dialer, d.initErr return d.dialer, d.initErr
} }
func (d *DetourDialer) init() {
var loaded bool
d.dialer, loaded = d.outboundManager.Outbound(d.detour)
if !loaded {
d.initErr = E.New("outbound detour not found: ", d.detour)
}
}
func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *DetourDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
dialer, err := d.Dialer() dialer, err := d.Dialer()
if err != nil { if err != nil {

View File

@@ -16,59 +16,82 @@ import (
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
) )
type Options struct {
Context context.Context
Options option.DialerOptions
RemoteIsDomain bool
DirectResolver bool
ResolverOnDetour bool
}
// TODO: merge with NewWithOptions
func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) { func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) {
if options.IsWireGuardListener { return NewWithOptions(Options{
return NewDefault(ctx, options) Context: ctx,
} Options: options,
RemoteIsDomain: remoteIsDomain,
})
}
func NewWithOptions(options Options) (N.Dialer, error) {
dialOptions := options.Options
var ( var (
dialer N.Dialer dialer N.Dialer
err error err error
) )
if options.Detour != "" { if dialOptions.Detour != "" {
outboundManager := service.FromContext[adapter.OutboundManager](ctx) outboundManager := service.FromContext[adapter.OutboundManager](options.Context)
if outboundManager == nil { if outboundManager == nil {
return nil, E.New("missing outbound manager") return nil, E.New("missing outbound manager")
} }
dialer = NewDetour(outboundManager, options.Detour) dialer = NewDetour(outboundManager, dialOptions.Detour)
} else { } else {
dialer, err = NewDefault(ctx, options) dialer, err = NewDefault(options.Context, dialOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if remoteIsDomain && options.Detour == "" { if options.RemoteIsDomain && (dialOptions.Detour == "" || options.ResolverOnDetour) {
networkManager := service.FromContext[adapter.NetworkManager](ctx) networkManager := service.FromContext[adapter.NetworkManager](options.Context)
dnsTransport := service.FromContext[adapter.DNSTransportManager](ctx) dnsTransport := service.FromContext[adapter.DNSTransportManager](options.Context)
var defaultOptions adapter.NetworkOptions var defaultOptions adapter.NetworkOptions
if networkManager != nil { if networkManager != nil {
defaultOptions = networkManager.DefaultOptions() defaultOptions = networkManager.DefaultOptions()
} }
var ( var (
server string
dnsQueryOptions adapter.DNSQueryOptions dnsQueryOptions adapter.DNSQueryOptions
resolveFallbackDelay time.Duration resolveFallbackDelay time.Duration
) )
if options.DomainResolver != nil && options.DomainResolver.Server != "" { if dialOptions.DomainResolver != nil && dialOptions.DomainResolver.Server != "" {
transport, loaded := dnsTransport.Transport(options.DomainResolver.Server) var transport adapter.DNSTransport
if !loaded { if !options.DirectResolver {
return nil, E.New("domain resolver not found: " + options.DomainResolver.Server) var loaded bool
transport, loaded = dnsTransport.Transport(dialOptions.DomainResolver.Server)
if !loaded {
return nil, E.New("domain resolver not found: " + dialOptions.DomainResolver.Server)
}
} }
var strategy C.DomainStrategy var strategy C.DomainStrategy
if options.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) { if dialOptions.DomainResolver.Strategy != option.DomainStrategy(C.DomainStrategyAsIS) {
strategy = C.DomainStrategy(options.DomainResolver.Strategy) strategy = C.DomainStrategy(dialOptions.DomainResolver.Strategy)
} else if } else if
//nolint:staticcheck //nolint:staticcheck
options.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) { dialOptions.DomainStrategy != option.DomainStrategy(C.DomainStrategyAsIS) {
//nolint:staticcheck //nolint:staticcheck
strategy = C.DomainStrategy(options.DomainStrategy) strategy = C.DomainStrategy(dialOptions.DomainStrategy)
} }
server = dialOptions.DomainResolver.Server
dnsQueryOptions = adapter.DNSQueryOptions{ dnsQueryOptions = adapter.DNSQueryOptions{
Transport: transport, Transport: transport,
Strategy: strategy, Strategy: strategy,
DisableCache: options.DomainResolver.DisableCache, DisableCache: dialOptions.DomainResolver.DisableCache,
RewriteTTL: options.DomainResolver.RewriteTTL, RewriteTTL: dialOptions.DomainResolver.RewriteTTL,
ClientSubnet: options.DomainResolver.ClientSubnet.Build(netip.Prefix{}), ClientSubnet: dialOptions.DomainResolver.ClientSubnet.Build(netip.Prefix{}),
} }
resolveFallbackDelay = time.Duration(options.FallbackDelay) resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
} else if options.DirectResolver {
return nil, E.New("missing domain resolver for domain server address")
} else if defaultOptions.DomainResolver != "" { } else if defaultOptions.DomainResolver != "" {
dnsQueryOptions = defaultOptions.DomainResolveOptions dnsQueryOptions = defaultOptions.DomainResolveOptions
transport, loaded := dnsTransport.Transport(defaultOptions.DomainResolver) transport, loaded := dnsTransport.Transport(defaultOptions.DomainResolver)
@@ -76,14 +99,15 @@ func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool)
return nil, E.New("default domain resolver not found: " + defaultOptions.DomainResolver) return nil, E.New("default domain resolver not found: " + defaultOptions.DomainResolver)
} }
dnsQueryOptions.Transport = transport dnsQueryOptions.Transport = transport
resolveFallbackDelay = time.Duration(options.FallbackDelay) resolveFallbackDelay = time.Duration(dialOptions.FallbackDelay)
} else { } else {
deprecated.Report(ctx, deprecated.OptionMissingDomainResolver) deprecated.Report(options.Context, deprecated.OptionMissingDomainResolver)
} }
dialer = NewResolveDialer( dialer = NewResolveDialer(
ctx, options.Context,
dialer, dialer,
options.Detour == "" && !options.TCPFastOpen, dialOptions.Detour == "" && !dialOptions.TCPFastOpen,
server,
dnsQueryOptions, dnsQueryOptions,
resolveFallbackDelay, resolveFallbackDelay,
) )

View File

@@ -3,12 +3,14 @@ package dialer
import ( import (
"context" "context"
"net" "net"
"sync"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service" "github.com/sagernet/sing/service"
@@ -30,20 +32,26 @@ type ParallelInterfaceResolveDialer interface {
} }
type resolveDialer struct { type resolveDialer struct {
transport adapter.DNSTransportManager
router adapter.DNSRouter router adapter.DNSRouter
dialer N.Dialer dialer N.Dialer
parallel bool parallel bool
server string
initOnce sync.Once
initErr error
queryOptions adapter.DNSQueryOptions queryOptions adapter.DNSQueryOptions
fallbackDelay time.Duration fallbackDelay time.Duration
} }
func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer { func NewResolveDialer(ctx context.Context, dialer N.Dialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ResolveDialer {
return &resolveDialer{ return &resolveDialer{
service.FromContext[adapter.DNSRouter](ctx), transport: service.FromContext[adapter.DNSTransportManager](ctx),
dialer, router: service.FromContext[adapter.DNSRouter](ctx),
parallel, dialer: dialer,
queryOptions, parallel: parallel,
fallbackDelay, server: server,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
} }
} }
@@ -52,20 +60,43 @@ type resolveParallelNetworkDialer struct {
dialer ParallelInterfaceDialer dialer ParallelInterfaceDialer
} }
func NewResolveParallelInterfaceDialer(ctx context.Context, dialer ParallelInterfaceDialer, parallel bool, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ParallelInterfaceResolveDialer { func NewResolveParallelInterfaceDialer(ctx context.Context, dialer ParallelInterfaceDialer, parallel bool, server string, queryOptions adapter.DNSQueryOptions, fallbackDelay time.Duration) ParallelInterfaceResolveDialer {
return &resolveParallelNetworkDialer{ return &resolveParallelNetworkDialer{
resolveDialer{ resolveDialer{
service.FromContext[adapter.DNSRouter](ctx), transport: service.FromContext[adapter.DNSTransportManager](ctx),
dialer, router: service.FromContext[adapter.DNSRouter](ctx),
parallel, dialer: dialer,
queryOptions, parallel: parallel,
fallbackDelay, server: server,
queryOptions: queryOptions,
fallbackDelay: fallbackDelay,
}, },
dialer, dialer,
} }
} }
func (d *resolveDialer) initialize() error {
d.initOnce.Do(d.initServer)
return d.initErr
}
func (d *resolveDialer) initServer() {
if d.server == "" {
return
}
transport, loaded := d.transport.Transport(d.server)
if !loaded {
d.initErr = E.New("domain resolver not found: " + d.server)
return
}
d.queryOptions.Transport = transport
}
func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination) return d.dialer.DialContext(ctx, network, destination)
} }
@@ -82,6 +113,10 @@ func (d *resolveDialer) DialContext(ctx context.Context, network string, destina
} }
func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination) return d.dialer.ListenPacket(ctx, destination)
} }
@@ -106,6 +141,10 @@ func (d *resolveDialer) Upstream() any {
} }
func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.DialContext(ctx, network, destination) return d.dialer.DialContext(ctx, network, destination)
} }
@@ -125,6 +164,10 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context
} }
func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) { func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) {
err := d.initialize()
if err != nil {
return nil, err
}
if !destination.IsFqdn() { if !destination.IsFqdn() {
return d.dialer.ListenPacket(ctx, destination) return d.dialer.ListenPacket(ctx, destination)
} }

View File

@@ -18,5 +18,6 @@ func HTTPHost(_ context.Context, metadata *adapter.InboundContext, reader io.Rea
} }
metadata.Protocol = C.ProtocolHTTP metadata.Protocol = C.ProtocolHTTP
metadata.Domain = M.ParseSocksaddr(request.Host).AddrString() metadata.Domain = M.ParseSocksaddr(request.Host).AddrString()
metadata.HTTPRequest = request
return nil return nil
} }

View File

@@ -21,6 +21,7 @@ func TLSClientHello(ctx context.Context, metadata *adapter.InboundContext, reade
if clientHello != nil { if clientHello != nil {
metadata.Protocol = C.ProtocolTLS metadata.Protocol = C.ProtocolTLS
metadata.Domain = clientHello.ServerName metadata.Domain = clientHello.ServerName
metadata.ClientHello = clientHello
return nil return nil
} }
return err return err

View File

@@ -8,11 +8,14 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/pem" "encoding/pem"
"math/big" "math/big"
"net"
"time" "time"
M "github.com/sagernet/sing/common/metadata"
) )
func GenerateCertificate(timeFunc func() time.Time, serverName string) (*tls.Certificate, error) { func GenerateKeyPair(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string) (*tls.Certificate, error) {
privateKeyPem, publicKeyPem, err := GenerateKeyPair(timeFunc, serverName, timeFunc().Add(time.Hour)) privateKeyPem, publicKeyPem, err := GenerateCertificate(parent, parentKey, timeFunc, serverName, timeFunc().Add(time.Hour))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -23,7 +26,7 @@ func GenerateCertificate(timeFunc func() time.Time, serverName string) (*tls.Cer
return &certificate, err return &certificate, err
} }
func GenerateKeyPair(timeFunc func() time.Time, serverName string, expire time.Time) (privateKeyPem []byte, publicKeyPem []byte, err error) { func GenerateCertificate(parent *x509.Certificate, parentKey any, timeFunc func() time.Time, serverName string, expire time.Time) (privateKeyPem []byte, publicKeyPem []byte, err error) {
if timeFunc == nil { if timeFunc == nil {
timeFunc = time.Now timeFunc = time.Now
} }
@@ -35,19 +38,36 @@ func GenerateKeyPair(timeFunc func() time.Time, serverName string, expire time.T
if err != nil { if err != nil {
return return
} }
template := &x509.Certificate{ var template *x509.Certificate
SerialNumber: serialNumber, if serverAddress := M.ParseAddr(serverName); serverAddress.IsValid() {
NotBefore: timeFunc().Add(time.Hour * -1), template = &x509.Certificate{
NotAfter: expire, SerialNumber: serialNumber,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, IPAddresses: []net.IP{serverAddress.AsSlice()},
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, NotBefore: timeFunc().Add(time.Hour * -1),
BasicConstraintsValid: true, NotAfter: expire,
Subject: pkix.Name{ KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
CommonName: serverName, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}, BasicConstraintsValid: true,
DNSNames: []string{serverName}, }
} else {
template = &x509.Certificate{
SerialNumber: serialNumber,
NotBefore: timeFunc().Add(time.Hour * -1),
NotAfter: expire,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
Subject: pkix.Name{
CommonName: serverName,
},
DNSNames: []string{serverName},
}
} }
publicDer, err := x509.CreateCertificate(rand.Reader, template, template, key.Public(), key) if parent == nil {
parent = template
parentKey = key
}
publicDer, err := x509.CreateCertificate(rand.Reader, template, parent, key.Public(), parentKey)
if err != nil { if err != nil {
return return
} }

View File

@@ -222,7 +222,7 @@ func NewSTDServer(ctx context.Context, logger log.Logger, options option.Inbound
} }
if certificate == nil && key == nil && options.Insecure { if certificate == nil && key == nil && options.Insecure {
tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return GenerateCertificate(ntp.TimeFuncFromContext(ctx), info.ServerName) return GenerateKeyPair(nil, nil, ntp.TimeFuncFromContext(ctx), info.ServerName)
} }
} else { } else {
if certificate == nil { if certificate == nil {

View File

@@ -7,8 +7,9 @@ import (
"strings" "strings"
"time" "time"
"github.com/sagernet/sing/common"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"golang.org/x/net/publicsuffix"
) )
type Conn struct { type Conn struct {
@@ -42,30 +43,12 @@ func (c *Conn) Write(b []byte) (n int, err error) {
return return
} }
} }
splits := strings.Split(string(b[serverName.Index:serverName.Index+serverName.Length]), ".") splits := strings.Split(serverName.ServerName, ".")
currentIndex := serverName.Index currentIndex := serverName.Index
var striped bool if publicSuffix := publicsuffix.List.PublicSuffix(serverName.ServerName); publicSuffix != "" {
if len(splits) > 3 { splits = splits[:len(splits)-strings.Count(serverName.ServerName, ".")]
suffix := splits[len(splits)-3] + "." + splits[len(splits)-2] + "." + splits[len(splits)-1]
if publicSuffixMatcher().Match(suffix) {
splits = splits[:len(splits)-3]
}
striped = true
} }
if !striped && len(splits) > 2 { if len(splits) > 1 && splits[0] == "..." {
suffix := splits[len(splits)-2] + "." + splits[len(splits)-1]
if publicSuffixMatcher().Match(suffix) {
splits = splits[:len(splits)-2]
}
striped = true
}
if !striped && len(splits) > 1 {
suffix := splits[len(splits)-1]
if publicSuffixMatcher().Match(suffix) {
splits = splits[:len(splits)-1]
}
}
if len(splits) > 1 && common.Contains(publicPrefix, splits[0]) {
currentIndex += len(splits[0]) + 1 currentIndex += len(splits[0]) + 1
splits = splits[1:] splits = splits[1:]
} }

View File

@@ -23,9 +23,9 @@ const (
) )
type myServerName struct { type myServerName struct {
Index int Index int
Length int Length int
sex []byte ServerName string
} }
func indexTLSServerName(payload []byte) *myServerName { func indexTLSServerName(payload []byte) *myServerName {
@@ -119,9 +119,9 @@ func indexTLSServerNameFromExtensions(exs []byte) *myServerName {
sniLen := uint16(sex[3])<<8 | uint16(sex[4]) sniLen := uint16(sex[3])<<8 | uint16(sex[4])
sex = sex[sniExtensionHeaderLen:] sex = sex[sniExtensionHeaderLen:]
return &myServerName{ return &myServerName{
Index: currentIndex + extensionHeaderLen + sniExtensionHeaderLen, Index: currentIndex + extensionHeaderLen + sniExtensionHeaderLen,
Length: int(sniLen), Length: int(sniLen),
sex: sex, ServerName: string(sex),
} }
} }
exs = exs[4+exLen:] exs = exs[4+exLen:]

View File

@@ -1,55 +0,0 @@
package tf
import (
"bufio"
"bytes"
_ "embed"
"io"
"strings"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/domain"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
var publicPrefix = []string{
"www",
}
//go:generate wget -O public_suffix_list.dat https://publicsuffix.org/list/public_suffix_list.dat
//go:embed public_suffix_list.dat
var publicSuffix []byte
var publicSuffixMatcher = common.OnceValue(func() *domain.Matcher {
matcher, err := initPublicSuffixMatcher()
if err != nil {
panic(F.ToString("error in initialize public suffix matcher"))
}
return matcher
})
func initPublicSuffixMatcher() (*domain.Matcher, error) {
reader := bufio.NewReader(bytes.NewReader(publicSuffix))
var domainList []string
for {
line, isPrefix, err := reader.ReadLine()
if err != nil {
if err == io.EOF {
break
}
return nil, err
}
if isPrefix {
return nil, E.New("unexpected prefix line")
}
lineStr := string(line)
lineStr = strings.TrimSpace(lineStr)
if lineStr == "" || strings.HasPrefix(lineStr, "//") {
continue
}
domainList = append(domainList, lineStr)
}
return domain.NewMatcher(domainList, nil, false), nil
}

File diff suppressed because it is too large Load Diff

7
constant/script.go Normal file
View File

@@ -0,0 +1,7 @@
package constant
const (
ScriptTypeSurge = "surge"
ScriptSourceTypeLocal = "local"
ScriptSourceTypeRemote = "remote"
)

View File

@@ -252,7 +252,8 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
ruleIndex = -1 ruleIndex = -1
for { for {
dnsCtx := adapter.OverrideContext(ctx) dnsCtx := adapter.OverrideContext(ctx)
transport, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message), &options) dnsOptions := options
transport, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message), &dnsOptions)
if rule != nil { if rule != nil {
switch action := rule.Action().(type) { switch action := rule.Action().(type) {
case *R.RuleActionReject: case *R.RuleActionReject:
@@ -271,10 +272,10 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
return rule.MatchAddressLimit(metadata) return rule.MatchAddressLimit(metadata)
} }
} }
if options.Strategy == C.DomainStrategyAsIS { if dnsOptions.Strategy == C.DomainStrategyAsIS {
options.Strategy = r.defaultDomainStrategy dnsOptions.Strategy = r.defaultDomainStrategy
} }
response, err = r.client.Exchange(dnsCtx, transport, message, options, responseCheck) response, err = r.client.Exchange(dnsCtx, transport, message, dnsOptions, responseCheck)
var rejected bool var rejected bool
if err != nil { if err != nil {
if errors.Is(err, ErrResponseRejectedCached) { if errors.Is(err, ErrResponseRejectedCached) {

View File

@@ -27,9 +27,14 @@ func NewTransportAdapter(transportType string, transportTag string, dependencies
} }
func NewTransportAdapterWithLocalOptions(transportType string, transportTag string, localOptions option.LocalDNSServerOptions) TransportAdapter { func NewTransportAdapterWithLocalOptions(transportType string, transportTag string, localOptions option.LocalDNSServerOptions) TransportAdapter {
var dependencies []string
if localOptions.DomainResolver != nil && localOptions.DomainResolver.Server != "" {
dependencies = append(dependencies, localOptions.DomainResolver.Server)
}
return TransportAdapter{ return TransportAdapter{
transportType: transportType, transportType: transportType,
transportTag: transportTag, transportTag: transportTag,
dependencies: dependencies,
strategy: C.DomainStrategy(localOptions.LegacyStrategy), strategy: C.DomainStrategy(localOptions.LegacyStrategy),
clientSubnet: localOptions.LegacyClientSubnet, clientSubnet: localOptions.LegacyClientSubnet,
} }
@@ -37,8 +42,11 @@ func NewTransportAdapterWithLocalOptions(transportType string, transportTag stri
func NewTransportAdapterWithRemoteOptions(transportType string, transportTag string, remoteOptions option.RemoteDNSServerOptions) TransportAdapter { func NewTransportAdapterWithRemoteOptions(transportType string, transportTag string, remoteOptions option.RemoteDNSServerOptions) TransportAdapter {
var dependencies []string var dependencies []string
if remoteOptions.AddressResolver != "" { if remoteOptions.DomainResolver != nil && remoteOptions.DomainResolver.Server != "" {
dependencies = []string{remoteOptions.AddressResolver} dependencies = append(dependencies, remoteOptions.DomainResolver.Server)
}
if remoteOptions.LegacyAddressResolver != "" {
dependencies = append(dependencies, remoteOptions.LegacyAddressResolver)
} }
return TransportAdapter{ return TransportAdapter{
transportType: transportType, transportType: transportType,

View File

@@ -19,37 +19,39 @@ func NewLocalDialer(ctx context.Context, options option.LocalDNSServerOptions) (
if options.LegacyDefaultDialer { if options.LegacyDefaultDialer {
return dialer.NewDefaultOutbound(ctx), nil return dialer.NewDefaultOutbound(ctx), nil
} else { } else {
return dialer.New(ctx, options.DialerOptions, false) return dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: options.DialerOptions,
DirectResolver: true,
})
} }
} }
func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) { func NewRemoteDialer(ctx context.Context, options option.RemoteDNSServerOptions) (N.Dialer, error) {
var (
transportDialer N.Dialer
err error
)
if options.LegacyDefaultDialer { if options.LegacyDefaultDialer {
transportDialer = dialer.NewDefaultOutbound(ctx) transportDialer := dialer.NewDefaultOutbound(ctx)
} else { if options.LegacyAddressResolver != "" {
transportDialer, err = dialer.New(ctx, options.DialerOptions, options.ServerIsDomain()) transport := service.FromContext[adapter.DNSTransportManager](ctx)
} resolverTransport, loaded := transport.Transport(options.LegacyAddressResolver)
if err != nil { if !loaded {
return nil, err return nil, E.New("address resolver not found: ", options.LegacyAddressResolver)
} }
if options.AddressResolver != "" { transportDialer = newTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.LegacyAddressStrategy), time.Duration(options.LegacyAddressFallbackDelay))
transport := service.FromContext[adapter.DNSTransportManager](ctx) } else if options.ServerIsDomain() {
resolverTransport, loaded := transport.Transport(options.AddressResolver) return nil, E.New("missing address resolver for server: ", options.Server)
if !loaded {
return nil, E.New("address resolver not found: ", options.AddressResolver)
} }
transportDialer = NewTransportDialer(transportDialer, service.FromContext[adapter.DNSRouter](ctx), resolverTransport, C.DomainStrategy(options.AddressStrategy), time.Duration(options.AddressFallbackDelay)) return transportDialer, nil
} else if options.ServerIsDomain() { } else {
return nil, E.New("missing address resolver for server: ", options.Server) return dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: options.DialerOptions,
RemoteIsDomain: options.ServerIsDomain(),
DirectResolver: true,
})
} }
return transportDialer, nil
} }
type TransportDialer struct { type legacyTransportDialer struct {
dialer N.Dialer dialer N.Dialer
dnsRouter adapter.DNSRouter dnsRouter adapter.DNSRouter
transport adapter.DNSTransport transport adapter.DNSTransport
@@ -57,8 +59,8 @@ type TransportDialer struct {
fallbackDelay time.Duration fallbackDelay time.Duration
} }
func NewTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *TransportDialer { func newTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) *legacyTransportDialer {
return &TransportDialer{ return &legacyTransportDialer{
dialer, dialer,
dnsRouter, dnsRouter,
transport, transport,
@@ -67,7 +69,7 @@ func NewTransportDialer(dialer N.Dialer, dnsRouter adapter.DNSRouter, transport
} }
} }
func (d *TransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *legacyTransportDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if destination.IsIP() { if destination.IsIP() {
return d.dialer.DialContext(ctx, network, destination) return d.dialer.DialContext(ctx, network, destination)
} }
@@ -81,7 +83,7 @@ func (d *TransportDialer) DialContext(ctx context.Context, network string, desti
return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay) return N.DialParallel(ctx, d.dialer, network, destination, addresses, d.strategy == C.DomainStrategyPreferIPv6, d.fallbackDelay)
} }
func (d *TransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *legacyTransportDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if destination.IsIP() { if destination.IsIP() {
return d.dialer.ListenPacket(ctx, destination) return d.dialer.ListenPacket(ctx, destination)
} }
@@ -96,6 +98,6 @@ func (d *TransportDialer) ListenPacket(ctx context.Context, destination M.Socksa
return conn, err return conn, err
} }
func (d *TransportDialer) Upstream() any { func (d *legacyTransportDialer) Upstream() any {
return d.dialer return d.dialer
} }

View File

@@ -2,6 +2,10 @@
icon: material/alert-decagram icon: material/alert-decagram
--- ---
#### 1.12.0-alpha.3
* Fixes and improvements
#### 1.12.0-alpha.2 #### 1.12.0-alpha.2
* Update quic-go to v0.49.0 * Update quic-go to v0.49.0

View File

@@ -19,10 +19,12 @@ import (
) )
var ( var (
bucketSelected = []byte("selected") bucketSelected = []byte("selected")
bucketExpand = []byte("group_expand") bucketExpand = []byte("group_expand")
bucketMode = []byte("clash_mode") bucketMode = []byte("clash_mode")
bucketRuleSet = []byte("rule_set") bucketRuleSet = []byte("rule_set")
bucketScript = []byte("script")
bucketSgPersistentStore = []byte("sg_persistent_store")
bucketNameList = []string{ bucketNameList = []string{
string(bucketSelected), string(bucketSelected),
@@ -284,8 +286,8 @@ func (c *CacheFile) StoreGroupExpand(group string, isExpand bool) error {
}) })
} }
func (c *CacheFile) LoadRuleSet(tag string) *adapter.SavedRuleSet { func (c *CacheFile) LoadRuleSet(tag string) *adapter.SavedBinary {
var savedSet adapter.SavedRuleSet var savedSet adapter.SavedBinary
err := c.DB.View(func(t *bbolt.Tx) error { err := c.DB.View(func(t *bbolt.Tx) error {
bucket := c.bucket(t, bucketRuleSet) bucket := c.bucket(t, bucketRuleSet)
if bucket == nil { if bucket == nil {
@@ -303,7 +305,7 @@ func (c *CacheFile) LoadRuleSet(tag string) *adapter.SavedRuleSet {
return &savedSet return &savedSet
} }
func (c *CacheFile) SaveRuleSet(tag string, set *adapter.SavedRuleSet) error { func (c *CacheFile) SaveRuleSet(tag string, set *adapter.SavedBinary) error {
return c.DB.Batch(func(t *bbolt.Tx) error { return c.DB.Batch(func(t *bbolt.Tx) error {
bucket, err := c.createBucket(t, bucketRuleSet) bucket, err := c.createBucket(t, bucketRuleSet)
if err != nil { if err != nil {
@@ -316,3 +318,70 @@ func (c *CacheFile) SaveRuleSet(tag string, set *adapter.SavedRuleSet) error {
return bucket.Put([]byte(tag), setBinary) return bucket.Put([]byte(tag), setBinary)
}) })
} }
func (c *CacheFile) LoadScript(tag string) *adapter.SavedBinary {
var savedSet adapter.SavedBinary
err := c.DB.View(func(t *bbolt.Tx) error {
bucket := c.bucket(t, bucketScript)
if bucket == nil {
return os.ErrNotExist
}
scriptBinary := bucket.Get([]byte(tag))
if len(scriptBinary) == 0 {
return os.ErrInvalid
}
return savedSet.UnmarshalBinary(scriptBinary)
})
if err != nil {
return nil
}
return &savedSet
}
func (c *CacheFile) SaveScript(tag string, set *adapter.SavedBinary) error {
return c.DB.Batch(func(t *bbolt.Tx) error {
bucket, err := c.createBucket(t, bucketScript)
if err != nil {
return err
}
scriptBinary, err := set.MarshalBinary()
if err != nil {
return err
}
return bucket.Put([]byte(tag), scriptBinary)
})
}
func (c *CacheFile) SurgePersistentStoreRead(key string) string {
var value string
_ = c.DB.View(func(t *bbolt.Tx) error {
bucket := c.bucket(t, bucketSgPersistentStore)
if bucket == nil {
return nil
}
valueBinary := bucket.Get([]byte(key))
if len(valueBinary) > 0 {
value = string(valueBinary)
}
return nil
})
return value
}
func (c *CacheFile) SurgePersistentStoreWrite(key string, value string) error {
return c.DB.Batch(func(t *bbolt.Tx) error {
if value != "" {
bucket, err := c.createBucket(t, bucketSgPersistentStore)
if err != nil {
return err
}
return bucket.Put([]byte(key), []byte(value))
} else {
bucket := c.bucket(t, bucketSgPersistentStore)
if bucket == nil {
return nil
}
return bucket.Delete([]byte(key))
}
})
}

View File

@@ -0,0 +1,84 @@
package clashapi
import (
"context"
"net/http"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/service"
"github.com/go-chi/chi/v5"
"github.com/go-chi/render"
"github.com/gofrs/uuid/v5"
"howett.net/plist"
)
func mitmRouter(ctx context.Context) http.Handler {
r := chi.NewRouter()
r.Get("/mobileconfig", getMobileConfig(ctx))
r.Get("/certificate", getCertificate(ctx))
return r
}
func getMobileConfig(ctx context.Context) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {
engine := service.FromContext[adapter.MITMEngine](ctx)
if engine == nil {
http.NotFound(writer, request)
render.PlainText(writer, request, "MITM not enabled")
return
}
certificate := engine.ExportCertificate()
if certificate == nil {
http.NotFound(writer, request)
render.PlainText(writer, request, "Certificate not configured")
return
}
writer.Header().Set("Content-Type", "application/x-apple-aspen-config")
uuidGen := common.Must1(uuid.NewV4()).String()
mobileConfig := map[string]interface{}{
"PayloadContent": []interface{}{
map[string]interface{}{
"PayloadCertificateFileName": "Certificates.cer",
"PayloadContent": certificate.Raw,
"PayloadDescription": "Adds a root certificate",
"PayloadDisplayName": certificate.Subject.CommonName,
"PayloadIdentifier": "com.apple.security.root." + uuidGen,
"PayloadType": "com.apple.security.root",
"PayloadUUID": uuidGen,
"PayloadVersion": 1,
},
},
"PayloadDisplayName": certificate.Subject.CommonName,
"PayloadIdentifier": "io.nekohasekai.sfa.ca.profile." + uuidGen,
"PayloadRemovalDisallowed": false,
"PayloadType": "Configuration",
"PayloadUUID": uuidGen,
"PayloadVersion": 1,
}
encoder := plist.NewEncoder(writer)
encoder.Indent("\t")
encoder.Encode(mobileConfig)
}
}
func getCertificate(ctx context.Context) http.HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request) {
engine := service.FromContext[adapter.MITMEngine](ctx)
if engine == nil {
http.NotFound(writer, request)
render.PlainText(writer, request, "MITM not enabled")
return
}
certificate := engine.ExportCertificate()
if certificate == nil {
http.NotFound(writer, request)
render.PlainText(writer, request, "Certificate not configured")
return
}
writer.Header().Set("Content-Type", "application/x-x509-ca-cert")
writer.Header().Set("Content-Disposition", "attachment; filename=Certificate.crt")
writer.Write(certificate.Raw)
}
}

View File

@@ -124,6 +124,7 @@ func NewServer(ctx context.Context, logFactory log.ObservableFactory, options op
r.Mount("/profile", profileRouter()) r.Mount("/profile", profileRouter())
r.Mount("/cache", cacheRouter(ctx)) r.Mount("/cache", cacheRouter(ctx))
r.Mount("/dns", dnsRouter(s.dnsRouter)) r.Mount("/dns", dnsRouter(s.dnsRouter))
r.Mount("/mitm", mitmRouter(ctx))
s.setupMetaAPI(r) s.setupMetaAPI(r)
}) })

View File

@@ -32,4 +32,9 @@ type Notification struct {
Subtitle string Subtitle string
Body string Body string
OpenURL string OpenURL string
Clipboard string
MediaURL string
MediaData []byte
MediaType string
Timeout int
} }

View File

@@ -7,11 +7,13 @@ var (
type Locale struct { type Locale struct {
// deprecated messages for graphical clients // deprecated messages for graphical clients
Locale string
DeprecatedMessage string DeprecatedMessage string
DeprecatedMessageNoLink string DeprecatedMessageNoLink string
} }
var defaultLocal = &Locale{ var defaultLocal = &Locale{
Locale: "en_US",
DeprecatedMessage: "%s is deprecated in sing-box %s and will be removed in sing-box %s please checkout documentation for migration.", DeprecatedMessage: "%s is deprecated in sing-box %s and will be removed in sing-box %s please checkout documentation for migration.",
DeprecatedMessageNoLink: "%s is deprecated in sing-box %s and will be removed in sing-box %s.", DeprecatedMessageNoLink: "%s is deprecated in sing-box %s and will be removed in sing-box %s.",
} }

View File

@@ -4,6 +4,7 @@ var warningMessageForEndUsers = "\n\n如果您不明白此消息意味着什么
func init() { func init() {
localeRegistry["zh_CN"] = &Locale{ localeRegistry["zh_CN"] = &Locale{
Locale: "zh_CN",
DeprecatedMessage: "%s 已在 sing-box %s 中被弃用,且将在 sing-box %s 中被移除,请参阅迁移指南。" + warningMessageForEndUsers, DeprecatedMessage: "%s 已在 sing-box %s 中被弃用,且将在 sing-box %s 中被移除,请参阅迁移指南。" + warningMessageForEndUsers,
DeprecatedMessageNoLink: "%s 已在 sing-box %s 中被弃用,且将在 sing-box %s 中被移除。" + warningMessageForEndUsers, DeprecatedMessageNoLink: "%s 已在 sing-box %s 中被弃用,且将在 sing-box %s 中被移除。" + warningMessageForEndUsers,
} }

8
go.mod
View File

@@ -3,9 +3,11 @@ module github.com/sagernet/sing-box
go 1.20 go 1.20
require ( require (
github.com/adhocore/gronx v1.19.5
github.com/caddyserver/certmagic v0.20.0 github.com/caddyserver/certmagic v0.20.0
github.com/cloudflare/circl v1.3.7 github.com/cloudflare/circl v1.3.7
github.com/cretz/bine v0.2.0 github.com/cretz/bine v0.2.0
github.com/dop251/goja v0.0.0-20250125213203-5ef83b82af17
github.com/go-chi/chi/v5 v5.1.0 github.com/go-chi/chi/v5 v5.1.0
github.com/go-chi/render v1.0.3 github.com/go-chi/render v1.0.3
github.com/gofrs/uuid/v5 v5.3.0 github.com/gofrs/uuid/v5 v5.3.0
@@ -61,15 +63,17 @@ require (
github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 // indirect github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 // indirect
github.com/dlclark/regexp2 v1.11.4 // indirect
github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect
github.com/go-ole/go-ole v1.3.0 // indirect github.com/go-ole/go-ole v1.3.0 // indirect
github.com/go-sourcemap/sourcemap v2.1.4+incompatible // indirect
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect
github.com/gobwas/httphead v0.1.0 // indirect github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect github.com/gobwas/pool v0.2.1 // indirect
github.com/google/btree v1.1.3 // indirect github.com/google/btree v1.1.3 // indirect
github.com/google/go-cmp v0.6.0 // indirect github.com/google/go-cmp v0.6.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect github.com/google/go-querystring v1.1.0 // indirect
github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a // indirect github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 // indirect
github.com/hashicorp/yamux v0.1.2 // indirect github.com/hashicorp/yamux v0.1.2 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/josharian/native v1.1.0 // indirect github.com/josharian/native v1.1.0 // indirect
@@ -78,7 +82,6 @@ require (
github.com/libdns/libdns v0.2.2 // indirect github.com/libdns/libdns v0.2.2 // indirect
github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect
github.com/mdlayher/socket v0.4.1 // indirect github.com/mdlayher/socket v0.4.1 // indirect
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect
github.com/onsi/ginkgo/v2 v2.9.7 // indirect github.com/onsi/ginkgo/v2 v2.9.7 // indirect
github.com/pierrec/lz4/v4 v4.1.14 // indirect github.com/pierrec/lz4/v4 v4.1.14 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
@@ -97,7 +100,6 @@ require (
golang.org/x/tools v0.24.0 // indirect golang.org/x/tools v0.24.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect
lukechampine.com/blake3 v1.3.0 // indirect lukechampine.com/blake3 v1.3.0 // indirect
) )

32
go.sum
View File

@@ -1,3 +1,6 @@
github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
github.com/adhocore/gronx v1.19.5 h1:cwIG4nT1v9DvadxtHBe6MzE+FZ1JDvAUC45U2fl4eSQ=
github.com/adhocore/gronx v1.19.5/go.mod h1:7oUY1WAU8rEJWmAxXR2DN0JaO4gi9khSgKjiRypqteg=
github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU=
github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY=
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI= github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
@@ -16,6 +19,10 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 h1:CaO/zOnF8VvUfEbhRatPcwKVWamvbYd8tQGRWacE9kU= github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1 h1:CaO/zOnF8VvUfEbhRatPcwKVWamvbYd8tQGRWacE9kU=
github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1/go.mod h1:+hnT3ywWDTAFrW5aE+u2Sa/wT555ZqwoCS+pk3p6ry4= github.com/dgrijalva/jwt-go/v4 v4.0.0-preview1/go.mod h1:+hnT3ywWDTAFrW5aE+u2Sa/wT555ZqwoCS+pk3p6ry4=
github.com/dlclark/regexp2 v1.11.4 h1:rPYF9/LECdNymJufQKmri9gV604RvvABwgOA8un7yAo=
github.com/dlclark/regexp2 v1.11.4/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dop251/goja v0.0.0-20250125213203-5ef83b82af17 h1:spJaibPy2sZNwo6Q0HjBVufq7hBUj5jNFOKRoogCBow=
github.com/dop251/goja v0.0.0-20250125213203-5ef83b82af17/go.mod h1:MxLav0peU43GgvwVgNbLAj1s/bSGboKkhuULvq/7hx4=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw=
@@ -25,6 +32,8 @@ github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIo
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE=
github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78=
github.com/go-sourcemap/sourcemap v2.1.4+incompatible h1:a+iTbH5auLKxaNwQFg0B+TCYl6lbukKPc7b5x0n1s6Q=
github.com/go-sourcemap/sourcemap v2.1.4+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
@@ -41,8 +50,8 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a h1:fEBsGL/sjAuJrgah5XqmmYsTLzJp/TO9Lhy39gkverk= github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8 h1:FKHo8hFI3A+7w0aUQuYXQ+6EN5stWmeY/AZqtM8xk9k=
github.com/google/pprof v0.0.0-20231101202521-4ca4178f5c7a/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= github.com/google/pprof v0.0.0-20240727154555-813a5fbdbec8/go.mod h1:K1liHPHnj73Fdn/EKuT8nrFqBihUSKXoLYU0BuatOYo=
github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8=
github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
@@ -58,9 +67,6 @@ github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6K
github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=
github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg= github.com/klauspost/cpuid/v2 v2.2.5 h1:0E5MSMDEoAulmXNFquVs//DdoomxaoTY1kUhbc/qbZg=
github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/klauspost/cpuid/v2 v2.2.5/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/libdns/alidns v1.0.3 h1:LFHuGnbseq5+HCeGa1aW8awyX/4M2psB9962fdD2+yQ= github.com/libdns/alidns v1.0.3 h1:LFHuGnbseq5+HCeGa1aW8awyX/4M2psB9962fdD2+yQ=
github.com/libdns/alidns v1.0.3/go.mod h1:e18uAG6GanfRhcJj6/tps2rCMzQJaYVcGKT+ELjdjGE= github.com/libdns/alidns v1.0.3/go.mod h1:e18uAG6GanfRhcJj6/tps2rCMzQJaYVcGKT+ELjdjGE=
github.com/libdns/cloudflare v0.1.1 h1:FVPfWwP8zZCqj268LZjmkDleXlHPlFU9KC4OJ3yn054= github.com/libdns/cloudflare v0.1.1 h1:FVPfWwP8zZCqj268LZjmkDleXlHPlFU9KC4OJ3yn054=
@@ -80,8 +86,6 @@ github.com/mholt/acmez v1.2.0 h1:1hhLxSgY5FvH5HCnGUuwbKY2VQVo8IU7rxXKSnZ7F30=
github.com/mholt/acmez v1.2.0/go.mod h1:VT9YwH1xgNX1kmYY89gY8xPJC84BFAisjo8Egigt4kE= github.com/mholt/acmez v1.2.0/go.mod h1:VT9YwH1xgNX1kmYY89gY8xPJC84BFAisjo8Egigt4kE=
github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ= github.com/miekg/dns v1.1.62 h1:cN8OuEF1/x5Rq6Np+h1epln8OiyPWV+lROx9LxcGgIQ=
github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ= github.com/miekg/dns v1.1.62/go.mod h1:mvDlcItzm+br7MToIKqkglaGhlFMHJ9DTNNWONWXbNQ=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss= github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss=
github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0= github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0=
github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU= github.com/onsi/gomega v1.27.7 h1:fVih9JD6ogIiHUN6ePK7HJidyEDpWGVB5mzM7cWNXoU=
@@ -114,8 +118,6 @@ github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a h1:ObwtHN2VpqE0ZN
github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM= github.com/sagernet/netlink v0.0.0-20240612041022-b9a21c07ac6a/go.mod h1:xLnfdiJbSp8rNqYEdIW/6eDO4mVoogml14Bh2hSiFpM=
github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I= github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNenDW2zaXr8I=
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8= github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/quic-go v0.48.2-beta.1 h1:W0plrLWa1XtOWDTdX3CJwxmQuxkya12nN5BRGZ87kEg=
github.com/sagernet/quic-go v0.48.2-beta.1/go.mod h1:1WgdDIVD1Gybp40JTWketeSfKA/+or9YMLaG5VeTk4k=
github.com/sagernet/quic-go v0.49.0-beta.1 h1:3LdoCzVVfYRibZns1tYWSIoB65fpTmrwy+yfK8DQ8Jk= github.com/sagernet/quic-go v0.49.0-beta.1 h1:3LdoCzVVfYRibZns1tYWSIoB65fpTmrwy+yfK8DQ8Jk=
github.com/sagernet/quic-go v0.49.0-beta.1/go.mod h1:uesWD1Ihrldq1M3XtjuEvIUqi8WHNsRs71b3Lt1+p/U= github.com/sagernet/quic-go v0.49.0-beta.1/go.mod h1:uesWD1Ihrldq1M3XtjuEvIUqi8WHNsRs71b3Lt1+p/U=
github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc=
@@ -172,8 +174,6 @@ go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M=
go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y=
golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8= golang.org/x/crypto v0.0.0-20210513164829-c07d793c2f9a/go.mod h1:P+XmwS30IXTQdn5tA2iutPOUgjI07+tq3H3K9MVA1s8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc=
golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc= golang.org/x/crypto v0.32.0/go.mod h1:ZnnJkOaASj8g0AjIduWNlq2NRxL0PlBrbKVyZ6V/Ugc=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
@@ -182,8 +182,6 @@ golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0=
golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.31.0 h1:68CPQngjLL0r2AlUKiSxtQFKvzRVbnzLwMUn5SzcLHo=
golang.org/x/net v0.31.0/go.mod h1:P4fl1q7dY2hnZFxEk4pPSkDHF+QqjitcnDjUQyMM+pM=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
@@ -195,12 +193,10 @@ golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q= golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
@@ -222,10 +218,10 @@ google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM=
google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA= google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI=
google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f h1:BLraFXnmrev5lT+xlilqcH8XK9/i0At2xKjWk4p6zsU=
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg= gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@@ -10,6 +10,10 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
) )
const (
DefaultTimeFormat = "-0700 2006-01-02 15:04:05"
)
type Options struct { type Options struct {
Context context.Context Context context.Context
Options option.LogOptions Options option.LogOptions
@@ -47,7 +51,7 @@ func New(options Options) (Factory, error) {
DisableColors: logOptions.DisableColor || logFilePath != "", DisableColors: logOptions.DisableColor || logFilePath != "",
DisableTimestamp: !logOptions.Timestamp && logFilePath != "", DisableTimestamp: !logOptions.Timestamp && logFilePath != "",
FullTimestamp: logOptions.Timestamp, FullTimestamp: logOptions.Timestamp,
TimestampFormat: "-0700 2006-01-02 15:04:05", TimestampFormat: DefaultTimeFormat,
} }
factory := NewDefaultFactory( factory := NewDefaultFactory(
options.Context, options.Context,

11
mitm/constants.go Normal file
View File

@@ -0,0 +1,11 @@
package mitm
import (
"encoding/base64"
"github.com/sagernet/sing/common"
)
var surgeTinyGif = common.OnceValue(func() []byte {
return common.Must1(base64.StdEncoding.DecodeString("R0lGODlhAQABAAAAACH5BAEAAAAALAAAAAABAAEAAAIBAAA="))
})

1121
mitm/engine.go Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -128,18 +128,34 @@ func (o *NewDNSServerOptions) Upgrade(ctx context.Context) error {
} else { } else {
serverType = C.DNSTypeUDP serverType = C.DNSTypeUDP
} }
remoteOptions := RemoteDNSServerOptions{ var remoteOptions RemoteDNSServerOptions
LocalDNSServerOptions: LocalDNSServerOptions{ if options.Detour == "" {
DialerOptions: DialerOptions{ remoteOptions = RemoteDNSServerOptions{
Detour: options.Detour, LocalDNSServerOptions: LocalDNSServerOptions{
LegacyStrategy: options.Strategy,
LegacyDefaultDialer: options.Detour == "",
LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
}, },
LegacyStrategy: options.Strategy, LegacyAddressResolver: options.AddressResolver,
LegacyDefaultDialer: options.Detour == "", LegacyAddressStrategy: options.AddressStrategy,
LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}), LegacyAddressFallbackDelay: options.AddressFallbackDelay,
}, }
AddressResolver: options.AddressResolver, } else {
AddressStrategy: options.AddressStrategy, remoteOptions = RemoteDNSServerOptions{
AddressFallbackDelay: options.AddressFallbackDelay, LocalDNSServerOptions: LocalDNSServerOptions{
DialerOptions: DialerOptions{
Detour: options.Detour,
DomainResolver: &DomainResolveOptions{
Server: options.AddressResolver,
Strategy: options.AddressStrategy,
},
FallbackDelay: options.AddressFallbackDelay,
},
LegacyStrategy: options.Strategy,
LegacyDefaultDialer: options.Detour == "",
LegacyClientSubnet: options.ClientSubnet.Build(netip.Prefix{}),
},
}
} }
switch serverType { switch serverType {
case C.DNSTypeUDP: case C.DNSTypeUDP:
@@ -274,9 +290,9 @@ type LocalDNSServerOptions struct {
type RemoteDNSServerOptions struct { type RemoteDNSServerOptions struct {
LocalDNSServerOptions LocalDNSServerOptions
ServerOptions ServerOptions
AddressResolver string `json:"address_resolver,omitempty"` LegacyAddressResolver string `json:"-"`
AddressStrategy DomainStrategy `json:"address_strategy,omitempty"` LegacyAddressStrategy DomainStrategy `json:"-"`
AddressFallbackDelay badoption.Duration `json:"address_fallback_delay,omitempty"` LegacyAddressFallbackDelay badoption.Duration `json:"-"`
} }
type RemoteTLSDNSServerOptions struct { type RemoteTLSDNSServerOptions struct {

38
option/mitm.go Normal file
View File

@@ -0,0 +1,38 @@
package option
import (
"github.com/sagernet/sing/common/json/badoption"
)
type MITMOptions struct {
Enabled bool `json:"enabled,omitempty"`
HTTP2Enabled bool `json:"http2_enabled,omitempty"`
TLSDecryptionOptions *TLSDecryptionOptions `json:"tls_decryption,omitempty"`
}
type TLSDecryptionOptions struct {
Enabled bool `json:"enabled,omitempty"`
KeyPair string `json:"key_pair_p12,omitempty"`
KeyPassword string `json:"key_password,omitempty"`
}
type MITMRouteOptions struct {
Enabled bool `json:"enabled,omitempty"`
Print bool `json:"print,omitempty"`
Script badoption.Listable[MITMRouteSurgeScriptOptions] `json:"sg_script,omitempty"`
SurgeURLRewrite badoption.Listable[SurgeURLRewriteLine] `json:"sg_url_rewrite,omitempty"`
SurgeHeaderRewrite badoption.Listable[SurgeHeaderRewriteLine] `json:"sg_header_rewrite,omitempty"`
SurgeBodyRewrite badoption.Listable[SurgeBodyRewriteLine] `json:"sg_body_rewrite,omitempty"`
SurgeMapLocal badoption.Listable[SurgeMapLocalLine] `json:"sg_map_local,omitempty"`
}
type MITMRouteSurgeScriptOptions struct {
Tag string `json:"tag"`
Type badoption.Listable[string] `json:"type"`
Pattern badoption.Listable[*badoption.Regexp] `json:"pattern"`
Timeout badoption.Duration `json:"timeout,omitempty"`
RequiresBody bool `json:"requires_body,omitempty"`
MaxSize int64 `json:"max_size,omitempty"`
BinaryBodyMode bool `json:"binary_body_mode,omitempty"`
Arguments badoption.Listable[string] `json:"arguments,omitempty"`
}

View File

@@ -0,0 +1,444 @@
package option
import (
"encoding/base64"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"unicode"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/json"
)
type SurgeURLRewriteLine struct {
Pattern *regexp.Regexp
Destination *url.URL
Redirect bool
Reject bool
}
func (l SurgeURLRewriteLine) String() string {
var fields []string
fields = append(fields, l.Pattern.String())
if l.Reject {
fields = append(fields, "_")
} else {
fields = append(fields, l.Destination.String())
}
switch {
case l.Redirect:
fields = append(fields, "302")
case l.Reject:
fields = append(fields, "reject")
default:
fields = append(fields, "header")
}
return encodeSurgeKeys(fields)
}
func (l SurgeURLRewriteLine) MarshalJSON() ([]byte, error) {
return json.Marshal(l.String())
}
func (l *SurgeURLRewriteLine) UnmarshalJSON(bytes []byte) error {
var stringValue string
err := json.Unmarshal(bytes, &stringValue)
if err != nil {
return err
}
fields, err := surgeFields(stringValue)
if err != nil {
return E.Cause(err, "invalid surge_url_rewrite line: ", stringValue)
} else if len(fields) < 2 || len(fields) > 3 {
return E.New("invalid surge_url_rewrite line: ", stringValue)
}
pattern, err := regexp.Compile(fields[0].Key)
if err != nil {
return E.Cause(err, "invalid surge_url_rewrite line: invalid pattern: ", stringValue)
}
l.Pattern = pattern
l.Destination, err = url.Parse(fields[1].Key)
if err != nil {
return E.Cause(err, "invalid surge_url_rewrite line: invalid destination: ", stringValue)
}
if len(fields) == 3 {
switch fields[2].Key {
case "header":
case "302":
l.Redirect = true
case "reject":
l.Reject = true
default:
return E.New("invalid surge_url_rewrite line: invalid action: ", stringValue)
}
}
return nil
}
type SurgeHeaderRewriteLine struct {
Response bool
Pattern *regexp.Regexp
Add bool
Delete bool
Replace bool
ReplaceRegex bool
Key string
Match *regexp.Regexp
Value string
}
func (l SurgeHeaderRewriteLine) String() string {
var fields []string
if !l.Response {
fields = append(fields, "http-request")
} else {
fields = append(fields, "http-response")
}
fields = append(fields, l.Pattern.String())
if l.Add {
fields = append(fields, "header-add")
} else if l.Delete {
fields = append(fields, "header-del")
} else if l.Replace {
fields = append(fields, "header-replace")
} else if l.ReplaceRegex {
fields = append(fields, "header-replace-regex")
}
fields = append(fields, l.Key)
if l.Add || l.Replace {
fields = append(fields, l.Value)
} else if l.ReplaceRegex {
fields = append(fields, l.Match.String(), l.Value)
}
return encodeSurgeKeys(fields)
}
func (l SurgeHeaderRewriteLine) MarshalJSON() ([]byte, error) {
return json.Marshal(l.String())
}
func (l *SurgeHeaderRewriteLine) UnmarshalJSON(bytes []byte) error {
var stringValue string
err := json.Unmarshal(bytes, &stringValue)
if err != nil {
return err
}
fields, err := surgeFields(stringValue)
if err != nil {
return E.Cause(err, "invalid surge_header_rewrite line: ", stringValue)
} else if len(fields) < 4 {
return E.New("invalid surge_header_rewrite line: ", stringValue)
}
switch fields[0].Key {
case "http-request":
case "http-response":
l.Response = true
default:
return E.New("invalid surge_header_rewrite line: invalid type: ", stringValue)
}
l.Pattern, err = regexp.Compile(fields[1].Key)
if err != nil {
return E.Cause(err, "invalid surge_header_rewrite line: invalid pattern: ", stringValue)
}
switch fields[2].Key {
case "header-add":
l.Add = true
if len(fields) != 5 {
return E.New("invalid surge_header_rewrite line: " + stringValue)
}
l.Key = fields[3].Key
l.Value = fields[4].Key
case "header-del":
l.Delete = true
l.Key = fields[3].Key
case "header-replace":
l.Replace = true
if len(fields) != 5 {
return E.New("invalid surge_header_rewrite line: " + stringValue)
}
l.Key = fields[3].Key
l.Value = fields[4].Key
case "header-replace-regex":
l.ReplaceRegex = true
if len(fields) != 6 {
return E.New("invalid surge_header_rewrite line: " + stringValue)
}
l.Key = fields[3].Key
l.Match, err = regexp.Compile(fields[4].Key)
if err != nil {
return E.Cause(err, "invalid surge_header_rewrite line: invalid match: ", stringValue)
}
l.Value = fields[5].Key
default:
return E.New("invalid surge_header_rewrite line: invalid action: ", stringValue)
}
return nil
}
type SurgeBodyRewriteLine struct {
Response bool
Pattern *regexp.Regexp
Match []*regexp.Regexp
Replace []string
}
func (l SurgeBodyRewriteLine) String() string {
var fields []string
if !l.Response {
fields = append(fields, "http-request")
} else {
fields = append(fields, "http-response")
}
for i := 0; i < len(l.Match); i += 2 {
fields = append(fields, l.Match[i].String(), l.Replace[i])
}
return strings.Join(fields, " ")
}
func (l SurgeBodyRewriteLine) MarshalJSON() ([]byte, error) {
return json.Marshal(l.String())
}
func (l *SurgeBodyRewriteLine) UnmarshalJSON(bytes []byte) error {
var stringValue string
err := json.Unmarshal(bytes, &stringValue)
if err != nil {
return err
}
fields, err := surgeFields(stringValue)
if err != nil {
return E.Cause(err, "invalid surge_body_rewrite line: ", stringValue)
} else if len(fields) < 4 {
return E.New("invalid surge_body_rewrite line: ", stringValue)
} else if len(fields)%2 != 0 {
return E.New("invalid surge_body_rewrite line: ", stringValue)
}
switch fields[0].Key {
case "http-request":
case "http-response":
l.Response = true
default:
return E.New("invalid surge_body_rewrite line: invalid type: ", stringValue)
}
l.Pattern, err = regexp.Compile(fields[1].Key)
for i := 2; i < len(fields); i += 2 {
var match *regexp.Regexp
match, err = regexp.Compile(fields[i].Key)
if err != nil {
return E.Cause(err, "invalid surge_body_rewrite line: invalid match: ", stringValue)
}
l.Match = append(l.Match, match)
l.Replace = append(l.Replace, fields[i+1].Key)
}
return nil
}
type SurgeMapLocalLine struct {
Pattern *regexp.Regexp
StatusCode int
File bool
Text bool
TinyGif bool
Base64 bool
Data string
Base64Data []byte
Headers http.Header
}
func (l SurgeMapLocalLine) String() string {
var fields []surgeField
fields = append(fields, surgeField{Key: l.Pattern.String()})
if l.File {
fields = append(fields, surgeField{Key: "data-type", Value: "file"})
fields = append(fields, surgeField{Key: "data", Value: l.Data})
} else if l.Text {
fields = append(fields, surgeField{Key: "data-type", Value: "text"})
fields = append(fields, surgeField{Key: "data", Value: l.Data})
} else if l.TinyGif {
fields = append(fields, surgeField{Key: "data-type", Value: "tiny-gif"})
} else if l.Base64 {
fields = append(fields, surgeField{Key: "data-type", Value: "base64"})
fields = append(fields, surgeField{Key: "data-type", Value: base64.StdEncoding.EncodeToString(l.Base64Data)})
}
fields = append(fields, surgeField{Key: "status-code", Value: F.ToString(l.StatusCode), ValueSet: true})
if len(l.Headers) > 0 {
var headers []string
for key, values := range l.Headers {
for _, value := range values {
headers = append(headers, key+":"+value)
}
}
fields = append(fields, surgeField{Key: "headers", Value: strings.Join(headers, "|")})
}
return encodeSurgeFields(fields)
}
func (l SurgeMapLocalLine) MarshalJSON() ([]byte, error) {
return json.Marshal(l.String())
}
func (l *SurgeMapLocalLine) UnmarshalJSON(bytes []byte) error {
var stringValue string
err := json.Unmarshal(bytes, &stringValue)
if err != nil {
return err
}
fields, err := surgeFields(stringValue)
if err != nil {
return E.Cause(err, "invalid surge_map_local line: ", stringValue)
} else if len(fields) < 1 {
return E.New("invalid surge_map_local line: ", stringValue)
}
l.Pattern, err = regexp.Compile(fields[0].Key)
if err != nil {
return E.Cause(err, "invalid surge_map_local line: invalid pattern: ", stringValue)
}
dataTypeField := common.Find(fields, func(it surgeField) bool {
return it.Key == "data-type"
})
if !dataTypeField.ValueSet {
return E.New("invalid surge_map_local line: missing data-type: ", stringValue)
}
switch dataTypeField.Value {
case "file":
l.File = true
case "text":
l.Text = true
case "tiny-gif":
l.TinyGif = true
case "base64":
l.Base64 = true
}
for i := 1; i < len(fields); i++ {
switch fields[i].Key {
case "data-type":
continue
case "data":
if l.File {
l.Data = fields[i].Value
} else if l.Text {
l.Data = fields[i].Value
} else if l.Base64 {
l.Base64Data, err = base64.StdEncoding.DecodeString(fields[i].Value)
if err != nil {
return E.New("invalid surge_map_local line: invalid base64 data: ", stringValue)
}
}
case "status-code":
statusCode, err := strconv.ParseInt(fields[i].Value, 10, 16)
if err != nil {
return E.New("invalid surge_map_local line: invalid status code: ", stringValue)
}
l.StatusCode = int(statusCode)
case "headers":
headers := make(http.Header)
for _, headerLine := range strings.Split(fields[i].Value, "|") {
if !strings.Contains(headerLine, ":") {
return E.New("invalid surge_map_local line: headers: missing `:` in item: ", stringValue, ": ", headerLine)
}
headers.Add(common.SubstringBefore(headerLine, ":"), common.SubstringAfter(headerLine, ":"))
}
l.Headers = headers
default:
return E.New("invalid surge_map_local line: unknown options: ", stringValue)
}
}
return nil
}
type surgeField struct {
Key string
Value string
ValueSet bool
}
func encodeSurgeKeys(keys []string) string {
keys = common.Map(keys, func(it string) string {
if strings.ContainsFunc(it, unicode.IsSpace) {
return "\"" + it + "\""
} else {
return it
}
})
return strings.Join(keys, " ")
}
func encodeSurgeFields(fields []surgeField) string {
return strings.Join(common.Map(fields, func(it surgeField) string {
if !it.ValueSet {
if strings.ContainsFunc(it.Key, unicode.IsSpace) {
return "\"" + it.Key + "\""
} else {
return it.Key
}
} else {
if strings.ContainsFunc(it.Value, unicode.IsSpace) {
return it.Key + "=\"" + it.Value + "\""
} else {
return it.Key + "=" + it.Value
}
}
}), " ")
}
func surgeFields(s string) ([]surgeField, error) {
var (
fields []surgeField
currentField *surgeField
)
for _, field := range strings.Fields(s) {
if currentField != nil {
field = " " + field
if strings.HasSuffix(field, "\"") {
field = field[:len(field)-1]
if !currentField.ValueSet {
currentField.Key += field
} else {
currentField.Value += field
}
fields = append(fields, *currentField)
currentField = nil
} else {
if !currentField.ValueSet {
currentField.Key += field
} else {
currentField.Value += " " + field
}
}
}
if !strings.Contains(field, "=") {
if strings.HasPrefix(field, "\"") {
field = field[1:]
if strings.HasSuffix(field, "\"") {
field = field[:len(field)-1]
} else {
currentField = &surgeField{Key: field}
continue
}
}
fields = append(fields, surgeField{Key: field})
} else {
key := common.SubstringBefore(field, "=")
value := common.SubstringAfter(field, "=")
if strings.HasPrefix(value, "\"") {
value = value[1:]
if strings.HasSuffix(field, "\"") {
value = value[:len(value)-1]
} else {
currentField = &surgeField{Key: key, Value: field, ValueSet: true}
continue
}
}
fields = append(fields, surgeField{Key: key, Value: value, ValueSet: true})
}
}
if currentField != nil {
return nil, E.New("invalid surge fields line: ", s)
}
return fields, nil
}

View File

@@ -12,13 +12,15 @@ type _Options struct {
Schema string `json:"$schema,omitempty"` Schema string `json:"$schema,omitempty"`
Log *LogOptions `json:"log,omitempty"` Log *LogOptions `json:"log,omitempty"`
DNS *DNSOptions `json:"dns,omitempty"` DNS *DNSOptions `json:"dns,omitempty"`
NTP *NTPOptions `json:"ntp,omitempty"`
Certificate *CertificateOptions `json:"certificate,omitempty"`
Endpoints []Endpoint `json:"endpoints,omitempty"` Endpoints []Endpoint `json:"endpoints,omitempty"`
Inbounds []Inbound `json:"inbounds,omitempty"` Inbounds []Inbound `json:"inbounds,omitempty"`
Outbounds []Outbound `json:"outbounds,omitempty"` Outbounds []Outbound `json:"outbounds,omitempty"`
Route *RouteOptions `json:"route,omitempty"` Route *RouteOptions `json:"route,omitempty"`
Experimental *ExperimentalOptions `json:"experimental,omitempty"` Experimental *ExperimentalOptions `json:"experimental,omitempty"`
NTP *NTPOptions `json:"ntp,omitempty"`
Certificate *CertificateOptions `json:"certificate,omitempty"`
MITM *MITMOptions `json:"mitm,omitempty"`
Scripts []Script `json:"scripts,omitempty"`
} }
type Options _Options type Options _Options

View File

@@ -153,6 +153,8 @@ type RawRouteOptionsActionOptions struct {
TLSFragment bool `json:"tls_fragment,omitempty"` TLSFragment bool `json:"tls_fragment,omitempty"`
TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"` TLSFragmentFallbackDelay badoption.Duration `json:"tls_fragment_fallback_delay,omitempty"`
MITM *MITMRouteOptions `json:"mitm,omitempty"`
} }
type RouteOptionsActionOptions RawRouteOptionsActionOptions type RouteOptionsActionOptions RawRouteOptionsActionOptions

128
option/script.go Normal file
View File

@@ -0,0 +1,128 @@
package option
import (
C "github.com/sagernet/sing-box/constant"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badjson"
"github.com/sagernet/sing/common/json/badoption"
)
type _ScriptSourceOptions struct {
Source string `json:"source"`
LocalOptions LocalScriptSource `json:"-"`
RemoteOptions RemoteScriptSource `json:"-"`
}
type LocalScriptSource struct {
Path string `json:"path"`
}
type RemoteScriptSource struct {
URL string `json:"url"`
DownloadDetour string `json:"download_detour,omitempty"`
UpdateInterval badoption.Duration `json:"update_interval,omitempty"`
}
type ScriptSourceOptions _ScriptSourceOptions
func (o ScriptSourceOptions) MarshalJSON() ([]byte, error) {
var source any
switch o.Source {
case C.ScriptSourceTypeLocal:
source = o.LocalOptions
case C.ScriptSourceTypeRemote:
source = o.RemoteOptions
default:
return nil, E.New("unknown script source: ", o.Source)
}
return badjson.MarshallObjects((_ScriptSourceOptions)(o), source)
}
func (o *ScriptSourceOptions) UnmarshalJSON(bytes []byte) error {
err := json.Unmarshal(bytes, (*_ScriptSourceOptions)(o))
if err != nil {
return err
}
var source any
switch o.Source {
case C.ScriptSourceTypeLocal:
source = &o.LocalOptions
case C.ScriptSourceTypeRemote:
source = &o.RemoteOptions
default:
return E.New("unknown script source: ", o.Source)
}
return json.Unmarshal(bytes, source)
}
// TODO: make struct in order
type Script struct {
ScriptSourceOptions
ScriptOptions
}
func (s Script) MarshalJSON() ([]byte, error) {
return badjson.MarshallObjects(s.ScriptSourceOptions, s.ScriptOptions)
}
func (s *Script) UnmarshalJSON(bytes []byte) error {
err := json.Unmarshal(bytes, &s.ScriptSourceOptions)
if err != nil {
return err
}
return badjson.UnmarshallExcluded(bytes, &s.ScriptSourceOptions, &s.ScriptOptions)
}
type _ScriptOptions struct {
Type string `json:"type"`
Tag string `json:"tag"`
SurgeOptions SurgeScriptOptions `json:"-"`
}
type ScriptOptions _ScriptOptions
func (o ScriptOptions) MarshalJSON() ([]byte, error) {
var v any
switch o.Type {
case C.ScriptTypeSurge:
v = &o.SurgeOptions
default:
return nil, E.New("unknown script type: ", o.Type)
}
if v == nil {
return badjson.MarshallObjects((_ScriptOptions)(o))
}
return badjson.MarshallObjects((_ScriptOptions)(o), v)
}
func (o *ScriptOptions) UnmarshalJSON(bytes []byte) error {
err := json.Unmarshal(bytes, (*_ScriptOptions)(o))
if err != nil {
return err
}
var v any
switch o.Type {
case C.ScriptTypeSurge:
v = &o.SurgeOptions
case "":
return E.New("missing script type")
default:
return E.New("unknown script type: ", o.Type)
}
if v == nil {
// check unknown fields
return json.UnmarshalDisallowUnknownFields(bytes, &_ScriptOptions{})
}
return badjson.UnmarshallExcluded(bytes, (*_ScriptOptions)(o), v)
}
type SurgeScriptOptions struct {
CronOptions *CronScriptOptions `json:"cron,omitempty"`
}
type CronScriptOptions struct {
Expression string `json:"expression"`
Arguments []string `json:"arguments,omitempty"`
Timeout badoption.Duration `json:"timeout,omitempty"`
}

View File

@@ -53,7 +53,14 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL
if options.Detour == "" { if options.Detour == "" {
options.IsWireGuardListener = true options.IsWireGuardListener = true
} }
outboundDialer, err := dialer.New(ctx, options.DialerOptions, false) outboundDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: options.DialerOptions,
RemoteIsDomain: common.Any(options.Peers, func(it option.WireGuardPeer) bool {
return !M.ParseAddr(it.Address).IsValid()
}),
ResolverOnDetour: true,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -56,7 +56,14 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
} else if options.GSO { } else if options.GSO {
return nil, E.New("gso is conflict with detour") return nil, E.New("gso is conflict with detour")
} }
outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain()) outboundDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: options.DialerOptions,
RemoteIsDomain: options.ServerIsDomain() || common.Any(options.Peers, func(it option.LegacyWireGuardPeer) bool {
return it.ServerIsDomain()
}),
ResolverOnDetour: true,
})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -21,23 +21,31 @@ import (
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
) )
var _ adapter.ConnectionManager = (*ConnectionManager)(nil) var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
type ConnectionManager struct { type ConnectionManager struct {
ctx context.Context
logger logger.ContextLogger logger logger.ContextLogger
mitm adapter.MITMEngine
access sync.Mutex access sync.Mutex
connections list.List[io.Closer] connections list.List[io.Closer]
} }
func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager { func NewConnectionManager(ctx context.Context, logger logger.ContextLogger) *ConnectionManager {
return &ConnectionManager{ return &ConnectionManager{
ctx: ctx,
logger: logger, logger: logger,
} }
} }
func (m *ConnectionManager) Start(stage adapter.StartStage) error { func (m *ConnectionManager) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStateInitialize:
m.mitm = service.FromContext[adapter.MITMEngine](m.ctx)
}
return nil return nil
} }
@@ -52,6 +60,14 @@ func (m *ConnectionManager) Close() error {
} }
func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if metadata.MITM != nil && metadata.MITM.Enabled {
if m.mitm == nil {
m.logger.WarnContext(ctx, "MITM disabled")
} else {
m.mitm.NewConnection(ctx, this, conn, metadata, onClose)
return
}
}
ctx = adapter.WithContext(ctx, &metadata) ctx = adapter.WithContext(ctx, &metadata)
var ( var (
remoteConn net.Conn remoteConn net.Conn
@@ -175,6 +191,12 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
natConn.UpdateDestination(destinationAddress) natConn.UpdateDestination(destinationAddress)
} }
} else if metadata.RouteOriginalDestination.IsValid() && metadata.RouteOriginalDestination != metadata.Destination {
if metadata.UDPDisableDomainUnmapping {
remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), metadata.Destination, metadata.RouteOriginalDestination)
} else {
remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), metadata.Destination, metadata.RouteOriginalDestination)
}
} }
var udpTimeout time.Duration var udpTimeout time.Duration
if metadata.UDPTimeout > 0 { if metadata.UDPTimeout > 0 {

View File

@@ -458,6 +458,9 @@ match:
metadata.TLSFragment = true metadata.TLSFragment = true
metadata.TLSFragmentFallbackDelay = routeOptions.TLSFragmentFallbackDelay metadata.TLSFragmentFallbackDelay = routeOptions.TLSFragmentFallbackDelay
} }
if routeOptions.MITM != nil && routeOptions.MITM.Enabled {
metadata.MITM = routeOptions.MITM
}
} }
switch action := currentRule.Action().(type) { switch action := currentRule.Action().(type) {
case *rule.RuleActionSniff: case *rule.RuleActionSniff:
@@ -594,7 +597,7 @@ func (r *Router) actionSniff(
return return
} }
} else { } else {
if !metadata.Destination.Addr.IsGlobalUnicast() { if !metadata.Destination.Addr.IsGlobalUnicast() && !metadata.RouteOriginalDestination.IsValid() {
metadata.Destination = destination metadata.Destination = destination
} }
if len(packetBuffers) > 0 { if len(packetBuffers) > 0 {

View File

@@ -162,7 +162,7 @@ func (r *Router) Start(stage adapter.StartStage) error {
r.started = true r.started = true
return nil return nil
case adapter.StartStateStarted: case adapter.StartStateStarted:
for _, ruleSet := range r.ruleSetMap { for _, ruleSet := range r.ruleSets {
ruleSet.Cleanup() ruleSet.Cleanup()
} }
runtime.GC() runtime.GC()
@@ -180,6 +180,13 @@ func (r *Router) Close() error {
}) })
monitor.Finish() monitor.Finish()
} }
for i, ruleSet := range r.ruleSets {
monitor.Start("close rule-set[", i, "]")
err = E.Append(err, ruleSet.Close(), func(err error) error {
return E.Cause(err, "close rule-set[", i, "]")
})
monitor.Finish()
}
return err return err
} }

View File

@@ -38,6 +38,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
UDPConnect: action.RouteOptions.UDPConnect, UDPConnect: action.RouteOptions.UDPConnect,
TLSFragment: action.RouteOptions.TLSFragment, TLSFragment: action.RouteOptions.TLSFragment,
TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay), TLSFragmentFallbackDelay: time.Duration(action.RouteOptions.TLSFragmentFallbackDelay),
MITM: action.RouteOptions.MITM,
}, },
}, nil }, nil
case C.RuleActionTypeRouteOptions: case C.RuleActionTypeRouteOptions:
@@ -51,6 +52,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti
UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout), UDPTimeout: time.Duration(action.RouteOptionsOptions.UDPTimeout),
TLSFragment: action.RouteOptionsOptions.TLSFragment, TLSFragment: action.RouteOptionsOptions.TLSFragment,
TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay), TLSFragmentFallbackDelay: time.Duration(action.RouteOptionsOptions.TLSFragmentFallbackDelay),
MITM: action.RouteOptionsOptions.MITM,
}, nil }, nil
case C.RuleActionTypeDirect: case C.RuleActionTypeDirect:
directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false) directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions), false)
@@ -140,15 +142,7 @@ func (r *RuleActionRoute) Type() string {
func (r *RuleActionRoute) String() string { func (r *RuleActionRoute) String() string {
var descriptions []string var descriptions []string
descriptions = append(descriptions, r.Outbound) descriptions = append(descriptions, r.Outbound)
if r.UDPDisableDomainUnmapping { descriptions = append(descriptions, r.Descriptions()...)
descriptions = append(descriptions, "udp-disable-domain-unmapping")
}
if r.UDPConnect {
descriptions = append(descriptions, "udp-connect")
}
if r.TLSFragment {
descriptions = append(descriptions, "tls-fragment")
}
return F.ToString("route(", strings.Join(descriptions, ","), ")") return F.ToString("route(", strings.Join(descriptions, ","), ")")
} }
@@ -164,14 +158,33 @@ type RuleActionRouteOptions struct {
UDPTimeout time.Duration UDPTimeout time.Duration
TLSFragment bool TLSFragment bool
TLSFragmentFallbackDelay time.Duration TLSFragmentFallbackDelay time.Duration
MITM *option.MITMRouteOptions
} }
func (r *RuleActionRouteOptions) Type() string { func (r *RuleActionRouteOptions) Type() string {
return C.RuleActionTypeRouteOptions return C.RuleActionTypeRouteOptions
} }
func (r *RuleActionRouteOptions) String() string { func (r *RuleActionRouteOptions) Descriptions() []string {
var descriptions []string var descriptions []string
if r.OverrideAddress.IsValid() {
descriptions = append(descriptions, F.ToString("override-address=", r.OverrideAddress.AddrString()))
}
if r.OverridePort > 0 {
descriptions = append(descriptions, F.ToString("override-port=", r.OverridePort))
}
if r.NetworkStrategy != nil {
descriptions = append(descriptions, F.ToString("network-strategy=", r.NetworkStrategy))
}
if r.NetworkType != nil {
descriptions = append(descriptions, F.ToString("network-type=", strings.Join(common.Map(r.NetworkType, C.InterfaceType.String), ",")))
}
if r.FallbackNetworkType != nil {
descriptions = append(descriptions, F.ToString("fallback-network-type="+strings.Join(common.Map(r.NetworkType, C.InterfaceType.String), ",")))
}
if r.FallbackDelay > 0 {
descriptions = append(descriptions, F.ToString("fallback-delay=", r.FallbackDelay.String()))
}
if r.UDPDisableDomainUnmapping { if r.UDPDisableDomainUnmapping {
descriptions = append(descriptions, "udp-disable-domain-unmapping") descriptions = append(descriptions, "udp-disable-domain-unmapping")
} }
@@ -179,9 +192,22 @@ func (r *RuleActionRouteOptions) String() string {
descriptions = append(descriptions, "udp-connect") descriptions = append(descriptions, "udp-connect")
} }
if r.UDPTimeout > 0 { if r.UDPTimeout > 0 {
descriptions = append(descriptions, "udp-timeout") descriptions = append(descriptions, F.ToString("udp-timeout=", r.UDPTimeout))
} }
return F.ToString("route-options(", strings.Join(descriptions, ","), ")") if r.TLSFragment {
descriptions = append(descriptions, "tls-fragment")
if r.TLSFragmentFallbackDelay > 0 {
descriptions = append(descriptions, F.ToString("tls-fragment-fallbac-delay=", r.TLSFragmentFallbackDelay.String()))
}
}
if r.MITM != nil && r.MITM.Enabled {
descriptions = append(descriptions, "mitm")
}
return descriptions
}
func (r *RuleActionRouteOptions) String() string {
return F.ToString("route-options(", strings.Join(r.Descriptions(), ","), ")")
} }
type RuleActionDNSRoute struct { type RuleActionDNSRoute struct {

View File

@@ -5,6 +5,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"github.com/sagernet/fswatch" "github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
@@ -26,14 +27,16 @@ import (
var _ adapter.RuleSet = (*LocalRuleSet)(nil) var _ adapter.RuleSet = (*LocalRuleSet)(nil)
type LocalRuleSet struct { type LocalRuleSet struct {
ctx context.Context ctx context.Context
logger logger.Logger logger logger.Logger
tag string tag string
rules []adapter.HeadlessRule rules []adapter.HeadlessRule
metadata adapter.RuleSetMetadata metadata adapter.RuleSetMetadata
fileFormat string fileFormat string
watcher *fswatch.Watcher watcher *fswatch.Watcher
refs atomic.Int32 callbackAccess sync.Mutex
callbacks list.List[adapter.RuleSetUpdateCallback]
refs atomic.Int32
} }
func NewLocalRuleSet(ctx context.Context, logger logger.Logger, options option.RuleSet) (*LocalRuleSet, error) { func NewLocalRuleSet(ctx context.Context, logger logger.Logger, options option.RuleSet) (*LocalRuleSet, error) {
@@ -52,13 +55,12 @@ func NewLocalRuleSet(ctx context.Context, logger logger.Logger, options option.R
return nil, err return nil, err
} }
} else { } else {
err := ruleSet.reloadFile(filemanager.BasePath(ctx, options.LocalOptions.Path)) filePath := filemanager.BasePath(ctx, options.LocalOptions.Path)
filePath, _ = filepath.Abs(filePath)
err := ruleSet.reloadFile(filePath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
}
if options.Type == C.RuleSetTypeLocal {
filePath, _ := filepath.Abs(options.LocalOptions.Path)
watcher, err := fswatch.NewWatcher(fswatch.Options{ watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: []string{filePath}, Path: []string{filePath},
Callback: func(path string) { Callback: func(path string) {
@@ -141,6 +143,12 @@ func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error {
metadata.ContainsIPCIDRRule = hasHeadlessRule(headlessRules, isIPCIDRHeadlessRule) metadata.ContainsIPCIDRRule = hasHeadlessRule(headlessRules, isIPCIDRHeadlessRule)
s.rules = rules s.rules = rules
s.metadata = metadata s.metadata = metadata
s.callbackAccess.Lock()
callbacks := s.callbacks.Array()
s.callbackAccess.Unlock()
for _, callback := range callbacks {
callback(s)
}
return nil return nil
} }
@@ -173,10 +181,15 @@ func (s *LocalRuleSet) Cleanup() {
} }
func (s *LocalRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] { func (s *LocalRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] {
return nil s.callbackAccess.Lock()
defer s.callbackAccess.Unlock()
return s.callbacks.PushBack(callback)
} }
func (s *LocalRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) { func (s *LocalRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUpdateCallback]) {
s.callbackAccess.Lock()
defer s.callbackAccess.Unlock()
s.callbacks.Remove(element)
} }
func (s *LocalRuleSet) Close() error { func (s *LocalRuleSet) Close() error {

View File

@@ -35,23 +35,23 @@ import (
var _ adapter.RuleSet = (*RemoteRuleSet)(nil) var _ adapter.RuleSet = (*RemoteRuleSet)(nil)
type RemoteRuleSet struct { type RemoteRuleSet struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
outboundManager adapter.OutboundManager logger logger.ContextLogger
logger logger.ContextLogger outbound adapter.OutboundManager
options option.RuleSet options option.RuleSet
metadata adapter.RuleSetMetadata metadata adapter.RuleSetMetadata
updateInterval time.Duration updateInterval time.Duration
dialer N.Dialer dialer N.Dialer
rules []adapter.HeadlessRule rules []adapter.HeadlessRule
lastUpdated time.Time lastUpdated time.Time
lastEtag string lastEtag string
updateTicker *time.Ticker updateTicker *time.Ticker
cacheFile adapter.CacheFile cacheFile adapter.CacheFile
pauseManager pause.Manager pauseManager pause.Manager
callbackAccess sync.Mutex callbackAccess sync.Mutex
callbacks list.List[adapter.RuleSetUpdateCallback] callbacks list.List[adapter.RuleSetUpdateCallback]
refs atomic.Int32 refs atomic.Int32
} }
func NewRemoteRuleSet(ctx context.Context, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet { func NewRemoteRuleSet(ctx context.Context, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet {
@@ -63,13 +63,13 @@ func NewRemoteRuleSet(ctx context.Context, logger logger.ContextLogger, options
updateInterval = 24 * time.Hour updateInterval = 24 * time.Hour
} }
return &RemoteRuleSet{ return &RemoteRuleSet{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
outboundManager: service.FromContext[adapter.OutboundManager](ctx), outbound: service.FromContext[adapter.OutboundManager](ctx),
logger: logger, logger: logger,
options: options, options: options,
updateInterval: updateInterval, updateInterval: updateInterval,
pauseManager: service.FromContext[pause.Manager](ctx), pauseManager: service.FromContext[pause.Manager](ctx),
} }
} }
@@ -85,13 +85,13 @@ func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext *adapter.
s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx) s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx)
var dialer N.Dialer var dialer N.Dialer
if s.options.RemoteOptions.DownloadDetour != "" { if s.options.RemoteOptions.DownloadDetour != "" {
outbound, loaded := s.outboundManager.Outbound(s.options.RemoteOptions.DownloadDetour) outbound, loaded := s.outbound.Outbound(s.options.RemoteOptions.DownloadDetour)
if !loaded { if !loaded {
return E.New("download_detour not found: ", s.options.RemoteOptions.DownloadDetour) return E.New("download detour not found: ", s.options.RemoteOptions.DownloadDetour)
} }
dialer = outbound dialer = outbound
} else { } else {
dialer = s.outboundManager.Default() dialer = s.outbound.Default()
} }
s.dialer = dialer s.dialer = dialer
if s.cacheFile != nil { if s.cacheFile != nil {
@@ -292,7 +292,7 @@ func (s *RemoteRuleSet) fetchOnce(ctx context.Context, startContext *adapter.HTT
} }
s.lastUpdated = time.Now() s.lastUpdated = time.Now()
if s.cacheFile != nil { if s.cacheFile != nil {
err = s.cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedRuleSet{ err = s.cacheFile.SaveRuleSet(s.options.Tag, &adapter.SavedBinary{
LastUpdated: s.lastUpdated, LastUpdated: s.lastUpdated,
Content: content, Content: content,
LastEtag: s.lastEtag, LastEtag: s.lastEtag,

23
script/jsc/array.go Normal file
View File

@@ -0,0 +1,23 @@
package jsc
import (
_ "unsafe"
"github.com/dop251/goja"
)
func NewUint8Array(runtime *goja.Runtime, data []byte) goja.Value {
buffer := runtime.NewArrayBuffer(data)
ctor, loaded := goja.AssertConstructor(runtimeGetUint8Array(runtime))
if !loaded {
panic(runtime.NewTypeError("missing UInt8Array constructor"))
}
array, err := ctor(nil, runtime.ToValue(buffer))
if err != nil {
panic(runtime.NewGoError(err))
}
return array
}
//go:linkname runtimeGetUint8Array github.com/dop251/goja.(*Runtime).getUint8Array
func runtimeGetUint8Array(r *goja.Runtime) *goja.Object

18
script/jsc/array_test.go Normal file
View File

@@ -0,0 +1,18 @@
package jsc_test
import (
"testing"
"github.com/sagernet/sing-box/script/jsc"
"github.com/dop251/goja"
"github.com/stretchr/testify/require"
)
func TestNewUInt8Array(t *testing.T) {
runtime := goja.New()
runtime.Set("hello", jsc.NewUint8Array(runtime, []byte("world")))
result, err := runtime.RunString("hello instanceof Uint8Array")
require.NoError(t, err)
require.True(t, result.ToBoolean())
}

124
script/jsc/assert.go Normal file
View File

@@ -0,0 +1,124 @@
package jsc
import (
"net/http"
F "github.com/sagernet/sing/common/format"
"github.com/dop251/goja"
)
func IsNil(value goja.Value) bool {
return value == nil || goja.IsUndefined(value) || goja.IsNull(value)
}
func AssertObject(vm *goja.Runtime, value goja.Value, name string, nilable bool) *goja.Object {
if IsNil(value) {
if nilable {
return nil
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
objectValue, isObject := value.(*goja.Object)
if !isObject {
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected object, but got ", value)))
}
return objectValue
}
func AssertString(vm *goja.Runtime, value goja.Value, name string, nilable bool) string {
if IsNil(value) {
if nilable {
return ""
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
stringValue, isString := value.Export().(string)
if !isString {
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected string, but got ", value)))
}
return stringValue
}
func AssertInt(vm *goja.Runtime, value goja.Value, name string, nilable bool) int64 {
if IsNil(value) {
if nilable {
return 0
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
integerValue, isNumber := value.Export().(int64)
if !isNumber {
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected integer, but got ", value)))
}
return integerValue
}
func AssertBool(vm *goja.Runtime, value goja.Value, name string, nilable bool) bool {
if IsNil(value) {
if nilable {
return false
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
boolValue, isBool := value.Export().(bool)
if !isBool {
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected boolean, but got ", value)))
}
return boolValue
}
func AssertBinary(vm *goja.Runtime, value goja.Value, name string, nilable bool) []byte {
if IsNil(value) {
if nilable {
return nil
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
switch exportedValue := value.Export().(type) {
case []byte:
return exportedValue
case goja.ArrayBuffer:
return exportedValue.Bytes()
default:
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected Uint8Array or ArrayBuffer, but got ", value)))
}
}
func AssertStringBinary(vm *goja.Runtime, value goja.Value, name string, nilable bool) []byte {
if IsNil(value) {
if nilable {
return nil
}
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
switch exportedValue := value.Export().(type) {
case string:
return []byte(exportedValue)
case []byte:
return exportedValue
case goja.ArrayBuffer:
return exportedValue.Bytes()
default:
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected string, Uint8Array or ArrayBuffer, but got ", value)))
}
}
func AssertFunction(vm *goja.Runtime, value goja.Value, name string) goja.Callable {
if IsNil(value) {
panic(vm.NewTypeError(F.ToString("invalid argument: missing ", name)))
}
functionValue, isFunction := goja.AssertFunction(value)
if !isFunction {
panic(vm.NewTypeError(F.ToString("invalid argument: ", name, ": expected function, but got ", value)))
}
return functionValue
}
func AssertHTTPHeader(vm *goja.Runtime, value goja.Value, name string) http.Header {
headersObject := AssertObject(vm, value, name, true)
if headersObject == nil {
return nil
}
return ObjectToHeaders(vm, headersObject, name)
}

192
script/jsc/class.go Normal file
View File

@@ -0,0 +1,192 @@
package jsc
import (
"time"
"github.com/sagernet/sing/common"
"github.com/dop251/goja"
)
type Module interface {
Runtime() *goja.Runtime
}
type Class[M Module, C any] interface {
Module() M
Runtime() *goja.Runtime
DefineField(name string, getter func(this C) any, setter func(this C, value goja.Value))
DefineMethod(name string, method func(this C, call goja.FunctionCall) any)
DefineStaticMethod(name string, method func(c Class[M, C], call goja.FunctionCall) any)
DefineConstructor(constructor func(c Class[M, C], call goja.ConstructorCall) C)
ToValue() goja.Value
New(instance C) *goja.Object
Prototype() *goja.Object
Is(value goja.Value) bool
As(value goja.Value) C
}
func GetClass[M Module, C any](runtime *goja.Runtime, exports *goja.Object, className string) Class[M, C] {
objectValue := exports.Get(className)
if objectValue == nil {
panic(runtime.NewTypeError("Missing class: " + className))
}
object, isObject := objectValue.(*goja.Object)
if !isObject {
panic(runtime.NewTypeError("Invalid class: " + className))
}
classObject, isClass := object.Get("_class").(*goja.Object)
if !isClass {
panic(runtime.NewTypeError("Invalid class: " + className))
}
class, isClass := classObject.Export().(Class[M, C])
if !isClass {
panic(runtime.NewTypeError("Invalid class: " + className))
}
return class
}
type goClass[M Module, C any] struct {
m M
prototype *goja.Object
constructor goja.Value
}
func NewClass[M Module, C any](module M) Class[M, C] {
class := &goClass[M, C]{
m: module,
prototype: module.Runtime().NewObject(),
}
clazz := module.Runtime().ToValue(class).(*goja.Object)
clazz.Set("toString", module.Runtime().ToValue(func(call goja.FunctionCall) goja.Value {
return module.Runtime().ToValue("[sing-box Class]")
}))
class.prototype.DefineAccessorProperty("_class", class.Runtime().ToValue(func(call goja.FunctionCall) goja.Value { return clazz }), nil, goja.FLAG_FALSE, goja.FLAG_TRUE)
return class
}
func (c *goClass[M, C]) Module() M {
return c.m
}
func (c *goClass[M, C]) Runtime() *goja.Runtime {
return c.m.Runtime()
}
func (c *goClass[M, C]) DefineField(name string, getter func(this C) any, setter func(this C, value goja.Value)) {
var (
getterValue goja.Value
setterValue goja.Value
)
if getter != nil {
getterValue = c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value {
this, isThis := call.This.Export().(C)
if !isThis {
panic(c.Runtime().NewTypeError("Illegal this value: " + call.This.ExportType().String()))
}
return c.toValue(getter(this), goja.Null())
})
}
if setter != nil {
setterValue = c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value {
this, isThis := call.This.Export().(C)
if !isThis {
panic(c.Runtime().NewTypeError("Illegal this value: " + call.This.String()))
}
setter(this, call.Argument(0))
return goja.Undefined()
})
}
c.prototype.DefineAccessorProperty(name, getterValue, setterValue, goja.FLAG_FALSE, goja.FLAG_TRUE)
}
func (c *goClass[M, C]) DefineMethod(name string, method func(this C, call goja.FunctionCall) any) {
methodValue := c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value {
this, isThis := call.This.Export().(C)
if !isThis {
panic(c.Runtime().NewTypeError("Illegal this value: " + call.This.String()))
}
return c.toValue(method(this, call), goja.Undefined())
})
c.prototype.Set(name, methodValue)
if name == "entries" {
c.prototype.DefineDataPropertySymbol(goja.SymIterator, methodValue, goja.FLAG_TRUE, goja.FLAG_FALSE, goja.FLAG_TRUE)
}
}
func (c *goClass[M, C]) DefineStaticMethod(name string, method func(c Class[M, C], call goja.FunctionCall) any) {
c.prototype.Set(name, c.Runtime().ToValue(func(call goja.FunctionCall) goja.Value {
return c.toValue(method(c, call), goja.Undefined())
}))
}
func (c *goClass[M, C]) DefineConstructor(constructor func(c Class[M, C], call goja.ConstructorCall) C) {
constructorObject := c.Runtime().ToValue(func(call goja.ConstructorCall) *goja.Object {
value := constructor(c, call)
object := c.toValue(value, goja.Undefined()).(*goja.Object)
object.SetPrototype(call.This.Prototype())
return object
}).(*goja.Object)
constructorObject.SetPrototype(c.prototype)
c.prototype.DefineDataProperty("constructor", constructorObject, goja.FLAG_FALSE, goja.FLAG_FALSE, goja.FLAG_FALSE)
c.constructor = constructorObject
}
func (c *goClass[M, C]) toValue(rawValue any, defaultValue goja.Value) goja.Value {
switch value := rawValue.(type) {
case nil:
return defaultValue
case time.Time:
return TimeToValue(c.Runtime(), value)
default:
return c.Runtime().ToValue(value)
}
}
func (c *goClass[M, C]) ToValue() goja.Value {
if c.constructor == nil {
constructorObject := c.Runtime().ToValue(func(call goja.ConstructorCall) *goja.Object {
panic(c.Runtime().NewTypeError("Illegal constructor call"))
}).(*goja.Object)
constructorObject.SetPrototype(c.prototype)
c.prototype.DefineDataProperty("constructor", constructorObject, goja.FLAG_FALSE, goja.FLAG_FALSE, goja.FLAG_FALSE)
c.constructor = constructorObject
}
return c.constructor
}
func (c *goClass[M, C]) New(instance C) *goja.Object {
object := c.Runtime().ToValue(instance).(*goja.Object)
object.SetPrototype(c.prototype)
return object
}
func (c *goClass[M, C]) Prototype() *goja.Object {
return c.prototype
}
func (c *goClass[M, C]) Is(value goja.Value) bool {
object, isObject := value.(*goja.Object)
if !isObject {
return false
}
prototype := object.Prototype()
for prototype != nil {
if prototype == c.prototype {
return true
}
prototype = prototype.Prototype()
}
return false
}
func (c *goClass[M, C]) As(value goja.Value) C {
object, isObject := value.(*goja.Object)
if !isObject {
return common.DefaultValue[C]()
}
if !c.Is(object) {
return common.DefaultValue[C]()
}
return object.Export().(C)
}

56
script/jsc/headers.go Normal file
View File

@@ -0,0 +1,56 @@
package jsc
import (
"net/http"
"reflect"
"github.com/sagernet/sing/common"
F "github.com/sagernet/sing/common/format"
"github.com/dop251/goja"
)
func HeadersToValue(runtime *goja.Runtime, headers http.Header) goja.Value {
object := runtime.NewObject()
for key, value := range headers {
if len(value) == 1 {
object.Set(key, value[0])
} else {
object.Set(key, ArrayToValue(runtime, value))
}
}
return object
}
func ArrayToValue[T any](runtime *goja.Runtime, values []T) goja.Value {
return runtime.NewArray(common.Map(values, func(it T) any { return it })...)
}
func ObjectToHeaders(vm *goja.Runtime, object *goja.Object, name string) http.Header {
headers := make(http.Header)
for _, key := range object.Keys() {
valueObject := object.Get(key)
switch headerValue := valueObject.(type) {
case goja.String:
headers.Set(key, headerValue.String())
case *goja.Object:
values := headerValue.Export()
valueArray, isArray := values.([]any)
if !isArray {
panic(vm.NewTypeError(F.ToString("invalid value: ", name, ".", key, "expected string or string array, got ", valueObject.String())))
}
newValues := make([]string, 0, len(valueArray))
for _, value := range valueArray {
stringValue, isString := value.(string)
if !isString {
panic(vm.NewTypeError(F.ToString("invalid value: ", name, ".", key, " expected string or string array, got array item type: ", reflect.TypeOf(value))))
}
newValues = append(newValues, stringValue)
}
headers[key] = newValues
default:
panic(vm.NewTypeError(F.ToString("invalid value: ", name, ".", key, " expected string or string array, got ", valueObject.String())))
}
}
return headers
}

View File

@@ -0,0 +1,31 @@
package jsc_test
import (
"fmt"
"net/http"
"reflect"
"testing"
"github.com/sagernet/sing-box/script/jsc"
"github.com/dop251/goja"
"github.com/stretchr/testify/require"
)
func TestHeaders(t *testing.T) {
runtime := goja.New()
runtime.Set("headers", jsc.HeadersToValue(runtime, http.Header{
"My-Header": []string{"My-Value1", "My-Value2"},
}))
headers := runtime.Get("headers").(*goja.Object).Get("My-Header").(*goja.Object)
fmt.Println(reflect.ValueOf(headers.Export()).Type().String())
}
func TestBody(t *testing.T) {
runtime := goja.New()
_, err := runtime.RunString(`
var responseBody = new Uint8Array([1, 2, 3, 4, 5])
`)
require.NoError(t, err)
fmt.Println(reflect.TypeOf(runtime.Get("responseBody").Export()))
}

36
script/jsc/iterator.go Normal file
View File

@@ -0,0 +1,36 @@
package jsc
import "github.com/dop251/goja"
type Iterator[M Module, T any] struct {
c Class[M, *Iterator[M, T]]
values []T
block func(this T) any
}
func NewIterator[M Module, T any](class Class[M, *Iterator[M, T]], values []T, block func(this T) any) goja.Value {
return class.New(&Iterator[M, T]{class, values, block})
}
func CreateIterator[M Module, T any](module M) Class[M, *Iterator[M, T]] {
class := NewClass[M, *Iterator[M, T]](module)
class.DefineMethod("next", (*Iterator[M, T]).next)
class.DefineMethod("toString", (*Iterator[M, T]).toString)
return class
}
func (i *Iterator[M, T]) next(call goja.FunctionCall) any {
result := i.c.Runtime().NewObject()
if len(i.values) == 0 {
result.Set("done", true)
} else {
result.Set("done", false)
result.Set("value", i.block(i.values[0]))
i.values = i.values[1:]
}
return result
}
func (i *Iterator[M, T]) toString(call goja.FunctionCall) any {
return "[sing-box Iterator]"
}

18
script/jsc/time.go Normal file
View File

@@ -0,0 +1,18 @@
package jsc
import (
"time"
_ "unsafe"
"github.com/dop251/goja"
)
func TimeToValue(runtime *goja.Runtime, time time.Time) goja.Value {
return runtimeNewDateObject(runtime, time, true, runtimeGetDatePrototype(runtime))
}
//go:linkname runtimeNewDateObject github.com/dop251/goja.(*Runtime).newDateObject
func runtimeNewDateObject(r *goja.Runtime, t time.Time, isSet bool, proto *goja.Object) *goja.Object
//go:linkname runtimeGetDatePrototype github.com/dop251/goja.(*Runtime).getDatePrototype
func runtimeGetDatePrototype(r *goja.Runtime) *goja.Object

20
script/jsc/time_test.go Normal file
View File

@@ -0,0 +1,20 @@
package jsc_test
import (
"testing"
"time"
"github.com/sagernet/sing-box/script/jsc"
"github.com/dop251/goja"
"github.com/stretchr/testify/require"
)
func TestTimeToValue(t *testing.T) {
t.Parallel()
runtime := goja.New()
now := time.Now()
err := runtime.Set("now", jsc.TimeToValue(runtime, now))
require.NoError(t, err)
println(runtime.Get("now").String())
}

83
script/jstest/assert.js Normal file
View File

@@ -0,0 +1,83 @@
'use strict';
const assert = {
_isSameValue(a, b) {
if (a === b) {
// Handle +/-0 vs. -/+0
return a !== 0 || 1 / a === 1 / b;
}
// Handle NaN vs. NaN
return a !== a && b !== b;
},
_toString(value) {
try {
if (value === 0 && 1 / value === -Infinity) {
return '-0';
}
return String(value);
} catch (err) {
if (err.name === 'TypeError') {
return Object.prototype.toString.call(value);
}
throw err;
}
},
sameValue(actual, expected, message) {
if (assert._isSameValue(actual, expected)) {
return;
}
if (message === undefined) {
message = '';
} else {
message += ' ';
}
message += 'Expected SameValue(«' + assert._toString(actual) + '», «' + assert._toString(expected) + '») to be true';
throw new Error(message);
},
throws(f, ctor, message) {
if (message === undefined) {
message = '';
} else {
message += ' ';
}
try {
f();
} catch (e) {
if (e.constructor !== ctor) {
throw new Error(message + "Wrong exception type was thrown: " + e.constructor.name);
}
return;
}
throw new Error(message + "No exception was thrown");
},
throwsNodeError(f, ctor, code, message) {
if (message === undefined) {
message = '';
} else {
message += ' ';
}
try {
f();
} catch (e) {
if (e.constructor !== ctor) {
throw new Error(message + "Wrong exception type was thrown: " + e.constructor.name);
}
if (e.code !== code) {
throw new Error(message + "Wrong exception code was thrown: " + e.code);
}
return;
}
throw new Error(message + "No exception was thrown");
}
}
module.exports = assert;

21
script/jstest/test.go Normal file
View File

@@ -0,0 +1,21 @@
package jstest
import (
_ "embed"
"github.com/sagernet/sing-box/script/modules/require"
)
//go:embed assert.js
var assertJS []byte
func NewRegistry() *require.Registry {
return require.NewRegistry(require.WithFsEnable(true), require.WithLoader(func(path string) ([]byte, error) {
switch path {
case "assert.js":
return assertJS, nil
default:
return require.DefaultSourceLoader(path)
}
}))
}

116
script/manager.go Normal file
View File

@@ -0,0 +1,116 @@
package script
import (
"context"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/task"
)
var _ adapter.ScriptManager = (*Manager)(nil)
type Manager struct {
ctx context.Context
logger logger.ContextLogger
scripts []adapter.Script
scriptByName map[string]adapter.Script
surgeCache *adapter.SurgeInMemoryCache
}
func NewManager(ctx context.Context, logFactory log.Factory, scripts []option.Script) (*Manager, error) {
manager := &Manager{
ctx: ctx,
logger: logFactory.NewLogger("script"),
scriptByName: make(map[string]adapter.Script),
}
for _, scriptOptions := range scripts {
script, err := NewScript(ctx, logFactory.NewLogger(F.ToString("script/", scriptOptions.Type, "[", scriptOptions.Tag, "]")), scriptOptions)
if err != nil {
return nil, E.Cause(err, "initialize script: ", scriptOptions.Tag)
}
manager.scripts = append(manager.scripts, script)
manager.scriptByName[scriptOptions.Tag] = script
}
return manager, nil
}
func (m *Manager) Start(stage adapter.StartStage) error {
monitor := taskmonitor.New(m.logger, C.StartTimeout)
switch stage {
case adapter.StartStateStart:
var cacheContext *adapter.HTTPStartContext
if len(m.scripts) > 0 {
monitor.Start("initialize rule-set")
cacheContext = adapter.NewHTTPStartContext(m.ctx)
var scriptStartGroup task.Group
for _, script := range m.scripts {
scriptInPlace := script
scriptStartGroup.Append0(func(ctx context.Context) error {
err := scriptInPlace.StartContext(ctx, cacheContext)
if err != nil {
return E.Cause(err, "initialize script/", scriptInPlace.Type(), "[", scriptInPlace.Tag(), "]")
}
return nil
})
}
scriptStartGroup.Concurrency(5)
scriptStartGroup.FastFail()
err := scriptStartGroup.Run(m.ctx)
monitor.Finish()
if err != nil {
return err
}
}
if cacheContext != nil {
cacheContext.Close()
}
case adapter.StartStatePostStart:
for _, script := range m.scripts {
monitor.Start(F.ToString("post start script/", script.Type(), "[", script.Tag(), "]"))
err := script.PostStart()
monitor.Finish()
if err != nil {
return E.Cause(err, "post start script/", script.Type(), "[", script.Tag(), "]")
}
}
}
return nil
}
func (m *Manager) Close() error {
monitor := taskmonitor.New(m.logger, C.StopTimeout)
var err error
for _, script := range m.scripts {
monitor.Start(F.ToString("close start script/", script.Type(), "[", script.Tag(), "]"))
err = E.Append(err, script.Close(), func(err error) error {
return E.Cause(err, "close script/", script.Type(), "[", script.Tag(), "]")
})
monitor.Finish()
}
return err
}
func (m *Manager) Scripts() []adapter.Script {
return m.scripts
}
func (m *Manager) Script(name string) (adapter.Script, bool) {
script, loaded := m.scriptByName[name]
return script, loaded
}
func (m *Manager) SurgeCache() *adapter.SurgeInMemoryCache {
if m.surgeCache == nil {
m.surgeCache = &adapter.SurgeInMemoryCache{
Data: make(map[string]string),
}
}
return m.surgeCache
}

View File

@@ -0,0 +1,50 @@
package boxctx
import (
"context"
"time"
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing/common/logger"
"github.com/dop251/goja"
)
type Context struct {
class jsc.Class[*Module, *Context]
Context context.Context
Logger logger.ContextLogger
Tag string
StartedAt time.Time
ErrorHandler func(error)
}
func FromRuntime(runtime *goja.Runtime) *Context {
contextValue := runtime.Get("context")
if contextValue == nil {
return nil
}
context, isContext := contextValue.Export().(*Context)
if !isContext {
return nil
}
return context
}
func MustFromRuntime(runtime *goja.Runtime) *Context {
context := FromRuntime(runtime)
if context == nil {
panic(runtime.NewTypeError("Missing sing-box context"))
}
return context
}
func createContext(module *Module) jsc.Class[*Module, *Context] {
class := jsc.NewClass[*Module, *Context](module)
class.DefineMethod("toString", (*Context).toString)
return class
}
func (c *Context) toString(call goja.FunctionCall) any {
return "[sing-box Context]"
}

View File

@@ -0,0 +1,35 @@
package boxctx
import (
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing-box/script/modules/require"
"github.com/dop251/goja"
)
const ModuleName = "context"
type Module struct {
runtime *goja.Runtime
classContext jsc.Class[*Module, *Context]
}
func Require(runtime *goja.Runtime, module *goja.Object) {
m := &Module{
runtime: runtime,
}
m.classContext = createContext(m)
exports := module.Get("exports").(*goja.Object)
exports.Set("Context", m.classContext.ToValue())
}
func Enable(runtime *goja.Runtime, context *Context) {
exports := require.Require(runtime, ModuleName).ToObject(runtime)
classContext := jsc.GetClass[*Module, *Context](runtime, exports, "Context")
context.class = classContext
runtime.Set("context", classContext.New(context))
}
func (m *Module) Runtime() *goja.Runtime {
return m.runtime
}

View File

@@ -0,0 +1,281 @@
package console
import (
"bytes"
"context"
"encoding/xml"
"sync"
"time"
sLog "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing-box/script/modules/boxctx"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
"github.com/dop251/goja"
)
type Console struct {
class jsc.Class[*Module, *Console]
access sync.Mutex
countMap map[string]int
timeMap map[string]time.Time
}
func NewConsole(class jsc.Class[*Module, *Console]) goja.Value {
return class.New(&Console{
class: class,
countMap: make(map[string]int),
timeMap: make(map[string]time.Time),
})
}
func createConsole(m *Module) jsc.Class[*Module, *Console] {
class := jsc.NewClass[*Module, *Console](m)
class.DefineMethod("assert", (*Console).assert)
class.DefineMethod("clear", (*Console).clear)
class.DefineMethod("count", (*Console).count)
class.DefineMethod("countReset", (*Console).countReset)
class.DefineMethod("debug", (*Console).debug)
class.DefineMethod("dir", (*Console).dir)
class.DefineMethod("dirxml", (*Console).dirxml)
class.DefineMethod("error", (*Console).error)
class.DefineMethod("group", (*Console).stub)
class.DefineMethod("groupCollapsed", (*Console).stub)
class.DefineMethod("groupEnd", (*Console).stub)
class.DefineMethod("info", (*Console).info)
class.DefineMethod("log", (*Console)._log)
class.DefineMethod("profile", (*Console).stub)
class.DefineMethod("profileEnd", (*Console).profileEnd)
class.DefineMethod("table", (*Console).table)
class.DefineMethod("time", (*Console).time)
class.DefineMethod("timeEnd", (*Console).timeEnd)
class.DefineMethod("timeLog", (*Console).timeLog)
class.DefineMethod("timeStamp", (*Console).stub)
class.DefineMethod("trace", (*Console).trace)
class.DefineMethod("warn", (*Console).warn)
return class
}
func (c *Console) stub(call goja.FunctionCall) any {
return goja.Undefined()
}
func (c *Console) assert(call goja.FunctionCall) any {
assertion := call.Argument(0).ToBoolean()
if !assertion {
return c.log(logger.ContextLogger.ErrorContext, call.Arguments[1:])
}
return goja.Undefined()
}
func (c *Console) clear(call goja.FunctionCall) any {
return nil
}
func (c *Console) count(call goja.FunctionCall) any {
label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true)
if label == "" {
label = "default"
}
c.access.Lock()
newValue := c.countMap[label] + 1
c.countMap[label] = newValue
c.access.Unlock()
writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, F.ToString(label, ": ", newValue))
return goja.Undefined()
}
func (c *Console) countReset(call goja.FunctionCall) any {
label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true)
if label == "" {
label = "default"
}
c.access.Lock()
delete(c.countMap, label)
c.access.Unlock()
return goja.Undefined()
}
func (c *Console) log(logFunc func(logger.ContextLogger, context.Context, ...any), args []goja.Value) any {
var buffer bytes.Buffer
var formatString string
if len(args) > 0 {
formatString = args[0].String()
}
format(c.class.Runtime(), &buffer, formatString, args[1:]...)
writeLog(c.class.Runtime(), logFunc, buffer.String())
return goja.Undefined()
}
func (c *Console) debug(call goja.FunctionCall) any {
return c.log(logger.ContextLogger.DebugContext, call.Arguments)
}
func (c *Console) dir(call goja.FunctionCall) any {
object := jsc.AssertObject(c.class.Runtime(), call.Argument(0), "object", false)
var buffer bytes.Buffer
for _, key := range object.Keys() {
value := object.Get(key)
buffer.WriteString(key)
buffer.WriteString(": ")
buffer.WriteString(value.String())
buffer.WriteString("\n")
}
writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, buffer.String())
return goja.Undefined()
}
func (c *Console) dirxml(call goja.FunctionCall) any {
var buffer bytes.Buffer
encoder := xml.NewEncoder(&buffer)
encoder.Indent("", " ")
encoder.Encode(call.Argument(0).Export())
writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, buffer.String())
return goja.Undefined()
}
func (c *Console) error(call goja.FunctionCall) any {
return c.log(logger.ContextLogger.ErrorContext, call.Arguments)
}
func (c *Console) info(call goja.FunctionCall) any {
return c.log(logger.ContextLogger.InfoContext, call.Arguments)
}
func (c *Console) _log(call goja.FunctionCall) any {
return c.log(logger.ContextLogger.InfoContext, call.Arguments)
}
func (c *Console) profileEnd(call goja.FunctionCall) any {
return goja.Undefined()
}
func (c *Console) table(call goja.FunctionCall) any {
return c.dir(call)
}
func (c *Console) time(call goja.FunctionCall) any {
label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true)
if label == "" {
label = "default"
}
c.access.Lock()
c.timeMap[label] = time.Now()
c.access.Unlock()
return goja.Undefined()
}
func (c *Console) timeEnd(call goja.FunctionCall) any {
label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true)
if label == "" {
label = "default"
}
c.access.Lock()
startTime, ok := c.timeMap[label]
if !ok {
c.access.Unlock()
return goja.Undefined()
}
delete(c.timeMap, label)
c.access.Unlock()
writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, F.ToString(label, ": ", time.Since(startTime).String(), " - - timer ended"))
return goja.Undefined()
}
func (c *Console) timeLog(call goja.FunctionCall) any {
label := jsc.AssertString(c.class.Runtime(), call.Argument(0), "label", true)
if label == "" {
label = "default"
}
c.access.Lock()
startTime, ok := c.timeMap[label]
c.access.Unlock()
if !ok {
writeLog(c.class.Runtime(), logger.ContextLogger.ErrorContext, F.ToString("Timer \"", label, "\" doesn't exist."))
return goja.Undefined()
}
writeLog(c.class.Runtime(), logger.ContextLogger.InfoContext, F.ToString(label, ": ", time.Since(startTime)))
return goja.Undefined()
}
func (c *Console) trace(call goja.FunctionCall) any {
return c.log(logger.ContextLogger.TraceContext, call.Arguments)
}
func (c *Console) warn(call goja.FunctionCall) any {
return c.log(logger.ContextLogger.WarnContext, call.Arguments)
}
func writeLog(runtime *goja.Runtime, logFunc func(logger.ContextLogger, context.Context, ...any), message string) {
var (
ctx context.Context
sLogger logger.ContextLogger
)
boxCtx := boxctx.FromRuntime(runtime)
if boxCtx != nil {
ctx = boxCtx.Context
sLogger = boxCtx.Logger
} else {
ctx = context.Background()
sLogger = sLog.StdLogger()
}
logFunc(sLogger, ctx, message)
}
func format(runtime *goja.Runtime, b *bytes.Buffer, f string, args ...goja.Value) {
pct := false
argNum := 0
for _, chr := range f {
if pct {
if argNum < len(args) {
if format1(runtime, chr, args[argNum], b) {
argNum++
}
} else {
b.WriteByte('%')
b.WriteRune(chr)
}
pct = false
} else {
if chr == '%' {
pct = true
} else {
b.WriteRune(chr)
}
}
}
for _, arg := range args[argNum:] {
b.WriteByte(' ')
b.WriteString(arg.String())
}
}
func format1(runtime *goja.Runtime, f rune, val goja.Value, w *bytes.Buffer) bool {
switch f {
case 's':
w.WriteString(val.String())
case 'd':
w.WriteString(val.ToNumber().String())
case 'j':
if json, ok := runtime.Get("JSON").(*goja.Object); ok {
if stringify, ok := goja.AssertFunction(json.Get("stringify")); ok {
res, err := stringify(json, val)
if err != nil {
panic(err)
}
w.WriteString(res.String())
}
}
case '%':
w.WriteByte('%')
return false
default:
w.WriteByte('%')
w.WriteRune(f)
return false
}
return true
}

View File

@@ -0,0 +1,3 @@
package console
type Context struct{}

View File

@@ -0,0 +1,34 @@
package console
import (
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing-box/script/modules/require"
"github.com/dop251/goja"
)
const ModuleName = "console"
type Module struct {
runtime *goja.Runtime
console jsc.Class[*Module, *Console]
}
func Require(runtime *goja.Runtime, module *goja.Object) {
m := &Module{
runtime: runtime,
}
m.console = createConsole(m)
exports := module.Get("exports").(*goja.Object)
exports.Set("Console", m.console.ToValue())
}
func Enable(runtime *goja.Runtime) {
exports := require.Require(runtime, ModuleName).ToObject(runtime)
classConsole := jsc.GetClass[*Module, *Console](runtime, exports, "Console")
runtime.Set("console", NewConsole(classConsole))
}
func (m *Module) Runtime() *goja.Runtime {
return m.runtime
}

View File

@@ -0,0 +1,489 @@
package eventloop
import (
"sync"
"sync/atomic"
"time"
"github.com/dop251/goja"
)
type job struct {
cancel func() bool
fn func()
idx int
cancelled bool
}
type Timer struct {
job
timer *time.Timer
}
type Interval struct {
job
ticker *time.Ticker
stopChan chan struct{}
}
type Immediate struct {
job
}
type EventLoop struct {
vm *goja.Runtime
jobChan chan func()
jobs []*job
jobCount int32
canRun int32
auxJobsLock sync.Mutex
wakeupChan chan struct{}
auxJobsSpare, auxJobs []func()
stopLock sync.Mutex
stopCond *sync.Cond
running bool
terminated bool
errorHandler func(error)
}
func Enable(runtime *goja.Runtime, errorHandler func(error)) *EventLoop {
loop := &EventLoop{
vm: runtime,
jobChan: make(chan func()),
wakeupChan: make(chan struct{}, 1),
errorHandler: errorHandler,
}
loop.stopCond = sync.NewCond(&loop.stopLock)
runtime.Set("setTimeout", loop.setTimeout)
runtime.Set("setInterval", loop.setInterval)
runtime.Set("setImmediate", loop.setImmediate)
runtime.Set("clearTimeout", loop.clearTimeout)
runtime.Set("clearInterval", loop.clearInterval)
runtime.Set("clearImmediate", loop.clearImmediate)
return loop
}
func (loop *EventLoop) schedule(call goja.FunctionCall, repeating bool) goja.Value {
if fn, ok := goja.AssertFunction(call.Argument(0)); ok {
delay := call.Argument(1).ToInteger()
var args []goja.Value
if len(call.Arguments) > 2 {
args = append(args, call.Arguments[2:]...)
}
f := func() {
_, err := fn(nil, args...)
if err != nil {
loop.errorHandler(err)
}
}
loop.jobCount++
var job *job
var ret goja.Value
if repeating {
interval := loop.newInterval(f)
interval.start(loop, time.Duration(delay)*time.Millisecond)
job = &interval.job
ret = loop.vm.ToValue(interval)
} else {
timeout := loop.newTimeout(f)
timeout.start(loop, time.Duration(delay)*time.Millisecond)
job = &timeout.job
ret = loop.vm.ToValue(timeout)
}
job.idx = len(loop.jobs)
loop.jobs = append(loop.jobs, job)
return ret
}
return nil
}
func (loop *EventLoop) setTimeout(call goja.FunctionCall) goja.Value {
return loop.schedule(call, false)
}
func (loop *EventLoop) setInterval(call goja.FunctionCall) goja.Value {
return loop.schedule(call, true)
}
func (loop *EventLoop) setImmediate(call goja.FunctionCall) goja.Value {
if fn, ok := goja.AssertFunction(call.Argument(0)); ok {
var args []goja.Value
if len(call.Arguments) > 1 {
args = append(args, call.Arguments[1:]...)
}
f := func() {
_, err := fn(nil, args...)
if err != nil {
loop.errorHandler(err)
}
}
loop.jobCount++
return loop.vm.ToValue(loop.addImmediate(f))
}
return nil
}
// SetTimeout schedules to run the specified function in the context
// of the loop as soon as possible after the specified timeout period.
// SetTimeout returns a Timer which can be passed to ClearTimeout.
// The instance of goja.Runtime that is passed to the function and any Values derived
// from it must not be used outside the function. SetTimeout is
// safe to call inside or outside the loop.
// If the loop is terminated (see Terminate()) returns nil.
func (loop *EventLoop) SetTimeout(fn func(*goja.Runtime), timeout time.Duration) *Timer {
t := loop.newTimeout(func() { fn(loop.vm) })
if loop.addAuxJob(func() {
t.start(loop, timeout)
loop.jobCount++
t.idx = len(loop.jobs)
loop.jobs = append(loop.jobs, &t.job)
}) {
return t
}
return nil
}
// ClearTimeout cancels a Timer returned by SetTimeout if it has not run yet.
// ClearTimeout is safe to call inside or outside the loop.
func (loop *EventLoop) ClearTimeout(t *Timer) {
loop.addAuxJob(func() {
loop.clearTimeout(t)
})
}
// SetInterval schedules to repeatedly run the specified function in
// the context of the loop as soon as possible after every specified
// timeout period. SetInterval returns an Interval which can be
// passed to ClearInterval. The instance of goja.Runtime that is passed to the
// function and any Values derived from it must not be used outside
// the function. SetInterval is safe to call inside or outside the
// loop.
// If the loop is terminated (see Terminate()) returns nil.
func (loop *EventLoop) SetInterval(fn func(*goja.Runtime), timeout time.Duration) *Interval {
i := loop.newInterval(func() { fn(loop.vm) })
if loop.addAuxJob(func() {
i.start(loop, timeout)
loop.jobCount++
i.idx = len(loop.jobs)
loop.jobs = append(loop.jobs, &i.job)
}) {
return i
}
return nil
}
// ClearInterval cancels an Interval returned by SetInterval.
// ClearInterval is safe to call inside or outside the loop.
func (loop *EventLoop) ClearInterval(i *Interval) {
loop.addAuxJob(func() {
loop.clearInterval(i)
})
}
func (loop *EventLoop) setRunning() {
loop.stopLock.Lock()
defer loop.stopLock.Unlock()
if loop.running {
panic("Loop is already started")
}
loop.running = true
atomic.StoreInt32(&loop.canRun, 1)
loop.auxJobsLock.Lock()
loop.terminated = false
loop.auxJobsLock.Unlock()
}
// Run calls the specified function, starts the event loop and waits until there are no more delayed jobs to run
// after which it stops the loop and returns.
// The instance of goja.Runtime that is passed to the function and any Values derived from it must not be used
// outside the function.
// Do NOT use this function while the loop is already running. Use RunOnLoop() instead.
// If the loop is already started it will panic.
func (loop *EventLoop) Run(fn func(*goja.Runtime)) {
loop.setRunning()
fn(loop.vm)
loop.run(false)
}
// Start the event loop in the background. The loop continues to run until Stop() is called.
// If the loop is already started it will panic.
func (loop *EventLoop) Start() {
loop.setRunning()
go loop.run(true)
}
// StartInForeground starts the event loop in the current goroutine. The loop continues to run until Stop() is called.
// If the loop is already started it will panic.
// Use this instead of Start if you want to recover from panics that may occur while calling native Go functions from
// within setInterval and setTimeout callbacks.
func (loop *EventLoop) StartInForeground() {
loop.setRunning()
loop.run(true)
}
// Stop the loop that was started with Start(). After this function returns there will be no more jobs executed
// by the loop. It is possible to call Start() or Run() again after this to resume the execution.
// Note, it does not cancel active timeouts (use Terminate() instead if you want this).
// It is not allowed to run Start() (or Run()) and Stop() or Terminate() concurrently.
// Calling Stop() on a non-running loop has no effect.
// It is not allowed to call Stop() from the loop, because it is synchronous and cannot complete until the loop
// is not running any jobs. Use StopNoWait() instead.
// return number of jobs remaining
func (loop *EventLoop) Stop() int {
loop.stopLock.Lock()
for loop.running {
atomic.StoreInt32(&loop.canRun, 0)
loop.wakeup()
loop.stopCond.Wait()
}
loop.stopLock.Unlock()
return int(loop.jobCount)
}
// StopNoWait tells the loop to stop and returns immediately. Can be used inside the loop. Calling it on a
// non-running loop has no effect.
func (loop *EventLoop) StopNoWait() {
loop.stopLock.Lock()
if loop.running {
atomic.StoreInt32(&loop.canRun, 0)
loop.wakeup()
}
loop.stopLock.Unlock()
}
// Terminate stops the loop and clears all active timeouts and intervals. After it returns there are no
// active timers or goroutines associated with the loop. Any attempt to submit a task (by using RunOnLoop(),
// SetTimeout() or SetInterval()) will not succeed.
// After being terminated the loop can be restarted again by using Start() or Run().
// This method must not be called concurrently with Stop*(), Start(), or Run().
func (loop *EventLoop) Terminate() {
loop.Stop()
loop.auxJobsLock.Lock()
loop.terminated = true
loop.auxJobsLock.Unlock()
loop.runAux()
for i := 0; i < len(loop.jobs); i++ {
job := loop.jobs[i]
if !job.cancelled {
job.cancelled = true
if job.cancel() {
loop.removeJob(job)
i--
}
}
}
for len(loop.jobs) > 0 {
(<-loop.jobChan)()
}
}
// RunOnLoop schedules to run the specified function in the context of the loop as soon as possible.
// The order of the runs is preserved (i.e. the functions will be called in the same order as calls to RunOnLoop())
// The instance of goja.Runtime that is passed to the function and any Values derived from it must not be used
// outside the function. It is safe to call inside or outside the loop.
// Returns true on success or false if the loop is terminated (see Terminate()).
func (loop *EventLoop) RunOnLoop(fn func(*goja.Runtime)) bool {
return loop.addAuxJob(func() { fn(loop.vm) })
}
func (loop *EventLoop) runAux() {
loop.auxJobsLock.Lock()
jobs := loop.auxJobs
loop.auxJobs = loop.auxJobsSpare
loop.auxJobsLock.Unlock()
for i, job := range jobs {
job()
jobs[i] = nil
}
loop.auxJobsSpare = jobs[:0]
}
func (loop *EventLoop) run(inBackground bool) {
loop.runAux()
if inBackground {
loop.jobCount++
}
LOOP:
for loop.jobCount > 0 {
select {
case job := <-loop.jobChan:
job()
case <-loop.wakeupChan:
loop.runAux()
if atomic.LoadInt32(&loop.canRun) == 0 {
break LOOP
}
}
}
if inBackground {
loop.jobCount--
}
loop.stopLock.Lock()
loop.running = false
loop.stopLock.Unlock()
loop.stopCond.Broadcast()
}
func (loop *EventLoop) wakeup() {
select {
case loop.wakeupChan <- struct{}{}:
default:
}
}
func (loop *EventLoop) addAuxJob(fn func()) bool {
loop.auxJobsLock.Lock()
if loop.terminated {
loop.auxJobsLock.Unlock()
return false
}
loop.auxJobs = append(loop.auxJobs, fn)
loop.auxJobsLock.Unlock()
loop.wakeup()
return true
}
func (loop *EventLoop) newTimeout(f func()) *Timer {
t := &Timer{
job: job{fn: f},
}
t.cancel = t.doCancel
return t
}
func (t *Timer) start(loop *EventLoop, timeout time.Duration) {
t.timer = time.AfterFunc(timeout, func() {
loop.jobChan <- func() {
loop.doTimeout(t)
}
})
}
func (loop *EventLoop) newInterval(f func()) *Interval {
i := &Interval{
job: job{fn: f},
stopChan: make(chan struct{}),
}
i.cancel = i.doCancel
return i
}
func (i *Interval) start(loop *EventLoop, timeout time.Duration) {
// https://nodejs.org/api/timers.html#timers_setinterval_callback_delay_args
if timeout <= 0 {
timeout = time.Millisecond
}
i.ticker = time.NewTicker(timeout)
go i.run(loop)
}
func (loop *EventLoop) addImmediate(f func()) *Immediate {
i := &Immediate{
job: job{fn: f},
}
loop.addAuxJob(func() {
loop.doImmediate(i)
})
return i
}
func (loop *EventLoop) doTimeout(t *Timer) {
loop.removeJob(&t.job)
if !t.cancelled {
t.cancelled = true
loop.jobCount--
t.fn()
}
}
func (loop *EventLoop) doInterval(i *Interval) {
if !i.cancelled {
i.fn()
}
}
func (loop *EventLoop) doImmediate(i *Immediate) {
if !i.cancelled {
i.cancelled = true
loop.jobCount--
i.fn()
}
}
func (loop *EventLoop) clearTimeout(t *Timer) {
if t != nil && !t.cancelled {
t.cancelled = true
loop.jobCount--
if t.doCancel() {
loop.removeJob(&t.job)
}
}
}
func (loop *EventLoop) clearInterval(i *Interval) {
if i != nil && !i.cancelled {
i.cancelled = true
loop.jobCount--
i.doCancel()
}
}
func (loop *EventLoop) removeJob(job *job) {
idx := job.idx
if idx < 0 {
return
}
if idx < len(loop.jobs)-1 {
loop.jobs[idx] = loop.jobs[len(loop.jobs)-1]
loop.jobs[idx].idx = idx
}
loop.jobs[len(loop.jobs)-1] = nil
loop.jobs = loop.jobs[:len(loop.jobs)-1]
job.idx = -1
}
func (loop *EventLoop) clearImmediate(i *Immediate) {
if i != nil && !i.cancelled {
i.cancelled = true
loop.jobCount--
}
}
func (i *Interval) doCancel() bool {
close(i.stopChan)
return false
}
func (t *Timer) doCancel() bool {
return t.timer.Stop()
}
func (i *Interval) run(loop *EventLoop) {
L:
for {
select {
case <-i.stopChan:
i.ticker.Stop()
break L
case <-i.ticker.C:
loop.jobChan <- func() {
loop.doInterval(i)
}
}
}
loop.jobChan <- func() {
loop.removeJob(&i.job)
}
}

View File

@@ -0,0 +1,231 @@
package require
import (
"errors"
"io"
"io/fs"
"os"
"path"
"path/filepath"
"runtime"
"sync"
"syscall"
"text/template"
js "github.com/dop251/goja"
"github.com/dop251/goja/parser"
)
type ModuleLoader func(*js.Runtime, *js.Object)
// SourceLoader represents a function that returns a file data at a given path.
// The function should return ModuleFileDoesNotExistError if the file either doesn't exist or is a directory.
// This error will be ignored by the resolver and the search will continue. Any other errors will be propagated.
type SourceLoader func(path string) ([]byte, error)
var (
InvalidModuleError = errors.New("Invalid module")
IllegalModuleNameError = errors.New("Illegal module name")
NoSuchBuiltInModuleError = errors.New("No such built-in module")
ModuleFileDoesNotExistError = errors.New("module file does not exist")
)
// Registry contains a cache of compiled modules which can be used by multiple Runtimes
type Registry struct {
sync.Mutex
native map[string]ModuleLoader
builtin map[string]ModuleLoader
compiled map[string]*js.Program
srcLoader SourceLoader
globalFolders []string
fsEnabled bool
}
type RequireModule struct {
r *Registry
runtime *js.Runtime
modules map[string]*js.Object
nodeModules map[string]*js.Object
}
func NewRegistry(opts ...Option) *Registry {
r := &Registry{}
for _, opt := range opts {
opt(r)
}
return r
}
type Option func(*Registry)
// WithLoader sets a function which will be called by the require() function in order to get a source code for a
// module at the given path. The same function will be used to get external source maps.
// Note, this only affects the modules loaded by the require() function. If you need to use it as a source map
// loader for code parsed in a different way (such as runtime.RunString() or eval()), use (*Runtime).SetParserOptions()
func WithLoader(srcLoader SourceLoader) Option {
return func(r *Registry) {
r.srcLoader = srcLoader
}
}
// WithGlobalFolders appends the given paths to the registry's list of
// global folders to search if the requested module is not found
// elsewhere. By default, a registry's global folders list is empty.
// In the reference Node.js implementation, the default global folders
// list is $NODE_PATH, $HOME/.node_modules, $HOME/.node_libraries and
// $PREFIX/lib/node, see
// https://nodejs.org/api/modules.html#modules_loading_from_the_global_folders.
func WithGlobalFolders(globalFolders ...string) Option {
return func(r *Registry) {
r.globalFolders = globalFolders
}
}
func WithFsEnable(enabled bool) Option {
return func(r *Registry) {
r.fsEnabled = enabled
}
}
// Enable adds the require() function to the specified runtime.
func (r *Registry) Enable(runtime *js.Runtime) *RequireModule {
rrt := &RequireModule{
r: r,
runtime: runtime,
modules: make(map[string]*js.Object),
nodeModules: make(map[string]*js.Object),
}
runtime.Set("require", rrt.require)
return rrt
}
func (r *Registry) RegisterNodeModule(name string, loader ModuleLoader) {
r.Lock()
defer r.Unlock()
if r.builtin == nil {
r.builtin = make(map[string]ModuleLoader)
}
name = filepathClean(name)
r.builtin[name] = loader
}
func (r *Registry) RegisterNativeModule(name string, loader ModuleLoader) {
r.Lock()
defer r.Unlock()
if r.native == nil {
r.native = make(map[string]ModuleLoader)
}
name = filepathClean(name)
r.native[name] = loader
}
// DefaultSourceLoader is used if none was set (see WithLoader()). It simply loads files from the host's filesystem.
func DefaultSourceLoader(filename string) ([]byte, error) {
fp := filepath.FromSlash(filename)
f, err := os.Open(fp)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
err = ModuleFileDoesNotExistError
} else if runtime.GOOS == "windows" {
if errors.Is(err, syscall.Errno(0x7b)) { // ERROR_INVALID_NAME, The filename, directory name, or volume label syntax is incorrect.
err = ModuleFileDoesNotExistError
}
}
return nil, err
}
defer f.Close()
// On some systems (e.g. plan9 and FreeBSD) it is possible to use the standard read() call on directories
// which means we cannot rely on read() returning an error, we have to do stat() instead.
if fi, err := f.Stat(); err == nil {
if fi.IsDir() {
return nil, ModuleFileDoesNotExistError
}
} else {
return nil, err
}
return io.ReadAll(f)
}
func (r *Registry) getSource(p string) ([]byte, error) {
srcLoader := r.srcLoader
if srcLoader == nil {
srcLoader = DefaultSourceLoader
}
return srcLoader(p)
}
func (r *Registry) getCompiledSource(p string) (*js.Program, error) {
r.Lock()
defer r.Unlock()
prg := r.compiled[p]
if prg == nil {
buf, err := r.getSource(p)
if err != nil {
return nil, err
}
s := string(buf)
if path.Ext(p) == ".json" {
s = "module.exports = JSON.parse('" + template.JSEscapeString(s) + "')"
}
source := "(function(exports, require, module) {" + s + "\n})"
parsed, err := js.Parse(p, source, parser.WithSourceMapLoader(r.srcLoader))
if err != nil {
return nil, err
}
prg, err = js.CompileAST(parsed, false)
if err == nil {
if r.compiled == nil {
r.compiled = make(map[string]*js.Program)
}
r.compiled[p] = prg
}
return prg, err
}
return prg, nil
}
func (r *RequireModule) require(call js.FunctionCall) js.Value {
ret, err := r.Require(call.Argument(0).String())
if err != nil {
if _, ok := err.(*js.Exception); !ok {
panic(r.runtime.NewGoError(err))
}
panic(err)
}
return ret
}
func filepathClean(p string) string {
return path.Clean(p)
}
// Require can be used to import modules from Go source (similar to JS require() function).
func (r *RequireModule) Require(p string) (ret js.Value, err error) {
module, err := r.resolve(p)
if err != nil {
return
}
ret = module.Get("exports")
return
}
func Require(runtime *js.Runtime, name string) js.Value {
if r, ok := js.AssertFunction(runtime.Get("require")); ok {
mod, err := r(js.Undefined(), runtime.ToValue(name))
if err != nil {
panic(err)
}
return mod
}
panic(runtime.NewTypeError("Please enable require for this runtime using new(require.Registry).Enable(runtime)"))
}

View File

@@ -0,0 +1,277 @@
package require
import (
"encoding/json"
"errors"
"path"
"path/filepath"
"runtime"
"strings"
js "github.com/dop251/goja"
)
const NodePrefix = "node:"
// NodeJS module search algorithm described by
// https://nodejs.org/api/modules.html#modules_all_together
func (r *RequireModule) resolve(modpath string) (module *js.Object, err error) {
origPath, modpath := modpath, filepathClean(modpath)
if modpath == "" {
return nil, IllegalModuleNameError
}
var start string
err = nil
if path.IsAbs(origPath) {
start = "/"
} else {
start = r.getCurrentModulePath()
}
p := path.Join(start, modpath)
if isFileOrDirectoryPath(origPath) && r.r.fsEnabled {
if module = r.modules[p]; module != nil {
return
}
module, err = r.loadAsFileOrDirectory(p)
if err == nil && module != nil {
r.modules[p] = module
}
} else {
module, err = r.loadNative(origPath)
if err == nil {
return
} else {
if err == InvalidModuleError {
err = nil
} else {
return
}
}
if module = r.nodeModules[p]; module != nil {
return
}
if r.r.fsEnabled {
module, err = r.loadNodeModules(modpath, start)
if err == nil && module != nil {
r.nodeModules[p] = module
}
}
}
if module == nil && err == nil {
err = InvalidModuleError
}
return
}
func (r *RequireModule) loadNative(path string) (*js.Object, error) {
module := r.modules[path]
if module != nil {
return module, nil
}
var ldr ModuleLoader
if r.r.native != nil {
ldr = r.r.native[path]
}
var isBuiltIn, withPrefix bool
if ldr == nil {
if r.r.builtin != nil {
ldr = r.r.builtin[path]
}
if ldr == nil && strings.HasPrefix(path, NodePrefix) {
ldr = r.r.builtin[path[len(NodePrefix):]]
if ldr == nil {
return nil, NoSuchBuiltInModuleError
}
withPrefix = true
}
isBuiltIn = true
}
if ldr != nil {
module = r.createModuleObject()
r.modules[path] = module
if isBuiltIn {
if withPrefix {
r.modules[path[len(NodePrefix):]] = module
} else {
if !strings.HasPrefix(path, NodePrefix) {
r.modules[NodePrefix+path] = module
}
}
}
ldr(r.runtime, module)
return module, nil
}
return nil, InvalidModuleError
}
func (r *RequireModule) loadAsFileOrDirectory(path string) (module *js.Object, err error) {
if module, err = r.loadAsFile(path); module != nil || err != nil {
return
}
return r.loadAsDirectory(path)
}
func (r *RequireModule) loadAsFile(path string) (module *js.Object, err error) {
if module, err = r.loadModule(path); module != nil || err != nil {
return
}
p := path + ".js"
if module, err = r.loadModule(p); module != nil || err != nil {
return
}
p = path + ".json"
return r.loadModule(p)
}
func (r *RequireModule) loadIndex(modpath string) (module *js.Object, err error) {
p := path.Join(modpath, "index.js")
if module, err = r.loadModule(p); module != nil || err != nil {
return
}
p = path.Join(modpath, "index.json")
return r.loadModule(p)
}
func (r *RequireModule) loadAsDirectory(modpath string) (module *js.Object, err error) {
p := path.Join(modpath, "package.json")
buf, err := r.r.getSource(p)
if err != nil {
return r.loadIndex(modpath)
}
var pkg struct {
Main string
}
err = json.Unmarshal(buf, &pkg)
if err != nil || len(pkg.Main) == 0 {
return r.loadIndex(modpath)
}
m := path.Join(modpath, pkg.Main)
if module, err = r.loadAsFile(m); module != nil || err != nil {
return
}
return r.loadIndex(m)
}
func (r *RequireModule) loadNodeModule(modpath, start string) (*js.Object, error) {
return r.loadAsFileOrDirectory(path.Join(start, modpath))
}
func (r *RequireModule) loadNodeModules(modpath, start string) (module *js.Object, err error) {
for _, dir := range r.r.globalFolders {
if module, err = r.loadNodeModule(modpath, dir); module != nil || err != nil {
return
}
}
for {
var p string
if path.Base(start) != "node_modules" {
p = path.Join(start, "node_modules")
} else {
p = start
}
if module, err = r.loadNodeModule(modpath, p); module != nil || err != nil {
return
}
if start == ".." { // Dir('..') is '.'
break
}
parent := path.Dir(start)
if parent == start {
break
}
start = parent
}
return
}
func (r *RequireModule) getCurrentModulePath() string {
var buf [2]js.StackFrame
frames := r.runtime.CaptureCallStack(2, buf[:0])
if len(frames) < 2 {
return "."
}
return path.Dir(frames[1].SrcName())
}
func (r *RequireModule) createModuleObject() *js.Object {
module := r.runtime.NewObject()
module.Set("exports", r.runtime.NewObject())
return module
}
func (r *RequireModule) loadModule(path string) (*js.Object, error) {
module := r.modules[path]
if module == nil {
module = r.createModuleObject()
r.modules[path] = module
err := r.loadModuleFile(path, module)
if err != nil {
module = nil
delete(r.modules, path)
if errors.Is(err, ModuleFileDoesNotExistError) {
err = nil
}
}
return module, err
}
return module, nil
}
func (r *RequireModule) loadModuleFile(path string, jsModule *js.Object) error {
prg, err := r.r.getCompiledSource(path)
if err != nil {
return err
}
f, err := r.runtime.RunProgram(prg)
if err != nil {
return err
}
if call, ok := js.AssertFunction(f); ok {
jsExports := jsModule.Get("exports")
jsRequire := r.runtime.Get("require")
// Run the module source, with "jsExports" as "this",
// "jsExports" as the "exports" variable, "jsRequire"
// as the "require" variable and "jsModule" as the
// "module" variable (Nodejs capable).
_, err = call(jsExports, jsExports, jsRequire, jsModule)
if err != nil {
return err
}
} else {
return InvalidModuleError
}
return nil
}
func isFileOrDirectoryPath(path string) bool {
result := path == "." || path == ".." ||
strings.HasPrefix(path, "/") ||
strings.HasPrefix(path, "./") ||
strings.HasPrefix(path, "../")
if runtime.GOOS == "windows" {
result = result ||
strings.HasPrefix(path, `.\`) ||
strings.HasPrefix(path, `..\`) ||
filepath.IsAbs(path)
}
return result
}

View File

@@ -0,0 +1,111 @@
package sgnotification
import (
"context"
"encoding/base64"
"strings"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/script/jsc"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
"github.com/dop251/goja"
)
type SurgeNotification struct {
vm *goja.Runtime
logger logger.Logger
platformInterface platform.Interface
scriptTag string
}
func Enable(vm *goja.Runtime, ctx context.Context, logger logger.Logger) {
platformInterface := service.FromContext[platform.Interface](ctx)
notification := &SurgeNotification{
vm: vm,
logger: logger,
platformInterface: platformInterface,
}
notificationObject := vm.NewObject()
notificationObject.Set("post", notification.js_post)
vm.Set("$notification", notificationObject)
}
func (s *SurgeNotification) js_post(call goja.FunctionCall) goja.Value {
var (
title string
subtitle string
body string
openURL string
clipboard string
mediaURL string
mediaData []byte
mediaType string
autoDismiss int
)
title = jsc.AssertString(s.vm, call.Argument(0), "title", true)
subtitle = jsc.AssertString(s.vm, call.Argument(1), "subtitle", true)
body = jsc.AssertString(s.vm, call.Argument(2), "body", true)
options := jsc.AssertObject(s.vm, call.Argument(3), "options", true)
if options != nil {
action := jsc.AssertString(s.vm, options.Get("action"), "options.action", true)
switch action {
case "open-url":
openURL = jsc.AssertString(s.vm, options.Get("url"), "options.url", false)
case "clipboard":
clipboard = jsc.AssertString(s.vm, options.Get("clipboard"), "options.clipboard", false)
}
mediaURL = jsc.AssertString(s.vm, options.Get("media-url"), "options.media-url", true)
mediaBase64 := jsc.AssertString(s.vm, options.Get("media-base64"), "options.media-base64", true)
if mediaBase64 != "" {
mediaBinary, err := base64.StdEncoding.DecodeString(mediaBase64)
if err != nil {
panic(s.vm.NewGoError(E.Cause(err, "decode media-base64")))
}
mediaData = mediaBinary
mediaType = jsc.AssertString(s.vm, options.Get("media-base64-mime"), "options.media-base64-mime", false)
}
autoDismiss = int(jsc.AssertInt(s.vm, options.Get("auto-dismiss"), "options.auto-dismiss", true))
}
if title != "" && subtitle == "" && body == "" {
body = title
title = ""
} else if title != "" && subtitle != "" && body == "" {
body = subtitle
subtitle = ""
}
var builder strings.Builder
if title != "" {
builder.WriteString("[")
builder.WriteString(title)
if subtitle != "" {
builder.WriteString(" - ")
builder.WriteString(subtitle)
}
builder.WriteString("]: ")
}
builder.WriteString(body)
s.logger.Info("notification: " + builder.String())
if s.platformInterface != nil {
err := s.platformInterface.SendNotification(&platform.Notification{
Identifier: "surge-script-notification-" + s.scriptTag,
TypeName: "Surge Script Notification (" + s.scriptTag + ")",
TypeID: 11,
Title: title,
Subtitle: subtitle,
Body: body,
OpenURL: openURL,
Clipboard: clipboard,
MediaURL: mediaURL,
MediaData: mediaData,
MediaType: mediaType,
Timeout: autoDismiss,
})
if err != nil {
s.logger.Error(E.Cause(err, "send notification"))
}
}
return goja.Undefined()
}

View File

@@ -0,0 +1,65 @@
package surge
import (
"runtime"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/locale"
"github.com/sagernet/sing-box/script/jsc"
"github.com/dop251/goja"
)
type Environment struct {
class jsc.Class[*Module, *Environment]
}
func createEnvironment(module *Module) jsc.Class[*Module, *Environment] {
class := jsc.NewClass[*Module, *Environment](module)
class.DefineField("system", (*Environment).getSystem, nil)
class.DefineField("surge-build", (*Environment).getSurgeBuild, nil)
class.DefineField("surge-version", (*Environment).getSurgeVersion, nil)
class.DefineField("language", (*Environment).getLanguage, nil)
class.DefineField("device-model", (*Environment).getDeviceModel, nil)
class.DefineMethod("toString", (*Environment).toString)
return class
}
func (e *Environment) getSystem() any {
switch runtime.GOOS {
case "ios":
return "iOS"
case "darwin":
return "macOS"
case "tvos":
return "tvOS"
case "linux":
return "Linux"
case "android":
return "Android"
case "windows":
return "Windows"
default:
return runtime.GOOS
}
}
func (e *Environment) getSurgeBuild() any {
return "N/A"
}
func (e *Environment) getSurgeVersion() any {
return "sing-box " + C.Version
}
func (e *Environment) getLanguage() any {
return locale.Current().Locale
}
func (e *Environment) getDeviceModel() any {
return "N/A"
}
func (e *Environment) toString(call goja.FunctionCall) any {
return "[sing-box Surge environment"
}

View File

@@ -0,0 +1,150 @@
package surge
import (
"bytes"
"crypto/tls"
"io"
"net/http"
"net/http/cookiejar"
"time"
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing-box/script/modules/boxctx"
"github.com/sagernet/sing/common"
F "github.com/sagernet/sing/common/format"
"github.com/dop251/goja"
"golang.org/x/net/publicsuffix"
)
type HTTP struct {
class jsc.Class[*Module, *HTTP]
cookieJar *cookiejar.Jar
httpTransport *http.Transport
}
func createHTTP(module *Module) jsc.Class[*Module, *HTTP] {
class := jsc.NewClass[*Module, *HTTP](module)
class.DefineConstructor(newHTTP)
class.DefineMethod("get", httpRequest(http.MethodGet))
class.DefineMethod("post", httpRequest(http.MethodPost))
class.DefineMethod("put", httpRequest(http.MethodPut))
class.DefineMethod("delete", httpRequest(http.MethodDelete))
class.DefineMethod("head", httpRequest(http.MethodHead))
class.DefineMethod("options", httpRequest(http.MethodOptions))
class.DefineMethod("patch", httpRequest(http.MethodPatch))
class.DefineMethod("trace", httpRequest(http.MethodTrace))
class.DefineMethod("toString", (*HTTP).toString)
return class
}
func newHTTP(class jsc.Class[*Module, *HTTP], call goja.ConstructorCall) *HTTP {
return &HTTP{
class: class,
cookieJar: common.Must1(cookiejar.New(&cookiejar.Options{
PublicSuffixList: publicsuffix.List,
})),
httpTransport: &http.Transport{
ForceAttemptHTTP2: true,
TLSClientConfig: &tls.Config{},
},
}
}
func httpRequest(method string) func(s *HTTP, call goja.FunctionCall) any {
return func(s *HTTP, call goja.FunctionCall) any {
if len(call.Arguments) != 2 {
panic(s.class.Runtime().NewTypeError("invalid arguments"))
}
context := boxctx.MustFromRuntime(s.class.Runtime())
var (
url string
headers http.Header
body []byte
timeout = 5 * time.Second
insecure bool
autoCookie bool = true
autoRedirect bool
// policy string
binaryMode bool
)
switch optionsValue := call.Argument(0).(type) {
case goja.String:
url = optionsValue.String()
case *goja.Object:
url = jsc.AssertString(s.class.Runtime(), optionsValue.Get("url"), "options.url", false)
headers = jsc.AssertHTTPHeader(s.class.Runtime(), optionsValue.Get("headers"), "option.headers")
body = jsc.AssertStringBinary(s.class.Runtime(), optionsValue.Get("body"), "options.body", true)
timeoutInt := jsc.AssertInt(s.class.Runtime(), optionsValue.Get("timeout"), "options.timeout", true)
if timeoutInt > 0 {
timeout = time.Duration(timeoutInt) * time.Second
}
insecure = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("insecure"), "options.insecure", true)
autoCookie = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("auto-cookie"), "options.auto-cookie", true)
autoRedirect = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("auto-redirect"), "options.auto-redirect", true)
// policy = jsc.AssertString(s.class.Runtime(), optionsValue.Get("policy"), "options.policy", true)
binaryMode = jsc.AssertBool(s.class.Runtime(), optionsValue.Get("binary-mode"), "options.binary-mode", true)
default:
panic(s.class.Runtime().NewTypeError(F.ToString("invalid argument: options: expected string or object, but got ", optionsValue)))
}
callback := jsc.AssertFunction(s.class.Runtime(), call.Argument(1), "callback")
s.httpTransport.TLSClientConfig.InsecureSkipVerify = insecure
httpClient := &http.Client{
Timeout: timeout,
Transport: s.httpTransport,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if autoRedirect {
return nil
}
return http.ErrUseLastResponse
},
}
if autoCookie {
httpClient.Jar = s.cookieJar
}
request, err := http.NewRequestWithContext(context.Context, method, url, bytes.NewReader(body))
if host := headers.Get("Host"); host != "" {
request.Host = host
headers.Del("Host")
}
request.Header = headers
if err != nil {
panic(s.class.Runtime().NewGoError(err))
}
go func() {
defer s.httpTransport.CloseIdleConnections()
response, executeErr := httpClient.Do(request)
if err != nil {
_, err = callback(nil, s.class.Runtime().NewGoError(executeErr), nil, nil)
if err != nil {
context.ErrorHandler(err)
}
return
}
defer response.Body.Close()
var content []byte
content, err = io.ReadAll(response.Body)
if err != nil {
_, err = callback(nil, s.class.Runtime().NewGoError(err), nil, nil)
if err != nil {
context.ErrorHandler(err)
}
}
responseObject := s.class.Runtime().NewObject()
responseObject.Set("status", response.StatusCode)
responseObject.Set("headers", jsc.HeadersToValue(s.class.Runtime(), response.Header))
var bodyValue goja.Value
if binaryMode {
bodyValue = jsc.NewUint8Array(s.class.Runtime(), content)
} else {
bodyValue = s.class.Runtime().ToValue(string(content))
}
_, err = callback(nil, nil, responseObject, bodyValue)
}()
return nil
}
}
func (h *HTTP) toString(call goja.FunctionCall) any {
return "[sing-box Surge HTTP]"
}

View File

@@ -0,0 +1,63 @@
package surge
import (
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing-box/script/modules/require"
"github.com/sagernet/sing/common"
"github.com/dop251/goja"
)
const ModuleName = "surge"
type Module struct {
runtime *goja.Runtime
classScript jsc.Class[*Module, *Script]
classEnvironment jsc.Class[*Module, *Environment]
classPersistentStore jsc.Class[*Module, *PersistentStore]
classHTTP jsc.Class[*Module, *HTTP]
classUtils jsc.Class[*Module, *Utils]
classNotification jsc.Class[*Module, *Notification]
}
func Require(runtime *goja.Runtime, module *goja.Object) {
m := &Module{
runtime: runtime,
}
m.classScript = createScript(m)
m.classEnvironment = createEnvironment(m)
m.classPersistentStore = createPersistentStore(m)
m.classHTTP = createHTTP(m)
m.classUtils = createUtils(m)
m.classNotification = createNotification(m)
exports := module.Get("exports").(*goja.Object)
exports.Set("Script", m.classScript.ToValue())
exports.Set("Environment", m.classEnvironment.ToValue())
exports.Set("PersistentStore", m.classPersistentStore.ToValue())
exports.Set("HTTP", m.classHTTP.ToValue())
exports.Set("Utils", m.classUtils.ToValue())
exports.Set("Notification", m.classNotification.ToValue())
}
func Enable(runtime *goja.Runtime, scriptType string, args []string) {
exports := require.Require(runtime, ModuleName).ToObject(runtime)
classScript := jsc.GetClass[*Module, *Script](runtime, exports, "Script")
classEnvironment := jsc.GetClass[*Module, *Environment](runtime, exports, "Environment")
classPersistentStore := jsc.GetClass[*Module, *PersistentStore](runtime, exports, "PersistentStore")
classHTTP := jsc.GetClass[*Module, *HTTP](runtime, exports, "HTTP")
classUtils := jsc.GetClass[*Module, *Utils](runtime, exports, "Utils")
classNotification := jsc.GetClass[*Module, *Notification](runtime, exports, "Notification")
runtime.Set("$script", classScript.New(&Script{class: classScript, ScriptType: scriptType}))
runtime.Set("$environment", classEnvironment.New(&Environment{class: classEnvironment}))
runtime.Set("$persistentStore", newPersistentStore(classPersistentStore))
runtime.Set("$http", classHTTP.New(newHTTP(classHTTP, goja.ConstructorCall{})))
runtime.Set("$utils", classUtils.New(&Utils{class: classUtils}))
runtime.Set("$notification", newNotification(classNotification))
runtime.Set("$argument", runtime.NewArray(common.Map(args, func(it string) any {
return it
})...))
}
func (m *Module) Runtime() *goja.Runtime {
return m.runtime
}

View File

@@ -0,0 +1,120 @@
package surge
import (
"encoding/base64"
"strings"
"github.com/sagernet/sing-box/experimental/libbox/platform"
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing-box/script/modules/boxctx"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
"github.com/dop251/goja"
)
type Notification struct {
class jsc.Class[*Module, *Notification]
logger logger.ContextLogger
tag string
platformInterface platform.Interface
}
func createNotification(module *Module) jsc.Class[*Module, *Notification] {
class := jsc.NewClass[*Module, *Notification](module)
class.DefineMethod("post", (*Notification).post)
class.DefineMethod("toString", (*Notification).toString)
return class
}
func newNotification(class jsc.Class[*Module, *Notification]) goja.Value {
context := boxctx.MustFromRuntime(class.Runtime())
return class.New(&Notification{
class: class,
logger: context.Logger,
tag: context.Tag,
platformInterface: service.FromContext[platform.Interface](context.Context),
})
}
func (s *Notification) post(call goja.FunctionCall) any {
var (
title string
subtitle string
body string
openURL string
clipboard string
mediaURL string
mediaData []byte
mediaType string
autoDismiss int
)
title = jsc.AssertString(s.class.Runtime(), call.Argument(0), "title", true)
subtitle = jsc.AssertString(s.class.Runtime(), call.Argument(1), "subtitle", true)
body = jsc.AssertString(s.class.Runtime(), call.Argument(2), "body", true)
options := jsc.AssertObject(s.class.Runtime(), call.Argument(3), "options", true)
if options != nil {
action := jsc.AssertString(s.class.Runtime(), options.Get("action"), "options.action", true)
switch action {
case "open-url":
openURL = jsc.AssertString(s.class.Runtime(), options.Get("url"), "options.url", false)
case "clipboard":
clipboard = jsc.AssertString(s.class.Runtime(), options.Get("clipboard"), "options.clipboard", false)
}
mediaURL = jsc.AssertString(s.class.Runtime(), options.Get("media-url"), "options.media-url", true)
mediaBase64 := jsc.AssertString(s.class.Runtime(), options.Get("media-base64"), "options.media-base64", true)
if mediaBase64 != "" {
mediaBinary, err := base64.StdEncoding.DecodeString(mediaBase64)
if err != nil {
panic(s.class.Runtime().NewGoError(E.Cause(err, "decode media-base64")))
}
mediaData = mediaBinary
mediaType = jsc.AssertString(s.class.Runtime(), options.Get("media-base64-mime"), "options.media-base64-mime", false)
}
autoDismiss = int(jsc.AssertInt(s.class.Runtime(), options.Get("auto-dismiss"), "options.auto-dismiss", true))
}
if title != "" && subtitle == "" && body == "" {
body = title
title = ""
} else if title != "" && subtitle != "" && body == "" {
body = subtitle
subtitle = ""
}
var builder strings.Builder
if title != "" {
builder.WriteString("[")
builder.WriteString(title)
if subtitle != "" {
builder.WriteString(" - ")
builder.WriteString(subtitle)
}
builder.WriteString("]: ")
}
builder.WriteString(body)
s.logger.Info("notification: " + builder.String())
if s.platformInterface != nil {
err := s.platformInterface.SendNotification(&platform.Notification{
Identifier: "surge-script-notification-" + s.tag,
TypeName: "Surge Script Notification (" + s.tag + ")",
TypeID: 11,
Title: title,
Subtitle: subtitle,
Body: body,
OpenURL: openURL,
Clipboard: clipboard,
MediaURL: mediaURL,
MediaData: mediaData,
MediaType: mediaType,
Timeout: autoDismiss,
})
if err != nil {
s.logger.Error(E.Cause(err, "send notification"))
}
}
return nil
}
func (s *Notification) toString(call goja.FunctionCall) any {
return "[sing-box Surge notification]"
}

View File

@@ -0,0 +1,78 @@
package surge
import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing-box/script/modules/boxctx"
"github.com/sagernet/sing/service"
"github.com/dop251/goja"
)
type PersistentStore struct {
class jsc.Class[*Module, *PersistentStore]
cacheFile adapter.CacheFile
inMemoryCache *adapter.SurgeInMemoryCache
tag string
}
func createPersistentStore(module *Module) jsc.Class[*Module, *PersistentStore] {
class := jsc.NewClass[*Module, *PersistentStore](module)
class.DefineMethod("get", (*PersistentStore).get)
class.DefineMethod("set", (*PersistentStore).set)
class.DefineMethod("toString", (*PersistentStore).toString)
return class
}
func newPersistentStore(class jsc.Class[*Module, *PersistentStore]) goja.Value {
boxCtx := boxctx.MustFromRuntime(class.Runtime())
return class.New(&PersistentStore{
class: class,
cacheFile: service.FromContext[adapter.CacheFile](boxCtx.Context),
inMemoryCache: service.FromContext[adapter.ScriptManager](boxCtx.Context).SurgeCache(),
tag: boxCtx.Tag,
})
}
func (s *PersistentStore) get(call goja.FunctionCall) any {
key := jsc.AssertString(s.class.Runtime(), call.Argument(0), "key", true)
if key == "" {
key = s.tag
}
var value string
if s.cacheFile != nil {
value = s.cacheFile.SurgePersistentStoreRead(key)
} else {
s.inMemoryCache.RLock()
value = s.inMemoryCache.Data[key]
s.inMemoryCache.RUnlock()
}
if value == "" {
return goja.Null()
} else {
return value
}
}
func (s *PersistentStore) set(call goja.FunctionCall) any {
data := jsc.AssertString(s.class.Runtime(), call.Argument(0), "data", true)
key := jsc.AssertString(s.class.Runtime(), call.Argument(1), "key", true)
if key == "" {
key = s.tag
}
if s.cacheFile != nil {
err := s.cacheFile.SurgePersistentStoreWrite(key, data)
if err != nil {
panic(s.class.Runtime().NewGoError(err))
}
} else {
s.inMemoryCache.Lock()
s.inMemoryCache.Data[key] = data
s.inMemoryCache.Unlock()
}
return goja.Undefined()
}
func (s *PersistentStore) toString(call goja.FunctionCall) any {
return "[sing-box Surge persistentStore]"
}

View File

@@ -0,0 +1,32 @@
package surge
import (
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing-box/script/modules/boxctx"
F "github.com/sagernet/sing/common/format"
)
type Script struct {
class jsc.Class[*Module, *Script]
ScriptType string
}
func createScript(module *Module) jsc.Class[*Module, *Script] {
class := jsc.NewClass[*Module, *Script](module)
class.DefineField("name", (*Script).getName, nil)
class.DefineField("type", (*Script).getType, nil)
class.DefineField("startTime", (*Script).getStartTime, nil)
return class
}
func (s *Script) getName() any {
return F.ToString("script:", boxctx.MustFromRuntime(s.class.Runtime()).Tag)
}
func (s *Script) getType() any {
return s.ScriptType
}
func (s *Script) getStartTime() any {
return boxctx.MustFromRuntime(s.class.Runtime()).StartedAt
}

View File

@@ -0,0 +1,50 @@
package surge
import (
"bytes"
"compress/gzip"
"io"
"github.com/sagernet/sing-box/script/jsc"
E "github.com/sagernet/sing/common/exceptions"
"github.com/dop251/goja"
)
type Utils struct {
class jsc.Class[*Module, *Utils]
}
func createUtils(module *Module) jsc.Class[*Module, *Utils] {
class := jsc.NewClass[*Module, *Utils](module)
class.DefineMethod("geoip", (*Utils).stub)
class.DefineMethod("ipasn", (*Utils).stub)
class.DefineMethod("ipaso", (*Utils).stub)
class.DefineMethod("ungzip", (*Utils).ungzip)
class.DefineMethod("toString", (*Utils).toString)
return class
}
func (u *Utils) stub(call goja.FunctionCall) any {
return nil
}
func (u *Utils) ungzip(call goja.FunctionCall) any {
if len(call.Arguments) != 1 {
panic(u.class.Runtime().NewGoError(E.New("invalid argument")))
}
binary := jsc.AssertBinary(u.class.Runtime(), call.Argument(0), "binary", false)
reader, err := gzip.NewReader(bytes.NewReader(binary))
if err != nil {
panic(u.class.Runtime().NewGoError(err))
}
binary, err = io.ReadAll(reader)
if err != nil {
panic(u.class.Runtime().NewGoError(err))
}
return jsc.NewUint8Array(u.class.Runtime(), binary)
}
func (u *Utils) toString(call goja.FunctionCall) any {
return "[sing-box Surge utils]"
}

View File

@@ -0,0 +1,55 @@
package url
import "strings"
var tblEscapeURLQuery = [128]byte{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,
}
// The code below is mostly borrowed from the standard Go url package
const upperhex = "0123456789ABCDEF"
func escape(s string, table *[128]byte, spaceToPlus bool) string {
spaceCount, hexCount := 0, 0
for i := 0; i < len(s); i++ {
c := s[i]
if c > 127 || table[c] == 0 {
if c == ' ' && spaceToPlus {
spaceCount++
} else {
hexCount++
}
}
}
if spaceCount == 0 && hexCount == 0 {
return s
}
var sb strings.Builder
hexBuf := [3]byte{'%', 0, 0}
sb.Grow(len(s) + 2*hexCount)
for i := 0; i < len(s); i++ {
switch c := s[i]; {
case c == ' ' && spaceToPlus:
sb.WriteByte('+')
case c > 127 || table[c] == 0:
hexBuf[1] = upperhex[c>>4]
hexBuf[2] = upperhex[c&15]
sb.Write(hexBuf[:])
default:
sb.WriteByte(c)
}
}
return sb.String()
}

View File

@@ -0,0 +1,41 @@
package url
import (
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing-box/script/modules/require"
"github.com/dop251/goja"
)
const ModuleName = "url"
var _ jsc.Module = (*Module)(nil)
type Module struct {
runtime *goja.Runtime
classURL jsc.Class[*Module, *URL]
classURLSearchParams jsc.Class[*Module, *URLSearchParams]
classURLSearchParamsIterator jsc.Class[*Module, *jsc.Iterator[*Module, searchParam]]
}
func Require(runtime *goja.Runtime, module *goja.Object) {
m := &Module{
runtime: runtime,
}
m.classURL = createURL(m)
m.classURLSearchParams = createURLSearchParams(m)
m.classURLSearchParamsIterator = jsc.CreateIterator[*Module, searchParam](m)
exports := module.Get("exports").(*goja.Object)
exports.Set("URL", m.classURL.ToValue())
exports.Set("URLSearchParams", m.classURLSearchParams.ToValue())
}
func Enable(runtime *goja.Runtime) {
exports := require.Require(runtime, ModuleName).ToObject(runtime)
runtime.Set("URL", exports.Get("URL"))
runtime.Set("URLSearchParams", exports.Get("URLSearchParams"))
}
func (m *Module) Runtime() *goja.Runtime {
return m.runtime
}

View File

@@ -0,0 +1,37 @@
package url_test
import (
_ "embed"
"testing"
"github.com/sagernet/sing-box/script/jstest"
"github.com/sagernet/sing-box/script/modules/url"
"github.com/dop251/goja"
)
var (
//go:embed testdata/url_test.js
urlTest string
//go:embed testdata/url_search_params_test.js
urlSearchParamsTest string
)
func TestURL(t *testing.T) {
registry := jstest.NewRegistry()
registry.RegisterNodeModule(url.ModuleName, url.Require)
vm := goja.New()
registry.Enable(vm)
url.Enable(vm)
vm.RunScript("url_test.js", urlTest)
}
func TestURLSearchParams(t *testing.T) {
registry := jstest.NewRegistry()
registry.RegisterNodeModule(url.ModuleName, url.Require)
vm := goja.New()
registry.Enable(vm)
url.Enable(vm)
vm.RunScript("url_search_params_test.js", urlSearchParamsTest)
}

View File

@@ -0,0 +1,385 @@
"use strict";
const assert = require("assert.js");
let params;
function testCtor(value, expected) {
assert.sameValue(new URLSearchParams(value).toString(), expected);
}
testCtor("user=abc&query=xyz", "user=abc&query=xyz");
testCtor("?user=abc&query=xyz", "user=abc&query=xyz");
testCtor(
{
num: 1,
user: "abc",
query: ["first", "second"],
obj: { prop: "value" },
b: true,
},
"num=1&user=abc&query=first%2Csecond&obj=%5Bobject+Object%5D&b=true"
);
const map = new Map();
map.set("user", "abc");
map.set("query", "xyz");
testCtor(map, "user=abc&query=xyz");
testCtor(
[
["user", "abc"],
["query", "first"],
["query", "second"],
],
"user=abc&query=first&query=second"
);
// Each key-value pair must have exactly two elements
assert.throwsNodeError(() => new URLSearchParams([["single_value"]]), TypeError, "ERR_INVALID_TUPLE");
assert.throwsNodeError(() => new URLSearchParams([["too", "many", "values"]]), TypeError, "ERR_INVALID_TUPLE");
params = new URLSearchParams("a=b&cc=d");
params.forEach((value, name, searchParams) => {
if (name === "a") {
assert.sameValue(value, "b");
}
if (name === "cc") {
assert.sameValue(value, "d");
}
assert.sameValue(searchParams, params);
});
params.forEach((value, name, searchParams) => {
if (name === "a") {
assert.sameValue(value, "b");
searchParams.set("cc", "d1");
}
if (name === "cc") {
assert.sameValue(value, "d1");
}
assert.sameValue(searchParams, params);
});
assert.throwsNodeError(() => params.forEach(123), TypeError, "ERR_INVALID_ARG_TYPE");
assert.throwsNodeError(() => params.forEach.call(1, 2), TypeError, "ERR_INVALID_THIS");
params = new URLSearchParams("a=1=2&b=3");
assert.sameValue(params.size, 2);
assert.sameValue(params.get("a"), "1=2");
assert.sameValue(params.get("b"), "3");
params = new URLSearchParams("&");
assert.sameValue(params.size, 0);
params = new URLSearchParams("& ");
assert.sameValue(params.size, 1);
assert.sameValue(params.get(" "), "");
params = new URLSearchParams(" &");
assert.sameValue(params.size, 1);
assert.sameValue(params.get(" "), "");
params = new URLSearchParams("=");
assert.sameValue(params.size, 1);
assert.sameValue(params.get(""), "");
params = new URLSearchParams("&=2");
assert.sameValue(params.size, 1);
assert.sameValue(params.get(""), "2");
params = new URLSearchParams("?user=abc");
assert.throwsNodeError(() => params.append(), TypeError, "ERR_MISSING_ARGS");
params.append("query", "first");
assert.sameValue(params.toString(), "user=abc&query=first");
params = new URLSearchParams("first=one&second=two&third=three");
assert.throwsNodeError(() => params.delete(), TypeError, "ERR_MISSING_ARGS");
params.delete("second", "fake-value");
assert.sameValue(params.toString(), "first=one&second=two&third=three");
params.delete("third", "three");
assert.sameValue(params.toString(), "first=one&second=two");
params.delete("second");
assert.sameValue(params.toString(), "first=one");
params = new URLSearchParams("user=abc&query=xyz");
assert.throwsNodeError(() => params.get(), TypeError, "ERR_MISSING_ARGS");
assert.sameValue(params.get("user"), "abc");
assert.sameValue(params.get("non-existant"), null);
params = new URLSearchParams("query=first&query=second");
assert.throwsNodeError(() => params.getAll(), TypeError, "ERR_MISSING_ARGS");
const all = params.getAll("query");
assert.sameValue(all.includes("first"), true);
assert.sameValue(all.includes("second"), true);
assert.sameValue(all.length, 2);
const getAllUndefined = params.getAll(undefined);
assert.sameValue(getAllUndefined.length, 0);
const getAllNonExistant = params.getAll("does_not_exists");
assert.sameValue(getAllNonExistant.length, 0);
params = new URLSearchParams("user=abc&query=xyz");
assert.throwsNodeError(() => params.has(), TypeError, "ERR_MISSING_ARGS");
assert.sameValue(params.has(undefined), false);
assert.sameValue(params.has("user"), true);
assert.sameValue(params.has("user", "abc"), true);
assert.sameValue(params.has("user", "abc", "extra-param"), true);
assert.sameValue(params.has("user", "efg"), false);
assert.sameValue(params.has("user", undefined), true);
params = new URLSearchParams();
params.append("foo", "bar");
params.append("foo", "baz");
params.append("abc", "def");
assert.sameValue(params.toString(), "foo=bar&foo=baz&abc=def");
params.set("foo", "def");
params.set("xyz", "opq");
assert.sameValue(params.toString(), "foo=def&abc=def&xyz=opq");
params = new URLSearchParams("query=first&query=second&user=abc&double=first,second");
const URLSearchIteratorPrototype = params.entries().__proto__;
assert.sameValue(typeof URLSearchIteratorPrototype, "object");
assert.sameValue(params[Symbol.iterator], params.entries);
{
const entries = params.entries();
assert.sameValue(entries.toString(), "[object URLSearchParams Iterator]");
assert.sameValue(entries.__proto__, URLSearchIteratorPrototype);
let item = entries.next();
assert.sameValue(item.value.toString(), ["query", "first"].toString());
assert.sameValue(item.done, false);
item = entries.next();
assert.sameValue(item.value.toString(), ["query", "second"].toString());
assert.sameValue(item.done, false);
item = entries.next();
assert.sameValue(item.value.toString(), ["user", "abc"].toString());
assert.sameValue(item.done, false);
item = entries.next();
assert.sameValue(item.value.toString(), ["double", "first,second"].toString());
assert.sameValue(item.done, false);
item = entries.next();
assert.sameValue(item.value, undefined);
assert.sameValue(item.done, true);
}
params = new URLSearchParams("query=first&query=second&user=abc");
{
const keys = params.keys();
assert.sameValue(keys.__proto__, URLSearchIteratorPrototype);
let item = keys.next();
assert.sameValue(item.value, "query");
assert.sameValue(item.done, false);
item = keys.next();
assert.sameValue(item.value, "query");
assert.sameValue(item.done, false);
item = keys.next();
assert.sameValue(item.value, "user");
assert.sameValue(item.done, false);
item = keys.next();
assert.sameValue(item.value, undefined);
assert.sameValue(item.done, true);
}
params = new URLSearchParams("query=first&query=second&user=abc");
{
const values = params.values();
assert.sameValue(values.__proto__, URLSearchIteratorPrototype);
let item = values.next();
assert.sameValue(item.value, "first");
assert.sameValue(item.done, false);
item = values.next();
assert.sameValue(item.value, "second");
assert.sameValue(item.done, false);
item = values.next();
assert.sameValue(item.value, "abc");
assert.sameValue(item.done, false);
item = values.next();
assert.sameValue(item.value, undefined);
assert.sameValue(item.done, true);
}
params = new URLSearchParams("query[]=abc&type=search&query[]=123");
params.sort();
assert.sameValue(params.toString(), "query%5B%5D=abc&query%5B%5D=123&type=search");
params = new URLSearchParams("query=first&query=second&user=abc");
assert.sameValue(params.size, 3);
params = new URLSearchParams("%");
assert.sameValue(params.has("%"), true);
assert.sameValue(params.toString(), "%25=");
{
const params = new URLSearchParams("");
assert.sameValue(params.size, 0);
assert.sameValue(params.toString(), "");
assert.sameValue(params.get(undefined), null);
params.set(undefined, true);
assert.sameValue(params.has(undefined), true);
assert.sameValue(params.has("undefined"), true);
assert.sameValue(params.get("undefined"), "true");
assert.sameValue(params.get(undefined), "true");
assert.sameValue(params.getAll(undefined).toString(), ["true"].toString());
params.delete(undefined);
assert.sameValue(params.has(undefined), false);
assert.sameValue(params.has("undefined"), false);
assert.sameValue(params.has(null), false);
params.set(null, "nullval");
assert.sameValue(params.has(null), true);
assert.sameValue(params.has("null"), true);
assert.sameValue(params.get(null), "nullval");
assert.sameValue(params.get("null"), "nullval");
params.delete(null);
assert.sameValue(params.has(null), false);
assert.sameValue(params.has("null"), false);
}
function* functionGeneratorExample() {
yield ["user", "abc"];
yield ["query", "first"];
yield ["query", "second"];
}
params = new URLSearchParams(functionGeneratorExample());
assert.sameValue(params.toString(), "user=abc&query=first&query=second");
assert.sameValue(params.__proto__.constructor, URLSearchParams);
assert.sameValue(params instanceof URLSearchParams, true);
{
const params = new URLSearchParams("1=2&1=3");
assert.sameValue(params.get(1), "2");
assert.sameValue(params.getAll(1).toString(), ["2", "3"].toString());
assert.sameValue(params.getAll("x").toString(), [].toString());
}
// Sync
{
const url = new URL("https://test.com/");
const params = url.searchParams;
assert.sameValue(params.size, 0);
url.search = "a=1";
assert.sameValue(params.size, 1);
assert.sameValue(params.get("a"), "1");
}
{
const url = new URL("https://test.com/?a=1");
const params = url.searchParams;
assert.sameValue(params.size, 1);
url.search = "";
assert.sameValue(params.size, 0);
url.search = "b=2";
assert.sameValue(params.size, 1);
}
{
const url = new URL("https://test.com/");
const params = url.searchParams;
params.append("a", "1");
assert.sameValue(url.toString(), "https://test.com/?a=1");
}
{
const url = new URL("https://test.com/");
url.searchParams.append("a", "1");
url.searchParams.append("b", "1");
assert.sameValue(url.toString(), "https://test.com/?a=1&b=1");
}
{
const url = new URL("https://test.com/");
const params = url.searchParams;
url.searchParams.append("a", "1");
assert.sameValue(url.search, "?a=1");
}
{
const url = new URL("https://test.com/?a=1");
const params = url.searchParams;
params.append("a", "2");
assert.sameValue(url.search, "?a=1&a=2");
}
{
const url = new URL("https://test.com/");
const params = url.searchParams;
params.set("a", "1");
assert.sameValue(url.search, "?a=1");
}
{
const url = new URL("https://test.com/");
url.searchParams.set("a", "1");
url.searchParams.set("b", "1");
assert.sameValue(url.toString(), "https://test.com/?a=1&b=1");
}
{
const url = new URL("https://test.com/?a=1&b=2");
const params = url.searchParams;
params.delete("a");
assert.sameValue(url.search, "?b=2");
}
{
const url = new URL("https://test.com/?b=2&a=1");
const params = url.searchParams;
params.sort();
assert.sameValue(url.search, "?a=1&b=2");
}
{
const url = new URL("https://test.com/?a=1");
const params = url.searchParams;
params.delete("a");
assert.sameValue(url.search, "");
params.set("a", 2);
assert.sameValue(url.search, "?a=2");
}
// FAILING: no custom properties on wrapped Go structs
/*
{
const params = new URLSearchParams("");
assert.sameValue(Object.isExtensible(params), true);
assert.sameValue(Reflect.defineProperty(params, "customField", {value: 42, configurable: true}), true);
assert.sameValue(params.customField, 42);
const desc = Reflect.getOwnPropertyDescriptor(params, "customField");
assert.sameValue(desc.value, 42);
assert.sameValue(desc.writable, false);
assert.sameValue(desc.enumerable, false);
assert.sameValue(desc.configurable, true);
}
*/
// Escape
{
const myURL = new URL('https://example.org/abc?fo~o=~ba r%z');
assert.sameValue(myURL.search, "?fo~o=~ba%20r%z");
// Modify the URL via searchParams...
myURL.searchParams.sort();
assert.sameValue(myURL.search, "?fo%7Eo=%7Eba+r%25z");
}

229
script/modules/url/testdata/url_test.js vendored Normal file
View File

@@ -0,0 +1,229 @@
"use strict";
const assert = require("assert.js");
function testURLCtor(str, expected) {
assert.sameValue(new URL(str).toString(), expected);
}
function testURLCtorBase(ref, base, expected, message) {
assert.sameValue(new URL(ref, base).toString(), expected, message);
}
testURLCtorBase("https://example.org/", undefined, "https://example.org/");
testURLCtorBase("/foo", "https://example.org/", "https://example.org/foo");
testURLCtorBase("http://Example.com/", "https://example.org/", "http://example.com/");
testURLCtorBase("https://Example.com/", "https://example.org/", "https://example.com/");
testURLCtorBase("foo://Example.com/", "https://example.org/", "foo://Example.com/");
testURLCtorBase("foo:Example.com/", "https://example.org/", "foo:Example.com/");
testURLCtorBase("#hash", "https://example.org/", "https://example.org/#hash");
testURLCtor("HTTP://test.com", "http://test.com/");
testURLCtor("HTTPS://á.com", "https://xn--1ca.com/");
testURLCtor("HTTPS://á.com:123", "https://xn--1ca.com:123/");
testURLCtor("https://test.com#asdfá", "https://test.com/#asdf%C3%A1");
testURLCtor("HTTPS://á.com:123/á", "https://xn--1ca.com:123/%C3%A1");
testURLCtor("fish://á.com", "fish://%C3%A1.com");
testURLCtor("https://test.com/?a=1 /2", "https://test.com/?a=1%20/2");
testURLCtor("https://test.com/á=1?á=1&ü=2#é", "https://test.com/%C3%A1=1?%C3%A1=1&%C3%BC=2#%C3%A9");
assert.throws(() => new URL("test"), TypeError);
assert.throws(() => new URL("ssh://EEE:ddd"), TypeError);
{
let u = new URL("https://example.org/");
assert.sameValue(u.__proto__.constructor, URL);
assert.sameValue(u instanceof URL, true);
}
{
let u = new URL("https://example.org/");
assert.sameValue(u.searchParams, u.searchParams);
}
let myURL;
// Hash
myURL = new URL("https://example.org/foo#bar");
myURL.hash = "baz";
assert.sameValue(myURL.href, "https://example.org/foo#baz");
myURL.hash = "#baz";
assert.sameValue(myURL.href, "https://example.org/foo#baz");
myURL.hash = "#á=1 2";
assert.sameValue(myURL.href, "https://example.org/foo#%C3%A1=1%202");
myURL.hash = "#a/#b";
// FAILING: the second # gets escaped
//assert.sameValue(myURL.href, "https://example.org/foo#a/#b");
assert.sameValue(myURL.search, "");
// FAILING: the second # gets escaped
//assert.sameValue(myURL.hash, "#a/#b");
// Host
myURL = new URL("https://example.org:81/foo");
myURL.host = "example.com:82";
assert.sameValue(myURL.href, "https://example.com:82/foo");
// Hostname
myURL = new URL("https://example.org:81/foo");
myURL.hostname = "example.com:82";
assert.sameValue(myURL.href, "https://example.org:81/foo");
myURL.hostname = "á.com";
assert.sameValue(myURL.href, "https://xn--1ca.com:81/foo");
// href
myURL = new URL("https://example.org/foo");
myURL.href = "https://example.com/bar";
assert.sameValue(myURL.href, "https://example.com/bar");
// Password
myURL = new URL("https://abc:xyz@example.com");
myURL.password = "123";
assert.sameValue(myURL.href, "https://abc:123@example.com/");
// pathname
myURL = new URL("https://example.org/abc/xyz?123");
myURL.pathname = "/abcdef";
assert.sameValue(myURL.href, "https://example.org/abcdef?123");
myURL.pathname = "";
assert.sameValue(myURL.href, "https://example.org/?123");
myURL.pathname = "á";
assert.sameValue(myURL.pathname, "/%C3%A1");
assert.sameValue(myURL.href, "https://example.org/%C3%A1?123");
// port
myURL = new URL("https://example.org:8888");
assert.sameValue(myURL.port, "8888");
function testSetPort(port, expected) {
const url = new URL("https://example.org:8888");
url.port = port;
assert.sameValue(url.port, expected);
}
testSetPort(0, "0");
testSetPort(-0, "0");
// Default ports are automatically transformed to the empty string
// (HTTPS protocol's default port is 443)
testSetPort("443", "");
testSetPort(443, "");
// Empty string is the same as default port
testSetPort("", "");
// Completely invalid port strings are ignored
testSetPort("abcd", "8888");
testSetPort("-123", "");
testSetPort(-123, "");
testSetPort(-123.45, "");
testSetPort(undefined, "8888");
testSetPort(null, "8888");
testSetPort(+Infinity, "8888");
testSetPort(-Infinity, "8888");
testSetPort(NaN, "8888");
// Leading numbers are treated as a port number
testSetPort("5678abcd", "5678");
testSetPort("a5678abcd", "");
// Non-integers are truncated
testSetPort(1234.5678, "1234");
// Out-of-range numbers which are not represented in scientific notation
// will be ignored.
testSetPort(1e10, "8888");
testSetPort("123456", "8888");
testSetPort(123456, "8888");
testSetPort(4.567e21, "4");
// toString() takes precedence over valueOf(), even if it returns a valid integer
testSetPort(
{
toString() {
return "2";
},
valueOf() {
return 1;
},
},
"2"
);
// Protocol
function testSetProtocol(url, protocol, expected) {
url.protocol = protocol;
assert.sameValue(url.protocol, expected);
}
testSetProtocol(new URL("https://example.org"), "ftp", "ftp:");
testSetProtocol(new URL("https://example.org"), "ftp:", "ftp:");
testSetProtocol(new URL("https://example.org"), "FTP:", "ftp:");
testSetProtocol(new URL("https://example.org"), "ftp: blah", "ftp:");
// special to non-special
testSetProtocol(new URL("https://example.org"), "foo", "https:");
// non-special to special
testSetProtocol(new URL("fish://example.org"), "https", "fish:");
// Search
myURL = new URL("https://example.org/abc?123");
myURL.search = "abc=xyz";
assert.sameValue(myURL.href, "https://example.org/abc?abc=xyz");
myURL.search = "a=1 2";
assert.sameValue(myURL.href, "https://example.org/abc?a=1%202");
myURL.search = "á=ú";
assert.sameValue(myURL.search, "?%C3%A1=%C3%BA");
assert.sameValue(myURL.href, "https://example.org/abc?%C3%A1=%C3%BA");
myURL.hash = "hash";
myURL.search = "a=#b";
assert.sameValue(myURL.href, "https://example.org/abc?a=%23b#hash");
assert.sameValue(myURL.search, "?a=%23b");
assert.sameValue(myURL.hash, "#hash");
// Username
myURL = new URL("https://abc:xyz@example.com/");
myURL.username = "123";
assert.sameValue(myURL.href, "https://123:xyz@example.com/");
// Origin, read-only
assert.throws(() => {
myURL.origin = "abc";
}, TypeError);
// href
myURL = new URL("https://example.org");
myURL.href = "https://example.com";
assert.sameValue(myURL.href, "https://example.com/");
assert.throws(() => {
myURL.href = "test";
}, TypeError);
// Search Params
myURL = new URL("https://example.com/");
myURL.searchParams.append("user", "abc");
assert.sameValue(myURL.toString(), "https://example.com/?user=abc");
myURL.searchParams.append("first", "one");
assert.sameValue(myURL.toString(), "https://example.com/?user=abc&first=one");
myURL.searchParams.delete("user");
assert.sameValue(myURL.toString(), "https://example.com/?first=one");
{
const url = require("url");
assert.sameValue(url.domainToASCII('español.com'), "xn--espaol-zwa.com");
assert.sameValue(url.domainToASCII('中文.com'), "xn--fiq228c.com");
assert.sameValue(url.domainToASCII('xn--iñvalid.com'), "");
assert.sameValue(url.domainToUnicode('xn--espaol-zwa.com'), "español.com");
assert.sameValue(url.domainToUnicode('xn--fiq228c.com'), "中文.com");
assert.sameValue(url.domainToUnicode('xn--iñvalid.com'), "");
}

315
script/modules/url/url.go Normal file
View File

@@ -0,0 +1,315 @@
package url
import (
"net"
"net/url"
"strings"
"github.com/sagernet/sing-box/script/jsc"
E "github.com/sagernet/sing/common/exceptions"
"github.com/dop251/goja"
"golang.org/x/net/idna"
)
type URL struct {
class jsc.Class[*Module, *URL]
url *url.URL
params *URLSearchParams
paramsValue goja.Value
}
func newURL(c jsc.Class[*Module, *URL], call goja.ConstructorCall) *URL {
var (
u, base *url.URL
err error
)
switch argURL := call.Argument(0).Export().(type) {
case *URL:
u = argURL.url
default:
u, err = parseURL(call.Argument(0).String())
if err != nil {
panic(c.Runtime().NewGoError(E.Cause(err, "parse URL")))
}
}
if len(call.Arguments) == 2 {
switch argBaseURL := call.Argument(1).Export().(type) {
case *URL:
base = argBaseURL.url
default:
base, err = parseURL(call.Argument(1).String())
if err != nil {
panic(c.Runtime().NewGoError(E.Cause(err, "parse base URL")))
}
}
}
if base != nil {
u = base.ResolveReference(u)
}
return &URL{class: c, url: u}
}
func createURL(module *Module) jsc.Class[*Module, *URL] {
class := jsc.NewClass[*Module, *URL](module)
class.DefineConstructor(newURL)
class.DefineField("hash", (*URL).getHash, (*URL).setHash)
class.DefineField("host", (*URL).getHost, (*URL).setHost)
class.DefineField("hostname", (*URL).getHostName, (*URL).setHostName)
class.DefineField("href", (*URL).getHref, (*URL).setHref)
class.DefineField("origin", (*URL).getOrigin, nil)
class.DefineField("password", (*URL).getPassword, (*URL).setPassword)
class.DefineField("pathname", (*URL).getPathname, (*URL).setPathname)
class.DefineField("port", (*URL).getPort, (*URL).setPort)
class.DefineField("protocol", (*URL).getProtocol, (*URL).setProtocol)
class.DefineField("search", (*URL).getSearch, (*URL).setSearch)
class.DefineField("searchParams", (*URL).getSearchParams, (*URL).setSearchParams)
class.DefineField("username", (*URL).getUsername, (*URL).setUsername)
class.DefineMethod("toString", (*URL).toString)
class.DefineMethod("toJSON", (*URL).toJSON)
class.DefineStaticMethod("canParse", canParse)
// class.DefineStaticMethod("createObjectURL", createObjectURL)
class.DefineStaticMethod("parse", parse)
// class.DefineStaticMethod("revokeObjectURL", revokeObjectURL)
return class
}
func canParse(class jsc.Class[*Module, *URL], call goja.FunctionCall) any {
switch call.Argument(0).Export().(type) {
case *URL:
default:
_, err := parseURL(call.Argument(0).String())
if err != nil {
return false
}
}
if len(call.Arguments) == 2 {
switch call.Argument(1).Export().(type) {
case *URL:
default:
_, err := parseURL(call.Argument(1).String())
if err != nil {
return false
}
}
}
return true
}
func parse(class jsc.Class[*Module, *URL], call goja.FunctionCall) any {
var (
u, base *url.URL
err error
)
switch argURL := call.Argument(0).Export().(type) {
case *URL:
u = argURL.url
default:
u, err = parseURL(call.Argument(0).String())
if err != nil {
return goja.Null()
}
}
if len(call.Arguments) == 2 {
switch argBaseURL := call.Argument(1).Export().(type) {
case *URL:
base = argBaseURL.url
default:
base, err = parseURL(call.Argument(1).String())
if err != nil {
return goja.Null()
}
}
}
if base != nil {
u = base.ResolveReference(u)
}
return &URL{class: class, url: u}
}
func (r *URL) getHash() any {
if r.url.Fragment != "" {
return "#" + r.url.EscapedFragment()
}
return ""
}
func (r *URL) setHash(value goja.Value) {
r.url.RawFragment = strings.TrimPrefix(value.String(), "#")
}
func (r *URL) getHost() any {
return r.url.Host
}
func (r *URL) setHost(value goja.Value) {
r.url.Host = strings.TrimSuffix(value.String(), ":")
}
func (r *URL) getHostName() any {
return r.url.Hostname()
}
func (r *URL) setHostName(value goja.Value) {
r.url.Host = joinHostPort(value.String(), r.url.Port())
}
func (r *URL) getHref() any {
return r.url.String()
}
func (r *URL) setHref(value goja.Value) {
newURL, err := url.Parse(value.String())
if err != nil {
panic(r.class.Runtime().NewGoError(err))
}
r.url = newURL
r.params = nil
}
func (r *URL) getOrigin() any {
return r.url.Scheme + "://" + r.url.Host
}
func (r *URL) getPassword() any {
if r.url.User != nil {
password, _ := r.url.User.Password()
return password
}
return ""
}
func (r *URL) setPassword(value goja.Value) {
if r.url.User == nil {
r.url.User = url.UserPassword("", value.String())
} else {
r.url.User = url.UserPassword(r.url.User.Username(), value.String())
}
}
func (r *URL) getPathname() any {
return r.url.EscapedPath()
}
func (r *URL) setPathname(value goja.Value) {
r.url.RawPath = value.String()
}
func (r *URL) getPort() any {
return r.url.Port()
}
func (r *URL) setPort(value goja.Value) {
r.url.Host = joinHostPort(r.url.Hostname(), value.String())
}
func (r *URL) getProtocol() any {
return r.url.Scheme + ":"
}
func (r *URL) setProtocol(value goja.Value) {
r.url.Scheme = strings.TrimSuffix(value.String(), ":")
}
func (r *URL) getSearch() any {
if r.params != nil {
if len(r.params.params) > 0 {
return "?" + generateQuery(r.params.params)
}
} else if r.url.RawQuery != "" {
return "?" + r.url.RawQuery
}
return ""
}
func (r *URL) setSearch(value goja.Value) {
params, err := parseQuery(value.String())
if err == nil {
if r.params != nil {
r.params.params = params
} else {
r.url.RawQuery = generateQuery(params)
}
}
}
func (r *URL) getSearchParams() any {
var params []searchParam
if r.url.RawQuery != "" {
params, _ = parseQuery(r.url.RawQuery)
}
if r.params == nil {
r.params = &URLSearchParams{
class: r.class.Module().classURLSearchParams,
params: params,
}
r.paramsValue = r.class.Module().classURLSearchParams.New(r.params)
}
return r.paramsValue
}
func (r *URL) setSearchParams(value goja.Value) {
if params, ok := value.Export().(*URLSearchParams); ok {
r.params = params
r.paramsValue = value
}
}
func (r *URL) getUsername() any {
if r.url.User != nil {
return r.url.User.Username()
}
return ""
}
func (r *URL) setUsername(value goja.Value) {
if r.url.User == nil {
r.url.User = url.User(value.String())
} else {
password, _ := r.url.User.Password()
r.url.User = url.UserPassword(value.String(), password)
}
}
func (r *URL) toString(call goja.FunctionCall) any {
if r.params != nil {
r.url.RawQuery = generateQuery(r.params.params)
}
return r.url.String()
}
func (r *URL) toJSON(call goja.FunctionCall) any {
return r.toString(call)
}
func parseURL(s string) (*url.URL, error) {
u, err := url.Parse(s)
if err != nil {
return nil, E.Cause(err, "invalid URL")
}
switch u.Scheme {
case "https", "http", "ftp", "wss", "ws":
if u.Path == "" {
u.Path = "/"
}
hostname := u.Hostname()
asciiHostname, err := idna.Punycode.ToASCII(strings.ToLower(hostname))
if err != nil {
return nil, E.Cause(err, "invalid hostname")
}
if asciiHostname != hostname {
u.Host = joinHostPort(asciiHostname, u.Port())
}
}
if u.RawQuery != "" {
u.RawQuery = escape(u.RawQuery, &tblEscapeURLQuery, false)
}
return u, nil
}
func joinHostPort(hostname, port string) string {
if port == "" {
return hostname
}
return net.JoinHostPort(hostname, port)
}

View File

@@ -0,0 +1,244 @@
package url
import (
"fmt"
"net/url"
"sort"
"strings"
"github.com/sagernet/sing-box/script/jsc"
F "github.com/sagernet/sing/common/format"
"github.com/dop251/goja"
)
type URLSearchParams struct {
class jsc.Class[*Module, *URLSearchParams]
params []searchParam
}
func createURLSearchParams(module *Module) jsc.Class[*Module, *URLSearchParams] {
class := jsc.NewClass[*Module, *URLSearchParams](module)
class.DefineConstructor(newURLSearchParams)
class.DefineField("size", (*URLSearchParams).getSize, nil)
class.DefineMethod("append", (*URLSearchParams).append)
class.DefineMethod("delete", (*URLSearchParams).delete)
class.DefineMethod("entries", (*URLSearchParams).entries)
class.DefineMethod("forEach", (*URLSearchParams).forEach)
class.DefineMethod("get", (*URLSearchParams).get)
class.DefineMethod("getAll", (*URLSearchParams).getAll)
class.DefineMethod("has", (*URLSearchParams).has)
class.DefineMethod("keys", (*URLSearchParams).keys)
class.DefineMethod("set", (*URLSearchParams).set)
class.DefineMethod("sort", (*URLSearchParams).sort)
class.DefineMethod("toString", (*URLSearchParams).toString)
class.DefineMethod("values", (*URLSearchParams).values)
return class
}
func newURLSearchParams(class jsc.Class[*Module, *URLSearchParams], call goja.ConstructorCall) *URLSearchParams {
var (
params []searchParam
err error
)
switch argInit := call.Argument(0).Export().(type) {
case *URLSearchParams:
params = argInit.params
case string:
params, err = parseQuery(argInit)
if err != nil {
panic(class.Runtime().NewGoError(err))
}
case [][]string:
for _, pair := range argInit {
if len(pair) != 2 {
panic(class.Runtime().NewTypeError("Each query pair must be an iterable [name, value] tuple"))
}
params = append(params, searchParam{pair[0], pair[1]})
}
case map[string]any:
for name, value := range argInit {
stringValue, isString := value.(string)
if !isString {
panic(class.Runtime().NewTypeError("Invalid query value"))
}
params = append(params, searchParam{name, stringValue})
}
}
return &URLSearchParams{class, params}
}
func (s *URLSearchParams) getSize() any {
return len(s.params)
}
func (s *URLSearchParams) append(call goja.FunctionCall) any {
name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false)
value := call.Argument(1).String()
s.params = append(s.params, searchParam{name, value})
return goja.Undefined()
}
func (s *URLSearchParams) delete(call goja.FunctionCall) any {
name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false)
argValue := call.Argument(1)
if !jsc.IsNil(argValue) {
value := argValue.String()
for i, param := range s.params {
if param.Key == name && param.Value == value {
s.params = append(s.params[:i], s.params[i+1:]...)
break
}
}
} else {
for i, param := range s.params {
if param.Key == name {
s.params = append(s.params[:i], s.params[i+1:]...)
break
}
}
}
return goja.Undefined()
}
func (s *URLSearchParams) entries(call goja.FunctionCall) any {
return jsc.NewIterator[*Module, searchParam](s.class.Module().classURLSearchParamsIterator, s.params, func(this searchParam) any {
return s.class.Runtime().NewArray(this.Key, this.Value)
})
}
func (s *URLSearchParams) forEach(call goja.FunctionCall) any {
callback := jsc.AssertFunction(s.class.Runtime(), call.Argument(0), "callbackFn")
thisValue := call.Argument(1)
for _, param := range s.params {
for _, value := range param.Value {
_, err := callback(thisValue, s.class.Runtime().ToValue(value), s.class.Runtime().ToValue(param.Key), call.This)
if err != nil {
panic(s.class.Runtime().NewGoError(err))
}
}
}
return goja.Undefined()
}
func (s *URLSearchParams) get(call goja.FunctionCall) any {
name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false)
for _, param := range s.params {
if param.Key == name {
return param.Value
}
}
return goja.Null()
}
func (s *URLSearchParams) getAll(call goja.FunctionCall) any {
name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false)
var values []any
for _, param := range s.params {
if param.Key == name {
values = append(values, param.Value)
}
}
return s.class.Runtime().NewArray(values...)
}
func (s *URLSearchParams) has(call goja.FunctionCall) any {
name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false)
argValue := call.Argument(1)
if !jsc.IsNil(argValue) {
value := argValue.String()
for _, param := range s.params {
if param.Key == name && param.Value == value {
return true
}
}
} else {
for _, param := range s.params {
if param.Key == name {
return true
}
}
}
return false
}
func (s *URLSearchParams) keys(call goja.FunctionCall) any {
return jsc.NewIterator[*Module, searchParam](s.class.Module().classURLSearchParamsIterator, s.params, func(this searchParam) any {
return this.Key
})
}
func (s *URLSearchParams) set(call goja.FunctionCall) any {
name := jsc.AssertString(s.class.Runtime(), call.Argument(0), "name", false)
value := call.Argument(1).String()
for i, param := range s.params {
if param.Key == name {
s.params[i].Value = value
return goja.Undefined()
}
}
s.params = append(s.params, searchParam{name, value})
return goja.Undefined()
}
func (s *URLSearchParams) sort(call goja.FunctionCall) any {
sort.SliceStable(s.params, func(i, j int) bool {
return s.params[i].Key < s.params[j].Key
})
return goja.Undefined()
}
func (s *URLSearchParams) toString(call goja.FunctionCall) any {
return generateQuery(s.params)
}
func (s *URLSearchParams) values(call goja.FunctionCall) any {
return jsc.NewIterator[*Module, searchParam](s.class.Module().classURLSearchParamsIterator, s.params, func(this searchParam) any {
return this.Value
})
}
type searchParam struct {
Key string
Value string
}
func parseQuery(query string) (params []searchParam, err error) {
query = strings.TrimPrefix(query, "?")
for query != "" {
var key string
key, query, _ = strings.Cut(query, "&")
if strings.Contains(key, ";") {
err = fmt.Errorf("invalid semicolon separator in query")
continue
}
if key == "" {
continue
}
key, value, _ := strings.Cut(key, "=")
key, err1 := url.QueryUnescape(key)
if err1 != nil {
if err == nil {
err = err1
}
continue
}
value, err1 = url.QueryUnescape(value)
if err1 != nil {
if err == nil {
err = err1
}
continue
}
params = append(params, searchParam{key, value})
}
return
}
func generateQuery(params []searchParam) string {
var parts []string
for _, param := range params {
parts = append(parts, F.ToString(param.Key, "=", url.QueryEscape(param.Value)))
}
return strings.Join(parts, "&")
}

47
script/runtime.go Normal file
View File

@@ -0,0 +1,47 @@
package script
import (
"context"
"github.com/sagernet/sing-box/script/modules/boxctx"
"github.com/sagernet/sing-box/script/modules/console"
"github.com/sagernet/sing-box/script/modules/eventloop"
"github.com/sagernet/sing-box/script/modules/require"
"github.com/sagernet/sing-box/script/modules/surge"
"github.com/sagernet/sing-box/script/modules/url"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/common/ntp"
"github.com/dop251/goja"
"github.com/dop251/goja/parser"
)
func NewRuntime(ctx context.Context, cancel context.CancelCauseFunc) *goja.Runtime {
vm := goja.New()
if timeFunc := ntp.TimeFuncFromContext(ctx); timeFunc != nil {
vm.SetTimeSource(timeFunc)
}
vm.SetParserOptions(parser.WithDisableSourceMaps)
registry := require.NewRegistry(require.WithLoader(func(path string) ([]byte, error) {
return nil, E.New("unsupported usage")
}))
registry.Enable(vm)
registry.RegisterNodeModule(console.ModuleName, console.Require)
registry.RegisterNodeModule(url.ModuleName, url.Require)
registry.RegisterNativeModule(boxctx.ModuleName, boxctx.Require)
registry.RegisterNativeModule(surge.ModuleName, surge.Require)
console.Enable(vm)
url.Enable(vm)
eventloop.Enable(vm, cancel)
return vm
}
func SetModules(runtime *goja.Runtime, ctx context.Context, logger logger.ContextLogger, errorHandler func(error), tag string) {
boxctx.Enable(runtime, &boxctx.Context{
Context: ctx,
Logger: logger,
Tag: tag,
ErrorHandler: errorHandler,
})
}

20
script/script.go Normal file
View File

@@ -0,0 +1,20 @@
package script
import (
"context"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
func NewScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (adapter.Script, error) {
switch options.Type {
case C.ScriptTypeSurge:
return NewSurgeScript(ctx, logger, options)
default:
return nil, E.New("unknown script type: ", options.Type)
}
}

345
script/script_surge.go Normal file
View File

@@ -0,0 +1,345 @@
package script
import (
"context"
"net/http"
"sync"
"time"
"unsafe"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/script/jsc"
"github.com/sagernet/sing-box/script/modules/surge"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
"github.com/adhocore/gronx"
"github.com/dop251/goja"
)
const defaultSurgeScriptTimeout = 10 * time.Second
var _ adapter.SurgeScript = (*SurgeScript)(nil)
type SurgeScript struct {
ctx context.Context
logger logger.ContextLogger
tag string
source Source
cronExpression string
cronTimeout time.Duration
cronArguments []string
cronTimer *time.Timer
cronDone chan struct{}
}
func NewSurgeScript(ctx context.Context, logger logger.ContextLogger, options option.Script) (adapter.Script, error) {
source, err := NewSource(ctx, logger, options)
if err != nil {
return nil, err
}
cronOptions := common.PtrValueOrDefault(options.SurgeOptions.CronOptions)
if cronOptions.Expression != "" {
if !gronx.IsValid(cronOptions.Expression) {
return nil, E.New("invalid cron expression: ", cronOptions.Expression)
}
}
return &SurgeScript{
ctx: ctx,
logger: logger,
tag: options.Tag,
source: source,
cronExpression: cronOptions.Expression,
cronTimeout: time.Duration(cronOptions.Timeout),
cronArguments: cronOptions.Arguments,
cronDone: make(chan struct{}),
}, nil
}
func (s *SurgeScript) Type() string {
return C.ScriptTypeSurge
}
func (s *SurgeScript) Tag() string {
return s.tag
}
func (s *SurgeScript) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
return s.source.StartContext(ctx, startContext)
}
func (s *SurgeScript) PostStart() error {
err := s.source.PostStart()
if err != nil {
return err
}
if s.cronExpression != "" {
go s.loopCronEvents()
}
return nil
}
func (s *SurgeScript) loopCronEvents() {
s.logger.Debug("starting event")
err := s.ExecuteGeneric(s.ctx, "cron", s.cronTimeout, s.cronArguments)
if err != nil {
s.logger.Error(E.Cause(err, "running event"))
}
nextTick, err := gronx.NextTick(s.cronExpression, false)
if err != nil {
s.logger.Error(E.Cause(err, "determine next tick"))
return
}
s.cronTimer = time.NewTimer(nextTick.Sub(time.Now()))
s.logger.Debug("next event at: ", nextTick.Format(log.DefaultTimeFormat))
for {
select {
case <-s.ctx.Done():
return
case <-s.cronDone:
return
case <-s.cronTimer.C:
s.logger.Debug("starting event")
err = s.ExecuteGeneric(s.ctx, "cron", s.cronTimeout, s.cronArguments)
if err != nil {
s.logger.Error(E.Cause(err, "running event"))
}
nextTick, err = gronx.NextTick(s.cronExpression, false)
if err != nil {
s.logger.Error(E.Cause(err, "determine next tick"))
return
}
s.cronTimer.Reset(nextTick.Sub(time.Now()))
s.logger.Debug("configured next event at: ", nextTick)
}
}
}
func (s *SurgeScript) Close() error {
err := s.source.Close()
if s.cronTimer != nil {
s.cronTimer.Stop()
close(s.cronDone)
}
return err
}
func (s *SurgeScript) ExecuteGeneric(ctx context.Context, scriptType string, timeout time.Duration, arguments []string) error {
program := s.source.Program()
if program == nil {
return E.New("invalid script")
}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
runtime := NewRuntime(ctx, cancel)
SetModules(runtime, ctx, s.logger, cancel, s.tag)
surge.Enable(runtime, scriptType, arguments)
if timeout == 0 {
timeout = defaultSurgeScriptTimeout
}
ctx, timeoutCancel := context.WithTimeout(ctx, timeout)
defer timeoutCancel()
done := make(chan struct{})
doneFunc := common.OnceFunc(func() {
close(done)
})
runtime.Set("done", func(call goja.FunctionCall) goja.Value {
doneFunc()
return goja.Undefined()
})
var (
access sync.Mutex
scriptErr error
)
go func() {
_, err := runtime.RunProgram(program)
if err != nil {
access.Lock()
scriptErr = err
access.Unlock()
doneFunc()
}
}()
select {
case <-ctx.Done():
runtime.Interrupt(ctx.Err())
return ctx.Err()
case <-done:
access.Lock()
defer access.Unlock()
if scriptErr != nil {
runtime.Interrupt(scriptErr)
} else {
runtime.Interrupt("script done")
}
}
return scriptErr
}
func (s *SurgeScript) ExecuteHTTPRequest(ctx context.Context, timeout time.Duration, request *http.Request, body []byte, binaryBody bool, arguments []string) (*adapter.HTTPRequestScriptResult, error) {
program := s.source.Program()
if program == nil {
return nil, E.New("invalid script")
}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
runtime := NewRuntime(ctx, cancel)
SetModules(runtime, ctx, s.logger, cancel, s.tag)
surge.Enable(runtime, "http-request", arguments)
if timeout == 0 {
timeout = defaultSurgeScriptTimeout
}
ctx, timeoutCancel := context.WithTimeout(ctx, timeout)
defer timeoutCancel()
runtime.ClearInterrupt()
requestObject := runtime.NewObject()
requestObject.Set("url", request.URL.String())
requestObject.Set("method", request.Method)
requestObject.Set("headers", jsc.HeadersToValue(runtime, request.Header))
if !binaryBody {
requestObject.Set("body", string(body))
} else {
requestObject.Set("body", jsc.NewUint8Array(runtime, body))
}
requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request))))
runtime.Set("request", requestObject)
done := make(chan struct{})
doneFunc := common.OnceFunc(func() {
close(done)
})
var (
access sync.Mutex
result adapter.HTTPRequestScriptResult
scriptErr error
)
runtime.Set("done", func(call goja.FunctionCall) goja.Value {
defer doneFunc()
resultObject := jsc.AssertObject(runtime, call.Argument(0), "done() argument", true)
if resultObject == nil {
panic(runtime.NewGoError(E.New("request rejected by script")))
}
access.Lock()
defer access.Unlock()
result.URL = jsc.AssertString(runtime, resultObject.Get("url"), "url", true)
result.Headers = jsc.AssertHTTPHeader(runtime, resultObject.Get("headers"), "headers")
result.Body = jsc.AssertStringBinary(runtime, resultObject.Get("body"), "body", true)
responseObject := jsc.AssertObject(runtime, resultObject.Get("response"), "response", true)
if responseObject != nil {
result.Response = &adapter.HTTPRequestScriptResponse{
Status: int(jsc.AssertInt(runtime, responseObject.Get("status"), "status", true)),
Headers: jsc.AssertHTTPHeader(runtime, responseObject.Get("headers"), "headers"),
Body: jsc.AssertStringBinary(runtime, responseObject.Get("body"), "body", true),
}
}
return goja.Undefined()
})
go func() {
_, err := runtime.RunProgram(program)
if err != nil {
access.Lock()
scriptErr = err
access.Unlock()
doneFunc()
}
}()
select {
case <-ctx.Done():
runtime.Interrupt(ctx.Err())
return nil, ctx.Err()
case <-done:
access.Lock()
defer access.Unlock()
if scriptErr != nil {
runtime.Interrupt(scriptErr)
} else {
runtime.Interrupt("script done")
}
}
return &result, scriptErr
}
func (s *SurgeScript) ExecuteHTTPResponse(ctx context.Context, timeout time.Duration, request *http.Request, response *http.Response, body []byte, binaryBody bool, arguments []string) (*adapter.HTTPResponseScriptResult, error) {
program := s.source.Program()
if program == nil {
return nil, E.New("invalid script")
}
ctx, cancel := context.WithCancelCause(ctx)
defer cancel(nil)
runtime := NewRuntime(ctx, cancel)
SetModules(runtime, ctx, s.logger, cancel, s.tag)
surge.Enable(runtime, "http-response", arguments)
if timeout == 0 {
timeout = defaultSurgeScriptTimeout
}
ctx, timeoutCancel := context.WithTimeout(ctx, timeout)
defer timeoutCancel()
runtime.ClearInterrupt()
requestObject := runtime.NewObject()
requestObject.Set("url", request.URL.String())
requestObject.Set("method", request.Method)
requestObject.Set("headers", jsc.HeadersToValue(runtime, request.Header))
requestObject.Set("id", F.ToString(uintptr(unsafe.Pointer(request))))
runtime.Set("request", requestObject)
responseObject := runtime.NewObject()
responseObject.Set("status", response.StatusCode)
responseObject.Set("headers", jsc.HeadersToValue(runtime, response.Header))
if !binaryBody {
responseObject.Set("body", string(body))
} else {
responseObject.Set("body", jsc.NewUint8Array(runtime, body))
}
runtime.Set("response", responseObject)
done := make(chan struct{})
doneFunc := common.OnceFunc(func() {
close(done)
})
var (
access sync.Mutex
result adapter.HTTPResponseScriptResult
scriptErr error
)
runtime.Set("done", func(call goja.FunctionCall) goja.Value {
resultObject := jsc.AssertObject(runtime, call.Argument(0), "done() argument", true)
if resultObject == nil {
panic(runtime.NewGoError(E.New("response rejected by script")))
}
access.Lock()
defer access.Unlock()
result.Status = int(jsc.AssertInt(runtime, resultObject.Get("status"), "status", true))
result.Headers = jsc.AssertHTTPHeader(runtime, resultObject.Get("headers"), "headers")
result.Body = jsc.AssertStringBinary(runtime, resultObject.Get("body"), "body", true)
doneFunc()
return goja.Undefined()
})
go func() {
_, err := runtime.RunProgram(program)
if err != nil {
access.Lock()
scriptErr = err
access.Unlock()
doneFunc()
}
}()
select {
case <-ctx.Done():
runtime.Interrupt(ctx.Err())
return nil, ctx.Err()
case <-done:
access.Lock()
defer access.Unlock()
if scriptErr != nil {
runtime.Interrupt(scriptErr)
} else {
runtime.Interrupt("script done")
}
return &result, scriptErr
}
}

31
script/source.go Normal file
View File

@@ -0,0 +1,31 @@
package script
import (
"context"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/dop251/goja"
)
type Source interface {
StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error
PostStart() error
Program() *goja.Program
Close() error
}
func NewSource(ctx context.Context, logger logger.Logger, options option.Script) (Source, error) {
switch options.Source {
case C.ScriptSourceTypeLocal:
return NewLocalSource(ctx, logger, options)
case C.ScriptSourceTypeRemote:
return NewRemoteSource(ctx, logger, options)
default:
return nil, E.New("unknown source type: ", options.Source)
}
}

92
script/source_local.go Normal file
View File

@@ -0,0 +1,92 @@
package script
import (
"context"
"os"
"path/filepath"
"github.com/sagernet/fswatch"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service/filemanager"
"github.com/dop251/goja"
)
var _ Source = (*LocalSource)(nil)
type LocalSource struct {
ctx context.Context
logger logger.Logger
tag string
program *goja.Program
watcher *fswatch.Watcher
}
func NewLocalSource(ctx context.Context, logger logger.Logger, options option.Script) (*LocalSource, error) {
script := &LocalSource{
ctx: ctx,
logger: logger,
tag: options.Tag,
}
filePath := filemanager.BasePath(ctx, options.LocalOptions.Path)
filePath, _ = filepath.Abs(options.LocalOptions.Path)
err := script.reloadFile(filePath)
if err != nil {
return nil, err
}
watcher, err := fswatch.NewWatcher(fswatch.Options{
Path: []string{filePath},
Callback: func(path string) {
uErr := script.reloadFile(path)
if uErr != nil {
logger.Error(E.Cause(uErr, "reload script ", path))
}
},
})
if err != nil {
return nil, err
}
script.watcher = watcher
return script, nil
}
func (s *LocalSource) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
if s.watcher != nil {
err := s.watcher.Start()
if err != nil {
s.logger.Error(E.Cause(err, "watch script file"))
}
}
return nil
}
func (s *LocalSource) reloadFile(path string) error {
content, err := os.ReadFile(path)
if err != nil {
return err
}
program, err := goja.Compile("script:"+s.tag, string(content), false)
if err != nil {
return E.Cause(err, "compile ", path)
}
if s.program != nil {
s.logger.Info("reloaded from ", path)
}
s.program = program
return nil
}
func (s *LocalSource) PostStart() error {
return nil
}
func (s *LocalSource) Program() *goja.Program {
return s.program
}
func (s *LocalSource) Close() error {
return s.watcher.Close()
}

224
script/source_remote.go Normal file
View File

@@ -0,0 +1,224 @@
package script
import (
"context"
"crypto/tls"
"io"
"net"
"net/http"
"runtime"
"time"
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/dop251/goja"
)
var _ Source = (*RemoteSource)(nil)
type RemoteSource struct {
ctx context.Context
cancel context.CancelFunc
logger logger.Logger
outbound adapter.OutboundManager
options option.Script
updateInterval time.Duration
dialer N.Dialer
program *goja.Program
lastUpdated time.Time
lastEtag string
updateTicker *time.Ticker
cacheFile adapter.CacheFile
pauseManager pause.Manager
}
func NewRemoteSource(ctx context.Context, logger logger.Logger, options option.Script) (*RemoteSource, error) {
ctx, cancel := context.WithCancel(ctx)
var updateInterval time.Duration
if options.RemoteOptions.UpdateInterval > 0 {
updateInterval = time.Duration(options.RemoteOptions.UpdateInterval)
} else {
updateInterval = 24 * time.Hour
}
return &RemoteSource{
ctx: ctx,
cancel: cancel,
logger: logger,
outbound: service.FromContext[adapter.OutboundManager](ctx),
options: options,
updateInterval: updateInterval,
pauseManager: service.FromContext[pause.Manager](ctx),
}, nil
}
func (s *RemoteSource) StartContext(ctx context.Context, startContext *adapter.HTTPStartContext) error {
s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx)
var dialer N.Dialer
if s.options.RemoteOptions.DownloadDetour != "" {
outbound, loaded := s.outbound.Outbound(s.options.RemoteOptions.DownloadDetour)
if !loaded {
return E.New("download detour not found: ", s.options.RemoteOptions.DownloadDetour)
}
dialer = outbound
} else {
dialer = s.outbound.Default()
}
s.dialer = dialer
if s.cacheFile != nil {
if savedSet := s.cacheFile.LoadScript(s.options.Tag); savedSet != nil {
err := s.loadBytes(savedSet.Content)
if err != nil {
return E.Cause(err, "restore cached rule-set")
}
s.lastUpdated = savedSet.LastUpdated
s.lastEtag = savedSet.LastEtag
}
}
if s.lastUpdated.IsZero() {
err := s.fetchOnce(ctx, startContext)
if err != nil {
return E.Cause(err, "initial rule-set: ", s.options.Tag)
}
}
s.updateTicker = time.NewTicker(s.updateInterval)
return nil
}
func (s *RemoteSource) PostStart() error {
go s.loopUpdate()
return nil
}
func (s *RemoteSource) Program() *goja.Program {
return s.program
}
func (s *RemoteSource) loadBytes(content []byte) error {
program, err := goja.Compile(F.ToString("script:", s.options.Tag), string(content), false)
if err != nil {
return err
}
s.program = program
return nil
}
func (s *RemoteSource) loopUpdate() {
if time.Since(s.lastUpdated) > s.updateInterval {
err := s.fetchOnce(s.ctx, nil)
if err != nil {
s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err)
}
}
for {
runtime.GC()
select {
case <-s.ctx.Done():
return
case <-s.updateTicker.C:
s.pauseManager.WaitActive()
err := s.fetchOnce(s.ctx, nil)
if err != nil {
s.logger.Error("fetch rule-set ", s.options.Tag, ": ", err)
}
}
}
}
func (s *RemoteSource) fetchOnce(ctx context.Context, startContext *adapter.HTTPStartContext) error {
s.logger.Debug("updating script ", s.options.Tag, " from URL: ", s.options.RemoteOptions.URL)
var httpClient *http.Client
if startContext != nil {
httpClient = startContext.HTTPClient(s.options.RemoteOptions.DownloadDetour, s.dialer)
} else {
httpClient = &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
TLSHandshakeTimeout: C.TCPTimeout,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
TLSClientConfig: &tls.Config{
Time: ntp.TimeFuncFromContext(s.ctx),
RootCAs: adapter.RootPoolFromContext(s.ctx),
},
},
}
}
request, err := http.NewRequest("GET", s.options.RemoteOptions.URL, nil)
if err != nil {
return err
}
if s.lastEtag != "" {
request.Header.Set("If-None-Match", s.lastEtag)
}
response, err := httpClient.Do(request.WithContext(ctx))
if err != nil {
return err
}
switch response.StatusCode {
case http.StatusOK:
case http.StatusNotModified:
s.lastUpdated = time.Now()
if s.cacheFile != nil {
savedRuleSet := s.cacheFile.LoadScript(s.options.Tag)
if savedRuleSet != nil {
savedRuleSet.LastUpdated = s.lastUpdated
err = s.cacheFile.SaveScript(s.options.Tag, savedRuleSet)
if err != nil {
s.logger.Error("save script updated time: ", err)
return nil
}
}
}
s.logger.Info("update script ", s.options.Tag, ": not modified")
return nil
default:
return E.New("unexpected status: ", response.Status)
}
content, err := io.ReadAll(response.Body)
if err != nil {
response.Body.Close()
return err
}
err = s.loadBytes(content)
if err != nil {
response.Body.Close()
return err
}
response.Body.Close()
eTagHeader := response.Header.Get("Etag")
if eTagHeader != "" {
s.lastEtag = eTagHeader
}
s.lastUpdated = time.Now()
if s.cacheFile != nil {
err = s.cacheFile.SaveScript(s.options.Tag, &adapter.SavedBinary{
LastUpdated: s.lastUpdated,
Content: content,
LastEtag: s.lastEtag,
})
if err != nil {
s.logger.Error("save script cache: ", err)
}
}
s.logger.Info("updated script ", s.options.Tag)
return nil
}
func (s *RemoteSource) Close() error {
if s.updateTicker != nil {
s.updateTicker.Stop()
}
s.cancel()
return nil
}