1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "io"
10 "log/slog"
11 "slices"
12 "strings"
13
14 "charm.land/fantasy"
15 "github.com/charmbracelet/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/crush/internal/agent/prompt"
17 "github.com/charmbracelet/crush/internal/agent/tools"
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/history"
21 "github.com/charmbracelet/crush/internal/log"
22 "github.com/charmbracelet/crush/internal/lsp"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/session"
26
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/azure"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openaicompat"
32 "charm.land/fantasy/providers/openrouter"
33 "github.com/qjebbs/go-jsons"
34)
35
36type Coordinator interface {
37 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
38 // SetMainAgent(string)
39 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
40 Cancel(sessionID string)
41 CancelAll()
42 IsSessionBusy(sessionID string) bool
43 IsBusy() bool
44 QueuedPrompts(sessionID string) int
45 ClearQueue(sessionID string)
46 Summarize(context.Context, string) error
47 Model() Model
48 UpdateModels(ctx context.Context) error
49}
50
51type coordinator struct {
52 cfg *config.Config
53 sessions session.Service
54 messages message.Service
55 permissions permission.Service
56 history history.Service
57 lspClients *csync.Map[string, *lsp.Client]
58
59 currentAgent SessionAgent
60 agents map[string]SessionAgent
61}
62
63func NewCoordinator(
64 ctx context.Context,
65 cfg *config.Config,
66 sessions session.Service,
67 messages message.Service,
68 permissions permission.Service,
69 history history.Service,
70 lspClients *csync.Map[string, *lsp.Client],
71) (Coordinator, error) {
72 c := &coordinator{
73 cfg: cfg,
74 sessions: sessions,
75 messages: messages,
76 permissions: permissions,
77 history: history,
78 lspClients: lspClients,
79 agents: make(map[string]SessionAgent),
80 }
81
82 agentCfg, ok := cfg.Agents[config.AgentCoder]
83 if !ok {
84 return nil, errors.New("coder agent not configured")
85 }
86
87 // TODO: make this dynamic when we support multiple agents
88 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
89 if err != nil {
90 return nil, err
91 }
92
93 agent, err := c.buildAgent(ctx, prompt, agentCfg)
94 if err != nil {
95 return nil, err
96 }
97 c.currentAgent = agent
98 c.agents[config.AgentCoder] = agent
99 return c, nil
100}
101
102// Run implements Coordinator.
103func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
104 model := c.currentAgent.Model()
105 maxTokens := model.CatwalkCfg.DefaultMaxTokens
106 if model.ModelCfg.MaxTokens != 0 {
107 maxTokens = model.ModelCfg.MaxTokens
108 }
109
110 if !model.CatwalkCfg.SupportsImages && attachments != nil {
111 attachments = nil
112 }
113
114 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
115 if !ok {
116 return nil, errors.New("model provider not configured")
117 }
118
119 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
120
121 return c.currentAgent.Run(ctx, SessionAgentCall{
122 SessionID: sessionID,
123 Prompt: prompt,
124 Attachments: attachments,
125 MaxOutputTokens: maxTokens,
126 ProviderOptions: mergedOptions,
127 Temperature: temp,
128 TopP: topP,
129 TopK: topK,
130 FrequencyPenalty: freqPenalty,
131 PresencePenalty: presPenalty,
132 })
133}
134
135func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
136 options := fantasy.ProviderOptions{}
137
138 cfgOpts := []byte("{}")
139 catwalkOpts := []byte("{}")
140
141 if model.ModelCfg.ProviderOptions != nil {
142 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
143 if err == nil {
144 cfgOpts = data
145 }
146 }
147
148 if model.CatwalkCfg.Options.ProviderOptions != nil {
149 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
150 if err == nil {
151 catwalkOpts = data
152 }
153 }
154
155 readers := []io.Reader{
156 bytes.NewReader(catwalkOpts),
157 bytes.NewReader(cfgOpts),
158 }
159
160 got, err := jsons.Merge(readers)
161 if err != nil {
162 slog.Error("Could not merge call config", "err", err)
163 return options
164 }
165
166 mergedOptions := make(map[string]any)
167
168 err = json.Unmarshal([]byte(got), &mergedOptions)
169 if err != nil {
170 slog.Error("Could not create config for call", "err", err)
171 return options
172 }
173
174 switch tp {
175 case openai.Name:
176 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
177 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
178 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
179 }
180 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
181 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
182 mergedOptions["reasoning_summary"] = "auto"
183 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
184 }
185 parsed, err := openai.ParseResponsesOptions(mergedOptions)
186 if err == nil {
187 options[openai.Name] = parsed
188 }
189 } else {
190 parsed, err := openai.ParseOptions(mergedOptions)
191 if err == nil {
192 options[openai.Name] = parsed
193 }
194 }
195 case anthropic.Name:
196 _, hasThink := mergedOptions["thinking"]
197 if !hasThink && model.ModelCfg.Think {
198 mergedOptions["thinking"] = map[string]any{
199 // TODO: kujtim see if we need to make this dynamic
200 "budget_tokens": 2000,
201 }
202 }
203 parsed, err := anthropic.ParseOptions(mergedOptions)
204 if err == nil {
205 options[anthropic.Name] = parsed
206 }
207
208 case openrouter.Name:
209 _, hasReasoning := mergedOptions["reasoning"]
210 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
211 mergedOptions["reasoning"] = map[string]any{
212 "enabled": true,
213 "effort": model.ModelCfg.ReasoningEffort,
214 }
215 }
216 parsed, err := openrouter.ParseOptions(mergedOptions)
217 if err == nil {
218 options[openrouter.Name] = parsed
219 }
220 case google.Name:
221 parsed, err := google.ParseOptions(mergedOptions)
222 if err == nil {
223 options[google.Name] = parsed
224 }
225 case azure.Name:
226 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
227 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
228 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
229 }
230 // azure uses the same options as openaicompat
231 parsed, err := openaicompat.ParseOptions(mergedOptions)
232 if err == nil {
233 options[azure.Name] = parsed
234 }
235 case openaicompat.Name:
236 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
237 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
238 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
239 }
240 parsed, err := openaicompat.ParseOptions(mergedOptions)
241 if err == nil {
242 options[openaicompat.Name] = parsed
243 }
244 }
245
246 return options
247}
248
249func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
250 modelOptions := getProviderOptions(model, tp)
251 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
252 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
253 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
254 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
255 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
256 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
257}
258
259func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
260 large, small, err := c.buildAgentModels(ctx)
261 if err != nil {
262 return nil, err
263 }
264
265 systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
266 if err != nil {
267 return nil, err
268 }
269
270 tools, err := c.buildTools(ctx, agent)
271 if err != nil {
272 return nil, err
273 }
274 return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
275}
276
277func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
278 var allTools []fantasy.AgentTool
279 if slices.Contains(agent.AllowedTools, AgentToolName) {
280 agentTool, err := c.agentTool(ctx)
281 if err != nil {
282 return nil, err
283 }
284 allTools = append(allTools, agentTool)
285 }
286
287 allTools = append(allTools,
288 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
289 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
290 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
291 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
292 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
293 tools.NewGlobTool(c.cfg.WorkingDir()),
294 tools.NewGrepTool(c.cfg.WorkingDir()),
295 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
296 tools.NewSourcegraphTool(nil),
297 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
298 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
299 )
300
301 var filteredTools []fantasy.AgentTool
302 for _, tool := range allTools {
303 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
304 filteredTools = append(filteredTools, tool)
305 }
306 }
307
308 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
309
310 for _, mcpTool := range mcpTools {
311 if agent.AllowedMCP == nil {
312 // No MCP restrictions
313 filteredTools = append(filteredTools, mcpTool)
314 } else if len(agent.AllowedMCP) == 0 {
315 // no mcps allowed
316 break
317 }
318
319 for mcp, tools := range agent.AllowedMCP {
320 if mcp == mcpTool.MCP() {
321 if len(tools) == 0 {
322 filteredTools = append(filteredTools, mcpTool)
323 }
324 for _, t := range tools {
325 if t == mcpTool.MCPToolName() {
326 filteredTools = append(filteredTools, mcpTool)
327 }
328 }
329 break
330 }
331 }
332 }
333
334 return filteredTools, nil
335}
336
337// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
338func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
339 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
340 if !ok {
341 return Model{}, Model{}, errors.New("large model not selected")
342 }
343 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
344 if !ok {
345 return Model{}, Model{}, errors.New("small model not selected")
346 }
347
348 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
349 if !ok {
350 return Model{}, Model{}, errors.New("large model provider not configured")
351 }
352
353 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
354 if err != nil {
355 return Model{}, Model{}, err
356 }
357
358 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
359 if !ok {
360 return Model{}, Model{}, errors.New("large model provider not configured")
361 }
362
363 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
364 if err != nil {
365 return Model{}, Model{}, err
366 }
367
368 var largeCatwalkModel *catwalk.Model
369 var smallCatwalkModel *catwalk.Model
370
371 for _, m := range largeProviderCfg.Models {
372 if m.ID == largeModelCfg.Model {
373 largeCatwalkModel = &m
374 }
375 }
376 for _, m := range smallProviderCfg.Models {
377 if m.ID == smallModelCfg.Model {
378 smallCatwalkModel = &m
379 }
380 }
381
382 if largeCatwalkModel == nil {
383 return Model{}, Model{}, errors.New("large model not found in provider config")
384 }
385
386 if smallCatwalkModel == nil {
387 return Model{}, Model{}, errors.New("snall model not found in provider config")
388 }
389
390 largeModel, err := largeProvider.LanguageModel(ctx, largeModelCfg.Model)
391 if err != nil {
392 return Model{}, Model{}, err
393 }
394 smallModel, err := smallProvider.LanguageModel(ctx, smallModelCfg.Model)
395 if err != nil {
396 return Model{}, Model{}, err
397 }
398
399 return Model{
400 Model: largeModel,
401 CatwalkCfg: *largeCatwalkModel,
402 ModelCfg: largeModelCfg,
403 }, Model{
404 Model: smallModel,
405 CatwalkCfg: *smallCatwalkModel,
406 ModelCfg: smallModelCfg,
407 }, nil
408}
409
410func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
411 hasBearerAuth := false
412 for key := range headers {
413 if strings.ToLower(key) == "authorization" {
414 hasBearerAuth = true
415 break
416 }
417 }
418 if hasBearerAuth {
419 apiKey = "" // clear apiKey to avoid using X-Api-Key header
420 }
421
422 var opts []anthropic.Option
423
424 if apiKey != "" {
425 // Use standard X-Api-Key header
426 opts = append(opts, anthropic.WithAPIKey(apiKey))
427 }
428
429 if len(headers) > 0 {
430 opts = append(opts, anthropic.WithHeaders(headers))
431 }
432
433 if baseURL != "" {
434 opts = append(opts, anthropic.WithBaseURL(baseURL))
435 }
436
437 if c.cfg.Options.Debug {
438 httpClient := log.NewHTTPClient()
439 opts = append(opts, anthropic.WithHTTPClient(httpClient))
440 }
441
442 return anthropic.New(opts...)
443}
444
445func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
446 opts := []openai.Option{
447 openai.WithAPIKey(apiKey),
448 openai.WithUseResponsesAPI(),
449 }
450 if c.cfg.Options.Debug {
451 httpClient := log.NewHTTPClient()
452 opts = append(opts, openai.WithHTTPClient(httpClient))
453 }
454 if len(headers) > 0 {
455 opts = append(opts, openai.WithHeaders(headers))
456 }
457 if baseURL != "" {
458 opts = append(opts, openai.WithBaseURL(baseURL))
459 }
460 return openai.New(opts...)
461}
462
463func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
464 opts := []openrouter.Option{
465 openrouter.WithAPIKey(apiKey),
466 }
467 if c.cfg.Options.Debug {
468 httpClient := log.NewHTTPClient()
469 opts = append(opts, openrouter.WithHTTPClient(httpClient))
470 }
471 if len(headers) > 0 {
472 opts = append(opts, openrouter.WithHeaders(headers))
473 }
474 return openrouter.New(opts...)
475}
476
477func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
478 opts := []openaicompat.Option{
479 openaicompat.WithBaseURL(baseURL),
480 openaicompat.WithAPIKey(apiKey),
481 }
482 if c.cfg.Options.Debug {
483 httpClient := log.NewHTTPClient()
484 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
485 }
486 if len(headers) > 0 {
487 opts = append(opts, openaicompat.WithHeaders(headers))
488 }
489
490 return openaicompat.New(opts...)
491}
492
493func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
494 opts := []azure.Option{
495 azure.WithBaseURL(baseURL),
496 azure.WithAPIKey(apiKey),
497 }
498 if c.cfg.Options.Debug {
499 httpClient := log.NewHTTPClient()
500 opts = append(opts, azure.WithHTTPClient(httpClient))
501 }
502 if options == nil {
503 options = make(map[string]string)
504 }
505 if apiVersion, ok := options["apiVersion"]; ok {
506 opts = append(opts, azure.WithAPIVersion(apiVersion))
507 }
508 if len(headers) > 0 {
509 opts = append(opts, azure.WithHeaders(headers))
510 }
511
512 return azure.New(opts...)
513}
514
515func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
516 opts := []google.Option{
517 google.WithBaseURL(baseURL),
518 google.WithGeminiAPIKey(apiKey),
519 }
520 if c.cfg.Options.Debug {
521 httpClient := log.NewHTTPClient()
522 opts = append(opts, google.WithHTTPClient(httpClient))
523 }
524 if len(headers) > 0 {
525 opts = append(opts, google.WithHeaders(headers))
526 }
527 return google.New(opts...)
528}
529
530func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
531 opts := []google.Option{}
532 if c.cfg.Options.Debug {
533 httpClient := log.NewHTTPClient()
534 opts = append(opts, google.WithHTTPClient(httpClient))
535 }
536 if len(headers) > 0 {
537 opts = append(opts, google.WithHeaders(headers))
538 }
539
540 project := options["project"]
541 location := options["location"]
542
543 opts = append(opts, google.WithVertex(project, location))
544
545 return google.New(opts...)
546}
547
548func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
549 if model.Think {
550 return true
551 }
552
553 if model.ProviderOptions == nil {
554 return false
555 }
556
557 opts, err := anthropic.ParseOptions(model.ProviderOptions)
558 if err != nil {
559 return false
560 }
561 if opts.Thinking != nil {
562 return true
563 }
564 return false
565}
566
567func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
568 headers := providerCfg.ExtraHeaders
569
570 // handle special headers for anthropic
571 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
572 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
573 }
574
575 // TODO: make sure we have
576 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
577 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
578
579 switch providerCfg.Type {
580 case openai.Name:
581 return c.buildOpenaiProvider(baseURL, apiKey, headers)
582 case anthropic.Name:
583 return c.buildAnthropicProvider(baseURL, apiKey, headers)
584 case openrouter.Name:
585 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
586 case azure.Name:
587 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
588 case google.Name:
589 return c.buildGoogleProvider(baseURL, apiKey, headers)
590 case "vertexai":
591 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
592 case openaicompat.Name:
593 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
594 default:
595 return nil, errors.New("provider type not supported")
596 }
597}
598
599func (c *coordinator) Cancel(sessionID string) {
600 c.currentAgent.Cancel(sessionID)
601}
602
603func (c *coordinator) CancelAll() {
604 c.currentAgent.CancelAll()
605}
606
607func (c *coordinator) ClearQueue(sessionID string) {
608 c.currentAgent.ClearQueue(sessionID)
609}
610
611func (c *coordinator) IsBusy() bool {
612 return c.currentAgent.IsBusy()
613}
614
615func (c *coordinator) IsSessionBusy(sessionID string) bool {
616 return c.currentAgent.IsSessionBusy(sessionID)
617}
618
619func (c *coordinator) Model() Model {
620 return c.currentAgent.Model()
621}
622
623func (c *coordinator) UpdateModels(ctx context.Context) error {
624 // build the models again so we make sure we get the latest config
625 large, small, err := c.buildAgentModels(ctx)
626 if err != nil {
627 return err
628 }
629 c.currentAgent.SetModels(large, small)
630
631 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
632 if !ok {
633 return errors.New("coder agent not configured")
634 }
635
636 tools, err := c.buildTools(ctx, agentCfg)
637 if err != nil {
638 return err
639 }
640 c.currentAgent.SetTools(tools)
641 return nil
642}
643
644func (c *coordinator) QueuedPrompts(sessionID string) int {
645 return c.currentAgent.QueuedPrompts(sessionID)
646}
647
648func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
649 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
650 if !ok {
651 return errors.New("model provider not configured")
652 }
653 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg.Type))
654}