1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/fantasy"
19 "github.com/charmbracelet/catwalk/pkg/catwalk"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/prompt"
22 "github.com/charmbracelet/crush/internal/agent/tools"
23 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
24 "github.com/charmbracelet/crush/internal/config"
25 "github.com/charmbracelet/crush/internal/csync"
26 "github.com/charmbracelet/crush/internal/history"
27 "github.com/charmbracelet/crush/internal/log"
28 "github.com/charmbracelet/crush/internal/lsp"
29 "github.com/charmbracelet/crush/internal/message"
30 "github.com/charmbracelet/crush/internal/oauth/copilot"
31 "github.com/charmbracelet/crush/internal/permission"
32 "github.com/charmbracelet/crush/internal/session"
33 "golang.org/x/sync/errgroup"
34
35 "charm.land/fantasy/providers/anthropic"
36 "charm.land/fantasy/providers/azure"
37 "charm.land/fantasy/providers/bedrock"
38 "charm.land/fantasy/providers/google"
39 "charm.land/fantasy/providers/openai"
40 "charm.land/fantasy/providers/openaicompat"
41 "charm.land/fantasy/providers/openrouter"
42 openaisdk "github.com/openai/openai-go/v2/option"
43 "github.com/qjebbs/go-jsons"
44)
45
46type Coordinator interface {
47 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
48 // SetMainAgent(string)
49 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
50 Cancel(sessionID string)
51 CancelAll()
52 IsSessionBusy(sessionID string) bool
53 IsBusy() bool
54 QueuedPrompts(sessionID string) int
55 QueuedPromptsList(sessionID string) []string
56 ClearQueue(sessionID string)
57 Summarize(context.Context, string) error
58 Model() Model
59 UpdateModels(ctx context.Context) error
60}
61
62type coordinator struct {
63 cfg *config.Config
64 sessions session.Service
65 messages message.Service
66 permissions permission.Service
67 history history.Service
68 lspClients *csync.Map[string, *lsp.Client]
69
70 currentAgent SessionAgent
71 agents map[string]SessionAgent
72
73 readyWg errgroup.Group
74}
75
76func NewCoordinator(
77 ctx context.Context,
78 cfg *config.Config,
79 sessions session.Service,
80 messages message.Service,
81 permissions permission.Service,
82 history history.Service,
83 lspClients *csync.Map[string, *lsp.Client],
84) (Coordinator, error) {
85 c := &coordinator{
86 cfg: cfg,
87 sessions: sessions,
88 messages: messages,
89 permissions: permissions,
90 history: history,
91 lspClients: lspClients,
92 agents: make(map[string]SessionAgent),
93 }
94
95 agentCfg, ok := cfg.Agents[config.AgentCoder]
96 if !ok {
97 return nil, errors.New("coder agent not configured")
98 }
99
100 // TODO: make this dynamic when we support multiple agents
101 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
102 if err != nil {
103 return nil, err
104 }
105
106 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
107 if err != nil {
108 return nil, err
109 }
110 c.currentAgent = agent
111 c.agents[config.AgentCoder] = agent
112 return c, nil
113}
114
115// Run implements Coordinator.
116func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
117 if err := c.readyWg.Wait(); err != nil {
118 return nil, err
119 }
120
121 model := c.currentAgent.Model()
122 maxTokens := model.CatwalkCfg.DefaultMaxTokens
123 if model.ModelCfg.MaxTokens != 0 {
124 maxTokens = model.ModelCfg.MaxTokens
125 }
126
127 if !model.CatwalkCfg.SupportsImages && attachments != nil {
128 // filter out image attachments
129 filteredAttachments := make([]message.Attachment, 0, len(attachments))
130 for _, att := range attachments {
131 if att.IsText() {
132 filteredAttachments = append(filteredAttachments, att)
133 }
134 }
135 attachments = filteredAttachments
136 }
137
138 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
139 if !ok {
140 return nil, errors.New("model provider not configured")
141 }
142
143 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
144
145 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
146 slog.Info("Token needs to be refreshed", "provider", providerCfg.ID)
147 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
148 return nil, err
149 }
150 }
151
152 run := func() (*fantasy.AgentResult, error) {
153 return c.currentAgent.Run(ctx, SessionAgentCall{
154 SessionID: sessionID,
155 Prompt: prompt,
156 Attachments: attachments,
157 MaxOutputTokens: maxTokens,
158 ProviderOptions: mergedOptions,
159 Temperature: temp,
160 TopP: topP,
161 TopK: topK,
162 FrequencyPenalty: freqPenalty,
163 PresencePenalty: presPenalty,
164 })
165 }
166 result, originalErr := run()
167
168 if c.isUnauthorized(originalErr) {
169 switch {
170 case providerCfg.OAuthToken != nil:
171 slog.Info("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
172 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
173 return nil, originalErr
174 }
175 slog.Info("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
176 return run()
177 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
178 slog.Info("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
179 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
180 return nil, originalErr
181 }
182 slog.Info("Retrying request with refreshed API key", "provider", providerCfg.ID)
183 return run()
184 }
185 }
186
187 return result, originalErr
188}
189
190func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
191 options := fantasy.ProviderOptions{}
192
193 cfgOpts := []byte("{}")
194 providerCfgOpts := []byte("{}")
195 catwalkOpts := []byte("{}")
196
197 if model.ModelCfg.ProviderOptions != nil {
198 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
199 if err == nil {
200 cfgOpts = data
201 }
202 }
203
204 if providerCfg.ProviderOptions != nil {
205 data, err := json.Marshal(providerCfg.ProviderOptions)
206 if err == nil {
207 providerCfgOpts = data
208 }
209 }
210
211 if model.CatwalkCfg.Options.ProviderOptions != nil {
212 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
213 if err == nil {
214 catwalkOpts = data
215 }
216 }
217
218 readers := []io.Reader{
219 bytes.NewReader(catwalkOpts),
220 bytes.NewReader(providerCfgOpts),
221 bytes.NewReader(cfgOpts),
222 }
223
224 got, err := jsons.Merge(readers)
225 if err != nil {
226 slog.Error("Could not merge call config", "err", err)
227 return options
228 }
229
230 mergedOptions := make(map[string]any)
231
232 err = json.Unmarshal([]byte(got), &mergedOptions)
233 if err != nil {
234 slog.Error("Could not create config for call", "err", err)
235 return options
236 }
237
238 switch providerCfg.Type {
239 case openai.Name, azure.Name:
240 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
241 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
242 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
243 }
244 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
245 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
246 mergedOptions["reasoning_summary"] = "auto"
247 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
248 }
249 parsed, err := openai.ParseResponsesOptions(mergedOptions)
250 if err == nil {
251 options[openai.Name] = parsed
252 }
253 } else {
254 parsed, err := openai.ParseOptions(mergedOptions)
255 if err == nil {
256 options[openai.Name] = parsed
257 }
258 }
259 case anthropic.Name:
260 _, hasThink := mergedOptions["thinking"]
261 if !hasThink && model.ModelCfg.Think {
262 mergedOptions["thinking"] = map[string]any{
263 // TODO: kujtim see if we need to make this dynamic
264 "budget_tokens": 2000,
265 }
266 }
267 parsed, err := anthropic.ParseOptions(mergedOptions)
268 if err == nil {
269 options[anthropic.Name] = parsed
270 }
271
272 case openrouter.Name:
273 _, hasReasoning := mergedOptions["reasoning"]
274 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
275 mergedOptions["reasoning"] = map[string]any{
276 "enabled": true,
277 "effort": model.ModelCfg.ReasoningEffort,
278 }
279 }
280 parsed, err := openrouter.ParseOptions(mergedOptions)
281 if err == nil {
282 options[openrouter.Name] = parsed
283 }
284 case google.Name:
285 _, hasReasoning := mergedOptions["thinking_config"]
286 if !hasReasoning {
287 mergedOptions["thinking_config"] = map[string]any{
288 "thinking_budget": 2000,
289 "include_thoughts": true,
290 }
291 }
292 parsed, err := google.ParseOptions(mergedOptions)
293 if err == nil {
294 options[google.Name] = parsed
295 }
296 case openaicompat.Name:
297 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
298 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
299 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
300 }
301 parsed, err := openaicompat.ParseOptions(mergedOptions)
302 if err == nil {
303 options[openaicompat.Name] = parsed
304 }
305 }
306
307 return options
308}
309
310func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
311 modelOptions := getProviderOptions(model, cfg)
312 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
313 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
314 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
315 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
316 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
317 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
318}
319
320func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
321 large, small, err := c.buildAgentModels(ctx, isSubAgent)
322 if err != nil {
323 return nil, err
324 }
325
326 largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
327 result := NewSessionAgent(SessionAgentOptions{
328 large,
329 small,
330 largeProviderCfg.SystemPromptPrefix,
331 "",
332 isSubAgent,
333 c.cfg.Options.DisableAutoSummarize,
334 c.permissions.SkipRequests(),
335 c.sessions,
336 c.messages,
337 nil,
338 })
339
340 c.readyWg.Go(func() error {
341 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
342 if err != nil {
343 return err
344 }
345 result.SetSystemPrompt(systemPrompt)
346 return nil
347 })
348
349 c.readyWg.Go(func() error {
350 tools, err := c.buildTools(ctx, agent)
351 if err != nil {
352 return err
353 }
354 result.SetTools(tools)
355 return nil
356 })
357
358 return result, nil
359}
360
361func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
362 var allTools []fantasy.AgentTool
363 if slices.Contains(agent.AllowedTools, AgentToolName) {
364 agentTool, err := c.agentTool(ctx)
365 if err != nil {
366 return nil, err
367 }
368 allTools = append(allTools, agentTool)
369 }
370
371 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
372 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
373 if err != nil {
374 return nil, err
375 }
376 allTools = append(allTools, agenticFetchTool)
377 }
378
379 // Get the model name for the agent
380 modelName := ""
381 if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
382 if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
383 modelName = model.Name
384 }
385 }
386
387 allTools = append(allTools,
388 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
389 tools.NewJobOutputTool(),
390 tools.NewJobKillTool(),
391 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
392 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
393 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
394 tools.NewDeleteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
395 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
396 tools.NewGlobTool(c.cfg.WorkingDir()),
397 tools.NewGrepTool(c.cfg.WorkingDir()),
398 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
399 tools.NewSourcegraphTool(nil),
400 tools.NewTodosTool(c.sessions),
401 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
402 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
403 )
404
405 if len(c.cfg.LSP) > 0 {
406 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspClients), tools.NewReferencesTool(c.lspClients))
407 }
408
409 var filteredTools []fantasy.AgentTool
410 for _, tool := range allTools {
411 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
412 filteredTools = append(filteredTools, tool)
413 }
414 }
415
416 // Wait for MCP initialization to complete before reading MCP tools.
417 if err := mcp.WaitForInit(ctx); err != nil {
418 return nil, fmt.Errorf("failed to wait for MCP initialization: %w", err)
419 }
420
421 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg.WorkingDir()) {
422 if agent.AllowedMCP == nil {
423 // No MCP restrictions
424 filteredTools = append(filteredTools, tool)
425 continue
426 }
427 if len(agent.AllowedMCP) == 0 {
428 // No MCPs allowed
429 slog.Debug("no MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
430 break
431 }
432
433 for mcp, tools := range agent.AllowedMCP {
434 if mcp != tool.MCP() {
435 continue
436 }
437 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
438 filteredTools = append(filteredTools, tool)
439 }
440 }
441 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
442 }
443 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
444 return strings.Compare(a.Info().Name, b.Info().Name)
445 })
446 return filteredTools, nil
447}
448
449// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
450func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
451 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
452 if !ok {
453 return Model{}, Model{}, errors.New("large model not selected")
454 }
455 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
456 if !ok {
457 return Model{}, Model{}, errors.New("small model not selected")
458 }
459
460 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
461 if !ok {
462 return Model{}, Model{}, errors.New("large model provider not configured")
463 }
464
465 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
466 if err != nil {
467 return Model{}, Model{}, err
468 }
469
470 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
471 if !ok {
472 return Model{}, Model{}, errors.New("large model provider not configured")
473 }
474
475 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg, true)
476 if err != nil {
477 return Model{}, Model{}, err
478 }
479
480 var largeCatwalkModel *catwalk.Model
481 var smallCatwalkModel *catwalk.Model
482
483 for _, m := range largeProviderCfg.Models {
484 if m.ID == largeModelCfg.Model {
485 largeCatwalkModel = &m
486 }
487 }
488 for _, m := range smallProviderCfg.Models {
489 if m.ID == smallModelCfg.Model {
490 smallCatwalkModel = &m
491 }
492 }
493
494 if largeCatwalkModel == nil {
495 return Model{}, Model{}, errors.New("large model not found in provider config")
496 }
497
498 if smallCatwalkModel == nil {
499 return Model{}, Model{}, errors.New("small model not found in provider config")
500 }
501
502 largeModelID := largeModelCfg.Model
503 smallModelID := smallModelCfg.Model
504
505 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
506 largeModelID += ":exacto"
507 }
508
509 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
510 smallModelID += ":exacto"
511 }
512
513 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
514 if err != nil {
515 return Model{}, Model{}, err
516 }
517 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
518 if err != nil {
519 return Model{}, Model{}, err
520 }
521
522 return Model{
523 Model: largeModel,
524 CatwalkCfg: *largeCatwalkModel,
525 ModelCfg: largeModelCfg,
526 }, Model{
527 Model: smallModel,
528 CatwalkCfg: *smallCatwalkModel,
529 ModelCfg: smallModelCfg,
530 }, nil
531}
532
533func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
534 var opts []anthropic.Option
535
536 if strings.HasPrefix(apiKey, "Bearer ") {
537 // NOTE: Prevent the SDK from picking up the API key from env.
538 os.Setenv("ANTHROPIC_API_KEY", "")
539 headers["Authorization"] = apiKey
540 } else if apiKey != "" {
541 // X-Api-Key header
542 opts = append(opts, anthropic.WithAPIKey(apiKey))
543 }
544
545 if len(headers) > 0 {
546 opts = append(opts, anthropic.WithHeaders(headers))
547 }
548
549 if baseURL != "" {
550 opts = append(opts, anthropic.WithBaseURL(baseURL))
551 }
552
553 if c.cfg.Options.Debug {
554 httpClient := log.NewHTTPClient()
555 opts = append(opts, anthropic.WithHTTPClient(httpClient))
556 }
557 return anthropic.New(opts...)
558}
559
560func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
561 opts := []openai.Option{
562 openai.WithAPIKey(apiKey),
563 openai.WithUseResponsesAPI(),
564 }
565 if c.cfg.Options.Debug {
566 httpClient := log.NewHTTPClient()
567 opts = append(opts, openai.WithHTTPClient(httpClient))
568 }
569 if len(headers) > 0 {
570 opts = append(opts, openai.WithHeaders(headers))
571 }
572 if baseURL != "" {
573 opts = append(opts, openai.WithBaseURL(baseURL))
574 }
575 return openai.New(opts...)
576}
577
578func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
579 opts := []openrouter.Option{
580 openrouter.WithAPIKey(apiKey),
581 }
582 if c.cfg.Options.Debug {
583 httpClient := log.NewHTTPClient()
584 opts = append(opts, openrouter.WithHTTPClient(httpClient))
585 }
586 if len(headers) > 0 {
587 opts = append(opts, openrouter.WithHeaders(headers))
588 }
589 return openrouter.New(opts...)
590}
591
592func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
593 opts := []openaicompat.Option{
594 openaicompat.WithBaseURL(baseURL),
595 openaicompat.WithAPIKey(apiKey),
596 }
597
598 // Set HTTP client based on provider and debug mode.
599 var httpClient *http.Client
600 if providerID == string(catwalk.InferenceProviderCopilot) {
601 opts = append(opts, openaicompat.WithUseResponsesAPI())
602 httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
603 } else if c.cfg.Options.Debug {
604 httpClient = log.NewHTTPClient()
605 }
606 if httpClient != nil {
607 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
608 }
609
610 if len(headers) > 0 {
611 opts = append(opts, openaicompat.WithHeaders(headers))
612 }
613
614 for extraKey, extraValue := range extraBody {
615 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
616 }
617
618 return openaicompat.New(opts...)
619}
620
621func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
622 opts := []azure.Option{
623 azure.WithBaseURL(baseURL),
624 azure.WithAPIKey(apiKey),
625 azure.WithUseResponsesAPI(),
626 }
627 if c.cfg.Options.Debug {
628 httpClient := log.NewHTTPClient()
629 opts = append(opts, azure.WithHTTPClient(httpClient))
630 }
631 if options == nil {
632 options = make(map[string]string)
633 }
634 if apiVersion, ok := options["apiVersion"]; ok {
635 opts = append(opts, azure.WithAPIVersion(apiVersion))
636 }
637 if len(headers) > 0 {
638 opts = append(opts, azure.WithHeaders(headers))
639 }
640
641 return azure.New(opts...)
642}
643
644func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
645 var opts []bedrock.Option
646 if c.cfg.Options.Debug {
647 httpClient := log.NewHTTPClient()
648 opts = append(opts, bedrock.WithHTTPClient(httpClient))
649 }
650 if len(headers) > 0 {
651 opts = append(opts, bedrock.WithHeaders(headers))
652 }
653 bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
654 if bearerToken != "" {
655 opts = append(opts, bedrock.WithAPIKey(bearerToken))
656 }
657 return bedrock.New(opts...)
658}
659
660func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
661 opts := []google.Option{
662 google.WithBaseURL(baseURL),
663 google.WithGeminiAPIKey(apiKey),
664 }
665 if c.cfg.Options.Debug {
666 httpClient := log.NewHTTPClient()
667 opts = append(opts, google.WithHTTPClient(httpClient))
668 }
669 if len(headers) > 0 {
670 opts = append(opts, google.WithHeaders(headers))
671 }
672 return google.New(opts...)
673}
674
675func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
676 opts := []google.Option{}
677 if c.cfg.Options.Debug {
678 httpClient := log.NewHTTPClient()
679 opts = append(opts, google.WithHTTPClient(httpClient))
680 }
681 if len(headers) > 0 {
682 opts = append(opts, google.WithHeaders(headers))
683 }
684
685 project := options["project"]
686 location := options["location"]
687
688 opts = append(opts, google.WithVertex(project, location))
689
690 return google.New(opts...)
691}
692
693func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
694 opts := []hyper.Option{
695 hyper.WithBaseURL(baseURL),
696 hyper.WithAPIKey(apiKey),
697 }
698 if c.cfg.Options.Debug {
699 httpClient := log.NewHTTPClient()
700 opts = append(opts, hyper.WithHTTPClient(httpClient))
701 }
702 return hyper.New(opts...)
703}
704
705func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
706 if model.Think {
707 return true
708 }
709
710 if model.ProviderOptions == nil {
711 return false
712 }
713
714 opts, err := anthropic.ParseOptions(model.ProviderOptions)
715 if err != nil {
716 return false
717 }
718 if opts.Thinking != nil {
719 return true
720 }
721 return false
722}
723
724func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
725 headers := maps.Clone(providerCfg.ExtraHeaders)
726 if headers == nil {
727 headers = make(map[string]string)
728 }
729
730 // handle special headers for anthropic
731 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
732 if v, ok := headers["anthropic-beta"]; ok {
733 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
734 } else {
735 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
736 }
737 }
738
739 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
740 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
741
742 switch providerCfg.Type {
743 case openai.Name:
744 return c.buildOpenaiProvider(baseURL, apiKey, headers)
745 case anthropic.Name:
746 return c.buildAnthropicProvider(baseURL, apiKey, headers)
747 case openrouter.Name:
748 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
749 case azure.Name:
750 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
751 case bedrock.Name:
752 return c.buildBedrockProvider(headers)
753 case google.Name:
754 return c.buildGoogleProvider(baseURL, apiKey, headers)
755 case "google-vertex":
756 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
757 case openaicompat.Name:
758 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
759 if providerCfg.ExtraBody == nil {
760 providerCfg.ExtraBody = map[string]any{}
761 }
762 providerCfg.ExtraBody["tool_stream"] = true
763 }
764 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
765 case hyper.Name:
766 return c.buildHyperProvider(baseURL, apiKey)
767 default:
768 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
769 }
770}
771
772func isExactoSupported(modelID string) bool {
773 supportedModels := []string{
774 "moonshotai/kimi-k2-0905",
775 "deepseek/deepseek-v3.1-terminus",
776 "z-ai/glm-4.6",
777 "openai/gpt-oss-120b",
778 "qwen/qwen3-coder",
779 }
780 return slices.Contains(supportedModels, modelID)
781}
782
783func (c *coordinator) Cancel(sessionID string) {
784 c.currentAgent.Cancel(sessionID)
785}
786
787func (c *coordinator) CancelAll() {
788 c.currentAgent.CancelAll()
789}
790
791func (c *coordinator) ClearQueue(sessionID string) {
792 c.currentAgent.ClearQueue(sessionID)
793}
794
795func (c *coordinator) IsBusy() bool {
796 return c.currentAgent.IsBusy()
797}
798
799func (c *coordinator) IsSessionBusy(sessionID string) bool {
800 return c.currentAgent.IsSessionBusy(sessionID)
801}
802
803func (c *coordinator) Model() Model {
804 return c.currentAgent.Model()
805}
806
807func (c *coordinator) UpdateModels(ctx context.Context) error {
808 // build the models again so we make sure we get the latest config
809 large, small, err := c.buildAgentModels(ctx, false)
810 if err != nil {
811 return err
812 }
813 c.currentAgent.SetModels(large, small)
814
815 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
816 if !ok {
817 return errors.New("coder agent not configured")
818 }
819
820 tools, err := c.buildTools(ctx, agentCfg)
821 if err != nil {
822 return err
823 }
824 c.currentAgent.SetTools(tools)
825 return nil
826}
827
828func (c *coordinator) QueuedPrompts(sessionID string) int {
829 return c.currentAgent.QueuedPrompts(sessionID)
830}
831
832func (c *coordinator) QueuedPromptsList(sessionID string) []string {
833 return c.currentAgent.QueuedPromptsList(sessionID)
834}
835
836func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
837 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
838 if !ok {
839 return errors.New("model provider not configured")
840 }
841 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
842}
843
844func (c *coordinator) isUnauthorized(err error) bool {
845 var providerErr *fantasy.ProviderError
846 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
847}
848
849func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
850 if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
851 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
852 return err
853 }
854 if err := c.UpdateModels(ctx); err != nil {
855 return err
856 }
857 return nil
858}
859
860func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
861 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
862 if err != nil {
863 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
864 return err
865 }
866
867 providerCfg.APIKey = newAPIKey
868 c.cfg.Providers.Set(providerCfg.ID, providerCfg)
869
870 if err := c.UpdateModels(ctx); err != nil {
871 return err
872 }
873 return nil
874}