1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/fantasy"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/notify"
22 "github.com/charmbracelet/crush/internal/agent/prompt"
23 "github.com/charmbracelet/crush/internal/agent/tools"
24 "github.com/charmbracelet/crush/internal/config"
25 "github.com/charmbracelet/crush/internal/filetracker"
26 "github.com/charmbracelet/crush/internal/history"
27 "github.com/charmbracelet/crush/internal/log"
28 "github.com/charmbracelet/crush/internal/lsp"
29 "github.com/charmbracelet/crush/internal/message"
30 "github.com/charmbracelet/crush/internal/oauth/copilot"
31 "github.com/charmbracelet/crush/internal/permission"
32 "github.com/charmbracelet/crush/internal/pubsub"
33 "github.com/charmbracelet/crush/internal/session"
34 "golang.org/x/sync/errgroup"
35
36 "charm.land/fantasy/providers/anthropic"
37 "charm.land/fantasy/providers/azure"
38 "charm.land/fantasy/providers/bedrock"
39 "charm.land/fantasy/providers/google"
40 "charm.land/fantasy/providers/openai"
41 "charm.land/fantasy/providers/openaicompat"
42 "charm.land/fantasy/providers/openrouter"
43 "charm.land/fantasy/providers/vercel"
44 openaisdk "github.com/charmbracelet/openai-go/option"
45 "github.com/qjebbs/go-jsons"
46)
47
48// Coordinator errors.
49var (
50 errCoderAgentNotConfigured = errors.New("coder agent not configured")
51 errModelProviderNotConfigured = errors.New("model provider not configured")
52 errLargeModelNotSelected = errors.New("large model not selected")
53 errSmallModelNotSelected = errors.New("small model not selected")
54 errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
55 errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
56 errLargeModelNotFound = errors.New("large model not found in provider config")
57 errSmallModelNotFound = errors.New("small model not found in provider config")
58)
59
60type Coordinator interface {
61 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
62 // SetMainAgent(string)
63 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
64 Cancel(sessionID string)
65 CancelAll()
66 IsSessionBusy(sessionID string) bool
67 IsBusy() bool
68 QueuedPrompts(sessionID string) int
69 QueuedPromptsList(sessionID string) []string
70 ClearQueue(sessionID string)
71 Summarize(context.Context, string) error
72 Model() Model
73 UpdateModels(ctx context.Context) error
74}
75
76type coordinator struct {
77 cfg *config.ConfigStore
78 sessions session.Service
79 messages message.Service
80 permissions permission.Service
81 history history.Service
82 filetracker filetracker.Service
83 lspManager *lsp.Manager
84 notify pubsub.Publisher[notify.Notification]
85
86 currentAgent SessionAgent
87 agents map[string]SessionAgent
88
89 readyWg errgroup.Group
90}
91
92func NewCoordinator(
93 ctx context.Context,
94 cfg *config.ConfigStore,
95 sessions session.Service,
96 messages message.Service,
97 permissions permission.Service,
98 history history.Service,
99 filetracker filetracker.Service,
100 lspManager *lsp.Manager,
101 notify pubsub.Publisher[notify.Notification],
102) (Coordinator, error) {
103 c := &coordinator{
104 cfg: cfg,
105 sessions: sessions,
106 messages: messages,
107 permissions: permissions,
108 history: history,
109 filetracker: filetracker,
110 lspManager: lspManager,
111 notify: notify,
112 agents: make(map[string]SessionAgent),
113 }
114
115 agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
116 if !ok {
117 return nil, errCoderAgentNotConfigured
118 }
119
120 // TODO: make this dynamic when we support multiple agents
121 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
122 if err != nil {
123 return nil, err
124 }
125
126 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
127 if err != nil {
128 return nil, err
129 }
130 c.currentAgent = agent
131 c.agents[config.AgentCoder] = agent
132 return c, nil
133}
134
135// Run implements Coordinator.
136func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
137 if err := c.readyWg.Wait(); err != nil {
138 return nil, err
139 }
140
141 // refresh models before each run
142 if err := c.UpdateModels(ctx); err != nil {
143 return nil, fmt.Errorf("failed to update models: %w", err)
144 }
145
146 model := c.currentAgent.Model()
147 maxTokens := model.CatwalkCfg.DefaultMaxTokens
148 if model.ModelCfg.MaxTokens != 0 {
149 maxTokens = model.ModelCfg.MaxTokens
150 }
151
152 if !model.CatwalkCfg.SupportsImages && attachments != nil {
153 // filter out image attachments
154 filteredAttachments := make([]message.Attachment, 0, len(attachments))
155 for _, att := range attachments {
156 if att.IsText() {
157 filteredAttachments = append(filteredAttachments, att)
158 }
159 }
160 attachments = filteredAttachments
161 }
162
163 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
164 if !ok {
165 return nil, errModelProviderNotConfigured
166 }
167
168 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
169
170 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
171 slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
172 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
173 return nil, err
174 }
175 }
176
177 run := func() (*fantasy.AgentResult, error) {
178 return c.currentAgent.Run(ctx, SessionAgentCall{
179 SessionID: sessionID,
180 Prompt: prompt,
181 Attachments: attachments,
182 MaxOutputTokens: maxTokens,
183 ProviderOptions: mergedOptions,
184 Temperature: temp,
185 TopP: topP,
186 TopK: topK,
187 FrequencyPenalty: freqPenalty,
188 PresencePenalty: presPenalty,
189 })
190 }
191 result, originalErr := run()
192
193 if c.isUnauthorized(originalErr) {
194 switch {
195 case providerCfg.OAuthToken != nil:
196 slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
197 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
198 return nil, originalErr
199 }
200 slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
201 return run()
202 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
203 slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
204 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
205 return nil, originalErr
206 }
207 slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
208 return run()
209 }
210 }
211
212 return result, originalErr
213}
214
215func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
216 options := fantasy.ProviderOptions{}
217
218 cfgOpts := []byte("{}")
219 providerCfgOpts := []byte("{}")
220 catwalkOpts := []byte("{}")
221
222 if model.ModelCfg.ProviderOptions != nil {
223 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
224 if err == nil {
225 cfgOpts = data
226 }
227 }
228
229 if providerCfg.ProviderOptions != nil {
230 data, err := json.Marshal(providerCfg.ProviderOptions)
231 if err == nil {
232 providerCfgOpts = data
233 }
234 }
235
236 if model.CatwalkCfg.Options.ProviderOptions != nil {
237 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
238 if err == nil {
239 catwalkOpts = data
240 }
241 }
242
243 readers := []io.Reader{
244 bytes.NewReader(catwalkOpts),
245 bytes.NewReader(providerCfgOpts),
246 bytes.NewReader(cfgOpts),
247 }
248
249 got, err := jsons.Merge(readers)
250 if err != nil {
251 slog.Error("Could not merge call config", "err", err)
252 return options
253 }
254
255 mergedOptions := make(map[string]any)
256
257 err = json.Unmarshal([]byte(got), &mergedOptions)
258 if err != nil {
259 slog.Error("Could not create config for call", "err", err)
260 return options
261 }
262
263 providerType := providerCfg.Type
264 if providerType == "hyper" {
265 if strings.Contains(model.CatwalkCfg.ID, "claude") {
266 providerType = anthropic.Name
267 } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
268 providerType = openai.Name
269 } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
270 providerType = google.Name
271 } else {
272 providerType = openaicompat.Name
273 }
274 }
275
276 switch providerType {
277 case openai.Name, azure.Name:
278 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
279 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
280 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
281 }
282 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
283 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
284 mergedOptions["reasoning_summary"] = "auto"
285 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
286 }
287 parsed, err := openai.ParseResponsesOptions(mergedOptions)
288 if err == nil {
289 options[openai.Name] = parsed
290 }
291 } else {
292 parsed, err := openai.ParseOptions(mergedOptions)
293 if err == nil {
294 options[openai.Name] = parsed
295 }
296 }
297 case anthropic.Name:
298 var (
299 _, hasEffort = mergedOptions["effort"]
300 _, hasThink = mergedOptions["thinking"]
301 )
302 switch {
303 case !hasEffort && model.ModelCfg.ReasoningEffort != "":
304 mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
305 case !hasThink && model.ModelCfg.Think:
306 mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
307 }
308 parsed, err := anthropic.ParseOptions(mergedOptions)
309 if err == nil {
310 options[anthropic.Name] = parsed
311 }
312
313 case openrouter.Name:
314 _, hasReasoning := mergedOptions["reasoning"]
315 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
316 mergedOptions["reasoning"] = map[string]any{
317 "enabled": true,
318 "effort": model.ModelCfg.ReasoningEffort,
319 }
320 }
321 parsed, err := openrouter.ParseOptions(mergedOptions)
322 if err == nil {
323 options[openrouter.Name] = parsed
324 }
325 case vercel.Name:
326 _, hasReasoning := mergedOptions["reasoning"]
327 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
328 mergedOptions["reasoning"] = map[string]any{
329 "enabled": true,
330 "effort": model.ModelCfg.ReasoningEffort,
331 }
332 }
333 parsed, err := vercel.ParseOptions(mergedOptions)
334 if err == nil {
335 options[vercel.Name] = parsed
336 }
337 case google.Name:
338 _, hasReasoning := mergedOptions["thinking_config"]
339 if !hasReasoning {
340 if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
341 mergedOptions["thinking_config"] = map[string]any{
342 "thinking_budget": 2000,
343 "include_thoughts": true,
344 }
345 } else {
346 mergedOptions["thinking_config"] = map[string]any{
347 "thinking_level": model.ModelCfg.ReasoningEffort,
348 "include_thoughts": true,
349 }
350 }
351 }
352 parsed, err := google.ParseOptions(mergedOptions)
353 if err == nil {
354 options[google.Name] = parsed
355 }
356 case openaicompat.Name:
357 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
358 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
359 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
360 }
361 parsed, err := openaicompat.ParseOptions(mergedOptions)
362 if err == nil {
363 options[openaicompat.Name] = parsed
364 }
365 }
366
367 return options
368}
369
370func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
371 modelOptions := getProviderOptions(model, cfg)
372 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
373 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
374 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
375 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
376 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
377 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
378}
379
380func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
381 large, small, err := c.buildAgentModels(ctx, isSubAgent)
382 if err != nil {
383 return nil, err
384 }
385
386 largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
387 result := NewSessionAgent(SessionAgentOptions{
388 LargeModel: large,
389 SmallModel: small,
390 SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
391 SystemPrompt: "",
392 IsSubAgent: isSubAgent,
393 DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
394 IsYolo: c.permissions.SkipRequests(),
395 Sessions: c.sessions,
396 Messages: c.messages,
397 Tools: nil,
398 Notify: c.notify,
399 })
400
401 c.readyWg.Go(func() error {
402 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
403 if err != nil {
404 return err
405 }
406 result.SetSystemPrompt(systemPrompt)
407 return nil
408 })
409
410 c.readyWg.Go(func() error {
411 tools, err := c.buildTools(ctx, agent)
412 if err != nil {
413 return err
414 }
415 result.SetTools(tools)
416 return nil
417 })
418
419 return result, nil
420}
421
422func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
423 var allTools []fantasy.AgentTool
424 if slices.Contains(agent.AllowedTools, AgentToolName) {
425 agentTool, err := c.agentTool(ctx)
426 if err != nil {
427 return nil, err
428 }
429 allTools = append(allTools, agentTool)
430 }
431
432 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
433 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
434 if err != nil {
435 return nil, err
436 }
437 allTools = append(allTools, agenticFetchTool)
438 }
439
440 // Get the model name for the agent
441 modelName := ""
442 if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
443 if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
444 modelName = model.Name
445 }
446 }
447
448 allTools = append(allTools,
449 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
450 tools.NewJobOutputTool(),
451 tools.NewJobKillTool(),
452 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
453 tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
454 tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
455 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
456 tools.NewGlobTool(c.cfg.WorkingDir()),
457 tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
458 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
459 tools.NewSourcegraphTool(nil),
460 tools.NewTodosTool(c.sessions),
461 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
462 tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
463 )
464
465 // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
466 if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
467 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
468 }
469
470 if len(c.cfg.Config().MCP) > 0 {
471 allTools = append(
472 allTools,
473 tools.NewListMCPResourcesTool(c.cfg, c.permissions),
474 tools.NewReadMCPResourceTool(c.cfg, c.permissions),
475 )
476 }
477
478 var filteredTools []fantasy.AgentTool
479 for _, tool := range allTools {
480 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
481 filteredTools = append(filteredTools, tool)
482 }
483 }
484
485 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
486 if agent.AllowedMCP == nil {
487 // No MCP restrictions
488 filteredTools = append(filteredTools, tool)
489 continue
490 }
491 if len(agent.AllowedMCP) == 0 {
492 // No MCPs allowed
493 slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
494 break
495 }
496
497 for mcp, tools := range agent.AllowedMCP {
498 if mcp != tool.MCP() {
499 continue
500 }
501 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
502 filteredTools = append(filteredTools, tool)
503 break
504 }
505 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
506 }
507 }
508 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
509 return strings.Compare(a.Info().Name, b.Info().Name)
510 })
511 return filteredTools, nil
512}
513
514// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
515func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
516 largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
517 if !ok {
518 return Model{}, Model{}, errLargeModelNotSelected
519 }
520 smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
521 if !ok {
522 return Model{}, Model{}, errSmallModelNotSelected
523 }
524
525 largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
526 if !ok {
527 return Model{}, Model{}, errLargeModelProviderNotConfigured
528 }
529
530 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
531 if err != nil {
532 return Model{}, Model{}, err
533 }
534
535 smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
536 if !ok {
537 return Model{}, Model{}, errSmallModelProviderNotConfigured
538 }
539
540 smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
541 if err != nil {
542 return Model{}, Model{}, err
543 }
544
545 var largeCatwalkModel *catwalk.Model
546 var smallCatwalkModel *catwalk.Model
547
548 for _, m := range largeProviderCfg.Models {
549 if m.ID == largeModelCfg.Model {
550 largeCatwalkModel = &m
551 }
552 }
553 for _, m := range smallProviderCfg.Models {
554 if m.ID == smallModelCfg.Model {
555 smallCatwalkModel = &m
556 }
557 }
558
559 if largeCatwalkModel == nil {
560 return Model{}, Model{}, errLargeModelNotFound
561 }
562
563 if smallCatwalkModel == nil {
564 return Model{}, Model{}, errSmallModelNotFound
565 }
566
567 largeModelID := largeModelCfg.Model
568 smallModelID := smallModelCfg.Model
569
570 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
571 largeModelID += ":exacto"
572 }
573
574 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
575 smallModelID += ":exacto"
576 }
577
578 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
579 if err != nil {
580 return Model{}, Model{}, err
581 }
582 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
583 if err != nil {
584 return Model{}, Model{}, err
585 }
586
587 return Model{
588 Model: largeModel,
589 CatwalkCfg: *largeCatwalkModel,
590 ModelCfg: largeModelCfg,
591 }, Model{
592 Model: smallModel,
593 CatwalkCfg: *smallCatwalkModel,
594 ModelCfg: smallModelCfg,
595 }, nil
596}
597
598func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
599 var opts []anthropic.Option
600
601 switch {
602 case strings.HasPrefix(apiKey, "Bearer "):
603 // NOTE: Prevent the SDK from picking up the API key from env.
604 os.Setenv("ANTHROPIC_API_KEY", "")
605 headers["Authorization"] = apiKey
606 case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
607 // NOTE: Prevent the SDK from picking up the API key from env.
608 os.Setenv("ANTHROPIC_API_KEY", "")
609 headers["Authorization"] = "Bearer " + apiKey
610 case apiKey != "":
611 // X-Api-Key header
612 opts = append(opts, anthropic.WithAPIKey(apiKey))
613 }
614
615 if len(headers) > 0 {
616 opts = append(opts, anthropic.WithHeaders(headers))
617 }
618
619 if baseURL != "" {
620 opts = append(opts, anthropic.WithBaseURL(baseURL))
621 }
622
623 if c.cfg.Config().Options.Debug {
624 httpClient := log.NewHTTPClient()
625 opts = append(opts, anthropic.WithHTTPClient(httpClient))
626 }
627 return anthropic.New(opts...)
628}
629
630func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
631 opts := []openai.Option{
632 openai.WithAPIKey(apiKey),
633 openai.WithUseResponsesAPI(),
634 }
635 if c.cfg.Config().Options.Debug {
636 httpClient := log.NewHTTPClient()
637 opts = append(opts, openai.WithHTTPClient(httpClient))
638 }
639 if len(headers) > 0 {
640 opts = append(opts, openai.WithHeaders(headers))
641 }
642 if baseURL != "" {
643 opts = append(opts, openai.WithBaseURL(baseURL))
644 }
645 return openai.New(opts...)
646}
647
648func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
649 opts := []openrouter.Option{
650 openrouter.WithAPIKey(apiKey),
651 }
652 if c.cfg.Config().Options.Debug {
653 httpClient := log.NewHTTPClient()
654 opts = append(opts, openrouter.WithHTTPClient(httpClient))
655 }
656 if len(headers) > 0 {
657 opts = append(opts, openrouter.WithHeaders(headers))
658 }
659 return openrouter.New(opts...)
660}
661
662func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
663 opts := []vercel.Option{
664 vercel.WithAPIKey(apiKey),
665 }
666 if c.cfg.Config().Options.Debug {
667 httpClient := log.NewHTTPClient()
668 opts = append(opts, vercel.WithHTTPClient(httpClient))
669 }
670 if len(headers) > 0 {
671 opts = append(opts, vercel.WithHeaders(headers))
672 }
673 return vercel.New(opts...)
674}
675
676func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
677 opts := []openaicompat.Option{
678 openaicompat.WithBaseURL(baseURL),
679 openaicompat.WithAPIKey(apiKey),
680 }
681
682 // Set HTTP client based on provider and debug mode.
683 var httpClient *http.Client
684 if providerID == string(catwalk.InferenceProviderCopilot) {
685 opts = append(opts, openaicompat.WithUseResponsesAPI())
686 httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
687 } else if c.cfg.Config().Options.Debug {
688 httpClient = log.NewHTTPClient()
689 }
690 if httpClient != nil {
691 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
692 }
693
694 if len(headers) > 0 {
695 opts = append(opts, openaicompat.WithHeaders(headers))
696 }
697
698 for extraKey, extraValue := range extraBody {
699 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
700 }
701
702 return openaicompat.New(opts...)
703}
704
705func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
706 opts := []azure.Option{
707 azure.WithBaseURL(baseURL),
708 azure.WithAPIKey(apiKey),
709 azure.WithUseResponsesAPI(),
710 }
711 if c.cfg.Config().Options.Debug {
712 httpClient := log.NewHTTPClient()
713 opts = append(opts, azure.WithHTTPClient(httpClient))
714 }
715 if options == nil {
716 options = make(map[string]string)
717 }
718 if apiVersion, ok := options["apiVersion"]; ok {
719 opts = append(opts, azure.WithAPIVersion(apiVersion))
720 }
721 if len(headers) > 0 {
722 opts = append(opts, azure.WithHeaders(headers))
723 }
724
725 return azure.New(opts...)
726}
727
728func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
729 var opts []bedrock.Option
730 if c.cfg.Config().Options.Debug {
731 httpClient := log.NewHTTPClient()
732 opts = append(opts, bedrock.WithHTTPClient(httpClient))
733 }
734 if len(headers) > 0 {
735 opts = append(opts, bedrock.WithHeaders(headers))
736 }
737 switch {
738 case apiKey != "":
739 opts = append(opts, bedrock.WithAPIKey(apiKey))
740 case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
741 opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
742 default:
743 // Skip, let the SDK do authentication.
744 }
745 return bedrock.New(opts...)
746}
747
748func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
749 opts := []google.Option{
750 google.WithBaseURL(baseURL),
751 google.WithGeminiAPIKey(apiKey),
752 }
753 if c.cfg.Config().Options.Debug {
754 httpClient := log.NewHTTPClient()
755 opts = append(opts, google.WithHTTPClient(httpClient))
756 }
757 if len(headers) > 0 {
758 opts = append(opts, google.WithHeaders(headers))
759 }
760 return google.New(opts...)
761}
762
763func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
764 opts := []google.Option{}
765 if c.cfg.Config().Options.Debug {
766 httpClient := log.NewHTTPClient()
767 opts = append(opts, google.WithHTTPClient(httpClient))
768 }
769 if len(headers) > 0 {
770 opts = append(opts, google.WithHeaders(headers))
771 }
772
773 project := options["project"]
774 location := options["location"]
775
776 opts = append(opts, google.WithVertex(project, location))
777
778 return google.New(opts...)
779}
780
781func (c *coordinator) buildHyperProvider(apiKey string) (fantasy.Provider, error) {
782 opts := []hyper.Option{
783 hyper.WithAPIKey(apiKey),
784 }
785 if c.cfg.Config().Options.Debug {
786 httpClient := log.NewHTTPClient()
787 opts = append(opts, hyper.WithHTTPClient(httpClient))
788 }
789 return hyper.New(opts...)
790}
791
792func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
793 if model.Think {
794 return true
795 }
796 opts, err := anthropic.ParseOptions(model.ProviderOptions)
797 return err == nil && opts.Thinking != nil
798}
799
800func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
801 headers := maps.Clone(providerCfg.ExtraHeaders)
802 if headers == nil {
803 headers = make(map[string]string)
804 }
805
806 // handle special headers for anthropic
807 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
808 if v, ok := headers["anthropic-beta"]; ok {
809 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
810 } else {
811 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
812 }
813 }
814
815 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
816 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
817
818 switch providerCfg.Type {
819 case openai.Name:
820 return c.buildOpenaiProvider(baseURL, apiKey, headers)
821 case anthropic.Name:
822 return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
823 case openrouter.Name:
824 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
825 case vercel.Name:
826 return c.buildVercelProvider(baseURL, apiKey, headers)
827 case azure.Name:
828 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
829 case bedrock.Name:
830 return c.buildBedrockProvider(apiKey, headers)
831 case google.Name:
832 return c.buildGoogleProvider(baseURL, apiKey, headers)
833 case "google-vertex":
834 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
835 case openaicompat.Name:
836 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
837 if providerCfg.ExtraBody == nil {
838 providerCfg.ExtraBody = map[string]any{}
839 }
840 providerCfg.ExtraBody["tool_stream"] = true
841 }
842 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
843 case hyper.Name:
844 return c.buildHyperProvider(apiKey)
845 default:
846 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
847 }
848}
849
850func isExactoSupported(modelID string) bool {
851 supportedModels := []string{
852 "moonshotai/kimi-k2-0905",
853 "deepseek/deepseek-v3.1-terminus",
854 "z-ai/glm-4.6",
855 "openai/gpt-oss-120b",
856 "qwen/qwen3-coder",
857 }
858 return slices.Contains(supportedModels, modelID)
859}
860
861func (c *coordinator) Cancel(sessionID string) {
862 c.currentAgent.Cancel(sessionID)
863}
864
865func (c *coordinator) CancelAll() {
866 c.currentAgent.CancelAll()
867}
868
869func (c *coordinator) ClearQueue(sessionID string) {
870 c.currentAgent.ClearQueue(sessionID)
871}
872
873func (c *coordinator) IsBusy() bool {
874 return c.currentAgent.IsBusy()
875}
876
877func (c *coordinator) IsSessionBusy(sessionID string) bool {
878 return c.currentAgent.IsSessionBusy(sessionID)
879}
880
881func (c *coordinator) Model() Model {
882 return c.currentAgent.Model()
883}
884
885func (c *coordinator) UpdateModels(ctx context.Context) error {
886 // build the models again so we make sure we get the latest config
887 large, small, err := c.buildAgentModels(ctx, false)
888 if err != nil {
889 return err
890 }
891 c.currentAgent.SetModels(large, small)
892
893 agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
894 if !ok {
895 return errCoderAgentNotConfigured
896 }
897
898 tools, err := c.buildTools(ctx, agentCfg)
899 if err != nil {
900 return err
901 }
902 c.currentAgent.SetTools(tools)
903 return nil
904}
905
906func (c *coordinator) QueuedPrompts(sessionID string) int {
907 return c.currentAgent.QueuedPrompts(sessionID)
908}
909
910func (c *coordinator) QueuedPromptsList(sessionID string) []string {
911 return c.currentAgent.QueuedPromptsList(sessionID)
912}
913
914func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
915 providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
916 if !ok {
917 return errModelProviderNotConfigured
918 }
919 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
920}
921
922func (c *coordinator) isUnauthorized(err error) bool {
923 var providerErr *fantasy.ProviderError
924 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
925}
926
927func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
928 if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
929 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
930 return err
931 }
932 if err := c.UpdateModels(ctx); err != nil {
933 return err
934 }
935 return nil
936}
937
938func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
939 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
940 if err != nil {
941 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
942 return err
943 }
944
945 providerCfg.APIKey = newAPIKey
946 c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
947
948 if err := c.UpdateModels(ctx); err != nil {
949 return err
950 }
951 return nil
952}
953
954// subAgentParams holds the parameters for running a sub-agent.
955type subAgentParams struct {
956 Agent SessionAgent
957 SessionID string
958 AgentMessageID string
959 ToolCallID string
960 Prompt string
961 SessionTitle string
962 // SessionSetup is an optional callback invoked after session creation
963 // but before agent execution, for custom session configuration.
964 SessionSetup func(sessionID string)
965}
966
967// runSubAgent runs a sub-agent and handles session management and cost accumulation.
968// It creates a sub-session, runs the agent with the given prompt, and propagates
969// the cost to the parent session.
970func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
971 // Create sub-session
972 agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
973 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
974 if err != nil {
975 return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
976 }
977
978 // Call session setup function if provided
979 if params.SessionSetup != nil {
980 params.SessionSetup(session.ID)
981 }
982
983 // Get model configuration
984 model := params.Agent.Model()
985 maxTokens := model.CatwalkCfg.DefaultMaxTokens
986 if model.ModelCfg.MaxTokens != 0 {
987 maxTokens = model.ModelCfg.MaxTokens
988 }
989
990 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
991 if !ok {
992 return fantasy.ToolResponse{}, errModelProviderNotConfigured
993 }
994
995 // Run the agent
996 result, err := params.Agent.Run(ctx, SessionAgentCall{
997 SessionID: session.ID,
998 Prompt: params.Prompt,
999 MaxOutputTokens: maxTokens,
1000 ProviderOptions: getProviderOptions(model, providerCfg),
1001 Temperature: model.ModelCfg.Temperature,
1002 TopP: model.ModelCfg.TopP,
1003 TopK: model.ModelCfg.TopK,
1004 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1005 PresencePenalty: model.ModelCfg.PresencePenalty,
1006 NonInteractive: true,
1007 })
1008 if err != nil {
1009 return fantasy.NewTextErrorResponse("error generating response"), nil
1010 }
1011
1012 // Update parent session cost
1013 if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1014 return fantasy.ToolResponse{}, err
1015 }
1016
1017 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1018}
1019
1020// updateParentSessionCost accumulates the cost from a child session to its parent session.
1021func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1022 childSession, err := c.sessions.Get(ctx, childSessionID)
1023 if err != nil {
1024 return fmt.Errorf("get child session: %w", err)
1025 }
1026
1027 parentSession, err := c.sessions.Get(ctx, parentSessionID)
1028 if err != nil {
1029 return fmt.Errorf("get parent session: %w", err)
1030 }
1031
1032 parentSession.Cost += childSession.Cost
1033
1034 if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1035 return fmt.Errorf("save parent session: %w", err)
1036 }
1037
1038 return nil
1039}