diff --git a/route/conn.go b/route/conn.go index 59afe539..9fdc6cda 100644 --- a/route/conn.go +++ b/route/conn.go @@ -13,6 +13,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" + "github.com/sagernet/sing-box/common/sniff" "github.com/sagernet/sing-box/common/tlsfragment" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing/common" @@ -128,11 +129,12 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co if metadata.TLSFragment || metadata.TLSRecordFragment { remoteConn = tf.NewConn(remoteConn, ctx, metadata.TLSFragment, metadata.TLSRecordFragment, metadata.TLSFragmentFallbackDelay) } + serverFirst := sniff.Skip(&metadata) var done atomic.Bool - if m.kickWriteHandshake(ctx, conn, remoteConn, false, &done, onClose) { + if m.kickWriteHandshake(ctx, conn, remoteConn, serverFirst, false, &done, onClose) { return } - if m.kickWriteHandshake(ctx, remoteConn, conn, true, &done, onClose) { + if m.kickWriteHandshake(ctx, remoteConn, conn, serverFirst, true, &done, onClose) { return } go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose) @@ -293,37 +295,43 @@ func (m *ConnectionManager) connectionCopy(ctx context.Context, source net.Conn, } } -func (m *ConnectionManager) kickWriteHandshake(ctx context.Context, source net.Conn, destination net.Conn, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) bool { +func (m *ConnectionManager) kickWriteHandshake(ctx context.Context, source net.Conn, destination net.Conn, serverFirst bool, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) bool { if !N.NeedHandshakeForWrite(destination) { return false } var ( - cachedBuffer *buf.Buffer + err error wrotePayload bool ) - sourceReader, readCounters := N.UnwrapCountReader(source, nil) - destinationWriter, writeCounters := N.UnwrapCountWriter(destination, nil) - if cachedReader, ok := sourceReader.(N.CachedReader); ok { - cachedBuffer = cachedReader.ReadCached() - } - var err error - if cachedBuffer != nil { - wrotePayload = true - dataLen := cachedBuffer.Len() - _, err = destinationWriter.Write(cachedBuffer.Bytes()) - cachedBuffer.Release() - if err == nil { - for _, counter := range readCounters { - counter(int64(dataLen)) - } - for _, counter := range writeCounters { - counter(int64(dataLen)) - } - } - } else { + if serverFirst { _ = destination.SetWriteDeadline(time.Now().Add(C.ReadPayloadTimeout)) - _, err = destinationWriter.Write(nil) + _, err = destination.Write(nil) _ = destination.SetWriteDeadline(time.Time{}) + } else { + var cachedBuffer *buf.Buffer + sourceReader, readCounters := N.UnwrapCountReader(source, nil) + destinationWriter, writeCounters := N.UnwrapCountWriter(destination, nil) + if cachedReader, ok := sourceReader.(N.CachedReader); ok { + cachedBuffer = cachedReader.ReadCached() + } + if cachedBuffer != nil { + wrotePayload = true + dataLen := cachedBuffer.Len() + _, err = destinationWriter.Write(cachedBuffer.Bytes()) + cachedBuffer.Release() + if err == nil { + for _, counter := range readCounters { + counter(int64(dataLen)) + } + for _, counter := range writeCounters { + counter(int64(dataLen)) + } + } + } else { + _ = destination.SetWriteDeadline(time.Now().Add(C.ReadPayloadTimeout)) + _, err = destinationWriter.Write(nil) + _ = destination.SetWriteDeadline(time.Time{}) + } } if err == nil { return false