fetch.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"html/template"
  8	"io"
  9	"net/http"
 10	"strings"
 11	"time"
 12	"unicode/utf8"
 13
 14	"charm.land/fantasy"
 15	md "github.com/JohannesKaufmann/html-to-markdown"
 16	"github.com/PuerkitoBio/goquery"
 17	"github.com/charmbracelet/crush/internal/permission"
 18)
 19
 20const (
 21	FetchToolName = "fetch"
 22	MaxFetchSize  = 100 * 1024 // 100KB
 23)
 24
 25//go:embed fetch.md.tpl
 26var fetchDescriptionTmpl []byte
 27
 28var fetchDescriptionTpl = template.Must(
 29	template.New("fetchDescription").
 30		Parse(string(fetchDescriptionTmpl)),
 31)
 32
 33func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
 34	if client == nil {
 35		transport := http.DefaultTransport.(*http.Transport).Clone()
 36		transport.MaxIdleConns = 100
 37		transport.MaxIdleConnsPerHost = 10
 38		transport.IdleConnTimeout = 90 * time.Second
 39
 40		client = &http.Client{
 41			Timeout:   30 * time.Second,
 42			Transport: transport,
 43		}
 44	}
 45
 46	return fantasy.NewParallelAgentTool(
 47		FetchToolName,
 48		renderToolDescription(fetchDescriptionTpl),
 49		func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 50			if params.URL == "" {
 51				return fantasy.NewTextErrorResponse("URL parameter is required"), nil
 52			}
 53
 54			format := strings.ToLower(params.Format)
 55			if format != "text" && format != "markdown" && format != "html" {
 56				return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
 57			}
 58
 59			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 60				return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
 61			}
 62
 63			sessionID := GetSessionFromContext(ctx)
 64			if sessionID == "" {
 65				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
 66			}
 67
 68			p, err := permissions.Request(ctx,
 69				permission.CreatePermissionRequest{
 70					SessionID:   sessionID,
 71					Path:        workingDir,
 72					ToolCallID:  call.ID,
 73					ToolName:    FetchToolName,
 74					Action:      "fetch",
 75					Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
 76					Params:      FetchPermissionsParams(params),
 77				},
 78			)
 79			if err != nil {
 80				return fantasy.ToolResponse{}, err
 81			}
 82			if !p {
 83				return NewPermissionDeniedResponse(), nil
 84			}
 85
 86			// maxFetchTimeoutSeconds is the maximum allowed timeout for fetch requests (2 minutes)
 87			const maxFetchTimeoutSeconds = 120
 88
 89			// Handle timeout with context
 90			requestCtx := ctx
 91			if params.Timeout > 0 {
 92				if params.Timeout > maxFetchTimeoutSeconds {
 93					params.Timeout = maxFetchTimeoutSeconds
 94				}
 95				var cancel context.CancelFunc
 96				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 97				defer cancel()
 98			}
 99
100			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
101			if err != nil {
102				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
103			}
104
105			req.Header.Set("User-Agent", "crush/1.0")
106
107			resp, err := client.Do(req)
108			if err != nil {
109				return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
110			}
111			defer resp.Body.Close()
112
113			if resp.StatusCode != http.StatusOK {
114				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
115			}
116
117			body, err := io.ReadAll(io.LimitReader(resp.Body, MaxFetchSize))
118			if err != nil {
119				return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
120			}
121
122			content := string(body)
123
124			validUTF8 := utf8.ValidString(content)
125			if !validUTF8 {
126				return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
127			}
128			contentType := resp.Header.Get("Content-Type")
129
130			switch format {
131			case "text":
132				if strings.Contains(contentType, "text/html") {
133					text, err := extractTextFromHTML(content)
134					if err != nil {
135						return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
136					}
137					content = text
138				}
139
140			case "markdown":
141				if strings.Contains(contentType, "text/html") {
142					markdown, err := convertHTMLToMarkdown(content)
143					if err != nil {
144						return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
145					}
146					content = markdown
147				}
148
149				content = "```\n" + content + "\n```"
150
151			case "html":
152				// return only the body of the HTML document
153				if strings.Contains(contentType, "text/html") {
154					doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
155					if err != nil {
156						return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
157					}
158					body, err := doc.Find("body").Html()
159					if err != nil {
160						return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
161					}
162					if body == "" {
163						return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
164					}
165					content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
166				}
167			}
168			// truncate content if it exceeds max read size
169			if int64(len(content)) >= MaxFetchSize {
170				content = content[:MaxFetchSize]
171				content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxFetchSize)
172			}
173
174			return fantasy.NewTextResponse(content), nil
175		})
176}
177
178func extractTextFromHTML(html string) (string, error) {
179	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
180	if err != nil {
181		return "", err
182	}
183
184	text := doc.Find("body").Text()
185	text = strings.Join(strings.Fields(text), " ")
186
187	return text, nil
188}
189
190func convertHTMLToMarkdown(html string) (string, error) {
191	converter := md.NewConverter("", true, nil)
192
193	markdown, err := converter.ConvertString(html)
194	if err != nil {
195		return "", err
196	}
197
198	return markdown, nil
199}