package daemon import ( "context" "errors" "sync/atomic" "testing" "time" ) func TestCrashRecovery_NormalExit(t *testing.T) { cr := NewCrashRecovery(DefaultCrashRecoveryConfig()) calls := 0 err := cr.RunWithRecovery(context.Background(), "s1", func() error { calls++ return nil // 正常退出 }) if err != nil { t.Errorf("normal exit should return nil, got %v", err) } if calls != 1 { t.Errorf("fn should be called exactly once, got %d", calls) } } func TestCrashRecovery_RetriesOnError(t *testing.T) { cr := NewCrashRecovery(CrashRecoveryConfig{ MaxRetries: 2, InitialDelay: 1 * time.Millisecond, MaxDelay: 10 * time.Millisecond, Multiplier: 2.0, }) var calls atomic.Int32 testErr := errors.New("crash") err := cr.RunWithRecovery(context.Background(), "s1", func() error { calls.Add(1) return testErr }) // MaxRetries=2 → 初始尝试 + 2 次重试 = 3 次调用 if calls.Load() != 3 { t.Errorf("expected 3 calls (1 initial + 2 retries), got %d", calls.Load()) } if !errors.Is(err, testErr) { t.Errorf("should return original error, got %v", err) } } func TestCrashRecovery_SucceedsAfterRetry(t *testing.T) { cr := NewCrashRecovery(CrashRecoveryConfig{ MaxRetries: 3, InitialDelay: 1 * time.Millisecond, MaxDelay: 10 * time.Millisecond, Multiplier: 2.0, }) var calls atomic.Int32 err := cr.RunWithRecovery(context.Background(), "s1", func() error { n := calls.Add(1) if n < 3 { return errors.New("not ready yet") } return nil // 第 3 次成功 }) if err != nil { t.Errorf("should succeed on 3rd attempt, got %v", err) } if calls.Load() != 3 { t.Errorf("expected 3 calls, got %d", calls.Load()) } } func TestCrashRecovery_OnCrashCallback(t *testing.T) { var crashCount atomic.Int32 var giveUpCount atomic.Int32 cr := NewCrashRecovery(CrashRecoveryConfig{ MaxRetries: 2, InitialDelay: 1 * time.Millisecond, MaxDelay: 5 * time.Millisecond, Multiplier: 2.0, OnCrash: func(_ string, _ int, _ error) { crashCount.Add(1) }, OnGiveUp: func(_ string, _ int) { giveUpCount.Add(1) }, }) cr.RunWithRecovery(context.Background(), "s1", func() error { return errors.New("crash") }) // 3 次崩溃(1 initial + 2 retries),OnCrash 调用 3 次 if crashCount.Load() != 3 { t.Errorf("expected 3 OnCrash calls, got %d", crashCount.Load()) } if giveUpCount.Load() != 1 { t.Errorf("expected 1 OnGiveUp call, got %d", giveUpCount.Load()) } } func TestCrashRecovery_CtxCancelStopsRetry(t *testing.T) { cr := NewCrashRecovery(CrashRecoveryConfig{ MaxRetries: 100, InitialDelay: 50 * time.Millisecond, // 每次重试等 50ms MaxDelay: 200 * time.Millisecond, Multiplier: 1.0, // 固定间隔便于测试 }) ctx, cancel := context.WithCancel(context.Background()) var calls atomic.Int32 go func() { time.Sleep(120 * time.Millisecond) cancel() }() err := cr.RunWithRecovery(ctx, "s1", func() error { calls.Add(1) return errors.New("always fail") }) // ctx 取消后应停止 if !errors.Is(err, context.Canceled) { t.Errorf("expected context.Canceled, got %v", err) } // 调用次数应该 < 100(ctx 取消了) if calls.Load() >= 10 { t.Errorf("too many calls (%d), ctx cancel should have stopped retries", calls.Load()) } } func TestCrashRecovery_MaxRetryZeroNoRetry(t *testing.T) { cr := NewCrashRecovery(CrashRecoveryConfig{ MaxRetries: 0, // 不重试 InitialDelay: 1 * time.Millisecond, MaxDelay: 10 * time.Millisecond, Multiplier: 2.0, }) var calls atomic.Int32 cr.RunWithRecovery(context.Background(), "s1", func() error { calls.Add(1) return errors.New("crash") }) if calls.Load() != 1 { t.Errorf("MaxRetries=0 should call fn exactly once, got %d", calls.Load()) } } func TestCrashRecovery_BackoffDelay(t *testing.T) { cr := NewCrashRecovery(CrashRecoveryConfig{ MaxRetries: 5, InitialDelay: 10 * time.Millisecond, MaxDelay: 100 * time.Millisecond, Multiplier: 2.0, }) // attempt=1: 10ms // attempt=2: 20ms // attempt=3: 40ms // attempt=4: 80ms // attempt=5: 100ms(上限) tests := []struct { attempt int want time.Duration }{ {1, 10 * time.Millisecond}, {2, 20 * time.Millisecond}, {3, 40 * time.Millisecond}, {4, 80 * time.Millisecond}, {5, 100 * time.Millisecond}, // 达到上限 {6, 100 * time.Millisecond}, // 超出上限,仍为 100ms } for _, tt := range tests { got := cr.backoffDelay(tt.attempt) if got != tt.want { t.Errorf("backoffDelay(%d) = %v, want %v", tt.attempt, got, tt.want) } } }