1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"net/http"
 10	"strings"
 11	"time"
 12	"unicode/utf8"
 13
 14	md "github.com/JohannesKaufmann/html-to-markdown"
 15	"github.com/PuerkitoBio/goquery"
 16	"github.com/charmbracelet/crush/internal/permission"
 17	"github.com/charmbracelet/crush/internal/proto"
 18)
 19
 20type (
 21	FetchParams            = proto.FetchParams
 22	FetchPermissionsParams = proto.FetchPermissionsParams
 23)
 24
 25type fetchTool struct {
 26	client      *http.Client
 27	permissions permission.Service
 28	workingDir  string
 29}
 30
 31const FetchToolName = proto.FetchToolName
 32
 33//go:embed fetch.md
 34var fetchDescription []byte
 35
 36func NewFetchTool(permissions permission.Service, workingDir string) BaseTool {
 37	return &fetchTool{
 38		client: &http.Client{
 39			Timeout: 30 * time.Second,
 40			Transport: &http.Transport{
 41				MaxIdleConns:        100,
 42				MaxIdleConnsPerHost: 10,
 43				IdleConnTimeout:     90 * time.Second,
 44			},
 45		},
 46		permissions: permissions,
 47		workingDir:  workingDir,
 48	}
 49}
 50
 51func (t *fetchTool) Name() string {
 52	return FetchToolName
 53}
 54
 55func (t *fetchTool) Info() ToolInfo {
 56	return ToolInfo{
 57		Name:        FetchToolName,
 58		Description: string(fetchDescription),
 59		Parameters: map[string]any{
 60			"url": map[string]any{
 61				"type":        "string",
 62				"description": "The URL to fetch content from",
 63			},
 64			"format": map[string]any{
 65				"type":        "string",
 66				"description": "The format to return the content in (text, markdown, or html)",
 67				"enum":        []string{"text", "markdown", "html"},
 68			},
 69			"timeout": map[string]any{
 70				"type":        "number",
 71				"description": "Optional timeout in seconds (max 120)",
 72			},
 73		},
 74		Required: []string{"url", "format"},
 75	}
 76}
 77
 78func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 79	var params FetchParams
 80	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 81		return NewTextErrorResponse("Failed to parse fetch parameters: " + err.Error()), nil
 82	}
 83
 84	if params.URL == "" {
 85		return NewTextErrorResponse("URL parameter is required"), nil
 86	}
 87
 88	format := strings.ToLower(params.Format)
 89	if format != "text" && format != "markdown" && format != "html" {
 90		return NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
 91	}
 92
 93	if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
 94		return NewTextErrorResponse("URL must start with http:// or https://"), nil
 95	}
 96
 97	sessionID, messageID := GetContextValues(ctx)
 98	if sessionID == "" || messageID == "" {
 99		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
100	}
101
102	p := t.permissions.Request(
103		permission.CreatePermissionRequest{
104			SessionID:   sessionID,
105			Path:        t.workingDir,
106			ToolCallID:  call.ID,
107			ToolName:    FetchToolName,
108			Action:      "fetch",
109			Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
110			Params:      FetchPermissionsParams(params),
111		},
112	)
113
114	if !p {
115		return ToolResponse{}, permission.ErrorPermissionDenied
116	}
117
118	// Handle timeout with context
119	requestCtx := ctx
120	if params.Timeout > 0 {
121		maxTimeout := 120 // 2 minutes
122		if params.Timeout > maxTimeout {
123			params.Timeout = maxTimeout
124		}
125		var cancel context.CancelFunc
126		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
127		defer cancel()
128	}
129
130	req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
131	if err != nil {
132		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
133	}
134
135	req.Header.Set("User-Agent", "crush/1.0")
136
137	resp, err := t.client.Do(req)
138	if err != nil {
139		return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
140	}
141	defer resp.Body.Close()
142
143	if resp.StatusCode != http.StatusOK {
144		return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
145	}
146
147	maxSize := int64(5 * 1024 * 1024) // 5MB
148	body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
149	if err != nil {
150		return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
151	}
152
153	content := string(body)
154
155	isValidUt8 := utf8.ValidString(content)
156	if !isValidUt8 {
157		return NewTextErrorResponse("Response content is not valid UTF-8"), nil
158	}
159	contentType := resp.Header.Get("Content-Type")
160
161	switch format {
162	case "text":
163		if strings.Contains(contentType, "text/html") {
164			text, err := extractTextFromHTML(content)
165			if err != nil {
166				return NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
167			}
168			content = text
169		}
170
171	case "markdown":
172		if strings.Contains(contentType, "text/html") {
173			markdown, err := convertHTMLToMarkdown(content)
174			if err != nil {
175				return NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
176			}
177			content = markdown
178		}
179
180		content = "```\n" + content + "\n```"
181
182	case "html":
183		// return only the body of the HTML document
184		if strings.Contains(contentType, "text/html") {
185			doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
186			if err != nil {
187				return NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
188			}
189			body, err := doc.Find("body").Html()
190			if err != nil {
191				return NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
192			}
193			if body == "" {
194				return NewTextErrorResponse("No body content found in HTML"), nil
195			}
196			content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
197		}
198	}
199	// calculate byte size of content
200	contentSize := int64(len(content))
201	if contentSize > MaxReadSize {
202		content = content[:MaxReadSize]
203		content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
204	}
205
206	return NewTextResponse(content), nil
207}
208
209func extractTextFromHTML(html string) (string, error) {
210	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
211	if err != nil {
212		return "", err
213	}
214
215	text := doc.Find("body").Text()
216	text = strings.Join(strings.Fields(text), " ")
217
218	return text, nil
219}
220
221func convertHTMLToMarkdown(html string) (string, error) {
222	converter := md.NewConverter("", true, nil)
223
224	markdown, err := converter.ConvertString(html)
225	if err != nil {
226		return "", err
227	}
228
229	return markdown, nil
230}