// scorer_test.go -- 评分器接口及实现的单元测试. // // 覆盖场景: // - TextScorer 基础评分 // - CompositeScorer 加权融合 // - CompositeScorer 动态 Add/Remove // - ExternalScorer 序列化/反序列化(mock,不启动真实进程) // - SelectRelevant 传 nil scorer 使用默认 // - SelectRelevant 传自定义 scorer package memory import ( "encoding/json" "testing" "time" ) // TestTextScorer_BasicScore 测试 TextScorer 基础评分 func TestTextScorer_BasicScore(t *testing.T) { scorer := &TextScorer{} if scorer.Name() != "text" { t.Errorf("Name() 应为 'text', 实际: %q", scorer.Name()) } header := &MemoryHeader{ Frontmatter: Frontmatter{ Name: "golang-setup", Description: "golang project setup and configuration", Type: TypeProject, }, ModTime: time.Now(), } score := scorer.Score("golang project", header) if score <= 0 { t.Errorf("相关查询分数应 > 0, 实际: %f", score) } } // TestTextScorer_ExactMatch 测试完全匹配 func TestTextScorer_ExactMatch(t *testing.T) { scorer := &TextScorer{} header := &MemoryHeader{ Frontmatter: Frontmatter{ Name: "golang testing", Description: "golang testing", }, } score := scorer.Score("golang testing", header) if score < 0.5 { t.Errorf("完全匹配分数应 >= 0.5, 实际: %f", score) } } // TestTextScorer_NoMatch 测试完全不相关 func TestTextScorer_NoMatch(t *testing.T) { scorer := &TextScorer{} header := &MemoryHeader{ Frontmatter: Frontmatter{ Name: "piano-music", Description: "piano music concert performance", }, } score := scorer.Score("golang testing", header) if score > 0.3 { t.Errorf("不相关查询分数应 < 0.3, 实际: %f", score) } } // TestTextScorer_EmptyQuery 测试空查询 func TestTextScorer_EmptyQuery(t *testing.T) { scorer := &TextScorer{} header := &MemoryHeader{ Frontmatter: Frontmatter{ Name: "test", Description: "test description", }, } score := scorer.Score("", header) if score != 0 { t.Errorf("空查询分数应为 0, 实际: %f", score) } } // mockScorer 是用于测试的 mock 评分器,总是返回固定分数. type mockScorer struct { name string fixedScore float64 } func (m *mockScorer) Name() string { return m.name } func (m *mockScorer) Score(query string, header *MemoryHeader) float64 { return m.fixedScore } // TestCompositeScorer_WeightedAverage 测试加权平均 func TestCompositeScorer_WeightedAverage(t *testing.T) { cs := NewCompositeScorer( WeightedScorer{Scorer: &mockScorer{name: "high", fixedScore: 1.0}, Weight: 0.7}, WeightedScorer{Scorer: &mockScorer{name: "low", fixedScore: 0.0}, Weight: 0.3}, ) if cs.Name() != "composite" { t.Errorf("Name() 应为 'composite', 实际: %q", cs.Name()) } header := &MemoryHeader{ Frontmatter: Frontmatter{Name: "test", Description: "test"}, } score := cs.Score("anything", header) // 期望: (0.7 * 1.0 + 0.3 * 0.0) / (0.7 + 0.3) = 0.7 expected := 0.7 if diff := score - expected; diff > 0.001 || diff < -0.001 { t.Errorf("加权平均分数应为 %f, 实际: %f", expected, score) } } // TestCompositeScorer_EqualWeights 测试等权重 func TestCompositeScorer_EqualWeights(t *testing.T) { cs := NewCompositeScorer( WeightedScorer{Scorer: &mockScorer{name: "a", fixedScore: 0.8}, Weight: 1.0}, WeightedScorer{Scorer: &mockScorer{name: "b", fixedScore: 0.4}, Weight: 1.0}, ) header := &MemoryHeader{ Frontmatter: Frontmatter{Name: "test", Description: "test"}, } score := cs.Score("anything", header) // 期望: (1.0 * 0.8 + 1.0 * 0.4) / (1.0 + 1.0) = 0.6 expected := 0.6 if diff := score - expected; diff > 0.001 || diff < -0.001 { t.Errorf("等权重平均分数应为 %f, 实际: %f", expected, score) } } // TestCompositeScorer_Empty 测试空组合评分器 func TestCompositeScorer_Empty(t *testing.T) { cs := NewCompositeScorer() header := &MemoryHeader{ Frontmatter: Frontmatter{Name: "test", Description: "test"}, } score := cs.Score("anything", header) if score != 0 { t.Errorf("空组合评分器应返回 0, 实际: %f", score) } } // TestCompositeScorer_Add 测试动态添加 func TestCompositeScorer_Add(t *testing.T) { cs := NewCompositeScorer( WeightedScorer{Scorer: &mockScorer{name: "a", fixedScore: 0.5}, Weight: 1.0}, ) header := &MemoryHeader{ Frontmatter: Frontmatter{Name: "test", Description: "test"}, } // 添加前只有 a score1 := cs.Score("q", header) if diff := score1 - 0.5; diff > 0.001 || diff < -0.001 { t.Errorf("添加前分数应为 0.5, 实际: %f", score1) } // 动态添加 b cs.Add(WeightedScorer{Scorer: &mockScorer{name: "b", fixedScore: 1.0}, Weight: 1.0}) score2 := cs.Score("q", header) // 期望: (1.0 * 0.5 + 1.0 * 1.0) / (1.0 + 1.0) = 0.75 expected := 0.75 if diff := score2 - expected; diff > 0.001 || diff < -0.001 { t.Errorf("添加后分数应为 %f, 实际: %f", expected, score2) } } // TestCompositeScorer_Remove 测试按名称移除 func TestCompositeScorer_Remove(t *testing.T) { cs := NewCompositeScorer( WeightedScorer{Scorer: &mockScorer{name: "a", fixedScore: 0.5}, Weight: 1.0}, WeightedScorer{Scorer: &mockScorer{name: "b", fixedScore: 1.0}, Weight: 1.0}, ) header := &MemoryHeader{ Frontmatter: Frontmatter{Name: "test", Description: "test"}, } // 移除 b removed := cs.Remove("b") if !removed { t.Error("Remove 应返回 true") } score := cs.Score("q", header) // 只剩 a,分数应为 0.5 if diff := score - 0.5; diff > 0.001 || diff < -0.001 { t.Errorf("移除后分数应为 0.5, 实际: %f", score) } // 移除不存在的 removed = cs.Remove("nonexistent") if removed { t.Error("移除不存在的评分器应返回 false") } } // TestExternalScorer_RequestSerialization 测试请求序列化格式 func TestExternalScorer_RequestSerialization(t *testing.T) { req := externalScorerRequest{ Query: "数据库配置", Name: "db_config", Description: "database configuration and connection pooling", Type: "project", } data, err := json.Marshal(req) if err != nil { t.Fatalf("序列化失败: %v", err) } // 反序列化验证 var parsed externalScorerRequest if err := json.Unmarshal(data, &parsed); err != nil { t.Fatalf("反序列化失败: %v", err) } if parsed.Query != req.Query { t.Errorf("Query 不匹配: %q vs %q", parsed.Query, req.Query) } if parsed.Name != req.Name { t.Errorf("Name 不匹配: %q vs %q", parsed.Name, req.Name) } if parsed.Description != req.Description { t.Errorf("Description 不匹配: %q vs %q", parsed.Description, req.Description) } if parsed.Type != req.Type { t.Errorf("Type 不匹配: %q vs %q", parsed.Type, req.Type) } } // TestExternalScorer_ResponseDeserialization 测试响应反序列化 func TestExternalScorer_ResponseDeserialization(t *testing.T) { tests := []struct { name string json string expected float64 wantErr bool }{ {"正常分数", `{"score": 0.85}`, 0.85, false}, {"零分", `{"score": 0}`, 0, false}, {"满分", `{"score": 1.0}`, 1.0, false}, {"无效 JSON", `not json`, 0, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var resp externalScorerResponse err := json.Unmarshal([]byte(tt.json), &resp) if tt.wantErr { if err == nil { t.Error("应该报错") } return } if err != nil { t.Fatalf("反序列化失败: %v", err) } if diff := resp.Score - tt.expected; diff > 0.001 || diff < -0.001 { t.Errorf("Score = %f, 期望 %f", resp.Score, tt.expected) } }) } } // TestSelectRelevant_NilScorer 测试传 nil scorer 使用默认 func TestSelectRelevant_NilScorer(t *testing.T) { headers := []MemoryHeader{ {Frontmatter: Frontmatter{Name: "go-testing", Description: "golang testing patterns"}, ModTime: time.Now()}, {Frontmatter: Frontmatter{Name: "python-ml", Description: "python machine learning"}, ModTime: time.Now()}, } // 不传 scorer(向后兼容) result1 := SelectRelevant("golang testing", headers, 5) // 显式传 nil result2 := SelectRelevant("golang testing", headers, 5, nil) if len(result1) != len(result2) { t.Errorf("nil scorer 和无 scorer 结果数量应相同: %d vs %d", len(result1), len(result2)) } if len(result1) == 0 { t.Fatal("应有相关结果") } if result1[0].Frontmatter.Name != result2[0].Frontmatter.Name { t.Error("nil scorer 和无 scorer 结果应相同") } } // TestSelectRelevant_CustomScorer 测试传自定义 scorer func TestSelectRelevant_CustomScorer(t *testing.T) { headers := []MemoryHeader{ {Frontmatter: Frontmatter{Name: "a", Description: "first"}, ModTime: time.Now()}, {Frontmatter: Frontmatter{Name: "b", Description: "second"}, ModTime: time.Now()}, {Frontmatter: Frontmatter{Name: "c", Description: "third"}, ModTime: time.Now()}, } // 自定义评分器:只给 name="b" 的打高分 customScorer := &nameMatchScorer{targetName: "b"} result := SelectRelevant("any query", headers, 5, customScorer) if len(result) != 1 { t.Fatalf("期望 1 个结果, 实际 %d", len(result)) } if result[0].Frontmatter.Name != "b" { t.Errorf("最相关的应为 'b', 实际: %q", result[0].Frontmatter.Name) } } // nameMatchScorer 只给指定名称的记忆打高分,用于测试. type nameMatchScorer struct { targetName string } func (s *nameMatchScorer) Name() string { return "name-match" } func (s *nameMatchScorer) Score(query string, header *MemoryHeader) float64 { if header.Frontmatter.Name == s.targetName { return 1.0 } return 0 } // TestSelectRelevant_WithCompositeScorer 测试组合评分器集成 func TestSelectRelevant_WithCompositeScorer(t *testing.T) { headers := []MemoryHeader{ {Frontmatter: Frontmatter{Name: "go-testing", Description: "golang testing"}, ModTime: time.Now()}, {Frontmatter: Frontmatter{Name: "target", Description: "unrelated content"}, ModTime: time.Now()}, } cs := NewCompositeScorer( // 文本评分器权重 0.3 WeightedScorer{Scorer: &TextScorer{}, Weight: 0.3}, // name-match 评分器权重 0.7,只给 "target" 高分 WeightedScorer{Scorer: &nameMatchScorer{targetName: "target"}, Weight: 0.7}, ) result := SelectRelevant("golang testing", headers, 5, cs) if len(result) == 0 { t.Fatal("应有相关结果") } // name-match 权重更高,target 应排第一 if result[0].Frontmatter.Name != "target" { t.Errorf("组合评分器下最相关的应为 'target', 实际: %q", result[0].Frontmatter.Name) } } // TestFileStoreWithScorer 测试带评分器的 fileStore func TestFileStoreWithScorer(t *testing.T) { // 使用自定义评分器创建 store customScorer := &nameMatchScorer{targetName: "go-test"} dir := t.TempDir() baseDir := dir + "/memory" store := &fileStore{ cwd: dir, baseDir: baseDir, scorer: customScorer, } ctx := t.Context() // 保存记忆 store.Save(ctx, &Entry{Name: "go-test", Description: "golang testing", Type: TypeProject, Content: "test"}) store.Save(ctx, &Entry{Name: "python-ml", Description: "python ml", Type: TypeProject, Content: "ml"}) // 自定义评分器只给 go-test 高分 results, err := store.FindRelevant(ctx, "anything", 5) if err != nil { t.Fatalf("查找失败: %v", err) } if len(results) == 0 { t.Fatal("应有结果") } if results[0].Name != "go-test" { t.Errorf("应为 'go-test', 实际: %q", results[0].Name) } }