fetch.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10	"time"
 11	"unicode/utf8"
 12
 13	"charm.land/fantasy"
 14	md "github.com/JohannesKaufmann/html-to-markdown"
 15	"github.com/PuerkitoBio/goquery"
 16	"github.com/charmbracelet/crush/internal/permission"
 17)
 18
 19type fetchTool struct {
 20	client      *http.Client
 21	permissions permission.Service
 22	workingDir  string
 23}
 24
 25const FetchToolName = "fetch"
 26
 27//go:embed fetch.md
 28var fetchDescription []byte
 29
 30func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) fantasy.AgentTool {
 31	if client == nil {
 32		client = &http.Client{
 33			Timeout: 30 * time.Second,
 34			Transport: &http.Transport{
 35				MaxIdleConns:        100,
 36				MaxIdleConnsPerHost: 10,
 37				IdleConnTimeout:     90 * time.Second,
 38			},
 39		}
 40	}
 41
 42	return fantasy.NewAgentTool(
 43		FetchToolName,
 44		string(fetchDescription),
 45		func(ctx context.Context, params FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 46			if params.URL == "" {
 47				return fantasy.NewTextErrorResponse("URL parameter is required"), nil
 48			}
 49
 50			format := strings.ToLower(params.Format)
 51			if format != "text" && format != "markdown" && format != "html" {
 52				return fantasy.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
 53			}
 54
 55			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 56				return fantasy.NewTextErrorResponse("URL must start with http:// or https://"), nil
 57			}
 58
 59			sessionID := GetSessionFromContext(ctx)
 60			if sessionID == "" {
 61				return fantasy.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
 62			}
 63
 64			p := permissions.Request(
 65				permission.CreatePermissionRequest{
 66					SessionID:   sessionID,
 67					Path:        workingDir,
 68					ToolCallID:  call.ID,
 69					ToolName:    FetchToolName,
 70					Action:      "fetch",
 71					Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
 72					Params:      FetchPermissionsParams(params),
 73				},
 74			)
 75
 76			if !p {
 77				return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
 78			}
 79
 80			// Handle timeout with context
 81			requestCtx := ctx
 82			if params.Timeout > 0 {
 83				maxTimeout := 120 // 2 minutes
 84				if params.Timeout > maxTimeout {
 85					params.Timeout = maxTimeout
 86				}
 87				var cancel context.CancelFunc
 88				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 89				defer cancel()
 90			}
 91
 92			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
 93			if err != nil {
 94				return fantasy.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
 95			}
 96
 97			req.Header.Set("User-Agent", "crush/1.0")
 98
 99			resp, err := client.Do(req)
100			if err != nil {
101				return fantasy.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
102			}
103			defer resp.Body.Close()
104
105			if resp.StatusCode != http.StatusOK {
106				return fantasy.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
107			}
108
109			maxSize := int64(5 * 1024 * 1024) // 5MB
110			body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
111			if err != nil {
112				return fantasy.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
113			}
114
115			content := string(body)
116
117			isValidUt8 := utf8.ValidString(content)
118			if !isValidUt8 {
119				return fantasy.NewTextErrorResponse("Response content is not valid UTF-8"), nil
120			}
121			contentType := resp.Header.Get("Content-Type")
122
123			switch format {
124			case "text":
125				if strings.Contains(contentType, "text/html") {
126					text, err := extractTextFromHTML(content)
127					if err != nil {
128						return fantasy.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
129					}
130					content = text
131				}
132
133			case "markdown":
134				if strings.Contains(contentType, "text/html") {
135					markdown, err := convertHTMLToMarkdown(content)
136					if err != nil {
137						return fantasy.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
138					}
139					content = markdown
140				}
141
142				content = "```\n" + content + "\n```"
143
144			case "html":
145				// return only the body of the HTML document
146				if strings.Contains(contentType, "text/html") {
147					doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
148					if err != nil {
149						return fantasy.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
150					}
151					body, err := doc.Find("body").Html()
152					if err != nil {
153						return fantasy.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
154					}
155					if body == "" {
156						return fantasy.NewTextErrorResponse("No body content found in HTML"), nil
157					}
158					content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
159				}
160			}
161			// calculate byte size of content
162			contentSize := int64(len(content))
163			if contentSize > MaxReadSize {
164				content = content[:MaxReadSize]
165				content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
166			}
167
168			return fantasy.NewTextResponse(content), nil
169		})
170}
171
172func extractTextFromHTML(html string) (string, error) {
173	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
174	if err != nil {
175		return "", err
176	}
177
178	text := doc.Find("body").Text()
179	text = strings.Join(strings.Fields(text), " ")
180
181	return text, nil
182}
183
184func convertHTMLToMarkdown(html string) (string, error) {
185	converter := md.NewConverter("", true, nil)
186
187	markdown, err := converter.ConvertString(html)
188	if err != nil {
189		return "", err
190	}
191
192	return markdown, nil
193}