1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "io"
10 "log/slog"
11 "slices"
12 "strings"
13
14 "charm.land/fantasy"
15 "github.com/charmbracelet/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/crush/internal/agent/prompt"
17 "github.com/charmbracelet/crush/internal/agent/tools"
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/history"
21 "github.com/charmbracelet/crush/internal/log"
22 "github.com/charmbracelet/crush/internal/lsp"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/session"
26
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/azure"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openaicompat"
32 "charm.land/fantasy/providers/openrouter"
33 "github.com/qjebbs/go-jsons"
34)
35
36type Coordinator interface {
37 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
38 // SetMainAgent(string)
39 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
40 Cancel(sessionID string)
41 CancelAll()
42 IsSessionBusy(sessionID string) bool
43 IsBusy() bool
44 QueuedPrompts(sessionID string) int
45 ClearQueue(sessionID string)
46 Summarize(context.Context, string) error
47 Model() Model
48 UpdateModels() error
49}
50
51type coordinator struct {
52 cfg *config.Config
53 sessions session.Service
54 messages message.Service
55 permissions permission.Service
56 history history.Service
57 lspClients *csync.Map[string, *lsp.Client]
58
59 currentAgent SessionAgent
60 agents map[string]SessionAgent
61}
62
63func NewCoordinator(
64 cfg *config.Config,
65 sessions session.Service,
66 messages message.Service,
67 permissions permission.Service,
68 history history.Service,
69 lspClients *csync.Map[string, *lsp.Client],
70) (Coordinator, error) {
71 c := &coordinator{
72 cfg: cfg,
73 sessions: sessions,
74 messages: messages,
75 permissions: permissions,
76 history: history,
77 lspClients: lspClients,
78 agents: make(map[string]SessionAgent),
79 }
80
81 agentCfg, ok := cfg.Agents[config.AgentCoder]
82 if !ok {
83 return nil, errors.New("coder agent not configured")
84 }
85
86 // TODO: make this dynamic when we support multiple agents
87 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
88 if err != nil {
89 return nil, err
90 }
91
92 agent, err := c.buildAgent(prompt, agentCfg)
93 if err != nil {
94 return nil, err
95 }
96 c.currentAgent = agent
97 c.agents[config.AgentCoder] = agent
98 return c, nil
99}
100
101// Run implements Coordinator.
102func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
103 model := c.currentAgent.Model()
104 maxTokens := model.CatwalkCfg.DefaultMaxTokens
105 if model.ModelCfg.MaxTokens != 0 {
106 maxTokens = model.ModelCfg.MaxTokens
107 }
108
109 if !model.CatwalkCfg.SupportsImages && attachments != nil {
110 attachments = nil
111 }
112
113 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
114 if !ok {
115 return nil, errors.New("model provider not configured")
116 }
117
118 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
119
120 return c.currentAgent.Run(ctx, SessionAgentCall{
121 SessionID: sessionID,
122 Prompt: prompt,
123 Attachments: attachments,
124 MaxOutputTokens: maxTokens,
125 ProviderOptions: mergedOptions,
126 Temperature: temp,
127 TopP: topP,
128 TopK: topK,
129 FrequencyPenalty: freqPenalty,
130 PresencePenalty: presPenalty,
131 })
132}
133
134func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
135 options := fantasy.ProviderOptions{}
136
137 cfgOpts := []byte("{}")
138 catwalkOpts := []byte("{}")
139
140 if model.ModelCfg.ProviderOptions != nil {
141 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
142 if err == nil {
143 cfgOpts = data
144 }
145 }
146
147 if model.CatwalkCfg.Options.ProviderOptions != nil {
148 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
149 if err == nil {
150 catwalkOpts = data
151 }
152 }
153
154 readers := []io.Reader{
155 bytes.NewReader(catwalkOpts),
156 bytes.NewReader(cfgOpts),
157 }
158
159 got, err := jsons.Merge(readers)
160 if err != nil {
161 slog.Error("Could not merge call config", "err", err)
162 return options
163 }
164
165 mergedOptions := make(map[string]any)
166
167 err = json.Unmarshal([]byte(got), &mergedOptions)
168 if err != nil {
169 slog.Error("Could not create config for call", "err", err)
170 return options
171 }
172
173 switch tp {
174 case openai.Name:
175 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
176 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
177 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
178 }
179 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
180 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
181 mergedOptions["reasoning_summary"] = "auto"
182 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
183 }
184 parsed, err := openai.ParseResponsesOptions(mergedOptions)
185 if err == nil {
186 options[openai.Name] = parsed
187 }
188 } else {
189 parsed, err := openai.ParseOptions(mergedOptions)
190 if err == nil {
191 options[openai.Name] = parsed
192 }
193 }
194 case anthropic.Name:
195 _, hasThink := mergedOptions["thinking"]
196 if !hasThink && model.ModelCfg.Think {
197 mergedOptions["thinking"] = map[string]any{
198 // TODO: kujtim see if we need to make this dynamic
199 "budget_tokens": 2000,
200 }
201 }
202 parsed, err := anthropic.ParseOptions(mergedOptions)
203 if err == nil {
204 options[anthropic.Name] = parsed
205 }
206
207 case openrouter.Name:
208 _, hasReasoning := mergedOptions["reasoning"]
209 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
210 mergedOptions["reasoning"] = map[string]any{
211 "enabled": true,
212 "effort": model.ModelCfg.ReasoningEffort,
213 }
214 }
215 parsed, err := openrouter.ParseOptions(mergedOptions)
216 if err == nil {
217 options[openrouter.Name] = parsed
218 }
219 case google.Name:
220 parsed, err := google.ParseOptions(mergedOptions)
221 if err == nil {
222 options[google.Name] = parsed
223 }
224 case azure.Name:
225 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
226 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
227 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
228 }
229 // azure uses the same options as openaicompat
230 parsed, err := openaicompat.ParseOptions(mergedOptions)
231 if err == nil {
232 options[azure.Name] = parsed
233 }
234 case openaicompat.Name:
235 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
236 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
237 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
238 }
239 parsed, err := openaicompat.ParseOptions(mergedOptions)
240 if err == nil {
241 options[openaicompat.Name] = parsed
242 }
243 }
244
245 return options
246}
247
248func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
249 modelOptions := getProviderOptions(model, tp)
250 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
251 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
252 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
253 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
254 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
255 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
256}
257
258func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
259 large, small, err := c.buildAgentModels()
260 if err != nil {
261 return nil, err
262 }
263
264 systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
265 if err != nil {
266 return nil, err
267 }
268
269 tools, err := c.buildTools(agent)
270 if err != nil {
271 return nil, err
272 }
273 return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
274}
275
276func (c *coordinator) buildTools(agent config.Agent) ([]fantasy.AgentTool, error) {
277 var allTools []fantasy.AgentTool
278 if slices.Contains(agent.AllowedTools, AgentToolName) {
279 agentTool, err := c.agentTool()
280 if err != nil {
281 return nil, err
282 }
283 allTools = append(allTools, agentTool)
284 }
285
286 allTools = append(allTools,
287 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
288 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
289 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
290 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
291 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
292 tools.NewGlobTool(c.cfg.WorkingDir()),
293 tools.NewGrepTool(c.cfg.WorkingDir()),
294 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
295 tools.NewSourcegraphTool(nil),
296 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
297 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
298 )
299
300 var filteredTools []fantasy.AgentTool
301 for _, tool := range allTools {
302 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
303 filteredTools = append(filteredTools, tool)
304 }
305 }
306
307 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
308
309 for _, mcpTool := range mcpTools {
310 if agent.AllowedMCP == nil {
311 // No MCP restrictions
312 filteredTools = append(filteredTools, mcpTool)
313 } else if len(agent.AllowedMCP) == 0 {
314 // no mcps allowed
315 break
316 }
317
318 for mcp, tools := range agent.AllowedMCP {
319 if mcp == mcpTool.MCP() {
320 if len(tools) == 0 {
321 filteredTools = append(filteredTools, mcpTool)
322 }
323 for _, t := range tools {
324 if t == mcpTool.MCPToolName() {
325 filteredTools = append(filteredTools, mcpTool)
326 }
327 }
328 break
329 }
330 }
331 }
332
333 return filteredTools, nil
334}
335
336// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
337func (c *coordinator) buildAgentModels() (Model, Model, error) {
338 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
339 if !ok {
340 return Model{}, Model{}, errors.New("large model not selected")
341 }
342 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
343 if !ok {
344 return Model{}, Model{}, errors.New("small model not selected")
345 }
346
347 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
348 if !ok {
349 return Model{}, Model{}, errors.New("large model provider not configured")
350 }
351
352 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
353 if err != nil {
354 return Model{}, Model{}, err
355 }
356
357 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
358 if !ok {
359 return Model{}, Model{}, errors.New("large model provider not configured")
360 }
361
362 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
363 if err != nil {
364 return Model{}, Model{}, err
365 }
366
367 var largeCatwalkModel *catwalk.Model
368 var smallCatwalkModel *catwalk.Model
369
370 for _, m := range largeProviderCfg.Models {
371 if m.ID == largeModelCfg.Model {
372 largeCatwalkModel = &m
373 }
374 }
375 for _, m := range smallProviderCfg.Models {
376 if m.ID == smallModelCfg.Model {
377 smallCatwalkModel = &m
378 }
379 }
380
381 if largeCatwalkModel == nil {
382 return Model{}, Model{}, errors.New("large model not found in provider config")
383 }
384
385 if smallCatwalkModel == nil {
386 return Model{}, Model{}, errors.New("snall model not found in provider config")
387 }
388
389 largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
390 if err != nil {
391 return Model{}, Model{}, err
392 }
393 smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
394 if err != nil {
395 return Model{}, Model{}, err
396 }
397
398 return Model{
399 Model: largeModel,
400 CatwalkCfg: *largeCatwalkModel,
401 ModelCfg: largeModelCfg,
402 }, Model{
403 Model: smallModel,
404 CatwalkCfg: *smallCatwalkModel,
405 ModelCfg: smallModelCfg,
406 }, nil
407}
408
409func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
410 hasBearerAuth := false
411 for key := range headers {
412 if strings.ToLower(key) == "authorization" {
413 hasBearerAuth = true
414 break
415 }
416 }
417 if hasBearerAuth {
418 apiKey = "" // clear apiKey to avoid using X-Api-Key header
419 }
420
421 var opts []anthropic.Option
422
423 if apiKey != "" {
424 // Use standard X-Api-Key header
425 opts = append(opts, anthropic.WithAPIKey(apiKey))
426 }
427
428 if len(headers) > 0 {
429 opts = append(opts, anthropic.WithHeaders(headers))
430 }
431
432 if baseURL != "" {
433 opts = append(opts, anthropic.WithBaseURL(baseURL))
434 }
435
436 if c.cfg.Options.Debug {
437 httpClient := log.NewHTTPClient()
438 opts = append(opts, anthropic.WithHTTPClient(httpClient))
439 }
440
441 return anthropic.New(opts...)
442}
443
444func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
445 opts := []openai.Option{
446 openai.WithAPIKey(apiKey),
447 openai.WithUseResponsesAPI(),
448 }
449 if c.cfg.Options.Debug {
450 httpClient := log.NewHTTPClient()
451 opts = append(opts, openai.WithHTTPClient(httpClient))
452 }
453 if len(headers) > 0 {
454 opts = append(opts, openai.WithHeaders(headers))
455 }
456 if baseURL != "" {
457 opts = append(opts, openai.WithBaseURL(baseURL))
458 }
459 return openai.New(opts...)
460}
461
462func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) fantasy.Provider {
463 opts := []openrouter.Option{
464 openrouter.WithAPIKey(apiKey),
465 }
466 if c.cfg.Options.Debug {
467 httpClient := log.NewHTTPClient()
468 opts = append(opts, openrouter.WithHTTPClient(httpClient))
469 }
470 if len(headers) > 0 {
471 opts = append(opts, openrouter.WithHeaders(headers))
472 }
473 return openrouter.New(opts...)
474}
475
476func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
477 opts := []openaicompat.Option{
478 openaicompat.WithBaseURL(baseURL),
479 openaicompat.WithAPIKey(apiKey),
480 }
481 if c.cfg.Options.Debug {
482 httpClient := log.NewHTTPClient()
483 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
484 }
485 if len(headers) > 0 {
486 opts = append(opts, openaicompat.WithHeaders(headers))
487 }
488
489 return openaicompat.New(opts...)
490}
491
492func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) fantasy.Provider {
493 opts := []azure.Option{
494 azure.WithBaseURL(baseURL),
495 azure.WithAPIKey(apiKey),
496 }
497 if c.cfg.Options.Debug {
498 httpClient := log.NewHTTPClient()
499 opts = append(opts, azure.WithHTTPClient(httpClient))
500 }
501 if options == nil {
502 options = make(map[string]string)
503 }
504 if apiVersion, ok := options["apiVersion"]; ok {
505 opts = append(opts, azure.WithAPIVersion(apiVersion))
506 }
507 if len(headers) > 0 {
508 opts = append(opts, azure.WithHeaders(headers))
509 }
510
511 return azure.New(opts...)
512}
513
514func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
515 opts := []google.Option{
516 google.WithBaseURL(baseURL),
517 google.WithGeminiAPIKey(apiKey),
518 }
519 if c.cfg.Options.Debug {
520 httpClient := log.NewHTTPClient()
521 opts = append(opts, google.WithHTTPClient(httpClient))
522 }
523 if len(headers) > 0 {
524 opts = append(opts, google.WithHeaders(headers))
525 }
526 return google.New(opts...)
527}
528
529func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) fantasy.Provider {
530 opts := []google.Option{}
531 if c.cfg.Options.Debug {
532 httpClient := log.NewHTTPClient()
533 opts = append(opts, google.WithHTTPClient(httpClient))
534 }
535 if len(headers) > 0 {
536 opts = append(opts, google.WithHeaders(headers))
537 }
538
539 project := options["project"]
540 location := options["location"]
541
542 opts = append(opts, google.WithVertex(project, location))
543
544 return google.New(opts...)
545}
546
547func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
548 if model.Think {
549 return true
550 }
551
552 if model.ProviderOptions == nil {
553 return false
554 }
555
556 opts, err := anthropic.ParseOptions(model.ProviderOptions)
557 if err != nil {
558 return false
559 }
560 if opts.Thinking != nil {
561 return true
562 }
563 return false
564}
565
566func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
567 headers := providerCfg.ExtraHeaders
568
569 // handle special headers for anthropic
570 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
571 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
572 }
573
574 // TODO: make sure we have
575 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
576 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
577 var provider fantasy.Provider
578 switch providerCfg.Type {
579 case openai.Name:
580 provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
581 case anthropic.Name:
582 provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
583 case openrouter.Name:
584 provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
585 case azure.Name:
586 provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
587 case google.Name:
588 provider = c.buildGoogleProvider(baseURL, apiKey, headers)
589 // this is not in fantasy since its just the google provider with extra stuff
590 case "google-vertex":
591 provider = c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
592 case openaicompat.Name:
593 provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
594 default:
595 return nil, errors.New("provider type not supported")
596 }
597 return provider, nil
598}
599
600func (c *coordinator) Cancel(sessionID string) {
601 c.currentAgent.Cancel(sessionID)
602}
603
604func (c *coordinator) CancelAll() {
605 c.currentAgent.CancelAll()
606}
607
608func (c *coordinator) ClearQueue(sessionID string) {
609 c.currentAgent.ClearQueue(sessionID)
610}
611
612func (c *coordinator) IsBusy() bool {
613 return c.currentAgent.IsBusy()
614}
615
616func (c *coordinator) IsSessionBusy(sessionID string) bool {
617 return c.currentAgent.IsSessionBusy(sessionID)
618}
619
620func (c *coordinator) Model() Model {
621 return c.currentAgent.Model()
622}
623
624func (c *coordinator) UpdateModels() error {
625 // build the models again so we make sure we get the latest config
626 large, small, err := c.buildAgentModels()
627 if err != nil {
628 return err
629 }
630 c.currentAgent.SetModels(large, small)
631
632 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
633 if !ok {
634 return errors.New("coder agent not configured")
635 }
636
637 tools, err := c.buildTools(agentCfg)
638 if err != nil {
639 return err
640 }
641 c.currentAgent.SetTools(tools)
642 return nil
643}
644
645func (c *coordinator) QueuedPrompts(sessionID string) int {
646 return c.currentAgent.QueuedPrompts(sessionID)
647}
648
649func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
650 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
651 if !ok {
652 return errors.New("model provider not configured")
653 }
654 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg.Type))
655}