1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "path/filepath"
16 "slices"
17 "strings"
18
19 "charm.land/catwalk/pkg/catwalk"
20 "charm.land/fantasy"
21 "github.com/charmbracelet/crush/internal/agent/hyper"
22 "github.com/charmbracelet/crush/internal/agent/notify"
23 "github.com/charmbracelet/crush/internal/agent/prompt"
24 "github.com/charmbracelet/crush/internal/agent/tools"
25 "github.com/charmbracelet/crush/internal/config"
26 "github.com/charmbracelet/crush/internal/filetracker"
27 "github.com/charmbracelet/crush/internal/history"
28 "github.com/charmbracelet/crush/internal/log"
29 "github.com/charmbracelet/crush/internal/lsp"
30 "github.com/charmbracelet/crush/internal/message"
31 "github.com/charmbracelet/crush/internal/oauth/copilot"
32 "github.com/charmbracelet/crush/internal/permission"
33 "github.com/charmbracelet/crush/internal/pubsub"
34 "github.com/charmbracelet/crush/internal/session"
35 "golang.org/x/sync/errgroup"
36
37 "charm.land/fantasy/providers/anthropic"
38 "charm.land/fantasy/providers/azure"
39 "charm.land/fantasy/providers/bedrock"
40 "charm.land/fantasy/providers/google"
41 "charm.land/fantasy/providers/openai"
42 "charm.land/fantasy/providers/openaicompat"
43 "charm.land/fantasy/providers/openrouter"
44 "charm.land/fantasy/providers/vercel"
45 openaisdk "github.com/charmbracelet/openai-go/option"
46 "github.com/qjebbs/go-jsons"
47)
48
49// Coordinator errors.
50var (
51 errCoderAgentNotConfigured = errors.New("coder agent not configured")
52 errModelProviderNotConfigured = errors.New("model provider not configured")
53 errLargeModelNotSelected = errors.New("large model not selected")
54 errSmallModelNotSelected = errors.New("small model not selected")
55 errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
56 errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
57 errLargeModelNotFound = errors.New("large model not found in provider config")
58 errSmallModelNotFound = errors.New("small model not found in provider config")
59)
60
61type Coordinator interface {
62 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
63 // SetMainAgent(string)
64 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
65 Cancel(sessionID string)
66 CancelAll()
67 IsSessionBusy(sessionID string) bool
68 IsBusy() bool
69 QueuedPrompts(sessionID string) int
70 QueuedPromptsList(sessionID string) []string
71 ClearQueue(sessionID string)
72 Summarize(context.Context, string) error
73 Model() Model
74 UpdateModels(ctx context.Context) error
75}
76
77type coordinator struct {
78 cfg *config.ConfigStore
79 sessions session.Service
80 messages message.Service
81 permissions permission.Service
82 history history.Service
83 filetracker filetracker.Service
84 lspManager *lsp.Manager
85 notify pubsub.Publisher[notify.Notification]
86
87 currentAgent SessionAgent
88 agents map[string]SessionAgent
89
90 readyWg errgroup.Group
91}
92
93func NewCoordinator(
94 ctx context.Context,
95 cfg *config.ConfigStore,
96 sessions session.Service,
97 messages message.Service,
98 permissions permission.Service,
99 history history.Service,
100 filetracker filetracker.Service,
101 lspManager *lsp.Manager,
102 notify pubsub.Publisher[notify.Notification],
103) (Coordinator, error) {
104 c := &coordinator{
105 cfg: cfg,
106 sessions: sessions,
107 messages: messages,
108 permissions: permissions,
109 history: history,
110 filetracker: filetracker,
111 lspManager: lspManager,
112 notify: notify,
113 agents: make(map[string]SessionAgent),
114 }
115
116 agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
117 if !ok {
118 return nil, errCoderAgentNotConfigured
119 }
120
121 // TODO: make this dynamic when we support multiple agents
122 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
123 if err != nil {
124 return nil, err
125 }
126
127 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
128 if err != nil {
129 return nil, err
130 }
131 c.currentAgent = agent
132 c.agents[config.AgentCoder] = agent
133 return c, nil
134}
135
136// Run implements Coordinator.
137func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
138 if err := c.readyWg.Wait(); err != nil {
139 return nil, err
140 }
141
142 // refresh models before each run
143 if err := c.UpdateModels(ctx); err != nil {
144 return nil, fmt.Errorf("failed to update models: %w", err)
145 }
146
147 model := c.currentAgent.Model()
148 maxTokens := model.CatwalkCfg.DefaultMaxTokens
149 if model.ModelCfg.MaxTokens != 0 {
150 maxTokens = model.ModelCfg.MaxTokens
151 }
152
153 if !model.CatwalkCfg.SupportsImages && attachments != nil {
154 // filter out image attachments
155 filteredAttachments := make([]message.Attachment, 0, len(attachments))
156 for _, att := range attachments {
157 if att.IsText() {
158 filteredAttachments = append(filteredAttachments, att)
159 }
160 }
161 attachments = filteredAttachments
162 }
163
164 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
165 if !ok {
166 return nil, errModelProviderNotConfigured
167 }
168
169 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
170
171 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
172 slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
173 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
174 return nil, err
175 }
176 }
177
178 run := func() (*fantasy.AgentResult, error) {
179 return c.currentAgent.Run(ctx, SessionAgentCall{
180 SessionID: sessionID,
181 Prompt: prompt,
182 Attachments: attachments,
183 MaxOutputTokens: maxTokens,
184 ProviderOptions: mergedOptions,
185 Temperature: temp,
186 TopP: topP,
187 TopK: topK,
188 FrequencyPenalty: freqPenalty,
189 PresencePenalty: presPenalty,
190 })
191 }
192 result, originalErr := run()
193
194 if c.isUnauthorized(originalErr) {
195 switch {
196 case providerCfg.OAuthToken != nil:
197 slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
198 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
199 return nil, originalErr
200 }
201 slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
202 return run()
203 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
204 slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
205 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
206 return nil, originalErr
207 }
208 slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
209 return run()
210 }
211 }
212
213 return result, originalErr
214}
215
216func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
217 options := fantasy.ProviderOptions{}
218
219 cfgOpts := []byte("{}")
220 providerCfgOpts := []byte("{}")
221 catwalkOpts := []byte("{}")
222
223 if model.ModelCfg.ProviderOptions != nil {
224 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
225 if err == nil {
226 cfgOpts = data
227 }
228 }
229
230 if providerCfg.ProviderOptions != nil {
231 data, err := json.Marshal(providerCfg.ProviderOptions)
232 if err == nil {
233 providerCfgOpts = data
234 }
235 }
236
237 if model.CatwalkCfg.Options.ProviderOptions != nil {
238 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
239 if err == nil {
240 catwalkOpts = data
241 }
242 }
243
244 readers := []io.Reader{
245 bytes.NewReader(catwalkOpts),
246 bytes.NewReader(providerCfgOpts),
247 bytes.NewReader(cfgOpts),
248 }
249
250 got, err := jsons.Merge(readers)
251 if err != nil {
252 slog.Error("Could not merge call config", "err", err)
253 return options
254 }
255
256 mergedOptions := make(map[string]any)
257
258 err = json.Unmarshal([]byte(got), &mergedOptions)
259 if err != nil {
260 slog.Error("Could not create config for call", "err", err)
261 return options
262 }
263
264 providerType := providerCfg.Type
265 if providerType == "hyper" {
266 if strings.Contains(model.CatwalkCfg.ID, "claude") {
267 providerType = anthropic.Name
268 } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
269 providerType = openai.Name
270 } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
271 providerType = google.Name
272 } else {
273 providerType = openaicompat.Name
274 }
275 }
276
277 switch providerType {
278 case openai.Name, azure.Name:
279 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
280 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
281 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
282 }
283 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
284 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
285 mergedOptions["reasoning_summary"] = "auto"
286 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
287 }
288 parsed, err := openai.ParseResponsesOptions(mergedOptions)
289 if err == nil {
290 options[openai.Name] = parsed
291 }
292 } else {
293 parsed, err := openai.ParseOptions(mergedOptions)
294 if err == nil {
295 options[openai.Name] = parsed
296 }
297 }
298 case anthropic.Name:
299 var (
300 _, hasEffort = mergedOptions["effort"]
301 _, hasThink = mergedOptions["thinking"]
302 )
303 switch {
304 case !hasEffort && model.ModelCfg.ReasoningEffort != "":
305 mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
306 case !hasThink && model.ModelCfg.Think:
307 mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
308 }
309 parsed, err := anthropic.ParseOptions(mergedOptions)
310 if err == nil {
311 options[anthropic.Name] = parsed
312 }
313
314 case openrouter.Name:
315 _, hasReasoning := mergedOptions["reasoning"]
316 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
317 mergedOptions["reasoning"] = map[string]any{
318 "enabled": true,
319 "effort": model.ModelCfg.ReasoningEffort,
320 }
321 }
322 parsed, err := openrouter.ParseOptions(mergedOptions)
323 if err == nil {
324 options[openrouter.Name] = parsed
325 }
326 case vercel.Name:
327 _, hasReasoning := mergedOptions["reasoning"]
328 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
329 mergedOptions["reasoning"] = map[string]any{
330 "enabled": true,
331 "effort": model.ModelCfg.ReasoningEffort,
332 }
333 }
334 parsed, err := vercel.ParseOptions(mergedOptions)
335 if err == nil {
336 options[vercel.Name] = parsed
337 }
338 case google.Name:
339 _, hasReasoning := mergedOptions["thinking_config"]
340 if !hasReasoning {
341 if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
342 mergedOptions["thinking_config"] = map[string]any{
343 "thinking_budget": 2000,
344 "include_thoughts": true,
345 }
346 } else {
347 mergedOptions["thinking_config"] = map[string]any{
348 "thinking_level": model.ModelCfg.ReasoningEffort,
349 "include_thoughts": true,
350 }
351 }
352 }
353 parsed, err := google.ParseOptions(mergedOptions)
354 if err == nil {
355 options[google.Name] = parsed
356 }
357 case openaicompat.Name:
358 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
359 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
360 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
361 }
362 parsed, err := openaicompat.ParseOptions(mergedOptions)
363 if err == nil {
364 options[openaicompat.Name] = parsed
365 }
366 }
367
368 return options
369}
370
371func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
372 modelOptions := getProviderOptions(model, cfg)
373 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
374 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
375 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
376 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
377 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
378 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
379}
380
381func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
382 large, small, err := c.buildAgentModels(ctx, isSubAgent)
383 if err != nil {
384 return nil, err
385 }
386
387 largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
388 result := NewSessionAgent(SessionAgentOptions{
389 LargeModel: large,
390 SmallModel: small,
391 SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
392 SystemPrompt: "",
393 IsSubAgent: isSubAgent,
394 DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
395 IsYolo: c.permissions.SkipRequests(),
396 Sessions: c.sessions,
397 Messages: c.messages,
398 Tools: nil,
399 Notify: c.notify,
400 })
401
402 c.readyWg.Go(func() error {
403 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
404 if err != nil {
405 return err
406 }
407 result.SetSystemPrompt(systemPrompt)
408 return nil
409 })
410
411 c.readyWg.Go(func() error {
412 tools, err := c.buildTools(ctx, agent)
413 if err != nil {
414 return err
415 }
416 result.SetTools(tools)
417 return nil
418 })
419
420 return result, nil
421}
422
423func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
424 var allTools []fantasy.AgentTool
425 if slices.Contains(agent.AllowedTools, AgentToolName) {
426 agentTool, err := c.agentTool(ctx)
427 if err != nil {
428 return nil, err
429 }
430 allTools = append(allTools, agentTool)
431 }
432
433 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
434 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
435 if err != nil {
436 return nil, err
437 }
438 allTools = append(allTools, agenticFetchTool)
439 }
440
441 // Get the model name for the agent
442 modelName := ""
443 if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
444 if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
445 modelName = model.Name
446 }
447 }
448
449 logFile := filepath.Join(c.cfg.Config().Options.DataDirectory, "logs", "crush.log")
450
451 allTools = append(allTools,
452 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
453 tools.NewCrushInfoTool(c.cfg, c.lspManager),
454 tools.NewCrushLogsTool(logFile),
455 tools.NewJobOutputTool(),
456 tools.NewJobKillTool(),
457 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
458 tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
459 tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
460 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
461 tools.NewGlobTool(c.cfg.WorkingDir()),
462 tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
463 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
464 tools.NewSourcegraphTool(nil),
465 tools.NewTodosTool(c.sessions),
466 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
467 tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
468 )
469
470 // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
471 if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
472 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
473 }
474
475 if len(c.cfg.Config().MCP) > 0 {
476 allTools = append(
477 allTools,
478 tools.NewListMCPResourcesTool(c.cfg, c.permissions),
479 tools.NewReadMCPResourceTool(c.cfg, c.permissions),
480 )
481 }
482
483 var filteredTools []fantasy.AgentTool
484 for _, tool := range allTools {
485 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
486 filteredTools = append(filteredTools, tool)
487 }
488 }
489
490 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
491 if agent.AllowedMCP == nil {
492 // No MCP restrictions
493 filteredTools = append(filteredTools, tool)
494 continue
495 }
496 if len(agent.AllowedMCP) == 0 {
497 // No MCPs allowed
498 slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
499 break
500 }
501
502 for mcp, tools := range agent.AllowedMCP {
503 if mcp != tool.MCP() {
504 continue
505 }
506 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
507 filteredTools = append(filteredTools, tool)
508 break
509 }
510 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
511 }
512 }
513 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
514 return strings.Compare(a.Info().Name, b.Info().Name)
515 })
516 return filteredTools, nil
517}
518
519// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
520func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
521 largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
522 if !ok {
523 return Model{}, Model{}, errLargeModelNotSelected
524 }
525 smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
526 if !ok {
527 return Model{}, Model{}, errSmallModelNotSelected
528 }
529
530 largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
531 if !ok {
532 return Model{}, Model{}, errLargeModelProviderNotConfigured
533 }
534
535 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
536 if err != nil {
537 return Model{}, Model{}, err
538 }
539
540 smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
541 if !ok {
542 return Model{}, Model{}, errSmallModelProviderNotConfigured
543 }
544
545 smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
546 if err != nil {
547 return Model{}, Model{}, err
548 }
549
550 var largeCatwalkModel *catwalk.Model
551 var smallCatwalkModel *catwalk.Model
552
553 for _, m := range largeProviderCfg.Models {
554 if m.ID == largeModelCfg.Model {
555 largeCatwalkModel = &m
556 }
557 }
558 for _, m := range smallProviderCfg.Models {
559 if m.ID == smallModelCfg.Model {
560 smallCatwalkModel = &m
561 }
562 }
563
564 if largeCatwalkModel == nil {
565 return Model{}, Model{}, errLargeModelNotFound
566 }
567
568 if smallCatwalkModel == nil {
569 return Model{}, Model{}, errSmallModelNotFound
570 }
571
572 largeModelID := largeModelCfg.Model
573 smallModelID := smallModelCfg.Model
574
575 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
576 largeModelID += ":exacto"
577 }
578
579 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
580 smallModelID += ":exacto"
581 }
582
583 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
584 if err != nil {
585 return Model{}, Model{}, err
586 }
587 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
588 if err != nil {
589 return Model{}, Model{}, err
590 }
591
592 return Model{
593 Model: largeModel,
594 CatwalkCfg: *largeCatwalkModel,
595 ModelCfg: largeModelCfg,
596 }, Model{
597 Model: smallModel,
598 CatwalkCfg: *smallCatwalkModel,
599 ModelCfg: smallModelCfg,
600 }, nil
601}
602
603func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
604 var opts []anthropic.Option
605
606 switch {
607 case strings.HasPrefix(apiKey, "Bearer "):
608 // NOTE: Prevent the SDK from picking up the API key from env.
609 os.Setenv("ANTHROPIC_API_KEY", "")
610 headers["Authorization"] = apiKey
611 case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
612 // NOTE: Prevent the SDK from picking up the API key from env.
613 os.Setenv("ANTHROPIC_API_KEY", "")
614 headers["Authorization"] = "Bearer " + apiKey
615 case apiKey != "":
616 // X-Api-Key header
617 opts = append(opts, anthropic.WithAPIKey(apiKey))
618 }
619
620 if len(headers) > 0 {
621 opts = append(opts, anthropic.WithHeaders(headers))
622 }
623
624 if baseURL != "" {
625 opts = append(opts, anthropic.WithBaseURL(baseURL))
626 }
627
628 if c.cfg.Config().Options.Debug {
629 httpClient := log.NewHTTPClient()
630 opts = append(opts, anthropic.WithHTTPClient(httpClient))
631 }
632 return anthropic.New(opts...)
633}
634
635func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
636 opts := []openai.Option{
637 openai.WithAPIKey(apiKey),
638 openai.WithUseResponsesAPI(),
639 }
640 if c.cfg.Config().Options.Debug {
641 httpClient := log.NewHTTPClient()
642 opts = append(opts, openai.WithHTTPClient(httpClient))
643 }
644 if len(headers) > 0 {
645 opts = append(opts, openai.WithHeaders(headers))
646 }
647 if baseURL != "" {
648 opts = append(opts, openai.WithBaseURL(baseURL))
649 }
650 return openai.New(opts...)
651}
652
653func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
654 opts := []openrouter.Option{
655 openrouter.WithAPIKey(apiKey),
656 }
657 if c.cfg.Config().Options.Debug {
658 httpClient := log.NewHTTPClient()
659 opts = append(opts, openrouter.WithHTTPClient(httpClient))
660 }
661 if len(headers) > 0 {
662 opts = append(opts, openrouter.WithHeaders(headers))
663 }
664 return openrouter.New(opts...)
665}
666
667func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
668 opts := []vercel.Option{
669 vercel.WithAPIKey(apiKey),
670 }
671 if c.cfg.Config().Options.Debug {
672 httpClient := log.NewHTTPClient()
673 opts = append(opts, vercel.WithHTTPClient(httpClient))
674 }
675 if len(headers) > 0 {
676 opts = append(opts, vercel.WithHeaders(headers))
677 }
678 return vercel.New(opts...)
679}
680
681func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
682 opts := []openaicompat.Option{
683 openaicompat.WithBaseURL(baseURL),
684 openaicompat.WithAPIKey(apiKey),
685 }
686
687 // Set HTTP client based on provider and debug mode.
688 var httpClient *http.Client
689 if providerID == string(catwalk.InferenceProviderCopilot) {
690 opts = append(opts, openaicompat.WithUseResponsesAPI())
691 httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
692 } else if c.cfg.Config().Options.Debug {
693 httpClient = log.NewHTTPClient()
694 }
695 if httpClient != nil {
696 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
697 }
698
699 if len(headers) > 0 {
700 opts = append(opts, openaicompat.WithHeaders(headers))
701 }
702
703 for extraKey, extraValue := range extraBody {
704 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
705 }
706
707 return openaicompat.New(opts...)
708}
709
710func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
711 opts := []azure.Option{
712 azure.WithBaseURL(baseURL),
713 azure.WithAPIKey(apiKey),
714 azure.WithUseResponsesAPI(),
715 }
716 if c.cfg.Config().Options.Debug {
717 httpClient := log.NewHTTPClient()
718 opts = append(opts, azure.WithHTTPClient(httpClient))
719 }
720 if options == nil {
721 options = make(map[string]string)
722 }
723 if apiVersion, ok := options["apiVersion"]; ok {
724 opts = append(opts, azure.WithAPIVersion(apiVersion))
725 }
726 if len(headers) > 0 {
727 opts = append(opts, azure.WithHeaders(headers))
728 }
729
730 return azure.New(opts...)
731}
732
733func (c *coordinator) buildBedrockProvider(apiKey string, headers map[string]string) (fantasy.Provider, error) {
734 var opts []bedrock.Option
735 if c.cfg.Config().Options.Debug {
736 httpClient := log.NewHTTPClient()
737 opts = append(opts, bedrock.WithHTTPClient(httpClient))
738 }
739 if len(headers) > 0 {
740 opts = append(opts, bedrock.WithHeaders(headers))
741 }
742 switch {
743 case apiKey != "":
744 opts = append(opts, bedrock.WithAPIKey(apiKey))
745 case os.Getenv("AWS_BEARER_TOKEN_BEDROCK") != "":
746 opts = append(opts, bedrock.WithAPIKey(os.Getenv("AWS_BEARER_TOKEN_BEDROCK")))
747 default:
748 // Skip, let the SDK do authentication.
749 }
750 return bedrock.New(opts...)
751}
752
753func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
754 opts := []google.Option{
755 google.WithBaseURL(baseURL),
756 google.WithGeminiAPIKey(apiKey),
757 }
758 if c.cfg.Config().Options.Debug {
759 httpClient := log.NewHTTPClient()
760 opts = append(opts, google.WithHTTPClient(httpClient))
761 }
762 if len(headers) > 0 {
763 opts = append(opts, google.WithHeaders(headers))
764 }
765 return google.New(opts...)
766}
767
768func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
769 opts := []google.Option{}
770 if c.cfg.Config().Options.Debug {
771 httpClient := log.NewHTTPClient()
772 opts = append(opts, google.WithHTTPClient(httpClient))
773 }
774 if len(headers) > 0 {
775 opts = append(opts, google.WithHeaders(headers))
776 }
777
778 project := options["project"]
779 location := options["location"]
780
781 opts = append(opts, google.WithVertex(project, location))
782
783 return google.New(opts...)
784}
785
786func (c *coordinator) buildHyperProvider(apiKey string) (fantasy.Provider, error) {
787 opts := []hyper.Option{
788 hyper.WithAPIKey(apiKey),
789 }
790 if c.cfg.Config().Options.Debug {
791 httpClient := log.NewHTTPClient()
792 opts = append(opts, hyper.WithHTTPClient(httpClient))
793 }
794 return hyper.New(opts...)
795}
796
797func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
798 if model.Think {
799 return true
800 }
801 opts, err := anthropic.ParseOptions(model.ProviderOptions)
802 return err == nil && opts.Thinking != nil
803}
804
805func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
806 headers := maps.Clone(providerCfg.ExtraHeaders)
807 if headers == nil {
808 headers = make(map[string]string)
809 }
810
811 // handle special headers for anthropic
812 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
813 if v, ok := headers["anthropic-beta"]; ok {
814 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
815 } else {
816 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
817 }
818 }
819
820 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
821 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
822
823 switch providerCfg.Type {
824 case openai.Name:
825 return c.buildOpenaiProvider(baseURL, apiKey, headers)
826 case anthropic.Name:
827 return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
828 case openrouter.Name:
829 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
830 case vercel.Name:
831 return c.buildVercelProvider(baseURL, apiKey, headers)
832 case azure.Name:
833 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
834 case bedrock.Name:
835 return c.buildBedrockProvider(apiKey, headers)
836 case google.Name:
837 return c.buildGoogleProvider(baseURL, apiKey, headers)
838 case "google-vertex":
839 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
840 case openaicompat.Name:
841 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
842 if providerCfg.ExtraBody == nil {
843 providerCfg.ExtraBody = map[string]any{}
844 }
845 providerCfg.ExtraBody["tool_stream"] = true
846 }
847 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
848 case hyper.Name:
849 return c.buildHyperProvider(apiKey)
850 default:
851 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
852 }
853}
854
855func isExactoSupported(modelID string) bool {
856 supportedModels := []string{
857 "moonshotai/kimi-k2-0905",
858 "deepseek/deepseek-v3.1-terminus",
859 "z-ai/glm-4.6",
860 "openai/gpt-oss-120b",
861 "qwen/qwen3-coder",
862 }
863 return slices.Contains(supportedModels, modelID)
864}
865
866func (c *coordinator) Cancel(sessionID string) {
867 c.currentAgent.Cancel(sessionID)
868}
869
870func (c *coordinator) CancelAll() {
871 c.currentAgent.CancelAll()
872}
873
874func (c *coordinator) ClearQueue(sessionID string) {
875 c.currentAgent.ClearQueue(sessionID)
876}
877
878func (c *coordinator) IsBusy() bool {
879 return c.currentAgent.IsBusy()
880}
881
882func (c *coordinator) IsSessionBusy(sessionID string) bool {
883 return c.currentAgent.IsSessionBusy(sessionID)
884}
885
886func (c *coordinator) Model() Model {
887 return c.currentAgent.Model()
888}
889
890func (c *coordinator) UpdateModels(ctx context.Context) error {
891 // build the models again so we make sure we get the latest config
892 large, small, err := c.buildAgentModels(ctx, false)
893 if err != nil {
894 return err
895 }
896 c.currentAgent.SetModels(large, small)
897
898 agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
899 if !ok {
900 return errCoderAgentNotConfigured
901 }
902
903 tools, err := c.buildTools(ctx, agentCfg)
904 if err != nil {
905 return err
906 }
907 c.currentAgent.SetTools(tools)
908 return nil
909}
910
911func (c *coordinator) QueuedPrompts(sessionID string) int {
912 return c.currentAgent.QueuedPrompts(sessionID)
913}
914
915func (c *coordinator) QueuedPromptsList(sessionID string) []string {
916 return c.currentAgent.QueuedPromptsList(sessionID)
917}
918
919func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
920 providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
921 if !ok {
922 return errModelProviderNotConfigured
923 }
924 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
925}
926
927func (c *coordinator) isUnauthorized(err error) bool {
928 var providerErr *fantasy.ProviderError
929 return (errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized) ||
930 errors.Is(err, hyper.ErrUnauthorized)
931}
932
933func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
934 if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
935 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
936 return err
937 }
938 if err := c.UpdateModels(ctx); err != nil {
939 return err
940 }
941 return nil
942}
943
944func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
945 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
946 if err != nil {
947 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
948 return err
949 }
950
951 providerCfg.APIKey = newAPIKey
952 c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
953
954 if err := c.UpdateModels(ctx); err != nil {
955 return err
956 }
957 return nil
958}
959
960// subAgentParams holds the parameters for running a sub-agent.
961type subAgentParams struct {
962 Agent SessionAgent
963 SessionID string
964 AgentMessageID string
965 ToolCallID string
966 Prompt string
967 SessionTitle string
968 // SessionSetup is an optional callback invoked after session creation
969 // but before agent execution, for custom session configuration.
970 SessionSetup func(sessionID string)
971}
972
973// runSubAgent runs a sub-agent and handles session management and cost accumulation.
974// It creates a sub-session, runs the agent with the given prompt, and propagates
975// the cost to the parent session.
976func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
977 // Create sub-session
978 agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
979 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
980 if err != nil {
981 return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
982 }
983
984 // Call session setup function if provided
985 if params.SessionSetup != nil {
986 params.SessionSetup(session.ID)
987 }
988
989 // Get model configuration
990 model := params.Agent.Model()
991 maxTokens := model.CatwalkCfg.DefaultMaxTokens
992 if model.ModelCfg.MaxTokens != 0 {
993 maxTokens = model.ModelCfg.MaxTokens
994 }
995
996 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
997 if !ok {
998 return fantasy.ToolResponse{}, errModelProviderNotConfigured
999 }
1000
1001 // Run the agent
1002 result, err := params.Agent.Run(ctx, SessionAgentCall{
1003 SessionID: session.ID,
1004 Prompt: params.Prompt,
1005 MaxOutputTokens: maxTokens,
1006 ProviderOptions: getProviderOptions(model, providerCfg),
1007 Temperature: model.ModelCfg.Temperature,
1008 TopP: model.ModelCfg.TopP,
1009 TopK: model.ModelCfg.TopK,
1010 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1011 PresencePenalty: model.ModelCfg.PresencePenalty,
1012 NonInteractive: true,
1013 })
1014 if err != nil {
1015 return fantasy.NewTextErrorResponse("error generating response"), nil
1016 }
1017
1018 // Update parent session cost
1019 if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1020 return fantasy.ToolResponse{}, err
1021 }
1022
1023 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1024}
1025
1026// updateParentSessionCost accumulates the cost from a child session to its parent session.
1027func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1028 childSession, err := c.sessions.Get(ctx, childSessionID)
1029 if err != nil {
1030 return fmt.Errorf("get child session: %w", err)
1031 }
1032
1033 parentSession, err := c.sessions.Get(ctx, parentSessionID)
1034 if err != nil {
1035 return fmt.Errorf("get parent session: %w", err)
1036 }
1037
1038 parentSession.Cost += childSession.Cost
1039
1040 if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1041 return fmt.Errorf("save parent session: %w", err)
1042 }
1043
1044 return nil
1045}