Files

346 lines
10 KiB
Go

package client
import (
"context"
"net"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
CM "github.com/sagernet/sing-box/service/manager/constant"
pb "github.com/sagernet/sing-box/service/node_manager_api/manager"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
)
type APIClient struct {
boxService.Adapter
ctx context.Context
logger log.ContextLogger
dialer N.Dialer
creds credentials.TransportCredentials
options option.NodeManagerAPIClientOptions
conn *grpc.ClientConn
mtx sync.Mutex
}
func NewAPIClient(ctx context.Context, logger log.ContextLogger, tag string, options option.NodeManagerAPIClientOptions) (*APIClient, error) {
if options.APIKey == "" {
return nil, E.New("missing api key")
}
outboundDialer, err := dialer.New(ctx, options.DialerOptions, options.ServerIsDomain())
if err != nil {
return nil, err
}
creds := insecure.NewCredentials()
if options.TLS != nil {
tlsConfig, err := tls.NewClient(ctx, logger, options.Server, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
creds = &tlsCreds{tlsConfig}
}
return &APIClient{
Adapter: boxService.NewAdapter(C.TypeManager, tag),
ctx: metadata.AppendToOutgoingContext(ctx, "authorization", options.APIKey),
logger: logger,
dialer: outboundDialer,
creds: creds,
options: options,
}, nil
}
func (s *APIClient) AddNode(uuid string, node CM.ConnectedNode) error {
go func() {
isRetry := false
for {
if !isRetry {
select {
case <-s.ctx.Done():
return
default:
isRetry = true
}
} else {
select {
case <-time.After(5 * time.Second):
break
case <-s.ctx.Done():
return
}
}
conn, err := s.getConn()
if err != nil {
s.logger.Error(err)
continue
}
client := pb.NewManagerClient(conn)
stream, err := client.AddNode(s.ctx, &pb.Node{Uuid: uuid})
if err != nil {
s.logger.Error(err)
continue
}
err = s.handler(node, stream)
if err != nil {
s.logger.Error(err)
continue
}
}
}()
return nil
}
func (s *APIClient) AcquireLock(limiterId int, id string) (string, error) {
conn, err := s.getConn()
if err != nil {
return "", err
}
client := pb.NewManagerClient(conn)
lockReply, err := client.AcquireLock(s.ctx, &pb.AcquireLockRequest{LimiterId: int32(limiterId), Id: id})
if err != nil {
return "", err
}
return lockReply.HandleId, err
}
func (s *APIClient) RefreshLock(limiterId int, id string, handleId string) error {
conn, err := s.getConn()
if err != nil {
return err
}
client := pb.NewManagerClient(conn)
_, err = client.RefreshLock(s.ctx, &pb.LockData{LimiterId: int32(limiterId), Id: id, HandleId: handleId})
return err
}
func (s *APIClient) ReleaseLock(limiterId int, id string, handleId string) error {
conn, err := s.getConn()
if err != nil {
return err
}
client := pb.NewManagerClient(conn)
_, err = client.ReleaseLock(s.ctx, &pb.LockData{LimiterId: int32(limiterId), Id: id, HandleId: handleId})
return err
}
func (s *APIClient) AddTrafficUsage(limiterId int, n uint64) (uint64, error) {
conn, err := s.getConn()
if err != nil {
return 0, err
}
client := pb.NewManagerClient(conn)
reply, err := client.AddTrafficUsage(s.ctx, &pb.TrafficUsageRequest{LimiterId: int32(limiterId), N: n})
if err != nil {
return 0, err
}
return reply.Remaining, nil
}
func (s *APIClient) Start(stage adapter.StartStage) error {
return nil
}
func (s *APIClient) Close() error {
return nil
}
func (s *APIClient) getConn() (*grpc.ClientConn, error) {
s.mtx.Lock()
defer s.mtx.Unlock()
if s.conn != nil {
state := s.conn.GetState()
if state != connectivity.Shutdown && state != connectivity.TransientFailure {
return s.conn, nil
}
}
for {
conn, err := s.createConn()
if err != nil {
return nil, err
}
s.conn = conn
return conn, nil
}
}
func (s *APIClient) createConn() (*grpc.ClientConn, error) {
conn, err := grpc.NewClient(
s.options.ServerOptions.Build().String(),
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return s.dialer.DialContext(ctx, N.NetworkTCP, M.ParseSocksaddr(addr))
}),
grpc.WithTransportCredentials(s.creds),
)
if err != nil {
return nil, err
}
return conn, nil
}
func (s *APIClient) handler(node CM.ConnectedNode, stream grpc.ServerStreamingClient[pb.NodeData]) error {
for {
data, err := stream.Recv()
if err != nil {
return err
}
switch data.Op {
case pb.OpType_updateUser:
s.logger.DebugContext(s.ctx, "update user")
node.UpdateUser(s.convertUser(data.Data.(*pb.NodeData_User).User))
case pb.OpType_updateUsers:
s.logger.DebugContext(s.ctx, "update users")
users := data.Data.(*pb.NodeData_Users).Users.Values
convertedUsers := make([]CM.User, len(users))
for i, user := range users {
convertedUsers[i] = s.convertUser(user)
}
node.UpdateUsers(convertedUsers)
case pb.OpType_deleteUser:
s.logger.DebugContext(s.ctx, "delete user")
node.DeleteUser(s.convertUser(data.Data.(*pb.NodeData_User).User))
case pb.OpType_updateConnectionLimiter:
s.logger.DebugContext(s.ctx, "update connection limiter")
node.UpdateConnectionLimiter(s.convertConnectionLimiter(data.Data.(*pb.NodeData_ConnectionLimiter).ConnectionLimiter))
case pb.OpType_updateConnectionLimiters:
s.logger.DebugContext(s.ctx, "update connection limiters")
limiters := data.Data.(*pb.NodeData_ConnectionLimiters).ConnectionLimiters.Values
convertedLimiters := make([]CM.ConnectionLimiter, len(limiters))
for i, limiter := range limiters {
convertedLimiters[i] = s.convertConnectionLimiter(limiter)
}
node.UpdateConnectionLimiters(convertedLimiters)
case pb.OpType_deleteConnectionLimiter:
s.logger.DebugContext(s.ctx, "delete connection limiter")
node.DeleteConnectionLimiter(s.convertConnectionLimiter(data.Data.(*pb.NodeData_ConnectionLimiter).ConnectionLimiter))
case pb.OpType_updateBandwidthLimiter:
s.logger.DebugContext(s.ctx, "update bandwidth limiter")
node.UpdateBandwidthLimiter(s.convertBandwidthLimiter(data.Data.(*pb.NodeData_BandwidthLimiter).BandwidthLimiter))
case pb.OpType_updateBandwidthLimiters:
s.logger.DebugContext(s.ctx, "update bandwidth limiters")
limiters := data.Data.(*pb.NodeData_BandwidthLimiters).BandwidthLimiters.Values
convertedLimiters := make([]CM.BandwidthLimiter, len(limiters))
for i, limiter := range limiters {
convertedLimiters[i] = s.convertBandwidthLimiter(limiter)
}
node.UpdateBandwidthLimiters(convertedLimiters)
case pb.OpType_deleteBandwidthLimiter:
s.logger.DebugContext(s.ctx, "delete bandwidth limiter")
node.DeleteBandwidthLimiter(s.convertBandwidthLimiter(data.Data.(*pb.NodeData_BandwidthLimiter).BandwidthLimiter))
case pb.OpType_updateTrafficLimiter:
s.logger.DebugContext(s.ctx, "update traffic limiter")
node.UpdateTrafficLimiter(s.convertTrafficLimiter(data.Data.(*pb.NodeData_TrafficLimiter).TrafficLimiter))
case pb.OpType_updateTrafficLimiters:
s.logger.DebugContext(s.ctx, "update traffic limiters")
limiters := data.Data.(*pb.NodeData_TrafficLimiters).TrafficLimiters.Values
convertedLimiters := make([]CM.TrafficLimiter, len(limiters))
for i, limiter := range limiters {
convertedLimiters[i] = s.convertTrafficLimiter(limiter)
}
node.UpdateTrafficLimiters(convertedLimiters)
case pb.OpType_deleteTrafficLimiter:
s.logger.DebugContext(s.ctx, "delete traffic limiter")
node.DeleteTrafficLimiter(s.convertTrafficLimiter(data.Data.(*pb.NodeData_TrafficLimiter).TrafficLimiter))
case pb.OpType_updateRateLimiter:
s.logger.DebugContext(s.ctx, "update rate limiter")
node.UpdateRateLimiter(s.convertRateLimiter(data.Data.(*pb.NodeData_RateLimiter).RateLimiter))
case pb.OpType_updateRateLimiters:
s.logger.DebugContext(s.ctx, "update rate limiters")
limiters := data.Data.(*pb.NodeData_RateLimiters).RateLimiters.Values
convertedLimiters := make([]CM.RateLimiter, len(limiters))
for i, limiter := range limiters {
convertedLimiters[i] = s.convertRateLimiter(limiter)
}
node.UpdateRateLimiters(convertedLimiters)
case pb.OpType_deleteRateLimiter:
s.logger.DebugContext(s.ctx, "delete rate limiter")
node.DeleteRateLimiter(s.convertRateLimiter(data.Data.(*pb.NodeData_RateLimiter).RateLimiter))
}
}
}
func (s *APIClient) convertUser(user *pb.User) CM.User {
return CM.User{
ID: int(user.Id),
Username: user.Username,
Inbound: user.Inbound,
Type: user.Type,
UUID: user.Uuid,
Password: user.Password,
Secret: user.Secret,
Flow: user.Flow,
AlterID: int(user.AlterId),
}
}
func (s *APIClient) convertBandwidthLimiter(limiter *pb.BandwidthLimiter) CM.BandwidthLimiter {
return CM.BandwidthLimiter{
ID: int(limiter.Id),
Username: limiter.Username,
Outbound: limiter.Outbound,
Strategy: limiter.Strategy,
ConnectionType: limiter.ConnectionType,
Mode: limiter.Mode,
FlowKeys: limiter.FlowKeys,
Speed: limiter.Speed,
RawSpeed: limiter.RawSpeed,
}
}
func (s *APIClient) convertConnectionLimiter(limiter *pb.ConnectionLimiter) CM.ConnectionLimiter {
return CM.ConnectionLimiter{
ID: int(limiter.Id),
Username: limiter.Username,
Outbound: limiter.Outbound,
Strategy: limiter.Strategy,
ConnectionType: limiter.ConnectionType,
LockType: limiter.LockType,
Count: limiter.Count,
}
}
func (s *APIClient) convertTrafficLimiter(limiter *pb.TrafficLimiter) CM.TrafficLimiter {
return CM.TrafficLimiter{
ID: int(limiter.Id),
Username: limiter.Username,
Outbound: limiter.Outbound,
Strategy: limiter.Strategy,
Mode: limiter.Mode,
RawUsed: limiter.RawUsed,
Quota: limiter.Quota,
RawQuota: limiter.RawQuota,
}
}
func (s *APIClient) convertRateLimiter(limiter *pb.RateLimiter) CM.RateLimiter {
return CM.RateLimiter{
ID: int(limiter.Id),
Username: limiter.Username,
Outbound: limiter.Outbound,
Strategy: limiter.Strategy,
ConnectionType: limiter.ConnectionType,
Count: limiter.Count,
Interval: limiter.Interval,
}
}