1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/fantasy"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/prompt"
22 "github.com/charmbracelet/crush/internal/agent/tools"
23 "github.com/charmbracelet/crush/internal/config"
24 "github.com/charmbracelet/crush/internal/filetracker"
25 "github.com/charmbracelet/crush/internal/history"
26 "github.com/charmbracelet/crush/internal/log"
27 "github.com/charmbracelet/crush/internal/lsp"
28 "github.com/charmbracelet/crush/internal/message"
29 "github.com/charmbracelet/crush/internal/oauth/copilot"
30 "github.com/charmbracelet/crush/internal/permission"
31 "github.com/charmbracelet/crush/internal/session"
32 "golang.org/x/sync/errgroup"
33
34 "charm.land/fantasy/providers/anthropic"
35 "charm.land/fantasy/providers/azure"
36 "charm.land/fantasy/providers/bedrock"
37 "charm.land/fantasy/providers/google"
38 "charm.land/fantasy/providers/openai"
39 "charm.land/fantasy/providers/openaicompat"
40 "charm.land/fantasy/providers/openrouter"
41 "charm.land/fantasy/providers/vercel"
42 openaisdk "github.com/openai/openai-go/v2/option"
43 "github.com/qjebbs/go-jsons"
44)
45
46type Coordinator interface {
47 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
48 // SetMainAgent(string)
49 Run(ctx context.Context, sessionID, prompt string, verbose bool, attachments ...message.Attachment) (*fantasy.AgentResult, error)
50 Cancel(sessionID string)
51 CancelAll()
52 IsSessionBusy(sessionID string) bool
53 IsBusy() bool
54 QueuedPrompts(sessionID string) int
55 QueuedPromptsList(sessionID string) []string
56 ClearQueue(sessionID string)
57 Summarize(context.Context, string) error
58 Model() Model
59 UpdateModels(ctx context.Context) error
60}
61
62type coordinator struct {
63 cfg *config.Config
64 sessions session.Service
65 messages message.Service
66 permissions permission.Service
67 history history.Service
68 filetracker filetracker.Service
69 lspManager *lsp.Manager
70
71 currentAgent SessionAgent
72 agents map[string]SessionAgent
73
74 readyWg errgroup.Group
75}
76
77func NewCoordinator(
78 ctx context.Context,
79 cfg *config.Config,
80 sessions session.Service,
81 messages message.Service,
82 permissions permission.Service,
83 history history.Service,
84 filetracker filetracker.Service,
85 lspManager *lsp.Manager,
86) (Coordinator, error) {
87 c := &coordinator{
88 cfg: cfg,
89 sessions: sessions,
90 messages: messages,
91 permissions: permissions,
92 history: history,
93 filetracker: filetracker,
94 lspManager: lspManager,
95 agents: make(map[string]SessionAgent),
96 }
97
98 agentCfg, ok := cfg.Agents[config.AgentCoder]
99 if !ok {
100 return nil, errors.New("coder agent not configured")
101 }
102
103 // TODO: make this dynamic when we support multiple agents
104 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
105 if err != nil {
106 return nil, err
107 }
108
109 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
110 if err != nil {
111 return nil, err
112 }
113 c.currentAgent = agent
114 c.agents[config.AgentCoder] = agent
115 return c, nil
116}
117
118// Run implements Coordinator.
119func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, verbose bool, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
120 if err := c.readyWg.Wait(); err != nil {
121 return nil, err
122 }
123
124 // refresh models before each run
125 if err := c.UpdateModels(ctx); err != nil {
126 return nil, fmt.Errorf("failed to update models: %w", err)
127 }
128
129 model := c.currentAgent.Model()
130 maxTokens := model.CatwalkCfg.DefaultMaxTokens
131 if model.ModelCfg.MaxTokens != 0 {
132 maxTokens = model.ModelCfg.MaxTokens
133 }
134
135 if !model.CatwalkCfg.SupportsImages && attachments != nil {
136 // filter out image attachments
137 filteredAttachments := make([]message.Attachment, 0, len(attachments))
138 for _, att := range attachments {
139 if att.IsText() {
140 filteredAttachments = append(filteredAttachments, att)
141 }
142 }
143 attachments = filteredAttachments
144 }
145
146 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
147 if !ok {
148 return nil, errors.New("model provider not configured")
149 }
150
151 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
152
153 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
154 slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
155 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
156 return nil, err
157 }
158 }
159
160 run := func() (*fantasy.AgentResult, error) {
161 return c.currentAgent.Run(ctx, SessionAgentCall{
162 SessionID: sessionID,
163 Prompt: prompt,
164 Attachments: attachments,
165 MaxOutputTokens: maxTokens,
166 ShowToolCalls: verbose,
167 ProviderOptions: mergedOptions,
168 Temperature: temp,
169 TopP: topP,
170 TopK: topK,
171 FrequencyPenalty: freqPenalty,
172 PresencePenalty: presPenalty,
173 })
174 }
175 result, originalErr := run()
176
177 if c.isUnauthorized(originalErr) {
178 switch {
179 case providerCfg.OAuthToken != nil:
180 slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
181 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
182 return nil, originalErr
183 }
184 slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
185 return run()
186 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
187 slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
188 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
189 return nil, originalErr
190 }
191 slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
192 return run()
193 }
194 }
195
196 return result, originalErr
197}
198
199func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
200 options := fantasy.ProviderOptions{}
201
202 cfgOpts := []byte("{}")
203 providerCfgOpts := []byte("{}")
204 catwalkOpts := []byte("{}")
205
206 if model.ModelCfg.ProviderOptions != nil {
207 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
208 if err == nil {
209 cfgOpts = data
210 }
211 }
212
213 if providerCfg.ProviderOptions != nil {
214 data, err := json.Marshal(providerCfg.ProviderOptions)
215 if err == nil {
216 providerCfgOpts = data
217 }
218 }
219
220 if model.CatwalkCfg.Options.ProviderOptions != nil {
221 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
222 if err == nil {
223 catwalkOpts = data
224 }
225 }
226
227 readers := []io.Reader{
228 bytes.NewReader(catwalkOpts),
229 bytes.NewReader(providerCfgOpts),
230 bytes.NewReader(cfgOpts),
231 }
232
233 got, err := jsons.Merge(readers)
234 if err != nil {
235 slog.Error("Could not merge call config", "err", err)
236 return options
237 }
238
239 mergedOptions := make(map[string]any)
240
241 err = json.Unmarshal([]byte(got), &mergedOptions)
242 if err != nil {
243 slog.Error("Could not create config for call", "err", err)
244 return options
245 }
246
247 providerType := providerCfg.Type
248 if providerType == "hyper" {
249 if strings.Contains(model.CatwalkCfg.ID, "claude") {
250 providerType = anthropic.Name
251 } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
252 providerType = openai.Name
253 } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
254 providerType = google.Name
255 } else {
256 providerType = openaicompat.Name
257 }
258 }
259
260 switch providerType {
261 case openai.Name, azure.Name:
262 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
263 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
264 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
265 }
266 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
267 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
268 mergedOptions["reasoning_summary"] = "auto"
269 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
270 }
271 parsed, err := openai.ParseResponsesOptions(mergedOptions)
272 if err == nil {
273 options[openai.Name] = parsed
274 }
275 } else {
276 parsed, err := openai.ParseOptions(mergedOptions)
277 if err == nil {
278 options[openai.Name] = parsed
279 }
280 }
281 case anthropic.Name:
282 var (
283 _, hasEffort = mergedOptions["effort"]
284 _, hasThink = mergedOptions["thinking"]
285 )
286 switch {
287 case !hasEffort && model.ModelCfg.ReasoningEffort != "":
288 mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
289 case !hasThink && model.ModelCfg.Think:
290 mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
291 }
292 parsed, err := anthropic.ParseOptions(mergedOptions)
293 if err == nil {
294 options[anthropic.Name] = parsed
295 }
296
297 case openrouter.Name:
298 _, hasReasoning := mergedOptions["reasoning"]
299 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
300 mergedOptions["reasoning"] = map[string]any{
301 "enabled": true,
302 "effort": model.ModelCfg.ReasoningEffort,
303 }
304 }
305 parsed, err := openrouter.ParseOptions(mergedOptions)
306 if err == nil {
307 options[openrouter.Name] = parsed
308 }
309 case vercel.Name:
310 _, hasReasoning := mergedOptions["reasoning"]
311 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
312 mergedOptions["reasoning"] = map[string]any{
313 "enabled": true,
314 "effort": model.ModelCfg.ReasoningEffort,
315 }
316 }
317 parsed, err := vercel.ParseOptions(mergedOptions)
318 if err == nil {
319 options[vercel.Name] = parsed
320 }
321 case google.Name:
322 _, hasReasoning := mergedOptions["thinking_config"]
323 if !hasReasoning {
324 if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
325 mergedOptions["thinking_config"] = map[string]any{
326 "thinking_budget": 2000,
327 "include_thoughts": true,
328 }
329 } else {
330 mergedOptions["thinking_config"] = map[string]any{
331 "thinking_level": model.ModelCfg.ReasoningEffort,
332 "include_thoughts": true,
333 }
334 }
335 }
336 parsed, err := google.ParseOptions(mergedOptions)
337 if err == nil {
338 options[google.Name] = parsed
339 }
340 case openaicompat.Name:
341 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
342 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
343 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
344 }
345 parsed, err := openaicompat.ParseOptions(mergedOptions)
346 if err == nil {
347 options[openaicompat.Name] = parsed
348 }
349 }
350
351 return options
352}
353
354func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
355 modelOptions := getProviderOptions(model, cfg)
356 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
357 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
358 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
359 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
360 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
361 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
362}
363
364func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
365 large, small, err := c.buildAgentModels(ctx, isSubAgent)
366 if err != nil {
367 return nil, err
368 }
369
370 largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
371 result := NewSessionAgent(SessionAgentOptions{
372 large,
373 small,
374 largeProviderCfg.SystemPromptPrefix,
375 "",
376 isSubAgent,
377 c.cfg.Options.DisableAutoSummarize,
378 c.permissions.SkipRequests(),
379 c.sessions,
380 c.messages,
381 nil,
382 })
383
384 c.readyWg.Go(func() error {
385 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
386 if err != nil {
387 return err
388 }
389 result.SetSystemPrompt(systemPrompt)
390 return nil
391 })
392
393 c.readyWg.Go(func() error {
394 tools, err := c.buildTools(ctx, agent)
395 if err != nil {
396 return err
397 }
398 result.SetTools(tools)
399 return nil
400 })
401
402 return result, nil
403}
404
405func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
406 var allTools []fantasy.AgentTool
407 if slices.Contains(agent.AllowedTools, AgentToolName) {
408 agentTool, err := c.agentTool(ctx)
409 if err != nil {
410 return nil, err
411 }
412 allTools = append(allTools, agentTool)
413 }
414
415 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
416 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
417 if err != nil {
418 return nil, err
419 }
420 allTools = append(allTools, agenticFetchTool)
421 }
422
423 // Get the model name for the agent
424 modelName := ""
425 if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
426 if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
427 modelName = model.Name
428 }
429 }
430
431 allTools = append(allTools,
432 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
433 tools.NewJobOutputTool(),
434 tools.NewJobKillTool(),
435 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
436 tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
437 tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
438 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
439 tools.NewGlobTool(c.cfg.WorkingDir()),
440 tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep),
441 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
442 tools.NewSourcegraphTool(nil),
443 tools.NewTodosTool(c.sessions),
444 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
445 tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
446 )
447
448 // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
449 if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP {
450 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
451 }
452
453 if len(c.cfg.MCP) > 0 {
454 allTools = append(
455 allTools,
456 tools.NewListMCPResourcesTool(c.cfg, c.permissions),
457 tools.NewReadMCPResourceTool(c.cfg, c.permissions),
458 )
459 }
460
461 var filteredTools []fantasy.AgentTool
462 for _, tool := range allTools {
463 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
464 filteredTools = append(filteredTools, tool)
465 }
466 }
467
468 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
469 if agent.AllowedMCP == nil {
470 // No MCP restrictions
471 filteredTools = append(filteredTools, tool)
472 continue
473 }
474 if len(agent.AllowedMCP) == 0 {
475 // No MCPs allowed
476 slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
477 break
478 }
479
480 for mcp, tools := range agent.AllowedMCP {
481 if mcp != tool.MCP() {
482 continue
483 }
484 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
485 filteredTools = append(filteredTools, tool)
486 break
487 }
488 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
489 }
490 }
491 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
492 return strings.Compare(a.Info().Name, b.Info().Name)
493 })
494 return filteredTools, nil
495}
496
497// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
498func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
499 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
500 if !ok {
501 return Model{}, Model{}, errors.New("large model not selected")
502 }
503 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
504 if !ok {
505 return Model{}, Model{}, errors.New("small model not selected")
506 }
507
508 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
509 if !ok {
510 return Model{}, Model{}, errors.New("large model provider not configured")
511 }
512
513 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
514 if err != nil {
515 return Model{}, Model{}, err
516 }
517
518 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
519 if !ok {
520 return Model{}, Model{}, errors.New("small model provider not configured")
521 }
522
523 smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
524 if err != nil {
525 return Model{}, Model{}, err
526 }
527
528 var largeCatwalkModel *catwalk.Model
529 var smallCatwalkModel *catwalk.Model
530
531 for _, m := range largeProviderCfg.Models {
532 if m.ID == largeModelCfg.Model {
533 largeCatwalkModel = &m
534 }
535 }
536 for _, m := range smallProviderCfg.Models {
537 if m.ID == smallModelCfg.Model {
538 smallCatwalkModel = &m
539 }
540 }
541
542 if largeCatwalkModel == nil {
543 return Model{}, Model{}, errors.New("large model not found in provider config")
544 }
545
546 if smallCatwalkModel == nil {
547 return Model{}, Model{}, errors.New("small model not found in provider config")
548 }
549
550 largeModelID := largeModelCfg.Model
551 smallModelID := smallModelCfg.Model
552
553 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
554 largeModelID += ":exacto"
555 }
556
557 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
558 smallModelID += ":exacto"
559 }
560
561 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
562 if err != nil {
563 return Model{}, Model{}, err
564 }
565 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
566 if err != nil {
567 return Model{}, Model{}, err
568 }
569
570 return Model{
571 Model: largeModel,
572 CatwalkCfg: *largeCatwalkModel,
573 ModelCfg: largeModelCfg,
574 }, Model{
575 Model: smallModel,
576 CatwalkCfg: *smallCatwalkModel,
577 ModelCfg: smallModelCfg,
578 }, nil
579}
580
581func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
582 var opts []anthropic.Option
583
584 switch {
585 case strings.HasPrefix(apiKey, "Bearer "):
586 // NOTE: Prevent the SDK from picking up the API key from env.
587 os.Setenv("ANTHROPIC_API_KEY", "")
588 headers["Authorization"] = apiKey
589 case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
590 // NOTE: Prevent the SDK from picking up the API key from env.
591 os.Setenv("ANTHROPIC_API_KEY", "")
592 headers["Authorization"] = "Bearer " + apiKey
593 case apiKey != "":
594 // X-Api-Key header
595 opts = append(opts, anthropic.WithAPIKey(apiKey))
596 }
597
598 if len(headers) > 0 {
599 opts = append(opts, anthropic.WithHeaders(headers))
600 }
601
602 if baseURL != "" {
603 opts = append(opts, anthropic.WithBaseURL(baseURL))
604 }
605
606 if c.cfg.Options.Debug {
607 httpClient := log.NewHTTPClient()
608 opts = append(opts, anthropic.WithHTTPClient(httpClient))
609 }
610 return anthropic.New(opts...)
611}
612
613func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
614 opts := []openai.Option{
615 openai.WithAPIKey(apiKey),
616 openai.WithUseResponsesAPI(),
617 }
618 if c.cfg.Options.Debug {
619 httpClient := log.NewHTTPClient()
620 opts = append(opts, openai.WithHTTPClient(httpClient))
621 }
622 if len(headers) > 0 {
623 opts = append(opts, openai.WithHeaders(headers))
624 }
625 if baseURL != "" {
626 opts = append(opts, openai.WithBaseURL(baseURL))
627 }
628 return openai.New(opts...)
629}
630
631func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
632 opts := []openrouter.Option{
633 openrouter.WithAPIKey(apiKey),
634 }
635 if c.cfg.Options.Debug {
636 httpClient := log.NewHTTPClient()
637 opts = append(opts, openrouter.WithHTTPClient(httpClient))
638 }
639 if len(headers) > 0 {
640 opts = append(opts, openrouter.WithHeaders(headers))
641 }
642 return openrouter.New(opts...)
643}
644
645func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
646 opts := []vercel.Option{
647 vercel.WithAPIKey(apiKey),
648 }
649 if c.cfg.Options.Debug {
650 httpClient := log.NewHTTPClient()
651 opts = append(opts, vercel.WithHTTPClient(httpClient))
652 }
653 if len(headers) > 0 {
654 opts = append(opts, vercel.WithHeaders(headers))
655 }
656 return vercel.New(opts...)
657}
658
659func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
660 opts := []openaicompat.Option{
661 openaicompat.WithBaseURL(baseURL),
662 openaicompat.WithAPIKey(apiKey),
663 }
664
665 // Set HTTP client based on provider and debug mode.
666 var httpClient *http.Client
667 if providerID == string(catwalk.InferenceProviderCopilot) {
668 opts = append(opts, openaicompat.WithUseResponsesAPI())
669 httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
670 } else if c.cfg.Options.Debug {
671 httpClient = log.NewHTTPClient()
672 }
673 if httpClient != nil {
674 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
675 }
676
677 if len(headers) > 0 {
678 opts = append(opts, openaicompat.WithHeaders(headers))
679 }
680
681 for extraKey, extraValue := range extraBody {
682 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
683 }
684
685 return openaicompat.New(opts...)
686}
687
688func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
689 opts := []azure.Option{
690 azure.WithBaseURL(baseURL),
691 azure.WithAPIKey(apiKey),
692 azure.WithUseResponsesAPI(),
693 }
694 if c.cfg.Options.Debug {
695 httpClient := log.NewHTTPClient()
696 opts = append(opts, azure.WithHTTPClient(httpClient))
697 }
698 if options == nil {
699 options = make(map[string]string)
700 }
701 if apiVersion, ok := options["apiVersion"]; ok {
702 opts = append(opts, azure.WithAPIVersion(apiVersion))
703 }
704 if len(headers) > 0 {
705 opts = append(opts, azure.WithHeaders(headers))
706 }
707
708 return azure.New(opts...)
709}
710
711func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
712 var opts []bedrock.Option
713 if c.cfg.Options.Debug {
714 httpClient := log.NewHTTPClient()
715 opts = append(opts, bedrock.WithHTTPClient(httpClient))
716 }
717 if len(headers) > 0 {
718 opts = append(opts, bedrock.WithHeaders(headers))
719 }
720 bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
721 if bearerToken != "" {
722 opts = append(opts, bedrock.WithAPIKey(bearerToken))
723 }
724 return bedrock.New(opts...)
725}
726
727func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
728 opts := []google.Option{
729 google.WithBaseURL(baseURL),
730 google.WithGeminiAPIKey(apiKey),
731 }
732 if c.cfg.Options.Debug {
733 httpClient := log.NewHTTPClient()
734 opts = append(opts, google.WithHTTPClient(httpClient))
735 }
736 if len(headers) > 0 {
737 opts = append(opts, google.WithHeaders(headers))
738 }
739 return google.New(opts...)
740}
741
742func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
743 opts := []google.Option{}
744 if c.cfg.Options.Debug {
745 httpClient := log.NewHTTPClient()
746 opts = append(opts, google.WithHTTPClient(httpClient))
747 }
748 if len(headers) > 0 {
749 opts = append(opts, google.WithHeaders(headers))
750 }
751
752 project := options["project"]
753 location := options["location"]
754
755 opts = append(opts, google.WithVertex(project, location))
756
757 return google.New(opts...)
758}
759
760func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
761 opts := []hyper.Option{
762 hyper.WithBaseURL(baseURL),
763 hyper.WithAPIKey(apiKey),
764 }
765 if c.cfg.Options.Debug {
766 httpClient := log.NewHTTPClient()
767 opts = append(opts, hyper.WithHTTPClient(httpClient))
768 }
769 return hyper.New(opts...)
770}
771
772func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
773 if model.Think {
774 return true
775 }
776
777 if model.ProviderOptions == nil {
778 return false
779 }
780
781 opts, err := anthropic.ParseOptions(model.ProviderOptions)
782 if err != nil {
783 return false
784 }
785 if opts.Thinking != nil {
786 return true
787 }
788 return false
789}
790
791func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
792 headers := maps.Clone(providerCfg.ExtraHeaders)
793 if headers == nil {
794 headers = make(map[string]string)
795 }
796
797 // handle special headers for anthropic
798 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
799 if v, ok := headers["anthropic-beta"]; ok {
800 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
801 } else {
802 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
803 }
804 }
805
806 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
807 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
808
809 switch providerCfg.Type {
810 case openai.Name:
811 return c.buildOpenaiProvider(baseURL, apiKey, headers)
812 case anthropic.Name:
813 return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
814 case openrouter.Name:
815 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
816 case vercel.Name:
817 return c.buildVercelProvider(baseURL, apiKey, headers)
818 case azure.Name:
819 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
820 case bedrock.Name:
821 return c.buildBedrockProvider(headers)
822 case google.Name:
823 return c.buildGoogleProvider(baseURL, apiKey, headers)
824 case "google-vertex":
825 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
826 case openaicompat.Name:
827 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
828 if providerCfg.ExtraBody == nil {
829 providerCfg.ExtraBody = map[string]any{}
830 }
831 providerCfg.ExtraBody["tool_stream"] = true
832 }
833 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
834 case hyper.Name:
835 return c.buildHyperProvider(baseURL, apiKey)
836 default:
837 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
838 }
839}
840
841func isExactoSupported(modelID string) bool {
842 supportedModels := []string{
843 "moonshotai/kimi-k2-0905",
844 "deepseek/deepseek-v3.1-terminus",
845 "z-ai/glm-4.6",
846 "openai/gpt-oss-120b",
847 "qwen/qwen3-coder",
848 }
849 return slices.Contains(supportedModels, modelID)
850}
851
852func (c *coordinator) Cancel(sessionID string) {
853 c.currentAgent.Cancel(sessionID)
854}
855
856func (c *coordinator) CancelAll() {
857 c.currentAgent.CancelAll()
858}
859
860func (c *coordinator) ClearQueue(sessionID string) {
861 c.currentAgent.ClearQueue(sessionID)
862}
863
864func (c *coordinator) IsBusy() bool {
865 return c.currentAgent.IsBusy()
866}
867
868func (c *coordinator) IsSessionBusy(sessionID string) bool {
869 return c.currentAgent.IsSessionBusy(sessionID)
870}
871
872func (c *coordinator) Model() Model {
873 return c.currentAgent.Model()
874}
875
876func (c *coordinator) UpdateModels(ctx context.Context) error {
877 // build the models again so we make sure we get the latest config
878 large, small, err := c.buildAgentModels(ctx, false)
879 if err != nil {
880 return err
881 }
882 c.currentAgent.SetModels(large, small)
883
884 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
885 if !ok {
886 return errors.New("coder agent not configured")
887 }
888
889 tools, err := c.buildTools(ctx, agentCfg)
890 if err != nil {
891 return err
892 }
893 c.currentAgent.SetTools(tools)
894 return nil
895}
896
897func (c *coordinator) QueuedPrompts(sessionID string) int {
898 return c.currentAgent.QueuedPrompts(sessionID)
899}
900
901func (c *coordinator) QueuedPromptsList(sessionID string) []string {
902 return c.currentAgent.QueuedPromptsList(sessionID)
903}
904
905func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
906 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
907 if !ok {
908 return errors.New("model provider not configured")
909 }
910 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
911}
912
913func (c *coordinator) isUnauthorized(err error) bool {
914 var providerErr *fantasy.ProviderError
915 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
916}
917
918func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
919 if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
920 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
921 return err
922 }
923 if err := c.UpdateModels(ctx); err != nil {
924 return err
925 }
926 return nil
927}
928
929func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
930 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
931 if err != nil {
932 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
933 return err
934 }
935
936 providerCfg.APIKey = newAPIKey
937 c.cfg.Providers.Set(providerCfg.ID, providerCfg)
938
939 if err := c.UpdateModels(ctx); err != nil {
940 return err
941 }
942 return nil
943}
944
945// subAgentParams holds the parameters for running a sub-agent.
946type subAgentParams struct {
947 Agent SessionAgent
948 SessionID string
949 AgentMessageID string
950 ToolCallID string
951 Prompt string
952 SessionTitle string
953 // SessionSetup is an optional callback invoked after session creation
954 // but before agent execution, for custom session configuration.
955 SessionSetup func(sessionID string)
956}
957
958// runSubAgent runs a sub-agent and handles session management and cost accumulation.
959// It creates a sub-session, runs the agent with the given prompt, and propagates
960// the cost to the parent session.
961func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
962 // Create sub-session
963 agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
964 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
965 if err != nil {
966 return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
967 }
968
969 // Call session setup function if provided
970 if params.SessionSetup != nil {
971 params.SessionSetup(session.ID)
972 }
973
974 // Get model configuration
975 model := params.Agent.Model()
976 maxTokens := model.CatwalkCfg.DefaultMaxTokens
977 if model.ModelCfg.MaxTokens != 0 {
978 maxTokens = model.ModelCfg.MaxTokens
979 }
980
981 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
982 if !ok {
983 return fantasy.ToolResponse{}, errors.New("model provider not configured")
984 }
985
986 // Run the agent
987 result, err := params.Agent.Run(ctx, SessionAgentCall{
988 SessionID: session.ID,
989 Prompt: params.Prompt,
990 MaxOutputTokens: maxTokens,
991 ProviderOptions: getProviderOptions(model, providerCfg),
992 Temperature: model.ModelCfg.Temperature,
993 TopP: model.ModelCfg.TopP,
994 TopK: model.ModelCfg.TopK,
995 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
996 PresencePenalty: model.ModelCfg.PresencePenalty,
997 })
998 if err != nil {
999 return fantasy.NewTextErrorResponse("error generating response"), nil
1000 }
1001
1002 // Update parent session cost
1003 if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1004 return fantasy.ToolResponse{}, err
1005 }
1006
1007 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1008}
1009
1010// updateParentSessionCost accumulates the cost from a child session to its parent session.
1011func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1012 childSession, err := c.sessions.Get(ctx, childSessionID)
1013 if err != nil {
1014 return fmt.Errorf("get child session: %w", err)
1015 }
1016
1017 parentSession, err := c.sessions.Get(ctx, parentSessionID)
1018 if err != nil {
1019 return fmt.Errorf("get parent session: %w", err)
1020 }
1021
1022 parentSession.Cost += childSession.Cost
1023
1024 if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1025 return fmt.Errorf("save parent session: %w", err)
1026 }
1027
1028 return nil
1029}