From 3312b8da50fb2fc073845950eb640577838a077e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 23 Apr 2026 01:33:38 +0800 Subject: [PATCH] Clean up DNS transports --- adapter/dns.go | 2 + dns/transport/base.go | 145 -------- dns/transport/conn_pool.go | 547 ++++++++++++++++++++++++++++ dns/transport/connector.go | 321 ---------------- dns/transport/connector_test.go | 407 --------------------- dns/transport/quic/quic.go | 100 +++-- dns/transport/tls.go | 90 ++--- dns/transport/udp.go | 80 ++-- protocol/tailscale/dns_transport.go | 75 +++- 9 files changed, 736 insertions(+), 1031 deletions(-) delete mode 100644 dns/transport/base.go create mode 100644 dns/transport/conn_pool.go delete mode 100644 dns/transport/connector.go delete mode 100644 dns/transport/connector_test.go diff --git a/adapter/dns.go b/adapter/dns.go index 8f065e2e..23fbc9de 100644 --- a/adapter/dns.go +++ b/adapter/dns.go @@ -68,6 +68,8 @@ type DNSTransport interface { Type() string Tag() string Dependencies() []string + // Reset closes the transport's existing connections so later requests use fresh connections. + // Exchanges that are currently using those connections may fail. Reset() Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) } diff --git a/dns/transport/base.go b/dns/transport/base.go deleted file mode 100644 index 06e41fd0..00000000 --- a/dns/transport/base.go +++ /dev/null @@ -1,145 +0,0 @@ -package transport - -import ( - "context" - "os" - "sync" - - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-box/dns" - E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" -) - -type TransportState int - -const ( - StateNew TransportState = iota - StateStarted - StateClosing - StateClosed -) - -var ( - ErrTransportClosed = os.ErrClosed - ErrConnectionReset = E.New("connection reset") -) - -type BaseTransport struct { - dns.TransportAdapter - Logger logger.ContextLogger - - mutex sync.Mutex - state TransportState - inFlight int32 - queriesComplete chan struct{} - closeCtx context.Context - closeCancel context.CancelFunc -} - -func NewBaseTransport(adapter dns.TransportAdapter, logger logger.ContextLogger) *BaseTransport { - ctx, cancel := context.WithCancel(context.Background()) - return &BaseTransport{ - TransportAdapter: adapter, - Logger: logger, - state: StateNew, - closeCtx: ctx, - closeCancel: cancel, - } -} - -func (t *BaseTransport) State() TransportState { - t.mutex.Lock() - defer t.mutex.Unlock() - return t.state -} - -func (t *BaseTransport) SetStarted() error { - t.mutex.Lock() - defer t.mutex.Unlock() - switch t.state { - case StateNew: - t.state = StateStarted - return nil - case StateStarted: - return nil - default: - return ErrTransportClosed - } -} - -func (t *BaseTransport) BeginQuery() bool { - t.mutex.Lock() - defer t.mutex.Unlock() - if t.state != StateStarted { - return false - } - t.inFlight++ - return true -} - -func (t *BaseTransport) EndQuery() { - t.mutex.Lock() - if t.inFlight > 0 { - t.inFlight-- - } - if t.inFlight == 0 && t.queriesComplete != nil { - close(t.queriesComplete) - t.queriesComplete = nil - } - t.mutex.Unlock() -} - -func (t *BaseTransport) CloseContext() context.Context { - return t.closeCtx -} - -func (t *BaseTransport) Shutdown(ctx context.Context) error { - t.mutex.Lock() - - if t.state >= StateClosing { - t.mutex.Unlock() - return nil - } - - if t.state == StateNew { - t.state = StateClosed - t.mutex.Unlock() - t.closeCancel() - return nil - } - - t.state = StateClosing - - if t.inFlight == 0 { - t.state = StateClosed - t.mutex.Unlock() - t.closeCancel() - return nil - } - - t.queriesComplete = make(chan struct{}) - queriesComplete := t.queriesComplete - t.mutex.Unlock() - - t.closeCancel() - - select { - case <-queriesComplete: - t.mutex.Lock() - t.state = StateClosed - t.mutex.Unlock() - return nil - case <-ctx.Done(): - t.mutex.Lock() - t.state = StateClosed - t.mutex.Unlock() - return ctx.Err() - } -} - -func (t *BaseTransport) Close() error { - ctx, cancel := context.WithTimeout(context.Background(), C.TCPTimeout) - defer cancel() - return t.Shutdown(ctx) -} diff --git a/dns/transport/conn_pool.go b/dns/transport/conn_pool.go new file mode 100644 index 00000000..6161e9bd --- /dev/null +++ b/dns/transport/conn_pool.go @@ -0,0 +1,547 @@ +package transport + +import ( + "context" + "net" + "sync" + "time" + + "github.com/sagernet/sing/common/x/list" +) + +type ConnPoolMode int + +const ( + ConnPoolSingle ConnPoolMode = iota + ConnPoolOrdered +) + +type ConnPoolOptions[T comparable] struct { + Mode ConnPoolMode + IsAlive func(T) bool + Close func(T, error) +} + +type ConnPool[T comparable] struct { + options ConnPoolOptions[T] + + access sync.Mutex + closed bool + state *connPoolState[T] +} + +type connPoolState[T comparable] struct { + ctx context.Context + cancel context.CancelCauseFunc + + all map[T]struct{} + + idle list.List[T] + idleElements map[T]*list.Element[T] + + shared T + hasShared bool + sharedClaimed bool + sharedCtx context.Context + sharedCancel context.CancelCauseFunc + + connecting *connPoolConnect[T] +} + +type connPoolConnect[T comparable] struct { + done chan struct{} + err error +} + +type connPoolDialContext struct { + context.Context + parent context.Context +} + +func (c connPoolDialContext) Deadline() (time.Time, bool) { + return c.parent.Deadline() +} + +func (c connPoolDialContext) Value(key any) any { + return c.parent.Value(key) +} + +func NewConnPool[T comparable](options ConnPoolOptions[T]) *ConnPool[T] { + return &ConnPool[T]{ + options: options, + state: newConnPoolState[T](options.Mode), + } +} + +func newConnPoolState[T comparable](mode ConnPoolMode) *connPoolState[T] { + ctx, cancel := context.WithCancelCause(context.Background()) + state := &connPoolState[T]{ + ctx: ctx, + cancel: cancel, + all: make(map[T]struct{}), + } + if mode == ConnPoolOrdered { + state.idleElements = make(map[T]*list.Element[T]) + } + return state +} + +func (p *ConnPool[T]) Acquire(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) { + switch p.options.Mode { + case ConnPoolSingle: + conn, _, created, err := p.acquireShared(ctx, dial) + return conn, created, err + case ConnPoolOrdered: + return p.acquireOrdered(ctx, dial) + default: + var zero T + return zero, false, net.ErrClosed + } +} + +func (p *ConnPool[T]) AcquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) { + if p.options.Mode != ConnPoolSingle { + var zero T + return zero, nil, false, net.ErrClosed + } + return p.acquireShared(ctx, dial) +} + +func (p *ConnPool[T]) Release(conn T, reuse bool) { + var ( + closeConn bool + closeErr error + ) + + p.access.Lock() + if p.closed || p.state == nil { + closeConn = true + closeErr = net.ErrClosed + p.access.Unlock() + if closeConn { + p.options.Close(conn, closeErr) + } + return + } + + currentState := p.state + _, tracked := currentState.all[conn] + if !tracked { + closeConn = true + closeErr = p.closeCause(currentState) + p.access.Unlock() + if closeConn { + p.options.Close(conn, closeErr) + } + return + } + + if !reuse || !p.options.IsAlive(conn) { + delete(currentState.all, conn) + switch p.options.Mode { + case ConnPoolSingle: + if currentState.hasShared && currentState.shared == conn { + var zero T + currentState.shared = zero + currentState.hasShared = false + currentState.sharedClaimed = false + currentState.sharedCtx = nil + if currentState.sharedCancel != nil { + currentState.sharedCancel(net.ErrClosed) + currentState.sharedCancel = nil + } + } + case ConnPoolOrdered: + if element, loaded := currentState.idleElements[conn]; loaded { + currentState.idle.Remove(element) + delete(currentState.idleElements, conn) + } + } + closeConn = true + closeErr = net.ErrClosed + p.access.Unlock() + if closeConn { + p.options.Close(conn, closeErr) + } + return + } + + if p.options.Mode == ConnPoolOrdered { + if _, loaded := currentState.idleElements[conn]; !loaded { + currentState.idleElements[conn] = currentState.idle.PushBack(conn) + } + } + p.access.Unlock() +} + +func (p *ConnPool[T]) Invalidate(conn T, cause error) { + p.access.Lock() + if p.closed || p.state == nil { + p.access.Unlock() + p.options.Close(conn, cause) + return + } + + currentState := p.state + _, tracked := currentState.all[conn] + if !tracked { + p.access.Unlock() + return + } + + delete(currentState.all, conn) + switch p.options.Mode { + case ConnPoolSingle: + if currentState.hasShared && currentState.shared == conn { + var zero T + currentState.shared = zero + currentState.hasShared = false + currentState.sharedClaimed = false + currentState.sharedCtx = nil + if currentState.sharedCancel != nil { + currentState.sharedCancel(cause) + currentState.sharedCancel = nil + } + } + case ConnPoolOrdered: + if element, loaded := currentState.idleElements[conn]; loaded { + currentState.idle.Remove(element) + delete(currentState.idleElements, conn) + } + } + p.access.Unlock() + + p.options.Close(conn, cause) +} + +func (p *ConnPool[T]) Reset() { + p.access.Lock() + if p.closed { + p.access.Unlock() + return + } + + oldState := p.state + p.state = newConnPoolState[T](p.options.Mode) + p.access.Unlock() + + p.closeState(oldState, net.ErrClosed) +} + +func (p *ConnPool[T]) Close() error { + p.access.Lock() + if p.closed { + p.access.Unlock() + return nil + } + + p.closed = true + oldState := p.state + p.state = nil + p.access.Unlock() + + p.closeState(oldState, net.ErrClosed) + return nil +} + +func (p *ConnPool[T]) acquireOrdered(ctx context.Context, dial func(context.Context) (T, error)) (T, bool, error) { + var zero T + for { + var ( + staleConn T + hasStale bool + ) + + p.access.Lock() + if p.closed { + p.access.Unlock() + return zero, false, net.ErrClosed + } + + currentState := p.state + if element := currentState.idle.Front(); element != nil { + conn := currentState.idle.Remove(element) + delete(currentState.idleElements, conn) + if p.options.IsAlive(conn) { + p.access.Unlock() + return conn, false, nil + } + delete(currentState.all, conn) + staleConn = conn + hasStale = true + } + p.access.Unlock() + + if hasStale { + p.options.Close(staleConn, net.ErrClosed) + continue + } + + conn, err := p.dial(ctx, currentState, dial) + if err != nil { + return zero, false, err + } + + p.access.Lock() + if p.closed { + p.access.Unlock() + p.options.Close(conn, net.ErrClosed) + return zero, false, net.ErrClosed + } + if p.state != currentState { + cause := p.closeCause(currentState) + p.access.Unlock() + p.options.Close(conn, cause) + return zero, false, cause + } + currentState.all[conn] = struct{}{} + p.access.Unlock() + return conn, true, nil + } +} + +func (p *ConnPool[T]) acquireShared(ctx context.Context, dial func(context.Context) (T, error)) (T, context.Context, bool, error) { + var zero T + for { + var ( + staleConn T + hasStale bool + state *connPoolConnect[T] + current *connPoolState[T] + startDial bool + ) + + p.access.Lock() + if p.closed { + p.access.Unlock() + return zero, nil, false, net.ErrClosed + } + + current = p.state + if current.hasShared { + conn := current.shared + if p.options.IsAlive(conn) { + created := !current.sharedClaimed + current.sharedClaimed = true + connCtx := current.sharedCtx + p.access.Unlock() + return conn, connCtx, created, nil + } + delete(current.all, conn) + var zeroConn T + current.shared = zeroConn + current.hasShared = false + current.sharedClaimed = false + current.sharedCtx = nil + if current.sharedCancel != nil { + current.sharedCancel(net.ErrClosed) + current.sharedCancel = nil + } + staleConn = conn + hasStale = true + p.access.Unlock() + p.options.Close(staleConn, net.ErrClosed) + continue + } + + if current.connecting == nil { + current.connecting = &connPoolConnect[T]{ + done: make(chan struct{}), + } + startDial = true + } + state = current.connecting + p.access.Unlock() + + if hasStale { + continue + } + if startDial { + go p.connectSingle(current, state, ctx, dial) + } + + select { + case <-state.done: + conn, connCtx, created, retry, err := p.collectShared(current, state, startDial) + if retry { + continue + } + return conn, connCtx, created, err + case <-ctx.Done(): + return zero, nil, false, ctx.Err() + case <-current.ctx.Done(): + p.access.Lock() + closed := p.closed + p.access.Unlock() + if closed { + return zero, nil, false, net.ErrClosed + } + } + } +} + +func (p *ConnPool[T]) connectSingle(current *connPoolState[T], state *connPoolConnect[T], ctx context.Context, dial func(context.Context) (T, error)) { + conn, err := p.dial(ctx, current, dial) + if err != nil { + p.access.Lock() + if current.connecting == state { + current.connecting = nil + } + state.err = err + p.access.Unlock() + close(state.done) + return + } + + var closeErr error + + p.access.Lock() + if current.connecting == state { + current.connecting = nil + } + if p.closed { + closeErr = net.ErrClosed + state.err = closeErr + } else if p.state != current { + closeErr = p.closeCause(current) + state.err = closeErr + } else { + sharedCtx, sharedCancel := context.WithCancelCause(current.ctx) + current.shared = conn + current.hasShared = true + current.sharedClaimed = false + current.sharedCtx = sharedCtx + current.sharedCancel = sharedCancel + current.all[conn] = struct{}{} + } + p.access.Unlock() + + if closeErr != nil { + p.options.Close(conn, closeErr) + } + close(state.done) +} + +func (p *ConnPool[T]) collectShared(current *connPoolState[T], state *connPoolConnect[T], startDial bool) (T, context.Context, bool, bool, error) { + var zero T + + p.access.Lock() + if state.err != nil { + err := state.err + p.access.Unlock() + if startDial { + return zero, nil, false, false, err + } + return zero, nil, false, true, nil + } + if p.closed { + p.access.Unlock() + return zero, nil, false, false, net.ErrClosed + } + if p.state != current { + cause := p.closeCause(current) + p.access.Unlock() + return zero, nil, false, false, cause + } + if !current.hasShared { + p.access.Unlock() + return zero, nil, false, true, nil + } + + conn := current.shared + if !p.options.IsAlive(conn) { + delete(current.all, conn) + var zeroConn T + current.shared = zeroConn + current.hasShared = false + current.sharedClaimed = false + current.sharedCtx = nil + if current.sharedCancel != nil { + current.sharedCancel(net.ErrClosed) + current.sharedCancel = nil + } + p.access.Unlock() + p.options.Close(conn, net.ErrClosed) + return zero, nil, false, true, nil + } + + created := !current.sharedClaimed + current.sharedClaimed = true + connCtx := current.sharedCtx + p.access.Unlock() + return conn, connCtx, created, false, nil +} + +func (p *ConnPool[T]) dial(ctx context.Context, current *connPoolState[T], dial func(context.Context) (T, error)) (T, error) { + var zero T + + if err := ctx.Err(); err != nil { + return zero, err + } + if cause := context.Cause(current.ctx); cause != nil { + return zero, cause + } + + dialCtx, cancel := context.WithCancelCause(current.ctx) + var ( + stateAccess sync.Mutex + dialComplete bool + ) + stopCancel := context.AfterFunc(ctx, func() { + stateAccess.Lock() + if !dialComplete { + cancel(context.Cause(ctx)) + } + stateAccess.Unlock() + }) + + select { + case <-ctx.Done(): + stateAccess.Lock() + dialComplete = true + stateAccess.Unlock() + stopCancel() + cancel(context.Cause(ctx)) + return zero, ctx.Err() + default: + } + + conn, err := dial(connPoolDialContext{ + Context: dialCtx, + parent: ctx, + }) + stateAccess.Lock() + dialComplete = true + stateAccess.Unlock() + stopCancel() + if err != nil { + if cause := context.Cause(dialCtx); cause != nil { + return zero, cause + } + return zero, err + } + if cause := context.Cause(dialCtx); cause != nil { + p.options.Close(conn, cause) + return zero, cause + } + return conn, nil +} + +func (p *ConnPool[T]) closeState(state *connPoolState[T], cause error) { + if state == nil { + return + } + + state.cancel(cause) + if state.sharedCancel != nil { + state.sharedCancel(cause) + } + for conn := range state.all { + p.options.Close(conn, cause) + } +} + +func (p *ConnPool[T]) closeCause(state *connPoolState[T]) error { + _ = state + return net.ErrClosed +} diff --git a/dns/transport/connector.go b/dns/transport/connector.go deleted file mode 100644 index 3a87456d..00000000 --- a/dns/transport/connector.go +++ /dev/null @@ -1,321 +0,0 @@ -package transport - -import ( - "context" - "net" - "sync" - "time" - - E "github.com/sagernet/sing/common/exceptions" -) - -type ConnectorCallbacks[T any] struct { - IsClosed func(connection T) bool - Close func(connection T) - Reset func(connection T) -} - -type Connector[T any] struct { - dial func(ctx context.Context) (T, error) - callbacks ConnectorCallbacks[T] - - access sync.Mutex - connection T - hasConnection bool - connectionCancel context.CancelFunc - connecting chan struct{} - - closeCtx context.Context - closed bool -} - -func NewConnector[T any](closeCtx context.Context, dial func(context.Context) (T, error), callbacks ConnectorCallbacks[T]) *Connector[T] { - return &Connector[T]{ - dial: dial, - callbacks: callbacks, - closeCtx: closeCtx, - } -} - -func NewSingleflightConnector(closeCtx context.Context, dial func(context.Context) (*Connection, error)) *Connector[*Connection] { - return NewConnector(closeCtx, dial, ConnectorCallbacks[*Connection]{ - IsClosed: func(connection *Connection) bool { - return connection.IsClosed() - }, - Close: func(connection *Connection) { - connection.CloseWithError(ErrTransportClosed) - }, - Reset: func(connection *Connection) { - connection.CloseWithError(ErrConnectionReset) - }, - }) -} - -type contextKeyConnecting struct{} - -var errRecursiveConnectorDial = E.New("recursive connector dial") - -type connectorDialResult[T any] struct { - connection T - cancel context.CancelFunc - err error -} - -func (c *Connector[T]) Get(ctx context.Context) (T, error) { - var zero T - for { - c.access.Lock() - - if c.closed { - c.access.Unlock() - return zero, ErrTransportClosed - } - - if c.hasConnection && !c.callbacks.IsClosed(c.connection) { - connection := c.connection - c.access.Unlock() - return connection, nil - } - - c.hasConnection = false - if c.connectionCancel != nil { - c.connectionCancel() - c.connectionCancel = nil - } - if isRecursiveConnectorDial(ctx, c) { - c.access.Unlock() - return zero, errRecursiveConnectorDial - } - - if c.connecting != nil { - connecting := c.connecting - c.access.Unlock() - - select { - case <-connecting: - continue - case <-ctx.Done(): - return zero, ctx.Err() - case <-c.closeCtx.Done(): - return zero, ErrTransportClosed - } - } - - if err := ctx.Err(); err != nil { - c.access.Unlock() - return zero, err - } - - connecting := make(chan struct{}) - c.connecting = connecting - dialContext := context.WithValue(ctx, contextKeyConnecting{}, c) - dialResult := make(chan connectorDialResult[T], 1) - c.access.Unlock() - - go func() { - connection, cancel, err := c.dialWithCancellation(dialContext) - dialResult <- connectorDialResult[T]{ - connection: connection, - cancel: cancel, - err: err, - } - }() - - select { - case result := <-dialResult: - return c.completeDial(ctx, connecting, result) - case <-ctx.Done(): - go func() { - result := <-dialResult - _, _ = c.completeDial(ctx, connecting, result) - }() - return zero, ctx.Err() - case <-c.closeCtx.Done(): - go func() { - result := <-dialResult - _, _ = c.completeDial(ctx, connecting, result) - }() - return zero, ErrTransportClosed - } - } -} - -func isRecursiveConnectorDial[T any](ctx context.Context, connector *Connector[T]) bool { - dialConnector, loaded := ctx.Value(contextKeyConnecting{}).(*Connector[T]) - return loaded && dialConnector == connector -} - -func (c *Connector[T]) completeDial(ctx context.Context, connecting chan struct{}, result connectorDialResult[T]) (T, error) { - var zero T - - c.access.Lock() - defer c.access.Unlock() - defer func() { - if c.connecting == connecting { - c.connecting = nil - } - close(connecting) - }() - - if result.err != nil { - return zero, result.err - } - if c.closed || c.closeCtx.Err() != nil { - result.cancel() - c.callbacks.Close(result.connection) - return zero, ErrTransportClosed - } - if err := ctx.Err(); err != nil { - result.cancel() - c.callbacks.Close(result.connection) - return zero, err - } - - c.connection = result.connection - c.hasConnection = true - c.connectionCancel = result.cancel - return c.connection, nil -} - -func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) { - var zero T - if err := ctx.Err(); err != nil { - return zero, nil, err - } - connCtx, cancel := context.WithCancel(c.closeCtx) - - var ( - stateAccess sync.Mutex - dialComplete bool - ) - stopCancel := context.AfterFunc(ctx, func() { - stateAccess.Lock() - if !dialComplete { - cancel() - } - stateAccess.Unlock() - }) - select { - case <-ctx.Done(): - stateAccess.Lock() - dialComplete = true - stateAccess.Unlock() - stopCancel() - cancel() - return zero, nil, ctx.Err() - default: - } - - connection, err := c.dial(valueContext{connCtx, ctx}) - stateAccess.Lock() - dialComplete = true - stateAccess.Unlock() - stopCancel() - if err != nil { - cancel() - return zero, nil, err - } - return connection, cancel, nil -} - -type valueContext struct { - context.Context - parent context.Context -} - -func (v valueContext) Value(key any) any { - return v.parent.Value(key) -} - -func (v valueContext) Deadline() (time.Time, bool) { - return v.parent.Deadline() -} - -func (c *Connector[T]) Close() error { - c.access.Lock() - defer c.access.Unlock() - - if c.closed { - return nil - } - c.closed = true - - if c.connectionCancel != nil { - c.connectionCancel() - c.connectionCancel = nil - } - if c.hasConnection { - c.callbacks.Close(c.connection) - c.hasConnection = false - } - - return nil -} - -func (c *Connector[T]) Reset() { - c.access.Lock() - defer c.access.Unlock() - - if c.connectionCancel != nil { - c.connectionCancel() - c.connectionCancel = nil - } - if c.hasConnection { - c.callbacks.Reset(c.connection) - c.hasConnection = false - } -} - -type Connection struct { - net.Conn - - closeOnce sync.Once - done chan struct{} - closeError error -} - -func WrapConnection(conn net.Conn) *Connection { - return &Connection{ - Conn: conn, - done: make(chan struct{}), - } -} - -func (c *Connection) Done() <-chan struct{} { - return c.done -} - -func (c *Connection) IsClosed() bool { - select { - case <-c.done: - return true - default: - return false - } -} - -func (c *Connection) CloseError() error { - select { - case <-c.done: - if c.closeError != nil { - return c.closeError - } - return ErrTransportClosed - default: - return nil - } -} - -func (c *Connection) Close() error { - return c.CloseWithError(ErrTransportClosed) -} - -func (c *Connection) CloseWithError(err error) error { - var returnError error - c.closeOnce.Do(func() { - c.closeError = err - returnError = c.Conn.Close() - close(c.done) - }) - return returnError -} diff --git a/dns/transport/connector_test.go b/dns/transport/connector_test.go deleted file mode 100644 index 309b28c8..00000000 --- a/dns/transport/connector_test.go +++ /dev/null @@ -1,407 +0,0 @@ -package transport - -import ( - "context" - "sync/atomic" - "testing" - "time" - - "github.com/stretchr/testify/require" -) - -type testConnectorConnection struct{} - -func TestConnectorRecursiveGetFailsFast(t *testing.T) { - t.Parallel() - - var ( - dialCount atomic.Int32 - closeCount atomic.Int32 - connector *Connector[*testConnectorConnection] - ) - - dial := func(ctx context.Context) (*testConnectorConnection, error) { - dialCount.Add(1) - _, err := connector.Get(ctx) - if err != nil { - return nil, err - } - return &testConnectorConnection{}, nil - } - - connector = NewConnector(context.Background(), dial, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) { - closeCount.Add(1) - }, - Reset: func(connection *testConnectorConnection) { - closeCount.Add(1) - }, - }) - - _, err := connector.Get(context.Background()) - require.ErrorIs(t, err, errRecursiveConnectorDial) - require.EqualValues(t, 1, dialCount.Load()) - require.EqualValues(t, 0, closeCount.Load()) -} - -func TestConnectorRecursiveGetAcrossConnectorsAllowed(t *testing.T) { - t.Parallel() - - var ( - outerDialCount atomic.Int32 - innerDialCount atomic.Int32 - outerConnector *Connector[*testConnectorConnection] - innerConnector *Connector[*testConnectorConnection] - ) - - innerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - innerDialCount.Add(1) - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - outerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - outerDialCount.Add(1) - _, err := innerConnector.Get(ctx) - if err != nil { - return nil, err - } - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - _, err := outerConnector.Get(context.Background()) - require.NoError(t, err) - require.EqualValues(t, 1, outerDialCount.Load()) - require.EqualValues(t, 1, innerDialCount.Load()) -} - -func TestConnectorDialContextPreservesValueAndDeadline(t *testing.T) { - t.Parallel() - - type contextKey struct{} - - var ( - dialValue any - dialDeadline time.Time - dialHasDeadline bool - ) - - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialValue = ctx.Value(contextKey{}) - dialDeadline, dialHasDeadline = ctx.Deadline() - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - deadline := time.Now().Add(time.Minute) - requestContext, cancel := context.WithDeadline(context.WithValue(context.Background(), contextKey{}, "test-value"), deadline) - defer cancel() - - _, err := connector.Get(requestContext) - require.NoError(t, err) - require.Equal(t, "test-value", dialValue) - require.True(t, dialHasDeadline) - require.WithinDuration(t, deadline, dialDeadline, time.Second) -} - -func TestConnectorDialSkipsCanceledRequest(t *testing.T) { - t.Parallel() - - var dialCount atomic.Int32 - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialCount.Add(1) - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - requestContext, cancel := context.WithCancel(context.Background()) - cancel() - - _, err := connector.Get(requestContext) - require.ErrorIs(t, err, context.Canceled) - require.EqualValues(t, 0, dialCount.Load()) -} - -func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) { - t.Parallel() - - var ( - dialCount atomic.Int32 - closeCount atomic.Int32 - ) - dialStarted := make(chan struct{}, 1) - releaseDial := make(chan struct{}) - - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialCount.Add(1) - select { - case dialStarted <- struct{}{}: - default: - } - <-releaseDial - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) { - closeCount.Add(1) - }, - Reset: func(connection *testConnectorConnection) {}, - }) - - requestContext, cancel := context.WithCancel(context.Background()) - result := make(chan error, 1) - go func() { - _, err := connector.Get(requestContext) - result <- err - }() - - <-dialStarted - cancel() - close(releaseDial) - - err := <-result - require.ErrorIs(t, err, context.Canceled) - require.EqualValues(t, 1, dialCount.Load()) - require.Eventually(t, func() bool { - return closeCount.Load() == 1 - }, time.Second, 10*time.Millisecond) - - _, err = connector.Get(context.Background()) - require.NoError(t, err) - require.EqualValues(t, 2, dialCount.Load()) -} - -func TestConnectorCanceledRequestReturnsBeforeIgnoredDialCompletes(t *testing.T) { - t.Parallel() - - var ( - dialCount atomic.Int32 - closeCount atomic.Int32 - ) - dialStarted := make(chan struct{}, 1) - releaseDial := make(chan struct{}) - - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialCount.Add(1) - select { - case dialStarted <- struct{}{}: - default: - } - <-releaseDial - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) { - closeCount.Add(1) - }, - Reset: func(connection *testConnectorConnection) {}, - }) - - requestContext, cancel := context.WithCancel(context.Background()) - result := make(chan error, 1) - go func() { - _, err := connector.Get(requestContext) - result <- err - }() - - <-dialStarted - cancel() - - select { - case err := <-result: - require.ErrorIs(t, err, context.Canceled) - case <-time.After(time.Second): - t.Fatal("Get did not return after request cancel") - } - - require.EqualValues(t, 1, dialCount.Load()) - require.EqualValues(t, 0, closeCount.Load()) - - close(releaseDial) - - require.Eventually(t, func() bool { - return closeCount.Load() == 1 - }, time.Second, 10*time.Millisecond) - - _, err := connector.Get(context.Background()) - require.NoError(t, err) - require.EqualValues(t, 2, dialCount.Load()) -} - -func TestConnectorWaiterDoesNotStartNewDialBeforeCanceledDialCompletes(t *testing.T) { - t.Parallel() - - var ( - dialCount atomic.Int32 - closeCount atomic.Int32 - ) - firstDialStarted := make(chan struct{}, 1) - secondDialStarted := make(chan struct{}, 1) - releaseFirstDial := make(chan struct{}) - - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - attempt := dialCount.Add(1) - switch attempt { - case 1: - select { - case firstDialStarted <- struct{}{}: - default: - } - <-releaseFirstDial - case 2: - select { - case secondDialStarted <- struct{}{}: - default: - } - } - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) { - closeCount.Add(1) - }, - Reset: func(connection *testConnectorConnection) {}, - }) - - requestContext, cancel := context.WithCancel(context.Background()) - firstResult := make(chan error, 1) - go func() { - _, err := connector.Get(requestContext) - firstResult <- err - }() - - <-firstDialStarted - cancel() - - secondResult := make(chan error, 1) - go func() { - _, err := connector.Get(context.Background()) - secondResult <- err - }() - - select { - case <-secondDialStarted: - t.Fatal("second dial started before first dial completed") - case <-time.After(100 * time.Millisecond): - } - - select { - case err := <-firstResult: - require.ErrorIs(t, err, context.Canceled) - case <-time.After(time.Second): - t.Fatal("first Get did not return after request cancel") - } - - close(releaseFirstDial) - - require.Eventually(t, func() bool { - return closeCount.Load() == 1 - }, time.Second, 10*time.Millisecond) - - select { - case <-secondDialStarted: - case <-time.After(time.Second): - t.Fatal("second dial did not start after first dial completed") - } - - err := <-secondResult - require.NoError(t, err) - require.EqualValues(t, 2, dialCount.Load()) -} - -func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) { - t.Parallel() - - var dialContext context.Context - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialContext = ctx - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - requestContext, cancel := context.WithCancel(context.Background()) - _, err := connector.Get(requestContext) - require.NoError(t, err) - require.NotNil(t, dialContext) - - cancel() - - select { - case <-dialContext.Done(): - t.Fatal("dial context canceled by request context after successful dial") - case <-time.After(100 * time.Millisecond): - } - - err = connector.Close() - require.NoError(t, err) -} - -func TestConnectorDialContextCanceledOnClose(t *testing.T) { - t.Parallel() - - var dialContext context.Context - connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) { - dialContext = ctx - return &testConnectorConnection{}, nil - }, ConnectorCallbacks[*testConnectorConnection]{ - IsClosed: func(connection *testConnectorConnection) bool { - return false - }, - Close: func(connection *testConnectorConnection) {}, - Reset: func(connection *testConnectorConnection) {}, - }) - - _, err := connector.Get(context.Background()) - require.NoError(t, err) - require.NotNil(t, dialContext) - - select { - case <-dialContext.Done(): - t.Fatal("dial context canceled before connector close") - default: - } - - err = connector.Close() - require.NoError(t, err) - - select { - case <-dialContext.Done(): - case <-time.After(time.Second): - t.Fatal("dial context not canceled after connector close") - } -} diff --git a/dns/transport/quic/quic.go b/dns/transport/quic/quic.go index 26461006..3a7b6163 100644 --- a/dns/transport/quic/quic.go +++ b/dns/transport/quic/quic.go @@ -31,14 +31,13 @@ func RegisterTransport(registry *dns.TransportRegistry) { } type Transport struct { - *transport.BaseTransport + dns.TransportAdapter - ctx context.Context dialer N.Dialer serverAddr M.Socksaddr tlsConfig tls.Config - connector *transport.Connector[*quic.Conn] + connection *transport.ConnPool[*quic.Conn] } func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) { @@ -63,93 +62,76 @@ func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options return nil, E.New("invalid server address: ", serverAddr) } - t := &Transport{ - BaseTransport: transport.NewBaseTransport( - dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions), - logger, - ), - ctx: ctx, - dialer: transportDialer, - serverAddr: serverAddr, - tlsConfig: tlsConfig, - } - - t.connector = transport.NewConnector(t.CloseContext(), t.dial, transport.ConnectorCallbacks[*quic.Conn]{ - IsClosed: func(connection *quic.Conn) bool { - return common.Done(connection.Context()) - }, - Close: func(connection *quic.Conn) { - connection.CloseWithError(0, "") - }, - Reset: func(connection *quic.Conn) { - connection.CloseWithError(0, "") - }, - }) - - return t, nil -} - -func (t *Transport) dial(ctx context.Context) (*quic.Conn, error) { - conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) - if err != nil { - return nil, E.Cause(err, "dial UDP connection") - } - earlyConnection, err := sQUIC.DialEarly( - ctx, - bufio.NewUnbindPacketConn(conn), - t.serverAddr.UDPAddr(), - t.tlsConfig, - nil, - ) - if err != nil { - conn.Close() - return nil, E.Cause(err, "establish QUIC connection") - } - return earlyConnection, nil + return &Transport{ + TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions), + dialer: transportDialer, + serverAddr: serverAddr, + tlsConfig: tlsConfig, + connection: transport.NewConnPool(transport.ConnPoolOptions[*quic.Conn]{ + Mode: transport.ConnPoolSingle, + IsAlive: func(conn *quic.Conn) bool { + return conn != nil && !common.Done(conn.Context()) + }, + Close: func(conn *quic.Conn, _ error) { + conn.CloseWithError(0, "") + }, + }), + }, nil } func (t *Transport) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } - err := t.SetStarted() - if err != nil { - return err - } return dialer.InitializeDetour(t.dialer) } func (t *Transport) Close() error { - return E.Errors(t.BaseTransport.Close(), t.connector.Close()) + return t.connection.Close() } func (t *Transport) Reset() { - t.connector.Reset() + t.connection.Reset() } func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - if !t.BeginQuery() { - return nil, transport.ErrTransportClosed - } - defer t.EndQuery() - var ( conn *quic.Conn err error response *mDNS.Msg ) for i := 0; i < 2; i++ { - conn, err = t.connector.Get(ctx) + conn, _, err = t.connection.Acquire(ctx, func(ctx context.Context) (*quic.Conn, error) { + rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) + if err != nil { + return nil, E.Cause(err, "dial UDP connection") + } + earlyConnection, err := sQUIC.DialEarly( + ctx, + bufio.NewUnbindPacketConn(rawConn), + t.serverAddr.UDPAddr(), + t.tlsConfig, + nil, + ) + if err != nil { + rawConn.Close() + return nil, E.Cause(err, "establish QUIC connection") + } + return earlyConnection, nil + }) if err != nil { return nil, err } response, err = t.exchange(ctx, message, conn) if err == nil { + t.connection.Release(conn, true) return response, nil } else if !isQUICRetryError(err) { + t.connection.Release(conn, true) return nil, err } else { - t.connector.Reset() + t.connection.Release(conn, true) + t.Reset() continue } } diff --git a/dns/transport/tls.go b/dns/transport/tls.go index 4d463296..43978b6f 100644 --- a/dns/transport/tls.go +++ b/dns/transport/tls.go @@ -2,7 +2,6 @@ package transport import ( "context" - "sync" "time" "github.com/sagernet/sing-box/adapter" @@ -17,7 +16,6 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/x/list" mDNS "github.com/miekg/dns" ) @@ -29,13 +27,13 @@ func RegisterTLS(registry *dns.TransportRegistry) { } type TLSTransport struct { - *BaseTransport + dns.TransportAdapter + logger logger.ContextLogger dialer tls.Dialer serverAddr M.Socksaddr tlsConfig tls.Config - access sync.Mutex - connections list.List[*tlsDNSConn] + connections *ConnPool[*tlsDNSConn] } type tlsDNSConn struct { @@ -66,10 +64,20 @@ func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options o func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport { return &TLSTransport{ - BaseTransport: NewBaseTransport(adapter, logger), - dialer: tls.NewDialer(dialer, tlsConfig), - serverAddr: serverAddr, - tlsConfig: tlsConfig, + TransportAdapter: adapter, + logger: logger, + dialer: tls.NewDialer(dialer, tlsConfig), + serverAddr: serverAddr, + tlsConfig: tlsConfig, + connections: NewConnPool(ConnPoolOptions[*tlsDNSConn]{ + Mode: ConnPoolOrdered, + IsAlive: func(conn *tlsDNSConn) bool { + return conn != nil + }, + Close: func(conn *tlsDNSConn, _ error) { + conn.Close() + }, + }), } } @@ -77,53 +85,43 @@ func (t *TLSTransport) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } - err := t.SetStarted() - if err != nil { - return err - } return dialer.InitializeDetour(t.dialer) } func (t *TLSTransport) Close() error { - t.access.Lock() - for connection := t.connections.Front(); connection != nil; connection = connection.Next() { - connection.Value.Close() - } - t.connections.Init() - t.access.Unlock() - return t.BaseTransport.Close() + return t.connections.Close() } func (t *TLSTransport) Reset() { - t.access.Lock() - defer t.access.Unlock() - for connection := t.connections.Front(); connection != nil; connection = connection.Next() { - connection.Value.Close() - } - t.connections.Init() + t.connections.Reset() } func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - if !t.BeginQuery() { - return nil, ErrTransportClosed - } - defer t.EndQuery() - - t.access.Lock() - conn := t.connections.PopFront() - t.access.Unlock() - if conn != nil { + var lastErr error + for attempt := 0; attempt < 2; attempt++ { + conn, created, err := t.connections.Acquire(ctx, func(ctx context.Context) (*tlsDNSConn, error) { + tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr) + if err != nil { + return nil, E.Cause(err, "dial TLS connection") + } + return &tlsDNSConn{Conn: tlsConn}, nil + }) + if err != nil { + return nil, err + } response, err := t.exchange(ctx, message, conn) if err == nil { + t.connections.Release(conn, true) return response, nil } - t.Logger.DebugContext(ctx, "discarded pooled connection: ", err) + lastErr = err + t.logger.DebugContext(ctx, "discarded pooled connection: ", err) + t.connections.Release(conn, false) + if created { + return nil, err + } } - tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr) - if err != nil { - return nil, E.Cause(err, "dial TLS connection") - } - return t.exchange(ctx, message, &tlsDNSConn{Conn: tlsConn}) + return nil, lastErr } func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) { @@ -133,22 +131,12 @@ func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tl conn.queryId++ err := WriteMessage(conn, conn.queryId, message) if err != nil { - conn.Close() return nil, E.Cause(err, "write request") } response, err := ReadMessage(conn) if err != nil { - conn.Close() return nil, E.Cause(err, "read response") } - t.access.Lock() - if t.State() >= StateClosing { - t.access.Unlock() - conn.Close() - return response, nil - } conn.SetDeadline(time.Time{}) - t.connections.PushBack(conn) - t.access.Unlock() return response, nil } diff --git a/dns/transport/udp.go b/dns/transport/udp.go index a7272545..c9f520e3 100644 --- a/dns/transport/udp.go +++ b/dns/transport/udp.go @@ -2,6 +2,7 @@ package transport import ( "context" + "net" "sync" "sync/atomic" @@ -27,13 +28,14 @@ func RegisterUDP(registry *dns.TransportRegistry) { } type UDPTransport struct { - *BaseTransport + dns.TransportAdapter + logger logger.ContextLogger dialer N.Dialer serverAddr M.Socksaddr udpSize atomic.Int32 - connector *Connector[*Connection] + connection *ConnPool[net.Conn] callbackAccess sync.RWMutex queryId uint16 @@ -63,43 +65,38 @@ func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options o func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialerInstance N.Dialer, serverAddr M.Socksaddr) *UDPTransport { t := &UDPTransport{ - BaseTransport: NewBaseTransport(adapter, logger), - dialer: dialerInstance, - serverAddr: serverAddr, - callbacks: make(map[uint16]*udpCallback), + TransportAdapter: adapter, + logger: logger, + dialer: dialerInstance, + serverAddr: serverAddr, + callbacks: make(map[uint16]*udpCallback), + connection: NewConnPool(ConnPoolOptions[net.Conn]{ + Mode: ConnPoolSingle, + IsAlive: func(conn net.Conn) bool { + return conn != nil + }, + Close: func(conn net.Conn, cause error) { + conn.Close() + }, + }), } t.udpSize.Store(2048) - t.connector = NewSingleflightConnector(t.CloseContext(), t.dial) return t } -func (t *UDPTransport) dial(ctx context.Context) (*Connection, error) { - rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) - if err != nil { - return nil, E.Cause(err, "dial UDP connection") - } - conn := WrapConnection(rawConn) - go t.recvLoop(conn) - return conn, nil -} - func (t *UDPTransport) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } - err := t.SetStarted() - if err != nil { - return err - } return dialer.InitializeDetour(t.dialer) } func (t *UDPTransport) Close() error { - return E.Errors(t.BaseTransport.Close(), t.connector.Close()) + return t.connection.Close() } func (t *UDPTransport) Reset() { - t.connector.Reset() + t.connection.Reset() } func (t *UDPTransport) nextAvailableQueryId() (uint16, error) { @@ -116,17 +113,12 @@ func (t *UDPTransport) nextAvailableQueryId() (uint16, error) { } func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - if !t.BeginQuery() { - return nil, ErrTransportClosed - } - defer t.EndQuery() - response, err := t.exchange(ctx, message) if err != nil { return nil, err } if response.Truncated { - t.Logger.InfoContext(ctx, "response truncated, retrying with TCP") + t.logger.InfoContext(ctx, "response truncated, retrying with TCP") return t.exchangeTCP(ctx, message) } return response, nil @@ -158,16 +150,25 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M break } if t.udpSize.CompareAndSwap(current, udpSize) { - t.connector.Reset() + t.Reset() break } } } - conn, err := t.connector.Get(ctx) + conn, connCtx, created, err := t.connection.AcquireShared(ctx, func(ctx context.Context) (net.Conn, error) { + rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) + if err != nil { + return nil, E.Cause(err, "dial UDP connection") + } + return rawConn, nil + }) if err != nil { return nil, err } + if created { + go t.recvLoop(conn) + } callback := &udpCallback{ done: make(chan struct{}), @@ -177,6 +178,7 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M queryId, err := t.nextAvailableQueryId() if err != nil { t.callbackAccess.Unlock() + t.connection.Release(conn, true) return nil, err } t.callbacks[queryId] = callback @@ -203,30 +205,30 @@ func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M _, err = conn.Write(rawMessage) if err != nil { - conn.CloseWithError(err) + t.connection.Invalidate(conn, err) return nil, E.Cause(err, "write request") } select { case <-callback.done: + t.connection.Release(conn, true) callback.response.Id = originalId return callback.response, nil - case <-conn.Done(): - return nil, conn.CloseError() - case <-t.CloseContext().Done(): - return nil, ErrTransportClosed + case <-connCtx.Done(): + return nil, context.Cause(connCtx) case <-ctx.Done(): + t.connection.Release(conn, true) return nil, ctx.Err() } } -func (t *UDPTransport) recvLoop(conn *Connection) { +func (t *UDPTransport) recvLoop(conn net.Conn) { for { buffer := buf.NewSize(int(t.udpSize.Load())) _, err := buffer.ReadOnceFrom(conn) if err != nil { buffer.Release() - conn.CloseWithError(err) + t.connection.Invalidate(conn, err) return } @@ -234,7 +236,7 @@ func (t *UDPTransport) recvLoop(conn *Connection) { err = message.Unpack(buffer.Bytes()) buffer.Release() if err != nil { - t.Logger.Debug("discarded malformed UDP response: ", err) + t.logger.Debug("discarded malformed UDP response: ", err) continue } diff --git a/protocol/tailscale/dns_transport.go b/protocol/tailscale/dns_transport.go index 3a92a66b..4195235c 100644 --- a/protocol/tailscale/dns_transport.go +++ b/protocol/tailscale/dns_transport.go @@ -49,6 +49,7 @@ type DNSTransport struct { dnsRouter adapter.DNSRouter endpointManager adapter.EndpointManager endpoint *Endpoint + access sync.RWMutex routePrefixes []netip.Prefix routes map[string][]adapter.DNSTransport hosts map[string][]netip.Addr @@ -91,6 +92,12 @@ func (t *DNSTransport) Start(stage adapter.StartStage) error { } func (t *DNSTransport) Reset() { + t.access.RLock() + transports := t.collectResolversLocked() + t.access.RUnlock() + for _, transport := range transports { + transport.Reset() + } } func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *nDNS.Config) { @@ -101,7 +108,7 @@ func (t *DNSTransport) onReconfig(cfg *wgcfg.Config, routerCfg *router.Config, d } func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *nDNS.Config) error { - t.routePrefixes = buildRoutePrefixes(routeConfig) + routePrefixes := buildRoutePrefixes(routeConfig) directDialerOnce := sync.OnceValue(func() N.Dialer { directDialer := common.Must1(dialer.NewDefault(t.ctx, option.DialerOptions{})) return &DNSDialer{transport: t, fallbackDialer: directDialer} @@ -130,9 +137,19 @@ func (t *DNSTransport) updateDNSServers(routeConfig *router.Config, dnsConfig *n } defaultResolvers = append(defaultResolvers, myResolver) } + + t.access.Lock() + oldResolvers := t.collectResolversLocked() + t.routePrefixes = routePrefixes t.routes = routes t.hosts = hosts t.defaultResolvers = defaultResolvers + t.access.Unlock() + + for _, transport := range oldResolvers { + transport.Close() + } + if len(defaultResolvers) > 0 { t.logger.Info("updated ", len(routes), " routes, ", len(hosts), " hosts, default resolvers: ", strings.Join(common.Map(dnsConfig.DefaultResolvers, func(it *dnstype.Resolver) string { return it.Addr }), " ")) @@ -207,7 +224,22 @@ func buildRoutePrefixes(routeConfig *router.Config) []netip.Prefix { } func (t *DNSTransport) Close() error { - return nil + t.access.Lock() + transports := t.collectResolversLocked() + t.routePrefixes = nil + t.routes = nil + t.hosts = nil + t.defaultResolvers = nil + t.access.Unlock() + + var err error + for _, transport := range transports { + name := "resolver/" + transport.Type() + "[" + transport.Tag() + "]" + err = E.Append(err, transport.Close(), func(err error) error { + return E.Cause(err, "close ", name) + }) + } + return err } func (t *DNSTransport) Raw() bool { @@ -219,7 +251,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M return nil, os.ErrInvalid } question := message.Question[0] - addresses, hostsLoaded := t.hosts[question.Name] + + t.access.RLock() + hosts := t.hosts + routes := t.routes + defaultResolvers := t.defaultResolvers + acceptDefaultResolvers := t.acceptDefaultResolvers + t.access.RUnlock() + + addresses, hostsLoaded := hosts[question.Name] if hostsLoaded { switch question.Qtype { case mDNS.TypeA: @@ -238,7 +278,7 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M } } } - for domainSuffix, transports := range t.routes { + for domainSuffix, transports := range routes { if strings.HasSuffix(question.Name, domainSuffix) { if len(transports) == 0 { return &mDNS.Msg{ @@ -262,10 +302,10 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M return nil, lastErr } } - if t.acceptDefaultResolvers { - if len(t.defaultResolvers) > 0 { + if acceptDefaultResolvers { + if len(defaultResolvers) > 0 { var lastErr error - for _, resolver := range t.defaultResolvers { + for _, resolver := range defaultResolvers { response, err := resolver.Exchange(ctx, message) if err != nil { lastErr = err @@ -281,6 +321,15 @@ func (t *DNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.M return nil, dns.RcodeNameError } +func (t *DNSTransport) collectResolversLocked() []adapter.DNSTransport { + var transports []adapter.DNSTransport + for _, resolvers := range t.routes { + transports = append(transports, resolvers...) + } + transports = append(transports, t.defaultResolvers...) + return transports +} + type DNSDialer struct { transport *DNSTransport fallbackDialer N.Dialer @@ -290,7 +339,8 @@ func (d *DNSDialer) DialContext(ctx context.Context, network string, destination if destination.IsDomain() { panic("invalid request here") } - for _, prefix := range d.transport.routePrefixes { + routePrefixes := d.transport.routePrefixesSnapshot() + for _, prefix := range routePrefixes { if prefix.Contains(destination.Addr) { return d.transport.endpoint.DialContext(ctx, network, destination) } @@ -302,10 +352,17 @@ func (d *DNSDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) ( if destination.IsDomain() { panic("invalid request here") } - for _, prefix := range d.transport.routePrefixes { + routePrefixes := d.transport.routePrefixesSnapshot() + for _, prefix := range routePrefixes { if prefix.Contains(destination.Addr) { return d.transport.endpoint.ListenPacket(ctx, destination) } } return d.fallbackDialer.ListenPacket(ctx, destination) } + +func (t *DNSTransport) routePrefixesSnapshot() []netip.Prefix { + t.access.RLock() + defer t.access.RUnlock() + return append([]netip.Prefix(nil), t.routePrefixes...) +}