1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/fantasy"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/prompt"
22 "github.com/charmbracelet/crush/internal/agent/tools"
23 "github.com/charmbracelet/crush/internal/config"
24 "github.com/charmbracelet/crush/internal/filetracker"
25 "github.com/charmbracelet/crush/internal/history"
26 "github.com/charmbracelet/crush/internal/log"
27 "github.com/charmbracelet/crush/internal/lsp"
28 "github.com/charmbracelet/crush/internal/message"
29 "github.com/charmbracelet/crush/internal/oauth/copilot"
30 "github.com/charmbracelet/crush/internal/permission"
31 "github.com/charmbracelet/crush/internal/session"
32 "golang.org/x/sync/errgroup"
33
34 "charm.land/fantasy/providers/anthropic"
35 "charm.land/fantasy/providers/azure"
36 "charm.land/fantasy/providers/bedrock"
37 "charm.land/fantasy/providers/google"
38 "charm.land/fantasy/providers/openai"
39 "charm.land/fantasy/providers/openaicompat"
40 "charm.land/fantasy/providers/openrouter"
41 "charm.land/fantasy/providers/vercel"
42 openaisdk "github.com/openai/openai-go/v2/option"
43 "github.com/qjebbs/go-jsons"
44)
45
46// Coordinator errors.
47var (
48 errCoderAgentNotConfigured = errors.New("coder agent not configured")
49 errModelProviderNotConfigured = errors.New("model provider not configured")
50 errLargeModelNotSelected = errors.New("large model not selected")
51 errSmallModelNotSelected = errors.New("small model not selected")
52 errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
53 errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
54 errLargeModelNotFound = errors.New("large model not found in provider config")
55 errSmallModelNotFound = errors.New("small model not found in provider config")
56)
57
58type Coordinator interface {
59 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
60 // SetMainAgent(string)
61 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
62 Cancel(sessionID string)
63 CancelAll()
64 IsSessionBusy(sessionID string) bool
65 IsBusy() bool
66 QueuedPrompts(sessionID string) int
67 QueuedPromptsList(sessionID string) []string
68 ClearQueue(sessionID string)
69 Summarize(context.Context, string) error
70 Model() Model
71 UpdateModels(ctx context.Context) error
72}
73
74type coordinator struct {
75 cfg *config.Config
76 sessions session.Service
77 messages message.Service
78 permissions permission.Service
79 history history.Service
80 filetracker filetracker.Service
81 lspManager *lsp.Manager
82
83 currentAgent SessionAgent
84 agents map[string]SessionAgent
85
86 readyWg errgroup.Group
87}
88
89func NewCoordinator(
90 ctx context.Context,
91 cfg *config.Config,
92 sessions session.Service,
93 messages message.Service,
94 permissions permission.Service,
95 history history.Service,
96 filetracker filetracker.Service,
97 lspManager *lsp.Manager,
98) (Coordinator, error) {
99 c := &coordinator{
100 cfg: cfg,
101 sessions: sessions,
102 messages: messages,
103 permissions: permissions,
104 history: history,
105 filetracker: filetracker,
106 lspManager: lspManager,
107 agents: make(map[string]SessionAgent),
108 }
109
110 agentCfg, ok := cfg.Agents[config.AgentCoder]
111 if !ok {
112 return nil, errCoderAgentNotConfigured
113 }
114
115 // TODO: make this dynamic when we support multiple agents
116 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
117 if err != nil {
118 return nil, err
119 }
120
121 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
122 if err != nil {
123 return nil, err
124 }
125 c.currentAgent = agent
126 c.agents[config.AgentCoder] = agent
127 return c, nil
128}
129
130// Run implements Coordinator.
131func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
132 if err := c.readyWg.Wait(); err != nil {
133 return nil, err
134 }
135
136 // refresh models before each run
137 if err := c.UpdateModels(ctx); err != nil {
138 return nil, fmt.Errorf("failed to update models: %w", err)
139 }
140
141 model := c.currentAgent.Model()
142 maxTokens := model.CatwalkCfg.DefaultMaxTokens
143 if model.ModelCfg.MaxTokens != 0 {
144 maxTokens = model.ModelCfg.MaxTokens
145 }
146
147 if !model.CatwalkCfg.SupportsImages && attachments != nil {
148 // filter out image attachments
149 filteredAttachments := make([]message.Attachment, 0, len(attachments))
150 for _, att := range attachments {
151 if att.IsText() {
152 filteredAttachments = append(filteredAttachments, att)
153 }
154 }
155 attachments = filteredAttachments
156 }
157
158 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
159 if !ok {
160 return nil, errModelProviderNotConfigured
161 }
162
163 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
164
165 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
166 slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
167 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
168 return nil, err
169 }
170 }
171
172 run := func() (*fantasy.AgentResult, error) {
173 return c.currentAgent.Run(ctx, SessionAgentCall{
174 SessionID: sessionID,
175 Prompt: prompt,
176 Attachments: attachments,
177 MaxOutputTokens: maxTokens,
178 ProviderOptions: mergedOptions,
179 Temperature: temp,
180 TopP: topP,
181 TopK: topK,
182 FrequencyPenalty: freqPenalty,
183 PresencePenalty: presPenalty,
184 })
185 }
186 result, originalErr := run()
187
188 if c.isUnauthorized(originalErr) {
189 switch {
190 case providerCfg.OAuthToken != nil:
191 slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
192 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
193 return nil, originalErr
194 }
195 slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
196 return run()
197 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
198 slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
199 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
200 return nil, originalErr
201 }
202 slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
203 return run()
204 }
205 }
206
207 return result, originalErr
208}
209
210func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
211 options := fantasy.ProviderOptions{}
212
213 cfgOpts := []byte("{}")
214 providerCfgOpts := []byte("{}")
215 catwalkOpts := []byte("{}")
216
217 if model.ModelCfg.ProviderOptions != nil {
218 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
219 if err == nil {
220 cfgOpts = data
221 }
222 }
223
224 if providerCfg.ProviderOptions != nil {
225 data, err := json.Marshal(providerCfg.ProviderOptions)
226 if err == nil {
227 providerCfgOpts = data
228 }
229 }
230
231 if model.CatwalkCfg.Options.ProviderOptions != nil {
232 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
233 if err == nil {
234 catwalkOpts = data
235 }
236 }
237
238 readers := []io.Reader{
239 bytes.NewReader(catwalkOpts),
240 bytes.NewReader(providerCfgOpts),
241 bytes.NewReader(cfgOpts),
242 }
243
244 got, err := jsons.Merge(readers)
245 if err != nil {
246 slog.Error("Could not merge call config", "err", err)
247 return options
248 }
249
250 mergedOptions := make(map[string]any)
251
252 err = json.Unmarshal([]byte(got), &mergedOptions)
253 if err != nil {
254 slog.Error("Could not create config for call", "err", err)
255 return options
256 }
257
258 providerType := providerCfg.Type
259 if providerType == "hyper" {
260 if strings.Contains(model.CatwalkCfg.ID, "claude") {
261 providerType = anthropic.Name
262 } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
263 providerType = openai.Name
264 } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
265 providerType = google.Name
266 } else {
267 providerType = openaicompat.Name
268 }
269 }
270
271 switch providerType {
272 case openai.Name, azure.Name:
273 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
274 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
275 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
276 }
277 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
278 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
279 mergedOptions["reasoning_summary"] = "auto"
280 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
281 }
282 parsed, err := openai.ParseResponsesOptions(mergedOptions)
283 if err == nil {
284 options[openai.Name] = parsed
285 }
286 } else {
287 parsed, err := openai.ParseOptions(mergedOptions)
288 if err == nil {
289 options[openai.Name] = parsed
290 }
291 }
292 case anthropic.Name:
293 var (
294 _, hasEffort = mergedOptions["effort"]
295 _, hasThink = mergedOptions["thinking"]
296 )
297 switch {
298 case !hasEffort && model.ModelCfg.ReasoningEffort != "":
299 mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
300 case !hasThink && model.ModelCfg.Think:
301 mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
302 }
303 parsed, err := anthropic.ParseOptions(mergedOptions)
304 if err == nil {
305 options[anthropic.Name] = parsed
306 }
307
308 case openrouter.Name:
309 _, hasReasoning := mergedOptions["reasoning"]
310 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
311 mergedOptions["reasoning"] = map[string]any{
312 "enabled": true,
313 "effort": model.ModelCfg.ReasoningEffort,
314 }
315 }
316 parsed, err := openrouter.ParseOptions(mergedOptions)
317 if err == nil {
318 options[openrouter.Name] = parsed
319 }
320 case vercel.Name:
321 _, hasReasoning := mergedOptions["reasoning"]
322 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
323 mergedOptions["reasoning"] = map[string]any{
324 "enabled": true,
325 "effort": model.ModelCfg.ReasoningEffort,
326 }
327 }
328 parsed, err := vercel.ParseOptions(mergedOptions)
329 if err == nil {
330 options[vercel.Name] = parsed
331 }
332 case google.Name:
333 _, hasReasoning := mergedOptions["thinking_config"]
334 if !hasReasoning {
335 if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
336 mergedOptions["thinking_config"] = map[string]any{
337 "thinking_budget": 2000,
338 "include_thoughts": true,
339 }
340 } else {
341 mergedOptions["thinking_config"] = map[string]any{
342 "thinking_level": model.ModelCfg.ReasoningEffort,
343 "include_thoughts": true,
344 }
345 }
346 }
347 parsed, err := google.ParseOptions(mergedOptions)
348 if err == nil {
349 options[google.Name] = parsed
350 }
351 case openaicompat.Name:
352 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
353 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
354 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
355 }
356 parsed, err := openaicompat.ParseOptions(mergedOptions)
357 if err == nil {
358 options[openaicompat.Name] = parsed
359 }
360 }
361
362 return options
363}
364
365func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
366 modelOptions := getProviderOptions(model, cfg)
367 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
368 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
369 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
370 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
371 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
372 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
373}
374
375func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
376 large, small, err := c.buildAgentModels(ctx, isSubAgent)
377 if err != nil {
378 return nil, err
379 }
380
381 largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
382 result := NewSessionAgent(SessionAgentOptions{
383 large,
384 small,
385 largeProviderCfg.SystemPromptPrefix,
386 "",
387 isSubAgent,
388 c.cfg.Options.DisableAutoSummarize,
389 c.permissions.SkipRequests(),
390 c.sessions,
391 c.messages,
392 nil,
393 })
394
395 c.readyWg.Go(func() error {
396 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
397 if err != nil {
398 return err
399 }
400 result.SetSystemPrompt(systemPrompt)
401 return nil
402 })
403
404 c.readyWg.Go(func() error {
405 tools, err := c.buildTools(ctx, agent)
406 if err != nil {
407 return err
408 }
409 result.SetTools(tools)
410 return nil
411 })
412
413 return result, nil
414}
415
416func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
417 var allTools []fantasy.AgentTool
418 if slices.Contains(agent.AllowedTools, AgentToolName) {
419 agentTool, err := c.agentTool(ctx)
420 if err != nil {
421 return nil, err
422 }
423 allTools = append(allTools, agentTool)
424 }
425
426 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
427 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
428 if err != nil {
429 return nil, err
430 }
431 allTools = append(allTools, agenticFetchTool)
432 }
433
434 // Get the model name for the agent
435 modelName := ""
436 if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
437 if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
438 modelName = model.Name
439 }
440 }
441
442 allTools = append(allTools,
443 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
444 tools.NewJobOutputTool(),
445 tools.NewJobKillTool(),
446 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
447 tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
448 tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
449 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
450 tools.NewGlobTool(c.cfg.WorkingDir()),
451 tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep),
452 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
453 tools.NewSourcegraphTool(nil),
454 tools.NewTodosTool(c.sessions),
455 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
456 tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
457 )
458
459 // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
460 if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP {
461 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
462 }
463
464 if len(c.cfg.MCP) > 0 {
465 allTools = append(
466 allTools,
467 tools.NewListMCPResourcesTool(c.cfg, c.permissions),
468 tools.NewReadMCPResourceTool(c.cfg, c.permissions),
469 )
470 }
471
472 var filteredTools []fantasy.AgentTool
473 for _, tool := range allTools {
474 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
475 filteredTools = append(filteredTools, tool)
476 }
477 }
478
479 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
480 if agent.AllowedMCP == nil {
481 // No MCP restrictions
482 filteredTools = append(filteredTools, tool)
483 continue
484 }
485 if len(agent.AllowedMCP) == 0 {
486 // No MCPs allowed
487 slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
488 break
489 }
490
491 for mcp, tools := range agent.AllowedMCP {
492 if mcp != tool.MCP() {
493 continue
494 }
495 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
496 filteredTools = append(filteredTools, tool)
497 break
498 }
499 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
500 }
501 }
502 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
503 return strings.Compare(a.Info().Name, b.Info().Name)
504 })
505 return filteredTools, nil
506}
507
508// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
509func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
510 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
511 if !ok {
512 return Model{}, Model{}, errLargeModelNotSelected
513 }
514 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
515 if !ok {
516 return Model{}, Model{}, errSmallModelNotSelected
517 }
518
519 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
520 if !ok {
521 return Model{}, Model{}, errLargeModelProviderNotConfigured
522 }
523
524 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
525 if err != nil {
526 return Model{}, Model{}, err
527 }
528
529 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
530 if !ok {
531 return Model{}, Model{}, errSmallModelProviderNotConfigured
532 }
533
534 smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
535 if err != nil {
536 return Model{}, Model{}, err
537 }
538
539 var largeCatwalkModel *catwalk.Model
540 var smallCatwalkModel *catwalk.Model
541
542 for _, m := range largeProviderCfg.Models {
543 if m.ID == largeModelCfg.Model {
544 largeCatwalkModel = &m
545 }
546 }
547 for _, m := range smallProviderCfg.Models {
548 if m.ID == smallModelCfg.Model {
549 smallCatwalkModel = &m
550 }
551 }
552
553 if largeCatwalkModel == nil {
554 return Model{}, Model{}, errLargeModelNotFound
555 }
556
557 if smallCatwalkModel == nil {
558 return Model{}, Model{}, errSmallModelNotFound
559 }
560
561 largeModelID := largeModelCfg.Model
562 smallModelID := smallModelCfg.Model
563
564 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
565 largeModelID += ":exacto"
566 }
567
568 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
569 smallModelID += ":exacto"
570 }
571
572 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
573 if err != nil {
574 return Model{}, Model{}, err
575 }
576 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
577 if err != nil {
578 return Model{}, Model{}, err
579 }
580
581 return Model{
582 Model: largeModel,
583 CatwalkCfg: *largeCatwalkModel,
584 ModelCfg: largeModelCfg,
585 }, Model{
586 Model: smallModel,
587 CatwalkCfg: *smallCatwalkModel,
588 ModelCfg: smallModelCfg,
589 }, nil
590}
591
592func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
593 var opts []anthropic.Option
594
595 switch {
596 case strings.HasPrefix(apiKey, "Bearer "):
597 // NOTE: Prevent the SDK from picking up the API key from env.
598 os.Setenv("ANTHROPIC_API_KEY", "")
599 headers["Authorization"] = apiKey
600 case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
601 // NOTE: Prevent the SDK from picking up the API key from env.
602 os.Setenv("ANTHROPIC_API_KEY", "")
603 headers["Authorization"] = "Bearer " + apiKey
604 case apiKey != "":
605 // X-Api-Key header
606 opts = append(opts, anthropic.WithAPIKey(apiKey))
607 }
608
609 if len(headers) > 0 {
610 opts = append(opts, anthropic.WithHeaders(headers))
611 }
612
613 if baseURL != "" {
614 opts = append(opts, anthropic.WithBaseURL(baseURL))
615 }
616
617 if c.cfg.Options.Debug {
618 httpClient := log.NewHTTPClient()
619 opts = append(opts, anthropic.WithHTTPClient(httpClient))
620 }
621 return anthropic.New(opts...)
622}
623
624func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
625 opts := []openai.Option{
626 openai.WithAPIKey(apiKey),
627 openai.WithUseResponsesAPI(),
628 }
629 if c.cfg.Options.Debug {
630 httpClient := log.NewHTTPClient()
631 opts = append(opts, openai.WithHTTPClient(httpClient))
632 }
633 if len(headers) > 0 {
634 opts = append(opts, openai.WithHeaders(headers))
635 }
636 if baseURL != "" {
637 opts = append(opts, openai.WithBaseURL(baseURL))
638 }
639 return openai.New(opts...)
640}
641
642func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
643 opts := []openrouter.Option{
644 openrouter.WithAPIKey(apiKey),
645 }
646 if c.cfg.Options.Debug {
647 httpClient := log.NewHTTPClient()
648 opts = append(opts, openrouter.WithHTTPClient(httpClient))
649 }
650 if len(headers) > 0 {
651 opts = append(opts, openrouter.WithHeaders(headers))
652 }
653 return openrouter.New(opts...)
654}
655
656func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
657 opts := []vercel.Option{
658 vercel.WithAPIKey(apiKey),
659 }
660 if c.cfg.Options.Debug {
661 httpClient := log.NewHTTPClient()
662 opts = append(opts, vercel.WithHTTPClient(httpClient))
663 }
664 if len(headers) > 0 {
665 opts = append(opts, vercel.WithHeaders(headers))
666 }
667 return vercel.New(opts...)
668}
669
670func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
671 opts := []openaicompat.Option{
672 openaicompat.WithBaseURL(baseURL),
673 openaicompat.WithAPIKey(apiKey),
674 }
675
676 // Set HTTP client based on provider and debug mode.
677 var httpClient *http.Client
678 if providerID == string(catwalk.InferenceProviderCopilot) {
679 opts = append(opts, openaicompat.WithUseResponsesAPI())
680 httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
681 } else if c.cfg.Options.Debug {
682 httpClient = log.NewHTTPClient()
683 }
684 if httpClient != nil {
685 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
686 }
687
688 if len(headers) > 0 {
689 opts = append(opts, openaicompat.WithHeaders(headers))
690 }
691
692 for extraKey, extraValue := range extraBody {
693 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
694 }
695
696 return openaicompat.New(opts...)
697}
698
699func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
700 opts := []azure.Option{
701 azure.WithBaseURL(baseURL),
702 azure.WithAPIKey(apiKey),
703 azure.WithUseResponsesAPI(),
704 }
705 if c.cfg.Options.Debug {
706 httpClient := log.NewHTTPClient()
707 opts = append(opts, azure.WithHTTPClient(httpClient))
708 }
709 if options == nil {
710 options = make(map[string]string)
711 }
712 if apiVersion, ok := options["apiVersion"]; ok {
713 opts = append(opts, azure.WithAPIVersion(apiVersion))
714 }
715 if len(headers) > 0 {
716 opts = append(opts, azure.WithHeaders(headers))
717 }
718
719 return azure.New(opts...)
720}
721
722func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
723 var opts []bedrock.Option
724 if c.cfg.Options.Debug {
725 httpClient := log.NewHTTPClient()
726 opts = append(opts, bedrock.WithHTTPClient(httpClient))
727 }
728 if len(headers) > 0 {
729 opts = append(opts, bedrock.WithHeaders(headers))
730 }
731 bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
732 if bearerToken != "" {
733 opts = append(opts, bedrock.WithAPIKey(bearerToken))
734 }
735 return bedrock.New(opts...)
736}
737
738func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
739 opts := []google.Option{
740 google.WithBaseURL(baseURL),
741 google.WithGeminiAPIKey(apiKey),
742 }
743 if c.cfg.Options.Debug {
744 httpClient := log.NewHTTPClient()
745 opts = append(opts, google.WithHTTPClient(httpClient))
746 }
747 if len(headers) > 0 {
748 opts = append(opts, google.WithHeaders(headers))
749 }
750 return google.New(opts...)
751}
752
753func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
754 opts := []google.Option{}
755 if c.cfg.Options.Debug {
756 httpClient := log.NewHTTPClient()
757 opts = append(opts, google.WithHTTPClient(httpClient))
758 }
759 if len(headers) > 0 {
760 opts = append(opts, google.WithHeaders(headers))
761 }
762
763 project := options["project"]
764 location := options["location"]
765
766 opts = append(opts, google.WithVertex(project, location))
767
768 return google.New(opts...)
769}
770
771func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
772 opts := []hyper.Option{
773 hyper.WithBaseURL(baseURL),
774 hyper.WithAPIKey(apiKey),
775 }
776 if c.cfg.Options.Debug {
777 httpClient := log.NewHTTPClient()
778 opts = append(opts, hyper.WithHTTPClient(httpClient))
779 }
780 return hyper.New(opts...)
781}
782
783func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
784 if model.Think {
785 return true
786 }
787 opts, err := anthropic.ParseOptions(model.ProviderOptions)
788 return err == nil && opts.Thinking != nil
789}
790
791func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
792 headers := maps.Clone(providerCfg.ExtraHeaders)
793 if headers == nil {
794 headers = make(map[string]string)
795 }
796
797 // handle special headers for anthropic
798 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
799 if v, ok := headers["anthropic-beta"]; ok {
800 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
801 } else {
802 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
803 }
804 }
805
806 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
807 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
808
809 switch providerCfg.Type {
810 case openai.Name:
811 return c.buildOpenaiProvider(baseURL, apiKey, headers)
812 case anthropic.Name:
813 return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
814 case openrouter.Name:
815 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
816 case vercel.Name:
817 return c.buildVercelProvider(baseURL, apiKey, headers)
818 case azure.Name:
819 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
820 case bedrock.Name:
821 return c.buildBedrockProvider(headers)
822 case google.Name:
823 return c.buildGoogleProvider(baseURL, apiKey, headers)
824 case "google-vertex":
825 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
826 case openaicompat.Name:
827 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
828 if providerCfg.ExtraBody == nil {
829 providerCfg.ExtraBody = map[string]any{}
830 }
831 providerCfg.ExtraBody["tool_stream"] = true
832 }
833 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
834 case hyper.Name:
835 return c.buildHyperProvider(baseURL, apiKey)
836 default:
837 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
838 }
839}
840
841func isExactoSupported(modelID string) bool {
842 supportedModels := []string{
843 "moonshotai/kimi-k2-0905",
844 "deepseek/deepseek-v3.1-terminus",
845 "z-ai/glm-4.6",
846 "openai/gpt-oss-120b",
847 "qwen/qwen3-coder",
848 }
849 return slices.Contains(supportedModels, modelID)
850}
851
852func (c *coordinator) Cancel(sessionID string) {
853 c.currentAgent.Cancel(sessionID)
854}
855
856func (c *coordinator) CancelAll() {
857 c.currentAgent.CancelAll()
858}
859
860func (c *coordinator) ClearQueue(sessionID string) {
861 c.currentAgent.ClearQueue(sessionID)
862}
863
864func (c *coordinator) IsBusy() bool {
865 return c.currentAgent.IsBusy()
866}
867
868func (c *coordinator) IsSessionBusy(sessionID string) bool {
869 return c.currentAgent.IsSessionBusy(sessionID)
870}
871
872func (c *coordinator) Model() Model {
873 return c.currentAgent.Model()
874}
875
876func (c *coordinator) UpdateModels(ctx context.Context) error {
877 // build the models again so we make sure we get the latest config
878 large, small, err := c.buildAgentModels(ctx, false)
879 if err != nil {
880 return err
881 }
882 c.currentAgent.SetModels(large, small)
883
884 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
885 if !ok {
886 return errCoderAgentNotConfigured
887 }
888
889 tools, err := c.buildTools(ctx, agentCfg)
890 if err != nil {
891 return err
892 }
893 c.currentAgent.SetTools(tools)
894 return nil
895}
896
897func (c *coordinator) QueuedPrompts(sessionID string) int {
898 return c.currentAgent.QueuedPrompts(sessionID)
899}
900
901func (c *coordinator) QueuedPromptsList(sessionID string) []string {
902 return c.currentAgent.QueuedPromptsList(sessionID)
903}
904
905func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
906 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
907 if !ok {
908 return errModelProviderNotConfigured
909 }
910 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
911}
912
913func (c *coordinator) isUnauthorized(err error) bool {
914 var providerErr *fantasy.ProviderError
915 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
916}
917
918func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
919 if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
920 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
921 return err
922 }
923 if err := c.UpdateModels(ctx); err != nil {
924 return err
925 }
926 return nil
927}
928
929func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
930 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
931 if err != nil {
932 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
933 return err
934 }
935
936 providerCfg.APIKey = newAPIKey
937 c.cfg.Providers.Set(providerCfg.ID, providerCfg)
938
939 if err := c.UpdateModels(ctx); err != nil {
940 return err
941 }
942 return nil
943}
944
945// subAgentParams holds the parameters for running a sub-agent.
946type subAgentParams struct {
947 Agent SessionAgent
948 SessionID string
949 AgentMessageID string
950 ToolCallID string
951 Prompt string
952 SessionTitle string
953 // SessionSetup is an optional callback invoked after session creation
954 // but before agent execution, for custom session configuration.
955 SessionSetup func(sessionID string)
956}
957
958// runSubAgent runs a sub-agent and handles session management and cost accumulation.
959// It creates a sub-session, runs the agent with the given prompt, and propagates
960// the cost to the parent session.
961func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
962 // Create sub-session
963 agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
964 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
965 if err != nil {
966 return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
967 }
968
969 // Call session setup function if provided
970 if params.SessionSetup != nil {
971 params.SessionSetup(session.ID)
972 }
973
974 // Get model configuration
975 model := params.Agent.Model()
976 maxTokens := model.CatwalkCfg.DefaultMaxTokens
977 if model.ModelCfg.MaxTokens != 0 {
978 maxTokens = model.ModelCfg.MaxTokens
979 }
980
981 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
982 if !ok {
983 return fantasy.ToolResponse{}, errModelProviderNotConfigured
984 }
985
986 // Run the agent
987 result, err := params.Agent.Run(ctx, SessionAgentCall{
988 SessionID: session.ID,
989 Prompt: params.Prompt,
990 MaxOutputTokens: maxTokens,
991 ProviderOptions: getProviderOptions(model, providerCfg),
992 Temperature: model.ModelCfg.Temperature,
993 TopP: model.ModelCfg.TopP,
994 TopK: model.ModelCfg.TopK,
995 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
996 PresencePenalty: model.ModelCfg.PresencePenalty,
997 })
998 if err != nil {
999 return fantasy.NewTextErrorResponse("error generating response"), nil
1000 }
1001
1002 // Update parent session cost
1003 if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1004 return fantasy.ToolResponse{}, err
1005 }
1006
1007 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1008}
1009
1010// updateParentSessionCost accumulates the cost from a child session to its parent session.
1011func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1012 childSession, err := c.sessions.Get(ctx, childSessionID)
1013 if err != nil {
1014 return fmt.Errorf("get child session: %w", err)
1015 }
1016
1017 parentSession, err := c.sessions.Get(ctx, parentSessionID)
1018 if err != nil {
1019 return fmt.Errorf("get parent session: %w", err)
1020 }
1021
1022 parentSession.Cost += childSession.Cost
1023
1024 if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1025 return fmt.Errorf("save parent session: %w", err)
1026 }
1027
1028 return nil
1029}