package validator import ( "context" "errors" "strings" "testing" "git.flytoex.net/yuanwei/flyto-agent/pkg/flyto" ) // mockModelProvider is a flyto.ModelProvider that emits a fixed event // sequence and captures the inbound Request for inspection. Shared by // llm_adapter_flyto_test.go and llm_validator_test.go. type mockModelProvider struct { name string events []flyto.Event err error // if non-nil, Stream returns this instead got *flyto.Request // captured request for assertion } func (m *mockModelProvider) Name() string { if m.name == "" { return "mock" } return m.name } func (m *mockModelProvider) Stream(ctx context.Context, req *flyto.Request) (<-chan flyto.Event, error) { m.got = req if m.err != nil { return nil, m.err } ch := make(chan flyto.Event, len(m.events)) for _, e := range m.events { ch <- e } close(ch) return ch, nil } func (m *mockModelProvider) Models(_ context.Context) ([]flyto.ModelInfo, error) { return nil, nil } // cancellableProvider returns a channel that never closes, to exercise // context cancellation paths. type cancellableProvider struct { ch <-chan flyto.Event } func (p *cancellableProvider) Name() string { return "cancellable" } func (p *cancellableProvider) Stream(_ context.Context, _ *flyto.Request) (<-chan flyto.Event, error) { return p.ch, nil } func (p *cancellableProvider) Models(_ context.Context) ([]flyto.ModelInfo, error) { return nil, nil } // -- FlytoLLMClient tests -- func TestNewFlytoLLMClient_NilProvider(t *testing.T) { _, err := NewFlytoLLMClient(nil, "", "") if err == nil { t.Errorf("expected error for nil provider") } } func TestFlytoLLMClient_Complete_TextEventAggregation(t *testing.T) { mp := &mockModelProvider{ events: []flyto.Event{ &flyto.TextDeltaEvent{Text: "hel"}, &flyto.TextDeltaEvent{Text: "lo "}, &flyto.TextEvent{Text: "hello world"}, }, } client, err := NewFlytoLLMClient(mp, "model-x", "be a reviewer") if err != nil { t.Fatalf("construction: %v", err) } resp, err := client.Complete(context.Background(), "review this", "", 100) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp != "hello world" { t.Errorf("TextEvent should win, got %q", resp) } if mp.got.Model != "model-x" { t.Errorf("model fallback failed: got %q", mp.got.Model) } if mp.got.System != "be a reviewer" { t.Errorf("system prompt missing: got %q", mp.got.System) } if mp.got.MaxTokens != 100 { t.Errorf("max_tokens mismatch: got %d", mp.got.MaxTokens) } } func TestFlytoLLMClient_Complete_DeltaFallback(t *testing.T) { mp := &mockModelProvider{ events: []flyto.Event{ &flyto.TextDeltaEvent{Text: "part1 "}, &flyto.TextDeltaEvent{Text: "part2"}, }, } client, _ := NewFlytoLLMClient(mp, "model-x", "") resp, err := client.Complete(context.Background(), "hello", "", 0) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp != "part1 part2" { t.Errorf("delta fallback failed, got %q", resp) } } func TestFlytoLLMClient_Complete_StreamError(t *testing.T) { mp := &mockModelProvider{err: errors.New("provider down")} client, _ := NewFlytoLLMClient(mp, "model-x", "") _, err := client.Complete(context.Background(), "hello", "", 0) if !errors.Is(err, ErrValidatorBackend) { t.Errorf("expected ErrValidatorBackend, got %v", err) } if !strings.Contains(err.Error(), "provider down") { t.Errorf("underlying error should surface, got %v", err) } } func TestFlytoLLMClient_Complete_ErrorEvent(t *testing.T) { mp := &mockModelProvider{ events: []flyto.Event{ &flyto.ErrorEvent{Err: errors.New("stream boom")}, }, } client, _ := NewFlytoLLMClient(mp, "model-x", "") _, err := client.Complete(context.Background(), "hello", "", 0) if !errors.Is(err, ErrValidatorBackend) { t.Errorf("ErrorEvent should wrap ErrValidatorBackend, got %v", err) } if !strings.Contains(err.Error(), "stream boom") { t.Errorf("underlying error should bubble up, got %v", err) } } func TestFlytoLLMClient_Complete_ContextCancelled(t *testing.T) { ch := make(chan flyto.Event) // never sends, never closes mp := &cancellableProvider{ch: ch} client, _ := NewFlytoLLMClient(mp, "model-x", "") ctx, cancel := context.WithCancel(context.Background()) cancel() _, err := client.Complete(ctx, "hello", "", 0) if !errors.Is(err, context.Canceled) { t.Errorf("expected context.Canceled, got %v", err) } } func TestFlytoLLMClient_Complete_ModelOverride(t *testing.T) { mp := &mockModelProvider{ events: []flyto.Event{&flyto.TextEvent{Text: "ok"}}, } client, _ := NewFlytoLLMClient(mp, "default-model", "") _, err := client.Complete(context.Background(), "hello", "override-model", 0) if err != nil { t.Fatalf("unexpected error: %v", err) } if mp.got.Model != "override-model" { t.Errorf("override model should win, got %q", mp.got.Model) } } func TestFlytoLLMClient_Complete_IgnoresNonTextEvents(t *testing.T) { // Non-text events should be silently ignored; only text aggregation // wins. mp := &mockModelProvider{ events: []flyto.Event{ &flyto.ToolUseEvent{ID: "t1", ToolName: "bash", Input: map[string]any{"cmd": "ls"}}, &flyto.ThinkingEvent{Text: "ignored thinking"}, &flyto.TextEvent{Text: "real answer"}, &flyto.DoneEvent{}, }, } client, _ := NewFlytoLLMClient(mp, "model-x", "") resp, err := client.Complete(context.Background(), "hello", "", 0) if err != nil { t.Fatalf("unexpected error: %v", err) } if resp != "real answer" { t.Errorf("only TextEvent should be aggregated, got %q", resp) } }