package api import ( "errors" "fmt" "net/http" "testing" ) // ============================================================ // ErrorCategory 测试 // ============================================================ func TestErrorCategory_String(t *testing.T) { tests := []struct { cat ErrorCategory want string }{ {ErrUnknown, "unknown"}, {ErrAborted, "aborted"}, {ErrTimeout, "api_timeout"}, {ErrRateLimit, "rate_limit"}, {ErrOverloaded, "server_overload"}, {ErrPromptTooLong, "prompt_too_long"}, {ErrMediaTooLarge, "media_too_large"}, {ErrRequestTooLarge, "request_too_large"}, {ErrInvalidRequest, "invalid_request"}, {ErrAuthentication, "auth_error"}, {ErrModelNotFound, "model_not_found"}, {ErrBilling, "billing_error"}, {ErrServerError, "server_error"}, {ErrConnection, "connection_error"}, {ErrSSL, "ssl_cert_error"}, {ErrToolMismatch, "tool_use_mismatch"}, {ErrUnexpectedTool, "unexpected_tool_result"}, {ErrDuplicateToolID, "duplicate_tool_use_id"}, {ErrInvalidModel, "invalid_model"}, {ErrContentPolicy, "content_policy"}, } for _, tt := range tests { if got := tt.cat.String(); got != tt.want { t.Errorf("ErrorCategory(%d).String() = %q, want %q", tt.cat, got, tt.want) } } } func TestErrorCategory_IsRetryableByDefault(t *testing.T) { retryable := []ErrorCategory{ErrTimeout, ErrOverloaded, ErrServerError, ErrConnection} notRetryable := []ErrorCategory{ ErrUnknown, ErrAborted, ErrRateLimit, ErrPromptTooLong, ErrMediaTooLarge, ErrInvalidRequest, ErrAuthentication, ErrModelNotFound, ErrBilling, ErrSSL, ErrToolMismatch, } for _, cat := range retryable { if !cat.IsRetryableByDefault() { t.Errorf("%s should be retryable by default", cat) } } for _, cat := range notRetryable { if cat.IsRetryableByDefault() { t.Errorf("%s should NOT be retryable by default", cat) } } } func TestErrorCategory_Unknown(t *testing.T) { // 超出范围的值应该返回 "unknown" cat := ErrorCategory(999) if got := cat.String(); got != "unknown" { t.Errorf("out-of-range category.String() = %q, want %q", got, "unknown") } } // ============================================================ // APIError 测试 // ============================================================ func TestAPIError_Error(t *testing.T) { // 有状态码 e1 := &APIError{ErrCategory: ErrRateLimit, StatusCode: 429, Msg: "too many requests"} got1 := e1.Error() if got1 != "api: HTTP 429 [rate_limit]: too many requests" { t.Errorf("Error() = %q", got1) } // 无状态码(连接错误) e2 := &APIError{ErrCategory: ErrConnection, Msg: "connection refused"} got2 := e2.Error() if got2 != "api: [connection_error]: connection refused" { t.Errorf("Error() = %q", got2) } } func TestAPIError_Unwrap(t *testing.T) { cause := fmt.Errorf("underlying cause") e := &APIError{ErrCategory: ErrConnection, Cause: cause} if !errors.Is(e, cause) { t.Error("errors.Is should find the cause") } } func TestAPIError_ErrorsAs(t *testing.T) { // 验证 errors.As 可以从包装错误中提取 *APIError inner := &APIError{ErrCategory: ErrRateLimit, StatusCode: 429, Msg: "rate limited"} wrapped := fmt.Errorf("wrapped: %w", inner) var apiErr *APIError if !errors.As(wrapped, &apiErr) { t.Fatal("errors.As should extract *APIError") } if apiErr.ErrCategory != ErrRateLimit { t.Errorf("Category = %v, want ErrRateLimit", apiErr.ErrCategory) } } func TestAPIError_IsRetryable(t *testing.T) { // RetryInfo 优先 e1 := &APIError{ ErrCategory: ErrRateLimit, // 默认不可重试 Retry: &RetryInfo{Retryable: true}, } if !e1.IsRetryable() { t.Error("should be retryable when RetryInfo says so") } // Retry 覆盖默认 e2 := &APIError{ ErrCategory: ErrTimeout, // 默认可重试 Retry: &RetryInfo{Retryable: false}, } if e2.IsRetryable() { t.Error("should NOT be retryable when RetryInfo says no") } // 无 Retry 用默认 e3 := &APIError{ErrCategory: ErrServerError} if !e3.IsRetryable() { t.Error("ServerError should be retryable by default") } } func TestAPIError_AnalyticsTag(t *testing.T) { e := &APIError{ErrCategory: ErrOverloaded} if got := e.AnalyticsTag(); got != "server_overload" { t.Errorf("AnalyticsTag() = %q, want %q", got, "server_overload") } } // ============================================================ // ParseTokenGap 测试 // ============================================================ func TestParseTokenGap(t *testing.T) { tests := []struct { msg string want int }{ // 标准格式 {"prompt is too long: 137500 tokens > 135000 maximum", 2500}, // 大写(Vertex) {"Prompt is too long: 200000 tokens > 128000 maximum", 72000}, // 无匹配 {"some other error", 0}, // actual < limit(不应该发生,但要容错) {"prompt is too long: 100 tokens > 200 maximum", 0}, // 带前缀的消息 {"400 Bad Request prompt is too long: 10000 tokens > 8000 maximum", 2000}, // 单数 token {"prompt is too long: 1 token > 0 maximum", 1}, } for _, tt := range tests { got := ParseTokenGap(tt.msg) if got != tt.want { t.Errorf("ParseTokenGap(%q) = %d, want %d", tt.msg, got, tt.want) } } } // ============================================================ // SanitizeErrorHTML 测试 // ============================================================ func TestSanitizeErrorHTML(t *testing.T) { tests := []struct { name string input string want string }{ { "plain text passthrough", "some error message", "some error message", }, { "CloudFlare HTML with title", `502 Bad Gateway

