1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11 "unicode/utf8"
12
13 "charm.land/fantasy"
14 md "github.com/JohannesKaufmann/html-to-markdown"
15 "github.com/PuerkitoBio/goquery"
16 "github.com/charmbracelet/crush/internal/permission"
17)
18
19const FetchToolName = "fetch"
20
21//go:embed fetch.md
22var fetchDescription []byte
23
24func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
25 if client == nil {
26 client = &http.Client{
27 Timeout: 30 * time.Second,
28 Transport: &http.Transport{
29 MaxIdleConns: 100,
30 MaxIdleConnsPerHost: 10,
31 IdleConnTimeout: 90 * time.Second,
32 },
33 }
34 }
35
36 return fantasy.NewParallelAgentTool(
37 FetchToolName,
38 string(fetchDescription),
39 func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
40 if params.URL == "" {
41 return fantasy.NewTextErrorResponse("URL parameter is required"), nil
42 }
43
44 format := strings.ToLower(params.Format)
45 if format != "text" && format != "markdown" && format != "html" {
46 return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
47 }
48
49 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
50 return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
51 }
52
53 sessionID := GetSessionFromContext(ctx)
54 if sessionID == "" {
55 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
56 }
57
58 p, err := permissions.Request(ctx,
59 permission.CreatePermissionRequest{
60 SessionID: sessionID,
61 Path: workingDir,
62 ToolCallID: call.ID,
63 ToolName: FetchToolName,
64 Action: "fetch",
65 Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
66 Params: FetchPermissionsParams(params),
67 },
68 )
69 if err != nil {
70 return fantasy.ToolResponse{}, err
71 }
72 if !p.Granted {
73 if p.Message != "" {
74 return fantasy.NewTextErrorResponse("User denied permission." + permission.UserCommentaryTag(p.Message)), nil
75 }
76 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
77 }
78
79 // maxFetchTimeoutSeconds is the maximum allowed timeout for fetch requests (2 minutes)
80 const maxFetchTimeoutSeconds = 120
81
82 // Handle timeout with context
83 requestCtx := ctx
84 if params.Timeout > 0 {
85 if params.Timeout > maxFetchTimeoutSeconds {
86 params.Timeout = maxFetchTimeoutSeconds
87 }
88 var cancel context.CancelFunc
89 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
90 defer cancel()
91 }
92
93 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
94 if err != nil {
95 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
96 }
97
98 req.Header.Set("User-Agent", "crush/1.0")
99
100 resp, err := client.Do(req)
101 if err != nil {
102 return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
103 }
104 defer resp.Body.Close()
105
106 if resp.StatusCode != http.StatusOK {
107 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
108 }
109
110 // maxFetchResponseSizeBytes is the maximum size of response body to read (5MB)
111 const maxFetchResponseSizeBytes = int64(5 * 1024 * 1024)
112
113 maxSize := maxFetchResponseSizeBytes
114 body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
115 if err != nil {
116 return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
117 }
118
119 content := string(body)
120
121 validUTF8 := utf8.ValidString(content)
122 if !validUTF8 {
123 return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
124 }
125 contentType := resp.Header.Get("Content-Type")
126
127 switch format {
128 case "text":
129 if strings.Contains(contentType, "text/html") {
130 text, err := extractTextFromHTML(content)
131 if err != nil {
132 return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
133 }
134 content = text
135 }
136
137 case "markdown":
138 if strings.Contains(contentType, "text/html") {
139 markdown, err := convertHTMLToMarkdown(content)
140 if err != nil {
141 return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
142 }
143 content = markdown
144 }
145
146 content = "```\n" + content + "\n```"
147
148 case "html":
149 // return only the body of the HTML document
150 if strings.Contains(contentType, "text/html") {
151 doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
152 if err != nil {
153 return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
154 }
155 body, err := doc.Find("body").Html()
156 if err != nil {
157 return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
158 }
159 if body == "" {
160 return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
161 }
162 content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
163 }
164 }
165 // truncate content if it exceeds max read size
166 if int64(len(content)) > MaxReadSize {
167 content = content[:MaxReadSize]
168 content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
169 }
170
171 return fantasy.NewTextResponse(p.AppendCommentary(content)), nil
172 })
173}
174
175func extractTextFromHTML(html string) (string, error) {
176 doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
177 if err != nil {
178 return "", err
179 }
180
181 text := doc.Find("body").Text()
182 text = strings.Join(strings.Fields(text), " ")
183
184 return text, nil
185}
186
187func convertHTMLToMarkdown(html string) (string, error) {
188 converter := md.NewConverter("", true, nil)
189
190 markdown, err := converter.ConvertString(html)
191 if err != nil {
192 return "", err
193 }
194
195 return markdown, nil
196}