package evolve import ( "context" "errors" "testing" "time" "git.flytoex.net/yuanwei/flyto-agent/pkg/flyto" ) // fakeProvider implements flyto.ModelProvider for adapter tests, replaying a // preset event script. When blockAfterEvents is true it holds the stream open // (blocks on ctx.Done) after emitting the scripted events, letting tests // exercise context cancellation mid-stream. type fakeProvider struct { events []flyto.Event streamErr error blockAfterEvents bool lastReq *flyto.Request } func (p *fakeProvider) Name() string { return "fake" } func (p *fakeProvider) Stream(ctx context.Context, req *flyto.Request) (<-chan flyto.Event, error) { p.lastReq = req if p.streamErr != nil { return nil, p.streamErr } ch := make(chan flyto.Event, len(p.events)+1) go func() { defer close(ch) for _, e := range p.events { select { case <-ctx.Done(): return case ch <- e: } } if p.blockAfterEvents { <-ctx.Done() } }() return ch, nil } func (p *fakeProvider) Models(ctx context.Context) ([]flyto.ModelInfo, error) { return nil, nil } func newTestClient(t *testing.T, p *fakeProvider) *FlytoLLMClient { t.Helper() c, err := NewFlytoLLMClient(p, "default-model", "sys prompt") if err != nil { t.Fatalf("NewFlytoLLMClient: %v", err) } return c } func TestFlytoLLMClient_NewNilProvider(t *testing.T) { if _, err := NewFlytoLLMClient(nil, "m", ""); err == nil { t.Fatal("expected error for nil provider, got nil") } } func TestFlytoLLMClient_TextDeltaAggregation(t *testing.T) { p := &fakeProvider{events: []flyto.Event{ &flyto.TextDeltaEvent{Text: "hel"}, &flyto.TextDeltaEvent{Text: "lo "}, &flyto.TextDeltaEvent{Text: "world"}, }} got, err := newTestClient(t, p).Complete(context.Background(), "p", LLMCallOpts{}) if err != nil { t.Fatalf("Complete: %v", err) } if got != "hello world" { t.Errorf("got %q, want %q", got, "hello world") } } // TextEvent is authoritative: when provider emits both Delta and TextEvent // (anthropic behavior), aggregation must not double-count. func TestFlytoLLMClient_TextEventAuthoritative(t *testing.T) { p := &fakeProvider{events: []flyto.Event{ &flyto.TextDeltaEvent{Text: "hel"}, &flyto.TextDeltaEvent{Text: "lo world"}, &flyto.TextEvent{Text: "hello world"}, }} got, err := newTestClient(t, p).Complete(context.Background(), "p", LLMCallOpts{}) if err != nil { t.Fatalf("Complete: %v", err) } if got != "hello world" { t.Errorf("got %q, want %q (deltas must not double-count)", got, "hello world") } } func TestFlytoLLMClient_MultipleTextEvents(t *testing.T) { p := &fakeProvider{events: []flyto.Event{ &flyto.TextEvent{Text: "part1 "}, &flyto.TextEvent{Text: "part2"}, }} got, err := newTestClient(t, p).Complete(context.Background(), "p", LLMCallOpts{}) if err != nil { t.Fatalf("Complete: %v", err) } if got != "part1 part2" { t.Errorf("got %q, want %q", got, "part1 part2") } } func TestFlytoLLMClient_ErrorEventWrapsErrLLMFailed(t *testing.T) { upstream := errors.New("upstream boom") p := &fakeProvider{events: []flyto.Event{ &flyto.TextDeltaEvent{Text: "partial"}, &flyto.ErrorEvent{Err: upstream}, }} _, err := newTestClient(t, p).Complete(context.Background(), "p", LLMCallOpts{}) if err == nil { t.Fatal("expected error, got nil") } if !errors.Is(err, ErrLLMFailed) { t.Errorf("errors.Is(err, ErrLLMFailed) = false, want true; err=%v", err) } } func TestFlytoLLMClient_StreamErrWrapsErrLLMFailed(t *testing.T) { p := &fakeProvider{streamErr: errors.New("dial fail")} _, err := newTestClient(t, p).Complete(context.Background(), "p", LLMCallOpts{}) if !errors.Is(err, ErrLLMFailed) { t.Errorf("errors.Is(err, ErrLLMFailed) = false, want true; err=%v", err) } } func TestFlytoLLMClient_EmptyStream(t *testing.T) { p := &fakeProvider{events: nil} got, err := newTestClient(t, p).Complete(context.Background(), "p", LLMCallOpts{}) if err != nil { t.Fatalf("Complete: %v", err) } if got != "" { t.Errorf("got %q, want empty", got) } } func TestFlytoLLMClient_CtxCancel(t *testing.T) { p := &fakeProvider{blockAfterEvents: true} ctx, cancel := context.WithCancel(context.Background()) done := make(chan error, 1) go func() { _, err := newTestClient(t, p).Complete(ctx, "p", LLMCallOpts{}) done <- err }() time.Sleep(20 * time.Millisecond) cancel() select { case err := <-done: if !errors.Is(err, context.Canceled) { t.Errorf("want context.Canceled, got %v", err) } case <-time.After(2 * time.Second): t.Fatal("Complete did not return within 2s after cancel") } } // Non-text events (ToolUse, Thinking, Usage, Done) must be dropped, never // surfaced as text nor as errors. func TestFlytoLLMClient_IgnoresNonTextEvents(t *testing.T) { p := &fakeProvider{events: []flyto.Event{ &flyto.ToolUseEvent{ID: "t1", ToolName: "search", Input: map[string]any{"q": "x"}}, &flyto.ThinkingDeltaEvent{Text: "let me think"}, &flyto.ThinkingEvent{Text: "thought through"}, &flyto.UsageEvent{InputTokens: 10, OutputTokens: 3}, &flyto.TextEvent{Text: "final answer"}, &flyto.DoneEvent{TotalInputTokens: 10}, }} got, err := newTestClient(t, p).Complete(context.Background(), "p", LLMCallOpts{}) if err != nil { t.Fatalf("Complete: %v", err) } if got != "final answer" { t.Errorf("got %q, want %q", got, "final answer") } } func TestFlytoLLMClient_OptsPropagation(t *testing.T) { cases := []struct { name string defaultModel string opts LLMCallOpts wantReqModel string wantReqMax int }{ { name: "opts model overrides default", defaultModel: "default-model", opts: LLMCallOpts{Model: "override-model", MaxTokens: 512, Temperature: 0.7}, wantReqModel: "override-model", wantReqMax: 512, }, { name: "empty opts model falls back to default", defaultModel: "default-model", opts: LLMCallOpts{MaxTokens: 128}, wantReqModel: "default-model", wantReqMax: 128, }, { name: "zero max tokens propagates zero (provider default)", defaultModel: "default-model", opts: LLMCallOpts{Model: "m"}, wantReqModel: "m", wantReqMax: 0, }, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { p := &fakeProvider{events: []flyto.Event{&flyto.TextEvent{Text: "ok"}}} c, err := NewFlytoLLMClient(p, tc.defaultModel, "sys") if err != nil { t.Fatalf("NewFlytoLLMClient: %v", err) } if _, err := c.Complete(context.Background(), "prompt text", tc.opts); err != nil { t.Fatalf("Complete: %v", err) } if p.lastReq.Model != tc.wantReqModel { t.Errorf("Request.Model = %q, want %q", p.lastReq.Model, tc.wantReqModel) } if p.lastReq.MaxTokens != tc.wantReqMax { t.Errorf("Request.MaxTokens = %d, want %d", p.lastReq.MaxTokens, tc.wantReqMax) } if p.lastReq.System != "sys" { t.Errorf("Request.System = %q, want %q", p.lastReq.System, "sys") } if len(p.lastReq.Messages) != 1 || p.lastReq.Messages[0].Role != flyto.RoleUser { t.Errorf("Request.Messages = %+v, want single user message", p.lastReq.Messages) } }) } }