fetch.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10	"time"
 11	"unicode/utf8"
 12
 13	"charm.land/fantasy"
 14	md "github.com/JohannesKaufmann/html-to-markdown"
 15	"github.com/PuerkitoBio/goquery"
 16	"github.com/charmbracelet/crush/internal/permission"
 17)
 18
 19const FetchToolName = "fetch"
 20
 21//go:embed fetch.md
 22var fetchDescription []byte
 23
 24func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
 25	if client == nil {
 26		client = &http.Client{
 27			Timeout: 30 * time.Second,
 28			Transport: &http.Transport{
 29				MaxIdleConns:        100,
 30				MaxIdleConnsPerHost: 10,
 31				IdleConnTimeout:     90 * time.Second,
 32			},
 33		}
 34	}
 35
 36	return fantasy.NewParallelAgentTool(
 37		FetchToolName,
 38		string(fetchDescription),
 39		func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 40			if params.URL == "" {
 41				return fantasy.NewTextErrorResponse("URL parameter is required"), nil
 42			}
 43
 44			format := strings.ToLower(params.Format)
 45			if format != "text" && format != "markdown" && format != "html" {
 46				return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
 47			}
 48
 49			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 50				return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
 51			}
 52
 53			sessionID := GetSessionFromContext(ctx)
 54			if sessionID == "" {
 55				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
 56			}
 57
 58			p, err := permissions.Request(ctx,
 59				permission.CreatePermissionRequest{
 60					SessionID:   sessionID,
 61					Path:        workingDir,
 62					ToolCallID:  call.ID,
 63					ToolName:    FetchToolName,
 64					Action:      "fetch",
 65					Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
 66					Params:      FetchPermissionsParams(params),
 67				},
 68			)
 69			if err != nil {
 70				return fantasy.ToolResponse{}, err
 71			}
 72			if !p.Granted {
 73				if p.Message != "" {
 74					return fantasy.NewTextErrorResponse("User denied permission." + permission.UserCommentaryTag(p.Message)), nil
 75				}
 76				return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
 77			}
 78
 79			// maxFetchTimeoutSeconds is the maximum allowed timeout for fetch requests (2 minutes)
 80			const maxFetchTimeoutSeconds = 120
 81
 82			// Handle timeout with context
 83			requestCtx := ctx
 84			if params.Timeout > 0 {
 85				if params.Timeout > maxFetchTimeoutSeconds {
 86					params.Timeout = maxFetchTimeoutSeconds
 87				}
 88				var cancel context.CancelFunc
 89				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 90				defer cancel()
 91			}
 92
 93			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
 94			if err != nil {
 95				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
 96			}
 97
 98			req.Header.Set("User-Agent", "crush/1.0")
 99
100			resp, err := client.Do(req)
101			if err != nil {
102				return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
103			}
104			defer resp.Body.Close()
105
106			if resp.StatusCode != http.StatusOK {
107				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
108			}
109
110			// maxFetchResponseSizeBytes is the maximum size of response body to read (5MB)
111			const maxFetchResponseSizeBytes = int64(5 * 1024 * 1024)
112
113			maxSize := maxFetchResponseSizeBytes
114			body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
115			if err != nil {
116				return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
117			}
118
119			content := string(body)
120
121			validUTF8 := utf8.ValidString(content)
122			if !validUTF8 {
123				return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
124			}
125			contentType := resp.Header.Get("Content-Type")
126
127			switch format {
128			case "text":
129				if strings.Contains(contentType, "text/html") {
130					text, err := extractTextFromHTML(content)
131					if err != nil {
132						return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
133					}
134					content = text
135				}
136
137			case "markdown":
138				if strings.Contains(contentType, "text/html") {
139					markdown, err := convertHTMLToMarkdown(content)
140					if err != nil {
141						return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
142					}
143					content = markdown
144				}
145
146				content = "```\n" + content + "\n```"
147
148			case "html":
149				// return only the body of the HTML document
150				if strings.Contains(contentType, "text/html") {
151					doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
152					if err != nil {
153						return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
154					}
155					body, err := doc.Find("body").Html()
156					if err != nil {
157						return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
158					}
159					if body == "" {
160						return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
161					}
162					content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
163				}
164			}
165			// truncate content if it exceeds max read size
166			if int64(len(content)) > MaxReadSize {
167				content = content[:MaxReadSize]
168				content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
169			}
170
171			return fantasy.NewTextResponse(p.AppendCommentary(content)), nil
172		})
173}
174
175func extractTextFromHTML(html string) (string, error) {
176	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
177	if err != nil {
178		return "", err
179	}
180
181	text := doc.Find("body").Text()
182	text = strings.Join(strings.Fields(text), " ")
183
184	return text, nil
185}
186
187func convertHTMLToMarkdown(html string) (string, error) {
188	converter := md.NewConverter("", true, nil)
189
190	markdown, err := converter.ConvertString(html)
191	if err != nil {
192		return "", err
193	}
194
195	return markdown, nil
196}