fetch.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"html/template"
  8	"io"
  9	"net/http"
 10	"strings"
 11	"time"
 12	"unicode/utf8"
 13
 14	"charm.land/fantasy"
 15	md "github.com/JohannesKaufmann/html-to-markdown"
 16	"github.com/PuerkitoBio/goquery"
 17	"github.com/charmbracelet/crush/internal/permission"
 18)
 19
 20const (
 21	FetchToolName = "fetch"
 22	MaxFetchSize  = 100 * 1024 // 100KB
 23)
 24
 25//go:embed fetch.md.tpl
 26var fetchDescriptionTmpl []byte
 27
 28var fetchDescriptionTpl = template.Must(
 29	template.New("fetchDescription").
 30		Parse(string(fetchDescriptionTmpl)),
 31)
 32
 33type fetchDescriptionData struct {
 34	GhAvailable    bool
 35	MaxFetchSizeKB int
 36}
 37
 38func fetchDescription() string {
 39	return renderTemplate(fetchDescriptionTpl, fetchDescriptionData{
 40		GhAvailable:    ghAvailable,
 41		MaxFetchSizeKB: MaxFetchSize / 1024,
 42	})
 43}
 44
 45func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
 46	if client == nil {
 47		transport := http.DefaultTransport.(*http.Transport).Clone()
 48		transport.MaxIdleConns = 100
 49		transport.MaxIdleConnsPerHost = 10
 50		transport.IdleConnTimeout = 90 * time.Second
 51
 52		client = &http.Client{
 53			Timeout:   30 * time.Second,
 54			Transport: transport,
 55		}
 56	}
 57
 58	return fantasy.NewParallelAgentTool(
 59		FetchToolName,
 60		fetchDescription(),
 61		func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 62			if params.URL == "" {
 63				return fantasy.NewTextErrorResponse("URL parameter is required"), nil
 64			}
 65
 66			format := strings.ToLower(params.Format)
 67			if format != "text" && format != "markdown" && format != "html" {
 68				return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
 69			}
 70
 71			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 72				return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
 73			}
 74
 75			sessionID := GetSessionFromContext(ctx)
 76			if sessionID == "" {
 77				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
 78			}
 79
 80			p, err := permissions.Request(
 81				ctx,
 82				permission.CreatePermissionRequest{
 83					SessionID:   sessionID,
 84					Path:        workingDir,
 85					ToolCallID:  call.ID,
 86					ToolName:    FetchToolName,
 87					Action:      "fetch",
 88					Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
 89					Params:      FetchPermissionsParams(params),
 90				},
 91			)
 92			if err != nil {
 93				return fantasy.ToolResponse{}, err
 94			}
 95			if !p {
 96				return NewPermissionDeniedResponse(), nil
 97			}
 98
 99			// maxFetchTimeoutSeconds is the maximum allowed timeout for fetch requests (2 minutes)
100			const maxFetchTimeoutSeconds = 120
101
102			// Handle timeout with context
103			requestCtx := ctx
104			if params.Timeout > 0 {
105				if params.Timeout > maxFetchTimeoutSeconds {
106					params.Timeout = maxFetchTimeoutSeconds
107				}
108				var cancel context.CancelFunc
109				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
110				defer cancel()
111			}
112
113			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
114			if err != nil {
115				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
116			}
117
118			req.Header.Set("User-Agent", "crush/1.0")
119
120			resp, err := client.Do(req)
121			if err != nil {
122				return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
123			}
124			defer resp.Body.Close()
125
126			if resp.StatusCode != http.StatusOK {
127				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
128			}
129
130			body, err := io.ReadAll(io.LimitReader(resp.Body, MaxFetchSize))
131			if err != nil {
132				return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
133			}
134
135			content := string(body)
136
137			validUTF8 := utf8.ValidString(content)
138			if !validUTF8 {
139				return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
140			}
141			contentType := resp.Header.Get("Content-Type")
142
143			switch format {
144			case "text":
145				if strings.Contains(contentType, "text/html") {
146					text, err := extractTextFromHTML(content)
147					if err != nil {
148						return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
149					}
150					content = text
151				}
152
153			case "markdown":
154				if strings.Contains(contentType, "text/html") {
155					markdown, err := convertHTMLToMarkdown(content)
156					if err != nil {
157						return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
158					}
159					content = markdown
160				}
161
162				content = "```\n" + content + "\n```"
163
164			case "html":
165				// return only the body of the HTML document
166				if strings.Contains(contentType, "text/html") {
167					doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
168					if err != nil {
169						return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
170					}
171					body, err := doc.Find("body").Html()
172					if err != nil {
173						return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
174					}
175					if body == "" {
176						return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
177					}
178					content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
179				}
180			}
181			// truncate content if it exceeds max read size
182			if int64(len(content)) >= MaxFetchSize {
183				content = content[:MaxFetchSize]
184				content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxFetchSize)
185			}
186
187			return fantasy.NewTextResponse(content), nil
188		},
189	)
190}
191
192func extractTextFromHTML(html string) (string, error) {
193	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
194	if err != nil {
195		return "", err
196	}
197
198	text := doc.Find("body").Text()
199	text = strings.Join(strings.Fields(text), " ")
200
201	return text, nil
202}
203
204func convertHTMLToMarkdown(html string) (string, error) {
205	converter := md.NewConverter("", true, nil)
206
207	markdown, err := converter.ConvertString(html)
208	if err != nil {
209		return "", err
210	}
211
212	return markdown, nil
213}