package wire import ( "encoding/json" "testing" "git.flytoex.net/yuanwei/flyto-agent/pkg/flyto" ) // --- buildUsageEvent --- func TestBuildUsageEvent_NilUsage(t *testing.T) { chunk := openaiChunk{} // Usage is nil evt := buildUsageEvent(chunk, "end_turn") if evt.StopReason != "end_turn" { t.Errorf("StopReason = %q", evt.StopReason) } if evt.InputTokens != 0 || evt.OutputTokens != 0 { t.Errorf("tokens should be 0 when usage is nil: in=%d out=%d", evt.InputTokens, evt.OutputTokens) } } func TestBuildUsageEvent_WithUsage(t *testing.T) { chunk := openaiChunk{} // openaiChunk.Usage is an embedded anonymous struct -- set it via JSON. raw := `{"usage":{"prompt_tokens":100,"completion_tokens":50,"cache_write_tokens":20}}` if err := json.Unmarshal([]byte(raw), &chunk); err != nil { t.Fatalf("unmarshal: %v", err) } evt := buildUsageEvent(chunk, "tool_use") if evt.InputTokens != 100 { t.Errorf("InputTokens = %d, want 100", evt.InputTokens) } if evt.OutputTokens != 50 { t.Errorf("OutputTokens = %d, want 50", evt.OutputTokens) } if evt.CacheCreationTokens != 20 { t.Errorf("CacheCreationTokens = %d, want 20", evt.CacheCreationTokens) } } func TestBuildUsageEvent_WithCacheRead(t *testing.T) { raw := `{"usage":{"prompt_tokens":200,"completion_tokens":30,"prompt_tokens_details":{"cached_tokens":150}}}` var chunk openaiChunk if err := json.Unmarshal([]byte(raw), &chunk); err != nil { t.Fatalf("unmarshal: %v", err) } evt := buildUsageEvent(chunk, "end_turn") if evt.CacheReadTokens != 150 { t.Errorf("CacheReadTokens = %d, want 150", evt.CacheReadTokens) } } // --- flytoMessagesToOpenAI --- func TestFlytoMessagesToOpenAI_SystemPrompt(t *testing.T) { msgs := []flyto.Message{ flyto.UserText("hello"), } result := flytoMessagesToOpenAI(msgs, "You are helpful.", false) if len(result) < 2 { t.Fatalf("expected at least 2 messages (system + user), got %d", len(result)) } if result[0].Role != "system" { t.Errorf("first message role = %q, want system", result[0].Role) } // system prompt content should contain "You are helpful." var content string if err := json.Unmarshal(result[0].Content, &content); err == nil { if content != "You are helpful." { t.Errorf("system content = %q", content) } } } func TestFlytoMessagesToOpenAI_CacheSystem(t *testing.T) { result := flytoMessagesToOpenAI(nil, "cached system", true) if len(result) == 0 { t.Fatal("expected system message") } // When cacheSystem=true, content is an array with cache_control. var blocks []json.RawMessage if err := json.Unmarshal(result[0].Content, &blocks); err != nil { t.Fatalf("cacheSystem content should be array: %v", err) } if len(blocks) == 0 { t.Fatal("cache blocks empty") } } func TestFlytoMessagesToOpenAI_NoSystemPrompt(t *testing.T) { msgs := []flyto.Message{ flyto.UserText("hi"), flyto.AssistantText("hello"), } result := flytoMessagesToOpenAI(msgs, "", false) if len(result) != 2 { t.Fatalf("expected 2 messages, got %d", len(result)) } if result[0].Role != "user" { t.Errorf("first role = %q, want user", result[0].Role) } if result[1].Role != "assistant" { t.Errorf("second role = %q, want assistant", result[1].Role) } } func TestFlytoMessagesToOpenAI_ToolUseAndResult(t *testing.T) { msgs := []flyto.Message{ { Role: flyto.RoleAssistant, Blocks: []flyto.Block{ flyto.ToolUseBlock("call_1", "Bash", map[string]any{"command": "ls"}), }, }, { Role: flyto.RoleUser, Blocks: []flyto.Block{ flyto.ToolResultBlock("call_1", "file1.go\nfile2.go", false), }, }, } result := flytoMessagesToOpenAI(msgs, "", false) // assistant message should have tool_calls if len(result) < 2 { t.Fatalf("expected at least 2 messages, got %d", len(result)) } if len(result[0].ToolCalls) == 0 { t.Error("assistant message should have tool_calls") } if result[0].ToolCalls[0].ID != "call_1" { t.Errorf("tool call ID = %q", result[0].ToolCalls[0].ID) } // tool result message if result[1].ToolCallID != "call_1" { t.Errorf("tool result ToolCallID = %q", result[1].ToolCallID) } } // --- parseNonSSEError --- func TestParseNonSSEError_OpenAIFormat(t *testing.T) { body := []byte(`{"error":{"message":"Invalid API key","type":"auth_error"}}`) err := parseNonSSEError(body, "application/json") if err == nil { t.Fatal("expected error") } if got := err.Error(); !strContains(got, "Invalid API key") { t.Errorf("error = %q, should contain 'Invalid API key'", got) } } func TestParseNonSSEError_MiniMaxBaseResp(t *testing.T) { body := []byte(`{"base_resp":{"status_code":1004,"status_msg":"rate limit exceeded"}}`) err := parseNonSSEError(body, "application/json") if err == nil { t.Fatal("expected error") } if got := err.Error(); !strContains(got, "1004") || !strContains(got, "rate limit") { t.Errorf("error = %q, should contain status code and message", got) } } func TestParseNonSSEError_NoError(t *testing.T) { body := []byte(`{"ok":true}`) err := parseNonSSEError(body, "application/json") // No error fields present -- should fall through. // The function returns nil or a generic error depending on implementation. _ = err // no crash is the main check } func TestParseNonSSEError_InvalidJSON(t *testing.T) { body := []byte(`not json at all`) err := parseNonSSEError(body, "text/plain") // Should not crash on invalid JSON. _ = err } // --- parseOpenRouterPrice --- func TestParseOpenRouterPrice_Normal(t *testing.T) { // OpenRouter prices are per-token; we multiply by 1M to get per-1M-tokens. got := parseOpenRouterPrice("0.000003") // $3 per 1M if got < 2.99 || got > 3.01 { t.Errorf("parseOpenRouterPrice(0.000003) = %f, want ~3.0", got) } } func TestParseOpenRouterPrice_Zero(t *testing.T) { if got := parseOpenRouterPrice("0"); got != 0 { t.Errorf("got %f, want 0", got) } if got := parseOpenRouterPrice(""); got != 0 { t.Errorf("empty string: got %f, want 0", got) } } func TestParseOpenRouterPrice_NA(t *testing.T) { // OpenRouter sometimes returns "N/A". got := parseOpenRouterPrice("N/A") if got != 0 { t.Errorf("N/A: got %f, want 0", got) } } // strContains is a test helper (avoids importing strings for one function). // Named to avoid conflict with schema_test.go's contains(). func strContains(s, sub string) bool { for i := 0; i <= len(s)-len(sub); i++ { if s[i:i+len(sub)] == sub { return true } } return false }