Handle TUN loopback in direct outbound

This commit is contained in:
世界
2026-06-03 10:37:53 +08:00
parent 1086ab2563
commit 761b7f4e12
8 changed files with 77 additions and 19 deletions

View File

@@ -20,6 +20,7 @@ import (
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
@@ -37,10 +38,12 @@ type Outbound struct {
outbound.Adapter
ctx context.Context
logger logger.ContextLogger
network adapter.NetworkManager
dialer dialer.ParallelInterfaceDialer
domainStrategy C.DomainStrategy
fallbackDelay time.Duration
isEmpty bool
myAddresses common.TypedValue[[]netip.Prefix]
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (adapter.Outbound, error) {
@@ -61,6 +64,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP, N.NetworkICMP}, options.DialerOptions),
ctx: ctx,
logger: logger,
network: service.FromContext[adapter.NetworkManager](ctx),
//nolint:staticcheck
domainStrategy: C.DomainStrategy(options.DomainStrategy),
fallbackDelay: time.Duration(options.FallbackDelay),
@@ -74,7 +78,48 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return outbound, nil
}
func (h *Outbound) Start(stage adapter.StartStage) error {
switch stage {
case adapter.StartStatePostStart, adapter.StartStateStarted:
h.fetchMyAddresses()
}
return nil
}
func (h *Outbound) fetchMyAddresses() {
if len(h.myAddresses.Load()) > 0 {
return
}
myInterfaceNames := h.network.InterfaceMonitor().MyInterfaces()
if len(myInterfaceNames) == 0 {
return
}
var myAddresses []netip.Prefix
for _, myInterfaceName := range myInterfaceNames {
myInterface, err := h.network.InterfaceFinder().ByName(myInterfaceName)
if err != nil {
continue
}
myAddresses = append(myAddresses, myInterface.Addresses...)
}
h.myAddresses.Store(myAddresses)
}
func (h *Outbound) isMyLoopbackAddress(addresses ...netip.Addr) bool {
for _, prefix := range h.myAddresses.Load() {
for _, address := range addresses {
if prefix.Addr() != address && prefix.Contains(address) {
return true
}
}
}
return false
}
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if h.isMyLoopbackAddress(destination.Addr) {
return nil, E.New("loopback connection to TUN range")
}
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
@@ -89,6 +134,9 @@ func (h *Outbound) DialContext(ctx context.Context, network string, destination
}
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if h.isMyLoopbackAddress(destination.Addr) {
return nil, E.New("loopback connection to TUN range")
}
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
@@ -111,6 +159,9 @@ func (h *Outbound) NewDirectRouteConnection(metadata adapter.InboundContext, rou
}
func (h *Outbound) DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) {
if h.isMyLoopbackAddress(destinationAddresses...) {
return nil, E.New("loopback connection to TUN range")
}
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
@@ -125,6 +176,9 @@ func (h *Outbound) DialParallel(ctx context.Context, network string, destination
}
func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy *C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
if h.isMyLoopbackAddress(destinationAddresses...) {
return nil, E.New("loopback connection to TUN range")
}
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination
@@ -139,6 +193,9 @@ func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, dest
}
func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy *C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) {
if h.isMyLoopbackAddress(destinationAddresses...) {
return nil, netip.Addr{}, E.New("loopback connection to TUN range")
}
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Outbound = h.Tag()
metadata.Destination = destination