1package agent
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "errors"
8 "log/slog"
9 "slices"
10 "strings"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/agent/prompt"
14 "github.com/charmbracelet/crush/internal/agent/tools"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/history"
18 "github.com/charmbracelet/crush/internal/log"
19 "github.com/charmbracelet/crush/internal/lsp"
20 "github.com/charmbracelet/crush/internal/message"
21 "github.com/charmbracelet/crush/internal/permission"
22 "github.com/charmbracelet/crush/internal/session"
23 "github.com/charmbracelet/fantasy/ai"
24 "github.com/charmbracelet/fantasy/anthropic"
25 "github.com/charmbracelet/fantasy/azure"
26 "github.com/charmbracelet/fantasy/google"
27 "github.com/charmbracelet/fantasy/openai"
28 "github.com/charmbracelet/fantasy/openaicompat"
29 "github.com/charmbracelet/fantasy/openrouter"
30 "github.com/qjebbs/go-jsons"
31)
32
33type Coordinator interface {
34 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
35 // SetMainAgent(string)
36 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
37 Cancel(sessionID string)
38 CancelAll()
39 IsSessionBusy(sessionID string) bool
40 IsBusy() bool
41 QueuedPrompts(sessionID string) int
42 ClearQueue(sessionID string)
43 Summarize(context.Context, string) error
44 Model() Model
45 UpdateModels() error
46}
47
48type coordinator struct {
49 cfg *config.Config
50 sessions session.Service
51 messages message.Service
52 permissions permission.Service
53 history history.Service
54 lspClients *csync.Map[string, *lsp.Client]
55
56 currentAgent SessionAgent
57 agents map[string]SessionAgent
58}
59
60func NewCoordinator(
61 cfg *config.Config,
62 sessions session.Service,
63 messages message.Service,
64 permissions permission.Service,
65 history history.Service,
66 lspClients *csync.Map[string, *lsp.Client],
67) (Coordinator, error) {
68 c := &coordinator{
69 cfg: cfg,
70 sessions: sessions,
71 messages: messages,
72 permissions: permissions,
73 history: history,
74 lspClients: lspClients,
75 agents: make(map[string]SessionAgent),
76 }
77
78 agentCfg, ok := cfg.Agents[config.AgentCoder]
79 if !ok {
80 return nil, errors.New("coder agent not configured")
81 }
82
83 // TODO: make this dynamic when we support multiple agents
84 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
85 if err != nil {
86 return nil, err
87 }
88
89 agent, err := c.buildAgent(prompt, agentCfg)
90 if err != nil {
91 return nil, err
92 }
93 c.currentAgent = agent
94 c.agents[config.AgentCoder] = agent
95 return c, nil
96}
97
98// Run implements Coordinator.
99func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
100 model := c.currentAgent.Model()
101 maxTokens := model.CatwalkCfg.DefaultMaxTokens
102 if model.ModelCfg.MaxTokens != 0 {
103 maxTokens = model.ModelCfg.MaxTokens
104 }
105
106 if !model.CatwalkCfg.SupportsImages && attachments != nil {
107 attachments = nil
108 }
109
110 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model)
111
112 return c.currentAgent.Run(ctx, SessionAgentCall{
113 SessionID: sessionID,
114 Prompt: prompt,
115 Attachments: attachments,
116 MaxOutputTokens: maxTokens,
117 ProviderOptions: mergedOptions,
118 Temperature: temp,
119 TopP: topP,
120 TopK: topK,
121 FrequencyPenalty: freqPenalty,
122 PresencePenalty: presPenalty,
123 })
124}
125
126func getProviderOptions(model Model) ai.ProviderOptions {
127 options := ai.ProviderOptions{}
128
129 cfgOpts := "{}"
130 catwalkOpts := "{}"
131
132 if model.ModelCfg.ProviderOptions != nil {
133 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
134 if err == nil {
135 cfgOpts = string(data)
136 }
137 }
138
139 if model.CatwalkCfg.Options.ProviderOptions != nil {
140 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
141 if err == nil {
142 catwalkOpts = string(data)
143 }
144 }
145
146 got, err := jsons.Merge([]string{catwalkOpts, cfgOpts})
147 if err != nil {
148 slog.Error("Could not merge call config", "err", err)
149 return options
150 }
151
152 mergedOptions := make(map[string]any)
153
154 err = json.Unmarshal([]byte(got), &mergedOptions)
155 if err != nil {
156 slog.Error("Could not create config for call", "err", err)
157 return options
158 }
159
160 switch model.Model.Provider() {
161 case openai.Name:
162 parsed, err := openai.ParseOptions(mergedOptions)
163 if err == nil {
164 options[openai.Name] = parsed
165 }
166 case anthropic.Name:
167 parsed, err := anthropic.ParseOptions(mergedOptions)
168 if err == nil {
169 options[anthropic.Name] = parsed
170 }
171 case openrouter.Name:
172 parsed, err := openrouter.ParseOptions(mergedOptions)
173 if err == nil {
174 options[openrouter.Name] = parsed
175 }
176 case google.Name:
177 parsed, err := google.ParseOptions(mergedOptions)
178 if err == nil {
179 options[google.Name] = parsed
180 }
181 case openaicompat.Name:
182 parsed, err := openaicompat.ParseOptions(mergedOptions)
183 if err == nil {
184 options[openaicompat.Name] = parsed
185 }
186 }
187
188 return options
189}
190
191func mergeCallOptions(model Model) (ai.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
192 modelOptions := getProviderOptions(model)
193 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
194 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
195 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
196 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
197 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
198 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
199}
200
201func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
202 large, small, err := c.buildAgentModels()
203 if err != nil {
204 return nil, err
205 }
206
207 systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
208 if err != nil {
209 return nil, err
210 }
211
212 tools, err := c.buildTools(agent)
213 if err != nil {
214 return nil, err
215 }
216 return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
217}
218
219func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
220 var allTools []ai.AgentTool
221 if slices.Contains(agent.AllowedTools, AgentToolName) {
222 agentTool, err := c.agentTool()
223 if err != nil {
224 return nil, err
225 }
226 allTools = append(allTools, agentTool)
227 }
228
229 allTools = append(allTools,
230 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
231 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
232 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
233 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
234 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
235 tools.NewGlobTool(c.cfg.WorkingDir()),
236 tools.NewGrepTool(c.cfg.WorkingDir()),
237 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
238 tools.NewSourcegraphTool(nil),
239 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
240 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
241 )
242
243 var filteredTools []ai.AgentTool
244 for _, tool := range allTools {
245 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
246 filteredTools = append(filteredTools, tool)
247 }
248 }
249
250 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
251
252 for _, mcpTool := range mcpTools {
253 if agent.AllowedMCP == nil {
254 // No MCP restrictions
255 filteredTools = append(filteredTools, mcpTool)
256 } else if len(agent.AllowedMCP) == 0 {
257 // no mcps allowed
258 break
259 }
260
261 for mcp, tools := range agent.AllowedMCP {
262 if mcp == mcpTool.MCP() {
263 if len(tools) == 0 {
264 filteredTools = append(filteredTools, mcpTool)
265 }
266 for _, t := range tools {
267 if t == mcpTool.MCPToolName() {
268 filteredTools = append(filteredTools, mcpTool)
269 }
270 }
271 break
272 }
273 }
274 }
275
276 return filteredTools, nil
277}
278
279// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
280func (c *coordinator) buildAgentModels() (Model, Model, error) {
281 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
282 if !ok {
283 return Model{}, Model{}, errors.New("large model not selected")
284 }
285 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
286 if !ok {
287 return Model{}, Model{}, errors.New("small model not selected")
288 }
289
290 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
291 if !ok {
292 return Model{}, Model{}, errors.New("large model provider not configured")
293 }
294
295 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
296 if err != nil {
297 return Model{}, Model{}, err
298 }
299
300 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
301 if !ok {
302 return Model{}, Model{}, errors.New("large model provider not configured")
303 }
304
305 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
306 if err != nil {
307 return Model{}, Model{}, err
308 }
309
310 var largeCatwalkModel *catwalk.Model
311 var smallCatwalkModel *catwalk.Model
312
313 for _, m := range largeProviderCfg.Models {
314 if m.ID == largeModelCfg.Model {
315 largeCatwalkModel = &m
316 }
317 }
318 for _, m := range smallProviderCfg.Models {
319 if m.ID == smallModelCfg.Model {
320 smallCatwalkModel = &m
321 }
322 }
323
324 if largeCatwalkModel == nil {
325 return Model{}, Model{}, errors.New("large model not found in provider config")
326 }
327
328 if smallCatwalkModel == nil {
329 return Model{}, Model{}, errors.New("snall model not found in provider config")
330 }
331
332 largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
333 if err != nil {
334 return Model{}, Model{}, err
335 }
336 smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
337 if err != nil {
338 return Model{}, Model{}, err
339 }
340
341 return Model{
342 Model: largeModel,
343 CatwalkCfg: *largeCatwalkModel,
344 ModelCfg: largeModelCfg,
345 }, Model{
346 Model: smallModel,
347 CatwalkCfg: *smallCatwalkModel,
348 ModelCfg: smallModelCfg,
349 }, nil
350}
351
352func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
353 hasBearerAuth := false
354 for key := range headers {
355 if strings.ToLower(key) == "authorization" {
356 hasBearerAuth = true
357 break
358 }
359 }
360 if hasBearerAuth {
361 apiKey = "" // clear apiKey to avoid using X-Api-Key header
362 }
363
364 var opts []anthropic.Option
365
366 if apiKey != "" {
367 // Use standard X-Api-Key header
368 opts = append(opts, anthropic.WithAPIKey(apiKey))
369 }
370
371 if len(headers) > 0 {
372 opts = append(opts, anthropic.WithHeaders(headers))
373 }
374
375 if baseURL != "" {
376 opts = append(opts, anthropic.WithBaseURL(baseURL))
377 }
378
379 if c.cfg.Options.Debug {
380 httpClient := log.NewHTTPClient()
381 opts = append(opts, anthropic.WithHTTPClient(httpClient))
382 }
383
384 return anthropic.New(opts...)
385}
386
387func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
388 opts := []openai.Option{
389 openai.WithAPIKey(apiKey),
390 }
391 if c.cfg.Options.Debug {
392 httpClient := log.NewHTTPClient()
393 opts = append(opts, openai.WithHTTPClient(httpClient))
394 }
395 if len(headers) > 0 {
396 opts = append(opts, openai.WithHeaders(headers))
397 }
398 if baseURL != "" {
399 opts = append(opts, openai.WithBaseURL(baseURL))
400 }
401 return openai.New(opts...)
402}
403
404func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
405 opts := []openrouter.Option{
406 openrouter.WithAPIKey(apiKey),
407 }
408 if c.cfg.Options.Debug {
409 httpClient := log.NewHTTPClient()
410 opts = append(opts, openrouter.WithHTTPClient(httpClient))
411 }
412 if len(headers) > 0 {
413 opts = append(opts, openrouter.WithHeaders(headers))
414 }
415 return openrouter.New(opts...)
416}
417
418func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
419 opts := []openaicompat.Option{
420 openaicompat.WithBaseURL(baseURL),
421 openaicompat.WithAPIKey(apiKey),
422 }
423 if c.cfg.Options.Debug {
424 httpClient := log.NewHTTPClient()
425 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
426 }
427 if len(headers) > 0 {
428 opts = append(opts, openaicompat.WithHeaders(headers))
429 }
430
431 return openaicompat.New(opts...)
432}
433
434func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) ai.Provider {
435 opts := []azure.Option{
436 azure.WithBaseURL(baseURL),
437 azure.WithAPIKey(apiKey),
438 }
439 if c.cfg.Options.Debug {
440 httpClient := log.NewHTTPClient()
441 opts = append(opts, azure.WithHTTPClient(httpClient))
442 }
443 if options == nil {
444 options = make(map[string]string)
445 }
446 if apiVersion, ok := options["apiVersion"]; ok {
447 opts = append(opts, azure.WithAPIVersion(apiVersion))
448 }
449 if len(headers) > 0 {
450 opts = append(opts, azure.WithHeaders(headers))
451 }
452
453 return azure.New(opts...)
454}
455
456// TODO: add baseURL for google
457func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
458 opts := []google.Option{
459 google.WithAPIKey(apiKey),
460 }
461 if c.cfg.Options.Debug {
462 httpClient := log.NewHTTPClient()
463 opts = append(opts, google.WithHTTPClient(httpClient))
464 }
465 if len(headers) > 0 {
466 opts = append(opts, google.WithHeaders(headers))
467 }
468 return google.New(opts...)
469}
470
471func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
472 if model.Think {
473 return true
474 }
475
476 if model.ProviderOptions == nil {
477 return false
478 }
479
480 opts, err := anthropic.ParseOptions(model.ProviderOptions)
481 if err != nil {
482 return false
483 }
484 if opts.Thinking != nil {
485 return true
486 }
487 return false
488}
489
490func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
491 headers := providerCfg.ExtraHeaders
492
493 // handle special headers for anthropic
494 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
495 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
496 }
497
498 // TODO: make sure we have
499 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
500 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
501 var provider ai.Provider
502 switch providerCfg.Type {
503 case openai.Name:
504 provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
505 case anthropic.Name:
506 provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
507 case openrouter.Name:
508 provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
509 case azure.Name:
510 provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
511 case google.Name:
512 provider = c.buildGoogleProvider(baseURL, apiKey, headers)
513 case openaicompat.Name:
514 provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
515 default:
516 return nil, errors.New("provider type not supported")
517 }
518 return provider, nil
519}
520
521func (c *coordinator) Cancel(sessionID string) {
522 c.currentAgent.Cancel(sessionID)
523}
524
525func (c *coordinator) CancelAll() {
526 c.currentAgent.CancelAll()
527}
528
529func (c *coordinator) ClearQueue(sessionID string) {
530 c.currentAgent.ClearQueue(sessionID)
531}
532
533func (c *coordinator) IsBusy() bool {
534 return c.currentAgent.IsBusy()
535}
536
537func (c *coordinator) IsSessionBusy(sessionID string) bool {
538 return c.currentAgent.IsSessionBusy(sessionID)
539}
540
541func (c *coordinator) Model() Model {
542 return c.currentAgent.Model()
543}
544
545func (c *coordinator) UpdateModels() error {
546 // build the models again so we make sure we get the latest config
547 large, small, err := c.buildAgentModels()
548 if err != nil {
549 return err
550 }
551 c.currentAgent.SetModels(large, small)
552
553 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
554 if !ok {
555 return errors.New("coder agent not configured")
556 }
557
558 tools, err := c.buildTools(agentCfg)
559 if err != nil {
560 return err
561 }
562 c.currentAgent.SetTools(tools)
563 return nil
564}
565
566func (c *coordinator) QueuedPrompts(sessionID string) int {
567 return c.currentAgent.QueuedPrompts(sessionID)
568}
569
570func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
571 return c.currentAgent.Summarize(ctx, sessionID)
572}