package engine import ( "testing" "git.flytoex.net/yuanwei/flyto-agent/pkg/config" "git.flytoex.net/yuanwei/flyto-agent/pkg/query" ) // testObserver 记录事件用于断言. type testObserver struct { events []observedEvent } type observedEvent struct { name string data map[string]any } func (o *testObserver) Event(name string, data map[string]any) { o.events = append(o.events, observedEvent{name: name, data: data}) } func (o *testObserver) Error(err error, ctx map[string]any) {} func newTestBudgetManager(obs EventObserver) *TokenBudgetManager { registry := config.NewModelRegistry() // 注册测试用模型(DefaultModels 已清空,需显式注册) registry.Register("claude-opus-4-6", &config.ModelConfig{ ID: "claude-opus-4-6", ContextWindow: 200000, MaxOutputTokens: 32000, InputPricePer1M: 15.0, OutputPricePer1M: 75.0, }) registry.Register("claude-sonnet-4-6", &config.ModelConfig{ ID: "claude-sonnet-4-6", ContextWindow: 200000, MaxOutputTokens: 16384, InputPricePer1M: 3.0, OutputPricePer1M: 15.0, CacheReadPricePer1M: 0.3, CacheWritePricePer1M: 3.75, }) return NewTokenBudgetManager(registry, obs) } // --- EstimateCurrentUsage --- func TestEstimateCurrentUsage_NoMessages(t *testing.T) { mgr := newTestBudgetManager(nil) result := mgr.EstimateCurrentUsage(nil) if result != 0 { t.Errorf("expected 0 for nil messages, got %d", result) } } func TestEstimateCurrentUsage_NoAnchor(t *testing.T) { mgr := newTestBudgetManager(nil) messages := []query.Message{ {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "hello world"}}}, {Role: query.RoleAssistant, Content: []query.Content{{Type: query.ContentText, Text: "hi there"}}}, } result := mgr.EstimateCurrentUsage(messages) // 没有锚点时应使用全部粗估,结果应 > 0 if result <= 0 { t.Errorf("expected > 0 for messages without anchor, got %d", result) } } func TestEstimateCurrentUsage_WithAnchor(t *testing.T) { mgr := newTestBudgetManager(nil) messages := []query.Message{ {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "hello"}}}, { Role: query.RoleAssistant, Content: []query.Content{{Type: query.ContentText, Text: "response"}}, Metadata: map[string]any{ "usage": &query.Usage{ InputTokens: 1000, OutputTokens: 200, }, }, }, // 锚点之后的新消息 {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "follow up question"}}}, } result := mgr.EstimateCurrentUsage(messages) // 锚点精确值 = 1000 + 200 = 1200,加上后续消息粗估 if result < 1200 { t.Errorf("expected at least 1200 (anchor value), got %d", result) } if result > 2000 { t.Errorf("result unexpectedly high: %d", result) } } func TestEstimateCurrentUsage_WithCachedAnchor(t *testing.T) { mgr := newTestBudgetManager(nil) messages := []query.Message{ {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "hello"}}}, { Role: query.RoleAssistant, Content: []query.Content{{Type: query.ContentText, Text: "response"}}, Metadata: map[string]any{ "usage": &query.Usage{ InputTokens: 500, OutputTokens: 200, Cache: query.CacheTokens{Read: 300, Written: 100}, }, }, }, } result := mgr.EstimateCurrentUsage(messages) // 锚点精确值 = 500 + 200 + 300 + 100 = 1100 if result != 1100 { t.Errorf("expected 1100 (anchor with cache), got %d", result) } } func TestEstimateCurrentUsage_AnchorAtEnd(t *testing.T) { mgr := newTestBudgetManager(nil) messages := []query.Message{ {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "hello"}}}, { Role: query.RoleAssistant, Content: []query.Content{{Type: query.ContentText, Text: "response"}}, Metadata: map[string]any{ "usage": &query.Usage{ InputTokens: 2000, OutputTokens: 500, }, }, }, } result := mgr.EstimateCurrentUsage(messages) // 锚点在最后,没有后续消息,结果应等于锚点精确值 if result != 2500 { t.Errorf("expected 2500 (anchor at end), got %d", result) } } // --- Sibling 回溯 --- func TestSiblingRewind(t *testing.T) { mgr := newTestBudgetManager(nil) messages := []query.Message{ {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "start"}}}, // 并行工具调用:两个 assistant 消息共享同一 api_response_id { Role: query.RoleAssistant, Content: []query.Content{{Type: query.ContentToolUse, Name: "tool_a"}}, Metadata: map[string]any{ "api_response_id": "resp-123", }, }, {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentToolResult, Text: "result_a"}}}, { Role: query.RoleAssistant, Content: []query.Content{{Type: query.ContentToolUse, Name: "tool_b"}}, Metadata: map[string]any{ "api_response_id": "resp-123", "usage": &query.Usage{ InputTokens: 5000, OutputTokens: 300, }, }, }, {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentToolResult, Text: "result_b"}}}, } // 从 index 3 (最后一个 assistant) 开始回溯 result := mgr.siblingRewind(messages, 3) if result != 1 { t.Errorf("expected sibling rewind to index 1, got %d", result) } } func TestSiblingRewind_NoResponseID(t *testing.T) { mgr := newTestBudgetManager(nil) messages := []query.Message{ {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "start"}}}, { Role: query.RoleAssistant, Content: []query.Content{{Type: query.ContentText, Text: "resp"}}, }, } result := mgr.siblingRewind(messages, 1) // 没有 response ID,不回溯 if result != 1 { t.Errorf("expected no rewind (index 1), got %d", result) } } func TestSiblingRewind_DifferentResponseIDs(t *testing.T) { mgr := newTestBudgetManager(nil) messages := []query.Message{ { Role: query.RoleAssistant, Content: []query.Content{{Type: query.ContentText, Text: "first"}}, Metadata: map[string]any{ "api_response_id": "resp-aaa", }, }, {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "user"}}}, { Role: query.RoleAssistant, Content: []query.Content{{Type: query.ContentText, Text: "second"}}, Metadata: map[string]any{ "api_response_id": "resp-bbb", }, }, } result := mgr.siblingRewind(messages, 2) // 不同 response ID,不回溯 if result != 2 { t.Errorf("expected no rewind (index 2), got %d", result) } } // --- GetTokenCountFromUsage --- func TestGetTokenCountFromUsage_Nil(t *testing.T) { if GetTokenCountFromUsage(nil) != 0 { t.Error("expected 0 for nil usage") } } func TestGetTokenCountFromUsage_Complete(t *testing.T) { usage := &query.Usage{ InputTokens: 1000, OutputTokens: 500, Cache: query.CacheTokens{Read: 200, Written: 100}, } result := GetTokenCountFromUsage(usage) if result != 1800 { t.Errorf("expected 1800, got %d", result) } } func TestGetTokenCountFromUsage_NoCaching(t *testing.T) { usage := &query.Usage{ InputTokens: 1000, OutputTokens: 500, } result := GetTokenCountFromUsage(usage) if result != 1500 { t.Errorf("expected 1500, got %d", result) } } // --- GetBillingTokens --- func TestGetBillingTokens_Nil(t *testing.T) { if GetBillingTokens(nil, nil) != 0 { t.Error("expected 0 for nil usage") } } func TestGetBillingTokens_SonnetPricing(t *testing.T) { usage := &query.Usage{ InputTokens: 10000, OutputTokens: 5000, Cache: query.CacheTokens{Read: 2000, Written: 1000}, } pricing := &config.ModelConfig{ InputPricePer1M: 3.0, OutputPricePer1M: 15.0, CacheReadPricePer1M: 0.3, CacheWritePricePer1M: 3.75, } result := GetBillingTokens(usage, pricing) // input: 10000 * 3.0 / 1M = 0.03 // output: 5000 * 15.0 / 1M = 0.075 // cache_read: 2000 * 0.3 / 1M = 0.0006 // cache_write: 1000 * 3.75 / 1M = 0.00375 expected := 0.03 + 0.075 + 0.0006 + 0.00375 if abs(result-expected) > 0.0001 { t.Errorf("expected ~%.6f, got %.6f", expected, result) } } // --- GetFinalContextTokens --- func TestGetFinalContextTokens_Nil(t *testing.T) { if GetFinalContextTokens(nil) != 0 { t.Error("expected 0 for nil usage") } } func TestGetFinalContextTokens_IgnoresCache(t *testing.T) { usage := &query.Usage{ InputTokens: 1000, OutputTokens: 500, Cache: query.CacheTokens{Read: 200, Written: 100}, } result := GetFinalContextTokens(usage) // 只计 input + output,忽略 cache if result != 1500 { t.Errorf("expected 1500 (ignoring cache), got %d", result) } } // --- EffectiveContextWindow --- func TestEffectiveContextWindow_Sonnet(t *testing.T) { mgr := newTestBudgetManager(nil) result := mgr.EffectiveContextWindow("claude-sonnet-4-6") // Sonnet: 200000 - min(16384, 20000) = 200000 - 16384 = 183616 expected := 200000 - 16384 if result != expected { t.Errorf("expected %d, got %d", expected, result) } } func TestEffectiveContextWindow_Opus(t *testing.T) { mgr := newTestBudgetManager(nil) result := mgr.EffectiveContextWindow("claude-opus-4-6") // Opus: 200000 - min(32000, 20000) = 200000 - 20000 = 180000 expected := 200000 - 20000 if result != expected { t.Errorf("expected %d, got %d", expected, result) } } func TestEffectiveContextWindow_UnknownModel(t *testing.T) { mgr := newTestBudgetManager(nil) result := mgr.EffectiveContextWindow("unknown-model") // 未知模型: 200000 - min(16384, 20000) = 200000 - 16384 = 183616 expected := 200000 - 16384 if result != expected { t.Errorf("expected %d for unknown model, got %d", expected, result) } } // --- AutoCompactThreshold --- func TestAutoCompactThreshold_Sonnet(t *testing.T) { mgr := newTestBudgetManager(nil) effective := mgr.EffectiveContextWindow("claude-sonnet-4-6") expected := effective - AutoCompactBufferTokens result := mgr.AutoCompactThreshold("claude-sonnet-4-6") if result != expected { t.Errorf("expected %d, got %d", expected, result) } } // --- EffectiveContextWindowWithThinking --- func TestEffectiveContextWindowWithThinking_NoThinking(t *testing.T) { mgr := newTestBudgetManager(nil) base := mgr.EffectiveContextWindow("claude-sonnet-4-6") result := mgr.EffectiveContextWindowWithThinking("claude-sonnet-4-6", 0) if result != base { t.Errorf("expected %d (no thinking), got %d", base, result) } } func TestEffectiveContextWindowWithThinking_WithBudget(t *testing.T) { mgr := newTestBudgetManager(nil) base := mgr.EffectiveContextWindow("claude-sonnet-4-6") thinkingBudget := 10000 result := mgr.EffectiveContextWindowWithThinking("claude-sonnet-4-6", thinkingBudget) expected := base - thinkingBudget if result != expected { t.Errorf("expected %d (with thinking), got %d", expected, result) } } func TestEffectiveContextWindowWithThinking_MinimumFloor(t *testing.T) { // 注册一个超小窗口模型来测试最小值 registry := config.NewModelRegistry() registry.Register("tiny-model", &config.ModelConfig{ ID: "tiny-model", ContextWindow: 20000, MaxOutputTokens: 4096, }) mgr := NewTokenBudgetManager(registry, nil) result := mgr.EffectiveContextWindowWithThinking("tiny-model", 20000) // 应触发最小值保护 10000 if result != 10000 { t.Errorf("expected minimum 10000, got %d", result) } } // --- CalculateWarningState --- func TestCalculateWarningState_Normal(t *testing.T) { mgr := newTestBudgetManager(nil) effective := mgr.EffectiveContextWindow("claude-sonnet-4-6") // 使用 50% 的窗口 usage := effective / 2 state := mgr.CalculateWarningState(usage, "claude-sonnet-4-6") if state.PercentUsed != 50 { t.Errorf("expected 50%% used, got %d%%", state.PercentUsed) } if state.PercentLeft != 50 { t.Errorf("expected 50%% left, got %d%%", state.PercentLeft) } if state.IsAboveWarningThreshold { t.Error("should not be above warning threshold at 50%") } if state.IsAboveErrorThreshold { t.Error("should not be above error threshold at 50%") } if state.IsAboveAutoCompactThreshold { t.Error("should not be above auto compact threshold at 50%") } if state.IsAtBlockingLimit { t.Error("should not be at blocking limit at 50%") } } func TestCalculateWarningState_AboveWarning(t *testing.T) { mgr := newTestBudgetManager(nil) effective := mgr.EffectiveContextWindow("claude-sonnet-4-6") // 使用到只剩 10K usage := effective - 10000 state := mgr.CalculateWarningState(usage, "claude-sonnet-4-6") if !state.IsAboveWarningThreshold { t.Error("should be above warning threshold (only 10K left)") } if !state.IsAboveAutoCompactThreshold { t.Error("should be above auto compact threshold") } } func TestCalculateWarningState_AtBlockingLimit(t *testing.T) { mgr := newTestBudgetManager(nil) effective := mgr.EffectiveContextWindow("claude-sonnet-4-6") state := mgr.CalculateWarningState(effective, "claude-sonnet-4-6") if !state.IsAtBlockingLimit { t.Error("should be at blocking limit when usage == effective window") } if state.PercentUsed != 100 { t.Errorf("expected 100%% used, got %d%%", state.PercentUsed) } } func TestCalculateWarningState_OverLimit(t *testing.T) { mgr := newTestBudgetManager(nil) effective := mgr.EffectiveContextWindow("claude-sonnet-4-6") state := mgr.CalculateWarningState(effective+5000, "claude-sonnet-4-6") if !state.IsAtBlockingLimit { t.Error("should be at blocking limit when usage exceeds effective window") } if state.PercentUsed != 100 { t.Errorf("expected capped at 100%%, got %d%%", state.PercentUsed) } } // --- OnModelSwitch --- func TestOnModelSwitch_NoOverflow(t *testing.T) { obs := &testObserver{} mgr := newTestBudgetManager(obs) // 低用量,切模型不会溢出 state := mgr.OnModelSwitch("claude-opus-4-6", "claude-sonnet-4-6", 50000) if state.IsAboveAutoCompactThreshold { t.Error("should not overflow with low usage") } // 不应有溢出事件 for _, evt := range obs.events { if evt.name == "token_budget_model_switch_overflow" { t.Error("should not emit overflow event") } } } func TestOnModelSwitch_WithOverflow(t *testing.T) { obs := &testObserver{} registry := config.NewModelRegistry() // 注册一个小窗口模型 registry.Register("small-model", &config.ModelConfig{ ID: "small-model", ContextWindow: 50000, MaxOutputTokens: 4096, }) mgr := NewTokenBudgetManager(registry, obs) // 当前用量 45000,切到小窗口模型 state := mgr.OnModelSwitch("claude-sonnet-4-6", "small-model", 45000) if !state.IsAboveAutoCompactThreshold { t.Error("should overflow when switching to small window model") } // 应有溢出事件 found := false for _, evt := range obs.events { if evt.name == "token_budget_model_switch_overflow" { found = true if evt.data["old_model"] != "claude-sonnet-4-6" { t.Errorf("wrong old_model: %v", evt.data["old_model"]) } if evt.data["new_model"] != "small-model" { t.Errorf("wrong new_model: %v", evt.data["new_model"]) } } } if !found { t.Error("expected overflow event but none found") } } // --- Metadata 辅助 --- func TestExtractUsageFromMetadata_MapFormat(t *testing.T) { msg := &query.Message{ Metadata: map[string]any{ "usage": map[string]any{ "input_tokens": float64(1000), "output_tokens": float64(500), "cache_read_input_tokens": float64(200), "cache_creation_input_tokens": float64(100), }, }, } usage := extractUsageFromMetadata(msg) if usage == nil { t.Fatal("expected non-nil usage") } if usage.InputTokens != 1000 { t.Errorf("expected 1000 input tokens, got %d", usage.InputTokens) } if usage.OutputTokens != 500 { t.Errorf("expected 500 output tokens, got %d", usage.OutputTokens) } if usage.Cache.Read != 200 { t.Errorf("expected 200 cache read tokens, got %d", usage.Cache.Read) } if usage.Cache.Written != 100 { t.Errorf("expected 100 cache creation tokens, got %d", usage.Cache.Written) } } func TestExtractUsageFromMetadata_NoMetadata(t *testing.T) { msg := &query.Message{} if extractUsageFromMetadata(msg) != nil { t.Error("expected nil for message without metadata") } } func TestExtractUsageFromMetadata_NoUsageKey(t *testing.T) { msg := &query.Message{ Metadata: map[string]any{ "other": "value", }, } if extractUsageFromMetadata(msg) != nil { t.Error("expected nil for message without usage key") } } // --- Observer 埋点 --- func TestEstimateCurrentUsage_EmitsEvent(t *testing.T) { obs := &testObserver{} mgr := newTestBudgetManager(obs) messages := []query.Message{ {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "hello"}}}, { Role: query.RoleAssistant, Content: []query.Content{{Type: query.ContentText, Text: "response"}}, Metadata: map[string]any{ "usage": &query.Usage{InputTokens: 100, OutputTokens: 50}, }, }, } mgr.EstimateCurrentUsage(messages) found := false for _, evt := range obs.events { if evt.name == "token_budget_estimated" { found = true if evt.data["anchor_tokens"].(int) != 150 { t.Errorf("expected anchor_tokens=150, got %v", evt.data["anchor_tokens"]) } } } if !found { t.Error("expected token_budget_estimated event") } } // --- AutoCompactThresholdWithThinking --- func TestAutoCompactThresholdWithThinking(t *testing.T) { mgr := newTestBudgetManager(nil) withoutThinking := mgr.AutoCompactThreshold("claude-sonnet-4-6") withThinking := mgr.AutoCompactThresholdWithThinking("claude-sonnet-4-6", 10000) if withThinking >= withoutThinking { t.Errorf("thinking budget should reduce threshold: %d >= %d", withThinking, withoutThinking) } expected := withoutThinking - 10000 if withThinking != expected { t.Errorf("expected %d, got %d", expected, withThinking) } } // --- IsAtBlockingLimit & PickWarningCode (dead-field wire) --- // TestCalculateWarningState_BlockingImpliesCriticalAndWarning locks the // implication chain documented on IsAtBlockingLimit's godoc: when usage // reaches 100% of the effective window, blocking MUST also set // IsAboveErrorThreshold and IsAboveWarningThreshold true. This chain is // what justifies PickWarningCode's severity-first switch (blocking case // short-circuits both weaker branches). // // TestCalculateWarningState_BlockingImpliesCriticalAndWarning 锁定 // IsAtBlockingLimit godoc 声明的 implication chain: usage 达 100% 有效窗口 // 时, blocking 必然令 IsAboveErrorThreshold 和 IsAboveWarningThreshold 同为 // true. 该链正是 PickWarningCode 严重优先 switch 的依据 (blocking case 短 // 路两个较弱分支). func TestCalculateWarningState_BlockingImpliesCriticalAndWarning(t *testing.T) { mgr := newTestBudgetManager(nil) effective := mgr.EffectiveContextWindow("claude-sonnet-4-6") state := mgr.CalculateWarningState(effective, "claude-sonnet-4-6") if !state.IsAtBlockingLimit { t.Fatal("usage == effective 应 IsAtBlockingLimit=true") } if !state.IsAboveErrorThreshold { t.Error("blocking 应同时 IsAboveErrorThreshold=true (implication chain 断裂)") } if !state.IsAboveWarningThreshold { t.Error("blocking 应同时 IsAboveWarningThreshold=true (implication chain 断裂)") } } // TestPickWarningCode_SeverityOrder exercises all four state combinations // against PickWarningCode: nothing-set → "", warning-only → // context_window_warning, up-to-critical → context_window_critical, // up-to-blocking → context_window_blocked. The last case is the regression // test for the original bug: an else-if chain from weakest threshold first // swallowed blocking into the critical branch, leaving IsAtBlockingLimit // unread. // // TestPickWarningCode_SeverityOrder 对 PickWarningCode 四种状态组合各测一次: // 全 false → 空串; 仅 warning → context_window_warning; warning+critical → // context_window_critical; 全 true → context_window_blocked. 最后一 case 正是 // 原 bug 的回归测试: 原 else-if 链从最弱阈值起判把 blocking 吞进 critical, // 令 IsAtBlockingLimit 从未被读. func TestPickWarningCode_SeverityOrder(t *testing.T) { cases := []struct { name string state *TokenWarningState want string }{ {"nil", nil, ""}, {"empty", &TokenWarningState{}, ""}, {"warning_only", &TokenWarningState{IsAboveWarningThreshold: true}, "context_window_warning"}, {"critical", &TokenWarningState{IsAboveWarningThreshold: true, IsAboveErrorThreshold: true}, "context_window_critical"}, {"blocking", &TokenWarningState{IsAboveWarningThreshold: true, IsAboveErrorThreshold: true, IsAtBlockingLimit: true}, "context_window_blocked"}, } for _, tc := range cases { got := PickWarningCode(tc.state) if got != tc.want { t.Errorf("%s: PickWarningCode = %q, want %q", tc.name, got, tc.want) } } } // --- 辅助 --- func abs(x float64) float64 { if x < 0 { return -x } return x }