package api import ( "errors" "fmt" "net/http" "testing" ) // ============================================================ // ErrorCategory 测试 // ============================================================ func TestErrorCategory_String(t *testing.T) { tests := []struct { cat ErrorCategory want string }{ {ErrUnknown, "unknown"}, {ErrAborted, "aborted"}, {ErrTimeout, "api_timeout"}, {ErrRateLimit, "rate_limit"}, {ErrOverloaded, "server_overload"}, {ErrPromptTooLong, "prompt_too_long"}, {ErrMediaTooLarge, "media_too_large"}, {ErrRequestTooLarge, "request_too_large"}, {ErrInvalidRequest, "invalid_request"}, {ErrAuthentication, "auth_error"}, {ErrModelNotFound, "model_not_found"}, {ErrBilling, "billing_error"}, {ErrServerError, "server_error"}, {ErrConnection, "connection_error"}, {ErrSSL, "ssl_cert_error"}, {ErrToolMismatch, "tool_use_mismatch"}, {ErrUnexpectedTool, "unexpected_tool_result"}, {ErrDuplicateToolID, "duplicate_tool_use_id"}, {ErrInvalidModel, "invalid_model"}, {ErrContentPolicy, "content_policy"}, } for _, tt := range tests { if got := tt.cat.String(); got != tt.want { t.Errorf("ErrorCategory(%d).String() = %q, want %q", tt.cat, got, tt.want) } } } func TestErrorCategory_IsRetryableByDefault(t *testing.T) { retryable := []ErrorCategory{ErrTimeout, ErrOverloaded, ErrServerError, ErrConnection} notRetryable := []ErrorCategory{ ErrUnknown, ErrAborted, ErrRateLimit, ErrPromptTooLong, ErrMediaTooLarge, ErrInvalidRequest, ErrAuthentication, ErrModelNotFound, ErrBilling, ErrSSL, ErrToolMismatch, } for _, cat := range retryable { if !cat.IsRetryableByDefault() { t.Errorf("%s should be retryable by default", cat) } } for _, cat := range notRetryable { if cat.IsRetryableByDefault() { t.Errorf("%s should NOT be retryable by default", cat) } } } func TestErrorCategory_Unknown(t *testing.T) { // 超出范围的值应该返回 "unknown" cat := ErrorCategory(999) if got := cat.String(); got != "unknown" { t.Errorf("out-of-range category.String() = %q, want %q", got, "unknown") } } // ============================================================ // APIError 测试 // ============================================================ func TestAPIError_Error(t *testing.T) { // 有状态码 e1 := &APIError{ErrCategory: ErrRateLimit, StatusCode: 429, Msg: "too many requests"} got1 := e1.Error() if got1 != "api: HTTP 429 [rate_limit]: too many requests" { t.Errorf("Error() = %q", got1) } // 无状态码(连接错误) e2 := &APIError{ErrCategory: ErrConnection, Msg: "connection refused"} got2 := e2.Error() if got2 != "api: [connection_error]: connection refused" { t.Errorf("Error() = %q", got2) } } func TestAPIError_Unwrap(t *testing.T) { cause := fmt.Errorf("underlying cause") e := &APIError{ErrCategory: ErrConnection, Cause: cause} if !errors.Is(e, cause) { t.Error("errors.Is should find the cause") } } func TestAPIError_ErrorsAs(t *testing.T) { // 验证 errors.As 可以从包装错误中提取 *APIError inner := &APIError{ErrCategory: ErrRateLimit, StatusCode: 429, Msg: "rate limited"} wrapped := fmt.Errorf("wrapped: %w", inner) var apiErr *APIError if !errors.As(wrapped, &apiErr) { t.Fatal("errors.As should extract *APIError") } if apiErr.ErrCategory != ErrRateLimit { t.Errorf("Category = %v, want ErrRateLimit", apiErr.ErrCategory) } } func TestAPIError_IsRetryable(t *testing.T) { // RetryInfo 优先 e1 := &APIError{ ErrCategory: ErrRateLimit, // 默认不可重试 Retry: &RetryInfo{Retryable: true}, } if !e1.IsRetryable() { t.Error("should be retryable when RetryInfo says so") } // Retry 覆盖默认 e2 := &APIError{ ErrCategory: ErrTimeout, // 默认可重试 Retry: &RetryInfo{Retryable: false}, } if e2.IsRetryable() { t.Error("should NOT be retryable when RetryInfo says no") } // 无 Retry 用默认 e3 := &APIError{ErrCategory: ErrServerError} if !e3.IsRetryable() { t.Error("ServerError should be retryable by default") } } func TestAPIError_AnalyticsTag(t *testing.T) { e := &APIError{ErrCategory: ErrOverloaded} if got := e.AnalyticsTag(); got != "server_overload" { t.Errorf("AnalyticsTag() = %q, want %q", got, "server_overload") } } // ============================================================ // ParseTokenGap 测试 // ============================================================ func TestParseTokenGap(t *testing.T) { tests := []struct { msg string want int }{ // 标准格式 {"prompt is too long: 137500 tokens > 135000 maximum", 2500}, // 大写(Vertex) {"Prompt is too long: 200000 tokens > 128000 maximum", 72000}, // 无匹配 {"some other error", 0}, // actual < limit(不应该发生,但要容错) {"prompt is too long: 100 tokens > 200 maximum", 0}, // 带前缀的消息 {"400 Bad Request prompt is too long: 10000 tokens > 8000 maximum", 2000}, // 单数 token {"prompt is too long: 1 token > 0 maximum", 1}, } for _, tt := range tests { got := ParseTokenGap(tt.msg) if got != tt.want { t.Errorf("ParseTokenGap(%q) = %d, want %d", tt.msg, got, tt.want) } } } // ============================================================ // SanitizeErrorHTML 测试 // ============================================================ func TestSanitizeErrorHTML(t *testing.T) { tests := []struct { name string input string want string }{ { "plain text passthrough", "some error message", "some error message", }, { "CloudFlare HTML with title", `