1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11 "unicode/utf8"
12
13 "charm.land/fantasy"
14 md "github.com/JohannesKaufmann/html-to-markdown"
15 "github.com/PuerkitoBio/goquery"
16 "github.com/charmbracelet/crush/internal/permission"
17)
18
19type FetchParams struct {
20 URL string `json:"url" description:"The URL to fetch content from"`
21 Format string `json:"format" description:"The format to return the content in (text, markdown, or html)"`
22 Timeout int `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
23}
24
25type FetchPermissionsParams struct {
26 URL string `json:"url"`
27 Format string `json:"format"`
28 Timeout int `json:"timeout,omitempty"`
29}
30
31const FetchToolName = "fetch"
32
33//go:embed fetch.md
34var fetchDescription []byte
35
36func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
37 if client == nil {
38 client = &http.Client{
39 Timeout: 30 * time.Second,
40 Transport: &http.Transport{
41 MaxIdleConns: 100,
42 MaxIdleConnsPerHost: 10,
43 IdleConnTimeout: 90 * time.Second,
44 },
45 }
46 }
47
48 return fantasy.NewAgentTool(
49 FetchToolName,
50 string(fetchDescription),
51 func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
52 if params.URL == "" {
53 return fantasy.NewTextErrorResponse("URL parameter is required"), nil
54 }
55
56 format := strings.ToLower(params.Format)
57 if format != "text" && format != "markdown" && format != "html" {
58 return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
59 }
60
61 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
62 return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
63 }
64
65 sessionID := GetSessionFromContext(ctx)
66 if sessionID == "" {
67 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
68 }
69
70 p := permissions.Request(
71 permission.CreatePermissionRequest{
72 SessionID: sessionID,
73 Path: workingDir,
74 ToolCallID: call.ID,
75 ToolName: FetchToolName,
76 Action: "fetch",
77 Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
78 Params: FetchPermissionsParams(params),
79 },
80 )
81
82 if !p {
83 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
84 }
85
86 // Handle timeout with context
87 requestCtx := ctx
88 if params.Timeout > 0 {
89 maxTimeout := 120 // 2 minutes
90 if params.Timeout > maxTimeout {
91 params.Timeout = maxTimeout
92 }
93 var cancel context.CancelFunc
94 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
95 defer cancel()
96 }
97
98 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
99 if err != nil {
100 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
101 }
102
103 req.Header.Set("User-Agent", "crush/1.0")
104
105 resp, err := client.Do(req)
106 if err != nil {
107 return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
108 }
109 defer resp.Body.Close()
110
111 if resp.StatusCode != http.StatusOK {
112 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
113 }
114
115 maxSize := int64(5 * 1024 * 1024) // 5MB
116 body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
117 if err != nil {
118 return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
119 }
120
121 content := string(body)
122
123 isValidUt8 := utf8.ValidString(content)
124 if !isValidUt8 {
125 return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
126 }
127 contentType := resp.Header.Get("Content-Type")
128
129 switch format {
130 case "text":
131 if strings.Contains(contentType, "text/html") {
132 text, err := extractTextFromHTML(content)
133 if err != nil {
134 return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
135 }
136 content = text
137 }
138
139 case "markdown":
140 if strings.Contains(contentType, "text/html") {
141 markdown, err := convertHTMLToMarkdown(content)
142 if err != nil {
143 return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
144 }
145 content = markdown
146 }
147
148 content = "```\n" + content + "\n```"
149
150 case "html":
151 // return only the body of the HTML document
152 if strings.Contains(contentType, "text/html") {
153 doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
154 if err != nil {
155 return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
156 }
157 body, err := doc.Find("body").Html()
158 if err != nil {
159 return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
160 }
161 if body == "" {
162 return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
163 }
164 content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
165 }
166 }
167 // calculate byte size of content
168 contentSize := int64(len(content))
169 if contentSize > MaxReadSize {
170 content = content[:MaxReadSize]
171 content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
172 }
173
174 return fantasy.NewTextResponse(content), nil
175 })
176}
177
178func extractTextFromHTML(html string) (string, error) {
179 doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
180 if err != nil {
181 return "", err
182 }
183
184 text := doc.Find("body").Text()
185 text = strings.Join(strings.Fields(text), " ")
186
187 return text, nil
188}
189
190func convertHTMLToMarkdown(html string) (string, error) {
191 converter := md.NewConverter("", true, nil)
192
193 markdown, err := converter.ConvertString(html)
194 if err != nil {
195 return "", err
196 }
197
198 return markdown, nil
199}