1package tools
2
3import (
4 "context"
5 _ "embed"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11 "unicode/utf8"
12
13 "charm.land/fantasy"
14 md "github.com/JohannesKaufmann/html-to-markdown"
15 "github.com/PuerkitoBio/goquery"
16 "github.com/charmbracelet/crush/internal/permission"
17)
18
19type fetchTool struct {
20 client *http.Client
21 permissions permission.Service
22 workingDir string
23}
24
25const FetchToolName = "fetch"
26
27//go:embed fetch.md
28var fetchDescription []byte
29
30func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
31 if client == nil {
32 client = &http.Client{
33 Timeout: 30 * time.Second,
34 Transport: &http.Transport{
35 MaxIdleConns: 100,
36 MaxIdleConnsPerHost: 10,
37 IdleConnTimeout: 90 * time.Second,
38 },
39 }
40 }
41
42 return fantasy.NewAgentTool(
43 FetchToolName,
44 string(fetchDescription),
45 func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
46 if params.URL == "" {
47 return fantasy.NewTextErrorResponse("URL parameter is required"), nil
48 }
49
50 format := strings.ToLower(params.Format)
51 if format != "text" && format != "markdown" && format != "html" {
52 return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
53 }
54
55 if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
56 return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
57 }
58
59 sessionID := GetSessionFromContext(ctx)
60 if sessionID == "" {
61 return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
62 }
63
64 p := permissions.Request(
65 permission.CreatePermissionRequest{
66 SessionID: sessionID,
67 Path: workingDir,
68 ToolCallID: call.ID,
69 ToolName: FetchToolName,
70 Action: "fetch",
71 Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
72 Params: FetchPermissionsParams(params),
73 },
74 )
75
76 if !p {
77 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
78 }
79
80 // Handle timeout with context
81 requestCtx := ctx
82 if params.Timeout > 0 {
83 maxTimeout := 120 // 2 minutes
84 if params.Timeout > maxTimeout {
85 params.Timeout = maxTimeout
86 }
87 var cancel context.CancelFunc
88 requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
89 defer cancel()
90 }
91
92 req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
93 if err != nil {
94 return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
95 }
96
97 req.Header.Set("User-Agent", "crush/1.0")
98
99 resp, err := client.Do(req)
100 if err != nil {
101 return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
102 }
103 defer resp.Body.Close()
104
105 if resp.StatusCode != http.StatusOK {
106 return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
107 }
108
109 maxSize := int64(5 * 1024 * 1024) // 5MB
110 body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
111 if err != nil {
112 return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
113 }
114
115 content := string(body)
116
117 isValidUt8 := utf8.ValidString(content)
118 if !isValidUt8 {
119 return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
120 }
121 contentType := resp.Header.Get("Content-Type")
122
123 switch format {
124 case "text":
125 if strings.Contains(contentType, "text/html") {
126 text, err := extractTextFromHTML(content)
127 if err != nil {
128 return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
129 }
130 content = text
131 }
132
133 case "markdown":
134 if strings.Contains(contentType, "text/html") {
135 markdown, err := convertHTMLToMarkdown(content)
136 if err != nil {
137 return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
138 }
139 content = markdown
140 }
141
142 content = "```\n" + content + "\n```"
143
144 case "html":
145 // return only the body of the HTML document
146 if strings.Contains(contentType, "text/html") {
147 doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
148 if err != nil {
149 return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
150 }
151 body, err := doc.Find("body").Html()
152 if err != nil {
153 return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
154 }
155 if body == "" {
156 return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
157 }
158 content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
159 }
160 }
161 // calculate byte size of content
162 contentSize := int64(len(content))
163 if contentSize > MaxReadSize {
164 content = content[:MaxReadSize]
165 content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
166 }
167
168 return fantasy.NewTextResponse(content), nil
169 })
170}
171
172func extractTextFromHTML(html string) (string, error) {
173 doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
174 if err != nil {
175 return "", err
176 }
177
178 text := doc.Find("body").Text()
179 text = strings.Join(strings.Fields(text), " ")
180
181 return text, nil
182}
183
184func convertHTMLToMarkdown(html string) (string, error) {
185 converter := md.NewConverter("", true, nil)
186
187 markdown, err := converter.ConvertString(html)
188 if err != nil {
189 return "", err
190 }
191
192 return markdown, nil
193}