mirror of
https://github.com/shtorm-7/sing-box-extended.git
synced 2026-06-12 22:38:15 +03:00
Update sing-box core
This commit is contained in:
@@ -281,11 +281,11 @@ func (s *Service) getAccessToken() (string, error) {
|
||||
return newCredentials.AccessToken, nil
|
||||
}
|
||||
|
||||
func detectContextWindow(betaHeader string, inputTokens int64) int {
|
||||
if inputTokens > premiumContextThreshold {
|
||||
func detectContextWindow(betaHeader string, totalInputTokens int64) int {
|
||||
if totalInputTokens > premiumContextThreshold {
|
||||
features := strings.Split(betaHeader, ",")
|
||||
for _, feature := range features {
|
||||
if strings.TrimSpace(feature) == "context-1m" {
|
||||
if strings.HasPrefix(strings.TrimSpace(feature), "context-1m") {
|
||||
return contextWindowPremium
|
||||
}
|
||||
}
|
||||
@@ -362,6 +362,13 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
serviceOverridesAcceptEncoding := len(s.httpHeaders.Values("Accept-Encoding")) > 0
|
||||
if s.usageTracker != nil && !serviceOverridesAcceptEncoding {
|
||||
// Strip Accept-Encoding so Go Transport adds it automatically
|
||||
// and transparently decompresses the response for correct usage counting.
|
||||
proxyRequest.Header.Del("Accept-Encoding")
|
||||
}
|
||||
|
||||
anthropicBetaHeader := proxyRequest.Header.Get("anthropic-beta")
|
||||
if anthropicBetaHeader != "" {
|
||||
proxyRequest.Header.Set("anthropic-beta", anthropicBetaOAuthValue+","+anthropicBetaHeader)
|
||||
@@ -447,7 +454,8 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
|
||||
if usage.InputTokens > 0 || usage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, usage.InputTokens)
|
||||
totalInputTokens := usage.InputTokens + usage.CacheCreationInputTokens + usage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
@@ -547,7 +555,8 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
|
||||
if accumulatedUsage.InputTokens > 0 || accumulatedUsage.OutputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, accumulatedUsage.InputTokens)
|
||||
totalInputTokens := accumulatedUsage.InputTokens + accumulatedUsage.CacheCreationInputTokens + accumulatedUsage.CacheReadInputTokens
|
||||
contextWindow := detectContextWindow(anthropicBetaHeader, totalInputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
|
||||
@@ -65,9 +65,10 @@ type CostCombinationJSON struct {
|
||||
}
|
||||
|
||||
type CostsSummaryJSON struct {
|
||||
TotalUSD float64 `json:"total_usd"`
|
||||
ByUser map[string]float64 `json:"by_user"`
|
||||
ByWeek map[string]float64 `json:"by_week,omitempty"`
|
||||
TotalUSD float64 `json:"total_usd"`
|
||||
ByUser map[string]float64 `json:"by_user"`
|
||||
ByWeek map[string]float64 `json:"by_week,omitempty"`
|
||||
ByUserAndWeek map[string]map[string]float64 `json:"by_user_and_week,omitempty"`
|
||||
}
|
||||
|
||||
type AggregatedUsageJSON struct {
|
||||
@@ -492,6 +493,31 @@ func buildByWeekCost(combinations []CostCombination) map[string]float64 {
|
||||
return byWeek
|
||||
}
|
||||
|
||||
func buildByUserAndWeekCost(combinations []CostCombination) map[string]map[string]float64 {
|
||||
byUserAndWeek := make(map[string]map[string]float64)
|
||||
for _, combination := range combinations {
|
||||
if combination.WeekStartUnix <= 0 {
|
||||
continue
|
||||
}
|
||||
weekStartAt := time.Unix(combination.WeekStartUnix, 0).UTC()
|
||||
weekKey := formatWeekStartKey(weekStartAt)
|
||||
for user, userStats := range combination.ByUser {
|
||||
userWeeks, exists := byUserAndWeek[user]
|
||||
if !exists {
|
||||
userWeeks = make(map[string]float64)
|
||||
byUserAndWeek[user] = userWeeks
|
||||
}
|
||||
userWeeks[weekKey] += calculateCost(userStats, combination.Model, combination.ContextWindow)
|
||||
}
|
||||
}
|
||||
for _, weekCosts := range byUserAndWeek {
|
||||
for weekKey, cost := range weekCosts {
|
||||
weekCosts[weekKey] = roundCost(cost)
|
||||
}
|
||||
}
|
||||
return byUserAndWeek
|
||||
}
|
||||
|
||||
func deriveWeekStartUnix(cycleHint *WeeklyCycleHint) int64 {
|
||||
if cycleHint == nil || cycleHint.WindowMinutes <= 0 || cycleHint.ResetAt.IsZero() {
|
||||
return 0
|
||||
@@ -522,6 +548,11 @@ func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
|
||||
result.Costs.ByWeek = nil
|
||||
}
|
||||
|
||||
result.Costs.ByUserAndWeek = buildByUserAndWeekCost(u.Combinations)
|
||||
if len(result.Costs.ByUserAndWeek) == 0 {
|
||||
result.Costs.ByUserAndWeek = nil
|
||||
}
|
||||
|
||||
for user, cost := range result.Costs.ByUser {
|
||||
result.Costs.ByUser[user] = roundCost(cost)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build with_gvisor
|
||||
|
||||
package derp
|
||||
|
||||
import (
|
||||
|
||||
@@ -130,6 +130,7 @@ type Service struct {
|
||||
credentialPath string
|
||||
credentials *oauthCredentials
|
||||
users []option.OCMUser
|
||||
dialer N.Dialer
|
||||
httpClient *http.Client
|
||||
httpHeaders http.Header
|
||||
listener *listener.Listener
|
||||
@@ -138,7 +139,9 @@ type Service struct {
|
||||
userManager *UserManager
|
||||
accessMutex sync.RWMutex
|
||||
usageTracker *AggregatedUsage
|
||||
trackingGroup sync.WaitGroup
|
||||
webSocketMutex sync.Mutex
|
||||
webSocketGroup sync.WaitGroup
|
||||
webSocketConns map[*webSocketSession]struct{}
|
||||
shuttingDown bool
|
||||
}
|
||||
|
||||
@@ -187,6 +190,7 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
logger: logger,
|
||||
credentialPath: options.CredentialPath,
|
||||
users: options.Users,
|
||||
dialer: serviceDialer,
|
||||
httpClient: httpClient,
|
||||
httpHeaders: options.Headers.Build(),
|
||||
listener: listener.New(listener.Options{
|
||||
@@ -195,8 +199,9 @@ func NewService(ctx context.Context, logger log.ContextLogger, tag string, optio
|
||||
Network: []string{N.NetworkTCP},
|
||||
Listen: options.ListenOptions,
|
||||
}),
|
||||
userManager: userManager,
|
||||
usageTracker: usageTracker,
|
||||
userManager: userManager,
|
||||
usageTracker: usageTracker,
|
||||
webSocketConns: make(map[*webSocketSession]struct{}),
|
||||
}
|
||||
|
||||
if options.TLS != nil {
|
||||
@@ -356,6 +361,11 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
if strings.EqualFold(r.Header.Get("Upgrade"), "websocket") && strings.HasPrefix(path, "/v1/responses") {
|
||||
s.handleWebSocket(w, r, proxyPath, username)
|
||||
return
|
||||
}
|
||||
|
||||
var requestModel string
|
||||
|
||||
if s.usageTracker != nil && r.Body != nil {
|
||||
@@ -497,8 +507,10 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
responseModel = requestModel
|
||||
}
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
@@ -606,8 +618,10 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
@@ -624,11 +638,17 @@ func (s *Service) handleResponseWithTracking(writer http.ResponseWriter, respons
|
||||
}
|
||||
|
||||
func (s *Service) Close() error {
|
||||
webSocketSessions := s.startWebSocketShutdown()
|
||||
|
||||
err := common.Close(
|
||||
common.PtrOrNil(s.httpServer),
|
||||
common.PtrOrNil(s.listener),
|
||||
s.tlsConfig,
|
||||
)
|
||||
for _, session := range webSocketSessions {
|
||||
session.Close()
|
||||
}
|
||||
s.webSocketGroup.Wait()
|
||||
|
||||
if s.usageTracker != nil {
|
||||
s.usageTracker.cancelPendingSave()
|
||||
@@ -640,3 +660,48 @@ func (s *Service) Close() error {
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Service) registerWebSocketSession(session *webSocketSession) bool {
|
||||
s.webSocketMutex.Lock()
|
||||
defer s.webSocketMutex.Unlock()
|
||||
|
||||
if s.shuttingDown {
|
||||
return false
|
||||
}
|
||||
|
||||
s.webSocketConns[session] = struct{}{}
|
||||
s.webSocketGroup.Add(1)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Service) unregisterWebSocketSession(session *webSocketSession) {
|
||||
s.webSocketMutex.Lock()
|
||||
_, loaded := s.webSocketConns[session]
|
||||
if loaded {
|
||||
delete(s.webSocketConns, session)
|
||||
}
|
||||
s.webSocketMutex.Unlock()
|
||||
|
||||
if loaded {
|
||||
s.webSocketGroup.Done()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) isShuttingDown() bool {
|
||||
s.webSocketMutex.Lock()
|
||||
defer s.webSocketMutex.Unlock()
|
||||
return s.shuttingDown
|
||||
}
|
||||
|
||||
func (s *Service) startWebSocketShutdown() []*webSocketSession {
|
||||
s.webSocketMutex.Lock()
|
||||
defer s.webSocketMutex.Unlock()
|
||||
|
||||
s.shuttingDown = true
|
||||
|
||||
webSocketSessions := make([]*webSocketSession, 0, len(s.webSocketConns))
|
||||
for session := range s.webSocketConns {
|
||||
webSocketSessions = append(webSocketSessions, session)
|
||||
}
|
||||
return webSocketSessions
|
||||
}
|
||||
|
||||
@@ -46,6 +46,7 @@ func (u *UsageStats) UnmarshalJSON(data []byte) error {
|
||||
type CostCombination struct {
|
||||
Model string `json:"model"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
ContextWindow int `json:"context_window"`
|
||||
WeekStartUnix int64 `json:"week_start_unix,omitempty"`
|
||||
Total UsageStats `json:"total"`
|
||||
ByUser map[string]UsageStats `json:"by_user"`
|
||||
@@ -74,15 +75,17 @@ type UsageStatsJSON struct {
|
||||
type CostCombinationJSON struct {
|
||||
Model string `json:"model"`
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
ContextWindow int `json:"context_window"`
|
||||
WeekStartUnix int64 `json:"week_start_unix,omitempty"`
|
||||
Total UsageStatsJSON `json:"total"`
|
||||
ByUser map[string]UsageStatsJSON `json:"by_user"`
|
||||
}
|
||||
|
||||
type CostsSummaryJSON struct {
|
||||
TotalUSD float64 `json:"total_usd"`
|
||||
ByUser map[string]float64 `json:"by_user"`
|
||||
ByWeek map[string]float64 `json:"by_week,omitempty"`
|
||||
TotalUSD float64 `json:"total_usd"`
|
||||
ByUser map[string]float64 `json:"by_user"`
|
||||
ByWeek map[string]float64 `json:"by_week,omitempty"`
|
||||
ByUserAndWeek map[string]map[string]float64 `json:"by_user_and_week,omitempty"`
|
||||
}
|
||||
|
||||
type AggregatedUsageJSON struct {
|
||||
@@ -103,8 +106,9 @@ type ModelPricing struct {
|
||||
}
|
||||
|
||||
type modelFamily struct {
|
||||
pattern *regexp.Regexp
|
||||
pricing ModelPricing
|
||||
pattern *regexp.Regexp
|
||||
pricing ModelPricing
|
||||
premiumPricing *ModelPricing
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -115,6 +119,12 @@ const (
|
||||
serviceTierScale = "scale"
|
||||
)
|
||||
|
||||
const (
|
||||
contextWindowStandard = 272000
|
||||
contextWindowPremium = 1050000
|
||||
premiumContextThreshold = 272000
|
||||
)
|
||||
|
||||
var (
|
||||
gpt52Pricing = ModelPricing{
|
||||
InputPrice: 1.75,
|
||||
@@ -158,6 +168,30 @@ var (
|
||||
CachedInputPrice: 0.025,
|
||||
}
|
||||
|
||||
gpt54StandardPricing = ModelPricing{
|
||||
InputPrice: 2.5,
|
||||
OutputPrice: 15.0,
|
||||
CachedInputPrice: 0.25,
|
||||
}
|
||||
|
||||
gpt54PremiumPricing = ModelPricing{
|
||||
InputPrice: 5.0,
|
||||
OutputPrice: 22.5,
|
||||
CachedInputPrice: 0.5,
|
||||
}
|
||||
|
||||
gpt54ProPricing = ModelPricing{
|
||||
InputPrice: 30.0,
|
||||
OutputPrice: 180.0,
|
||||
CachedInputPrice: 30.0,
|
||||
}
|
||||
|
||||
gpt54ProPremiumPricing = ModelPricing{
|
||||
InputPrice: 60.0,
|
||||
OutputPrice: 270.0,
|
||||
CachedInputPrice: 60.0,
|
||||
}
|
||||
|
||||
gpt52ProPricing = ModelPricing{
|
||||
InputPrice: 21.0,
|
||||
OutputPrice: 168.0,
|
||||
@@ -170,6 +204,30 @@ var (
|
||||
CachedInputPrice: 15.0,
|
||||
}
|
||||
|
||||
gpt54FlexPricing = ModelPricing{
|
||||
InputPrice: 1.25,
|
||||
OutputPrice: 7.5,
|
||||
CachedInputPrice: 0.125,
|
||||
}
|
||||
|
||||
gpt54PremiumFlexPricing = ModelPricing{
|
||||
InputPrice: 2.5,
|
||||
OutputPrice: 11.25,
|
||||
CachedInputPrice: 0.25,
|
||||
}
|
||||
|
||||
gpt54ProFlexPricing = ModelPricing{
|
||||
InputPrice: 15.0,
|
||||
OutputPrice: 90.0,
|
||||
CachedInputPrice: 15.0,
|
||||
}
|
||||
|
||||
gpt54ProPremiumFlexPricing = ModelPricing{
|
||||
InputPrice: 30.0,
|
||||
OutputPrice: 135.0,
|
||||
CachedInputPrice: 30.0,
|
||||
}
|
||||
|
||||
gpt52FlexPricing = ModelPricing{
|
||||
InputPrice: 0.875,
|
||||
OutputPrice: 7.0,
|
||||
@@ -194,6 +252,18 @@ var (
|
||||
CachedInputPrice: 0.0025,
|
||||
}
|
||||
|
||||
gpt54PriorityPricing = ModelPricing{
|
||||
InputPrice: 5.0,
|
||||
OutputPrice: 30.0,
|
||||
CachedInputPrice: 0.5,
|
||||
}
|
||||
|
||||
gpt54PremiumPriorityPricing = ModelPricing{
|
||||
InputPrice: 10.0,
|
||||
OutputPrice: 45.0,
|
||||
CachedInputPrice: 1.0,
|
||||
}
|
||||
|
||||
gpt52PriorityPricing = ModelPricing{
|
||||
InputPrice: 3.5,
|
||||
OutputPrice: 28.0,
|
||||
@@ -381,6 +451,16 @@ var (
|
||||
}
|
||||
|
||||
standardModelFamilies = []modelFamily{
|
||||
{
|
||||
pattern: regexp.MustCompile(`^gpt-5\.4-pro(?:$|-)`),
|
||||
pricing: gpt54ProPricing,
|
||||
premiumPricing: &gpt54ProPremiumPricing,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^gpt-5\.4(?:$|-)`),
|
||||
pricing: gpt54StandardPricing,
|
||||
premiumPricing: &gpt54PremiumPricing,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^gpt-5\.3-codex(?:$|-)`),
|
||||
pricing: gpt52CodexPricing,
|
||||
@@ -524,6 +604,16 @@ var (
|
||||
}
|
||||
|
||||
flexModelFamilies = []modelFamily{
|
||||
{
|
||||
pattern: regexp.MustCompile(`^gpt-5\.4-pro(?:$|-)`),
|
||||
pricing: gpt54ProFlexPricing,
|
||||
premiumPricing: &gpt54ProPremiumFlexPricing,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^gpt-5\.4(?:$|-)`),
|
||||
pricing: gpt54FlexPricing,
|
||||
premiumPricing: &gpt54PremiumFlexPricing,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^gpt-5-mini(?:$|-)`),
|
||||
pricing: gpt5MiniFlexPricing,
|
||||
@@ -555,6 +645,11 @@ var (
|
||||
}
|
||||
|
||||
priorityModelFamilies = []modelFamily{
|
||||
{
|
||||
pattern: regexp.MustCompile(`^gpt-5\.4(?:$|-)`),
|
||||
pricing: gpt54PriorityPricing,
|
||||
premiumPricing: &gpt54PremiumPriorityPricing,
|
||||
},
|
||||
{
|
||||
pattern: regexp.MustCompile(`^gpt-5\.3-codex(?:$|-)`),
|
||||
pricing: gpt52CodexPriorityPricing,
|
||||
@@ -637,15 +732,28 @@ func modelFamiliesForTier(serviceTier string) []modelFamily {
|
||||
}
|
||||
}
|
||||
|
||||
func findPricingInFamilies(model string, modelFamilies []modelFamily) (ModelPricing, bool) {
|
||||
func findPricingInFamilies(model string, contextWindow int, modelFamilies []modelFamily) (ModelPricing, bool) {
|
||||
isPremium := contextWindow >= contextWindowPremium
|
||||
for _, family := range modelFamilies {
|
||||
if family.pattern.MatchString(model) {
|
||||
if isPremium && family.premiumPricing != nil {
|
||||
return *family.premiumPricing, true
|
||||
}
|
||||
return family.pricing, true
|
||||
}
|
||||
}
|
||||
return ModelPricing{}, false
|
||||
}
|
||||
|
||||
func hasPremiumPricingInFamilies(model string, modelFamilies []modelFamily) bool {
|
||||
for _, family := range modelFamilies {
|
||||
if family.pattern.MatchString(model) {
|
||||
return family.premiumPricing != nil
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeServiceTier(serviceTier string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(serviceTier)) {
|
||||
case "", serviceTierAuto, serviceTierDefault:
|
||||
@@ -662,27 +770,27 @@ func normalizeServiceTier(serviceTier string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func getPricing(model string, serviceTier string) ModelPricing {
|
||||
func getPricing(model string, serviceTier string, contextWindow int) ModelPricing {
|
||||
normalizedServiceTier := normalizeServiceTier(serviceTier)
|
||||
modelFamilies := modelFamiliesForTier(normalizedServiceTier)
|
||||
families := modelFamiliesForTier(normalizedServiceTier)
|
||||
|
||||
if pricing, found := findPricingInFamilies(model, modelFamilies); found {
|
||||
if pricing, found := findPricingInFamilies(model, contextWindow, families); found {
|
||||
return pricing
|
||||
}
|
||||
|
||||
normalizedModel := normalizeGPT5Model(model)
|
||||
if normalizedModel != model {
|
||||
if pricing, found := findPricingInFamilies(normalizedModel, modelFamilies); found {
|
||||
if pricing, found := findPricingInFamilies(normalizedModel, contextWindow, families); found {
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
|
||||
if normalizedServiceTier != serviceTierDefault {
|
||||
if pricing, found := findPricingInFamilies(model, standardModelFamilies); found {
|
||||
if pricing, found := findPricingInFamilies(model, contextWindow, standardModelFamilies); found {
|
||||
return pricing
|
||||
}
|
||||
if normalizedModel != model {
|
||||
if pricing, found := findPricingInFamilies(normalizedModel, standardModelFamilies); found {
|
||||
if pricing, found := findPricingInFamilies(normalizedModel, contextWindow, standardModelFamilies); found {
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
@@ -691,6 +799,30 @@ func getPricing(model string, serviceTier string) ModelPricing {
|
||||
return gpt4oPricing
|
||||
}
|
||||
|
||||
func detectContextWindow(model string, serviceTier string, inputTokens int64) int {
|
||||
if inputTokens <= premiumContextThreshold {
|
||||
return contextWindowStandard
|
||||
}
|
||||
normalizedServiceTier := normalizeServiceTier(serviceTier)
|
||||
families := modelFamiliesForTier(normalizedServiceTier)
|
||||
if hasPremiumPricingInFamilies(model, families) {
|
||||
return contextWindowPremium
|
||||
}
|
||||
normalizedModel := normalizeGPT5Model(model)
|
||||
if normalizedModel != model && hasPremiumPricingInFamilies(normalizedModel, families) {
|
||||
return contextWindowPremium
|
||||
}
|
||||
if normalizedServiceTier != serviceTierDefault {
|
||||
if hasPremiumPricingInFamilies(model, standardModelFamilies) {
|
||||
return contextWindowPremium
|
||||
}
|
||||
if normalizedModel != model && hasPremiumPricingInFamilies(normalizedModel, standardModelFamilies) {
|
||||
return contextWindowPremium
|
||||
}
|
||||
}
|
||||
return contextWindowStandard
|
||||
}
|
||||
|
||||
func normalizeGPT5Model(model string) string {
|
||||
if !strings.HasPrefix(model, "gpt-5.") {
|
||||
return model
|
||||
@@ -706,18 +838,18 @@ func normalizeGPT5Model(model string) string {
|
||||
case strings.Contains(model, "-chat-latest"):
|
||||
return "gpt-5.2-chat-latest"
|
||||
case strings.Contains(model, "-pro"):
|
||||
return "gpt-5.2-pro"
|
||||
return "gpt-5.4-pro"
|
||||
case strings.Contains(model, "-mini"):
|
||||
return "gpt-5-mini"
|
||||
case strings.Contains(model, "-nano"):
|
||||
return "gpt-5-nano"
|
||||
default:
|
||||
return "gpt-5.2"
|
||||
return "gpt-5.4"
|
||||
}
|
||||
}
|
||||
|
||||
func calculateCost(stats UsageStats, model string, serviceTier string) float64 {
|
||||
pricing := getPricing(model, serviceTier)
|
||||
func calculateCost(stats UsageStats, model string, serviceTier string, contextWindow int) float64 {
|
||||
pricing := getPricing(model, serviceTier, contextWindow)
|
||||
|
||||
regularInputTokens := stats.InputTokens - stats.CachedTokens
|
||||
if regularInputTokens < 0 {
|
||||
@@ -738,13 +870,16 @@ func roundCost(cost float64) float64 {
|
||||
func normalizeCombinations(combinations []CostCombination) {
|
||||
for index := range combinations {
|
||||
combinations[index].ServiceTier = normalizeServiceTier(combinations[index].ServiceTier)
|
||||
if combinations[index].ContextWindow <= 0 {
|
||||
combinations[index].ContextWindow = contextWindowStandard
|
||||
}
|
||||
if combinations[index].ByUser == nil {
|
||||
combinations[index].ByUser = make(map[string]UsageStats)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func addUsageToCombinations(combinations *[]CostCombination, model string, serviceTier string, weekStartUnix int64, user string, inputTokens, outputTokens, cachedTokens int64) {
|
||||
func addUsageToCombinations(combinations *[]CostCombination, model string, serviceTier string, contextWindow int, weekStartUnix int64, user string, inputTokens, outputTokens, cachedTokens int64) {
|
||||
var matchedCombination *CostCombination
|
||||
for index := range *combinations {
|
||||
combination := &(*combinations)[index]
|
||||
@@ -752,7 +887,7 @@ func addUsageToCombinations(combinations *[]CostCombination, model string, servi
|
||||
if combination.ServiceTier != combinationServiceTier {
|
||||
combination.ServiceTier = combinationServiceTier
|
||||
}
|
||||
if combination.Model == model && combinationServiceTier == serviceTier && combination.WeekStartUnix == weekStartUnix {
|
||||
if combination.Model == model && combinationServiceTier == serviceTier && combination.ContextWindow == contextWindow && combination.WeekStartUnix == weekStartUnix {
|
||||
matchedCombination = combination
|
||||
break
|
||||
}
|
||||
@@ -762,6 +897,7 @@ func addUsageToCombinations(combinations *[]CostCombination, model string, servi
|
||||
newCombination := CostCombination{
|
||||
Model: model,
|
||||
ServiceTier: serviceTier,
|
||||
ContextWindow: contextWindow,
|
||||
WeekStartUnix: weekStartUnix,
|
||||
Total: UsageStats{},
|
||||
ByUser: make(map[string]UsageStats),
|
||||
@@ -790,12 +926,13 @@ func buildCombinationJSON(combinations []CostCombination, aggregateUserCosts map
|
||||
var totalCost float64
|
||||
|
||||
for index, combination := range combinations {
|
||||
combinationTotalCost := calculateCost(combination.Total, combination.Model, combination.ServiceTier)
|
||||
combinationTotalCost := calculateCost(combination.Total, combination.Model, combination.ServiceTier, combination.ContextWindow)
|
||||
totalCost += combinationTotalCost
|
||||
|
||||
combinationJSON := CostCombinationJSON{
|
||||
Model: combination.Model,
|
||||
ServiceTier: combination.ServiceTier,
|
||||
ContextWindow: combination.ContextWindow,
|
||||
WeekStartUnix: combination.WeekStartUnix,
|
||||
Total: UsageStatsJSON{
|
||||
RequestCount: combination.Total.RequestCount,
|
||||
@@ -808,7 +945,7 @@ func buildCombinationJSON(combinations []CostCombination, aggregateUserCosts map
|
||||
}
|
||||
|
||||
for user, userStats := range combination.ByUser {
|
||||
userCost := calculateCost(userStats, combination.Model, combination.ServiceTier)
|
||||
userCost := calculateCost(userStats, combination.Model, combination.ServiceTier, combination.ContextWindow)
|
||||
if aggregateUserCosts != nil {
|
||||
aggregateUserCosts[user] += userCost
|
||||
}
|
||||
@@ -856,7 +993,7 @@ func buildByWeekCost(combinations []CostCombination) map[string]float64 {
|
||||
}
|
||||
weekStartAt := time.Unix(combination.WeekStartUnix, 0).UTC()
|
||||
weekKey := formatWeekStartKey(weekStartAt)
|
||||
byWeek[weekKey] += calculateCost(combination.Total, combination.Model, combination.ServiceTier)
|
||||
byWeek[weekKey] += calculateCost(combination.Total, combination.Model, combination.ServiceTier, combination.ContextWindow)
|
||||
}
|
||||
for weekKey, weekCost := range byWeek {
|
||||
byWeek[weekKey] = roundCost(weekCost)
|
||||
@@ -864,6 +1001,31 @@ func buildByWeekCost(combinations []CostCombination) map[string]float64 {
|
||||
return byWeek
|
||||
}
|
||||
|
||||
func buildByUserAndWeekCost(combinations []CostCombination) map[string]map[string]float64 {
|
||||
byUserAndWeek := make(map[string]map[string]float64)
|
||||
for _, combination := range combinations {
|
||||
if combination.WeekStartUnix <= 0 {
|
||||
continue
|
||||
}
|
||||
weekStartAt := time.Unix(combination.WeekStartUnix, 0).UTC()
|
||||
weekKey := formatWeekStartKey(weekStartAt)
|
||||
for user, userStats := range combination.ByUser {
|
||||
userWeeks, exists := byUserAndWeek[user]
|
||||
if !exists {
|
||||
userWeeks = make(map[string]float64)
|
||||
byUserAndWeek[user] = userWeeks
|
||||
}
|
||||
userWeeks[weekKey] += calculateCost(userStats, combination.Model, combination.ServiceTier, combination.ContextWindow)
|
||||
}
|
||||
}
|
||||
for _, weekCosts := range byUserAndWeek {
|
||||
for weekKey, cost := range weekCosts {
|
||||
weekCosts[weekKey] = roundCost(cost)
|
||||
}
|
||||
}
|
||||
return byUserAndWeek
|
||||
}
|
||||
|
||||
func deriveWeekStartUnix(cycleHint *WeeklyCycleHint) int64 {
|
||||
if cycleHint == nil || cycleHint.WindowMinutes <= 0 || cycleHint.ResetAt.IsZero() {
|
||||
return 0
|
||||
@@ -894,6 +1056,11 @@ func (u *AggregatedUsage) ToJSON() *AggregatedUsageJSON {
|
||||
result.Costs.ByWeek = nil
|
||||
}
|
||||
|
||||
result.Costs.ByUserAndWeek = buildByUserAndWeekCost(u.Combinations)
|
||||
if len(result.Costs.ByUserAndWeek) == 0 {
|
||||
result.Costs.ByUserAndWeek = nil
|
||||
}
|
||||
|
||||
for user, cost := range result.Costs.ByUser {
|
||||
result.Costs.ByUser[user] = roundCost(cost)
|
||||
}
|
||||
@@ -956,14 +1123,17 @@ func (u *AggregatedUsage) Save() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) AddUsage(model string, inputTokens, outputTokens, cachedTokens int64, serviceTier string, user string) error {
|
||||
return u.AddUsageWithCycleHint(model, inputTokens, outputTokens, cachedTokens, serviceTier, user, time.Now(), nil)
|
||||
func (u *AggregatedUsage) AddUsage(model string, contextWindow int, inputTokens, outputTokens, cachedTokens int64, serviceTier string, user string) error {
|
||||
return u.AddUsageWithCycleHint(model, contextWindow, inputTokens, outputTokens, cachedTokens, serviceTier, user, time.Now(), nil)
|
||||
}
|
||||
|
||||
func (u *AggregatedUsage) AddUsageWithCycleHint(model string, inputTokens, outputTokens, cachedTokens int64, serviceTier string, user string, observedAt time.Time, cycleHint *WeeklyCycleHint) error {
|
||||
func (u *AggregatedUsage) AddUsageWithCycleHint(model string, contextWindow int, inputTokens, outputTokens, cachedTokens int64, serviceTier string, user string, observedAt time.Time, cycleHint *WeeklyCycleHint) error {
|
||||
if model == "" {
|
||||
return E.New("model cannot be empty")
|
||||
}
|
||||
if contextWindow <= 0 {
|
||||
return E.New("contextWindow must be positive")
|
||||
}
|
||||
|
||||
normalizedServiceTier := normalizeServiceTier(serviceTier)
|
||||
if observedAt.IsZero() {
|
||||
@@ -976,7 +1146,7 @@ func (u *AggregatedUsage) AddUsageWithCycleHint(model string, inputTokens, outpu
|
||||
u.LastUpdated = observedAt
|
||||
weekStartUnix := deriveWeekStartUnix(cycleHint)
|
||||
|
||||
addUsageToCombinations(&u.Combinations, model, normalizedServiceTier, weekStartUnix, user, inputTokens, outputTokens, cachedTokens)
|
||||
addUsageToCombinations(&u.Combinations, model, normalizedServiceTier, contextWindow, weekStartUnix, user, inputTokens, outputTokens, cachedTokens)
|
||||
|
||||
go u.scheduleSave()
|
||||
|
||||
|
||||
285
service/ocm/service_websocket.go
Normal file
285
service/ocm/service_websocket.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package ocm
|
||||
|
||||
import (
|
||||
"context"
|
||||
stdTLS "crypto/tls"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/ntp"
|
||||
"github.com/sagernet/ws"
|
||||
"github.com/sagernet/ws/wsutil"
|
||||
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
)
|
||||
|
||||
type webSocketSession struct {
|
||||
clientConn net.Conn
|
||||
upstreamConn net.Conn
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (s *webSocketSession) Close() {
|
||||
s.closeOnce.Do(func() {
|
||||
s.clientConn.Close()
|
||||
s.upstreamConn.Close()
|
||||
})
|
||||
}
|
||||
|
||||
func buildUpstreamWebSocketURL(baseURL string, proxyPath string) string {
|
||||
upstreamURL := baseURL
|
||||
if strings.HasPrefix(upstreamURL, "https://") {
|
||||
upstreamURL = "wss://" + upstreamURL[len("https://"):]
|
||||
} else if strings.HasPrefix(upstreamURL, "http://") {
|
||||
upstreamURL = "ws://" + upstreamURL[len("http://"):]
|
||||
}
|
||||
return upstreamURL + proxyPath
|
||||
}
|
||||
|
||||
func isForwardableResponseHeader(key string) bool {
|
||||
lowerKey := strings.ToLower(key)
|
||||
switch {
|
||||
case strings.HasPrefix(lowerKey, "x-codex-"):
|
||||
return true
|
||||
case strings.HasPrefix(lowerKey, "x-reasoning"):
|
||||
return true
|
||||
case lowerKey == "openai-model":
|
||||
return true
|
||||
case strings.Contains(lowerKey, "-secondary-"):
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isForwardableWebSocketRequestHeader(key string) bool {
|
||||
if isHopByHopHeader(key) {
|
||||
return false
|
||||
}
|
||||
|
||||
lowerKey := strings.ToLower(key)
|
||||
switch {
|
||||
case lowerKey == "authorization":
|
||||
return false
|
||||
case strings.HasPrefix(lowerKey, "sec-websocket-"):
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleWebSocket(w http.ResponseWriter, r *http.Request, proxyPath string, username string) {
|
||||
accessToken, err := s.getAccessToken()
|
||||
if err != nil {
|
||||
s.logger.Error("get access token for websocket: ", err)
|
||||
writeJSONError(w, r, http.StatusUnauthorized, "authentication_error", "authentication failed")
|
||||
return
|
||||
}
|
||||
|
||||
upstreamURL := buildUpstreamWebSocketURL(s.getBaseURL(), proxyPath)
|
||||
if r.URL.RawQuery != "" {
|
||||
upstreamURL += "?" + r.URL.RawQuery
|
||||
}
|
||||
|
||||
upstreamHeaders := make(http.Header)
|
||||
for key, values := range r.Header {
|
||||
if isForwardableWebSocketRequestHeader(key) {
|
||||
upstreamHeaders[key] = values
|
||||
}
|
||||
}
|
||||
for key, values := range s.httpHeaders {
|
||||
upstreamHeaders.Del(key)
|
||||
upstreamHeaders[key] = values
|
||||
}
|
||||
upstreamHeaders.Set("Authorization", "Bearer "+accessToken)
|
||||
if accountID := s.getAccountID(); accountID != "" {
|
||||
upstreamHeaders.Set("ChatGPT-Account-Id", accountID)
|
||||
}
|
||||
|
||||
upstreamResponseHeaders := make(http.Header)
|
||||
upstreamDialer := ws.Dialer{
|
||||
NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return s.dialer.DialContext(ctx, network, M.ParseSocksaddr(addr))
|
||||
},
|
||||
TLSConfig: &stdTLS.Config{
|
||||
RootCAs: adapter.RootPoolFromContext(s.ctx),
|
||||
Time: ntp.TimeFuncFromContext(s.ctx),
|
||||
},
|
||||
Header: ws.HandshakeHeaderHTTP(upstreamHeaders),
|
||||
OnHeader: func(key, value []byte) error {
|
||||
upstreamResponseHeaders.Add(string(key), string(value))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
upstreamConn, upstreamBufferedReader, _, err := upstreamDialer.Dial(r.Context(), upstreamURL)
|
||||
if err != nil {
|
||||
s.logger.Error("dial upstream websocket: ", err)
|
||||
writeJSONError(w, r, http.StatusBadGateway, "api_error", "upstream websocket connection failed")
|
||||
return
|
||||
}
|
||||
|
||||
weeklyCycleHint := extractWeeklyCycleHint(upstreamResponseHeaders)
|
||||
|
||||
clientResponseHeaders := make(http.Header)
|
||||
for key, values := range upstreamResponseHeaders {
|
||||
if isForwardableResponseHeader(key) {
|
||||
clientResponseHeaders[key] = values
|
||||
}
|
||||
}
|
||||
|
||||
clientUpgrader := ws.HTTPUpgrader{
|
||||
Header: clientResponseHeaders,
|
||||
}
|
||||
if s.isShuttingDown() {
|
||||
upstreamConn.Close()
|
||||
writeJSONError(w, r, http.StatusServiceUnavailable, "api_error", "service is shutting down")
|
||||
return
|
||||
}
|
||||
clientConn, _, _, err := clientUpgrader.Upgrade(r, w)
|
||||
if err != nil {
|
||||
s.logger.Error("upgrade client websocket: ", err)
|
||||
upstreamConn.Close()
|
||||
return
|
||||
}
|
||||
session := &webSocketSession{
|
||||
clientConn: clientConn,
|
||||
upstreamConn: upstreamConn,
|
||||
}
|
||||
if !s.registerWebSocketSession(session) {
|
||||
session.Close()
|
||||
return
|
||||
}
|
||||
defer s.unregisterWebSocketSession(session)
|
||||
|
||||
var upstreamReadWriter io.ReadWriter
|
||||
if upstreamBufferedReader != nil {
|
||||
upstreamReadWriter = struct {
|
||||
io.Reader
|
||||
io.Writer
|
||||
}{upstreamBufferedReader, upstreamConn}
|
||||
} else {
|
||||
upstreamReadWriter = upstreamConn
|
||||
}
|
||||
|
||||
modelChannel := make(chan string, 1)
|
||||
var waitGroup sync.WaitGroup
|
||||
|
||||
waitGroup.Add(2)
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer session.Close()
|
||||
s.proxyWebSocketClientToUpstream(clientConn, upstreamConn, modelChannel)
|
||||
}()
|
||||
go func() {
|
||||
defer waitGroup.Done()
|
||||
defer session.Close()
|
||||
s.proxyWebSocketUpstreamToClient(upstreamReadWriter, clientConn, modelChannel, username, weeklyCycleHint)
|
||||
}()
|
||||
waitGroup.Wait()
|
||||
}
|
||||
|
||||
func (s *Service) proxyWebSocketClientToUpstream(clientConn net.Conn, upstreamConn net.Conn, modelChannel chan<- string) {
|
||||
for {
|
||||
data, opCode, err := wsutil.ReadClientData(clientConn)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("read client websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if opCode == ws.OpText && s.usageTracker != nil {
|
||||
var request struct {
|
||||
Type string `json:"type"`
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if json.Unmarshal(data, &request) == nil && request.Type == "response.create" && request.Model != "" {
|
||||
select {
|
||||
case modelChannel <- request.Model:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = wsutil.WriteClientMessage(upstreamConn, opCode, data)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("write upstream websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) proxyWebSocketUpstreamToClient(upstreamReadWriter io.ReadWriter, clientConn net.Conn, modelChannel <-chan string, username string, weeklyCycleHint *WeeklyCycleHint) {
|
||||
var requestModel string
|
||||
for {
|
||||
data, opCode, err := wsutil.ReadServerData(upstreamReadWriter)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("read upstream websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if opCode == ws.OpText && s.usageTracker != nil {
|
||||
select {
|
||||
case model := <-modelChannel:
|
||||
requestModel = model
|
||||
default:
|
||||
}
|
||||
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
}
|
||||
if json.Unmarshal(data, &event) == nil && event.Type == "response.completed" {
|
||||
var streamEvent responses.ResponseStreamEventUnion
|
||||
if json.Unmarshal(data, &streamEvent) == nil {
|
||||
completedEvent := streamEvent.AsResponseCompleted()
|
||||
responseModel := string(completedEvent.Response.Model)
|
||||
serviceTier := string(completedEvent.Response.ServiceTier)
|
||||
inputTokens := completedEvent.Response.Usage.InputTokens
|
||||
outputTokens := completedEvent.Response.Usage.OutputTokens
|
||||
cachedTokens := completedEvent.Response.Usage.InputTokensDetails.CachedTokens
|
||||
|
||||
if inputTokens > 0 || outputTokens > 0 {
|
||||
if responseModel == "" {
|
||||
responseModel = requestModel
|
||||
}
|
||||
if responseModel != "" {
|
||||
contextWindow := detectContextWindow(responseModel, serviceTier, inputTokens)
|
||||
s.usageTracker.AddUsageWithCycleHint(
|
||||
responseModel,
|
||||
contextWindow,
|
||||
inputTokens,
|
||||
outputTokens,
|
||||
cachedTokens,
|
||||
serviceTier,
|
||||
username,
|
||||
time.Now(),
|
||||
weeklyCycleHint,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = wsutil.WriteServerMessage(clientConn, opCode, data)
|
||||
if err != nil {
|
||||
if !E.IsClosedOrCanceled(err) {
|
||||
s.logger.Debug("write client websocket: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user