1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10	"time"
 11	"unicode/utf8"
 12
 13	md "github.com/JohannesKaufmann/html-to-markdown"
 14	"github.com/PuerkitoBio/goquery"
 15	"github.com/charmbracelet/crush/internal/permission"
 16)
 17
 18type FetchParams struct {
 19	URL     string `json:"url"`
 20	Format  string `json:"format"`
 21	Timeout int    `json:"timeout,omitempty"`
 22}
 23
 24type FetchPermissionsParams struct {
 25	URL     string `json:"url"`
 26	Format  string `json:"format"`
 27	Timeout int    `json:"timeout,omitempty"`
 28}
 29
 30type fetchTool struct {
 31	client      *http.Client
 32	permissions permission.Service
 33	workingDir  string
 34}
 35
 36const (
 37	FetchToolName        = "fetch"
 38	fetchToolDescription = `Fetches content from a URL and returns it in the specified format.
 39
 40WHEN TO USE THIS TOOL:
 41- Use when you need to download content from a URL
 42- Helpful for retrieving documentation, API responses, or web content
 43- Useful for getting external information to assist with tasks
 44
 45HOW TO USE:
 46- Provide the URL to fetch content from
 47- Specify the desired output format (text, markdown, or html)
 48- Optionally set a timeout for the request
 49
 50FEATURES:
 51- Supports three output formats: text, markdown, and html
 52- Automatically handles HTTP redirects
 53- Sets reasonable timeouts to prevent hanging
 54- Validates input parameters before making requests
 55
 56LIMITATIONS:
 57- Maximum response size is 5MB
 58- Only supports HTTP and HTTPS protocols
 59- Cannot handle authentication or cookies
 60- Some websites may block automated requests
 61
 62TIPS:
 63- Use text format for plain text content or simple API responses
 64- Use markdown format for content that should be rendered with formatting
 65- Use html format when you need the raw HTML structure
 66- Set appropriate timeouts for potentially slow websites`
 67)
 68
 69func NewFetchTool(permissions permission.Service, workingDir string) BaseTool {
 70	return &fetchTool{
 71		client: &http.Client{
 72			Timeout: 30 * time.Second,
 73			Transport: &http.Transport{
 74				MaxIdleConns:        100,
 75				MaxIdleConnsPerHost: 10,
 76				IdleConnTimeout:     90 * time.Second,
 77			},
 78		},
 79		permissions: permissions,
 80		workingDir:  workingDir,
 81	}
 82}
 83
 84func (t *fetchTool) Name() string {
 85	return FetchToolName
 86}
 87
 88func (t *fetchTool) Info() ToolInfo {
 89	return ToolInfo{
 90		Name:        FetchToolName,
 91		Description: fetchToolDescription,
 92		Parameters: map[string]any{
 93			"url": map[string]any{
 94				"type":        "string",
 95				"description": "The URL to fetch content from",
 96			},
 97			"format": map[string]any{
 98				"type":        "string",
 99				"description": "The format to return the content in (text, markdown, or html)",
100				"enum":        []string{"text", "markdown", "html"},
101			},
102			"timeout": map[string]any{
103				"type":        "number",
104				"description": "Optional timeout in seconds (max 120)",
105			},
106		},
107		Required: []string{"url", "format"},
108	}
109}
110
111func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
112	var params FetchParams
113	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
114		return NewTextErrorResponse("Failed to parse fetch parameters: " + err.Error()), nil
115	}
116
117	if params.URL == "" {
118		return NewTextErrorResponse("URL parameter is required"), nil
119	}
120
121	format := strings.ToLower(params.Format)
122	if format != "text" && format != "markdown" && format != "html" {
123		return NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
124	}
125
126	if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
127		return NewTextErrorResponse("URL must start with http:// or https://"), nil
128	}
129
130	sessionID, messageID := GetContextValues(ctx)
131	if sessionID == "" || messageID == "" {
132		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
133	}
134
135	p := t.permissions.Request(
136		permission.CreatePermissionRequest{
137			SessionID:   sessionID,
138			Path:        t.workingDir,
139			ToolCallID:  call.ID,
140			ToolName:    FetchToolName,
141			Action:      "fetch",
142			Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
143			Params:      FetchPermissionsParams(params),
144		},
145	)
146
147	if !p {
148		return ToolResponse{}, permission.ErrorPermissionDenied
149	}
150
151	// Handle timeout with context
152	requestCtx := ctx
153	if params.Timeout > 0 {
154		maxTimeout := 120 // 2 minutes
155		if params.Timeout > maxTimeout {
156			params.Timeout = maxTimeout
157		}
158		var cancel context.CancelFunc
159		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
160		defer cancel()
161	}
162
163	req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
164	if err != nil {
165		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
166	}
167
168	req.Header.Set("User-Agent", "crush/1.0")
169
170	resp, err := t.client.Do(req)
171	if err != nil {
172		return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
173	}
174	defer resp.Body.Close()
175
176	if resp.StatusCode != http.StatusOK {
177		return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
178	}
179
180	maxSize := int64(5 * 1024 * 1024) // 5MB
181	body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
182	if err != nil {
183		return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
184	}
185
186	content := string(body)
187
188	isValidUt8 := utf8.ValidString(content)
189	if !isValidUt8 {
190		return NewTextErrorResponse("Response content is not valid UTF-8"), nil
191	}
192	contentType := resp.Header.Get("Content-Type")
193
194	switch format {
195	case "text":
196		if strings.Contains(contentType, "text/html") {
197			text, err := extractTextFromHTML(content)
198			if err != nil {
199				return NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
200			}
201			content = text
202		}
203
204	case "markdown":
205		if strings.Contains(contentType, "text/html") {
206			markdown, err := convertHTMLToMarkdown(content)
207			if err != nil {
208				return NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
209			}
210			content = markdown
211		}
212
213		content = "```\n" + content + "\n```"
214
215	case "html":
216		// return only the body of the HTML document
217		if strings.Contains(contentType, "text/html") {
218			doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
219			if err != nil {
220				return NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
221			}
222			body, err := doc.Find("body").Html()
223			if err != nil {
224				return NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
225			}
226			if body == "" {
227				return NewTextErrorResponse("No body content found in HTML"), nil
228			}
229			content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
230		}
231	}
232	// calculate byte size of content
233	contentSize := int64(len(content))
234	if contentSize > MaxReadSize {
235		content = content[:MaxReadSize]
236		content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
237	}
238
239	return NewTextResponse(content), nil
240}
241
242func extractTextFromHTML(html string) (string, error) {
243	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
244	if err != nil {
245		return "", err
246	}
247
248	text := doc.Find("body").Text()
249	text = strings.Join(strings.Fields(text), " ")
250
251	return text, nil
252}
253
254func convertHTMLToMarkdown(html string) (string, error) {
255	converter := md.NewConverter("", true, nil)
256
257	markdown, err := converter.ConvertString(html)
258	if err != nil {
259		return "", err
260	}
261
262	return markdown, nil
263}