package builtin
// WebFetch 工具 -- HTTP GET 请求获取网页内容.
//
// 这是 Agent 获取外部信息的能力:通过 HTTP GET 请求获取网页内容,
// 并将 HTML 转换为纯文本,方便模型阅读.
//
// 特性:
// - HTTP GET 请求获取网页
// - 智能 HTML-to-text 转换:保留段落结构(p → 双换行,br → 换行,li → "- ")
// - 提取
作为输出的第一行
// - 移除 script 和 style 标签的内容
// - 完整处理 HTML entities(& < > " 等)
// - 对 JSON 响应直接返回格式化的 JSON(不做 HTML 转换)
// - HTTP 状态码检查(非 2xx 返回错误信息)
// - 支持超时配置(默认 30 秒)
// - 限制响应大小(最大 1MB)
// - ConcurrencySafe: true,ReadOnly: true
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"git.flytoex.net/yuanwei/flyto-agent/pkg/permission"
"git.flytoex.net/yuanwei/flyto-agent/pkg/tools"
)
// privateIPNets 是 SSRF 防护用的私网/本地地址段.
//
// 升华改进(ELEVATED): 早期方案无任何 SSRF 防护--攻击者可用
//
// url: "http://169.254.169.254/latest/meta-data/" 获取 AWS 临时凭证
// url: "http://127.0.0.1:9200/_cat/indices" 探测本地 Elasticsearch
// url: "http://192.168.1.1/admin" 访问内网路由器
//
// 我们在 DNS 解析后检查每个 IP,而非只检查 URL 中的 hostname--
// 防止攻击者使用解析到私网的域名绕过基于字符串的检查(DNS rebinding).
// 替代方案:<只检查 URL hostname 字符串> -
// 否决:hostname 可以解析到任意 IP,字符串检查完全可绕过.
var privateIPNets []*net.IPNet
func init() {
for _, cidr := range []string{
"127.0.0.0/8", // IPv4 loopback
"::1/128", // IPv6 loopback
"169.254.0.0/16", // IPv4 link-local(AWS 元数据服务在此段)
"fe80::/10", // IPv6 link-local
"10.0.0.0/8", // RFC 1918 私网
"172.16.0.0/12", // RFC 1918 私网
"192.168.0.0/16", // RFC 1918 私网
"100.64.0.0/10", // RFC 6598 共享地址空间(Tailscale 等 VPN)
"fc00::/7", // IPv6 唯一本地地址
"0.0.0.0/8", // "本网络"地址
} {
_, network, err := net.ParseCIDR(cidr)
if err == nil {
privateIPNets = append(privateIPNets, network)
}
}
}
// isPrivateIP 判断 IP 是否属于私网/本地范围.
func isPrivateIP(ip net.IP) bool {
for _, network := range privateIPNets {
if network.Contains(ip) {
return true
}
}
return false
}
// safeDialContext 是防 SSRF 的自定义 DialContext--
// 在 TCP 连接前解析 hostname,拦截指向私网的请求.
// 精妙之处(CLEVER): 在 DialContext 层拦截而非 URL 层--
// 这样即使经过 HTTP 重定向到私网,也会在真正建立连接前被阻断.
func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("webfetch: invalid address %q: %w", addr, err)
}
// 如果 host 直接是 IP,立即检查
if ip := net.ParseIP(host); ip != nil {
if isPrivateIP(ip) {
return nil, fmt.Errorf("webfetch: connection to private/internal address %s blocked (SSRF defense)", ip)
}
}
// hostname 需要 DNS 解析,检查所有解析结果
addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
if err != nil {
return nil, fmt.Errorf("webfetch: DNS lookup failed for %q: %w", host, err)
}
for _, a := range addrs {
if isPrivateIP(a.IP) {
return nil, fmt.Errorf("webfetch: DNS resolved %q to private/internal address %s (SSRF defense)", host, a.IP)
}
}
var d net.Dialer
return d.DialContext(ctx, network, net.JoinHostPort(host, port))
}
// webFetchDefaultTimeout 是 WebFetch 工具的默认请求超时(30 秒).
const webFetchDefaultTimeout = 30 * time.Second
// webFetchMaxTimeout 是用户可设置的最大超时上限(120 秒).
const webFetchMaxTimeout = 120 * time.Second
// WebFetchTool 是网页获取工具.
type WebFetchTool struct{}
// NewWebFetchTool 创建一个 WebFetch 工具实例.
func NewWebFetchTool() *WebFetchTool {
return &WebFetchTool{}
}
// webFetchInput 是 WebFetch 工具的输入参数.
type webFetchInput struct {
URL string `json:"url"`
Timeout int `json:"timeout,omitempty"` // 超时(秒),默认 30
}
// Name 返回工具名称.
func (t *WebFetchTool) Name() string {
return "WebFetch"
}
// Description 返回工具描述.
func (t *WebFetchTool) Description(ctx context.Context) string {
return "Fetches content from a URL via HTTP GET. " +
"HTML is intelligently converted to plain text (preserving paragraph structure, lists, title). " +
"JSON responses are returned formatted. " +
"Non-2xx HTTP status codes are reported as errors. " +
"Supports timeout configuration (default 30 seconds). " +
"Response size is limited to 1MB."
}
// InputSchema 返回工具的 JSON Schema 输入定义.
func (t *WebFetchTool) InputSchema() json.RawMessage {
return json.RawMessage(`{
"type": "object",
"properties": {
"url": {
"type": "string",
"description": "The URL to fetch content from"
},
"timeout": {
"type": "integer",
"description": "Timeout in seconds (default 30, max 120)"
}
},
"required": ["url"]
}`)
}
// Metadata 返回工具元数据.
func (t *WebFetchTool) Metadata() tools.Metadata {
return tools.Metadata{
ConcurrencySafe: true,
ReadOnly: true,
Destructive: false,
SearchHint: "web fetch http url page download",
PermissionClass: permission.PermClassWebFetch,
AuditOperation: "read",
}
}
// Execute 获取网页内容.
func (t *WebFetchTool) Execute(ctx context.Context, input json.RawMessage, progress tools.ProgressFunc) (*tools.Result, error) {
var params webFetchInput
if err := json.Unmarshal(input, ¶ms); err != nil {
return nil, fmt.Errorf("webfetch: invalid input: %w", err)
}
if params.URL == "" {
return &tools.Result{
Output: "error: url is required",
IsError: true,
}, nil
}
// 验证 URL 格式
if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
return &tools.Result{
Output: "error: url must start with http:// or https://",
IsError: true,
}, nil
}
// 计算超时
timeout := webFetchDefaultTimeout
if params.Timeout < 0 {
return &tools.Result{
Output: "error: timeout must be a positive number of seconds",
IsError: true,
}, nil
}
if params.Timeout > 0 {
timeout = time.Duration(params.Timeout) * time.Second
if timeout > webFetchMaxTimeout {
timeout = webFetchMaxTimeout
}
}
// 创建带超时的 HTTP 客户端(含 SSRF 防护和重定向限制)
client := &http.Client{
Timeout: timeout,
Transport: &http.Transport{
DialContext: safeDialContext,
},
// 限制最多 5 次重定向,防止无限重定向和重定向到内网
// 精妙之处(CLEVER): CheckRedirect 在 HTTP 层拦截,safeDialContext 在 TCP 层兜底--
// 双层防护确保即使 HTTP 重定向绕过字符串检查,TCP 层仍会阻断私网连接.
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 5 {
return fmt.Errorf("webfetch: too many redirects (max 5)")
}
return nil
},
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, "GET", params.URL, nil)
if err != nil {
return &tools.Result{
Output: fmt.Sprintf("error creating request: %v", err),
IsError: true,
}, nil
}
req.Header.Set("User-Agent", "AgentEngine/1.0")
req.Header.Set("Accept", "text/html, text/plain, application/json, */*")
// 执行请求
resp, err := client.Do(req)
if err != nil {
if resp != nil && resp.Body != nil {
resp.Body.Close()
}
return &tools.Result{
Output: fmt.Sprintf("error fetching URL: %v", err),
IsError: true,
}, nil
}
defer resp.Body.Close()
// 限制读取大小(1MB)
const maxSize = 1024 * 1024
// 精妙之处(CLEVER): LimitReader(maxSize+1) 而非 LimitReader(maxSize)--
// 多读 1 字节用于检测响应是否被截断.如果读到了 maxSize+1 字节,
// 说明原始响应超过了限制,需要在输出中标注 "[content truncated]".
// 如果用 maxSize,刚好 maxSize 大小的响应也会被标记为截断.
limitedReader := io.LimitReader(resp.Body, maxSize+1)
body, err := io.ReadAll(limitedReader)
if err != nil {
return &tools.Result{
Output: fmt.Sprintf("error reading response: %v", err),
IsError: true,
}, nil
}
truncated := len(body) > maxSize
if truncated {
body = body[:maxSize]
}
content := string(body)
// 检查 HTTP 状态码 - 非 2xx 一律视为错误
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return &tools.Result{
Output: fmt.Sprintf("HTTP %d %s\n\n%s", resp.StatusCode, resp.Status, truncateWebText(content, 1000)),
IsError: true,
}, nil
}
contentType := resp.Header.Get("Content-Type")
// 对 JSON 响应直接返回格式化的 JSON
if strings.Contains(contentType, "application/json") || isJSONContent(content) {
formatted := tryFormatJSON(content)
if truncated {
formatted += "\n... [content truncated, exceeded 1MB]"
}
return &tools.Result{
Output: fmt.Sprintf("URL: %s\nStatus: %d\nContent-Type: %s\n\n%s", params.URL, resp.StatusCode, contentType, formatted),
IsError: false,
}, nil
}
// 如果是 HTML,转换为纯文本
if strings.Contains(contentType, "text/html") || strings.Contains(content, " maxOutput {
content = content[:maxOutput] + "\n... [content truncated, exceeded 256KB]"
}
if truncated {
content += "\n... [response truncated, exceeded 1MB]"
}
return &tools.Result{
Output: fmt.Sprintf("URL: %s\nStatus: %d\n\n%s", params.URL, resp.StatusCode, content),
IsError: false,
}, nil
}
// isJSONContent 检测内容是否为 JSON 格式.
func isJSONContent(content string) bool {
trimmed := strings.TrimSpace(content)
return (strings.HasPrefix(trimmed, "{") && strings.HasSuffix(trimmed, "}")) ||
(strings.HasPrefix(trimmed, "[") && strings.HasSuffix(trimmed, "]"))
}
// tryFormatJSON 尝试格式化 JSON 内容.如果失败,返回原始内容.
func tryFormatJSON(content string) string {
var data any
if err := json.Unmarshal([]byte(content), &data); err != nil {
return content
}
formatted, err := json.MarshalIndent(data, "", " ")
if err != nil {
return content
}
return string(formatted)
}
// extractTitle 从 HTML 中提取 标签的内容.
func extractTitle(html string) string {
lower := strings.ToLower(html)
startIdx := strings.Index(lower, " 标签
gtIdx := strings.Index(html[startIdx:], ">")
if gtIdx == -1 {
return ""
}
titleStart := startIdx + gtIdx + 1
endIdx := strings.Index(lower[titleStart:], "' {
inTag = false
buildingTag = false
lowerTag := strings.ToLower(tagName)
// 获取标签名(去除属性)
cleanTag := lowerTag
if spaceIdx := strings.IndexAny(cleanTag, " \t\n\r"); spaceIdx >= 0 {
cleanTag = cleanTag[:spaceIdx]
}
// 检测 script/style 开始和结束标签
if cleanTag == "script" {
inScript = true
} else if cleanTag == "/script" {
inScript = false
continue
} else if cleanTag == "style" {
inStyle = true
} else if cleanTag == "/style" {
inStyle = false
continue
}
// 不在 script/style 中才处理标签效果
if !inScript && !inStyle {
switch cleanTag {
case "p", "/p", "div", "/div":
result.WriteString("\n\n")
case "br", "br/":
result.WriteByte('\n')
case "hr":
result.WriteString("\n---\n")
case "h1", "h2", "h3", "h4", "h5", "h6":
result.WriteString("\n\n")
case "/h1", "/h2", "/h3", "/h4", "/h5", "/h6":
result.WriteString("\n\n")
case "li":
result.WriteString("\n- ")
case "/li":
// 不额外添加内容
case "tr":
result.WriteByte('\n')
case "td", "th":
result.WriteString("\t")
case "blockquote":
result.WriteString("\n> ")
case "/blockquote":
result.WriteByte('\n')
case "pre":
result.WriteString("\n```\n")
case "/pre":
result.WriteString("\n```\n")
}
}
continue
}
if inTag {
if buildingTag {
if ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r' {
buildingTag = false
} else {
tagName += string(ch)
}
}
continue
}
if inScript || inStyle {
continue
}
// 解码 HTML 实体
if ch == '&' && i+1 < len(html) {
end := strings.IndexByte(html[i:], ';')
if end > 0 && end < 12 {
entity := html[i : i+end+1]
decoded := decodeHTMLEntity(entity)
if decoded != "" {
result.WriteString(decoded)
i += end
continue
}
}
}
result.WriteByte(ch)
}
return result.String()
}
// decodeHTMLEntities 解码字符串中的所有 HTML 实体.
func decodeHTMLEntities(s string) string {
var result strings.Builder
i := 0
for i < len(s) {
if s[i] == '&' {
end := strings.IndexByte(s[i:], ';')
if end > 0 && end < 12 {
entity := s[i : i+end+1]
decoded := decodeHTMLEntity(entity)
if decoded != "" {
result.WriteString(decoded)
i += end + 1
continue
}
}
}
result.WriteByte(s[i])
i++
}
return result.String()
}
// decodeHTMLEntity 解码单个 HTML 实体.
// 支持命名实体和数字实体(NNN; 和 HHH;).
func decodeHTMLEntity(entity string) string {
// 命名实体
entities := map[string]string{
"&": "&",
"<": "<",
">": ">",
""": "\"",
"'": "'",
"'": "'",
" ": " ",
"—": "--",
"–": "-",
"©": "(c)",
"®": "(R)",
"™": "(TM)",
"«": "<<",
"»": ">>",
"…": "...",
"•": "*",
"·": ".",
" ": " ",
" ": " ",
" ": " ",
"": "",
"‘": "'",
"’": "'",
"“": "\"",
"”": "\"",
"−": "-",
"×": "x",
"÷": "/",
"°": "deg",
"¶": "P",
"§": "S",
}
if v, ok := entities[entity]; ok {
return v
}
// 数字实体 NNN;
if strings.HasPrefix(entity, "") && strings.HasSuffix(entity, ";") {
numStr := entity[2 : len(entity)-1]
var codePoint int
if strings.HasPrefix(numStr, "x") || strings.HasPrefix(numStr, "X") {
// 十六进制 HHH;
_, err := fmt.Sscanf(numStr[1:], "%x", &codePoint)
if err == nil && codePoint > 0 && codePoint < 0x10FFFF {
return string(rune(codePoint))
}
} else {
// 十进制 NNN;
_, err := fmt.Sscanf(numStr, "%d", &codePoint)
if err == nil && codePoint > 0 && codePoint < 0x10FFFF {
return string(rune(codePoint))
}
}
}
return ""
}
// cleanWhitespace 清理多余的空白行.
func cleanWhitespace(text string) string {
lines := strings.Split(text, "\n")
var result []string
emptyCount := 0
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed == "" {
emptyCount++
if emptyCount <= 2 {
result = append(result, "")
}
} else {
emptyCount = 0
result = append(result, trimmed)
}
}
return strings.Join(result, "\n")
}
// truncateWebText 截断文本到指定长度.
func truncateWebText(text string, maxLen int) string {
if len(text) <= maxLen {
return text
}
return text[:maxLen] + "..."
}