// timeout_test.go -- Handler 超时保护的单元测试(任务 5.5). // // 覆盖场景: // - Handler 在超时内正常返回 → 取决策结果 // - Handler 超时(阻塞不返回)→ 返回 DecisionDeny + reason="permission handler timeout" // - 上层 context 取消时 → 透传 context.Canceled,不当作 timeout 处理 // - 自定义超时(< DefaultPermissionTimeout)正确生效 // - NewEngineWithTimeout 零值超时回退到默认值 // // P1-1 并发安全测试(SetMode/Mode 原子操作): // - 并发读写 Mode 不触发 data race // - SetMode 对后续 Mode() 立即可见 package permission import ( "context" "errors" "sync" "testing" "time" ) // TestHandlerTimeout_NormalReturn 测试 Handler 在超时前正常返回. // 预期:返回 Handler 的决策结果,不触发超时逻辑. func TestHandlerTimeout_NormalReturn(t *testing.T) { // Handler 立即返回 Allow handler := func(ctx context.Context, req *Request) (*Response, error) { return &Response{Decision: DecisionAllow, Reason: "ok"}, nil } e := NewEngineWithTimeout(ModeDefault, handler, 5*time.Second) req := &Request{ToolName: "Bash", Input: map[string]any{"command": "echo hi"}} resp, err := e.Check(context.Background(), req) if err != nil { t.Fatalf("不应报错,实际: %v", err) } if resp.Decision != DecisionAllow { t.Errorf("预期 Allow,实际: %s (reason: %s)", resp.Decision, resp.Reason) } } // TestHandlerTimeout_HandlerBlocks 测试 Handler 阻塞超时. // 预期:超时后返回 DecisionDeny + reason 含 "permission handler timeout". func TestHandlerTimeout_HandlerBlocks(t *testing.T) { // Handler 永远阻塞(直到 ctx 被取消) handler := func(ctx context.Context, req *Request) (*Response, error) { <-ctx.Done() return nil, ctx.Err() } // 使用很短的超时,让测试快速结束 e := NewEngineWithTimeout(ModeDefault, handler, 50*time.Millisecond) req := &Request{ToolName: "Bash", Input: map[string]any{"command": "rm -rf /"}} start := time.Now() resp, err := e.Check(context.Background(), req) elapsed := time.Since(start) if err != nil { t.Fatalf("超时应返回 DecisionDeny 而非 error,实际 error: %v", err) } if resp.Decision != DecisionDeny { t.Errorf("超时应返回 DecisionDeny,实际: %s", resp.Decision) } if resp.Reason != "permission handler timeout" { t.Errorf("reason 应为 'permission handler timeout',实际: %q", resp.Reason) } // 超时时间应在合理范围内(50ms + 200ms 缓冲) if elapsed > 250*time.Millisecond { t.Errorf("超时等待时间过长: %v(应 < 250ms)", elapsed) } t.Logf("超时正确触发,耗时 %v,Decision=%s,Reason=%s", elapsed, resp.Decision, resp.Reason) } // TestHandlerTimeout_ContextCanceled 测试上层 context 取消时透传错误. // 预期:返回 context.Canceled 错误,而不是返回超时拒绝. func TestHandlerTimeout_ContextCanceled(t *testing.T) { // Handler 永远阻塞 handler := func(ctx context.Context, req *Request) (*Response, error) { <-ctx.Done() return nil, ctx.Err() } e := NewEngineWithTimeout(ModeDefault, handler, 5*time.Second) // 预先取消的 context ctx, cancel := context.WithCancel(context.Background()) cancel() // 立即取消 req := &Request{ToolName: "Bash", Input: map[string]any{"command": "echo hi"}} _, err := e.Check(ctx, req) if err == nil { t.Fatal("上层 context 取消时应返回错误") } if !errors.Is(err, context.Canceled) { t.Errorf("应返回 context.Canceled,实际: %v", err) } } // TestHandlerTimeout_HandlerError 测试 Handler 返回 error 时透传错误. // 预期:Handler 的错误被透传给调用方. func TestHandlerTimeout_HandlerError(t *testing.T) { expectedErr := errors.New("some handler error") handler := func(ctx context.Context, req *Request) (*Response, error) { return nil, expectedErr } e := NewEngineWithTimeout(ModeDefault, handler, 5*time.Second) req := &Request{ToolName: "Bash", Input: map[string]any{"command": "echo hi"}} _, err := e.Check(context.Background(), req) if !errors.Is(err, expectedErr) { t.Errorf("Handler error 应被透传,预期 %v,实际: %v", expectedErr, err) } } // TestHandlerTimeout_DefaultTimeout 测试 NewEngineWithTimeout 零值回退到默认值. func TestHandlerTimeout_DefaultTimeout(t *testing.T) { // 使用 0 超时应等同于 DefaultPermissionTimeout handler := func(ctx context.Context, req *Request) (*Response, error) { return &Response{Decision: DecisionAllow, Reason: "ok"}, nil } e := NewEngineWithTimeout(ModeDefault, handler, 0) eng := e.(*engine) if eng.handlerTimeout != DefaultPermissionTimeout { t.Errorf("零值超时应回退到 DefaultPermissionTimeout (%v),实际: %v", DefaultPermissionTimeout, eng.handlerTimeout) } } // TestHandlerTimeout_DefaultPermissionTimeoutConst 验证常量值正确. func TestHandlerTimeout_DefaultPermissionTimeoutConst(t *testing.T) { if DefaultPermissionTimeout != 5*time.Minute { t.Errorf("DefaultPermissionTimeout 应为 5min,实际: %v", DefaultPermissionTimeout) } } // TestHandlerTimeout_DenialTrackerUpdated 测试超时时拒绝追踪器被正确更新. // 即:超时触发 DecisionDeny,denial tracker 应记录该拒绝. func TestHandlerTimeout_DenialTrackerUpdated(t *testing.T) { // Handler 永远阻塞 handler := func(ctx context.Context, req *Request) (*Response, error) { <-ctx.Done() return nil, ctx.Err() } e := NewEngineWithTimeout(ModeDefault, handler, 20*time.Millisecond) eng := e.(*engine) req := &Request{ ToolName: "Bash", Input: map[string]any{"command": "dangerous_cmd"}, } resp, _ := e.Check(context.Background(), req) if resp.Decision != DecisionDeny { t.Fatalf("应超时拒绝,实际: %s", resp.Decision) } // denial tracker 应该记录了一次拒绝 stats := eng.denial.Stats() if stats.TotalDenials == 0 { t.Error("超时触发 DecisionDeny 时,denial tracker 应记录拒绝") } } // --- P1-1:SetMode/Mode 并发安全测试 --- // TestSetMode_ConcurrentReadWrite 验证 SetMode/Mode 在并发场景下不触发 data race. // 用 -race 标志运行可检测到并发问题(go test -race ./pkg/permission/). func TestSetMode_ConcurrentReadWrite(t *testing.T) { e := NewEngine(ModeDefault, func(_ context.Context, _ *Request) (*Response, error) { return &Response{Decision: DecisionAllow}, nil }) const goroutines = 50 var wg sync.WaitGroup wg.Add(goroutines * 2) // 一半 goroutine 并发写 for i := 0; i < goroutines; i++ { go func(i int) { defer wg.Done() if i%2 == 0 { e.SetMode(ModePlan) } else { e.SetMode(ModeDefault) } }(i) } // 另一半 goroutine 并发读 for i := 0; i < goroutines; i++ { go func() { defer wg.Done() m := e.Mode() // 读到的值必须是有效的 Mode(不能是零值或不合法值) if m != ModeDefault && m != ModeAcceptEdits && m != ModeBypass && m != ModePlan { t.Errorf("Mode() 返回了非法值: %q", m) } }() } wg.Wait() } // TestSetMode_Visibility 验证 SetMode 对 Mode() 的可见性. func TestSetMode_Visibility(t *testing.T) { e := NewEngine(ModeDefault, nil) e.SetMode(ModePlan) if got := e.Mode(); got != ModePlan { t.Errorf("SetMode(ModePlan) 后 Mode() 应返回 ModePlan,实际: %q", got) } e.SetMode(ModeBypass) if got := e.Mode(); got != ModeBypass { t.Errorf("SetMode(ModeBypass) 后 Mode() 应返回 ModeBypass,实际: %q", got) } } // TestNewEngine_DefaultMode 验证 NewEngine 空 mode 时默认为 ModeDefault. func TestNewEngine_DefaultMode(t *testing.T) { e := NewEngine("", nil) // 空字符串应回退到 ModeDefault if got := e.Mode(); got != ModeDefault { t.Errorf("空 mode 应回退到 ModeDefault,实际: %q", got) } }