1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11 "unicode/utf8"
12
13 "charm.land/fantasy"
14 md "github.com/JohannesKaufmann/html-to-markdown"
15 "github.com/PuerkitoBio/goquery"
16 "github.com/charmbracelet/crush/internal/permission"
17)
18
19type FetchParams struct {
20 URL string `json:"url" description:"The URL to fetch content from"`
21 Format string `json:"format" description:"The format to return the content in (text, markdown, or html)"`
22 Timeout int `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
23}
24
25type FetchPermissionsParams struct {
26 URL string `json:"url"`
27 Format string `json:"format"`
28 Timeout int `json:"timeout,omitempty"`
29}
30
31type fetchTool struct {
32 client *http.Client
33 permissions permission.Service
34 workingDir string
35}
36
37const FetchToolName = "fetch"
38
39//go:embed fetch.md
40var fetchDescription []byte
41
42func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
43 if client == nil {
44 client = &http.Client{
45 Timeout: 30 * time.Second,
46 Transport: &http.Transport{
47 MaxIdleConns: 100,
48 MaxIdleConnsPerHost: 10,
49 IdleConnTimeout: 90 * time.Second,
50 },
51 }
52 }
53
54 return fantasy.NewAgentTool(
55 FetchToolName,
56 string(fetchDescription),
57 func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
58 if params.URL == "" {
59 return fantasy.NewTextErrorResponse("URL parameter is required"), nil
60 }
61
62 format := strings.ToLower(params.Format)
63 if format != "text" && format != "markdown" && format != "html" {
64 return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
65 }
66
67 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
68 return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
69 }
70
71 sessionID := GetSessionFromContext(ctx)
72 if sessionID == "" {
73 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
74 }
75
76 p := permissions.Request(
77 permission.CreatePermissionRequest{
78 SessionID: sessionID,
79 Path: workingDir,
80 ToolCallID: call.ID,
81 ToolName: FetchToolName,
82 Action: "fetch",
83 Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
84 Params: FetchPermissionsParams(params),
85 },
86 )
87
88 if !p {
89 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
90 }
91
92 // Handle timeout with context
93 requestCtx := ctx
94 if params.Timeout > 0 {
95 maxTimeout := 120 // 2 minutes
96 if params.Timeout > maxTimeout {
97 params.Timeout = maxTimeout
98 }
99 var cancel context.CancelFunc
100 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
101 defer cancel()
102 }
103
104 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
105 if err != nil {
106 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
107 }
108
109 req.Header.Set("User-Agent", "crush/1.0")
110
111 resp, err := client.Do(req)
112 if err != nil {
113 return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
114 }
115 defer resp.Body.Close()
116
117 if resp.StatusCode != http.StatusOK {
118 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
119 }
120
121 maxSize := int64(5 * 1024 * 1024) // 5MB
122 body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
123 if err != nil {
124 return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
125 }
126
127 content := string(body)
128
129 isValidUt8 := utf8.ValidString(content)
130 if !isValidUt8 {
131 return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
132 }
133 contentType := resp.Header.Get("Content-Type")
134
135 switch format {
136 case "text":
137 if strings.Contains(contentType, "text/html") {
138 text, err := extractTextFromHTML(content)
139 if err != nil {
140 return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
141 }
142 content = text
143 }
144
145 case "markdown":
146 if strings.Contains(contentType, "text/html") {
147 markdown, err := convertHTMLToMarkdown(content)
148 if err != nil {
149 return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
150 }
151 content = markdown
152 }
153
154 content = "```\n" + content + "\n```"
155
156 case "html":
157 // return only the body of the HTML document
158 if strings.Contains(contentType, "text/html") {
159 doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
160 if err != nil {
161 return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
162 }
163 body, err := doc.Find("body").Html()
164 if err != nil {
165 return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
166 }
167 if body == "" {
168 return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
169 }
170 content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
171 }
172 }
173 // calculate byte size of content
174 contentSize := int64(len(content))
175 if contentSize > MaxReadSize {
176 content = content[:MaxReadSize]
177 content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
178 }
179
180 return fantasy.NewTextResponse(content), nil
181 })
182}
183
184func extractTextFromHTML(html string) (string, error) {
185 doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
186 if err != nil {
187 return "", err
188 }
189
190 text := doc.Find("body").Text()
191 text = strings.Join(strings.Fields(text), " ")
192
193 return text, nil
194}
195
196func convertHTMLToMarkdown(html string) (string, error) {
197 converter := md.NewConverter("", true, nil)
198
199 markdown, err := converter.ConvertString(html)
200 if err != nil {
201 return "", err
202 }
203
204 return markdown, nil
205}