1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "net/http"
9 "os"
10 "time"
11
12 "charm.land/fantasy"
13
14 "github.com/charmbracelet/crush/internal/agent/prompt"
15 "github.com/charmbracelet/crush/internal/agent/tools"
16 "github.com/charmbracelet/crush/internal/permission"
17)
18
19//go:embed templates/agentic_fetch.md
20var agenticFetchToolDescription []byte
21
22// agenticFetchValidationResult holds the validated parameters from the tool call context.
23type agenticFetchValidationResult struct {
24 SessionID string
25 AgentMessageID string
26}
27
28// validateAgenticFetchParams validates the tool call parameters and extracts required context values.
29func validateAgenticFetchParams(ctx context.Context, params tools.AgenticFetchParams) (agenticFetchValidationResult, error) {
30 if params.Prompt == "" {
31 return agenticFetchValidationResult{}, errors.New("prompt is required")
32 }
33
34 sessionID := tools.GetSessionFromContext(ctx)
35 if sessionID == "" {
36 return agenticFetchValidationResult{}, errors.New("session id missing from context")
37 }
38
39 agentMessageID := tools.GetMessageFromContext(ctx)
40 if agentMessageID == "" {
41 return agenticFetchValidationResult{}, errors.New("agent message id missing from context")
42 }
43
44 return agenticFetchValidationResult{
45 SessionID: sessionID,
46 AgentMessageID: agentMessageID,
47 }, nil
48}
49
50//go:embed templates/agentic_fetch_prompt.md.tpl
51var agenticFetchPromptTmpl []byte
52
53func (c *coordinator) agenticFetchTool(_ context.Context, client *http.Client) (fantasy.AgentTool, error) {
54 if client == nil {
55 transport := http.DefaultTransport.(*http.Transport).Clone()
56 transport.MaxIdleConns = 100
57 transport.MaxIdleConnsPerHost = 10
58 transport.IdleConnTimeout = 90 * time.Second
59
60 client = &http.Client{
61 Timeout: 30 * time.Second,
62 Transport: transport,
63 }
64 }
65
66 return fantasy.NewParallelAgentTool(
67 tools.AgenticFetchToolName,
68 string(agenticFetchToolDescription),
69 func(ctx context.Context, params tools.AgenticFetchParams, call fantasy.ToolCall) (fantasy.ToolResponse, error) {
70 validationResult, err := validateAgenticFetchParams(ctx, params)
71 if err != nil {
72 return fantasy.NewTextErrorResponse(err.Error()), nil
73 }
74
75 // Determine description based on mode.
76 var description string
77 if params.URL != "" {
78 description = fmt.Sprintf("Fetch and analyze content from URL: %s", params.URL)
79 } else {
80 description = "Search the web and analyze results"
81 }
82
83 p, err := c.permissions.Request(ctx,
84 permission.CreatePermissionRequest{
85 SessionID: validationResult.SessionID,
86 Path: c.cfg.WorkingDir(),
87 ToolCallID: call.ID,
88 ToolName: tools.AgenticFetchToolName,
89 Action: "fetch",
90 Description: description,
91 Params: tools.AgenticFetchPermissionsParams(params),
92 },
93 )
94 if err != nil {
95 return fantasy.ToolResponse{}, err
96 }
97 if !p {
98 return fantasy.ToolResponse{}, permission.ErrorPermissionDenied
99 }
100
101 tmpDir, err := os.MkdirTemp(c.cfg.Options.DataDirectory, "crush-fetch-*")
102 if err != nil {
103 return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary directory: %s", err)), nil
104 }
105 defer os.RemoveAll(tmpDir)
106
107 var fullPrompt string
108
109 if params.URL != "" {
110 // URL mode: fetch the URL content first.
111 content, err := tools.FetchURLAndConvert(ctx, client, params.URL)
112 if err != nil {
113 return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to fetch URL: %s", err)), nil
114 }
115
116 hasLargeContent := len(content) > tools.LargeContentThreshold
117
118 if hasLargeContent {
119 tempFile, err := os.CreateTemp(tmpDir, "page-*.md")
120 if err != nil {
121 return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to create temporary file: %s", err)), nil
122 }
123 tempFilePath := tempFile.Name()
124
125 if _, err := tempFile.WriteString(content); err != nil {
126 tempFile.Close()
127 return fantasy.NewTextErrorResponse(fmt.Sprintf("Failed to write content to file: %s", err)), nil
128 }
129 tempFile.Close()
130
131 fullPrompt = fmt.Sprintf("%s\n\nThe web page from %s has been saved to: %s\n\nUse the view and grep tools to analyze this file and extract the requested information.", params.Prompt, params.URL, tempFilePath)
132 } else {
133 fullPrompt = fmt.Sprintf("%s\n\nWeb page URL: %s\n\n<webpage_content>\n%s\n</webpage_content>", params.Prompt, params.URL, content)
134 }
135 } else {
136 // Search mode: let the sub-agent search and fetch as needed.
137 fullPrompt = fmt.Sprintf("%s\n\nUse the web_search tool to find relevant information. Break down the question into smaller, focused searches if needed. After searching, use web_fetch to get detailed content from the most relevant results.", params.Prompt)
138 }
139
140 promptOpts := []prompt.Option{
141 prompt.WithWorkingDir(tmpDir),
142 }
143
144 promptTemplate, err := prompt.NewPrompt("agentic_fetch", string(agenticFetchPromptTmpl), promptOpts...)
145 if err != nil {
146 return fantasy.ToolResponse{}, fmt.Errorf("error creating prompt: %s", err)
147 }
148
149 _, small, err := c.buildAgentModels(ctx, true)
150 if err != nil {
151 return fantasy.ToolResponse{}, fmt.Errorf("error building models: %s", err)
152 }
153
154 systemPrompt, err := promptTemplate.Build(ctx, small.Model.Provider(), small.Model.Model(), *c.cfg)
155 if err != nil {
156 return fantasy.ToolResponse{}, fmt.Errorf("error building system prompt: %s", err)
157 }
158
159 smallProviderCfg, ok := c.cfg.Providers.Get(small.ModelCfg.Provider)
160 if !ok {
161 return fantasy.ToolResponse{}, errors.New("small model provider not configured")
162 }
163
164 webFetchTool := tools.NewWebFetchTool(tmpDir, client)
165 webSearchTool := tools.NewWebSearchTool(client)
166 fetchTools := []fantasy.AgentTool{
167 webFetchTool,
168 webSearchTool,
169 tools.NewGlobTool(tmpDir),
170 tools.NewGrepTool(tmpDir),
171 tools.NewSourcegraphTool(client),
172 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, tmpDir),
173 }
174
175 agent := NewSessionAgent(SessionAgentOptions{
176 LargeModel: small, // Use small model for both (fetch doesn't need large)
177 SmallModel: small,
178 SystemPromptPrefix: smallProviderCfg.SystemPromptPrefix,
179 SystemPrompt: systemPrompt,
180 DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize,
181 IsYolo: c.permissions.SkipRequests(),
182 Sessions: c.sessions,
183 Messages: c.messages,
184 Tools: fetchTools,
185 })
186
187 agentToolSessionID := c.sessions.CreateAgentToolSessionID(validationResult.AgentMessageID, call.ID)
188 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, validationResult.SessionID, "Fetch Analysis")
189 if err != nil {
190 return fantasy.ToolResponse{}, fmt.Errorf("error creating session: %s", err)
191 }
192
193 c.permissions.AutoApproveSession(session.ID)
194
195 // Use small model for web content analysis (faster and cheaper)
196 maxTokens := small.CatwalkCfg.DefaultMaxTokens
197 if small.ModelCfg.MaxTokens != 0 {
198 maxTokens = small.ModelCfg.MaxTokens
199 }
200
201 result, err := agent.Run(ctx, SessionAgentCall{
202 SessionID: session.ID,
203 Prompt: fullPrompt,
204 MaxOutputTokens: maxTokens,
205 ProviderOptions: getProviderOptions(small, smallProviderCfg),
206 Temperature: small.ModelCfg.Temperature,
207 TopP: small.ModelCfg.TopP,
208 TopK: small.ModelCfg.TopK,
209 FrequencyPenalty: small.ModelCfg.FrequencyPenalty,
210 PresencePenalty: small.ModelCfg.PresencePenalty,
211 })
212 if err != nil {
213 return fantasy.NewTextErrorResponse("error generating response"), nil
214 }
215
216 updatedSession, err := c.sessions.Get(ctx, session.ID)
217 if err != nil {
218 return fantasy.ToolResponse{}, fmt.Errorf("error getting session: %s", err)
219 }
220 parentSession, err := c.sessions.Get(ctx, validationResult.SessionID)
221 if err != nil {
222 return fantasy.ToolResponse{}, fmt.Errorf("error getting parent session: %s", err)
223 }
224
225 parentSession.Cost += updatedSession.Cost
226
227 _, err = c.sessions.Save(ctx, parentSession)
228 if err != nil {
229 return fantasy.ToolResponse{}, fmt.Errorf("error saving parent session: %s", err)
230 }
231
232 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
233 }), nil
234}