Files
sing-box-extended/protocol/bond/outbound.go

153 lines
4.2 KiB
Go

package bond
import (
"context"
"errors"
"net"
"sync"
"github.com/gofrs/uuid/v5"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/outbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.BondOutboundOptions](registry, C.TypeBond, NewOutbound)
}
type Outbound struct {
outbound.Adapter
ctx context.Context
logger logger.ContextLogger
outbounds []adapter.Outbound
downloadRatios []uint8
uploadRatios []uint8
uotClient *uot.Client
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.BondOutboundOptions) (adapter.Outbound, error) {
outboundRegistry := service.FromContext[adapter.OutboundRegistry](ctx)
outbounds := make([]adapter.Outbound, 0, len(options.Outbounds))
downloadRatios := make([]uint8, 0, len(options.Outbounds))
uploadRatios := make([]uint8, 0, len(options.Outbounds))
var totalDownloadRatio, totalUploadRatio uint8
for _, outboundOptions := range options.Outbounds {
count := outboundOptions.Count
if count == 0 {
count = 1
}
for range count {
outbound, err := outboundRegistry.UnsafeCreateOutbound(ctx, router, logger, outboundOptions.Outbound.Tag, outboundOptions.Outbound.Type, outboundOptions.Outbound.Options)
if err != nil {
return nil, err
}
outbounds = append(outbounds, outbound)
downloadRatios = append(downloadRatios, outboundOptions.DownloadRatio)
uploadRatios = append(uploadRatios, outboundOptions.UploadRatio)
totalDownloadRatio += outboundOptions.DownloadRatio
totalUploadRatio += outboundOptions.UploadRatio
}
}
if totalDownloadRatio != 100 || totalUploadRatio != 100 {
return nil, E.New("invalid ratios")
}
outbound := &Outbound{
Adapter: outbound.NewAdapter(C.TypeTunnelClient, tag, []string{N.NetworkTCP, N.NetworkUDP}, []string{}),
ctx: ctx,
outbounds: outbounds,
downloadRatios: downloadRatios,
uploadRatios: uploadRatios,
logger: logger,
}
outbound.uotClient = &uot.Client{
Dialer: outbound,
Version: uot.Version,
}
return outbound, nil
}
func (h *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if N.NetworkName(network) == N.NetworkUDP {
return h.uotClient.DialContext(ctx, network, destination)
}
conns := make([]net.Conn, len(h.outbounds))
connUUID, err := uuid.NewV4()
if err != nil {
return nil, err
}
errs := make([]error, 0, len(conns))
var mtx sync.Mutex
var wg sync.WaitGroup
for i, outbound := range h.outbounds {
wg.Go(
func() {
conn, err := outbound.DialContext(ctx, network, Destination)
if err != nil {
mtx.Lock()
errs = append(errs, err)
mtx.Unlock()
return
}
err = WriteRequest(
conn,
&Request{
UUID: connUUID,
Index: byte(i),
Count: byte(len(h.outbounds)),
DownloadRatio: h.uploadRatios[i],
UploadRatio: h.downloadRatios[i],
Destination: destination,
},
)
if err != nil {
conn.Close()
mtx.Lock()
errs = append(errs, err)
mtx.Unlock()
return
}
conns[i] = conn
},
)
}
wg.Wait()
if len(errs) != 0 {
for _, conn := range conns {
if conn != nil {
conn.Close()
}
}
return nil, errors.Join(errs...)
}
conn := NewBondedConn(conns, h.downloadRatios, h.uploadRatios)
return conn, nil
}
func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
return h.uotClient.ListenPacket(ctx, destination)
}
func (h *Outbound) Close() error {
errs := make([]error, 0)
for _, outbound := range h.outbounds {
err := common.Close(outbound)
if err != nil {
errs = append(errs, err)
}
}
if len(errs) != 0 {
return errors.Join(errs...)
}
return nil
}