1package agent
2
3import (
4 "context"
5 "errors"
6 "slices"
7 "strings"
8
9 "github.com/charmbracelet/catwalk/pkg/catwalk"
10 "github.com/charmbracelet/crush/internal/agent/prompt"
11 "github.com/charmbracelet/crush/internal/agent/tools"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/csync"
14 "github.com/charmbracelet/crush/internal/history"
15 "github.com/charmbracelet/crush/internal/log"
16 "github.com/charmbracelet/crush/internal/lsp"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/session"
20 "github.com/charmbracelet/fantasy/ai"
21 "github.com/charmbracelet/fantasy/anthropic"
22 "github.com/charmbracelet/fantasy/google"
23 "github.com/charmbracelet/fantasy/openai"
24 "github.com/charmbracelet/fantasy/openaicompat"
25 "github.com/charmbracelet/fantasy/openrouter"
26)
27
28type Coordinator interface {
29 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
30 // SetMainAgent(string)
31 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
32 Cancel(sessionID string)
33 CancelAll()
34 IsSessionBusy(sessionID string) bool
35 IsBusy() bool
36 QueuedPrompts(sessionID string) int
37 ClearQueue(sessionID string)
38 Summarize(context.Context, string) error
39 Model() Model
40 UpdateModels() error
41}
42
43type coordinator struct {
44 cfg *config.Config
45 sessions session.Service
46 messages message.Service
47 permissions permission.Service
48 history history.Service
49 lspClients *csync.Map[string, *lsp.Client]
50
51 currentAgent SessionAgent
52 agents map[string]SessionAgent
53}
54
55func NewCoordinator(
56 cfg *config.Config,
57 sessions session.Service,
58 messages message.Service,
59 permissions permission.Service,
60 history history.Service,
61 lspClients *csync.Map[string, *lsp.Client],
62) (Coordinator, error) {
63 c := &coordinator{
64 cfg: cfg,
65 sessions: sessions,
66 messages: messages,
67 permissions: permissions,
68 history: history,
69 lspClients: lspClients,
70 agents: make(map[string]SessionAgent),
71 }
72
73 agentCfg, ok := cfg.Agents[config.AgentCoder]
74 if !ok {
75 return nil, errors.New("coder agent not configured")
76 }
77
78 // TODO: make this dynamic when we support multiple agents
79 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
80 if err != nil {
81 return nil, err
82 }
83
84 agent, err := c.buildAgent(prompt, agentCfg)
85 if err != nil {
86 return nil, err
87 }
88 c.currentAgent = agent
89 c.agents[config.AgentCoder] = agent
90 return c, nil
91}
92
93// Run implements Coordinator.
94func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
95 model := c.currentAgent.Model()
96 maxTokens := model.CatwalkCfg.DefaultMaxTokens
97 if model.ModelCfg.MaxTokens != 0 {
98 maxTokens = model.ModelCfg.MaxTokens
99 }
100
101 if !model.CatwalkCfg.SupportsImages && attachments != nil {
102 attachments = nil
103 }
104
105 return c.currentAgent.Run(ctx, SessionAgentCall{
106 SessionID: sessionID,
107 Prompt: prompt,
108 Attachments: attachments,
109 MaxOutputTokens: maxTokens,
110 ProviderOptions: c.getProviderOptions(model),
111 Temperature: model.ModelCfg.Temperature,
112 TopP: model.ModelCfg.TopP,
113 TopK: model.ModelCfg.TopK,
114 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
115 PresencePenalty: model.ModelCfg.PresencePenalty,
116 })
117}
118
119func (c *coordinator) getProviderOptions(model Model) ai.ProviderOptions {
120 options := ai.ProviderOptions{}
121
122 switch model.Model.Provider() {
123 case openai.Name:
124 parsed, err := openai.ParseOptions(model.ModelCfg.ProviderOptions)
125 if err == nil {
126 options[openai.Name] = parsed
127 }
128 case anthropic.Name:
129 parsed, err := anthropic.ParseOptions(model.ModelCfg.ProviderOptions)
130 if err == nil {
131 options[anthropic.Name] = parsed
132 }
133 case openrouter.Name:
134 parsed, err := openrouter.ParseOptions(model.ModelCfg.ProviderOptions)
135 if err == nil {
136 options[openrouter.Name] = parsed
137 }
138 case google.Name:
139 parsed, err := google.ParseOptions(model.ModelCfg.ProviderOptions)
140 if err == nil {
141 options[google.Name] = parsed
142 }
143 case openaicompat.Name:
144 parsed, err := openaicompat.ParseOptions(model.ModelCfg.ProviderOptions)
145 if err == nil {
146 options[openaicompat.Name] = parsed
147 }
148 }
149
150 return options
151}
152
153func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
154 large, small, err := c.buildAgentModels()
155 if err != nil {
156 return nil, err
157 }
158
159 systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
160 if err != nil {
161 return nil, err
162 }
163
164 tools, err := c.buildTools(agent)
165 if err != nil {
166 return nil, err
167 }
168 return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
169}
170
171func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
172 var allTools []ai.AgentTool
173 if slices.Contains(agent.AllowedTools, AgentToolName) {
174 agentTool, err := c.agentTool()
175 if err != nil {
176 return nil, err
177 }
178 allTools = append(allTools, agentTool)
179 }
180
181 allTools = append(allTools,
182 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
183 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
184 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
185 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
186 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
187 tools.NewGlobTool(c.cfg.WorkingDir()),
188 tools.NewGrepTool(c.cfg.WorkingDir()),
189 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
190 tools.NewSourcegraphTool(nil),
191 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
192 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
193 )
194
195 var filteredTools []ai.AgentTool
196 for _, tool := range allTools {
197 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
198 filteredTools = append(filteredTools, tool)
199 }
200 }
201
202 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
203
204 for _, mcpTool := range mcpTools {
205 if agent.AllowedMCP == nil {
206 // No MCP restrictions
207 filteredTools = append(filteredTools, mcpTool)
208 } else if len(agent.AllowedMCP) == 0 {
209 // no mcps allowed
210 break
211 }
212
213 for mcp, tools := range agent.AllowedMCP {
214 if mcp == mcpTool.MCP() {
215 if len(tools) == 0 {
216 filteredTools = append(filteredTools, mcpTool)
217 }
218 for _, t := range tools {
219 if t == mcpTool.MCPToolName() {
220 filteredTools = append(filteredTools, mcpTool)
221 }
222 }
223 break
224 }
225 }
226 }
227
228 return filteredTools, nil
229}
230
231// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
232func (c *coordinator) buildAgentModels() (Model, Model, error) {
233 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
234 if !ok {
235 return Model{}, Model{}, errors.New("large model not selected")
236 }
237 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
238 if !ok {
239 return Model{}, Model{}, errors.New("small model not selected")
240 }
241
242 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
243 if !ok {
244 return Model{}, Model{}, errors.New("large model provider not configured")
245 }
246
247 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
248 if err != nil {
249 return Model{}, Model{}, err
250 }
251
252 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
253 if !ok {
254 return Model{}, Model{}, errors.New("large model provider not configured")
255 }
256
257 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
258 if err != nil {
259 return Model{}, Model{}, err
260 }
261
262 var largeCatwalkModel *catwalk.Model
263 var smallCatwalkModel *catwalk.Model
264
265 for _, m := range largeProviderCfg.Models {
266 if m.ID == largeModelCfg.Model {
267 largeCatwalkModel = &m
268 }
269 }
270 for _, m := range smallProviderCfg.Models {
271 if m.ID == smallModelCfg.Model {
272 smallCatwalkModel = &m
273 }
274 }
275
276 if largeCatwalkModel == nil {
277 return Model{}, Model{}, errors.New("large model not found in provider config")
278 }
279
280 if smallCatwalkModel == nil {
281 return Model{}, Model{}, errors.New("snall model not found in provider config")
282 }
283
284 largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
285 if err != nil {
286 return Model{}, Model{}, err
287 }
288 smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
289 if err != nil {
290 return Model{}, Model{}, err
291 }
292
293 return Model{
294 Model: largeModel,
295 CatwalkCfg: *largeCatwalkModel,
296 ModelCfg: largeModelCfg,
297 }, Model{
298 Model: smallModel,
299 CatwalkCfg: *smallCatwalkModel,
300 ModelCfg: smallModelCfg,
301 }, nil
302}
303
304func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
305 hasBearerAuth := false
306 for key := range headers {
307 if strings.ToLower(key) == "authorization" {
308 hasBearerAuth = true
309 break
310 }
311 }
312 if hasBearerAuth {
313 apiKey = "" // clear apiKey to avoid using X-Api-Key header
314 }
315
316 var opts []anthropic.Option
317
318 if apiKey != "" {
319 // Use standard X-Api-Key header
320 opts = append(opts, anthropic.WithAPIKey(apiKey))
321 }
322
323 if len(headers) > 0 {
324 opts = append(opts, anthropic.WithHeaders(headers))
325 }
326
327 if baseURL != "" {
328 opts = append(opts, anthropic.WithBaseURL(baseURL))
329 }
330
331 if c.cfg.Options.Debug {
332 httpClient := log.NewHTTPClient()
333 opts = append(opts, anthropic.WithHTTPClient(httpClient))
334 }
335
336 return anthropic.New(opts...)
337}
338
339func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
340 opts := []openai.Option{
341 openai.WithAPIKey(apiKey),
342 }
343 if c.cfg.Options.Debug {
344 httpClient := log.NewHTTPClient()
345 opts = append(opts, openai.WithHTTPClient(httpClient))
346 }
347 if len(headers) > 0 {
348 opts = append(opts, openai.WithHeaders(headers))
349 }
350 if baseURL != "" {
351 opts = append(opts, openai.WithBaseURL(baseURL))
352 }
353 return openai.New(opts...)
354}
355
356func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
357 opts := []openrouter.Option{
358 openrouter.WithAPIKey(apiKey),
359 }
360 if c.cfg.Options.Debug {
361 httpClient := log.NewHTTPClient()
362 opts = append(opts, openrouter.WithHTTPClient(httpClient))
363 }
364 if len(headers) > 0 {
365 opts = append(opts, openrouter.WithHeaders(headers))
366 }
367 return openrouter.New(opts...)
368}
369
370func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
371 opts := []openaicompat.Option{
372 openaicompat.WithAPIKey(apiKey),
373 }
374 if c.cfg.Options.Debug {
375 httpClient := log.NewHTTPClient()
376 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
377 }
378 if len(headers) > 0 {
379 opts = append(opts, openaicompat.WithHeaders(headers))
380 }
381
382 return openaicompat.New(baseURL, opts...)
383}
384
385// TODO: add baseURL for google
386func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
387 opts := []google.Option{
388 google.WithAPIKey(apiKey),
389 }
390 if c.cfg.Options.Debug {
391 httpClient := log.NewHTTPClient()
392 opts = append(opts, google.WithHTTPClient(httpClient))
393 }
394 if len(headers) > 0 {
395 opts = append(opts, google.WithHeaders(headers))
396 }
397 return google.New(opts...)
398}
399
400func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
401 if model.Think {
402 return true
403 }
404
405 if model.ProviderOptions == nil {
406 return false
407 }
408
409 opts, err := anthropic.ParseOptions(model.ProviderOptions)
410 if err != nil {
411 return false
412 }
413 if opts.Thinking != nil {
414 return true
415 }
416 return false
417}
418
419func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
420 headers := providerCfg.ExtraHeaders
421
422 // handle special headers for anthropic
423 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
424 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
425 }
426
427 // TODO: make sure we have
428 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
429 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
430 var provider ai.Provider
431 switch providerCfg.Type {
432 case openai.Name:
433 provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
434 case anthropic.Name:
435 provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
436 case openrouter.Name:
437 provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
438 case google.Name:
439 provider = c.buildGoogleProvider(baseURL, apiKey, headers)
440 case openaicompat.Name:
441 provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
442 default:
443 return nil, errors.New("provider type not supported")
444 }
445 return provider, nil
446}
447
448func (c *coordinator) Cancel(sessionID string) {
449 c.currentAgent.Cancel(sessionID)
450}
451
452func (c *coordinator) CancelAll() {
453 c.currentAgent.CancelAll()
454}
455
456func (c *coordinator) ClearQueue(sessionID string) {
457 c.currentAgent.ClearQueue(sessionID)
458}
459
460func (c *coordinator) IsBusy() bool {
461 return c.currentAgent.IsBusy()
462}
463
464func (c *coordinator) IsSessionBusy(sessionID string) bool {
465 return c.currentAgent.IsSessionBusy(sessionID)
466}
467
468func (c *coordinator) Model() Model {
469 return c.currentAgent.Model()
470}
471
472func (c *coordinator) UpdateModels() error {
473 // build the models again so we make sure we get the latest config
474 large, small, err := c.buildAgentModels()
475 if err != nil {
476 return err
477 }
478 c.currentAgent.SetModels(large, small)
479
480 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
481 if !ok {
482 return errors.New("coder agent not configured")
483 }
484
485 tools, err := c.buildTools(agentCfg)
486 if err != nil {
487 return err
488 }
489 c.currentAgent.SetTools(tools)
490 return nil
491}
492
493func (c *coordinator) QueuedPrompts(sessionID string) int {
494 return c.currentAgent.QueuedPrompts(sessionID)
495}
496
497func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
498 return c.currentAgent.Summarize(ctx, sessionID)
499}