1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "io"
10 "log/slog"
11 "slices"
12 "strings"
13
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/agent/prompt"
16 "github.com/charmbracelet/crush/internal/agent/tools"
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/csync"
19 "github.com/charmbracelet/crush/internal/history"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/fantasy/ai"
26 "github.com/charmbracelet/fantasy/anthropic"
27 "github.com/charmbracelet/fantasy/azure"
28 "github.com/charmbracelet/fantasy/google"
29 "github.com/charmbracelet/fantasy/openai"
30 "github.com/charmbracelet/fantasy/openaicompat"
31 "github.com/charmbracelet/fantasy/openrouter"
32 "github.com/qjebbs/go-jsons"
33)
34
35type Coordinator interface {
36 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
37 // SetMainAgent(string)
38 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error)
39 Cancel(sessionID string)
40 CancelAll()
41 IsSessionBusy(sessionID string) bool
42 IsBusy() bool
43 QueuedPrompts(sessionID string) int
44 ClearQueue(sessionID string)
45 Summarize(context.Context, string) error
46 Model() Model
47 UpdateModels() error
48}
49
50type coordinator struct {
51 cfg *config.Config
52 sessions session.Service
53 messages message.Service
54 permissions permission.Service
55 history history.Service
56 lspClients *csync.Map[string, *lsp.Client]
57
58 currentAgent SessionAgent
59 agents map[string]SessionAgent
60}
61
62func NewCoordinator(
63 cfg *config.Config,
64 sessions session.Service,
65 messages message.Service,
66 permissions permission.Service,
67 history history.Service,
68 lspClients *csync.Map[string, *lsp.Client],
69) (Coordinator, error) {
70 c := &coordinator{
71 cfg: cfg,
72 sessions: sessions,
73 messages: messages,
74 permissions: permissions,
75 history: history,
76 lspClients: lspClients,
77 agents: make(map[string]SessionAgent),
78 }
79
80 agentCfg, ok := cfg.Agents[config.AgentCoder]
81 if !ok {
82 return nil, errors.New("coder agent not configured")
83 }
84
85 // TODO: make this dynamic when we support multiple agents
86 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
87 if err != nil {
88 return nil, err
89 }
90
91 agent, err := c.buildAgent(prompt, agentCfg)
92 if err != nil {
93 return nil, err
94 }
95 c.currentAgent = agent
96 c.agents[config.AgentCoder] = agent
97 return c, nil
98}
99
100// Run implements Coordinator.
101func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*ai.AgentResult, error) {
102 model := c.currentAgent.Model()
103 maxTokens := model.CatwalkCfg.DefaultMaxTokens
104 if model.ModelCfg.MaxTokens != 0 {
105 maxTokens = model.ModelCfg.MaxTokens
106 }
107
108 if !model.CatwalkCfg.SupportsImages && attachments != nil {
109 attachments = nil
110 }
111
112 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
113 if !ok {
114 return nil, errors.New("model provider not configured")
115 }
116
117 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
118
119 return c.currentAgent.Run(ctx, SessionAgentCall{
120 SessionID: sessionID,
121 Prompt: prompt,
122 Attachments: attachments,
123 MaxOutputTokens: maxTokens,
124 ProviderOptions: mergedOptions,
125 Temperature: temp,
126 TopP: topP,
127 TopK: topK,
128 FrequencyPenalty: freqPenalty,
129 PresencePenalty: presPenalty,
130 })
131}
132
133func getProviderOptions(model Model, tp catwalk.Type) ai.ProviderOptions {
134 options := ai.ProviderOptions{}
135
136 cfgOpts := []byte("{}")
137 catwalkOpts := []byte("{}")
138
139 if model.ModelCfg.ProviderOptions != nil {
140 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
141 if err == nil {
142 cfgOpts = data
143 }
144 }
145
146 if model.CatwalkCfg.Options.ProviderOptions != nil {
147 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
148 if err == nil {
149 catwalkOpts = data
150 }
151 }
152
153 readers := []io.Reader{
154 bytes.NewReader(catwalkOpts),
155 bytes.NewReader(cfgOpts),
156 }
157
158 got, err := jsons.Merge(readers)
159 if err != nil {
160 slog.Error("Could not merge call config", "err", err)
161 return options
162 }
163
164 mergedOptions := make(map[string]any)
165
166 err = json.Unmarshal([]byte(got), &mergedOptions)
167 if err != nil {
168 slog.Error("Could not create config for call", "err", err)
169 return options
170 }
171
172 switch tp {
173 case openai.Name:
174 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
175 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
176 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
177 }
178 parsed, err := openai.ParseOptions(mergedOptions)
179 if err == nil {
180 options[openai.Name] = parsed
181 }
182 case anthropic.Name:
183 _, hasThink := mergedOptions["thinking"]
184 if !hasThink && model.ModelCfg.Think {
185 mergedOptions["thinking"] = map[string]any{
186 // TODO: kujtim see if we need to make this dynamic
187 "budget_tokens": 2000,
188 }
189 }
190 parsed, err := anthropic.ParseOptions(mergedOptions)
191 if err == nil {
192 options[anthropic.Name] = parsed
193 }
194
195 case openrouter.Name:
196 _, hasReasoning := mergedOptions["reasoning"]
197 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
198 mergedOptions["reasoning"] = map[string]any{
199 "enabled": true,
200 "effort": model.ModelCfg.ReasoningEffort,
201 }
202 }
203 parsed, err := openrouter.ParseOptions(mergedOptions)
204 if err == nil {
205 options[openrouter.Name] = parsed
206 }
207 case google.Name:
208 parsed, err := google.ParseOptions(mergedOptions)
209 if err == nil {
210 options[google.Name] = parsed
211 }
212 case azure.Name:
213 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
214 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
215 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
216 }
217 // azure uses the same options as openaicompat
218 parsed, err := openaicompat.ParseOptions(mergedOptions)
219 if err == nil {
220 options[azure.Name] = parsed
221 }
222 case openaicompat.Name:
223 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
224 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
225 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
226 }
227 parsed, err := openaicompat.ParseOptions(mergedOptions)
228 if err == nil {
229 options[openaicompat.Name] = parsed
230 }
231 }
232
233 return options
234}
235
236func mergeCallOptions(model Model, tp catwalk.Type) (ai.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
237 modelOptions := getProviderOptions(model, tp)
238 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
239 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
240 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
241 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
242 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
243 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
244}
245
246func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
247 large, small, err := c.buildAgentModels()
248 if err != nil {
249 return nil, err
250 }
251
252 systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
253 if err != nil {
254 return nil, err
255 }
256
257 tools, err := c.buildTools(agent)
258 if err != nil {
259 return nil, err
260 }
261 return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
262}
263
264func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) {
265 var allTools []ai.AgentTool
266 if slices.Contains(agent.AllowedTools, AgentToolName) {
267 agentTool, err := c.agentTool()
268 if err != nil {
269 return nil, err
270 }
271 allTools = append(allTools, agentTool)
272 }
273
274 allTools = append(allTools,
275 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
276 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
277 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
278 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
279 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
280 tools.NewGlobTool(c.cfg.WorkingDir()),
281 tools.NewGrepTool(c.cfg.WorkingDir()),
282 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
283 tools.NewSourcegraphTool(nil),
284 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
285 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
286 )
287
288 var filteredTools []ai.AgentTool
289 for _, tool := range allTools {
290 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
291 filteredTools = append(filteredTools, tool)
292 }
293 }
294
295 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
296
297 for _, mcpTool := range mcpTools {
298 if agent.AllowedMCP == nil {
299 // No MCP restrictions
300 filteredTools = append(filteredTools, mcpTool)
301 } else if len(agent.AllowedMCP) == 0 {
302 // no mcps allowed
303 break
304 }
305
306 for mcp, tools := range agent.AllowedMCP {
307 if mcp == mcpTool.MCP() {
308 if len(tools) == 0 {
309 filteredTools = append(filteredTools, mcpTool)
310 }
311 for _, t := range tools {
312 if t == mcpTool.MCPToolName() {
313 filteredTools = append(filteredTools, mcpTool)
314 }
315 }
316 break
317 }
318 }
319 }
320
321 return filteredTools, nil
322}
323
324// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
325func (c *coordinator) buildAgentModels() (Model, Model, error) {
326 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
327 if !ok {
328 return Model{}, Model{}, errors.New("large model not selected")
329 }
330 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
331 if !ok {
332 return Model{}, Model{}, errors.New("small model not selected")
333 }
334
335 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
336 if !ok {
337 return Model{}, Model{}, errors.New("large model provider not configured")
338 }
339
340 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
341 if err != nil {
342 return Model{}, Model{}, err
343 }
344
345 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
346 if !ok {
347 return Model{}, Model{}, errors.New("large model provider not configured")
348 }
349
350 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
351 if err != nil {
352 return Model{}, Model{}, err
353 }
354
355 var largeCatwalkModel *catwalk.Model
356 var smallCatwalkModel *catwalk.Model
357
358 for _, m := range largeProviderCfg.Models {
359 if m.ID == largeModelCfg.Model {
360 largeCatwalkModel = &m
361 }
362 }
363 for _, m := range smallProviderCfg.Models {
364 if m.ID == smallModelCfg.Model {
365 smallCatwalkModel = &m
366 }
367 }
368
369 if largeCatwalkModel == nil {
370 return Model{}, Model{}, errors.New("large model not found in provider config")
371 }
372
373 if smallCatwalkModel == nil {
374 return Model{}, Model{}, errors.New("snall model not found in provider config")
375 }
376
377 largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
378 if err != nil {
379 return Model{}, Model{}, err
380 }
381 smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
382 if err != nil {
383 return Model{}, Model{}, err
384 }
385
386 return Model{
387 Model: largeModel,
388 CatwalkCfg: *largeCatwalkModel,
389 ModelCfg: largeModelCfg,
390 }, Model{
391 Model: smallModel,
392 CatwalkCfg: *smallCatwalkModel,
393 ModelCfg: smallModelCfg,
394 }, nil
395}
396
397func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
398 hasBearerAuth := false
399 for key := range headers {
400 if strings.ToLower(key) == "authorization" {
401 hasBearerAuth = true
402 break
403 }
404 }
405 if hasBearerAuth {
406 apiKey = "" // clear apiKey to avoid using X-Api-Key header
407 }
408
409 var opts []anthropic.Option
410
411 if apiKey != "" {
412 // Use standard X-Api-Key header
413 opts = append(opts, anthropic.WithAPIKey(apiKey))
414 }
415
416 if len(headers) > 0 {
417 opts = append(opts, anthropic.WithHeaders(headers))
418 }
419
420 if baseURL != "" {
421 opts = append(opts, anthropic.WithBaseURL(baseURL))
422 }
423
424 if c.cfg.Options.Debug {
425 httpClient := log.NewHTTPClient()
426 opts = append(opts, anthropic.WithHTTPClient(httpClient))
427 }
428
429 return anthropic.New(opts...)
430}
431
432func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
433 opts := []openai.Option{
434 openai.WithAPIKey(apiKey),
435 }
436 if c.cfg.Options.Debug {
437 httpClient := log.NewHTTPClient()
438 opts = append(opts, openai.WithHTTPClient(httpClient))
439 }
440 if len(headers) > 0 {
441 opts = append(opts, openai.WithHeaders(headers))
442 }
443 if baseURL != "" {
444 opts = append(opts, openai.WithBaseURL(baseURL))
445 }
446 return openai.New(opts...)
447}
448
449func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) ai.Provider {
450 opts := []openrouter.Option{
451 openrouter.WithAPIKey(apiKey),
452 }
453 if c.cfg.Options.Debug {
454 httpClient := log.NewHTTPClient()
455 opts = append(opts, openrouter.WithHTTPClient(httpClient))
456 }
457 if len(headers) > 0 {
458 opts = append(opts, openrouter.WithHeaders(headers))
459 }
460 return openrouter.New(opts...)
461}
462
463func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
464 opts := []openaicompat.Option{
465 openaicompat.WithBaseURL(baseURL),
466 openaicompat.WithAPIKey(apiKey),
467 }
468 if c.cfg.Options.Debug {
469 httpClient := log.NewHTTPClient()
470 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
471 }
472 if len(headers) > 0 {
473 opts = append(opts, openaicompat.WithHeaders(headers))
474 }
475
476 return openaicompat.New(opts...)
477}
478
479func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) ai.Provider {
480 opts := []azure.Option{
481 azure.WithBaseURL(baseURL),
482 azure.WithAPIKey(apiKey),
483 }
484 if c.cfg.Options.Debug {
485 httpClient := log.NewHTTPClient()
486 opts = append(opts, azure.WithHTTPClient(httpClient))
487 }
488 if options == nil {
489 options = make(map[string]string)
490 }
491 if apiVersion, ok := options["apiVersion"]; ok {
492 opts = append(opts, azure.WithAPIVersion(apiVersion))
493 }
494 if len(headers) > 0 {
495 opts = append(opts, azure.WithHeaders(headers))
496 }
497
498 return azure.New(opts...)
499}
500
501func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) ai.Provider {
502 opts := []google.Option{
503 google.WithBaseURL(baseURL),
504 google.WithGeminiAPIKey(apiKey),
505 }
506 if c.cfg.Options.Debug {
507 httpClient := log.NewHTTPClient()
508 opts = append(opts, google.WithHTTPClient(httpClient))
509 }
510 if len(headers) > 0 {
511 opts = append(opts, google.WithHeaders(headers))
512 }
513 return google.New(opts...)
514}
515
516func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) ai.Provider {
517 opts := []google.Option{}
518 if c.cfg.Options.Debug {
519 httpClient := log.NewHTTPClient()
520 opts = append(opts, google.WithHTTPClient(httpClient))
521 }
522 if len(headers) > 0 {
523 opts = append(opts, google.WithHeaders(headers))
524 }
525
526 project := options["project"]
527 location := options["location"]
528
529 opts = append(opts, google.WithVertex(project, location))
530
531 return google.New(opts...)
532}
533
534func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
535 if model.Think {
536 return true
537 }
538
539 if model.ProviderOptions == nil {
540 return false
541 }
542
543 opts, err := anthropic.ParseOptions(model.ProviderOptions)
544 if err != nil {
545 return false
546 }
547 if opts.Thinking != nil {
548 return true
549 }
550 return false
551}
552
553func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (ai.Provider, error) {
554 headers := providerCfg.ExtraHeaders
555
556 // handle special headers for anthropic
557 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
558 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
559 }
560
561 // TODO: make sure we have
562 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
563 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
564 var provider ai.Provider
565 switch providerCfg.Type {
566 case openai.Name:
567 provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
568 case anthropic.Name:
569 provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
570 case openrouter.Name:
571 provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
572 case azure.Name:
573 provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
574 case google.Name:
575 provider = c.buildGoogleProvider(baseURL, apiKey, headers)
576 // this is not in fantasy since its just the google provider with extra stuff
577 case "google-vertex":
578 provider = c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
579 case openaicompat.Name:
580 provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
581 default:
582 return nil, errors.New("provider type not supported")
583 }
584 return provider, nil
585}
586
587func (c *coordinator) Cancel(sessionID string) {
588 c.currentAgent.Cancel(sessionID)
589}
590
591func (c *coordinator) CancelAll() {
592 c.currentAgent.CancelAll()
593}
594
595func (c *coordinator) ClearQueue(sessionID string) {
596 c.currentAgent.ClearQueue(sessionID)
597}
598
599func (c *coordinator) IsBusy() bool {
600 return c.currentAgent.IsBusy()
601}
602
603func (c *coordinator) IsSessionBusy(sessionID string) bool {
604 return c.currentAgent.IsSessionBusy(sessionID)
605}
606
607func (c *coordinator) Model() Model {
608 return c.currentAgent.Model()
609}
610
611func (c *coordinator) UpdateModels() error {
612 // build the models again so we make sure we get the latest config
613 large, small, err := c.buildAgentModels()
614 if err != nil {
615 return err
616 }
617 c.currentAgent.SetModels(large, small)
618
619 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
620 if !ok {
621 return errors.New("coder agent not configured")
622 }
623
624 tools, err := c.buildTools(agentCfg)
625 if err != nil {
626 return err
627 }
628 c.currentAgent.SetTools(tools)
629 return nil
630}
631
632func (c *coordinator) QueuedPrompts(sessionID string) int {
633 return c.currentAgent.QueuedPrompts(sessionID)
634}
635
636func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
637 return c.currentAgent.Summarize(ctx, sessionID)
638}