package retry import ( "context" "errors" "net/http" "testing" "time" ) func TestRetryer_SuccessOnFirstAttempt(t *testing.T) { r := &Retryer{Policy: DefaultExponentialBackoff()} rctx := &RetryContext{IsForeground: true} err := r.Do(context.Background(), rctx, func(attempt int, rctx *RetryContext) error { return nil }) if err != nil { t.Errorf("expected nil error, got %v", err) } } func TestRetryer_SuccessAfterRetry(t *testing.T) { r := &Retryer{ Policy: &ExponentialBackoff{ BaseDelay: 1 * time.Millisecond, // 快速测试 MaxRetries: 5, JitterPct: 0, }, } rctx := &RetryContext{IsForeground: true} callCount := 0 err := r.Do(context.Background(), rctx, func(attempt int, rctx *RetryContext) error { callCount++ if callCount < 3 { return &mockRetryError{ category: "server_error", retryable: true, retryInfo: &RetryInfo{Retryable: true}, } } return nil }) if err != nil { t.Errorf("expected success after retries, got %v", err) } if callCount != 3 { t.Errorf("expected 3 calls, got %d", callCount) } } func TestRetryer_NonAPIErrorNoRetry(t *testing.T) { r := &Retryer{Policy: DefaultExponentialBackoff()} rctx := &RetryContext{IsForeground: true} err := r.Do(context.Background(), rctx, func(attempt int, rctx *RetryContext) error { return errors.New("not an API error") }) if err == nil || err.Error() != "not an API error" { t.Errorf("non-API errors should not be retried, got %v", err) } } func TestRetryer_MaxRetriesExhausted(t *testing.T) { r := &Retryer{ Policy: &ExponentialBackoff{ BaseDelay: 1 * time.Millisecond, MaxRetries: 2, JitterPct: 0, }, } rctx := &RetryContext{IsForeground: true} callCount := 0 err := r.Do(context.Background(), rctx, func(attempt int, rctx *RetryContext) error { callCount++ return &mockRetryError{ category: "server_error", retryable: true, retryInfo: &RetryInfo{Retryable: true}, } }) var cannotRetry *CannotRetryError if !errors.As(err, &cannotRetry) { t.Fatalf("expected CannotRetryError, got %T: %v", err, err) } if callCount != 3 { // 1 initial + 2 retries t.Errorf("expected 3 calls (1+2 retries), got %d", callCount) } } func TestRetryer_ContextCancellation(t *testing.T) { r := &Retryer{ Policy: &ExponentialBackoff{ BaseDelay: 1 * time.Second, // 长延迟 MaxRetries: 10, JitterPct: 0, }, } rctx := &RetryContext{IsForeground: true} ctx, cancel := context.WithCancel(context.Background()) callCount := 0 go func() { time.Sleep(10 * time.Millisecond) cancel() }() err := r.Do(ctx, rctx, func(attempt int, rctx *RetryContext) error { callCount++ return &mockRetryError{ category: "server_error", retryable: true, retryInfo: &RetryInfo{Retryable: true}, } }) var cannotRetry *CannotRetryError if !errors.As(err, &cannotRetry) { t.Fatalf("expected CannotRetryError on cancellation, got %T", err) } } func TestRetryer_OnRetryCallback(t *testing.T) { var retryAttempts []int r := &Retryer{ Policy: &ExponentialBackoff{ BaseDelay: 1 * time.Millisecond, MaxRetries: 5, JitterPct: 0, }, OnRetry: func(err RetryError, attempt int, delay time.Duration, reason string) { retryAttempts = append(retryAttempts, attempt) }, } rctx := &RetryContext{IsForeground: true} callCount := 0 _ = r.Do(context.Background(), rctx, func(attempt int, rctx *RetryContext) error { callCount++ if callCount < 4 { return &mockRetryError{ category: "server_error", retryable: true, retryInfo: &RetryInfo{Retryable: true}, } } return nil }) if len(retryAttempts) != 3 { t.Errorf("expected 3 retry callbacks, got %d", len(retryAttempts)) } } func TestRetryer_ConsecutiveCountTracking(t *testing.T) { r := &Retryer{ Policy: &ExponentialBackoff{ BaseDelay: 1 * time.Millisecond, MaxRetries: 10, JitterPct: 0, }, } rctx := &RetryContext{IsForeground: true} // 两次 529,然后一次 429,再一次 529 sequence := []string{ "server_overload", "server_overload", "rate_limit", "server_overload", } idx := 0 _ = r.Do(context.Background(), rctx, func(attempt int, rctx *RetryContext) error { if idx >= len(sequence) { return nil } cat := sequence[idx] idx++ return &mockRetryError{ category: cat, retryable: true, retryInfo: &RetryInfo{Retryable: true}, } }) // 最后一次 529 的计数应该是 1(被 429 打断了) if rctx.ConsecutiveCounts["server_overload"] != 1 { t.Errorf("consecutive overloaded = %d, want 1 (interrupted by rate_limit)", rctx.ConsecutiveCounts["server_overload"]) } } func TestRetryer_OverflowCorrection(t *testing.T) { r := &Retryer{ Policy: &ExponentialBackoff{ BaseDelay: 1 * time.Millisecond, MaxRetries: 5, JitterPct: 0, }, OverflowHandler: DefaultOverflowHandler(), } rctx := &RetryContext{IsForeground: true} callCount := 0 err := r.Do(context.Background(), rctx, func(attempt int, rctx *RetryContext) error { callCount++ if callCount == 1 { // 第一次返回溢出错误 return &mockRetryError{ category: "invalid_request", retryable: false, retryInfo: &RetryInfo{Retryable: false}, // message 通过 Message() 方法返回,用于 ParseOverflow } } // 第二次检查 max_tokens 被修正了 if rctx.MaxTokensOverride == 0 { t.Error("MaxTokensOverride should have been set") } return nil }) // 注意:overflow 测试需要 Message() 返回含溢出信息的字符串 // 这个测试验证的是溢出 handler 不会崩溃(message 不含溢出信息时跳过) _ = err if callCount < 1 { t.Errorf("expected at least 1 call, got %d", callCount) } } func TestRetryer_OverflowCorrection_WithMessage(t *testing.T) { r := &Retryer{ Policy: &ExponentialBackoff{ BaseDelay: 1 * time.Millisecond, MaxRetries: 5, JitterPct: 0, }, OverflowHandler: DefaultOverflowHandler(), } rctx := &RetryContext{IsForeground: true} callCount := 0 err := r.Do(context.Background(), rctx, func(attempt int, rctx *RetryContext) error { callCount++ if callCount == 1 { // 第一次返回含溢出信息的错误 return &mockOverflowError{ msg: "input length and `max_tokens` exceed context limit: 188059 + 20000 > 200000", } } // 第二次检查 max_tokens 被修正了 if rctx.MaxTokensOverride == 0 { t.Error("MaxTokensOverride should have been set") } return nil }) if err != nil { t.Errorf("expected success after overflow correction, got %v", err) } if callCount != 2 { t.Errorf("expected 2 calls, got %d", callCount) } } func TestRetryer_FallbackTriggered(t *testing.T) { r := &Retryer{ Policy: NewCompositeRetryPolicy( &ModelFallback{ConsecutiveThreshold: 2}, &ExponentialBackoff{ BaseDelay: 1 * time.Millisecond, MaxRetries: 10, JitterPct: 0, }, ), } rctx := &RetryContext{ IsForeground: true, Model: "opus-4", FallbackModel: "sonnet-4", } err := r.Do(context.Background(), rctx, func(attempt int, rctx *RetryContext) error { return &mockRetryError{ category: "server_overload", retryable: true, retryInfo: &RetryInfo{Retryable: true}, } }) var fallback *FallbackTriggeredError if !errors.As(err, &fallback) { t.Fatalf("expected FallbackTriggeredError, got %T: %v", err, err) } if fallback.OriginalModel != "opus-4" || fallback.FallbackModel != "sonnet-4" { t.Errorf("fallback models wrong: %s -> %s", fallback.OriginalModel, fallback.FallbackModel) } } // mockOverflowError 是携带 overflow 错误消息的 RetryError mock. // 用于测试 ContextOverflowHandler 从错误消息中提取溢出参数的逻辑. type mockOverflowError struct { msg string } func (m *mockOverflowError) Error() string { return m.msg } func (m *mockOverflowError) Category() string { return "invalid_request" } func (m *mockOverflowError) IsRetryable() bool { return false } func (m *mockOverflowError) RetryDelay() time.Duration { return 0 } func (m *mockOverflowError) Message() string { return m.msg } func (m *mockOverflowError) Headers() http.Header { return nil } func (m *mockOverflowError) RetryInfo() *RetryInfo { return nil } // --------------------------------------------------------------------------- // QuerySource wire regression guards // --------------------------------------------------------------------------- // TestWithQuerySource_RoundTrip guards the ctx injection helper pair: // writing a source via WithQuerySource must be readable via // QuerySourceFromCtx; absent injection returns empty string without error. // // TestWithQuerySource_RoundTrip 保护 ctx 注入 helper 对: 经 WithQuerySource // 写入的 source 必须可通过 QuerySourceFromCtx 读出; 未注入时返回空串不报错. func TestWithQuerySource_RoundTrip(t *testing.T) { base := context.Background() if got := QuerySourceFromCtx(base); got != "" { t.Errorf("empty ctx: got %q, want empty", got) } ctx := WithQuerySource(base, "main_thread") if got := QuerySourceFromCtx(ctx); got != "main_thread" { t.Errorf("round trip: got %q, want main_thread", got) } // Derived ctx must inherit the label (standard context semantics). // 派生 ctx 必须继承标签 (标准 context 语义). child, cancel := context.WithCancel(ctx) defer cancel() if got := QuerySourceFromCtx(child); got != "main_thread" { t.Errorf("derived ctx: got %q, want main_thread", got) } } // TestRetryer_Do_EmbedsQuerySourceInCannotRetryReason is the load-bearing // regression: RetryContext.QuerySource is the godoc-promised "request // source label" only if a caller reading a retry failure can see it. // Here we set up a retry that exhausts (policy says "no"), and assert the // CannotRetryError.Reason contains "source=summary", proving the field // flows from rctx into the error message that operators read. // // TestRetryer_Do_EmbedsQuerySourceInCannotRetryReason 是承载性回归: // RetryContext.QuerySource 只有在读 retry 失败的调用方能看到它时才兑现 // godoc 承诺的"请求来源标识". 此处构造一个耗尽的 retry (策略拒绝), // 断言 CannotRetryError.Reason 含 "source=summary", 证明字段从 rctx // 流到运维看到的错误消息. func TestRetryer_Do_EmbedsQuerySourceInCannotRetryReason(t *testing.T) { // Policy that always denies: forces immediate CannotRetryError. // 始终拒绝的策略: 强制立刻 CannotRetryError. denyAll := &stubPolicy{decision: &RetryDecision{Retry: false, Reason: "test denied"}} r := &Retryer{Policy: denyAll} rctx := &RetryContext{IsForeground: false, QuerySource: "summary"} err := r.Do(context.Background(), rctx, func(attempt int, rctx *RetryContext) error { return &mockRetryError{category: "server_error", retryable: true, retryInfo: &RetryInfo{Retryable: true}} }) var cre *CannotRetryError if !errors.As(err, &cre) { t.Fatalf("expected CannotRetryError, got %T: %v", err, err) } if !containsSourceLabel(cre.Reason, "summary") { t.Errorf("Reason = %q, want containing source=summary", cre.Reason) } } // TestRetryer_Do_EmptyQuerySource_ReasonUnchanged guards the graceful // degradation: when no QuerySource is set, the CannotRetryError.Reason // string is clean (no stray "(source=)" suffix), so uninstrumented call // sites still produce readable errors. // // TestRetryer_Do_EmptyQuerySource_ReasonUnchanged 保护优雅降级: 未设 // QuerySource 时 CannotRetryError.Reason 干净 (不带多余 "(source=)" 尾巴), // 未插桩 call site 仍产出可读错误. func TestRetryer_Do_EmptyQuerySource_ReasonUnchanged(t *testing.T) { denyAll := &stubPolicy{decision: &RetryDecision{Retry: false, Reason: "test denied"}} r := &Retryer{Policy: denyAll} rctx := &RetryContext{IsForeground: false} // no QuerySource err := r.Do(context.Background(), rctx, func(attempt int, rctx *RetryContext) error { return &mockRetryError{category: "server_error", retryable: true, retryInfo: &RetryInfo{Retryable: true}} }) var cre *CannotRetryError if !errors.As(err, &cre) { t.Fatalf("expected CannotRetryError, got %T: %v", err, err) } if cre.Reason != "test denied" { t.Errorf("Reason = %q, want exactly 'test denied' (no source suffix)", cre.Reason) } } // containsSourceLabel reports whether reason ends with "(source=