package context import ( "context" "strings" "testing" ) // --------------------------------------------------------------------------- // BundleOverlay 测试 // --------------------------------------------------------------------------- func TestBundleOverlay_NoOverrides(t *testing.T) { base := NewDefaultBundle() overlay := NewBundleOverlay(base) // 无覆盖时,静态/动态 sections 与 base 完全相同 baseSt := base.StaticSections() overlaySt := overlay.StaticSections() if len(baseSt) != len(overlaySt) { t.Fatalf("static sections length mismatch: base=%d overlay=%d", len(baseSt), len(overlaySt)) } for i, s := range baseSt { if s.Name != overlaySt[i].Name { t.Errorf("static[%d]: base=%s overlay=%s", i, s.Name, overlaySt[i].Name) } } baseDy := base.DynamicSections() overlayDy := overlay.DynamicSections() if len(baseDy) != len(overlayDy) { t.Fatalf("dynamic sections length mismatch: base=%d overlay=%d", len(baseDy), len(overlayDy)) } } func TestBundleOverlay_OverrideStatic(t *testing.T) { base := NewDefaultBundle() const newIntro = "你是 Flyto,一个智能助手。" overlay := NewBundleOverlay(base).OverrideStatic("intro", newIntro) sections := overlay.StaticSections() found := false for _, s := range sections { if s.Name == "intro" { found = true if s.Text != newIntro { t.Errorf("intro text: got %q, want %q", s.Text, newIntro) } if !s.Static { t.Error("overridden static section should have Static=true") } } } if !found { t.Fatal("intro section not found in overlay") } // 其他 section 的 Text 不受影响(与 base 内容相同) baseDoingTasksText := "" for _, bs := range base.StaticSections() { if bs.Name == "doing_tasks" { baseDoingTasksText = bs.Text } } for _, s := range sections { if s.Name == "doing_tasks" { if s.Text != baseDoingTasksText { t.Error("doing_tasks text should be same as base (not replaced)") } } } } func TestBundleOverlay_OverrideDynamic(t *testing.T) { base := NewDefaultBundle() const customEnv = "# 自定义环境信息" overlay := NewBundleOverlay(base).OverrideDynamic("env_info", func(_ context.Context) string { return customEnv }) ctx := context.Background() reg := NewSectionRegistry() sections := overlay.DynamicSections() found := false for _, s := range sections { if s.Name == "env_info" { found = true got := reg.Compute(ctx, s) if got != customEnv { t.Errorf("env_info: got %q, want %q", got, customEnv) } } } if !found { t.Fatal("env_info section not found in overlay dynamic sections") } } func TestBundleOverlay_OverrideVolatile(t *testing.T) { base := NewDefaultBundle() callCount := 0 overlay := NewBundleOverlay(base).OverrideVolatile("env_info", func(_ context.Context) string { callCount++ return "volatile content" }, "test: env changes every turn") ctx := context.Background() reg := NewSectionRegistry() // 找到被覆盖的 env_info for _, s := range overlay.DynamicSections() { if s.Name == "env_info" { if !s.CacheBreak { t.Error("volatile override should have CacheBreak=true") } // 多次调用,每次都应该重算(CacheBreak 跳过缓存) reg.Compute(ctx, s) reg.Compute(ctx, s) if callCount != 2 { t.Errorf("volatile section should be called each time, got %d calls", callCount) } } } } func TestBundleOverlay_ChainedOverrides(t *testing.T) { base := NewDefaultBundle() overlay := NewBundleOverlay(base). OverrideStatic("intro", "intro override"). OverrideStatic("system", "system override"). OverrideStatic("doing_tasks", "doing_tasks override") sections := overlay.StaticSections() overridden := map[string]bool{} for _, s := range sections { switch s.Name { case "intro": if s.Text != "intro override" { t.Errorf("intro: %q", s.Text) } overridden["intro"] = true case "system": if s.Text != "system override" { t.Errorf("system: %q", s.Text) } overridden["system"] = true case "doing_tasks": if s.Text != "doing_tasks override" { t.Errorf("doing_tasks: %q", s.Text) } overridden["doing_tasks"] = true } } for _, name := range []string{"intro", "system", "doing_tasks"} { if !overridden[name] { t.Errorf("section %q not found or not overridden", name) } } } func TestBundleOverlay_UnknownNameIgnored(t *testing.T) { // 覆盖一个 base 中不存在的 section 名称--应被静默忽略(不产生新 section) base := NewDefaultBundle() overlay := NewBundleOverlay(base).OverrideStatic("this_section_does_not_exist", "ghost") baseStatic := base.StaticSections() overlayStatic := overlay.StaticSections() if len(baseStatic) != len(overlayStatic) { t.Errorf("unknown override name should not add new sections: base=%d overlay=%d", len(baseStatic), len(overlayStatic)) } } func TestBundleOverlay_ImplementsPromptBundle(t *testing.T) { // 编译时断言:BundleOverlay 实现了 PromptBundle 接口 var _ PromptBundle = NewBundleOverlay(NewDefaultBundle()) } // --------------------------------------------------------------------------- // BundleFromFunc 测试 // --------------------------------------------------------------------------- func TestNewBundleFromFunc_Basic(t *testing.T) { called := map[string]int{} bundle := NewBundleFromFunc( func() []*Section { called["static"]++ return []*Section{StaticSection("custom_intro", "hello")} }, func() []*Section { called["dynamic"]++ return []*Section{DynamicSection("custom_env", func(_ context.Context) string { return "env" })} }, ) st := bundle.StaticSections() if len(st) != 1 || st[0].Name != "custom_intro" { t.Errorf("static: got %v", st) } dy := bundle.DynamicSections() if len(dy) != 1 || dy[0].Name != "custom_env" { t.Errorf("dynamic: got %v", dy) } if called["static"] != 1 || called["dynamic"] != 1 { t.Errorf("call counts: %v", called) } } func TestNewBundleFromFunc_NilFuncs(t *testing.T) { bundle := NewBundleFromFunc(nil, nil) if bundle.StaticSections() != nil { t.Error("nil staticFn should return nil") } if bundle.DynamicSections() != nil { t.Error("nil dynamicFn should return nil") } } // --------------------------------------------------------------------------- // PromptLanguage context helpers 测试 // --------------------------------------------------------------------------- func TestPromptLanguageFromCtx(t *testing.T) { ctx := context.Background() // 未设置时返回空字符串 if lang := PromptLanguageFromCtx(ctx); lang != "" { t.Errorf("empty ctx: got %q", lang) } // 设置后可以读回 ctx = WithPromptLanguage(ctx, "zh-CN") if lang := PromptLanguageFromCtx(ctx); lang != "zh-CN" { t.Errorf("after set: got %q, want zh-CN", lang) } } // --------------------------------------------------------------------------- // BundleOverlay 与 BuildPromptBlocks 集成测试 // --------------------------------------------------------------------------- func TestBundleOverlay_IntegrationWithBuildPromptBlocks(t *testing.T) { const customIntro = "自定义介绍段落" overlay := NewBundleOverlay(NewDefaultBundle()). OverrideStatic("intro", customIntro) ctx := context.Background() reg := NewSectionRegistry() blocks := BuildPromptBlocks(ctx, overlay, reg, false) if len(blocks) == 0 { t.Fatal("BuildPromptBlocks returned empty") } combined := BlocksToString(blocks) if combined == "" { t.Fatal("combined prompt is empty") } // 自定义 intro 应出现在最终提示词中 found := false for _, block := range blocks { if strings.Contains(block.Text, customIntro) { found = true break } } if !found { t.Errorf("customIntro %q not found in blocks", customIntro) } }