1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "html/template"
8 "io"
9 "net/http"
10 "strings"
11 "time"
12 "unicode/utf8"
13
14 "charm.land/fantasy"
15 md "github.com/JohannesKaufmann/html-to-markdown"
16 "github.com/PuerkitoBio/goquery"
17 "github.com/charmbracelet/crush/internal/permission"
18)
19
20const (
21 FetchToolName = "fetch"
22 MaxFetchSize = 100 * 1024 // 100KB
23)
24
25//go:embed fetch.md.tpl
26var fetchDescriptionTmpl []byte
27
28var fetchDescriptionTpl = template.Must(
29 template.New("fetchDescription").
30 Parse(string(fetchDescriptionTmpl)),
31)
32
33func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
34 if client == nil {
35 transport := http.DefaultTransport.(*http.Transport).Clone()
36 transport.MaxIdleConns = 100
37 transport.MaxIdleConnsPerHost = 10
38 transport.IdleConnTimeout = 90 * time.Second
39
40 client = &http.Client{
41 Timeout: 30 * time.Second,
42 Transport: transport,
43 }
44 }
45
46 return fantasy.NewParallelAgentTool(
47 FetchToolName,
48 renderToolDescription(fetchDescriptionTpl),
49 func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
50 if params.URL == "" {
51 return fantasy.NewTextErrorResponse("URL parameter is required"), nil
52 }
53
54 format := strings.ToLower(params.Format)
55 if format != "text" && format != "markdown" && format != "html" {
56 return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
57 }
58
59 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
60 return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
61 }
62
63 sessionID := GetSessionFromContext(ctx)
64 if sessionID == "" {
65 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
66 }
67
68 p, err := permissions.Request(ctx,
69 permission.CreatePermissionRequest{
70 SessionID: sessionID,
71 Path: workingDir,
72 ToolCallID: call.ID,
73 ToolName: FetchToolName,
74 Action: "fetch",
75 Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
76 Params: FetchPermissionsParams(params),
77 },
78 )
79 if err != nil {
80 return fantasy.ToolResponse{}, err
81 }
82 if !p {
83 return NewPermissionDeniedResponse(), nil
84 }
85
86 // maxFetchTimeoutSeconds is the maximum allowed timeout for fetch requests (2 minutes)
87 const maxFetchTimeoutSeconds = 120
88
89 // Handle timeout with context
90 requestCtx := ctx
91 if params.Timeout > 0 {
92 if params.Timeout > maxFetchTimeoutSeconds {
93 params.Timeout = maxFetchTimeoutSeconds
94 }
95 var cancel context.CancelFunc
96 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
97 defer cancel()
98 }
99
100 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
101 if err != nil {
102 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
103 }
104
105 req.Header.Set("User-Agent", "crush/1.0")
106
107 resp, err := client.Do(req)
108 if err != nil {
109 return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
110 }
111 defer resp.Body.Close()
112
113 if resp.StatusCode != http.StatusOK {
114 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
115 }
116
117 body, err := io.ReadAll(io.LimitReader(resp.Body, MaxFetchSize))
118 if err != nil {
119 return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
120 }
121
122 content := string(body)
123
124 validUTF8 := utf8.ValidString(content)
125 if !validUTF8 {
126 return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
127 }
128 contentType := resp.Header.Get("Content-Type")
129
130 switch format {
131 case "text":
132 if strings.Contains(contentType, "text/html") {
133 text, err := extractTextFromHTML(content)
134 if err != nil {
135 return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
136 }
137 content = text
138 }
139
140 case "markdown":
141 if strings.Contains(contentType, "text/html") {
142 markdown, err := convertHTMLToMarkdown(content)
143 if err != nil {
144 return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
145 }
146 content = markdown
147 }
148
149 content = "```\n" + content + "\n```"
150
151 case "html":
152 // return only the body of the HTML document
153 if strings.Contains(contentType, "text/html") {
154 doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
155 if err != nil {
156 return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
157 }
158 body, err := doc.Find("body").Html()
159 if err != nil {
160 return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
161 }
162 if body == "" {
163 return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
164 }
165 content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
166 }
167 }
168 // truncate content if it exceeds max read size
169 if int64(len(content)) >= MaxFetchSize {
170 content = content[:MaxFetchSize]
171 content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxFetchSize)
172 }
173
174 return fantasy.NewTextResponse(content), nil
175 })
176}
177
178func extractTextFromHTML(html string) (string, error) {
179 doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
180 if err != nil {
181 return "", err
182 }
183
184 text := doc.Find("body").Text()
185 text = strings.Join(strings.Fields(text), " ")
186
187 return text, nil
188}
189
190func convertHTMLToMarkdown(html string) (string, error) {
191 converter := md.NewConverter("", true, nil)
192
193 markdown, err := converter.ConvertString(html)
194 if err != nil {
195 return "", err
196 }
197
198 return markdown, nil
199}