1package agent
2
3import (
4 "context"
5 "errors"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/agent/prompt"
11 "github.com/charmbracelet/crush/internal/agent/tools"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/csync"
14 "github.com/charmbracelet/crush/internal/history"
15 "github.com/charmbracelet/crush/internal/log"
16 "github.com/charmbracelet/crush/internal/lsp"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/session"
20 "github.com/charmbracelet/fantasy/ai"
21 "github.com/charmbracelet/fantasy/anthropic"
22 "github.com/charmbracelet/fantasy/google"
23 "github.com/charmbracelet/fantasy/openai"
24 "github.com/charmbracelet/fantasy/openaicompat"
25 "github.com/charmbracelet/fantasy/openrouter"
26)
27
28type Coordinator interface {
29 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
30 // SetMainAgent(string)
31 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
32 Cancel(sessionID string)
33 CancelAll()
34 IsSessionBusy(sessionID string) bool
35 IsBusy() bool
36 QueuedPrompts(sessionID string) int
37 ClearQueue(sessionID string)
38 Summarize(context.Context, string) error
39 Model() Model
40 UpdateModels() error
41}
42
43type coordinator struct {
44 cfg *config.Config
45 sessions session.Service
46 messages message.Service
47 permissions permission.Service
48 history history.Service
49 lspClients *csync.Map[string, *lsp.Client]
50
51 currentAgent SessionAgent
52 agents map[string]SessionAgent
53}
54
55func NewCoordinator(
56 cfg *config.Config,
57 sessions session.Service,
58 messages message.Service,
59 permissions permission.Service,
60 history history.Service,
61 lspClients *csync.Map[string, *lsp.Client],
62) (Coordinator, error) {
63 c := &coordinator{
64 cfg: cfg,
65 sessions: sessions,
66 messages: messages,
67 permissions: permissions,
68 history: history,
69 lspClients: lspClients,
70 agents: make(map[string]SessionAgent),
71 }
72
73 agentCfg, ok := cfg.Agents[config.AgentCoder]
74 if !ok {
75 return nil, errors.New("coder agent not configured")
76 }
77
78 // TODO: make this dynamic when we support multiple agents
79 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
80 if err != nil {
81 return nil, err
82 }
83
84 agent, err := c.buildAgent(prompt, agentCfg)
85 if err != nil {
86 return nil, err
87 }
88 c.currentAgent = agent
89 c.agents[config.AgentCoder] = agent
90 return c, nil
91}
92
93// Run implements Coordinator.
94func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
95 model := c.currentAgent.Model()
96 maxTokens := model.CatwalkCfg.DefaultMaxTokens
97 if model.ModelCfg.MaxTokens != 0 {
98 maxTokens = model.ModelCfg.MaxTokens
99 }
100
101 if !model.CatwalkCfg.SupportsImages && attachments != nil {
102 attachments = nil
103 }
104
105 return c.currentAgent.Run(ctx, SessionAgentCall{
106 SessionID: sessionID,
107 Prompt: prompt,
108 Attachments: attachments,
109 MaxOutputTokens: maxTokens,
110 ProviderOptions: c.getProviderOptions(model),
111 Temperature: model.ModelCfg.Temperature,
112 TopP: model.ModelCfg.TopP,
113 TopK: model.ModelCfg.TopK,
114 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
115 PresencePenalty: model.ModelCfg.PresencePenalty,
116 })
117}
118
119func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions {
120 options := ai.ProviderOptions{}
121
122 switch model.Model.Provider() {
123 case openai.Name:
124 parsed, err := openai.ParseOptions(model.ModelCfg.ProviderOptions)
125 if err == nil {
126 options[openai.Name] = parsed
127 }
128 case anthropic.Name:
129 parsed, err := anthropic.ParseOptions(model.ModelCfg.ProviderOptions)
130 if err == nil {
131 options[anthropic.Name] = parsed
132 }
133 case openrouter.Name:
134 parsed, err := openrouter.ParseOptions(model.ModelCfg.ProviderOptions)
135 if err == nil {
136 options[openrouter.Name] = parsed
137 }
138 case google.Name:
139 parsed, err := google.ParseOptions(model.ModelCfg.ProviderOptions)
140 if err == nil {
141 options[google.Name] = parsed
142 }
143 case openaicompat.Name:
144 parsed, err := openaicompat.ParseOptions(model.ModelCfg.ProviderOptions)
145 if err == nil {
146 options[openaicompat.Name] = parsed
147 }
148 }
149
150 return options
151}
152
153func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
154 large, small, err := c.buildAgentModels()
155 if err != nil {
156 return nil, err
157 }
158
159 systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
160 if err != nil {
161 return nil, err
162 }
163
164 tools, err := c.buildTools(agent)
165 if err != nil {
166 return nil, err
167 }
168 return NewSessionAgent(large, small, systemPrompt, c.sessions, c.messages, tools...), nil
169}
170
171func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
172 var allTools []ai.AgentTool
173 if slices.Contains(agent.AllowedTools, AgentToolName) {
174 agentTool, err := c.agentTool()
175 if err != nil {
176 return nil, err
177 }
178 allTools = append(allTools, agentTool)
179 }
180
181 allTools = append(allTools,
182 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
183 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
184 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
185 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
186 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
187 tools.NewGlobTool(c.cfg.WorkingDir()),
188 tools.NewGrepTool(c.cfg.WorkingDir()),
189 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
190 tools.NewSourcegraphTool(nil),
191 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
192 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
193 )
194
195 var filteredTools []ai.AgentTool
196 for _, tool := range allTools {
197 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
198 filteredTools = append(filteredTools, tool)
199 }
200 }
201
202 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
203
204 for _, mcpTool := range mcpTools {
205 if agent.AllowedMCP == nil {
206 // No MCP restrictions
207 filteredTools = append(filteredTools, mcpTool)
208 } else if len(agent.AllowedMCP) == 0 {
209 // no mcps allowed
210 break
211 }
212
213 for mcp, tools := range agent.AllowedMCP {
214 if mcp == mcpTool.MCP() {
215 if len(tools) == 0 {
216 filteredTools = append(filteredTools, mcpTool)
217 }
218 for _, t := range tools {
219 if t == mcpTool.MCPToolName() {
220 filteredTools = append(filteredTools, mcpTool)
221 }
222 }
223 break
224 }
225 }
226 }
227
228 return filteredTools, nil
229}
230
231// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
232func (c *coordinator) buildAgentModels() (Model, Model, error) {
233 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
234 if !ok {
235 return Model{}, Model{}, errors.New("large model not selected")
236 }
237 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
238 if !ok {
239 return Model{}, Model{}, errors.New("small model not selected")
240 }
241
242 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
243 if !ok {
244 return Model{}, Model{}, errors.New("large model provider not configured")
245 }
246
247 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
248 if err != nil {
249 return Model{}, Model{}, err
250 }
251
252 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
253 if !ok {
254 return Model{}, Model{}, errors.New("large model provider not configured")
255 }
256
257 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
258 if err != nil {
259 return Model{}, Model{}, err
260 }
261
262 var largeCatwalkModel *catwalk.Model
263 var smallCatwalkModel *catwalk.Model
264
265 for _, m := range largeProviderCfg.Models {
266 if m.ID == largeModelCfg.Model {
267 largeCatwalkModel = &m
268 }
269 }
270 for _, m := range smallProviderCfg.Models {
271 if m.ID == smallModelCfg.Model {
272 smallCatwalkModel = &m
273 }
274 }
275
276 if largeCatwalkModel == nil {
277 return Model{}, Model{}, errors.New("large model not found in provider config")
278 }
279
280 if smallCatwalkModel == nil {
281 return Model{}, Model{}, errors.New("snall model not found in provider config")
282 }
283
284 largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
285 if err != nil {
286 return Model{}, Model{}, err
287 }
288 smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
289 if err != nil {
290 return Model{}, Model{}, err
291 }
292
293 return Model{
294 Model: largeModel,
295 CatwalkCfg: *largeCatwalkModel,
296 ModelCfg: largeModelCfg,
297 }, Model{
298 Model: smallModel,
299 CatwalkCfg: *smallCatwalkModel,
300 ModelCfg: smallModelCfg,
301 }, nil
302}
303
304func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
305 hasBearerAuth := false
306 for key := range headers {
307 if strings.ToLower(key) == "authorization" {
308 hasBearerAuth = true
309 break
310 }
311 }
312 if hasBearerAuth {
313 apiKey = "" // clear apiKey to avoid using X-Api-Key header
314 }
315
316 var opts []anthropic.Option
317
318 if apiKey != "" {
319 // Use standard X-Api-Key header
320 opts = append(opts, anthropic.WithAPIKey(apiKey))
321 }
322
323 if len(headers) > 0 {
324 opts = append(opts, anthropic.WithHeaders(headers))
325 }
326
327 if baseURL != "" {
328 opts = append(opts, anthropic.WithBaseURL(baseURL))
329 }
330
331 if c.cfg.Options.Debug {
332 httpClient := log.NewHTTPClient()
333 opts = append(opts, anthropic.WithHTTPClient(httpClient))
334 }
335
336 return anthropic.New(opts...)
337}
338
339func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
340 opts := []openai.Option{
341 openai.WithAPIKey(apiKey),
342 }
343 if c.cfg.Options.Debug {
344 httpClient := log.NewHTTPClient()
345 opts = append(opts, openai.WithHTTPClient(httpClient))
346 }
347 if len(headers) > 0 {
348 opts = append(opts, openai.WithHeaders(headers))
349 }
350 if baseURL != "" {
351 opts = append(opts, openai.WithBaseURL(baseURL))
352 }
353 return openai.New(opts...)
354}
355
356func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
357 opts := []openrouter.Option{
358 openrouter.WithAPIKey(apiKey),
359 openrouter.WithLanguageUniqueToolCallIds(),
360 }
361 if c.cfg.Options.Debug {
362 httpClient := log.NewHTTPClient()
363 opts = append(opts, openrouter.WithHTTPClient(httpClient))
364 }
365 if len(headers) > 0 {
366 opts = append(opts, openrouter.WithHeaders(headers))
367 }
368 return openrouter.New(opts...)
369}
370
371func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
372 opts := []openaicompat.Option{
373 openaicompat.WithAPIKey(apiKey),
374 }
375 if c.cfg.Options.Debug {
376 httpClient := log.NewHTTPClient()
377 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
378 }
379 if len(headers) > 0 {
380 opts = append(opts, openaicompat.WithHeaders(headers))
381 }
382
383 return openaicompat.New(baseURL, opts...)
384}
385
386// TODO: add baseURL for google
387func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
388 opts := []google.Option{
389 google.WithAPIKey(apiKey),
390 }
391 if c.cfg.Options.Debug {
392 httpClient := log.NewHTTPClient()
393 opts = append(opts, google.WithHTTPClient(httpClient))
394 }
395 if len(headers) > 0 {
396 opts = append(opts, google.WithHeaders(headers))
397 }
398 return google.New(opts...)
399}
400
401func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
402 if model.Think {
403 return true
404 }
405
406 if model.ProviderOptions == nil {
407 return false
408 }
409
410 opts, err := anthropic.ParseOptions(model.ProviderOptions)
411 if err != nil {
412 return false
413 }
414 if opts.Thinking != nil {
415 return true
416 }
417 return false
418}
419
420func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
421 headers := providerCfg.ExtraHeaders
422
423 // handle special headers for anthropic
424 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
425 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
426 }
427
428 // TODO: make sure we have
429 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
430 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
431 var provider ai.Provider
432 switch providerCfg.Type {
433 case openai.Name:
434 provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
435 case anthropic.Name:
436 provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
437 case openrouter.Name:
438 provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
439 case google.Name:
440 provider = c.buildGoogleProvider(baseURL, apiKey, headers)
441 case openaicompat.Name:
442 provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
443 default:
444 return nil, errors.New("provider type not supported")
445 }
446 return provider, nil
447}
448
449func (c *coordinator) Cancel(sessionID string) {
450 c.currentAgent.Cancel(sessionID)
451}
452
453func (c *coordinator) CancelAll() {
454 c.currentAgent.CancelAll()
455}
456
457func (c *coordinator) ClearQueue(sessionID string) {
458 c.currentAgent.ClearQueue(sessionID)
459}
460
461func (c *coordinator) IsBusy() bool {
462 return c.currentAgent.IsBusy()
463}
464
465func (c *coordinator) IsSessionBusy(sessionID string) bool {
466 return c.currentAgent.IsSessionBusy(sessionID)
467}
468
469func (c *coordinator) Model() Model {
470 return c.currentAgent.Model()
471}
472
473func (c *coordinator) UpdateModels() error {
474 // build the models again so we make sure we get the latest config
475 large, small, err := c.buildAgentModels()
476 if err != nil {
477 return err
478 }
479 c.currentAgent.SetModels(large, small)
480
481 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
482 if !ok {
483 return errors.New("coder agent not configured")
484 }
485
486 tools, err := c.buildTools(agentCfg)
487 if err != nil {
488 return err
489 }
490 c.currentAgent.SetTools(tools)
491 return nil
492}
493
494func (c *coordinator) QueuedPrompts(sessionID string) int {
495 return c.currentAgent.QueuedPrompts(sessionID)
496}
497
498func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
499 return c.currentAgent.Summarize(ctx, sessionID)
500}