mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-06 11:14:56 +03:00
Add OpenVPN, TrustTunnel, Sudoku, inbound managers. Fixes
This commit is contained in:
@@ -96,6 +96,12 @@ func (h *Inbound) Close() error {
|
||||
return common.Close(h.listener, h.tlsConfig)
|
||||
}
|
||||
|
||||
func (h *Inbound) UpdateUsers(users []option.AnyTLSUser) {
|
||||
h.service.UpdateUsers(common.Map(users, func(it option.AnyTLSUser) anytls.User {
|
||||
return anytls.User(it)
|
||||
}))
|
||||
}
|
||||
|
||||
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
if h.tlsConfig != nil {
|
||||
tlsConn, err := tls.ServerHandshake(ctx, conn, h.tlsConfig)
|
||||
|
||||
@@ -86,6 +86,10 @@ func (h *Inbound) Close() error {
|
||||
)
|
||||
}
|
||||
|
||||
func (h *Inbound) UpdateUsers(users []auth.User) {
|
||||
h.authenticator.UpdateUsers(users)
|
||||
}
|
||||
|
||||
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
if h.tlsConfig != nil {
|
||||
tlsConn, err := tls.ServerHandshake(ctx, conn, h.tlsConfig)
|
||||
|
||||
@@ -3,15 +3,17 @@ package bandwidth
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
)
|
||||
|
||||
type connWithDownloadBandwidthLimiter struct {
|
||||
net.Conn
|
||||
ctx context.Context
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
}
|
||||
|
||||
func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter Limiter) *connWithDownloadBandwidthLimiter {
|
||||
func NewConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter BandwidthLimiter) *connWithDownloadBandwidthLimiter {
|
||||
return &connWithDownloadBandwidthLimiter{conn, ctx, limiter}
|
||||
}
|
||||
|
||||
@@ -26,10 +28,10 @@ func (conn *connWithDownloadBandwidthLimiter) Write(p []byte) (n int, err error)
|
||||
type connWithUploadBandwidthLimiter struct {
|
||||
net.Conn
|
||||
ctx context.Context
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
}
|
||||
|
||||
func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter Limiter) *connWithUploadBandwidthLimiter {
|
||||
func NewConnWithUploadBandwidthLimiter(ctx context.Context, conn net.Conn, limiter BandwidthLimiter) *connWithUploadBandwidthLimiter {
|
||||
return &connWithUploadBandwidthLimiter{conn, ctx, limiter}
|
||||
}
|
||||
|
||||
@@ -47,10 +49,10 @@ func (conn *connWithUploadBandwidthLimiter) Read(p []byte) (n int, err error) {
|
||||
|
||||
type connWithCloseHandler struct {
|
||||
net.Conn
|
||||
onClose CloseHandlerFunc
|
||||
onClose onclose.CloseHandlerFunc
|
||||
}
|
||||
|
||||
func NewConnWithCloseHandler(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandler {
|
||||
func NewConnWithCloseHandler(conn net.Conn, onClose onclose.CloseHandlerFunc) *connWithCloseHandler {
|
||||
return &connWithCloseHandler{conn, onClose}
|
||||
}
|
||||
|
||||
@@ -62,10 +64,10 @@ func (conn *connWithCloseHandler) Close() error {
|
||||
type packetConnWithDownloadBandwidthLimiter struct {
|
||||
net.PacketConn
|
||||
ctx context.Context
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
}
|
||||
|
||||
func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter Limiter) *packetConnWithDownloadBandwidthLimiter {
|
||||
func NewPacketConnWithDownloadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter) *packetConnWithDownloadBandwidthLimiter {
|
||||
return &packetConnWithDownloadBandwidthLimiter{conn, ctx, limiter}
|
||||
}
|
||||
|
||||
@@ -80,10 +82,10 @@ func (conn *packetConnWithDownloadBandwidthLimiter) WriteTo(p []byte, addr net.A
|
||||
type packetConnWithUploadBandwidthLimiter struct {
|
||||
net.PacketConn
|
||||
ctx context.Context
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
}
|
||||
|
||||
func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter Limiter) *packetConnWithUploadBandwidthLimiter {
|
||||
func NewPacketConnWithUploadBandwidthLimiter(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter) *packetConnWithUploadBandwidthLimiter {
|
||||
return &packetConnWithUploadBandwidthLimiter{conn, ctx, limiter}
|
||||
}
|
||||
|
||||
@@ -101,10 +103,10 @@ func (conn *packetConnWithUploadBandwidthLimiter) ReadFrom(p []byte) (n int, add
|
||||
|
||||
type packetConnWithCloseHandler struct {
|
||||
net.PacketConn
|
||||
onClose CloseHandlerFunc
|
||||
onClose onclose.CloseHandlerFunc
|
||||
}
|
||||
|
||||
func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandler {
|
||||
func NewPacketConnWithCloseHandler(conn net.PacketConn, onClose onclose.CloseHandlerFunc) *packetConnWithCloseHandler {
|
||||
return &packetConnWithCloseHandler{conn, onClose}
|
||||
}
|
||||
|
||||
@@ -113,38 +115,38 @@ func (conn *packetConnWithCloseHandler) Close() error {
|
||||
return conn.PacketConn.Close()
|
||||
}
|
||||
|
||||
func connWithDownloadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn {
|
||||
func connWithDownloadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn {
|
||||
if reverse {
|
||||
return NewConnWithUploadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
return NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
|
||||
func connWithUploadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn {
|
||||
func connWithUploadBandwidthWrapper(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn {
|
||||
if reverse {
|
||||
return NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
return NewConnWithUploadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
|
||||
func connWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn {
|
||||
func connWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn {
|
||||
return NewConnWithUploadBandwidthLimiter(ctx, NewConnWithDownloadBandwidthLimiter(ctx, conn, limiter), limiter)
|
||||
}
|
||||
|
||||
func packetConnWithDownloadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn {
|
||||
func packetConnWithDownloadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn {
|
||||
if reverse {
|
||||
return NewPacketConnWithUploadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
return NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
|
||||
func packetConnWithUploadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn {
|
||||
func packetConnWithUploadBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn {
|
||||
if reverse {
|
||||
return NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
return NewPacketConnWithUploadBandwidthLimiter(ctx, conn, limiter)
|
||||
}
|
||||
|
||||
func packetConnWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn {
|
||||
func packetConnWithBidirectionalBandwidthWrapper(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn {
|
||||
return NewPacketConnWithUploadBandwidthLimiter(ctx, NewPacketConnWithDownloadBandwidthLimiter(ctx, conn, limiter), limiter)
|
||||
}
|
||||
|
||||
@@ -9,12 +9,13 @@ import (
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
)
|
||||
|
||||
type Limiter interface {
|
||||
type BandwidthLimiter interface {
|
||||
WaitN(ctx context.Context, n int) (err error)
|
||||
SetSpeed(speed uint64)
|
||||
}
|
||||
|
||||
type FlowKeysLimiter struct {
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
connIDGetter ConnIDGetter
|
||||
|
||||
waits map[string][]*wait
|
||||
@@ -25,7 +26,7 @@ type FlowKeysLimiter struct {
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter Limiter) *FlowKeysLimiter {
|
||||
func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter BandwidthLimiter) *FlowKeysLimiter {
|
||||
return &FlowKeysLimiter{
|
||||
limiter: limiter,
|
||||
connIDGetter: connIDGetter,
|
||||
@@ -36,6 +37,10 @@ func NewFlowKeysLimiter(connIDGetter ConnIDGetter, limiter Limiter) *FlowKeysLim
|
||||
}
|
||||
}
|
||||
|
||||
func (l *FlowKeysLimiter) SetSpeed(speed uint64) {
|
||||
l.limiter.SetSpeed(speed)
|
||||
}
|
||||
|
||||
func (l *FlowKeysLimiter) WaitN(ctx context.Context, n int) error {
|
||||
id, _ := l.connIDGetter(ctx, adapter.ContextFrom(ctx))
|
||||
mainWait := &wait{ctx, make(chan struct{}), n}
|
||||
|
||||
@@ -7,16 +7,16 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
type (
|
||||
CloseHandlerFunc = func()
|
||||
ConnIDGetter = func(context.Context, *adapter.InboundContext) (string, bool)
|
||||
ConnWrapper = func(ctx context.Context, conn net.Conn, limiter Limiter, reverse bool) net.Conn
|
||||
PacketConnWrapper = func(ctx context.Context, conn net.PacketConn, limiter Limiter, reverse bool) net.PacketConn
|
||||
ConnWrapper = func(ctx context.Context, conn net.Conn, limiter BandwidthLimiter, reverse bool) net.Conn
|
||||
PacketConnWrapper = func(ctx context.Context, conn net.PacketConn, limiter BandwidthLimiter, reverse bool) net.PacketConn
|
||||
)
|
||||
|
||||
type BandwidthStrategy interface {
|
||||
@@ -24,8 +24,12 @@ type BandwidthStrategy interface {
|
||||
wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error)
|
||||
}
|
||||
|
||||
type SpeedUpdater interface {
|
||||
SetSpeed(speed uint64)
|
||||
}
|
||||
|
||||
type BandwidthLimiterStrategy interface {
|
||||
getLimiter(ctx context.Context, metadata *adapter.InboundContext) (Limiter, CloseHandlerFunc, error)
|
||||
getLimiter(ctx context.Context, metadata *adapter.InboundContext) (BandwidthLimiter, onclose.CloseHandlerFunc, error)
|
||||
}
|
||||
|
||||
type DefaultWrapStrategy struct {
|
||||
@@ -54,8 +58,14 @@ func (s *DefaultWrapStrategy) wrapPacketConn(ctx context.Context, conn net.Packe
|
||||
return NewPacketConnWithCloseHandler(s.packetConnWrapper(ctx, conn, limiter, reverse), onClose), nil
|
||||
}
|
||||
|
||||
func (s *DefaultWrapStrategy) SetSpeed(speed uint64) {
|
||||
if updater, ok := s.limiterStrategy.(SpeedUpdater); ok {
|
||||
updater.SetSpeed(speed)
|
||||
}
|
||||
}
|
||||
|
||||
type GlobalBandwidthStrategy struct {
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
}
|
||||
|
||||
func NewGlobalBandwidthStrategy(speed uint64, flowKeys []string) (*GlobalBandwidthStrategy, error) {
|
||||
@@ -68,12 +78,16 @@ func NewGlobalBandwidthStrategy(speed uint64, flowKeys []string) (*GlobalBandwid
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *GlobalBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (Limiter, CloseHandlerFunc, error) {
|
||||
func (s *GlobalBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (BandwidthLimiter, onclose.CloseHandlerFunc, error) {
|
||||
return s.limiter, func() {}, nil
|
||||
}
|
||||
|
||||
func (s *GlobalBandwidthStrategy) SetSpeed(speed uint64) {
|
||||
s.limiter.SetSpeed(speed)
|
||||
}
|
||||
|
||||
type idBandwidthLimiter struct {
|
||||
limiter Limiter
|
||||
limiter BandwidthLimiter
|
||||
handles uint32
|
||||
}
|
||||
|
||||
@@ -94,7 +108,7 @@ func NewConnectionBandwidthStrategy(connIDGetter ConnIDGetter, speed uint64, flo
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ConnectionBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (Limiter, CloseHandlerFunc, error) {
|
||||
func (s *ConnectionBandwidthStrategy) getLimiter(ctx context.Context, metadata *adapter.InboundContext) (BandwidthLimiter, onclose.CloseHandlerFunc, error) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
id, ok := s.connIDGetter(ctx, metadata)
|
||||
@@ -126,6 +140,15 @@ func (s *ConnectionBandwidthStrategy) getLimiter(ctx context.Context, metadata *
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *ConnectionBandwidthStrategy) SetSpeed(speed uint64) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
s.speed = speed
|
||||
for _, limiter := range s.limiters {
|
||||
limiter.limiter.SetSpeed(speed)
|
||||
}
|
||||
}
|
||||
|
||||
type UsersBandwidthStrategy struct {
|
||||
strategies map[string]BandwidthStrategy
|
||||
mtx sync.Mutex
|
||||
@@ -167,20 +190,86 @@ func (s *UsersBandwidthStrategy) getStrategy(ctx context.Context, metadata *adap
|
||||
return nil, E.New("user strategy not found: ", user)
|
||||
}
|
||||
|
||||
type bwConnEntry struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
|
||||
|
||||
type ManagerBandwidthStrategy struct {
|
||||
*UsersBandwidthStrategy
|
||||
strategies map[string]BandwidthStrategy
|
||||
conns map[string][]*bwConnEntry
|
||||
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewManagerBandwidthStrategy() *ManagerBandwidthStrategy {
|
||||
return &ManagerBandwidthStrategy{
|
||||
UsersBandwidthStrategy: NewUsersBandwidthStrategy(map[string]BandwidthStrategy{}),
|
||||
strategies: make(map[string]BandwidthStrategy),
|
||||
conns: make(map[string][]*bwConnEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerBandwidthStrategy) wrapConn(ctx context.Context, conn net.Conn, metadata *adapter.InboundContext, reverse bool) (net.Conn, error) {
|
||||
s.mtx.Lock()
|
||||
var user string
|
||||
if metadata != nil {
|
||||
user = metadata.User
|
||||
}
|
||||
strategy, ok := s.strategies[user]
|
||||
s.mtx.Unlock()
|
||||
if !ok {
|
||||
return nil, E.New("user strategy not found: ", user)
|
||||
}
|
||||
wrapped, err := strategy.wrapConn(ctx, conn, metadata, reverse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry := &bwConnEntry{conn: conn}
|
||||
s.mtx.Lock()
|
||||
s.conns[user] = append(s.conns[user], entry)
|
||||
s.mtx.Unlock()
|
||||
return onclose.NewConn(wrapped, func() {
|
||||
s.mtx.Lock()
|
||||
entries := s.conns[user]
|
||||
for i, e := range entries {
|
||||
if e == entry {
|
||||
s.conns[user] = append(entries[:i], entries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mtx.Unlock()
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *ManagerBandwidthStrategy) wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error) {
|
||||
s.mtx.Lock()
|
||||
var user string
|
||||
if metadata != nil {
|
||||
user = metadata.User
|
||||
}
|
||||
strategy, ok := s.strategies[user]
|
||||
s.mtx.Unlock()
|
||||
if !ok {
|
||||
return nil, E.New("user strategy not found: ", user)
|
||||
}
|
||||
return strategy.wrapPacketConn(ctx, conn, metadata, reverse)
|
||||
}
|
||||
|
||||
func (s *ManagerBandwidthStrategy) UpdateStrategies(strategies map[string]BandwidthStrategy) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
var closedEntries []*bwConnEntry
|
||||
for user, entries := range s.conns {
|
||||
if _, exists := strategies[user]; !exists {
|
||||
closedEntries = append(closedEntries, entries...)
|
||||
delete(s.conns, user)
|
||||
}
|
||||
}
|
||||
s.strategies = strategies
|
||||
s.mtx.Unlock()
|
||||
for _, entry := range closedEntries {
|
||||
entry.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type BypassBandwidthStrategy struct{}
|
||||
@@ -263,8 +352,8 @@ func CreateStrategy(strategy string, mode string, connectionType string, speed u
|
||||
return NewDefaultWrapStrategy(limiterStrategy, connWrapper, packetConnWrapper), nil
|
||||
}
|
||||
|
||||
func createSpeedLimiter(speed uint64, flowKeys []string) (Limiter, error) {
|
||||
var limiter Limiter = rate.NewLimiter(rate.Limit(float64(speed)), 65536)
|
||||
func createSpeedLimiter(speed uint64, flowKeys []string) (BandwidthLimiter, error) {
|
||||
var limiter BandwidthLimiter = &speedLimiter{limiter: rate.NewLimiter(rate.Limit(float64(speed)), 65536)}
|
||||
for i := len(flowKeys) - 1; i >= 0; i-- {
|
||||
getter, err := flowKeysConnIDGetter(flowKeys[i])
|
||||
if err != nil {
|
||||
@@ -275,16 +364,24 @@ func createSpeedLimiter(speed uint64, flowKeys []string) (Limiter, error) {
|
||||
return limiter, nil
|
||||
}
|
||||
|
||||
type speedLimiter struct {
|
||||
limiter *rate.Limiter
|
||||
}
|
||||
|
||||
func (r *speedLimiter) WaitN(ctx context.Context, n int) error {
|
||||
return r.limiter.WaitN(ctx, n)
|
||||
}
|
||||
|
||||
func (r *speedLimiter) SetSpeed(speed uint64) {
|
||||
r.limiter.SetLimit(rate.Limit(float64(speed)))
|
||||
}
|
||||
|
||||
func flowKeysConnIDGetter(name string) (ConnIDGetter, error) {
|
||||
switch name {
|
||||
case "user":
|
||||
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
|
||||
return metadata.User, true
|
||||
}, nil
|
||||
case "destination":
|
||||
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
|
||||
return metadata.Destination.String(), true
|
||||
}, nil
|
||||
case "source_ip":
|
||||
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
|
||||
return metadata.Source.IPAddr().String(), true
|
||||
@@ -302,6 +399,14 @@ func flowKeysConnIDGetter(name string) (ConnIDGetter, error) {
|
||||
}
|
||||
return strconv.FormatUint(uint64(id.ID), 10), ok
|
||||
}, nil
|
||||
case "protocol":
|
||||
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
|
||||
return metadata.Protocol, metadata.Protocol != ""
|
||||
}, nil
|
||||
case "destination":
|
||||
return func(ctx context.Context, metadata *adapter.InboundContext) (string, bool) {
|
||||
return metadata.Destination.String(), true
|
||||
}, nil
|
||||
default:
|
||||
return nil, E.New("flow key not found: ", name)
|
||||
}
|
||||
|
||||
@@ -4,13 +4,14 @@ import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func NewDefaultLock(max uint32) LockIDGetter {
|
||||
locks := make(map[string]*uint32)
|
||||
mtx := sync.Mutex{}
|
||||
return func(id string) (CloseHandlerFunc, context.Context, error) {
|
||||
return func(id string) (onclose.CloseHandlerFunc, context.Context, error) {
|
||||
mtx.Lock()
|
||||
defer mtx.Unlock()
|
||||
handles, ok := locks[id]
|
||||
@@ -22,16 +23,13 @@ func NewDefaultLock(max uint32) LockIDGetter {
|
||||
locks[id] = handles
|
||||
}
|
||||
*handles++
|
||||
var once sync.Once
|
||||
return func() {
|
||||
once.Do(func() {
|
||||
mtx.Lock()
|
||||
defer mtx.Unlock()
|
||||
*handles--
|
||||
if *handles == 0 {
|
||||
delete(locks, id)
|
||||
}
|
||||
})
|
||||
mtx.Lock()
|
||||
defer mtx.Unlock()
|
||||
*handles--
|
||||
if *handles == 0 {
|
||||
delete(locks, id)
|
||||
}
|
||||
}, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/route"
|
||||
@@ -110,7 +111,7 @@ func (h *Outbound) DialContext(ctx context.Context, network string, destination
|
||||
onClose()
|
||||
return nil, err
|
||||
}
|
||||
conn = newConnWithCloseHandlerFunc(conn, onClose)
|
||||
conn = onclose.NewConn(conn, onClose)
|
||||
if lockCtx != nil {
|
||||
go connChecker(lockCtx, conn.Close)
|
||||
}
|
||||
@@ -127,7 +128,7 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
|
||||
onClose()
|
||||
return nil, err
|
||||
}
|
||||
conn = newPacketConnWithCloseHandlerFunc(conn, onClose)
|
||||
conn = onclose.NewPacketConn(conn, onClose)
|
||||
if lockCtx != nil {
|
||||
go connChecker(lockCtx, conn.Close)
|
||||
}
|
||||
@@ -141,7 +142,7 @@ func (h *Outbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata
|
||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||
return
|
||||
}
|
||||
conn = newConnWithCloseHandlerFunc(conn, limiterOnClose)
|
||||
conn = onclose.NewConn(conn, limiterOnClose)
|
||||
if lockCtx != nil {
|
||||
go connChecker(lockCtx, conn.Close)
|
||||
}
|
||||
@@ -158,7 +159,7 @@ func (h *Outbound) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn,
|
||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||
return
|
||||
}
|
||||
conn = bufio.NewPacketConn(newPacketConnWithCloseHandlerFunc(bufio.NewNetPacketConn(conn), limiterOnClose))
|
||||
conn = bufio.NewPacketConn(onclose.NewPacketConn(bufio.NewNetPacketConn(conn), limiterOnClose))
|
||||
if lockCtx != nil {
|
||||
go connChecker(lockCtx, conn.Close)
|
||||
}
|
||||
@@ -172,33 +173,7 @@ func (h *Outbound) GetStrategy() ConnectionStrategy {
|
||||
return h.strategy
|
||||
}
|
||||
|
||||
type connWithCloseHandlerFunc struct {
|
||||
net.Conn
|
||||
onClose CloseHandlerFunc
|
||||
}
|
||||
|
||||
func newConnWithCloseHandlerFunc(conn net.Conn, onClose CloseHandlerFunc) *connWithCloseHandlerFunc {
|
||||
return &connWithCloseHandlerFunc{conn, onClose}
|
||||
}
|
||||
|
||||
func (conn *connWithCloseHandlerFunc) Close() error {
|
||||
conn.onClose()
|
||||
return conn.Conn.Close()
|
||||
}
|
||||
|
||||
type packetConnWithCloseHandlerFunc struct {
|
||||
net.PacketConn
|
||||
onClose CloseHandlerFunc
|
||||
}
|
||||
|
||||
func newPacketConnWithCloseHandlerFunc(conn net.PacketConn, onClose CloseHandlerFunc) *packetConnWithCloseHandlerFunc {
|
||||
return &packetConnWithCloseHandlerFunc{conn, onClose}
|
||||
}
|
||||
|
||||
func (conn *packetConnWithCloseHandlerFunc) Close() error {
|
||||
conn.onClose()
|
||||
return conn.PacketConn.Close()
|
||||
}
|
||||
|
||||
func connChecker(ctx context.Context, closeFunc func() error) {
|
||||
<-ctx.Done()
|
||||
|
||||
@@ -6,18 +6,17 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type (
|
||||
CloseHandlerFunc = func()
|
||||
|
||||
ConnIDGetter = func(context.Context, *adapter.InboundContext) (string, bool)
|
||||
LockIDGetter = func(string) (CloseHandlerFunc, context.Context, error)
|
||||
LockIDGetter = func(string) (onclose.CloseHandlerFunc, context.Context, error)
|
||||
|
||||
ConnectionStrategy interface {
|
||||
request(ctx context.Context, metadata *adapter.InboundContext) (onClose CloseHandlerFunc, lockCtx context.Context, err error)
|
||||
request(ctx context.Context, metadata *adapter.InboundContext) (onClose onclose.CloseHandlerFunc, lockCtx context.Context, err error)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -36,7 +35,7 @@ func NewDefaultConnectionStrategy(connIDGetter ConnIDGetter, lockIDGetter LockID
|
||||
return outbound
|
||||
}
|
||||
|
||||
func (s *DefaultConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
|
||||
func (s *DefaultConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
id, ok := s.connIDGetter(ctx, metadata)
|
||||
@@ -57,7 +56,7 @@ func NewUsersConnectionStrategy(strategies map[string]ConnectionStrategy) *Users
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
|
||||
func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
var user string
|
||||
@@ -71,20 +70,78 @@ func (s *UsersConnectionStrategy) request(ctx context.Context, metadata *adapter
|
||||
return nil, nil, E.New("user strategy not found: ", user)
|
||||
}
|
||||
|
||||
type cancelEntry struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type ManagerConnectionStrategy struct {
|
||||
*UsersConnectionStrategy
|
||||
strategies map[string]ConnectionStrategy
|
||||
cancels map[string][]*cancelEntry
|
||||
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewManagerConnectionStrategy() *ManagerConnectionStrategy {
|
||||
return &ManagerConnectionStrategy{
|
||||
UsersConnectionStrategy: NewUsersConnectionStrategy(map[string]ConnectionStrategy{}),
|
||||
strategies: make(map[string]ConnectionStrategy),
|
||||
cancels: make(map[string][]*cancelEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
|
||||
s.mtx.Lock()
|
||||
var user string
|
||||
if metadata != nil {
|
||||
user = metadata.User
|
||||
}
|
||||
strategy, ok := s.strategies[user]
|
||||
if !ok {
|
||||
s.mtx.Unlock()
|
||||
return nil, nil, E.New("user strategy not found: ", user)
|
||||
}
|
||||
s.mtx.Unlock()
|
||||
onClose, _, err := strategy.request(ctx, metadata)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cancelCtx, cancel := context.WithCancel(context.Background())
|
||||
entry := &cancelEntry{cancel: cancel}
|
||||
s.mtx.Lock()
|
||||
s.cancels[user] = append(s.cancels[user], entry)
|
||||
s.mtx.Unlock()
|
||||
originalOnClose := onClose
|
||||
wrappedOnClose := func() {
|
||||
s.mtx.Lock()
|
||||
entries := s.cancels[user]
|
||||
for i, e := range entries {
|
||||
if e == entry {
|
||||
s.cancels[user] = append(entries[:i], entries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mtx.Unlock()
|
||||
cancel()
|
||||
if originalOnClose != nil {
|
||||
originalOnClose()
|
||||
}
|
||||
}
|
||||
return wrappedOnClose, cancelCtx, nil
|
||||
}
|
||||
|
||||
func (s *ManagerConnectionStrategy) UpdateStrategies(strategies map[string]ConnectionStrategy) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
var entries []*cancelEntry
|
||||
for user, cancels := range s.cancels {
|
||||
if _, exists := strategies[user]; !exists {
|
||||
entries = append(entries, cancels...)
|
||||
delete(s.cancels, user)
|
||||
}
|
||||
}
|
||||
s.strategies = strategies
|
||||
s.mtx.Unlock()
|
||||
for _, entry := range entries {
|
||||
entry.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
type BypassConnectionStrategy struct{}
|
||||
@@ -93,7 +150,7 @@ func NewBypassConnectionStrategy() *BypassConnectionStrategy {
|
||||
return &BypassConnectionStrategy{}
|
||||
}
|
||||
|
||||
func (s *BypassConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (CloseHandlerFunc, context.Context, error) {
|
||||
func (s *BypassConnectionStrategy) request(ctx context.Context, metadata *adapter.InboundContext) (onclose.CloseHandlerFunc, context.Context, error) {
|
||||
return func() {}, nil, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/onclose"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type (
|
||||
CloseHandlerFunc = func()
|
||||
ConnWrapper = func(ctx context.Context, conn net.Conn, limiter TrafficLimiter, reverse bool) net.Conn
|
||||
PacketConnWrapper = func(ctx context.Context, conn net.PacketConn, limiter TrafficLimiter, reverse bool) net.PacketConn
|
||||
)
|
||||
@@ -72,32 +72,60 @@ func (s *GlobalTrafficStrategy) getLimiter(ctx context.Context, metadata *adapte
|
||||
return s.limiter, nil
|
||||
}
|
||||
|
||||
type connEntry struct {
|
||||
conn net.Conn
|
||||
}
|
||||
|
||||
|
||||
|
||||
type ManagerTrafficStrategy struct {
|
||||
strategies map[string]TrafficStrategy
|
||||
mtx sync.Mutex
|
||||
conns map[string][]*connEntry
|
||||
|
||||
mtx sync.Mutex
|
||||
}
|
||||
|
||||
func NewManagerTrafficStrategy() *ManagerTrafficStrategy {
|
||||
return &ManagerTrafficStrategy{}
|
||||
return &ManagerTrafficStrategy{
|
||||
conns: make(map[string][]*connEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ManagerTrafficStrategy) wrapConn(ctx context.Context, conn net.Conn, metadata *adapter.InboundContext, reverse bool) (net.Conn, error) {
|
||||
strategy, err := s.getStrategy(ctx, metadata)
|
||||
strategy, user, err := s.getStrategy(ctx, metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return strategy.wrapConn(ctx, conn, metadata, reverse)
|
||||
wrapped, err := strategy.wrapConn(ctx, conn, metadata, reverse)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry := &connEntry{conn: conn}
|
||||
s.mtx.Lock()
|
||||
s.conns[user] = append(s.conns[user], entry)
|
||||
s.mtx.Unlock()
|
||||
return onclose.NewConn(wrapped, func() {
|
||||
s.mtx.Lock()
|
||||
entries := s.conns[user]
|
||||
for i, e := range entries {
|
||||
if e == entry {
|
||||
s.conns[user] = append(entries[:i], entries[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mtx.Unlock()
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *ManagerTrafficStrategy) wrapPacketConn(ctx context.Context, conn net.PacketConn, metadata *adapter.InboundContext, reverse bool) (net.PacketConn, error) {
|
||||
strategy, err := s.getStrategy(ctx, metadata)
|
||||
strategy, _, err := s.getStrategy(ctx, metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return strategy.wrapPacketConn(ctx, conn, metadata, reverse)
|
||||
}
|
||||
|
||||
func (s *ManagerTrafficStrategy) getStrategy(ctx context.Context, metadata *adapter.InboundContext) (TrafficStrategy, error) {
|
||||
func (s *ManagerTrafficStrategy) getStrategy(ctx context.Context, metadata *adapter.InboundContext) (TrafficStrategy, string, error) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
var user string
|
||||
@@ -106,15 +134,25 @@ func (s *ManagerTrafficStrategy) getStrategy(ctx context.Context, metadata *adap
|
||||
}
|
||||
strategy, ok := s.strategies[user]
|
||||
if ok {
|
||||
return strategy, nil
|
||||
return strategy, user, nil
|
||||
}
|
||||
return nil, E.New("user strategy not found: ", user)
|
||||
return nil, user, E.New("user strategy not found: ", user)
|
||||
}
|
||||
|
||||
func (s *ManagerTrafficStrategy) UpdateStrategies(strategies map[string]TrafficStrategy) {
|
||||
s.mtx.Lock()
|
||||
defer s.mtx.Unlock()
|
||||
var closedEntries []*connEntry
|
||||
for user, entries := range s.conns {
|
||||
if _, exists := strategies[user]; !exists {
|
||||
closedEntries = append(closedEntries, entries...)
|
||||
delete(s.conns, user)
|
||||
}
|
||||
}
|
||||
s.strategies = strategies
|
||||
s.mtx.Unlock()
|
||||
for _, entry := range closedEntries {
|
||||
entry.conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type BypassTrafficStrategy struct{}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/common/cloudflare"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
@@ -99,7 +100,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
|
||||
logger.ErrorContext(ctx, E.New("failed to generate cert: ", err))
|
||||
return
|
||||
}
|
||||
tlsConfig, err := tls.NewMASQUEClient(ctx, logger, "consumer-masque.cloudflareclient.com", cert, privKey, peerPubKey, options.MASQUEOutboundTLSOptions)
|
||||
tlsConfig, err := tls.NewMASQUEClient(ctx, logger, "consumer-masque.cloudflareclient.com", cert, privKey, peerPubKey, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
logger.ErrorContext(ctx, E.New("failed to prepare TLS config: ", err))
|
||||
return
|
||||
|
||||
@@ -98,6 +98,10 @@ func (h *Inbound) Close() error {
|
||||
)
|
||||
}
|
||||
|
||||
func (h *Inbound) UpdateUsers(users []auth.User) {
|
||||
h.authenticator.UpdateUsers(users)
|
||||
}
|
||||
|
||||
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
err := h.newConnection(ctx, conn, metadata, onClose)
|
||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||
|
||||
@@ -147,6 +147,10 @@ func (n *Inbound) Close() error {
|
||||
)
|
||||
}
|
||||
|
||||
func (n *Inbound) UpdateUsers(users []auth.User) {
|
||||
n.authenticator.UpdateUsers(users)
|
||||
}
|
||||
|
||||
func (n *Inbound) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
ctx := log.ContextWithNewID(request.Context())
|
||||
if request.Method != "CONNECT" {
|
||||
|
||||
158
protocol/openvpn/outbound.go
Normal file
158
protocol/openvpn/outbound.go
Normal file
@@ -0,0 +1,158 @@
|
||||
//go:build with_openvpn
|
||||
|
||||
package openvpn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"time"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
ovpn "github.com/sagernet/sing-box/transport/openvpn"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
func RegisterOutbound(registry *outbound.Registry) {
|
||||
outbound.Register[option.OpenVPNOutboundOptions](registry, C.TypeOpenVPN, NewOutbound)
|
||||
}
|
||||
|
||||
type Outbound struct {
|
||||
outbound.Adapter
|
||||
ctx context.Context
|
||||
dnsRouter adapter.DNSRouter
|
||||
logger logger.ContextLogger
|
||||
tunnel *ovpn.Tunnel
|
||||
}
|
||||
|
||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.OpenVPNOutboundOptions) (adapter.Outbound, error) {
|
||||
tlsConfig, err := tls.NewOpenVPNClient(ctx, logger, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var tlsKey []byte
|
||||
keyDirection := -1
|
||||
if options.TLSAuth != "" || options.TLSAuthPath != "" {
|
||||
tlsAuth := options.TLSAuth
|
||||
if tlsAuth == "" {
|
||||
data, err := os.ReadFile(options.TLSAuthPath)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read tls_auth_path")
|
||||
}
|
||||
tlsAuth = string(data)
|
||||
}
|
||||
tlsKey = []byte(tlsAuth)
|
||||
keyDirection = options.KeyDirection
|
||||
} else {
|
||||
tlsCrypt := options.TLSCrypt
|
||||
if tlsCrypt == "" && options.TLSCryptPath != "" {
|
||||
data, err := os.ReadFile(options.TLSCryptPath)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "read tls_crypt_path")
|
||||
}
|
||||
tlsCrypt = string(data)
|
||||
}
|
||||
tlsKey = []byte(tlsCrypt)
|
||||
}
|
||||
clientConfig := &ovpn.ClientConfig{
|
||||
Proto: options.Proto,
|
||||
Cipher: options.Cipher,
|
||||
Auth: options.Auth,
|
||||
Username: options.Username,
|
||||
Password: options.Password,
|
||||
KeyDirection: keyDirection,
|
||||
TLSCrypt: tlsKey,
|
||||
TLSCryptV2: options.TLSCryptV2,
|
||||
}
|
||||
if err := clientConfig.Prepare(); err != nil {
|
||||
return nil, E.Cause(err, "invalid openvpn config")
|
||||
}
|
||||
outboundDialer, err := dialer.New(ctx, options.DialerOptions, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeOpenVPN, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
|
||||
ctx: ctx,
|
||||
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
|
||||
logger: logger,
|
||||
}
|
||||
tunnel, err := ovpn.NewTunnel(ctx, logger, ovpn.TunnelOptions{
|
||||
Dialer: outboundDialer,
|
||||
Servers: options.Servers,
|
||||
TLSConfig: tlsConfig,
|
||||
Config: clientConfig,
|
||||
ReconnectDelay: time.Duration(options.ReconnectDelay),
|
||||
PingInterval: time.Duration(options.PingInterval),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.tunnel = tunnel
|
||||
return o, nil
|
||||
}
|
||||
|
||||
func (o *Outbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStatePostStart {
|
||||
return nil
|
||||
}
|
||||
return o.tunnel.Start()
|
||||
}
|
||||
|
||||
func (o *Outbound) Close() error {
|
||||
if o.tunnel != nil {
|
||||
return o.tunnel.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
switch network {
|
||||
case N.NetworkTCP:
|
||||
o.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||
case N.NetworkUDP:
|
||||
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
}
|
||||
if destination.IsDomain() {
|
||||
destinationAddresses, err := o.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return N.DialSerial(ctx, o.tunnel, network, destination, destinationAddresses)
|
||||
}
|
||||
if !destination.Addr.IsValid() {
|
||||
return nil, E.New("invalid destination: ", destination)
|
||||
}
|
||||
return o.tunnel.DialContext(ctx, network, destination)
|
||||
}
|
||||
|
||||
func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
if destination.IsDomain() {
|
||||
destinationAddresses, err := o.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
packetConn, destinationAddress, err := N.ListenSerial(ctx, o.tunnel, destination, destinationAddresses)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if destinationAddress.IsValid() && destination != M.SocksaddrFrom(destinationAddress, destination.Port) {
|
||||
return bufio.NewNATPacketConn(bufio.NewPacketConn(packetConn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil
|
||||
}
|
||||
return packetConn, nil
|
||||
}
|
||||
return o.tunnel.ListenPacket(ctx, destination)
|
||||
}
|
||||
@@ -70,6 +70,10 @@ func (h *Inbound) Close() error {
|
||||
return h.listener.Close()
|
||||
}
|
||||
|
||||
func (h *Inbound) UpdateUsers(users []auth.User) {
|
||||
h.authenticator.UpdateUsers(users)
|
||||
}
|
||||
|
||||
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
err := socks.HandleConnectionEx(ctx, conn, std_bufio.NewReader(conn), h.authenticator, adapter.NewUpstreamHandlerEx(metadata, h.newUserConnection, h.streamUserPacketConnection), h.listener, h.udpTimeout, metadata.Source, onClose)
|
||||
N.CloseOnHandshakeFailure(conn, onClose, err)
|
||||
|
||||
178
protocol/sudoku/inbound.go
Normal file
178
protocol/sudoku/inbound.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/sudoku"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func RegisterInbound(registry *inbound.Registry) {
|
||||
inbound.Register[option.SudokuInboundOptions](registry, C.TypeSudoku, NewInbound)
|
||||
}
|
||||
|
||||
type Inbound struct {
|
||||
inbound.Adapter
|
||||
router adapter.ConnectionRouterEx
|
||||
logger logger.ContextLogger
|
||||
listener *listener.Listener
|
||||
protoConf sudoku.ProtocolConfig
|
||||
tunnelSrv *sudoku.HTTPMaskTunnelServer
|
||||
fallback string
|
||||
}
|
||||
|
||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SudokuInboundOptions) (adapter.Inbound, error) {
|
||||
defaultConf := sudoku.DefaultConfig()
|
||||
tableType, err := sudoku.NormalizeTableType(options.TableType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paddingMin, paddingMax := sudoku.ResolvePadding(options.PaddingMin, options.PaddingMax, defaultConf.PaddingMin, defaultConf.PaddingMax)
|
||||
enablePureDownlink := sudoku.DerefBool(options.EnablePureDownlink, defaultConf.EnablePureDownlink)
|
||||
handshakeTimeout := sudoku.DerefInt(options.HandshakeTimeout, defaultConf.HandshakeTimeoutSeconds)
|
||||
|
||||
tables, err := sudoku.NewServerTablesWithCustomPatterns(sudoku.ServerAEADSeed(options.Key), tableType, options.CustomTable, options.CustomTables)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protoConf := sudoku.ProtocolConfig{
|
||||
Key: options.Key,
|
||||
AEADMethod: defaultConf.AEADMethod,
|
||||
PaddingMin: paddingMin,
|
||||
PaddingMax: paddingMax,
|
||||
EnablePureDownlink: enablePureDownlink,
|
||||
HandshakeTimeoutSeconds: handshakeTimeout,
|
||||
DisableHTTPMask: options.DisableHTTPMask,
|
||||
HTTPMaskMode: options.HTTPMaskMode,
|
||||
HTTPMaskPathRoot: strings.TrimSpace(options.PathRoot),
|
||||
}
|
||||
if len(tables) == 1 {
|
||||
protoConf.Table = tables[0]
|
||||
} else {
|
||||
protoConf.Tables = tables
|
||||
}
|
||||
if options.AEADMethod != "" {
|
||||
protoConf.AEADMethod = options.AEADMethod
|
||||
}
|
||||
|
||||
in := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeSudoku, tag),
|
||||
router: router,
|
||||
logger: logger,
|
||||
protoConf: protoConf,
|
||||
fallback: strings.TrimSpace(options.Fallback),
|
||||
}
|
||||
if in.fallback != "" {
|
||||
in.tunnelSrv = sudoku.NewHTTPMaskTunnelServerWithFallback(&in.protoConf)
|
||||
} else {
|
||||
in.tunnelSrv = sudoku.NewHTTPMaskTunnelServer(&in.protoConf)
|
||||
}
|
||||
|
||||
in.listener = listener.New(listener.Options{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
Network: []string{N.NetworkTCP},
|
||||
Listen: options.ListenOptions,
|
||||
ConnectionHandler: in,
|
||||
})
|
||||
return in, nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
return h.listener.Start()
|
||||
}
|
||||
|
||||
func (h *Inbound) Close() error {
|
||||
return h.listener.Close()
|
||||
}
|
||||
|
||||
func (h *Inbound) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
h.handleConn(ctx, conn, metadata, onClose)
|
||||
}
|
||||
|
||||
func (h *Inbound) handleConn(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
handshakeConn := conn
|
||||
handshakeCfg := &h.protoConf
|
||||
|
||||
if h.tunnelSrv != nil {
|
||||
c, cfg, done, err := h.tunnelSrv.WrapConn(conn)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
if done {
|
||||
return
|
||||
}
|
||||
if c != nil {
|
||||
handshakeConn = c
|
||||
}
|
||||
if cfg != nil {
|
||||
handshakeCfg = cfg
|
||||
}
|
||||
}
|
||||
|
||||
cConn, meta, err := sudoku.ServerHandshake(handshakeConn, handshakeCfg)
|
||||
if err != nil {
|
||||
h.logger.DebugContext(ctx, "handshake failed: ", err)
|
||||
conn.Close()
|
||||
if handshakeConn != conn {
|
||||
handshakeConn.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
session, err := sudoku.ReadServerSession(cConn, meta)
|
||||
if err != nil {
|
||||
h.logger.WarnContext(ctx, "read session failed: ", err)
|
||||
cConn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
switch session.Type {
|
||||
case sudoku.SessionTypeUoT:
|
||||
h.handleUoT(ctx, session.Conn, metadata, onClose)
|
||||
case sudoku.SessionTypeMultiplex:
|
||||
mux, err := sudoku.AcceptMultiplexServer(session.Conn)
|
||||
if err != nil {
|
||||
session.Conn.Close()
|
||||
return
|
||||
}
|
||||
defer mux.Close()
|
||||
for {
|
||||
stream, target, err := mux.AcceptTCP()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go h.routeTCP(ctx, stream, target, metadata)
|
||||
}
|
||||
default:
|
||||
h.routeTCP(ctx, session.Conn, session.Target, metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Inbound) routeTCP(ctx context.Context, conn net.Conn, target string, metadata adapter.InboundContext) {
|
||||
destination := M.ParseSocksaddr(target)
|
||||
metadata.Destination = destination
|
||||
h.logger.InfoContext(ctx, "inbound connection to ", destination)
|
||||
h.router.RouteConnectionEx(ctx, conn, metadata, nil)
|
||||
}
|
||||
|
||||
func (h *Inbound) handleUoT(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
|
||||
packetConn := sudoku.NewUoTPacketConn(conn)
|
||||
h.router.RoutePacketConnectionEx(ctx, bufio.NewPacketConn(packetConn), metadata, onClose)
|
||||
}
|
||||
401
protocol/sudoku/outbound.go
Normal file
401
protocol/sudoku/outbound.go
Normal file
@@ -0,0 +1,401 @@
|
||||
package sudoku
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/sudoku"
|
||||
"github.com/sagernet/sing-box/transport/sudoku/obfs/httpmask"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
)
|
||||
|
||||
func RegisterOutbound(registry *outbound.Registry) {
|
||||
outbound.Register[option.SudokuOutboundOptions](registry, C.TypeSudoku, NewOutbound)
|
||||
}
|
||||
|
||||
type Outbound struct {
|
||||
outbound.Adapter
|
||||
logger logger.ContextLogger
|
||||
dialer N.Dialer
|
||||
tlsConfig tls.Config
|
||||
baseConf sudoku.ProtocolConfig
|
||||
|
||||
muxMu sync.Mutex
|
||||
muxClient *sudoku.MultiplexClient
|
||||
|
||||
httpMaskMu sync.Mutex
|
||||
httpMaskClient *httpmask.TunnelClient
|
||||
httpMaskKey string
|
||||
}
|
||||
|
||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SudokuOutboundOptions) (adapter.Outbound, error) {
|
||||
outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defaultConf := sudoku.DefaultConfig()
|
||||
tableType, err := sudoku.NormalizeTableType(options.TableType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paddingMin, paddingMax := sudoku.ResolvePadding(options.PaddingMin, options.PaddingMax, defaultConf.PaddingMin, defaultConf.PaddingMax)
|
||||
enablePureDownlink := sudoku.DerefBool(options.EnablePureDownlink, defaultConf.EnablePureDownlink)
|
||||
|
||||
serverAddr := options.ServerOptions.Build()
|
||||
|
||||
disableHTTPMask := defaultConf.DisableHTTPMask
|
||||
httpMaskMode := defaultConf.HTTPMaskMode
|
||||
var httpMaskHost string
|
||||
var pathRoot string
|
||||
httpMaskMultiplex := defaultConf.HTTPMaskMultiplex
|
||||
|
||||
if hm := options.HTTPMask; hm != nil {
|
||||
disableHTTPMask = !hm.Enabled
|
||||
if hm.Mode != "" {
|
||||
httpMaskMode = hm.Mode
|
||||
}
|
||||
httpMaskHost = hm.Host
|
||||
pathRoot = strings.TrimSpace(hm.PathRoot)
|
||||
if hm.Multiplex != "" {
|
||||
httpMaskMultiplex = hm.Multiplex
|
||||
}
|
||||
}
|
||||
|
||||
baseConf := sudoku.ProtocolConfig{
|
||||
ServerAddress: serverAddr.String(),
|
||||
Key: options.Key,
|
||||
AEADMethod: defaultConf.AEADMethod,
|
||||
PaddingMin: paddingMin,
|
||||
PaddingMax: paddingMax,
|
||||
EnablePureDownlink: enablePureDownlink,
|
||||
HandshakeTimeoutSeconds: defaultConf.HandshakeTimeoutSeconds,
|
||||
DisableHTTPMask: disableHTTPMask,
|
||||
HTTPMaskMode: httpMaskMode,
|
||||
HTTPMaskHost: httpMaskHost,
|
||||
HTTPMaskPathRoot: pathRoot,
|
||||
HTTPMaskMultiplex: httpMaskMultiplex,
|
||||
}
|
||||
if options.AEADMethod != "" {
|
||||
baseConf.AEADMethod = options.AEADMethod
|
||||
}
|
||||
|
||||
tables, err := sudoku.NewClientTablesWithCustomPatterns(sudoku.ClientAEADSeed(options.Key), tableType, options.CustomTable, options.CustomTables)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "build table(s)")
|
||||
}
|
||||
if len(tables) == 1 {
|
||||
baseConf.Table = tables[0]
|
||||
} else {
|
||||
baseConf.Tables = tables
|
||||
}
|
||||
|
||||
out := &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSudoku, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
|
||||
logger: logger,
|
||||
dialer: outboundDialer,
|
||||
baseConf: baseConf,
|
||||
}
|
||||
if hm := options.HTTPMask; !disableHTTPMask && hm != nil && hm.TLS != nil && hm.TLS.Enabled {
|
||||
tlsOptions := option.OutboundTLSOptions{
|
||||
Enabled: true,
|
||||
ServerName: options.Server,
|
||||
Fragment: hm.TLS.Fragment,
|
||||
FragmentFallbackDelay: hm.TLS.FragmentFallbackDelay,
|
||||
RecordFragment: hm.TLS.RecordFragment,
|
||||
KernelTx: hm.TLS.KernelTx,
|
||||
KernelRx: hm.TLS.KernelRx,
|
||||
}
|
||||
out.tlsConfig, err = tls.NewClientWithOptions(tls.ClientOptions{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
ServerAddress: options.Server,
|
||||
Options: tlsOptions,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
h.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||
case N.NetworkUDP:
|
||||
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
}
|
||||
ctx, metadata := adapter.ExtendContext(ctx)
|
||||
metadata.Outbound = h.Tag()
|
||||
metadata.Destination = destination
|
||||
|
||||
cfg := h.baseConf
|
||||
cfg.TargetAddress = destination.String()
|
||||
|
||||
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
|
||||
if muxMode == "on" && !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
|
||||
stream, err := h.dialMultiplex(ctx, cfg.TargetAddress)
|
||||
if err == nil {
|
||||
return stream, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := h.dialAndHandshake(ctx, &cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addrBuf, err := sudoku.EncodeAddress(cfg.TargetAddress)
|
||||
if err != nil {
|
||||
c.Close()
|
||||
return nil, E.Cause(err, "encode target address")
|
||||
}
|
||||
if err = sudoku.WriteKIPMessage(c, sudoku.KIPTypeOpenTCP, addrBuf); err != nil {
|
||||
c.Close()
|
||||
return nil, E.Cause(err, "send target address")
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
ctx, metadata := adapter.ExtendContext(ctx)
|
||||
metadata.Outbound = h.Tag()
|
||||
metadata.Destination = destination
|
||||
|
||||
cfg := h.baseConf
|
||||
cfg.TargetAddress = destination.String()
|
||||
|
||||
c, err := h.dialAndHandshake(ctx, &cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = sudoku.WriteKIPMessage(c, sudoku.KIPTypeStartUoT, nil); err != nil {
|
||||
c.Close()
|
||||
return nil, E.Cause(err, "start uot")
|
||||
}
|
||||
|
||||
return bufio.NewBindPacketConn(sudoku.NewUoTPacketConn(c), destination), nil
|
||||
}
|
||||
|
||||
func (h *Outbound) Close() error {
|
||||
h.resetMuxClient()
|
||||
h.resetHTTPMaskClient()
|
||||
return common.Close(h.tlsConfig)
|
||||
}
|
||||
|
||||
func (h *Outbound) InterfaceUpdated() {
|
||||
h.resetMuxClient()
|
||||
h.resetHTTPMaskClient()
|
||||
}
|
||||
|
||||
func (h *Outbound) dialAndHandshake(ctx context.Context, cfg *sudoku.ProtocolConfig) (net.Conn, error) {
|
||||
handshakeCfg := *cfg
|
||||
if !handshakeCfg.DisableHTTPMask && httpTunnelModeEnabled(handshakeCfg.HTTPMaskMode) {
|
||||
handshakeCfg.DisableHTTPMask = true
|
||||
}
|
||||
|
||||
upgrade := func(raw net.Conn) (net.Conn, error) {
|
||||
return sudoku.ClientHandshake(raw, &handshakeCfg)
|
||||
}
|
||||
|
||||
var c net.Conn
|
||||
var err error
|
||||
var handshakeDone bool
|
||||
|
||||
if !cfg.DisableHTTPMask && httpTunnelModeEnabled(cfg.HTTPMaskMode) {
|
||||
muxMode := normalizeHTTPMaskMultiplex(cfg.HTTPMaskMultiplex)
|
||||
if muxMode == "auto" && strings.ToLower(strings.TrimSpace(cfg.HTTPMaskMode)) != "ws" {
|
||||
if client, cerr := h.getOrCreateHTTPMaskClient(cfg); cerr == nil && client != nil {
|
||||
c, err = client.DialTunnel(ctx, httpmask.TunnelDialOptions{
|
||||
Mode: cfg.HTTPMaskMode,
|
||||
TLSConfig: h.httpMaskTLSConfig(),
|
||||
HostOverride: cfg.HTTPMaskHost,
|
||||
PathRoot: cfg.HTTPMaskPathRoot,
|
||||
AuthKey: sudoku.ClientAEADSeed(cfg.Key),
|
||||
Upgrade: upgrade,
|
||||
Multiplex: cfg.HTTPMaskMultiplex,
|
||||
DialContext: h.dialRaw,
|
||||
})
|
||||
if err != nil {
|
||||
h.resetHTTPMaskClient()
|
||||
}
|
||||
}
|
||||
}
|
||||
if c == nil && err == nil {
|
||||
c, err = sudoku.DialHTTPMaskTunnel(ctx, cfg.ServerAddress, cfg, h.dialRaw, upgrade)
|
||||
}
|
||||
if err == nil && c != nil {
|
||||
handshakeDone = true
|
||||
}
|
||||
}
|
||||
if c == nil && err == nil {
|
||||
c, err = h.dialer.DialContext(ctx, N.NetworkTCP, M.ParseSocksaddr(cfg.ServerAddress))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "connect to ", cfg.ServerAddress)
|
||||
}
|
||||
|
||||
if !handshakeDone {
|
||||
c, err = sudoku.ClientHandshake(c, &handshakeCfg)
|
||||
if err != nil {
|
||||
common.Close(c)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (h *Outbound) dialRaw(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return h.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
}
|
||||
|
||||
func (h *Outbound) httpMaskTLSConfig() httpmask.TLSClientConfig {
|
||||
if h.tlsConfig == nil {
|
||||
return nil
|
||||
}
|
||||
return tlsConfigAdapter{h.tlsConfig}
|
||||
}
|
||||
|
||||
type tlsConfigAdapter struct {
|
||||
config tls.Config
|
||||
}
|
||||
|
||||
func (a tlsConfigAdapter) Client(conn net.Conn) (net.Conn, error) {
|
||||
return a.config.Client(conn)
|
||||
}
|
||||
|
||||
func (h *Outbound) dialMultiplex(ctx context.Context, targetAddress string) (net.Conn, error) {
|
||||
for attempt := 0; attempt < 2; attempt++ {
|
||||
client, err := h.getOrCreateMuxClient(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stream, err := client.Dial(ctx, targetAddress)
|
||||
if err != nil {
|
||||
h.resetMuxClient()
|
||||
continue
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
return nil, fmt.Errorf("multiplex open stream failed")
|
||||
}
|
||||
|
||||
func (h *Outbound) getOrCreateMuxClient(ctx context.Context) (*sudoku.MultiplexClient, error) {
|
||||
h.muxMu.Lock()
|
||||
defer h.muxMu.Unlock()
|
||||
|
||||
if h.muxClient != nil && !h.muxClient.IsClosed() {
|
||||
return h.muxClient, nil
|
||||
}
|
||||
|
||||
baseCfg := h.baseConf
|
||||
baseConn, err := h.dialAndHandshake(ctx, &baseCfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client, err := sudoku.StartMultiplexClient(baseConn)
|
||||
if err != nil {
|
||||
baseConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
h.muxClient = client
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (h *Outbound) resetMuxClient() {
|
||||
h.muxMu.Lock()
|
||||
defer h.muxMu.Unlock()
|
||||
if h.muxClient != nil {
|
||||
h.muxClient.Close()
|
||||
h.muxClient = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Outbound) getOrCreateHTTPMaskClient(cfg *sudoku.ProtocolConfig) (*httpmask.TunnelClient, error) {
|
||||
key := cfg.ServerAddress + "|" + fmt.Sprint(h.tlsConfig != nil) + "|" + strings.TrimSpace(cfg.HTTPMaskHost)
|
||||
|
||||
h.httpMaskMu.Lock()
|
||||
if h.httpMaskClient != nil && h.httpMaskKey == key {
|
||||
client := h.httpMaskClient
|
||||
h.httpMaskMu.Unlock()
|
||||
return client, nil
|
||||
}
|
||||
h.httpMaskMu.Unlock()
|
||||
|
||||
client, err := httpmask.NewTunnelClient(cfg.ServerAddress, httpmask.TunnelClientOptions{
|
||||
TLSConfig: h.httpMaskTLSConfig(),
|
||||
HostOverride: cfg.HTTPMaskHost,
|
||||
DialContext: h.dialRaw,
|
||||
MaxIdleConns: 32,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
h.httpMaskMu.Lock()
|
||||
defer h.httpMaskMu.Unlock()
|
||||
if h.httpMaskClient != nil && h.httpMaskKey == key {
|
||||
client.CloseIdleConnections()
|
||||
return h.httpMaskClient, nil
|
||||
}
|
||||
if h.httpMaskClient != nil {
|
||||
h.httpMaskClient.CloseIdleConnections()
|
||||
}
|
||||
h.httpMaskClient = client
|
||||
h.httpMaskKey = key
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (h *Outbound) resetHTTPMaskClient() {
|
||||
h.httpMaskMu.Lock()
|
||||
defer h.httpMaskMu.Unlock()
|
||||
if h.httpMaskClient != nil {
|
||||
h.httpMaskClient.CloseIdleConnections()
|
||||
h.httpMaskClient = nil
|
||||
h.httpMaskKey = ""
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeHTTPMaskMultiplex(mode string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case "", "off":
|
||||
return "off"
|
||||
case "auto":
|
||||
return "auto"
|
||||
case "on":
|
||||
return "on"
|
||||
default:
|
||||
return "off"
|
||||
}
|
||||
}
|
||||
|
||||
func httpTunnelModeEnabled(mode string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(mode)) {
|
||||
case "stream", "poll", "auto", "ws":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
190
protocol/trusttunnel/inbound.go
Normal file
190
protocol/trusttunnel/inbound.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package trusttunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/inbound"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/trusttunnel"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/auth"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"golang.org/x/net/http2/h2c"
|
||||
)
|
||||
|
||||
func RegisterInbound(registry *inbound.Registry) {
|
||||
inbound.Register[option.TrustTunnelInboundOptions](registry, C.TypeTrustTunnel, NewInbound)
|
||||
}
|
||||
|
||||
type Inbound struct {
|
||||
inbound.Adapter
|
||||
ctx context.Context
|
||||
router adapter.Router
|
||||
logger logger.ContextLogger
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
service *trusttunnel.Service
|
||||
httpServer *http.Server
|
||||
quicService *trusttunnel.QUICService
|
||||
network []string
|
||||
}
|
||||
|
||||
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrustTunnelInboundOptions) (adapter.Inbound, error) {
|
||||
if options.TLS == nil || !options.TLS.Enabled {
|
||||
return nil, C.ErrTLSRequired
|
||||
}
|
||||
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
networkList := options.Network.Build()
|
||||
if len(networkList) == 0 {
|
||||
networkList = []string{N.NetworkTCP}
|
||||
}
|
||||
inbound := &Inbound{
|
||||
Adapter: inbound.NewAdapter(C.TypeTrustTunnel, tag),
|
||||
ctx: ctx,
|
||||
router: router,
|
||||
logger: logger,
|
||||
tlsConfig: tlsConfig,
|
||||
network: networkList,
|
||||
listener: listener.New(listener.Options{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
}
|
||||
service := trusttunnel.NewService(trusttunnel.ServiceOptions{
|
||||
Ctx: ctx,
|
||||
Logger: logger,
|
||||
Handler: (*inboundHandler)(inbound),
|
||||
})
|
||||
userMap := make(map[string]string, len(options.Users))
|
||||
for _, u := range options.Users {
|
||||
userMap[u.Name] = u.Password
|
||||
}
|
||||
service.UpdateUsers(userMap)
|
||||
inbound.service = service
|
||||
if common.Contains(networkList, N.NetworkUDP) {
|
||||
inbound.quicService = trusttunnel.NewQUICService(service, options.CongestionController, options.CWND, options.BBRProfile)
|
||||
}
|
||||
return inbound, nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
if h.tlsConfig != nil {
|
||||
err := h.tlsConfig.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if common.Contains(h.network, N.NetworkTCP) {
|
||||
tcpListener, err := h.listener.ListenTCP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
h.httpServer = &http.Server{
|
||||
Handler: h2c.NewHandler(h.service, &http2.Server{}),
|
||||
BaseContext: func(net.Listener) context.Context {
|
||||
return h.ctx
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
var l net.Listener = tcpListener
|
||||
if h.tlsConfig != nil {
|
||||
if len(h.tlsConfig.NextProtos()) == 0 {
|
||||
h.tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
|
||||
} else if !common.Contains(h.tlsConfig.NextProtos(), http2.NextProtoTLS) {
|
||||
h.tlsConfig.SetNextProtos(append([]string{http2.NextProtoTLS}, h.tlsConfig.NextProtos()...))
|
||||
}
|
||||
l = aTLS.NewListener(tcpListener, h.tlsConfig)
|
||||
}
|
||||
sErr := h.httpServer.Serve(l)
|
||||
if sErr != nil && !errors.Is(sErr, http.ErrServerClosed) {
|
||||
h.logger.Error("HTTP server error: ", sErr)
|
||||
}
|
||||
}()
|
||||
}
|
||||
if common.Contains(h.network, N.NetworkUDP) {
|
||||
udpConn, err := h.listener.ListenUDP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = h.quicService.Start(h.ctx, udpConn, h.tlsConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Inbound) Close() error {
|
||||
return common.Close(
|
||||
h.listener,
|
||||
common.PtrOrNil(h.httpServer),
|
||||
h.quicService,
|
||||
h.tlsConfig,
|
||||
)
|
||||
}
|
||||
|
||||
func (h *Inbound) UpdateUsers(users []option.TrustTunnelUser) {
|
||||
userMap := make(map[string]string, len(users))
|
||||
for _, u := range users {
|
||||
userMap[u.Name] = u.Password
|
||||
}
|
||||
h.service.UpdateUsers(userMap)
|
||||
}
|
||||
|
||||
type inboundHandler Inbound
|
||||
|
||||
func (h *inboundHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error {
|
||||
var inboundCtx adapter.InboundContext
|
||||
inboundCtx.Inbound = h.Tag()
|
||||
inboundCtx.InboundType = h.Type()
|
||||
//nolint:staticcheck
|
||||
inboundCtx.InboundDetour = h.listener.ListenOptions().Detour
|
||||
inboundCtx.Source = metadata.Source
|
||||
inboundCtx.Destination = metadata.Destination
|
||||
if userName, _ := auth.UserFromContext[string](ctx); userName != "" {
|
||||
inboundCtx.User = userName
|
||||
h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", inboundCtx.Destination)
|
||||
} else {
|
||||
h.logger.InfoContext(ctx, "inbound connection to ", inboundCtx.Destination)
|
||||
}
|
||||
h.router.RouteConnectionEx(ctx, conn, inboundCtx, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *inboundHandler) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata M.Metadata) error {
|
||||
var inboundCtx adapter.InboundContext
|
||||
inboundCtx.Inbound = h.Tag()
|
||||
inboundCtx.InboundType = h.Type()
|
||||
//nolint:staticcheck
|
||||
inboundCtx.InboundDetour = h.listener.ListenOptions().Detour
|
||||
inboundCtx.Source = metadata.Source
|
||||
inboundCtx.Destination = metadata.Destination
|
||||
if userName, _ := auth.UserFromContext[string](ctx); userName != "" {
|
||||
inboundCtx.User = userName
|
||||
h.logger.InfoContext(ctx, "[", userName, "] inbound packet connection to ", inboundCtx.Destination)
|
||||
} else {
|
||||
h.logger.InfoContext(ctx, "inbound packet connection to ", inboundCtx.Destination)
|
||||
}
|
||||
h.router.RoutePacketConnectionEx(ctx, conn, inboundCtx, nil)
|
||||
return nil
|
||||
}
|
||||
118
protocol/trusttunnel/outbound.go
Normal file
118
protocol/trusttunnel/outbound.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package trusttunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/adapter/outbound"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing-box/transport/trusttunnel"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/bufio"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
func RegisterOutbound(registry *outbound.Registry) {
|
||||
outbound.Register[option.TrustTunnelOutboundOptions](registry, C.TypeTrustTunnel, NewOutbound)
|
||||
}
|
||||
|
||||
type Outbound struct {
|
||||
outbound.Adapter
|
||||
logger logger.ContextLogger
|
||||
client trusttunnel.Dialer
|
||||
}
|
||||
|
||||
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrustTunnelOutboundOptions) (adapter.Outbound, error) {
|
||||
if options.TLS == nil || !options.TLS.Enabled {
|
||||
return nil, C.ErrTLSRequired
|
||||
}
|
||||
outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serverAddr := options.ServerOptions.Build()
|
||||
networkList := options.Network.Build()
|
||||
clientOpts := trusttunnel.ClientOptions{
|
||||
Server: serverAddr,
|
||||
Username: options.Username,
|
||||
Password: options.Password,
|
||||
QUIC: options.QUIC,
|
||||
CongestionControl: options.CongestionController,
|
||||
CWND: options.CWND,
|
||||
BBRProfile: options.BBRProfile,
|
||||
HealthCheck: options.HealthCheck,
|
||||
}
|
||||
if options.QUIC {
|
||||
tlsConfig, err := tls.NewClient(ctx, logger, options.Server, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(tlsConfig.NextProtos()) == 0 {
|
||||
tlsConfig.SetNextProtos([]string{"h3"})
|
||||
}
|
||||
clientOpts.QUICDialer = outboundDialer
|
||||
clientOpts.QUICTLSConfig = tlsConfig
|
||||
} else {
|
||||
tlsConfig, err := tls.NewClient(ctx, logger, options.Server, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(tlsConfig.NextProtos()) == 0 {
|
||||
tlsConfig.SetNextProtos([]string{http2.NextProtoTLS})
|
||||
}
|
||||
clientOpts.TLSDialer = tls.NewDialer(outboundDialer, tlsConfig)
|
||||
}
|
||||
var client trusttunnel.Dialer
|
||||
if options.Multiplex != nil && options.Multiplex.Enabled {
|
||||
clientOpts.MaxConnections = options.Multiplex.MaxConnections
|
||||
clientOpts.MinStreams = options.Multiplex.MinStreams
|
||||
clientOpts.MaxStreams = options.Multiplex.MaxStreams
|
||||
client, err = trusttunnel.NewMultiplexClient(ctx, clientOpts)
|
||||
} else {
|
||||
client, err = trusttunnel.NewClient(ctx, clientOpts)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Outbound{
|
||||
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeTrustTunnel, tag, networkList, options.DialerOptions),
|
||||
logger: logger,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
|
||||
switch N.NetworkName(network) {
|
||||
case N.NetworkTCP:
|
||||
h.logger.InfoContext(ctx, "outbound connection to ", destination)
|
||||
return h.client.Dial(ctx, destination.String())
|
||||
case N.NetworkUDP:
|
||||
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
conn, err := h.client.ListenPacket(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bufio.NewBindPacketConn(conn, destination), nil
|
||||
default:
|
||||
return nil, E.New("unsupported network: ", network)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
|
||||
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
|
||||
return h.client.ListenPacket(ctx)
|
||||
}
|
||||
|
||||
func (h *Outbound) Close() error {
|
||||
return h.client.Close()
|
||||
}
|
||||
Reference in New Issue
Block a user