Files
sing-box-extended/protocol/bond/conn.go

160 lines
3.1 KiB
Go

package bond
import (
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"time"
)
type bondedConn struct {
conns []net.Conn
downloadRatios []uint8
uploadRatios []uint8
readBuffer *bytes.Buffer
}
func NewBondedConn(conns []net.Conn, downloadRatios, uploadRatios []uint8) *bondedConn {
return &bondedConn{
conns: conns,
downloadRatios: downloadRatios,
uploadRatios: uploadRatios,
readBuffer: bytes.NewBuffer(make([]byte, 0, 65536)),
}
}
func (c *bondedConn) Read(b []byte) (n int, err error) {
if c.readBuffer.Len() == 0 {
c.readBuffer.Reset()
var header [2]byte
_, err := io.ReadFull(c.conns[0], header[:])
if err != nil {
return 0, err
}
size := int(binary.BigEndian.Uint16(header[:]))
chunkLens := splitByRatios(size, c.downloadRatios)
total := 0
for i, chunkLen := range chunkLens {
if chunkLen == 0 {
continue
}
n, err := io.CopyN(c.readBuffer, c.conns[i], int64(chunkLen))
total += int(n)
if err != nil {
return total, err
}
}
}
return c.readBuffer.Read(b)
}
func (c *bondedConn) Write(b []byte) (n int, err error) {
chunkLens := splitByRatios(len(b), c.uploadRatios)
var header [2]byte
binary.BigEndian.PutUint16(header[:], uint16(len(b)))
_, err = c.conns[0].Write(header[:])
if err != nil {
return 0, err
}
total := 0
for i, chunkLen := range chunkLens {
if chunkLen == 0 {
continue
}
chunk := b[total : total+chunkLen]
conn := c.conns[i]
subTotal := 0
for subTotal < len(chunk) {
n, err := conn.Write(chunk[subTotal:])
subTotal += n
total += n
if err != nil {
return total, err
}
if n == 0 {
return total, io.ErrUnexpectedEOF
}
}
}
return total, err
}
func (c *bondedConn) Close() error {
errs := make([]error, 0)
for _, conn := range c.conns {
err := conn.Close()
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}
func (c *bondedConn) LocalAddr() net.Addr {
return nil
}
func (c *bondedConn) RemoteAddr() net.Addr {
return nil
}
func (c *bondedConn) SetDeadline(t time.Time) error {
errs := make([]error, 0)
for _, conn := range c.conns {
err := conn.SetDeadline(t)
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}
func (c *bondedConn) SetReadDeadline(t time.Time) error {
errs := make([]error, 0)
for _, conn := range c.conns {
err := conn.SetReadDeadline(t)
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}
func (c *bondedConn) SetWriteDeadline(t time.Time) error {
errs := make([]error, 0)
for _, conn := range c.conns {
err := conn.SetWriteDeadline(t)
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}
func splitByRatios(number int, ratios []uint8) []int {
result := make([]int, len(ratios))
remaining := number
for i := 0; i < len(ratios)-1; i++ {
part := number * int(ratios[i]) / 100
result[i] = part
remaining -= part
}
result[len(ratios)-1] = remaining
return result
}