// session_snapshot_test.go - 会话快照(断点续传)的单元测试(1.6). // // 覆盖场景: // - BuildSnapshot: 字段正确 // - FileSnapshotStore.Save: 原子写入(tmp + rename) // - FileSnapshotStore.Load: 正常读取 // - FileSnapshotStore.Load: 文件不存在返回 false(非 error) // - FileSnapshotStore.Delete: 幂等删除 // - FileSnapshotStore.Save: 目录不存在自动创建 // - ResumeConversation: 快照不存在时返回 error // - 快照 JSON 包含正确字段(round-trip) // - PartialToolUse 序列化/反序列化 // - 路径穿越防御(conversationID 含 "/") package engine import ( "context" "encoding/json" "os" "path/filepath" "strings" "testing" "time" "git.flytoex.net/yuanwei/flyto-agent/pkg/query" ) // --- BuildSnapshot --- // TestBuildSnapshot_Fields 验证 BuildSnapshot 字段正确 func TestBuildSnapshot_Fields(t *testing.T) { msgs := []query.Message{ {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "hello"}}}, } snap := BuildSnapshot("conv-001", msgs, 3) if snap.ConversationID != "conv-001" { t.Errorf("ConversationID: %q", snap.ConversationID) } if len(snap.Messages) != 1 { t.Errorf("Messages len: %d", len(snap.Messages)) } if snap.TurnIndex != 3 { t.Errorf("TurnIndex: %d", snap.TurnIndex) } if snap.SavedAt.IsZero() { t.Error("SavedAt 不应为零值") } if snap.SavedAt.Location() != time.UTC { t.Error("SavedAt 应为 UTC 时区") } } // --- FileSnapshotStore --- // TestFileSnapshotStore_SaveLoad 保存后加载 func TestFileSnapshotStore_SaveLoad(t *testing.T) { dir := t.TempDir() store := NewFileSnapshotStore(dir) snap := BuildSnapshot("test-conv", []query.Message{ {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "what time is it?"}}}, {Role: query.RoleAssistant, Content: []query.Content{{Type: query.ContentText, Text: "12:00"}}}, }, 1) ctx := context.Background() if err := store.Save(ctx, snap); err != nil { t.Fatalf("Save 失败: %v", err) } loaded, found, err := store.Load(ctx, "test-conv") if err != nil { t.Fatalf("Load 失败: %v", err) } if !found { t.Fatal("应找到快照,got found=false") } if loaded.ConversationID != "test-conv" { t.Errorf("ConversationID: %q", loaded.ConversationID) } if len(loaded.Messages) != 2 { t.Errorf("Messages len: %d", len(loaded.Messages)) } } // TestFileSnapshotStore_NotFound 不存在时返回 (zero, false, nil) func TestFileSnapshotStore_NotFound(t *testing.T) { dir := t.TempDir() store := NewFileSnapshotStore(dir) _, found, err := store.Load(context.Background(), "nonexistent") if err != nil { t.Errorf("不存在时 err 应为 nil,got: %v", err) } if found { t.Error("不存在时 found 应为 false") } } // TestFileSnapshotStore_Delete_Idempotent 删除幂等 func TestFileSnapshotStore_Delete_Idempotent(t *testing.T) { dir := t.TempDir() store := NewFileSnapshotStore(dir) ctx := context.Background() // 保存后删除 snap := BuildSnapshot("del-test", nil, 0) store.Save(ctx, snap) if err := store.Delete(ctx, "del-test"); err != nil { t.Errorf("第一次删除失败: %v", err) } // 再次删除(文件已不存在)不报错 if err := store.Delete(ctx, "del-test"); err != nil { t.Errorf("重复删除应幂等,got: %v", err) } } // TestFileSnapshotStore_AutoCreateDir 目录不存在时自动创建 func TestFileSnapshotStore_AutoCreateDir(t *testing.T) { base := t.TempDir() dir := filepath.Join(base, "deep", "nested", "dir") store := NewFileSnapshotStore(dir) snap := BuildSnapshot("auto-dir", nil, 0) if err := store.Save(context.Background(), snap); err != nil { t.Fatalf("深层目录自动创建失败: %v", err) } if _, err := os.Stat(dir); err != nil { t.Errorf("目录应已创建: %v", err) } } // TestFileSnapshotStore_AtomicWrite 原子写入:tmp 文件不残留 func TestFileSnapshotStore_AtomicWrite(t *testing.T) { dir := t.TempDir() store := NewFileSnapshotStore(dir) snap := BuildSnapshot("atomic", nil, 0) if err := store.Save(context.Background(), snap); err != nil { t.Fatalf("Save 失败: %v", err) } // 目录中不应有 .tmp 文件 entries, _ := os.ReadDir(dir) for _, e := range entries { if strings.HasSuffix(e.Name(), ".tmp") { t.Errorf("发现残留 tmp 文件: %s", e.Name()) } } // 应有正式的 .json 文件 found := false for _, e := range entries { if e.Name() == "atomic.json" { found = true } } if !found { t.Error("未找到 atomic.json 快照文件") } } // TestFileSnapshotStore_Overwrite 覆盖写入 func TestFileSnapshotStore_Overwrite(t *testing.T) { dir := t.TempDir() store := NewFileSnapshotStore(dir) ctx := context.Background() snap1 := BuildSnapshot("conv-x", []query.Message{{Role: query.RoleUser}}, 1) store.Save(ctx, snap1) snap2 := BuildSnapshot("conv-x", []query.Message{{Role: query.RoleUser}, {Role: query.RoleAssistant}}, 2) store.Save(ctx, snap2) loaded, _, _ := store.Load(ctx, "conv-x") if loaded.TurnIndex != 2 { t.Errorf("应覆盖为新快照(TurnIndex=2),got %d", loaded.TurnIndex) } if len(loaded.Messages) != 2 { t.Errorf("应有 2 条消息,got %d", len(loaded.Messages)) } } // TestFileSnapshotStore_PathTraversal conversationID 含 "/" 时路径穿越防御 func TestFileSnapshotStore_PathTraversal(t *testing.T) { dir := t.TempDir() store := NewFileSnapshotStore(dir) // 含 "/" 的 ID:filepath.Base 取最后一段,防止写到 dir 外 snap := BuildSnapshot("../../etc/passwd", nil, 0) if err := store.Save(context.Background(), snap); err != nil { // 可能因为文件名含特殊字符失败,但不应写到 dir 外 return } // 验证写入的文件在 dir 内,而不是 /etc/passwd entries, _ := os.ReadDir(dir) for _, e := range entries { if strings.Contains(e.Name(), "etc") { t.Errorf("路径穿越防御失败,写入了危险路径: %s", e.Name()) } } } // TestFileSnapshotStore_RoundTrip JSON round-trip:PartialToolUse 保留 func TestFileSnapshotStore_RoundTrip(t *testing.T) { dir := t.TempDir() store := NewFileSnapshotStore(dir) ctx := context.Background() snap := SessionSnapshot{ ConversationID: "roundtrip", TurnIndex: 5, SavedAt: time.Now().UTC().Truncate(time.Millisecond), Messages: []query.Message{ {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "test"}}}, }, PartialToolUse: &PartialToolUse{ ID: "tu-123", Name: "Bash", Input: `{"command": "ls `, }, } store.Save(ctx, snap) loaded, found, err := store.Load(ctx, "roundtrip") if err != nil || !found { t.Fatalf("Load 失败: err=%v found=%v", err, found) } if loaded.PartialToolUse == nil { t.Fatal("PartialToolUse 应保留") } if loaded.PartialToolUse.ID != "tu-123" { t.Errorf("PartialToolUse.ID: %q", loaded.PartialToolUse.ID) } if loaded.PartialToolUse.Input != `{"command": "ls ` { t.Errorf("PartialToolUse.Input: %q", loaded.PartialToolUse.Input) } if len(loaded.Messages) != 1 { t.Errorf("Messages len: %d", len(loaded.Messages)) } } // TestSessionSnapshot_JSONFormat 验证快照 JSON 字段名符合约定 func TestSessionSnapshot_JSONFormat(t *testing.T) { snap := SessionSnapshot{ ConversationID: "json-test", TurnIndex: 2, SavedAt: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), } data, err := json.Marshal(snap) if err != nil { t.Fatalf("Marshal 失败: %v", err) } s := string(data) expectedFields := []string{"conversation_id", "turn_index", "saved_at", "messages"} for _, f := range expectedFields { if !strings.Contains(s, `"`+f+`"`) { t.Errorf("JSON 应包含字段 %q,got: %s", f, s) } } // PartialToolUse 为 nil 时不应出现在 JSON 中(omitempty) if strings.Contains(s, "partial_tool_use") { t.Errorf("nil PartialToolUse 不应出现在 JSON 中,got: %s", s) } } // TestResumeConversation_NotFound 快照不存在时返回 error func TestResumeConversation_NotFound(t *testing.T) { dir := t.TempDir() store := NewFileSnapshotStore(dir) // 需要一个真实的 Engine 才能测试 ResumeConversation // 但构建 Engine 依赖 API key 等配置,所以这里只测试 store.Load 的 not-found 路径 // (ResumeConversation 的完整集成测试需要 mock API) _, found, err := store.Load(context.Background(), "nonexistent-conv") if err != nil { t.Fatalf("Load error: %v", err) } if found { t.Fatal("不应找到快照") } } // TestDefaultSnapshotDir 验证默认目录在 ~/.flyto/snapshots func TestDefaultSnapshotDir(t *testing.T) { dir := defaultSnapshotDir() if !strings.Contains(dir, "flyto") || !strings.Contains(dir, "snapshots") { t.Errorf("默认目录应包含 flyto/snapshots,got: %q", dir) } } // TestFileSnapshotStore_ZeroDir 空 dir 使用默认目录 func TestFileSnapshotStore_ZeroDir(t *testing.T) { store := NewFileSnapshotStore("") if store.Dir == "" { t.Error("Dir 不应为空") } if !strings.Contains(store.Dir, "snapshots") { t.Errorf("Dir 应包含 snapshots,got: %q", store.Dir) } } // TestSnapshotMessages_Preservation 消息内容完整保存和恢复 func TestSnapshotMessages_Preservation(t *testing.T) { dir := t.TempDir() store := NewFileSnapshotStore(dir) ctx := context.Background() // 模拟一个中断在 tool_result 末尾的会话 msgs := []query.Message{ {Role: query.RoleUser, Content: []query.Content{{Type: query.ContentText, Text: "list files"}}}, { Role: query.RoleAssistant, Content: []query.Content{ {Type: query.ContentToolUse, ID: "t1", Name: "Bash", Input: map[string]any{"command": "ls"}}, }, }, { Role: query.RoleUser, Content: []query.Content{ {Type: query.ContentToolResult, ToolUseID: "t1", Text: "file1.go\nfile2.go"}, }, }, } snap := BuildSnapshot("interrupted", msgs, 1) store.Save(ctx, snap) loaded, _, _ := store.Load(ctx, "interrupted") // 验证消息列表末尾是 tool_result(中断点状态) lastMsg := loaded.Messages[len(loaded.Messages)-1] if lastMsg.Role != query.RoleUser { t.Errorf("末尾消息应为 user(tool_result),got %s", lastMsg.Role) } if len(lastMsg.Content) == 0 || lastMsg.Content[0].Type != query.ContentToolResult { t.Error("末尾消息应包含 tool_result") } // 注入哨兵后:应增加一条哨兵消息 resumed := maybeInjectResumeSentinel(loaded.Messages) if len(resumed) != len(msgs)+1 { t.Errorf("恢复后消息数应为 %d,got %d", len(msgs)+1, len(resumed)) } }