1package tools
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"fmt"
  8	"io"
  9	"net/http"
 10	"strings"
 11	"time"
 12	"unicode/utf8"
 13
 14	md "github.com/JohannesKaufmann/html-to-markdown"
 15	"github.com/PuerkitoBio/goquery"
 16	"github.com/charmbracelet/crush/internal/permission"
 17)
 18
 19type FetchParams struct {
 20	URL     string `json:"url"`
 21	Format  string `json:"format"`
 22	Timeout int    `json:"timeout,omitempty"`
 23}
 24
 25type FetchPermissionsParams struct {
 26	URL     string `json:"url"`
 27	Format  string `json:"format"`
 28	Timeout int    `json:"timeout,omitempty"`
 29}
 30
 31type fetchTool struct {
 32	client      *http.Client
 33	permissions permission.Service
 34	workingDir  string
 35}
 36
 37const FetchToolName = "fetch"
 38
 39//go:embed fetch.md
 40var fetchDescription []byte
 41
 42func NewFetchTool(permissions permission.Service, workingDir string) BaseTool {
 43	return &fetchTool{
 44		client: &http.Client{
 45			Timeout: 30 * time.Second,
 46			Transport: &http.Transport{
 47				MaxIdleConns:        100,
 48				MaxIdleConnsPerHost: 10,
 49				IdleConnTimeout:     90 * time.Second,
 50			},
 51		},
 52		permissions: permissions,
 53		workingDir:  workingDir,
 54	}
 55}
 56
 57func (t *fetchTool) Name() string {
 58	return FetchToolName
 59}
 60
 61func (t *fetchTool) Info() ToolInfo {
 62	return ToolInfo{
 63		Name:        FetchToolName,
 64		Description: string(fetchDescription),
 65		Parameters: map[string]any{
 66			"url": map[string]any{
 67				"type":        "string",
 68				"description": "The URL to fetch content from",
 69			},
 70			"format": map[string]any{
 71				"type":        "string",
 72				"description": "The format to return the content in (text, markdown, or html)",
 73				"enum":        []string{"text", "markdown", "html"},
 74			},
 75			"timeout": map[string]any{
 76				"type":        "number",
 77				"description": "Optional timeout in seconds (max 120)",
 78			},
 79		},
 80		Required: []string{"url", "format"},
 81	}
 82}
 83
 84func (t *fetchTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) {
 85	var params FetchParams
 86	if err := json.Unmarshal([]byte(call.Input), ¶ms); err != nil {
 87		return NewTextErrorResponse("Failed to parse fetch parameters: " + err.Error()), nil
 88	}
 89
 90	if params.URL == "" {
 91		return NewTextErrorResponse("URL parameter is required"), nil
 92	}
 93
 94	format := strings.ToLower(params.Format)
 95	if format != "text" && format != "markdown" && format != "html" {
 96		return NewTextErrorResponse("Format must be one of: text, markdown, html"), nil
 97	}
 98
 99	if !strings.HasPrefix(params.URL, "http://") && !strings.HasPrefix(params.URL, "https://") {
100		return NewTextErrorResponse("URL must start with http:// or https://"), nil
101	}
102
103	sessionID, messageID := GetContextValues(ctx)
104	if sessionID == "" || messageID == "" {
105		return ToolResponse{}, fmt.Errorf("session ID and message ID are required for creating a new file")
106	}
107
108	p := t.permissions.Request(
109		permission.CreatePermissionRequest{
110			SessionID:   sessionID,
111			Path:        t.workingDir,
112			ToolCallID:  call.ID,
113			ToolName:    FetchToolName,
114			Action:      "fetch",
115			Description: fmt.Sprintf("Fetch content from URL: %s", params.URL),
116			Params:      FetchPermissionsParams(params),
117		},
118	)
119
120	if !p {
121		return ToolResponse{}, permission.ErrorPermissionDenied
122	}
123
124	// Handle timeout with context
125	requestCtx := ctx
126	if params.Timeout > 0 {
127		maxTimeout := 120 // 2 minutes
128		if params.Timeout > maxTimeout {
129			params.Timeout = maxTimeout
130		}
131		var cancel context.CancelFunc
132		requestCtx, cancel = context.WithTimeout(ctx, time.Duration(params.Timeout)*time.Second)
133		defer cancel()
134	}
135
136	req, err := http.NewRequestWithContext(requestCtx, "GET", params.URL, nil)
137	if err != nil {
138		return ToolResponse{}, fmt.Errorf("failed to create request: %w", err)
139	}
140
141	req.Header.Set("User-Agent", "crush/1.0")
142
143	resp, err := t.client.Do(req)
144	if err != nil {
145		return ToolResponse{}, fmt.Errorf("failed to fetch URL: %w", err)
146	}
147	defer resp.Body.Close()
148
149	if resp.StatusCode != http.StatusOK {
150		return NewTextErrorResponse(fmt.Sprintf("Request failed with status code: %d", resp.StatusCode)), nil
151	}
152
153	maxSize := int64(5 * 1024 * 1024) // 5MB
154	body, err := io.ReadAll(io.LimitReader(resp.Body, maxSize))
155	if err != nil {
156		return NewTextErrorResponse("Failed to read response body: " + err.Error()), nil
157	}
158
159	content := string(body)
160
161	isValidUt8 := utf8.ValidString(content)
162	if !isValidUt8 {
163		return NewTextErrorResponse("Response content is not valid UTF-8"), nil
164	}
165	contentType := resp.Header.Get("Content-Type")
166
167	switch format {
168	case "text":
169		if strings.Contains(contentType, "text/html") {
170			text, err := extractTextFromHTML(content)
171			if err != nil {
172				return NewTextErrorResponse("Failed to extract text from HTML: " + err.Error()), nil
173			}
174			content = text
175		}
176
177	case "markdown":
178		if strings.Contains(contentType, "text/html") {
179			markdown, err := convertHTMLToMarkdown(content)
180			if err != nil {
181				return NewTextErrorResponse("Failed to convert HTML to Markdown: " + err.Error()), nil
182			}
183			content = markdown
184		}
185
186		content = "```\n" + content + "\n```"
187
188	case "html":
189		// return only the body of the HTML document
190		if strings.Contains(contentType, "text/html") {
191			doc, err := goquery.NewDocumentFromReader(strings.NewReader(content))
192			if err != nil {
193				return NewTextErrorResponse("Failed to parse HTML: " + err.Error()), nil
194			}
195			body, err := doc.Find("body").Html()
196			if err != nil {
197				return NewTextErrorResponse("Failed to extract body from HTML: " + err.Error()), nil
198			}
199			if body == "" {
200				return NewTextErrorResponse("No body content found in HTML"), nil
201			}
202			content = "<html>\n<body>\n" + body + "\n</body>\n</html>"
203		}
204	}
205	// calculate byte size of content
206	contentSize := int64(len(content))
207	if contentSize > MaxReadSize {
208		content = content[:MaxReadSize]
209		content += fmt.Sprintf("\n\n[Content truncated to %d bytes]", MaxReadSize)
210	}
211
212	return NewTextResponse(content), nil
213}
214
215func extractTextFromHTML(html string) (string, error) {
216	doc, err := goquery.NewDocumentFromReader(strings.NewReader(html))
217	if err != nil {
218		return "", err
219	}
220
221	text := doc.Find("body").Text()
222	text = strings.Join(strings.Fields(text), " ")
223
224	return text, nil
225}
226
227func convertHTMLToMarkdown(html string) (string, error) {
228	converter := md.NewConverter("", true, nil)
229
230	markdown, err := converter.ConvertString(html)
231	if err != nil {
232		return "", err
233	}
234
235	return markdown, nil
236}