1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/fantasy"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/notify"
22 "github.com/charmbracelet/crush/internal/agent/prompt"
23 "github.com/charmbracelet/crush/internal/agent/tools"
24 "github.com/charmbracelet/crush/internal/config"
25 "github.com/charmbracelet/crush/internal/filetracker"
26 "github.com/charmbracelet/crush/internal/history"
27 "github.com/charmbracelet/crush/internal/log"
28 "github.com/charmbracelet/crush/internal/lsp"
29 "github.com/charmbracelet/crush/internal/message"
30 "github.com/charmbracelet/crush/internal/oauth/copilot"
31 "github.com/charmbracelet/crush/internal/permission"
32 "github.com/charmbracelet/crush/internal/pubsub"
33 "github.com/charmbracelet/crush/internal/session"
34 "golang.org/x/sync/errgroup"
35
36 "charm.land/fantasy/providers/anthropic"
37 "charm.land/fantasy/providers/azure"
38 "charm.land/fantasy/providers/bedrock"
39 "charm.land/fantasy/providers/google"
40 "charm.land/fantasy/providers/openai"
41 "charm.land/fantasy/providers/openaicompat"
42 "charm.land/fantasy/providers/openrouter"
43 "charm.land/fantasy/providers/vercel"
44 openaisdk "github.com/openai/openai-go/v2/option"
45 "github.com/qjebbs/go-jsons"
46)
47
48// Coordinator errors.
49var (
50 errCoderAgentNotConfigured = errors.New("coder agent not configured")
51 errModelProviderNotConfigured = errors.New("model provider not configured")
52 errLargeModelNotSelected = errors.New("large model not selected")
53 errSmallModelNotSelected = errors.New("small model not selected")
54 errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
55 errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
56 errLargeModelNotFound = errors.New("large model not found in provider config")
57 errSmallModelNotFound = errors.New("small model not found in provider config")
58)
59
60type Coordinator interface {
61 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
62 // SetMainAgent(string)
63 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
64 Cancel(sessionID string)
65 CancelAll()
66 IsSessionBusy(sessionID string) bool
67 IsBusy() bool
68 QueuedPrompts(sessionID string) int
69 QueuedPromptsList(sessionID string) []string
70 ClearQueue(sessionID string)
71 Summarize(context.Context, string) error
72 Model() Model
73 UpdateModels(ctx context.Context) error
74}
75
76type coordinator struct {
77 cfg *config.ConfigStore
78 sessions session.Service
79 messages message.Service
80 permissions permission.Service
81 history history.Service
82 filetracker filetracker.Service
83 lspManager *lsp.Manager
84 notify pubsub.Publisher[notify.Notification]
85
86 currentAgent SessionAgent
87 agents map[string]SessionAgent
88
89 readyWg errgroup.Group
90}
91
92func NewCoordinator(
93 ctx context.Context,
94 cfg *config.ConfigStore,
95 sessions session.Service,
96 messages message.Service,
97 permissions permission.Service,
98 history history.Service,
99 filetracker filetracker.Service,
100 lspManager *lsp.Manager,
101 notify pubsub.Publisher[notify.Notification],
102) (Coordinator, error) {
103 c := &coordinator{
104 cfg: cfg,
105 sessions: sessions,
106 messages: messages,
107 permissions: permissions,
108 history: history,
109 filetracker: filetracker,
110 lspManager: lspManager,
111 notify: notify,
112 agents: make(map[string]SessionAgent),
113 }
114
115 agentCfg, ok := cfg.Config().Agents[config.AgentCoder]
116 if !ok {
117 return nil, errCoderAgentNotConfigured
118 }
119
120 // TODO: make this dynamic when we support multiple agents
121 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
122 if err != nil {
123 return nil, err
124 }
125
126 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
127 if err != nil {
128 return nil, err
129 }
130 c.currentAgent = agent
131 c.agents[config.AgentCoder] = agent
132 return c, nil
133}
134
135// Run implements Coordinator.
136func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
137 if err := c.readyWg.Wait(); err != nil {
138 return nil, err
139 }
140
141 // refresh models before each run
142 if err := c.UpdateModels(ctx); err != nil {
143 return nil, fmt.Errorf("failed to update models: %w", err)
144 }
145
146 model := c.currentAgent.Model()
147 maxTokens := model.CatwalkCfg.DefaultMaxTokens
148 if model.ModelCfg.MaxTokens != 0 {
149 maxTokens = model.ModelCfg.MaxTokens
150 }
151
152 if !model.CatwalkCfg.SupportsImages && attachments != nil {
153 // filter out image attachments
154 filteredAttachments := make([]message.Attachment, 0, len(attachments))
155 for _, att := range attachments {
156 if att.IsText() {
157 filteredAttachments = append(filteredAttachments, att)
158 }
159 }
160 attachments = filteredAttachments
161 }
162
163 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
164 if !ok {
165 return nil, errModelProviderNotConfigured
166 }
167
168 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
169
170 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
171 slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
172 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
173 return nil, err
174 }
175 }
176
177 run := func() (*fantasy.AgentResult, error) {
178 return c.currentAgent.Run(ctx, SessionAgentCall{
179 SessionID: sessionID,
180 Prompt: prompt,
181 Attachments: attachments,
182 MaxOutputTokens: maxTokens,
183 ProviderOptions: mergedOptions,
184 Temperature: temp,
185 TopP: topP,
186 TopK: topK,
187 FrequencyPenalty: freqPenalty,
188 PresencePenalty: presPenalty,
189 })
190 }
191 result, originalErr := run()
192
193 if c.isUnauthorized(originalErr) {
194 switch {
195 case providerCfg.OAuthToken != nil:
196 slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
197 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
198 return nil, originalErr
199 }
200 slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
201 return run()
202 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
203 slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
204 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
205 return nil, originalErr
206 }
207 slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
208 return run()
209 }
210 }
211
212 return result, originalErr
213}
214
215func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
216 options := fantasy.ProviderOptions{}
217
218 cfgOpts := []byte("{}")
219 providerCfgOpts := []byte("{}")
220 catwalkOpts := []byte("{}")
221
222 if model.ModelCfg.ProviderOptions != nil {
223 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
224 if err == nil {
225 cfgOpts = data
226 }
227 }
228
229 if providerCfg.ProviderOptions != nil {
230 data, err := json.Marshal(providerCfg.ProviderOptions)
231 if err == nil {
232 providerCfgOpts = data
233 }
234 }
235
236 if model.CatwalkCfg.Options.ProviderOptions != nil {
237 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
238 if err == nil {
239 catwalkOpts = data
240 }
241 }
242
243 readers := []io.Reader{
244 bytes.NewReader(catwalkOpts),
245 bytes.NewReader(providerCfgOpts),
246 bytes.NewReader(cfgOpts),
247 }
248
249 got, err := jsons.Merge(readers)
250 if err != nil {
251 slog.Error("Could not merge call config", "err", err)
252 return options
253 }
254
255 mergedOptions := make(map[string]any)
256
257 err = json.Unmarshal([]byte(got), &mergedOptions)
258 if err != nil {
259 slog.Error("Could not create config for call", "err", err)
260 return options
261 }
262
263 providerType := providerCfg.Type
264 if providerType == "hyper" {
265 if strings.Contains(model.CatwalkCfg.ID, "claude") {
266 providerType = anthropic.Name
267 } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
268 providerType = openai.Name
269 } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
270 providerType = google.Name
271 } else {
272 providerType = openaicompat.Name
273 }
274 }
275
276 switch providerType {
277 case openai.Name, azure.Name:
278 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
279 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
280 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
281 }
282 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
283 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
284 mergedOptions["reasoning_summary"] = "auto"
285 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
286 }
287 parsed, err := openai.ParseResponsesOptions(mergedOptions)
288 if err == nil {
289 options[openai.Name] = parsed
290 }
291 } else {
292 parsed, err := openai.ParseOptions(mergedOptions)
293 if err == nil {
294 options[openai.Name] = parsed
295 }
296 }
297 case anthropic.Name:
298 var (
299 _, hasEffort = mergedOptions["effort"]
300 _, hasThink = mergedOptions["thinking"]
301 )
302 switch {
303 case !hasEffort && model.ModelCfg.ReasoningEffort != "":
304 mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
305 case !hasThink && model.ModelCfg.Think:
306 mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
307 }
308 parsed, err := anthropic.ParseOptions(mergedOptions)
309 if err == nil {
310 options[anthropic.Name] = parsed
311 }
312
313 case openrouter.Name:
314 _, hasReasoning := mergedOptions["reasoning"]
315 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
316 mergedOptions["reasoning"] = map[string]any{
317 "enabled": true,
318 "effort": model.ModelCfg.ReasoningEffort,
319 }
320 }
321 parsed, err := openrouter.ParseOptions(mergedOptions)
322 if err == nil {
323 options[openrouter.Name] = parsed
324 }
325 case vercel.Name:
326 _, hasReasoning := mergedOptions["reasoning"]
327 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
328 mergedOptions["reasoning"] = map[string]any{
329 "enabled": true,
330 "effort": model.ModelCfg.ReasoningEffort,
331 }
332 }
333 parsed, err := vercel.ParseOptions(mergedOptions)
334 if err == nil {
335 options[vercel.Name] = parsed
336 }
337 case google.Name:
338 _, hasReasoning := mergedOptions["thinking_config"]
339 if !hasReasoning {
340 if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
341 mergedOptions["thinking_config"] = map[string]any{
342 "thinking_budget": 2000,
343 "include_thoughts": true,
344 }
345 } else {
346 mergedOptions["thinking_config"] = map[string]any{
347 "thinking_level": model.ModelCfg.ReasoningEffort,
348 "include_thoughts": true,
349 }
350 }
351 }
352 parsed, err := google.ParseOptions(mergedOptions)
353 if err == nil {
354 options[google.Name] = parsed
355 }
356 case openaicompat.Name:
357 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
358 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
359 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
360 }
361 parsed, err := openaicompat.ParseOptions(mergedOptions)
362 if err == nil {
363 options[openaicompat.Name] = parsed
364 }
365 }
366
367 return options
368}
369
370func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
371 modelOptions := getProviderOptions(model, cfg)
372 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
373 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
374 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
375 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
376 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
377 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
378}
379
380func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
381 large, small, err := c.buildAgentModels(ctx, isSubAgent)
382 if err != nil {
383 return nil, err
384 }
385
386 largeProviderCfg, _ := c.cfg.Config().Providers.Get(large.ModelCfg.Provider)
387 result := NewSessionAgent(SessionAgentOptions{
388 LargeModel: large,
389 SmallModel: small,
390 SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
391 SystemPrompt: "",
392 IsSubAgent: isSubAgent,
393 DisableAutoSummarize: c.cfg.Config().Options.DisableAutoSummarize,
394 IsYolo: c.permissions.SkipRequests(),
395 Sessions: c.sessions,
396 Messages: c.messages,
397 Tools: nil,
398 Notify: c.notify,
399 })
400
401 c.readyWg.Go(func() error {
402 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), c.cfg)
403 if err != nil {
404 return err
405 }
406 result.SetSystemPrompt(systemPrompt)
407 return nil
408 })
409
410 c.readyWg.Go(func() error {
411 tools, err := c.buildTools(ctx, agent)
412 if err != nil {
413 return err
414 }
415 result.SetTools(tools)
416 return nil
417 })
418
419 return result, nil
420}
421
422func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
423 var allTools []fantasy.AgentTool
424 if slices.Contains(agent.AllowedTools, AgentToolName) {
425 agentTool, err := c.agentTool(ctx)
426 if err != nil {
427 return nil, err
428 }
429 allTools = append(allTools, agentTool)
430 }
431
432 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
433 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
434 if err != nil {
435 return nil, err
436 }
437 allTools = append(allTools, agenticFetchTool)
438 }
439
440 // Get the model name for the agent
441 modelName := ""
442 if modelCfg, ok := c.cfg.Config().Models[agent.Model]; ok {
443 if model := c.cfg.Config().GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
444 modelName = model.Name
445 }
446 }
447
448 allTools = append(allTools,
449 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Options.Attribution, modelName),
450 tools.NewJobOutputTool(),
451 tools.NewJobKillTool(),
452 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
453 tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
454 tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
455 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
456 tools.NewGlobTool(c.cfg.WorkingDir()),
457 tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
458 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
459 tools.NewSourcegraphTool(nil),
460 tools.NewTodosTool(c.sessions),
461 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
462 tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
463 )
464
465 // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
466 if len(c.cfg.Config().LSP) > 0 || c.cfg.Config().Options.AutoLSP == nil || *c.cfg.Config().Options.AutoLSP {
467 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
468 }
469
470 if len(c.cfg.Config().MCP) > 0 {
471 allTools = append(
472 allTools,
473 tools.NewListMCPResourcesTool(c.cfg, c.permissions),
474 tools.NewReadMCPResourceTool(c.cfg, c.permissions),
475 )
476 }
477
478 var filteredTools []fantasy.AgentTool
479 for _, tool := range allTools {
480 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
481 filteredTools = append(filteredTools, tool)
482 }
483 }
484
485 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
486 if agent.AllowedMCP == nil {
487 // No MCP restrictions
488 filteredTools = append(filteredTools, tool)
489 continue
490 }
491 if len(agent.AllowedMCP) == 0 {
492 // No MCPs allowed
493 slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
494 break
495 }
496
497 for mcp, tools := range agent.AllowedMCP {
498 if mcp != tool.MCP() {
499 continue
500 }
501 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
502 filteredTools = append(filteredTools, tool)
503 break
504 }
505 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
506 }
507 }
508 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
509 return strings.Compare(a.Info().Name, b.Info().Name)
510 })
511 return filteredTools, nil
512}
513
514// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
515func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
516 largeModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeLarge]
517 if !ok {
518 return Model{}, Model{}, errLargeModelNotSelected
519 }
520 smallModelCfg, ok := c.cfg.Config().Models[config.SelectedModelTypeSmall]
521 if !ok {
522 return Model{}, Model{}, errSmallModelNotSelected
523 }
524
525 largeProviderCfg, ok := c.cfg.Config().Providers.Get(largeModelCfg.Provider)
526 if !ok {
527 return Model{}, Model{}, errLargeModelProviderNotConfigured
528 }
529
530 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
531 if err != nil {
532 return Model{}, Model{}, err
533 }
534
535 smallProviderCfg, ok := c.cfg.Config().Providers.Get(smallModelCfg.Provider)
536 if !ok {
537 return Model{}, Model{}, errSmallModelProviderNotConfigured
538 }
539
540 smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
541 if err != nil {
542 return Model{}, Model{}, err
543 }
544
545 var largeCatwalkModel *catwalk.Model
546 var smallCatwalkModel *catwalk.Model
547
548 for _, m := range largeProviderCfg.Models {
549 if m.ID == largeModelCfg.Model {
550 largeCatwalkModel = &m
551 }
552 }
553 for _, m := range smallProviderCfg.Models {
554 if m.ID == smallModelCfg.Model {
555 smallCatwalkModel = &m
556 }
557 }
558
559 if largeCatwalkModel == nil {
560 return Model{}, Model{}, errLargeModelNotFound
561 }
562
563 if smallCatwalkModel == nil {
564 return Model{}, Model{}, errSmallModelNotFound
565 }
566
567 largeModelID := largeModelCfg.Model
568 smallModelID := smallModelCfg.Model
569
570 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
571 largeModelID += ":exacto"
572 }
573
574 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
575 smallModelID += ":exacto"
576 }
577
578 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
579 if err != nil {
580 return Model{}, Model{}, err
581 }
582 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
583 if err != nil {
584 return Model{}, Model{}, err
585 }
586
587 return Model{
588 Model: largeModel,
589 CatwalkCfg: *largeCatwalkModel,
590 ModelCfg: largeModelCfg,
591 }, Model{
592 Model: smallModel,
593 CatwalkCfg: *smallCatwalkModel,
594 ModelCfg: smallModelCfg,
595 }, nil
596}
597
598func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
599 var opts []anthropic.Option
600
601 switch {
602 case strings.HasPrefix(apiKey, "Bearer "):
603 // NOTE: Prevent the SDK from picking up the API key from env.
604 os.Setenv("ANTHROPIC_API_KEY", "")
605 headers["Authorization"] = apiKey
606 case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
607 // NOTE: Prevent the SDK from picking up the API key from env.
608 os.Setenv("ANTHROPIC_API_KEY", "")
609 headers["Authorization"] = "Bearer " + apiKey
610 case apiKey != "":
611 // X-Api-Key header
612 opts = append(opts, anthropic.WithAPIKey(apiKey))
613 }
614
615 if len(headers) > 0 {
616 opts = append(opts, anthropic.WithHeaders(headers))
617 }
618
619 if baseURL != "" {
620 opts = append(opts, anthropic.WithBaseURL(baseURL))
621 }
622
623 if c.cfg.Config().Options.Debug {
624 httpClient := log.NewHTTPClient()
625 opts = append(opts, anthropic.WithHTTPClient(httpClient))
626 }
627 return anthropic.New(opts...)
628}
629
630func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
631 opts := []openai.Option{
632 openai.WithAPIKey(apiKey),
633 openai.WithUseResponsesAPI(),
634 }
635 if c.cfg.Config().Options.Debug {
636 httpClient := log.NewHTTPClient()
637 opts = append(opts, openai.WithHTTPClient(httpClient))
638 }
639 if len(headers) > 0 {
640 opts = append(opts, openai.WithHeaders(headers))
641 }
642 if baseURL != "" {
643 opts = append(opts, openai.WithBaseURL(baseURL))
644 }
645 return openai.New(opts...)
646}
647
648func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
649 opts := []openrouter.Option{
650 openrouter.WithAPIKey(apiKey),
651 }
652 if c.cfg.Config().Options.Debug {
653 httpClient := log.NewHTTPClient()
654 opts = append(opts, openrouter.WithHTTPClient(httpClient))
655 }
656 if len(headers) > 0 {
657 opts = append(opts, openrouter.WithHeaders(headers))
658 }
659 return openrouter.New(opts...)
660}
661
662func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
663 opts := []vercel.Option{
664 vercel.WithAPIKey(apiKey),
665 }
666 if c.cfg.Config().Options.Debug {
667 httpClient := log.NewHTTPClient()
668 opts = append(opts, vercel.WithHTTPClient(httpClient))
669 }
670 if len(headers) > 0 {
671 opts = append(opts, vercel.WithHeaders(headers))
672 }
673 return vercel.New(opts...)
674}
675
676func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
677 opts := []openaicompat.Option{
678 openaicompat.WithBaseURL(baseURL),
679 openaicompat.WithAPIKey(apiKey),
680 }
681
682 // Set HTTP client based on provider and debug mode.
683 var httpClient *http.Client
684 if providerID == string(catwalk.InferenceProviderCopilot) {
685 opts = append(opts, openaicompat.WithUseResponsesAPI())
686 httpClient = copilot.NewClient(isSubAgent, c.cfg.Config().Options.Debug)
687 } else if c.cfg.Config().Options.Debug {
688 httpClient = log.NewHTTPClient()
689 }
690 if httpClient != nil {
691 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
692 }
693
694 if len(headers) > 0 {
695 opts = append(opts, openaicompat.WithHeaders(headers))
696 }
697
698 for extraKey, extraValue := range extraBody {
699 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
700 }
701
702 return openaicompat.New(opts...)
703}
704
705func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
706 opts := []azure.Option{
707 azure.WithBaseURL(baseURL),
708 azure.WithAPIKey(apiKey),
709 azure.WithUseResponsesAPI(),
710 }
711 if c.cfg.Config().Options.Debug {
712 httpClient := log.NewHTTPClient()
713 opts = append(opts, azure.WithHTTPClient(httpClient))
714 }
715 if options == nil {
716 options = make(map[string]string)
717 }
718 if apiVersion, ok := options["apiVersion"]; ok {
719 opts = append(opts, azure.WithAPIVersion(apiVersion))
720 }
721 if len(headers) > 0 {
722 opts = append(opts, azure.WithHeaders(headers))
723 }
724
725 return azure.New(opts...)
726}
727
728func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
729 var opts []bedrock.Option
730 if c.cfg.Config().Options.Debug {
731 httpClient := log.NewHTTPClient()
732 opts = append(opts, bedrock.WithHTTPClient(httpClient))
733 }
734 if len(headers) > 0 {
735 opts = append(opts, bedrock.WithHeaders(headers))
736 }
737 bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
738 if bearerToken != "" {
739 opts = append(opts, bedrock.WithAPIKey(bearerToken))
740 }
741 return bedrock.New(opts...)
742}
743
744func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
745 opts := []google.Option{
746 google.WithBaseURL(baseURL),
747 google.WithGeminiAPIKey(apiKey),
748 }
749 if c.cfg.Config().Options.Debug {
750 httpClient := log.NewHTTPClient()
751 opts = append(opts, google.WithHTTPClient(httpClient))
752 }
753 if len(headers) > 0 {
754 opts = append(opts, google.WithHeaders(headers))
755 }
756 return google.New(opts...)
757}
758
759func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
760 opts := []google.Option{}
761 if c.cfg.Config().Options.Debug {
762 httpClient := log.NewHTTPClient()
763 opts = append(opts, google.WithHTTPClient(httpClient))
764 }
765 if len(headers) > 0 {
766 opts = append(opts, google.WithHeaders(headers))
767 }
768
769 project := options["project"]
770 location := options["location"]
771
772 opts = append(opts, google.WithVertex(project, location))
773
774 return google.New(opts...)
775}
776
777func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
778 opts := []hyper.Option{
779 hyper.WithBaseURL(baseURL),
780 hyper.WithAPIKey(apiKey),
781 }
782 if c.cfg.Config().Options.Debug {
783 httpClient := log.NewHTTPClient()
784 opts = append(opts, hyper.WithHTTPClient(httpClient))
785 }
786 return hyper.New(opts...)
787}
788
789func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
790 if model.Think {
791 return true
792 }
793 opts, err := anthropic.ParseOptions(model.ProviderOptions)
794 return err == nil && opts.Thinking != nil
795}
796
797func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
798 headers := maps.Clone(providerCfg.ExtraHeaders)
799 if headers == nil {
800 headers = make(map[string]string)
801 }
802
803 // handle special headers for anthropic
804 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
805 if v, ok := headers["anthropic-beta"]; ok {
806 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
807 } else {
808 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
809 }
810 }
811
812 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
813 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
814
815 switch providerCfg.Type {
816 case openai.Name:
817 return c.buildOpenaiProvider(baseURL, apiKey, headers)
818 case anthropic.Name:
819 return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
820 case openrouter.Name:
821 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
822 case vercel.Name:
823 return c.buildVercelProvider(baseURL, apiKey, headers)
824 case azure.Name:
825 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
826 case bedrock.Name:
827 return c.buildBedrockProvider(headers)
828 case google.Name:
829 return c.buildGoogleProvider(baseURL, apiKey, headers)
830 case "google-vertex":
831 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
832 case openaicompat.Name:
833 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
834 if providerCfg.ExtraBody == nil {
835 providerCfg.ExtraBody = map[string]any{}
836 }
837 providerCfg.ExtraBody["tool_stream"] = true
838 }
839 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
840 case hyper.Name:
841 return c.buildHyperProvider(baseURL, apiKey)
842 default:
843 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
844 }
845}
846
847func isExactoSupported(modelID string) bool {
848 supportedModels := []string{
849 "moonshotai/kimi-k2-0905",
850 "deepseek/deepseek-v3.1-terminus",
851 "z-ai/glm-4.6",
852 "openai/gpt-oss-120b",
853 "qwen/qwen3-coder",
854 }
855 return slices.Contains(supportedModels, modelID)
856}
857
858func (c *coordinator) Cancel(sessionID string) {
859 c.currentAgent.Cancel(sessionID)
860}
861
862func (c *coordinator) CancelAll() {
863 c.currentAgent.CancelAll()
864}
865
866func (c *coordinator) ClearQueue(sessionID string) {
867 c.currentAgent.ClearQueue(sessionID)
868}
869
870func (c *coordinator) IsBusy() bool {
871 return c.currentAgent.IsBusy()
872}
873
874func (c *coordinator) IsSessionBusy(sessionID string) bool {
875 return c.currentAgent.IsSessionBusy(sessionID)
876}
877
878func (c *coordinator) Model() Model {
879 return c.currentAgent.Model()
880}
881
882func (c *coordinator) UpdateModels(ctx context.Context) error {
883 // build the models again so we make sure we get the latest config
884 large, small, err := c.buildAgentModels(ctx, false)
885 if err != nil {
886 return err
887 }
888 c.currentAgent.SetModels(large, small)
889
890 agentCfg, ok := c.cfg.Config().Agents[config.AgentCoder]
891 if !ok {
892 return errCoderAgentNotConfigured
893 }
894
895 tools, err := c.buildTools(ctx, agentCfg)
896 if err != nil {
897 return err
898 }
899 c.currentAgent.SetTools(tools)
900 return nil
901}
902
903func (c *coordinator) QueuedPrompts(sessionID string) int {
904 return c.currentAgent.QueuedPrompts(sessionID)
905}
906
907func (c *coordinator) QueuedPromptsList(sessionID string) []string {
908 return c.currentAgent.QueuedPromptsList(sessionID)
909}
910
911func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
912 providerCfg, ok := c.cfg.Config().Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
913 if !ok {
914 return errModelProviderNotConfigured
915 }
916 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
917}
918
919func (c *coordinator) isUnauthorized(err error) bool {
920 var providerErr *fantasy.ProviderError
921 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
922}
923
924func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
925 if err := c.cfg.RefreshOAuthToken(ctx, config.ScopeGlobal, providerCfg.ID); err != nil {
926 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
927 return err
928 }
929 if err := c.UpdateModels(ctx); err != nil {
930 return err
931 }
932 return nil
933}
934
935func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
936 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
937 if err != nil {
938 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
939 return err
940 }
941
942 providerCfg.APIKey = newAPIKey
943 c.cfg.Config().Providers.Set(providerCfg.ID, providerCfg)
944
945 if err := c.UpdateModels(ctx); err != nil {
946 return err
947 }
948 return nil
949}
950
951// subAgentParams holds the parameters for running a sub-agent.
952type subAgentParams struct {
953 Agent SessionAgent
954 SessionID string
955 AgentMessageID string
956 ToolCallID string
957 Prompt string
958 SessionTitle string
959 // SessionSetup is an optional callback invoked after session creation
960 // but before agent execution, for custom session configuration.
961 SessionSetup func(sessionID string)
962}
963
964// runSubAgent runs a sub-agent and handles session management and cost accumulation.
965// It creates a sub-session, runs the agent with the given prompt, and propagates
966// the cost to the parent session.
967func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
968 // Create sub-session
969 agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
970 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
971 if err != nil {
972 return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
973 }
974
975 // Call session setup function if provided
976 if params.SessionSetup != nil {
977 params.SessionSetup(session.ID)
978 }
979
980 // Get model configuration
981 model := params.Agent.Model()
982 maxTokens := model.CatwalkCfg.DefaultMaxTokens
983 if model.ModelCfg.MaxTokens != 0 {
984 maxTokens = model.ModelCfg.MaxTokens
985 }
986
987 providerCfg, ok := c.cfg.Config().Providers.Get(model.ModelCfg.Provider)
988 if !ok {
989 return fantasy.ToolResponse{}, errModelProviderNotConfigured
990 }
991
992 // Run the agent
993 result, err := params.Agent.Run(ctx, SessionAgentCall{
994 SessionID: session.ID,
995 Prompt: params.Prompt,
996 MaxOutputTokens: maxTokens,
997 ProviderOptions: getProviderOptions(model, providerCfg),
998 Temperature: model.ModelCfg.Temperature,
999 TopP: model.ModelCfg.TopP,
1000 TopK: model.ModelCfg.TopK,
1001 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1002 PresencePenalty: model.ModelCfg.PresencePenalty,
1003 NonInteractive: true,
1004 })
1005 if err != nil {
1006 return fantasy.NewTextErrorResponse("error generating response"), nil
1007 }
1008
1009 // Update parent session cost
1010 if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1011 return fantasy.ToolResponse{}, err
1012 }
1013
1014 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1015}
1016
1017// updateParentSessionCost accumulates the cost from a child session to its parent session.
1018func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1019 childSession, err := c.sessions.Get(ctx, childSessionID)
1020 if err != nil {
1021 return fmt.Errorf("get child session: %w", err)
1022 }
1023
1024 parentSession, err := c.sessions.Get(ctx, parentSessionID)
1025 if err != nil {
1026 return fmt.Errorf("get parent session: %w", err)
1027 }
1028
1029 parentSession.Cost += childSession.Cost
1030
1031 if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1032 return fmt.Errorf("save parent session: %w", err)
1033 }
1034
1035 return nil
1036}