1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "slices"
13 "strings"
14
15 "charm.land/fantasy"
16 "github.com/charmbracelet/catwalk/pkg/catwalk"
17 "github.com/charmbracelet/crush/internal/agent/prompt"
18 "github.com/charmbracelet/crush/internal/agent/tools"
19 "github.com/charmbracelet/crush/internal/config"
20 "github.com/charmbracelet/crush/internal/csync"
21 "github.com/charmbracelet/crush/internal/history"
22 "github.com/charmbracelet/crush/internal/log"
23 "github.com/charmbracelet/crush/internal/lsp"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/session"
27
28 "charm.land/fantasy/providers/anthropic"
29 "charm.land/fantasy/providers/azure"
30 "charm.land/fantasy/providers/google"
31 "charm.land/fantasy/providers/openai"
32 "charm.land/fantasy/providers/openaicompat"
33 "charm.land/fantasy/providers/openrouter"
34 "github.com/qjebbs/go-jsons"
35)
36
37type Coordinator interface {
38 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
39 // SetMainAgent(string)
40 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
41 Cancel(sessionID string)
42 CancelAll()
43 IsSessionBusy(sessionID string) bool
44 IsBusy() bool
45 QueuedPrompts(sessionID string) int
46 ClearQueue(sessionID string)
47 Summarize(context.Context, string) error
48 Model() Model
49 UpdateModels(ctx context.Context) error
50}
51
52type coordinator struct {
53 cfg *config.Config
54 sessions session.Service
55 messages message.Service
56 permissions permission.Service
57 history history.Service
58 lspClients *csync.Map[string, *lsp.Client]
59
60 currentAgent SessionAgent
61 agents map[string]SessionAgent
62}
63
64func NewCoordinator(
65 ctx context.Context,
66 cfg *config.Config,
67 sessions session.Service,
68 messages message.Service,
69 permissions permission.Service,
70 history history.Service,
71 lspClients *csync.Map[string, *lsp.Client],
72) (Coordinator, error) {
73 c := &coordinator{
74 cfg: cfg,
75 sessions: sessions,
76 messages: messages,
77 permissions: permissions,
78 history: history,
79 lspClients: lspClients,
80 agents: make(map[string]SessionAgent),
81 }
82
83 agentCfg, ok := cfg.Agents[config.AgentCoder]
84 if !ok {
85 return nil, errors.New("coder agent not configured")
86 }
87
88 // TODO: make this dynamic when we support multiple agents
89 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
90 if err != nil {
91 return nil, err
92 }
93
94 agent, err := c.buildAgent(ctx, prompt, agentCfg)
95 if err != nil {
96 return nil, err
97 }
98 c.currentAgent = agent
99 c.agents[config.AgentCoder] = agent
100 return c, nil
101}
102
103// Run implements Coordinator.
104func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
105 model := c.currentAgent.Model()
106 maxTokens := model.CatwalkCfg.DefaultMaxTokens
107 if model.ModelCfg.MaxTokens != 0 {
108 maxTokens = model.ModelCfg.MaxTokens
109 }
110
111 if !model.CatwalkCfg.SupportsImages && attachments != nil {
112 attachments = nil
113 }
114
115 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
116 if !ok {
117 return nil, errors.New("model provider not configured")
118 }
119
120 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
121
122 return c.currentAgent.Run(ctx, SessionAgentCall{
123 SessionID: sessionID,
124 Prompt: prompt,
125 Attachments: attachments,
126 MaxOutputTokens: maxTokens,
127 ProviderOptions: mergedOptions,
128 Temperature: temp,
129 TopP: topP,
130 TopK: topK,
131 FrequencyPenalty: freqPenalty,
132 PresencePenalty: presPenalty,
133 })
134}
135
136func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
137 options := fantasy.ProviderOptions{}
138
139 cfgOpts := []byte("{}")
140 catwalkOpts := []byte("{}")
141
142 if model.ModelCfg.ProviderOptions != nil {
143 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
144 if err == nil {
145 cfgOpts = data
146 }
147 }
148
149 if model.CatwalkCfg.Options.ProviderOptions != nil {
150 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
151 if err == nil {
152 catwalkOpts = data
153 }
154 }
155
156 readers := []io.Reader{
157 bytes.NewReader(catwalkOpts),
158 bytes.NewReader(cfgOpts),
159 }
160
161 got, err := jsons.Merge(readers)
162 if err != nil {
163 slog.Error("Could not merge call config", "err", err)
164 return options
165 }
166
167 mergedOptions := make(map[string]any)
168
169 err = json.Unmarshal([]byte(got), &mergedOptions)
170 if err != nil {
171 slog.Error("Could not create config for call", "err", err)
172 return options
173 }
174
175 switch tp {
176 case openai.Name:
177 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
178 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
179 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
180 }
181 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
182 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
183 mergedOptions["reasoning_summary"] = "auto"
184 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
185 }
186 parsed, err := openai.ParseResponsesOptions(mergedOptions)
187 if err == nil {
188 options[openai.Name] = parsed
189 }
190 } else {
191 parsed, err := openai.ParseOptions(mergedOptions)
192 if err == nil {
193 options[openai.Name] = parsed
194 }
195 }
196 case anthropic.Name:
197 _, hasThink := mergedOptions["thinking"]
198 if !hasThink && model.ModelCfg.Think {
199 mergedOptions["thinking"] = map[string]any{
200 // TODO: kujtim see if we need to make this dynamic
201 "budget_tokens": 2000,
202 }
203 }
204 parsed, err := anthropic.ParseOptions(mergedOptions)
205 if err == nil {
206 options[anthropic.Name] = parsed
207 }
208
209 case openrouter.Name:
210 _, hasReasoning := mergedOptions["reasoning"]
211 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
212 mergedOptions["reasoning"] = map[string]any{
213 "enabled": true,
214 "effort": model.ModelCfg.ReasoningEffort,
215 }
216 }
217 parsed, err := openrouter.ParseOptions(mergedOptions)
218 if err == nil {
219 options[openrouter.Name] = parsed
220 }
221 case google.Name:
222 _, hasReasoning := mergedOptions["thinking_config"]
223 if !hasReasoning {
224 mergedOptions["thinking_config"] = map[string]any{
225 "thinking_budget": 2000,
226 "include_thoughts": true,
227 }
228 }
229 parsed, err := google.ParseOptions(mergedOptions)
230 if err == nil {
231 options[google.Name] = parsed
232 }
233 case azure.Name:
234 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
235 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
236 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
237 }
238 // azure uses the same options as openaicompat
239 parsed, err := openaicompat.ParseOptions(mergedOptions)
240 if err == nil {
241 options[azure.Name] = parsed
242 }
243 case openaicompat.Name:
244 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
245 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
246 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
247 }
248 parsed, err := openaicompat.ParseOptions(mergedOptions)
249 if err == nil {
250 options[openaicompat.Name] = parsed
251 }
252 }
253
254 return options
255}
256
257func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
258 modelOptions := getProviderOptions(model, tp)
259 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
260 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
261 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
262 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
263 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
264 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
265}
266
267func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
268 large, small, err := c.buildAgentModels(ctx)
269 if err != nil {
270 return nil, err
271 }
272
273 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
274 if err != nil {
275 return nil, err
276 }
277
278 tools, err := c.buildTools(ctx, agent)
279 if err != nil {
280 return nil, err
281 }
282 return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
283}
284
285func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
286 var allTools []fantasy.AgentTool
287 if slices.Contains(agent.AllowedTools, AgentToolName) {
288 agentTool, err := c.agentTool(ctx)
289 if err != nil {
290 return nil, err
291 }
292 allTools = append(allTools, agentTool)
293 }
294
295 allTools = append(allTools,
296 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
297 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
298 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
299 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
300 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
301 tools.NewGlobTool(c.cfg.WorkingDir()),
302 tools.NewGrepTool(c.cfg.WorkingDir()),
303 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
304 tools.NewSourcegraphTool(nil),
305 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
306 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
307 )
308
309 if len(c.cfg.LSP) > 0 {
310 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspClients), tools.NewReferencesTool(c.lspClients))
311 }
312
313 var filteredTools []fantasy.AgentTool
314 for _, tool := range allTools {
315 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
316 filteredTools = append(filteredTools, tool)
317 }
318 }
319
320 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
321
322 for _, mcpTool := range mcpTools {
323 if agent.AllowedMCP == nil {
324 // No MCP restrictions
325 filteredTools = append(filteredTools, mcpTool)
326 } else if len(agent.AllowedMCP) == 0 {
327 // no mcps allowed
328 break
329 }
330
331 for mcp, tools := range agent.AllowedMCP {
332 if mcp == mcpTool.MCP() {
333 if len(tools) == 0 {
334 filteredTools = append(filteredTools, mcpTool)
335 }
336 for _, t := range tools {
337 if t == mcpTool.MCPToolName() {
338 filteredTools = append(filteredTools, mcpTool)
339 }
340 }
341 break
342 }
343 }
344 }
345 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
346 return strings.Compare(a.Info().Name, b.Info().Name)
347 })
348 return filteredTools, nil
349}
350
351// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
352func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
353 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
354 if !ok {
355 return Model{}, Model{}, errors.New("large model not selected")
356 }
357 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
358 if !ok {
359 return Model{}, Model{}, errors.New("small model not selected")
360 }
361
362 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
363 if !ok {
364 return Model{}, Model{}, errors.New("large model provider not configured")
365 }
366
367 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
368 if err != nil {
369 return Model{}, Model{}, err
370 }
371
372 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
373 if !ok {
374 return Model{}, Model{}, errors.New("large model provider not configured")
375 }
376
377 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
378 if err != nil {
379 return Model{}, Model{}, err
380 }
381
382 var largeCatwalkModel *catwalk.Model
383 var smallCatwalkModel *catwalk.Model
384
385 for _, m := range largeProviderCfg.Models {
386 if m.ID == largeModelCfg.Model {
387 largeCatwalkModel = &m
388 }
389 }
390 for _, m := range smallProviderCfg.Models {
391 if m.ID == smallModelCfg.Model {
392 smallCatwalkModel = &m
393 }
394 }
395
396 if largeCatwalkModel == nil {
397 return Model{}, Model{}, errors.New("large model not found in provider config")
398 }
399
400 if smallCatwalkModel == nil {
401 return Model{}, Model{}, errors.New("snall model not found in provider config")
402 }
403
404 largeModel, err := largeProvider.LanguageModel(ctx, largeModelCfg.Model)
405 if err != nil {
406 return Model{}, Model{}, err
407 }
408 smallModel, err := smallProvider.LanguageModel(ctx, smallModelCfg.Model)
409 if err != nil {
410 return Model{}, Model{}, err
411 }
412
413 return Model{
414 Model: largeModel,
415 CatwalkCfg: *largeCatwalkModel,
416 ModelCfg: largeModelCfg,
417 }, Model{
418 Model: smallModel,
419 CatwalkCfg: *smallCatwalkModel,
420 ModelCfg: smallModelCfg,
421 }, nil
422}
423
424func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
425 hasBearerAuth := false
426 for key := range headers {
427 if strings.ToLower(key) == "authorization" {
428 hasBearerAuth = true
429 break
430 }
431 }
432 if hasBearerAuth {
433 apiKey = "" // clear apiKey to avoid using X-Api-Key header
434 }
435
436 var opts []anthropic.Option
437
438 if apiKey != "" {
439 // Use standard X-Api-Key header
440 opts = append(opts, anthropic.WithAPIKey(apiKey))
441 }
442
443 if len(headers) > 0 {
444 opts = append(opts, anthropic.WithHeaders(headers))
445 }
446
447 if baseURL != "" {
448 opts = append(opts, anthropic.WithBaseURL(baseURL))
449 }
450
451 if c.cfg.Options.Debug {
452 httpClient := log.NewHTTPClient()
453 opts = append(opts, anthropic.WithHTTPClient(httpClient))
454 }
455
456 return anthropic.New(opts...)
457}
458
459func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
460 opts := []openai.Option{
461 openai.WithAPIKey(apiKey),
462 openai.WithUseResponsesAPI(),
463 }
464 if c.cfg.Options.Debug {
465 httpClient := log.NewHTTPClient()
466 opts = append(opts, openai.WithHTTPClient(httpClient))
467 }
468 if len(headers) > 0 {
469 opts = append(opts, openai.WithHeaders(headers))
470 }
471 if baseURL != "" {
472 opts = append(opts, openai.WithBaseURL(baseURL))
473 }
474 return openai.New(opts...)
475}
476
477func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
478 opts := []openrouter.Option{
479 openrouter.WithAPIKey(apiKey),
480 }
481 if c.cfg.Options.Debug {
482 httpClient := log.NewHTTPClient()
483 opts = append(opts, openrouter.WithHTTPClient(httpClient))
484 }
485 if len(headers) > 0 {
486 opts = append(opts, openrouter.WithHeaders(headers))
487 }
488 return openrouter.New(opts...)
489}
490
491func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
492 opts := []openaicompat.Option{
493 openaicompat.WithBaseURL(baseURL),
494 openaicompat.WithAPIKey(apiKey),
495 }
496 if c.cfg.Options.Debug {
497 httpClient := log.NewHTTPClient()
498 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
499 }
500 if len(headers) > 0 {
501 opts = append(opts, openaicompat.WithHeaders(headers))
502 }
503
504 return openaicompat.New(opts...)
505}
506
507func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
508 opts := []azure.Option{
509 azure.WithBaseURL(baseURL),
510 azure.WithAPIKey(apiKey),
511 }
512 if c.cfg.Options.Debug {
513 httpClient := log.NewHTTPClient()
514 opts = append(opts, azure.WithHTTPClient(httpClient))
515 }
516 if options == nil {
517 options = make(map[string]string)
518 }
519 if apiVersion, ok := options["apiVersion"]; ok {
520 opts = append(opts, azure.WithAPIVersion(apiVersion))
521 }
522 if len(headers) > 0 {
523 opts = append(opts, azure.WithHeaders(headers))
524 }
525
526 return azure.New(opts...)
527}
528
529func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
530 opts := []google.Option{
531 google.WithBaseURL(baseURL),
532 google.WithGeminiAPIKey(apiKey),
533 }
534 if c.cfg.Options.Debug {
535 httpClient := log.NewHTTPClient()
536 opts = append(opts, google.WithHTTPClient(httpClient))
537 }
538 if len(headers) > 0 {
539 opts = append(opts, google.WithHeaders(headers))
540 }
541 return google.New(opts...)
542}
543
544func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
545 opts := []google.Option{}
546 if c.cfg.Options.Debug {
547 httpClient := log.NewHTTPClient()
548 opts = append(opts, google.WithHTTPClient(httpClient))
549 }
550 if len(headers) > 0 {
551 opts = append(opts, google.WithHeaders(headers))
552 }
553
554 project := options["project"]
555 location := options["location"]
556
557 opts = append(opts, google.WithVertex(project, location))
558
559 return google.New(opts...)
560}
561
562func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
563 if model.Think {
564 return true
565 }
566
567 if model.ProviderOptions == nil {
568 return false
569 }
570
571 opts, err := anthropic.ParseOptions(model.ProviderOptions)
572 if err != nil {
573 return false
574 }
575 if opts.Thinking != nil {
576 return true
577 }
578 return false
579}
580
581func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
582 headers := providerCfg.ExtraHeaders
583
584 // handle special headers for anthropic
585 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
586 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
587 }
588
589 // TODO: make sure we have
590 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
591 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
592
593 switch providerCfg.Type {
594 case openai.Name:
595 return c.buildOpenaiProvider(baseURL, apiKey, headers)
596 case anthropic.Name:
597 return c.buildAnthropicProvider(baseURL, apiKey, headers)
598 case openrouter.Name:
599 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
600 case azure.Name:
601 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
602 case google.Name:
603 return c.buildGoogleProvider(baseURL, apiKey, headers)
604 case "google-vertex", "vertexai":
605 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
606 case openaicompat.Name:
607 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
608 default:
609 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
610 }
611}
612
613func (c *coordinator) Cancel(sessionID string) {
614 c.currentAgent.Cancel(sessionID)
615}
616
617func (c *coordinator) CancelAll() {
618 c.currentAgent.CancelAll()
619}
620
621func (c *coordinator) ClearQueue(sessionID string) {
622 c.currentAgent.ClearQueue(sessionID)
623}
624
625func (c *coordinator) IsBusy() bool {
626 return c.currentAgent.IsBusy()
627}
628
629func (c *coordinator) IsSessionBusy(sessionID string) bool {
630 return c.currentAgent.IsSessionBusy(sessionID)
631}
632
633func (c *coordinator) Model() Model {
634 return c.currentAgent.Model()
635}
636
637func (c *coordinator) UpdateModels(ctx context.Context) error {
638 // build the models again so we make sure we get the latest config
639 large, small, err := c.buildAgentModels(ctx)
640 if err != nil {
641 return err
642 }
643 c.currentAgent.SetModels(large, small)
644
645 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
646 if !ok {
647 return errors.New("coder agent not configured")
648 }
649
650 tools, err := c.buildTools(ctx, agentCfg)
651 if err != nil {
652 return err
653 }
654 c.currentAgent.SetTools(tools)
655 return nil
656}
657
658func (c *coordinator) QueuedPrompts(sessionID string) int {
659 return c.currentAgent.QueuedPrompts(sessionID)
660}
661
662func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
663 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
664 if !ok {
665 return errors.New("model provider not configured")
666 }
667 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg.Type))
668}