1package tools
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "fmt"
8 "io"
9 "net/http"
10 "strings"
11 "time"
12 "unicode/utf8"
13
14 md "github.com/JohannesKaufmann/html-to-markdown"
15 "github.com/PuerkitoBio/goquery"
16 "github.com/charmbracelet/crush/internal/permission"
17 "github.com/charmbracelet/crush/internal/proto"
18)
19
20type (
21 FetchParams = proto.FetchParams
22 FetchPermissionsParams = proto.FetchPermissionsParams
23)
24
25type fetchTool struct {
26 client *http.Client
27 permissions permission.Service
28 workingDir string
29}
30
31const FetchToolName = proto.FetchToolName
32
33//go:embed fetch.md
34var fetchDescription []byte
35
36func NewFetchTool(permissions permission.Service, workingDir string) BaseTool {
37 return &fetchTool{
38 client: &http.Client{
39 Timeout: 30 * time.Second,
40 Transport: &http.Transport{
41 MaxIdleConns: 100,
42 MaxIdleConnsPerHost: 10,
43 IdleConnTimeout: 90 * time.Second,
44 },
45 },
46 permissions: permissions,
47 workingDir: workingDir,
48 }
49}
50
51func (t *fetchTool) Name() string {
52 return FetchToolName
53}
54
55func (t *fetchTool) Info() ToolInfo {
56 return ToolInfo{
57 Name: FetchToolName,
58 Description: string(fetchDescription),
59 Parameters: map[string]any{
60 "url": map[string]any{
61 "type": "string",
62 "description": "The URL to fetch content from",
63 },
64 "format": map[string]any{
65 "type": "string",
66 "description": "The format to return the content in (text, markdown, or html)",
67 "enum": []string{"text", "markdown", "html"},
68 },
69 "timeout": map[string]any{
70 "type": "number",
71 "description": "Optional timeout in seconds (max 120)",
72 },
73 },
74 Required: []string{"url", "format"},
75 }
76}
77
78func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
79 var params FetchParams
80 if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
81 return NewTextErrorResponse("Failed to parse fetch parameters: " + err.Error()), nil
82 }
83
84 if params.URL == "" {
85 return NewTextErrorResponse("URL parameter is required"), nil
86 }
87
88 format := strings.ToLower(params.Format)
89 if format != "text" && format != "markdown" && format != "html" {
90 return NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
91 }
92
93 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
94 return NewTextErrorResponse("URL must start with http:// or https://"), nil
95 }
96
97 sessionID, messageID := GetContextValues(ctx)
98 if sessionID == "" || messageID == "" {
99 return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
100 }
101
102 p := t.permissions.Request(
103 permission.CreatePermissionRequest{
104 SessionID: sessionID,
105 Path: t.workingDir,
106 ToolCallID: call.ID,
107 ToolName: FetchToolName,
108 Action: "fetch",
109 Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
110 Params: FetchPermissionsParams(params),
111 },
112 )
113
114 if !p {
115 return ToolResponse{}, permission.ErrorPermissionDenied
116 }
117
118 // Handle timeout with context
119 requestCtx := ctx
120 if params.Timeout > 0 {
121 maxTimeout := 120 // 2 minutes
122 if params.Timeout > maxTimeout {
123 params.Timeout = maxTimeout
124 }
125 var cancel context.CancelFunc
126 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
127 defer cancel()
128 }
129
130 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
131 if err != nil {
132 return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
133 }
134
135 req.Header.Set("User-Agent", "crush/1.0")
136
137 resp, err := t.client.Do(req)
138 if err != nil {
139 return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
140 }
141 defer resp.Body.Close()
142
143 if resp.StatusCode != http.StatusOK {
144 return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
145 }
146
147 maxSize := int64(5 * 1024 * 1024) // 5MB
148 body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
149 if err != nil {
150 return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
151 }
152
153 content := string(body)
154
155 isValidUt8 := utf8.ValidString(content)
156 if !isValidUt8 {
157 return NewTextErrorResponse("Response content is not valid UTF-8"), nil
158 }
159 contentType := resp.Header.Get("Content-Type")
160
161 switch format {
162 case "text":
163 if strings.Contains(contentType, "text/html") {
164 text, err := extractTextFromHTML(content)
165 if err != nil {
166 return NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
167 }
168 content = text
169 }
170
171 case "markdown":
172 if strings.Contains(contentType, "text/html") {
173 markdown, err := convertHTMLToMarkdown(content)
174 if err != nil {
175 return NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
176 }
177 content = markdown
178 }
179
180 content = "```\n" + content + "\n```"
181
182 case "html":
183 // return only the body of the HTML document
184 if strings.Contains(contentType, "text/html") {
185 doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
186 if err != nil {
187 return NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
188 }
189 body, err := doc.Find("body").Html()
190 if err != nil {
191 return NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
192 }
193 if body == "" {
194 return NewTextErrorResponse("No body content found in HTML"), nil
195 }
196 content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
197 }
198 }
199 // calculate byte size of content
200 contentSize := int64(len(content))
201 if contentSize > MaxReadSize {
202 content = content[:MaxReadSize]
203 content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
204 }
205
206 return NewTextResponse(content), nil
207}
208
209func extractTextFromHTML(html string) (string, error) {
210 doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
211 if err != nil {
212 return "", err
213 }
214
215 text := doc.Find("body").Text()
216 text = strings.Join(strings.Fields(text), " ")
217
218 return text, nil
219}
220
221func convertHTMLToMarkdown(html string) (string, error) {
222 converter := md.NewConverter("", true, nil)
223
224 markdown, err := converter.ConvertString(html)
225 if err != nil {
226 return "", err
227 }
228
229 return markdown, nil
230}