diff --git a/transport/v2rayxhttp/client.go b/transport/v2rayxhttp/client.go index e8f5fb12..37d3fa63 100644 --- a/transport/v2rayxhttp/client.go +++ b/transport/v2rayxhttp/client.go @@ -123,10 +123,10 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { httpClient, xmuxClient := c.getHTTPClient() httpClient2, xmuxClient2 := c.getHTTPClient2() if xmuxClient != nil { - xmuxClient.OpenUsage.Add(1) + xmuxClient.AddOpenUsage(1) } if xmuxClient2 != nil && xmuxClient2 != xmuxClient { - xmuxClient2.OpenUsage.Add(1) + xmuxClient2.AddOpenUsage(1) } var closed atomic.Int32 reader, writer := io.Pipe() @@ -137,10 +137,10 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) { return } if xmuxClient != nil { - xmuxClient.OpenUsage.Add(-1) + xmuxClient.AddOpenUsage(-1) } if xmuxClient2 != nil && xmuxClient2 != xmuxClient { - xmuxClient2.OpenUsage.Add(-1) + xmuxClient2.AddOpenUsage(-1) } }, } diff --git a/transport/v2rayxhttp/dialer.go b/transport/v2rayxhttp/dialer.go index 08cc4ddc..255042b1 100644 --- a/transport/v2rayxhttp/dialer.go +++ b/transport/v2rayxhttp/dialer.go @@ -8,19 +8,24 @@ import ( "net" "net/http" "net/http/httptrace" + "reflect" "strings" "sync" + "unsafe" + "github.com/sagernet/quic-go/http3" common "github.com/sagernet/sing-box/common/xray" "github.com/sagernet/sing-box/common/xray/buf" "github.com/sagernet/sing-box/common/xray/signal/done" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" + "golang.org/x/net/http2" ) // interface to abstract between use of browser dialer, vs net/http type DialerClient interface { IsClosed() bool + Close() // ctx, url, sessionId, body, uploadOnly OpenStream(context.Context, string, string, io.Reader, bool) (io.ReadCloser, net.Addr, net.Addr, error) @@ -38,9 +43,54 @@ type DefaultDialerClient struct { // pool of net.Conn, created using dialUploadConn uploadRawPool *sync.Pool dialUploadConn func(ctxInner context.Context) (net.Conn, error) + + mtx sync.RWMutex +} + +type clientConnPool struct { + t *http2.Transport + mu sync.Mutex + conns map[string][]*http2.ClientConn // key is host:port +} + +type efaceWords struct { + typ unsafe.Pointer + data unsafe.Pointer +} + +//go:linkname transportConnPool golang.org/x/net/http2.(*Transport).connPool +func transportConnPool(t *http2.Transport) http2.ClientConnPool + +func (c *DefaultDialerClient) Close() { + c.mtx.Lock() + defer c.mtx.Unlock() + if c.closed { + return + } + c.closed = true + switch transport := c.client.Transport.(type) { + case *http.Transport: + transport.CloseIdleConnections() + case *http2.Transport: + connPool := transportConnPool(transport) + p := (*clientConnPool)((*efaceWords)(unsafe.Pointer(&connPool)).data) + p.mu.Lock() + defer p.mu.Unlock() + for _, vv := range p.conns { + for _, cc := range vv { + cc.Close() + } + } + case *http3.Transport: + transport.Close() + default: + panic(E.New("unknown transport type: ", reflect.TypeOf(transport))) + } } func (c *DefaultDialerClient) IsClosed() bool { + c.mtx.RLock() + defer c.mtx.RUnlock() return c.closed } @@ -67,7 +117,7 @@ func (c *DefaultDialerClient) OpenStream(ctx context.Context, url string, sessio resp, err := c.client.Do(req) if err != nil { if !uploadOnly { // stream-down is enough - c.closed = true + c.Close() } gotConn.Close() common.Close(body) @@ -133,7 +183,7 @@ func (c *DefaultDialerClient) PostPacket(ctx context.Context, url string, sessio if h1UploadConn.UnreadedResponsesCount > 0 { resp, err := http.ReadResponse(h1UploadConn.RespBufReader, req) if err != nil { - c.closed = true + c.Close() return fmt.Errorf("error while reading response: %s", err.Error()) } io.Copy(io.Discard, resp.Body) diff --git a/transport/v2rayxhttp/mux.go b/transport/v2rayxhttp/mux.go index f753aed5..e134fdeb 100644 --- a/transport/v2rayxhttp/mux.go +++ b/transport/v2rayxhttp/mux.go @@ -13,15 +13,43 @@ import ( ) type XmuxConn interface { + Close() IsClosed() bool } type XmuxClient struct { XmuxConn XmuxConn - OpenUsage atomic.Int32 + openUsage int32 leftUsage int32 LeftRequests atomic.Int32 UnreusableAt time.Time + + closed bool + mtx sync.Mutex +} + +func (c *XmuxClient) Close() { + c.mtx.Lock() + defer c.mtx.Unlock() + c.closed = true + if c.openUsage <= 0 { + c.XmuxConn.Close() + } +} + +func (c *XmuxClient) AddOpenUsage(delta int32) { + c.mtx.Lock() + defer c.mtx.Unlock() + c.openUsage += delta + if c.closed && c.openUsage <= 0 { + c.XmuxConn.Close() + } +} + +func (c *XmuxClient) GetOpenUsage() int32 { + c.mtx.Lock() + defer c.mtx.Unlock() + return c.openUsage } type XmuxManager struct { @@ -65,6 +93,7 @@ func (m *XmuxManager) newXmuxClient() *XmuxClient { func (m *XmuxManager) GetXmuxClient(ctx context.Context) *XmuxClient { m.mtx.Lock() defer m.mtx.Unlock() + var evicted []*XmuxClient for i := 0; i < len(m.xmuxClients); { xmuxClient := m.xmuxClients[i] if xmuxClient.XmuxConn.IsClosed() || @@ -72,10 +101,14 @@ func (m *XmuxManager) GetXmuxClient(ctx context.Context) *XmuxClient { xmuxClient.LeftRequests.Load() <= 0 || (xmuxClient.UnreusableAt != time.Time{} && time.Now().After(xmuxClient.UnreusableAt)) { m.xmuxClients = append(m.xmuxClients[:i], m.xmuxClients[i+1:]...) + evicted = append(evicted, xmuxClient) } else { i++ } } + for _, c := range evicted { + c.Close() + } if len(m.xmuxClients) == 0 { return m.newXmuxClient() } @@ -85,7 +118,7 @@ func (m *XmuxManager) GetXmuxClient(ctx context.Context) *XmuxClient { xmuxClients := make([]*XmuxClient, 0) if m.concurrency > 0 { for _, xmuxClient := range m.xmuxClients { - if xmuxClient.OpenUsage.Load() < m.concurrency { + if xmuxClient.GetOpenUsage() < m.concurrency { xmuxClients = append(xmuxClients, xmuxClient) } }