package pricing import ( "os" "path/filepath" "testing" "git.flytoex.net/yuanwei/flyto-agent/pkg/config" ) func TestAsInt(t *testing.T) { cases := []struct { name string in any want int ok bool }{ {"nil", nil, 0, false}, {"float64", float64(200000), 200000, true}, {"int", 42, 42, true}, {"int64", int64(123), 123, true}, {"string", "100", 0, false}, {"bool", true, 0, false}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { got, ok := asInt(c.in) if got != c.want || ok != c.ok { t.Errorf("asInt(%v): got (%d, %v), want (%d, %v)", c.in, got, ok, c.want, c.ok) } }) } } func TestAsFloat(t *testing.T) { cases := []struct { name string in any want float64 ok bool }{ {"nil", nil, 0, false}, {"float64", 3.0, 3.0, true}, {"int", 5, 5.0, true}, {"int64", int64(7), 7.0, true}, {"string", "3.0", 0, false}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { got, ok := asFloat(c.in) if got != c.want || ok != c.ok { t.Errorf("asFloat(%v): got (%f, %v), want (%f, %v)", c.in, got, ok, c.want, c.ok) } }) } } func TestAsBool(t *testing.T) { cases := []struct { name string in any want bool ok bool }{ {"nil", nil, false, false}, {"true", true, true, true}, {"false", false, false, true}, {"string", "true", false, false}, {"int", 1, false, false}, } for _, c := range cases { t.Run(c.name, func(t *testing.T) { got, ok := asBool(c.in) if got != c.want || ok != c.ok { t.Errorf("asBool(%v): got (%v, %v), want (%v, %v)", c.in, got, ok, c.want, c.ok) } }) } } func TestToModelInfo_NilInput(t *testing.T) { if got := toModelInfo(nil); got != nil { t.Errorf("toModelInfo(nil): got %+v, want nil", got) } } func TestToModelInfo_EmptyModelID(t *testing.T) { caps := &ModelCapabilities{Provider: "anthropic", Model: ""} if got := toModelInfo(caps); got != nil { t.Errorf("toModelInfo(empty model): got %+v, want nil", got) } } func TestToModelInfo_FullConversion(t *testing.T) { // 模拟 JSON 解析后的 float64 数字. caps := &ModelCapabilities{ Provider: "anthropic", Model: "claude-sonnet-4-6", ContextWindow: Capability{Value: float64(200000), Source: "documented"}, MaxOutputTokens: Capability{Value: float64(16384), Source: "documented"}, InputPricePer1M: Capability{Value: 3.0, Source: "documented"}, OutputPricePer1M: Capability{Value: 15.0, Source: "documented"}, CacheReadPricePer1M: Capability{Value: 0.3, Source: "documented"}, CacheWritePricePer1M: Capability{Value: 3.75, Source: "documented"}, Thinking: Capability{Value: true, Source: "documented"}, Caching: Capability{Value: true, Source: "probed"}, Vision: Capability{Value: true, Source: "documented"}, } info := toModelInfo(caps) if info == nil { t.Fatal("expected non-nil ModelInfo") } if info.ID != "claude-sonnet-4-6" { t.Errorf("ID: got %q, want claude-sonnet-4-6", info.ID) } if info.Provider != "anthropic" { t.Errorf("Provider: got %q, want anthropic", info.Provider) } if info.ContextWindow != 200000 { t.Errorf("ContextWindow: got %d, want 200000", info.ContextWindow) } if info.MaxOutputTokens != 16384 { t.Errorf("MaxOutputTokens: got %d, want 16384", info.MaxOutputTokens) } if info.InputPricePer1M != 3.0 { t.Errorf("InputPricePer1M: got %f, want 3.0", info.InputPricePer1M) } if info.OutputPricePer1M != 15.0 { t.Errorf("OutputPricePer1M: got %f, want 15.0", info.OutputPricePer1M) } if info.CacheReadPricePer1M != 0.3 { t.Errorf("CacheReadPricePer1M: got %f, want 0.3", info.CacheReadPricePer1M) } if info.CacheWritePricePer1M != 3.75 { t.Errorf("CacheWritePricePer1M: got %f, want 3.75", info.CacheWritePricePer1M) } if !info.SupportsThinking { t.Error("SupportsThinking: expected true") } if !info.SupportsCaching { t.Error("SupportsCaching: expected true") } if !info.SupportsVision { t.Error("SupportsVision: expected true") } } func TestToModelInfo_MissingFields(t *testing.T) { // 只有 Provider 和 Model,其他字段缺失--确保不会崩溃且零值生效. caps := &ModelCapabilities{ Provider: "openai", Model: "gpt-4o", } info := toModelInfo(caps) if info == nil { t.Fatal("expected non-nil ModelInfo") } if info.ID != "gpt-4o" { t.Errorf("ID: got %q, want gpt-4o", info.ID) } if info.ContextWindow != 0 { t.Errorf("ContextWindow: expected 0 for missing field, got %d", info.ContextWindow) } if info.InputPricePer1M != 0 { t.Errorf("InputPricePer1M: expected 0, got %f", info.InputPricePer1M) } if info.SupportsThinking { t.Error("SupportsThinking: expected false") } } func TestRegisterFromReport_NilSafe(t *testing.T) { registry := config.NewModelRegistry() if n := RegisterFromReport(registry, nil); n != 0 { t.Errorf("RegisterFromReport(nil report): got %d, want 0", n) } if n := RegisterFromReport(nil, &CapabilityReport{}); n != 0 { t.Errorf("RegisterFromReport(nil registry): got %d, want 0", n) } } func TestRegisterFromReport_RegistersModels(t *testing.T) { registry := config.NewModelRegistry() report := &CapabilityReport{ SchemaVersion: "1.0", Models: map[string]*ModelCapabilities{ "anthropic:claude-sonnet-4-6": { Provider: "anthropic", Model: "claude-sonnet-4-6", ContextWindow: Capability{Value: float64(200000)}, InputPricePer1M: Capability{Value: 3.0}, OutputPricePer1M: Capability{Value: 15.0}, }, "openai:gpt-4o": { Provider: "openai", Model: "gpt-4o", ContextWindow: Capability{Value: float64(128000)}, InputPricePer1M: Capability{Value: 2.5}, OutputPricePer1M: Capability{Value: 10.0}, }, }, } n := RegisterFromReport(registry, report) if n != 2 { t.Errorf("expected 2 models registered, got %d", n) } // 验证 registry 能查到并正确估算成本. if cfg := registry.GetConfig("claude-sonnet-4-6"); cfg == nil { t.Error("claude-sonnet-4-6 not found in registry") } else { if cfg.ContextWindow != 200000 { t.Errorf("ContextWindow: got %d, want 200000", cfg.ContextWindow) } if cfg.InputPricePer1M != 3.0 { t.Errorf("InputPricePer1M: got %f, want 3.0", cfg.InputPricePer1M) } } // 成本估算:1M input + 1M output = 3 + 15 = 18 USD cost := registry.EstimateSimpleCost("claude-sonnet-4-6", 1_000_000, 1_000_000) if cost != 18.0 { t.Errorf("EstimateSimpleCost: got %f, want 18.0", cost) } } func TestRegisterFromReport_SkipsEmptyModelID(t *testing.T) { registry := config.NewModelRegistry() report := &CapabilityReport{ Models: map[string]*ModelCapabilities{ "bad-entry": {Provider: "test", Model: ""}, // 空 model id }, } n := RegisterFromReport(registry, report) if n != 0 { t.Errorf("expected 0 models registered (empty ID skipped), got %d", n) } } func TestLoadAndRegisterFrom_FullFlow(t *testing.T) { // 写入 fixture 文件,完整跑一遍 Load + Register. dir := t.TempDir() path := filepath.Join(dir, "capabilities.json") if err := os.WriteFile(path, []byte(testCapabilitiesJSON), 0o644); err != nil { t.Fatalf("write fixture: %v", err) } registry := config.NewModelRegistry() n, err := LoadAndRegisterFrom(registry, path) if err != nil { t.Fatalf("unexpected error: %v", err) } if n != 1 { t.Errorf("expected 1 model, got %d", n) } cfg := registry.GetConfig("claude-sonnet-4-6") if cfg == nil { t.Fatal("claude-sonnet-4-6 not registered") } if cfg.ContextWindow != 200000 { t.Errorf("ContextWindow: got %d, want 200000", cfg.ContextWindow) } if cfg.CacheReadPricePer1M != 0.3 { t.Errorf("CacheReadPricePer1M: got %f, want 0.3", cfg.CacheReadPricePer1M) } if !cfg.SupportsCaching { t.Error("SupportsCaching: expected true") } // 验证 EstimateCost 包含 cache 成本. // 1M input + 1M output + 1M cache read + 1M cache write // = 3.0 + 15.0 + 0.3 + 3.75 = 22.05 cost := registry.EstimateCost("claude-sonnet-4-6", 1_000_000, 1_000_000, 1_000_000, 1_000_000) const want = 22.05 if diff := cost - want; diff > 1e-9 || diff < -1e-9 { t.Errorf("EstimateCost: got %f, want %f", cost, want) } } func TestLoadAndRegisterFrom_MissingFile(t *testing.T) { registry := config.NewModelRegistry() path := filepath.Join(t.TempDir(), "does-not-exist.json") n, err := LoadAndRegisterFrom(registry, path) if err != nil { t.Errorf("expected nil error for missing file, got %v", err) } if n != 0 { t.Errorf("expected 0 models, got %d", n) } } func TestLoadAndRegister_DefaultPath(t *testing.T) { // 通过环境变量注入 fixture,测试 LoadAndRegister 的默认路径行为. dir := t.TempDir() path := filepath.Join(dir, "capabilities.json") if err := os.WriteFile(path, []byte(testCapabilitiesJSON), 0o644); err != nil { t.Fatalf("write fixture: %v", err) } t.Setenv("FLYTO_CAPABILITIES_PATH", path) registry := config.NewModelRegistry() n, err := LoadAndRegister(registry) if err != nil { t.Fatalf("unexpected error: %v", err) } if n != 1 { t.Errorf("expected 1 model, got %d", n) } }