fetch_tool.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"net/http"
 10	"os"
 11	"time"
 12
 13	"charm.land/fantasy"
 14
 15	"github.com/charmbracelet/crush/internal/agent/prompt"
 16	"github.com/charmbracelet/crush/internal/agent/tools"
 17	"github.com/charmbracelet/crush/internal/permission"
 18)
 19
 20//go:embed templates/fetch.md
 21var fetchToolDescription []byte
 22
 23//go:embed templates/fetch_prompt.md.tpl
 24var fetchPromptTmpl []byte
 25
 26func (c *coordinator) fetchTool(_ context.Context, client *http.Client) (fantasy.AgentTool, error) {
 27	if client == nil {
 28		client = &http.Client{
 29			Timeout: 30 * time.Second,
 30			Transport: &http.Transport{
 31				MaxIdleConns:        100,
 32				MaxIdleConnsPerHost: 10,
 33				IdleConnTimeout:     90 * time.Second,
 34			},
 35		}
 36	}
 37
 38	return fantasy.NewAgentTool(
 39		tools.FetchToolName,
 40		string(fetchToolDescription),
 41		func(ctx context.Context, params tools.FetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
 42			if err := json.Unmarshal([]byte(call.Input), &params); err != nil {
 43				return fantasy.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
 44			}
 45
 46			if params.URL == "" {
 47				return fantasy.NewTextErrorResponse("url is required"), nil
 48			}
 49
 50			if params.Prompt == "" {
 51				return fantasy.NewTextErrorResponse("prompt is required"), nil
 52			}
 53
 54			sessionID := tools.GetSessionFromContext(ctx)
 55			if sessionID == "" {
 56				return fantasy.ToolResponse{}, errors.New("session id missing from context")
 57			}
 58
 59			agentMessageID := tools.GetMessageFromContext(ctx)
 60			if agentMessageID == "" {
 61				return fantasy.ToolResponse{}, errors.New("agent message id missing from context")
 62			}
 63
 64			p := c.permissions.Request(
 65				permission.CreatePermissionRequest{
 66					SessionID:   sessionID,
 67					Path:        c.cfg.WorkingDir(),
 68					ToolCallID:  call.ID,
 69					ToolName:    tools.FetchToolName,
 70					Action:      "fetch",
 71					Description: fmt.Sprintf("Fetch and analyze content from URL: %s", params.URL),
 72					Params:      tools.FetchPermissionsParams(params),
 73				},
 74			)
 75
 76			if !p {
 77				return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
 78			}
 79
 80			content, err := tools.FetchURLAndConvert(ctx, client, params.URL)
 81			if err != nil {
 82				return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to fetch URL: %s", err)), nil
 83			}
 84
 85			tmpDir, err := os.MkdirTemp(c.cfg.Options.DataDirectory, "crush-fetch-*")
 86			if err != nil {
 87				return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary directory: %s", err)), nil
 88			}
 89			defer os.RemoveAll(tmpDir)
 90
 91			hasLargeContent := len(content) > tools.LargeContentThreshold
 92			var fullPrompt string
 93
 94			if hasLargeContent {
 95				tempFile, err := os.CreateTemp(tmpDir, "page-*.md")
 96				if err != nil {
 97					return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary file: %s", err)), nil
 98				}
 99				tempFilePath := tempFile.Name()
100
101				if _, err := tempFile.WriteString(content); err != nil {
102					tempFile.Close()
103					return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to write content to file: %s", err)), nil
104				}
105				tempFile.Close()
106
107				fullPrompt = fmt.Sprintf("%s\n\nThe web page from %s has been saved to: %s\n\nUse the view and grep tools to analyze this file and extract the requested information.", params.Prompt, params.URL, tempFilePath)
108			} else {
109				fullPrompt = fmt.Sprintf("%s\n\nWeb page URL: %s\n\n<webpage_content>\n%s\n</webpage_content>", params.Prompt, params.URL, content)
110			}
111
112			promptOpts := []prompt.Option{
113				prompt.WithWorkingDir(tmpDir),
114			}
115
116			promptTemplate, err := prompt.NewPrompt("fetch", string(fetchPromptTmpl), promptOpts...)
117			if err != nil {
118				return fantasy.ToolResponse{}, fmt.Errorf("error creating prompt: %s", err)
119			}
120
121			_, small, err := c.buildAgentModels(ctx)
122			if err != nil {
123				return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err)
124			}
125
126			systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), *c.cfg)
127			if err != nil {
128				return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err)
129			}
130
131			smallProviderCfg, ok := c.cfg.Providers.Get(small.ModelCfg.Provider)
132			if !ok {
133				return fantasy.ToolResponse{}, errors.New("small model provider not configured")
134			}
135
136			webFetchTool := tools.NewWebFetchTool(tmpDir, client)
137			fetchTools := []fantasy.AgentTool{
138				webFetchTool,
139				tools.NewGlobTool(tmpDir),
140				tools.NewGrepTool(tmpDir),
141				tools.NewViewTool(c.lspClients, c.permissions, tmpDir),
142			}
143
144			agent := NewSessionAgent(SessionAgentOptions{
145				LargeModel:           small, // Use small model for both (fetch doesn't need large)
146				SmallModel:           small,
147				SystemPromptPrefix:   smallProviderCfg.SystemPromptPrefix,
148				SystemPrompt:         systemPrompt,
149				DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize,
150				IsYolo:               c.permissions.SkipRequests(),
151				Sessions:             c.sessions,
152				Messages:             c.messages,
153				Tools:                fetchTools,
154			})
155
156			agentToolSessionID := c.sessions.CreateAgentToolSessionID(agentMessageID, call.ID)
157			session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, sessionID, "Fetch Analysis")
158			if err != nil {
159				return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
160			}
161
162			c.permissions.AutoApproveSession(session.ID)
163
164			// Use small model for web content analysis (faster and cheaper)
165			maxTokens := small.CatwalkCfg.DefaultMaxTokens
166			if small.ModelCfg.MaxTokens != 0 {
167				maxTokens = small.ModelCfg.MaxTokens
168			}
169
170			result, err := agent.Run(ctx, SessionAgentCall{
171				SessionID:        session.ID,
172				Prompt:           fullPrompt,
173				MaxOutputTokens:  maxTokens,
174				ProviderOptions:  getProviderOptions(small, smallProviderCfg),
175				Temperature:      small.ModelCfg.Temperature,
176				TopP:             small.ModelCfg.TopP,
177				TopK:             small.ModelCfg.TopK,
178				FrequencyPenalty: small.ModelCfg.FrequencyPenalty,
179				PresencePenalty:  small.ModelCfg.PresencePenalty,
180			})
181			if err != nil {
182				return fantasy.NewTextErrorResponse("error generating response"), nil
183			}
184
185			updatedSession, err := c.sessions.Get(ctx, session.ID)
186			if err != nil {
187				return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
188			}
189			parentSession, err := c.sessions.Get(ctx, sessionID)
190			if err != nil {
191				return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
192			}
193
194			parentSession.Cost += updatedSession.Cost
195
196			_, err = c.sessions.Save(ctx, parentSession)
197			if err != nil {
198				return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
199			}
200
201			return fantasy.NewTextResponse(result.Response.Content.Text()), nil
202		}), nil
203}