1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "io"
10 "log/slog"
11 "slices"
12 "strings"
13
14 "charm.land/fantasy"
15 "github.com/charmbracelet/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/crush/internal/agent/prompt"
17 "github.com/charmbracelet/crush/internal/agent/tools"
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/history"
21 "github.com/charmbracelet/crush/internal/log"
22 "github.com/charmbracelet/crush/internal/lsp"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/session"
26
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/azure"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openaicompat"
32 "charm.land/fantasy/providers/openrouter"
33 "github.com/qjebbs/go-jsons"
34)
35
36type Coordinator interface {
37 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
38 // SetMainAgent(string)
39 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
40 Cancel(sessionID string)
41 CancelAll()
42 IsSessionBusy(sessionID string) bool
43 IsBusy() bool
44 QueuedPrompts(sessionID string) int
45 ClearQueue(sessionID string)
46 Summarize(context.Context, string) error
47 Model() Model
48 UpdateModels(ctx context.Context) error
49}
50
51type coordinator struct {
52 cfg *config.Config
53 sessions session.Service
54 messages message.Service
55 permissions permission.Service
56 history history.Service
57 lspClients *csync.Map[string, *lsp.Client]
58
59 currentAgent SessionAgent
60 agents map[string]SessionAgent
61}
62
63func NewCoordinator(
64 ctx context.Context,
65 cfg *config.Config,
66 sessions session.Service,
67 messages message.Service,
68 permissions permission.Service,
69 history history.Service,
70 lspClients *csync.Map[string, *lsp.Client],
71) (Coordinator, error) {
72 c := &coordinator{
73 cfg: cfg,
74 sessions: sessions,
75 messages: messages,
76 permissions: permissions,
77 history: history,
78 lspClients: lspClients,
79 agents: make(map[string]SessionAgent),
80 }
81
82 agentCfg, ok := cfg.Agents[config.AgentCoder]
83 if !ok {
84 return nil, errors.New("coder agent not configured")
85 }
86
87 // TODO: make this dynamic when we support multiple agents
88 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
89 if err != nil {
90 return nil, err
91 }
92
93 agent, err := c.buildAgent(ctx, prompt, agentCfg)
94 if err != nil {
95 return nil, err
96 }
97 c.currentAgent = agent
98 c.agents[config.AgentCoder] = agent
99 return c, nil
100}
101
102// Run implements Coordinator.
103func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
104 model := c.currentAgent.Model()
105 maxTokens := model.CatwalkCfg.DefaultMaxTokens
106 if model.ModelCfg.MaxTokens != 0 {
107 maxTokens = model.ModelCfg.MaxTokens
108 }
109
110 if !model.CatwalkCfg.SupportsImages && attachments != nil {
111 attachments = nil
112 }
113
114 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
115 if !ok {
116 return nil, errors.New("model provider not configured")
117 }
118
119 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
120
121 return c.currentAgent.Run(ctx, SessionAgentCall{
122 SessionID: sessionID,
123 Prompt: prompt,
124 Attachments: attachments,
125 MaxOutputTokens: maxTokens,
126 ProviderOptions: mergedOptions,
127 Temperature: temp,
128 TopP: topP,
129 TopK: topK,
130 FrequencyPenalty: freqPenalty,
131 PresencePenalty: presPenalty,
132 })
133}
134
135func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
136 options := fantasy.ProviderOptions{}
137
138 cfgOpts := []byte("{}")
139 catwalkOpts := []byte("{}")
140
141 if model.ModelCfg.ProviderOptions != nil {
142 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
143 if err == nil {
144 cfgOpts = data
145 }
146 }
147
148 if model.CatwalkCfg.Options.ProviderOptions != nil {
149 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
150 if err == nil {
151 catwalkOpts = data
152 }
153 }
154
155 readers := []io.Reader{
156 bytes.NewReader(catwalkOpts),
157 bytes.NewReader(cfgOpts),
158 }
159
160 got, err := jsons.Merge(readers)
161 if err != nil {
162 slog.Error("Could not merge call config", "err", err)
163 return options
164 }
165
166 mergedOptions := make(map[string]any)
167
168 err = json.Unmarshal([]byte(got), &mergedOptions)
169 if err != nil {
170 slog.Error("Could not create config for call", "err", err)
171 return options
172 }
173
174 switch tp {
175 case openai.Name:
176 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
177 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
178 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
179 }
180 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
181 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
182 mergedOptions["reasoning_summary"] = "auto"
183 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
184 }
185 parsed, err := openai.ParseResponsesOptions(mergedOptions)
186 if err == nil {
187 options[openai.Name] = parsed
188 }
189 } else {
190 parsed, err := openai.ParseOptions(mergedOptions)
191 if err == nil {
192 options[openai.Name] = parsed
193 }
194 }
195 case anthropic.Name:
196 _, hasThink := mergedOptions["thinking"]
197 if !hasThink && model.ModelCfg.Think {
198 mergedOptions["thinking"] = map[string]any{
199 // TODO: kujtim see if we need to make this dynamic
200 "budget_tokens": 2000,
201 }
202 }
203 parsed, err := anthropic.ParseOptions(mergedOptions)
204 if err == nil {
205 options[anthropic.Name] = parsed
206 }
207
208 case openrouter.Name:
209 _, hasReasoning := mergedOptions["reasoning"]
210 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
211 mergedOptions["reasoning"] = map[string]any{
212 "enabled": true,
213 "effort": model.ModelCfg.ReasoningEffort,
214 }
215 }
216 parsed, err := openrouter.ParseOptions(mergedOptions)
217 if err == nil {
218 options[openrouter.Name] = parsed
219 }
220 case google.Name:
221 _, hasReasoning := mergedOptions["thinking_config"]
222 if !hasReasoning {
223 mergedOptions["thinking_config"] = map[string]any{
224 "thinking_budget": 2000,
225 "include_thoughts": true,
226 }
227 }
228 parsed, err := google.ParseOptions(mergedOptions)
229 if err == nil {
230 options[google.Name] = parsed
231 }
232 case azure.Name:
233 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
234 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
235 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
236 }
237 // azure uses the same options as openaicompat
238 parsed, err := openaicompat.ParseOptions(mergedOptions)
239 if err == nil {
240 options[azure.Name] = parsed
241 }
242 case openaicompat.Name:
243 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
244 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
245 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
246 }
247 parsed, err := openaicompat.ParseOptions(mergedOptions)
248 if err == nil {
249 options[openaicompat.Name] = parsed
250 }
251 }
252
253 return options
254}
255
256func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
257 modelOptions := getProviderOptions(model, tp)
258 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
259 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
260 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
261 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
262 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
263 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
264}
265
266func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
267 large, small, err := c.buildAgentModels(ctx)
268 if err != nil {
269 return nil, err
270 }
271
272 systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
273 if err != nil {
274 return nil, err
275 }
276
277 tools, err := c.buildTools(ctx, agent)
278 if err != nil {
279 return nil, err
280 }
281 return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
282}
283
284func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
285 var allTools []fantasy.AgentTool
286 if slices.Contains(agent.AllowedTools, AgentToolName) {
287 agentTool, err := c.agentTool(ctx)
288 if err != nil {
289 return nil, err
290 }
291 allTools = append(allTools, agentTool)
292 }
293
294 allTools = append(allTools,
295 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
296 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
297 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
298 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
299 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
300 tools.NewGlobTool(c.cfg.WorkingDir()),
301 tools.NewGrepTool(c.cfg.WorkingDir()),
302 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
303 tools.NewSourcegraphTool(nil),
304 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
305 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
306 )
307
308 if len(c.cfg.LSP) > 0 {
309 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspClients), tools.NewReferencesTool(c.lspClients))
310 }
311
312 var filteredTools []fantasy.AgentTool
313 for _, tool := range allTools {
314 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
315 filteredTools = append(filteredTools, tool)
316 }
317 }
318
319 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
320
321 for _, mcpTool := range mcpTools {
322 if agent.AllowedMCP == nil {
323 // No MCP restrictions
324 filteredTools = append(filteredTools, mcpTool)
325 } else if len(agent.AllowedMCP) == 0 {
326 // no mcps allowed
327 break
328 }
329
330 for mcp, tools := range agent.AllowedMCP {
331 if mcp == mcpTool.MCP() {
332 if len(tools) == 0 {
333 filteredTools = append(filteredTools, mcpTool)
334 }
335 for _, t := range tools {
336 if t == mcpTool.MCPToolName() {
337 filteredTools = append(filteredTools, mcpTool)
338 }
339 }
340 break
341 }
342 }
343 }
344 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
345 return strings.Compare(a.Info().Name, b.Info().Name)
346 })
347 return filteredTools, nil
348}
349
350// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
351func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
352 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
353 if !ok {
354 return Model{}, Model{}, errors.New("large model not selected")
355 }
356 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
357 if !ok {
358 return Model{}, Model{}, errors.New("small model not selected")
359 }
360
361 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
362 if !ok {
363 return Model{}, Model{}, errors.New("large model provider not configured")
364 }
365
366 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
367 if err != nil {
368 return Model{}, Model{}, err
369 }
370
371 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
372 if !ok {
373 return Model{}, Model{}, errors.New("large model provider not configured")
374 }
375
376 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
377 if err != nil {
378 return Model{}, Model{}, err
379 }
380
381 var largeCatwalkModel *catwalk.Model
382 var smallCatwalkModel *catwalk.Model
383
384 for _, m := range largeProviderCfg.Models {
385 if m.ID == largeModelCfg.Model {
386 largeCatwalkModel = &m
387 }
388 }
389 for _, m := range smallProviderCfg.Models {
390 if m.ID == smallModelCfg.Model {
391 smallCatwalkModel = &m
392 }
393 }
394
395 if largeCatwalkModel == nil {
396 return Model{}, Model{}, errors.New("large model not found in provider config")
397 }
398
399 if smallCatwalkModel == nil {
400 return Model{}, Model{}, errors.New("snall model not found in provider config")
401 }
402
403 largeModel, err := largeProvider.LanguageModel(ctx, largeModelCfg.Model)
404 if err != nil {
405 return Model{}, Model{}, err
406 }
407 smallModel, err := smallProvider.LanguageModel(ctx, smallModelCfg.Model)
408 if err != nil {
409 return Model{}, Model{}, err
410 }
411
412 return Model{
413 Model: largeModel,
414 CatwalkCfg: *largeCatwalkModel,
415 ModelCfg: largeModelCfg,
416 }, Model{
417 Model: smallModel,
418 CatwalkCfg: *smallCatwalkModel,
419 ModelCfg: smallModelCfg,
420 }, nil
421}
422
423func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
424 hasBearerAuth := false
425 for key := range headers {
426 if strings.ToLower(key) == "authorization" {
427 hasBearerAuth = true
428 break
429 }
430 }
431 if hasBearerAuth {
432 apiKey = "" // clear apiKey to avoid using X-Api-Key header
433 }
434
435 var opts []anthropic.Option
436
437 if apiKey != "" {
438 // Use standard X-Api-Key header
439 opts = append(opts, anthropic.WithAPIKey(apiKey))
440 }
441
442 if len(headers) > 0 {
443 opts = append(opts, anthropic.WithHeaders(headers))
444 }
445
446 if baseURL != "" {
447 opts = append(opts, anthropic.WithBaseURL(baseURL))
448 }
449
450 if c.cfg.Options.Debug {
451 httpClient := log.NewHTTPClient()
452 opts = append(opts, anthropic.WithHTTPClient(httpClient))
453 }
454
455 return anthropic.New(opts...)
456}
457
458func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
459 opts := []openai.Option{
460 openai.WithAPIKey(apiKey),
461 openai.WithUseResponsesAPI(),
462 }
463 if c.cfg.Options.Debug {
464 httpClient := log.NewHTTPClient()
465 opts = append(opts, openai.WithHTTPClient(httpClient))
466 }
467 if len(headers) > 0 {
468 opts = append(opts, openai.WithHeaders(headers))
469 }
470 if baseURL != "" {
471 opts = append(opts, openai.WithBaseURL(baseURL))
472 }
473 return openai.New(opts...)
474}
475
476func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
477 opts := []openrouter.Option{
478 openrouter.WithAPIKey(apiKey),
479 }
480 if c.cfg.Options.Debug {
481 httpClient := log.NewHTTPClient()
482 opts = append(opts, openrouter.WithHTTPClient(httpClient))
483 }
484 if len(headers) > 0 {
485 opts = append(opts, openrouter.WithHeaders(headers))
486 }
487 return openrouter.New(opts...)
488}
489
490func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
491 opts := []openaicompat.Option{
492 openaicompat.WithBaseURL(baseURL),
493 openaicompat.WithAPIKey(apiKey),
494 }
495 if c.cfg.Options.Debug {
496 httpClient := log.NewHTTPClient()
497 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
498 }
499 if len(headers) > 0 {
500 opts = append(opts, openaicompat.WithHeaders(headers))
501 }
502
503 return openaicompat.New(opts...)
504}
505
506func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
507 opts := []azure.Option{
508 azure.WithBaseURL(baseURL),
509 azure.WithAPIKey(apiKey),
510 }
511 if c.cfg.Options.Debug {
512 httpClient := log.NewHTTPClient()
513 opts = append(opts, azure.WithHTTPClient(httpClient))
514 }
515 if options == nil {
516 options = make(map[string]string)
517 }
518 if apiVersion, ok := options["apiVersion"]; ok {
519 opts = append(opts, azure.WithAPIVersion(apiVersion))
520 }
521 if len(headers) > 0 {
522 opts = append(opts, azure.WithHeaders(headers))
523 }
524
525 return azure.New(opts...)
526}
527
528func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
529 opts := []google.Option{
530 google.WithBaseURL(baseURL),
531 google.WithGeminiAPIKey(apiKey),
532 }
533 if c.cfg.Options.Debug {
534 httpClient := log.NewHTTPClient()
535 opts = append(opts, google.WithHTTPClient(httpClient))
536 }
537 if len(headers) > 0 {
538 opts = append(opts, google.WithHeaders(headers))
539 }
540 return google.New(opts...)
541}
542
543func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
544 opts := []google.Option{}
545 if c.cfg.Options.Debug {
546 httpClient := log.NewHTTPClient()
547 opts = append(opts, google.WithHTTPClient(httpClient))
548 }
549 if len(headers) > 0 {
550 opts = append(opts, google.WithHeaders(headers))
551 }
552
553 project := options["project"]
554 location := options["location"]
555
556 opts = append(opts, google.WithVertex(project, location))
557
558 return google.New(opts...)
559}
560
561func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
562 if model.Think {
563 return true
564 }
565
566 if model.ProviderOptions == nil {
567 return false
568 }
569
570 opts, err := anthropic.ParseOptions(model.ProviderOptions)
571 if err != nil {
572 return false
573 }
574 if opts.Thinking != nil {
575 return true
576 }
577 return false
578}
579
580func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
581 headers := providerCfg.ExtraHeaders
582
583 // handle special headers for anthropic
584 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
585 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
586 }
587
588 // TODO: make sure we have
589 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
590 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
591
592 switch providerCfg.Type {
593 case openai.Name:
594 return c.buildOpenaiProvider(baseURL, apiKey, headers)
595 case anthropic.Name:
596 return c.buildAnthropicProvider(baseURL, apiKey, headers)
597 case openrouter.Name:
598 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
599 case azure.Name:
600 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
601 case google.Name:
602 return c.buildGoogleProvider(baseURL, apiKey, headers)
603 case "vertexai":
604 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
605 case openaicompat.Name:
606 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
607 default:
608 return nil, errors.New("provider type not supported")
609 }
610}
611
612func (c *coordinator) Cancel(sessionID string) {
613 c.currentAgent.Cancel(sessionID)
614}
615
616func (c *coordinator) CancelAll() {
617 c.currentAgent.CancelAll()
618}
619
620func (c *coordinator) ClearQueue(sessionID string) {
621 c.currentAgent.ClearQueue(sessionID)
622}
623
624func (c *coordinator) IsBusy() bool {
625 return c.currentAgent.IsBusy()
626}
627
628func (c *coordinator) IsSessionBusy(sessionID string) bool {
629 return c.currentAgent.IsSessionBusy(sessionID)
630}
631
632func (c *coordinator) Model() Model {
633 return c.currentAgent.Model()
634}
635
636func (c *coordinator) UpdateModels(ctx context.Context) error {
637 // build the models again so we make sure we get the latest config
638 large, small, err := c.buildAgentModels(ctx)
639 if err != nil {
640 return err
641 }
642 c.currentAgent.SetModels(large, small)
643
644 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
645 if !ok {
646 return errors.New("coder agent not configured")
647 }
648
649 tools, err := c.buildTools(ctx, agentCfg)
650 if err != nil {
651 return err
652 }
653 c.currentAgent.SetTools(tools)
654 return nil
655}
656
657func (c *coordinator) QueuedPrompts(sessionID string) int {
658 return c.currentAgent.QueuedPrompts(sessionID)
659}
660
661func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
662 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
663 if !ok {
664 return errors.New("model provider not configured")
665 }
666 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg.Type))
667}