package engine import ( "testing" ) // --- stop_reason 兜底测试 --- // TestIsMaxOutputTokensError 测试 max_tokens 截断检测 func TestIsMaxOutputTokensError(t *testing.T) { tests := []struct { stopReason string want bool }{ {"max_tokens", true}, {"tool_use", false}, {"end_turn", false}, {"", false}, } for _, tt := range tests { got := isMaxOutputTokensError(tt.stopReason) if got != tt.want { t.Errorf("isMaxOutputTokensError(%q) = %v, want %v", tt.stopReason, got, tt.want) } } } // TestExtractToolCalls_WithBlocks 测试从 blockState 提取工具调用 func TestExtractToolCalls_WithBlocks(t *testing.T) { blocks := map[int]*blockState{ 0: {blockType: "text", text: "hello"}, 1: {blockType: "tool_use", id: "tool_1", name: "bash", partialJSON: `{"command":"ls"}`}, 2: {blockType: "tool_use", id: "tool_2", name: "grep", partialJSON: `{"pattern":"foo"}`}, } calls := extractToolCalls(blocks) if len(calls) != 2 { t.Fatalf("expected 2 tool calls, got %d", len(calls)) } // 验证提取的 ID 和名称 ids := map[string]string{} for _, c := range calls { ids[c.ID] = c.Name } if ids["tool_1"] != "bash" { t.Error("tool_1 should be 'bash'") } if ids["tool_2"] != "grep" { t.Error("tool_2 should be 'grep'") } } // TestExtractToolCalls_Empty 测试没有工具调用时 func TestExtractToolCalls_Empty(t *testing.T) { blocks := map[int]*blockState{ 0: {blockType: "text", text: "just text"}, 1: {blockType: "thinking", text: "thinking..."}, } calls := extractToolCalls(blocks) if len(calls) != 0 { t.Errorf("expected 0 tool calls, got %d", len(calls)) } } // TestExtractToolCalls_EmptyJSON 测试空 JSON 输入 func TestExtractToolCalls_EmptyJSON(t *testing.T) { blocks := map[int]*blockState{ 0: {blockType: "tool_use", id: "tool_1", name: "bash", partialJSON: ""}, } calls := extractToolCalls(blocks) if len(calls) != 1 { t.Fatalf("expected 1 tool call, got %d", len(calls)) } if string(calls[0].Input) != "{}" { t.Errorf("empty partialJSON should produce {}, got %s", string(calls[0].Input)) } } // TestBuildAssistantBlocks 测试构建 assistant 消息块 func TestBuildAssistantBlocks(t *testing.T) { blocks := map[int]*blockState{ 0: {blockType: "text", text: "hello"}, 1: {blockType: "tool_use", id: "t1", name: "bash", partialJSON: `{"cmd":"ls"}`}, } result := buildAssistantBlocks(blocks) if len(result) != 2 { t.Fatalf("expected 2 blocks, got %d", len(result)) } if result[0].Type != "text" || result[0].Text != "hello" { t.Error("first block should be text") } if result[1].Type != "tool_use" || result[1].ID != "t1" { t.Error("second block should be tool_use") } } // --- max_tokens 升级逻辑测试 --- // TestMaxTokensConstants 测试常量值是否合理 func TestMaxTokensConstants(t *testing.T) { // 验证默认值和升级值的关系 defaultMax := 8192 escalatedMax := 64000 if defaultMax >= escalatedMax { t.Error("default max_tokens should be less than escalated") } if defaultMax < 1024 { t.Error("default max_tokens too small") } if escalatedMax > 200000 { t.Error("escalated max_tokens too large") } } // --- isRetryableError 向后兼容测试 --- // TestIsRetryableError_BackwardCompat 测试旧接口的向后兼容性 func TestIsRetryableError_BackwardCompat(t *testing.T) { // isRetryableError(无 source 参数)应默认为前台行为 if !isRetryableError("api: HTTP 429: rate limited") { t.Error("429 should be retryable") } if !isRetryableError("api: HTTP 529: overloaded") { t.Error("529 should be retryable for default (MainThread) source") } if isRetryableError("api: HTTP 500: error") { t.Error("500 should NOT be retryable") } } // --- isPartialStream 空响应兜底测试 --- // TestIsPartialStream 验证三个条件的精确交集. func TestIsPartialStream(t *testing.T) { tests := []struct { name string receivedMessageStart bool hasAnyContentBlock bool stopReason string want bool }{ { name: "典型 proxy 截断:有消息头、无内容、无 stop_reason", receivedMessageStart: true, hasAnyContentBlock: false, stopReason: "", want: true, }, { name: "正常回复:有消息头、有内容、有 stop_reason", receivedMessageStart: true, hasAnyContentBlock: true, stopReason: "end_turn", want: false, }, { name: "有内容时不算 partial-stream(即使无 stop_reason)", receivedMessageStart: true, hasAnyContentBlock: true, stopReason: "", want: false, }, { name: "无消息头:HTTP 层面失败,不是 partial-stream", receivedMessageStart: false, hasAnyContentBlock: false, stopReason: "", want: false, }, { name: "有 stop_reason 时不算截断(极少见的合法空响应)", receivedMessageStart: true, hasAnyContentBlock: false, stopReason: "end_turn", want: false, }, { name: "tool_use stop_reason 有内容:正常工具调用", receivedMessageStart: true, hasAnyContentBlock: true, stopReason: "tool_use", want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := isPartialStream(tt.receivedMessageStart, tt.hasAnyContentBlock, tt.stopReason) if got != tt.want { t.Errorf("isPartialStream(%v, %v, %q) = %v, want %v", tt.receivedMessageStart, tt.hasAnyContentBlock, tt.stopReason, got, tt.want) } }) } } // TestErrStreamTruncated_Registered 验证新错误码在各注册表中都存在. func TestErrStreamTruncated_Registered(t *testing.T) { // 确认错误码有默认 suggestion if _, ok := defaultSuggestions[ErrStreamTruncated]; !ok { t.Error("ErrStreamTruncated 缺少 defaultSuggestions 条目") } // 确认错误码有 retryable 标记 if _, ok := defaultRetryable[ErrStreamTruncated]; !ok { t.Error("ErrStreamTruncated 缺少 defaultRetryable 条目") } // partial-stream 是瞬态故障,应该标记为可重试 if !defaultRetryable[ErrStreamTruncated] { t.Error("ErrStreamTruncated 应为可重试(瞬态代理故障)") } // 验证 NewEngineError 能正确构建该错误 err := NewEngineError(ErrStreamTruncated, "test truncation", nil) if err.Code != ErrStreamTruncated { t.Errorf("unexpected code: %s", err.Code) } if !err.Retryable { t.Error("constructed error should be retryable") } } // --- TestClassifyAPIError_Integration 测试错误分类 --- // TestClassifyAPIError_Integration 测试错误分类 func TestClassifyAPIError_Integration(t *testing.T) { tests := []struct { errStr string want ErrorCode }{ {"api: HTTP 401: unauthorized", ErrAPIAuth}, {"api: HTTP 429: rate limited", ErrAPIRateLimit}, {"api: HTTP 529: overloaded", ErrAPIOverloaded}, {"api: HTTP 400: bad request", ErrAPIBadRequest}, {"api: HTTP 500: internal error", ErrInternal}, } for _, tt := range tests { got := ClassifyAPIError(tt.errStr) if got != tt.want { t.Errorf("ClassifyAPIError(%q) = %v, want %v", tt.errStr, got, tt.want) } } }