502

`, "502 Bad Gateway", }, { "HTML without title", `Error`, "", }, { "empty string", "", "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := SanitizeErrorHTML(tt.input) if got != tt.want { t.Errorf("SanitizeErrorHTML() = %q, want %q", got, tt.want) } }) } } // ============================================================ // ParseAPIErrorBody 测试 // ============================================================ func TestParseAPIErrorBody(t *testing.T) { tests := []struct { name string body string wantType string wantMsg string }{ { "standard API error", `{"type":"error","error":{"type":"invalid_request_error","message":"prompt is too long"}}`, "invalid_request_error", "prompt is too long", }, { "flat error", `{"type":"overloaded_error","message":"server busy"}`, "overloaded_error", "server busy", }, { "empty body", "", "", "", }, { "HTML body", `503 Service Unavailable`, "", "503 Service Unavailable", }, { "plain text body", "Internal Server Error", "", "Internal Server Error", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotType, gotMsg := ParseAPIErrorBody([]byte(tt.body)) if gotType != tt.wantType { t.Errorf("type = %q, want %q", gotType, tt.wantType) } if gotMsg != tt.wantMsg { t.Errorf("message = %q, want %q", gotMsg, tt.wantMsg) } }) } } // ============================================================ // ParseRetryAfter / ParseShouldRetry 测试 // ============================================================ func TestParseRetryAfter(t *testing.T) { tests := []struct { value string want int // seconds }{ {"30", 30}, {"0", 0}, {"-1", 0}, {"", 0}, {"abc", 0}, {" 5 ", 5}, } for _, tt := range tests { got := ParseRetryAfter(tt.value) wantDuration := 0 if tt.want > 0 { wantDuration = tt.want } if int(got.Seconds()) != wantDuration { t.Errorf("ParseRetryAfter(%q) = %v, want %ds", tt.value, got, wantDuration) } } } func TestParseShouldRetry(t *testing.T) { // true v := ParseShouldRetry("true") if v == nil || !*v { t.Error("ParseShouldRetry('true') should return *true") } // false v = ParseShouldRetry("false") if v == nil || *v { t.Error("ParseShouldRetry('false') should return *false") } // empty = nil v = ParseShouldRetry("") if v != nil { t.Error("ParseShouldRetry('') should return nil") } // unknown = nil v = ParseShouldRetry("maybe") if v != nil { t.Error("ParseShouldRetry('maybe') should return nil") } // case insensitive v = ParseShouldRetry("True") if v == nil || !*v { t.Error("ParseShouldRetry('True') should return *true") } } // ============================================================ // APIError 作为 http.Header 的载体 // ============================================================ func TestAPIError_HeadersPreserved(t *testing.T) { h := http.Header{} h.Set("X-Request-Id", "abc123") h.Set("Retry-After", "30") e := &APIError{ ErrCategory: ErrRateLimit, StatusCode: 429, RespHeaders: h, } if got := e.RespHeaders.Get("X-Request-Id"); got != "abc123" { t.Errorf("X-Request-Id = %q, want %q", got, "abc123") } }