fetch.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10	"time"
 11	"unicode/utf8"
 12
 13	"charm.land/fantasy"
 14	md "github.com/JohannesKaufmann/html-to-markdown"
 15	"github.com/PuerkitoBio/goquery"
 16	"github.com/charmbracelet/crush/internal/permission"
 17)
 18
 19const (
 20	FetchToolName = "fetch"
 21	MaxFetchSize  = 1 * 1024 * 1024 // 1MB
 22)
 23
 24//go:embed fetch.md
 25var fetchDescription []byte
 26
 27func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
 28	if client == nil {
 29		transport := http.DefaultTransport.(*http.Transport).Clone()
 30		transport.MaxIdleConns = 100
 31		transport.MaxIdleConnsPerHost = 10
 32		transport.IdleConnTimeout = 90 * time.Second
 33
 34		client = &http.Client{
 35			Timeout:   30 * time.Second,
 36			Transport: transport,
 37		}
 38	}
 39
 40	return fantasy.NewParallelAgentTool(
 41		FetchToolName,
 42		string(fetchDescription),
 43		func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 44			if params.URL == "" {
 45				return fantasy.NewTextErrorResponse("URL parameter is required"), nil
 46			}
 47
 48			format := strings.ToLower(params.Format)
 49			if format != "text" && format != "markdown" && format != "html" {
 50				return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
 51			}
 52
 53			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 54				return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
 55			}
 56
 57			sessionID := GetSessionFromContext(ctx)
 58			if sessionID == "" {
 59				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
 60			}
 61
 62			p, err := permissions.Request(ctx,
 63				permission.CreatePermissionRequest{
 64					SessionID:   sessionID,
 65					Path:        workingDir,
 66					ToolCallID:  call.ID,
 67					ToolName:    FetchToolName,
 68					Action:      "fetch",
 69					Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
 70					Params:      FetchPermissionsParams(params),
 71				},
 72			)
 73			if err != nil {
 74				return fantasy.ToolResponse{}, err
 75			}
 76			if !p {
 77				return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
 78			}
 79
 80			// maxFetchTimeoutSeconds is the maximum allowed timeout for fetch requests (2 minutes)
 81			const maxFetchTimeoutSeconds = 120
 82
 83			// Handle timeout with context
 84			requestCtx := ctx
 85			if params.Timeout > 0 {
 86				if params.Timeout > maxFetchTimeoutSeconds {
 87					params.Timeout = maxFetchTimeoutSeconds
 88				}
 89				var cancel context.CancelFunc
 90				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 91				defer cancel()
 92			}
 93
 94			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
 95			if err != nil {
 96				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
 97			}
 98
 99			req.Header.Set("User-Agent", "crush/1.0")
100
101			resp, err := client.Do(req)
102			if err != nil {
103				return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
104			}
105			defer resp.Body.Close()
106
107			if resp.StatusCode != http.StatusOK {
108				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
109			}
110
111			body, err := io.ReadAll(io.LimitReader(resp.Body, MaxFetchSize))
112			if err != nil {
113				return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
114			}
115
116			content := string(body)
117
118			validUTF8 := utf8.ValidString(content)
119			if !validUTF8 {
120				return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
121			}
122			contentType := resp.Header.Get("Content-Type")
123
124			switch format {
125			case "text":
126				if strings.Contains(contentType, "text/html") {
127					text, err := extractTextFromHTML(content)
128					if err != nil {
129						return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
130					}
131					content = text
132				}
133
134			case "markdown":
135				if strings.Contains(contentType, "text/html") {
136					markdown, err := convertHTMLToMarkdown(content)
137					if err != nil {
138						return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
139					}
140					content = markdown
141				}
142
143				content = "```\n" + content + "\n```"
144
145			case "html":
146				// return only the body of the HTML document
147				if strings.Contains(contentType, "text/html") {
148					doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
149					if err != nil {
150						return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
151					}
152					body, err := doc.Find("body").Html()
153					if err != nil {
154						return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
155					}
156					if body == "" {
157						return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
158					}
159					content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
160				}
161			}
162			// truncate content if it exceeds max read size
163			if int64(len(content)) > MaxFetchSize {
164				content = content[:MaxFetchSize]
165				content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxFetchSize)
166			}
167
168			return fantasy.NewTextResponse(content), nil
169		})
170}
171
172func extractTextFromHTML(html string) (string, error) {
173	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
174	if err != nil {
175		return "", err
176	}
177
178	text := doc.Find("body").Text()
179	text = strings.Join(strings.Fields(text), " ")
180
181	return text, nil
182}
183
184func convertHTMLToMarkdown(html string) (string, error) {
185	converter := md.NewConverter("", true, nil)
186
187	markdown, err := converter.ConvertString(html)
188	if err != nil {
189		return "", err
190	}
191
192	return markdown, nil
193}