1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/fantasy"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/prompt"
22 "github.com/charmbracelet/crush/internal/agent/tools"
23 "github.com/charmbracelet/crush/internal/config"
24 "github.com/charmbracelet/crush/internal/filetracker"
25 "github.com/charmbracelet/crush/internal/history"
26 "github.com/charmbracelet/crush/internal/log"
27 "github.com/charmbracelet/crush/internal/lsp"
28 "github.com/charmbracelet/crush/internal/message"
29 "github.com/charmbracelet/crush/internal/oauth/copilot"
30 "github.com/charmbracelet/crush/internal/permission"
31 "github.com/charmbracelet/crush/internal/session"
32 "golang.org/x/sync/errgroup"
33
34 "charm.land/fantasy/providers/anthropic"
35 "charm.land/fantasy/providers/azure"
36 "charm.land/fantasy/providers/bedrock"
37 "charm.land/fantasy/providers/google"
38 "charm.land/fantasy/providers/openai"
39 "charm.land/fantasy/providers/openaicompat"
40 "charm.land/fantasy/providers/openrouter"
41 "charm.land/fantasy/providers/vercel"
42 openaisdk "github.com/openai/openai-go/v2/option"
43 "github.com/qjebbs/go-jsons"
44)
45
46type Coordinator interface {
47 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
48 // SetMainAgent(string)
49 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
50 Cancel(sessionID string)
51 CancelAll()
52 IsSessionBusy(sessionID string) bool
53 IsBusy() bool
54 QueuedPrompts(sessionID string) int
55 QueuedPromptsList(sessionID string) []string
56 ClearQueue(sessionID string)
57 Summarize(context.Context, string) error
58 Model() Model
59 UpdateModels(ctx context.Context) error
60}
61
62type coordinator struct {
63 cfg *config.Config
64 sessions session.Service
65 messages message.Service
66 permissions permission.Service
67 history history.Service
68 filetracker filetracker.Service
69 lspManager *lsp.Manager
70
71 currentAgent SessionAgent
72 agents map[string]SessionAgent
73
74 readyWg errgroup.Group
75}
76
77func NewCoordinator(
78 ctx context.Context,
79 cfg *config.Config,
80 sessions session.Service,
81 messages message.Service,
82 permissions permission.Service,
83 history history.Service,
84 filetracker filetracker.Service,
85 lspManager *lsp.Manager,
86) (Coordinator, error) {
87 c := &coordinator{
88 cfg: cfg,
89 sessions: sessions,
90 messages: messages,
91 permissions: permissions,
92 history: history,
93 filetracker: filetracker,
94 lspManager: lspManager,
95 agents: make(map[string]SessionAgent),
96 }
97
98 agentCfg, ok := cfg.Agents[config.AgentCoder]
99 if !ok {
100 return nil, errors.New("coder agent not configured")
101 }
102
103 // TODO: make this dynamic when we support multiple agents
104 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
105 if err != nil {
106 return nil, err
107 }
108
109 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
110 if err != nil {
111 return nil, err
112 }
113 c.currentAgent = agent
114 c.agents[config.AgentCoder] = agent
115 return c, nil
116}
117
118// Run implements Coordinator.
119func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
120 if err := c.readyWg.Wait(); err != nil {
121 return nil, err
122 }
123
124 // refresh models before each run
125 if err := c.UpdateModels(ctx); err != nil {
126 return nil, fmt.Errorf("failed to update models: %w", err)
127 }
128
129 model := c.currentAgent.Model()
130 maxTokens := model.CatwalkCfg.DefaultMaxTokens
131 if model.ModelCfg.MaxTokens != 0 {
132 maxTokens = model.ModelCfg.MaxTokens
133 }
134
135 if !model.CatwalkCfg.SupportsImages && attachments != nil {
136 // filter out image attachments
137 filteredAttachments := make([]message.Attachment, 0, len(attachments))
138 for _, att := range attachments {
139 if att.IsText() {
140 filteredAttachments = append(filteredAttachments, att)
141 }
142 }
143 attachments = filteredAttachments
144 }
145
146 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
147 if !ok {
148 return nil, errors.New("model provider not configured")
149 }
150
151 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
152
153 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
154 slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
155 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
156 return nil, err
157 }
158 }
159
160 run := func() (*fantasy.AgentResult, error) {
161 return c.currentAgent.Run(ctx, SessionAgentCall{
162 SessionID: sessionID,
163 Prompt: prompt,
164 Attachments: attachments,
165 MaxOutputTokens: maxTokens,
166 ProviderOptions: mergedOptions,
167 Temperature: temp,
168 TopP: topP,
169 TopK: topK,
170 FrequencyPenalty: freqPenalty,
171 PresencePenalty: presPenalty,
172 })
173 }
174 result, originalErr := run()
175
176 if c.isUnauthorized(originalErr) {
177 switch {
178 case providerCfg.OAuthToken != nil:
179 slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
180 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
181 return nil, originalErr
182 }
183 slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
184 return run()
185 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
186 slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
187 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
188 return nil, originalErr
189 }
190 slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
191 return run()
192 }
193 }
194
195 return result, originalErr
196}
197
198func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
199 options := fantasy.ProviderOptions{}
200
201 cfgOpts := []byte("{}")
202 providerCfgOpts := []byte("{}")
203 catwalkOpts := []byte("{}")
204
205 if model.ModelCfg.ProviderOptions != nil {
206 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
207 if err == nil {
208 cfgOpts = data
209 }
210 }
211
212 if providerCfg.ProviderOptions != nil {
213 data, err := json.Marshal(providerCfg.ProviderOptions)
214 if err == nil {
215 providerCfgOpts = data
216 }
217 }
218
219 if model.CatwalkCfg.Options.ProviderOptions != nil {
220 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
221 if err == nil {
222 catwalkOpts = data
223 }
224 }
225
226 readers := []io.Reader{
227 bytes.NewReader(catwalkOpts),
228 bytes.NewReader(providerCfgOpts),
229 bytes.NewReader(cfgOpts),
230 }
231
232 got, err := jsons.Merge(readers)
233 if err != nil {
234 slog.Error("Could not merge call config", "err", err)
235 return options
236 }
237
238 mergedOptions := make(map[string]any)
239
240 err = json.Unmarshal([]byte(got), &mergedOptions)
241 if err != nil {
242 slog.Error("Could not create config for call", "err", err)
243 return options
244 }
245
246 providerType := providerCfg.Type
247 if providerType == "hyper" {
248 if strings.Contains(model.CatwalkCfg.ID, "claude") {
249 providerType = anthropic.Name
250 } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
251 providerType = openai.Name
252 } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
253 providerType = google.Name
254 } else {
255 providerType = openaicompat.Name
256 }
257 }
258
259 switch providerType {
260 case openai.Name, azure.Name:
261 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
262 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
263 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
264 }
265 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
266 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
267 mergedOptions["reasoning_summary"] = "auto"
268 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
269 }
270 parsed, err := openai.ParseResponsesOptions(mergedOptions)
271 if err == nil {
272 options[openai.Name] = parsed
273 }
274 } else {
275 parsed, err := openai.ParseOptions(mergedOptions)
276 if err == nil {
277 options[openai.Name] = parsed
278 }
279 }
280 case anthropic.Name:
281 var (
282 _, hasEffort = mergedOptions["effort"]
283 _, hasThink = mergedOptions["thinking"]
284 )
285 switch {
286 case !hasEffort && model.ModelCfg.ReasoningEffort != "":
287 mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
288 case !hasThink && model.ModelCfg.Think:
289 mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
290 }
291 parsed, err := anthropic.ParseOptions(mergedOptions)
292 if err == nil {
293 options[anthropic.Name] = parsed
294 }
295
296 case openrouter.Name:
297 _, hasReasoning := mergedOptions["reasoning"]
298 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
299 mergedOptions["reasoning"] = map[string]any{
300 "enabled": true,
301 "effort": model.ModelCfg.ReasoningEffort,
302 }
303 }
304 parsed, err := openrouter.ParseOptions(mergedOptions)
305 if err == nil {
306 options[openrouter.Name] = parsed
307 }
308 case vercel.Name:
309 _, hasReasoning := mergedOptions["reasoning"]
310 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
311 mergedOptions["reasoning"] = map[string]any{
312 "enabled": true,
313 "effort": model.ModelCfg.ReasoningEffort,
314 }
315 }
316 parsed, err := vercel.ParseOptions(mergedOptions)
317 if err == nil {
318 options[vercel.Name] = parsed
319 }
320 case google.Name:
321 _, hasReasoning := mergedOptions["thinking_config"]
322 if !hasReasoning {
323 if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
324 mergedOptions["thinking_config"] = map[string]any{
325 "thinking_budget": 2000,
326 "include_thoughts": true,
327 }
328 } else {
329 mergedOptions["thinking_config"] = map[string]any{
330 "thinking_level": model.ModelCfg.ReasoningEffort,
331 "include_thoughts": true,
332 }
333 }
334 }
335 parsed, err := google.ParseOptions(mergedOptions)
336 if err == nil {
337 options[google.Name] = parsed
338 }
339 case openaicompat.Name:
340 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
341 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
342 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
343 }
344 parsed, err := openaicompat.ParseOptions(mergedOptions)
345 if err == nil {
346 options[openaicompat.Name] = parsed
347 }
348 }
349
350 return options
351}
352
353func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
354 modelOptions := getProviderOptions(model, cfg)
355 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
356 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
357 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
358 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
359 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
360 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
361}
362
363func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
364 large, small, err := c.buildAgentModels(ctx, isSubAgent)
365 if err != nil {
366 return nil, err
367 }
368
369 largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
370 result := NewSessionAgent(SessionAgentOptions{
371 large,
372 small,
373 largeProviderCfg.SystemPromptPrefix,
374 "",
375 isSubAgent,
376 c.cfg.Options.DisableAutoSummarize,
377 c.permissions.SkipRequests(),
378 c.sessions,
379 c.messages,
380 nil,
381 })
382
383 c.readyWg.Go(func() error {
384 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
385 if err != nil {
386 return err
387 }
388 result.SetSystemPrompt(systemPrompt)
389 return nil
390 })
391
392 c.readyWg.Go(func() error {
393 tools, err := c.buildTools(ctx, agent)
394 if err != nil {
395 return err
396 }
397 result.SetTools(tools)
398 return nil
399 })
400
401 return result, nil
402}
403
404func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
405 var allTools []fantasy.AgentTool
406 if slices.Contains(agent.AllowedTools, AgentToolName) {
407 agentTool, err := c.agentTool(ctx)
408 if err != nil {
409 return nil, err
410 }
411 allTools = append(allTools, agentTool)
412 }
413
414 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
415 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
416 if err != nil {
417 return nil, err
418 }
419 allTools = append(allTools, agenticFetchTool)
420 }
421
422 // Get the model name for the agent
423 modelName := ""
424 if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
425 if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
426 modelName = model.Name
427 }
428 }
429
430 allTools = append(allTools,
431 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
432 tools.NewJobOutputTool(),
433 tools.NewJobKillTool(),
434 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
435 tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
436 tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
437 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
438 tools.NewGlobTool(c.cfg.WorkingDir()),
439 tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep),
440 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
441 tools.NewSourcegraphTool(nil),
442 tools.NewTodosTool(c.sessions),
443 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
444 tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
445 )
446
447 // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
448 if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP {
449 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
450 }
451
452 if len(c.cfg.MCP) > 0 {
453 allTools = append(
454 allTools,
455 tools.NewListMCPResourcesTool(c.cfg, c.permissions),
456 tools.NewReadMCPResourceTool(c.cfg, c.permissions),
457 )
458 }
459
460 var filteredTools []fantasy.AgentTool
461 for _, tool := range allTools {
462 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
463 filteredTools = append(filteredTools, tool)
464 }
465 }
466
467 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
468 if agent.AllowedMCP == nil {
469 // No MCP restrictions
470 filteredTools = append(filteredTools, tool)
471 continue
472 }
473 if len(agent.AllowedMCP) == 0 {
474 // No MCPs allowed
475 slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
476 break
477 }
478
479 for mcp, tools := range agent.AllowedMCP {
480 if mcp != tool.MCP() {
481 continue
482 }
483 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
484 filteredTools = append(filteredTools, tool)
485 break
486 }
487 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
488 }
489 }
490 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
491 return strings.Compare(a.Info().Name, b.Info().Name)
492 })
493 return filteredTools, nil
494}
495
496// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
497func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
498 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
499 if !ok {
500 return Model{}, Model{}, errors.New("large model not selected")
501 }
502 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
503 if !ok {
504 return Model{}, Model{}, errors.New("small model not selected")
505 }
506
507 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
508 if !ok {
509 return Model{}, Model{}, errors.New("large model provider not configured")
510 }
511
512 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
513 if err != nil {
514 return Model{}, Model{}, err
515 }
516
517 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
518 if !ok {
519 return Model{}, Model{}, errors.New("small model provider not configured")
520 }
521
522 smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
523 if err != nil {
524 return Model{}, Model{}, err
525 }
526
527 var largeCatwalkModel *catwalk.Model
528 var smallCatwalkModel *catwalk.Model
529
530 for _, m := range largeProviderCfg.Models {
531 if m.ID == largeModelCfg.Model {
532 largeCatwalkModel = &m
533 }
534 }
535 for _, m := range smallProviderCfg.Models {
536 if m.ID == smallModelCfg.Model {
537 smallCatwalkModel = &m
538 }
539 }
540
541 if largeCatwalkModel == nil {
542 return Model{}, Model{}, errors.New("large model not found in provider config")
543 }
544
545 if smallCatwalkModel == nil {
546 return Model{}, Model{}, errors.New("small model not found in provider config")
547 }
548
549 largeModelID := largeModelCfg.Model
550 smallModelID := smallModelCfg.Model
551
552 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
553 largeModelID += ":exacto"
554 }
555
556 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
557 smallModelID += ":exacto"
558 }
559
560 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
561 if err != nil {
562 return Model{}, Model{}, err
563 }
564 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
565 if err != nil {
566 return Model{}, Model{}, err
567 }
568
569 return Model{
570 Model: largeModel,
571 CatwalkCfg: *largeCatwalkModel,
572 ModelCfg: largeModelCfg,
573 }, Model{
574 Model: smallModel,
575 CatwalkCfg: *smallCatwalkModel,
576 ModelCfg: smallModelCfg,
577 }, nil
578}
579
580func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
581 var opts []anthropic.Option
582
583 switch {
584 case strings.HasPrefix(apiKey, "Bearer "):
585 // NOTE: Prevent the SDK from picking up the API key from env.
586 os.Setenv("ANTHROPIC_API_KEY", "")
587 headers["Authorization"] = apiKey
588 case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
589 // NOTE: Prevent the SDK from picking up the API key from env.
590 os.Setenv("ANTHROPIC_API_KEY", "")
591 headers["Authorization"] = "Bearer " + apiKey
592 case apiKey != "":
593 // X-Api-Key header
594 opts = append(opts, anthropic.WithAPIKey(apiKey))
595 }
596
597 if len(headers) > 0 {
598 opts = append(opts, anthropic.WithHeaders(headers))
599 }
600
601 if baseURL != "" {
602 opts = append(opts, anthropic.WithBaseURL(baseURL))
603 }
604
605 if c.cfg.Options.Debug {
606 httpClient := log.NewHTTPClient()
607 opts = append(opts, anthropic.WithHTTPClient(httpClient))
608 }
609 return anthropic.New(opts...)
610}
611
612func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
613 opts := []openai.Option{
614 openai.WithAPIKey(apiKey),
615 openai.WithUseResponsesAPI(),
616 }
617 if c.cfg.Options.Debug {
618 httpClient := log.NewHTTPClient()
619 opts = append(opts, openai.WithHTTPClient(httpClient))
620 }
621 if len(headers) > 0 {
622 opts = append(opts, openai.WithHeaders(headers))
623 }
624 if baseURL != "" {
625 opts = append(opts, openai.WithBaseURL(baseURL))
626 }
627 return openai.New(opts...)
628}
629
630func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
631 opts := []openrouter.Option{
632 openrouter.WithAPIKey(apiKey),
633 }
634 if c.cfg.Options.Debug {
635 httpClient := log.NewHTTPClient()
636 opts = append(opts, openrouter.WithHTTPClient(httpClient))
637 }
638 if len(headers) > 0 {
639 opts = append(opts, openrouter.WithHeaders(headers))
640 }
641 return openrouter.New(opts...)
642}
643
644func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
645 opts := []vercel.Option{
646 vercel.WithAPIKey(apiKey),
647 }
648 if c.cfg.Options.Debug {
649 httpClient := log.NewHTTPClient()
650 opts = append(opts, vercel.WithHTTPClient(httpClient))
651 }
652 if len(headers) > 0 {
653 opts = append(opts, vercel.WithHeaders(headers))
654 }
655 return vercel.New(opts...)
656}
657
658func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
659 opts := []openaicompat.Option{
660 openaicompat.WithBaseURL(baseURL),
661 openaicompat.WithAPIKey(apiKey),
662 }
663
664 // Set HTTP client based on provider and debug mode.
665 var httpClient *http.Client
666 if providerID == string(catwalk.InferenceProviderCopilot) {
667 opts = append(opts, openaicompat.WithUseResponsesAPI())
668 httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
669 } else if c.cfg.Options.Debug {
670 httpClient = log.NewHTTPClient()
671 }
672 if httpClient != nil {
673 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
674 }
675
676 if len(headers) > 0 {
677 opts = append(opts, openaicompat.WithHeaders(headers))
678 }
679
680 for extraKey, extraValue := range extraBody {
681 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
682 }
683
684 return openaicompat.New(opts...)
685}
686
687func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
688 opts := []azure.Option{
689 azure.WithBaseURL(baseURL),
690 azure.WithAPIKey(apiKey),
691 azure.WithUseResponsesAPI(),
692 }
693 if c.cfg.Options.Debug {
694 httpClient := log.NewHTTPClient()
695 opts = append(opts, azure.WithHTTPClient(httpClient))
696 }
697 if options == nil {
698 options = make(map[string]string)
699 }
700 if apiVersion, ok := options["apiVersion"]; ok {
701 opts = append(opts, azure.WithAPIVersion(apiVersion))
702 }
703 if len(headers) > 0 {
704 opts = append(opts, azure.WithHeaders(headers))
705 }
706
707 return azure.New(opts...)
708}
709
710func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
711 var opts []bedrock.Option
712 if c.cfg.Options.Debug {
713 httpClient := log.NewHTTPClient()
714 opts = append(opts, bedrock.WithHTTPClient(httpClient))
715 }
716 if len(headers) > 0 {
717 opts = append(opts, bedrock.WithHeaders(headers))
718 }
719 bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
720 if bearerToken != "" {
721 opts = append(opts, bedrock.WithAPIKey(bearerToken))
722 }
723 return bedrock.New(opts...)
724}
725
726func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
727 opts := []google.Option{
728 google.WithBaseURL(baseURL),
729 google.WithGeminiAPIKey(apiKey),
730 }
731 if c.cfg.Options.Debug {
732 httpClient := log.NewHTTPClient()
733 opts = append(opts, google.WithHTTPClient(httpClient))
734 }
735 if len(headers) > 0 {
736 opts = append(opts, google.WithHeaders(headers))
737 }
738 return google.New(opts...)
739}
740
741func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
742 opts := []google.Option{}
743 if c.cfg.Options.Debug {
744 httpClient := log.NewHTTPClient()
745 opts = append(opts, google.WithHTTPClient(httpClient))
746 }
747 if len(headers) > 0 {
748 opts = append(opts, google.WithHeaders(headers))
749 }
750
751 project := options["project"]
752 location := options["location"]
753
754 opts = append(opts, google.WithVertex(project, location))
755
756 return google.New(opts...)
757}
758
759func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
760 opts := []hyper.Option{
761 hyper.WithBaseURL(baseURL),
762 hyper.WithAPIKey(apiKey),
763 }
764 if c.cfg.Options.Debug {
765 httpClient := log.NewHTTPClient()
766 opts = append(opts, hyper.WithHTTPClient(httpClient))
767 }
768 return hyper.New(opts...)
769}
770
771func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
772 if model.Think {
773 return true
774 }
775
776 if model.ProviderOptions == nil {
777 return false
778 }
779
780 opts, err := anthropic.ParseOptions(model.ProviderOptions)
781 if err != nil {
782 return false
783 }
784 if opts.Thinking != nil {
785 return true
786 }
787 return false
788}
789
790func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
791 headers := maps.Clone(providerCfg.ExtraHeaders)
792 if headers == nil {
793 headers = make(map[string]string)
794 }
795
796 // handle special headers for anthropic
797 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
798 if v, ok := headers["anthropic-beta"]; ok {
799 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
800 } else {
801 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
802 }
803 }
804
805 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
806 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
807
808 switch providerCfg.Type {
809 case openai.Name:
810 return c.buildOpenaiProvider(baseURL, apiKey, headers)
811 case anthropic.Name:
812 return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
813 case openrouter.Name:
814 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
815 case vercel.Name:
816 return c.buildVercelProvider(baseURL, apiKey, headers)
817 case azure.Name:
818 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
819 case bedrock.Name:
820 return c.buildBedrockProvider(headers)
821 case google.Name:
822 return c.buildGoogleProvider(baseURL, apiKey, headers)
823 case "google-vertex":
824 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
825 case openaicompat.Name:
826 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
827 if providerCfg.ExtraBody == nil {
828 providerCfg.ExtraBody = map[string]any{}
829 }
830 providerCfg.ExtraBody["tool_stream"] = true
831 }
832 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
833 case hyper.Name:
834 return c.buildHyperProvider(baseURL, apiKey)
835 default:
836 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
837 }
838}
839
840func isExactoSupported(modelID string) bool {
841 supportedModels := []string{
842 "moonshotai/kimi-k2-0905",
843 "deepseek/deepseek-v3.1-terminus",
844 "z-ai/glm-4.6",
845 "openai/gpt-oss-120b",
846 "qwen/qwen3-coder",
847 }
848 return slices.Contains(supportedModels, modelID)
849}
850
851func (c *coordinator) Cancel(sessionID string) {
852 c.currentAgent.Cancel(sessionID)
853}
854
855func (c *coordinator) CancelAll() {
856 c.currentAgent.CancelAll()
857}
858
859func (c *coordinator) ClearQueue(sessionID string) {
860 c.currentAgent.ClearQueue(sessionID)
861}
862
863func (c *coordinator) IsBusy() bool {
864 return c.currentAgent.IsBusy()
865}
866
867func (c *coordinator) IsSessionBusy(sessionID string) bool {
868 return c.currentAgent.IsSessionBusy(sessionID)
869}
870
871func (c *coordinator) Model() Model {
872 return c.currentAgent.Model()
873}
874
875func (c *coordinator) UpdateModels(ctx context.Context) error {
876 // build the models again so we make sure we get the latest config
877 large, small, err := c.buildAgentModels(ctx, false)
878 if err != nil {
879 return err
880 }
881 c.currentAgent.SetModels(large, small)
882
883 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
884 if !ok {
885 return errors.New("coder agent not configured")
886 }
887
888 tools, err := c.buildTools(ctx, agentCfg)
889 if err != nil {
890 return err
891 }
892 c.currentAgent.SetTools(tools)
893 return nil
894}
895
896func (c *coordinator) QueuedPrompts(sessionID string) int {
897 return c.currentAgent.QueuedPrompts(sessionID)
898}
899
900func (c *coordinator) QueuedPromptsList(sessionID string) []string {
901 return c.currentAgent.QueuedPromptsList(sessionID)
902}
903
904func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
905 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
906 if !ok {
907 return errors.New("model provider not configured")
908 }
909 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
910}
911
912func (c *coordinator) isUnauthorized(err error) bool {
913 var providerErr *fantasy.ProviderError
914 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
915}
916
917func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
918 if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
919 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
920 return err
921 }
922 if err := c.UpdateModels(ctx); err != nil {
923 return err
924 }
925 return nil
926}
927
928func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
929 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
930 if err != nil {
931 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
932 return err
933 }
934
935 providerCfg.APIKey = newAPIKey
936 c.cfg.Providers.Set(providerCfg.ID, providerCfg)
937
938 if err := c.UpdateModels(ctx); err != nil {
939 return err
940 }
941 return nil
942}
943
944// subAgentParams holds the parameters for running a sub-agent.
945type subAgentParams struct {
946 Agent SessionAgent
947 SessionID string
948 AgentMessageID string
949 ToolCallID string
950 Prompt string
951 SessionTitle string
952 // SessionSetup is an optional callback invoked after session creation
953 // but before agent execution, for custom session configuration.
954 SessionSetup func(sessionID string)
955}
956
957// runSubAgent runs a sub-agent and handles session management and cost accumulation.
958// It creates a sub-session, runs the agent with the given prompt, and propagates
959// the cost to the parent session.
960func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
961 // Create sub-session
962 agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
963 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
964 if err != nil {
965 return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
966 }
967
968 // Call session setup function if provided
969 if params.SessionSetup != nil {
970 params.SessionSetup(session.ID)
971 }
972
973 // Get model configuration
974 model := params.Agent.Model()
975 maxTokens := model.CatwalkCfg.DefaultMaxTokens
976 if model.ModelCfg.MaxTokens != 0 {
977 maxTokens = model.ModelCfg.MaxTokens
978 }
979
980 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
981 if !ok {
982 return fantasy.ToolResponse{}, errors.New("model provider not configured")
983 }
984
985 // Run the agent
986 result, err := params.Agent.Run(ctx, SessionAgentCall{
987 SessionID: session.ID,
988 Prompt: params.Prompt,
989 MaxOutputTokens: maxTokens,
990 ProviderOptions: getProviderOptions(model, providerCfg),
991 Temperature: model.ModelCfg.Temperature,
992 TopP: model.ModelCfg.TopP,
993 TopK: model.ModelCfg.TopK,
994 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
995 PresencePenalty: model.ModelCfg.PresencePenalty,
996 })
997 if err != nil {
998 return fantasy.NewTextErrorResponse("error generating response"), nil
999 }
1000
1001 // Update parent session cost
1002 if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1003 return fantasy.ToolResponse{}, err
1004 }
1005
1006 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1007}
1008
1009// updateParentSessionCost accumulates the cost from a child session to its parent session.
1010func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1011 childSession, err := c.sessions.Get(ctx, childSessionID)
1012 if err != nil {
1013 return fmt.Errorf("get child session: %w", err)
1014 }
1015
1016 parentSession, err := c.sessions.Get(ctx, parentSessionID)
1017 if err != nil {
1018 return fmt.Errorf("get parent session: %w", err)
1019 }
1020
1021 parentSession.Cost += childSession.Cost
1022
1023 if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1024 return fmt.Errorf("save parent session: %w", err)
1025 }
1026
1027 return nil
1028}