mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-05-22 20:29:51 +03:00
Update sing-box core
This commit is contained in:
139
service/ccm/credential.go
Normal file
139
service/ccm/credential.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
oauth2TokenURL = "https://console.anthropic.com/v1/oauth/token"
|
||||
claudeAPIBaseURL = "https://api.anthropic.com"
|
||||
tokenRefreshBufferMs = 60000
|
||||
anthropicBetaOAuthValue = "oauth-2025-04-20"
|
||||
)
|
||||
|
||||
func getRealUser() (*user.User, error) {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
sudoUserInfo, err := user.Lookup(sudoUser)
|
||||
if err == nil {
|
||||
return sudoUserInfo, nil
|
||||
}
|
||||
}
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
func getDefaultCredentialsPath() (string, error) {
|
||||
if configDir := os.Getenv("CLAUDE_CONFIG_DIR"); configDir != "" {
|
||||
return filepath.Join(configDir, ".credentials.json"), nil
|
||||
}
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil
|
||||
}
|
||||
|
||||
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var credentialsContainer struct {
|
||||
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
|
||||
}
|
||||
err = json.Unmarshal(data, &credentialsContainer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if credentialsContainer.ClaudeAIAuth == nil {
|
||||
return nil, E.New("claudeAiOauth field not found in credentials")
|
||||
}
|
||||
return credentialsContainer.ClaudeAIAuth, nil
|
||||
}
|
||||
|
||||
func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error {
|
||||
data, err := json.MarshalIndent(map[string]any{
|
||||
"claudeAiOauth": oauthCredentials,
|
||||
}, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
}
|
||||
|
||||
type oauthCredentials struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt int64 `json:"expiresAt"`
|
||||
Scopes []string `json:"scopes,omitempty"`
|
||||
SubscriptionType string `json:"subscriptionType,omitempty"`
|
||||
IsMax bool `json:"isMax,omitempty"`
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) needsRefresh() bool {
|
||||
if c.ExpiresAt == 0 {
|
||||
return false
|
||||
}
|
||||
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
|
||||
}
|
||||
|
||||
func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
if credentials.RefreshToken == "" {
|
||||
return nil, E.New("refresh token is empty")
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.RefreshToken,
|
||||
"client_id": oauth2ClientID,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
}
|
||||
|
||||
var tokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode response")
|
||||
}
|
||||
|
||||
newCredentials := *credentials
|
||||
newCredentials.AccessToken = tokenResponse.AccessToken
|
||||
if tokenResponse.RefreshToken != "" {
|
||||
newCredentials.RefreshToken = tokenResponse.RefreshToken
|
||||
}
|
||||
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
|
||||
|
||||
return &newCredentials, nil
|
||||
}
|
||||
116
service/ccm/credential_darwin.go
Normal file
116
service/ccm/credential_darwin.go
Normal file
@@ -0,0 +1,116 @@
|
||||
//go:build darwin && cgo
|
||||
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
|
||||
"github.com/keybase/go-keychain"
|
||||
)
|
||||
|
||||
func getKeychainServiceName() string {
|
||||
configDirectory := os.Getenv("CLAUDE_CONFIG_DIR")
|
||||
if configDirectory == "" {
|
||||
return "Claude Code-credentials"
|
||||
}
|
||||
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "Claude Code-credentials"
|
||||
}
|
||||
defaultConfigDirectory := filepath.Join(userInfo.HomeDir, ".claude")
|
||||
if configDirectory == defaultConfigDirectory {
|
||||
return "Claude Code-credentials"
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(configDirectory))
|
||||
return "Claude Code-credentials-" + hex.EncodeToString(hash[:])[:8]
|
||||
}
|
||||
|
||||
func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
if customPath != "" {
|
||||
return readCredentialsFromFile(customPath)
|
||||
}
|
||||
|
||||
userInfo, err := getRealUser()
|
||||
if err == nil {
|
||||
query := keychain.NewItem()
|
||||
query.SetSecClass(keychain.SecClassGenericPassword)
|
||||
query.SetService(getKeychainServiceName())
|
||||
query.SetAccount(userInfo.Username)
|
||||
query.SetMatchLimit(keychain.MatchLimitOne)
|
||||
query.SetReturnData(true)
|
||||
|
||||
results, err := keychain.QueryItem(query)
|
||||
if err == nil && len(results) == 1 {
|
||||
var container struct {
|
||||
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
|
||||
}
|
||||
unmarshalErr := json.Unmarshal(results[0].Data, &container)
|
||||
if unmarshalErr == nil && container.ClaudeAIAuth != nil {
|
||||
return container.ClaudeAIAuth, nil
|
||||
}
|
||||
}
|
||||
if err != nil && err != keychain.ErrorItemNotFound {
|
||||
return nil, E.Cause(err, "query keychain")
|
||||
}
|
||||
}
|
||||
|
||||
defaultPath, err := getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return readCredentialsFromFile(defaultPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error {
|
||||
if customPath != "" {
|
||||
return writeCredentialsToFile(oauthCredentials, customPath)
|
||||
}
|
||||
|
||||
userInfo, err := getRealUser()
|
||||
if err == nil {
|
||||
data, err := json.Marshal(map[string]any{"claudeAiOauth": oauthCredentials})
|
||||
if err == nil {
|
||||
serviceName := getKeychainServiceName()
|
||||
item := keychain.NewItem()
|
||||
item.SetSecClass(keychain.SecClassGenericPassword)
|
||||
item.SetService(serviceName)
|
||||
item.SetAccount(userInfo.Username)
|
||||
item.SetData(data)
|
||||
item.SetAccessible(keychain.AccessibleWhenUnlocked)
|
||||
|
||||
err = keychain.AddItem(item)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err == keychain.ErrorDuplicateItem {
|
||||
query := keychain.NewItem()
|
||||
query.SetSecClass(keychain.SecClassGenericPassword)
|
||||
query.SetService(serviceName)
|
||||
query.SetAccount(userInfo.Username)
|
||||
|
||||
updateItem := keychain.NewItem()
|
||||
updateItem.SetData(data)
|
||||
|
||||
updateErr := keychain.UpdateItem(query, updateItem)
|
||||
if updateErr == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
defaultPath, err := getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeCredentialsToFile(oauthCredentials, defaultPath)
|
||||
}
|
||||
25
service/ccm/credential_other.go
Normal file
25
service/ccm/credential_other.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build !darwin
|
||||
|
||||
package ccm
|
||||
|
||||
func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return readCredentialsFromFile(customPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return writeCredentialsToFile(oauthCredentials, customPath)
|
||||
}
|
||||
588
service/ccm/service.go
Normal file
588
service/ccm/service.go
Normal file
@@ -0,0 +1,588 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
const (
|
||||
contextWindowStandard = 200000
|
||||
contextWindowPremium = 1000000
|
||||
premiumContextThreshold = 200000
|
||||
)
|
||||
|
||||
func RegisterService(registry *boxService.Registry) {
|
||||
boxService.Register[option.CCMServiceOptions](registry, C.TypeCCM, NewService)
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Type string `json:"type"`
|
||||
Error errorDetails `json:"error"`
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
}
|
||||
|
||||
type errorDetails struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
json.NewEncoder(w).Encode(errorResponse{
|
||||
Type: "error",
|
||||
Error: errorDetails{
|
||||
Type: errorType,
|
||||
Message: message,
|
||||
},
|
||||
RequestID: r.Header.Get("Request-Id"),
|
||||
})
|
||||
}
|
||||
|
||||
func isHopByHopHeader(header string) bool {
|
||||
switch strings.ToLower(header) {
|
||||
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
weeklyWindowSeconds = 604800
|
||||
weeklyWindowMinutes = weeklyWindowSeconds / 60
|
||||
)
|
||||
|
||||
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
|
||||
headerValue := strings.TrimSpace(headers.Get(headerName))
|
||||
if headerValue == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
|
||||
if parseError != nil {
|
||||
return 0, false
|
||||
}
|
||||
return parsedValue, true
|
||||
}
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
resetAtUnix, hasResetAt := parseInt64Header(headers, "anthropic-ratelimit-unified-7d-reset")
|
||||
if !hasResetAt || resetAtUnix <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: weeklyWindowMinutes,
|
||||
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
boxService.Adapter
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
credentialPath string
|
||||
credentials *oauthCredentials
|
||||
users []option.CCMUser
|
||||
httpClient *http.Client
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
accessMutex sync.RWMutex
|
||||
usageTracker *AggregatedUsage
|
||||
trackingGroup sync.WaitGroup
|
||||
shuttingDown bool
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) {
|
||||
serviceDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer")
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSClientConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
userManager := &UserManager{
|
||||
tokenMap: make(map[string]string),
|
||||
}
|
||||
|
||||
var usageTracker *AggregatedUsage
|
||||
if options.UsagesPath != "" {
|
||||
usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
service := &Service{
|
||||
Adapter: boxService.NewAdapter(C.TypeCCM, tag),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
credentialPath: options.CredentialPath,
|
||||
users: options.Users,
|
||||
httpClient: httpClient,
|
||||
httpHeaders: options.Headers.Build(),
|
||||
listener: listener.New(listener.Options{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
Network: []string{N.NetworkTCP},
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
userManager: userManager,
|
||||
usageTracker: usageTracker,
|
||||
}
|
||||
|
||||
if options.TLS != nil {
|
||||
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service.tlsConfig = tlsConfig
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *Service) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.userManager.UpdateUsers(s.users)
|
||||
|
||||
credentials, err := platformReadCredentials(s.credentialPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read credentials")
|
||||
}
|
||||
s.credentials = credentials
|
||||
|
||||
if s.usageTracker != nil {
|
||||
err = s.usageTracker.Load()
|
||||
if err != nil {
|
||||
s.logger.Warn("load usage statistics: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Mount("/", s)
|
||||
|
||||
s.httpServer = &http.Server{Handler: router}
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
err = s.tlsConfig.Start()
|
||||
if err != nil {
|
||||
return E.Cause(err, "create TLS config")
|
||||
}
|
||||
}
|
||||
|
||||
tcpListener, err := s.listener.ListenTCP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
|
||||
s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
|
||||
}
|
||||
tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig)
|
||||
}
|
||||
|
||||
go func() {
|
||||
serveErr := s.httpServer.Serve(tcpListener)
|
||||
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
||||
s.logger.Error("serve error: ", serveErr)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) getAccessToken() (string, error) {
|
||||
s.accessMutex.RLock()
|
||||
if !s.credentials.needsRefresh() {
|
||||
token := s.credentials.AccessToken
|
||||
s.accessMutex.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
s.accessMutex.RUnlock()
|
||||
|
||||
s.accessMutex.Lock()
|
||||
defer s.accessMutex.Unlock()
|
||||
|
||||
if !s.credentials.needsRefresh() {
|
||||
return s.credentials.AccessToken, nil
|
||||
}
|
||||
|
||||
newCredentials, err := refreshToken(s.httpClient, s.credentials)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.credentials = newCredentials
|
||||
|
||||
err = platformWriteCredentials(newCredentials, s.credentialPath)
|
||||
if err != nil {
|
||||
s.logger.Warn("persist refreshed token: ", err)
|
||||
}
|
||||
|
||||
return newCredentials.AccessToken, nil
|
||||
}
|
||||
|
||||
func detectContextWindow(betaHeader string, inputTokens int64) int {
|
||||
if inputTokens > premiumContextThreshold {
|
||||
features := strings.Split(betaHeader, ",")
|
||||
for _, feature := range features {
|
||||
if strings.TrimSpace(feature) == "context-1m" {
|
||||
return contextWindowPremium
|
||||
}
|
||||
}
|
||||
}
|
||||
return contextWindowStandard
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if !strings.HasPrefix(r.URL.Path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
|
||||
return
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var requestModel string
|
||||
var messagesCount int
|
||||
|
||||
if s.usageTracker != nil && r.Body != nil {
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err == nil {
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
Messages []anthropic.MessageParam `json:"messages"`
|
||||
}
|
||||
err := json.Unmarshal(bodyBytes, &request)
|
||||
if err == nil {
|
||||
requestModel = request.Model
|
||||
messagesCount = len(request.Messages)
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := s.getAccessToken()
|
||||
if err != nil {
|
||||
s.logger.Error("get access token: ", err)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
|
||||
return
|
||||
}
|
||||
|
||||
proxyURL := claudeAPIBaseURL + r.URL.RequestURI()
|
||||
proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
for key, values := range r.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
|
||||
if anthropicBetaHeader != "" {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
|
||||
} else {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
|
||||
}
|
||||
|
||||
for key, values := range s.httpHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
response, err := s.httpClient.Do(proxyRequest)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
if s.usageTracker != nil && response.StatusCode == http.StatusOK {
|
||||
s.handleResponseWithTracking(w, response, requestModel, anthropicBetaHeader, messagesCount, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
_, _ = io.Copy(w, response.Body)
|
||||
return
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var message anthropic.Message
|
||||
var usage anthropic.Usage
|
||||
var responseModel string
|
||||
err = json.Unmarshal(bodyBytes, &message)
|
||||
if err == nil {
|
||||
responseModel = string(message.Model)
|
||||
usage = message.Usage
|
||||
}
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, usage.InputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
usage.InputTokens,
|
||||
usage.OutputTokens,
|
||||
usage.CacheReadInputTokens,
|
||||
usage.CacheCreationInputTokens,
|
||||
usage.CacheCreation.Ephemeral5mInputTokens,
|
||||
usage.CacheCreation.Ephemeral1hInputTokens,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = writer.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
var accumulatedUsage anthropic.Usage
|
||||
var responseModel string
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
var leftover []byte
|
||||
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
data := append(leftover, buffer[:n]...)
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
|
||||
if err == nil {
|
||||
leftover = lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
} else {
|
||||
leftover = nil
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
eventData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(eventData, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
|
||||
var event anthropic.MessageStreamEventUnion
|
||||
err := json.Unmarshal(eventData, &event)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
switch event.Type {
|
||||
case "message_start":
|
||||
messageStart := event.AsMessageStart()
|
||||
if messageStart.Message.Model != "" {
|
||||
responseModel = string(messageStart.Message.Model)
|
||||
}
|
||||
if messageStart.Message.Usage.InputTokens > 0 {
|
||||
accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
|
||||
accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
|
||||
accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
|
||||
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral5mInputTokens
|
||||
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens = messageStart.Message.Usage.CacheCreation.Ephemeral1hInputTokens
|
||||
}
|
||||
case "message_delta":
|
||||
messageDelta := event.AsMessageDelta()
|
||||
if messageDelta.Usage.OutputTokens > 0 {
|
||||
accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, accumulatedUsage.InputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
messagesCount,
|
||||
accumulatedUsage.InputTokens,
|
||||
accumulatedUsage.OutputTokens,
|
||||
accumulatedUsage.CacheReadInputTokens,
|
||||
accumulatedUsage.CacheCreationInputTokens,
|
||||
accumulatedUsage.CacheCreation.Ephemeral5mInputTokens,
|
||||
accumulatedUsage.CacheCreation.Ephemeral1hInputTokens,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
err := common.Close(
|
||||
common.PtrOrNil(s.httpServer),
|
||||
common.PtrOrNil(s.listener),
|
||||
s.tlsConfig,
|
||||
)
|
||||
|
||||
if s.usageTracker != nil {
|
||||
s.usageTracker.cancelPendingSave()
|
||||
saveErr := s.usageTracker.Save()
|
||||
if saveErr != nil {
|
||||
s.logger.Error("save usage statistics: ", saveErr)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
675
service/ccm/service_usage.go
Normal file
675
service/ccm/service_usage.go
Normal file
@@ -0,0 +1,675 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/log"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type UsageStats struct {
|
||||
RequestCount int `json:"request_count"`
|
||||
MessagesCount int `json:"messages_count"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheReadInputTokens int64 `json:"cache_read_input_tokens"`
|
||||
CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"`
|
||||
CacheCreation5MinuteInputTokens int64 `json:"cache_creation_5m_input_tokens,omitempty"`
|
||||
CacheCreation1HourInputTokens int64 `json:"cache_creation_1h_input_tokens,omitempty"`
|
||||
}
|
||||
|
||||
type CostCombination struct {
|
||||
Model string `json:"model"`
|
||||
ContextWindow int `json:"context_window"`
|
||||
WeekStartUnix int64 `json:"week_start_unix,omitempty"`
|
||||
Total UsageStats `json:"total"`
|
||||
ByUser map[string]UsageStats `json:"by_user"`
|
||||
}
|
||||
|
||||
type AggregatedUsage struct {
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
Combinations []CostCombination `json:"combinations"`
|
||||
mutex sync.Mutex
|
||||
filePath string
|
||||
logger log.ContextLogger
|
||||
lastSaveTime time.Time
|
||||
pendingSave bool
|
||||
saveTimer *time.Timer
|
||||
saveMutex sync.Mutex
|
||||
}
|
||||
|
||||
type UsageStatsJSON struct {
|
||||
RequestCount int `json:"request_count"`
|
||||
MessagesCount int `json:"messages_count"`
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
CacheReadInputTokens int64 `json:"cache_read_input_tokens"`
|
||||
CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"`
|
||||
CacheCreation5MinuteInputTokens int64 `json:"cache_creation_5m_input_tokens,omitempty"`
|
||||
CacheCreation1HourInputTokens int64 `json:"cache_creation_1h_input_tokens,omitempty"`
|
||||
CostUSD float64 `json:"cost_usd"`
|
||||
}
|
||||
|
||||
type CostCombinationJSON struct {
|
||||
Model string `json:"model"`
|
||||
ContextWindow int `json:"context_window"`
|
||||
WeekStartUnix int64 `json:"week_start_unix,omitempty"`
|
||||
Total UsageStatsJSON `json:"total"`
|
||||
ByUser map[string]UsageStatsJSON `json:"by_user"`
|
||||
}
|
||||
|
||||
type CostsSummaryJSON struct {
|
||||
TotalUSD float64 `json:"total_usd"`
|
||||
ByUser map[string]float64 `json:"by_user"`
|
||||
ByWeek map[string]float64 `json:"by_week,omitempty"`
|
||||
}
|
||||
|
||||
type AggregatedUsageJSON struct {
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
Costs CostsSummaryJSON `json:"costs"`
|
||||
Combinations []CostCombinationJSON `json:"combinations"`
|
||||
}
|
||||
|
||||
type WeeklyCycleHint struct {
|
||||
WindowMinutes int64
|
||||
ResetAt time.Time
|
||||
}
|
||||
|
||||
type ModelPricing struct {
|
||||
InputPrice float64
|
||||
OutputPrice float64
|
||||
CacheReadPrice float64
|
||||
CacheWritePrice5Minute float64
|
||||
CacheWritePrice1Hour float64
|
||||
}
|
||||
|
||||
type modelFamily struct {
|
||||
pattern *regexp.Regexp
|
||||
standardPricing ModelPricing
|
||||
premiumPricing *ModelPricing
|
||||
}
|
||||
|
||||
var (
|
||||
opus46StandardPricing = ModelPricing{
|
||||
InputPrice: 5.0,
|
||||
OutputPrice: 25.0,
|
||||
CacheReadPrice: 0.5,
|
||||
CacheWritePrice5Minute: 6.25,
|
||||
CacheWritePrice1Hour: 10.0,
|
||||
}
|
||||
|
||||
opus46PremiumPricing = ModelPricing{
|
||||
InputPrice: 10.0,
|
||||
OutputPrice: 37.5,
|
||||
CacheReadPrice: 1.0,
|
||||
CacheWritePrice5Minute: 12.5,
|
||||
CacheWritePrice1Hour: 20.0,
|
||||
}
|
||||
|
||||
opus45Pricing = ModelPricing{
|
||||
InputPrice: 5.0,
|
||||
OutputPrice: 25.0,
|
||||
CacheReadPrice: 0.5,
|
||||
CacheWritePrice5Minute: 6.25,
|
||||
CacheWritePrice1Hour: 10.0,
|
||||
}
|
||||
|
||||
opus4Pricing = ModelPricing{
|
||||
InputPrice: 15.0,
|
||||
OutputPrice: 75.0,
|
||||
CacheReadPrice: 1.5,
|
||||
CacheWritePrice5Minute: 18.75,
|
||||
CacheWritePrice1Hour: 30.0,
|
||||
}
|
||||
|
||||
sonnet46StandardPricing = ModelPricing{
|
||||
InputPrice: 3.0,
|
||||
OutputPrice: 15.0,
|
||||
CacheReadPrice: 0.3,
|
||||
CacheWritePrice5Minute: 3.75,
|
||||
CacheWritePrice1Hour: 6.0,
|
||||
}
|
||||
|
||||
sonnet46PremiumPricing = ModelPricing{
|
||||
InputPrice: 6.0,
|
||||
OutputPrice: 22.5,
|
||||
CacheReadPrice: 0.6,
|
||||
CacheWritePrice5Minute: 7.5,
|
||||
CacheWritePrice1Hour: 12.0,
|
||||
}
|
||||
|
||||
sonnet45StandardPricing = ModelPricing{
|
||||
InputPrice: 3.0,
|
||||
OutputPrice: 15.0,
|
||||
CacheReadPrice: 0.3,
|
||||
CacheWritePrice5Minute: 3.75,
|
||||
CacheWritePrice1Hour: 6.0,
|
||||
}
|
||||
|
||||
sonnet45PremiumPricing = ModelPricing{
|
||||
InputPrice: 6.0,
|
||||
OutputPrice: 22.5,
|
||||
CacheReadPrice: 0.6,
|
||||
CacheWritePrice5Minute: 7.5,
|
||||
CacheWritePrice1Hour: 12.0,
|
||||
}
|
||||
|
||||
sonnet4StandardPricing = ModelPricing{
|
||||
InputPrice: 3.0,
|
||||
OutputPrice: 15.0,
|
||||
CacheReadPrice: 0.3,
|
||||
CacheWritePrice5Minute: 3.75,
|
||||
CacheWritePrice1Hour: 6.0,
|
||||
}
|
||||
|
||||
sonnet4PremiumPricing = ModelPricing{
|
||||
InputPrice: 6.0,
|
||||
OutputPrice: 22.5,
|
||||
CacheReadPrice: 0.6,
|
||||
CacheWritePrice5Minute: 7.5,
|
||||
CacheWritePrice1Hour: 12.0,
|
||||
}
|
||||
|
||||
sonnet37Pricing = ModelPricing{
|
||||
InputPrice: 3.0,
|
||||
OutputPrice: 15.0,
|
||||
CacheReadPrice: 0.3,
|
||||
CacheWritePrice5Minute: 3.75,
|
||||
CacheWritePrice1Hour: 6.0,
|
||||
}
|
||||
|
||||
sonnet35Pricing = ModelPricing{
|
||||
InputPrice: 3.0,
|
||||
OutputPrice: 15.0,
|
||||
CacheReadPrice: 0.3,
|
||||
CacheWritePrice5Minute: 3.75,
|
||||
CacheWritePrice1Hour: 6.0,
|
||||
}
|
||||
|
||||
haiku45Pricing = ModelPricing{
|
||||
InputPrice: 1.0,
|
||||
OutputPrice: 5.0,
|
||||
CacheReadPrice: 0.1,
|
||||
CacheWritePrice5Minute: 1.25,
|
||||
CacheWritePrice1Hour: 2.0,
|
||||
}
|
||||
|
||||
haiku4Pricing = ModelPricing{
|
||||
InputPrice: 1.0,
|
||||
OutputPrice: 5.0,
|
||||
CacheReadPrice: 0.1,
|
||||
CacheWritePrice5Minute: 1.25,
|
||||
CacheWritePrice1Hour: 2.0,
|
||||
}
|
||||
|
||||
haiku35Pricing = ModelPricing{
|
||||
InputPrice: 0.8,
|
||||
OutputPrice: 4.0,
|
||||
CacheReadPrice: 0.08,
|
||||
CacheWritePrice5Minute: 1.0,
|
||||
CacheWritePrice1Hour: 1.6,
|
||||
}
|
||||
|
||||
haiku3Pricing = ModelPricing{
|
||||
InputPrice: 0.25,
|
||||
OutputPrice: 1.25,
|
||||
CacheReadPrice: 0.03,
|
||||
CacheWritePrice5Minute: 0.3,
|
||||
CacheWritePrice1Hour: 0.5,
|
||||
}
|
||||
|
||||
opus3Pricing = ModelPricing{
|
||||
InputPrice: 15.0,
|
||||
OutputPrice: 75.0,
|
||||
CacheReadPrice: 1.5,
|
||||
CacheWritePrice5Minute: 18.75,
|
||||
CacheWritePrice1Hour: 30.0,
|
||||
}
|
||||
|
||||
modelFamilies = []modelFamily{
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-opus-4-6(?:-|$)`),
|
||||
standardPricing: opus46StandardPricing,
|
||||
premiumPricing: &opus46PremiumPricing,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-opus-4-5(?:-|$)`),
|
||||
standardPricing: opus45Pricing,
|
||||
premiumPricing: nil,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-(?:opus-4(?:-|$)|4-opus-)`),
|
||||
standardPricing: opus4Pricing,
|
||||
premiumPricing: nil,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-(?:opus-3(?:-|$)|3-opus-)`),
|
||||
standardPricing: opus3Pricing,
|
||||
premiumPricing: nil,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-(?:sonnet-4-6(?:-|$)|4-6-sonnet-)`),
|
||||
standardPricing: sonnet46StandardPricing,
|
||||
premiumPricing: &sonnet46PremiumPricing,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-(?:sonnet-4-5(?:-|$)|4-5-sonnet-)`),
|
||||
standardPricing: sonnet45StandardPricing,
|
||||
premiumPricing: &sonnet45PremiumPricing,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-(?:sonnet-4(?:-|$)|4-sonnet-)`),
|
||||
standardPricing: sonnet4StandardPricing,
|
||||
premiumPricing: &sonnet4PremiumPricing,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-3-7-sonnet(?:-|$)`),
|
||||
standardPricing: sonnet37Pricing,
|
||||
premiumPricing: nil,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-3-5-sonnet(?:-|$)`),
|
||||
standardPricing: sonnet35Pricing,
|
||||
premiumPricing: nil,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-(?:haiku-4-5(?:-|$)|4-5-haiku-)`),
|
||||
standardPricing: haiku45Pricing,
|
||||
premiumPricing: nil,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-haiku-4(?:-|$)`),
|
||||
standardPricing: haiku4Pricing,
|
||||
premiumPricing: nil,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-3-5-haiku(?:-|$)`),
|
||||
standardPricing: haiku35Pricing,
|
||||
premiumPricing: nil,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^claude-3-haiku(?:-|$)`),
|
||||
standardPricing: haiku3Pricing,
|
||||
premiumPricing: nil,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func getPricing(model string, contextWindow int) ModelPricing {
|
||||
isPremium := contextWindow >= contextWindowPremium
|
||||
|
||||
for _, family := range modelFamilies {
|
||||
if family.pattern.MatchString(model) {
|
||||
if isPremium && family.premiumPricing != nil {
|
||||
return *family.premiumPricing
|
||||
}
|
||||
return family.standardPricing
|
||||
}
|
||||
}
|
||||
|
||||
return sonnet4StandardPricing
|
||||
}
|
||||
|
||||
func calculateCost(stats UsageStats, model string, contextWindow int) float64 {
|
||||
pricing := getPricing(model, contextWindow)
|
||||
|
||||
cacheCreationCost := 0.0
|
||||
if stats.CacheCreation5MinuteInputTokens > 0 || stats.CacheCreation1HourInputTokens > 0 {
|
||||
cacheCreationCost = float64(stats.CacheCreation5MinuteInputTokens)*pricing.CacheWritePrice5Minute +
|
||||
float64(stats.CacheCreation1HourInputTokens)*pricing.CacheWritePrice1Hour
|
||||
} else {
|
||||
// Backward compatibility for usage files generated before TTL split tracking.
|
||||
cacheCreationCost = float64(stats.CacheCreationInputTokens) * pricing.CacheWritePrice5Minute
|
||||
}
|
||||
|
||||
cost := (float64(stats.InputTokens)*pricing.InputPrice +
|
||||
float64(stats.OutputTokens)*pricing.OutputPrice +
|
||||
float64(stats.CacheReadInputTokens)*pricing.CacheReadPrice +
|
||||
cacheCreationCost) / 1_000_000
|
||||
|
||||
return math.Round(cost*100) / 100
|
||||
}
|
||||
|
||||
func roundCost(cost float64) float64 {
|
||||
return math.Round(cost*100) / 100
|
||||
}
|
||||
|
||||
func normalizeCombinations(combinations []CostCombination) {
|
||||
for index := range combinations {
|
||||
if combinations[index].ByUser == nil {
|
||||
combinations[index].ByUser = make(map[string]UsageStats)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addUsageToCombinations(
|
||||
combinations *[]CostCombination,
|
||||
model string,
|
||||
contextWindow int,
|
||||
weekStartUnix int64,
|
||||
messagesCount int,
|
||||
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens, cacheCreation5MinuteTokens, cacheCreation1HourTokens int64,
|
||||
user string,
|
||||
) {
|
||||
var matchedCombination *CostCombination
|
||||
for index := range *combinations {
|
||||
combination := &(*combinations)[index]
|
||||
if combination.Model == model && combination.ContextWindow == contextWindow && combination.WeekStartUnix == weekStartUnix {
|
||||
matchedCombination = combination
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchedCombination == nil {
|
||||
newCombination := CostCombination{
|
||||
Model: model,
|
||||
ContextWindow: contextWindow,
|
||||
WeekStartUnix: weekStartUnix,
|
||||
Total: UsageStats{},
|
||||
ByUser: make(map[string]UsageStats),
|
||||
}
|
||||
*combinations = append(*combinations, newCombination)
|
||||
matchedCombination = &(*combinations)[len(*combinations)-1]
|
||||
}
|
||||
|
||||
if cacheCreationTokens == 0 {
|
||||
cacheCreationTokens = cacheCreation5MinuteTokens + cacheCreation1HourTokens
|
||||
}
|
||||
|
||||
matchedCombination.Total.RequestCount++
|
||||
matchedCombination.Total.MessagesCount += messagesCount
|
||||
matchedCombination.Total.InputTokens += inputTokens
|
||||
matchedCombination.Total.OutputTokens += outputTokens
|
||||
matchedCombination.Total.CacheReadInputTokens += cacheReadTokens
|
||||
matchedCombination.Total.CacheCreationInputTokens += cacheCreationTokens
|
||||
matchedCombination.Total.CacheCreation5MinuteInputTokens += cacheCreation5MinuteTokens
|
||||
matchedCombination.Total.CacheCreation1HourInputTokens += cacheCreation1HourTokens
|
||||
|
||||
if user != "" {
|
||||
userStats := matchedCombination.ByUser[user]
|
||||
userStats.RequestCount++
|
||||
userStats.MessagesCount += messagesCount
|
||||
userStats.InputTokens += inputTokens
|
||||
userStats.OutputTokens += outputTokens
|
||||
userStats.CacheReadInputTokens += cacheReadTokens
|
||||
userStats.CacheCreationInputTokens += cacheCreationTokens
|
||||
userStats.CacheCreation5MinuteInputTokens += cacheCreation5MinuteTokens
|
||||
userStats.CacheCreation1HourInputTokens += cacheCreation1HourTokens
|
||||
matchedCombination.ByUser[user] = userStats
|
||||
}
|
||||
}
|
||||
|
||||
func buildCombinationJSON(combinations []CostCombination, aggregateUserCosts map[string]float64) ([]CostCombinationJSON, float64) {
|
||||
result := make([]CostCombinationJSON, len(combinations))
|
||||
var totalCost float64
|
||||
|
||||
for index, combination := range combinations {
|
||||
combinationTotalCost := calculateCost(combination.Total, combination.Model, combination.ContextWindow)
|
||||
totalCost += combinationTotalCost
|
||||
|
||||
combinationJSON := CostCombinationJSON{
|
||||
Model: combination.Model,
|
||||
ContextWindow: combination.ContextWindow,
|
||||
WeekStartUnix: combination.WeekStartUnix,
|
||||
Total: UsageStatsJSON{
|
||||
RequestCount: combination.Total.RequestCount,
|
||||
MessagesCount: combination.Total.MessagesCount,
|
||||
InputTokens: combination.Total.InputTokens,
|
||||
OutputTokens: combination.Total.OutputTokens,
|
||||
CacheReadInputTokens: combination.Total.CacheReadInputTokens,
|
||||
CacheCreationInputTokens: combination.Total.CacheCreationInputTokens,
|
||||
CacheCreation5MinuteInputTokens: combination.Total.CacheCreation5MinuteInputTokens,
|
||||
CacheCreation1HourInputTokens: combination.Total.CacheCreation1HourInputTokens,
|
||||
CostUSD: combinationTotalCost,
|
||||
},
|
||||
ByUser: make(map[string]UsageStatsJSON),
|
||||
}
|
||||
|
||||
for user, userStats := range combination.ByUser {
|
||||
userCost := calculateCost(userStats, combination.Model, combination.ContextWindow)
|
||||
if aggregateUserCosts != nil {
|
||||
aggregateUserCosts[user] += userCost
|
||||
}
|
||||
|
||||
combinationJSON.ByUser[user] = UsageStatsJSON{
|
||||
RequestCount: userStats.RequestCount,
|
||||
MessagesCount: userStats.MessagesCount,
|
||||
InputTokens: userStats.InputTokens,
|
||||
OutputTokens: userStats.OutputTokens,
|
||||
CacheReadInputTokens: userStats.CacheReadInputTokens,
|
||||
CacheCreationInputTokens: userStats.CacheCreationInputTokens,
|
||||
CacheCreation5MinuteInputTokens: userStats.CacheCreation5MinuteInputTokens,
|
||||
CacheCreation1HourInputTokens: userStats.CacheCreation1HourInputTokens,
|
||||
CostUSD: userCost,
|
||||
}
|
||||
}
|
||||
|
||||
result[index] = combinationJSON
|
||||
}
|
||||
|
||||
return result, roundCost(totalCost)
|
||||
}
|
||||
|
||||
func formatUTCOffsetLabel(timestamp time.Time) string {
|
||||
_, offsetSeconds := timestamp.Zone()
|
||||
sign := "+"
|
||||
if offsetSeconds < 0 {
|
||||
sign = "-"
|
||||
offsetSeconds = -offsetSeconds
|
||||
}
|
||||
offsetHours := offsetSeconds / 3600
|
||||
offsetMinutes := (offsetSeconds % 3600) / 60
|
||||
if offsetMinutes == 0 {
|
||||
return fmt.Sprintf("UTC%s%d", sign, offsetHours)
|
||||
}
|
||||
return fmt.Sprintf("UTC%s%d:%02d", sign, offsetHours, offsetMinutes)
|
||||
}
|
||||
|
||||
func formatWeekStartKey(cycleStartAt time.Time) string {
|
||||
localCycleStart := cycleStartAt.In(time.Local)
|
||||
return fmt.Sprintf("%s %s", localCycleStart.Format("2006-01-02 15:04:05"), formatUTCOffsetLabel(localCycleStart))
|
||||
}
|
||||
|
||||
func buildByWeekCost(combinations []CostCombination) map[string]float64 {
|
||||
byWeek := make(map[string]float64)
|
||||
for _, combination := range combinations {
|
||||
if combination.WeekStartUnix <= 0 {
|
||||
continue
|
||||
}
|
||||
weekStartAt := time.Unix(combination.WeekStartUnix, 0).UTC()
|
||||
weekKey := formatWeekStartKey(weekStartAt)
|
||||
byWeek[weekKey] += calculateCost(combination.Total, combination.Model, combination.ContextWindow)
|
||||
}
|
||||
for weekKey, weekCost := range byWeek {
|
||||
byWeek[weekKey] = roundCost(weekCost)
|
||||
}
|
||||
return byWeek
|
||||
}
|
||||
|
||||
func deriveWeekStartUnix(cycleHint *WeeklyCycleHint) int64 {
|
||||
if cycleHint == nil || cycleHint.WindowMinutes <= 0 || cycleHint.ResetAt.IsZero() {
|
||||
return 0
|
||||
}
|
||||
windowDuration := time.Duration(cycleHint.WindowMinutes) * time.Minute
|
||||
return cycleHint.ResetAt.UTC().Add(-windowDuration).Unix()
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
result := &AggregatedUsageJSON{
|
||||
LastUpdated: u.LastUpdated,
|
||||
Costs: CostsSummaryJSON{
|
||||
TotalUSD: 0,
|
||||
ByUser: make(map[string]float64),
|
||||
ByWeek: make(map[string]float64),
|
||||
},
|
||||
}
|
||||
|
||||
globalCombinationsJSON, totalCost := buildCombinationJSON(u.Combinations, result.Costs.ByUser)
|
||||
result.Combinations = globalCombinationsJSON
|
||||
result.Costs.TotalUSD = totalCost
|
||||
result.Costs.ByWeek = buildByWeekCost(u.Combinations)
|
||||
|
||||
if len(result.Costs.ByWeek) == 0 {
|
||||
result.Costs.ByWeek = nil
|
||||
}
|
||||
|
||||
for user, cost := range result.Costs.ByUser {
|
||||
result.Costs.ByUser[user] = roundCost(cost)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) Load() error {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
u.LastUpdated = time.Time{}
|
||||
u.Combinations = nil
|
||||
|
||||
data, err := os.ReadFile(u.filePath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var temp struct {
|
||||
LastUpdated time.Time `json:"last_updated"`
|
||||
Combinations []CostCombination `json:"combinations"`
|
||||
}
|
||||
|
||||
err = json.Unmarshal(data, &temp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.LastUpdated = temp.LastUpdated
|
||||
u.Combinations = temp.Combinations
|
||||
normalizeCombinations(u.Combinations)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) Save() error {
|
||||
jsonData := u.ToJSON()
|
||||
|
||||
data, err := json.MarshalIndent(jsonData, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmpFile := u.filePath + ".tmp"
|
||||
err = os.WriteFile(tmpFile, data, 0o644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(tmpFile)
|
||||
err = os.Rename(tmpFile, u.filePath)
|
||||
if err == nil {
|
||||
u.saveMutex.Lock()
|
||||
u.lastSaveTime = time.Now()
|
||||
u.saveMutex.Unlock()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) AddUsage(
|
||||
model string,
|
||||
contextWindow int,
|
||||
messagesCount int,
|
||||
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens, cacheCreation5MinuteTokens, cacheCreation1HourTokens int64,
|
||||
user string,
|
||||
) error {
|
||||
return u.AddUsageWithCycleHint(model, contextWindow, messagesCount, inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens, cacheCreation5MinuteTokens, cacheCreation1HourTokens, user, time.Now(), nil)
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) AddUsageWithCycleHint(
|
||||
model string,
|
||||
contextWindow int,
|
||||
messagesCount int,
|
||||
inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens, cacheCreation5MinuteTokens, cacheCreation1HourTokens int64,
|
||||
user string,
|
||||
observedAt time.Time,
|
||||
cycleHint *WeeklyCycleHint,
|
||||
) error {
|
||||
if model == "" {
|
||||
return E.New("model cannot be empty")
|
||||
}
|
||||
if contextWindow <= 0 {
|
||||
return E.New("contextWindow must be positive")
|
||||
}
|
||||
if observedAt.IsZero() {
|
||||
observedAt = time.Now()
|
||||
}
|
||||
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
u.LastUpdated = observedAt
|
||||
weekStartUnix := deriveWeekStartUnix(cycleHint)
|
||||
|
||||
addUsageToCombinations(&u.Combinations, model, contextWindow, weekStartUnix, messagesCount, inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens, cacheCreation5MinuteTokens, cacheCreation1HourTokens, user)
|
||||
|
||||
go u.scheduleSave()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) scheduleSave() {
|
||||
const saveInterval = time.Minute
|
||||
|
||||
u.saveMutex.Lock()
|
||||
defer u.saveMutex.Unlock()
|
||||
|
||||
timeSinceLastSave := time.Since(u.lastSaveTime)
|
||||
|
||||
if timeSinceLastSave >= saveInterval {
|
||||
go u.saveAsync()
|
||||
return
|
||||
}
|
||||
|
||||
if u.pendingSave {
|
||||
return
|
||||
}
|
||||
|
||||
u.pendingSave = true
|
||||
remainingTime := saveInterval - timeSinceLastSave
|
||||
|
||||
u.saveTimer = time.AfterFunc(remainingTime, func() {
|
||||
u.saveMutex.Lock()
|
||||
u.pendingSave = false
|
||||
u.saveMutex.Unlock()
|
||||
u.saveAsync()
|
||||
})
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) saveAsync() {
|
||||
err := u.Save()
|
||||
if err != nil {
|
||||
if u.logger != nil {
|
||||
u.logger.Error("save usage statistics: ", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) cancelPendingSave() {
|
||||
u.saveMutex.Lock()
|
||||
defer u.saveMutex.Unlock()
|
||||
|
||||
if u.saveTimer != nil {
|
||||
u.saveTimer.Stop()
|
||||
u.saveTimer = nil
|
||||
}
|
||||
u.pendingSave = false
|
||||
}
|
||||
29
service/ccm/service_user.go
Normal file
29
service/ccm/service_user.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package ccm
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
type UserManager struct {
|
||||
accessMutex sync.RWMutex
|
||||
tokenMap map[string]string
|
||||
}
|
||||
|
||||
func (m *UserManager) UpdateUsers(users []option.CCMUser) {
|
||||
m.accessMutex.Lock()
|
||||
defer m.accessMutex.Unlock()
|
||||
tokenMap := make(map[string]string, len(users))
|
||||
for _, user := range users {
|
||||
tokenMap[user.Token] = user.Name
|
||||
}
|
||||
m.tokenMap = tokenMap
|
||||
}
|
||||
|
||||
func (m *UserManager) Authenticate(token string) (string, bool) {
|
||||
m.accessMutex.RLock()
|
||||
username, found := m.tokenMap[token]
|
||||
m.accessMutex.RUnlock()
|
||||
return username, found
|
||||
}
|
||||
@@ -36,9 +36,10 @@ import (
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
"github.com/sagernet/sing/service"
|
||||
"github.com/sagernet/sing/service/filemanager"
|
||||
"github.com/sagernet/tailscale/client/tailscale"
|
||||
"github.com/sagernet/tailscale/client/local"
|
||||
"github.com/sagernet/tailscale/derp"
|
||||
"github.com/sagernet/tailscale/derp/derphttp"
|
||||
"github.com/sagernet/tailscale/derp/derpserver"
|
||||
"github.com/sagernet/tailscale/net/netmon"
|
||||
"github.com/sagernet/tailscale/net/stun"
|
||||
"github.com/sagernet/tailscale/net/wsconn"
|
||||
@@ -62,7 +63,7 @@ type Service struct {
|
||||
listener *listener.Listener
|
||||
stunListener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
server *derp.Server
|
||||
server *derpserver.Server
|
||||
configPath string
|
||||
verifyClientEndpoint []string
|
||||
verifyClientURL []*option.DERPVerifyClientURLOptions
|
||||
@@ -141,7 +142,7 @@ func (d *Service) Start(stage adapter.StartStage) error {
|
||||
return err
|
||||
}
|
||||
|
||||
server := derp.NewServer(config.PrivateKey, func(format string, args ...any) {
|
||||
server := derpserver.New(config.PrivateKey, func(format string, args ...any) {
|
||||
d.logger.Debug(fmt.Sprintf(format, args...))
|
||||
})
|
||||
|
||||
@@ -193,7 +194,7 @@ func (d *Service) Start(stage adapter.StartStage) error {
|
||||
d.server = server
|
||||
|
||||
derpMux := http.NewServeMux()
|
||||
derpHandler := derphttp.Handler(server)
|
||||
derpHandler := derpserver.Handler(server)
|
||||
derpHandler = addWebSocketSupport(server, derpHandler)
|
||||
derpMux.Handle("/derp", derpHandler)
|
||||
|
||||
@@ -202,8 +203,8 @@ func (d *Service) Start(stage adapter.StartStage) error {
|
||||
return E.New("invalid home value: ", d.home)
|
||||
}
|
||||
|
||||
derpMux.HandleFunc("/derp/probe", derphttp.ProbeHandler)
|
||||
derpMux.HandleFunc("/derp/latency-check", derphttp.ProbeHandler)
|
||||
derpMux.HandleFunc("/derp/probe", derpserver.ProbeHandler)
|
||||
derpMux.HandleFunc("/derp/latency-check", derpserver.ProbeHandler)
|
||||
derpMux.HandleFunc("/bootstrap-dns", tsweb.BrowserHeaderHandlerFunc(handleBootstrapDNS(d.ctx)))
|
||||
derpMux.Handle("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
tsweb.AddBrowserHeaders(w)
|
||||
@@ -213,7 +214,7 @@ func (d *Service) Start(stage adapter.StartStage) error {
|
||||
tsweb.AddBrowserHeaders(w)
|
||||
io.WriteString(w, "User-agent: *\nDisallow: /\n")
|
||||
}))
|
||||
derpMux.Handle("/generate_204", http.HandlerFunc(derphttp.ServeNoContent))
|
||||
derpMux.Handle("/generate_204", http.HandlerFunc(derpserver.ServeNoContent))
|
||||
|
||||
err = d.tlsConfig.Start()
|
||||
if err != nil {
|
||||
@@ -244,7 +245,7 @@ func (d *Service) Start(stage adapter.StartStage) error {
|
||||
}
|
||||
case adapter.StartStatePostStart:
|
||||
if len(d.verifyClientEndpoint) > 0 {
|
||||
var endpoints []*tailscale.LocalClient
|
||||
var endpoints []*local.Client
|
||||
endpointManager := service.FromContext[adapter.EndpointManager](d.ctx)
|
||||
for _, endpointTag := range d.verifyClientEndpoint {
|
||||
endpoint, loaded := endpointManager.Get(endpointTag)
|
||||
@@ -289,7 +290,7 @@ func checkMeshKey(meshKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Service) startMeshWithHost(derpServer *derp.Server, server *option.DERPMeshOptions) error {
|
||||
func (d *Service) startMeshWithHost(derpServer *derpserver.Server, server *option.DERPMeshOptions) error {
|
||||
meshDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: d.ctx,
|
||||
Options: server.DialerOptions,
|
||||
@@ -307,11 +308,11 @@ func (d *Service) startMeshWithHost(derpServer *derp.Server, server *option.DERP
|
||||
}
|
||||
var stdConfig *tls.STDConfig
|
||||
if server.TLS != nil && server.TLS.Enabled {
|
||||
tlsConfig, err := tls.NewClient(d.ctx, hostname, common.PtrValueOrDefault(server.TLS))
|
||||
tlsConfig, err := tls.NewClient(d.ctx, d.logger, hostname, common.PtrValueOrDefault(server.TLS))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stdConfig, err = tlsConfig.Config()
|
||||
stdConfig, err = tlsConfig.STDConfig()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -343,7 +344,8 @@ func (d *Service) startMeshWithHost(derpServer *derp.Server, server *option.DERP
|
||||
})
|
||||
add := func(m derp.PeerPresentMessage) { derpServer.AddPacketForwarder(m.Key, meshClient) }
|
||||
remove := func(m derp.PeerGoneMessage) { derpServer.RemovePacketForwarder(m.Peer, meshClient) }
|
||||
go meshClient.RunWatchConnectionLoop(context.Background(), derpServer.PublicKey(), logf, add, remove)
|
||||
notifyError := func(err error) { d.logger.Error(err) }
|
||||
go meshClient.RunWatchConnectionLoop(context.Background(), derpServer.PublicKey(), logf, add, remove, notifyError)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -399,7 +401,7 @@ func getHomeHandler(val string) (_ http.Handler, ok bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func addWebSocketSupport(s *derp.Server, base http.Handler) http.Handler {
|
||||
func addWebSocketSupport(s *derpserver.Server, base http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
up := strings.ToLower(r.Header.Get("Upgrade"))
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
}
|
||||
creds := insecure.NewCredentials()
|
||||
if options.TLS != nil {
|
||||
tlsConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS))
|
||||
tlsConfig, err := tls.NewClient(ctx, logger, options.Server, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
173
service/ocm/credential.go
Normal file
173
service/ocm/credential.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
const (
|
||||
oauth2ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
oauth2TokenURL = "https://auth.openai.com/oauth/token"
|
||||
openaiAPIBaseURL = "https://api.openai.com"
|
||||
chatGPTBackendURL = "https://chatgpt.com/backend-api/codex"
|
||||
tokenRefreshIntervalDays = 8
|
||||
)
|
||||
|
||||
func getRealUser() (*user.User, error) {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
sudoUserInfo, err := user.Lookup(sudoUser)
|
||||
if err == nil {
|
||||
return sudoUserInfo, nil
|
||||
}
|
||||
}
|
||||
return user.Current()
|
||||
}
|
||||
|
||||
func getDefaultCredentialsPath() (string, error) {
|
||||
if codexHome := os.Getenv("CODEX_HOME"); codexHome != "" {
|
||||
return filepath.Join(codexHome, "auth.json"), nil
|
||||
}
|
||||
userInfo, err := getRealUser()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(userInfo.HomeDir, ".codex", "auth.json"), nil
|
||||
}
|
||||
|
||||
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var credentials oauthCredentials
|
||||
err = json.Unmarshal(data, &credentials)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &credentials, nil
|
||||
}
|
||||
|
||||
func writeCredentialsToFile(credentials *oauthCredentials, path string) error {
|
||||
data, err := json.MarshalIndent(credentials, "", " ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.WriteFile(path, data, 0o600)
|
||||
}
|
||||
|
||||
type oauthCredentials struct {
|
||||
APIKey string `json:"OPENAI_API_KEY,omitempty"`
|
||||
Tokens *tokenData `json:"tokens,omitempty"`
|
||||
LastRefresh *time.Time `json:"last_refresh,omitempty"`
|
||||
}
|
||||
|
||||
type tokenData struct {
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) isAPIKeyMode() bool {
|
||||
return c.APIKey != ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) getAccessToken() string {
|
||||
if c.APIKey != "" {
|
||||
return c.APIKey
|
||||
}
|
||||
if c.Tokens != nil {
|
||||
return c.Tokens.AccessToken
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) getAccountID() string {
|
||||
if c.Tokens != nil {
|
||||
return c.Tokens.AccountID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *oauthCredentials) needsRefresh() bool {
|
||||
if c.APIKey != "" {
|
||||
return false
|
||||
}
|
||||
if c.Tokens == nil || c.Tokens.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
if c.LastRefresh == nil {
|
||||
return true
|
||||
}
|
||||
return time.Since(*c.LastRefresh) >= time.Duration(tokenRefreshIntervalDays)*24*time.Hour
|
||||
}
|
||||
|
||||
func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
|
||||
if credentials.Tokens == nil || credentials.Tokens.RefreshToken == "" {
|
||||
return nil, E.New("refresh token is empty")
|
||||
}
|
||||
|
||||
requestBody, err := json.Marshal(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": credentials.Tokens.RefreshToken,
|
||||
"client_id": oauth2ClientID,
|
||||
"scope": "openid profile email",
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "marshal request")
|
||||
}
|
||||
|
||||
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json")
|
||||
|
||||
response, err := httpClient.Do(request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(response.Body)
|
||||
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
|
||||
}
|
||||
|
||||
var tokenResponse struct {
|
||||
IDToken string `json:"id_token"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "decode response")
|
||||
}
|
||||
|
||||
newCredentials := *credentials
|
||||
if newCredentials.Tokens == nil {
|
||||
newCredentials.Tokens = &tokenData{}
|
||||
}
|
||||
if tokenResponse.IDToken != "" {
|
||||
newCredentials.Tokens.IDToken = tokenResponse.IDToken
|
||||
}
|
||||
if tokenResponse.AccessToken != "" {
|
||||
newCredentials.Tokens.AccessToken = tokenResponse.AccessToken
|
||||
}
|
||||
if tokenResponse.RefreshToken != "" {
|
||||
newCredentials.Tokens.RefreshToken = tokenResponse.RefreshToken
|
||||
}
|
||||
now := time.Now()
|
||||
newCredentials.LastRefresh = &now
|
||||
|
||||
return &newCredentials, nil
|
||||
}
|
||||
25
service/ocm/credential_darwin.go
Normal file
25
service/ocm/credential_darwin.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build darwin
|
||||
|
||||
package ocm
|
||||
|
||||
func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return readCredentialsFromFile(customPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return writeCredentialsToFile(credentials, customPath)
|
||||
}
|
||||
25
service/ocm/credential_other.go
Normal file
25
service/ocm/credential_other.go
Normal file
@@ -0,0 +1,25 @@
|
||||
//go:build !darwin
|
||||
|
||||
package ocm
|
||||
|
||||
func platformReadCredentials(customPath string) (*oauthCredentials, error) {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return readCredentialsFromFile(customPath)
|
||||
}
|
||||
|
||||
func platformWriteCredentials(credentials *oauthCredentials, customPath string) error {
|
||||
if customPath == "" {
|
||||
var err error
|
||||
customPath, err = getDefaultCredentialsPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return writeCredentialsToFile(credentials, customPath)
|
||||
}
|
||||
642
service/ocm/service.go
Normal file
642
service/ocm/service.go
Normal file
@@ -0,0 +1,642 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"mime"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||
"github.com/sagernet/sing-box/common/dialer"
|
||||
"github.com/sagernet/sing-box/common/listener"
|
||||
"github.com/sagernet/sing-box/common/tls"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
"github.com/sagernet/sing/common/buf"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
aTLS "github.com/sagernet/sing/common/tls"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
func RegisterService(registry *boxService.Registry) {
|
||||
boxService.Register[option.OCMServiceOptions](registry, C.TypeOCM, NewService)
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
Error errorDetails `json:"error"`
|
||||
}
|
||||
|
||||
type errorDetails struct {
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code,omitempty"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
json.NewEncoder(w).Encode(errorResponse{
|
||||
Error: errorDetails{
|
||||
Type: errorType,
|
||||
Message: message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func isHopByHopHeader(header string) bool {
|
||||
switch strings.ToLower(header) {
|
||||
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeRateLimitIdentifier(limitIdentifier string) string {
|
||||
trimmedIdentifier := strings.TrimSpace(strings.ToLower(limitIdentifier))
|
||||
if trimmedIdentifier == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ReplaceAll(trimmedIdentifier, "_", "-")
|
||||
}
|
||||
|
||||
func parseInt64Header(headers http.Header, headerName string) (int64, bool) {
|
||||
headerValue := strings.TrimSpace(headers.Get(headerName))
|
||||
if headerValue == "" {
|
||||
return 0, false
|
||||
}
|
||||
parsedValue, parseError := strconv.ParseInt(headerValue, 10, 64)
|
||||
if parseError != nil {
|
||||
return 0, false
|
||||
}
|
||||
return parsedValue, true
|
||||
}
|
||||
|
||||
func weeklyCycleHintForLimit(headers http.Header, limitIdentifier string) *WeeklyCycleHint {
|
||||
normalizedLimitIdentifier := normalizeRateLimitIdentifier(limitIdentifier)
|
||||
if normalizedLimitIdentifier == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
windowHeader := "x-" + normalizedLimitIdentifier + "-secondary-window-minutes"
|
||||
resetHeader := "x-" + normalizedLimitIdentifier + "-secondary-reset-at"
|
||||
|
||||
windowMinutes, hasWindowMinutes := parseInt64Header(headers, windowHeader)
|
||||
resetAtUnix, hasResetAt := parseInt64Header(headers, resetHeader)
|
||||
if !hasWindowMinutes || !hasResetAt || windowMinutes <= 0 || resetAtUnix <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &WeeklyCycleHint{
|
||||
WindowMinutes: windowMinutes,
|
||||
ResetAt: time.Unix(resetAtUnix, 0).UTC(),
|
||||
}
|
||||
}
|
||||
|
||||
func extractWeeklyCycleHint(headers http.Header) *WeeklyCycleHint {
|
||||
activeLimitIdentifier := normalizeRateLimitIdentifier(headers.Get("x-codex-active-limit"))
|
||||
if activeLimitIdentifier != "" {
|
||||
if activeHint := weeklyCycleHintForLimit(headers, activeLimitIdentifier); activeHint != nil {
|
||||
return activeHint
|
||||
}
|
||||
}
|
||||
return weeklyCycleHintForLimit(headers, "codex")
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
boxService.Adapter
|
||||
ctx context.Context
|
||||
logger log.ContextLogger
|
||||
credentialPath string
|
||||
credentials *oauthCredentials
|
||||
users []option.OCMUser
|
||||
httpClient *http.Client
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
tlsConfig tls.ServerConfig
|
||||
httpServer *http.Server
|
||||
userManager *UserManager
|
||||
accessMutex sync.RWMutex
|
||||
usageTracker *AggregatedUsage
|
||||
trackingGroup sync.WaitGroup
|
||||
shuttingDown bool
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OCMServiceOptions) (adapter.Service, error) {
|
||||
serviceDialer, err := dialer.NewWithOptions(dialer.Options{
|
||||
Context: ctx,
|
||||
Options: option.DialerOptions{
|
||||
Detour: options.Detour,
|
||||
},
|
||||
RemoteIsDomain: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, E.Cause(err, "create dialer")
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: &http.Transport{
|
||||
ForceAttemptHTTP2: true,
|
||||
TLSClientConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(ctx),
|
||||
Time: ntp.TimeFuncFromContext(ctx),
|
||||
},
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
userManager := &UserManager{
|
||||
tokenMap: make(map[string]string),
|
||||
}
|
||||
|
||||
var usageTracker *AggregatedUsage
|
||||
if options.UsagesPath != "" {
|
||||
usageTracker = &AggregatedUsage{
|
||||
LastUpdated: time.Now(),
|
||||
Combinations: make([]CostCombination, 0),
|
||||
filePath: options.UsagesPath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
service := &Service{
|
||||
Adapter: boxService.NewAdapter(C.TypeOCM, tag),
|
||||
ctx: ctx,
|
||||
logger: logger,
|
||||
credentialPath: options.CredentialPath,
|
||||
users: options.Users,
|
||||
httpClient: httpClient,
|
||||
httpHeaders: options.Headers.Build(),
|
||||
listener: listener.New(listener.Options{
|
||||
Context: ctx,
|
||||
Logger: logger,
|
||||
Network: []string{N.NetworkTCP},
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
userManager: userManager,
|
||||
usageTracker: usageTracker,
|
||||
}
|
||||
|
||||
if options.TLS != nil {
|
||||
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
service.tlsConfig = tlsConfig
|
||||
}
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *Service) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.userManager.UpdateUsers(s.users)
|
||||
|
||||
credentials, err := platformReadCredentials(s.credentialPath)
|
||||
if err != nil {
|
||||
return E.Cause(err, "read credentials")
|
||||
}
|
||||
s.credentials = credentials
|
||||
|
||||
if s.usageTracker != nil {
|
||||
err = s.usageTracker.Load()
|
||||
if err != nil {
|
||||
s.logger.Warn("load usage statistics: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
router.Mount("/", s)
|
||||
|
||||
s.httpServer = &http.Server{Handler: router}
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
err = s.tlsConfig.Start()
|
||||
if err != nil {
|
||||
return E.Cause(err, "create TLS config")
|
||||
}
|
||||
}
|
||||
|
||||
tcpListener, err := s.listener.ListenTCP()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.tlsConfig != nil {
|
||||
if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
|
||||
s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
|
||||
}
|
||||
tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig)
|
||||
}
|
||||
|
||||
go func() {
|
||||
serveErr := s.httpServer.Serve(tcpListener)
|
||||
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
|
||||
s.logger.Error("serve error: ", serveErr)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) getAccessToken() (string, error) {
|
||||
s.accessMutex.RLock()
|
||||
if !s.credentials.needsRefresh() {
|
||||
token := s.credentials.getAccessToken()
|
||||
s.accessMutex.RUnlock()
|
||||
return token, nil
|
||||
}
|
||||
s.accessMutex.RUnlock()
|
||||
|
||||
s.accessMutex.Lock()
|
||||
defer s.accessMutex.Unlock()
|
||||
|
||||
if !s.credentials.needsRefresh() {
|
||||
return s.credentials.getAccessToken(), nil
|
||||
}
|
||||
|
||||
newCredentials, err := refreshToken(s.httpClient, s.credentials)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
s.credentials = newCredentials
|
||||
|
||||
err = platformWriteCredentials(newCredentials, s.credentialPath)
|
||||
if err != nil {
|
||||
s.logger.Warn("persist refreshed token: ", err)
|
||||
}
|
||||
|
||||
return newCredentials.getAccessToken(), nil
|
||||
}
|
||||
|
||||
func (s *Service) getAccountID() string {
|
||||
s.accessMutex.RLock()
|
||||
defer s.accessMutex.RUnlock()
|
||||
return s.credentials.getAccountID()
|
||||
}
|
||||
|
||||
func (s *Service) isAPIKeyMode() bool {
|
||||
s.accessMutex.RLock()
|
||||
defer s.accessMutex.RUnlock()
|
||||
return s.credentials.isAPIKeyMode()
|
||||
}
|
||||
|
||||
func (s *Service) getBaseURL() string {
|
||||
if s.isAPIKeyMode() {
|
||||
return openaiAPIBaseURL
|
||||
}
|
||||
return chatGPTBackendURL
|
||||
}
|
||||
|
||||
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
path := r.URL.Path
|
||||
if !strings.HasPrefix(path, "/v1/") {
|
||||
writeJSONError(w, r, http.StatusNotFound, "invalid_request_error", "path must start with /v1/")
|
||||
return
|
||||
}
|
||||
|
||||
var proxyPath string
|
||||
if s.isAPIKeyMode() {
|
||||
proxyPath = path
|
||||
} else {
|
||||
if path == "/v1/chat/completions" {
|
||||
writeJSONError(w, r, http.StatusBadRequest, "invalid_request_error",
|
||||
"chat completions endpoint is only available in API key mode")
|
||||
return
|
||||
}
|
||||
proxyPath = strings.TrimPrefix(path, "/v1")
|
||||
}
|
||||
|
||||
var username string
|
||||
if len(s.users) > 0 {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
|
||||
return
|
||||
}
|
||||
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if clientToken == authHeader {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
|
||||
return
|
||||
}
|
||||
var ok bool
|
||||
username, ok = s.userManager.Authenticate(clientToken)
|
||||
if !ok {
|
||||
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var requestModel string
|
||||
|
||||
if s.usageTracker != nil && r.Body != nil {
|
||||
bodyBytes, err := io.ReadAll(r.Body)
|
||||
if err == nil {
|
||||
var request struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
err := json.Unmarshal(bodyBytes, &request)
|
||||
if err == nil {
|
||||
requestModel = request.Model
|
||||
}
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
accessToken, err := s.getAccessToken()
|
||||
if err != nil {
|
||||
s.logger.Error("get access token: ", err)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
|
||||
return
|
||||
}
|
||||
|
||||
proxyURL := s.getBaseURL() + proxyPath
|
||||
if r.URL.RawQuery != "" {
|
||||
proxyURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("create proxy request: ", err)
|
||||
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
|
||||
return
|
||||
}
|
||||
|
||||
for key, values := range r.Header {
|
||||
if !isHopByHopHeader(key) && key != "Authorization" {
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
for key, values := range s.httpHeaders {
|
||||
proxyRequest.Header.Del(key)
|
||||
proxyRequest.Header[key] = values
|
||||
}
|
||||
|
||||
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
|
||||
if accountID := s.getAccountID(); accountID != "" {
|
||||
proxyRequest.Header.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
|
||||
response, err := s.httpClient.Do(proxyRequest)
|
||||
if err != nil {
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
|
||||
return
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
for key, values := range response.Header {
|
||||
if !isHopByHopHeader(key) {
|
||||
w.Header()[key] = values
|
||||
}
|
||||
}
|
||||
w.WriteHeader(response.StatusCode)
|
||||
|
||||
trackUsage := s.usageTracker != nil && response.StatusCode == http.StatusOK &&
|
||||
(path == "/v1/chat/completions" || strings.HasPrefix(path, "/v1/responses"))
|
||||
if trackUsage {
|
||||
s.handleResponseWithTracking(w, response, path, requestModel, username)
|
||||
} else {
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
if err == nil && mediaType != "text/event-stream" {
|
||||
_, _ = io.Copy(w, response.Body)
|
||||
return
|
||||
}
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
return
|
||||
}
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
_, writeError := w.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, path string, requestModel string, username string) {
|
||||
isChatCompletions := path == "/v1/chat/completions"
|
||||
weeklyCycleHint := extractWeeklyCycleHint(response.Header)
|
||||
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
|
||||
isStreaming := err == nil && mediaType == "text/event-stream"
|
||||
if !isStreaming && !isChatCompletions && response.Header.Get("Content-Type") == "" {
|
||||
isStreaming = true
|
||||
}
|
||||
if !isStreaming {
|
||||
bodyBytes, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("read response body: ", err)
|
||||
return
|
||||
}
|
||||
|
||||
var responseModel, serviceTier string
|
||||
var inputTokens, outputTokens, cachedTokens int64
|
||||
|
||||
if isChatCompletions {
|
||||
var chatCompletion openai.ChatCompletion
|
||||
if json.Unmarshal(bodyBytes, &chatCompletion) == nil {
|
||||
responseModel = chatCompletion.Model
|
||||
serviceTier = string(chatCompletion.ServiceTier)
|
||||
inputTokens = chatCompletion.Usage.PromptTokens
|
||||
outputTokens = chatCompletion.Usage.CompletionTokens
|
||||
cachedTokens = chatCompletion.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
} else {
|
||||
var responsesResponse responses.Response
|
||||
if json.Unmarshal(bodyBytes, &responsesResponse) == nil {
|
||||
responseModel = string(responsesResponse.Model)
|
||||
serviceTier = string(responsesResponse.ServiceTier)
|
||||
inputTokens = responsesResponse.Usage.InputTokens
|
||||
outputTokens = responsesResponse.Usage.OutputTokens
|
||||
cachedTokens = responsesResponse.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
}
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
if responseModel != "" {
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
_, _ = writer.Write(bodyBytes)
|
||||
return
|
||||
}
|
||||
|
||||
flusher, ok := writer.(http.Flusher)
|
||||
if !ok {
|
||||
s.logger.Error("streaming not supported")
|
||||
return
|
||||
}
|
||||
|
||||
var inputTokens, outputTokens, cachedTokens int64
|
||||
var responseModel, serviceTier string
|
||||
buffer := make([]byte, buf.BufferSize)
|
||||
var leftover []byte
|
||||
|
||||
for {
|
||||
n, err := response.Body.Read(buffer)
|
||||
if n > 0 {
|
||||
data := append(leftover, buffer[:n]...)
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
|
||||
if err == nil {
|
||||
leftover = lines[len(lines)-1]
|
||||
lines = lines[:len(lines)-1]
|
||||
} else {
|
||||
leftover = nil
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
eventData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if bytes.Equal(eventData, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
|
||||
if isChatCompletions {
|
||||
var chatChunk openai.ChatCompletionChunk
|
||||
if json.Unmarshal(eventData, &chatChunk) == nil {
|
||||
if chatChunk.Model != "" {
|
||||
responseModel = chatChunk.Model
|
||||
}
|
||||
if chatChunk.ServiceTier != "" {
|
||||
serviceTier = string(chatChunk.ServiceTier)
|
||||
}
|
||||
if chatChunk.Usage.PromptTokens > 0 {
|
||||
inputTokens = chatChunk.Usage.PromptTokens
|
||||
cachedTokens = chatChunk.Usage.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if chatChunk.Usage.CompletionTokens > 0 {
|
||||
outputTokens = chatChunk.Usage.CompletionTokens
|
||||
}
|
||||
}
|
||||
} else {
|
||||
var streamEvent responses.ResponseStreamEventUnion
|
||||
if json.Unmarshal(eventData, &streamEvent) == nil {
|
||||
if streamEvent.Type == "response.completed" {
|
||||
completedEvent := streamEvent.AsResponseCompleted()
|
||||
if string(completedEvent.Response.Model) != "" {
|
||||
responseModel = string(completedEvent.Response.Model)
|
||||
}
|
||||
if completedEvent.Response.ServiceTier != "" {
|
||||
serviceTier = string(completedEvent.Response.ServiceTier)
|
||||
}
|
||||
if completedEvent.Response.Usage.InputTokens > 0 {
|
||||
inputTokens = completedEvent.Response.Usage.InputTokens
|
||||
cachedTokens = completedEvent.Response.Usage.InputTokensDetails.CachedTokens
|
||||
}
|
||||
if completedEvent.Response.Usage.OutputTokens > 0 {
|
||||
outputTokens = completedEvent.Response.Usage.OutputTokens
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_, writeError := writer.Write(buffer[:n])
|
||||
if writeError != nil {
|
||||
s.logger.Error("write streaming response: ", writeError)
|
||||
return
|
||||
}
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
err := common.Close(
|
||||
common.PtrOrNil(s.httpServer),
|
||||
common.PtrOrNil(s.listener),
|
||||
s.tlsConfig,
|
||||
)
|
||||
|
||||
if s.usageTracker != nil {
|
||||
s.usageTracker.cancelPendingSave()
|
||||
saveErr := s.usageTracker.Save()
|
||||
if saveErr != nil {
|
||||
s.logger.Error("save usage statistics: ", saveErr)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
1032
service/ocm/service_usage.go
Normal file
1032
service/ocm/service_usage.go
Normal file
File diff suppressed because it is too large
Load Diff
29
service/ocm/service_user.go
Normal file
29
service/ocm/service_user.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
)
|
||||
|
||||
type UserManager struct {
|
||||
accessMutex sync.RWMutex
|
||||
tokenMap map[string]string
|
||||
}
|
||||
|
||||
func (m *UserManager) UpdateUsers(users []option.OCMUser) {
|
||||
m.accessMutex.Lock()
|
||||
defer m.accessMutex.Unlock()
|
||||
tokenMap := make(map[string]string, len(users))
|
||||
for _, user := range users {
|
||||
tokenMap[user.Token] = user.Name
|
||||
}
|
||||
m.tokenMap = tokenMap
|
||||
}
|
||||
|
||||
func (m *UserManager) Authenticate(token string) (string, bool) {
|
||||
m.accessMutex.RLock()
|
||||
username, found := m.tokenMap[token]
|
||||
m.accessMutex.RUnlock()
|
||||
return username, found
|
||||
}
|
||||
51
service/oomkiller/config.go
Normal file
51
service/oomkiller/config.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package oomkiller
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
func buildTimerConfig(options option.OOMKillerServiceOptions, memoryLimit uint64, useAvailable bool) (timerConfig, error) {
|
||||
safetyMargin := uint64(defaultSafetyMargin)
|
||||
if options.SafetyMargin != nil && options.SafetyMargin.Value() > 0 {
|
||||
safetyMargin = options.SafetyMargin.Value()
|
||||
}
|
||||
|
||||
minInterval := defaultMinInterval
|
||||
if options.MinInterval != 0 {
|
||||
minInterval = time.Duration(options.MinInterval.Build())
|
||||
if minInterval <= 0 {
|
||||
return timerConfig{}, E.New("min_interval must be greater than 0")
|
||||
}
|
||||
}
|
||||
|
||||
maxInterval := defaultMaxInterval
|
||||
if options.MaxInterval != 0 {
|
||||
maxInterval = time.Duration(options.MaxInterval.Build())
|
||||
if maxInterval <= 0 {
|
||||
return timerConfig{}, E.New("max_interval must be greater than 0")
|
||||
}
|
||||
}
|
||||
if maxInterval < minInterval {
|
||||
return timerConfig{}, E.New("max_interval must be greater than or equal to min_interval")
|
||||
}
|
||||
|
||||
checksBeforeLimit := defaultChecksBeforeLimit
|
||||
if options.ChecksBeforeLimit != 0 {
|
||||
checksBeforeLimit = options.ChecksBeforeLimit
|
||||
if checksBeforeLimit <= 0 {
|
||||
return timerConfig{}, E.New("checks_before_limit must be greater than 0")
|
||||
}
|
||||
}
|
||||
|
||||
return timerConfig{
|
||||
memoryLimit: memoryLimit,
|
||||
safetyMargin: safetyMargin,
|
||||
minInterval: minInterval,
|
||||
maxInterval: maxInterval,
|
||||
checksBeforeLimit: checksBeforeLimit,
|
||||
useAvailable: useAvailable,
|
||||
}, nil
|
||||
}
|
||||
192
service/oomkiller/service.go
Normal file
192
service/oomkiller/service.go
Normal file
@@ -0,0 +1,192 @@
|
||||
//go:build darwin && cgo
|
||||
|
||||
package oomkiller
|
||||
|
||||
/*
|
||||
#include <dispatch/dispatch.h>
|
||||
|
||||
static dispatch_source_t memoryPressureSource;
|
||||
|
||||
extern void goMemoryPressureCallback(unsigned long status);
|
||||
|
||||
static void startMemoryPressureMonitor() {
|
||||
memoryPressureSource = dispatch_source_create(
|
||||
DISPATCH_SOURCE_TYPE_MEMORYPRESSURE,
|
||||
0,
|
||||
DISPATCH_MEMORYPRESSURE_WARN | DISPATCH_MEMORYPRESSURE_CRITICAL,
|
||||
dispatch_get_global_queue(QOS_CLASS_DEFAULT, 0)
|
||||
);
|
||||
dispatch_source_set_event_handler(memoryPressureSource, ^{
|
||||
unsigned long status = dispatch_source_get_data(memoryPressureSource);
|
||||
goMemoryPressureCallback(status);
|
||||
});
|
||||
dispatch_activate(memoryPressureSource);
|
||||
}
|
||||
|
||||
static void stopMemoryPressureMonitor() {
|
||||
if (memoryPressureSource) {
|
||||
dispatch_source_cancel(memoryPressureSource);
|
||||
memoryPressureSource = NULL;
|
||||
}
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
runtimeDebug "runtime/debug"
|
||||
"sync"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||
boxConstant "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common/memory"
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
func RegisterService(registry *boxService.Registry) {
|
||||
boxService.Register[option.OOMKillerServiceOptions](registry, boxConstant.TypeOOMKiller, NewService)
|
||||
}
|
||||
|
||||
var (
|
||||
globalAccess sync.Mutex
|
||||
globalServices []*Service
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
boxService.Adapter
|
||||
logger log.ContextLogger
|
||||
router adapter.Router
|
||||
memoryLimit uint64
|
||||
hasTimerMode bool
|
||||
useAvailable bool
|
||||
timerConfig timerConfig
|
||||
adaptiveTimer *adaptiveTimer
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OOMKillerServiceOptions) (adapter.Service, error) {
|
||||
s := &Service{
|
||||
Adapter: boxService.NewAdapter(boxConstant.TypeOOMKiller, tag),
|
||||
logger: logger,
|
||||
router: service.FromContext[adapter.Router](ctx),
|
||||
}
|
||||
|
||||
if options.MemoryLimit != nil {
|
||||
s.memoryLimit = options.MemoryLimit.Value()
|
||||
if s.memoryLimit > 0 {
|
||||
s.hasTimerMode = true
|
||||
}
|
||||
}
|
||||
|
||||
config, err := buildTimerConfig(options, s.memoryLimit, s.useAvailable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.timerConfig = config
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Service) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.hasTimerMode {
|
||||
s.adaptiveTimer = newAdaptiveTimer(s.logger, s.router, s.timerConfig)
|
||||
if s.memoryLimit > 0 {
|
||||
s.logger.Info("started memory monitor with limit: ", s.memoryLimit/(1024*1024), " MiB")
|
||||
} else {
|
||||
s.logger.Info("started memory monitor with available memory detection")
|
||||
}
|
||||
} else {
|
||||
s.logger.Info("started memory pressure monitor")
|
||||
}
|
||||
|
||||
globalAccess.Lock()
|
||||
isFirst := len(globalServices) == 0
|
||||
globalServices = append(globalServices, s)
|
||||
globalAccess.Unlock()
|
||||
|
||||
if isFirst {
|
||||
C.startMemoryPressureMonitor()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
if s.adaptiveTimer != nil {
|
||||
s.adaptiveTimer.stop()
|
||||
}
|
||||
globalAccess.Lock()
|
||||
for i, svc := range globalServices {
|
||||
if svc == s {
|
||||
globalServices = append(globalServices[:i], globalServices[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
isLast := len(globalServices) == 0
|
||||
globalAccess.Unlock()
|
||||
if isLast {
|
||||
C.stopMemoryPressureMonitor()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//export goMemoryPressureCallback
|
||||
func goMemoryPressureCallback(status C.ulong) {
|
||||
globalAccess.Lock()
|
||||
services := make([]*Service, len(globalServices))
|
||||
copy(services, globalServices)
|
||||
globalAccess.Unlock()
|
||||
if len(services) == 0 {
|
||||
return
|
||||
}
|
||||
criticalFlag := C.ulong(C.DISPATCH_MEMORYPRESSURE_CRITICAL)
|
||||
warnFlag := C.ulong(C.DISPATCH_MEMORYPRESSURE_WARN)
|
||||
isCritical := status&criticalFlag != 0
|
||||
isWarning := status&warnFlag != 0
|
||||
var level string
|
||||
switch {
|
||||
case isCritical:
|
||||
level = "critical"
|
||||
case isWarning:
|
||||
level = "warning"
|
||||
default:
|
||||
level = "normal"
|
||||
}
|
||||
var freeOSMemory bool
|
||||
for _, s := range services {
|
||||
usage := memory.Total()
|
||||
if s.hasTimerMode {
|
||||
if isCritical {
|
||||
s.logger.Warn("memory pressure: ", level, ", usage: ", usage/(1024*1024), " MiB")
|
||||
if s.adaptiveTimer != nil {
|
||||
s.adaptiveTimer.startNow()
|
||||
}
|
||||
} else if isWarning {
|
||||
s.logger.Warn("memory pressure: ", level, ", usage: ", usage/(1024*1024), " MiB")
|
||||
} else {
|
||||
s.logger.Debug("memory pressure: ", level, ", usage: ", usage/(1024*1024), " MiB")
|
||||
if s.adaptiveTimer != nil {
|
||||
s.adaptiveTimer.stop()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if isCritical {
|
||||
s.logger.Error("memory pressure: ", level, ", usage: ", usage/(1024*1024), " MiB, resetting network")
|
||||
s.router.ResetNetwork()
|
||||
freeOSMemory = true
|
||||
} else if isWarning {
|
||||
s.logger.Warn("memory pressure: ", level, ", usage: ", usage/(1024*1024), " MiB")
|
||||
} else {
|
||||
s.logger.Debug("memory pressure: ", level, ", usage: ", usage/(1024*1024), " MiB")
|
||||
}
|
||||
}
|
||||
}
|
||||
if freeOSMemory {
|
||||
runtimeDebug.FreeOSMemory()
|
||||
}
|
||||
}
|
||||
81
service/oomkiller/service_stub.go
Normal file
81
service/oomkiller/service_stub.go
Normal file
@@ -0,0 +1,81 @@
|
||||
//go:build !darwin || !cgo
|
||||
|
||||
package oomkiller
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
boxService "github.com/sagernet/sing-box/adapter/service"
|
||||
boxConstant "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/memory"
|
||||
"github.com/sagernet/sing/service"
|
||||
)
|
||||
|
||||
func RegisterService(registry *boxService.Registry) {
|
||||
boxService.Register[option.OOMKillerServiceOptions](registry, boxConstant.TypeOOMKiller, NewService)
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
boxService.Adapter
|
||||
logger log.ContextLogger
|
||||
router adapter.Router
|
||||
adaptiveTimer *adaptiveTimer
|
||||
timerConfig timerConfig
|
||||
hasTimerMode bool
|
||||
useAvailable bool
|
||||
memoryLimit uint64
|
||||
}
|
||||
|
||||
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.OOMKillerServiceOptions) (adapter.Service, error) {
|
||||
s := &Service{
|
||||
Adapter: boxService.NewAdapter(boxConstant.TypeOOMKiller, tag),
|
||||
logger: logger,
|
||||
router: service.FromContext[adapter.Router](ctx),
|
||||
}
|
||||
|
||||
if options.MemoryLimit != nil {
|
||||
s.memoryLimit = options.MemoryLimit.Value()
|
||||
}
|
||||
if s.memoryLimit > 0 {
|
||||
s.hasTimerMode = true
|
||||
} else if memory.AvailableSupported() {
|
||||
s.useAvailable = true
|
||||
s.hasTimerMode = true
|
||||
}
|
||||
|
||||
config, err := buildTimerConfig(options, s.memoryLimit, s.useAvailable)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.timerConfig = config
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *Service) Start(stage adapter.StartStage) error {
|
||||
if stage != adapter.StartStateStart {
|
||||
return nil
|
||||
}
|
||||
if !s.hasTimerMode {
|
||||
return E.New("memory pressure monitoring is not available on this platform without memory_limit")
|
||||
}
|
||||
s.adaptiveTimer = newAdaptiveTimer(s.logger, s.router, s.timerConfig)
|
||||
s.adaptiveTimer.start(0)
|
||||
if s.useAvailable {
|
||||
s.logger.Info("started memory monitor with available memory detection")
|
||||
} else {
|
||||
s.logger.Info("started memory monitor with limit: ", s.memoryLimit/(1024*1024), " MiB")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
if s.adaptiveTimer != nil {
|
||||
s.adaptiveTimer.stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
158
service/oomkiller/service_timer.go
Normal file
158
service/oomkiller/service_timer.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package oomkiller
|
||||
|
||||
import (
|
||||
runtimeDebug "runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing/common/memory"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultChecksBeforeLimit = 4
|
||||
defaultMinInterval = 500 * time.Millisecond
|
||||
defaultMaxInterval = 10 * time.Second
|
||||
defaultSafetyMargin = 5 * 1024 * 1024
|
||||
)
|
||||
|
||||
type adaptiveTimer struct {
|
||||
logger log.ContextLogger
|
||||
router adapter.Router
|
||||
memoryLimit uint64
|
||||
safetyMargin uint64
|
||||
minInterval time.Duration
|
||||
maxInterval time.Duration
|
||||
checksBeforeLimit int
|
||||
useAvailable bool
|
||||
|
||||
access sync.Mutex
|
||||
timer *time.Timer
|
||||
previousUsage uint64
|
||||
lastInterval time.Duration
|
||||
}
|
||||
|
||||
type timerConfig struct {
|
||||
memoryLimit uint64
|
||||
safetyMargin uint64
|
||||
minInterval time.Duration
|
||||
maxInterval time.Duration
|
||||
checksBeforeLimit int
|
||||
useAvailable bool
|
||||
}
|
||||
|
||||
func newAdaptiveTimer(logger log.ContextLogger, router adapter.Router, config timerConfig) *adaptiveTimer {
|
||||
return &adaptiveTimer{
|
||||
logger: logger,
|
||||
router: router,
|
||||
memoryLimit: config.memoryLimit,
|
||||
safetyMargin: config.safetyMargin,
|
||||
minInterval: config.minInterval,
|
||||
maxInterval: config.maxInterval,
|
||||
checksBeforeLimit: config.checksBeforeLimit,
|
||||
useAvailable: config.useAvailable,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *adaptiveTimer) start(_ uint64) {
|
||||
t.access.Lock()
|
||||
defer t.access.Unlock()
|
||||
t.startLocked()
|
||||
}
|
||||
|
||||
func (t *adaptiveTimer) startNow() {
|
||||
t.access.Lock()
|
||||
t.startLocked()
|
||||
t.access.Unlock()
|
||||
t.poll()
|
||||
}
|
||||
|
||||
func (t *adaptiveTimer) startLocked() {
|
||||
if t.timer != nil {
|
||||
return
|
||||
}
|
||||
t.previousUsage = memory.Total()
|
||||
t.lastInterval = t.minInterval
|
||||
t.timer = time.AfterFunc(t.minInterval, t.poll)
|
||||
}
|
||||
|
||||
func (t *adaptiveTimer) stop() {
|
||||
t.access.Lock()
|
||||
defer t.access.Unlock()
|
||||
t.stopLocked()
|
||||
}
|
||||
|
||||
func (t *adaptiveTimer) stopLocked() {
|
||||
if t.timer != nil {
|
||||
t.timer.Stop()
|
||||
t.timer = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *adaptiveTimer) running() bool {
|
||||
t.access.Lock()
|
||||
defer t.access.Unlock()
|
||||
return t.timer != nil
|
||||
}
|
||||
|
||||
func (t *adaptiveTimer) poll() {
|
||||
t.access.Lock()
|
||||
defer t.access.Unlock()
|
||||
if t.timer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage := memory.Total()
|
||||
delta := int64(usage) - int64(t.previousUsage)
|
||||
t.previousUsage = usage
|
||||
|
||||
var remaining uint64
|
||||
var triggered bool
|
||||
|
||||
if t.memoryLimit > 0 {
|
||||
if usage >= t.memoryLimit {
|
||||
remaining = 0
|
||||
triggered = true
|
||||
} else {
|
||||
remaining = t.memoryLimit - usage
|
||||
}
|
||||
} else if t.useAvailable {
|
||||
available := memory.Available()
|
||||
if available <= t.safetyMargin {
|
||||
remaining = 0
|
||||
triggered = true
|
||||
} else {
|
||||
remaining = available - t.safetyMargin
|
||||
}
|
||||
} else {
|
||||
remaining = 0
|
||||
}
|
||||
|
||||
if triggered {
|
||||
t.logger.Error("memory threshold reached, usage: ", usage/(1024*1024), " MiB, resetting network")
|
||||
t.router.ResetNetwork()
|
||||
runtimeDebug.FreeOSMemory()
|
||||
}
|
||||
|
||||
var interval time.Duration
|
||||
if triggered {
|
||||
interval = t.maxInterval
|
||||
} else if delta <= 0 {
|
||||
interval = t.maxInterval
|
||||
} else if t.checksBeforeLimit <= 0 {
|
||||
interval = t.maxInterval
|
||||
} else {
|
||||
timeToLimit := time.Duration(float64(remaining) / float64(delta) * float64(t.lastInterval))
|
||||
interval = timeToLimit / time.Duration(t.checksBeforeLimit)
|
||||
if interval < t.minInterval {
|
||||
interval = t.minInterval
|
||||
}
|
||||
if interval > t.maxInterval {
|
||||
interval = t.maxInterval
|
||||
}
|
||||
}
|
||||
|
||||
t.lastInterval = interval
|
||||
t.timer.Reset(interval)
|
||||
}
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"syscall"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
"github.com/sagernet/sing-box/common/process"
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/dns"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
@@ -111,7 +110,7 @@ func (t *resolve1Manager) createMetadata(sender dbus.Sender) adapter.InboundCont
|
||||
if err != nil {
|
||||
return metadata
|
||||
}
|
||||
var processInfo process.Info
|
||||
var processInfo adapter.ConnectionOwner
|
||||
metadata.ProcessInfo = &processInfo
|
||||
processInfo.ProcessID = uint32(senderPid)
|
||||
|
||||
@@ -140,7 +139,7 @@ func (t *resolve1Manager) createMetadata(sender dbus.Sender) adapter.InboundCont
|
||||
processInfo.UserId = int32(uid)
|
||||
uidFound = true
|
||||
if osUser, _ := user.LookupId(F.ToString(uid)); osUser != nil {
|
||||
processInfo.User = osUser.Username
|
||||
processInfo.UserName = osUser.Username
|
||||
}
|
||||
break
|
||||
}
|
||||
@@ -159,8 +158,8 @@ func (t *resolve1Manager) log(sender dbus.Sender, message ...any) {
|
||||
var prefix string
|
||||
if metadata.ProcessInfo.ProcessPath != "" {
|
||||
prefix = filepath.Base(metadata.ProcessInfo.ProcessPath)
|
||||
} else if metadata.ProcessInfo.User != "" {
|
||||
prefix = F.ToString("user:", metadata.ProcessInfo.User)
|
||||
} else if metadata.ProcessInfo.UserName != "" {
|
||||
prefix = F.ToString("user:", metadata.ProcessInfo.UserName)
|
||||
} else if metadata.ProcessInfo.UserId != 0 {
|
||||
prefix = F.ToString("uid:", metadata.ProcessInfo.UserId)
|
||||
}
|
||||
@@ -177,8 +176,8 @@ func (t *resolve1Manager) logRequest(sender dbus.Sender, message ...any) context
|
||||
var prefix string
|
||||
if metadata.ProcessInfo.ProcessPath != "" {
|
||||
prefix = filepath.Base(metadata.ProcessInfo.ProcessPath)
|
||||
} else if metadata.ProcessInfo.User != "" {
|
||||
prefix = F.ToString("user:", metadata.ProcessInfo.User)
|
||||
} else if metadata.ProcessInfo.UserName != "" {
|
||||
prefix = F.ToString("user:", metadata.ProcessInfo.UserName)
|
||||
} else if metadata.ProcessInfo.UserId != 0 {
|
||||
prefix = F.ToString("uid:", metadata.ProcessInfo.UserId)
|
||||
}
|
||||
|
||||
@@ -110,6 +110,16 @@ func (t *Transport) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *Transport) Reset() {
|
||||
t.linkAccess.RLock()
|
||||
defer t.linkAccess.RUnlock()
|
||||
for _, servers := range t.linkServers {
|
||||
for _, server := range servers.Servers {
|
||||
server.Reset()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) updateTransports(link *TransportLink) error {
|
||||
t.linkAccess.Lock()
|
||||
defer t.linkAccess.Unlock()
|
||||
@@ -129,7 +139,7 @@ func (t *Transport) updateTransports(link *TransportLink) error {
|
||||
return os.ErrInvalid
|
||||
}
|
||||
if link.dnsOverTLS {
|
||||
tlsConfig := common.Must1(tls.NewClient(t.ctx, serverAddr.String(), option.OutboundTLSOptions{
|
||||
tlsConfig := common.Must1(tls.NewClient(t.ctx, t.logger, serverAddr.String(), option.OutboundTLSOptions{
|
||||
Enabled: true,
|
||||
ServerName: serverAddr.String(),
|
||||
}))
|
||||
@@ -151,7 +161,7 @@ func (t *Transport) updateTransports(link *TransportLink) error {
|
||||
} else {
|
||||
serverName = serverAddr.String()
|
||||
}
|
||||
tlsConfig := common.Must1(tls.NewClient(t.ctx, serverAddr.String(), option.OutboundTLSOptions{
|
||||
tlsConfig := common.Must1(tls.NewClient(t.ctx, t.logger, serverAddr.String(), option.OutboundTLSOptions{
|
||||
Enabled: true,
|
||||
ServerName: serverName,
|
||||
}))
|
||||
|
||||
Reference in New Issue
Block a user