1package agent
2
3import (
4 "bytes"
5 "cmp"
6 "context"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "maps"
13 "net/http"
14 "os"
15 "slices"
16 "strings"
17
18 "charm.land/catwalk/pkg/catwalk"
19 "charm.land/fantasy"
20 "github.com/charmbracelet/crush/internal/agent/hyper"
21 "github.com/charmbracelet/crush/internal/agent/notify"
22 "github.com/charmbracelet/crush/internal/agent/prompt"
23 "github.com/charmbracelet/crush/internal/agent/tools"
24 "github.com/charmbracelet/crush/internal/config"
25 "github.com/charmbracelet/crush/internal/filetracker"
26 "github.com/charmbracelet/crush/internal/history"
27 "github.com/charmbracelet/crush/internal/log"
28 "github.com/charmbracelet/crush/internal/lsp"
29 "github.com/charmbracelet/crush/internal/message"
30 "github.com/charmbracelet/crush/internal/oauth/copilot"
31 "github.com/charmbracelet/crush/internal/permission"
32 "github.com/charmbracelet/crush/internal/pubsub"
33 "github.com/charmbracelet/crush/internal/session"
34 "golang.org/x/sync/errgroup"
35
36 "charm.land/fantasy/providers/anthropic"
37 "charm.land/fantasy/providers/azure"
38 "charm.land/fantasy/providers/bedrock"
39 "charm.land/fantasy/providers/google"
40 "charm.land/fantasy/providers/openai"
41 "charm.land/fantasy/providers/openaicompat"
42 "charm.land/fantasy/providers/openrouter"
43 "charm.land/fantasy/providers/vercel"
44 openaisdk "github.com/openai/openai-go/v2/option"
45 "github.com/qjebbs/go-jsons"
46)
47
48// Coordinator errors.
49var (
50 errCoderAgentNotConfigured = errors.New("coder agent not configured")
51 errModelProviderNotConfigured = errors.New("model provider not configured")
52 errLargeModelNotSelected = errors.New("large model not selected")
53 errSmallModelNotSelected = errors.New("small model not selected")
54 errLargeModelProviderNotConfigured = errors.New("large model provider not configured")
55 errSmallModelProviderNotConfigured = errors.New("small model provider not configured")
56 errLargeModelNotFound = errors.New("large model not found in provider config")
57 errSmallModelNotFound = errors.New("small model not found in provider config")
58)
59
60type Coordinator interface {
61 // INFO: (kujtim) this is not used yet we will use this when we have multiple agents
62 // SetMainAgent(string)
63 Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error)
64 Cancel(sessionID string)
65 CancelAll()
66 IsSessionBusy(sessionID string) bool
67 IsBusy() bool
68 QueuedPrompts(sessionID string) int
69 QueuedPromptsList(sessionID string) []string
70 ClearQueue(sessionID string)
71 Summarize(context.Context, string) error
72 Model() Model
73 UpdateModels(ctx context.Context) error
74 RefreshTools(ctx context.Context) error
75}
76
77type coordinator struct {
78 cfg *config.Config
79 sessions session.Service
80 messages message.Service
81 permissions permission.Service
82 history history.Service
83 filetracker filetracker.Service
84 lspManager *lsp.Manager
85 notify pubsub.Publisher[notify.Notification]
86
87 currentAgent SessionAgent
88 agents map[string]SessionAgent
89
90 readyWg errgroup.Group
91}
92
93func NewCoordinator(
94 ctx context.Context,
95 cfg *config.Config,
96 sessions session.Service,
97 messages message.Service,
98 permissions permission.Service,
99 history history.Service,
100 filetracker filetracker.Service,
101 lspManager *lsp.Manager,
102 notify pubsub.Publisher[notify.Notification],
103) (Coordinator, error) {
104 c := &coordinator{
105 cfg: cfg,
106 sessions: sessions,
107 messages: messages,
108 permissions: permissions,
109 history: history,
110 filetracker: filetracker,
111 lspManager: lspManager,
112 notify: notify,
113 agents: make(map[string]SessionAgent),
114 }
115
116 agentCfg, ok := cfg.Agents[config.AgentCoder]
117 if !ok {
118 return nil, errCoderAgentNotConfigured
119 }
120
121 // TODO: make this dynamic when we support multiple agents
122 prompt, err := coderPrompt(prompt.WithWorkingDir(c.cfg.WorkingDir()))
123 if err != nil {
124 return nil, err
125 }
126
127 agent, err := c.buildAgent(ctx, prompt, agentCfg, false)
128 if err != nil {
129 return nil, err
130 }
131 c.currentAgent = agent
132 c.agents[config.AgentCoder] = agent
133 return c, nil
134}
135
136// Run implements Coordinator.
137func (c *coordinator) Run(ctx context.Context, sessionID string, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
138 if err := c.readyWg.Wait(); err != nil {
139 return nil, err
140 }
141
142 // refresh models before each run
143 if err := c.UpdateModels(ctx); err != nil {
144 return nil, fmt.Errorf("failed to update models: %w", err)
145 }
146
147 model := c.currentAgent.Model()
148 maxTokens := model.CatwalkCfg.DefaultMaxTokens
149 if model.ModelCfg.MaxTokens != 0 {
150 maxTokens = model.ModelCfg.MaxTokens
151 }
152
153 if !model.CatwalkCfg.SupportsImages && attachments != nil {
154 // filter out image attachments
155 filteredAttachments := make([]message.Attachment, 0, len(attachments))
156 for _, att := range attachments {
157 if att.IsText() {
158 filteredAttachments = append(filteredAttachments, att)
159 }
160 }
161 attachments = filteredAttachments
162 }
163
164 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
165 if !ok {
166 return nil, errModelProviderNotConfigured
167 }
168
169 mergedOptions, temp, topP, topK, freqPenalty, presPenalty := mergeCallOptions(model, providerCfg)
170
171 if providerCfg.OAuthToken != nil && providerCfg.OAuthToken.IsExpired() {
172 slog.Debug("Token needs to be refreshed", "provider", providerCfg.ID)
173 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
174 return nil, err
175 }
176 }
177
178 run := func() (*fantasy.AgentResult, error) {
179 return c.currentAgent.Run(ctx, SessionAgentCall{
180 SessionID: sessionID,
181 Prompt: prompt,
182 Attachments: attachments,
183 MaxOutputTokens: maxTokens,
184 ProviderOptions: mergedOptions,
185 Temperature: temp,
186 TopP: topP,
187 TopK: topK,
188 FrequencyPenalty: freqPenalty,
189 PresencePenalty: presPenalty,
190 })
191 }
192 result, originalErr := run()
193
194 if c.isUnauthorized(originalErr) {
195 switch {
196 case providerCfg.OAuthToken != nil:
197 slog.Debug("Received 401. Refreshing token and retrying", "provider", providerCfg.ID)
198 if err := c.refreshOAuth2Token(ctx, providerCfg); err != nil {
199 return nil, originalErr
200 }
201 slog.Debug("Retrying request with refreshed OAuth token", "provider", providerCfg.ID)
202 return run()
203 case strings.Contains(providerCfg.APIKeyTemplate, "$"):
204 slog.Debug("Received 401. Refreshing API Key template and retrying", "provider", providerCfg.ID)
205 if err := c.refreshApiKeyTemplate(ctx, providerCfg); err != nil {
206 return nil, originalErr
207 }
208 slog.Debug("Retrying request with refreshed API key", "provider", providerCfg.ID)
209 return run()
210 }
211 }
212
213 return result, originalErr
214}
215
216func getProviderOptions(model Model, providerCfg config.ProviderConfig) fantasy.ProviderOptions {
217 options := fantasy.ProviderOptions{}
218
219 cfgOpts := []byte("{}")
220 providerCfgOpts := []byte("{}")
221 catwalkOpts := []byte("{}")
222
223 if model.ModelCfg.ProviderOptions != nil {
224 data, err := json.Marshal(model.ModelCfg.ProviderOptions)
225 if err == nil {
226 cfgOpts = data
227 }
228 }
229
230 if providerCfg.ProviderOptions != nil {
231 data, err := json.Marshal(providerCfg.ProviderOptions)
232 if err == nil {
233 providerCfgOpts = data
234 }
235 }
236
237 if model.CatwalkCfg.Options.ProviderOptions != nil {
238 data, err := json.Marshal(model.CatwalkCfg.Options.ProviderOptions)
239 if err == nil {
240 catwalkOpts = data
241 }
242 }
243
244 readers := []io.Reader{
245 bytes.NewReader(catwalkOpts),
246 bytes.NewReader(providerCfgOpts),
247 bytes.NewReader(cfgOpts),
248 }
249
250 got, err := jsons.Merge(readers)
251 if err != nil {
252 slog.Error("Could not merge call config", "err", err)
253 return options
254 }
255
256 mergedOptions := make(map[string]any)
257
258 err = json.Unmarshal([]byte(got), &mergedOptions)
259 if err != nil {
260 slog.Error("Could not create config for call", "err", err)
261 return options
262 }
263
264 providerType := providerCfg.Type
265 if providerType == "hyper" {
266 if strings.Contains(model.CatwalkCfg.ID, "claude") {
267 providerType = anthropic.Name
268 } else if strings.Contains(model.CatwalkCfg.ID, "gpt") {
269 providerType = openai.Name
270 } else if strings.Contains(model.CatwalkCfg.ID, "gemini") {
271 providerType = google.Name
272 } else {
273 providerType = openaicompat.Name
274 }
275 }
276
277 switch providerType {
278 case openai.Name, azure.Name:
279 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
280 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
281 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
282 }
283 if openai.IsResponsesModel(model.CatwalkCfg.ID) {
284 if openai.IsResponsesReasoningModel(model.CatwalkCfg.ID) {
285 mergedOptions["reasoning_summary"] = "auto"
286 mergedOptions["include"] = []openai.IncludeType{openai.IncludeReasoningEncryptedContent}
287 }
288 parsed, err := openai.ParseResponsesOptions(mergedOptions)
289 if err == nil {
290 options[openai.Name] = parsed
291 }
292 } else {
293 parsed, err := openai.ParseOptions(mergedOptions)
294 if err == nil {
295 options[openai.Name] = parsed
296 }
297 }
298 case anthropic.Name:
299 var (
300 _, hasEffort = mergedOptions["effort"]
301 _, hasThink = mergedOptions["thinking"]
302 )
303 switch {
304 case !hasEffort && model.ModelCfg.ReasoningEffort != "":
305 mergedOptions["effort"] = model.ModelCfg.ReasoningEffort
306 case !hasThink && model.ModelCfg.Think:
307 mergedOptions["thinking"] = map[string]any{"budget_tokens": 2000}
308 }
309 parsed, err := anthropic.ParseOptions(mergedOptions)
310 if err == nil {
311 options[anthropic.Name] = parsed
312 }
313
314 case openrouter.Name:
315 _, hasReasoning := mergedOptions["reasoning"]
316 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
317 mergedOptions["reasoning"] = map[string]any{
318 "enabled": true,
319 "effort": model.ModelCfg.ReasoningEffort,
320 }
321 }
322 parsed, err := openrouter.ParseOptions(mergedOptions)
323 if err == nil {
324 options[openrouter.Name] = parsed
325 }
326 case vercel.Name:
327 _, hasReasoning := mergedOptions["reasoning"]
328 if !hasReasoning && model.ModelCfg.ReasoningEffort != "" {
329 mergedOptions["reasoning"] = map[string]any{
330 "enabled": true,
331 "effort": model.ModelCfg.ReasoningEffort,
332 }
333 }
334 parsed, err := vercel.ParseOptions(mergedOptions)
335 if err == nil {
336 options[vercel.Name] = parsed
337 }
338 case google.Name:
339 _, hasReasoning := mergedOptions["thinking_config"]
340 if !hasReasoning {
341 if strings.HasPrefix(model.CatwalkCfg.ID, "gemini-2") {
342 mergedOptions["thinking_config"] = map[string]any{
343 "thinking_budget": 2000,
344 "include_thoughts": true,
345 }
346 } else {
347 mergedOptions["thinking_config"] = map[string]any{
348 "thinking_level": model.ModelCfg.ReasoningEffort,
349 "include_thoughts": true,
350 }
351 }
352 }
353 parsed, err := google.ParseOptions(mergedOptions)
354 if err == nil {
355 options[google.Name] = parsed
356 }
357 case openaicompat.Name:
358 _, hasReasoningEffort := mergedOptions["reasoning_effort"]
359 if !hasReasoningEffort && model.ModelCfg.ReasoningEffort != "" {
360 mergedOptions["reasoning_effort"] = model.ModelCfg.ReasoningEffort
361 }
362 parsed, err := openaicompat.ParseOptions(mergedOptions)
363 if err == nil {
364 options[openaicompat.Name] = parsed
365 }
366 }
367
368 return options
369}
370
371func mergeCallOptions(model Model, cfg config.ProviderConfig) (fantasy.ProviderOptions, *float64, *float64, *int64, *float64, *float64) {
372 modelOptions := getProviderOptions(model, cfg)
373 temp := cmp.Or(model.ModelCfg.Temperature, model.CatwalkCfg.Options.Temperature)
374 topP := cmp.Or(model.ModelCfg.TopP, model.CatwalkCfg.Options.TopP)
375 topK := cmp.Or(model.ModelCfg.TopK, model.CatwalkCfg.Options.TopK)
376 freqPenalty := cmp.Or(model.ModelCfg.FrequencyPenalty, model.CatwalkCfg.Options.FrequencyPenalty)
377 presPenalty := cmp.Or(model.ModelCfg.PresencePenalty, model.CatwalkCfg.Options.PresencePenalty)
378 return modelOptions, temp, topP, topK, freqPenalty, presPenalty
379}
380
381func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent, isSubAgent bool) (SessionAgent, error) {
382 large, small, err := c.buildAgentModels(ctx, isSubAgent)
383 if err != nil {
384 return nil, err
385 }
386
387 largeProviderCfg, _ := c.cfg.Providers.Get(large.ModelCfg.Provider)
388 result := NewSessionAgent(SessionAgentOptions{
389 LargeModel: large,
390 SmallModel: small,
391 SystemPromptPrefix: largeProviderCfg.SystemPromptPrefix,
392 SystemPrompt: "",
393 IsSubAgent: isSubAgent,
394 DisableAutoSummarize: c.cfg.Options.DisableAutoSummarize,
395 IsYolo: c.permissions.SkipRequests(),
396 Sessions: c.sessions,
397 Messages: c.messages,
398 Tools: nil,
399 Notify: c.notify,
400 })
401
402 c.readyWg.Go(func() error {
403 systemPrompt, err := prompt.Build(ctx, large.Model.Provider(), large.Model.Model(), *c.cfg)
404 if err != nil {
405 return err
406 }
407 result.SetSystemPrompt(systemPrompt)
408 return nil
409 })
410
411 c.readyWg.Go(func() error {
412 tools, err := c.buildTools(ctx, agent)
413 if err != nil {
414 return err
415 }
416 result.SetTools(tools)
417 return nil
418 })
419
420 return result, nil
421}
422
423func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
424 var allTools []fantasy.AgentTool
425 if slices.Contains(agent.AllowedTools, AgentToolName) {
426 agentTool, err := c.agentTool(ctx)
427 if err != nil {
428 return nil, err
429 }
430 allTools = append(allTools, agentTool)
431 }
432
433 if slices.Contains(agent.AllowedTools, tools.AgenticFetchToolName) {
434 agenticFetchTool, err := c.agenticFetchTool(ctx, nil)
435 if err != nil {
436 return nil, err
437 }
438 allTools = append(allTools, agenticFetchTool)
439 }
440
441 // Get the model name for the agent
442 modelName := ""
443 if modelCfg, ok := c.cfg.Models[agent.Model]; ok {
444 if model := c.cfg.GetModel(modelCfg.Provider, modelCfg.Model); model != nil {
445 modelName = model.Name
446 }
447 }
448
449 allTools = append(allTools,
450 tools.NewBashTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Options.Attribution, modelName),
451 tools.NewJobOutputTool(),
452 tools.NewJobKillTool(),
453 tools.NewDownloadTool(c.permissions, c.cfg.WorkingDir(), nil),
454 tools.NewEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
455 tools.NewMultiEditTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
456 tools.NewFetchTool(c.permissions, c.cfg.WorkingDir(), nil),
457 tools.NewGlobTool(c.cfg.WorkingDir()),
458 tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Tools.Grep),
459 tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Tools.Ls),
460 tools.NewSourcegraphTool(nil),
461 tools.NewTodosTool(c.sessions),
462 tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.cfg.WorkingDir(), c.cfg.Options.SkillsPaths...),
463 tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
464 )
465
466 // Add LSP tools if user has configured LSPs or auto_lsp is enabled (nil or true).
467 if len(c.cfg.LSP) > 0 || c.cfg.Options.AutoLSP == nil || *c.cfg.Options.AutoLSP {
468 allTools = append(allTools, tools.NewDiagnosticsTool(c.lspManager), tools.NewReferencesTool(c.lspManager), tools.NewLSPRestartTool(c.lspManager))
469 }
470
471 if len(c.cfg.MCP) > 0 {
472 allTools = append(
473 allTools,
474 tools.NewListMCPResourcesTool(c.cfg, c.permissions),
475 tools.NewReadMCPResourceTool(c.cfg, c.permissions),
476 )
477 }
478
479 var filteredTools []fantasy.AgentTool
480 for _, tool := range allTools {
481 if slices.Contains(agent.AllowedTools, tool.Info().Name) {
482 filteredTools = append(filteredTools, tool)
483 }
484 }
485
486 for _, tool := range tools.GetMCPTools(c.permissions, c.cfg, c.cfg.WorkingDir()) {
487 if agent.AllowedMCP == nil {
488 // No MCP restrictions
489 filteredTools = append(filteredTools, tool)
490 continue
491 }
492 if len(agent.AllowedMCP) == 0 {
493 // No MCPs allowed
494 slog.Debug("No MCPs allowed", "tool", tool.Name(), "agent", agent.Name)
495 break
496 }
497
498 for mcp, tools := range agent.AllowedMCP {
499 if mcp != tool.MCP() {
500 continue
501 }
502 if len(tools) == 0 || slices.Contains(tools, tool.MCPToolName()) {
503 filteredTools = append(filteredTools, tool)
504 break
505 }
506 slog.Debug("MCP not allowed", "tool", tool.Name(), "agent", agent.Name)
507 }
508 }
509 slices.SortFunc(filteredTools, func(a, b fantasy.AgentTool) int {
510 return strings.Compare(a.Info().Name, b.Info().Name)
511 })
512 return filteredTools, nil
513}
514
515// TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
516func (c *coordinator) buildAgentModels(ctx context.Context, isSubAgent bool) (Model, Model, error) {
517 largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
518 if !ok {
519 return Model{}, Model{}, errLargeModelNotSelected
520 }
521 smallModelCfg, ok := c.cfg.Models[config.SelectedModelTypeSmall]
522 if !ok {
523 return Model{}, Model{}, errSmallModelNotSelected
524 }
525
526 largeProviderCfg, ok := c.cfg.Providers.Get(largeModelCfg.Provider)
527 if !ok {
528 return Model{}, Model{}, errLargeModelProviderNotConfigured
529 }
530
531 largeProvider, err := c.buildProvider(largeProviderCfg, largeModelCfg, isSubAgent)
532 if err != nil {
533 return Model{}, Model{}, err
534 }
535
536 smallProviderCfg, ok := c.cfg.Providers.Get(smallModelCfg.Provider)
537 if !ok {
538 return Model{}, Model{}, errSmallModelProviderNotConfigured
539 }
540
541 smallProvider, err := c.buildProvider(smallProviderCfg, smallModelCfg, true)
542 if err != nil {
543 return Model{}, Model{}, err
544 }
545
546 var largeCatwalkModel *catwalk.Model
547 var smallCatwalkModel *catwalk.Model
548
549 for _, m := range largeProviderCfg.Models {
550 if m.ID == largeModelCfg.Model {
551 largeCatwalkModel = &m
552 }
553 }
554 for _, m := range smallProviderCfg.Models {
555 if m.ID == smallModelCfg.Model {
556 smallCatwalkModel = &m
557 }
558 }
559
560 if largeCatwalkModel == nil {
561 return Model{}, Model{}, errLargeModelNotFound
562 }
563
564 if smallCatwalkModel == nil {
565 return Model{}, Model{}, errSmallModelNotFound
566 }
567
568 largeModelID := largeModelCfg.Model
569 smallModelID := smallModelCfg.Model
570
571 if largeModelCfg.Provider == openrouter.Name && isExactoSupported(largeModelID) {
572 largeModelID += ":exacto"
573 }
574
575 if smallModelCfg.Provider == openrouter.Name && isExactoSupported(smallModelID) {
576 smallModelID += ":exacto"
577 }
578
579 largeModel, err := largeProvider.LanguageModel(ctx, largeModelID)
580 if err != nil {
581 return Model{}, Model{}, err
582 }
583 smallModel, err := smallProvider.LanguageModel(ctx, smallModelID)
584 if err != nil {
585 return Model{}, Model{}, err
586 }
587
588 return Model{
589 Model: largeModel,
590 CatwalkCfg: *largeCatwalkModel,
591 ModelCfg: largeModelCfg,
592 }, Model{
593 Model: smallModel,
594 CatwalkCfg: *smallCatwalkModel,
595 ModelCfg: smallModelCfg,
596 }, nil
597}
598
599func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string, providerID string) (fantasy.Provider, error) {
600 var opts []anthropic.Option
601
602 switch {
603 case strings.HasPrefix(apiKey, "Bearer "):
604 // NOTE: Prevent the SDK from picking up the API key from env.
605 os.Setenv("ANTHROPIC_API_KEY", "")
606 headers["Authorization"] = apiKey
607 case providerID == string(catwalk.InferenceProviderMiniMax) || providerID == string(catwalk.InferenceProviderMiniMaxChina):
608 // NOTE: Prevent the SDK from picking up the API key from env.
609 os.Setenv("ANTHROPIC_API_KEY", "")
610 headers["Authorization"] = "Bearer " + apiKey
611 case apiKey != "":
612 // X-Api-Key header
613 opts = append(opts, anthropic.WithAPIKey(apiKey))
614 }
615
616 if len(headers) > 0 {
617 opts = append(opts, anthropic.WithHeaders(headers))
618 }
619
620 if baseURL != "" {
621 opts = append(opts, anthropic.WithBaseURL(baseURL))
622 }
623
624 if c.cfg.Options.Debug {
625 httpClient := log.NewHTTPClient()
626 opts = append(opts, anthropic.WithHTTPClient(httpClient))
627 }
628 return anthropic.New(opts...)
629}
630
631func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
632 opts := []openai.Option{
633 openai.WithAPIKey(apiKey),
634 openai.WithUseResponsesAPI(),
635 }
636 if c.cfg.Options.Debug {
637 httpClient := log.NewHTTPClient()
638 opts = append(opts, openai.WithHTTPClient(httpClient))
639 }
640 if len(headers) > 0 {
641 opts = append(opts, openai.WithHeaders(headers))
642 }
643 if baseURL != "" {
644 opts = append(opts, openai.WithBaseURL(baseURL))
645 }
646 return openai.New(opts...)
647}
648
649func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
650 opts := []openrouter.Option{
651 openrouter.WithAPIKey(apiKey),
652 }
653 if c.cfg.Options.Debug {
654 httpClient := log.NewHTTPClient()
655 opts = append(opts, openrouter.WithHTTPClient(httpClient))
656 }
657 if len(headers) > 0 {
658 opts = append(opts, openrouter.WithHeaders(headers))
659 }
660 return openrouter.New(opts...)
661}
662
663func (c *coordinator) buildVercelProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
664 opts := []vercel.Option{
665 vercel.WithAPIKey(apiKey),
666 }
667 if c.cfg.Options.Debug {
668 httpClient := log.NewHTTPClient()
669 opts = append(opts, vercel.WithHTTPClient(httpClient))
670 }
671 if len(headers) > 0 {
672 opts = append(opts, vercel.WithHeaders(headers))
673 }
674 return vercel.New(opts...)
675}
676
677func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string, extraBody map[string]any, providerID string, isSubAgent bool) (fantasy.Provider, error) {
678 opts := []openaicompat.Option{
679 openaicompat.WithBaseURL(baseURL),
680 openaicompat.WithAPIKey(apiKey),
681 }
682
683 // Set HTTP client based on provider and debug mode.
684 var httpClient *http.Client
685 if providerID == string(catwalk.InferenceProviderCopilot) {
686 opts = append(opts, openaicompat.WithUseResponsesAPI())
687 httpClient = copilot.NewClient(isSubAgent, c.cfg.Options.Debug)
688 } else if c.cfg.Options.Debug {
689 httpClient = log.NewHTTPClient()
690 }
691 if httpClient != nil {
692 opts = append(opts, openaicompat.WithHTTPClient(httpClient))
693 }
694
695 if len(headers) > 0 {
696 opts = append(opts, openaicompat.WithHeaders(headers))
697 }
698
699 for extraKey, extraValue := range extraBody {
700 opts = append(opts, openaicompat.WithSDKOptions(openaisdk.WithJSONSet(extraKey, extraValue)))
701 }
702
703 return openaicompat.New(opts...)
704}
705
706func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
707 opts := []azure.Option{
708 azure.WithBaseURL(baseURL),
709 azure.WithAPIKey(apiKey),
710 azure.WithUseResponsesAPI(),
711 }
712 if c.cfg.Options.Debug {
713 httpClient := log.NewHTTPClient()
714 opts = append(opts, azure.WithHTTPClient(httpClient))
715 }
716 if options == nil {
717 options = make(map[string]string)
718 }
719 if apiVersion, ok := options["apiVersion"]; ok {
720 opts = append(opts, azure.WithAPIVersion(apiVersion))
721 }
722 if len(headers) > 0 {
723 opts = append(opts, azure.WithHeaders(headers))
724 }
725
726 return azure.New(opts...)
727}
728
729func (c *coordinator) buildBedrockProvider(headers map[string]string) (fantasy.Provider, error) {
730 var opts []bedrock.Option
731 if c.cfg.Options.Debug {
732 httpClient := log.NewHTTPClient()
733 opts = append(opts, bedrock.WithHTTPClient(httpClient))
734 }
735 if len(headers) > 0 {
736 opts = append(opts, bedrock.WithHeaders(headers))
737 }
738 bearerToken := os.Getenv("AWS_BEARER_TOKEN_BEDROCK")
739 if bearerToken != "" {
740 opts = append(opts, bedrock.WithAPIKey(bearerToken))
741 }
742 return bedrock.New(opts...)
743}
744
745func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
746 opts := []google.Option{
747 google.WithBaseURL(baseURL),
748 google.WithGeminiAPIKey(apiKey),
749 }
750 if c.cfg.Options.Debug {
751 httpClient := log.NewHTTPClient()
752 opts = append(opts, google.WithHTTPClient(httpClient))
753 }
754 if len(headers) > 0 {
755 opts = append(opts, google.WithHeaders(headers))
756 }
757 return google.New(opts...)
758}
759
760func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
761 opts := []google.Option{}
762 if c.cfg.Options.Debug {
763 httpClient := log.NewHTTPClient()
764 opts = append(opts, google.WithHTTPClient(httpClient))
765 }
766 if len(headers) > 0 {
767 opts = append(opts, google.WithHeaders(headers))
768 }
769
770 project := options["project"]
771 location := options["location"]
772
773 opts = append(opts, google.WithVertex(project, location))
774
775 return google.New(opts...)
776}
777
778func (c *coordinator) buildHyperProvider(baseURL, apiKey string) (fantasy.Provider, error) {
779 opts := []hyper.Option{
780 hyper.WithBaseURL(baseURL),
781 hyper.WithAPIKey(apiKey),
782 }
783 if c.cfg.Options.Debug {
784 httpClient := log.NewHTTPClient()
785 opts = append(opts, hyper.WithHTTPClient(httpClient))
786 }
787 return hyper.New(opts...)
788}
789
790func (c *coordinator) isAnthropicThinking(model config.SelectedModel) bool {
791 if model.Think {
792 return true
793 }
794 opts, err := anthropic.ParseOptions(model.ProviderOptions)
795 return err == nil && opts.Thinking != nil
796}
797
798func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model config.SelectedModel, isSubAgent bool) (fantasy.Provider, error) {
799 headers := maps.Clone(providerCfg.ExtraHeaders)
800 if headers == nil {
801 headers = make(map[string]string)
802 }
803
804 // handle special headers for anthropic
805 if providerCfg.Type == anthropic.Name && c.isAnthropicThinking(model) {
806 if v, ok := headers["anthropic-beta"]; ok {
807 headers["anthropic-beta"] = v + ",interleaved-thinking-2025-05-14"
808 } else {
809 headers["anthropic-beta"] = "interleaved-thinking-2025-05-14"
810 }
811 }
812
813 apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
814 baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
815
816 switch providerCfg.Type {
817 case openai.Name:
818 return c.buildOpenaiProvider(baseURL, apiKey, headers)
819 case anthropic.Name:
820 return c.buildAnthropicProvider(baseURL, apiKey, headers, providerCfg.ID)
821 case openrouter.Name:
822 return c.buildOpenrouterProvider(baseURL, apiKey, headers)
823 case vercel.Name:
824 return c.buildVercelProvider(baseURL, apiKey, headers)
825 case azure.Name:
826 return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
827 case bedrock.Name:
828 return c.buildBedrockProvider(headers)
829 case google.Name:
830 return c.buildGoogleProvider(baseURL, apiKey, headers)
831 case "google-vertex":
832 return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
833 case openaicompat.Name:
834 if providerCfg.ID == string(catwalk.InferenceProviderZAI) {
835 if providerCfg.ExtraBody == nil {
836 providerCfg.ExtraBody = map[string]any{}
837 }
838 providerCfg.ExtraBody["tool_stream"] = true
839 }
840 return c.buildOpenaiCompatProvider(baseURL, apiKey, headers, providerCfg.ExtraBody, providerCfg.ID, isSubAgent)
841 case hyper.Name:
842 return c.buildHyperProvider(baseURL, apiKey)
843 default:
844 return nil, fmt.Errorf("provider type not supported: %q", providerCfg.Type)
845 }
846}
847
848func isExactoSupported(modelID string) bool {
849 supportedModels := []string{
850 "moonshotai/kimi-k2-0905",
851 "deepseek/deepseek-v3.1-terminus",
852 "z-ai/glm-4.6",
853 "openai/gpt-oss-120b",
854 "qwen/qwen3-coder",
855 }
856 return slices.Contains(supportedModels, modelID)
857}
858
859func (c *coordinator) Cancel(sessionID string) {
860 c.currentAgent.Cancel(sessionID)
861}
862
863func (c *coordinator) CancelAll() {
864 c.currentAgent.CancelAll()
865}
866
867func (c *coordinator) ClearQueue(sessionID string) {
868 c.currentAgent.ClearQueue(sessionID)
869}
870
871func (c *coordinator) IsBusy() bool {
872 return c.currentAgent.IsBusy()
873}
874
875func (c *coordinator) IsSessionBusy(sessionID string) bool {
876 return c.currentAgent.IsSessionBusy(sessionID)
877}
878
879func (c *coordinator) Model() Model {
880 return c.currentAgent.Model()
881}
882
883func (c *coordinator) UpdateModels(ctx context.Context) error {
884 // build the models again so we make sure we get the latest config
885 large, small, err := c.buildAgentModels(ctx, false)
886 if err != nil {
887 return err
888 }
889 c.currentAgent.SetModels(large, small)
890
891 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
892 if !ok {
893 return errCoderAgentNotConfigured
894 }
895
896 tools, err := c.buildTools(ctx, agentCfg)
897 if err != nil {
898 return err
899 }
900 c.currentAgent.SetTools(tools)
901 return nil
902}
903
904func (c *coordinator) RefreshTools(ctx context.Context) error {
905 agentCfg, ok := c.cfg.Agents[config.AgentCoder]
906 if !ok {
907 return errors.New("coder agent not configured")
908 }
909
910 tools, err := c.buildTools(ctx, agentCfg)
911 if err != nil {
912 return err
913 }
914 c.currentAgent.SetTools(tools)
915 slog.Debug("refreshed agent tools", "count", len(tools))
916 return nil
917}
918
919func (c *coordinator) QueuedPrompts(sessionID string) int {
920 return c.currentAgent.QueuedPrompts(sessionID)
921}
922
923func (c *coordinator) QueuedPromptsList(sessionID string) []string {
924 return c.currentAgent.QueuedPromptsList(sessionID)
925}
926
927func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
928 providerCfg, ok := c.cfg.Providers.Get(c.currentAgent.Model().ModelCfg.Provider)
929 if !ok {
930 return errModelProviderNotConfigured
931 }
932 return c.currentAgent.Summarize(ctx, sessionID, getProviderOptions(c.currentAgent.Model(), providerCfg))
933}
934
935func (c *coordinator) isUnauthorized(err error) bool {
936 var providerErr *fantasy.ProviderError
937 return errors.As(err, &providerErr) && providerErr.StatusCode == http.StatusUnauthorized
938}
939
940func (c *coordinator) refreshOAuth2Token(ctx context.Context, providerCfg config.ProviderConfig) error {
941 if err := c.cfg.RefreshOAuthToken(ctx, providerCfg.ID); err != nil {
942 slog.Error("Failed to refresh OAuth token after 401 error", "provider", providerCfg.ID, "error", err)
943 return err
944 }
945 if err := c.UpdateModels(ctx); err != nil {
946 return err
947 }
948 return nil
949}
950
951func (c *coordinator) refreshApiKeyTemplate(ctx context.Context, providerCfg config.ProviderConfig) error {
952 newAPIKey, err := c.cfg.Resolve(providerCfg.APIKeyTemplate)
953 if err != nil {
954 slog.Error("Failed to re-resolve API key after 401 error", "provider", providerCfg.ID, "error", err)
955 return err
956 }
957
958 providerCfg.APIKey = newAPIKey
959 c.cfg.Providers.Set(providerCfg.ID, providerCfg)
960
961 if err := c.UpdateModels(ctx); err != nil {
962 return err
963 }
964 return nil
965}
966
967// subAgentParams holds the parameters for running a sub-agent.
968type subAgentParams struct {
969 Agent SessionAgent
970 SessionID string
971 AgentMessageID string
972 ToolCallID string
973 Prompt string
974 SessionTitle string
975 // SessionSetup is an optional callback invoked after session creation
976 // but before agent execution, for custom session configuration.
977 SessionSetup func(sessionID string)
978}
979
980// runSubAgent runs a sub-agent and handles session management and cost accumulation.
981// It creates a sub-session, runs the agent with the given prompt, and propagates
982// the cost to the parent session.
983func (c *coordinator) runSubAgent(ctx context.Context, params subAgentParams) (fantasy.ToolResponse, error) {
984 // Create sub-session
985 agentToolSessionID := c.sessions.CreateAgentToolSessionID(params.AgentMessageID, params.ToolCallID)
986 session, err := c.sessions.CreateTaskSession(ctx, agentToolSessionID, params.SessionID, params.SessionTitle)
987 if err != nil {
988 return fantasy.ToolResponse{}, fmt.Errorf("create session: %w", err)
989 }
990
991 // Call session setup function if provided
992 if params.SessionSetup != nil {
993 params.SessionSetup(session.ID)
994 }
995
996 // Get model configuration
997 model := params.Agent.Model()
998 maxTokens := model.CatwalkCfg.DefaultMaxTokens
999 if model.ModelCfg.MaxTokens != 0 {
1000 maxTokens = model.ModelCfg.MaxTokens
1001 }
1002
1003 providerCfg, ok := c.cfg.Providers.Get(model.ModelCfg.Provider)
1004 if !ok {
1005 return fantasy.ToolResponse{}, errModelProviderNotConfigured
1006 }
1007
1008 // Run the agent
1009 result, err := params.Agent.Run(ctx, SessionAgentCall{
1010 SessionID: session.ID,
1011 Prompt: params.Prompt,
1012 MaxOutputTokens: maxTokens,
1013 ProviderOptions: getProviderOptions(model, providerCfg),
1014 Temperature: model.ModelCfg.Temperature,
1015 TopP: model.ModelCfg.TopP,
1016 TopK: model.ModelCfg.TopK,
1017 FrequencyPenalty: model.ModelCfg.FrequencyPenalty,
1018 PresencePenalty: model.ModelCfg.PresencePenalty,
1019 NonInteractive: true,
1020 })
1021 if err != nil {
1022 return fantasy.NewTextErrorResponse("error generating response"), nil
1023 }
1024
1025 // Update parent session cost
1026 if err := c.updateParentSessionCost(ctx, session.ID, params.SessionID); err != nil {
1027 return fantasy.ToolResponse{}, err
1028 }
1029
1030 return fantasy.NewTextResponse(result.Response.Content.Text()), nil
1031}
1032
1033// updateParentSessionCost accumulates the cost from a child session to its parent session.
1034func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionID, parentSessionID string) error {
1035 childSession, err := c.sessions.Get(ctx, childSessionID)
1036 if err != nil {
1037 return fmt.Errorf("get child session: %w", err)
1038 }
1039
1040 parentSession, err := c.sessions.Get(ctx, parentSessionID)
1041 if err != nil {
1042 return fmt.Errorf("get parent session: %w", err)
1043 }
1044
1045 parentSession.Cost += childSession.Cost
1046
1047 if _, err := c.sessions.Save(ctx, parentSession); err != nil {
1048 return fmt.Errorf("save parent session: %w", err)
1049 }
1050
1051 return nil
1052}