fetch.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10	"time"
 11	"unicode/utf8"
 12
 13	"charm.land/fantasy"
 14	md "github.com/JohannesKaufmann/html-to-markdown"
 15	"github.com/PuerkitoBio/goquery"
 16	"github.com/charmbracelet/crush/internal/permission"
 17)
 18
 19type FetchParams struct {
 20	URL     string `json:"url" description:"The URL to fetch content from"`
 21	Format  string `json:"format" description:"The format to return the content in (text, markdown, or html)"`
 22	Timeout int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
 23}
 24
 25type FetchPermissionsParams struct {
 26	URL     string `json:"url"`
 27	Format  string `json:"format"`
 28	Timeout int    `json:"timeout,omitempty"`
 29}
 30
 31const FetchToolName = "fetch"
 32
 33//go:embed fetch.md
 34var fetchDescription []byte
 35
 36func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
 37	if client == nil {
 38		client = &http.Client{
 39			Timeout: 30 * time.Second,
 40			Transport: &http.Transport{
 41				MaxIdleConns:        100,
 42				MaxIdleConnsPerHost: 10,
 43				IdleConnTimeout:     90 * time.Second,
 44			},
 45		}
 46	}
 47
 48	return fantasy.NewAgentTool(
 49		FetchToolName,
 50		string(fetchDescription),
 51		func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 52			if params.URL == "" {
 53				return fantasy.NewTextErrorResponse("URL parameter is required"), nil
 54			}
 55
 56			format := strings.ToLower(params.Format)
 57			if format != "text" && format != "markdown" && format != "html" {
 58				return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
 59			}
 60
 61			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 62				return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
 63			}
 64
 65			sessionID := GetSessionFromContext(ctx)
 66			if sessionID == "" {
 67				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
 68			}
 69
 70			p := permissions.Request(
 71				permission.CreatePermissionRequest{
 72					SessionID:   sessionID,
 73					Path:        workingDir,
 74					ToolCallID:  call.ID,
 75					ToolName:    FetchToolName,
 76					Action:      "fetch",
 77					Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
 78					Params:      FetchPermissionsParams(params),
 79				},
 80			)
 81
 82			if !p {
 83				return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
 84			}
 85
 86			// Handle timeout with context
 87			requestCtx := ctx
 88			if params.Timeout > 0 {
 89				maxTimeout := 120 // 2 minutes
 90				if params.Timeout > maxTimeout {
 91					params.Timeout = maxTimeout
 92				}
 93				var cancel context.CancelFunc
 94				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 95				defer cancel()
 96			}
 97
 98			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
 99			if err != nil {
100				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
101			}
102
103			req.Header.Set("User-Agent", "crush/1.0")
104
105			resp, err := client.Do(req)
106			if err != nil {
107				return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
108			}
109			defer resp.Body.Close()
110
111			if resp.StatusCode != http.StatusOK {
112				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
113			}
114
115			maxSize := int64(5 * 1024 * 1024) // 5MB
116			body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
117			if err != nil {
118				return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
119			}
120
121			content := string(body)
122
123			isValidUt8 := utf8.ValidString(content)
124			if !isValidUt8 {
125				return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
126			}
127			contentType := resp.Header.Get("Content-Type")
128
129			switch format {
130			case "text":
131				if strings.Contains(contentType, "text/html") {
132					text, err := extractTextFromHTML(content)
133					if err != nil {
134						return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
135					}
136					content = text
137				}
138
139			case "markdown":
140				if strings.Contains(contentType, "text/html") {
141					markdown, err := convertHTMLToMarkdown(content)
142					if err != nil {
143						return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
144					}
145					content = markdown
146				}
147
148				content = "```\n" + content + "\n```"
149
150			case "html":
151				// return only the body of the HTML document
152				if strings.Contains(contentType, "text/html") {
153					doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
154					if err != nil {
155						return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
156					}
157					body, err := doc.Find("body").Html()
158					if err != nil {
159						return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
160					}
161					if body == "" {
162						return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
163					}
164					content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
165				}
166			}
167			// calculate byte size of content
168			contentSize := int64(len(content))
169			if contentSize > MaxReadSize {
170				content = content[:MaxReadSize]
171				content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
172			}
173
174			return fantasy.NewTextResponse(content), nil
175		})
176}
177
178func extractTextFromHTML(html string) (string, error) {
179	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
180	if err != nil {
181		return "", err
182	}
183
184	text := doc.Find("body").Text()
185	text = strings.Join(strings.Fields(text), " ")
186
187	return text, nil
188}
189
190func convertHTMLToMarkdown(html string) (string, error) {
191	converter := md.NewConverter("", true, nil)
192
193	markdown, err := converter.ConvertString(html)
194	if err != nil {
195		return "", err
196	}
197
198	return markdown, nil
199}