package engine import ( "context" "sync" "sync/atomic" "testing" "time" "git.flytoex.net/yuanwei/flyto-agent/pkg/security" ) // ============================================================ // Close 防重入 // ============================================================ func TestEngine_Close_Idempotent(t *testing.T) { e := &Engine{ cfg: testConfig(), observer: &NoopObserver{}, sessionState: struct { mu sync.RWMutex sessions map[string]*Session }{ sessions: make(map[string]*Session), }, } // 设置 root context e.rootCtx, e.rootCancel = testRootContext() // 第一次 Close 应该成功 if err := e.Close(); err != nil { t.Errorf("first Close() error: %v", err) } // 第二次 Close 应该直接返回 nil(防重入) if err := e.Close(); err != nil { t.Errorf("second Close() should return nil, got: %v", err) } } func TestEngine_Closed(t *testing.T) { e := &Engine{ cfg: testConfig(), observer: &NoopObserver{}, sessionState: struct { mu sync.RWMutex sessions map[string]*Session }{ sessions: make(map[string]*Session), }, } e.rootCtx, e.rootCancel = testRootContext() if e.Closed() { t.Error("should not be closed before Close()") } e.Close() if !e.Closed() { t.Error("should be closed after Close()") } } // ============================================================ // Close 取消 root context // ============================================================ func TestEngine_Close_CancelsRootContext(t *testing.T) { e := &Engine{ cfg: testConfig(), observer: &NoopObserver{}, sessionState: struct { mu sync.RWMutex sessions map[string]*Session }{ sessions: make(map[string]*Session), }, } e.rootCtx, e.rootCancel = testRootContext() // root context 应该活跃 select { case <-e.rootCtx.Done(): t.Fatal("rootCtx should not be cancelled before Close()") default: } e.Close() // root context 应该被取消 select { case <-e.rootCtx.Done(): // 正确 default: t.Error("rootCtx should be cancelled after Close()") } } // ============================================================ // Close 刷新 BufferedObserver // ============================================================ func TestEngine_Close_FlushesBufferedObserver(t *testing.T) { var flushed int32 buf := &mockBufferedObserver{ onClose: func() { atomic.AddInt32(&flushed, 1) }, } e := &Engine{ cfg: func() *Config { c := testConfig(); c.CloseTimeout = 5 * time.Second; return c }(), observer: buf, sessionState: struct { mu sync.RWMutex sessions map[string]*Session }{ sessions: make(map[string]*Session), }, } e.rootCtx, e.rootCancel = testRootContext() e.Close() if atomic.LoadInt32(&flushed) != 1 { t.Error("BufferedObserver should have been flushed during Close()") } } func TestEngine_Close_FlushesCompositeObserver(t *testing.T) { var flushed int32 buf := &mockBufferedObserver{ onClose: func() { atomic.AddInt32(&flushed, 1) }, } comp := NewCompositeObserver(&NoopObserver{}, buf) e := &Engine{ cfg: func() *Config { c := testConfig(); c.CloseTimeout = 5 * time.Second; return c }(), observer: comp, sessionState: struct { mu sync.RWMutex sessions map[string]*Session }{ sessions: make(map[string]*Session), }, } e.rootCtx, e.rootCancel = testRootContext() e.Close() // 注意:CompositeObserver 中的 BufferedObserver 不会被直接识别, // 因为 observer 类型是 *CompositeObserver 不是 *BufferedObserver. // Close() 中有专门的 CompositeObserver 处理分支. if atomic.LoadInt32(&flushed) != 1 { t.Error("BufferedObserver inside CompositeObserver should have been flushed") } } // ============================================================ // Close 超时保护 // ============================================================ func TestEngine_Close_DefaultTimeout(t *testing.T) { e := &Engine{ cfg: testConfig(), // CloseTimeout=0 → 默认 10s observer: &NoopObserver{}, sessionState: struct { mu sync.RWMutex sessions map[string]*Session }{ sessions: make(map[string]*Session), }, } e.rootCtx, e.rootCancel = testRootContext() start := time.Now() err := e.Close() elapsed := time.Since(start) if err != nil { t.Errorf("Close() error: %v", err) } // 无挂起资源,应该很快完成(< 1s,包含 100ms 等待) if elapsed > 1*time.Second { t.Errorf("Close took too long: %v", elapsed) } } // ============================================================ // Close 发送 observer 事件 // ============================================================ func TestEngine_Close_EmitsEvents(t *testing.T) { var events []string obs := &eventCollector{ onEvent: func(name string, data map[string]any) { events = append(events, name) }, } e := &Engine{ cfg: testConfig(), observer: obs, sessionState: struct { mu sync.RWMutex sessions map[string]*Session }{ sessions: make(map[string]*Session), }, } e.rootCtx, e.rootCancel = testRootContext() e.Close() // 应该发出 engine_closing 和 engine_closed hasClosing := false hasClosed := false for _, name := range events { if name == "engine_closing" { hasClosing = true } if name == "engine_closed" { hasClosed = true } } if !hasClosing { t.Error("should emit engine_closing event") } if !hasClosed { t.Error("should emit engine_closed event") } } // ============================================================ // closeWithTimeout // ============================================================ func TestCloseWithTimeout_Success(t *testing.T) { ctx, cancel := testTimeoutContext(5 * time.Second) defer cancel() err := closeWithTimeout(ctx, "test", func() error { return nil }) if err != nil { t.Errorf("expected nil error, got: %v", err) } } func TestCloseWithTimeout_Timeout(t *testing.T) { ctx, cancel := testTimeoutContext(100 * time.Millisecond) defer cancel() err := closeWithTimeout(ctx, "hanging", func() error { time.Sleep(5 * time.Second) return nil }) if err == nil { t.Error("expected timeout error") } } // ============================================================ // Engine.Context() // ============================================================ func TestEngine_Context(t *testing.T) { e := &Engine{ cfg: testConfig(), observer: &NoopObserver{}, sessionState: struct { mu sync.RWMutex sessions map[string]*Session }{ sessions: make(map[string]*Session), }, } e.rootCtx, e.rootCancel = testRootContext() ctx := e.Context() if ctx == nil { t.Fatal("Context() should not return nil") } // 应该和 rootCtx 是同一个 select { case <-ctx.Done(): t.Error("context should not be cancelled") default: } e.rootCancel() select { case <-ctx.Done(): // 正确 default: t.Error("context should be cancelled after rootCancel") } } // ============================================================ // 测试辅助 // ============================================================ func testRootContext() (ctx_ context.Context, cancel_ context.CancelFunc) { return context.WithCancel(context.Background()) } func testTimeoutContext(d time.Duration) (ctx_ context.Context, cancel_ context.CancelFunc) { return context.WithTimeout(context.Background(), d) } // mockBufferedObserver 模拟 BufferedObserver(实现 Close 追踪). // 精妙之处(CLEVER): 不用真正的 BufferedObserver(需要 inner observer + goroutine), // 用 mock 只验证 Close 被调用. type mockBufferedObserver struct { NoopObserver onClose func() } func (m *mockBufferedObserver) Close() { if m.onClose != nil { m.onClose() } } // eventCollector 收集 observer 事件用于断言. // // 精妙之处(CLEVER): mu 保护 names 切片--observer.Event() 可能从 // time.AfterFunc goroutine(空闲定时器)或心跳 goroutine 并发调用, // 不加锁会触发 race detector 报 DATA RACE. type eventCollector struct { NoopObserver mu sync.Mutex names []string onEvent func(name string, data map[string]any) } func (c *eventCollector) Event(name string, data map[string]any) { c.mu.Lock() c.names = append(c.names, name) c.mu.Unlock() if c.onEvent != nil { c.onEvent(name, data) } } // EventNames 线程安全地返回已收集事件名称的快照. func (c *eventCollector) EventNames() []string { c.mu.Lock() defer c.mu.Unlock() cp := make([]string, len(c.names)) copy(cp, c.names) return cp } // ============================================================ // ActivityTracker refcount 修复验证(P0-1) // ============================================================ // TestActivityTracker_StopCalledAfterHookBlocksAllTools 验证当 pre_tool_use hook // 拦截所有工具时,ActivityToolExec 计数器正确归零(不泄漏). // // 注意:这个测试在单元测试层验证逻辑正确性-- // 通过直接模拟"hook 拦截后 continue"的情况,验证 Stop 一定被调用. func TestActivityTracker_RefcountNotLeaked(t *testing.T) { // 精妙之处(CLEVER): 用 atomic.Bool 而非 plain bool-- // OnIdle 回调在 time.AfterFunc goroutine 中写入,test goroutine 读取, // plain bool 会触发 -race 检测器报 DATA RACE. var idleCalled atomic.Bool cfg := &ActivityTrackerConfig{ HeartbeatInterval: 50 * time.Millisecond, IdleDelay: 50 * time.Millisecond, OnIdle: func(d time.Duration) { idleCalled.Store(true) }, } tracker := NewActivityTracker(cfg, &NoopObserver{}) defer tracker.Close() // 模拟 Start(工具批次开始) tracker.Start(ActivityToolExec) // 模拟"所有工具被 hook 拦截,continue 前正确调用 Stop" tracker.Stop(ActivityToolExec) // 等待空闲回调触发(确认 refcount 正确归零) time.Sleep(200 * time.Millisecond) // 验证引用计数为 0(不泄漏) tracker.mu.Lock() count := tracker.refcount tracker.mu.Unlock() if count != 0 { t.Errorf("expected refcount=0, got %d (refcount leaked)", count) } if !idleCalled.Load() { t.Error("OnIdle should have been called after Stop() (refcount reached 0)") } } // ============================================================ // AuditObserver 接线测试(P0-2) // ============================================================ // TestEngine_AuditSink_Wired 验证 Config.AuditSink 设置后, // operation_recorded 事件会被 AuditObserver 处理并写入 sink. func TestEngine_AuditSink_Wired(t *testing.T) { var written []string sink := &mockAuditSink{ onWrite: func(entry security.AuditEntry) { written = append(written, "written") }, } obs := NewAuditObserver(sink, "test-session") // 模拟 operation_recorded 事件 obs.Event("operation_recorded", map[string]any{ "tool": "Write", "status": "success", "resource": "/tmp/test.go", }) if len(written) != 1 { t.Errorf("expected 1 audit entry written, got %d", len(written)) } } // TestEngine_AuditObserver_SecretScanBlocked 验证 secret_scan_blocked 事件 // 被 AuditObserver 正确转换为审计条目. func TestEngine_AuditObserver_SecretScanBlocked(t *testing.T) { var entries []string sink := &mockAuditSink{ onWrite: func(entry security.AuditEntry) { entries = append(entries, "blocked") }, } obs := NewAuditObserver(sink, "") obs.Event("secret_scan_blocked", map[string]any{ "path": "/tmp/secrets.env", "rule_ids": "github-pat,aws-access-token", "count": 2, }) if len(entries) != 1 { t.Errorf("expected 1 audit entry for secret_scan_blocked, got %d", len(entries)) } } // mockAuditSink 用于测试的 AuditSink mock. type mockAuditSink struct { onWrite func(entry security.AuditEntry) } func (m *mockAuditSink) Write(entry security.AuditEntry) error { if m.onWrite != nil { m.onWrite(entry) } return nil } func (m *mockAuditSink) Close() error { return nil }