1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/fantasy"
19 "github.com/charmbracelet/catwalk/pkg/catwalk"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/prompt"
22 "github.com/charmbracelet/crush/internal/agent/tools"
23 "github.com/charmbracelet/crush/internal/config"
24 "github.com/charmbracelet/crush/internal/csync"
25 "github.com/charmbracelet/crush/internal/history"
26 "github.com/charmbracelet/crush/internal/log"
27 "github.com/charmbracelet/crush/internal/lsp"
28 "github.com/charmbracelet/crush/internal/message"
29 "github.com/charmbracelet/crush/internal/oauth/copilot"
30 "github.com/charmbracelet/crush/internal/permission"
31 "github.com/charmbracelet/crush/internal/session"
32 "golang.org/x/sync/errgroup"
33
34 "charm.land/fantasy/providers/anthropic"
35 "charm.land/fantasy/providers/azure"
36 "charm.land/fantasy/providers/bedrock"
37 "charm.land/fantasy/providers/google"
38 "charm.land/fantasy/providers/openai"
39 "charm.land/fantasy/providers/openaicompat"
40 "charm.land/fantasy/providers/openrouter"
41 openaisdk "github.com/openai/openai-go/v2/option"
42 "github.com/qjebbs/go-jsons"
43)
44
45type Coordinator interface {
46 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
47 // SetMainAgent(string)
48 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
49 Cancel(sessionID string)
50 CancelAll()
51 IsSessionBusy(sessionID string) bool
52 IsBusy() bool
53 QueuedPrompts(sessionID string) int
54 QueuedPromptsList(sessionID string) []string
55 ClearQueue(sessionID string)
56 Summarize(context.Context, string) error
57 Model() Model
58 UpdateModels(ctx context.Context) error
59}
60
61type coordinator struct {
62 cfg *config.Config
63 sessions session.Service
64 messages message.Service
65 permissions permission.Service
66 history history.Service
67 lspClients *csync.Map[string, *lsp.Client]
68
69 currentAgent SessionAgent
70 agents map[string]SessionAgent
71
72 readyWg errgroup.Group
73}
74
75func NewCoordinator(
76 ctx context.Context,
77 cfg *config.Config,
78 sessions session.Service,
79 messages message.Service,
80 permissions permission.Service,
81 history history.Service,
82 lspClients *csync.Map[string, *lsp.Client],
83) (Coordinator, error) {
84 c := &coordinator{
85 cfg: cfg,
86 sessions: sessions,
87 messages: messages,
88 permissions: permissions,
89 history: history,
90 lspClients: lspClients,
91 agents: make(map[string]SessionAgent),
92 }
93
94 agentCfg, ok := cfg.Agents[config.AgentCoder]
95 if !ok {
96 return nil, errors.New("coder agent not configured")
97 }
98
99 // TODO: make this dynamic when we support multiple agents
100 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
101 if err != nil {
102 return nil, err
103 }
104
105 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
106 if err != nil {
107 return nil, err
108 }
109 c.currentAgent = agent
110 c.agents[config.AgentCoder] = agent
111 return c, nil
112}
113
114// Run implements Coordinator.
115func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
116 if err := c.readyWg.Wait(); err != nil {
117 return nil, err
118 }
119
120 model := c.currentAgent.Model()
121 maxTokens := model.CatwalkCfg.DefaultMaxTokens
122 if model.ModelCfg.MaxTokens != 0 {
123 maxTokens = model.ModelCfg.MaxTokens
124 }
125
126 if !model.CatwalkCfg.SupportsImages && attachments != nil {
127 // filter out image attachments
128 filteredAttachments := make([]message.Attachment, 0, len(attachments))
129 for _, att := range attachments {
130 if att.IsText() {
131 filteredAttachments = append(filteredAttachments, att)
132 }
133 }
134 attachments = filteredAttachments
135 }
136
137 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
138 if !ok {
139 return nil, errors.New("model provider not configured")
140 }
141
142 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
143
144 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
145 slog.Info("Token needs to be refreshed", "provider", providerCfg.ID)
146 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
147 return nil, err
148 }
149 }
150
151 run := func() (*fantasy.AgentResult, error) {
152 return c.currentAgent.Run(ctx, SessionAgentCall{
153 SessionID: sessionID,
154 Prompt: prompt,
155 Attachments: attachments,
156 MaxOutputTokens: maxTokens,
157 ProviderOptions: mergedOptions,
158 Temperature: temp,
159 TopP: topP,
160 TopK: topK,
161 FrequencyPenalty: freqPenalty,
162 PresencePenalty: presPenalty,
163 })
164 }
165 result, originalErr := run()
166
167 if c.isUnauthorized(originalErr) {
168 switch {
169 case providerCfg.OAuthToken != nil:
170 slog.Info("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
171 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
172 return nil, originalErr
173 }
174 slog.Info("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
175 return run()
176 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
177 slog.Info("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
178 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
179 return nil, originalErr
180 }
181 slog.Info("Retrying request with refreshed API key", "provider", providerCfg.ID)
182 return run()
183 }
184 }
185
186 return result, originalErr
187}
188
189func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
190 options := fantasy.ProviderOptions{}
191
192 cfgOpts := []byte("{}")
193 providerCfgOpts := []byte("{}")
194 catwalkOpts := []byte("{}")
195
196 if model.ModelCfg.ProviderOptions != nil {
197 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
198 if err == nil {
199 cfgOpts = data
200 }
201 }
202
203 if providerCfg.ProviderOptions != nil {
204 data, err := json.Marshal(providerCfg.ProviderOptions)
205 if err == nil {
206 providerCfgOpts = data
207 }
208 }
209
210 if model.CatwalkCfg.Options.ProviderOptions != nil {
211 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
212 if err == nil {
213 catwalkOpts = data
214 }
215 }
216
217 readers := []io.Reader{
218 bytes.NewReader(catwalkOpts),
219 bytes.NewReader(providerCfgOpts),
220 bytes.NewReader(cfgOpts),
221 }
222
223 got, err := jsons.Merge(readers)
224 if err != nil {
225 slog.Error("Could not merge call config", "err", err)
226 return options
227 }
228
229 mergedOptions := make(map[string]any)
230
231 err = json.Unmarshal([]byte(got), &mergedOptions)
232 if err != nil {
233 slog.Error("Could not create config for call", "err", err)
234 return options
235 }
236
237 switch providerCfg.Type {
238 case openai.Name, azure.Name:
239 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
240 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
241 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
242 }
243 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
244 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
245 mergedOptions["reasoning_summary"] = "auto"
246 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
247 }
248 parsed, err := openai.ParseResponsesOptions(mergedOptions)
249 if err == nil {
250 options[openai.Name] = parsed
251 }
252 } else {
253 parsed, err := openai.ParseOptions(mergedOptions)
254 if err == nil {
255 options[openai.Name] = parsed
256 }
257 }
258 case anthropic.Name:
259 _, hasThink := mergedOptions["thinking"]
260 if !hasThink && model.ModelCfg.Think {
261 mergedOptions["thinking"] = map[string]any{
262 // TODO: kujtim see if we need to make this dynamic
263 "budget_tokens": 2000,
264 }
265 }
266 parsed, err := anthropic.ParseOptions(mergedOptions)
267 if err == nil {
268 options[anthropic.Name] = parsed
269 }
270
271 case openrouter.Name:
272 _, hasReasoning := mergedOptions["reasoning"]
273 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
274 mergedOptions["reasoning"] = map[string]any{
275 "enabled": true,
276 "effort": model.ModelCfg.ReasoningEffort,
277 }
278 }
279 parsed, err := openrouter.ParseOptions(mergedOptions)
280 if err == nil {
281 options[openrouter.Name] = parsed
282 }
283 case google.Name:
284 _, hasReasoning := mergedOptions["thinking_config"]
285 if !hasReasoning {
286 mergedOptions["thinking_config"] = map[string]any{
287 "thinking_budget": 2000,
288 "include_thoughts": true,
289 }
290 }
291 parsed, err := google.ParseOptions(mergedOptions)
292 if err == nil {
293 options[google.Name] = parsed
294 }
295 case openaicompat.Name:
296 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
297 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
298 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
299 }
300 if providerCfg.ID == string(catwalk.InferenceProviderCopilot) && openai.IsResponsesModel(model.CatwalkCfg.ID) {
301 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
302 mergedOptions["reasoning_summary"] = "auto"
303 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
304 }
305 parsed, err := openai.ParseResponsesOptions(mergedOptions)
306 if err == nil {
307 options[openaicompat.Name] = parsed
308 }
309 } else {
310 parsed, err := openaicompat.ParseOptions(mergedOptions)
311 if err == nil {
312 options[openaicompat.Name] = parsed
313 }
314 }
315 }
316
317 return options
318}
319
320func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
321 modelOptions := getProviderOptions(model, cfg)
322 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
323 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
324 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
325 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
326 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
327 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
328}
329
330func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
331 large, small, err := c.buildAgentModels(ctx, isSubAgent)
332 if err != nil {
333 return nil, err
334 }
335
336 largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
337 result := NewSessionAgent(SessionAgentOptions{
338 large,
339 small,
340 largeProviderCfg.SystemPromptPrefix,
341 "",
342 isSubAgent,
343 c.cfg.Options.DisableAutoSummarize,
344 c.permissions.SkipRequests(),
345 c.sessions,
346 c.messages,
347 nil,
348 })
349
350 c.readyWg.Go(func() error {
351 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
352 if err != nil {
353 return err
354 }
355 result.SetSystemPrompt(systemPrompt)
356 return nil
357 })
358
359 c.readyWg.Go(func() error {
360 tools, err := c.buildTools(ctx, agent)
361 if err != nil {
362 return err
363 }
364 result.SetTools(tools)
365 return nil
366 })
367
368 return result, nil
369}
370
371func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
372 var allTools []fantasy.AgentTool
373 if slices.Contains(agent.AllowedTools, AgentToolName) {
374 agentTool, err := c.agentTool(ctx)
375 if err != nil {
376 return nil, err
377 }
378 allTools = append(allTools, agentTool)
379 }
380
381 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
382 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
383 if err != nil {
384 return nil, err
385 }
386 allTools = append(allTools, agenticFetchTool)
387 }
388
389 // Get the model name for the agent
390 modelName := ""
391 if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
392 if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
393 modelName = model.Name
394 }
395 }
396
397 allTools = append(allTools,
398 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
399 tools.NewJobOutputTool(),
400 tools.NewJobKillTool(),
401 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
402 tools.NewEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
403 tools.NewMultiEditTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
404 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
405 tools.NewGlobTool(c.cfg.WorkingDir()),
406 tools.NewGrepTool(c.cfg.WorkingDir()),
407 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
408 tools.NewSourcegraphTool(nil),
409 tools.NewTodosTool(c.sessions),
410 tools.NewViewTool(c.lspClients, c.permissions, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
411 tools.NewWriteTool(c.lspClients, c.permissions, c.history, c.cfg.WorkingDir()),
412 )
413
414 if len(c.cfg.LSP) > 0 {
415 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspClients), tools.NewReferencesTool(c.lspClients))
416 }
417
418 var filteredTools []fantasy.AgentTool
419 for _, tool := range allTools {
420 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
421 filteredTools = append(filteredTools, tool)
422 }
423 }
424
425 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg.WorkingDir()) {
426 if agent.AllowedMCP == nil {
427 // No MCP restrictions
428 filteredTools = append(filteredTools, tool)
429 continue
430 }
431 if len(agent.AllowedMCP) == 0 {
432 // No MCPs allowed
433 slog.Debug("no MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
434 break
435 }
436
437 for mcp, tools := range agent.AllowedMCP {
438 if mcp != tool.MCP() {
439 continue
440 }
441 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
442 filteredTools = append(filteredTools, tool)
443 }
444 }
445 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
446 }
447 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
448 return strings.Compare(a.Info().Name, b.Info().Name)
449 })
450 return filteredTools, nil
451}
452
453// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
454func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
455 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
456 if !ok {
457 return Model{}, Model{}, errors.New("large model not selected")
458 }
459 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
460 if !ok {
461 return Model{}, Model{}, errors.New("small model not selected")
462 }
463
464 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
465 if !ok {
466 return Model{}, Model{}, errors.New("large model provider not configured")
467 }
468
469 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
470 if err != nil {
471 return Model{}, Model{}, err
472 }
473
474 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
475 if !ok {
476 return Model{}, Model{}, errors.New("large model provider not configured")
477 }
478
479 smallProvider, err := c.buildProvider(smallProviderCfg, largeModelCfg, true)
480 if err != nil {
481 return Model{}, Model{}, err
482 }
483
484 var largeCatwalkModel *catwalk.Model
485 var smallCatwalkModel *catwalk.Model
486
487 for _, m := range largeProviderCfg.Models {
488 if m.ID == largeModelCfg.Model {
489 largeCatwalkModel = &m
490 }
491 }
492 for _, m := range smallProviderCfg.Models {
493 if m.ID == smallModelCfg.Model {
494 smallCatwalkModel = &m
495 }
496 }
497
498 if largeCatwalkModel == nil {
499 return Model{}, Model{}, errors.New("large model not found in provider config")
500 }
501
502 if smallCatwalkModel == nil {
503 return Model{}, Model{}, errors.New("small model not found in provider config")
504 }
505
506 largeModelID := largeModelCfg.Model
507 smallModelID := smallModelCfg.Model
508
509 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
510 largeModelID += ":exacto"
511 }
512
513 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
514 smallModelID += ":exacto"
515 }
516
517 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
518 if err != nil {
519 return Model{}, Model{}, err
520 }
521 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
522 if err != nil {
523 return Model{}, Model{}, err
524 }
525
526 return Model{
527 Model: largeModel,
528 CatwalkCfg: *largeCatwalkModel,
529 ModelCfg: largeModelCfg,
530 }, Model{
531 Model: smallModel,
532 CatwalkCfg: *smallCatwalkModel,
533 ModelCfg: smallModelCfg,
534 }, nil
535}
536
537func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
538 var opts []anthropic.Option
539
540 if strings.HasPrefix(apiKey, "Bearer ") {
541 // NOTE: Prevent the SDK from picking up the API key from env.
542 os.Setenv("ANTHROPIC_API_KEY", "")
543 headers["Authorization"] = apiKey
544 } else if apiKey != "" {
545 // X-Api-Key header
546 opts = append(opts, anthropic.WithAPIKey(apiKey))
547 }
548
549 if len(headers) > 0 {
550 opts = append(opts, anthropic.WithHeaders(headers))
551 }
552
553 if baseURL != "" {
554 opts = append(opts, anthropic.WithBaseURL(baseURL))
555 }
556
557 if c.cfg.Options.Debug {
558 httpClient := log.NewHTTPClient()
559 opts = append(opts, anthropic.WithHTTPClient(httpClient))
560 }
561 return anthropic.New(opts...)
562}
563
564func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
565 opts := []openai.Option{
566 openai.WithAPIKey(apiKey),
567 openai.WithUseResponsesAPI(),
568 }
569 if c.cfg.Options.Debug {
570 httpClient := log.NewHTTPClient()
571 opts = append(opts, openai.WithHTTPClient(httpClient))
572 }
573 if len(headers) > 0 {
574 opts = append(opts, openai.WithHeaders(headers))
575 }
576 if baseURL != "" {
577 opts = append(opts, openai.WithBaseURL(baseURL))
578 }
579 return openai.New(opts...)
580}
581
582func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
583 opts := []openrouter.Option{
584 openrouter.WithAPIKey(apiKey),
585 }
586 if c.cfg.Options.Debug {
587 httpClient := log.NewHTTPClient()
588 opts = append(opts, openrouter.WithHTTPClient(httpClient))
589 }
590 if len(headers) > 0 {
591 opts = append(opts, openrouter.WithHeaders(headers))
592 }
593 return openrouter.New(opts...)
594}
595
596func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
597 opts := []openaicompat.Option{
598 openaicompat.WithBaseURL(baseURL),
599 openaicompat.WithAPIKey(apiKey),
600 }
601
602 // Set HTTP client based on provider and debug mode.
603 var httpClient *http.Client
604 if providerID == string(catwalk.InferenceProviderCopilot) {
605 opts = append(opts, openaicompat.WithUseResponsesAPI())
606 httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
607 } else if c.cfg.Options.Debug {
608 httpClient = log.NewHTTPClient()
609 }
610 if httpClient != nil {
611 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
612 }
613
614 if len(headers) > 0 {
615 opts = append(opts, openaicompat.WithHeaders(headers))
616 }
617
618 for extraKey, extraValue := range extraBody {
619 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
620 }
621
622 return openaicompat.New(opts...)
623}
624
625func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
626 opts := []azure.Option{
627 azure.WithBaseURL(baseURL),
628 azure.WithAPIKey(apiKey),
629 azure.WithUseResponsesAPI(),
630 }
631 if c.cfg.Options.Debug {
632 httpClient := log.NewHTTPClient()
633 opts = append(opts, azure.WithHTTPClient(httpClient))
634 }
635 if options == nil {
636 options = make(map[string]string)
637 }
638 if apiVersion, ok := options["apiVersion"]; ok {
639 opts = append(opts, azure.WithAPIVersion(apiVersion))
640 }
641 if len(headers) > 0 {
642 opts = append(opts, azure.WithHeaders(headers))
643 }
644
645 return azure.New(opts...)
646}
647
648func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
649 var opts []bedrock.Option
650 if c.cfg.Options.Debug {
651 httpClient := log.NewHTTPClient()
652 opts = append(opts, bedrock.WithHTTPClient(httpClient))
653 }
654 if len(headers) > 0 {
655 opts = append(opts, bedrock.WithHeaders(headers))
656 }
657 bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
658 if bearerToken != "" {
659 opts = append(opts, bedrock.WithAPIKey(bearerToken))
660 }
661 return bedrock.New(opts...)
662}
663
664func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
665 opts := []google.Option{
666 google.WithBaseURL(baseURL),
667 google.WithGeminiAPIKey(apiKey),
668 }
669 if c.cfg.Options.Debug {
670 httpClient := log.NewHTTPClient()
671 opts = append(opts, google.WithHTTPClient(httpClient))
672 }
673 if len(headers) > 0 {
674 opts = append(opts, google.WithHeaders(headers))
675 }
676 return google.New(opts...)
677}
678
679func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
680 opts := []google.Option{}
681 if c.cfg.Options.Debug {
682 httpClient := log.NewHTTPClient()
683 opts = append(opts, google.WithHTTPClient(httpClient))
684 }
685 if len(headers) > 0 {
686 opts = append(opts, google.WithHeaders(headers))
687 }
688
689 project := options["project"]
690 location := options["location"]
691
692 opts = append(opts, google.WithVertex(project, location))
693
694 return google.New(opts...)
695}
696
697func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
698 opts := []hyper.Option{
699 hyper.WithBaseURL(baseURL),
700 hyper.WithAPIKey(apiKey),
701 }
702 if c.cfg.Options.Debug {
703 httpClient := log.NewHTTPClient()
704 opts = append(opts, hyper.WithHTTPClient(httpClient))
705 }
706 return hyper.New(opts...)
707}
708
709func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
710 if model.Think {
711 return true
712 }
713
714 if model.ProviderOptions == nil {
715 return false
716 }
717
718 opts, err := anthropic.ParseOptions(model.ProviderOptions)
719 if err != nil {
720 return false
721 }
722 if opts.Thinking != nil {
723 return true
724 }
725 return false
726}
727
728func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
729 headers := maps.Clone(providerCfg.ExtraHeaders)
730 if headers == nil {
731 headers = make(map[string]string)
732 }
733
734 // handle special headers for anthropic
735 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
736 if v, ok := headers["anthropic-beta"]; ok {
737 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
738 } else {
739 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
740 }
741 }
742
743 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
744 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
745
746 switch providerCfg.Type {
747 case openai.Name:
748 return c.buildOpenaiProvider(baseURL, apiKey, headers)
749 case anthropic.Name:
750 return c.buildAnthropicProvider(baseURL, apiKey, headers)
751 case openrouter.Name:
752 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
753 case azure.Name:
754 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
755 case bedrock.Name:
756 return c.buildBedrockProvider(headers)
757 case google.Name:
758 return c.buildGoogleProvider(baseURL, apiKey, headers)
759 case "google-vertex":
760 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
761 case openaicompat.Name:
762 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
763 if providerCfg.ExtraBody == nil {
764 providerCfg.ExtraBody = map[string]any{}
765 }
766 providerCfg.ExtraBody["tool_stream"] = true
767 }
768 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
769 case hyper.Name:
770 return c.buildHyperProvider(baseURL, apiKey)
771 default:
772 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
773 }
774}
775
776func isExactoSupported(modelID string) bool {
777 supportedModels := []string{
778 "moonshotai/kimi-k2-0905",
779 "deepseek/deepseek-v3.1-terminus",
780 "z-ai/glm-4.6",
781 "openai/gpt-oss-120b",
782 "qwen/qwen3-coder",
783 }
784 return slices.Contains(supportedModels, modelID)
785}
786
787func (c *coordinator) Cancel(sessionID string) {
788 c.currentAgent.Cancel(sessionID)
789}
790
791func (c *coordinator) CancelAll() {
792 c.currentAgent.CancelAll()
793}
794
795func (c *coordinator) ClearQueue(sessionID string) {
796 c.currentAgent.ClearQueue(sessionID)
797}
798
799func (c *coordinator) IsBusy() bool {
800 return c.currentAgent.IsBusy()
801}
802
803func (c *coordinator) IsSessionBusy(sessionID string) bool {
804 return c.currentAgent.IsSessionBusy(sessionID)
805}
806
807func (c *coordinator) Model() Model {
808 return c.currentAgent.Model()
809}
810
811func (c *coordinator) UpdateModels(ctx context.Context) error {
812 // build the models again so we make sure we get the latest config
813 large, small, err := c.buildAgentModels(ctx, false)
814 if err != nil {
815 return err
816 }
817 c.currentAgent.SetModels(large, small)
818
819 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
820 if !ok {
821 return errors.New("coder agent not configured")
822 }
823
824 tools, err := c.buildTools(ctx, agentCfg)
825 if err != nil {
826 return err
827 }
828 c.currentAgent.SetTools(tools)
829 return nil
830}
831
832func (c *coordinator) QueuedPrompts(sessionID string) int {
833 return c.currentAgent.QueuedPrompts(sessionID)
834}
835
836func (c *coordinator) QueuedPromptsList(sessionID string) []string {
837 return c.currentAgent.QueuedPromptsList(sessionID)
838}
839
840func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
841 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
842 if !ok {
843 return errors.New("model provider not configured")
844 }
845 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
846}
847
848func (c *coordinator) isUnauthorized(err error) bool {
849 var providerErr *fantasy.ProviderError
850 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
851}
852
853func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
854 if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
855 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
856 return err
857 }
858 if err := c.UpdateModels(ctx); err != nil {
859 return err
860 }
861 return nil
862}
863
864func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
865 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
866 if err != nil {
867 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
868 return err
869 }
870
871 providerCfg.APIKey = newAPIKey
872 c.cfg.Providers.Set(providerCfg.ID, providerCfg)
873
874 if err := c.UpdateModels(ctx); err != nil {
875 return err
876 }
877 return nil
878}