Add claude code multiplexer service

This commit is contained in:
世界
2025-10-21 10:34:05 +08:00
parent d87c9fd242
commit 0f5cda4169
30 changed files with 1807 additions and 369 deletions

136
service/ccm/credential.go Normal file
View File

@@ -0,0 +1,136 @@
package ccm
import (
"bytes"
"encoding/json"
"io"
"net/http"
"os"
"os/user"
"path/filepath"
"time"
E "github.com/sagernet/sing/common/exceptions"
)
const (
oauth2ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
oauth2TokenURL = "https://console.anthropic.com/v1/oauth/token"
claudeAPIBaseURL = "https://api.anthropic.com"
tokenRefreshBufferMs = 60000
anthropicBetaOAuthValue = "oauth-2025-04-20"
)
func getRealUser() (*user.User, error) {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
sudoUserInfo, err := user.Lookup(sudoUser)
if err == nil {
return sudoUserInfo, nil
}
}
return user.Current()
}
func getDefaultCredentialsPath() (string, error) {
userInfo, err := getRealUser()
if err != nil {
return "", err
}
return filepath.Join(userInfo.HomeDir, ".claude", ".credentials.json"), nil
}
func readCredentialsFromFile(path string) (*oauthCredentials, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var credentialsContainer struct {
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
}
err = json.Unmarshal(data, &credentialsContainer)
if err != nil {
return nil, err
}
if credentialsContainer.ClaudeAIAuth == nil {
return nil, E.New("claudeAiOauth field not found in credentials")
}
return credentialsContainer.ClaudeAIAuth, nil
}
func writeCredentialsToFile(oauthCredentials *oauthCredentials, path string) error {
data, err := json.MarshalIndent(map[string]any{
"claudeAiOauth": oauthCredentials,
}, "", " ")
if err != nil {
return err
}
return os.WriteFile(path, data, 0o600)
}
type oauthCredentials struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresAt int64 `json:"expiresAt"`
Scopes []string `json:"scopes,omitempty"`
SubscriptionType string `json:"subscriptionType,omitempty"`
IsMax bool `json:"isMax,omitempty"`
}
func (c *oauthCredentials) needsRefresh() bool {
if c.ExpiresAt == 0 {
return false
}
return time.Now().UnixMilli() >= c.ExpiresAt-tokenRefreshBufferMs
}
func refreshToken(httpClient *http.Client, credentials *oauthCredentials) (*oauthCredentials, error) {
if credentials.RefreshToken == "" {
return nil, E.New("refresh token is empty")
}
requestBody, err := json.Marshal(map[string]string{
"grant_type": "refresh_token",
"refresh_token": credentials.RefreshToken,
"client_id": oauth2ClientID,
})
if err != nil {
return nil, E.Cause(err, "marshal request")
}
request, err := http.NewRequest("POST", oauth2TokenURL, bytes.NewReader(requestBody))
if err != nil {
return nil, err
}
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json")
response, err := httpClient.Do(request)
if err != nil {
return nil, err
}
defer response.Body.Close()
if response.StatusCode != http.StatusOK {
body, _ := io.ReadAll(response.Body)
return nil, E.New("refresh failed: ", response.Status, " ", string(body))
}
var tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
}
err = json.NewDecoder(response.Body).Decode(&tokenResponse)
if err != nil {
return nil, E.Cause(err, "decode response")
}
newCredentials := *credentials
newCredentials.AccessToken = tokenResponse.AccessToken
if tokenResponse.RefreshToken != "" {
newCredentials.RefreshToken = tokenResponse.RefreshToken
}
newCredentials.ExpiresAt = time.Now().UnixMilli() + int64(tokenResponse.ExpiresIn)*1000
return &newCredentials, nil
}

View File

