1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/fantasy"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/notify"
22 "github.com/charmbracelet/crush/internal/agent/prompt"
23 "github.com/charmbracelet/crush/internal/agent/tools"
24 "github.com/charmbracelet/crush/internal/config"
25 "github.com/charmbracelet/crush/internal/filetracker"
26 "github.com/charmbracelet/crush/internal/history"
27 "github.com/charmbracelet/crush/internal/log"
28 "github.com/charmbracelet/crush/internal/lsp"
29 "github.com/charmbracelet/crush/internal/message"
30 "github.com/charmbracelet/crush/internal/oauth/copilot"
31 "github.com/charmbracelet/crush/internal/permission"
32 "github.com/charmbracelet/crush/internal/pubsub"
33 "github.com/charmbracelet/crush/internal/session"
34 "golang.org/x/sync/errgroup"
35
36 "charm.land/fantasy/providers/anthropic"
37 "charm.land/fantasy/providers/azure"
38 "charm.land/fantasy/providers/bedrock"
39 "charm.land/fantasy/providers/google"
40 "charm.land/fantasy/providers/openai"
41 "charm.land/fantasy/providers/openaicompat"
42 "charm.land/fantasy/providers/openrouter"
43 "charm.land/fantasy/providers/vercel"
44 openaisdk "github.com/charmbracelet/openai-go/option"
45 "github.com/qjebbs/go-jsons"
46)
47
48// Coordinator errors.
49var (
50 errCoderAgentNotConfigured = errors.New("coder agent not configured")
51 errModelProviderNotConfigured = errors.New("model provider not configured")
52 errLargeModelNotSelected = errors.New("large model not selected")
53 errSmallModelNotSelected = errors.New("small model not selected")
54 errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
55 errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
56 errLargeModelNotFound = errors.New("large model not found in provider config")
57 errSmallModelNotFound = errors.New("small model not found in provider config")
58)
59
60type Coordinator interface {
61 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
62 // SetMainAgent(string)
63 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
64 Cancel(sessionID string)
65 CancelAll()
66 IsSessionBusy(sessionID string) bool
67 IsBusy() bool
68 QueuedPrompts(sessionID string) int
69 QueuedPromptsList(sessionID string) []string
70 ClearQueue(sessionID string)
71 Summarize(context.Context, string) error
72 Model() Model
73 UpdateModels(ctx context.Context) error
74}
75
76type coordinator struct {
77 cfg *config.ConfigStore
78 sessions session.Service
79 messages message.Service
80 permissions permission.Service
81 history history.Service
82 filetracker filetracker.Service
83 lspManager *lsp.Manager
84 notify pubsub.Publisher[notify.Notification]
85
86 currentAgent SessionAgent
87 agents map[string]SessionAgent
88
89 readyWg errgroup.Group
90}
91
92func NewCoordinator(
93 ctx context.Context,
94 cfg *config.ConfigStore,
95 sessions session.Service,
96 messages message.Service,
97 permissions permission.Service,
98 history history.Service,
99 filetracker filetracker.Service,
100 lspManager *lsp.Manager,
101 notify pubsub.Publisher[notify.Notification],
102) (Coordinator, error) {
103 c := &coordinator{
104 cfg: cfg,
105 sessions: sessions,
106 messages: messages,
107 permissions: permissions,
108 history: history,
109 filetracker: filetracker,
110 lspManager: lspManager,
111 notify: notify,
112 agents: make(map[string]SessionAgent),
113 }
114
115 agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
116 if !ok {
117 return nil, errCoderAgentNotConfigured
118 }
119
120 // TODO: make this dynamic when we support multiple agents
121 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
122 if err != nil {
123 return nil, err
124 }
125
126 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
127 if err != nil {
128 return nil, err
129 }
130 c.currentAgent = agent
131 c.agents[config.AgentCoder] = agent
132 return c, nil
133}
134
135// Run implements Coordinator.
136func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
137 if err := c.readyWg.Wait(); err != nil {
138 return nil, err
139 }
140
141 // refresh models before each run
142 if err := c.UpdateModels(ctx); err != nil {
143 return nil, fmt.Errorf("failed to update models: %w", err)
144 }
145
146 model := c.currentAgent.Model()
147 maxTokens := model.CatwalkCfg.DefaultMaxTokens
148 if model.ModelCfg.MaxTokens != 0 {
149 maxTokens = model.ModelCfg.MaxTokens
150 }
151
152 if !model.CatwalkCfg.SupportsImages && attachments != nil {
153 // filter out image attachments
154 filteredAttachments := make([]message.Attachment, 0, len(attachments))
155 for _, att := range attachments {
156 if att.IsText() {
157 filteredAttachments = append(filteredAttachments, att)
158 }
159 }
160 attachments = filteredAttachments
161 }
162
163 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
164 if !ok {
165 return nil, errModelProviderNotConfigured
166 }
167
168 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
169
170 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
171 slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
172 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
173 return nil, err
174 }
175 }
176
177 run := func() (*fantasy.AgentResult, error) {
178 return c.currentAgent.Run(ctx, SessionAgentCall{
179 SessionID: sessionID,
180 Prompt: prompt,
181 Attachments: attachments,
182 MaxOutputTokens: maxTokens,
183 ProviderOptions: mergedOptions,
184 Temperature: temp,
185 TopP: topP,
186 TopK: topK,
187 FrequencyPenalty: freqPenalty,
188 PresencePenalty: presPenalty,
189 })
190 }
191 result, originalErr := run()
192
193 if c.isUnauthorized(originalErr) {
194 switch {
195 case providerCfg.OAuthToken != nil:
196 slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
197 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
198 return nil, originalErr
199 }
200 slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
201 return run()
202 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
203 slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
204 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
205 return nil, originalErr
206 }
207 slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
208 return run()
209 }
210 }
211
212 return result, originalErr
213}
214
215func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
216 options := fantasy.ProviderOptions{}
217
218 cfgOpts := []byte("{}")
219 providerCfgOpts := []byte("{}")
220 catwalkOpts := []byte("{}")
221
222 if model.ModelCfg.ProviderOptions != nil {
223 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
224 if err == nil {
225 cfgOpts = data
226 }
227 }
228
229 if providerCfg.ProviderOptions != nil {
230 data, err := json.Marshal(providerCfg.ProviderOptions)
231 if err == nil {
232 providerCfgOpts = data
233 }
234 }
235
236 if model.CatwalkCfg.Options.ProviderOptions != nil {
237 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
238 if err == nil {
239 catwalkOpts = data
240 }
241 }
242
243 readers := []io.Reader{
244 bytes.NewReader(catwalkOpts),
245 bytes.NewReader(providerCfgOpts),
246 bytes.NewReader(cfgOpts),
247 }
248
249 got, err := jsons.Merge(readers)
250 if err != nil {
251 slog.Error("Could not merge call config", "err", err)
252 return options
253 }
254
255 mergedOptions := make(map[string]any)
256
257 err = json.Unmarshal([]byte(got), &mergedOptions)
258 if err != nil {
259 slog.Error("Could not create config for call", "err", err)
260 return options
261 }
262
263 providerType := providerCfg.Type
264 if providerType == "hyper" {
265 if strings.Contains(model.CatwalkCfg.ID, "claude") {
266 providerType = anthropic.Name
267 } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
268 providerType = openai.Name
269 } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
270 providerType = google.Name
271 } else {
272 providerType = openaicompat.Name
273 }
274 }
275
276 switch providerType {
277 case openai.Name, azure.Name:
278 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
279 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
280 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
281 }
282 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
283 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
284 mergedOptions["reasoning_summary"] = "auto"
285 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
286 }
287 parsed, err := openai.ParseResponsesOptions(mergedOptions)
288 if err == nil {
289 options[openai.Name] = parsed
290 }
291 } else {
292 parsed, err := openai.ParseOptions(mergedOptions)
293 if err == nil {
294 options[openai.Name] = parsed
295 }
296 }
297 case anthropic.Name:
298 var (
299 _, hasEffort = mergedOptions["effort"]
300 _, hasThink = mergedOptions["thinking"]
301 )
302 switch {
303 case !hasEffort && model.ModelCfg.ReasoningEffort != "":
304 mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
305 case !hasThink && model.ModelCfg.Think:
306 mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
307 }
308 parsed, err := anthropic.ParseOptions(mergedOptions)
309 if err == nil {
310 options[anthropic.Name] = parsed
311 }
312
313 case openrouter.Name:
314 _, hasReasoning := mergedOptions["reasoning"]
315 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
316 mergedOptions["reasoning"] = map[string]any{
317 "enabled": true,
318 "effort": model.ModelCfg.ReasoningEffort,
319 }
320 }
321 parsed, err := openrouter.ParseOptions(mergedOptions)
322 if err == nil {
323 options[openrouter.Name] = parsed
324 }
325 case vercel.Name:
326 _, hasReasoning := mergedOptions["reasoning"]
327 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
328 mergedOptions["reasoning"] = map[string]any{
329 "enabled": true,
330 "effort": model.ModelCfg.ReasoningEffort,
331 }
332 }
333 parsed, err := vercel.ParseOptions(mergedOptions)
334 if err == nil {
335 options[vercel.Name] = parsed
336 }
337 case google.Name:
338 _, hasReasoning := mergedOptions["thinking_config"]
339 if !hasReasoning {
340 if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
341 mergedOptions["thinking_config"] = map[string]any{
342 "thinking_budget": 2000,
343 "include_thoughts": true,
344 }
345 } else {
346 mergedOptions["thinking_config"] = map[string]any{
347 "thinking_level": model.ModelCfg.ReasoningEffort,
348 "include_thoughts": true,
349 }
350 }
351 }
352 parsed, err := google.ParseOptions(mergedOptions)
353 if err == nil {
354 options[google.Name] = parsed
355 }
356 case openaicompat.Name:
357 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
358 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
359 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
360 }
361 parsed, err := openaicompat.ParseOptions(mergedOptions)
362 if err == nil {
363 options[openaicompat.Name] = parsed
364 }
365 }
366
367 return options
368}
369
370func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
371 modelOptions := getProviderOptions(model, cfg)
372 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
373 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
374 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
375 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
376 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
377 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
378}
379
380func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
381 large, small, err := c.buildAgentModels(ctx, isSubAgent)
382 if err != nil {
383 return nil, err
384 }
385
386 largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
387 result := NewSessionAgent(SessionAgentOptions{
388 LargeModel: large,
389 SmallModel: small,
390 SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
391 SystemPrompt: "",
392 IsSubAgent: isSubAgent,
393 DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
394 IsYolo: c.permissions.SkipRequests(),
395 Sessions: c.sessions,
396 Messages: c.messages,
397 Tools: nil,
398 Notify: c.notify,
399 })
400
401 c.readyWg.Go(func() error {
402 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
403 if err != nil {
404 return err
405 }
406 result.SetSystemPrompt(systemPrompt)
407 return nil
408 })
409
410 c.readyWg.Go(func() error {
411 tools, err := c.buildTools(ctx, agent)
412 if err != nil {
413 return err
414 }
415 result.SetTools(tools)
416 return nil
417 })
418
419 return result, nil
420}
421
422func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
423 var allTools []fantasy.AgentTool
424 if slices.Contains(agent.AllowedTools, AgentToolName) {
425 agentTool, err := c.agentTool(ctx)
426 if err != nil {
427 return nil, err
428 }
429 allTools = append(allTools, agentTool)
430 }
431
432 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
433 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
434 if err != nil {
435 return nil, err
436 }
437 allTools = append(allTools, agenticFetchTool)
438 }
439
440 // Get the model name for the agent
441 modelName := ""
442 if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
443 if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
444 modelName = model.Name
445 }
446 }
447
448 allTools = append(allTools,
449 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
450 tools.NewCrushInfoTool(c.cfg, c.lspManager),
451 tools.NewJobOutputTool(),
452 tools.NewJobKillTool(),
453 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
454 tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
455 tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
456 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
457 tools.NewGlobTool(c.cfg.WorkingDir()),
458 tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
459 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
460 tools.NewSourcegraphTool(nil),
461 tools.NewTodosTool(c.sessions),
462 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
463 tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
464 )
465
466 // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
467 if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
468 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
469 }
470
471 if len(c.cfg.Config().MCP) > 0 {
472 allTools = append(
473 allTools,
474 tools.NewListMCPResourcesTool(c.cfg, c.permissions),
475 tools.NewReadMCPResourceTool(c.cfg, c.permissions),
476 )
477 }
478
479 var filteredTools []fantasy.AgentTool
480 for _, tool := range allTools {
481 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
482 filteredTools = append(filteredTools, tool)
483 }
484 }
485
486 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
487 if agent.AllowedMCP == nil {
488 // No MCP restrictions
489 filteredTools = append(filteredTools, tool)
490 continue
491 }
492 if len(agent.AllowedMCP) == 0 {
493 // No MCPs allowed
494 slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
495 break
496 }
497
498 for mcp, tools := range agent.AllowedMCP {
499 if mcp != tool.MCP() {
500 continue
501 }
502 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
503 filteredTools = append(filteredTools, tool)
504 break
505 }
506 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
507 }
508 }
509 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
510 return strings.Compare(a.Info().Name, b.Info().Name)
511 })
512 return filteredTools, nil
513}
514
515// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
516func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
517 largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
518 if !ok {
519 return Model{}, Model{}, errLargeModelNotSelected
520 }
521 smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
522 if !ok {
523 return Model{}, Model{}, errSmallModelNotSelected
524 }
525
526 largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
527 if !ok {
528 return Model{}, Model{}, errLargeModelProviderNotConfigured
529 }
530
531 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
532 if err != nil {
533 return Model{}, Model{}, err
534 }
535
536 smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
537 if !ok {
538 return Model{}, Model{}, errSmallModelProviderNotConfigured
539 }
540
541 smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
542 if err != nil {
543 return Model{}, Model{}, err
544 }
545
546 var largeCatwalkModel *catwalk.Model
547 var smallCatwalkModel *catwalk.Model
548
549 for _, m := range largeProviderCfg.Models {
550 if m.ID == largeModelCfg.Model {
551 largeCatwalkModel = &m
552 }
553 }
554 for _, m := range smallProviderCfg.Models {
555 if m.ID == smallModelCfg.Model {
556 smallCatwalkModel = &m
557 }
558 }
559
560 if largeCatwalkModel == nil {
561 return Model{}, Model{}, errLargeModelNotFound
562 }
563
564 if smallCatwalkModel == nil {
565 return Model{}, Model{}, errSmallModelNotFound
566 }
567
568 largeModelID := largeModelCfg.Model
569 smallModelID := smallModelCfg.Model
570
571 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
572 largeModelID += ":exacto"
573 }
574
575 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
576 smallModelID += ":exacto"
577 }
578
579 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
580 if err != nil {
581 return Model{}, Model{}, err
582 }
583 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
584 if err != nil {
585 return Model{}, Model{}, err
586 }
587
588 return Model{
589 Model: largeModel,
590 CatwalkCfg: *largeCatwalkModel,
591 ModelCfg: largeModelCfg,
592 }, Model{
593 Model: smallModel,
594 CatwalkCfg: *smallCatwalkModel,
595 ModelCfg: smallModelCfg,
596 }, nil
597}
598
599func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
600 var opts []anthropic.Option
601
602 switch {
603 case strings.HasPrefix(apiKey, "Bearer "):
604 // NOTE: Prevent the SDK from picking up the API key from env.
605 os.Setenv("ANTHROPIC_API_KEY", "")
606 headers["Authorization"] = apiKey
607 case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
608 // NOTE: Prevent the SDK from picking up the API key from env.
609 os.Setenv("ANTHROPIC_API_KEY", "")
610 headers["Authorization"] = "Bearer " + apiKey
611 case apiKey != "":
612 // X-Api-Key header
613 opts = append(opts, anthropic.WithAPIKey(apiKey))
614 }
615
616 if len(headers) > 0 {
617 opts = append(opts, anthropic.WithHeaders(headers))
618 }
619
620 if baseURL != "" {
621 opts = append(opts, anthropic.WithBaseURL(baseURL))
622 }
623
624 if c.cfg.Config().Options.Debug {
625 httpClient := log.NewHTTPClient()
626 opts = append(opts, anthropic.WithHTTPClient(httpClient))
627 }
628 return anthropic.New(opts...)
629}
630
631func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
632 opts := []openai.Option{
633 openai.WithAPIKey(apiKey),
634 openai.WithUseResponsesAPI(),
635 }
636 if c.cfg.Config().Options.Debug {
637 httpClient := log.NewHTTPClient()
638 opts = append(opts, openai.WithHTTPClient(httpClient))
639 }
640 if len(headers) > 0 {
641 opts = append(opts, openai.WithHeaders(headers))
642 }
643 if baseURL != "" {
644 opts = append(opts, openai.WithBaseURL(baseURL))
645 }
646 return openai.New(opts...)
647}
648
649func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
650 opts := []openrouter.Option{
651 openrouter.WithAPIKey(apiKey),
652 }
653 if c.cfg.Config().Options.Debug {
654 httpClient := log.NewHTTPClient()
655 opts = append(opts, openrouter.WithHTTPClient(httpClient))
656 }
657 if len(headers) > 0 {
658 opts = append(opts, openrouter.WithHeaders(headers))
659 }
660 return openrouter.New(opts...)
661}
662
663func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
664 opts := []vercel.Option{
665 vercel.WithAPIKey(apiKey),
666 }
667 if c.cfg.Config().Options.Debug {
668 httpClient := log.NewHTTPClient()
669 opts = append(opts, vercel.WithHTTPClient(httpClient))
670 }
671 if len(headers) > 0 {
672 opts = append(opts, vercel.WithHeaders(headers))
673 }
674 return vercel.New(opts...)
675}
676
677func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
678 opts := []openaicompat.Option{
679 openaicompat.WithBaseURL(baseURL),
680 openaicompat.WithAPIKey(apiKey),
681 }
682
683 // Set HTTP client based on provider and debug mode.
684 var httpClient *http.Client
685 if providerID == string(catwalk.InferenceProviderCopilot) {
686 opts = append(opts, openaicompat.WithUseResponsesAPI())
687 httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
688 } else if c.cfg.Config().Options.Debug {
689 httpClient = log.NewHTTPClient()
690 }
691 if httpClient != nil {
692 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
693 }
694
695 if len(headers) > 0 {
696 opts = append(opts, openaicompat.WithHeaders(headers))
697 }
698
699 for extraKey, extraValue := range extraBody {
700 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
701 }
702
703 return openaicompat.New(opts...)
704}
705
706func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
707 opts := []azure.Option{
708 azure.WithBaseURL(baseURL),
709 azure.WithAPIKey(apiKey),
710 azure.WithUseResponsesAPI(),
711 }
712 if c.cfg.Config().Options.Debug {
713 httpClient := log.NewHTTPClient()
714 opts = append(opts, azure.WithHTTPClient(httpClient))
715 }
716 if options == nil {
717 options = make(map[string]string)
718 }
719 if apiVersion, ok := options["apiVersion"]; ok {
720 opts = append(opts, azure.WithAPIVersion(apiVersion))
721 }
722 if len(headers) > 0 {
723 opts = append(opts, azure.WithHeaders(headers))
724 }
725
726 return azure.New(opts...)
727}
728
729func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
730 var opts []bedrock.Option
731 if c.cfg.Config().Options.Debug {
732 httpClient := log.NewHTTPClient()
733 opts = append(opts, bedrock.WithHTTPClient(httpClient))
734 }
735 if len(headers) > 0 {
736 opts = append(opts, bedrock.WithHeaders(headers))
737 }
738 switch {
739 case apiKey != "":
740 opts = append(opts, bedrock.WithAPIKey(apiKey))
741 case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
742 opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
743 default:
744 // Skip, let the SDK do authentication.
745 }
746 return bedrock.New(opts...)
747}
748
749func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
750 opts := []google.Option{
751 google.WithBaseURL(baseURL),
752 google.WithGeminiAPIKey(apiKey),
753 }
754 if c.cfg.Config().Options.Debug {
755 httpClient := log.NewHTTPClient()
756 opts = append(opts, google.WithHTTPClient(httpClient))
757 }
758 if len(headers) > 0 {
759 opts = append(opts, google.WithHeaders(headers))
760 }
761 return google.New(opts...)
762}
763
764func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
765 opts := []google.Option{}
766 if c.cfg.Config().Options.Debug {
767 httpClient := log.NewHTTPClient()
768 opts = append(opts, google.WithHTTPClient(httpClient))
769 }
770 if len(headers) > 0 {
771 opts = append(opts, google.WithHeaders(headers))
772 }
773
774 project := options["project"]
775 location := options["location"]
776
777 opts = append(opts, google.WithVertex(project, location))
778
779 return google.New(opts...)
780}
781
782func (c *coordinator) buildHyperProvider(apiKey string) (fantasy.Provider, error) {
783 opts := []hyper.Option{
784 hyper.WithAPIKey(apiKey),
785 }
786 if c.cfg.Config().Options.Debug {
787 httpClient := log.NewHTTPClient()
788 opts = append(opts, hyper.WithHTTPClient(httpClient))
789 }
790 return hyper.New(opts...)
791}
792
793func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
794 if model.Think {
795 return true
796 }
797 opts, err := anthropic.ParseOptions(model.ProviderOptions)
798 return err == nil && opts.Thinking != nil
799}
800
801func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
802 headers := maps.Clone(providerCfg.ExtraHeaders)
803 if headers == nil {
804 headers = make(map[string]string)
805 }
806
807 // handle special headers for anthropic
808 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
809 if v, ok := headers["anthropic-beta"]; ok {
810 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
811 } else {
812 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
813 }
814 }
815
816 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
817 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
818
819 switch providerCfg.Type {
820 case openai.Name:
821 return c.buildOpenaiProvider(baseURL, apiKey, headers)
822 case anthropic.Name:
823 return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
824 case openrouter.Name:
825 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
826 case vercel.Name:
827 return c.buildVercelProvider(baseURL, apiKey, headers)
828 case azure.Name:
829 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
830 case bedrock.Name:
831 return c.buildBedrockProvider(apiKey, headers)
832 case google.Name:
833 return c.buildGoogleProvider(baseURL, apiKey, headers)
834 case "google-vertex":
835 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
836 case openaicompat.Name:
837 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
838 if providerCfg.ExtraBody == nil {
839 providerCfg.ExtraBody = map[string]any{}
840 }
841 providerCfg.ExtraBody["tool_stream"] = true
842 }
843 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
844 case hyper.Name:
845 return c.buildHyperProvider(apiKey)
846 default:
847 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
848 }
849}
850
851func isExactoSupported(modelID string) bool {
852 supportedModels := []string{
853 "moonshotai/kimi-k2-0905",
854 "deepseek/deepseek-v3.1-terminus",
855 "z-ai/glm-4.6",
856 "openai/gpt-oss-120b",
857 "qwen/qwen3-coder",
858 }
859 return slices.Contains(supportedModels, modelID)
860}
861
862func (c *coordinator) Cancel(sessionID string) {
863 c.currentAgent.Cancel(sessionID)
864}
865
866func (c *coordinator) CancelAll() {
867 c.currentAgent.CancelAll()
868}
869
870func (c *coordinator) ClearQueue(sessionID string) {
871 c.currentAgent.ClearQueue(sessionID)
872}
873
874func (c *coordinator) IsBusy() bool {
875 return c.currentAgent.IsBusy()
876}
877
878func (c *coordinator) IsSessionBusy(sessionID string) bool {
879 return c.currentAgent.IsSessionBusy(sessionID)
880}
881
882func (c *coordinator) Model() Model {
883 return c.currentAgent.Model()
884}
885
886func (c *coordinator) UpdateModels(ctx context.Context) error {
887 // build the models again so we make sure we get the latest config
888 large, small, err := c.buildAgentModels(ctx, false)
889 if err != nil {
890 return err
891 }
892 c.currentAgent.SetModels(large, small)
893
894 agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
895 if !ok {
896 return errCoderAgentNotConfigured
897 }
898
899 tools, err := c.buildTools(ctx, agentCfg)
900 if err != nil {
901 return err
902 }
903 c.currentAgent.SetTools(tools)
904 return nil
905}
906
907func (c *coordinator) QueuedPrompts(sessionID string) int {
908 return c.currentAgent.QueuedPrompts(sessionID)
909}
910
911func (c *coordinator) QueuedPromptsList(sessionID string) []string {
912 return c.currentAgent.QueuedPromptsList(sessionID)
913}
914
915func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
916 providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
917 if !ok {
918 return errModelProviderNotConfigured
919 }
920 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
921}
922
923func (c *coordinator) isUnauthorized(err error) bool {
924 var providerErr *fantasy.ProviderError
925 return (errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized) ||
926 errors.Is(err, hyper.ErrUnauthorized)
927}
928
929func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
930 if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
931 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
932 return err
933 }
934 if err := c.UpdateModels(ctx); err != nil {
935 return err
936 }
937 return nil
938}
939
940func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
941 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
942 if err != nil {
943 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
944 return err
945 }
946
947 providerCfg.APIKey = newAPIKey
948 c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
949
950 if err := c.UpdateModels(ctx); err != nil {
951 return err
952 }
953 return nil
954}
955
956// subAgentParams holds the parameters for running a sub-agent.
957type subAgentParams struct {
958 Agent SessionAgent
959 SessionID string
960 AgentMessageID string
961 ToolCallID string
962 Prompt string
963 SessionTitle string
964 // SessionSetup is an optional callback invoked after session creation
965 // but before agent execution, for custom session configuration.
966 SessionSetup func(sessionID string)
967}
968
969// runSubAgent runs a sub-agent and handles session management and cost accumulation.
970// It creates a sub-session, runs the agent with the given prompt, and propagates
971// the cost to the parent session.
972func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
973 // Create sub-session
974 agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
975 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
976 if err != nil {
977 return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
978 }
979
980 // Call session setup function if provided
981 if params.SessionSetup != nil {
982 params.SessionSetup(session.ID)
983 }
984
985 // Get model configuration
986 model := params.Agent.Model()
987 maxTokens := model.CatwalkCfg.DefaultMaxTokens
988 if model.ModelCfg.MaxTokens != 0 {
989 maxTokens = model.ModelCfg.MaxTokens
990 }
991
992 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
993 if !ok {
994 return fantasy.ToolResponse{}, errModelProviderNotConfigured
995 }
996
997 // Run the agent
998 result, err := params.Agent.Run(ctx, SessionAgentCall{
999 SessionID: session.ID,
1000 Prompt: params.Prompt,
1001 MaxOutputTokens: maxTokens,
1002 ProviderOptions: getProviderOptions(model, providerCfg),
1003 Temperature: model.ModelCfg.Temperature,
1004 TopP: model.ModelCfg.TopP,
1005 TopK: model.ModelCfg.TopK,
1006 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1007 PresencePenalty: model.ModelCfg.PresencePenalty,
1008 NonInteractive: true,
1009 })
1010 if err != nil {
1011 return fantasy.NewTextErrorResponse("error generating response"), nil
1012 }
1013
1014 // Update parent session cost
1015 if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1016 return fantasy.ToolResponse{}, err
1017 }
1018
1019 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1020}
1021
1022// updateParentSessionCost accumulates the cost from a child session to its parent session.
1023func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1024 childSession, err := c.sessions.Get(ctx, childSessionID)
1025 if err != nil {
1026 return fmt.Errorf("get child session: %w", err)
1027 }
1028
1029 parentSession, err := c.sessions.Get(ctx, parentSessionID)
1030 if err != nil {
1031 return fmt.Errorf("get parent session: %w", err)
1032 }
1033
1034 parentSession.Cost += childSession.Cost
1035
1036 if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1037 return fmt.Errorf("save parent session: %w", err)
1038 }
1039
1040 return nil
1041}