Add resolver for inbound

This commit is contained in:
世界
2022-07-07 23:36:32 +08:00
parent 538a1f5909
commit 9c256afc1a
22 changed files with 261 additions and 173 deletions

View File

@@ -10,16 +10,18 @@ import (
)
func New(router adapter.Router, options option.DialerOptions) N.Dialer {
domainStrategy := C.DomainStrategy(options.DomainStrategy)
var dialer N.Dialer
if options.Detour == "" {
dialer = NewDefault(options)
dialer = NewResolveDialer(router, dialer, domainStrategy)
return NewDefault(options)
} else {
dialer = NewDetour(router, options.Detour)
if domainStrategy != C.DomainStrategyAsIS {
dialer = NewResolveDialer(router, dialer, domainStrategy)
}
return NewDetour(router, options.Detour)
}
}
func NewOutbound(router adapter.Router, options option.OutboundDialerOptions) N.Dialer {
dialer := New(router, options.DialerOptions)
domainStrategy := C.DomainStrategy(options.DomainStrategy)
if domainStrategy != C.DomainStrategyAsIS || options.Detour == "" && !C.CGO_ENABLED {
dialer = NewResolveDialer(router, dialer, domainStrategy)
}
if options.OverrideOptions.IsValid() {
dialer = NewOverride(dialer, common.PtrValueOrDefault(options.OverrideOptions))

View File

@@ -5,7 +5,6 @@ import (
"net"
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
@@ -41,16 +40,7 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina
if err != nil {
return nil, err
}
var conn net.Conn
var connErrors []error
for _, address := range addresses {
conn, err = d.dialer.DialContext(ctx, network, M.SocksaddrFromAddrPort(address, destination.Port))
if err != nil {
connErrors = append(connErrors, err)
}
return conn, nil
}
return nil, E.Errors(connErrors...)
return DialSerial(ctx, d.dialer, network, destination, addresses)
}
func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
@@ -67,16 +57,7 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
if err != nil {
return nil, err
}
var conn net.PacketConn
var connErrors []error
for _, address := range addresses {
conn, err = d.dialer.ListenPacket(ctx, M.SocksaddrFromAddrPort(address, destination.Port))
if err != nil {
connErrors = append(connErrors, err)
}
return conn, nil
}
return nil, E.Errors(connErrors...)
return ListenSerial(ctx, d.dialer, destination, addresses)
}
func (d *ResolveDialer) Upstream() any {

39
common/dialer/serial.go Normal file
View File

@@ -0,0 +1,39 @@
package dialer
import (
"context"
"net"
"net/netip"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func DialSerial(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) {
var conn net.Conn
var err error
var connErrors []error
for _, address := range destinationAddresses {
conn, err = dialer.DialContext(ctx, network, M.SocksaddrFromAddrPort(address, destination.Port))
if err != nil {
connErrors = append(connErrors, err)
}
return conn, nil
}
return nil, E.Errors(connErrors...)
}
func ListenSerial(ctx context.Context, dialer N.Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.PacketConn, error) {
var conn net.PacketConn
var err error
var connErrors []error
for _, address := range destinationAddresses {
conn, err = dialer.ListenPacket(ctx, M.SocksaddrFromAddrPort(address, destination.Port))
if err != nil {
connErrors = append(connErrors, err)
}
return conn, nil
}
return nil, E.Errors(connErrors...)
}

View File

@@ -77,9 +77,14 @@ func (r *Reader) readMetadata() error {
}
func (r *Reader) Read(code string) ([]Item, error) {
if _, exists := r.domainIndex[code]; !exists {
index, exists := r.domainIndex[code]
if !exists {
return nil, E.New("code ", code, " not exists!")
}
_, err := r.reader.Seek(int64(index), io.SeekCurrent)
if err != nil {
return nil, err
}
counter := &rw.ReadCounter{Reader: r.reader}
domain := make([]Item, r.domainLength[code])
for i := range domain {
@@ -97,7 +102,7 @@ func (r *Reader) Read(code string) ([]Item, error) {
}
domain[i] = item
}
_, err := r.reader.Seek(int64(r.domainIndex[code])-counter.Count(), io.SeekCurrent)
_, err = r.reader.Seek(int64(-index)-counter.Count(), io.SeekCurrent)
return domain, err
}