1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "io"
10 "log/slog"
11 "slices"
12 "strings"
13
14 "charm.land/fantasy"
15 "github.com/charmbracelet/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/crush/internal/agent/prompt"
17 "github.com/charmbracelet/crush/internal/agent/tools"
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/history"
21 "github.com/charmbracelet/crush/internal/log"
22 "github.com/charmbracelet/crush/internal/lsp"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/session"
26
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/azure"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openaicompat"
32 "charm.land/fantasy/providers/openrouter"
33 "github.com/qjebbs/go-jsons"
34)
35
36type Coordinator interface {
37 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
38 // SetMainAgent(string)
39 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
40 Cancel(sessionID string)
41 CancelAll()
42 IsSessionBusy(sessionID string) bool
43 IsBusy() bool
44 QueuedPrompts(sessionID string) int
45 ClearQueue(sessionID string)
46 Summarize(context.Context, string) error
47 Model() Model
48 UpdateModels(ctx context.Context) error
49}
50
51type coordinator struct {
52 cfg *config.Config
53 sessions session.Service
54 messages message.Service
55 permissions permission.Service
56 history history.Service
57 lspClients *csync.Map[string, *lsp.Client]
58
59 currentAgent SessionAgent
60 agents map[string]SessionAgent
61}
62
63func NewCoordinator(
64 ctx context.Context,
65 cfg *config.Config,
66 sessions session.Service,
67 messages message.Service,
68 permissions permission.Service,
69 history history.Service,
70 lspClients *csync.Map[string, *lsp.Client],
71) (Coordinator, error) {
72 c := &coordinator{
73 cfg: cfg,
74 sessions: sessions,
75 messages: messages,
76 permissions: permissions,
77 history: history,
78 lspClients: lspClients,
79 agents: make(map[string]SessionAgent),
80 }
81
82 agentCfg, ok := cfg.Agents[config.AgentCoder]
83 if !ok {
84 return nil, errors.New("coder agent not configured")
85 }
86
87 // TODO: make this dynamic when we support multiple agents
88 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
89 if err != nil {
90 return nil, err
91 }
92
93 agent, err := c.buildAgent(ctx, prompt, agentCfg)
94 if err != nil {
95 return nil, err
96 }
97 c.currentAgent = agent
98 c.agents[config.AgentCoder] = agent
99 return c, nil
100}
101
102// Run implements Coordinator.
103func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
104 model := c.currentAgent.Model()
105 maxTokens := model.CatwalkCfg.DefaultMaxTokens
106 if model.ModelCfg.MaxTokens != 0 {
107 maxTokens = model.ModelCfg.MaxTokens
108 }
109
110 if !model.CatwalkCfg.SupportsImages && attachments != nil {
111 attachments = nil
112 }
113
114 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
115 if !ok {
116 return nil, errors.New("model provider not configured")
117 }
118
119 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
120
121 return c.currentAgent.Run(ctx, SessionAgentCall{
122 SessionID: sessionID,
123 Prompt: prompt,
124 Attachments: attachments,
125 MaxOutputTokens: maxTokens,
126 ProviderOptions: mergedOptions,
127 Temperature: temp,
128 TopP: topP,
129 TopK: topK,
130 FrequencyPenalty: freqPenalty,
131 PresencePenalty: presPenalty,
132 })
133}
134
135func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
136 options := fantasy.ProviderOptions{}
137
138 cfgOpts := []byte("{}")
139 catwalkOpts := []byte("{}")
140
141 if model.ModelCfg.ProviderOptions != nil {
142 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
143 if err == nil {
144 cfgOpts = data
145 }
146 }
147
148 if model.CatwalkCfg.Options.ProviderOptions != nil {
149 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
150 if err == nil {
151 catwalkOpts = data
152 }
153 }
154
155 readers := []io.Reader{
156 bytes.NewReader(catwalkOpts),
157 bytes.NewReader(cfgOpts),
158 }
159
160 got, err := jsons.Merge(readers)
161 if err != nil {
162 slog.Error("Could not merge call config", "err", err)
163 return options
164 }
165
166 mergedOptions := make(map[string]any)
167
168 err = json.Unmarshal([]byte(got), &mergedOptions)
169 if err != nil {
170 slog.Error("Could not create config for call", "err", err)
171 return options
172 }
173
174 switch tp {
175 case openai.Name:
176 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
177 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
178 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
179 }
180 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
181 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
182 mergedOptions["reasoning_summary"] = "auto"
183 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
184 }
185 parsed, err := openai.ParseResponsesOptions(mergedOptions)
186 if err == nil {
187 options[openai.Name] = parsed
188 }
189 } else {
190 parsed, err := openai.ParseOptions(mergedOptions)
191 if err == nil {
192 options[openai.Name] = parsed
193 }
194 }
195 case anthropic.Name:
196 _, hasThink := mergedOptions["thinking"]
197 if !hasThink && model.ModelCfg.Think {
198 mergedOptions["thinking"] = map[string]any{
199 // TODO: kujtim see if we need to make this dynamic
200 "budget_tokens": 2000,
201 }
202 }
203 parsed, err := anthropic.ParseOptions(mergedOptions)
204 if err == nil {
205 options[anthropic.Name] = parsed
206 }
207
208 case openrouter.Name:
209 _, hasReasoning := mergedOptions["reasoning"]
210 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
211 mergedOptions["reasoning"] = map[string]any{
212 "enabled": true,
213 "effort": model.ModelCfg.ReasoningEffort,
214 }
215 }
216 parsed, err := openrouter.ParseOptions(mergedOptions)
217 if err == nil {
218 options[openrouter.Name] = parsed
219 }
220 case google.Name:
221 _, hasReasoning := mergedOptions["thinking_config"]
222 if !hasReasoning {
223 mergedOptions["thinking_config"] = map[string]any{
224 "thinking_budget": 2000,
225 "include_thoughts": true,
226 }
227 }
228 parsed, err := google.ParseOptions(mergedOptions)
229 if err == nil {
230 options[google.Name] = parsed
231 }
232 case azure.Name:
233 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
234 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
235 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
236 }
237 // azure uses the same options as openaicompat
238 parsed, err := openaicompat.ParseOptions(mergedOptions)
239 if err == nil {
240 options[azure.Name] = parsed
241 }
242 case openaicompat.Name:
243 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
244 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
245 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
246 }
247 parsed, err := openaicompat.ParseOptions(mergedOptions)
248 if err == nil {
249 options[openaicompat.Name] = parsed
250 }
251 }
252
253 return options
254}
255
256func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
257 modelOptions := getProviderOptions(model, tp)
258 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
259 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
260 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
261 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
262 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
263 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
264}
265
266func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
267 large, small, err := c.buildAgentModels(ctx)
268 if err != nil {
269 return nil, err
270 }
271
272 systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
273 if err != nil {
274 return nil, err
275 }
276
277 tools, err := c.buildTools(ctx, agent)
278 if err != nil {
279 return nil, err
280 }
281 return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
282}
283
284func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
285 var allTools []fantasy.AgentTool
286 if slices.Contains(agent.AllowedTools, AgentToolName) {
287 agentTool, err := c.agentTool(ctx)
288 if err != nil {
289 return nil, err
290 }
291 allTools = append(allTools, agentTool)
292 }
293
294 allTools = append(allTools,
295 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
296 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
297 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
298 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
299 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
300 tools.NewGlobTool(c.cfg.WorkingDir()),
301 tools.NewGrepTool(c.cfg.WorkingDir()),
302 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
303 tools.NewSourcegraphTool(nil),
304 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
305 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
306 )
307
308 var filteredTools []fantasy.AgentTool
309 for _, tool := range allTools {
310 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
311 filteredTools = append(filteredTools, tool)
312 }
313 }
314
315 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
316
317 for _, mcpTool := range mcpTools {
318 if agent.AllowedMCP == nil {
319 // No MCP restrictions
320 filteredTools = append(filteredTools, mcpTool)
321 } else if len(agent.AllowedMCP) == 0 {
322 // no mcps allowed
323 break
324 }
325
326 for mcp, tools := range agent.AllowedMCP {
327 if mcp == mcpTool.MCP() {
328 if len(tools) == 0 {
329 filteredTools = append(filteredTools, mcpTool)
330 }
331 for _, t := range tools {
332 if t == mcpTool.MCPToolName() {
333 filteredTools = append(filteredTools, mcpTool)
334 }
335 }
336 break
337 }
338 }
339 }
340
341 return filteredTools, nil
342}
343
344// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
345func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
346 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
347 if !ok {
348 return Model{}, Model{}, errors.New("large model not selected")
349 }
350 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
351 if !ok {
352 return Model{}, Model{}, errors.New("small model not selected")
353 }
354
355 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
356 if !ok {
357 return Model{}, Model{}, errors.New("large model provider not configured")
358 }
359
360 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
361 if err != nil {
362 return Model{}, Model{}, err
363 }
364
365 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
366 if !ok {
367 return Model{}, Model{}, errors.New("large model provider not configured")
368 }
369
370 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
371 if err != nil {
372 return Model{}, Model{}, err
373 }
374
375 var largeCatwalkModel *catwalk.Model
376 var smallCatwalkModel *catwalk.Model
377
378 for _, m := range largeProviderCfg.Models {
379 if m.ID == largeModelCfg.Model {
380 largeCatwalkModel = &m
381 }
382 }
383 for _, m := range smallProviderCfg.Models {
384 if m.ID == smallModelCfg.Model {
385 smallCatwalkModel = &m
386 }
387 }
388
389 if largeCatwalkModel == nil {
390 return Model{}, Model{}, errors.New("large model not found in provider config")
391 }
392
393 if smallCatwalkModel == nil {
394 return Model{}, Model{}, errors.New("snall model not found in provider config")
395 }
396
397 largeModel, err := largeProvider.LanguageModel(ctx, largeModelCfg.Model)
398 if err != nil {
399 return Model{}, Model{}, err
400 }
401 smallModel, err := smallProvider.LanguageModel(ctx, smallModelCfg.Model)
402 if err != nil {
403 return Model{}, Model{}, err
404 }
405
406 return Model{
407 Model: largeModel,
408 CatwalkCfg: *largeCatwalkModel,
409 ModelCfg: largeModelCfg,
410 }, Model{
411 Model: smallModel,
412 CatwalkCfg: *smallCatwalkModel,
413 ModelCfg: smallModelCfg,
414 }, nil
415}
416
417func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
418 hasBearerAuth := false
419 for key := range headers {
420 if strings.ToLower(key) == "authorization" {
421 hasBearerAuth = true
422 break
423 }
424 }
425 if hasBearerAuth {
426 apiKey = "" // clear apiKey to avoid using X-Api-Key header
427 }
428
429 var opts []anthropic.Option
430
431 if apiKey != "" {
432 // Use standard X-Api-Key header
433 opts = append(opts, anthropic.WithAPIKey(apiKey))
434 }
435
436 if len(headers) > 0 {
437 opts = append(opts, anthropic.WithHeaders(headers))
438 }
439
440 if baseURL != "" {
441 opts = append(opts, anthropic.WithBaseURL(baseURL))
442 }
443
444 if c.cfg.Options.Debug {
445 httpClient := log.NewHTTPClient()
446 opts = append(opts, anthropic.WithHTTPClient(httpClient))
447 }
448
449 return anthropic.New(opts...)
450}
451
452func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
453 opts := []openai.Option{
454 openai.WithAPIKey(apiKey),
455 openai.WithUseResponsesAPI(),
456 }
457 if c.cfg.Options.Debug {
458 httpClient := log.NewHTTPClient()
459 opts = append(opts, openai.WithHTTPClient(httpClient))
460 }
461 if len(headers) > 0 {
462 opts = append(opts, openai.WithHeaders(headers))
463 }
464 if baseURL != "" {
465 opts = append(opts, openai.WithBaseURL(baseURL))
466 }
467 return openai.New(opts...)
468}
469
470func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
471 opts := []openrouter.Option{
472 openrouter.WithAPIKey(apiKey),
473 }
474 if c.cfg.Options.Debug {
475 httpClient := log.NewHTTPClient()
476 opts = append(opts, openrouter.WithHTTPClient(httpClient))
477 }
478 if len(headers) > 0 {
479 opts = append(opts, openrouter.WithHeaders(headers))
480 }
481 return openrouter.New(opts...)
482}
483
484func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
485 opts := []openaicompat.Option{
486 openaicompat.WithBaseURL(baseURL),
487 openaicompat.WithAPIKey(apiKey),
488 }
489 if c.cfg.Options.Debug {
490 httpClient := log.NewHTTPClient()
491 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
492 }
493 if len(headers) > 0 {
494 opts = append(opts, openaicompat.WithHeaders(headers))
495 }
496
497 return openaicompat.New(opts...)
498}
499
500func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
501 opts := []azure.Option{
502 azure.WithBaseURL(baseURL),
503 azure.WithAPIKey(apiKey),
504 }
505 if c.cfg.Options.Debug {
506 httpClient := log.NewHTTPClient()
507 opts = append(opts, azure.WithHTTPClient(httpClient))
508 }
509 if options == nil {
510 options = make(map[string]string)
511 }
512 if apiVersion, ok := options["apiVersion"]; ok {
513 opts = append(opts, azure.WithAPIVersion(apiVersion))
514 }
515 if len(headers) > 0 {
516 opts = append(opts, azure.WithHeaders(headers))
517 }
518
519 return azure.New(opts...)
520}
521
522func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
523 opts := []google.Option{
524 google.WithBaseURL(baseURL),
525 google.WithGeminiAPIKey(apiKey),
526 }
527 if c.cfg.Options.Debug {
528 httpClient := log.NewHTTPClient()
529 opts = append(opts, google.WithHTTPClient(httpClient))
530 }
531 if len(headers) > 0 {
532 opts = append(opts, google.WithHeaders(headers))
533 }
534 return google.New(opts...)
535}
536
537func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
538 opts := []google.Option{}
539 if c.cfg.Options.Debug {
540 httpClient := log.NewHTTPClient()
541 opts = append(opts, google.WithHTTPClient(httpClient))
542 }
543 if len(headers) > 0 {
544 opts = append(opts, google.WithHeaders(headers))
545 }
546
547 project := options["project"]
548 location := options["location"]
549
550 opts = append(opts, google.WithVertex(project, location))
551
552 return google.New(opts...)
553}
554
555func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
556 if model.Think {
557 return true
558 }
559
560 if model.ProviderOptions == nil {
561 return false
562 }
563
564 opts, err := anthropic.ParseOptions(model.ProviderOptions)
565 if err != nil {
566 return false
567 }
568 if opts.Thinking != nil {
569 return true
570 }
571 return false
572}
573
574func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
575 headers := providerCfg.ExtraHeaders
576
577 // handle special headers for anthropic
578 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
579 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
580 }
581
582 // TODO: make sure we have
583 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
584 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
585
586 switch providerCfg.Type {
587 case openai.Name:
588 return c.buildOpenaiProvider(baseURL, apiKey, headers)
589 case anthropic.Name:
590 return c.buildAnthropicProvider(baseURL, apiKey, headers)
591 case openrouter.Name:
592 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
593 case azure.Name:
594 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
595 case google.Name:
596 return c.buildGoogleProvider(baseURL, apiKey, headers)
597 case "vertexai":
598 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
599 case openaicompat.Name:
600 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
601 default:
602 return nil, errors.New("provider type not supported")
603 }
604}
605
606func (c *coordinator) Cancel(sessionID string) {
607 c.currentAgent.Cancel(sessionID)
608}
609
610func (c *coordinator) CancelAll() {
611 c.currentAgent.CancelAll()
612}
613
614func (c *coordinator) ClearQueue(sessionID string) {
615 c.currentAgent.ClearQueue(sessionID)
616}
617
618func (c *coordinator) IsBusy() bool {
619 return c.currentAgent.IsBusy()
620}
621
622func (c *coordinator) IsSessionBusy(sessionID string) bool {
623 return c.currentAgent.IsSessionBusy(sessionID)
624}
625
626func (c *coordinator) Model() Model {
627 return c.currentAgent.Model()
628}
629
630func (c *coordinator) UpdateModels(ctx context.Context) error {
631 // build the models again so we make sure we get the latest config
632 large, small, err := c.buildAgentModels(ctx)
633 if err != nil {
634 return err
635 }
636 c.currentAgent.SetModels(large, small)
637
638 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
639 if !ok {
640 return errors.New("coder agent not configured")
641 }
642
643 tools, err := c.buildTools(ctx, agentCfg)
644 if err != nil {
645 return err
646 }
647 c.currentAgent.SetTools(tools)
648 return nil
649}
650
651func (c *coordinator) QueuedPrompts(sessionID string) int {
652 return c.currentAgent.QueuedPrompts(sessionID)
653}
654
655func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
656 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
657 if !ok {
658 return errors.New("model provider not configured")
659 }
660 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg.Type))
661}