// sampling_test.go -- pre/post-sampling hook 的单元测试(模块 9.2). // // 覆盖场景: // - BuildPreSamplingEnv 环境变量正确性(HOOK_TYPE / MODEL / TURN / MESSAGE_COUNT) // - BuildPostSamplingEnv 环境变量正确性(INPUT_TOKENS / OUTPUT_TOKENS / STOP_REASON) // - RESPONSE_PREVIEW 截断到 500 字节 // - HookPreSampling / HookPostSampling 在 AllHookTypes() 中存在 // - Manager 可注册和执行 pre/post-sampling hook // - pre-sampling exit 0 放行 / exit 2 阻止(ParseStopHookResponse 语义复用) // - post-sampling 异步执行,不阻塞调用方 package hooks import ( "context" "fmt" "runtime" "strings" "sync" "testing" "time" "git.flytoex.net/yuanwei/flyto-agent/pkg/execenv" ) // TestBuildPreSamplingEnv 测试 pre_sampling 环境变量构建. func TestBuildPreSamplingEnv(t *testing.T) { env := BuildPreSamplingEnv("claude-sonnet-4-6", 3, 7, "/project") checks := map[string]string{ "HOOK_TYPE": string(HookPreSampling), "MODEL": "claude-sonnet-4-6", "TURN": "3", "MESSAGE_COUNT": "7", "PROJECT_ROOT": "/project", "PLATFORM": runtime.GOOS, } for key, want := range checks { if got := env[key]; got != want { t.Errorf("BuildPreSamplingEnv: %s = %q, 期望 %q", key, got, want) } } // 确认没有 MESSAGES_JSON(防止 env 过大,只传元数据) if _, ok := env["MESSAGES_JSON"]; ok { t.Error("BuildPreSamplingEnv 不应包含 MESSAGES_JSON(防止 env 超 ARG_MAX)") } } // TestBuildPostSamplingEnv 测试 post_sampling 环境变量构建. func TestBuildPostSamplingEnv(t *testing.T) { env := BuildPostSamplingEnv( "claude-opus-4-6", 2, 1234, 567, "end_turn", "这是模型的回复文本", "/project", ) checks := map[string]string{ "HOOK_TYPE": string(HookPostSampling), "MODEL": "claude-opus-4-6", "TURN": "2", "INPUT_TOKENS": "1234", "OUTPUT_TOKENS": "567", "STOP_REASON": "end_turn", "RESPONSE_PREVIEW": "这是模型的回复文本", "PROJECT_ROOT": "/project", } for key, want := range checks { if got := env[key]; got != want { t.Errorf("BuildPostSamplingEnv: %s = %q, 期望 %q", key, got, want) } } } // TestBuildPostSamplingEnv_PreviewTruncation 测试 RESPONSE_PREVIEW 截断到 500 字节. // // 精妙之处(CLEVER): 用 500 字节的 ASCII 字符串 + 多字节 UTF-8 后缀来验证截断-- // 确认截断边界在 500 字节处(不是字符数), // 同时验证超长响应不会把 env 撑到危险大小. func TestBuildPostSamplingEnv_PreviewTruncation(t *testing.T) { // 600 字节的 ASCII 字符串(每字符 1 字节) longResponse := strings.Repeat("x", 600) env := BuildPostSamplingEnv("model", 1, 0, 0, "end_turn", longResponse, "/p") preview := env["RESPONSE_PREVIEW"] if len(preview) != 500 { t.Errorf("RESPONSE_PREVIEW 应截断到 500 字节, 实际 %d 字节", len(preview)) } // 短响应不应截断 shortResponse := "hello world" env2 := BuildPostSamplingEnv("model", 1, 0, 0, "end_turn", shortResponse, "/p") if env2["RESPONSE_PREVIEW"] != shortResponse { t.Errorf("短响应不应截断: %q", env2["RESPONSE_PREVIEW"]) } } // TestAllHookTypes_ContainsSamplingHooks 验证 AllHookTypes 包含两个新类型. // 如果漏加,Manager.Register 会返回 "unknown hook type" 错误. func TestAllHookTypes_ContainsSamplingHooks(t *testing.T) { all := AllHookTypes() found := map[HookType]bool{} for _, ht := range all { found[ht] = true } for _, want := range []HookType{HookPreSampling, HookPostSampling} { if !found[want] { t.Errorf("AllHookTypes() 缺少 %q", want) } } } // TestManager_PreSampling_Allow 测试 pre-sampling hook exit 0 放行. // // 验证:注册了 pre-sampling hook 但 exit 0,ParseStopHookResponse 返回 shouldStop=false. func TestManager_PreSampling_Allow(t *testing.T) { m := NewManager(nil, execenv.DefaultExecutor{}) m.Disable() // 禁用真实执行,用 CallbackHandler 代替 allowCalled := false err := m.Register(HookPreSampling, HookDef{ Handler: HookHandlerFunc(func(ctx context.Context, hookType HookType, env map[string]string) *HookResult { allowCalled = true // 验证 env 包含必要字段 if env["HOOK_TYPE"] != string(HookPreSampling) { return &HookResult{ExitCode: 1, Stderr: "wrong hook type"} } if env["MODEL"] == "" { return &HookResult{ExitCode: 1, Stderr: "MODEL missing"} } return &HookResult{ExitCode: 0} // 放行 }), }) if err != nil { t.Fatalf("Register: %v", err) } // 重新启用执行(Handler 路径不走 shell) m.Enable() env := BuildPreSamplingEnv("claude-sonnet-4-6", 1, 3, "/project") results, execErr := m.Execute(context.Background(), HookPreSampling, env) if execErr != nil { t.Fatalf("Execute: %v", execErr) } if !allowCalled { t.Error("pre-sampling handler 应该被调用") } shouldStop, reason := ParseStopHookResponse(results) if shouldStop { t.Errorf("exit 0 不应阻止: reason=%q", reason) } } // TestManager_PreSampling_Block 测试 pre-sampling hook exit 2 阻止. // // 精妙之处(CLEVER): 复用 ParseStopHookResponse 而非新建解析函数-- // pre_sampling 的"阻止"和 stop hook 的"阻止继续"语义相同(非零退出 = 停止). // 测试也复用同样的解析路径,保证行为一致. func TestManager_PreSampling_Block(t *testing.T) { m := NewManager(nil, execenv.DefaultExecutor{}) const blockReason = "quota exceeded: monthly limit reached" err := m.Register(HookPreSampling, HookDef{ Handler: HookHandlerFunc(func(ctx context.Context, hookType HookType, env map[string]string) *HookResult { return &HookResult{ ExitCode: ExitCodeBlock, // exit 2 = 有意阻止 Stderr: blockReason, } }), }) if err != nil { t.Fatalf("Register: %v", err) } env := BuildPreSamplingEnv("model", 1, 0, "/p") results, _ := m.Execute(context.Background(), HookPreSampling, env) shouldStop, reason := ParseStopHookResponse(results) if !shouldStop { t.Error("exit 2 应该触发阻止") } if reason != blockReason { t.Errorf("阻止原因应为 %q, 实际 %q", blockReason, reason) } } // TestManager_PostSampling_AsyncNonBlocking 测试 post-sampling 异步执行不阻塞调用方. // // 验证:ExecuteAsync 立即返回 channel,hook 在后台执行, // 调用方不需要等待 hook 完成(符合 fire-and-forget 语义). func TestManager_PostSampling_AsyncNonBlocking(t *testing.T) { m := NewManager(nil, execenv.DefaultExecutor{}) hookDone := make(chan struct{}) var mu sync.Mutex calledAt := time.Time{} err := m.Register(HookPostSampling, HookDef{ Async: true, // 标记为异步 Handler: HookHandlerFunc(func(ctx context.Context, hookType HookType, env map[string]string) *HookResult { // 模拟慢 hook(50ms) time.Sleep(50 * time.Millisecond) mu.Lock() calledAt = time.Now() mu.Unlock() close(hookDone) return &HookResult{ExitCode: 0} }), }) if err != nil { t.Fatalf("Register: %v", err) } startTime := time.Now() env := BuildPostSamplingEnv("model", 1, 100, 50, "end_turn", "hello", "/p") // ExecuteAsync 立即返回(不等 hook 完成) resultCh := m.ExecuteAsync(HookPostSampling, env) elapsed := time.Since(startTime) // 调用应该在 hook 的 50ms sleep 完成之前返回(< 40ms 视为"立即") if elapsed > 40*time.Millisecond { t.Errorf("ExecuteAsync 应立即返回,实际等待了 %v", elapsed) } // 等待 hook 实际完成(最多 500ms) select { case <-hookDone: // ok case <-time.After(500 * time.Millisecond): t.Error("hook 应在 500ms 内完成") } // 消费 result channel(确保没有 goroutine 泄漏) select { case <-resultCh: case <-time.After(100 * time.Millisecond): t.Error("resultCh 应在 hook 完成后关闭") } mu.Lock() defer mu.Unlock() if calledAt.IsZero() { t.Error("hook handler 应该被调用") } } // TestManager_PostSampling_CallbackReceivesCorrectEnv 测试 post-sampling 回调接收正确的 env. func TestManager_PostSampling_CallbackReceivesCorrectEnv(t *testing.T) { m := NewManager(nil, execenv.DefaultExecutor{}) var receivedEnv map[string]string var mu sync.Mutex err := m.Register(HookPostSampling, HookDef{ Handler: HookHandlerFunc(func(ctx context.Context, hookType HookType, env map[string]string) *HookResult { mu.Lock() // 复制 env(map 不安全共享) receivedEnv = make(map[string]string, len(env)) for k, v := range env { receivedEnv[k] = v } mu.Unlock() return &HookResult{ExitCode: 0} }), }) if err != nil { t.Fatalf("Register: %v", err) } env := BuildPostSamplingEnv( "claude-haiku-4-5", 5, 2000, 300, "tool_use", "调用了 Bash 工具", "/project", ) _, execErr := m.Execute(context.Background(), HookPostSampling, env) if execErr != nil { t.Fatalf("Execute: %v", execErr) } mu.Lock() defer mu.Unlock() if receivedEnv == nil { t.Fatal("handler 未被调用") } checks := map[string]string{ "HOOK_TYPE": string(HookPostSampling), "MODEL": "claude-haiku-4-5", "TURN": "5", "INPUT_TOKENS": "2000", "STOP_REASON": "tool_use", } for key, want := range checks { if got := receivedEnv[key]; got != want { t.Errorf("env[%q] = %q, 期望 %q", key, got, want) } } } // TestBuildPostSamplingEnv_AllStopReasons 测试所有合法的 stop_reason 值. // 验证 env builder 不限制 stop_reason 的值(引擎原样传入). func TestBuildPostSamplingEnv_AllStopReasons(t *testing.T) { reasons := []string{"end_turn", "tool_use", "max_tokens", "stop_sequence", ""} for _, reason := range reasons { t.Run(fmt.Sprintf("stop_reason=%s", reason), func(t *testing.T) { env := BuildPostSamplingEnv("model", 1, 0, 0, reason, "", "/p") if env["STOP_REASON"] != reason { t.Errorf("STOP_REASON = %q, 期望 %q", env["STOP_REASON"], reason) } }) } }