From 7d0a6ad42d5551b2212af0a78966357d8269f636 Mon Sep 17 00:00:00 2001 From: Andrey Nering Date: Wed, 22 Oct 2025 14:57:13 -0300 Subject: [PATCH] refactor: adjustment based on fantasy api changes --- internal/agent/agent_tool.go | 4 +-- internal/agent/common_test.go | 32 ++++++++++++------ internal/agent/coordinator.go | 59 +++++++++++++++++----------------- internal/app/app.go | 9 +++--- internal/tui/page/chat/chat.go | 6 ++-- internal/tui/tui.go | 2 +- 6 files changed, 62 insertions(+), 50 deletions(-) diff --git a/internal/agent/agent_tool.go b/internal/agent/agent_tool.go index 03a2f0c8c8cfa53eafae47c2fcffc2b1fc36886e..64d01a95ecbc77d33e49d3fbac46bd3136c7cf15 100644 --- a/internal/agent/agent_tool.go +++ b/internal/agent/agent_tool.go @@ -25,7 +25,7 @@ const ( AgentToolName = "agent" ) -func (c *coordinator) agentTool() (fantasy.AgentTool, error) { +func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) { agentCfg, ok := c.cfg.Agents[config.AgentTask] if !ok { return nil, errors.New("task agent not configured") @@ -35,7 +35,7 @@ func (c *coordinator) agentTool() (fantasy.AgentTool, error) { return nil, err } - agent, err := c.buildAgent(prompt, agentCfg) + agent, err := c.buildAgent(ctx, prompt, agentCfg) if err != nil { return nil, err } diff --git a/internal/agent/common_test.go b/internal/agent/common_test.go index 0c1b34a6c88543862bea6550198cfb6b05b5c6c9..7e1bd5920c5d21d3d67fb7cbeab77f2cd446790d 100644 --- a/internal/agent/common_test.go +++ b/internal/agent/common_test.go @@ -47,43 +47,55 @@ type modelPair struct { } func anthropicBuilder(model string) builderFunc { - return func(_ *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) { - provider := anthropic.New( + return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) { + provider, err := anthropic.New( anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")), anthropic.WithHTTPClient(&http.Client{Transport: r}), ) - return provider.LanguageModel(model) + if err != nil { + return nil, err + } + return provider.LanguageModel(t.Context(), model) } } func openaiBuilder(model string) builderFunc { - return func(_ *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) { - provider := openai.New( + return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) { + provider, err := openai.New( openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")), openai.WithHTTPClient(&http.Client{Transport: r}), ) - return provider.LanguageModel(model) + if err != nil { + return nil, err + } + return provider.LanguageModel(t.Context(), model) } } func openRouterBuilder(model string) builderFunc { return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) { - provider := openrouter.New( + provider, err := openrouter.New( openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")), openrouter.WithHTTPClient(&http.Client{Transport: r}), ) - return provider.LanguageModel(model) + if err != nil { + return nil, err + } + return provider.LanguageModel(t.Context(), model) } } func zAIBuilder(model string) builderFunc { return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) { - provider := openaicompat.New( + provider, err := openaicompat.New( openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"), openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")), openaicompat.WithHTTPClient(&http.Client{Transport: r}), ) - return provider.LanguageModel(model) + if err != nil { + return nil, err + } + return provider.LanguageModel(t.Context(), model) } } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index ef3c0a1ab12d669b0c9d3f09a3597285c1ee512f..ef70b7cd12ef275b07a5beb6c37a3840a99d2b83 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -45,7 +45,7 @@ type Coordinator interface { ClearQueue(sessionID string) Summarize(context.Context, string) error Model() Model - UpdateModels() error + UpdateModels(ctx context.Context) error } type coordinator struct { @@ -61,6 +61,7 @@ type coordinator struct { } func NewCoordinator( + ctx context.Context, cfg *config.Config, sessions session.Service, messages message.Service, @@ -89,7 +90,7 @@ func NewCoordinator( return nil, err } - agent, err := c.buildAgent(prompt, agentCfg) + agent, err := c.buildAgent(ctx, prompt, agentCfg) if err != nil { return nil, err } @@ -255,8 +256,8 @@ func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *f return modelOptions, temp, topP, topK, freqPenalty, presPenalty } -func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) { - large, small, err := c.buildAgentModels() +func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) { + large, small, err := c.buildAgentModels(ctx) if err != nil { return nil, err } @@ -266,17 +267,17 @@ func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (Ses return nil, err } - tools, err := c.buildTools(agent) + tools, err := c.buildTools(ctx, agent) if err != nil { return nil, err } return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil } -func (c *coordinator) buildTools(agent config.Agent) ([]fantasy.AgentTool, error) { +func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) { var allTools []fantasy.AgentTool if slices.Contains(agent.AllowedTools, AgentToolName) { - agentTool, err := c.agentTool() + agentTool, err := c.agentTool(ctx) if err != nil { return nil, err } @@ -334,7 +335,7 @@ func (c *coordinator) buildTools(agent config.Agent) ([]fantasy.AgentTool, error } // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config -func (c *coordinator) buildAgentModels() (Model, Model, error) { +func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) { largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge] if !ok { return Model{}, Model{}, errors.New("large model not selected") @@ -386,11 +387,11 @@ func (c *coordinator) buildAgentModels() (Model, Model, error) { return Model{}, Model{}, errors.New("snall model not found in provider config") } - largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model) + largeModel, err := largeProvider.LanguageModel(ctx, largeModelCfg.Model) if err != nil { return Model{}, Model{}, err } - smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model) + smallModel, err := smallProvider.LanguageModel(ctx, smallModelCfg.Model) if err != nil { return Model{}, Model{}, err } @@ -406,7 +407,7 @@ func (c *coordinator) buildAgentModels() (Model, Model, error) { }, nil } -func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider { +func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) { hasBearerAuth := false for key := range headers { if strings.ToLower(key) == "authorization" { @@ -441,7 +442,7 @@ func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map return anthropic.New(opts...) } -func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider { +func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) { opts := []openai.Option{ openai.WithAPIKey(apiKey), openai.WithUseResponsesAPI(), @@ -459,7 +460,7 @@ func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[st return openai.New(opts...) } -func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) fantasy.Provider { +func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) { opts := []openrouter.Option{ openrouter.WithAPIKey(apiKey), } @@ -473,7 +474,7 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri return openrouter.New(opts...) } -func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider { +func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) { opts := []openaicompat.Option{ openaicompat.WithBaseURL(baseURL), openaicompat.WithAPIKey(apiKey), @@ -489,7 +490,7 @@ func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers return openaicompat.New(opts...) } -func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) fantasy.Provider { +func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) { opts := []azure.Option{ azure.WithBaseURL(baseURL), azure.WithAPIKey(apiKey), @@ -511,7 +512,7 @@ func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[str return azure.New(opts...) } -func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider { +func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) { opts := []google.Option{ google.WithBaseURL(baseURL), google.WithGeminiAPIKey(apiKey), @@ -526,7 +527,7 @@ func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[st return google.New(opts...) } -func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) fantasy.Provider { +func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) { opts := []google.Option{} if c.cfg.Options.Debug { httpClient := log.NewHTTPClient() @@ -574,27 +575,25 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con // TODO: make sure we have apiKey, _ := c.cfg.Resolve(providerCfg.APIKey) baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL) - var provider fantasy.Provider + switch providerCfg.Type { case openai.Name: - provider = c.buildOpenaiProvider(baseURL, apiKey, headers) + return c.buildOpenaiProvider(baseURL, apiKey, headers) case anthropic.Name: - provider = c.buildAnthropicProvider(baseURL, apiKey, headers) + return c.buildAnthropicProvider(baseURL, apiKey, headers) case openrouter.Name: - provider = c.buildOpenrouterProvider(baseURL, apiKey, headers) + return c.buildOpenrouterProvider(baseURL, apiKey, headers) case azure.Name: - provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams) + return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams) case google.Name: - provider = c.buildGoogleProvider(baseURL, apiKey, headers) - // this is not in fantasy since its just the google provider with extra stuff + return c.buildGoogleProvider(baseURL, apiKey, headers) case "vertexai": - provider = c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams) + return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams) case openaicompat.Name: - provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers) + return c.buildOpenaiCompatProvider(baseURL, apiKey, headers) default: return nil, errors.New("provider type not supported") } - return provider, nil } func (c *coordinator) Cancel(sessionID string) { @@ -621,9 +620,9 @@ func (c *coordinator) Model() Model { return c.currentAgent.Model() } -func (c *coordinator) UpdateModels() error { +func (c *coordinator) UpdateModels(ctx context.Context) error { // build the models again so we make sure we get the latest config - large, small, err := c.buildAgentModels() + large, small, err := c.buildAgentModels(ctx) if err != nil { return err } @@ -634,7 +633,7 @@ func (c *coordinator) UpdateModels() error { return errors.New("coder agent not configured") } - tools, err := c.buildTools(agentCfg) + tools, err := c.buildTools(ctx, agentCfg) if err != nil { return err } diff --git a/internal/app/app.go b/internal/app/app.go index 8309d21057a4c00ef18e2094a611d59b597ccf27..cb715124d24294394926c275a370812a539600ec 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -87,7 +87,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { // TODO: remove the concept of agent config, most likely. if cfg.IsConfigured() { - if err := app.InitCoderAgent(); err != nil { + if err := app.InitCoderAgent(ctx); err != nil { return nil, fmt.Errorf("failed to initialize coder agent: %w", err) } } else { @@ -207,8 +207,8 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool } } -func (app *App) UpdateAgentModel() error { - return app.AgentCoordinator.UpdateModels() +func (app *App) UpdateAgentModel(ctx context.Context) error { + return app.AgentCoordinator.UpdateModels(ctx) } func (app *App) setupEvents() { @@ -262,13 +262,14 @@ func setupSubscriber[T any]( }) } -func (app *App) InitCoderAgent() error { +func (app *App) InitCoderAgent(ctx context.Context) error { coderAgentCfg := app.config.Agents[config.AgentCoder] if coderAgentCfg.ID == "" { return fmt.Errorf("coder agent configuration is missing") } var err error app.AgentCoordinator, err = agent.NewCoordinator( + ctx, app.config, app.Sessions, app.Messages, diff --git a/internal/tui/page/chat/chat.go b/internal/tui/page/chat/chat.go index dc67509fa7dba992fe050540dbd3ff3f0942cdfe..356acfb8856dc3414c8cddfbcfc4a1d226b87d9f 100644 --- a/internal/tui/page/chat/chat.go +++ b/internal/tui/page/chat/chat.go @@ -346,7 +346,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) { p.splashFullScreen = true return p, p.SetSize(p.width, p.height) } - err := p.app.InitCoderAgent() + err := p.app.InitCoderAgent(context.TODO()) if err != nil { return p, util.ReportError(err) } @@ -538,7 +538,7 @@ func (p *chatPage) toggleThinking() tea.Cmd { cfg.Models[agentCfg.Model] = currentModel // Update the agent with the new configuration - if err := p.app.UpdateAgentModel(); err != nil { + if err := p.app.UpdateAgentModel(context.TODO()); err != nil { return util.InfoMsg{ Type: util.InfoTypeError, Msg: "Failed to update thinking mode: " + err.Error(), @@ -589,7 +589,7 @@ func (p *chatPage) handleReasoningEffortSelected(effort string) tea.Cmd { } // Update the agent with the new configuration - if err := p.app.UpdateAgentModel(); err != nil { + if err := p.app.UpdateAgentModel(context.TODO()); err != nil { return util.InfoMsg{ Type: util.InfoTypeError, Msg: "Failed to update reasoning effort: " + err.Error(), diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 1b9ea177b64e559594bd87b13745fc0b7675f15f..c25372295b5c906a6511e9d3fc821ea1e1f97661 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -204,7 +204,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { config.Get().UpdatePreferredModel(msg.ModelType, msg.Model) // Update the agent with the new model/provider configuration - if err := a.app.UpdateAgentModel(); err != nil { + if err := a.app.UpdateAgentModel(context.TODO()); err != nil { return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.Model, err)) }