1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/fantasy"
19 "github.com/charmbracelet/catwalk/pkg/catwalk"
20 "github.com/charmbracelet/crush/internal/agent/prompt"
21 "github.com/charmbracelet/crush/internal/agent/tools"
22 "github.com/charmbracelet/crush/internal/config"
23 "github.com/charmbracelet/crush/internal/csync"
24 "github.com/charmbracelet/crush/internal/history"
25 "github.com/charmbracelet/crush/internal/log"
26 "github.com/charmbracelet/crush/internal/lsp"
27 "github.com/charmbracelet/crush/internal/message"
28 "github.com/charmbracelet/crush/internal/permission"
29 "github.com/charmbracelet/crush/internal/session"
30 "golang.org/x/sync/errgroup"
31
32 "charm.land/fantasy/providers/anthropic"
33 "charm.land/fantasy/providers/azure"
34 "charm.land/fantasy/providers/bedrock"
35 "charm.land/fantasy/providers/google"
36 "charm.land/fantasy/providers/openai"
37 "charm.land/fantasy/providers/openaicompat"
38 "charm.land/fantasy/providers/openrouter"
39 openaisdk "github.com/openai/openai-go/v2/option"
40 "github.com/qjebbs/go-jsons"
41)
42
43type Coordinator interface {
44 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
45 // SetMainAgent(string)
46 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
47 Cancel(sessionID string)
48 CancelAll()
49 IsSessionBusy(sessionID string) bool
50 IsBusy() bool
51 QueuedPrompts(sessionID string) int
52 QueuedPromptsList(sessionID string) []string
53 ClearQueue(sessionID string)
54 Summarize(context.Context, string) error
55 Model() Model
56 UpdateModels(ctx context.Context) error
57}
58
59type coordinator struct {
60 cfg *config.Config
61 sessions session.Service
62 messages message.Service
63 permissions permission.Service
64 history history.Service
65 lspClients *csync.Map[string, *lsp.Client]
66
67 currentAgent SessionAgent
68 agents map[string]SessionAgent
69
70 readyWg errgroup.Group
71}
72
73func NewCoordinator(
74 ctx context.Context,
75 cfg *config.Config,
76 sessions session.Service,
77 messages message.Service,
78 permissions permission.Service,
79 history history.Service,
80 lspClients *csync.Map[string, *lsp.Client],
81) (Coordinator, error) {
82 c := &coordinator{
83 cfg: cfg,
84 sessions: sessions,
85 messages: messages,
86 permissions: permissions,
87 history: history,
88 lspClients: lspClients,
89 agents: make(map[string]SessionAgent),
90 }
91
92 agentCfg, ok := cfg.Agents[config.AgentCoder]
93 if !ok {
94 return nil, errors.New("coder agent not configured")
95 }
96
97 // TODO: make this dynamic when we support multiple agents
98 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
99 if err != nil {
100 return nil, err
101 }
102
103 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
104 if err != nil {
105 return nil, err
106 }
107 c.currentAgent = agent
108 c.agents[config.AgentCoder] = agent
109 return c, nil
110}
111
112// Run implements Coordinator.
113func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
114 if err := c.readyWg.Wait(); err != nil {
115 return nil, err
116 }
117
118 model := c.currentAgent.Model()
119 maxTokens := model.CatwalkCfg.DefaultMaxTokens
120 if model.ModelCfg.MaxTokens != 0 {
121 maxTokens = model.ModelCfg.MaxTokens
122 }
123
124 if !model.CatwalkCfg.SupportsImages && attachments != nil {
125 attachments = nil
126 }
127
128 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
129 if !ok {
130 return nil, errors.New("model provider not configured")
131 }
132
133 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
134
135 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
136 slog.Info("Token needs to be refreshed", "provider", providerCfg.ID)
137 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
138 return nil, err
139 }
140 }
141
142 run := func() (*fantasy.AgentResult, error) {
143 return c.currentAgent.Run(ctx, SessionAgentCall{
144 SessionID: sessionID,
145 Prompt: prompt,
146 Attachments: attachments,
147 MaxOutputTokens: maxTokens,
148 ProviderOptions: mergedOptions,
149 Temperature: temp,
150 TopP: topP,
151 TopK: topK,
152 FrequencyPenalty: freqPenalty,
153 PresencePenalty: presPenalty,
154 })
155 }
156 result, originalErr := run()
157
158 if c.isUnauthorized(originalErr) {
159 switch {
160 case providerCfg.OAuthToken != nil:
161 slog.Info("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
162 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
163 return nil, originalErr
164 }
165 slog.Info("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
166 return run()
167 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
168 slog.Info("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
169 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
170 return nil, originalErr
171 }
172 slog.Info("Retrying request with refreshed API key", "provider", providerCfg.ID)
173 return run()
174 }
175 }
176
177 return result, originalErr
178}
179
180func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
181 options := fantasy.ProviderOptions{}
182
183 cfgOpts := []byte("{}")
184 providerCfgOpts := []byte("{}")
185 catwalkOpts := []byte("{}")
186
187 if model.ModelCfg.ProviderOptions != nil {
188 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
189 if err == nil {
190 cfgOpts = data
191 }
192 }
193
194 if providerCfg.ProviderOptions != nil {
195 data, err := json.Marshal(providerCfg.ProviderOptions)
196 if err == nil {
197 providerCfgOpts = data
198 }
199 }
200
201 if model.CatwalkCfg.Options.ProviderOptions != nil {
202 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
203 if err == nil {
204 catwalkOpts = data
205 }
206 }
207
208 readers := []io.Reader{
209 bytes.NewReader(catwalkOpts),
210 bytes.NewReader(providerCfgOpts),
211 bytes.NewReader(cfgOpts),
212 }
213
214 got, err := jsons.Merge(readers)
215 if err != nil {
216 slog.Error("Could not merge call config", "err", err)
217 return options
218 }
219
220 mergedOptions := make(map[string]any)
221
222 err = json.Unmarshal([]byte(got), &mergedOptions)
223 if err != nil {
224 slog.Error("Could not create config for call", "err", err)
225 return options
226 }
227
228 switch providerCfg.Type {
229 case openai.Name, azure.Name:
230 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
231 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
232 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
233 }
234 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
235 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
236 mergedOptions["reasoning_summary"] = "auto"
237 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
238 }
239 parsed, err := openai.ParseResponsesOptions(mergedOptions)
240 if err == nil {
241 options[openai.Name] = parsed
242 }
243 } else {
244 parsed, err := openai.ParseOptions(mergedOptions)
245 if err == nil {
246 options[openai.Name] = parsed
247 }
248 }
249 case anthropic.Name:
250 _, hasThink := mergedOptions["thinking"]
251 if !hasThink && model.ModelCfg.Think {
252 mergedOptions["thinking"] = map[string]any{
253 // TODO: kujtim see if we need to make this dynamic
254 "budget_tokens": 2000,
255 }
256 }
257 parsed, err := anthropic.ParseOptions(mergedOptions)
258 if err == nil {
259 options[anthropic.Name] = parsed
260 }
261
262 case openrouter.Name:
263 _, hasReasoning := mergedOptions["reasoning"]
264 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
265 mergedOptions["reasoning"] = map[string]any{
266 "enabled": true,
267 "effort": model.ModelCfg.ReasoningEffort,
268 }
269 }
270 parsed, err := openrouter.ParseOptions(mergedOptions)
271 if err == nil {
272 options[openrouter.Name] = parsed
273 }
274 case google.Name:
275 _, hasReasoning := mergedOptions["thinking_config"]
276 if !hasReasoning {
277 mergedOptions["thinking_config"] = map[string]any{
278 "thinking_budget": 2000,
279 "include_thoughts": true,
280 }
281 }
282 parsed, err := google.ParseOptions(mergedOptions)
283 if err == nil {
284 options[google.Name] = parsed
285 }
286 case openaicompat.Name:
287 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
288 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
289 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
290 }
291 parsed, err := openaicompat.ParseOptions(mergedOptions)
292 if err == nil {
293 options[openaicompat.Name] = parsed
294 }
295 }
296
297 return options
298}
299
300func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
301 modelOptions := getProviderOptions(model, cfg)
302 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
303 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
304 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
305 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
306 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
307 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
308}
309
310func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
311 large, small, err := c.buildAgentModels(ctx)
312 if err != nil {
313 return nil, err
314 }
315
316 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
317 if err != nil {
318 return nil, err
319 }
320
321 largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
322 result := NewSessionAgent(SessionAgentOptions{
323 large,
324 small,
325 largeProviderCfg.SystemPromptPrefix,
326 systemPrompt,
327 isSubAgent,
328 c.cfg.Options.DisableAutoSummarize,
329 c.permissions.SkipRequests(),
330 c.sessions,
331 c.messages,
332 nil,
333 })
334 c.readyWg.Go(func() error {
335 tools, err := c.buildTools(ctx, agent)
336 if err != nil {
337 return err
338 }
339 result.SetTools(tools)
340 return nil
341 })
342
343 return result, nil
344}
345
346func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
347 var allTools []fantasy.AgentTool
348 if slices.Contains(agent.AllowedTools, AgentToolName) {
349 agentTool, err := c.agentTool(ctx)
350 if err != nil {
351 return nil, err
352 }
353 allTools = append(allTools, agentTool)
354 }
355
356 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
357 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
358 if err != nil {
359 return nil, err
360 }
361 allTools = append(allTools, agenticFetchTool)
362 }
363
364 // Get the model name for the agent
365 modelName := ""
366 if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
367 if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
368 modelName = model.Name
369 }
370 }
371
372 allTools = append(allTools,
373 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
374 tools.NewJobOutputTool(),
375 tools.NewJobKillTool(),
376 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
377 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
378 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
379 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
380 tools.NewGlobTool(c.cfg.WorkingDir()),
381 tools.NewGrepTool(c.cfg.WorkingDir()),
382 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
383 tools.NewSourcegraphTool(nil),
384 tools.NewTodosTool(c.sessions),
385 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir()),
386 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
387 )
388
389 if len(c.cfg.LSP) > 0 {
390 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspClients), tools.NewReferencesTool(c.lspClients))
391 }
392
393 var filteredTools []fantasy.AgentTool
394 for _, tool := range allTools {
395 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
396 filteredTools = append(filteredTools, tool)
397 }
398 }
399
400 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg.WorkingDir()) {
401 // Check MCP-specific disabled tools.
402 if mcpCfg, ok := c.cfg.MCP[tool.MCP()]; ok {
403 if slices.Contains(mcpCfg.DisabledTools, tool.MCPToolName()) {
404 continue
405 }
406 }
407 if agent.AllowedMCP == nil {
408 // No MCP restrictions
409 filteredTools = append(filteredTools, tool)
410 continue
411 }
412 if len(agent.AllowedMCP) == 0 {
413 // No MCPs allowed
414 slog.Debug("no MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
415 break
416 }
417
418 for mcp, tools := range agent.AllowedMCP {
419 if mcp != tool.MCP() {
420 continue
421 }
422 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
423 filteredTools = append(filteredTools, tool)
424 }
425 }
426 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
427 }
428 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
429 return strings.Compare(a.Info().Name, b.Info().Name)
430 })
431 return filteredTools, nil
432}
433
434// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
435func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
436 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
437 if !ok {
438 return Model{}, Model{}, errors.New("large model not selected")
439 }
440 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
441 if !ok {
442 return Model{}, Model{}, errors.New("small model not selected")
443 }
444
445 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
446 if !ok {
447 return Model{}, Model{}, errors.New("large model provider not configured")
448 }
449
450 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg)
451 if err != nil {
452 return Model{}, Model{}, err
453 }
454
455 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
456 if !ok {
457 return Model{}, Model{}, errors.New("large model provider not configured")
458 }
459
460 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg)
461 if err != nil {
462 return Model{}, Model{}, err
463 }
464
465 var largeCatwalkModel *catwalk.Model
466 var smallCatwalkModel *catwalk.Model
467
468 for _, m := range largeProviderCfg.Models {
469 if m.ID == largeModelCfg.Model {
470 largeCatwalkModel = &m
471 }
472 }
473 for _, m := range smallProviderCfg.Models {
474 if m.ID == smallModelCfg.Model {
475 smallCatwalkModel = &m
476 }
477 }
478
479 if largeCatwalkModel == nil {
480 return Model{}, Model{}, errors.New("large model not found in provider config")
481 }
482
483 if smallCatwalkModel == nil {
484 return Model{}, Model{}, errors.New("snall model not found in provider config")
485 }
486
487 largeModelID := largeModelCfg.Model
488 smallModelID := smallModelCfg.Model
489
490 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
491 largeModelID += ":exacto"
492 }
493
494 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
495 smallModelID += ":exacto"
496 }
497
498 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
499 if err != nil {
500 return Model{}, Model{}, err
501 }
502 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
503 if err != nil {
504 return Model{}, Model{}, err
505 }
506
507 return Model{
508 Model: largeModel,
509 CatwalkCfg: *largeCatwalkModel,
510 ModelCfg: largeModelCfg,
511 }, Model{
512 Model: smallModel,
513 CatwalkCfg: *smallCatwalkModel,
514 ModelCfg: smallModelCfg,
515 }, nil
516}
517
518func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
519 var opts []anthropic.Option
520
521 if strings.HasPrefix(apiKey, "Bearer ") {
522 // NOTE: Prevent the SDK from picking up the API key from env.
523 os.Setenv("ANTHROPIC_API_KEY", "")
524
525 headers["Authorization"] = apiKey
526 } else if apiKey != "" {
527 // X-Api-Key header
528 opts = append(opts, anthropic.WithAPIKey(apiKey))
529 }
530
531 if len(headers) > 0 {
532 opts = append(opts, anthropic.WithHeaders(headers))
533 }
534
535 if baseURL != "" {
536 opts = append(opts, anthropic.WithBaseURL(baseURL))
537 }
538
539 if c.cfg.Options.Debug {
540 httpClient := log.NewHTTPClient()
541 opts = append(opts, anthropic.WithHTTPClient(httpClient))
542 }
543
544 return anthropic.New(opts...)
545}
546
547func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
548 opts := []openai.Option{
549 openai.WithAPIKey(apiKey),
550 openai.WithUseResponsesAPI(),
551 }
552 if c.cfg.Options.Debug {
553 httpClient := log.NewHTTPClient()
554 opts = append(opts, openai.WithHTTPClient(httpClient))
555 }
556 if len(headers) > 0 {
557 opts = append(opts, openai.WithHeaders(headers))
558 }
559 if baseURL != "" {
560 opts = append(opts, openai.WithBaseURL(baseURL))
561 }
562 return openai.New(opts...)
563}
564
565func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
566 opts := []openrouter.Option{
567 openrouter.WithAPIKey(apiKey),
568 }
569 if c.cfg.Options.Debug {
570 httpClient := log.NewHTTPClient()
571 opts = append(opts, openrouter.WithHTTPClient(httpClient))
572 }
573 if len(headers) > 0 {
574 opts = append(opts, openrouter.WithHeaders(headers))
575 }
576 return openrouter.New(opts...)
577}
578
579func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any) (fantasy.Provider, error) {
580 opts := []openaicompat.Option{
581 openaicompat.WithBaseURL(baseURL),
582 openaicompat.WithAPIKey(apiKey),
583 }
584 if c.cfg.Options.Debug {
585 httpClient := log.NewHTTPClient()
586 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
587 }
588 if len(headers) > 0 {
589 opts = append(opts, openaicompat.WithHeaders(headers))
590 }
591
592 for extraKey, extraValue := range extraBody {
593 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
594 }
595
596 return openaicompat.New(opts...)
597}
598
599func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
600 opts := []azure.Option{
601 azure.WithBaseURL(baseURL),
602 azure.WithAPIKey(apiKey),
603 azure.WithUseResponsesAPI(),
604 }
605 if c.cfg.Options.Debug {
606 httpClient := log.NewHTTPClient()
607 opts = append(opts, azure.WithHTTPClient(httpClient))
608 }
609 if options == nil {
610 options = make(map[string]string)
611 }
612 if apiVersion, ok := options["apiVersion"]; ok {
613 opts = append(opts, azure.WithAPIVersion(apiVersion))
614 }
615 if len(headers) > 0 {
616 opts = append(opts, azure.WithHeaders(headers))
617 }
618
619 return azure.New(opts...)
620}
621
622func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
623 var opts []bedrock.Option
624 if c.cfg.Options.Debug {
625 httpClient := log.NewHTTPClient()
626 opts = append(opts, bedrock.WithHTTPClient(httpClient))
627 }
628 if len(headers) > 0 {
629 opts = append(opts, bedrock.WithHeaders(headers))
630 }
631 bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
632 if bearerToken != "" {
633 opts = append(opts, bedrock.WithAPIKey(bearerToken))
634 }
635 return bedrock.New(opts...)
636}
637
638func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
639 opts := []google.Option{
640 google.WithBaseURL(baseURL),
641 google.WithGeminiAPIKey(apiKey),
642 }
643 if c.cfg.Options.Debug {
644 httpClient := log.NewHTTPClient()
645 opts = append(opts, google.WithHTTPClient(httpClient))
646 }
647 if len(headers) > 0 {
648 opts = append(opts, google.WithHeaders(headers))
649 }
650 return google.New(opts...)
651}
652
653func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
654 opts := []google.Option{}
655 if c.cfg.Options.Debug {
656 httpClient := log.NewHTTPClient()
657 opts = append(opts, google.WithHTTPClient(httpClient))
658 }
659 if len(headers) > 0 {
660 opts = append(opts, google.WithHeaders(headers))
661 }
662
663 project := options["project"]
664 location := options["location"]
665
666 opts = append(opts, google.WithVertex(project, location))
667
668 return google.New(opts...)
669}
670
671func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
672 if model.Think {
673 return true
674 }
675
676 if model.ProviderOptions == nil {
677 return false
678 }
679
680 opts, err := anthropic.ParseOptions(model.ProviderOptions)
681 if err != nil {
682 return false
683 }
684 if opts.Thinking != nil {
685 return true
686 }
687 return false
688}
689
690func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel) (fantasy.Provider, error) {
691 headers := maps.Clone(providerCfg.ExtraHeaders)
692 if headers == nil {
693 headers = make(map[string]string)
694 }
695
696 // handle special headers for anthropic
697 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
698 if v, ok := headers["anthropic-beta"]; ok {
699 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
700 } else {
701 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
702 }
703 }
704
705 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
706 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
707
708 switch providerCfg.Type {
709 case openai.Name:
710 return c.buildOpenaiProvider(baseURL, apiKey, headers)
711 case anthropic.Name:
712 return c.buildAnthropicProvider(baseURL, apiKey, headers)
713 case openrouter.Name:
714 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
715 case azure.Name:
716 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
717 case bedrock.Name:
718 return c.buildBedrockProvider(headers)
719 case google.Name:
720 return c.buildGoogleProvider(baseURL, apiKey, headers)
721 case "google-vertex":
722 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
723 case openaicompat.Name:
724 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
725 if providerCfg.ExtraBody == nil {
726 providerCfg.ExtraBody = map[string]any{}
727 }
728 providerCfg.ExtraBody["tool_stream"] = true
729 }
730 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody)
731 default:
732 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
733 }
734}
735
736func isExactoSupported(modelID string) bool {
737 supportedModels := []string{
738 "moonshotai/kimi-k2-0905",
739 "deepseek/deepseek-v3.1-terminus",
740 "z-ai/glm-4.6",
741 "openai/gpt-oss-120b",
742 "qwen/qwen3-coder",
743 }
744 return slices.Contains(supportedModels, modelID)
745}
746
747func (c *coordinator) Cancel(sessionID string) {
748 c.currentAgent.Cancel(sessionID)
749}
750
751func (c *coordinator) CancelAll() {
752 c.currentAgent.CancelAll()
753}
754
755func (c *coordinator) ClearQueue(sessionID string) {
756 c.currentAgent.ClearQueue(sessionID)
757}
758
759func (c *coordinator) IsBusy() bool {
760 return c.currentAgent.IsBusy()
761}
762
763func (c *coordinator) IsSessionBusy(sessionID string) bool {
764 return c.currentAgent.IsSessionBusy(sessionID)
765}
766
767func (c *coordinator) Model() Model {
768 return c.currentAgent.Model()
769}
770
771func (c *coordinator) UpdateModels(ctx context.Context) error {
772 // build the models again so we make sure we get the latest config
773 large, small, err := c.buildAgentModels(ctx)
774 if err != nil {
775 return err
776 }
777 c.currentAgent.SetModels(large, small)
778
779 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
780 if !ok {
781 return errors.New("coder agent not configured")
782 }
783
784 tools, err := c.buildTools(ctx, agentCfg)
785 if err != nil {
786 return err
787 }
788 c.currentAgent.SetTools(tools)
789 return nil
790}
791
792func (c *coordinator) QueuedPrompts(sessionID string) int {
793 return c.currentAgent.QueuedPrompts(sessionID)
794}
795
796func (c *coordinator) QueuedPromptsList(sessionID string) []string {
797 return c.currentAgent.QueuedPromptsList(sessionID)
798}
799
800func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
801 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
802 if !ok {
803 return errors.New("model provider not configured")
804 }
805 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
806}
807
808func (c *coordinator) isUnauthorized(err error) bool {
809 var providerErr *fantasy.ProviderError
810 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
811}
812
813func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
814 if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
815 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
816 return err
817 }
818 if err := c.UpdateModels(ctx); err != nil {
819 return err
820 }
821 return nil
822}
823
824func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
825 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
826 if err != nil {
827 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
828 return err
829 }
830
831 providerCfg.APIKey = newAPIKey
832 c.cfg.Providers.Set(providerCfg.ID, providerCfg)
833
834 if err := c.UpdateModels(ctx); err != nil {
835 return err
836 }
837 return nil
838}