1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/fantasy"
19 "github.com/charmbracelet/catwalk/pkg/catwalk"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/prompt"
22 "github.com/charmbracelet/crush/internal/agent/tools"
23 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
24 "github.com/charmbracelet/crush/internal/config"
25 "github.com/charmbracelet/crush/internal/csync"
26 "github.com/charmbracelet/crush/internal/history"
27 "github.com/charmbracelet/crush/internal/log"
28 "github.com/charmbracelet/crush/internal/lsp"
29 "github.com/charmbracelet/crush/internal/message"
30 "github.com/charmbracelet/crush/internal/oauth/copilot"
31 "github.com/charmbracelet/crush/internal/permission"
32 "github.com/charmbracelet/crush/internal/session"
33 "golang.org/x/sync/errgroup"
34
35 "charm.land/fantasy/providers/anthropic"
36 "charm.land/fantasy/providers/azure"
37 "charm.land/fantasy/providers/bedrock"
38 "charm.land/fantasy/providers/google"
39 "charm.land/fantasy/providers/openai"
40 "charm.land/fantasy/providers/openaicompat"
41 "charm.land/fantasy/providers/openrouter"
42 openaisdk "github.com/openai/openai-go/v2/option"
43 "github.com/qjebbs/go-jsons"
44)
45
46type Coordinator interface {
47 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
48 // SetMainAgent(string)
49 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
50 Cancel(sessionID string)
51 CancelAll()
52 IsSessionBusy(sessionID string) bool
53 IsBusy() bool
54 QueuedPrompts(sessionID string) int
55 QueuedPromptsList(sessionID string) []string
56 ClearQueue(sessionID string)
57 Summarize(context.Context, string) error
58 Model() Model
59 UpdateModels(ctx context.Context) error
60}
61
62type coordinator struct {
63 cfg *config.Config
64 sessions session.Service
65 messages message.Service
66 permissions permission.Service
67 history history.Service
68 lspClients *csync.Map[string, *lsp.Client]
69
70 currentAgent SessionAgent
71 agents map[string]SessionAgent
72
73 readyWg errgroup.Group
74}
75
76func NewCoordinator(
77 ctx context.Context,
78 cfg *config.Config,
79 sessions session.Service,
80 messages message.Service,
81 permissions permission.Service,
82 history history.Service,
83 lspClients *csync.Map[string, *lsp.Client],
84) (Coordinator, error) {
85 c := &coordinator{
86 cfg: cfg,
87 sessions: sessions,
88 messages: messages,
89 permissions: permissions,
90 history: history,
91 lspClients: lspClients,
92 agents: make(map[string]SessionAgent),
93 }
94
95 agentCfg, ok := cfg.Agents[config.AgentCoder]
96 if !ok {
97 return nil, errors.New("coder agent not configured")
98 }
99
100 // TODO: make this dynamic when we support multiple agents
101 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
102 if err != nil {
103 return nil, err
104 }
105
106 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
107 if err != nil {
108 return nil, err
109 }
110 c.currentAgent = agent
111 c.agents[config.AgentCoder] = agent
112 return c, nil
113}
114
115// Run implements Coordinator.
116func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
117 if err := c.readyWg.Wait(); err != nil {
118 return nil, err
119 }
120
121 model := c.currentAgent.Model()
122 maxTokens := model.CatwalkCfg.DefaultMaxTokens
123 if model.ModelCfg.MaxTokens != 0 {
124 maxTokens = model.ModelCfg.MaxTokens
125 }
126
127 if !model.CatwalkCfg.SupportsImages && attachments != nil {
128 // filter out image attachments
129 filteredAttachments := make([]message.Attachment, 0, len(attachments))
130 for _, att := range attachments {
131 if att.IsText() {
132 filteredAttachments = append(filteredAttachments, att)
133 }
134 }
135 attachments = filteredAttachments
136 }
137
138 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
139 if !ok {
140 return nil, errors.New("model provider not configured")
141 }
142
143 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
144
145 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
146 slog.Info("Token needs to be refreshed", "provider", providerCfg.ID)
147 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
148 return nil, err
149 }
150 }
151
152 run := func() (*fantasy.AgentResult, error) {
153 return c.currentAgent.Run(ctx, SessionAgentCall{
154 SessionID: sessionID,
155 Prompt: prompt,
156 Attachments: attachments,
157 MaxOutputTokens: maxTokens,
158 ProviderOptions: mergedOptions,
159 Temperature: temp,
160 TopP: topP,
161 TopK: topK,
162 FrequencyPenalty: freqPenalty,
163 PresencePenalty: presPenalty,
164 })
165 }
166 result, originalErr := run()
167
168 if c.isUnauthorized(originalErr) {
169 switch {
170 case providerCfg.OAuthToken != nil:
171 slog.Info("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
172 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
173 return nil, originalErr
174 }
175 slog.Info("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
176 return run()
177 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
178 slog.Info("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
179 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
180 return nil, originalErr
181 }
182 slog.Info("Retrying request with refreshed API key", "provider", providerCfg.ID)
183 return run()
184 }
185 }
186
187 return result, originalErr
188}
189
190func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
191 options := fantasy.ProviderOptions{}
192
193 cfgOpts := []byte("{}")
194 providerCfgOpts := []byte("{}")
195 catwalkOpts := []byte("{}")
196
197 if model.ModelCfg.ProviderOptions != nil {
198 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
199 if err == nil {
200 cfgOpts = data
201 }
202 }
203
204 if providerCfg.ProviderOptions != nil {
205 data, err := json.Marshal(providerCfg.ProviderOptions)
206 if err == nil {
207 providerCfgOpts = data
208 }
209 }
210
211 if model.CatwalkCfg.Options.ProviderOptions != nil {
212 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
213 if err == nil {
214 catwalkOpts = data
215 }
216 }
217
218 readers := []io.Reader{
219 bytes.NewReader(catwalkOpts),
220 bytes.NewReader(providerCfgOpts),
221 bytes.NewReader(cfgOpts),
222 }
223
224 got, err := jsons.Merge(readers)
225 if err != nil {
226 slog.Error("Could not merge call config", "err", err)
227 return options
228 }
229
230 mergedOptions := make(map[string]any)
231
232 err = json.Unmarshal([]byte(got), &mergedOptions)
233 if err != nil {
234 slog.Error("Could not create config for call", "err", err)
235 return options
236 }
237
238 switch providerCfg.Type {
239 case openai.Name, azure.Name:
240 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
241 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
242 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
243 }
244 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
245 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
246 mergedOptions["reasoning_summary"] = "auto"
247 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
248 }
249 parsed, err := openai.ParseResponsesOptions(mergedOptions)
250 if err == nil {
251 options[openai.Name] = parsed
252 }
253 } else {
254 parsed, err := openai.ParseOptions(mergedOptions)
255 if err == nil {
256 options[openai.Name] = parsed
257 }
258 }
259 case anthropic.Name:
260 _, hasThink := mergedOptions["thinking"]
261 if !hasThink && model.ModelCfg.Think {
262 mergedOptions["thinking"] = map[string]any{
263 // TODO: kujtim see if we need to make this dynamic
264 "budget_tokens": 2000,
265 }
266 }
267 parsed, err := anthropic.ParseOptions(mergedOptions)
268 if err == nil {
269 options[anthropic.Name] = parsed
270 }
271
272 case openrouter.Name:
273 _, hasReasoning := mergedOptions["reasoning"]
274 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
275 mergedOptions["reasoning"] = map[string]any{
276 "enabled": true,
277 "effort": model.ModelCfg.ReasoningEffort,
278 }
279 }
280 parsed, err := openrouter.ParseOptions(mergedOptions)
281 if err == nil {
282 options[openrouter.Name] = parsed
283 }
284 case google.Name:
285 _, hasReasoning := mergedOptions["thinking_config"]
286 if !hasReasoning {
287 mergedOptions["thinking_config"] = map[string]any{
288 "thinking_budget": 2000,
289 "include_thoughts": true,
290 }
291 }
292 parsed, err := google.ParseOptions(mergedOptions)
293 if err == nil {
294 options[google.Name] = parsed
295 }
296 case openaicompat.Name:
297 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
298 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
299 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
300 }
301 parsed, err := openaicompat.ParseOptions(mergedOptions)
302 if err == nil {
303 options[openaicompat.Name] = parsed
304 }
305 }
306
307 return options
308}
309
310func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
311 modelOptions := getProviderOptions(model, cfg)
312 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
313 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
314 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
315 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
316 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
317 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
318}
319
320func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
321 large, small, err := c.buildAgentModels(ctx, isSubAgent)
322 if err != nil {
323 return nil, err
324 }
325
326 largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
327 result := NewSessionAgent(SessionAgentOptions{
328 large,
329 small,
330 largeProviderCfg.SystemPromptPrefix,
331 "",
332 isSubAgent,
333 c.cfg.Options.DisableAutoSummarize,
334 c.permissions.SkipRequests(),
335 c.sessions,
336 c.messages,
337 nil,
338 })
339
340 c.readyWg.Go(func() error {
341 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
342 if err != nil {
343 return err
344 }
345 result.SetSystemPrompt(systemPrompt)
346 return nil
347 })
348
349 c.readyWg.Go(func() error {
350 tools, err := c.buildTools(ctx, agent)
351 if err != nil {
352 return err
353 }
354 result.SetTools(tools)
355 return nil
356 })
357
358 return result, nil
359}
360
361func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
362 var allTools []fantasy.AgentTool
363 if slices.Contains(agent.AllowedTools, AgentToolName) {
364 agentTool, err := c.agentTool(ctx)
365 if err != nil {
366 return nil, err
367 }
368 allTools = append(allTools, agentTool)
369 }
370
371 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
372 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
373 if err != nil {
374 return nil, err
375 }
376 allTools = append(allTools, agenticFetchTool)
377 }
378
379 // Get the model name for the agent
380 modelName := ""
381 if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
382 if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
383 modelName = model.Name
384 }
385 }
386
387 allTools = append(allTools,
388 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
389 tools.NewJobOutputTool(),
390 tools.NewJobKillTool(),
391 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
392 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
393 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
394 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
395 tools.NewGlobTool(c.cfg.WorkingDir()),
396 tools.NewGrepTool(c.cfg.WorkingDir()),
397 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
398 tools.NewSourcegraphTool(nil),
399 tools.NewTodosTool(c.sessions),
400 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
401 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
402 )
403
404 if len(c.cfg.LSP) > 0 {
405 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspClients), tools.NewReferencesTool(c.lspClients))
406 }
407
408 var filteredTools []fantasy.AgentTool
409 for _, tool := range allTools {
410 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
411 filteredTools = append(filteredTools, tool)
412 }
413 }
414
415 // Wait for MCP initialization to complete before reading MCP tools.
416 if err := mcp.WaitForInit(ctx); err != nil {
417 return nil, fmt.Errorf("failed to wait for MCP initialization: %w", err)
418 }
419
420 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg.WorkingDir()) {
421 if agent.AllowedMCP == nil {
422 // No MCP restrictions
423 filteredTools = append(filteredTools, tool)
424 continue
425 }
426 if len(agent.AllowedMCP) == 0 {
427 // No MCPs allowed
428 slog.Debug("no MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
429 break
430 }
431
432 for mcp, tools := range agent.AllowedMCP {
433 if mcp != tool.MCP() {
434 continue
435 }
436 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
437 filteredTools = append(filteredTools, tool)
438 }
439 }
440 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
441 }
442 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
443 return strings.Compare(a.Info().Name, b.Info().Name)
444 })
445 return filteredTools, nil
446}
447
448// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
449func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
450 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
451 if !ok {
452 return Model{}, Model{}, errors.New("large model not selected")
453 }
454 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
455 if !ok {
456 return Model{}, Model{}, errors.New("small model not selected")
457 }
458
459 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
460 if !ok {
461 return Model{}, Model{}, errors.New("large model provider not configured")
462 }
463
464 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
465 if err != nil {
466 return Model{}, Model{}, err
467 }
468
469 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
470 if !ok {
471 return Model{}, Model{}, errors.New("large model provider not configured")
472 }
473
474 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg, true)
475 if err != nil {
476 return Model{}, Model{}, err
477 }
478
479 var largeCatwalkModel *catwalk.Model
480 var smallCatwalkModel *catwalk.Model
481
482 for _, m := range largeProviderCfg.Models {
483 if m.ID == largeModelCfg.Model {
484 largeCatwalkModel = &m
485 }
486 }
487 for _, m := range smallProviderCfg.Models {
488 if m.ID == smallModelCfg.Model {
489 smallCatwalkModel = &m
490 }
491 }
492
493 if largeCatwalkModel == nil {
494 return Model{}, Model{}, errors.New("large model not found in provider config")
495 }
496
497 if smallCatwalkModel == nil {
498 return Model{}, Model{}, errors.New("small model not found in provider config")
499 }
500
501 largeModelID := largeModelCfg.Model
502 smallModelID := smallModelCfg.Model
503
504 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
505 largeModelID += ":exacto"
506 }
507
508 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
509 smallModelID += ":exacto"
510 }
511
512 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
513 if err != nil {
514 return Model{}, Model{}, err
515 }
516 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
517 if err != nil {
518 return Model{}, Model{}, err
519 }
520
521 return Model{
522 Model: largeModel,
523 CatwalkCfg: *largeCatwalkModel,
524 ModelCfg: largeModelCfg,
525 }, Model{
526 Model: smallModel,
527 CatwalkCfg: *smallCatwalkModel,
528 ModelCfg: smallModelCfg,
529 }, nil
530}
531
532func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
533 var opts []anthropic.Option
534
535 if strings.HasPrefix(apiKey, "Bearer ") {
536 // NOTE: Prevent the SDK from picking up the API key from env.
537 os.Setenv("ANTHROPIC_API_KEY", "")
538 headers["Authorization"] = apiKey
539 } else if apiKey != "" {
540 // X-Api-Key header
541 opts = append(opts, anthropic.WithAPIKey(apiKey))
542 }
543
544 if len(headers) > 0 {
545 opts = append(opts, anthropic.WithHeaders(headers))
546 }
547
548 if baseURL != "" {
549 opts = append(opts, anthropic.WithBaseURL(baseURL))
550 }
551
552 if c.cfg.Options.Debug {
553 httpClient := log.NewHTTPClient()
554 opts = append(opts, anthropic.WithHTTPClient(httpClient))
555 }
556 return anthropic.New(opts...)
557}
558
559func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
560 opts := []openai.Option{
561 openai.WithAPIKey(apiKey),
562 openai.WithUseResponsesAPI(),
563 }
564 if c.cfg.Options.Debug {
565 httpClient := log.NewHTTPClient()
566 opts = append(opts, openai.WithHTTPClient(httpClient))
567 }
568 if len(headers) > 0 {
569 opts = append(opts, openai.WithHeaders(headers))
570 }
571 if baseURL != "" {
572 opts = append(opts, openai.WithBaseURL(baseURL))
573 }
574 return openai.New(opts...)
575}
576
577func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
578 opts := []openrouter.Option{
579 openrouter.WithAPIKey(apiKey),
580 }
581 if c.cfg.Options.Debug {
582 httpClient := log.NewHTTPClient()
583 opts = append(opts, openrouter.WithHTTPClient(httpClient))
584 }
585 if len(headers) > 0 {
586 opts = append(opts, openrouter.WithHeaders(headers))
587 }
588 return openrouter.New(opts...)
589}
590
591func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
592 opts := []openaicompat.Option{
593 openaicompat.WithBaseURL(baseURL),
594 openaicompat.WithAPIKey(apiKey),
595 }
596
597 // Set HTTP client based on provider and debug mode.
598 var httpClient *http.Client
599 if providerID == string(catwalk.InferenceProviderCopilot) {
600 opts = append(opts, openaicompat.WithUseResponsesAPI())
601 httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
602 } else if c.cfg.Options.Debug {
603 httpClient = log.NewHTTPClient()
604 }
605 if httpClient != nil {
606 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
607 }
608
609 if len(headers) > 0 {
610 opts = append(opts, openaicompat.WithHeaders(headers))
611 }
612
613 for extraKey, extraValue := range extraBody {
614 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
615 }
616
617 return openaicompat.New(opts...)
618}
619
620func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
621 opts := []azure.Option{
622 azure.WithBaseURL(baseURL),
623 azure.WithAPIKey(apiKey),
624 azure.WithUseResponsesAPI(),
625 }
626 if c.cfg.Options.Debug {
627 httpClient := log.NewHTTPClient()
628 opts = append(opts, azure.WithHTTPClient(httpClient))
629 }
630 if options == nil {
631 options = make(map[string]string)
632 }
633 if apiVersion, ok := options["apiVersion"]; ok {
634 opts = append(opts, azure.WithAPIVersion(apiVersion))
635 }
636 if len(headers) > 0 {
637 opts = append(opts, azure.WithHeaders(headers))
638 }
639
640 return azure.New(opts...)
641}
642
643func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
644 var opts []bedrock.Option
645 if c.cfg.Options.Debug {
646 httpClient := log.NewHTTPClient()
647 opts = append(opts, bedrock.WithHTTPClient(httpClient))
648 }
649 if len(headers) > 0 {
650 opts = append(opts, bedrock.WithHeaders(headers))
651 }
652 bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
653 if bearerToken != "" {
654 opts = append(opts, bedrock.WithAPIKey(bearerToken))
655 }
656 return bedrock.New(opts...)
657}
658
659func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
660 opts := []google.Option{
661 google.WithBaseURL(baseURL),
662 google.WithGeminiAPIKey(apiKey),
663 }
664 if c.cfg.Options.Debug {
665 httpClient := log.NewHTTPClient()
666 opts = append(opts, google.WithHTTPClient(httpClient))
667 }
668 if len(headers) > 0 {
669 opts = append(opts, google.WithHeaders(headers))
670 }
671 return google.New(opts...)
672}
673
674func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
675 opts := []google.Option{}
676 if c.cfg.Options.Debug {
677 httpClient := log.NewHTTPClient()
678 opts = append(opts, google.WithHTTPClient(httpClient))
679 }
680 if len(headers) > 0 {
681 opts = append(opts, google.WithHeaders(headers))
682 }
683
684 project := options["project"]
685 location := options["location"]
686
687 opts = append(opts, google.WithVertex(project, location))
688
689 return google.New(opts...)
690}
691
692func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
693 opts := []hyper.Option{
694 hyper.WithBaseURL(baseURL),
695 hyper.WithAPIKey(apiKey),
696 }
697 if c.cfg.Options.Debug {
698 httpClient := log.NewHTTPClient()
699 opts = append(opts, hyper.WithHTTPClient(httpClient))
700 }
701 return hyper.New(opts...)
702}
703
704func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
705 if model.Think {
706 return true
707 }
708
709 if model.ProviderOptions == nil {
710 return false
711 }
712
713 opts, err := anthropic.ParseOptions(model.ProviderOptions)
714 if err != nil {
715 return false
716 }
717 if opts.Thinking != nil {
718 return true
719 }
720 return false
721}
722
723func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
724 headers := maps.Clone(providerCfg.ExtraHeaders)
725 if headers == nil {
726 headers = make(map[string]string)
727 }
728
729 // handle special headers for anthropic
730 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
731 if v, ok := headers["anthropic-beta"]; ok {
732 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
733 } else {
734 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
735 }
736 }
737
738 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
739 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
740
741 switch providerCfg.Type {
742 case openai.Name:
743 return c.buildOpenaiProvider(baseURL, apiKey, headers)
744 case anthropic.Name:
745 return c.buildAnthropicProvider(baseURL, apiKey, headers)
746 case openrouter.Name:
747 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
748 case azure.Name:
749 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
750 case bedrock.Name:
751 return c.buildBedrockProvider(headers)
752 case google.Name:
753 return c.buildGoogleProvider(baseURL, apiKey, headers)
754 case "google-vertex":
755 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
756 case openaicompat.Name:
757 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
758 if providerCfg.ExtraBody == nil {
759 providerCfg.ExtraBody = map[string]any{}
760 }
761 providerCfg.ExtraBody["tool_stream"] = true
762 }
763 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
764 case hyper.Name:
765 return c.buildHyperProvider(baseURL, apiKey)
766 default:
767 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
768 }
769}
770
771func isExactoSupported(modelID string) bool {
772 supportedModels := []string{
773 "moonshotai/kimi-k2-0905",
774 "deepseek/deepseek-v3.1-terminus",
775 "z-ai/glm-4.6",
776 "openai/gpt-oss-120b",
777 "qwen/qwen3-coder",
778 }
779 return slices.Contains(supportedModels, modelID)
780}
781
782func (c *coordinator) Cancel(sessionID string) {
783 c.currentAgent.Cancel(sessionID)
784}
785
786func (c *coordinator) CancelAll() {
787 c.currentAgent.CancelAll()
788}
789
790func (c *coordinator) ClearQueue(sessionID string) {
791 c.currentAgent.ClearQueue(sessionID)
792}
793
794func (c *coordinator) IsBusy() bool {
795 return c.currentAgent.IsBusy()
796}
797
798func (c *coordinator) IsSessionBusy(sessionID string) bool {
799 return c.currentAgent.IsSessionBusy(sessionID)
800}
801
802func (c *coordinator) Model() Model {
803 return c.currentAgent.Model()
804}
805
806func (c *coordinator) UpdateModels(ctx context.Context) error {
807 // build the models again so we make sure we get the latest config
808 large, small, err := c.buildAgentModels(ctx, false)
809 if err != nil {
810 return err
811 }
812 c.currentAgent.SetModels(large, small)
813
814 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
815 if !ok {
816 return errors.New("coder agent not configured")
817 }
818
819 tools, err := c.buildTools(ctx, agentCfg)
820 if err != nil {
821 return err
822 }
823 c.currentAgent.SetTools(tools)
824 return nil
825}
826
827func (c *coordinator) QueuedPrompts(sessionID string) int {
828 return c.currentAgent.QueuedPrompts(sessionID)
829}
830
831func (c *coordinator) QueuedPromptsList(sessionID string) []string {
832 return c.currentAgent.QueuedPromptsList(sessionID)
833}
834
835func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
836 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
837 if !ok {
838 return errors.New("model provider not configured")
839 }
840 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
841}
842
843func (c *coordinator) isUnauthorized(err error) bool {
844 var providerErr *fantasy.ProviderError
845 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
846}
847
848func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
849 if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
850 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
851 return err
852 }
853 if err := c.UpdateModels(ctx); err != nil {
854 return err
855 }
856 return nil
857}
858
859func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
860 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
861 if err != nil {
862 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
863 return err
864 }
865
866 providerCfg.APIKey = newAPIKey
867 c.cfg.Providers.Set(providerCfg.ID, providerCfg)
868
869 if err := c.UpdateModels(ctx); err != nil {
870 return err
871 }
872 return nil
873}