Files
sing-box-extended/transport/v2raykcp/connection.go

569 lines
12 KiB
Go

package v2raykcp
import (
"bytes"
"io"
"net"
"runtime"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing/common/buf"
)
// PacketWriter writes low-level UDP packets with obfuscating header and AEAD.
// It mirrors v2ray-core's kcp.PacketWriter.
type PacketWriter interface {
Overhead() int
io.Writer
}
// State of the connection
type State int32
const (
StateActive State = 0
StateReadyToClose State = 1
StatePeerClosed State = 2
StateTerminating State = 3
StatePeerTerminating State = 4
StateTerminated State = 5
)
// Is returns true if current State is one of the candidates.
func (s State) Is(states ...State) bool {
for _, state := range states {
if s == state {
return true
}
}
return false
}
func nowMillisec() int64 {
now := time.Now()
return now.Unix()*1000 + int64(now.Nanosecond()/1000000)
}
// RoundTripInfo stores round trip time information
type RoundTripInfo struct {
mu sync.RWMutex
variation uint32
srtt uint32
rto uint32
minRtt uint32
updatedTimestamp uint32
}
func (info *RoundTripInfo) UpdatePeerRTO(rto uint32, current uint32) {
info.mu.Lock()
defer info.mu.Unlock()
if current-info.updatedTimestamp < 3000 {
return
}
info.updatedTimestamp = current
info.rto = rto
}
func (info *RoundTripInfo) Update(rtt uint32, current uint32) {
if rtt > 0x7FFFFFFF {
return
}
info.mu.Lock()
defer info.mu.Unlock()
if info.srtt == 0 {
info.srtt = rtt
info.variation = rtt / 2
} else {
delta := rtt - info.srtt
if info.srtt > rtt {
delta = info.srtt - rtt
}
info.variation = (3*info.variation + delta) / 4
info.srtt = (7*info.srtt + rtt) / 8
if info.srtt < info.minRtt {
info.srtt = info.minRtt
}
}
var rto uint32
if info.minRtt < 4*info.variation {
rto = info.srtt + 4*info.variation
} else {
rto = info.srtt + info.variation
}
if rto > 10000 {
rto = 10000
}
info.rto = rto * 5 / 4
info.updatedTimestamp = current
}
func (info *RoundTripInfo) Timeout() uint32 {
info.mu.RLock()
defer info.mu.RUnlock()
if info.rto == 0 {
return 100
}
return info.rto
}
func (info *RoundTripInfo) SmoothedTime() uint32 {
info.mu.RLock()
defer info.mu.RUnlock()
return info.srtt
}
// ConnMetadata stores connection metadata
type ConnMetadata struct {
LocalAddr net.Addr
RemoteAddr net.Addr
Conversation uint16
}
// Connection represents a KCP connection
type Connection struct {
meta ConnMetadata
closer io.Closer
rd time.Time
wd time.Time
since int64
dataInput chan struct{}
dataOutput chan struct{}
Config *Config
state int32
stateBeginTime uint32
lastIncomingTime uint32
lastPingTime uint32
mss uint32
roundTrip *RoundTripInfo
receivingWorker *ReceivingWorker
sendingWorker *SendingWorker
output SegmentWriter
dataUpdater *Updater
pingUpdater *Updater
}
func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, config *Config) *Connection {
conn := &Connection{
meta: meta,
closer: closer,
since: nowMillisec(),
dataInput: make(chan struct{}, 1),
dataOutput: make(chan struct{}, 1),
Config: config,
output: NewSegmentWriter(writer),
mss: config.GetMTUValue() - uint32(writer.Overhead()) - uint32(DataSegmentOverhead),
roundTrip: &RoundTripInfo{
rto: 100,
minRtt: config.GetTTIValue(),
},
}
conn.receivingWorker = NewReceivingWorker(conn)
conn.sendingWorker = NewSendingWorker(conn)
isTerminating := func() bool {
return conn.State().Is(StateTerminating, StateTerminated)
}
isTerminated := func() bool {
return conn.State() == StateTerminated
}
conn.dataUpdater = NewUpdater(
config.GetTTIValue(),
func() bool {
return !isTerminating() && (conn.sendingWorker.UpdateNecessary() || conn.receivingWorker.UpdateNecessary())
},
isTerminating,
conn.updateTask,
)
conn.pingUpdater = NewUpdater(
5000,
func() bool { return !isTerminated() },
isTerminated,
conn.updateTask,
)
conn.pingUpdater.WakeUp()
return conn
}
func (c *Connection) Elapsed() uint32 {
return uint32(nowMillisec() - c.since)
}
func (c *Connection) State() State {
return State(atomic.LoadInt32(&c.state))
}
func (c *Connection) SetState(state State) {
current := c.Elapsed()
atomic.StoreInt32(&c.state, int32(state))
atomic.StoreUint32(&c.stateBeginTime, current)
switch state {
case StateReadyToClose:
c.receivingWorker.CloseRead()
case StatePeerClosed:
c.sendingWorker.CloseWrite()
case StateTerminating:
c.receivingWorker.CloseRead()
c.sendingWorker.CloseWrite()
c.pingUpdater.SetInterval(time.Second)
case StatePeerTerminating:
c.sendingWorker.CloseWrite()
c.pingUpdater.SetInterval(time.Second)
case StateTerminated:
c.receivingWorker.CloseRead()
c.sendingWorker.CloseWrite()
c.pingUpdater.SetInterval(time.Second)
c.dataUpdater.WakeUp()
c.pingUpdater.WakeUp()
go c.Terminate()
}
}
func (c *Connection) Terminate() {
if c == nil {
return
}
time.Sleep(8 * time.Second)
if c.closer != nil {
c.closer.Close()
}
if c.sendingWorker != nil {
c.sendingWorker.Release()
}
if c.receivingWorker != nil {
c.receivingWorker.Release()
}
}
func (c *Connection) HandleOption(opt SegmentOption) {
if (opt & SegmentOptionClose) == SegmentOptionClose {
c.OnPeerClosed()
}
}
func (c *Connection) OnPeerClosed() {
switch c.State() {
case StateReadyToClose:
c.SetState(StateTerminating)
case StateActive:
c.SetState(StatePeerClosed)
}
}
func (c *Connection) Input(segments []Segment) {
current := c.Elapsed()
atomic.StoreUint32(&c.lastIncomingTime, current)
for _, s := range segments {
if s.Conversation() != c.meta.Conversation {
break
}
switch seg := s.(type) {
case *DataSegment:
c.HandleOption(seg.Option)
c.receivingWorker.ProcessSegment(seg)
if c.receivingWorker.IsDataAvailable() {
select {
case c.dataInput <- struct{}{}:
default:
}
}
c.dataUpdater.WakeUp()
case *AckSegment:
c.HandleOption(seg.Option)
c.sendingWorker.ProcessSegment(current, seg, c.roundTrip.Timeout())
select {
case c.dataOutput <- struct{}{}:
default:
}
c.dataUpdater.WakeUp()
case *CmdOnlySegment:
c.HandleOption(seg.Option)
if seg.Command() == CommandTerminate {
switch c.State() {
case StateActive, StatePeerClosed:
c.SetState(StatePeerTerminating)
case StateReadyToClose:
c.SetState(StateTerminating)
case StateTerminating:
c.SetState(StateTerminated)
}
}
if seg.Option == SegmentOptionClose || seg.Command() == CommandTerminate {
select {
case c.dataInput <- struct{}{}:
default:
}
select {
case c.dataOutput <- struct{}{}:
default:
}
}
c.sendingWorker.ProcessReceivingNext(seg.ReceivingNext)
c.receivingWorker.ProcessSendingNext(seg.SendingNext)
c.roundTrip.UpdatePeerRTO(seg.PeerRTO, current)
seg.Release()
default:
s.Release()
}
}
}
func (c *Connection) waitForDataInput() error {
for i := 0; i < 16; i++ {
select {
case <-c.dataInput:
return nil
default:
runtime.Gosched()
}
}
duration := time.Second * 16
if !c.rd.IsZero() {
duration = time.Until(c.rd)
if duration < 0 {
return ErrIOTimeout
}
}
select {
case <-c.dataInput:
return nil
case <-time.After(duration):
if !c.rd.IsZero() && c.rd.Before(time.Now()) {
return ErrIOTimeout
}
return nil
}
}
func (c *Connection) Read(b []byte) (int, error) {
if c == nil {
return 0, io.EOF
}
for {
if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
return 0, io.EOF
}
nBytes := c.receivingWorker.Read(b)
if nBytes > 0 {
c.dataUpdater.WakeUp()
return nBytes, nil
}
if c.State() == StatePeerTerminating {
return 0, io.EOF
}
if err := c.waitForDataInput(); err != nil {
return 0, err
}
}
}
func (c *Connection) waitForDataOutput() error {
for i := 0; i < 16; i++ {
select {
case <-c.dataOutput:
return nil
default:
runtime.Gosched()
}
}
duration := time.Second * 16
if !c.wd.IsZero() {
duration = time.Until(c.wd)
if duration < 0 {
return ErrIOTimeout
}
}
select {
case <-c.dataOutput:
return nil
case <-time.After(duration):
if !c.wd.IsZero() && c.wd.Before(time.Now()) {
return ErrIOTimeout
}
return nil
}
}
func (c *Connection) Write(b []byte) (int, error) {
if c.State() != StateActive {
return 0, io.ErrClosedPipe
}
totalWritten := 0
reader := bytes.NewReader(b)
for reader.Len() > 0 {
buffer := buf.New()
n, _ := buffer.ReadFrom(io.LimitReader(reader, int64(c.mss)))
if n == 0 {
buffer.Release()
break
}
for !c.sendingWorker.Push(buffer) {
if c.State() != StateActive {
buffer.Release()
return totalWritten, io.ErrClosedPipe
}
c.dataUpdater.WakeUp()
if err := c.waitForDataOutput(); err != nil {
buffer.Release()
return totalWritten, err
}
}
totalWritten += int(n)
}
c.dataUpdater.WakeUp()
return totalWritten, nil
}
func (c *Connection) updateTask() {
current := c.Elapsed()
if c.State() == StateTerminated {
return
}
if c.State() == StateActive && current-atomic.LoadUint32(&c.lastIncomingTime) >= 30000 {
_ = c.Close()
}
if c.State() == StateReadyToClose && c.sendingWorker.IsEmpty() {
c.SetState(StateTerminating)
}
if c.State() == StateTerminating {
if current-atomic.LoadUint32(&c.stateBeginTime) > 8000 {
c.SetState(StateTerminated)
} else {
c.Ping(current, CommandTerminate)
}
return
}
if c.State() == StatePeerTerminating && current-atomic.LoadUint32(&c.stateBeginTime) > 4000 {
c.SetState(StateTerminating)
}
if c.State() == StateReadyToClose && current-atomic.LoadUint32(&c.stateBeginTime) > 15000 {
c.SetState(StateTerminating)
}
c.receivingWorker.Flush(current)
c.sendingWorker.Flush(current)
if current-atomic.LoadUint32(&c.lastPingTime) >= 3000 {
c.Ping(current, CommandPing)
}
select {
case c.dataOutput <- struct{}{}:
default:
}
}
func (c *Connection) Close() error {
if c == nil {
return ErrClosedConnection
}
select {
case c.dataInput <- struct{}{}:
default:
}
select {
case c.dataOutput <- struct{}{}:
default:
}
switch c.State() {
case StateReadyToClose, StateTerminating, StateTerminated:
return ErrClosedConnection
case StateActive:
c.SetState(StateReadyToClose)
case StatePeerClosed:
c.SetState(StateTerminating)
case StatePeerTerminating:
c.SetState(StateTerminated)
}
return nil
}
func (c *Connection) LocalAddr() net.Addr {
if c == nil {
return nil
}
return c.meta.LocalAddr
}
func (c *Connection) RemoteAddr() net.Addr {
if c == nil {
return nil
}
return c.meta.RemoteAddr
}
func (c *Connection) SetDeadline(t time.Time) error {
if err := c.SetReadDeadline(t); err != nil {
return err
}
if err := c.SetWriteDeadline(t); err != nil {
return err
}
return nil
}
func (c *Connection) SetReadDeadline(t time.Time) error {
if c == nil {
return ErrClosedConnection
}
c.rd = t
return nil
}
func (c *Connection) SetWriteDeadline(t time.Time) error {
if c == nil {
return ErrClosedConnection
}
c.wd = t
return nil
}
func (c *Connection) Ping(current uint32, cmd Command) {
seg := NewCmdOnlySegment()
seg.Conv = c.meta.Conversation
seg.Cmd = cmd
seg.SendingNext = c.sendingWorker.FirstUnacknowledged()
seg.ReceivingNext = c.receivingWorker.NextNumber()
seg.PeerRTO = c.roundTrip.Timeout()
if c.State() == StateReadyToClose {
seg.Option = SegmentOptionClose
}
c.output.Write(seg)
atomic.StoreUint32(&c.lastPingTime, current)
seg.Release()
}