fetch.go

  1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10	"time"
 11	"unicode/utf8"
 12
 13	md "github.com/JohannesKaufmann/html-to-markdown"
 14	"github.com/PuerkitoBio/goquery"
 15	"github.com/charmbracelet/crush/internal/permission"
 16	"github.com/charmbracelet/fantasy/ai"
 17)
 18
 19type FetchParams struct {
 20	URL     string `json:"url" description:"The URL to fetch content from"`
 21	Format  string `json:"format" description:"The format to return the content in (text, markdown, or html)"`
 22	Timeout int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
 23}
 24
 25type FetchPermissionsParams struct {
 26	URL     string `json:"url"`
 27	Format  string `json:"format"`
 28	Timeout int    `json:"timeout,omitempty"`
 29}
 30
 31type fetchTool struct {
 32	client      *http.Client
 33	permissions permission.Service
 34	workingDir  string
 35}
 36
 37const FetchToolName = "fetch"
 38
 39//go:embed fetch.md
 40var fetchDescription []byte
 41
 42func NewFetchTool(permissions permission.Service, workingDir string) ai.AgentTool {
 43	client := &http.Client{
 44		Timeout: 30 * time.Second,
 45		Transport: &http.Transport{
 46			MaxIdleConns:        100,
 47			MaxIdleConnsPerHost: 10,
 48			IdleConnTimeout:     90 * time.Second,
 49		},
 50	}
 51
 52	return ai.NewAgentTool(
 53		FetchToolName,
 54		string(fetchDescription),
 55		func(ctx context.Context, params FetchParams, call ai.ToolCall) (ai.ToolResponse, error) {
 56			if params.URL == "" {
 57				return ai.NewTextErrorResponse("URL parameter is required"), nil
 58			}
 59
 60			format := strings.ToLower(params.Format)
 61			if format != "text" && format != "markdown" && format != "html" {
 62				return ai.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
 63			}
 64
 65			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 66				return ai.NewTextErrorResponse("URL must start with http:// or https://"), nil
 67			}
 68
 69			sessionID := GetSessionFromContext(ctx)
 70			if sessionID == "" {
 71				return ai.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
 72			}
 73
 74			p := permissions.Request(
 75				permission.CreatePermissionRequest{
 76					SessionID:   sessionID,
 77					Path:        workingDir,
 78					ToolCallID:  call.ID,
 79					ToolName:    FetchToolName,
 80					Action:      "fetch",
 81					Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
 82					Params:      FetchPermissionsParams(params),
 83				},
 84			)
 85
 86			if !p {
 87				return ai.ToolResponse{}, permission.ErrorPermissionDenied
 88			}
 89
 90			// Handle timeout with context
 91			requestCtx := ctx
 92			if params.Timeout > 0 {
 93				maxTimeout := 120 // 2 minutes
 94				if params.Timeout > maxTimeout {
 95					params.Timeout = maxTimeout
 96				}
 97				var cancel context.CancelFunc
 98				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
 99				defer cancel()
100			}
101
102			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
103			if err != nil {
104				return ai.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
105			}
106
107			req.Header.Set("User-Agent", "crush/1.0")
108
109			resp, err := client.Do(req)
110			if err != nil {
111				return ai.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
112			}
113			defer resp.Body.Close()
114
115			if resp.StatusCode != http.StatusOK {
116				return ai.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
117			}
118
119			maxSize := int64(5 * 1024 * 1024) // 5MB
120			body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
121			if err != nil {
122				return ai.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
123			}
124
125			content := string(body)
126
127			isValidUt8 := utf8.ValidString(content)
128			if !isValidUt8 {
129				return ai.NewTextErrorResponse("Response content is not valid UTF-8"), nil
130			}
131			contentType := resp.Header.Get("Content-Type")
132
133			switch format {
134			case "text":
135				if strings.Contains(contentType, "text/html") {
136					text, err := extractTextFromHTML(content)
137					if err != nil {
138						return ai.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
139					}
140					content = text
141				}
142
143			case "markdown":
144				if strings.Contains(contentType, "text/html") {
145					markdown, err := convertHTMLToMarkdown(content)
146					if err != nil {
147						return ai.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
148					}
149					content = markdown
150				}
151
152				content = "```\n" + content + "\n```"
153
154			case "html":
155				// return only the body of the HTML document
156				if strings.Contains(contentType, "text/html") {
157					doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
158					if err != nil {
159						return ai.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
160					}
161					body, err := doc.Find("body").Html()
162					if err != nil {
163						return ai.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
164					}
165					if body == "" {
166						return ai.NewTextErrorResponse("No body content found in HTML"), nil
167					}
168					content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
169				}
170			}
171			// calculate byte size of content
172			contentSize := int64(len(content))
173			if contentSize > MaxReadSize {
174				content = content[:MaxReadSize]
175				content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
176			}
177
178			return ai.NewTextResponse(content), nil
179		})
180}
181
182func extractTextFromHTML(html string) (string, error) {
183	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
184	if err != nil {
185		return "", err
186	}
187
188	text := doc.Find("body").Text()
189	text = strings.Join(strings.Fields(text), " ")
190
191	return text, nil
192}
193
194func convertHTMLToMarkdown(html string) (string, error) {
195	converter := md.NewConverter("", true, nil)
196
197	markdown, err := converter.ConvertString(html)
198	if err != nil {
199		return "", err
200	}
201
202	return markdown, nil
203}