package transport import ( "context" "net" "sync" "sync/atomic" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" mDNS "github.com/miekg/dns" ) var _ adapter.DNSTransport = (*UDPTransport)(nil) func RegisterUDP(registry *dns.TransportRegistry) { dns.RegisterTransport[option.RemoteDNSServerOptions](registry, C.DNSTypeUDP, NewUDP) } type UDPTransport struct { dns.TransportAdapter logger logger.ContextLogger dialer N.Dialer serverAddr M.Socksaddr udpSize atomic.Int32 connection *ConnPool[net.Conn] callbackAccess sync.RWMutex queryId uint16 callbacks map[uint16]*udpCallback } type udpCallback struct { access sync.Mutex response *mDNS.Msg done chan struct{} } func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) { transportDialer, err := dns.NewRemoteDialer(ctx, options) if err != nil { return nil, err } serverAddr := options.DNSServerAddressOptions.Build() if serverAddr.Port == 0 { serverAddr.Port = 53 } if !serverAddr.IsValid() { return nil, E.New("invalid server address: ", serverAddr) } return NewUDPRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeUDP, tag, options), transportDialer, serverAddr), nil } func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialerInstance N.Dialer, serverAddr M.Socksaddr) *UDPTransport { t := &UDPTransport{ TransportAdapter: adapter, logger: logger, dialer: dialerInstance, serverAddr: serverAddr, callbacks: make(map[uint16]*udpCallback), connection: NewConnPool(ConnPoolOptions[net.Conn]{ Mode: ConnPoolSingle, IsAlive: func(conn net.Conn) bool { return conn != nil }, Close: func(conn net.Conn, cause error) { conn.Close() }, }), } t.udpSize.Store(2048) return t } func (t *UDPTransport) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } return dialer.InitializeDetour(t.dialer) } func (t *UDPTransport) Close() error { return t.connection.Close() } func (t *UDPTransport) Reset() { t.connection.Reset() } func (t *UDPTransport) nextAvailableQueryId() (uint16, error) { start := t.queryId for { t.queryId++ if _, exists := t.callbacks[t.queryId]; !exists { return t.queryId, nil } if t.queryId == start { return 0, E.New("no available query ID") } } } func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { response, err := t.exchange(ctx, message) if err != nil { return nil, err } if response.Truncated { t.logger.InfoContext(ctx, "response truncated, retrying with TCP") return t.exchangeTCP(ctx, message) } return response, nil } func (t *UDPTransport) exchangeTCP(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr) if err != nil { return nil, E.Cause(err, "dial TCP connection") } defer conn.Close() err = WriteMessage(conn, message.Id, message) if err != nil { return nil, E.Cause(err, "write request") } response, err := ReadMessage(conn) if err != nil { return nil, E.Cause(err, "read response") } return response, nil } func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { if edns0Opt := message.IsEdns0(); edns0Opt != nil { udpSize := int32(edns0Opt.UDPSize()) for { current := t.udpSize.Load() if udpSize <= current { break } if t.udpSize.CompareAndSwap(current, udpSize) { t.Reset() break } } } conn, connCtx, created, err := t.connection.AcquireShared(ctx, func(ctx context.Context) (net.Conn, error) { rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) if err != nil { return nil, E.Cause(err, "dial UDP connection") } return rawConn, nil }) if err != nil { return nil, err } if created { go t.recvLoop(conn) } callback := &udpCallback{ done: make(chan struct{}), } t.callbackAccess.Lock() queryId, err := t.nextAvailableQueryId() if err != nil { t.callbackAccess.Unlock() t.connection.Release(conn, true) return nil, err } t.callbacks[queryId] = callback t.callbackAccess.Unlock() defer func() { t.callbackAccess.Lock() delete(t.callbacks, queryId) t.callbackAccess.Unlock() }() buffer := buf.NewSize(1 + message.Len()) defer buffer.Release() exMessage := *message exMessage.Compress = true originalId := message.Id exMessage.Id = queryId rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes()) if err != nil { return nil, err } _, err = conn.Write(rawMessage) if err != nil { t.connection.Invalidate(conn, err) return nil, E.Cause(err, "write request") } select { case <-callback.done: t.connection.Release(conn, true) callback.response.Id = originalId return callback.response, nil case <-connCtx.Done(): return nil, context.Cause(connCtx) case <-ctx.Done(): t.connection.Release(conn, true) return nil, ctx.Err() } } func (t *UDPTransport) recvLoop(conn net.Conn) { for { buffer := buf.NewSize(int(t.udpSize.Load())) _, err := buffer.ReadOnceFrom(conn) if err != nil { buffer.Release() t.connection.Invalidate(conn, err) return } var message mDNS.Msg err = message.Unpack(buffer.Bytes()) buffer.Release() if err != nil { t.logger.Debug("discarded malformed UDP response: ", err) continue } t.callbackAccess.RLock() callback, loaded := t.callbacks[message.Id] t.callbackAccess.RUnlock() if !loaded { continue } callback.access.Lock() select { case <-callback.done: default: callback.response = &message close(callback.done) } callback.access.Unlock() } }