package engine import ( "context" "sync" "time" "git.flytoex.net/yuanwei/flyto-agent/pkg/query" ) // ELEVATED: Session 生命周期设计说明 // // 早期实现:Session 是个简单对象,没有显式关闭语义,GC 负责回收. // Go 版本升华: // 1. closeOnce - 幂等关闭,防止多次 Close() 触发重复 observer 事件或 double-close panic // 2. done channel - Close() 时广播,trackEvents goroutine 可感知 session 已关闭并 drain rawCh // 3. pendingPermissions 清理 - Close() 向所有 pending channel 发送 false(拒绝), // 唤醒所有阻塞在 WaitForPermission() 的 goroutine,防止永久泄漏 // // 替代方案:<只用 s.closed bool 标志 + 在 WaitForPermission/trackEvents 里轮询> // - 否决:轮询有延迟且浪费 CPU;done channel 是 Go 惯用零开销广播模式. // Session 是一个有状态的多轮对话会话. // // 对应原项目中 REPL 的会话概念,但完全剥离了终端 UI. // Session 维护消息历史,支持多轮对话,上下文压缩,会话恢复. // // 增强功能: // - 自动追踪消息:Send() 完成后自动将用户消息和助手回复追加到历史 // - Token 统计:维护 inputTokens, outputTokens, costUSD 累计统计 // - 会话元数据:标题,创建时间,最后活跃时间 // - 权限回复:通过 pendingPermissions channel map 实现异步权限决策 // - Compact 集成:内置 compact 触发检查 // // 将会话状态集中管理, // 和 QueryEngine 的闭包变量里,耦合严重. // Go 版本将会话状态集中管理,干净独立. type Session struct { id string engine EngineRef messages []query.Message mu sync.Mutex closed bool closeOnce sync.Once // 保证 Close() 幂等 done chan struct{} // Close() 时关闭,广播给所有等待者 // 会话统计 inputTokens int // 累计输入 token 数 outputTokens int // 累计输出 token 数 costUSD float64 // 累计花费(美元) turnCount int // 总轮次数 // lastCostThresholdEmitted float64 // Highest CostUSD threshold already emitted for this session; prevents re-emission at same threshold, 0 means none crossed yet. 本会话已 emit 过的最高 CostUSD 档位, 防同档重复发送, 0 表示还没跨过任何档. lastCostThresholdEmitted float64 // 会话元数据 title string // 会话标题 createdAt time.Time // 创建时间 lastActiveAt time.Time // 最后活跃时间 // 权限回复通道 // key 是 permission request ID,value 是等待回复的 channel // 当 Engine 遇到权限请求时,创建 channel 放入 map; // 消费层调用 ResolvePermission 时,向 channel 发送决策. // Close() 时遍历 map 向每个 channel 发送 false,唤醒所有阻塞的 WaitForPermission goroutine. pendingPermissions map[string]chan bool } func newSession(id string, engine EngineRef) *Session { now := time.Now() return &Session{ id: id, engine: engine, messages: make([]query.Message, 0), pendingPermissions: make(map[string]chan bool), done: make(chan struct{}), createdAt: now, lastActiveAt: now, } } // ID 返回会话 ID. func (s *Session) ID() string { return s.id } // Send 在会话中发送一条消息,返回流式事件 channel. // 自动携带历史消息上下文. // // 增强:事件流完成后自动将用户消息和助手回复追加到 session.messages, // 并更新 token 统计和最后活跃时间. // // 可选 opts 透传给底层 Engine.Run,用于 per-Send 运行时覆盖: // 典型场景是 WithModel (运行时模型切换) 和 WithCheckpointHandler // (per-Send 不可逆操作确认回调). // // ELEVATED: opts 顺序 invariant. // 原方案: Send 不接受 opts, 内部硬编码 WithMessages(history), 把所有 // Run-level 可选能力都封死. // 新方案: Send 接受 opts ...RunOption, 但在调 engine.Run 时把 opts 放前面, // Session 自己的 WithMessages(history) append 在**最后**. 由于 RunOption 是 // 函数式 runConfig 应用,最后一个 option 覆盖前面的,因此: // - caller 传的 WithMessages 会被 Session 的 history 快照强制覆盖, // 锁死 Session "自动追踪历史" 的核心语义 -- 这是刻意的安全兜底. // - 其他所有 option (WithModel / WithMaxTurns / WithSecret 等) 正常生效. // // 替代方案: <在入口处扫描 opts 拒绝 WithMessages> -- 否决: RunOption 是 // 不透明的函数值,无法反射识别,只能靠应用顺序约定. // // 用法: // // session := agent.Session("my-session") // events := session.Send(ctx, "先看看项目结构") // // ... 处理事件 ... // events = session.Send(ctx, "然后修复那个 bug", WithModel("claude-sonnet-4-6")) // // 第二次调用自动包含第一轮的上下文,并切换到指定模型 func (s *Session) Send(ctx context.Context, prompt string, opts ...RunOption) <-chan Event { s.mu.Lock() if s.closed { s.mu.Unlock() ch := make(chan Event, 1) ch <- &ErrorEvent{ Err: context.Canceled, Code: string(ErrSessionClosed), Suggestion: "会话已关闭,请创建新的会话", } close(ch) return ch } // 快照当前消息历史用于本次查询 history := make([]query.Message, len(s.messages)) copy(history, s.messages) s.mu.Unlock() // caller opts 在前, Session 的 WithMessages(history) 在后, // 由最后一个 option 覆盖规则,强制锁死历史快照 (见方法 doc 的 invariant). runOpts := append(opts, WithMessages(history)) rawEvents := s.engine.Run(ctx, prompt, runOpts...) // 包装事件流,在完成时自动追踪消息和统计 wrappedCh := make(chan Event, 64) go s.trackEvents(ctx, prompt, rawEvents, wrappedCh) return wrappedCh } // trackEvents 从原始事件流读取事件,转发给消费层, // 同时收集助手回复文本和统计信息,在完成时自动追加到会话历史. // // 升华改进(ELEVATED): 三层退出保护 // 1. ctx.Done() - 消费层取消时停止转发(原有) // 2. s.done - Session.Close() 时立即退出并 drain rawCh,防止上游 goroutine 阻塞 // 3. rawCh 关闭 - 引擎正常结束时自然退出(原有) // // drain rawCh 的必要性: // // Close() 时引擎侧可能还在生产事件.如果不 drain,rawCh 的生产侧 goroutine 会因 // channel 满(buffered)或无消费者而永久阻塞,形成 goroutine 泄漏. // drain 后生产侧 goroutine 能正常结束. // // 替代方案:<只用 ctx.Done()> // - 否决:Close() 和 ctx 取消是独立的生命周期;ctx 是请求级,session.done 是会话级. // 替代方案:<在 drain 时也处理消息历史> // - 否决:session 已 closed,写 history 无意义且需要额外锁竞争. func (s *Session) trackEvents(ctx context.Context, prompt string, rawCh <-chan Event, outCh chan<- Event) { defer close(outCh) var assistantText string var turnInputTokens, turnOutputTokens int var turnCost float64 for evt := range rawCh { // 收集统计信息 switch e := evt.(type) { case *TextEvent: assistantText += e.Text case *TurnEndEvent: turnInputTokens += e.InputTokens turnOutputTokens += e.OutputTokens turnCost += e.CostUSD case *DoneEvent: // 使用 DoneEvent 中的最终统计 turnInputTokens = e.TotalInputTokens turnOutputTokens = e.TotalOutputTokens turnCost = e.TotalCostUSD } // 转发事件给消费层:ctx 取消或 session 关闭时退出,避免 goroutine 泄漏 select { case outCh <- evt: case <-ctx.Done(): // ctx 取消:drain rawCh 让上游生产 goroutine 能正常退出 go func() { for range rawCh { } }() return case <-s.done: // session 关闭:drain rawCh 让上游生产 goroutine 能正常退出 go func() { for range rawCh { } }() return } } // 事件流结束后,追加消息到历史并更新统计 + 跨阈值 emit s.applyTurn(prompt, assistantText, turnInputTokens, turnOutputTokens, turnCost) } // costThresholdsUSD is the preset ascending ladder of cumulative CostUSD // alarm thresholds for a session. Hardcoded (not Config-exposed) to keep the // public API surface stable; change the ladder here if product requirements // shift, no consumer-facing deprecation needed. // // costThresholdsUSD 是会话累计 CostUSD 告警的预设升序档位. 写死 (不经 Config // 暴露) 以保持公共 API 面稳定; 产品需求变化直接改本处, 不需要对消费者做 // deprecation 走查. var costThresholdsUSD = []float64{1, 5, 10, 50, 100} // crossedCostThresholds returns thresholds from costThresholdsUSD that lie // in the half-open interval (last, cost]. Ordered ascending so callers can // emit one event per crossed tier in canonical order, and so the caller can // update its guard to the last element of the slice to cover multi-tier // crossings in a single turn (e.g. $0 -> $15 yields [1, 5, 10]). // // crossedCostThresholds 返回 costThresholdsUSD 中落在半开区间 (last, cost] // 内的档位. 升序以便调用方按规范顺序每档 emit 一次, 并以 slice 末元素更新 // guard 覆盖单轮跨多档情形 (例: $0 -> $15 得 [1, 5, 10]). func crossedCostThresholds(last, cost float64) []float64 { var out []float64 for _, t := range costThresholdsUSD { if t > last && t <= cost { out = append(out, t) } } return out } // applyTurn appends the user prompt and assistant reply to history, accumulates // token / cost / turn / activity stats, and emits // session_cost_threshold_crossed for every preset threshold crossed this turn. // Extracted from trackEvents so tests can drive the stats path directly // without plumbing a full Engine.Run through a fake EngineRef. // // Locking: holds s.mu while mutating state and building event payloads, then // releases before calling Observer().Event -- mirrors Close()'s pattern // (session.go ~L395) of unlocking before observer emit to avoid holding the // session mutex across a potentially slow consumer callback. // // applyTurn 把用户提示和助手回复追加到历史, 累加 token / cost / turn / 活跃 // 时间统计, 并对本轮跨过的每个预设档位各 emit 一次 // session_cost_threshold_crossed. 从 trackEvents 抽出是为了测试能直接驱动 // 统计路径, 无需经 fake EngineRef 搭一整条 Engine.Run 通路. // // 锁策略: 在 s.mu 保护下改状态并构造事件 payload, 再释放锁调 // Observer().Event -- 对齐 Close() (session.go ~L395) 的 "解锁后 emit" 模式, // 避免持会话锁跨越可能慢的消费者回调. func (s *Session) applyTurn(prompt, assistantText string, turnInputTokens, turnOutputTokens int, turnCost float64) { s.mu.Lock() if s.closed { s.mu.Unlock() return } // 追加用户消息 s.messages = append(s.messages, query.Message{ Role: query.RoleUser, Content: []query.Content{ {Type: query.ContentText, Text: prompt}, }, Time: time.Now(), }) // 追加助手回复(如果有文本内容) if assistantText != "" { s.messages = append(s.messages, query.Message{ Role: query.RoleAssistant, Content: []query.Content{ {Type: query.ContentText, Text: assistantText}, }, Time: time.Now(), }) } // 更新统计 s.inputTokens += turnInputTokens s.outputTokens += turnOutputTokens s.costUSD += turnCost s.turnCount++ s.lastActiveAt = time.Now() // CLEVER: SessionStats snapshot built **inside** the lock so the field // reads under s.mu are consistent with the turn we just applied; then // each emitted payload reads back via stats.X SelectorExpr, which is // also what promotes SessionStats.TurnCount / InputTokens / OutputTokens // / CostUSD / MessageCount from scanner-dead to scanner-alive (declared // field becomes a real runtime read site, not a CompositeLit write). // // CLEVER: SessionStats 快照在锁内构造, 保证 s.mu 保护下的字段读和本轮累加 // 一致; 随后每个 payload 通过 stats.X SelectorExpr 读回, 这也是把 // SessionStats.TurnCount / InputTokens / OutputTokens / CostUSD / // MessageCount 5 字段从 scanner 视角"声明未读"升为"真运行时读"的依据 // (声明字段真正成为运行时 read site, 而非仅 CompositeLit 的 write). crossed := crossedCostThresholds(s.lastCostThresholdEmitted, s.costUSD) var payloads []map[string]any if len(crossed) > 0 { stats := SessionStats{ TurnCount: s.turnCount, InputTokens: s.inputTokens, OutputTokens: s.outputTokens, CostUSD: s.costUSD, MessageCount: len(s.messages), } for _, t := range crossed { payloads = append(payloads, map[string]any{ "session_id": s.id, "threshold_usd": t, "turn_count": stats.TurnCount, "input_tokens": stats.InputTokens, "output_tokens": stats.OutputTokens, "cost_usd": stats.CostUSD, "message_count": stats.MessageCount, }) } s.lastCostThresholdEmitted = crossed[len(crossed)-1] } s.mu.Unlock() // Emit outside the lock; mirrors Close()'s unlock-then-emit pattern. // engine may be nil in degenerate unit tests; guard defensively. // // 锁外 emit; 对齐 Close() 的 unlock-then-emit 模式. 退化单测可能 // engine 为 nil, 防御性守护. if len(payloads) == 0 || s.engine == nil { return } obs := s.engine.Observer() if obs == nil { return } for _, p := range payloads { obs.Event("session_cost_threshold_crossed", p) } } // ResolvePermission 回复一个权限请求. // 消费层收到 PermissionRequestEvent 后,通过此方法告知引擎用户的决策. // // 工作机制: // - Engine 的 runLoop 在需要权限时调用 WaitForPermission(requestID), // 它创建一个 channel 放入 pendingPermissions map,然后阻塞等待. // - 消费层(HTTP Server,CLI 等)收到 PermissionRequestEvent 后, // 调用 ResolvePermission 向 channel 发送决策,唤醒阻塞的 runLoop. // // 对应原项目中 control_request / control_response 协议. func (s *Session) ResolvePermission(requestID string, allow bool) { s.mu.Lock() ch, ok := s.pendingPermissions[requestID] if ok { delete(s.pendingPermissions, requestID) } s.mu.Unlock() if ok { // 非阻塞发送:channel 缓冲区为 1,不会死锁 select { case ch <- allow: default: } } } // WaitForPermission 等待消费层对指定权限请求的回复. // 被 Engine 的 runLoop 调用 -- 当 PermissionHandler 为 nil(HTTP Server 模式)时, // 引擎通过此方法异步等待消费层的权限决策. // // 返回 true 表示允许,false 表示拒绝. // ctx 取消或 session 关闭时返回 false(视为拒绝). // // 精妙之处(CLEVER): 三路 select:ch(正常回复)/ ctx.Done()(请求取消)/ s.done(会话关闭). // 后两路都执行 delete(pendingPermissions, requestID) 清理 map entry, // 防止 Close() 再次尝试向已无消费者的 channel 发送(double-send 死代码路径). // Close() 自己也会清理 map,但两处 delete 同一 key 是幂等的,不会 panic. func (s *Session) WaitForPermission(ctx context.Context, requestID string) bool { ch := make(chan bool, 1) s.mu.Lock() if s.closed { s.mu.Unlock() return false } s.pendingPermissions[requestID] = ch s.mu.Unlock() select { case allow := <-ch: // 正常路径:ResolvePermission 已在发送前 delete(map, requestID),无需再清理 return allow case <-ctx.Done(): // 请求取消:清理 pending entry,避免 Close() 向无消费者的 channel 发送 s.mu.Lock() delete(s.pendingPermissions, requestID) s.mu.Unlock() return false case <-s.done: // session 关闭:清理 pending entry(Close() 会清理所有剩余 entry,这里也幂等) s.mu.Lock() delete(s.pendingPermissions, requestID) s.mu.Unlock() return false } } // Messages 返回当前会话的消息历史(只读副本). func (s *Session) Messages() []query.Message { s.mu.Lock() defer s.mu.Unlock() out := make([]query.Message, len(s.messages)) copy(out, s.messages) return out } // Stats returns the session's cumulative statistics snapshot. // // Shape: pull. Complementary to the push-side // `session_cost_threshold_crossed` observer event — pull for on-demand // snapshots, push for cost-tier alerting; the two are orthogonal. // // Stats 返回会话的累计统计信息快照. // // 形态: 调取 (pull). 和 push 侧 `session_cost_threshold_crossed` observer // 事件正交互补 -- 要按需快照走 pull, 要成本跨档告警走 push. func (s *Session) Stats() SessionStats { s.mu.Lock() defer s.mu.Unlock() return SessionStats{ TurnCount: s.turnCount, InputTokens: s.inputTokens, OutputTokens: s.outputTokens, CostUSD: s.costUSD, MessageCount: len(s.messages), } } // SessionStats is the cumulative statistics of a session, exposed on both // a pull API and a push event, so consumers can pick the shape that fits: // // - Pull: call Session.Stats() at any time for a current snapshot (cheap, // synchronous, holds the session mutex briefly). Suitable for UIs that // render a "session summary" on demand or CLIs that print at exit. // - Push: subscribe to observer event `session_cost_threshold_crossed`. // The engine emits this exactly once per crossed threshold from the // preset ladder ($1/$5/$10/$50/$100) the first time cumulative CostUSD // passes it in this session. The event payload carries the full 5 // fields plus `threshold_usd` + `session_id`, so downstream cost-alert // wiring (Slack/PagerDuty/dashboard) needs no additional pull. // // These two paths are same-source: the snapshot shipped in the event payload // is constructed inside the same mutex region as the field accumulators, so // push payloads and a concurrent Stats() pull cannot diverge for a given // turn. If you need "current full snapshot" use pull; if you need "somebody // just crossed a cost tier" use push; the two do not overlap by design. // // SessionStats 是会话的累计统计信息, 同时暴露 pull API 和 push 事件两条消费 // 路径, 消费者按场景选形: // // - Pull: 任意时刻调 Session.Stats() 拿当前快照 (便宜, 同步, 仅短暂持会话 // 锁). 适合 UI 按需渲染"会话摘要"或 CLI 退出时打印. // - Push: 订阅 observer 事件 `session_cost_threshold_crossed`. 当会话累计 // CostUSD 首次跨过预设档位 ($1/$5/$10/$50/$100) 任一档时, 引擎对每个 // 跨过的档位各 emit 一次. payload 含完整 5 字段 + `threshold_usd` + // `session_id`, 下游成本告警接线 (Slack/PagerDuty/dashboard) 无需再 pull. // // 两条路径同源: 事件 payload 的快照在 mutex 保护区内和字段累加器同一区域 // 构造, 因此同一轮的 push payload 和并发 pull 的 Stats() 不会偏离. 要"当前 // 完整快照"走 pull, 要"成本跨档告警"走 push, 两者刻意不重叠. type SessionStats struct { TurnCount int // Total turn count. 总轮次数. InputTokens int // Cumulative input tokens. 累计输入 token 数. OutputTokens int // Cumulative output tokens. 累计输出 token 数. CostUSD float64 // Cumulative cost in USD. 累计花费 (美元). MessageCount int // Total message count in session history. 会话历史消息数量. } // Title 返回会话标题. func (s *Session) Title() string { s.mu.Lock() defer s.mu.Unlock() return s.title } // SetTitle 设置会话标题. func (s *Session) SetTitle(title string) { s.mu.Lock() defer s.mu.Unlock() s.title = title } // CreatedAt 返回会话创建时间. func (s *Session) CreatedAt() time.Time { return s.createdAt } // LastActiveAt 返回会话最后活跃时间. func (s *Session) LastActiveAt() time.Time { s.mu.Lock() defer s.mu.Unlock() return s.lastActiveAt } // IsClosed 返回会话是否已关闭. func (s *Session) IsClosed() bool { s.mu.Lock() defer s.mu.Unlock() return s.closed } // Close 关闭会话,释放资源. // // 幂等:多次调用安全,内部通过 closeOnce 保证只执行一次. // // 关闭顺序: // 1. 设置 s.closed = true(Send/WaitForPermission 快速失败) // 2. 关闭 s.done(广播给 trackEvents goroutine 和所有 WaitForPermission goroutine) // 3. 遍历 pendingPermissions,向每个 channel 发送 false 并清空 map // (唤醒任何在 WaitForPermission 中已过 closed 检查,但尚未收到 s.done 的 goroutine) // 4. 发送 observer 事件 // // 替代方案:<只 close(s.done),不向 pendingPermissions 发送 false> // - 否决:WaitForPermission 的 select 已监听 s.done,理论上够用; // 但向 ch 发送 false 是防御性设计--万一有竞争窗口(如 WaitForPermission 在 s.done 关闭 // 之前刚把 ch 写入 map 但还没进 select),额外的 false 发送能确保其最终被唤醒. func (s *Session) Close() { s.closeOnce.Do(func() { s.mu.Lock() s.closed = true // 关闭 done channel:广播给所有监听者(trackEvents,WaitForPermission) close(s.done) // 收集所有 pending permission channels,然后统一发送 false 并清空 map // 精妙之处(CLEVER): 在 mu 保护下取出所有 ch,mu 解锁后再发送-- // 避免持锁期间 ch <- false(虽然 buffered=1 通常不阻塞,但防御更严谨). pendingChs := make([]chan bool, 0, len(s.pendingPermissions)) for id, ch := range s.pendingPermissions { pendingChs = append(pendingChs, ch) delete(s.pendingPermissions, id) } s.mu.Unlock() // 向所有 pending channel 发送 false(拒绝),唤醒阻塞的 WaitForPermission goroutine for _, ch := range pendingChs { select { case ch <- false: default: // channel 已有值(极少见):不阻塞,WaitForPermission 会读到 s.done } } s.engine.Observer().Event("session_closed", map[string]any{ "session_id": s.id, }) }) }