1package agent
2
3import (
4 "context"
5 "errors"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/agent/prompt"
11 "github.com/charmbracelet/crush/internal/agent/tools"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/csync"
14 "github.com/charmbracelet/crush/internal/history"
15 "github.com/charmbracelet/crush/internal/log"
16 "github.com/charmbracelet/crush/internal/lsp"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/session"
20 "github.com/charmbracelet/fantasy/ai"
21 "github.com/charmbracelet/fantasy/anthropic"
22 "github.com/charmbracelet/fantasy/azure"
23 "github.com/charmbracelet/fantasy/google"
24 "github.com/charmbracelet/fantasy/openai"
25 "github.com/charmbracelet/fantasy/openaicompat"
26 "github.com/charmbracelet/fantasy/openrouter"
27)
28
29type Coordinator interface {
30 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
31 // SetMainAgent(string)
32 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
33 Cancel(sessionID string)
34 CancelAll()
35 IsSessionBusy(sessionID string) bool
36 IsBusy() bool
37 QueuedPrompts(sessionID string) int
38 ClearQueue(sessionID string)
39 Summarize(context.Context, string) error
40 Model() Model
41 UpdateModels() error
42}
43
44type coordinator struct {
45 cfg *config.Config
46 sessions session.Service
47 messages message.Service
48 permissions permission.Service
49 history history.Service
50 lspClients *csync.Map[string, *lsp.Client]
51
52 currentAgent SessionAgent
53 agents map[string]SessionAgent
54}
55
56func NewCoordinator(
57 cfg *config.Config,
58 sessions session.Service,
59 messages message.Service,
60 permissions permission.Service,
61 history history.Service,
62 lspClients *csync.Map[string, *lsp.Client],
63) (Coordinator, error) {
64 c := &coordinator{
65 cfg: cfg,
66 sessions: sessions,
67 messages: messages,
68 permissions: permissions,
69 history: history,
70 lspClients: lspClients,
71 agents: make(map[string]SessionAgent),
72 }
73
74 agentCfg, ok := cfg.Agents[config.AgentCoder]
75 if !ok {
76 return nil, errors.New("coder agent not configured")
77 }
78
79 // TODO: make this dynamic when we support multiple agents
80 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
81 if err != nil {
82 return nil, err
83 }
84
85 agent, err := c.buildAgent(prompt, agentCfg)
86 if err != nil {
87 return nil, err
88 }
89 c.currentAgent = agent
90 c.agents[config.AgentCoder] = agent
91 return c, nil
92}
93
94// Run implements Coordinator.
95func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
96 model := c.currentAgent.Model()
97 maxTokens := model.CatwalkCfg.DefaultMaxTokens
98 if model.ModelCfg.MaxTokens != 0 {
99 maxTokens = model.ModelCfg.MaxTokens
100 }
101
102 if !model.CatwalkCfg.SupportsImages && attachments != nil {
103 attachments = nil
104 }
105
106 return c.currentAgent.Run(ctx, SessionAgentCall{
107 SessionID: sessionID,
108 Prompt: prompt,
109 Attachments: attachments,
110 MaxOutputTokens: maxTokens,
111 ProviderOptions: c.getProviderOptions(model),
112 Temperature: model.ModelCfg.Temperature,
113 TopP: model.ModelCfg.TopP,
114 TopK: model.ModelCfg.TopK,
115 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
116 PresencePenalty: model.ModelCfg.PresencePenalty,
117 })
118}
119
120func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions {
121 options := ai.ProviderOptions{}
122
123 switch model.Model.Provider() {
124 case openai.Name:
125 parsed, err := openai.ParseOptions(model.ModelCfg.ProviderOptions)
126 if err == nil {
127 options[openai.Name] = parsed
128 }
129 case anthropic.Name:
130 parsed, err := anthropic.ParseOptions(model.ModelCfg.ProviderOptions)
131 if err == nil {
132 options[anthropic.Name] = parsed
133 }
134 case openrouter.Name:
135 parsed, err := openrouter.ParseOptions(model.ModelCfg.ProviderOptions)
136 if err == nil {
137 options[openrouter.Name] = parsed
138 }
139 case google.Name:
140 parsed, err := google.ParseOptions(model.ModelCfg.ProviderOptions)
141 if err == nil {
142 options[google.Name] = parsed
143 }
144 case openaicompat.Name:
145 parsed, err := openaicompat.ParseOptions(model.ModelCfg.ProviderOptions)
146 if err == nil {
147 options[openaicompat.Name] = parsed
148 }
149 }
150
151 return options
152}
153
154func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
155 large, small, err := c.buildAgentModels()
156 if err != nil {
157 return nil, err
158 }
159
160 systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
161 if err != nil {
162 return nil, err
163 }
164
165 tools, err := c.buildTools(agent)
166 if err != nil {
167 return nil, err
168 }
169 return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
170}
171
172func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
173 var allTools []ai.AgentTool
174 if slices.Contains(agent.AllowedTools, AgentToolName) {
175 agentTool, err := c.agentTool()
176 if err != nil {
177 return nil, err
178 }
179 allTools = append(allTools, agentTool)
180 }
181
182 allTools = append(allTools,
183 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
184 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
185 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
186 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
187 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
188 tools.NewGlobTool(c.cfg.WorkingDir()),
189 tools.NewGrepTool(c.cfg.WorkingDir()),
190 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
191 tools.NewSourcegraphTool(nil),
192 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
193 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
194 )
195
196 var filteredTools []ai.AgentTool
197 for _, tool := range allTools {
198 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
199 filteredTools = append(filteredTools, tool)
200 }
201 }
202
203 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
204
205 for _, mcpTool := range mcpTools {
206 if agent.AllowedMCP == nil {
207 // No MCP restrictions
208 filteredTools = append(filteredTools, mcpTool)
209 } else if len(agent.AllowedMCP) == 0 {
210 // no mcps allowed
211 break
212 }
213
214 for mcp, tools := range agent.AllowedMCP {
215 if mcp == mcpTool.MCP() {
216 if len(tools) == 0 {
217 filteredTools = append(filteredTools, mcpTool)
218 }
219 for _, t := range tools {
220 if t == mcpTool.MCPToolName() {
221 filteredTools = append(filteredTools, mcpTool)
222 }
223 }
224 break
225 }
226 }
227 }
228
229 return filteredTools, nil
230}
231
232// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
233func (c *coordinator) buildAgentModels() (Model, Model, error) {
234 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
235 if !ok {
236 return Model{}, Model{}, errors.New("large model not selected")
237 }
238 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
239 if !ok {
240 return Model{}, Model{}, errors.New("small model not selected")
241 }
242
243 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
244 if !ok {
245 return Model{}, Model{}, errors.New("large model provider not configured")
246 }
247
248 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
249 if err != nil {
250 return Model{}, Model{}, err
251 }
252
253 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
254 if !ok {
255 return Model{}, Model{}, errors.New("large model provider not configured")
256 }
257
258 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
259 if err != nil {
260 return Model{}, Model{}, err
261 }
262
263 var largeCatwalkModel *catwalk.Model
264 var smallCatwalkModel *catwalk.Model
265
266 for _, m := range largeProviderCfg.Models {
267 if m.ID == largeModelCfg.Model {
268 largeCatwalkModel = &m
269 }
270 }
271 for _, m := range smallProviderCfg.Models {
272 if m.ID == smallModelCfg.Model {
273 smallCatwalkModel = &m
274 }
275 }
276
277 if largeCatwalkModel == nil {
278 return Model{}, Model{}, errors.New("large model not found in provider config")
279 }
280
281 if smallCatwalkModel == nil {
282 return Model{}, Model{}, errors.New("snall model not found in provider config")
283 }
284
285 largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
286 if err != nil {
287 return Model{}, Model{}, err
288 }
289 smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
290 if err != nil {
291 return Model{}, Model{}, err
292 }
293
294 return Model{
295 Model: largeModel,
296 CatwalkCfg: *largeCatwalkModel,
297 ModelCfg: largeModelCfg,
298 }, Model{
299 Model: smallModel,
300 CatwalkCfg: *smallCatwalkModel,
301 ModelCfg: smallModelCfg,
302 }, nil
303}
304
305func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
306 hasBearerAuth := false
307 for key := range headers {
308 if strings.ToLower(key) == "authorization" {
309 hasBearerAuth = true
310 break
311 }
312 }
313 if hasBearerAuth {
314 apiKey = "" // clear apiKey to avoid using X-Api-Key header
315 }
316
317 var opts []anthropic.Option
318
319 if apiKey != "" {
320 // Use standard X-Api-Key header
321 opts = append(opts, anthropic.WithAPIKey(apiKey))
322 }
323
324 if len(headers) > 0 {
325 opts = append(opts, anthropic.WithHeaders(headers))
326 }
327
328 if baseURL != "" {
329 opts = append(opts, anthropic.WithBaseURL(baseURL))
330 }
331
332 if c.cfg.Options.Debug {
333 httpClient := log.NewHTTPClient()
334 opts = append(opts, anthropic.WithHTTPClient(httpClient))
335 }
336
337 return anthropic.New(opts...)
338}
339
340func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
341 opts := []openai.Option{
342 openai.WithAPIKey(apiKey),
343 }
344 if c.cfg.Options.Debug {
345 httpClient := log.NewHTTPClient()
346 opts = append(opts, openai.WithHTTPClient(httpClient))
347 }
348 if len(headers) > 0 {
349 opts = append(opts, openai.WithHeaders(headers))
350 }
351 if baseURL != "" {
352 opts = append(opts, openai.WithBaseURL(baseURL))
353 }
354 return openai.New(opts...)
355}
356
357func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
358 opts := []openrouter.Option{
359 openrouter.WithAPIKey(apiKey),
360 }
361 if c.cfg.Options.Debug {
362 httpClient := log.NewHTTPClient()
363 opts = append(opts, openrouter.WithHTTPClient(httpClient))
364 }
365 if len(headers) > 0 {
366 opts = append(opts, openrouter.WithHeaders(headers))
367 }
368 return openrouter.New(opts...)
369}
370
371func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
372 opts := []openaicompat.Option{
373 openaicompat.WithBaseURL(baseURL),
374 openaicompat.WithAPIKey(apiKey),
375 }
376 if c.cfg.Options.Debug {
377 httpClient := log.NewHTTPClient()
378 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
379 }
380 if len(headers) > 0 {
381 opts = append(opts, openaicompat.WithHeaders(headers))
382 }
383
384 return openaicompat.New(opts...)
385}
386
387func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) ai.Provider {
388 opts := []azure.Option{
389 azure.WithBaseURL(baseURL),
390 azure.WithAPIKey(apiKey),
391 }
392 if c.cfg.Options.Debug {
393 httpClient := log.NewHTTPClient()
394 opts = append(opts, azure.WithHTTPClient(httpClient))
395 }
396 if options == nil {
397 options = make(map[string]string)
398 }
399 if apiVersion, ok := options["apiVersion"]; ok {
400 opts = append(opts, azure.WithAPIVersion(apiVersion))
401 }
402 if len(headers) > 0 {
403 opts = append(opts, azure.WithHeaders(headers))
404 }
405
406 return azure.New(opts...)
407}
408
409// TODO: add baseURL for google
410func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
411 opts := []google.Option{
412 google.WithAPIKey(apiKey),
413 }
414 if c.cfg.Options.Debug {
415 httpClient := log.NewHTTPClient()
416 opts = append(opts, google.WithHTTPClient(httpClient))
417 }
418 if len(headers) > 0 {
419 opts = append(opts, google.WithHeaders(headers))
420 }
421 return google.New(opts...)
422}
423
424func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
425 if model.Think {
426 return true
427 }
428
429 if model.ProviderOptions == nil {
430 return false
431 }
432
433 opts, err := anthropic.ParseOptions(model.ProviderOptions)
434 if err != nil {
435 return false
436 }
437 if opts.Thinking != nil {
438 return true
439 }
440 return false
441}
442
443func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
444 headers := providerCfg.ExtraHeaders
445
446 // handle special headers for anthropic
447 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
448 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
449 }
450
451 // TODO: make sure we have
452 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
453 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
454 var provider ai.Provider
455 switch providerCfg.Type {
456 case openai.Name:
457 provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
458 case anthropic.Name:
459 provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
460 case openrouter.Name:
461 provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
462 case azure.Name:
463 provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
464 case google.Name:
465 provider = c.buildGoogleProvider(baseURL, apiKey, headers)
466 case openaicompat.Name:
467 provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
468 default:
469 return nil, errors.New("provider type not supported")
470 }
471 return provider, nil
472}
473
474func (c *coordinator) Cancel(sessionID string) {
475 c.currentAgent.Cancel(sessionID)
476}
477
478func (c *coordinator) CancelAll() {
479 c.currentAgent.CancelAll()
480}
481
482func (c *coordinator) ClearQueue(sessionID string) {
483 c.currentAgent.ClearQueue(sessionID)
484}
485
486func (c *coordinator) IsBusy() bool {
487 return c.currentAgent.IsBusy()
488}
489
490func (c *coordinator) IsSessionBusy(sessionID string) bool {
491 return c.currentAgent.IsSessionBusy(sessionID)
492}
493
494func (c *coordinator) Model() Model {
495 return c.currentAgent.Model()
496}
497
498func (c *coordinator) UpdateModels() error {
499 // build the models again so we make sure we get the latest config
500 large, small, err := c.buildAgentModels()
501 if err != nil {
502 return err
503 }
504 c.currentAgent.SetModels(large, small)
505
506 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
507 if !ok {
508 return errors.New("coder agent not configured")
509 }
510
511 tools, err := c.buildTools(agentCfg)
512 if err != nil {
513 return err
514 }
515 c.currentAgent.SetTools(tools)
516 return nil
517}
518
519func (c *coordinator) QueuedPrompts(sessionID string) int {
520 return c.currentAgent.QueuedPrompts(sessionID)
521}
522
523func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
524 return c.currentAgent.Summarize(ctx, sessionID)
525}