Files
sing-box-extended/protocol/failover/conn.go

233 lines
4.2 KiB
Go

package failover
import (
"bytes"
"context"
"encoding/binary"
"io"
"net"
"sync"
"time"
C "github.com/sagernet/sing-box/constant"
)
type dial func() (net.Conn, error)
type failoverConn struct {
net.Conn
ctx context.Context
dial dial
onClose func()
readIndex uint32
readBuffer *bytes.Buffer
writeIndex uint32
writeBuffers [BufferSize][]byte
await chan struct{}
awaitMtx sync.Mutex
err error
once sync.Once
mtx sync.RWMutex
}
func NewFailoverConn(ctx context.Context, conn net.Conn, dial dial, onClose func()) *failoverConn {
var writeBuffers [BufferSize][]byte
for i := range BufferSize {
writeBuffers[i] = make([]byte, 0, 1000)
}
return &failoverConn{
Conn: conn,
ctx: ctx,
dial: dial,
readBuffer: bytes.NewBuffer(make([]byte, 0, 1000)),
writeBuffers: writeBuffers,
onClose: onClose,
}
}
func (c *failoverConn) Read(b []byte) (int, error) {
for {
c.mtx.RLock()
conn := c.Conn
n, err := c.read(conn, b)
if err != nil {
if err == SessionClosed {
c.err = io.EOF
conn.Close()
c.mtx.RUnlock()
return 0, c.err
}
c.mtx.RUnlock()
err = c.awaitConn(conn)
if err != nil {
return 0, err
}
continue
}
c.readIndex++
c.mtx.RUnlock()
return n, err
}
}
func (c *failoverConn) Write(b []byte) (int, error) {
for {
c.mtx.RLock()
conn := c.Conn
n, err := c.write(conn, b)
if err != nil {
c.mtx.RUnlock()
err = c.awaitConn(conn)
if err != nil {
return 0, err
}
continue
}
writeIndex := c.writeIndex % BufferSize
c.writeBuffers[writeIndex] = append(c.writeBuffers[writeIndex][:0], b...)
c.writeIndex++
c.mtx.RUnlock()
return n, err
}
}
func (c *failoverConn) RestoreConn(conn net.Conn) error {
c.Conn.Close()
c.mtx.Lock()
defer c.mtx.Unlock()
_, err := conn.Write([]byte{
byte(c.readIndex >> 24),
byte(c.readIndex >> 16),
byte(c.readIndex >> 8),
byte(c.readIndex),
})
if err != nil {
return err
}
var data [4]byte
_, err = io.ReadFull(conn, data[:])
if err != nil {
return err
}
writeIndex := binary.BigEndian.Uint32(data[:])
buffers := make([][]byte, 0, BufferSize)
for writeIndex != c.writeIndex {
if len(buffers) == BufferSize {
return SessionBroken
}
buffers = append(buffers, c.writeBuffers[writeIndex%BufferSize])
writeIndex++
}
for _, buffer := range buffers {
_, err = c.write(conn, buffer)
if err != nil {
return err
}
}
c.Conn = conn
if c.await != nil {
close(c.await)
c.await = nil
}
return nil
}
func (c *failoverConn) Close() error {
c.once.Do(func() {
c.mtx.RLock()
if c.onClose != nil {
c.onClose()
}
c.err = io.EOF
c.mtx.RUnlock()
c.Write([]byte{})
})
return nil
}
func (c *failoverConn) read(conn net.Conn, b []byte) (int, error) {
if c.readBuffer.Len() == 0 {
c.readBuffer.Reset()
var data [2]byte
_, err := io.ReadFull(conn, data[:])
if err != nil {
return 0, err
}
n := binary.BigEndian.Uint16(data[:])
if n == 0 {
return 0, SessionClosed
}
_, err = io.CopyN(c.readBuffer, conn, int64(n))
if err != nil {
return 0, err
}
}
return c.readBuffer.Read(b)
}
func (c *failoverConn) write(conn net.Conn, b []byte) (int, error) {
buffer := make([]byte, 2+len(b))
binary.BigEndian.PutUint16(buffer, uint16(len(b)))
copy(buffer[2:], b)
n, err := conn.Write(buffer)
return n - 2, err
}
func (c *failoverConn) awaitConn(oldConn net.Conn) error {
c.awaitMtx.Lock()
defer c.awaitMtx.Unlock()
if c.err != nil {
return c.err
}
if c.Conn != oldConn {
return c.ctx.Err()
}
oldConn.Close()
timer := time.NewTimer(C.TCPConnectTimeout)
defer timer.Stop()
if c.dial != nil {
for {
select {
case <-c.ctx.Done():
return c.ctx.Err()
case <-timer.C:
c.err = SessionExpired
return c.err
default:
}
conn, err := c.dial()
if err != nil {
if err == SessionNotFound {
c.err = err
return err
}
continue
}
err = c.RestoreConn(conn)
if err != nil {
if err == SessionBroken {
c.err = err
return err
}
continue
}
return nil
}
} else {
c.await = make(chan struct{})
select {
case <-c.await:
case <-timer.C:
c.err = SessionExpired
return c.err
case <-c.ctx.Done():
return c.ctx.Err()
}
}
return nil
}