Improve loopback detector

This commit is contained in:
世界
2024-04-12 09:24:49 +08:00
parent 64a05a27a2
commit d9f2d31147
7 changed files with 48 additions and 92 deletions

View File

@@ -5,21 +5,24 @@ import (
"net/netip"
"sync"
"github.com/sagernet/sing-box/adapter"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type loopBackDetector struct {
router adapter.Router
connAccess sync.RWMutex
packetConnAccess sync.RWMutex
connMap map[netip.AddrPort]bool
packetConnMap map[netip.AddrPort]bool
packetConnMap map[uint16]bool
}
func newLoopBackDetector() *loopBackDetector {
func newLoopBackDetector(router adapter.Router) *loopBackDetector {
return &loopBackDetector{
router: router,
connMap: make(map[netip.AddrPort]bool),
packetConnMap: make(map[netip.AddrPort]bool),
packetConnMap: make(map[uint16]bool),
}
}
@@ -29,10 +32,16 @@ func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn {
return conn
}
if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn {
if !connAddr.Addr().IsLoopback() {
_, err := l.router.InterfaceFinder().InterfaceByAddr(connAddr.Addr())
if err != nil {
return conn
}
}
l.packetConnAccess.Lock()
l.packetConnMap[connAddr] = true
l.packetConnMap[connAddr.Port()] = true
l.packetConnAccess.Unlock()
return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connAddr: connAddr}
return &loopBackDetectUDPWrapper{abstractUDPConn: udpConn, detector: l, connPort: connAddr.Port()}
} else {
l.connAccess.Lock()
l.connMap[connAddr] = true
@@ -46,10 +55,16 @@ func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn) N.NetPacketConn {
if !connAddr.IsValid() {
return conn
}
if !connAddr.Addr().IsLoopback() {
_, err := l.router.InterfaceFinder().InterfaceByAddr(connAddr.Addr())
if err != nil {
return conn
}
}
l.packetConnAccess.Lock()
l.packetConnMap[connAddr] = true
l.packetConnMap[connAddr.Port()] = true
l.packetConnAccess.Unlock()
return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connAddr: connAddr}
return &loopBackDetectPacketWrapper{NetPacketConn: conn, detector: l, connPort: connAddr.Port()}
}
func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool {
@@ -59,9 +74,18 @@ func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool {
}
func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool {
if !connAddr.IsValid() || !connAddr.Addr().IsLoopback() {
return false
}
if !connAddr.Addr().IsLoopback() {
_, err := l.router.InterfaceFinder().InterfaceByAddr(connAddr.Addr())
if err != nil {
return false
}
}
l.packetConnAccess.RLock()
defer l.packetConnAccess.RUnlock()
return l.packetConnMap[connAddr]
return l.packetConnMap[connAddr.Port()]
}
type loopBackDetectWrapper struct {
@@ -95,14 +119,14 @@ func (w *loopBackDetectWrapper) Upstream() any {
type loopBackDetectPacketWrapper struct {
N.NetPacketConn
detector *loopBackDetector
connAddr netip.AddrPort
connPort uint16
closeOnce sync.Once
}
func (w *loopBackDetectPacketWrapper) Close() error {
w.closeOnce.Do(func() {
w.detector.packetConnAccess.Lock()
delete(w.detector.packetConnMap, w.connAddr)
delete(w.detector.packetConnMap, w.connPort)
w.detector.packetConnAccess.Unlock()
})
return w.NetPacketConn.Close()
@@ -128,14 +152,14 @@ type abstractUDPConn interface {
type loopBackDetectUDPWrapper struct {
abstractUDPConn
detector *loopBackDetector
connAddr netip.AddrPort
connPort uint16
closeOnce sync.Once
}
func (w *loopBackDetectUDPWrapper) Close() error {
w.closeOnce.Do(func() {
w.detector.packetConnAccess.Lock()
delete(w.detector.packetConnMap, w.connAddr)
delete(w.detector.packetConnMap, w.connPort)
w.detector.packetConnAccess.Unlock()
})
return w.abstractUDPConn.Close()