// grep_test.go -- Grep 工具的单元测试(双引擎版). // // 覆盖场景: // - 基本正则搜索 // - 大小写不敏感搜索 // - 不同输出模式(content, files_with_matches, count) // - 无效正则报错 // - 空模式报错 // - 搜索路径不存在报错 // - truncateLine 辅助函数 // - isBinaryExtension 辅助函数 // - 上下文行数(-A/-B/-C 分离控制) // - 双引擎切换(rg 可用 / 不可用) // - 文件类型过滤 // - 多行匹配 // - head_limit(0=默认 250) // - offset 偏移 // - Glob 逗号分割过滤 // - files_with_matches 排序 // - 二进制文件跳过 // - VCS 目录排除 package builtin import ( "context" "encoding/json" "os" "os/exec" "path/filepath" "strings" "testing" "git.flytoex.net/yuanwei/flyto-agent/pkg/execenv" ) // TestGrepTool_BasicSearch 测试基本正则搜索 func TestGrepTool_BasicSearch(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() os.WriteFile(filepath.Join(dir, "hello.txt"), []byte("hello world\ngoodbye world\nhello again"), 0644) input, _ := json.Marshal(grepInput{Pattern: "hello", Path: dir, OutputMode: "content"}) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("执行失败: %v", err) } if result.IsError { t.Fatalf("不应标记为错误: %s", result.Output) } // 应找到包含 hello 的行 if !strings.Contains(result.Output, "hello") { t.Errorf("应包含 hello: %s", result.Output) } } // TestGrepTool_CaseInsensitive 测试大小写不敏感搜索 func TestGrepTool_CaseInsensitive(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() os.WriteFile(filepath.Join(dir, "test.txt"), []byte("Hello World\nhello world\nHELLO WORLD"), 0644) input, _ := json.Marshal(grepInput{ Pattern: "hello", Path: dir, OutputMode: "count", CaseInsensitive: true, }) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("执行失败: %v", err) } // 应匹配到 3 行 if !strings.Contains(result.Output, ":3") { t.Errorf("大小写不敏感应匹配 3 行: %s", result.Output) } } // TestGrepTool_FilesWithMatches 测试 files_with_matches 输出模式 func TestGrepTool_FilesWithMatches(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() os.WriteFile(filepath.Join(dir, "a.txt"), []byte("hello"), 0644) os.WriteFile(filepath.Join(dir, "b.txt"), []byte("world"), 0644) os.WriteFile(filepath.Join(dir, "c.txt"), []byte("hello world"), 0644) input, _ := json.Marshal(grepInput{ Pattern: "hello", Path: dir, OutputMode: "files_with_matches", }) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("执行失败: %v", err) } // 应匹配到 a.txt 和 c.txt if !strings.Contains(result.Output, "a.txt") { t.Errorf("应匹配到 a.txt: %s", result.Output) } if !strings.Contains(result.Output, "c.txt") { t.Errorf("应匹配到 c.txt: %s", result.Output) } if strings.Contains(result.Output, "b.txt") { t.Errorf("不应匹配到 b.txt: %s", result.Output) } } // TestGrepTool_NoMatch 测试无匹配 func TestGrepTool_NoMatch(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() os.WriteFile(filepath.Join(dir, "test.txt"), []byte("hello world"), 0644) input, _ := json.Marshal(grepInput{Pattern: "zzzzz", Path: dir}) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("执行失败: %v", err) } if !strings.Contains(result.Output, "No matches") { t.Errorf("无匹配应有提示: %s", result.Output) } } // TestGrepTool_InvalidRegex 测试无效正则表达式 func TestGrepTool_InvalidRegex(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() input, _ := json.Marshal(grepInput{Pattern: "[invalid", Path: dir}) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("不应返回 Go error: %v", err) } if !result.IsError { t.Error("无效正则应报错") } if !strings.Contains(result.Output, "invalid regex") { t.Errorf("错误信息不匹配: %s", result.Output) } } // TestGrepTool_EmptyPattern 测试空模式 func TestGrepTool_EmptyPattern(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) input, _ := json.Marshal(grepInput{Pattern: ""}) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("不应返回 Go error: %v", err) } if !result.IsError { t.Error("空模式应报错") } } // TestGrepTool_SearchSingleFile 测试搜索单个文件 func TestGrepTool_SearchSingleFile(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() filePath := filepath.Join(dir, "single.txt") os.WriteFile(filePath, []byte("line one\nline two\nline three"), 0644) input, _ := json.Marshal(grepInput{Pattern: "two", Path: filePath, OutputMode: "content"}) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("执行失败: %v", err) } if !strings.Contains(result.Output, "two") { t.Errorf("应包含匹配: %s", result.Output) } } // TestTruncateLine 测试行截断 func TestTruncateLine(t *testing.T) { tests := []struct { name string line string maxLen int want string }{ {"短行不截断", "hello", 10, "hello"}, {"精确长度", "hello", 5, "hello"}, {"超长截断", "hello world", 5, "hello... [line truncated]"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := truncateLine(tt.line, tt.maxLen) if got != tt.want { t.Errorf("truncateLine(%q, %d) = %q, 期望 %q", tt.line, tt.maxLen, got, tt.want) } }) } } // TestIsBinaryExtension 测试二进制文件扩展名检测 func TestIsBinaryExtension(t *testing.T) { tests := []struct { ext string want bool }{ {".exe", true}, {".png", true}, {".zip", true}, {".pdf", true}, {".go", false}, {".txt", false}, {".js", false}, {".EXE", true}, } for _, tt := range tests { got := isBinaryExtension(tt.ext) if got != tt.want { t.Errorf("isBinaryExtension(%q) = %v, 期望 %v", tt.ext, got, tt.want) } } } // TestGrepTool_Metadata 测试工具元数据 func TestGrepTool_Metadata(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) meta := tool.Metadata() if !meta.ConcurrencySafe { t.Error("Grep 应标记为 ConcurrencySafe") } if !meta.ReadOnly { t.Error("Grep 应标记为 ReadOnly") } if tool.Name() != "Grep" { t.Errorf("期望名称 'Grep', 实际 %q", tool.Name()) } } // ---------- 双引擎测试 ---------- // TestGrepEngine_BuiltinEngine 测试内置 Go 引擎 func TestGrepEngine_BuiltinEngine(t *testing.T) { dir := t.TempDir() os.WriteFile(filepath.Join(dir, "code.go"), []byte("func main() {\n\tfmt.Println(\"hello\")\n}\n"), 0644) engine := NewBuiltinGrepEngine() if engine.Name() != "builtin" { t.Errorf("引擎名称应为 'builtin', 实际 %q", engine.Name()) } result, err := engine.Search(context.Background(), &GrepParams{ Pattern: "hello", SearchPath: dir, SearchDir: dir, OutputMode: "content", HeadLimit: 250, }) if err != nil { t.Fatalf("搜索失败: %v", err) } if result.TotalMatches == 0 { t.Error("应找到匹配") } if !strings.Contains(result.Output, "hello") { t.Errorf("输出应包含 hello: %s", result.Output) } } // TestGrepEngine_RipgrepEngine 测试 Ripgrep 引擎 func TestGrepEngine_RipgrepEngine(t *testing.T) { rgPath, err := exec.LookPath("rg") if err != nil { t.Skip("rg 不可用,跳过 Ripgrep 引擎测试") } dir := t.TempDir() os.WriteFile(filepath.Join(dir, "test.go"), []byte("func main() {\n\tfmt.Println(\"hello\")\n}\n"), 0644) os.WriteFile(filepath.Join(dir, "other.txt"), []byte("no match here"), 0644) engine := NewRipgrepEngine(rgPath, execenv.DefaultExecutor{}) if engine.Name() != "ripgrep" { t.Errorf("引擎名称应为 'ripgrep', 实际 %q", engine.Name()) } result, err := engine.Search(context.Background(), &GrepParams{ Pattern: "hello", SearchPath: dir, SearchDir: dir, OutputMode: "files_with_matches", HeadLimit: 250, }) if err != nil { t.Fatalf("搜索失败: %v", err) } if result.TotalMatches == 0 { t.Error("应找到匹配") } if !strings.Contains(result.Output, "test.go") { t.Errorf("输出应包含 test.go: %s", result.Output) } } // TestGrepEngine_DetectEngine 测试引擎自动检测 func TestGrepEngine_DetectEngine(t *testing.T) { engine := DetectGrepEngine(execenv.DefaultExecutor{}) name := engine.Name() if name != "ripgrep" && name != "builtin" { t.Errorf("引擎名称应为 'ripgrep' 或 'builtin', 实际 %q", name) } // 如果系统有 rg,应选择 ripgrep if _, err := exec.LookPath("rg"); err == nil { if name != "ripgrep" { t.Errorf("系统有 rg 应选择 ripgrep 引擎, 实际 %q", name) } } } // ---------- 新功能测试 ---------- // TestGrepTool_FileType 测试文件类型过滤 func TestGrepTool_FileType(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() os.WriteFile(filepath.Join(dir, "main.go"), []byte("func hello()"), 0644) os.WriteFile(filepath.Join(dir, "app.js"), []byte("function hello()"), 0644) os.WriteFile(filepath.Join(dir, "lib.py"), []byte("def hello():"), 0644) input, _ := json.Marshal(grepInput{ Pattern: "hello", Path: dir, OutputMode: "files_with_matches", FileType: "go", }) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("执行失败: %v", err) } if !strings.Contains(result.Output, "main.go") { t.Errorf("应匹配到 main.go: %s", result.Output) } // 注意:ripgrep 引擎会按 rg 内置类型过滤,内置引擎按我们的映射表过滤 // 两者对 .go 的定义一致 } // TestGrepTool_Multiline 测试多行匹配 func TestGrepTool_Multiline(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() os.WriteFile(filepath.Join(dir, "multi.txt"), []byte("start\nmiddle\nend"), 0644) input, _ := json.Marshal(grepInput{ Pattern: "start.*end", Path: dir, OutputMode: "content", Multiline: true, }) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("执行失败: %v", err) } if result.IsError { t.Fatalf("不应标记为错误: %s", result.Output) } // 多行模式下 . 匹配换行符,所以 start.*end 应该匹配 if strings.Contains(result.Output, "No matches") { t.Errorf("多行模式应匹配跨行内容: %s", result.Output) } } // TestGrepTool_HeadLimitDefault 测试 head_limit 默认值(未传 = 250) func TestGrepTool_HeadLimitDefault(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() // 创建大量匹配 var content strings.Builder for i := 0; i < 300; i++ { content.WriteString("hello line\n") } os.WriteFile(filepath.Join(dir, "big.txt"), []byte(content.String()), 0644) // 不传 head_limit,默认 250 input, _ := json.Marshal(grepInput{ Pattern: "hello", Path: dir, OutputMode: "content", }) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("执行失败: %v", err) } // 输出应被截断 if !strings.Contains(result.Output, "truncated") { t.Errorf("300 行匹配应被截断: %s", result.Output[:min(200, len(result.Output))]) } } // TestGrepTool_PathNotFound 测试路径不存在 func TestGrepTool_PathNotFound(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) input, _ := json.Marshal(grepInput{ Pattern: "hello", Path: "/nonexistent/path/to/nothing", }) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("不应返回 Go error: %v", err) } if !result.IsError { t.Error("不存在的路径应报错") } } // TestFileMatchesType 测试文件类型匹配 func TestFileMatchesType(t *testing.T) { tests := []struct { filePath string fileType string want bool }{ {"main.go", "go", true}, {"main.ts", "go", false}, {"app.js", "js", true}, {"app.jsx", "js", true}, {"lib.py", "py", true}, {"lib.pyi", "py", true}, {"main.rs", "rust", true}, {"Dockerfile", "dockerfile", true}, {"Makefile", "make", true}, {"test.txt", "unknown_type", true}, // 未知类型不过滤 } for _, tt := range tests { got := fileMatchesType(tt.filePath, tt.fileType) if got != tt.want { t.Errorf("fileMatchesType(%q, %q) = %v, 期望 %v", tt.filePath, tt.fileType, got, tt.want) } } } // TestRelativizePath 测试路径相对化 func TestRelativizePath(t *testing.T) { tests := []struct { line string baseDir string want string }{ {"/home/user/project/main.go", "/home/user/project", "main.go"}, {"/home/user/project/main.go:10:hello", "/home/user/project", "main.go:10:hello"}, {"already/relative.go", "/home/user/project", "already/relative.go"}, {"/other/path/file.go", "/home/user/project", "/other/path/file.go"}, } for _, tt := range tests { got := relativizePath(tt.line, tt.baseDir) if got != tt.want { t.Errorf("relativizePath(%q, %q) = %q, 期望 %q", tt.line, tt.baseDir, got, tt.want) } } } // TestCountMatchedFiles 测试匹配文件计数 func TestCountMatchedFiles(t *testing.T) { tests := []struct { name string lines []string mode string want int }{ { "files_with_matches", []string{"a.go", "b.go", "c.go"}, "files_with_matches", 3, }, { "count", []string{"a.go:5", "b.go:3"}, "count", 2, }, { "content", []string{"a.go:1:hello", "a.go:5:world", "b.go:1:test"}, "content", 2, }, { "空行和分隔符", []string{"a.go:1:hello", "--", "b.go:1:test", ""}, "content", 2, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := countMatchedFiles(tt.lines, tt.mode) if got != tt.want { t.Errorf("countMatchedFiles(..., %q) = %d, 期望 %d", tt.mode, got, tt.want) } }) } } // TestSortFilesByMtime 测试文件按 mtime 排序 func TestSortFilesByMtime(t *testing.T) { dir := t.TempDir() // 创建文件 var paths []string for i := 0; i < 5; i++ { p := filepath.Join(dir, string(rune('a'+i))+".txt") os.WriteFile(p, []byte("data"), 0644) paths = append(paths, p) } sorted := sortFilesByMtime(context.Background(), paths) if len(sorted) != 5 { t.Fatalf("应返回 5 个文件, 实际 %d", len(sorted)) } } // TestToLowerASCII 测试 ASCII 小写转换 func TestToLowerASCII(t *testing.T) { tests := []struct { input string want string }{ {".EXE", ".exe"}, {".go", ".go"}, {".PNG", ".png"}, {"", ""}, {"HELLO", "hello"}, {"MiXeD", "mixed"}, } for _, tt := range tests { got := toLowerASCII(tt.input) if got != tt.want { t.Errorf("toLowerASCII(%q) = %q, 期望 %q", tt.input, got, tt.want) } } } // TestGrepEngine_BuiltinEngine_Multiline 测试内置引擎多行搜索 func TestGrepEngine_BuiltinEngine_Multiline(t *testing.T) { dir := t.TempDir() os.WriteFile(filepath.Join(dir, "multi.txt"), []byte("func main() {\n\treturn\n}\n"), 0644) engine := NewBuiltinGrepEngine() result, err := engine.Search(context.Background(), &GrepParams{ Pattern: "func.*}", SearchPath: dir, SearchDir: dir, OutputMode: "content", HeadLimit: 250, Multiline: true, }) if err != nil { t.Fatalf("搜索失败: %v", err) } if result.TotalMatches == 0 { t.Error("多行模式应匹配跨行内容") } } // TestGrepEngine_BuiltinEngine_ContextLines 测试内置引擎上下文行数 func TestGrepEngine_BuiltinEngine_ContextLines(t *testing.T) { dir := t.TempDir() os.WriteFile(filepath.Join(dir, "ctx.txt"), []byte("line1\nline2\nTARGET\nline4\nline5\n"), 0644) engine := NewBuiltinGrepEngine() result, err := engine.Search(context.Background(), &GrepParams{ Pattern: "TARGET", SearchPath: dir, SearchDir: dir, OutputMode: "content", HeadLimit: 250, ContextBefore: 1, ContextAfter: 1, }) if err != nil { t.Fatalf("搜索失败: %v", err) } if result.TotalMatches == 0 { t.Fatal("应找到匹配") } // 输出应包含上下文行 if !strings.Contains(result.Output, "line2") { t.Errorf("应包含前一行 line2: %s", result.Output) } if !strings.Contains(result.Output, "line4") { t.Errorf("应包含后一行 line4: %s", result.Output) } } // TestGrepEngine_BuiltinEngine_NoMatch 测试内置引擎无匹配 func TestGrepEngine_BuiltinEngine_NoMatch(t *testing.T) { dir := t.TempDir() os.WriteFile(filepath.Join(dir, "test.txt"), []byte("hello"), 0644) engine := NewBuiltinGrepEngine() result, err := engine.Search(context.Background(), &GrepParams{ Pattern: "zzzzz", SearchPath: dir, SearchDir: dir, OutputMode: "content", HeadLimit: 250, }) if err != nil { t.Fatalf("搜索失败: %v", err) } if result.TotalMatches != 0 { t.Errorf("不应有匹配, 实际 %d", result.TotalMatches) } } // TestGrepTool_GlobFilter 测试 glob 过滤(逗号分割) func TestGrepTool_GlobFilter(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() os.WriteFile(filepath.Join(dir, "main.go"), []byte("func hello()"), 0644) os.WriteFile(filepath.Join(dir, "app.ts"), []byte("function hello()"), 0644) os.WriteFile(filepath.Join(dir, "lib.py"), []byte("def hello():"), 0644) // 逗号分割 glob input, _ := json.Marshal(grepInput{ Pattern: "hello", Path: dir, OutputMode: "files_with_matches", Glob: "*.go,*.ts", }) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("执行失败: %v", err) } if !strings.Contains(result.Output, "main.go") { t.Errorf("应匹配到 main.go: %s", result.Output) } if !strings.Contains(result.Output, "app.ts") { t.Errorf("应匹配到 app.ts: %s", result.Output) } if strings.Contains(result.Output, "lib.py") { t.Errorf("不应匹配到 lib.py: %s", result.Output) } } // ---------- symlink 安全测试 ---------- // TestGrepTool_SymlinkEscape 验证 Grep 的 symlink 路径逃逸漏洞已修复. // // 攻击模型:同 Glob -- cwd 内有 `ln -s /etc/passwd link`, // collectFilesForGrep 若不校验 symlink 目标,会把 /etc/passwd 加入搜索列表. func TestGrepTool_SymlinkEscape(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() // 合法文件 os.WriteFile(filepath.Join(dir, "safe.txt"), []byte("hello safe"), 0644) // 在 dir 外创建包含搜索关键字的 "秘密" 文件 outside := t.TempDir() os.WriteFile(filepath.Join(outside, "secret.txt"), []byte("hello secret"), 0644) // 创建越界 symlink symlinkPath := filepath.Join(dir, "evil_link.txt") if err := os.Symlink(filepath.Join(outside, "secret.txt"), symlinkPath); err != nil { t.Skipf("无法创建 symlink,跳过: %v", err) } input, _ := json.Marshal(grepInput{ Pattern: "hello", Path: dir, OutputMode: "files_with_matches", }) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("执行失败: %v", err) } // 合法文件应出现 if !strings.Contains(result.Output, "safe.txt") { t.Errorf("应匹配到 safe.txt: %s", result.Output) } // 越界 symlink 不应出现 if strings.Contains(result.Output, "evil_link.txt") { t.Errorf("越界 symlink 不应出现在 grep 结果中(symlink 逃逸漏洞): %s", result.Output) } } // TestGrepTool_BinarySkip 测试二进制文件跳过 func TestGrepTool_BinarySkip(t *testing.T) { tool := NewGrepTool(execenv.DefaultExecutor{}) dir := t.TempDir() // 创建包含 null 字节的"二进制"文件 os.WriteFile(filepath.Join(dir, "binary.dat"), []byte("hello\x00world"), 0644) os.WriteFile(filepath.Join(dir, "text.txt"), []byte("hello world"), 0644) input, _ := json.Marshal(grepInput{ Pattern: "hello", Path: dir, OutputMode: "files_with_matches", }) result, err := tool.Execute(context.Background(), input, nil) if err != nil { t.Fatalf("执行失败: %v", err) } // 应该找到文本文件但跳过二进制文件 if !strings.Contains(result.Output, "text.txt") { t.Errorf("应匹配到 text.txt: %s", result.Output) } }