@@ -0,0 +1,116 @@
//go:build darwin && cgo
package ccm
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"os"
"path/filepath"
E "github.com/sagernet/sing/common/exceptions"
"github.com/keybase/go-keychain"
)
func getKeychainServiceName() string {
configDirectory := os.Getenv("CLAUDE_CONFIG_DIR")
if configDirectory == "" {
return "Claude Code-credentials"
}
userInfo, err := getRealUser()
if err != nil {
return "Claude Code-credentials"
}
defaultConfigDirectory := filepath.Join(userInfo.HomeDir, ".claude")
if configDirectory == defaultConfigDirectory {
return "Claude Code-credentials"
}
hash := sha256.Sum256([]byte(configDirectory))
return "Claude Code-credentials-" + hex.EncodeToString(hash[:])[:8]
}
func platformReadCredentials(customPath string) (*oauthCredentials, error) {
if customPath != "" {
return readCredentialsFromFile(customPath)
}
userInfo, err := getRealUser()
if err == nil {
query := keychain.NewItem()
query.SetSecClass(keychain.SecClassGenericPassword)
query.SetService(getKeychainServiceName())
query.SetAccount(userInfo.Username)
query.SetMatchLimit(keychain.MatchLimitOne)
query.SetReturnData(true)
results, err := keychain.QueryItem(query)
if err == nil && len(results) == 1 {
var container struct {
ClaudeAIAuth *oauthCredentials `json:"claudeAiOauth,omitempty"`
}
unmarshalErr := json.Unmarshal(results[0].Data, &container)
if unmarshalErr == nil && container.ClaudeAIAuth != nil {
return container.ClaudeAIAuth, nil
}
}
if err != nil && err != keychain.ErrorItemNotFound {
return nil, E.Cause(err, "query keychain")
}
}
defaultPath, err := getDefaultCredentialsPath()
if err != nil {
return nil, err
}
return readCredentialsFromFile(defaultPath)
}
func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error {
if customPath != "" {
return writeCredentialsToFile(oauthCredentials, customPath)
}
userInfo, err := getRealUser()
if err == nil {
data, err := json.Marshal(map[string]any{"claudeAiOauth": oauthCredentials})
if err == nil {
serviceName := getKeychainServiceName()
item := keychain.NewItem()
item.SetSecClass(keychain.SecClassGenericPassword)
item.SetService(serviceName)
item.SetAccount(userInfo.Username)
item.SetData(data)
item.SetAccessible(keychain.AccessibleWhenUnlocked)
err = keychain.AddItem(item)
if err == nil {
return nil
}
if err == keychain.ErrorDuplicateItem {
query := keychain.NewItem()
query.SetSecClass(keychain.SecClassGenericPassword)
query.SetService(serviceName)
query.SetAccount(userInfo.Username)
updateItem := keychain.NewItem()
updateItem.SetData(data)
updateErr := keychain.UpdateItem(query, updateItem)
if updateErr == nil {
return nil
}
}
}
}
defaultPath, err := getDefaultCredentialsPath()
if err != nil {
return err
}
return writeCredentialsToFile(oauthCredentials, defaultPath)
}

View File

@@ -0,0 +1,25 @@
//go:build !darwin
package ccm
func platformReadCredentials(customPath string) (*oauthCredentials, error) {
if customPath == "" {
var err error
customPath, err = getDefaultCredentialsPath()
if err != nil {
return nil, err
}
}
return readCredentialsFromFile(customPath)
}
func platformWriteCredentials(oauthCredentials *oauthCredentials, customPath string) error {
if customPath == "" {
var err error
customPath, err = getDefaultCredentialsPath()
if err != nil {
return err
}
}
return writeCredentialsToFile(oauthCredentials, customPath)
}

541
service/ccm/service.go Normal file
View File

