1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "io"
10 "log/slog"
11 "slices"
12 "strings"
13
14 "charm.land/fantasy"
15 "github.com/charmbracelet/catwalk/pkg/catwalk"
16 "github.com/charmbracelet/crush/internal/agent/prompt"
17 "github.com/charmbracelet/crush/internal/agent/tools"
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/history"
21 "github.com/charmbracelet/crush/internal/log"
22 "github.com/charmbracelet/crush/internal/lsp"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/session"
26
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/azure"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openaicompat"
32 "charm.land/fantasy/providers/openrouter"
33 "github.com/qjebbs/go-jsons"
34)
35
36type Coordinator interface {
37 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
38 // SetMainAgent(string)
39 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
40 Cancel(sessionID string)
41 CancelAll()
42 IsSessionBusy(sessionID string) bool
43 IsBusy() bool
44 QueuedPrompts(sessionID string) int
45 ClearQueue(sessionID string)
46 Summarize(context.Context, string) error
47 Model() Model
48 UpdateModels() error
49}
50
51type coordinator struct {
52 cfg *config.Config
53 sessions session.Service
54 messages message.Service
55 permissions permission.Service
56 history history.Service
57 lspClients *csync.Map[string, *lsp.Client]
58
59 currentAgent SessionAgent
60 agents map[string]SessionAgent
61}
62
63func NewCoordinator(
64 cfg *config.Config,
65 sessions session.Service,
66 messages message.Service,
67 permissions permission.Service,
68 history history.Service,
69 lspClients *csync.Map[string, *lsp.Client],
70) (Coordinator, error) {
71 c := &coordinator{
72 cfg: cfg,
73 sessions: sessions,
74 messages: messages,
75 permissions: permissions,
76 history: history,
77 lspClients: lspClients,
78 agents: make(map[string]SessionAgent),
79 }
80
81 agentCfg, ok := cfg.Agents[config.AgentCoder]
82 if !ok {
83 return nil, errors.New("coder agent not configured")
84 }
85
86 // TODO: make this dynamic when we support multiple agents
87 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
88 if err != nil {
89 return nil, err
90 }
91
92 agent, err := c.buildAgent(prompt, agentCfg)
93 if err != nil {
94 return nil, err
95 }
96 c.currentAgent = agent
97 c.agents[config.AgentCoder] = agent
98 return c, nil
99}
100
101// Run implements Coordinator.
102func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
103 model := c.currentAgent.Model()
104 maxTokens := model.CatwalkCfg.DefaultMaxTokens
105 if model.ModelCfg.MaxTokens != 0 {
106 maxTokens = model.ModelCfg.MaxTokens
107 }
108
109 if !model.CatwalkCfg.SupportsImages && attachments != nil {
110 attachments = nil
111 }
112
113 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
114 if !ok {
115 return nil, errors.New("model provider not configured")
116 }
117
118 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg.Type)
119
120 return c.currentAgent.Run(ctx, SessionAgentCall{
121 SessionID: sessionID,
122 Prompt: prompt,
123 Attachments: attachments,
124 MaxOutputTokens: maxTokens,
125 ProviderOptions: mergedOptions,
126 Temperature: temp,
127 TopP: topP,
128 TopK: topK,
129 FrequencyPenalty: freqPenalty,
130 PresencePenalty: presPenalty,
131 })
132}
133
134func getProviderOptions(model Model, tp catwalk.Type) fantasy.ProviderOptions {
135 options := fantasy.ProviderOptions{}
136
137 cfgOpts := []byte("{}")
138 catwalkOpts := []byte("{}")
139
140 if model.ModelCfg.ProviderOptions != nil {
141 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
142 if err == nil {
143 cfgOpts = data
144 }
145 }
146
147 if model.CatwalkCfg.Options.ProviderOptions != nil {
148 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
149 if err == nil {
150 catwalkOpts = data
151 }
152 }
153
154 readers := []io.Reader{
155 bytes.NewReader(catwalkOpts),
156 bytes.NewReader(cfgOpts),
157 }
158
159 got, err := jsons.Merge(readers)
160 if err != nil {
161 slog.Error("Could not merge call config", "err", err)
162 return options
163 }
164
165 mergedOptions := make(map[string]any)
166
167 err = json.Unmarshal([]byte(got), &mergedOptions)
168 if err != nil {
169 slog.Error("Could not create config for call", "err", err)
170 return options
171 }
172
173 switch tp {
174 case openai.Name:
175 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
176 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
177 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
178 }
179 parsed, err := openai.ParseOptions(mergedOptions)
180 if err == nil {
181 options[openai.Name] = parsed
182 }
183 case anthropic.Name:
184 _, hasThink := mergedOptions["thinking"]
185 if !hasThink && model.ModelCfg.Think {
186 mergedOptions["thinking"] = map[string]any{
187 // TODO: kujtim see if we need to make this dynamic
188 "budget_tokens": 2000,
189 }
190 }
191 parsed, err := anthropic.ParseOptions(mergedOptions)
192 if err == nil {
193 options[anthropic.Name] = parsed
194 }
195
196 case openrouter.Name:
197 _, hasReasoning := mergedOptions["reasoning"]
198 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
199 mergedOptions["reasoning"] = map[string]any{
200 "enabled": true,
201 "effort": model.ModelCfg.ReasoningEffort,
202 }
203 }
204 parsed, err := openrouter.ParseOptions(mergedOptions)
205 if err == nil {
206 options[openrouter.Name] = parsed
207 }
208 case google.Name:
209 parsed, err := google.ParseOptions(mergedOptions)
210 if err == nil {
211 options[google.Name] = parsed
212 }
213 case azure.Name:
214 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
215 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
216 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
217 }
218 // azure uses the same options as openaicompat
219 parsed, err := openaicompat.ParseOptions(mergedOptions)
220 if err == nil {
221 options[azure.Name] = parsed
222 }
223 case openaicompat.Name:
224 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
225 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
226 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
227 }
228 parsed, err := openaicompat.ParseOptions(mergedOptions)
229 if err == nil {
230 options[openaicompat.Name] = parsed
231 }
232 }
233
234 return options
235}
236
237func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
238 modelOptions := getProviderOptions(model, tp)
239 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
240 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
241 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
242 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
243 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
244 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
245}
246
247func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
248 large, small, err := c.buildAgentModels()
249 if err != nil {
250 return nil, err
251 }
252
253 systemPrompt, err := prompt.Build(large.Model.Provider(), large.Model.Model(), *c.cfg)
254 if err != nil {
255 return nil, err
256 }
257
258 tools, err := c.buildTools(agent)
259 if err != nil {
260 return nil, err
261 }
262 return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
263}
264
265func (c *coordinator) buildTools(agent config.Agent) ([]fantasy.AgentTool, error) {
266 var allTools []fantasy.AgentTool
267 if slices.Contains(agent.AllowedTools, AgentToolName) {
268 agentTool, err := c.agentTool()
269 if err != nil {
270 return nil, err
271 }
272 allTools = append(allTools, agentTool)
273 }
274
275 allTools = append(allTools,
276 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution),
277 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
278 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
279 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
280 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
281 tools.NewGlobTool(c.cfg.WorkingDir()),
282 tools.NewGrepTool(c.cfg.WorkingDir()),
283 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
284 tools.NewSourcegraphTool(nil),
285 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
286 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
287 )
288
289 var filteredTools []fantasy.AgentTool
290 for _, tool := range allTools {
291 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
292 filteredTools = append(filteredTools, tool)
293 }
294 }
295
296 mcpTools := tools.GetMCPTools(context.Background(), c.permissions, c.cfg)
297
298 for _, mcpTool := range mcpTools {
299 if agent.AllowedMCP == nil {
300 // No MCP restrictions
301 filteredTools = append(filteredTools, mcpTool)
302 } else if len(agent.AllowedMCP) == 0 {
303 // no mcps allowed
304 break
305 }
306
307 for mcp, tools := range agent.AllowedMCP {
308 if mcp == mcpTool.MCP() {
309 if len(tools) == 0 {
310 filteredTools = append(filteredTools, mcpTool)
311 }
312 for _, t := range tools {
313 if t == mcpTool.MCPToolName() {
314 filteredTools = append(filteredTools, mcpTool)
315 }
316 }
317 break
318 }
319 }
320 }
321
322 return filteredTools, nil
323}
324
325// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
326func (c *coordinator) buildAgentModels() (Model, Model, error) {
327 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
328 if !ok {
329 return Model{}, Model{}, errors.New("large model not selected")
330 }
331 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
332 if !ok {
333 return Model{}, Model{}, errors.New("small model not selected")
334 }
335
336 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
337 if !ok {
338 return Model{}, Model{}, errors.New("large model provider not configured")
339 }
340
341 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
342 if err != nil {
343 return Model{}, Model{}, err
344 }
345
346 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
347 if !ok {
348 return Model{}, Model{}, errors.New("large model provider not configured")
349 }
350
351 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
352 if err != nil {
353 return Model{}, Model{}, err
354 }
355
356 var largeCatwalkModel *catwalk.Model
357 var smallCatwalkModel *catwalk.Model
358
359 for _, m := range largeProviderCfg.Models {
360 if m.ID == largeModelCfg.Model {
361 largeCatwalkModel = &m
362 }
363 }
364 for _, m := range smallProviderCfg.Models {
365 if m.ID == smallModelCfg.Model {
366 smallCatwalkModel = &m
367 }
368 }
369
370 if largeCatwalkModel == nil {
371 return Model{}, Model{}, errors.New("large model not found in provider config")
372 }
373
374 if smallCatwalkModel == nil {
375 return Model{}, Model{}, errors.New("snall model not found in provider config")
376 }
377
378 largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
379 if err != nil {
380 return Model{}, Model{}, err
381 }
382 smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
383 if err != nil {
384 return Model{}, Model{}, err
385 }
386
387 return Model{
388 Model: largeModel,
389 CatwalkCfg: *largeCatwalkModel,
390 ModelCfg: largeModelCfg,
391 }, Model{
392 Model: smallModel,
393 CatwalkCfg: *smallCatwalkModel,
394 ModelCfg: smallModelCfg,
395 }, nil
396}
397
398func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
399 hasBearerAuth := false
400 for key := range headers {
401 if strings.ToLower(key) == "authorization" {
402 hasBearerAuth = true
403 break
404 }
405 }
406 if hasBearerAuth {
407 apiKey = "" // clear apiKey to avoid using X-Api-Key header
408 }
409
410 var opts []anthropic.Option
411
412 if apiKey != "" {
413 // Use standard X-Api-Key header
414 opts = append(opts, anthropic.WithAPIKey(apiKey))
415 }
416
417 if len(headers) > 0 {
418 opts = append(opts, anthropic.WithHeaders(headers))
419 }
420
421 if baseURL != "" {
422 opts = append(opts, anthropic.WithBaseURL(baseURL))
423 }
424
425 if c.cfg.Options.Debug {
426 httpClient := log.NewHTTPClient()
427 opts = append(opts, anthropic.WithHTTPClient(httpClient))
428 }
429
430 return anthropic.New(opts...)
431}
432
433func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
434 opts := []openai.Option{
435 openai.WithAPIKey(apiKey),
436 }
437 if c.cfg.Options.Debug {
438 httpClient := log.NewHTTPClient()
439 opts = append(opts, openai.WithHTTPClient(httpClient))
440 }
441 if len(headers) > 0 {
442 opts = append(opts, openai.WithHeaders(headers))
443 }
444 if baseURL != "" {
445 opts = append(opts, openai.WithBaseURL(baseURL))
446 }
447 return openai.New(opts...)
448}
449
450func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) fantasy.Provider {
451 opts := []openrouter.Option{
452 openrouter.WithAPIKey(apiKey),
453 }
454 if c.cfg.Options.Debug {
455 httpClient := log.NewHTTPClient()
456 opts = append(opts, openrouter.WithHTTPClient(httpClient))
457 }
458 if len(headers) > 0 {
459 opts = append(opts, openrouter.WithHeaders(headers))
460 }
461 return openrouter.New(opts...)
462}
463
464func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
465 opts := []openaicompat.Option{
466 openaicompat.WithBaseURL(baseURL),
467 openaicompat.WithAPIKey(apiKey),
468 }
469 if c.cfg.Options.Debug {
470 httpClient := log.NewHTTPClient()
471 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
472 }
473 if len(headers) > 0 {
474 opts = append(opts, openaicompat.WithHeaders(headers))
475 }
476
477 return openaicompat.New(opts...)
478}
479
480func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) fantasy.Provider {
481 opts := []azure.Option{
482 azure.WithBaseURL(baseURL),
483 azure.WithAPIKey(apiKey),
484 }
485 if c.cfg.Options.Debug {
486 httpClient := log.NewHTTPClient()
487 opts = append(opts, azure.WithHTTPClient(httpClient))
488 }
489 if options == nil {
490 options = make(map[string]string)
491 }
492 if apiVersion, ok := options["apiVersion"]; ok {
493 opts = append(opts, azure.WithAPIVersion(apiVersion))
494 }
495 if len(headers) > 0 {
496 opts = append(opts, azure.WithHeaders(headers))
497 }
498
499 return azure.New(opts...)
500}
501
502func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
503 opts := []google.Option{
504 google.WithBaseURL(baseURL),
505 google.WithGeminiAPIKey(apiKey),
506 }
507 if c.cfg.Options.Debug {
508 httpClient := log.NewHTTPClient()
509 opts = append(opts, google.WithHTTPClient(httpClient))
510 }
511 if len(headers) > 0 {
512 opts = append(opts, google.WithHeaders(headers))
513 }
514 return google.New(opts...)
515}
516
517func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) fantasy.Provider {
518 opts := []google.Option{}
519 if c.cfg.Options.Debug {
520 httpClient := log.NewHTTPClient()
521 opts = append(opts, google.WithHTTPClient(httpClient))
522 }
523 if len(headers) > 0 {
524 opts = append(opts, google.WithHeaders(headers))
525 }
526
527 project := options["project"]
528 location := options["location"]
529
530 opts = append(opts, google.WithVertex(project, location))
531
532 return google.New(opts...)
533}
534
535func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
536 if model.Think {
537 return true
538 }
539
540 if model.ProviderOptions == nil {
541 return false
542 }
543
544 opts, err := anthropic.ParseOptions(model.ProviderOptions)
545 if err != nil {
546 return false
547 }
548 if opts.Thinking != nil {
549 return true
550 }
551 return false
552}
553
554func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
555 headers := providerCfg.ExtraHeaders
556
557 // handle special headers for anthropic
558 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
559 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
560 }
561
562 // TODO: make sure we have
563 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
564 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
565 var provider fantasy.Provider
566 switch providerCfg.Type {
567 case openai.Name:
568 provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
569 case anthropic.Name:
570 provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
571 case openrouter.Name:
572 provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
573 case azure.Name:
574 provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
575 case google.Name:
576 provider = c.buildGoogleProvider(baseURL, apiKey, headers)
577 // this is not in fantasy since its just the google provider with extra stuff
578 case "google-vertex":
579 provider = c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
580 case openaicompat.Name:
581 provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
582 default:
583 return nil, errors.New("provider type not supported")
584 }
585 return provider, nil
586}
587
588func (c *coordinator) Cancel(sessionID string) {
589 c.currentAgent.Cancel(sessionID)
590}
591
592func (c *coordinator) CancelAll() {
593 c.currentAgent.CancelAll()
594}
595
596func (c *coordinator) ClearQueue(sessionID string) {
597 c.currentAgent.ClearQueue(sessionID)
598}
599
600func (c *coordinator) IsBusy() bool {
601 return c.currentAgent.IsBusy()
602}
603
604func (c *coordinator) IsSessionBusy(sessionID string) bool {
605 return c.currentAgent.IsSessionBusy(sessionID)
606}
607
608func (c *coordinator) Model() Model {
609 return c.currentAgent.Model()
610}
611
612func (c *coordinator) UpdateModels() error {
613 // build the models again so we make sure we get the latest config
614 large, small, err := c.buildAgentModels()
615 if err != nil {
616 return err
617 }
618 c.currentAgent.SetModels(large, small)
619
620 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
621 if !ok {
622 return errors.New("coder agent not configured")
623 }
624
625 tools, err := c.buildTools(agentCfg)
626 if err != nil {
627 return err
628 }
629 c.currentAgent.SetTools(tools)
630 return nil
631}
632
633func (c *coordinator) QueuedPrompts(sessionID string) int {
634 return c.currentAgent.QueuedPrompts(sessionID)
635}
636
637func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
638 return c.currentAgent.Summarize(ctx, sessionID)
639}