1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"fmt"
  7	"io"
  8	"net/http"
  9	"strings"
 10	"time"
 11	"unicode/utf8"
 12
 13	md "github.com/JohannesKaufmann/html-to-markdown"
 14	"github.com/PuerkitoBio/goquery"
 15	"github.com/charmbracelet/crush/internal/permission"
 16	"github.com/charmbracelet/fantasy/ai"
 17)
 18
 19type FetchParams struct {
 20	URL     string `json:"url" description:"The URL to fetch content from"`
 21	Format  string `json:"format" description:"The format to return the content in (text, markdown, or html)"`
 22	Timeout int    `json:"timeout,omitempty" description:"Optional timeout in seconds (max 120)"`
 23}
 24
 25type FetchPermissionsParams struct {
 26	URL     string `json:"url"`
 27	Format  string `json:"format"`
 28	Timeout int    `json:"timeout,omitempty"`
 29}
 30
 31type fetchTool struct {
 32	client      *http.Client
 33	permissions permission.Service
 34	workingDir  string
 35}
 36
 37const FetchToolName = "fetch"
 38
 39//go:embed fetch.md
 40var fetchDescription []byte
 41
 42func NewFetchTool(permissions permission.Service, workingDir string, client *http.Client) ai.AgentTool {
 43	if client == nil {
 44		client = &http.Client{
 45			Timeout: 30 * time.Second,
 46			Transport: &http.Transport{
 47				MaxIdleConns:        100,
 48				MaxIdleConnsPerHost: 10,
 49				IdleConnTimeout:     90 * time.Second,
 50			},
 51		}
 52	}
 53
 54	return ai.NewAgentTool(
 55		FetchToolName,
 56		string(fetchDescription),
 57		func(ctx context.Context, params FetchParams, call ai.ToolCall) (ai.ToolResponse, error) {
 58			if params.URL == "" {
 59				return ai.NewTextErrorResponse("URL parameter is required"), nil
 60			}
 61
 62			format := strings.ToLower(params.Format)
 63			if format != "text" && format != "markdown" && format != "html" {
 64				return ai.NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
 65			}
 66
 67			if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 68				return ai.NewTextErrorResponse("URL must start with http:// or https://"), nil
 69			}
 70
 71			sessionID := GetSessionFromContext(ctx)
 72			if sessionID == "" {
 73				return ai.ToolResponse{}, fmt.Errorf("session ID is required for creating a new file")
 74			}
 75
 76			p := permissions.Request(
 77				permission.CreatePermissionRequest{
 78					SessionID:   sessionID,
 79					Path:        workingDir,
 80					ToolCallID:  call.ID,
 81					ToolName:    FetchToolName,
 82					Action:      "fetch",
 83					Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
 84					Params:      FetchPermissionsParams(params),
 85				},
 86			)
 87
 88			if !p {
 89				return ai.ToolResponse{}, permission.ErrorPermissionDenied
 90			}
 91
 92			// Handle timeout with context
 93			requestCtx := ctx
 94			if params.Timeout > 0 {
 95				maxTimeout := 120 // 2 minutes
 96				if params.Timeout > maxTimeout {
 97					params.Timeout = maxTimeout
 98				}
 99				var cancel context.CancelFunc
100				requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
101				defer cancel()
102			}
103
104			req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
105			if err != nil {
106				return ai.ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
107			}
108
109			req.Header.Set("User-Agent", "crush/1.0")
110
111			resp, err := client.Do(req)
112			if err != nil {
113				return ai.ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
114			}
115			defer resp.Body.Close()
116
117			if resp.StatusCode != http.StatusOK {
118				return ai.NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
119			}
120
121			maxSize := int64(5 * 1024 * 1024) // 5MB
122			body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
123			if err != nil {
124				return ai.NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
125			}
126
127			content := string(body)
128
129			isValidUt8 := utf8.ValidString(content)
130			if !isValidUt8 {
131				return ai.NewTextErrorResponse("Response content is not valid UTF-8"), nil
132			}
133			contentType := resp.Header.Get("Content-Type")
134
135			switch format {
136			case "text":
137				if strings.Contains(contentType, "text/html") {
138					text, err := extractTextFromHTML(content)
139					if err != nil {
140						return ai.NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
141					}
142					content = text
143				}
144
145			case "markdown":
146				if strings.Contains(contentType, "text/html") {
147					markdown, err := convertHTMLToMarkdown(content)
148					if err != nil {
149						return ai.NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
150					}
151					content = markdown
152				}
153
154				content = "```\n" + content + "\n```"
155
156			case "html":
157				// return only the body of the HTML document
158				if strings.Contains(contentType, "text/html") {
159					doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
160					if err != nil {
161						return ai.NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
162					}
163					body, err := doc.Find("body").Html()
164					if err != nil {
165						return ai.NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
166					}
167					if body == "" {
168						return ai.NewTextErrorResponse("No body content found in HTML"), nil
169					}
170					content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
171				}
172			}
173			// calculate byte size of content
174			contentSize := int64(len(content))
175			if contentSize > MaxReadSize {
176				content = content[:MaxReadSize]
177				content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
178			}
179
180			return ai.NewTextResponse(content), nil
181		})
182}
183
184func extractTextFromHTML(html string) (string, error) {
185	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
186	if err != nil {
187		return "", err
188	}
189
190	text := doc.Find("body").Text()
191	text = strings.Join(strings.Fields(text), " ")
192
193	return text, nil
194}
195
196func convertHTMLToMarkdown(html string) (string, error) {
197	converter := md.NewConverter("", true, nil)
198
199	markdown, err := converter.ConvertString(html)
200	if err != nil {
201		return "", err
202	}
203
204	return markdown, nil
205}