package wire import ( "encoding/json" "net/http" "strings" "testing" "git.flytoex.net/yuanwei/flyto-agent/pkg/flyto" ) // geminiGenReq 是解析后的 Gemini GenerateContentRequest 结构(用于断言). type geminiGenReq struct { GenerationConfig struct { MaxOutputTokens int `json:"maxOutputTokens,omitempty"` ResponseMimeType string `json:"responseMimeType,omitempty"` ResponseSchema json.RawMessage `json:"responseSchema,omitempty"` ThinkingConfig *struct { ThinkingBudget int `json:"thinkingBudget"` IncludeThoughts bool `json:"includeThoughts"` } `json:"thinkingConfig,omitempty"` } `json:"generationConfig"` } // parseGeminiReq 解析 Gemini 请求 JSON 用于断言. func parseGeminiReq(t *testing.T, data []byte) geminiGenReq { t.Helper() var req geminiGenReq if err := json.Unmarshal(data, &req); err != nil { t.Fatalf("failed to unmarshal Gemini request: %v", err) } return req } // TestBuildRequest_ResponseFormat_JSONObject 验证 json_object 模式. func TestBuildRequest_ResponseFormat_JSONObject(t *testing.T) { c := NewGeminiClient("fake-key", "https://fake.api") req := &StreamRequest{ Model: "gemini-2.0-flash", MaxTokens: 1024, ResponseFormat: &flyto.ResponseFormat{Type: "json_object"}, } got, err := c.buildRequest(req) if err != nil { t.Fatalf("buildRequest failed: %v", err) } r := parseGeminiReq(t, got) if r.GenerationConfig.ResponseMimeType != "application/json" { t.Errorf("responseMimeType = %q, want %q", r.GenerationConfig.ResponseMimeType, "application/json") } // responseSchema should be absent or empty if len(r.GenerationConfig.ResponseSchema) > 0 { t.Errorf("responseSchema should be empty for json_object, got %s", r.GenerationConfig.ResponseSchema) } } // TestBuildRequest_ResponseFormat_JSONSchema 验证 json_schema 模式. func TestBuildRequest_ResponseFormat_JSONSchema(t *testing.T) { c := NewGeminiClient("fake-key", "https://fake.api") schema := json.RawMessage(`{"type":"object","properties":{"name":{"type":"string"}}}`) req := &StreamRequest{ Model: "gemini-2.0-flash", MaxTokens: 1024, ResponseFormat: &flyto.ResponseFormat{Type: "json_schema", JSONSchema: schema}, } got, err := c.buildRequest(req) if err != nil { t.Fatalf("buildRequest failed: %v", err) } r := parseGeminiReq(t, got) if r.GenerationConfig.ResponseMimeType != "application/json" { t.Errorf("responseMimeType = %q, want %q", r.GenerationConfig.ResponseMimeType, "application/json") } if len(r.GenerationConfig.ResponseSchema) == 0 { t.Fatal("responseSchema should not be empty for json_schema") } // Verify schema content var gotSchema map[string]any if err := json.Unmarshal(r.GenerationConfig.ResponseSchema, &gotSchema); err != nil { t.Fatalf("responseSchema is not valid JSON: %v", err) } if gotSchema["type"] != "object" { t.Errorf("responseSchema.type = %v, want object", gotSchema["type"]) } } // TestBuildRequest_ResponseFormat_Nil 验证 nil ResponseFormat. func TestBuildRequest_ResponseFormat_Nil(t *testing.T) { c := NewGeminiClient("fake-key", "https://fake.api") req := &StreamRequest{ Model: "gemini-2.0-flash", MaxTokens: 1024, ResponseFormat: nil, } got, err := c.buildRequest(req) if err != nil { t.Fatalf("buildRequest failed: %v", err) } r := parseGeminiReq(t, got) if r.GenerationConfig.ResponseMimeType != "" { t.Errorf("responseMimeType = %q, want empty", r.GenerationConfig.ResponseMimeType) } } // TestBuildRequest_ThinkingBudget_PerRequest 验证 per-request thinking budget. func TestBuildRequest_ThinkingBudget_PerRequest(t *testing.T) { c := NewGeminiClient("fake-key", "https://fake.api") req := &StreamRequest{ Model: "gemini-2.0-flash", MaxTokens: 1024, ThinkingBudget: 4000, } got, err := c.buildRequest(req) if err != nil { t.Fatalf("buildRequest failed: %v", err) } r := parseGeminiReq(t, got) if r.GenerationConfig.ThinkingConfig == nil { t.Fatal("thinkingConfig should not be nil when ThinkingBudget > 0") } if r.GenerationConfig.ThinkingConfig.ThinkingBudget != 4000 { t.Errorf("thinkingBudget = %d, want 4000", r.GenerationConfig.ThinkingConfig.ThinkingBudget) } if !r.GenerationConfig.ThinkingConfig.IncludeThoughts { t.Error("includeThoughts should be true") } } // TestBuildRequest_ThinkingBudget_ClientFallback 验证 client-level thinking budget 回退. func TestBuildRequest_ThinkingBudget_ClientFallback(t *testing.T) { c := NewGeminiClient("fake-key", "https://fake.api", GeminiWithThinkingBudget(8000)) req := &StreamRequest{ Model: "gemini-2.0-flash", MaxTokens: 1024, ThinkingBudget: 0, // per-request 为 0,使用 client 配置 } got, err := c.buildRequest(req) if err != nil { t.Fatalf("buildRequest failed: %v", err) } r := parseGeminiReq(t, got) if r.GenerationConfig.ThinkingConfig == nil { t.Fatal("thinkingConfig should not be nil when client has thinkingBudget") } if r.GenerationConfig.ThinkingConfig.ThinkingBudget != 8000 { t.Errorf("thinkingBudget = %d, want 8000 (client fallback)", r.GenerationConfig.ThinkingConfig.ThinkingBudget) } } // TestBuildRequest_ThinkingBudget_None 验证 thinking 禁用. func TestBuildRequest_ThinkingBudget_None(t *testing.T) { c := NewGeminiClient("fake-key", "https://fake.api") req := &StreamRequest{ Model: "gemini-2.0-flash", MaxTokens: 1024, ThinkingBudget: 0, } got, err := c.buildRequest(req) if err != nil { t.Fatalf("buildRequest failed: %v", err) } r := parseGeminiReq(t, got) if r.GenerationConfig.ThinkingConfig != nil { t.Errorf("thinkingConfig should be nil when both req and client have budget=0, got %+v", r.GenerationConfig.ThinkingConfig) } } // --- flytoMessagesToGemini --- func TestFlytoMessagesToGemini_TextOnly(t *testing.T) { msgs := []flyto.Message{ flyto.UserText("hello"), flyto.AssistantText("hi there"), } result := flytoMessagesToGemini(msgs) if len(result) != 2 { t.Fatalf("expected 2 contents, got %d", len(result)) } if result[0].Role != "user" { t.Errorf("first role = %q, want user", result[0].Role) } if result[1].Role != "model" { t.Errorf("second role = %q, want model", result[1].Role) } if result[0].Parts[0].Text != "hello" { t.Errorf("user text = %q", result[0].Parts[0].Text) } if result[1].Parts[0].Text != "hi there" { t.Errorf("model text = %q", result[1].Parts[0].Text) } } func TestFlytoMessagesToGemini_ToolUseAndResult(t *testing.T) { msgs := []flyto.Message{ { Role: flyto.RoleAssistant, Blocks: []flyto.Block{ flyto.ToolUseBlock("call_1", "Bash", map[string]any{"command": "ls"}), }, }, { Role: flyto.RoleUser, Blocks: []flyto.Block{ flyto.ToolResultBlock("call_1", "file1.go", false), }, }, } result := flytoMessagesToGemini(msgs) if len(result) != 2 { t.Fatalf("expected 2 contents, got %d", len(result)) } // assistant message should have functionCall modelParts := result[0].Parts if len(modelParts) != 1 || modelParts[0].FunctionCall == nil { t.Fatal("model message should have a FunctionCall part") } if modelParts[0].FunctionCall.Name != "Bash" { t.Errorf("FunctionCall.Name = %q", modelParts[0].FunctionCall.Name) } // user message should have functionResponse userParts := result[1].Parts if len(userParts) != 1 || userParts[0].FunctionResp == nil { t.Fatal("user message should have a FunctionResp part") } if userParts[0].FunctionResp.Name != "Bash" { t.Errorf("FunctionResp.Name = %q, want Bash (looked up from tool_use)", userParts[0].FunctionResp.Name) } if userParts[0].FunctionResp.Response["result"] != "file1.go" { t.Errorf("FunctionResp.Response = %v", userParts[0].FunctionResp.Response) } } func TestFlytoMessagesToGemini_ToolResultError(t *testing.T) { msgs := []flyto.Message{ { Role: flyto.RoleAssistant, Blocks: []flyto.Block{ flyto.ToolUseBlock("call_e", "Read", map[string]any{"path": "/nope"}), }, }, { Role: flyto.RoleUser, Blocks: []flyto.Block{ flyto.ToolResultBlock("call_e", "file not found", true), }, }, } result := flytoMessagesToGemini(msgs) resp := result[1].Parts[0].FunctionResp if resp.Response["error"] != "file not found" { t.Errorf("error field = %v", resp.Response["error"]) } if resp.Response["result"] != "" { t.Errorf("result should be empty on error, got %v", resp.Response["result"]) } } func TestFlytoMessagesToGemini_ToolUseNilInput(t *testing.T) { msgs := []flyto.Message{ { Role: flyto.RoleAssistant, Blocks: []flyto.Block{ {Type: flyto.BlockToolUse, ToolUseID: "t1", ToolName: "Glob", ToolInput: nil}, }, }, } result := flytoMessagesToGemini(msgs) if len(result) != 1 { t.Fatalf("expected 1 content, got %d", len(result)) } fc := result[0].Parts[0].FunctionCall if fc == nil { t.Fatal("FunctionCall should not be nil") } if fc.Args == nil { t.Error("Args should be empty map, not nil") } } func TestFlytoMessagesToGemini_EmptyMessages(t *testing.T) { result := flytoMessagesToGemini(nil) if len(result) != 0 { t.Errorf("nil messages should return empty, got %d", len(result)) } } func TestFlytoMessagesToGemini_EmptyTextSkipped(t *testing.T) { msgs := []flyto.Message{ {Role: flyto.RoleUser, Blocks: []flyto.Block{{Type: flyto.BlockText, Text: ""}}}, } result := flytoMessagesToGemini(msgs) if len(result) != 0 { t.Errorf("empty text blocks should be skipped, got %d contents", len(result)) } } // --- GeminiClient options --- func TestGeminiWithHTTPClient(t *testing.T) { custom := &http.Client{} c := NewGeminiClient("key", "", GeminiWithHTTPClient(custom)) if c.HTTPClient() != custom { t.Error("GeminiWithHTTPClient did not set the client") } } func TestGeminiWithBearerToken(t *testing.T) { c := NewGeminiClient("", "", GeminiWithBearerToken("my-token")) if c.bearerToken != "my-token" { t.Errorf("bearerToken = %q, want my-token", c.bearerToken) } } func TestGeminiHTTPClient_Default(t *testing.T) { c := NewGeminiClient("key", "") if c.HTTPClient() == nil { t.Error("default HTTPClient should not be nil") } } // --- geminiGenerateID --- func TestGeminiGenerateID_Format(t *testing.T) { id := geminiGenerateID() if !strings.HasPrefix(id, "gtool_") { t.Errorf("ID should start with gtool_, got %q", id) } // "gtool_" (6) + 24 hex chars (12 bytes) if len(id) != 30 { t.Errorf("ID length = %d, want 30", len(id)) } } func TestGeminiGenerateID_Unique(t *testing.T) { ids := make(map[string]bool) for range 100 { id := geminiGenerateID() if ids[id] { t.Fatalf("duplicate ID: %s", id) } ids[id] = true } }