1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "io"
10 "log/slog"
11 "slices"
12 "strings"
13
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/agent/prompt"
16 "github.com/charmbracelet/crush/internal/agent/tools"
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/csync"
19 "github.com/charmbracelet/crush/internal/history"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/fantasy/ai"
26 "github.com/charmbracelet/fantasy/anthropic"
27 "github.com/charmbracelet/fantasy/azure"
28 "github.com/charmbracelet/fantasy/google"
29 "github.com/charmbracelet/fantasy/openai"
30 "github.com/charmbracelet/fantasy/openaicompat"
31 "github.com/charmbracelet/fantasy/openrouter"
32 "github.com/qjebbs/go-jsons"
33)
34
35type Coordinator interface {
36 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
37 // SetMainAgent(string)
38 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
39 Cancel(sessionID string)
40 CancelAll()
41 IsSessionBusy(sessionID string) bool
42 IsBusy() bool
43 QueuedPrompts(sessionID string) int
44 ClearQueue(sessionID string)
45 Summarize(context.Context, string) error
46 Model() Model
47 UpdateModels() error
48}
49
50type coordinator struct {
51 cfg *config.Config
52 sessions session.Service
53 messages message.Service
54 permissions permission.Service
55 history history.Service
56 lspClients *csync.Map[string, *lsp.Client]
57
58 currentAgent SessionAgent
59 agents map[string]SessionAgent
60}
61
62func NewCoordinator(
63 cfg *config.Config,
64 sessions session.Service,
65 messages message.Service,
66 permissions permission.Service,
67 history history.Service,
68 lspClients *csync.Map[string, *lsp.Client],
69) (Coordinator, error) {
70 c := &coordinator{
71 cfg: cfg,
72 sessions: sessions,
73 messages: messages,
74 permissions: permissions,
75 history: history,
76 lspClients: lspClients,
77 agents: make(map[string]SessionAgent),
78 }
79
80 agentCfg, ok := cfg.Agents[config.AgentCoder]
81 if !ok {
82 return nil, errors.New("coder agent not configured")
83 }
84
85 // TODO: make this dynamic when we support multiple agents
86 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
87 if err != nil {
88 return nil, err
89 }
90
91 agent, err := c.buildAgent(prompt, agentCfg)
92 if err != nil {
93 return nil, err
94 }
95 c.currentAgent = agent
96 c.agents[config.AgentCoder] = agent
97 return c, nil
98}
99
100// Run implements Coordinator.
101func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
102 model := c.currentAgent.Model()
103 maxTokens := model.CatwalkCfg.DefaultMaxTokens
104 if model.ModelCfg.MaxTokens != 0 {
105 maxTokens = model.ModelCfg.MaxTokens
106 }
107
108 if !model.CatwalkCfg.SupportsImages && attachments != nil {
109 attachments = nil
110 }
111
112 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model)
113
114 return c.currentAgent.Run(ctx, SessionAgentCall{
115 SessionID: sessionID,
116 Prompt: prompt,
117 Attachments: attachments,
118 MaxOutputTokens: maxTokens,
119 ProviderOptions: mergedOptions,
120 Temperature: temp,
121 TopP: topP,
122 TopK: topK,
123 FrequencyPenalty: freqPenalty,
124 PresencePenalty: presPenalty,
125 })
126}
127
128func getProviderOptions(model Model) ai.ProviderOptions {
129 options := ai.ProviderOptions{}
130
131 cfgOpts := []byte("{}")
132 catwalkOpts := []byte("{}")
133
134 if model.ModelCfg.ProviderOptions != nil {
135 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
136 if err == nil {
137 cfgOpts = data
138 }
139 }
140
141 if model.CatwalkCfg.Options.ProviderOptions != nil {
142 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
143 if err == nil {
144 catwalkOpts = data
145 }
146 }
147
148 readers := []io.Reader{
149 bytes.NewReader(catwalkOpts),
150 bytes.NewReader(cfgOpts),
151 }
152
153 got, err := jsons.Merge(readers)
154 if err != nil {
155 slog.Error("Could not merge call config", "err", err)
156 return options
157 }
158
159 mergedOptions := make(map[string]any)
160
161 err = json.Unmarshal([]byte(got), &mergedOptions)
162 if err != nil {
163 slog.Error("Could not create config for call", "err", err)
164 return options
165 }
166
167 switch model.Model.Provider() {
168 case openai.Name:
169 parsed, err := openai.ParseOptions(mergedOptions)
170 if err == nil {
171 options[openai.Name] = parsed
172 }
173 case anthropic.Name:
174 parsed, err := anthropic.ParseOptions(mergedOptions)
175 if err == nil {
176 options[anthropic.Name] = parsed
177 }
178 case openrouter.Name:
179 parsed, err := openrouter.ParseOptions(mergedOptions)
180 if err == nil {
181 options[openrouter.Name] = parsed
182 }
183 case google.Name:
184 parsed, err := google.ParseOptions(mergedOptions)
185 if err == nil {
186 options[google.Name] = parsed
187 }
188 case openaicompat.Name:
189 parsed, err := openaicompat.ParseOptions(mergedOptions)
190 if err == nil {
191 options[openaicompat.Name] = parsed
192 }
193 }
194
195 return options
196}
197
198func mergeCallOptions(model Model) (ai.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
199 modelOptions := getProviderOptions(model)
200 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
201 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
202 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
203 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
204 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
205 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
206}
207
208func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
209 large, small, err := c.buildAgentModels()
210 if err != nil {
211 return nil, err
212 }
213
214 systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
215 if err != nil {
216 return nil, err
217 }
218
219 tools, err := c.buildTools(agent)
220 if err != nil {
221 return nil, err
222 }
223 return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
224}
225
226func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
227 var allTools []ai.AgentTool
228 if slices.Contains(agent.AllowedTools, AgentToolName) {
229 agentTool, err := c.agentTool()
230 if err != nil {
231 return nil, err
232 }
233 allTools = append(allTools, agentTool)
234 }
235
236 allTools = append(allTools,
237 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
238 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
239 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
240 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
241 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
242 tools.NewGlobTool(c.cfg.WorkingDir()),
243 tools.NewGrepTool(c.cfg.WorkingDir()),
244 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
245 tools.NewSourcegraphTool(nil),
246 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
247 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
248 )
249
250 var filteredTools []ai.AgentTool
251 for _, tool := range allTools {
252 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
253 filteredTools = append(filteredTools, tool)
254 }
255 }
256
257 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
258
259 for _, mcpTool := range mcpTools {
260 if agent.AllowedMCP == nil {
261 // No MCP restrictions
262 filteredTools = append(filteredTools, mcpTool)
263 } else if len(agent.AllowedMCP) == 0 {
264 // no mcps allowed
265 break
266 }
267
268 for mcp, tools := range agent.AllowedMCP {
269 if mcp == mcpTool.MCP() {
270 if len(tools) == 0 {
271 filteredTools = append(filteredTools, mcpTool)
272 }
273 for _, t := range tools {
274 if t == mcpTool.MCPToolName() {
275 filteredTools = append(filteredTools, mcpTool)
276 }
277 }
278 break
279 }
280 }
281 }
282
283 return filteredTools, nil
284}
285
286// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
287func (c *coordinator) buildAgentModels() (Model, Model, error) {
288 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
289 if !ok {
290 return Model{}, Model{}, errors.New("large model not selected")
291 }
292 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
293 if !ok {
294 return Model{}, Model{}, errors.New("small model not selected")
295 }
296
297 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
298 if !ok {
299 return Model{}, Model{}, errors.New("large model provider not configured")
300 }
301
302 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
303 if err != nil {
304 return Model{}, Model{}, err
305 }
306
307 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
308 if !ok {
309 return Model{}, Model{}, errors.New("large model provider not configured")
310 }
311
312 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
313 if err != nil {
314 return Model{}, Model{}, err
315 }
316
317 var largeCatwalkModel *catwalk.Model
318 var smallCatwalkModel *catwalk.Model
319
320 for _, m := range largeProviderCfg.Models {
321 if m.ID == largeModelCfg.Model {
322 largeCatwalkModel = &m
323 }
324 }
325 for _, m := range smallProviderCfg.Models {
326 if m.ID == smallModelCfg.Model {
327 smallCatwalkModel = &m
328 }
329 }
330
331 if largeCatwalkModel == nil {
332 return Model{}, Model{}, errors.New("large model not found in provider config")
333 }
334
335 if smallCatwalkModel == nil {
336 return Model{}, Model{}, errors.New("snall model not found in provider config")
337 }
338
339 largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
340 if err != nil {
341 return Model{}, Model{}, err
342 }
343 smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
344 if err != nil {
345 return Model{}, Model{}, err
346 }
347
348 return Model{
349 Model: largeModel,
350 CatwalkCfg: *largeCatwalkModel,
351 ModelCfg: largeModelCfg,
352 }, Model{
353 Model: smallModel,
354 CatwalkCfg: *smallCatwalkModel,
355 ModelCfg: smallModelCfg,
356 }, nil
357}
358
359func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
360 hasBearerAuth := false
361 for key := range headers {
362 if strings.ToLower(key) == "authorization" {
363 hasBearerAuth = true
364 break
365 }
366 }
367 if hasBearerAuth {
368 apiKey = "" // clear apiKey to avoid using X-Api-Key header
369 }
370
371 var opts []anthropic.Option
372
373 if apiKey != "" {
374 // Use standard X-Api-Key header
375 opts = append(opts, anthropic.WithAPIKey(apiKey))
376 }
377
378 if len(headers) > 0 {
379 opts = append(opts, anthropic.WithHeaders(headers))
380 }
381
382 if baseURL != "" {
383 opts = append(opts, anthropic.WithBaseURL(baseURL))
384 }
385
386 if c.cfg.Options.Debug {
387 httpClient := log.NewHTTPClient()
388 opts = append(opts, anthropic.WithHTTPClient(httpClient))
389 }
390
391 return anthropic.New(opts...)
392}
393
394func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
395 opts := []openai.Option{
396 openai.WithAPIKey(apiKey),
397 }
398 if c.cfg.Options.Debug {
399 httpClient := log.NewHTTPClient()
400 opts = append(opts, openai.WithHTTPClient(httpClient))
401 }
402 if len(headers) > 0 {
403 opts = append(opts, openai.WithHeaders(headers))
404 }
405 if baseURL != "" {
406 opts = append(opts, openai.WithBaseURL(baseURL))
407 }
408 return openai.New(opts...)
409}
410
411func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
412 opts := []openrouter.Option{
413 openrouter.WithAPIKey(apiKey),
414 }
415 if c.cfg.Options.Debug {
416 httpClient := log.NewHTTPClient()
417 opts = append(opts, openrouter.WithHTTPClient(httpClient))
418 }
419 if len(headers) > 0 {
420 opts = append(opts, openrouter.WithHeaders(headers))
421 }
422 return openrouter.New(opts...)
423}
424
425func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
426 opts := []openaicompat.Option{
427 openaicompat.WithBaseURL(baseURL),
428 openaicompat.WithAPIKey(apiKey),
429 }
430 if c.cfg.Options.Debug {
431 httpClient := log.NewHTTPClient()
432 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
433 }
434 if len(headers) > 0 {
435 opts = append(opts, openaicompat.WithHeaders(headers))
436 }
437
438 return openaicompat.New(opts...)
439}
440
441func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) ai.Provider {
442 opts := []azure.Option{
443 azure.WithBaseURL(baseURL),
444 azure.WithAPIKey(apiKey),
445 }
446 if c.cfg.Options.Debug {
447 httpClient := log.NewHTTPClient()
448 opts = append(opts, azure.WithHTTPClient(httpClient))
449 }
450 if options == nil {
451 options = make(map[string]string)
452 }
453 if apiVersion, ok := options["apiVersion"]; ok {
454 opts = append(opts, azure.WithAPIVersion(apiVersion))
455 }
456 if len(headers) > 0 {
457 opts = append(opts, azure.WithHeaders(headers))
458 }
459
460 return azure.New(opts...)
461}
462
463// TODO: add baseURL for google
464func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
465 opts := []google.Option{
466 google.WithAPIKey(apiKey),
467 }
468 if c.cfg.Options.Debug {
469 httpClient := log.NewHTTPClient()
470 opts = append(opts, google.WithHTTPClient(httpClient))
471 }
472 if len(headers) > 0 {
473 opts = append(opts, google.WithHeaders(headers))
474 }
475 return google.New(opts...)
476}
477
478func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
479 if model.Think {
480 return true
481 }
482
483 if model.ProviderOptions == nil {
484 return false
485 }
486
487 opts, err := anthropic.ParseOptions(model.ProviderOptions)
488 if err != nil {
489 return false
490 }
491 if opts.Thinking != nil {
492 return true
493 }
494 return false
495}
496
497func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
498 headers := providerCfg.ExtraHeaders
499
500 // handle special headers for anthropic
501 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
502 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
503 }
504
505 // TODO: make sure we have
506 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
507 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
508 var provider ai.Provider
509 switch providerCfg.Type {
510 case openai.Name:
511 provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
512 case anthropic.Name:
513 provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
514 case openrouter.Name:
515 provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
516 case azure.Name:
517 provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
518 case google.Name:
519 provider = c.buildGoogleProvider(baseURL, apiKey, headers)
520 case openaicompat.Name:
521 provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
522 default:
523 return nil, errors.New("provider type not supported")
524 }
525 return provider, nil
526}
527
528func (c *coordinator) Cancel(sessionID string) {
529 c.currentAgent.Cancel(sessionID)
530}
531
532func (c *coordinator) CancelAll() {
533 c.currentAgent.CancelAll()
534}
535
536func (c *coordinator) ClearQueue(sessionID string) {
537 c.currentAgent.ClearQueue(sessionID)
538}
539
540func (c *coordinator) IsBusy() bool {
541 return c.currentAgent.IsBusy()
542}
543
544func (c *coordinator) IsSessionBusy(sessionID string) bool {
545 return c.currentAgent.IsSessionBusy(sessionID)
546}
547
548func (c *coordinator) Model() Model {
549 return c.currentAgent.Model()
550}
551
552func (c *coordinator) UpdateModels() error {
553 // build the models again so we make sure we get the latest config
554 large, small, err := c.buildAgentModels()
555 if err != nil {
556 return err
557 }
558 c.currentAgent.SetModels(large, small)
559
560 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
561 if !ok {
562 return errors.New("coder agent not configured")
563 }
564
565 tools, err := c.buildTools(agentCfg)
566 if err != nil {
567 return err
568 }
569 c.currentAgent.SetTools(tools)
570 return nil
571}
572
573func (c *coordinator) QueuedPrompts(sessionID string) int {
574 return c.currentAgent.QueuedPrompts(sessionID)
575}
576
577func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
578 return c.currentAgent.Summarize(ctx, sessionID)
579}