@@ -0,0 +1,541 @@
package ccm
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"mime"
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
boxService "github.com/sagernet/sing-box/adapter/service"
"github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/common/listener"
"github.com/sagernet/sing-box/common/tls"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
aTLS "github.com/sagernet/sing/common/tls"
"github.com/anthropics/anthropic-sdk-go"
"github.com/go-chi/chi/v5"
"golang.org/x/net/http2"
)
const (
contextWindowStandard = 200000
contextWindowPremium = 1000000
premiumContextThreshold = 200000
)
func RegisterService(registry *boxService.Registry) {
boxService.Register[option.CCMServiceOptions](registry, C.TypeCCM, NewService)
}
type errorResponse struct {
Type string `json:"type"`
Error errorDetails `json:"error"`
RequestID string `json:"request_id,omitempty"`
}
type errorDetails struct {
Type string `json:"type"`
Message string `json:"message"`
}
func writeJSONError(w http.ResponseWriter, r *http.Request, statusCode int, errorType string, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(errorResponse{
Type: "error",
Error: errorDetails{
Type: errorType,
Message: message,
},
RequestID: r.Header.Get("Request-Id"),
})
}
func isHopByHopHeader(header string) bool {
switch strings.ToLower(header) {
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailers", "transfer-encoding", "upgrade", "host":
return true
default:
return false
}
}
type Service struct {
boxService.Adapter
ctx context.Context
logger log.ContextLogger
credentialPath string
credentials *oauthCredentials
users []option.CCMUser
httpClient *http.Client
httpHeaders http.Header
listener *listener.Listener
tlsConfig tls.ServerConfig
httpServer *http.Server
userManager *UserManager
accessMutex sync.RWMutex
usageTracker *AggregatedUsage
trackingGroup sync.WaitGroup
shuttingDown bool
}
func NewService(ctx context.Context, logger log.ContextLogger, tag string, options option.CCMServiceOptions) (adapter.Service, error) {
serviceDialer, err := dialer.NewWithOptions(dialer.Options{
Context: ctx,
Options: option.DialerOptions{
Detour: options.Detour,
},
RemoteIsDomain: true,
})
if err != nil {
return nil, E.Cause(err, "create dialer")
}
httpClient := &http.Client{
Transport: &http.Transport{
ForceAttemptHTTP2: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return serviceDialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
},
},
}
userManager := &UserManager{
tokenMap: make(map[string]string),
}
var usageTracker *AggregatedUsage
if options.UsagesPath != "" {
usageTracker = &AggregatedUsage{
LastUpdated: time.Now(),
Combinations: make([]CostCombination, 0),
filePath: options.UsagesPath,
logger: logger,
}
}
service := &Service{
Adapter: boxService.NewAdapter(C.TypeCCM, tag),
ctx: ctx,
logger: logger,
credentialPath: options.CredentialPath,
users: options.Users,
httpClient: httpClient,
httpHeaders: options.Headers.Build(),
listener: listener.New(listener.Options{
Context: ctx,
Logger: logger,
Network: []string{N.NetworkTCP},
Listen: options.ListenOptions,
}),
userManager: userManager,
usageTracker: usageTracker,
}
if options.TLS != nil {
tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS))
if err != nil {
return nil, err
}
service.tlsConfig = tlsConfig
}
return service, nil
}
func (s *Service) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
s.userManager.UpdateUsers(s.users)
credentials, err := platformReadCredentials(s.credentialPath)
if err != nil {
return E.Cause(err, "read credentials")
}
s.credentials = credentials
if s.usageTracker != nil {
err = s.usageTracker.Load()
if err != nil {
s.logger.Warn("load usage statistics: ", err)
}
}
router := chi.NewRouter()
router.Mount("/", s)
s.httpServer = &http.Server{Handler: router}
if s.tlsConfig != nil {
err = s.tlsConfig.Start()
if err != nil {
return E.Cause(err, "create TLS config")
}
}
tcpListener, err := s.listener.ListenTCP()
if err != nil {
return err
}
if s.tlsConfig != nil {
if !common.Contains(s.tlsConfig.NextProtos(), http2.NextProtoTLS) {
s.tlsConfig.SetNextProtos(append([]string{"h2"}, s.tlsConfig.NextProtos()...))
}
tcpListener = aTLS.NewListener(tcpListener, s.tlsConfig)
}
go func() {
serveErr := s.httpServer.Serve(tcpListener)
if serveErr != nil && !errors.Is(serveErr, http.ErrServerClosed) {
s.logger.Error("serve error: ", serveErr)
}
}()
return nil
}
func (s *Service) getAccessToken() (string, error) {
s.accessMutex.RLock()
if !s.credentials.needsRefresh() {
token := s.credentials.AccessToken
s.accessMutex.RUnlock()
return token, nil
}
s.accessMutex.RUnlock()
s.accessMutex.Lock()
defer s.accessMutex.Unlock()
if !s.credentials.needsRefresh() {
return s.credentials.AccessToken, nil
}
newCredentials, err := refreshToken(s.httpClient, s.credentials)
if err != nil {
return "", err
}
s.credentials = newCredentials
err = platformWriteCredentials(newCredentials, s.credentialPath)
if err != nil {
s.logger.Warn("persist refreshed token: ", err)
}
return newCredentials.AccessToken, nil
}
func detectContextWindow(betaHeader string, inputTokens int64) int {
if inputTokens > premiumContextThreshold {
features := strings.Split(betaHeader, ",")
for _, feature := range features {
if strings.TrimSpace(feature) == "context-1m" {
return contextWindowPremium
}
}
}
return contextWindowStandard
}
func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if !strings.HasPrefix(r.URL.Path, "/v1/") {
writeJSONError(w, r, http.StatusNotFound, "not_found_error", "Not found")
return
}
var username string
if len(s.users) > 0 {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": missing Authorization header")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "missing api key")
return
}
clientToken := strings.TrimPrefix(authHeader, "Bearer ")
if clientToken == authHeader {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": invalid Authorization format")
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key format")
return
}
var ok bool
username, ok = s.userManager.Authenticate(clientToken)
if !ok {
s.logger.Warn("authentication failed for request from ", r.RemoteAddr, ": unknown key: ", clientToken)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "invalid api key")
return
}
}
var requestModel string
var messagesCount int
if s.usageTracker != nil && r.Body != nil {
bodyBytes, err := io.ReadAll(r.Body)
if err == nil {
var request struct {
Model string `json:"model"`
Messages []anthropic.MessageParam `json:"messages"`
}
err := json.Unmarshal(bodyBytes, &request)
if err == nil {
requestModel = request.Model
messagesCount = len(request.Messages)
}
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
}
accessToken, err := s.getAccessToken()
if err != nil {
s.logger.Error("get access token: ", err)
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "Authentication failed")
return
}
proxyURL := claudeAPIBaseURL + r.URL.RequestURI()
proxyRequest, err := http.NewRequestWithContext(r.Context(), r.Method, proxyURL, r.Body)
if err != nil {
s.logger.Error("create proxy request: ", err)
writeJSONError(w, r, http.StatusInternalServerError, "api_error", "Internal server error")
return
}
for key, values := range r.Header {
if !isHopByHopHeader(key) && key != "Authorization" {
proxyRequest.Header[key] = values
}
}
anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
if anthropicBetaHeader != "" {
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
} else {
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue)
}
for key, values := range s.httpHeaders {
proxyRequest.Header.Del(key)
proxyRequest.Header[key] = values
}
proxyRequest.Header.Set("Authorization", "Bearer "+accessToken)
response, err := s.httpClient.Do(proxyRequest)
if err != nil {
writeJSONError(w, r, http.StatusBadGateway, "api_error", err.Error())
return
}
defer response.Body.Close()
for key, values := range response.Header {
if !isHopByHopHeader(key) {
w.Header()[key] = values
}
}
w.WriteHeader(response.StatusCode)
if s.usageTracker != nil && response.StatusCode == http.StatusOK {
s.handleResponseWithTracking(w, response, requestModel, anthropicBetaHeader, messagesCount, username)
} else {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
if err == nil && mediaType != "text/event-stream" {
_, _ = io.Copy(w, response.Body)
return
}
flusher, ok := w.(http.Flusher)
if !ok {
s.logger.Error("streaming not supported")
return
}
buffer := make([]byte, buf.BufferSize)
for {
n, err := response.Body.Read(buffer)
if n > 0 {
_, writeError := w.Write(buffer[:n])
if writeError != nil {
s.logger.Error("write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
return
}
}
}
}
func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, response *http.Response, requestModel string, anthropicBetaHeader string, messagesCount int, username string) {
mediaType, _, err := mime.ParseMediaType(response.Header.Get("Content-Type"))
isStreaming := err == nil && mediaType == "text/event-stream"
if !isStreaming {
bodyBytes, err := io.ReadAll(response.Body)
if err != nil {
s.logger.Error("read response body: ", err)
return
}
var message anthropic.Message
var usage anthropic.Usage
var responseModel string
err = json.Unmarshal(bodyBytes, &message)
if err == nil {
responseModel = string(message.Model)
usage = message.Usage
}
if responseModel == "" {
responseModel = requestModel
}
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
if responseModel != "" {
contextWindow := detectContextWindow(anthropicBetaHeader, usage.InputTokens)
s.usageTracker.AddUsage(
responseModel,
contextWindow,
messagesCount,
usage.InputTokens,
usage.OutputTokens,
usage.CacheReadInputTokens,
usage.CacheCreationInputTokens,
username,
)
}
}
_, _ = writer.Write(bodyBytes)
return
}
flusher, ok := writer.(http.Flusher)
if !ok {
s.logger.Error("streaming not supported")
return
}
var accumulatedUsage anthropic.Usage
var responseModel string
buffer := make([]byte, buf.BufferSize)
var leftover []byte
for {
n, err := response.Body.Read(buffer)
if n > 0 {
data := append(leftover, buffer[:n]...)
lines := bytes.Split(data, []byte("\n"))
if err == nil {
leftover = lines[len(lines)-1]
lines = lines[:len(lines)-1]
} else {
leftover = nil
}
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if bytes.HasPrefix(line, []byte("data: ")) {
eventData := bytes.TrimPrefix(line, []byte("data: "))
if bytes.Equal(eventData, []byte("[DONE]")) {
continue
}
var event anthropic.MessageStreamEventUnion
err := json.Unmarshal(eventData, &event)
if err != nil {
continue
}
switch event.Type {
case "message_start":
messageStart := event.AsMessageStart()
if messageStart.Message.Model != "" {
responseModel = string(messageStart.Message.Model)
}
if messageStart.Message.Usage.InputTokens > 0 {
accumulatedUsage.InputTokens = messageStart.Message.Usage.InputTokens
accumulatedUsage.CacheReadInputTokens = messageStart.Message.Usage.CacheReadInputTokens
accumulatedUsage.CacheCreationInputTokens = messageStart.Message.Usage.CacheCreationInputTokens
}
case "message_delta":
messageDelta := event.AsMessageDelta()
if messageDelta.Usage.OutputTokens > 0 {
accumulatedUsage.OutputTokens = messageDelta.Usage.OutputTokens
}
}
}
}
_, writeError := writer.Write(buffer[:n])
if writeError != nil {
s.logger.Error("write streaming response: ", writeError)
return
}
flusher.Flush()
}
if err != nil {
if responseModel == "" {
responseModel = requestModel
}
if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
if responseModel != "" {
contextWindow := detectContextWindow(anthropicBetaHeader, accumulatedUsage.InputTokens)
s.usageTracker.AddUsage(
responseModel,
contextWindow,
messagesCount,
accumulatedUsage.InputTokens,
accumulatedUsage.OutputTokens,
accumulatedUsage.CacheReadInputTokens,
accumulatedUsage.CacheCreationInputTokens,
username,
)
}
}
return
}
}
}
func (s *Service) Close() error {
err := common.Close(
common.PtrOrNil(s.httpServer),
common.PtrOrNil(s.listener),
s.tlsConfig,
)
if s.usageTracker != nil {
s.usageTracker.cancelPendingSave()
saveErr := s.usageTracker.Save()
if saveErr != nil {
s.logger.Error("save usage statistics: ", saveErr)
}
}
return err
}

