1package tools
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "strings"
11 "time"
12 "unicode/utf8"
13
14 md "github.com/JohannesKaufmann/html-to-markdown"
15 "github.com/PuerkitoBio/goquery"
16 "github.com/charmbracelet/crush/internal/permission"
17)
18
19type FetchParams struct {
20 URL string `json:"url"`
21 Format string `json:"format"`
22 Timeout int `json:"timeout,omitempty"`
23}
24
25type FetchPermissionsParams struct {
26 URL string `json:"url"`
27 Format string `json:"format"`
28 Timeout int `json:"timeout,omitempty"`
29}
30
31type fetchTool struct {
32 client *http.Client
33 permissions permission.Service
34 workingDir string
35}
36
37const FetchToolName = "fetch"
38
39//go:embed fetch.md
40var fetchDescription []byte
41
42func NewFetchTool(permissions permission.Service, workingDir string) BaseTool {
43 return &fetchTool{
44 client: &http.Client{
45 Timeout: 30 * time.Second,
46 Transport: &http.Transport{
47 MaxIdleConns: 100,
48 MaxIdleConnsPerHost: 10,
49 IdleConnTimeout: 90 * time.Second,
50 },
51 },
52 permissions: permissions,
53 workingDir: workingDir,
54 }
55}
56
57func (t *fetchTool) Name() string {
58 return FetchToolName
59}
60
61func (t *fetchTool) Info() ToolInfo {
62 return ToolInfo{
63 Name: FetchToolName,
64 Description: string(fetchDescription),
65 Parameters: map[string]any{
66 "url": map[string]any{
67 "type": "string",
68 "description": "The URL to fetch content from",
69 },
70 "format": map[string]any{
71 "type": "string",
72 "description": "The format to return the content in (text, markdown, or html)",
73 "enum": []string{"text", "markdown", "html"},
74 },
75 "timeout": map[string]any{
76 "type": "number",
77 "description": "Optional timeout in seconds (max 120)",
78 },
79 },
80 Required: []string{"url", "format"},
81 }
82}
83
84func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
85 var params FetchParams
86 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
87 return NewTextErrorResponse("Failed to parse fetch parameters: " + err.Error()), nil
88 }
89
90 if params.URL == "" {
91 return NewTextErrorResponse("URL parameter is required"), nil
92 }
93
94 format := strings.ToLower(params.Format)
95 if format != "text" && format != "markdown" && format != "html" {
96 return NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
97 }
98
99 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
100 return NewTextErrorResponse("URL must start with http:// or https://"), nil
101 }
102
103 sessionID, messageID := GetContextValues(ctx)
104 if sessionID == "" || messageID == "" {
105 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
106 }
107
108 p := t.permissions.Request(
109 permission.CreatePermissionRequest{
110 SessionID: sessionID,
111 Path: t.workingDir,
112 ToolCallID: call.ID,
113 ToolName: FetchToolName,
114 Action: "fetch",
115 Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
116 Params: FetchPermissionsParams(params),
117 },
118 )
119
120 if !p {
121 return ToolResponse{}, permission.ErrorPermissionDenied
122 }
123
124 // Handle timeout with context
125 requestCtx := ctx
126 if params.Timeout > 0 {
127 maxTimeout := 120 // 2 minutes
128 if params.Timeout > maxTimeout {
129 params.Timeout = maxTimeout
130 }
131 var cancel context.CancelFunc
132 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
133 defer cancel()
134 }
135
136 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
137 if err != nil {
138 return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
139 }
140
141 req.Header.Set("User-Agent", "crush/1.0")
142
143 resp, err := t.client.Do(req)
144 if err != nil {
145 return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
146 }
147 defer resp.Body.Close()
148
149 if resp.StatusCode != http.StatusOK {
150 return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
151 }
152
153 maxSize := int64(5 * 1024 * 1024) // 5MB
154 body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
155 if err != nil {
156 return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
157 }
158
159 content := string(body)
160
161 isValidUt8 := utf8.ValidString(content)
162 if !isValidUt8 {
163 return NewTextErrorResponse("Response content is not valid UTF-8"), nil
164 }
165 contentType := resp.Header.Get("Content-Type")
166
167 switch format {
168 case "text":
169 if strings.Contains(contentType, "text/html") {
170 text, err := extractTextFromHTML(content)
171 if err != nil {
172 return NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
173 }
174 content = text
175 }
176
177 case "markdown":
178 if strings.Contains(contentType, "text/html") {
179 markdown, err := convertHTMLToMarkdown(content)
180 if err != nil {
181 return NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
182 }
183 content = markdown
184 }
185
186 content = "```\n" + content + "\n```"
187
188 case "html":
189 // return only the body of the HTML document
190 if strings.Contains(contentType, "text/html") {
191 doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
192 if err != nil {
193 return NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
194 }
195 body, err := doc.Find("body").Html()
196 if err != nil {
197 return NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
198 }
199 if body == "" {
200 return NewTextErrorResponse("No body content found in HTML"), nil
201 }
202 content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
203 }
204 }
205 // calculate byte size of content
206 contentSize := int64(len(content))
207 if contentSize > MaxReadSize {
208 content = content[:MaxReadSize]
209 content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
210 }
211
212 return NewTextResponse(content), nil
213}
214
215func extractTextFromHTML(html string) (string, error) {
216 doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
217 if err != nil {
218 return "", err
219 }
220
221 text := doc.Find("body").Text()
222 text = strings.Join(strings.Fields(text), " ")
223
224 return text, nil
225}
226
227func convertHTMLToMarkdown(html string) (string, error) {
228 converter := md.NewConverter("", true, nil)
229
230 markdown, err := converter.ConvertString(html)
231 if err != nil {
232 return "", err
233 }
234
235 return markdown, nil
236}