diff --git a/common/interrupt/conn.go b/common/interrupt/conn.go index 6a6d31c6..b8235d6a 100644 --- a/common/interrupt/conn.go +++ b/common/interrupt/conn.go @@ -3,6 +3,7 @@ package interrupt import ( "net" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" ) @@ -73,3 +74,32 @@ func (c *PacketConn) WriterReplaceable() bool { func (c *PacketConn) Upstream() any { return c.PacketConn } + +type SingPacketConn struct { + N.PacketConn + group *Group + element *list.Element[*groupConnItem] +} + +/*func (c *SingPacketConn) MarkAsInternal() { + c.element.Value.internal = true +}*/ + +func (c *SingPacketConn) Close() error { + c.group.access.Lock() + defer c.group.access.Unlock() + c.group.connections.Remove(c.element) + return c.PacketConn.Close() +} + +func (c *SingPacketConn) ReaderReplaceable() bool { + return true +} + +func (c *SingPacketConn) WriterReplaceable() bool { + return true +} + +func (c *SingPacketConn) Upstream() any { + return c.PacketConn +} diff --git a/common/interrupt/group.go b/common/interrupt/group.go index ba2e7f73..bd3fbb0a 100644 --- a/common/interrupt/group.go +++ b/common/interrupt/group.go @@ -5,6 +5,7 @@ import ( "net" "sync" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" ) @@ -36,6 +37,13 @@ func (g *Group) NewPacketConn(conn net.PacketConn, isExternal bool) net.PacketCo return &PacketConn{PacketConn: conn, group: g, element: item} } +func (g *Group) NewSingPacketConn(conn N.PacketConn, isExternal bool) N.PacketConn { + g.access.Lock() + defer g.access.Unlock() + item := g.connections.PushBack(&groupConnItem{conn, isExternal}) + return &SingPacketConn{PacketConn: conn, group: g, element: item} +} + func (g *Group) Interrupt(interruptExternalConnections bool) { g.access.Lock() defer g.access.Unlock() diff --git a/protocol/group/selector.go b/protocol/group/selector.go index 9806e033..23526d19 100644 --- a/protocol/group/selector.go +++ b/protocol/group/selector.go @@ -157,6 +157,7 @@ func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (n func (s *Selector) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) selected := s.selected.Load() + conn = s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)) if outboundHandler, isHandler := selected.(adapter.ConnectionHandlerEx); isHandler { outboundHandler.NewConnectionEx(ctx, conn, metadata, onClose) } else { @@ -167,6 +168,7 @@ func (s *Selector) NewConnectionEx(ctx context.Context, conn net.Conn, metadata func (s *Selector) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) selected := s.selected.Load() + conn = s.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)) if outboundHandler, isHandler := selected.(adapter.PacketConnectionHandlerEx); isHandler { outboundHandler.NewPacketConnectionEx(ctx, conn, metadata, onClose) } else { diff --git a/protocol/group/urltest.go b/protocol/group/urltest.go index c1a5c597..5746e0cf 100644 --- a/protocol/group/urltest.go +++ b/protocol/group/urltest.go @@ -162,11 +162,13 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne func (s *URLTest) NewConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) + conn = s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)) s.connection.NewConnection(ctx, s, conn, metadata, onClose) } func (s *URLTest) NewPacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) { ctx = interrupt.ContextWithIsExternalConnection(ctx) + conn = s.group.interruptGroup.NewSingPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)) s.connection.NewPacketConnection(ctx, s, conn, metadata, onClose) }