View File

@@ -0,0 +1,407 @@
package ccm
import (
"encoding/json"
"math"
"os"
"regexp"
"sync"
"time"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
)
type UsageStats struct {
RequestCount int `json:"request_count"`
MessagesCount int `json:"messages_count"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheReadInputTokens int64 `json:"cache_read_input_tokens"`
CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"`
}
type CostCombination struct {
Model string `json:"model"`
ContextWindow int `json:"context_window"`
Total UsageStats `json:"total"`
ByUser map[string]UsageStats `json:"by_user"`
}
type AggregatedUsage struct {
LastUpdated time.Time `json:"last_updated"`
Combinations []CostCombination `json:"combinations"`
mutex sync.Mutex
filePath string
logger log.ContextLogger
lastSaveTime time.Time
pendingSave bool
saveTimer *time.Timer
saveMutex sync.Mutex
}
type UsageStatsJSON struct {
RequestCount int `json:"request_count"`
MessagesCount int `json:"messages_count"`
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
CacheReadInputTokens int64 `json:"cache_read_input_tokens"`
CacheCreationInputTokens int64 `json:"cache_creation_input_tokens"`
CostUSD float64 `json:"cost_usd"`
}
type CostCombinationJSON struct {
Model string `json:"model"`
ContextWindow int `json:"context_window"`
Total UsageStatsJSON `json:"total"`
ByUser map[string]UsageStatsJSON `json:"by_user"`
}
type CostsSummaryJSON struct {
TotalUSD float64 `json:"total_usd"`
ByUser map[string]float64 `json:"by_user"`
}
type AggregatedUsageJSON struct {
LastUpdated time.Time `json:"last_updated"`
Costs CostsSummaryJSON `json:"costs"`
Combinations []CostCombinationJSON `json:"combinations"`
}
type ModelPricing struct {
InputPrice float64
OutputPrice float64
CacheReadPrice float64
CacheWritePrice float64
}
type modelFamily struct {
pattern *regexp.Regexp
standardPricing ModelPricing
premiumPricing *ModelPricing
}
var (
opus4Pricing = ModelPricing{
InputPrice: 15.0,
OutputPrice: 75.0,
CacheReadPrice: 1.5,
CacheWritePrice: 18.75,
}
sonnet4StandardPricing = ModelPricing{
InputPrice: 3.0,
OutputPrice: 15.0,
CacheReadPrice: 0.3,
CacheWritePrice: 3.75,
}
sonnet4PremiumPricing = ModelPricing{
InputPrice: 6.0,
OutputPrice: 22.5,
CacheReadPrice: 0.6,
CacheWritePrice: 7.5,
}
haiku4Pricing = ModelPricing{
InputPrice: 1.0,
OutputPrice: 5.0,
CacheReadPrice: 0.1,
CacheWritePrice: 1.25,
}
haiku35Pricing = ModelPricing{
InputPrice: 0.8,
OutputPrice: 4.0,
CacheReadPrice: 0.08,
CacheWritePrice: 1.0,
}
sonnet35Pricing = ModelPricing{
InputPrice: 3.0,
OutputPrice: 15.0,
CacheReadPrice: 0.3,
CacheWritePrice: 3.75,
}
modelFamilies = []modelFamily{
{
pattern: regexp.MustCompile(`^claude-(?:opus-4-|4-opus-|opus-4-1-)`),
standardPricing: opus4Pricing,
premiumPricing: nil,
},
{
pattern: regexp.MustCompile(`^claude-3-7-sonnet-`),
standardPricing: sonnet4StandardPricing,
premiumPricing: &sonnet4PremiumPricing,
},
{
pattern: regexp.MustCompile(`^claude-(?:sonnet-4-|4-sonnet-)`),
standardPricing: sonnet4StandardPricing,
premiumPricing: &sonnet4PremiumPricing,
},
{
pattern: regexp.MustCompile(`^claude-haiku-4-`),
standardPricing: haiku4Pricing,
premiumPricing: nil,
},
{
pattern: regexp.MustCompile(`^claude-3-5-haiku-`),
standardPricing: haiku35Pricing,
premiumPricing: nil,
},
{
pattern: regexp.MustCompile(`^claude-3-5-sonnet-`),
standardPricing: sonnet35Pricing,
premiumPricing: nil,
},
}
)
func getPricing(model string, contextWindow int) ModelPricing {
isPremium := contextWindow >= contextWindowPremium
for _, family := range modelFamilies {
if family.pattern.MatchString(model) {
if isPremium && family.premiumPricing != nil {
return *family.premiumPricing
}
return family.standardPricing
}
}
return sonnet4StandardPricing
}
func calculateCost(stats UsageStats, model string, contextWindow int) float64 {
pricing := getPricing(model, contextWindow)
cost := (float64(stats.InputTokens)*pricing.InputPrice +
float64(stats.OutputTokens)*pricing.OutputPrice +
float64(stats.CacheReadInputTokens)*pricing.CacheReadPrice +
float64(stats.CacheCreationInputTokens)*pricing.CacheWritePrice) / 1_000_000
return math.Round(cost*100) / 100
}
func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
u.mutex.Lock()
defer u.mutex.Unlock()
result := &AggregatedUsageJSON{
LastUpdated: u.LastUpdated,
Combinations: make([]CostCombinationJSON, len(u.Combinations)),
Costs: CostsSummaryJSON{
TotalUSD: 0,
ByUser: make(map[string]float64),
},
}
for i, combo := range u.Combinations {
totalCost := calculateCost(combo.Total, combo.Model, combo.ContextWindow)
result.Costs.TotalUSD += totalCost
comboJSON := CostCombinationJSON{
Model: combo.Model,
ContextWindow: combo.ContextWindow,
Total: UsageStatsJSON{
RequestCount: combo.Total.RequestCount,
MessagesCount: combo.Total.MessagesCount,
InputTokens: combo.Total.InputTokens,
OutputTokens: combo.Total.OutputTokens,
CacheReadInputTokens: combo.Total.CacheReadInputTokens,
CacheCreationInputTokens: combo.Total.CacheCreationInputTokens,
CostUSD: totalCost,
},
ByUser: make(map[string]UsageStatsJSON),
}
for user, userStats := range combo.ByUser {
userCost := calculateCost(userStats, combo.Model, combo.ContextWindow)
result.Costs.ByUser[user] += userCost
comboJSON.ByUser[user] = UsageStatsJSON{
RequestCount: userStats.RequestCount,
MessagesCount: userStats.MessagesCount,
InputTokens: userStats.InputTokens,
OutputTokens: userStats.OutputTokens,
CacheReadInputTokens: userStats.CacheReadInputTokens,
CacheCreationInputTokens: userStats.CacheCreationInputTokens,
CostUSD: userCost,
}
}
result.Combinations[i] = comboJSON
}
result.Costs.TotalUSD = math.Round(result.Costs.TotalUSD*100) / 100
for user, cost := range result.Costs.ByUser {
result.Costs.ByUser[user] = math.Round(cost*100) / 100
}
return result
}
func (u *AggregatedUsage) Load() error {
u.mutex.Lock()
defer u.mutex.Unlock()
data, err := os.ReadFile(u.filePath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
var temp struct {
LastUpdated time.Time `json:"last_updated"`
Combinations []CostCombination `json:"combinations"`
}
err = json.Unmarshal(data, &temp)
if err != nil {
return err
}
u.LastUpdated = temp.LastUpdated
u.Combinations = temp.Combinations
for i := range u.Combinations {
if u.Combinations[i].ByUser == nil {
u.Combinations[i].ByUser = make(map[string]UsageStats)
}
}
return nil
}
func (u *AggregatedUsage) Save() error {
jsonData := u.ToJSON()
data, err := json.MarshalIndent(jsonData, "", " ")
if err != nil {
return err
}
tmpFile := u.filePath + ".tmp"
err = os.WriteFile(tmpFile, data, 0o644)
if err != nil {
return err
}
defer os.Remove(tmpFile)
err = os.Rename(tmpFile, u.filePath)
if err == nil {
u.saveMutex.Lock()
u.lastSaveTime = time.Now()
u.saveMutex.Unlock()
}
return err
}
func (u *AggregatedUsage) AddUsage(model string, contextWindow int, messagesCount int, inputTokens, outputTokens, cacheReadTokens, cacheCreationTokens int64, user string) error {
if model == "" {
return E.New("model cannot be empty")
}
if contextWindow <= 0 {
return E.New("contextWindow must be positive")
}
u.mutex.Lock()
defer u.mutex.Unlock()
u.LastUpdated = time.Now()
// Find or create combination
var combo *CostCombination
for i := range u.Combinations {
if u.Combinations[i].Model == model && u.Combinations[i].ContextWindow == contextWindow {
combo = &u.Combinations[i]
break
}
}
if combo == nil {
newCombo := CostCombination{
Model: model,
ContextWindow: contextWindow,
Total: UsageStats{},
ByUser: make(map[string]UsageStats),
}
u.Combinations = append(u.Combinations, newCombo)
combo = &u.Combinations[len(u.Combinations)-1]
}
// Update total stats
combo.Total.RequestCount++
combo.Total.MessagesCount += messagesCount
combo.Total.InputTokens += inputTokens
combo.Total.OutputTokens += outputTokens
combo.Total.CacheReadInputTokens += cacheReadTokens
combo.Total.CacheCreationInputTokens += cacheCreationTokens
// Update per-user stats if user is specified
if user != "" {
userStats := combo.ByUser[user]
userStats.RequestCount++
userStats.MessagesCount += messagesCount
userStats.InputTokens += inputTokens
userStats.OutputTokens += outputTokens
userStats.CacheReadInputTokens += cacheReadTokens
userStats.CacheCreationInputTokens += cacheCreationTokens
combo.ByUser[user] = userStats
}
go u.scheduleSave()
return nil
}
func (u *AggregatedUsage) scheduleSave() {
const saveInterval = time.Minute
u.saveMutex.Lock()
defer u.saveMutex.Unlock()
timeSinceLastSave := time.Since(u.lastSaveTime)
if timeSinceLastSave >= saveInterval {
go u.saveAsync()
return
}
if u.pendingSave {
return
}
u.pendingSave = true
remainingTime := saveInterval - timeSinceLastSave
u.saveTimer = time.AfterFunc(remainingTime, func() {
u.saveMutex.Lock()
u.pendingSave = false
u.saveMutex.Unlock()
u.saveAsync()
})
}
func (u *AggregatedUsage) saveAsync() {
err := u.Save()
if err != nil {
if u.logger != nil {
u.logger.Error("save usage statistics: ", err)
}
}
}
func (u *AggregatedUsage) cancelPendingSave() {
u.saveMutex.Lock()
defer u.saveMutex.Unlock()
if u.saveTimer != nil {
u.saveTimer.Stop()
u.saveTimer = nil
}
u.pendingSave = false
}

View File

@@ -0,0 +1,29 @@
package ccm
import (
"sync"
"github.com/sagernet/sing-box/option"
)
type UserManager struct {
accessMutex sync.RWMutex
tokenMap map[string]string
}
func (m *UserManager) UpdateUsers(users []option.CCMUser) {
m.accessMutex.Lock()
defer m.accessMutex.Unlock()
tokenMap := make(map[string]string, len(users))
for _, user := range users {
tokenMap[user.Token] = user.Name
}
m.tokenMap = tokenMap
}
func (m *UserManager) Authenticate(token string) (string, bool) {
m.accessMutex.RLock()
username, found := m.tokenMap[token]
m.accessMutex.RUnlock()
return username, found
}