// HTTPTransport:通过 HTTP Streamable 与 MCP 服务器通信. // // 对应 MCP 2025-03-26 规范的 Streamable HTTP 传输模式: // // 客户端 服务端 // | POST /mcp | // | Content-Type: application/json // | Accept: application/json, text/event-stream // | ─────────────────────────> | // | | // | 响应 A(单条): | // | 200 OK | // | Content-Type: application/json // | {"jsonrpc":"2.0",...} | // | <───────────────────────── | // | | // | 响应 B(流式): | // | 200 OK | // | Content-Type: text/event-stream // | data: {...} | ← 多条消息 // | data: {...} | // | <───────────────────────── | // | | // | 202 Accepted(通知 fire-and-forget) // | <───────────────────────── | // // Session 管理: // - 服务端可在响应头中返回 Mcp-Session-Id // - 后续请求需在请求头携带该 Session-Id // // 升华改进(ELEVATED): 相较于 SSE 传输(两条持久连接),HTTP Streamable 更简洁-- // 每次请求/响应是独立 HTTP 事务,无需维持长连接,对 CDN / 负载均衡更友好. // SSE 传输兼容旧版 MCP 服务器;HTTP Streamable 面向 2025-03-26 及以后的新服务器. // 替代方案:<复用 SSE 传输+扩展 POST 端点> - 否决,因为两个规范的会话管理模型不同. package mcp import ( "bufio" "bytes" "context" "fmt" "io" "net" "net/http" "strings" "sync" "sync/atomic" "time" ) // HTTPTransport 通过 HTTP Streamable 实现 Transport 接口. type HTTPTransport struct { url string auth AuthProvider // sessionID 由服务端在 initialize 阶段通过响应头分配. // 精妙之处(CLEVER): 用 atomic.Value 而非 sync.Mutex 保护 sessionID-- // sessionID 只在 initialize(首次请求)时写入一次,之后只读; // atomic.Value 的 Store/Load 无锁,比 Mutex 轻量,且语义更清晰(不可变一次性赋值). sessionID atomic.Value // stores string // 精妙之处(CLEVER): recvCh 使用有界缓冲(32 条)而非无界-- // 背压天然传导: 若消费者(Recv)慢, Send/readSSEStream 阻塞在 recvCh 发送端, // 不会静默丢弃消息(select 无 default 分支). 与 StdioTransport 设计一致. recvCh chan []byte // 异步流式响应消息 → Recv() recvErr chan error // 流读取错误(容量 1) done chan struct{} once sync.Once httpClient *http.Client } // NewHTTPTransport 创建 HTTP Streamable 传输实例. // // url 是服务端端点,例如 "https://mcp.example.com/mcp". // auth 是鉴权提供者;传 nil 则使用 NoopAuth(匿名访问). // // 与 SSETransport.NewSSETransport 不同,本构造函数不建立持久连接, // 连接在第一次 Send() 时才发生,是惰性的. func NewHTTPTransport(url string, auth AuthProvider) (*HTTPTransport, error) { if url == "" { return nil, fmt.Errorf("mcp: HTTP transport: url is empty") } if auth == nil { auth = &NoopAuth{} } t := &HTTPTransport{ url: url, auth: auth, recvCh: make(chan []byte, 32), recvErr: make(chan error, 1), done: make(chan struct{}), httpClient: &http.Client{ // 允许最长 5 分钟的流式响应(代码生成,大文件读取场景). // 单次请求如需更短超时,由调用方在 ctx 中控制. Timeout: 5 * time.Minute, // 升华改进(ELEVATED): 加 SSRF 防护 Transport-- // 早期方案只有裸 &http.Client{Timeout: ...},可请求 169.254.169.254(AWS metadata) // 或内网数据库.Transport 层拦截(DialContext)比 URL 字符串检查更彻底, // 即使经过 HTTP 重定向到私网也会在 TCP 握手前被阻断. // 逻辑与 pkg/tools/builtin/webfetch.go 的 safeDialContext 一致. // 替代方案:<提取到共享 internal/ssrf 包> - 否决:过度抽象,两处用量不值得新包. Transport: &http.Transport{ DialContext: mcpSafeDialContext, }, }, } return t, nil } // Send 发送 JSON-RPC 消息并处理响应. // // 根据服务端返回的 Content-Type 分两种处理路径: // - application/json:同步读取单条响应,投递到 recvCh // - text/event-stream:异步解析 SSE 流,每条事件投递到 recvCh func (t *HTTPTransport) Send(ctx context.Context, msg []byte) error { select { case <-t.done: return fmt.Errorf("mcp: HTTP transport closed") default: } req, err := http.NewRequestWithContext(ctx, http.MethodPost, t.url, bytes.NewReader(msg)) if err != nil { return fmt.Errorf("mcp: HTTP post: create request: %w", err) } req.Header.Set("Content-Type", "application/json") // 声明双重 Accept:允许服务端按需选择响应格式 req.Header.Set("Accept", "application/json, text/event-stream") // 携带 session ID(initialize 后生效) if sid, ok := t.sessionID.Load().(string); ok && sid != "" { req.Header.Set("Mcp-Session-Id", sid) } // 附加鉴权头(每次请求调用,支持 token 动态刷新) headers, err := t.auth.Headers(ctx) if err != nil { return fmt.Errorf("mcp: HTTP auth headers: %w", err) } for k, v := range headers { req.Header.Set(k, v) } resp, err := t.httpClient.Do(req) if err != nil { return fmt.Errorf("mcp: HTTP post to %q: %w", t.url, err) } // 提取 session ID(服务端首次分配,后续请求需携带) if sid := resp.Header.Get("Mcp-Session-Id"); sid != "" { t.sessionID.Store(sid) } ct := resp.Header.Get("Content-Type") switch { case strings.Contains(ct, "text/event-stream"): // 流式响应:异步读取,不阻塞 Send 返回 go t.readSSEStream(resp.Body) return nil case resp.StatusCode == http.StatusAccepted: // 202 Accepted:fire-and-forget 通知,无响应体 resp.Body.Close() return nil default: defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024)) return fmt.Errorf("mcp: HTTP post: status %d: %s", resp.StatusCode, string(body)) } // 单条 JSON 响应(最多 maxResponseSize 字节) body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize)) if err != nil { return fmt.Errorf("mcp: HTTP post: read response body: %w", err) } if len(body) > 0 { select { case t.recvCh <- body: case <-t.done: } } return nil } } // Recv 阻塞等待一条来自服务器的消息(单条响应或流中的一条事件). func (t *HTTPTransport) Recv(ctx context.Context) ([]byte, error) { select { case msg, ok := <-t.recvCh: if !ok { return nil, io.EOF } return msg, nil case err := <-t.recvErr: return nil, err case <-t.done: return nil, io.EOF case <-ctx.Done(): return nil, ctx.Err() } } // Close 关闭 HTTP 传输. func (t *HTTPTransport) Close() error { t.once.Do(func() { close(t.done) }) return nil } // readSSEStream 解析 SSE 格式的流式响应体,在后台 goroutine 中运行. // // 与 SSETransport.sseReadLoop 类似,但只处理 data 事件(无 endpoint 协商). func (t *HTTPTransport) readSSEStream(body io.ReadCloser) { defer body.Close() scanner := bufio.NewScanner(body) scanner.Buffer(make([]byte, 64*1024), 1*1024*1024) var dataLines []string for scanner.Scan() { select { case <-t.done: return default: } line := scanner.Text() if line == "" { // 空行:分发当前事件 if len(dataLines) > 0 { data := strings.Join(dataLines, "\n") if data != "" && data != "[DONE]" { // 忽略 OpenAI 风格的结束标记 select { case t.recvCh <- []byte(data): case <-t.done: return } } } dataLines = nil continue } if after, ok := strings.CutPrefix(line, "data:"); ok { dataLines = append(dataLines, strings.TrimSpace(after)) } // event: / id: / retry: 字段暂不处理 } if err := scanner.Err(); err != nil { select { case t.recvErr <- fmt.Errorf("mcp: HTTP SSE stream read error: %w", err): default: } } } // ── SSRF 防护 ────────────────────────────────────────────────────────────── // // 历史包袱(LEGACY): 原始实现只有裸 http.Client{Timeout: ...}, // 不阻止请求内网地址(169.254.169.254 / 10.x / 192.168.x 等). // 以下三个函数与 pkg/tools/builtin/webfetch.go 逻辑一致,因跨包关系无法复用, // 独立维护.若日后提取到 internal/ssrf 包,两处同步修改. // mcpPrivateIPNets 列出所有需要屏蔽的私网/本地地址段. var mcpPrivateIPNets []*net.IPNet func init() { for _, cidr := range []string{ "127.0.0.0/8", // IPv4 loopback "::1/128", // IPv6 loopback "169.254.0.0/16", // IPv4 link-local(AWS/GCP metadata 服务) "fe80::/10", // IPv6 link-local "10.0.0.0/8", // RFC 1918 私网 "172.16.0.0/12", // RFC 1918 私网 "192.168.0.0/16", // RFC 1918 私网 "100.64.0.0/10", // RFC 6598 共享地址空间(Tailscale 等 VPN) "fc00::/7", // IPv6 唯一本地地址 "0.0.0.0/8", // "本网络"地址 } { _, network, err := net.ParseCIDR(cidr) if err == nil { mcpPrivateIPNets = append(mcpPrivateIPNets, network) } } } // mcpIsPrivateIP 判断 IP 是否属于私网/本地范围. func mcpIsPrivateIP(ip net.IP) bool { for _, network := range mcpPrivateIPNets { if network.Contains(ip) { return true } } return false } // mcpSafeDialContext 是防 SSRF 的自定义 DialContext-- // 在 TCP 连接前解析 hostname,拦截指向私网的请求. // // 精妙之处(CLEVER): 在 DialContext 层拦截而非 URL 层-- // 即使经过 HTTP 重定向到私网,也会在真正建立 TCP 连接前被阻断. // 替代方案:<检查 URL hostname 字符串> - 否决:hostname 可解析到任意 IP,字符串检查可绕过. func mcpSafeDialContext(ctx context.Context, network, addr string) (net.Conn, error) { host, port, err := net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("mcp: SSRF guard: invalid address %q: %w", addr, err) } // 如果 host 直接是 IP,立即检查(不需要 DNS 解析) if ip := net.ParseIP(host); ip != nil { if mcpIsPrivateIP(ip) { return nil, fmt.Errorf("mcp: SSRF guard: connection to private/internal address %s blocked", ip) } var d net.Dialer return d.DialContext(ctx, network, net.JoinHostPort(host, port)) } // hostname 需要 DNS 解析,检查所有解析结果(防 DNS rebinding) addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host) if err != nil { return nil, fmt.Errorf("mcp: SSRF guard: DNS lookup failed for %q: %w", host, err) } for _, a := range addrs { if mcpIsPrivateIP(a.IP) { return nil, fmt.Errorf("mcp: SSRF guard: DNS resolved %q to private/internal address %s blocked", host, a.IP) } } var d net.Dialer return d.DialContext(ctx, network, net.JoinHostPort(host, port)) }