// extraction_test.go - 记忆提取相关单元测试(模块 10.3). // // 测试范围: // - isUnderDir:路径目录限制辅助函数 // - SubAgent.canUseTool:MemoryDirRestrict 路径检查 // - Engine.hasMemoryWritesSince:消息游标扫描 // - scheduleMemoryExtraction:单飞+后置补跑逻辑 // - SubAgent.historyMessages:历史消息预置 package engine import ( "context" "encoding/json" "sync" "sync/atomic" "testing" "time" "git.flytoex.net/yuanwei/flyto-agent/pkg/memory" "git.flytoex.net/yuanwei/flyto-agent/pkg/query" "git.flytoex.net/yuanwei/flyto-agent/pkg/tools" ) // ───────────────────────────────────────────────────────────────────────────── // isUnderDir // ───────────────────────────────────────────────────────────────────────────── func TestIsUnderDir_Basic(t *testing.T) { cases := []struct { path, dir string want bool }{ {"/foo/bar/baz.md", "/foo/bar", true}, {"/foo/bar", "/foo/bar", true}, // 相等算在目录内 {"/foo/barbaz", "/foo/bar", false}, // 不是 /foo/bar/ 的子路径 {"/foo/bar/", "/foo/bar", true}, // 尾部斜杠 {"", "/foo/bar", false}, // 空 path {"/foo/bar", "", false}, // 空 dir {"/other/path", "/foo/bar", false}, // 完全不同的路径 {"/foo/bar/a/b/c.md", "/foo/bar", true}, // 深层子路径 } for _, c := range cases { got := isUnderDir(c.path, c.dir) if got != c.want { t.Errorf("isUnderDir(%q, %q) = %v, want %v", c.path, c.dir, got, c.want) } } } // ───────────────────────────────────────────────────────────────────────────── // SubAgent.canUseTool - MemoryDirRestrict // ───────────────────────────────────────────────────────────────────────────── func TestSubAgent_CanUseTool_MemoryDirRestrict_AllowsInsideDir(t *testing.T) { sa := &SubAgent{ allowedTools: map[string]bool{"Edit": true, "Write": true, "Read": true}, memoryDirRestrict: "/home/user/.flyto/memory", } input := json.RawMessage(`{"file_path":"/home/user/.flyto/memory/user.md"}`) if !sa.canUseTool("Edit", input) { t.Error("Edit to path inside memoryDirRestrict should be allowed") } } func TestSubAgent_CanUseTool_MemoryDirRestrict_DeniesOutsideDir(t *testing.T) { sa := &SubAgent{ allowedTools: map[string]bool{"Edit": true, "Write": true}, memoryDirRestrict: "/home/user/.flyto/memory", } input := json.RawMessage(`{"file_path":"/home/user/project/main.go"}`) if sa.canUseTool("Edit", input) { t.Error("Edit to path outside memoryDirRestrict should be denied") } } func TestSubAgent_CanUseTool_MemoryDirRestrict_ReadUnrestricted(t *testing.T) { sa := &SubAgent{ allowedTools: map[string]bool{"Read": true}, memoryDirRestrict: "/home/user/.flyto/memory", } // Read 不受 memoryDirRestrict 限制(只限制写工具) input := json.RawMessage(`{"file_path":"/home/user/project/main.go"}`) if !sa.canUseTool("Read", input) { t.Error("Read should not be restricted by memoryDirRestrict") } } func TestSubAgent_CanUseTool_MemoryDirRestrict_EmptyPathDenied(t *testing.T) { sa := &SubAgent{ allowedTools: map[string]bool{"Write": true}, memoryDirRestrict: "/home/user/.flyto/memory", } // 没有 file_path 字段 input := json.RawMessage(`{}`) if sa.canUseTool("Write", input) { t.Error("Write with no file_path should be denied by memoryDirRestrict") } } func TestSubAgent_CanUseTool_NoRestrict_AllowsAnyPath(t *testing.T) { sa := &SubAgent{ allowedTools: map[string]bool{"Edit": true}, memoryDirRestrict: "", // 空 = 不限制 } input := json.RawMessage(`{"file_path":"/any/path/file.go"}`) if !sa.canUseTool("Edit", input) { t.Error("Edit without memoryDirRestrict should allow any path") } } // ───────────────────────────────────────────────────────────────────────────── // Engine.hasMemoryWritesSince // ───────────────────────────────────────────────────────────────────────────── // makeAssistantWithWrite 构造一个包含 Edit tool_use 的 assistant 消息. func makeAssistantWithWrite(filePath string) query.Message { return query.Message{ Role: query.RoleAssistant, Content: []query.Content{ { Type: query.ContentToolUse, Name: "Edit", Input: map[string]any{"file_path": filePath}, }, }, } } // makeUserMsg 构造一个普通用户消息(不含写操作). func makeUserMsg(text string) query.Message { return query.NewTextMsg(query.RoleUser, text) } func newEngineForExtraction(t *testing.T, memDir string) *Engine { t.Helper() store := memory.NewFileStore(memDir) // baseDir 会被 memoryDirForProject 转换 // 直接用一个带真实 Dir() 的 fileStore;我们用 memory.NewFileStoreWithBaseDir 如果存在, // 否则直接测 hasMemoryWritesSince(只用 e.mem.Dir()) e := &Engine{ mem: store, observer: &NoopObserver{}, } return e } func TestEngine_HasMemoryWritesSince_WrittenAfterIdx(t *testing.T) { memDir := t.TempDir() e := &Engine{ mem: newMemStoreWithDir(memDir), observer: &NoopObserver{}, } messages := []query.Message{ makeUserMsg("hello"), makeAssistantWithWrite(memDir + "/user.md"), // idx=1:在 memDir 内写了文件 } if !e.hasMemoryWritesSince(messages, 0) { t.Error("hasMemoryWritesSince should detect write to memory dir") } } func TestEngine_HasMemoryWritesSince_WrittenOutsideMemDir(t *testing.T) { memDir := t.TempDir() e := &Engine{ mem: newMemStoreWithDir(memDir), observer: &NoopObserver{}, } messages := []query.Message{ makeUserMsg("hello"), makeAssistantWithWrite("/other/path/file.go"), // 写了,但不在 memDir 内 } if e.hasMemoryWritesSince(messages, 0) { t.Error("hasMemoryWritesSince should not flag writes outside memory dir") } } func TestEngine_HasMemoryWritesSince_SinceIdxSkipsEarlier(t *testing.T) { memDir := t.TempDir() e := &Engine{ mem: newMemStoreWithDir(memDir), observer: &NoopObserver{}, } messages := []query.Message{ makeAssistantWithWrite(memDir + "/old.md"), // idx=0:早于游标 makeUserMsg("some user message"), } // sinceIdx=1 时 idx=0 的写操作被跳过 if e.hasMemoryWritesSince(messages, 1) { t.Error("hasMemoryWritesSince with sinceIdx=1 should skip message at idx=0") } } func TestEngine_HasMemoryWritesSince_EmptyMessages(t *testing.T) { memDir := t.TempDir() e := &Engine{ mem: newMemStoreWithDir(memDir), observer: &NoopObserver{}, } if e.hasMemoryWritesSince(nil, 0) { t.Error("hasMemoryWritesSince on nil messages should return false") } } // ───────────────────────────────────────────────────────────────────────────── // scheduleMemoryExtraction - 单飞 + 后置补跑 // ───────────────────────────────────────────────────────────────────────────── // countingExtractor 计数 BuildPrompt 调用次数,用于测试单飞逻辑. type countingExtractor struct { mu sync.Mutex buildCount atomic.Int32 blockCh chan struct{} // 关闭时解除阻塞(模拟慢提取) } func newCountingExtractor() *countingExtractor { return &countingExtractor{blockCh: make(chan struct{})} } func (c *countingExtractor) Name() string { return "counting" } func (c *countingExtractor) ShouldExtract(_, _ int) bool { return true } func (c *countingExtractor) AllowedTools() []string { return nil } func (c *countingExtractor) MaxTurns() int { return 1 } func (c *countingExtractor) BuildPrompt(_ []*memory.Entry, _ int) string { c.buildCount.Add(1) // 等待 blockCh 关闭(慢提取模拟) <-c.blockCh return "extract prompt" } // newMemStoreWithDir 用自定义 baseDir 创建 fileStore(测试用). func newMemStoreWithDir(dir string) memory.Store { return memory.NewFileStoreWithBaseDir(dir) } // ───────────────────────────────────────────────────────────────────────────── // SubAgent.historyMessages 预置 // ───────────────────────────────────────────────────────────────────────────── func TestSubAgentConfig_HistoryMessages_StoredOnSubAgent(t *testing.T) { parent := &Engine{ cfg: testConfig(), tools: tools.NewRegistry(), observer: &NoopObserver{}, } histMsgs := []query.Message{ makeUserMsg("first message"), makeUserMsg("second message"), } sa := SpawnSubAgent(parent, &SubAgentConfig{ Description: "test", HistoryMessages: histMsgs, }) if len(sa.historyMessages) != 2 { t.Errorf("historyMessages should have 2 entries, got %d", len(sa.historyMessages)) } } func TestSubAgentConfig_MemoryDirRestrict_StoredOnSubAgent(t *testing.T) { parent := &Engine{ cfg: testConfig(), tools: tools.NewRegistry(), observer: &NoopObserver{}, } sa := SpawnSubAgent(parent, &SubAgentConfig{ Description: "test", MemoryDirRestrict: "/tmp/memory", }) if sa.memoryDirRestrict != "/tmp/memory" { t.Errorf("memoryDirRestrict = %q, want %q", sa.memoryDirRestrict, "/tmp/memory") } } // ───────────────────────────────────────────────────────────────────────────── // memory.Store.Dir() // ───────────────────────────────────────────────────────────────────────────── func TestMemoryStore_Dir_ReturnsNonEmpty(t *testing.T) { dir := t.TempDir() store := memory.NewFileStoreWithBaseDir(dir) if store.Dir() != dir { t.Errorf("Dir() = %q, want %q", store.Dir(), dir) } } // ───────────────────────────────────────────────────────────────────────────── // scheduleMemoryExtraction 单飞测试(无 API 调用,用 blockCh 模拟慢提取) // ───────────────────────────────────────────────────────────────────────────── func TestScheduleMemoryExtraction_SingleFlight(t *testing.T) { // 这个测试验证:当 extractInProgress=true 时,第二次 schedule 会存入 pending // 而不是立即运行第二个 goroutine. e := &Engine{ mem: newMemStoreWithDir(t.TempDir()), observer: &NoopObserver{}, } // 手动设置 inProgress,模拟已有提取正在运行 e.extractState.mu.Lock() e.extractState.inProgress = true e.extractState.mu.Unlock() msgs := []query.Message{makeUserMsg("hello")} ctx := context.Background() // 第二次调用,应存入 pending 而非立即运行 e.scheduleMemoryExtraction(ctx, msgs, 5) e.extractState.mu.Lock() hasPending := e.extractState.pending != nil e.extractState.mu.Unlock() if !hasPending { t.Error("scheduleMemoryExtraction while in-progress should stash into extractPending") } // 恢复状态 e.extractState.mu.Lock() e.extractState.inProgress = false e.extractState.pending = nil e.extractState.mu.Unlock() } func TestScheduleMemoryExtraction_PendingOverwrite(t *testing.T) { // 验证:同一个 inProgress 期间多次 schedule,pending 只保留最后一次 e := &Engine{ mem: newMemStoreWithDir(t.TempDir()), observer: &NoopObserver{}, } e.extractState.mu.Lock() e.extractState.inProgress = true e.extractState.mu.Unlock() ctx := context.Background() msgs1 := []query.Message{makeUserMsg("first")} msgs2 := []query.Message{makeUserMsg("second")} e.scheduleMemoryExtraction(ctx, msgs1, 5) e.scheduleMemoryExtraction(ctx, msgs2, 6) // 应覆盖 msgs1 e.extractState.mu.Lock() pending := e.extractState.pending e.extractState.mu.Unlock() if pending == nil { t.Fatal("extractPending should not be nil") } if pending.turnCount != 6 { t.Errorf("pending should have turnCount=6 (latest), got %d", pending.turnCount) } // 验证是 msgs2(最新) if len(pending.messages) != 1 { t.Errorf("pending messages length = %d, want 1", len(pending.messages)) } // 恢复 e.extractState.mu.Lock() e.extractState.inProgress = false e.extractState.pending = nil e.extractState.mu.Unlock() } func TestScheduleMemoryExtraction_NoExtractorSkips(t *testing.T) { // extractor=nil 时,ShouldExtract 不会被调用,engine.go 的调用者已判断 // 这里只验证 scheduleMemoryExtraction 不 panic e := &Engine{ mem: newMemStoreWithDir(t.TempDir()), observer: &NoopObserver{}, // extractor = nil } // 直接设 inProgress,避免真正执行(避免 nil panic on extractor) e.extractState.mu.Lock() e.extractState.inProgress = true e.extractState.mu.Unlock() done := make(chan struct{}) go func() { defer close(done) e.scheduleMemoryExtraction(context.Background(), nil, 0) }() select { case <-done: case <-time.After(time.Second): t.Error("scheduleMemoryExtraction timed out") } }