package api import ( "net/http" "testing" ) // ============================================================ // DefaultClassifier 测试 // ============================================================ func TestDefaultClassifier_StatusCodes(t *testing.T) { c := &DefaultClassifier{} tests := []struct { name string status int body string wantCat ErrorCategory wantRetry bool }{ {"400 generic", 400, `{"error":{"type":"invalid_request_error","message":"bad request"}}`, ErrInvalidRequest, false}, {"401 auth", 401, `{"error":{"type":"auth","message":"invalid key"}}`, ErrAuthentication, false}, {"403 forbidden", 403, `{"error":{"type":"auth","message":"forbidden"}}`, ErrAuthentication, false}, {"404 not found", 404, `{"error":{"type":"not_found","message":"model not found"}}`, ErrModelNotFound, false}, {"408 timeout", 408, ``, ErrTimeout, true}, {"413 too large", 413, `body too large`, ErrRequestTooLarge, false}, {"429 rate limit", 429, `{"error":{"type":"rate_limit","message":"too many"}}`, ErrRateLimit, true}, {"529 overloaded", 529, `{"error":{"type":"overloaded","message":"busy"}}`, ErrOverloaded, true}, {"500 server", 500, `Internal Server Error`, ErrServerError, true}, {"502 gateway", 502, `Bad Gateway`, ErrServerError, true}, {"200 unknown", 200, ``, ErrUnknown, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := c.Classify(tt.status, nil, []byte(tt.body), nil) if result.ErrCategory != tt.wantCat { t.Errorf("Category = %v, want %v", result.ErrCategory, tt.wantCat) } if result.IsRetryable() != tt.wantRetry { t.Errorf("IsRetryable() = %v, want %v", result.IsRetryable(), tt.wantRetry) } }) } } func TestDefaultClassifier_BadRequest_SubTypes(t *testing.T) { c := &DefaultClassifier{} tests := []struct { name string body string wantCat ErrorCategory }{ { "prompt too long", `{"error":{"type":"invalid_request_error","message":"prompt is too long: 137500 tokens > 135000 maximum"}}`, ErrPromptTooLong, }, { "prompt too long uppercase", `{"error":{"type":"invalid_request_error","message":"Prompt is too long: 200000 tokens > 128000"}}`, ErrPromptTooLong, }, { "image too large", `{"error":{"type":"invalid_request_error","message":"image exceeds 5 MB maximum: 6000000 bytes"}}`, ErrMediaTooLarge, }, { "many image dimensions", `{"error":{"type":"invalid_request_error","message":"image dimensions exceed limit for many-image requests"}}`, ErrMediaTooLarge, }, { "PDF page limit", `{"error":{"type":"invalid_request_error","message":"maximum of 100 PDF pages exceeded"}}`, ErrMediaTooLarge, }, { "tool_use mismatch", `{"error":{"type":"invalid_request_error","message":"tool_use` + "`" + ` ids were found without ` + "`" + `tool_result` + "`" + ` blocks immediately after"}}`, ErrToolMismatch, }, { "unexpected tool_result", `{"error":{"type":"invalid_request_error","message":"unexpected ` + "`" + `tool_use_id` + "`" + ` found in ` + "`" + `tool_result` + "`" + `"}}`, ErrUnexpectedTool, }, { "duplicate tool_use ID", `{"error":{"type":"invalid_request_error","message":"` + "`" + `tool_use` + "`" + ` ids must be unique"}}`, ErrDuplicateToolID, }, { "invalid model", `{"error":{"type":"invalid_request_error","message":"Invalid model name: foo-bar"}}`, ErrInvalidModel, }, { "credit balance", `{"error":{"type":"billing_error","message":"Your credit balance is too low"}}`, ErrBilling, }, { "generic 400", `{"error":{"type":"invalid_request_error","message":"something else"}}`, ErrInvalidRequest, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := c.Classify(400, nil, []byte(tt.body), nil) if result.ErrCategory != tt.wantCat { t.Errorf("Category = %v, want %v", result.ErrCategory, tt.wantCat) } }) } } func TestDefaultClassifier_ConnectionError(t *testing.T) { c := &DefaultClassifier{} // 超时错误 result := c.Classify(0, nil, nil, &mockTimeoutError{msg: "dial tcp: i/o timeout"}) if result.ErrCategory != ErrTimeout { t.Errorf("timeout error: Category = %v, want ErrTimeout", result.ErrCategory) } if !result.IsRetryable() { t.Error("timeout should be retryable") } } // ============================================================ // AnthropicClassifier 测试 // ============================================================ func TestAnthropicClassifier_OverloadedDetection(t *testing.T) { c := &AnthropicClassifier{} // SDK 有时丢失 529 状态码,但消息包含 overloaded_error body := []byte(`{"type":"overloaded_error","message":"Overloaded"}`) result := c.Classify(400, nil, body, nil) // 应该被重新分类为 ErrOverloaded(而非 ErrInvalidRequest) if result.ErrCategory != ErrOverloaded { t.Errorf("overloaded detection: Category = %v, want ErrOverloaded", result.ErrCategory) } if !result.IsRetryable() { t.Error("overloaded should be retryable") } } func TestAnthropicClassifier_ShouldRetryHeader(t *testing.T) { c := &AnthropicClassifier{} // x-should-retry: true headers := http.Header{} headers.Set("x-should-retry", "true") result := c.Classify(500, headers, []byte(`server error`), nil) if result.Retry == nil || result.Retry.ServerSaid == nil || !*result.Retry.ServerSaid { t.Error("should parse x-should-retry: true") } // x-should-retry: false headers2 := http.Header{} headers2.Set("x-should-retry", "false") result2 := c.Classify(429, headers2, []byte(`rate limited`), nil) if result2.Retry == nil || result2.Retry.ServerSaid == nil || *result2.Retry.ServerSaid { t.Error("should parse x-should-retry: false") } if result2.IsRetryable() { t.Error("should NOT be retryable when server says false") } } func TestAnthropicClassifier_TokenGapExtraction(t *testing.T) { c := &AnthropicClassifier{} body := []byte(`{"error":{"type":"invalid_request_error","message":"prompt is too long: 137500 tokens > 135000 maximum"}}`) result := c.Classify(400, nil, body, nil) if result.ErrCategory != ErrPromptTooLong { t.Errorf("Category = %v, want ErrPromptTooLong", result.ErrCategory) } if result.TokenGap != 2500 { t.Errorf("TokenGap = %d, want 2500", result.TokenGap) } } func TestAnthropicClassifier_RetryAfterHeader(t *testing.T) { c := &AnthropicClassifier{} headers := http.Header{} headers.Set("Retry-After", "30") result := c.Classify(429, headers, []byte(`rate limited`), nil) if result.Retry == nil { t.Fatal("Retry should not be nil for 429") } if int(result.Retry.After.Seconds()) != 30 { t.Errorf("Retry.After = %v, want 30s", result.Retry.After) } } func TestAnthropicClassifier_DiagnosticHint(t *testing.T) { c := &AnthropicClassifier{Hinter: &DefaultHinter{}} // SSL 错误应该有提示 result := c.Classify(0, nil, nil, &mockError{msg: "tls: failed to verify certificate"}) if result.ErrCategory != ErrSSL { t.Errorf("Category = %v, want ErrSSL", result.ErrCategory) } if result.Hint == "" { t.Error("SSL error should have diagnostic hint") } } // ============================================================ // CompositeClassifier 测试 // ============================================================ func TestCompositeClassifier_PriorityOrder(t *testing.T) { // 第一个分类器返回 ErrUnknown,第二个返回 ErrRateLimit first := &stubClassifier{category: ErrUnknown} second := &stubClassifier{category: ErrRateLimit} comp := NewCompositeClassifier(first, second) result := comp.Classify(429, nil, nil, nil) if result.ErrCategory != ErrRateLimit { t.Errorf("Category = %v, want ErrRateLimit (from second classifier)", result.ErrCategory) } } func TestCompositeClassifier_FirstWins(t *testing.T) { first := &stubClassifier{category: ErrOverloaded} second := &stubClassifier{category: ErrRateLimit} comp := NewCompositeClassifier(first, second) result := comp.Classify(529, nil, nil, nil) if result.ErrCategory != ErrOverloaded { t.Errorf("Category = %v, want ErrOverloaded (from first classifier)", result.ErrCategory) } } func TestCompositeClassifier_Add(t *testing.T) { comp := NewCompositeClassifier() comp.Add(&stubClassifier{category: ErrBilling}) result := comp.Classify(400, nil, nil, nil) if result.ErrCategory != ErrBilling { t.Errorf("Category = %v, want ErrBilling", result.ErrCategory) } } func TestCompositeClassifier_Fallback(t *testing.T) { // 全部返回 ErrUnknown comp := NewCompositeClassifier(&stubClassifier{category: ErrUnknown}) result := comp.Classify(999, nil, []byte("weird"), nil) if result.ErrCategory != ErrUnknown { t.Errorf("Category = %v, want ErrUnknown", result.ErrCategory) } } // ============================================================ // 测试辅助 // ============================================================ type stubClassifier struct { category ErrorCategory } func (s *stubClassifier) Classify(statusCode int, headers http.Header, body []byte, cause error) *APIError { return &APIError{ ErrCategory: s.category, StatusCode: statusCode, Msg: string(body), } } type mockTimeoutError struct { msg string } func (e *mockTimeoutError) Error() string { return e.msg } func (e *mockTimeoutError) Timeout() bool { return true } func (e *mockTimeoutError) Temporary() bool { return true } type mockError struct { msg string } func (e *mockError) Error() string { return e.msg }