1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/fantasy"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/notify"
22 "github.com/charmbracelet/crush/internal/agent/prompt"
23 "github.com/charmbracelet/crush/internal/agent/tools"
24 "github.com/charmbracelet/crush/internal/config"
25 "github.com/charmbracelet/crush/internal/filetracker"
26 "github.com/charmbracelet/crush/internal/history"
27 "github.com/charmbracelet/crush/internal/log"
28 "github.com/charmbracelet/crush/internal/lsp"
29 "github.com/charmbracelet/crush/internal/message"
30 "github.com/charmbracelet/crush/internal/oauth/copilot"
31 "github.com/charmbracelet/crush/internal/permission"
32 "github.com/charmbracelet/crush/internal/pubsub"
33 "github.com/charmbracelet/crush/internal/session"
34 "golang.org/x/sync/errgroup"
35
36 "charm.land/fantasy/providers/anthropic"
37 "charm.land/fantasy/providers/azure"
38 "charm.land/fantasy/providers/bedrock"
39 "charm.land/fantasy/providers/google"
40 "charm.land/fantasy/providers/openai"
41 "charm.land/fantasy/providers/openaicompat"
42 "charm.land/fantasy/providers/openrouter"
43 "charm.land/fantasy/providers/vercel"
44 openaisdk "github.com/charmbracelet/openai-go/option"
45 "github.com/qjebbs/go-jsons"
46)
47
48// Coordinator errors.
49var (
50 errCoderAgentNotConfigured = errors.New("coder agent not configured")
51 errModelProviderNotConfigured = errors.New("model provider not configured")
52 errLargeModelNotSelected = errors.New("large model not selected")
53 errSmallModelNotSelected = errors.New("small model not selected")
54 errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
55 errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
56 errLargeModelNotFound = errors.New("large model not found in provider config")
57 errSmallModelNotFound = errors.New("small model not found in provider config")
58)
59
60type Coordinator interface {
61 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
62 // SetMainAgent(string)
63 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
64 Cancel(sessionID string)
65 CancelAll()
66 IsSessionBusy(sessionID string) bool
67 IsBusy() bool
68 QueuedPrompts(sessionID string) int
69 QueuedPromptsList(sessionID string) []string
70 ClearQueue(sessionID string)
71 Summarize(context.Context, string) error
72 Model() Model
73 UpdateModels(ctx context.Context) error
74}
75
76type coordinator struct {
77 cfg *config.ConfigStore
78 sessions session.Service
79 messages message.Service
80 permissions permission.Service
81 history history.Service
82 filetracker filetracker.Service
83 lspManager *lsp.Manager
84 notify pubsub.Publisher[notify.Notification]
85
86 currentAgent SessionAgent
87 agents map[string]SessionAgent
88
89 readyWg errgroup.Group
90}
91
92func NewCoordinator(
93 ctx context.Context,
94 cfg *config.ConfigStore,
95 sessions session.Service,
96 messages message.Service,
97 permissions permission.Service,
98 history history.Service,
99 filetracker filetracker.Service,
100 lspManager *lsp.Manager,
101 notify pubsub.Publisher[notify.Notification],
102) (Coordinator, error) {
103 c := &coordinator{
104 cfg: cfg,
105 sessions: sessions,
106 messages: messages,
107 permissions: permissions,
108 history: history,
109 filetracker: filetracker,
110 lspManager: lspManager,
111 notify: notify,
112 agents: make(map[string]SessionAgent),
113 }
114
115 agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
116 if !ok {
117 return nil, errCoderAgentNotConfigured
118 }
119
120 // TODO: make this dynamic when we support multiple agents
121 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
122 if err != nil {
123 return nil, err
124 }
125
126 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
127 if err != nil {
128 return nil, err
129 }
130 c.currentAgent = agent
131 c.agents[config.AgentCoder] = agent
132 return c, nil
133}
134
135// Run implements Coordinator.
136func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
137 if err := c.readyWg.Wait(); err != nil {
138 return nil, err
139 }
140
141 // refresh models before each run
142 if err := c.UpdateModels(ctx); err != nil {
143 return nil, fmt.Errorf("failed to update models: %w", err)
144 }
145
146 model := c.currentAgent.Model()
147 maxTokens := model.CatwalkCfg.DefaultMaxTokens
148 if model.ModelCfg.MaxTokens != 0 {
149 maxTokens = model.ModelCfg.MaxTokens
150 }
151
152 if !model.CatwalkCfg.SupportsImages && attachments != nil {
153 // filter out image attachments
154 filteredAttachments := make([]message.Attachment, 0, len(attachments))
155 for _, att := range attachments {
156 if att.IsText() {
157 filteredAttachments = append(filteredAttachments, att)
158 }
159 }
160 attachments = filteredAttachments
161 }
162
163 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
164 if !ok {
165 return nil, errModelProviderNotConfigured
166 }
167
168 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
169
170 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
171 slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
172 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
173 return nil, err
174 }
175 }
176
177 run := func() (*fantasy.AgentResult, error) {
178 return c.currentAgent.Run(ctx, SessionAgentCall{
179 SessionID: sessionID,
180 Prompt: prompt,
181 Attachments: attachments,
182 MaxOutputTokens: maxTokens,
183 ProviderOptions: mergedOptions,
184 Temperature: temp,
185 TopP: topP,
186 TopK: topK,
187 FrequencyPenalty: freqPenalty,
188 PresencePenalty: presPenalty,
189 })
190 }
191 result, originalErr := run()
192
193 if c.isUnauthorized(originalErr) {
194 switch {
195 case providerCfg.OAuthToken != nil:
196 slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
197 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
198 return nil, originalErr
199 }
200 slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
201 return run()
202 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
203 slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
204 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
205 return nil, originalErr
206 }
207 slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
208 return run()
209 }
210 }
211
212 return result, originalErr
213}
214
215func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
216 options := fantasy.ProviderOptions{}
217
218 cfgOpts := []byte("{}")
219 providerCfgOpts := []byte("{}")
220 catwalkOpts := []byte("{}")
221
222 if model.ModelCfg.ProviderOptions != nil {
223 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
224 if err == nil {
225 cfgOpts = data
226 }
227 }
228
229 if providerCfg.ProviderOptions != nil {
230 data, err := json.Marshal(providerCfg.ProviderOptions)
231 if err == nil {
232 providerCfgOpts = data
233 }
234 }
235
236 if model.CatwalkCfg.Options.ProviderOptions != nil {
237 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
238 if err == nil {
239 catwalkOpts = data
240 }
241 }
242
243 readers := []io.Reader{
244 bytes.NewReader(catwalkOpts),
245 bytes.NewReader(providerCfgOpts),
246 bytes.NewReader(cfgOpts),
247 }
248
249 got, err := jsons.Merge(readers)
250 if err != nil {
251 slog.Error("Could not merge call config", "err", err)
252 return options
253 }
254
255 mergedOptions := make(map[string]any)
256
257 err = json.Unmarshal([]byte(got), &mergedOptions)
258 if err != nil {
259 slog.Error("Could not create config for call", "err", err)
260 return options
261 }
262
263 providerType := providerCfg.Type
264 if providerType == "hyper" {
265 if strings.Contains(model.CatwalkCfg.ID, "claude") {
266 providerType = anthropic.Name
267 } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
268 providerType = openai.Name
269 } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
270 providerType = google.Name
271 } else {
272 providerType = openaicompat.Name
273 }
274 }
275
276 switch providerType {
277 case openai.Name, azure.Name:
278 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
279 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
280 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
281 }
282 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
283 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
284 mergedOptions["reasoning_summary"] = "auto"
285 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
286 }
287 parsed, err := openai.ParseResponsesOptions(mergedOptions)
288 if err == nil {
289 options[openai.Name] = parsed
290 }
291 } else {
292 parsed, err := openai.ParseOptions(mergedOptions)
293 if err == nil {
294 options[openai.Name] = parsed
295 }
296 }
297 case anthropic.Name:
298 var (
299 _, hasEffort = mergedOptions["effort"]
300 _, hasThink = mergedOptions["thinking"]
301 )
302 switch {
303 case !hasEffort && model.ModelCfg.ReasoningEffort != "":
304 mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
305 case !hasThink && model.ModelCfg.Think:
306 mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
307 }
308 parsed, err := anthropic.ParseOptions(mergedOptions)
309 if err == nil {
310 options[anthropic.Name] = parsed
311 }
312
313 case openrouter.Name:
314 _, hasReasoning := mergedOptions["reasoning"]
315 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
316 mergedOptions["reasoning"] = map[string]any{
317 "enabled": true,
318 "effort": model.ModelCfg.ReasoningEffort,
319 }
320 }
321 parsed, err := openrouter.ParseOptions(mergedOptions)
322 if err == nil {
323 options[openrouter.Name] = parsed
324 }
325 case vercel.Name:
326 _, hasReasoning := mergedOptions["reasoning"]
327 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
328 mergedOptions["reasoning"] = map[string]any{
329 "enabled": true,
330 "effort": model.ModelCfg.ReasoningEffort,
331 }
332 }
333 parsed, err := vercel.ParseOptions(mergedOptions)
334 if err == nil {
335 options[vercel.Name] = parsed
336 }
337 case google.Name:
338 _, hasReasoning := mergedOptions["thinking_config"]
339 if !hasReasoning {
340 if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
341 mergedOptions["thinking_config"] = map[string]any{
342 "thinking_budget": 2000,
343 "include_thoughts": true,
344 }
345 } else {
346 mergedOptions["thinking_config"] = map[string]any{
347 "thinking_level": model.ModelCfg.ReasoningEffort,
348 "include_thoughts": true,
349 }
350 }
351 }
352 parsed, err := google.ParseOptions(mergedOptions)
353 if err == nil {
354 options[google.Name] = parsed
355 }
356 case openaicompat.Name:
357 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
358 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
359 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
360 }
361 parsed, err := openaicompat.ParseOptions(mergedOptions)
362 if err == nil {
363 options[openaicompat.Name] = parsed
364 }
365 }
366
367 return options
368}
369
370func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
371 modelOptions := getProviderOptions(model, cfg)
372 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
373 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
374 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
375 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
376 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
377 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
378}
379
380func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
381 large, small, err := c.buildAgentModels(ctx, isSubAgent)
382 if err != nil {
383 return nil, err
384 }
385
386 largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
387 result := NewSessionAgent(SessionAgentOptions{
388 LargeModel: large,
389 SmallModel: small,
390 SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
391 SystemPrompt: "",
392 IsSubAgent: isSubAgent,
393 DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
394 IsYolo: c.permissions.SkipRequests(),
395 Sessions: c.sessions,
396 Messages: c.messages,
397 Tools: nil,
398 Notify: c.notify,
399 })
400
401 c.readyWg.Go(func() error {
402 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
403 if err != nil {
404 return err
405 }
406 result.SetSystemPrompt(systemPrompt)
407 return nil
408 })
409
410 c.readyWg.Go(func() error {
411 tools, err := c.buildTools(ctx, agent)
412 if err != nil {
413 return err
414 }
415 result.SetTools(tools)
416 return nil
417 })
418
419 return result, nil
420}
421
422func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
423 var allTools []fantasy.AgentTool
424 if slices.Contains(agent.AllowedTools, AgentToolName) {
425 agentTool, err := c.agentTool(ctx)
426 if err != nil {
427 return nil, err
428 }
429 allTools = append(allTools, agentTool)
430 }
431
432 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
433 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
434 if err != nil {
435 return nil, err
436 }
437 allTools = append(allTools, agenticFetchTool)
438 }
439
440 // Get the model name for the agent
441 modelName := ""
442 if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
443 if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
444 modelName = model.Name
445 }
446 }
447
448 allTools = append(allTools,
449 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
450 tools.NewJobOutputTool(),
451 tools.NewJobKillTool(),
452 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
453 tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
454 tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
455 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
456 tools.NewGlobTool(c.cfg.WorkingDir()),
457 tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
458 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
459 tools.NewSourcegraphTool(nil),
460 tools.NewTodosTool(c.sessions),
461 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
462 tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
463 )
464
465 // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
466 if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
467 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
468 }
469
470 if len(c.cfg.Config().MCP) > 0 {
471 allTools = append(
472 allTools,
473 tools.NewListMCPResourcesTool(c.cfg, c.permissions),
474 tools.NewReadMCPResourceTool(c.cfg, c.permissions),
475 )
476 }
477
478 var filteredTools []fantasy.AgentTool
479 for _, tool := range allTools {
480 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
481 filteredTools = append(filteredTools, tool)
482 }
483 }
484
485 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
486 if agent.AllowedMCP == nil {
487 // No MCP restrictions
488 filteredTools = append(filteredTools, tool)
489 continue
490 }
491 if len(agent.AllowedMCP) == 0 {
492 // No MCPs allowed
493 slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
494 break
495 }
496
497 for mcp, tools := range agent.AllowedMCP {
498 if mcp != tool.MCP() {
499 continue
500 }
501 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
502 filteredTools = append(filteredTools, tool)
503 break
504 }
505 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
506 }
507 }
508 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
509 return strings.Compare(a.Info().Name, b.Info().Name)
510 })
511 return filteredTools, nil
512}
513
514// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
515func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
516 largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
517 if !ok {
518 return Model{}, Model{}, errLargeModelNotSelected
519 }
520 smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
521 if !ok {
522 return Model{}, Model{}, errSmallModelNotSelected
523 }
524
525 largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
526 if !ok {
527 return Model{}, Model{}, errLargeModelProviderNotConfigured
528 }
529
530 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
531 if err != nil {
532 return Model{}, Model{}, err
533 }
534
535 smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
536 if !ok {
537 return Model{}, Model{}, errSmallModelProviderNotConfigured
538 }
539
540 smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
541 if err != nil {
542 return Model{}, Model{}, err
543 }
544
545 var largeCatwalkModel *catwalk.Model
546 var smallCatwalkModel *catwalk.Model
547
548 for _, m := range largeProviderCfg.Models {
549 if m.ID == largeModelCfg.Model {
550 largeCatwalkModel = &m
551 }
552 }
553 for _, m := range smallProviderCfg.Models {
554 if m.ID == smallModelCfg.Model {
555 smallCatwalkModel = &m
556 }
557 }
558
559 if largeCatwalkModel == nil {
560 return Model{}, Model{}, errLargeModelNotFound
561 }
562
563 if smallCatwalkModel == nil {
564 return Model{}, Model{}, errSmallModelNotFound
565 }
566
567 largeModelID := largeModelCfg.Model
568 smallModelID := smallModelCfg.Model
569
570 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
571 largeModelID += ":exacto"
572 }
573
574 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
575 smallModelID += ":exacto"
576 }
577
578 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
579 if err != nil {
580 return Model{}, Model{}, err
581 }
582 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
583 if err != nil {
584 return Model{}, Model{}, err
585 }
586
587 return Model{
588 Model: largeModel,
589 CatwalkCfg: *largeCatwalkModel,
590 ModelCfg: largeModelCfg,
591 }, Model{
592 Model: smallModel,
593 CatwalkCfg: *smallCatwalkModel,
594 ModelCfg: smallModelCfg,
595 }, nil
596}
597
598func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
599 var opts []anthropic.Option
600
601 switch {
602 case strings.HasPrefix(apiKey, "Bearer "):
603 // NOTE: Prevent the SDK from picking up the API key from env.
604 os.Setenv("ANTHROPIC_API_KEY", "")
605 headers["Authorization"] = apiKey
606 case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
607 // NOTE: Prevent the SDK from picking up the API key from env.
608 os.Setenv("ANTHROPIC_API_KEY", "")
609 headers["Authorization"] = "Bearer " + apiKey
610 case apiKey != "":
611 // X-Api-Key header
612 opts = append(opts, anthropic.WithAPIKey(apiKey))
613 }
614
615 if len(headers) > 0 {
616 opts = append(opts, anthropic.WithHeaders(headers))
617 }
618
619 if baseURL != "" {
620 opts = append(opts, anthropic.WithBaseURL(baseURL))
621 }
622
623 if c.cfg.Config().Options.Debug {
624 httpClient := log.NewHTTPClient()
625 opts = append(opts, anthropic.WithHTTPClient(httpClient))
626 }
627 return anthropic.New(opts...)
628}
629
630func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
631 opts := []openai.Option{
632 openai.WithAPIKey(apiKey),
633 openai.WithUseResponsesAPI(),
634 }
635 if c.cfg.Config().Options.Debug {
636 httpClient := log.NewHTTPClient()
637 opts = append(opts, openai.WithHTTPClient(httpClient))
638 }
639 if len(headers) > 0 {
640 opts = append(opts, openai.WithHeaders(headers))
641 }
642 if baseURL != "" {
643 opts = append(opts, openai.WithBaseURL(baseURL))
644 }
645 return openai.New(opts...)
646}
647
648func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
649 opts := []openrouter.Option{
650 openrouter.WithAPIKey(apiKey),
651 }
652 if c.cfg.Config().Options.Debug {
653 httpClient := log.NewHTTPClient()
654 opts = append(opts, openrouter.WithHTTPClient(httpClient))
655 }
656 if len(headers) > 0 {
657 opts = append(opts, openrouter.WithHeaders(headers))
658 }
659 return openrouter.New(opts...)
660}
661
662func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
663 opts := []vercel.Option{
664 vercel.WithAPIKey(apiKey),
665 }
666 if c.cfg.Config().Options.Debug {
667 httpClient := log.NewHTTPClient()
668 opts = append(opts, vercel.WithHTTPClient(httpClient))
669 }
670 if len(headers) > 0 {
671 opts = append(opts, vercel.WithHeaders(headers))
672 }
673 return vercel.New(opts...)
674}
675
676func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
677 opts := []openaicompat.Option{
678 openaicompat.WithBaseURL(baseURL),
679 openaicompat.WithAPIKey(apiKey),
680 }
681
682 // Set HTTP client based on provider and debug mode.
683 var httpClient *http.Client
684 if providerID == string(catwalk.InferenceProviderCopilot) {
685 opts = append(opts, openaicompat.WithUseResponsesAPI())
686 httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
687 } else if c.cfg.Config().Options.Debug {
688 httpClient = log.NewHTTPClient()
689 }
690 if httpClient != nil {
691 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
692 }
693
694 if len(headers) > 0 {
695 opts = append(opts, openaicompat.WithHeaders(headers))
696 }
697
698 for extraKey, extraValue := range extraBody {
699 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
700 }
701
702 return openaicompat.New(opts...)
703}
704
705func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
706 opts := []azure.Option{
707 azure.WithBaseURL(baseURL),
708 azure.WithAPIKey(apiKey),
709 azure.WithUseResponsesAPI(),
710 }
711 if c.cfg.Config().Options.Debug {
712 httpClient := log.NewHTTPClient()
713 opts = append(opts, azure.WithHTTPClient(httpClient))
714 }
715 if options == nil {
716 options = make(map[string]string)
717 }
718 if apiVersion, ok := options["apiVersion"]; ok {
719 opts = append(opts, azure.WithAPIVersion(apiVersion))
720 }
721 if len(headers) > 0 {
722 opts = append(opts, azure.WithHeaders(headers))
723 }
724
725 return azure.New(opts...)
726}
727
728func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
729 var opts []bedrock.Option
730 if c.cfg.Config().Options.Debug {
731 httpClient := log.NewHTTPClient()
732 opts = append(opts, bedrock.WithHTTPClient(httpClient))
733 }
734 if len(headers) > 0 {
735 opts = append(opts, bedrock.WithHeaders(headers))
736 }
737 switch {
738 case apiKey != "":
739 opts = append(opts, bedrock.WithAPIKey(apiKey))
740 case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
741 opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
742 default:
743 // Skip, let the SDK do authentication.
744 }
745 return bedrock.New(opts...)
746}
747
748func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
749 opts := []google.Option{
750 google.WithBaseURL(baseURL),
751 google.WithGeminiAPIKey(apiKey),
752 }
753 if c.cfg.Config().Options.Debug {
754 httpClient := log.NewHTTPClient()
755 opts = append(opts, google.WithHTTPClient(httpClient))
756 }
757 if len(headers) > 0 {
758 opts = append(opts, google.WithHeaders(headers))
759 }
760 return google.New(opts...)
761}
762
763func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
764 opts := []google.Option{}
765 if c.cfg.Config().Options.Debug {
766 httpClient := log.NewHTTPClient()
767 opts = append(opts, google.WithHTTPClient(httpClient))
768 }
769 if len(headers) > 0 {
770 opts = append(opts, google.WithHeaders(headers))
771 }
772
773 project := options["project"]
774 location := options["location"]
775
776 opts = append(opts, google.WithVertex(project, location))
777
778 return google.New(opts...)
779}
780
781func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
782 opts := []hyper.Option{
783 hyper.WithBaseURL(baseURL),
784 hyper.WithAPIKey(apiKey),
785 }
786 if c.cfg.Config().Options.Debug {
787 httpClient := log.NewHTTPClient()
788 opts = append(opts, hyper.WithHTTPClient(httpClient))
789 }
790 return hyper.New(opts...)
791}
792
793func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
794 if model.Think {
795 return true
796 }
797 opts, err := anthropic.ParseOptions(model.ProviderOptions)
798 return err == nil && opts.Thinking != nil
799}
800
801func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
802 headers := maps.Clone(providerCfg.ExtraHeaders)
803 if headers == nil {
804 headers = make(map[string]string)
805 }
806
807 // handle special headers for anthropic
808 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
809 if v, ok := headers["anthropic-beta"]; ok {
810 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
811 } else {
812 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
813 }
814 }
815
816 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
817 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
818
819 switch providerCfg.Type {
820 case openai.Name:
821 return c.buildOpenaiProvider(baseURL, apiKey, headers)
822 case anthropic.Name:
823 return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
824 case openrouter.Name:
825 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
826 case vercel.Name:
827 return c.buildVercelProvider(baseURL, apiKey, headers)
828 case azure.Name:
829 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
830 case bedrock.Name:
831 return c.buildBedrockProvider(apiKey, headers)
832 case google.Name:
833 return c.buildGoogleProvider(baseURL, apiKey, headers)
834 case "google-vertex":
835 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
836 case openaicompat.Name:
837 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
838 if providerCfg.ExtraBody == nil {
839 providerCfg.ExtraBody = map[string]any{}
840 }
841 providerCfg.ExtraBody["tool_stream"] = true
842 }
843 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
844 case hyper.Name:
845 return c.buildHyperProvider(baseURL, apiKey)
846 default:
847 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
848 }
849}
850
851func isExactoSupported(modelID string) bool {
852 supportedModels := []string{
853 "moonshotai/kimi-k2-0905",
854 "deepseek/deepseek-v3.1-terminus",
855 "z-ai/glm-4.6",
856 "openai/gpt-oss-120b",
857 "qwen/qwen3-coder",
858 }
859 return slices.Contains(supportedModels, modelID)
860}
861
862func (c *coordinator) Cancel(sessionID string) {
863 c.currentAgent.Cancel(sessionID)
864}
865
866func (c *coordinator) CancelAll() {
867 c.currentAgent.CancelAll()
868}
869
870func (c *coordinator) ClearQueue(sessionID string) {
871 c.currentAgent.ClearQueue(sessionID)
872}
873
874func (c *coordinator) IsBusy() bool {
875 return c.currentAgent.IsBusy()
876}
877
878func (c *coordinator) IsSessionBusy(sessionID string) bool {
879 return c.currentAgent.IsSessionBusy(sessionID)
880}
881
882func (c *coordinator) Model() Model {
883 return c.currentAgent.Model()
884}
885
886func (c *coordinator) UpdateModels(ctx context.Context) error {
887 // build the models again so we make sure we get the latest config
888 large, small, err := c.buildAgentModels(ctx, false)
889 if err != nil {
890 return err
891 }
892 c.currentAgent.SetModels(large, small)
893
894 agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
895 if !ok {
896 return errCoderAgentNotConfigured
897 }
898
899 tools, err := c.buildTools(ctx, agentCfg)
900 if err != nil {
901 return err
902 }
903 c.currentAgent.SetTools(tools)
904 return nil
905}
906
907func (c *coordinator) QueuedPrompts(sessionID string) int {
908 return c.currentAgent.QueuedPrompts(sessionID)
909}
910
911func (c *coordinator) QueuedPromptsList(sessionID string) []string {
912 return c.currentAgent.QueuedPromptsList(sessionID)
913}
914
915func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
916 providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
917 if !ok {
918 return errModelProviderNotConfigured
919 }
920 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
921}
922
923func (c *coordinator) isUnauthorized(err error) bool {
924 var providerErr *fantasy.ProviderError
925 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
926}
927
928func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
929 if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
930 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
931 return err
932 }
933 if err := c.UpdateModels(ctx); err != nil {
934 return err
935 }
936 return nil
937}
938
939func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
940 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
941 if err != nil {
942 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
943 return err
944 }
945
946 providerCfg.APIKey = newAPIKey
947 c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
948
949 if err := c.UpdateModels(ctx); err != nil {
950 return err
951 }
952 return nil
953}
954
955// subAgentParams holds the parameters for running a sub-agent.
956type subAgentParams struct {
957 Agent SessionAgent
958 SessionID string
959 AgentMessageID string
960 ToolCallID string
961 Prompt string
962 SessionTitle string
963 // SessionSetup is an optional callback invoked after session creation
964 // but before agent execution, for custom session configuration.
965 SessionSetup func(sessionID string)
966}
967
968// runSubAgent runs a sub-agent and handles session management and cost accumulation.
969// It creates a sub-session, runs the agent with the given prompt, and propagates
970// the cost to the parent session.
971func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
972 // Create sub-session
973 agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
974 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
975 if err != nil {
976 return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
977 }
978
979 // Call session setup function if provided
980 if params.SessionSetup != nil {
981 params.SessionSetup(session.ID)
982 }
983
984 // Get model configuration
985 model := params.Agent.Model()
986 maxTokens := model.CatwalkCfg.DefaultMaxTokens
987 if model.ModelCfg.MaxTokens != 0 {
988 maxTokens = model.ModelCfg.MaxTokens
989 }
990
991 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
992 if !ok {
993 return fantasy.ToolResponse{}, errModelProviderNotConfigured
994 }
995
996 // Run the agent
997 result, err := params.Agent.Run(ctx, SessionAgentCall{
998 SessionID: session.ID,
999 Prompt: params.Prompt,
1000 MaxOutputTokens: maxTokens,
1001 ProviderOptions: getProviderOptions(model, providerCfg),
1002 Temperature: model.ModelCfg.Temperature,
1003 TopP: model.ModelCfg.TopP,
1004 TopK: model.ModelCfg.TopK,
1005 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1006 PresencePenalty: model.ModelCfg.PresencePenalty,
1007 NonInteractive: true,
1008 })
1009 if err != nil {
1010 return fantasy.NewTextErrorResponse("error generating response"), nil
1011 }
1012
1013 // Update parent session cost
1014 if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1015 return fantasy.ToolResponse{}, err
1016 }
1017
1018 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1019}
1020
1021// updateParentSessionCost accumulates the cost from a child session to its parent session.
1022func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1023 childSession, err := c.sessions.Get(ctx, childSessionID)
1024 if err != nil {
1025 return fmt.Errorf("get child session: %w", err)
1026 }
1027
1028 parentSession, err := c.sessions.Get(ctx, parentSessionID)
1029 if err != nil {
1030 return fmt.Errorf("get parent session: %w", err)
1031 }
1032
1033 parentSession.Cost += childSession.Cost
1034
1035 if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1036 return fmt.Errorf("save parent session: %w", err)
1037 }
1038
1039 return nil
1040}