1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11 "unicode/utf8"
12
13 md "github.com/JohannesKaufmann/html-to-markdown"
14 "github.com/PuerkitoBio/goquery"
15 "github.com/charmbracelet/crush/internal/permission"
16 "github.com/charmbracelet/fantasy/ai"
17)
18
19type FetchParams struct {
20 URL string `json:"url" description:"The URL to fetch content from"`
21 Format string `json:"format" description:"The format to return the content in (text, markdown, or html)"`
22 Timeout int `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
23}
24
25type FetchPermissionsParams struct {
26 URL string `json:"url"`
27 Format string `json:"format"`
28 Timeout int `json:"timeout,omitempty"`
29}
30
31type fetchTool struct {
32 client *http.Client
33 permissions permission.Service
34 workingDir string
35}
36
37const FetchToolName = "fetch"
38
39//go:embed fetch.md
40var fetchDescription []byte
41
42func NewFetchTool(permissions permission.Service, workingDir string) ai.AgentTool {
43 client := &http.Client{
44 Timeout: 30 * time.Second,
45 Transport: &http.Transport{
46 MaxIdleConns: 100,
47 MaxIdleConnsPerHost: 10,
48 IdleConnTimeout: 90 * time.Second,
49 },
50 }
51
52 return ai.NewAgentTool(
53 FetchToolName,
54 string(fetchDescription),
55 func(ctx context.Context, params FetchParams, call ai.ToolCall) (ai.ToolResponse, error) {
56 if params.URL == "" {
57 return ai.NewTextErrorResponse("URL parameter is required"), nil
58 }
59
60 format := strings.ToLower(params.Format)
61 if format != "text" && format != "markdown" && format != "html" {
62 return ai.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
63 }
64
65 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
66 return ai.NewTextErrorResponse("URL must start with http:// or https://"), nil
67 }
68
69 sessionID := GetSessionFromContext(ctx)
70 if sessionID == "" {
71 return ai.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
72 }
73
74 p := permissions.Request(
75 permission.CreatePermissionRequest{
76 SessionID: sessionID,
77 Path: workingDir,
78 ToolCallID: call.ID,
79 ToolName: FetchToolName,
80 Action: "fetch",
81 Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
82 Params: FetchPermissionsParams(params),
83 },
84 )
85
86 if !p {
87 return ai.ToolResponse{}, permission.ErrorPermissionDenied
88 }
89
90 // Handle timeout with context
91 requestCtx := ctx
92 if params.Timeout > 0 {
93 maxTimeout := 120 // 2 minutes
94 if params.Timeout > maxTimeout {
95 params.Timeout = maxTimeout
96 }
97 var cancel context.CancelFunc
98 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
99 defer cancel()
100 }
101
102 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
103 if err != nil {
104 return ai.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
105 }
106
107 req.Header.Set("User-Agent", "crush/1.0")
108
109 resp, err := client.Do(req)
110 if err != nil {
111 return ai.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
112 }
113 defer resp.Body.Close()
114
115 if resp.StatusCode != http.StatusOK {
116 return ai.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
117 }
118
119 maxSize := int64(5 * 1024 * 1024) // 5MB
120 body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
121 if err != nil {
122 return ai.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
123 }
124
125 content := string(body)
126
127 isValidUt8 := utf8.ValidString(content)
128 if !isValidUt8 {
129 return ai.NewTextErrorResponse("Response content is not valid UTF-8"), nil
130 }
131 contentType := resp.Header.Get("Content-Type")
132
133 switch format {
134 case "text":
135 if strings.Contains(contentType, "text/html") {
136 text, err := extractTextFromHTML(content)
137 if err != nil {
138 return ai.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
139 }
140 content = text
141 }
142
143 case "markdown":
144 if strings.Contains(contentType, "text/html") {
145 markdown, err := convertHTMLToMarkdown(content)
146 if err != nil {
147 return ai.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
148 }
149 content = markdown
150 }
151
152 content = "```\n" + content + "\n```"
153
154 case "html":
155 // return only the body of the HTML document
156 if strings.Contains(contentType, "text/html") {
157 doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
158 if err != nil {
159 return ai.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
160 }
161 body, err := doc.Find("body").Html()
162 if err != nil {
163 return ai.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
164 }
165 if body == "" {
166 return ai.NewTextErrorResponse("No body content found in HTML"), nil
167 }
168 content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
169 }
170 }
171 // calculate byte size of content
172 contentSize := int64(len(content))
173 if contentSize > MaxReadSize {
174 content = content[:MaxReadSize]
175 content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
176 }
177
178 return ai.NewTextResponse(content), nil
179 })
180}
181
182func extractTextFromHTML(html string) (string, error) {
183 doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
184 if err != nil {
185 return "", err
186 }
187
188 text := doc.Find("body").Text()
189 text = strings.Join(strings.Fields(text), " ")
190
191 return text, nil
192}
193
194func convertHTMLToMarkdown(html string) (string, error) {
195 converter := md.NewConverter("", true, nil)
196
197 markdown, err := converter.ConvertString(html)
198 if err != nil {
199 return "", err
200 }
201
202 return markdown, nil
203}