Files
sing-box-extended/transport/trojan/service.go

188 lines
4.6 KiB
Go

package trojan
import (
"context"
"encoding/binary"
"net"
"sync"
"github.com/sagernet/sing/common/auth"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/rw"
)
type Handler interface {
N.TCPConnectionHandlerEx
N.UDPConnectionHandlerEx
}
type Service[K comparable] struct {
users map[K][56]byte
keys map[[56]byte]K
handler Handler
fallbackHandler N.TCPConnectionHandlerEx
logger logger.ContextLogger
conns map[K][]net.Conn
mu sync.RWMutex
}
func NewService[K comparable](handler Handler, fallbackHandler N.TCPConnectionHandlerEx, logger logger.ContextLogger) *Service[K] {
return &Service[K]{
users: make(map[K][56]byte),
keys: make(map[[56]byte]K),
handler: handler,
fallbackHandler: fallbackHandler,
logger: logger,
conns: make(map[K][]net.Conn),
}
}
var ErrUserExists = E.New("user already exists")
func (s *Service[K]) UpdateUsers(userList []K, passwordList []string) error {
users := make(map[K][56]byte)
keys := make(map[[56]byte]K)
for i, user := range userList {
if _, loaded := users[user]; loaded {
return ErrUserExists
}
key := Key(passwordList[i])
if oldUser, loaded := keys[key]; loaded {
return E.Extend(ErrUserExists, "password used by ", oldUser)
}
users[user] = key
keys[key] = user
}
s.mu.Lock()
s.users = users
s.keys = keys
var closedConns []net.Conn
for user, conns := range s.conns {
if _, exists := users[user]; !exists {
closedConns = append(closedConns, conns...)
delete(s.conns, user)
}
}
s.mu.Unlock()
for _, conn := range closedConns {
conn.Close()
}
return nil
}
func (s *Service[K]) NewConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, onClose N.CloseHandlerFunc) error {
var key [KeyLength]byte
n, err := conn.Read(key[:])
if err != nil {
return err
} else if n != KeyLength {
return s.fallback(ctx, conn, source, key[:n], E.New("bad request size"), onClose)
}
s.mu.RLock()
if user, loaded := s.keys[key]; loaded {
s.mu.RUnlock()
ctx = auth.ContextWithUser(ctx, user)
s.mu.Lock()
s.conns[user] = append(s.conns[user], conn)
s.mu.Unlock()
originalOnClose := onClose
onClose = N.OnceClose(func(it error) {
s.mu.Lock()
conns := s.conns[user]
for i, c := range conns {
if c == conn {
s.conns[user] = append(conns[:i], conns[i+1:]...)
break
}
}
s.mu.Unlock()
if originalOnClose != nil {
originalOnClose(it)
}
})
} else {
s.mu.RUnlock()
return s.fallback(ctx, conn, source, key[:], E.New("bad request"), onClose)
}
err = rw.SkipN(conn, 2)
if err != nil {
return E.Cause(err, "skip crlf")
}
var command byte
err = binary.Read(conn, binary.BigEndian, &command)
if err != nil {
return E.Cause(err, "read command")
}
switch command {
case CommandTCP, CommandUDP, CommandMux:
default:
return E.New("unknown command ", command)
}
// var destination M.Socksaddr
destination, err := M.SocksaddrSerializer.ReadAddrPort(conn)
if err != nil {
return E.Cause(err, "read destination")
}
err = rw.SkipN(conn, 2)
if err != nil {
return E.Cause(err, "skip crlf")
}
switch command {
case CommandTCP:
s.handler.NewConnectionEx(ctx, conn, source, destination, onClose)
case CommandUDP:
s.handler.NewPacketConnectionEx(ctx, &PacketConn{Conn: conn}, source, destination, onClose)
// case CommandMux:
default:
return HandleMuxConnection(ctx, conn, source, s.handler, s.logger, onClose)
}
return nil
}
func (s *Service[K]) fallback(ctx context.Context, conn net.Conn, source M.Socksaddr, header []byte, err error, onClose N.CloseHandlerFunc) error {
if s.fallbackHandler == nil {
return E.Extend(err, "fallback disabled")
}
conn = bufio.NewCachedConn(conn, buf.As(header).ToOwned())
s.fallbackHandler.NewConnectionEx(ctx, conn, source, M.Socksaddr{}, onClose)
return nil
}
type PacketConn struct {
net.Conn
readWaitOptions N.ReadWaitOptions
}
func (c *PacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
return ReadPacket(c.Conn, buffer)
}
func (c *PacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
return WritePacket(c.Conn, buffer, destination)
}
func (c *PacketConn) FrontHeadroom() int {
return M.MaxSocksaddrLength + 4
}
func (c *PacketConn) NeedAdditionalReadDeadline() bool {
return true
}
func (c *PacketConn) Upstream() any {
return c.Conn
}