Add OpenVPN, TrustTunnel, Sudoku, inbound managers. Fixes

This commit is contained in:
Shtorm
2026-06-04 01:47:50 +03:00
parent 9b3da79c32
commit 195a33379d
164 changed files with 16665 additions and 1332 deletions

View File

@@ -96,6 +96,12 @@ func (h *Inbound) Close() error {
return common.Close(h.listener, h.tlsConfig)
}
func (h *Inbound) UpdateUsers(users []option.AnyTLSUser) {
h.service.UpdateUsers(common.Map(users, func(it option.AnyTLSUser) anytls.User {
return anytls.User(it)
}))
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if h.tlsConfig != nil {
tlsConn, err := tls.ServerHandshake(ctx, conn, h.tlsConfig)

View File

@@ -86,6 +86,10 @@ func (h *Inbound) Close() error {
)
}
func (h *Inbound) UpdateUsers(users []auth.User) {
h.authenticator.UpdateUsers(users)
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
if h.tlsConfig != nil {
tlsConn, err := tls.ServerHandshake(ctx, conn, h.tlsConfig)

View File

@@ -3,15 +3,17 @@ package bandwidth
import (
"context"
"net"
"github.com/sagernet/sing-box/common/onclose"
)
type connWithDownloadBandwidthLimiter struct {
net.Conn
ctx context.Context
limiter Limiter
limiter BandwidthLimiter
}
func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter Limiter) *connWithDownloadBandwidthLimiter {
func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter BandwidthLimiter) *connWithDownloadBandwidthLimiter {
return &connWithDownloadBandwidthLimiter{conn, ctx, limiter}
}
@@ -26,10 +28,10 @@ func (conn *connWithDownloadBandwidthLimiter) Write(p []byte) (n int, err error)
type connWithUploadBandwidthLimiter struct {
net.Conn
ctx context.Context
limiter Limiter
limiter BandwidthLimiter
}
func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter Limiter) *connWithUploadBandwidthLimiter {
func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter BandwidthLimiter) *connWithUploadBandwidthLimiter {
return &connWithUploadBandwidthLimiter{conn, ctx, limiter}
}
@@ -47,10 +49,10 @@ func (conn *connWithUploadBandwidthLimiter) Read(p []byte) (n int, err error) {
type connWithCloseHandler struct {
net.Conn
onClose CloseHandlerFunc
onClose onclose.CloseHandlerFunc
}
func NewConnWithCloseHandler(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandler {
func NewConnWithCloseHandler(conn net.Conn, onClose onclose.CloseHandlerFunc) *connWithCloseHandler {
return &connWithCloseHandler{conn, onClose}
}
@@ -62,10 +64,10 @@ func (conn *connWithCloseHandler) Close() error {
type packetConnWithDownloadBandwidthLimiter struct {
net.PacketConn
ctx context.Context
limiter Limiter
limiter BandwidthLimiter
}
func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter Limiter) *packetConnWithDownloadBandwidthLimiter {
func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter) *packetConnWithDownloadBandwidthLimiter {
return &packetConnWithDownloadBandwidthLimiter{conn, ctx, limiter}
}
@@ -80,10 +82,10 @@ func (conn *packetConnWithDownloadBandwidthLimiter) WriteTo(p []byte, addr net.A
type packetConnWithUploadBandwidthLimiter struct {
net.PacketConn
ctx context.Context
limiter Limiter
limiter BandwidthLimiter
}
func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter Limiter) *packetConnWithUploadBandwidthLimiter {
func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter) *packetConnWithUploadBandwidthLimiter {
return &packetConnWithUploadBandwidthLimiter{conn, ctx, limiter}
}
@@ -101,10 +103,10 @@ func (conn *packetConnWithUploadBandwidthLimiter) ReadFrom(p []byte) (n int, add
type packetConnWithCloseHandler struct {
net.PacketConn
onClose CloseHandlerFunc
onClose onclose.CloseHandlerFunc
}
func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandler {
func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose onclose.CloseHandlerFunc) *packetConnWithCloseHandler {
return &packetConnWithCloseHandler{conn, onClose}
}
@@ -113,38 +115,38 @@ func (conn *packetConnWithCloseHandler) Close() error {
return conn.PacketConn.Close()
}
func connWithDownloadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn {
func connWithDownloadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn {
if reverse {
return NewConnWithUploadBandwidthLimiter(ctx, conn, limiter)
}
return NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
}
func connWithUploadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn {
func connWithUploadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn {
if reverse {
return NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
}
return NewConnWithUploadBandwidthLimiter(ctx, conn, limiter)
}
func connWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn {
func connWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn {
return NewConnWithUploadBandwidthLimiter(ctx, NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter), limiter)
}
func packetConnWithDownloadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn {
func packetConnWithDownloadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn {
if reverse {
return NewPacketConnWithUploadBandwidthLimiter(ctx, conn, limiter)
}
return NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
}
func packetConnWithUploadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn {
func packetConnWithUploadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn {
if reverse {
return NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
}
return NewPacketConnWithUploadBandwidthLimiter(ctx, conn, limiter)
}
func packetConnWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn {
func packetConnWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn {
return NewPacketConnWithUploadBandwidthLimiter(ctx, NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter), limiter)
}

View File

@@ -9,12 +9,13 @@ import (
"github.com/sagernet/sing-box/adapter"
)
type Limiter interface {
type BandwidthLimiter interface {
WaitN(ctx context.Context, n int) (err error)
SetSpeed(speed uint64)
}
type FlowKeysLimiter struct {
limiter Limiter
limiter BandwidthLimiter
connIDGetter ConnIDGetter
waits map[string][]*wait
@@ -25,7 +26,7 @@ type FlowKeysLimiter struct {
mtx sync.Mutex
}
func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter Limiter) *FlowKeysLimiter {
func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter BandwidthLimiter) *FlowKeysLimiter {
return &FlowKeysLimiter{
limiter: limiter,
connIDGetter: connIDGetter,
@@ -36,6 +37,10 @@ func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter Limiter) *FlowKeysLim
}
}
func (l *FlowKeysLimiter) SetSpeed(speed uint64) {
l.limiter.SetSpeed(speed)
}
func (l *FlowKeysLimiter) WaitN(ctx context.Context, n int) error {
id, _ := l.connIDGetter(ctx, adapter.ContextFrom(ctx))
mainWait := &wait{ctx, make(chan struct{}), n}

View File

@@ -7,16 +7,16 @@ import (
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/onclose"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
"golang.org/x/time/rate"
)
type (
CloseHandlerFunc = func()
ConnIDGetter = func(context.Context, *adapter.InboundContext) (string, bool)
ConnWrapper = func(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn
PacketConnWrapper = func(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn
ConnWrapper = func(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn
PacketConnWrapper = func(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn
)
type BandwidthStrategy interface {
@@ -24,8 +24,12 @@ type BandwidthStrategy interface {
wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error)
}
type SpeedUpdater interface {
SetSpeed(speed uint64)
}
type BandwidthLimiterStrategy interface {
getLimiter(ctx context.Context, metadata *adapter.InboundContext) (Limiter, CloseHandlerFunc, error)
getLimiter(ctx context.Context, metadata *adapter.InboundContext) (BandwidthLimiter, onclose.CloseHandlerFunc, error)
}
type DefaultWrapStrategy struct {
@@ -54,8 +58,14 @@ func (s *DefaultWrapStrategy) wrapPacketConn(ctx context.Context, conn net.Packe
return NewPacketConnWithCloseHandler(s.packetConnWrapper(ctx, conn, limiter, reverse), onClose), nil
}
func (s *DefaultWrapStrategy) SetSpeed(speed uint64) {
if updater, ok := s.limiterStrategy.(SpeedUpdater); ok {
updater.SetSpeed(speed)
}
}
type GlobalBandwidthStrategy struct {
limiter Limiter
limiter BandwidthLimiter
}
func NewGlobalBandwidthStrategy(speed uint64, flowKeys []string) (*GlobalBandwidthStrategy, error) {
@@ -68,12 +78,16 @@ func NewGlobalBandwidthStrategy(speed uint64, flowKeys []string) (*GlobalBandwid
}, nil
}
func (s *GlobalBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (Limiter, CloseHandlerFunc, error) {
func (s *GlobalBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (BandwidthLimiter, onclose.CloseHandlerFunc, error) {
return s.limiter, func() {}, nil
}
func (s *GlobalBandwidthStrategy) SetSpeed(speed uint64) {
s.limiter.SetSpeed(speed)
}
type idBandwidthLimiter struct {
limiter Limiter
limiter BandwidthLimiter
handles uint32
}
@@ -94,7 +108,7 @@ func NewConnectionBandwidthStrategy(connIDGetter ConnIDGetter, speed uint64, flo
}
}
func (s *ConnectionBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (Limiter, CloseHandlerFunc, error) {
func (s *ConnectionBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (BandwidthLimiter, onclose.CloseHandlerFunc, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
id, ok := s.connIDGetter(ctx, metadata)
@@ -126,6 +140,15 @@ func (s *ConnectionBandwidthStrategy) getLimiter(ctx context.Context, metadata *
}, nil
}
func (s *ConnectionBandwidthStrategy) SetSpeed(speed uint64) {
s.mtx.Lock()
defer s.mtx.Unlock()
s.speed = speed
for _, limiter := range s.limiters {
limiter.limiter.SetSpeed(speed)
}
}
type UsersBandwidthStrategy struct {
strategies map[string]BandwidthStrategy
mtx sync.Mutex
@@ -167,20 +190,86 @@ func (s *UsersBandwidthStrategy) getStrategy(ctx context.Context, metadata *adap
return nil, E.New("user strategy not found: ", user)
}
type bwConnEntry struct {
conn net.Conn
}
type ManagerBandwidthStrategy struct {
*UsersBandwidthStrategy
strategies map[string]BandwidthStrategy
conns map[string][]*bwConnEntry
mtx sync.Mutex
}
func NewManagerBandwidthStrategy() *ManagerBandwidthStrategy {
return &ManagerBandwidthStrategy{
UsersBandwidthStrategy: NewUsersBandwidthStrategy(map[string]BandwidthStrategy{}),
strategies: make(map[string]BandwidthStrategy),
conns: make(map[string][]*bwConnEntry),
}
}
func (s *ManagerBandwidthStrategy) wrapConn(ctx context.Context, conn net.Conn, metadata *adapter.InboundContext, reverse bool) (net.Conn, error) {
s.mtx.Lock()
var user string
if metadata != nil {
user = metadata.User
}
strategy, ok := s.strategies[user]
s.mtx.Unlock()
if !ok {
return nil, E.New("user strategy not found: ", user)
}
wrapped, err := strategy.wrapConn(ctx, conn, metadata, reverse)
if err != nil {
return nil, err
}
entry := &bwConnEntry{conn: conn}
s.mtx.Lock()
s.conns[user] = append(s.conns[user], entry)
s.mtx.Unlock()
return onclose.NewConn(wrapped, func() {
s.mtx.Lock()
entries := s.conns[user]
for i, e := range entries {
if e == entry {
s.conns[user] = append(entries[:i], entries[i+1:]...)
break
}
}
s.mtx.Unlock()
}), nil
}
func (s *ManagerBandwidthStrategy) wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error) {
s.mtx.Lock()
var user string
if metadata != nil {
user = metadata.User
}
strategy, ok := s.strategies[user]
s.mtx.Unlock()
if !ok {
return nil, E.New("user strategy not found: ", user)
}
return strategy.wrapPacketConn(ctx, conn, metadata, reverse)
}
func (s *ManagerBandwidthStrategy) UpdateStrategies(strategies map[string]BandwidthStrategy) {
s.mtx.Lock()
defer s.mtx.Unlock()
var closedEntries []*bwConnEntry
for user, entries := range s.conns {
if _, exists := strategies[user]; !exists {
closedEntries = append(closedEntries, entries...)
delete(s.conns, user)
}
}
s.strategies = strategies
s.mtx.Unlock()
for _, entry := range closedEntries {
entry.conn.Close()
}
}
type BypassBandwidthStrategy struct{}
@@ -263,8 +352,8 @@ func CreateStrategy(strategy string, mode string, connectionType string, speed u
return NewDefaultWrapStrategy(limiterStrategy, connWrapper, packetConnWrapper), nil
}
func createSpeedLimiter(speed uint64, flowKeys []string) (Limiter, error) {
var limiter Limiter = rate.NewLimiter(rate.Limit(float64(speed)), 65536)
func createSpeedLimiter(speed uint64, flowKeys []string) (BandwidthLimiter, error) {
var limiter BandwidthLimiter = &speedLimiter{limiter: rate.NewLimiter(rate.Limit(float64(speed)), 65536)}
for i := len(flowKeys) - 1; i >= 0; i-- {
getter, err := flowKeysConnIDGetter(flowKeys[i])
if err != nil {
@@ -275,16 +364,24 @@ func createSpeedLimiter(speed uint64, flowKeys []string) (Limiter, error) {
return limiter, nil
}
type speedLimiter struct {
limiter *rate.Limiter
}
func (r *speedLimiter) WaitN(ctx context.Context, n int) error {
return r.limiter.WaitN(ctx, n)
}
func (r *speedLimiter) SetSpeed(speed uint64) {
r.limiter.SetLimit(rate.Limit(float64(speed)))
}
func flowKeysConnIDGetter(name string) (ConnIDGetter, error) {
switch name {
case "user":
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
return metadata.User, true
}, nil
case "destination":
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
return metadata.Destination.String(), true
}, nil
case "source_ip":
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
return metadata.Source.IPAddr().String(), true
@@ -302,6 +399,14 @@ func flowKeysConnIDGetter(name string) (ConnIDGetter, error) {
}
return strconv.FormatUint(uint64(id.ID), 10), ok
}, nil
case "protocol":
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
return metadata.Protocol, metadata.Protocol != ""
}, nil
case "destination":
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
return metadata.Destination.String(), true
}, nil
default:
return nil, E.New("flow key not found: ", name)
}

View File

@@ -4,13 +4,14 @@ import (
"context"
"sync"
"github.com/sagernet/sing-box/common/onclose"
E "github.com/sagernet/sing/common/exceptions"
)
func NewDefaultLock(max uint32) LockIDGetter {
locks := make(map[string]*uint32)
mtx := sync.Mutex{}
return func(id string) (CloseHandlerFunc, context.Context, error) {
return func(id string) (onclose.CloseHandlerFunc, context.Context, error) {
mtx.Lock()
defer mtx.Unlock()
handles, ok := locks[id]
@@ -22,16 +23,13 @@ func NewDefaultLock(max uint32) LockIDGetter {
locks[id] = handles
}
*handles++
var once sync.Once
return func() {
once.Do(func() {
mtx.Lock()
defer mtx.Unlock()
*handles--
if *handles == 0 {
delete(locks, id)
}
})
mtx.Lock()
defer mtx.Unlock()
*handles--
if *handles == 0 {
delete(locks, id)
}
}, nil, nil
}
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/common/onclose"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/route"
@@ -110,7 +111,7 @@ func (h *Outbound) DialContext(ctx context.Context, network string, destination
onClose()
return nil, err
}
conn = newConnWithCloseHandlerFunc(conn, onClose)
conn = onclose.NewConn(conn, onClose)
if lockCtx != nil {
go connChecker(lockCtx, conn.Close)
}
@@ -127,7 +128,7 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
onClose()
return nil, err
}
conn = newPacketConnWithCloseHandlerFunc(conn, onClose)
conn = onclose.NewPacketConn(conn, onClose)
if lockCtx != nil {
go connChecker(lockCtx, conn.Close)
}
@@ -141,7 +142,7 @@ func (h *Outbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata
N.CloseOnHandshakeFailure(conn, onClose, err)
return
}
conn = newConnWithCloseHandlerFunc(conn, limiterOnClose)
conn = onclose.NewConn(conn, limiterOnClose)
if lockCtx != nil {
go connChecker(lockCtx, conn.Close)
}
@@ -158,7 +159,7 @@ func (h *Outbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
N.CloseOnHandshakeFailure(conn, onClose, err)
return
}
conn = bufio.NewPacketConn(newPacketConnWithCloseHandlerFunc(bufio.NewNetPacketConn(conn), limiterOnClose))
conn = bufio.NewPacketConn(onclose.NewPacketConn(bufio.NewNetPacketConn(conn), limiterOnClose))
if lockCtx != nil {
go connChecker(lockCtx, conn.Close)
}
@@ -172,33 +173,7 @@ func (h *Outbound) GetStrategy() ConnectionStrategy {
return h.strategy
}
type connWithCloseHandlerFunc struct {
net.Conn
onClose CloseHandlerFunc
}
func newConnWithCloseHandlerFunc(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandlerFunc {
return &connWithCloseHandlerFunc{conn, onClose}
}
func (conn *connWithCloseHandlerFunc) Close() error {
conn.onClose()
return conn.Conn.Close()
}
type packetConnWithCloseHandlerFunc struct {
net.PacketConn
onClose CloseHandlerFunc
}
func newPacketConnWithCloseHandlerFunc(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandlerFunc {
return &packetConnWithCloseHandlerFunc{conn, onClose}
}
func (conn *packetConnWithCloseHandlerFunc) Close() error {
conn.onClose()
return conn.PacketConn.Close()
}
func connChecker(ctx context.Context, closeFunc func() error) {
<-ctx.Done()

View File

@@ -6,18 +6,17 @@ import (
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/onclose"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
type (
CloseHandlerFunc = func()
ConnIDGetter = func(context.Context, *adapter.InboundContext) (string, bool)
LockIDGetter = func(string) (CloseHandlerFunc, context.Context, error)
LockIDGetter = func(string) (onclose.CloseHandlerFunc, context.Context, error)
ConnectionStrategy interface {
request(ctx context.Context, metadata *adapter.InboundContext) (onClose CloseHandlerFunc, lockCtx context.Context, err error)
request(ctx context.Context, metadata *adapter.InboundContext) (onClose onclose.CloseHandlerFunc, lockCtx context.Context, err error)
}
)
@@ -36,7 +35,7 @@ func NewDefaultConnectionStrategy(connIDGetter ConnIDGetter, lockIDGetter LockID
return outbound
}
func (s *DefaultConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
func (s *DefaultConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
id, ok := s.connIDGetter(ctx, metadata)
@@ -57,7 +56,7 @@ func NewUsersConnectionStrategy(strategies map[string]ConnectionStrategy) *Users
}
}
func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
var user string
@@ -71,20 +70,78 @@ func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter
return nil, nil, E.New("user strategy not found: ", user)
}
type cancelEntry struct {
cancel context.CancelFunc
}
type ManagerConnectionStrategy struct {
*UsersConnectionStrategy
strategies map[string]ConnectionStrategy
cancels map[string][]*cancelEntry
mtx sync.Mutex
}
func NewManagerConnectionStrategy() *ManagerConnectionStrategy {
return &ManagerConnectionStrategy{
UsersConnectionStrategy: NewUsersConnectionStrategy(map[string]ConnectionStrategy{}),
strategies: make(map[string]ConnectionStrategy),
cancels: make(map[string][]*cancelEntry),
}
}
func (s *ManagerConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
s.mtx.Lock()
var user string
if metadata != nil {
user = metadata.User
}
strategy, ok := s.strategies[user]
if !ok {
s.mtx.Unlock()
return nil, nil, E.New("user strategy not found: ", user)
}
s.mtx.Unlock()
onClose, _, err := strategy.request(ctx, metadata)
if err != nil {
return nil, nil, err
}
cancelCtx, cancel := context.WithCancel(context.Background())
entry := &cancelEntry{cancel: cancel}
s.mtx.Lock()
s.cancels[user] = append(s.cancels[user], entry)
s.mtx.Unlock()
originalOnClose := onClose
wrappedOnClose := func() {
s.mtx.Lock()
entries := s.cancels[user]
for i, e := range entries {
if e == entry {
s.cancels[user] = append(entries[:i], entries[i+1:]...)
break
}
}
s.mtx.Unlock()
cancel()
if originalOnClose != nil {
originalOnClose()
}
}
return wrappedOnClose, cancelCtx, nil
}
func (s *ManagerConnectionStrategy) UpdateStrategies(strategies map[string]ConnectionStrategy) {
s.mtx.Lock()
defer s.mtx.Unlock()
var entries []*cancelEntry
for user, cancels := range s.cancels {
if _, exists := strategies[user]; !exists {
entries = append(entries, cancels...)
delete(s.cancels, user)
}
}
s.strategies = strategies
s.mtx.Unlock()
for _, entry := range entries {
entry.cancel()
}
}
type BypassConnectionStrategy struct{}
@@ -93,7 +150,7 @@ func NewBypassConnectionStrategy() *BypassConnectionStrategy {
return &BypassConnectionStrategy{}
}
func (s *BypassConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
func (s *BypassConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
return func() {}, nil, nil
}

View File

@@ -6,11 +6,11 @@ import (
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/onclose"
E "github.com/sagernet/sing/common/exceptions"
)
type (
CloseHandlerFunc = func()
ConnWrapper = func(ctx context.Context, conn net.Conn, limiter TrafficLimiter, reverse bool) net.Conn
PacketConnWrapper = func(ctx context.Context, conn net.PacketConn, limiter TrafficLimiter, reverse bool) net.PacketConn
)
@@ -72,32 +72,60 @@ func (s *GlobalTrafficStrategy) getLimiter(ctx context.Context, metadata *adapte
return s.limiter, nil
}
type connEntry struct {
conn net.Conn
}
type ManagerTrafficStrategy struct {
strategies map[string]TrafficStrategy
mtx sync.Mutex
conns map[string][]*connEntry
mtx sync.Mutex
}
func NewManagerTrafficStrategy() *ManagerTrafficStrategy {
return &ManagerTrafficStrategy{}
return &ManagerTrafficStrategy{
conns: make(map[string][]*connEntry),
}
}
func (s *ManagerTrafficStrategy) wrapConn(ctx context.Context, conn net.Conn, metadata *adapter.InboundContext, reverse bool) (net.Conn, error) {
strategy, err := s.getStrategy(ctx, metadata)
strategy, user, err := s.getStrategy(ctx, metadata)
if err != nil {
return nil, err
}
return strategy.wrapConn(ctx, conn, metadata, reverse)
wrapped, err := strategy.wrapConn(ctx, conn, metadata, reverse)
if err != nil {
return nil, err
}
entry := &connEntry{conn: conn}
s.mtx.Lock()
s.conns[user] = append(s.conns[user], entry)
s.mtx.Unlock()
return onclose.NewConn(wrapped, func() {
s.mtx.Lock()
entries := s.conns[user]
for i, e := range entries {
if e == entry {
s.conns[user] = append(entries[:i], entries[i+1:]...)
break
}
}
s.mtx.Unlock()
}), nil
}
func (s *ManagerTrafficStrategy) wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error) {
strategy, err := s.getStrategy(ctx, metadata)
strategy, _, err := s.getStrategy(ctx, metadata)
if err != nil {
return nil, err
}
return strategy.wrapPacketConn(ctx, conn, metadata, reverse)
}
func (s *ManagerTrafficStrategy) getStrategy(ctx context.Context, metadata *adapter.InboundContext) (TrafficStrategy, error) {
func (s *ManagerTrafficStrategy) getStrategy(ctx context.Context, metadata *adapter.InboundContext) (TrafficStrategy, string, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
var user string
@@ -106,15 +134,25 @@ func (s *ManagerTrafficStrategy) getStrategy(ctx context.Context, metadata *adap
}
strategy, ok := s.strategies[user]
if ok {
return strategy, nil
return strategy, user, nil
}
return nil, E.New("user strategy not found: ", user)
return nil, user, E.New("user strategy not found: ", user)
}
func (s *ManagerTrafficStrategy) UpdateStrategies(strategies map[string]TrafficStrategy) {
s.mtx.Lock()
defer s.mtx.Unlock()
var closedEntries []*connEntry
for user, entries := range s.conns {
if _, exists := strategies[user]; !exists {
closedEntries = append(closedEntries, entries...)
delete(s.conns, user)
}
}
s.strategies = strategies
s.mtx.Unlock()
for _, entry := range closedEntries {
entry.conn.Close()
}
}
type BypassTrafficStrategy struct{}

View File

@@ -9,6 +9,7 @@ import (
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/cloudflare"
"github.com/sagernet/sing-box/common/dialer"
@@ -99,7 +100,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
logger.ErrorContext(ctx, E.New("failed to generate cert: ", err))
return
}
tlsConfig, err := tls.NewMASQUEClient(ctx, logger, "consumer-masque.cloudflareclient.com", cert, privKey, peerPubKey, options.MASQUEOutboundTLSOptions)
tlsConfig, err := tls.NewMASQUEClient(ctx, logger, "consumer-masque.cloudflareclient.com", cert, privKey, peerPubKey, common.PtrValueOrDefault(options.TLS))
if err != nil {
logger.ErrorContext(ctx, E.New("failed to prepare TLS config: ", err))
return

View File

@@ -98,6 +98,10 @@ func (h *Inbound) Close() error {
)
}
func (h *Inbound) UpdateUsers(users []auth.User) {
h.authenticator.UpdateUsers(users)
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
err := h.newConnection(ctx, conn, metadata, onClose)
N.CloseOnHandshakeFailure(conn, onClose, err)

View File

@@ -147,6 +147,10 @@ func (n *Inbound) Close() error {
)
}
func (n *Inbound) UpdateUsers(users []auth.User) {
n.authenticator.UpdateUsers(users)
}
func (n *Inbound) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
ctx := log.ContextWithNewID(request.Context())
if request.Method != "CONNECT" {

View File

@@ -0,0 +1,158 @@
//go:build with_openvpn
package openvpn
import (
"context"
"os"
"time"
"net"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
ovpn "github.com/sagernet/sing-box/transport/openvpn"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.OpenVPNOutboundOptions](registry, C.TypeOpenVPN, NewOutbound)
}
type Outbound struct {
outbound.Adapter
ctx context.Context
dnsRouter adapter.DNSRouter
logger logger.ContextLogger
tunnel *ovpn.Tunnel
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.OpenVPNOutboundOptions) (adapter.Outbound, error) {
tlsConfig, err := tls.NewOpenVPNClient(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
var tlsKey []byte
keyDirection := -1
if options.TLSAuth != "" || options.TLSAuthPath != "" {
tlsAuth := options.TLSAuth
if tlsAuth == "" {
data, err := os.ReadFile(options.TLSAuthPath)
if err != nil {
return nil, E.Cause(err, "read tls_auth_path")
}
tlsAuth = string(data)
}
tlsKey = []byte(tlsAuth)
keyDirection = options.KeyDirection
} else {
tlsCrypt := options.TLSCrypt
if tlsCrypt == "" && options.TLSCryptPath != "" {
data, err := os.ReadFile(options.TLSCryptPath)
if err != nil {
return nil, E.Cause(err, "read tls_crypt_path")
}
tlsCrypt = string(data)
}
tlsKey = []byte(tlsCrypt)
}
clientConfig := &ovpn.ClientConfig{
Proto: options.Proto,
Cipher: options.Cipher,
Auth: options.Auth,
Username: options.Username,
Password: options.Password,
KeyDirection: keyDirection,
TLSCrypt: tlsKey,
TLSCryptV2: options.TLSCryptV2,
}
if err := clientConfig.Prepare(); err != nil {
return nil, E.Cause(err, "invalid openvpn config")
}
outboundDialer, err := dialer.New(ctx, options.DialerOptions, true)
if err != nil {
return nil, err
}
o := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeOpenVPN, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx,
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
logger: logger,
}
tunnel, err := ovpn.NewTunnel(ctx, logger, ovpn.TunnelOptions{
Dialer: outboundDialer,
Servers: options.Servers,
TLSConfig: tlsConfig,
Config: clientConfig,
ReconnectDelay: time.Duration(options.ReconnectDelay),
PingInterval: time.Duration(options.PingInterval),
})
if err != nil {
return nil, err
}
o.tunnel = tunnel
return o, nil
}
func (o *Outbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStatePostStart {
return nil
}
return o.tunnel.Start()
}
func (o *Outbound) Close() error {
if o.tunnel != nil {
return o.tunnel.Close()
}
return nil
}
func (o *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch network {
case N.NetworkTCP:
o.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsDomain() {
destinationAddresses, err := o.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}
return N.DialSerial(ctx, o.tunnel, network, destination, destinationAddresses)
}
if !destination.Addr.IsValid() {
return nil, E.New("invalid destination: ", destination)
}
return o.tunnel.DialContext(ctx, network, destination)
}
func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsDomain() {
destinationAddresses, err := o.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}
packetConn, destinationAddress, err := N.ListenSerial(ctx, o.tunnel, destination, destinationAddresses)
if err != nil {
return nil, err
}
if destinationAddress.IsValid() && destination != M.SocksaddrFrom(destinationAddress, destination.Port) {
return bufio.NewNATPacketConn(bufio.NewPacketConn(packetConn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
}
return packetConn, nil
}
return o.tunnel.ListenPacket(ctx, destination)
}

View File

@@ -70,6 +70,10 @@ func (h *Inbound) Close() error {
return h.listener.Close()
}
func (h *Inbound) UpdateUsers(users []auth.User) {
h.authenticator.UpdateUsers(users)
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
err := socks.HandleConnectionEx(ctx, conn, std_bufio.NewReader(conn), h.authenticator, adapter.NewUpstreamHandlerEx(metadata, h.newUserConnection, h.streamUserPacketConnection), h.listener, h.udpTimeout, metadata.Source, onClose)
N.CloseOnHandshakeFailure(conn, onClose, err)

178
protocol/sudoku/inbound.go Normal file
View File

@@ -0,0 +1,178 @@
package sudoku
import (
"context"
"net"
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/listener"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/sudoku"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.SudokuInboundOptions](registry, C.TypeSudoku, NewInbound)
}
type Inbound struct {
inbound.Adapter
router adapter.ConnectionRouterEx
logger logger.ContextLogger
listener *listener.Listener
protoConf sudoku.ProtocolConfig
tunnelSrv *sudoku.HTTPMaskTunnelServer
fallback string
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SudokuInboundOptions) (adapter.Inbound, error) {
defaultConf := sudoku.DefaultConfig()
tableType, err := sudoku.NormalizeTableType(options.TableType)
if err != nil {
return nil, err
}
paddingMin, paddingMax := sudoku.ResolvePadding(options.PaddingMin, options.PaddingMax, defaultConf.PaddingMin, defaultConf.PaddingMax)
enablePureDownlink := sudoku.DerefBool(options.EnablePureDownlink, defaultConf.EnablePureDownlink)
handshakeTimeout := sudoku.DerefInt(options.HandshakeTimeout, defaultConf.HandshakeTimeoutSeconds)
tables, err := sudoku.NewServerTablesWithCustomPatterns(sudoku.ServerAEADSeed(options.Key), tableType, options.CustomTable, options.CustomTables)
if err != nil {
return nil, err
}
protoConf := sudoku.ProtocolConfig{
Key: options.Key,
AEADMethod: defaultConf.AEADMethod,
PaddingMin: paddingMin,
PaddingMax: paddingMax,
EnablePureDownlink: enablePureDownlink,
HandshakeTimeoutSeconds: handshakeTimeout,
DisableHTTPMask: options.DisableHTTPMask,
HTTPMaskMode: options.HTTPMaskMode,
HTTPMaskPathRoot: strings.TrimSpace(options.PathRoot),
}
if len(tables) == 1 {
protoConf.Table = tables[0]
} else {
protoConf.Tables = tables
}
if options.AEADMethod != "" {
protoConf.AEADMethod = options.AEADMethod
}
in := &Inbound{
Adapter: inbound.NewAdapter(C.TypeSudoku, tag),
router: router,
logger: logger,
protoConf: protoConf,
fallback: strings.TrimSpace(options.Fallback),
}
if in.fallback != "" {
in.tunnelSrv = sudoku.NewHTTPMaskTunnelServerWithFallback(&in.protoConf)
} else {
in.tunnelSrv = sudoku.NewHTTPMaskTunnelServer(&in.protoConf)
}
in.listener = listener.New(listener.Options{
Context: ctx,
Logger: logger,
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
ConnectionHandler: in,
})
return in, nil
}
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
return h.listener.Start()
}
func (h *Inbound) Close() error {
return h.listener.Close()
}
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
h.handleConn(ctx, conn, metadata, onClose)
}
func (h *Inbound) handleConn(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
handshakeConn := conn
handshakeCfg := &h.protoConf
if h.tunnelSrv != nil {
c, cfg, done, err := h.tunnelSrv.WrapConn(conn)
if err != nil {
conn.Close()
return
}
if done {
return
}
if c != nil {
handshakeConn = c
}
if cfg != nil {
handshakeCfg = cfg
}
}
cConn, meta, err := sudoku.ServerHandshake(handshakeConn, handshakeCfg)
if err != nil {
h.logger.DebugContext(ctx, "handshake failed: ", err)
conn.Close()
if handshakeConn != conn {
handshakeConn.Close()
}
return
}
session, err := sudoku.ReadServerSession(cConn, meta)
if err != nil {
h.logger.WarnContext(ctx, "read session failed: ", err)
cConn.Close()
return
}
switch session.Type {
case sudoku.SessionTypeUoT:
h.handleUoT(ctx, session.Conn, metadata, onClose)
case sudoku.SessionTypeMultiplex:
mux, err := sudoku.AcceptMultiplexServer(session.Conn)
if err != nil {
session.Conn.Close()
return
}
defer mux.Close()
for {
stream, target, err := mux.AcceptTCP()
if err != nil {
return
}
go h.routeTCP(ctx, stream, target, metadata)
}
default:
h.routeTCP(ctx, session.Conn, session.Target, metadata)
}
}
func (h *Inbound) routeTCP(ctx context.Context, conn net.Conn, target string, metadata adapter.InboundContext) {
destination := M.ParseSocksaddr(target)
metadata.Destination = destination
h.logger.InfoContext(ctx, "inbound connection to ", destination)
h.router.RouteConnectionEx(ctx, conn, metadata, nil)
}
func (h *Inbound) handleUoT(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
packetConn := sudoku.NewUoTPacketConn(conn)
h.router.RoutePacketConnectionEx(ctx, bufio.NewPacketConn(packetConn), metadata, onClose)
}

401
protocol/sudoku/outbound.go Normal file
View File

@@ -0,0 +1,401 @@
package sudoku
import (
"context"
"fmt"
"net"
"strings"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/sudoku"
"github.com/sagernet/sing-box/transport/sudoku/obfs/httpmask"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.SudokuOutboundOptions](registry, C.TypeSudoku, NewOutbound)
}
type Outbound struct {
outbound.Adapter
logger logger.ContextLogger
dialer N.Dialer
tlsConfig tls.Config
baseConf sudoku.ProtocolConfig
muxMu sync.Mutex
muxClient *sudoku.MultiplexClient
httpMaskMu sync.Mutex
httpMaskClient *httpmask.TunnelClient
httpMaskKey string
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SudokuOutboundOptions) (adapter.Outbound, error) {
outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
if err != nil {
return nil, err
}
defaultConf := sudoku.DefaultConfig()
tableType, err := sudoku.NormalizeTableType(options.TableType)
if err != nil {
return nil, err
}
paddingMin, paddingMax := sudoku.ResolvePadding(options.PaddingMin, options.PaddingMax, defaultConf.PaddingMin, defaultConf.PaddingMax)
enablePureDownlink := sudoku.DerefBool(options.EnablePureDownlink, defaultConf.EnablePureDownlink)
serverAddr := options.ServerOptions.Build()
disableHTTPMask := defaultConf.DisableHTTPMask
httpMaskMode := defaultConf.HTTPMaskMode
var httpMaskHost string
var pathRoot string
httpMaskMultiplex := defaultConf.HTTPMaskMultiplex
if hm := options.HTTPMask; hm != nil {
disableHTTPMask = !hm.Enabled
if hm.Mode != "" {
httpMaskMode = hm.Mode
}
httpMaskHost = hm.Host
pathRoot = strings.TrimSpace(hm.PathRoot)
if hm.Multiplex != "" {
httpMaskMultiplex = hm.Multiplex
}
}
baseConf := sudoku.ProtocolConfig{
ServerAddress: serverAddr.String(),
Key: options.Key,
AEADMethod: defaultConf.AEADMethod,
PaddingMin: paddingMin,
PaddingMax: paddingMax,
EnablePureDownlink: enablePureDownlink,
HandshakeTimeoutSeconds: defaultConf.HandshakeTimeoutSeconds,
DisableHTTPMask: disableHTTPMask,
HTTPMaskMode: httpMaskMode,
HTTPMaskHost: httpMaskHost,
HTTPMaskPathRoot: pathRoot,
HTTPMaskMultiplex: httpMaskMultiplex,
}
if options.AEADMethod != "" {
baseConf.AEADMethod = options.AEADMethod
}
tables, err := sudoku.NewClientTablesWithCustomPatterns(sudoku.ClientAEADSeed(options.Key), tableType, options.CustomTable, options.CustomTables)
if err != nil {
return nil, E.Cause(err, "build table(s)")
}
if len(tables) == 1 {
baseConf.Table = tables[0]
} else {
baseConf.Tables = tables
}
out := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSudoku, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
logger: logger,
dialer: outboundDialer,
baseConf: baseConf,
}
if hm := options.HTTPMask; !disableHTTPMask && hm != nil && hm.TLS != nil && hm.TLS.Enabled {
tlsOptions := option.OutboundTLSOptions{
Enabled: true,
ServerName: options.Server,
Fragment: hm.TLS.Fragment,
FragmentFallbackDelay: hm.TLS.FragmentFallbackDelay,
RecordFragment: hm.TLS.RecordFragment,
KernelTx: hm.TLS.KernelTx,
KernelRx: hm.TLS.KernelRx,
}
out.tlsConfig, err = tls.NewClientWithOptions(tls.ClientOptions{
Context: ctx,
Logger: logger,
ServerAddress: options.Server,
Options: tlsOptions,
})
if err != nil {
return nil, err
}
}
return out, nil
}
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
h.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
cfg := h.baseConf
cfg.TargetAddress = destination.String()
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
stream, err := h.dialMultiplex(ctx, cfg.TargetAddress)
if err == nil {
return stream, nil
}
return nil, err
}
c, err := h.dialAndHandshake(ctx, &cfg)
if err != nil {
return nil, err
}
addrBuf, err := sudoku.EncodeAddress(cfg.TargetAddress)
if err != nil {
c.Close()
return nil, E.Cause(err, "encode target address")
}
if err = sudoku.WriteKIPMessage(c, sudoku.KIPTypeOpenTCP, addrBuf); err != nil {
c.Close()
return nil, E.Cause(err, "send target address")
}
return c, nil
}
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
cfg := h.baseConf
cfg.TargetAddress = destination.String()
c, err := h.dialAndHandshake(ctx, &cfg)
if err != nil {
return nil, err
}
if err = sudoku.WriteKIPMessage(c, sudoku.KIPTypeStartUoT, nil); err != nil {
c.Close()
return nil, E.Cause(err, "start uot")
}
return bufio.NewBindPacketConn(sudoku.NewUoTPacketConn(c), destination), nil
}
func (h *Outbound) Close() error {
h.resetMuxClient()
h.resetHTTPMaskClient()
return common.Close(h.tlsConfig)
}
func (h *Outbound) InterfaceUpdated() {
h.resetMuxClient()
h.resetHTTPMaskClient()
}
func (h *Outbound) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfig) (net.Conn, error) {
handshakeCfg := *cfg
if !handshakeCfg.DisableHTTPMask && httpTunnelModeEnabled(handshakeCfg.HTTPMaskMode) {
handshakeCfg.DisableHTTPMask = true
}
upgrade := func(raw net.Conn) (net.Conn, error) {
return sudoku.ClientHandshake(raw, &handshakeCfg)
}
var c net.Conn
var err error
var handshakeDone bool
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
if muxMode == "auto" && strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) != "ws" {
if client, cerr := h.getOrCreateHTTPMaskClient(cfg); cerr == nil && client != nil {
c, err = client.DialTunnel(ctx, httpmask.TunnelDialOptions{
Mode: cfg.HTTPMaskMode,
TLSConfig: h.httpMaskTLSConfig(),
HostOverride: cfg.HTTPMaskHost,
PathRoot: cfg.HTTPMaskPathRoot,
AuthKey: sudoku.ClientAEADSeed(cfg.Key),
Upgrade: upgrade,
Multiplex: cfg.HTTPMaskMultiplex,
DialContext: h.dialRaw,
})
if err != nil {
h.resetHTTPMaskClient()
}
}
}
if c == nil && err == nil {
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, h.dialRaw, upgrade)
}
if err == nil && c != nil {
handshakeDone = true
}
}
if c == nil && err == nil {
c, err = h.dialer.DialContext(ctx, N.NetworkTCP, M.ParseSocksaddr(cfg.ServerAddress))
}
if err != nil {
return nil, E.Cause(err, "connect to ", cfg.ServerAddress)
}
if !handshakeDone {
c, err = sudoku.ClientHandshake(c, &handshakeCfg)
if err != nil {
common.Close(c)
return nil, err
}
}
return c, nil
}
func (h *Outbound) dialRaw(ctx context.Context, network, addr string) (net.Conn, error) {
return h.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
}
func (h *Outbound) httpMaskTLSConfig() httpmask.TLSClientConfig {
if h.tlsConfig == nil {
return nil
}
return tlsConfigAdapter{h.tlsConfig}
}
type tlsConfigAdapter struct {
config tls.Config
}
func (a tlsConfigAdapter) Client(conn net.Conn) (net.Conn, error) {
return a.config.Client(conn)
}
func (h *Outbound) dialMultiplex(ctx context.Context, targetAddress string) (net.Conn, error) {
for attempt := 0; attempt < 2; attempt++ {
client, err := h.getOrCreateMuxClient(ctx)
if err != nil {
return nil, err
}
stream, err := client.Dial(ctx, targetAddress)
if err != nil {
h.resetMuxClient()
continue
}
return stream, nil
}
return nil, fmt.Errorf("multiplex open stream failed")
}
func (h *Outbound) getOrCreateMuxClient(ctx context.Context) (*sudoku.MultiplexClient, error) {
h.muxMu.Lock()
defer h.muxMu.Unlock()
if h.muxClient != nil && !h.muxClient.IsClosed() {
return h.muxClient, nil
}
baseCfg := h.baseConf
baseConn, err := h.dialAndHandshake(ctx, &baseCfg)
if err != nil {
return nil, err
}
client, err := sudoku.StartMultiplexClient(baseConn)
if err != nil {
baseConn.Close()
return nil, err
}
h.muxClient = client
return client, nil
}
func (h *Outbound) resetMuxClient() {
h.muxMu.Lock()
defer h.muxMu.Unlock()
if h.muxClient != nil {
h.muxClient.Close()
h.muxClient = nil
}
}
func (h *Outbound) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*httpmask.TunnelClient, error) {
key := cfg.ServerAddress + "|" + fmt.Sprint(h.tlsConfig != nil) + "|" + strings.TrimSpace(cfg.HTTPMaskHost)
h.httpMaskMu.Lock()
if h.httpMaskClient != nil && h.httpMaskKey == key {
client := h.httpMaskClient
h.httpMaskMu.Unlock()
return client, nil
}
h.httpMaskMu.Unlock()
client, err := httpmask.NewTunnelClient(cfg.ServerAddress, httpmask.TunnelClientOptions{
TLSConfig: h.httpMaskTLSConfig(),
HostOverride: cfg.HTTPMaskHost,
DialContext: h.dialRaw,
MaxIdleConns: 32,
})
if err != nil {
return nil, err
}
h.httpMaskMu.Lock()
defer h.httpMaskMu.Unlock()
if h.httpMaskClient != nil && h.httpMaskKey == key {
client.CloseIdleConnections()
return h.httpMaskClient, nil
}
if h.httpMaskClient != nil {
h.httpMaskClient.CloseIdleConnections()
}
h.httpMaskClient = client
h.httpMaskKey = key
return client, nil
}
func (h *Outbound) resetHTTPMaskClient() {
h.httpMaskMu.Lock()
defer h.httpMaskMu.Unlock()
if h.httpMaskClient != nil {
h.httpMaskClient.CloseIdleConnections()
h.httpMaskClient = nil
h.httpMaskKey = ""
}
}
func normalizeHTTPMaskMultiplex(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "", "off":
return "off"
case "auto":
return "auto"
case "on":
return "on"
default:
return "off"
}
}
func httpTunnelModeEnabled(mode string) bool {
switch strings.ToLower(strings.TrimSpace(mode)) {
case "stream", "poll", "auto", "ws":
return true
default:
return false
}
}

View File

@@ -0,0 +1,190 @@
package trusttunnel
import (
"context"
"errors"
"net"
"net/http"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/common/listener"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/trusttunnel"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.TrustTunnelInboundOptions](registry, C.TypeTrustTunnel, NewInbound)
}
type Inbound struct {
inbound.Adapter
ctx context.Context
router adapter.Router
logger logger.ContextLogger
listener *listener.Listener
tlsConfig tls.ServerConfig
service *trusttunnel.Service
httpServer *http.Server
quicService *trusttunnel.QUICService
network []string
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrustTunnelInboundOptions) (adapter.Inbound, error) {
if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired
}
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
networkList := options.Network.Build()
if len(networkList) == 0 {
networkList = []string{N.NetworkTCP}
}
inbound := &Inbound{
Adapter: inbound.NewAdapter(C.TypeTrustTunnel, tag),
ctx: ctx,
router: router,
logger: logger,
tlsConfig: tlsConfig,
network: networkList,
listener: listener.New(listener.Options{
Context: ctx,
Logger: logger,
Listen: options.ListenOptions,
}),
}
service := trusttunnel.NewService(trusttunnel.ServiceOptions{
Ctx: ctx,
Logger: logger,
Handler: (*inboundHandler)(inbound),
})
userMap := make(map[string]string, len(options.Users))
for _, u := range options.Users {
userMap[u.Name] = u.Password
}
service.UpdateUsers(userMap)
inbound.service = service
if common.Contains(networkList, N.NetworkUDP) {
inbound.quicService = trusttunnel.NewQUICService(service, options.CongestionController, options.CWND, options.BBRProfile)
}
return inbound, nil
}
func (h *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
if h.tlsConfig != nil {
err := h.tlsConfig.Start()
if err != nil {
return err
}
}
if common.Contains(h.network, N.NetworkTCP) {
tcpListener, err := h.listener.ListenTCP()
if err != nil {
return err
}
h.httpServer = &http.Server{
Handler: h2c.NewHandler(h.service, &http2.Server{}),
BaseContext: func(net.Listener) context.Context {
return h.ctx
},
}
go func() {
var l net.Listener = tcpListener
if h.tlsConfig != nil {
if len(h.tlsConfig.NextProtos()) == 0 {
h.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
} else if !common.Contains(h.tlsConfig.NextProtos(), http2.NextProtoTLS) {
h.tlsConfig.SetNextProtos(append([]string{http2.NextProtoTLS}, h.tlsConfig.NextProtos()...))
}
l = aTLS.NewListener(tcpListener, h.tlsConfig)
}
sErr := h.httpServer.Serve(l)
if sErr != nil && !errors.Is(sErr, http.ErrServerClosed) {
h.logger.Error("HTTP server error: ", sErr)
}
}()
}
if common.Contains(h.network, N.NetworkUDP) {
udpConn, err := h.listener.ListenUDP()
if err != nil {
return err
}
err = h.quicService.Start(h.ctx, udpConn, h.tlsConfig)
if err != nil {
return err
}
}
return nil
}
func (h *Inbound) Close() error {
return common.Close(
h.listener,
common.PtrOrNil(h.httpServer),
h.quicService,
h.tlsConfig,
)
}
func (h *Inbound) UpdateUsers(users []option.TrustTunnelUser) {
userMap := make(map[string]string, len(users))
for _, u := range users {
userMap[u.Name] = u.Password
}
h.service.UpdateUsers(userMap)
}
type inboundHandler Inbound
func (h *inboundHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
var inboundCtx adapter.InboundContext
inboundCtx.Inbound = h.Tag()
inboundCtx.InboundType = h.Type()
//nolint:staticcheck
inboundCtx.InboundDetour = h.listener.ListenOptions().Detour
inboundCtx.Source = metadata.Source
inboundCtx.Destination = metadata.Destination
if userName, _ := auth.UserFromContext[string](ctx); userName != "" {
inboundCtx.User = userName
h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", inboundCtx.Destination)
} else {
h.logger.InfoContext(ctx, "inbound connection to ", inboundCtx.Destination)
}
h.router.RouteConnectionEx(ctx, conn, inboundCtx, nil)
return nil
}
func (h *inboundHandler) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
var inboundCtx adapter.InboundContext
inboundCtx.Inbound = h.Tag()
inboundCtx.InboundType = h.Type()
//nolint:staticcheck
inboundCtx.InboundDetour = h.listener.ListenOptions().Detour
inboundCtx.Source = metadata.Source
inboundCtx.Destination = metadata.Destination
if userName, _ := auth.UserFromContext[string](ctx); userName != "" {
inboundCtx.User = userName
h.logger.InfoContext(ctx, "[", userName, "] inbound packet connection to ", inboundCtx.Destination)
} else {
h.logger.InfoContext(ctx, "inbound packet connection to ", inboundCtx.Destination)
}
h.router.RoutePacketConnectionEx(ctx, conn, inboundCtx, nil)
return nil
}

View File

@@ -0,0 +1,118 @@
package trusttunnel
import (
"context"
"net"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/trusttunnel"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"golang.org/x/net/http2"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.TrustTunnelOutboundOptions](registry, C.TypeTrustTunnel, NewOutbound)
}
type Outbound struct {
outbound.Adapter
logger logger.ContextLogger
client trusttunnel.Dialer
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrustTunnelOutboundOptions) (adapter.Outbound, error) {
if options.TLS == nil || !options.TLS.Enabled {
return nil, C.ErrTLSRequired
}
outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
if err != nil {
return nil, err
}
serverAddr := options.ServerOptions.Build()
networkList := options.Network.Build()
clientOpts := trusttunnel.ClientOptions{
Server: serverAddr,
Username: options.Username,
Password: options.Password,
QUIC: options.QUIC,
CongestionControl: options.CongestionController,
CWND: options.CWND,
BBRProfile: options.BBRProfile,
HealthCheck: options.HealthCheck,
}
if options.QUIC {
tlsConfig, err := tls.NewClient(ctx, logger, options.Server, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
if len(tlsConfig.NextProtos()) == 0 {
tlsConfig.SetNextProtos([]string{"h3"})
}
clientOpts.QUICDialer = outboundDialer
clientOpts.QUICTLSConfig = tlsConfig
} else {
tlsConfig, err := tls.NewClient(ctx, logger, options.Server, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
if len(tlsConfig.NextProtos()) == 0 {
tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
}
clientOpts.TLSDialer = tls.NewDialer(outboundDialer, tlsConfig)
}
var client trusttunnel.Dialer
if options.Multiplex != nil && options.Multiplex.Enabled {
clientOpts.MaxConnections = options.Multiplex.MaxConnections
clientOpts.MinStreams = options.Multiplex.MinStreams
clientOpts.MaxStreams = options.Multiplex.MaxStreams
client, err = trusttunnel.NewMultiplexClient(ctx, clientOpts)
} else {
client, err = trusttunnel.NewClient(ctx, clientOpts)
}
if err != nil {
return nil, err
}
return &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrustTunnel, tag, networkList, options.DialerOptions),
logger: logger,
client: client,
}, nil
}
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
switch N.NetworkName(network) {
case N.NetworkTCP:
h.logger.InfoContext(ctx, "outbound connection to ", destination)
return h.client.Dial(ctx, destination.String())
case N.NetworkUDP:
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
conn, err := h.client.ListenPacket(ctx)
if err != nil {
return nil, err
}
return bufio.NewBindPacketConn(conn, destination), nil
default:
return nil, E.New("unsupported network: ", network)
}
}
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
return h.client.ListenPacket(ctx)
}
func (h *Outbound) Close() error {
return h.client.Close()
}