mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-12 06:18:16 +03:00
Fix connector canceled dial cleanup
This commit is contained in:
@@ -55,6 +55,12 @@ type contextKeyConnecting struct{}
|
|||||||
|
|
||||||
var errRecursiveConnectorDial = E.New("recursive connector dial")
|
var errRecursiveConnectorDial = E.New("recursive connector dial")
|
||||||
|
|
||||||
|
type connectorDialResult[T any] struct {
|
||||||
|
connection T
|
||||||
|
cancel context.CancelFunc
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
||||||
var zero T
|
var zero T
|
||||||
for {
|
for {
|
||||||
@@ -100,41 +106,37 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
|||||||
return zero, err
|
return zero, err
|
||||||
}
|
}
|
||||||
|
|
||||||
c.connecting = make(chan struct{})
|
connecting := make(chan struct{})
|
||||||
|
c.connecting = connecting
|
||||||
|
dialContext := context.WithValue(ctx, contextKeyConnecting{}, c)
|
||||||
|
dialResult := make(chan connectorDialResult[T], 1)
|
||||||
c.access.Unlock()
|
c.access.Unlock()
|
||||||
|
|
||||||
dialContext := context.WithValue(ctx, contextKeyConnecting{}, c)
|
go func() {
|
||||||
connection, cancel, err := c.dialWithCancellation(dialContext)
|
connection, cancel, err := c.dialWithCancellation(dialContext)
|
||||||
|
dialResult <- connectorDialResult[T]{
|
||||||
|
connection: connection,
|
||||||
|
cancel: cancel,
|
||||||
|
err: err,
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
c.access.Lock()
|
select {
|
||||||
close(c.connecting)
|
case result := <-dialResult:
|
||||||
c.connecting = nil
|
return c.completeDial(ctx, connecting, result)
|
||||||
|
case <-ctx.Done():
|
||||||
if err != nil {
|
go func() {
|
||||||
c.access.Unlock()
|
result := <-dialResult
|
||||||
return zero, err
|
_, _ = c.completeDial(ctx, connecting, result)
|
||||||
}
|
}()
|
||||||
|
return zero, ctx.Err()
|
||||||
if c.closed {
|
case <-c.closeCtx.Done():
|
||||||
cancel()
|
go func() {
|
||||||
c.callbacks.Close(connection)
|
result := <-dialResult
|
||||||
c.access.Unlock()
|
_, _ = c.completeDial(ctx, connecting, result)
|
||||||
|
}()
|
||||||
return zero, ErrTransportClosed
|
return zero, ErrTransportClosed
|
||||||
}
|
}
|
||||||
if err = ctx.Err(); err != nil {
|
|
||||||
cancel()
|
|
||||||
c.callbacks.Close(connection)
|
|
||||||
c.access.Unlock()
|
|
||||||
return zero, err
|
|
||||||
}
|
|
||||||
|
|
||||||
c.connection = connection
|
|
||||||
c.hasConnection = true
|
|
||||||
c.connectionCancel = cancel
|
|
||||||
result := c.connection
|
|
||||||
c.access.Unlock()
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -143,6 +145,38 @@ func isRecursiveConnectorDial[T any](ctx context.Context, connector *Connector[T
|
|||||||
return loaded && dialConnector == connector
|
return loaded && dialConnector == connector
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Connector[T]) completeDial(ctx context.Context, connecting chan struct{}, result connectorDialResult[T]) (T, error) {
|
||||||
|
var zero T
|
||||||
|
|
||||||
|
c.access.Lock()
|
||||||
|
defer c.access.Unlock()
|
||||||
|
defer func() {
|
||||||
|
if c.connecting == connecting {
|
||||||
|
c.connecting = nil
|
||||||
|
}
|
||||||
|
close(connecting)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if result.err != nil {
|
||||||
|
return zero, result.err
|
||||||
|
}
|
||||||
|
if c.closed || c.closeCtx.Err() != nil {
|
||||||
|
result.cancel()
|
||||||
|
c.callbacks.Close(result.connection)
|
||||||
|
return zero, ErrTransportClosed
|
||||||
|
}
|
||||||
|
if err := ctx.Err(); err != nil {
|
||||||
|
result.cancel()
|
||||||
|
c.callbacks.Close(result.connection)
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.connection = result.connection
|
||||||
|
c.hasConnection = true
|
||||||
|
c.connectionCancel = result.cancel
|
||||||
|
return c.connection, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) {
|
func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) {
|
||||||
var zero T
|
var zero T
|
||||||
if err := ctx.Err(); err != nil {
|
if err := ctx.Err(); err != nil {
|
||||||
|
|||||||
@@ -188,13 +188,157 @@ func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) {
|
|||||||
err := <-result
|
err := <-result
|
||||||
require.ErrorIs(t, err, context.Canceled)
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
require.EqualValues(t, 1, dialCount.Load())
|
require.EqualValues(t, 1, dialCount.Load())
|
||||||
require.EqualValues(t, 1, closeCount.Load())
|
require.Eventually(t, func() bool {
|
||||||
|
return closeCount.Load() == 1
|
||||||
|
}, time.Second, 10*time.Millisecond)
|
||||||
|
|
||||||
_, err = connector.Get(context.Background())
|
_, err = connector.Get(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.EqualValues(t, 2, dialCount.Load())
|
require.EqualValues(t, 2, dialCount.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnectorCanceledRequestReturnsBeforeIgnoredDialCompletes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
dialCount atomic.Int32
|
||||||
|
closeCount atomic.Int32
|
||||||
|
)
|
||||||
|
dialStarted := make(chan struct{}, 1)
|
||||||
|
releaseDial := make(chan struct{})
|
||||||
|
|
||||||
|
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
||||||
|
dialCount.Add(1)
|
||||||
|
select {
|
||||||
|
case dialStarted <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
<-releaseDial
|
||||||
|
return &testConnectorConnection{}, nil
|
||||||
|
}, ConnectorCallbacks[*testConnectorConnection]{
|
||||||
|
IsClosed: func(connection *testConnectorConnection) bool {
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
Close: func(connection *testConnectorConnection) {
|
||||||
|
closeCount.Add(1)
|
||||||
|
},
|
||||||
|
Reset: func(connection *testConnectorConnection) {},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestContext, cancel := context.WithCancel(context.Background())
|
||||||
|
result := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := connector.Get(requestContext)
|
||||||
|
result <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-dialStarted
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-result:
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("Get did not return after request cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
require.EqualValues(t, 1, dialCount.Load())
|
||||||
|
require.EqualValues(t, 0, closeCount.Load())
|
||||||
|
|
||||||
|
close(releaseDial)
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return closeCount.Load() == 1
|
||||||
|
}, time.Second, 10*time.Millisecond)
|
||||||
|
|
||||||
|
_, err := connector.Get(context.Background())
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 2, dialCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectorWaiterDoesNotStartNewDialBeforeCanceledDialCompletes(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var (
|
||||||
|
dialCount atomic.Int32
|
||||||
|
closeCount atomic.Int32
|
||||||
|
)
|
||||||
|
firstDialStarted := make(chan struct{}, 1)
|
||||||
|
secondDialStarted := make(chan struct{}, 1)
|
||||||
|
releaseFirstDial := make(chan struct{})
|
||||||
|
|
||||||
|
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
||||||
|
attempt := dialCount.Add(1)
|
||||||
|
switch attempt {
|
||||||
|
case 1:
|
||||||
|
select {
|
||||||
|
case firstDialStarted <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
<-releaseFirstDial
|
||||||
|
case 2:
|
||||||
|
select {
|
||||||
|
case secondDialStarted <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &testConnectorConnection{}, nil
|
||||||
|
}, ConnectorCallbacks[*testConnectorConnection]{
|
||||||
|
IsClosed: func(connection *testConnectorConnection) bool {
|
||||||
|
return false
|
||||||
|
},
|
||||||
|
Close: func(connection *testConnectorConnection) {
|
||||||
|
closeCount.Add(1)
|
||||||
|
},
|
||||||
|
Reset: func(connection *testConnectorConnection) {},
|
||||||
|
})
|
||||||
|
|
||||||
|
requestContext, cancel := context.WithCancel(context.Background())
|
||||||
|
firstResult := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := connector.Get(requestContext)
|
||||||
|
firstResult <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
<-firstDialStarted
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
secondResult := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := connector.Get(context.Background())
|
||||||
|
secondResult <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-secondDialStarted:
|
||||||
|
t.Fatal("second dial started before first dial completed")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-firstResult:
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("first Get did not return after request cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
close(releaseFirstDial)
|
||||||
|
|
||||||
|
require.Eventually(t, func() bool {
|
||||||
|
return closeCount.Load() == 1
|
||||||
|
}, time.Second, 10*time.Millisecond)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-secondDialStarted:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("second dial did not start after first dial completed")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := <-secondResult
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.EqualValues(t, 2, dialCount.Load())
|
||||||
|
}
|
||||||
|
|
||||||
func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) {
|
func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user