1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11 "unicode/utf8"
12
13 "charm.land/fantasy"
14 md "github.com/JohannesKaufmann/html-to-markdown"
15 "github.com/PuerkitoBio/goquery"
16 "github.com/charmbracelet/crush/internal/permission"
17)
18
19const FetchToolName = "fetch"
20
21//go:embed fetch.md
22var fetchDescription []byte
23
24func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
25 if client == nil {
26 transport := http.DefaultTransport.(*http.Transport).Clone()
27 transport.MaxIdleConns = 100
28 transport.MaxIdleConnsPerHost = 10
29 transport.IdleConnTimeout = 90 * time.Second
30
31 client = &http.Client{
32 Timeout: 30 * time.Second,
33 Transport: transport,
34 }
35 }
36
37 return fantasy.NewParallelAgentTool(
38 FetchToolName,
39 string(fetchDescription),
40 func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
41 if params.URL == "" {
42 return fantasy.NewTextErrorResponse("URL parameter is required"), nil
43 }
44
45 format := strings.ToLower(params.Format)
46 if format != "text" && format != "markdown" && format != "html" {
47 return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
48 }
49
50 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
51 return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
52 }
53
54 sessionID := GetSessionFromContext(ctx)
55 if sessionID == "" {
56 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
57 }
58
59 p, err := permissions.Request(ctx,
60 permission.CreatePermissionRequest{
61 SessionID: sessionID,
62 Path: workingDir,
63 ToolCallID: call.ID,
64 ToolName: FetchToolName,
65 Action: "fetch",
66 Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
67 Params: FetchPermissionsParams(params),
68 },
69 )
70 if err != nil {
71 return fantasy.ToolResponse{}, err
72 }
73 if !p {
74 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
75 }
76
77 // maxFetchTimeoutSeconds is the maximum allowed timeout for fetch requests (2 minutes)
78 const maxFetchTimeoutSeconds = 120
79
80 // Handle timeout with context
81 requestCtx := ctx
82 if params.Timeout > 0 {
83 if params.Timeout > maxFetchTimeoutSeconds {
84 params.Timeout = maxFetchTimeoutSeconds
85 }
86 var cancel context.CancelFunc
87 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
88 defer cancel()
89 }
90
91 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
92 if err != nil {
93 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
94 }
95
96 req.Header.Set("User-Agent", "crush/1.0")
97
98 resp, err := client.Do(req)
99 if err != nil {
100 return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
101 }
102 defer resp.Body.Close()
103
104 if resp.StatusCode != http.StatusOK {
105 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
106 }
107
108 // maxFetchResponseSizeBytes is the maximum size of response body to read (5MB)
109 const maxFetchResponseSizeBytes = int64(5 * 1024 * 1024)
110
111 maxSize := maxFetchResponseSizeBytes
112 body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
113 if err != nil {
114 return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
115 }
116
117 content := string(body)
118
119 validUTF8 := utf8.ValidString(content)
120 if !validUTF8 {
121 return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
122 }
123 contentType := resp.Header.Get("Content-Type")
124
125 switch format {
126 case "text":
127 if strings.Contains(contentType, "text/html") {
128 text, err := extractTextFromHTML(content)
129 if err != nil {
130 return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
131 }
132 content = text
133 }
134
135 case "markdown":
136 if strings.Contains(contentType, "text/html") {
137 markdown, err := convertHTMLToMarkdown(content)
138 if err != nil {
139 return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
140 }
141 content = markdown
142 }
143
144 content = "```\n" + content + "\n```"
145
146 case "html":
147 // return only the body of the HTML document
148 if strings.Contains(contentType, "text/html") {
149 doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
150 if err != nil {
151 return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
152 }
153 body, err := doc.Find("body").Html()
154 if err != nil {
155 return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
156 }
157 if body == "" {
158 return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
159 }
160 content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
161 }
162 }
163 // truncate content if it exceeds max read size
164 if int64(len(content)) > MaxReadSize {
165 content = content[:MaxReadSize]
166 content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
167 }
168
169 return fantasy.NewTextResponse(content), nil
170 })
171}
172
173func extractTextFromHTML(html string) (string, error) {
174 doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
175 if err != nil {
176 return "", err
177 }
178
179 text := doc.Find("body").Text()
180 text = strings.Join(strings.Fields(text), " ")
181
182 return text, nil
183}
184
185func convertHTMLToMarkdown(html string) (string, error) {
186 converter := md.NewConverter("", true, nil)
187
188 markdown, err := converter.ConvertString(html)
189 if err != nil {
190 return "", err
191 }
192
193 return markdown, nil
194}