1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11 "unicode/utf8"
12
13 "charm.land/fantasy"
14 md "github.com/JohannesKaufmann/html-to-markdown"
15 "github.com/PuerkitoBio/goquery"
16 "github.com/charmbracelet/crush/internal/permission"
17)
18
19const (
20 FetchToolName = "fetch"
21 MaxFetchSize = 1 * 1024 * 1024 // 1MB
22)
23
24//go:embed fetch.md
25var fetchDescription []byte
26
27func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
28 if client == nil {
29 transport := http.DefaultTransport.(*http.Transport).Clone()
30 transport.MaxIdleConns = 100
31 transport.MaxIdleConnsPerHost = 10
32 transport.IdleConnTimeout = 90 * time.Second
33
34 client = &http.Client{
35 Timeout: 30 * time.Second,
36 Transport: transport,
37 }
38 }
39
40 return fantasy.NewParallelAgentTool(
41 FetchToolName,
42 string(fetchDescription),
43 func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
44 if params.URL == "" {
45 return fantasy.NewTextErrorResponse("URL parameter is required"), nil
46 }
47
48 format := strings.ToLower(params.Format)
49 if format != "text" && format != "markdown" && format != "html" {
50 return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
51 }
52
53 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
54 return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
55 }
56
57 sessionID := GetSessionFromContext(ctx)
58 if sessionID == "" {
59 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
60 }
61
62 p, err := permissions.Request(ctx,
63 permission.CreatePermissionRequest{
64 SessionID: sessionID,
65 Path: workingDir,
66 ToolCallID: call.ID,
67 ToolName: FetchToolName,
68 Action: "fetch",
69 Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
70 Params: FetchPermissionsParams(params),
71 },
72 )
73 if err != nil {
74 return fantasy.ToolResponse{}, err
75 }
76 if !p {
77 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
78 }
79
80 // maxFetchTimeoutSeconds is the maximum allowed timeout for fetch requests (2 minutes)
81 const maxFetchTimeoutSeconds = 120
82
83 // Handle timeout with context
84 requestCtx := ctx
85 if params.Timeout > 0 {
86 if params.Timeout > maxFetchTimeoutSeconds {
87 params.Timeout = maxFetchTimeoutSeconds
88 }
89 var cancel context.CancelFunc
90 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
91 defer cancel()
92 }
93
94 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
95 if err != nil {
96 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
97 }
98
99 req.Header.Set("User-Agent", "crush/1.0")
100
101 resp, err := client.Do(req)
102 if err != nil {
103 return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
104 }
105 defer resp.Body.Close()
106
107 if resp.StatusCode != http.StatusOK {
108 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
109 }
110
111 body, err := io.ReadAll(io.LimitReader(resp.Body, MaxFetchSize))
112 if err != nil {
113 return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
114 }
115
116 content := string(body)
117
118 validUTF8 := utf8.ValidString(content)
119 if !validUTF8 {
120 return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
121 }
122 contentType := resp.Header.Get("Content-Type")
123
124 switch format {
125 case "text":
126 if strings.Contains(contentType, "text/html") {
127 text, err := extractTextFromHTML(content)
128 if err != nil {
129 return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
130 }
131 content = text
132 }
133
134 case "markdown":
135 if strings.Contains(contentType, "text/html") {
136 markdown, err := convertHTMLToMarkdown(content)
137 if err != nil {
138 return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
139 }
140 content = markdown
141 }
142
143 content = "```\n" + content + "\n```"
144
145 case "html":
146 // return only the body of the HTML document
147 if strings.Contains(contentType, "text/html") {
148 doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
149 if err != nil {
150 return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
151 }
152 body, err := doc.Find("body").Html()
153 if err != nil {
154 return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
155 }
156 if body == "" {
157 return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
158 }
159 content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
160 }
161 }
162 // truncate content if it exceeds max read size
163 if int64(len(content)) > MaxFetchSize {
164 content = content[:MaxFetchSize]
165 content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxFetchSize)
166 }
167
168 return fantasy.NewTextResponse(content), nil
169 })
170}
171
172func extractTextFromHTML(html string) (string, error) {
173 doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
174 if err != nil {
175 return "", err
176 }
177
178 text := doc.Find("body").Text()
179 text = strings.Join(strings.Fields(text), " ")
180
181 return text, nil
182}
183
184func convertHTMLToMarkdown(html string) (string, error) {
185 converter := md.NewConverter("", true, nil)
186
187 markdown, err := converter.ConvertString(html)
188 if err != nil {
189 return "", err
190 }
191
192 return markdown, nil
193}