fetch.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10	"time"
 11	"unicode/utf8"
 12
 13	"charm.land/fantasy"
 14	md "github.com/JohannesKaufmann/html-to-markdown"
 15	"github.com/PuerkitoBio/goquery"
 16	"github.com/charmbracelet/crush/internal/permission"
 17)
 18
 19const FetchToolName = "fetch"
 20
 21//go:embed fetch.md
 22var fetchDescription []byte
 23
 24func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
 25	if client == nil {
 26		transport := http.DefaultTransport.(*http.Transport).Clone()
 27		transport.MaxIdleConns = 100
 28		transport.MaxIdleConnsPerHost = 10
 29		transport.IdleConnTimeout = 90 * time.Second
 30
 31		client = &http.Client{
 32			Timeout:   30 * time.Second,
 33			Transport: transport,
 34		}
 35	}
 36
 37	return fantasy.NewParallelAgentTool(
 38		FetchToolName,
 39		string(fetchDescription),
 40		func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 41			if params.URL == "" {
 42				return fantasy.NewTextErrorResponse("URL parameter is required"), nil
 43			}
 44
 45			format := strings.ToLower(params.Format)
 46			if format != "text" && format != "markdown" && format != "html" {
 47				return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
 48			}
 49
 50			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 51				return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
 52			}
 53
 54			sessionID := GetSessionFromContext(ctx)
 55			if sessionID == "" {
 56				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
 57			}
 58
 59			p, err := permissions.Request(ctx,
 60				permission.CreatePermissionRequest{
 61					SessionID:   sessionID,
 62					Path:        workingDir,
 63					ToolCallID:  call.ID,
 64					ToolName:    FetchToolName,
 65					Action:      "fetch",
 66					Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
 67					Params:      FetchPermissionsParams(params),
 68				},
 69			)
 70			if err != nil {
 71				return fantasy.ToolResponse{}, err
 72			}
 73			if !p {
 74				return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
 75			}
 76
 77			// maxFetchTimeoutSeconds is the maximum allowed timeout for fetch requests (2 minutes)
 78			const maxFetchTimeoutSeconds = 120
 79
 80			// Handle timeout with context
 81			requestCtx := ctx
 82			if params.Timeout > 0 {
 83				if params.Timeout > maxFetchTimeoutSeconds {
 84					params.Timeout = maxFetchTimeoutSeconds
 85				}
 86				var cancel context.CancelFunc
 87				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 88				defer cancel()
 89			}
 90
 91			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
 92			if err != nil {
 93				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
 94			}
 95
 96			req.Header.Set("User-Agent", "crush/1.0")
 97
 98			resp, err := client.Do(req)
 99			if err != nil {
100				return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
101			}
102			defer resp.Body.Close()
103
104			if resp.StatusCode != http.StatusOK {
105				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
106			}
107
108			// maxFetchResponseSizeBytes is the maximum size of response body to read (5MB)
109			const maxFetchResponseSizeBytes = int64(5 * 1024 * 1024)
110
111			maxSize := maxFetchResponseSizeBytes
112			body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
113			if err != nil {
114				return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
115			}
116
117			content := string(body)
118
119			validUTF8 := utf8.ValidString(content)
120			if !validUTF8 {
121				return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
122			}
123			contentType := resp.Header.Get("Content-Type")
124
125			switch format {
126			case "text":
127				if strings.Contains(contentType, "text/html") {
128					text, err := extractTextFromHTML(content)
129					if err != nil {
130						return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
131					}
132					content = text
133				}
134
135			case "markdown":
136				if strings.Contains(contentType, "text/html") {
137					markdown, err := convertHTMLToMarkdown(content)
138					if err != nil {
139						return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
140					}
141					content = markdown
142				}
143
144				content = "```\n" + content + "\n```"
145
146			case "html":
147				// return only the body of the HTML document
148				if strings.Contains(contentType, "text/html") {
149					doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
150					if err != nil {
151						return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
152					}
153					body, err := doc.Find("body").Html()
154					if err != nil {
155						return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
156					}
157					if body == "" {
158						return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
159					}
160					content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
161				}
162			}
163			// truncate content if it exceeds max read size
164			if int64(len(content)) > MaxReadSize {
165				content = content[:MaxReadSize]
166				content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
167			}
168
169			return fantasy.NewTextResponse(content), nil
170		})
171}
172
173func extractTextFromHTML(html string) (string, error) {
174	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
175	if err != nil {
176		return "", err
177	}
178
179	text := doc.Find("body").Text()
180	text = strings.Join(strings.Fields(text), " ")
181
182	return text, nil
183}
184
185func convertHTMLToMarkdown(html string) (string, error) {
186	converter := md.NewConverter("", true, nil)
187
188	markdown, err := converter.ConvertString(html)
189	if err != nil {
190		return "", err
191	}
192
193	return markdown, nil
194}