refactor: adjustment based on fantasy api changes

Andrey Nering created

Change summary

internal/agent/agent_tool.go   |  4 +-
internal/agent/common_test.go  | 32 +++++++++++++------
internal/agent/coordinator.go  | 59 +++++++++++++++++------------------
internal/app/app.go            |  9 +++--
internal/tui/page/chat/chat.go |  6 +-
internal/tui/tui.go            |  2 
6 files changed, 62 insertions(+), 50 deletions(-)

Detailed changes

internal/agent/agent_tool.go 🔗

@@ -25,7 +25,7 @@ const (
 	AgentToolName = "agent"
 )
 
-func (c *coordinator) agentTool() (fantasy.AgentTool, error) {
+func (c *coordinator) agentTool(ctx context.Context) (fantasy.AgentTool, error) {
 	agentCfg, ok := c.cfg.Agents[config.AgentTask]
 	if !ok {
 		return nil, errors.New("task agent not configured")
@@ -35,7 +35,7 @@ func (c *coordinator) agentTool() (fantasy.AgentTool, error) {
 		return nil, err
 	}
 
-	agent, err := c.buildAgent(prompt, agentCfg)
+	agent, err := c.buildAgent(ctx, prompt, agentCfg)
 	if err != nil {
 		return nil, err
 	}

internal/agent/common_test.go 🔗

@@ -47,43 +47,55 @@ type modelPair struct {
 }
 
 func anthropicBuilder(model string) builderFunc {
-	return func(_ *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
-		provider := anthropic.New(
+	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
+		provider, err := anthropic.New(
 			anthropic.WithAPIKey(os.Getenv("CRUSH_ANTHROPIC_API_KEY")),
 			anthropic.WithHTTPClient(&http.Client{Transport: r}),
 		)
-		return provider.LanguageModel(model)
+		if err != nil {
+			return nil, err
+		}
+		return provider.LanguageModel(t.Context(), model)
 	}
 }
 
 func openaiBuilder(model string) builderFunc {
-	return func(_ *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
-		provider := openai.New(
+	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
+		provider, err := openai.New(
 			openai.WithAPIKey(os.Getenv("CRUSH_OPENAI_API_KEY")),
 			openai.WithHTTPClient(&http.Client{Transport: r}),
 		)
-		return provider.LanguageModel(model)
+		if err != nil {
+			return nil, err
+		}
+		return provider.LanguageModel(t.Context(), model)
 	}
 }
 
 func openRouterBuilder(model string) builderFunc {
 	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
-		provider := openrouter.New(
+		provider, err := openrouter.New(
 			openrouter.WithAPIKey(os.Getenv("CRUSH_OPENROUTER_API_KEY")),
 			openrouter.WithHTTPClient(&http.Client{Transport: r}),
 		)
-		return provider.LanguageModel(model)
+		if err != nil {
+			return nil, err
+		}
+		return provider.LanguageModel(t.Context(), model)
 	}
 }
 
 func zAIBuilder(model string) builderFunc {
 	return func(t *testing.T, r *recorder.Recorder) (fantasy.LanguageModel, error) {
-		provider := openaicompat.New(
+		provider, err := openaicompat.New(
 			openaicompat.WithBaseURL("https://api.z.ai/api/coding/paas/v4"),
 			openaicompat.WithAPIKey(os.Getenv("CRUSH_ZAI_API_KEY")),
 			openaicompat.WithHTTPClient(&http.Client{Transport: r}),
 		)
-		return provider.LanguageModel(model)
+		if err != nil {
+			return nil, err
+		}
+		return provider.LanguageModel(t.Context(), model)
 	}
 }
 

internal/agent/coordinator.go 🔗

@@ -45,7 +45,7 @@ type Coordinator interface {
 	ClearQueue(sessionID string)
 	Summarize(context.Context, string) error
 	Model() Model
-	UpdateModels() error
+	UpdateModels(ctx context.Context) error
 }
 
 type coordinator struct {
@@ -61,6 +61,7 @@ type coordinator struct {
 }
 
 func NewCoordinator(
+	ctx context.Context,
 	cfg *config.Config,
 	sessions session.Service,
 	messages message.Service,
@@ -89,7 +90,7 @@ func NewCoordinator(
 		return nil, err
 	}
 
-	agent, err := c.buildAgent(prompt, agentCfg)
+	agent, err := c.buildAgent(ctx, prompt, agentCfg)
 	if err != nil {
 		return nil, err
 	}
@@ -255,8 +256,8 @@ func mergeCallOptions(model Model, tp catwalk.Type) (fantasy.ProviderOptions, *f
 	return modelOptions, temp, topP, topK, freqPenalty, presPenalty
 }
 
-func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
-	large, small, err := c.buildAgentModels()
+func (c *coordinator) buildAgent(ctx context.Context, prompt *prompt.Prompt, agent config.Agent) (SessionAgent, error) {
+	large, small, err := c.buildAgentModels(ctx)
 	if err != nil {
 		return nil, err
 	}
@@ -266,17 +267,17 @@ func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (Ses
 		return nil, err
 	}
 
-	tools, err := c.buildTools(agent)
+	tools, err := c.buildTools(ctx, agent)
 	if err != nil {
 		return nil, err
 	}
 	return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil
 }
 
-func (c *coordinator) buildTools(agent config.Agent) ([]fantasy.AgentTool, error) {
+func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fantasy.AgentTool, error) {
 	var allTools []fantasy.AgentTool
 	if slices.Contains(agent.AllowedTools, AgentToolName) {
-		agentTool, err := c.agentTool()
+		agentTool, err := c.agentTool(ctx)
 		if err != nil {
 			return nil, err
 		}
@@ -334,7 +335,7 @@ func (c *coordinator) buildTools(agent config.Agent) ([]fantasy.AgentTool, error
 }
 
 // TODO: when we support multiple agents we need to change this so that we pass in the agent specific model config
-func (c *coordinator) buildAgentModels() (Model, Model, error) {
+func (c *coordinator) buildAgentModels(ctx context.Context) (Model, Model, error) {
 	largeModelCfg, ok := c.cfg.Models[config.SelectedModelTypeLarge]
 	if !ok {
 		return Model{}, Model{}, errors.New("large model not selected")
@@ -386,11 +387,11 @@ func (c *coordinator) buildAgentModels() (Model, Model, error) {
 		return Model{}, Model{}, errors.New("snall model not found in provider config")
 	}
 
-	largeModel, err := largeProvider.LanguageModel(largeModelCfg.Model)
+	largeModel, err := largeProvider.LanguageModel(ctx, largeModelCfg.Model)
 	if err != nil {
 		return Model{}, Model{}, err
 	}
-	smallModel, err := smallProvider.LanguageModel(smallModelCfg.Model)
+	smallModel, err := smallProvider.LanguageModel(ctx, smallModelCfg.Model)
 	if err != nil {
 		return Model{}, Model{}, err
 	}
@@ -406,7 +407,7 @@ func (c *coordinator) buildAgentModels() (Model, Model, error) {
 		}, nil
 }
 
-func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
+func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 	hasBearerAuth := false
 	for key := range headers {
 		if strings.ToLower(key) == "authorization" {
@@ -441,7 +442,7 @@ func (c *coordinator) buildAnthropicProvider(baseURL, apiKey string, headers map
 	return anthropic.New(opts...)
 }
 
-func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
+func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 	opts := []openai.Option{
 		openai.WithAPIKey(apiKey),
 		openai.WithUseResponsesAPI(),
@@ -459,7 +460,7 @@ func (c *coordinator) buildOpenaiProvider(baseURL, apiKey string, headers map[st
 	return openai.New(opts...)
 }
 
-func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) fantasy.Provider {
+func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 	opts := []openrouter.Option{
 		openrouter.WithAPIKey(apiKey),
 	}
@@ -473,7 +474,7 @@ func (c *coordinator) buildOpenrouterProvider(_, apiKey string, headers map[stri
 	return openrouter.New(opts...)
 }
 
-func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
+func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 	opts := []openaicompat.Option{
 		openaicompat.WithBaseURL(baseURL),
 		openaicompat.WithAPIKey(apiKey),
@@ -489,7 +490,7 @@ func (c *coordinator) buildOpenaiCompatProvider(baseURL, apiKey string, headers
 	return openaicompat.New(opts...)
 }
 
-func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) fantasy.Provider {
+func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 	opts := []azure.Option{
 		azure.WithBaseURL(baseURL),
 		azure.WithAPIKey(apiKey),
@@ -511,7 +512,7 @@ func (c *coordinator) buildAzureProvider(baseURL, apiKey string, headers map[str
 	return azure.New(opts...)
 }
 
-func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) fantasy.Provider {
+func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[string]string) (fantasy.Provider, error) {
 	opts := []google.Option{
 		google.WithBaseURL(baseURL),
 		google.WithGeminiAPIKey(apiKey),
@@ -526,7 +527,7 @@ func (c *coordinator) buildGoogleProvider(baseURL, apiKey string, headers map[st
 	return google.New(opts...)
 }
 
-func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) fantasy.Provider {
+func (c *coordinator) buildGoogleVertexProvider(headers map[string]string, options map[string]string) (fantasy.Provider, error) {
 	opts := []google.Option{}
 	if c.cfg.Options.Debug {
 		httpClient := log.NewHTTPClient()
@@ -574,27 +575,25 @@ func (c *coordinator) buildProvider(providerCfg config.ProviderConfig, model con
 	// TODO: make sure we have
 	apiKey, _ := c.cfg.Resolve(providerCfg.APIKey)
 	baseURL, _ := c.cfg.Resolve(providerCfg.BaseURL)
-	var provider fantasy.Provider
+
 	switch providerCfg.Type {
 	case openai.Name:
-		provider = c.buildOpenaiProvider(baseURL, apiKey, headers)
+		return c.buildOpenaiProvider(baseURL, apiKey, headers)
 	case anthropic.Name:
-		provider = c.buildAnthropicProvider(baseURL, apiKey, headers)
+		return c.buildAnthropicProvider(baseURL, apiKey, headers)
 	case openrouter.Name:
-		provider = c.buildOpenrouterProvider(baseURL, apiKey, headers)
+		return c.buildOpenrouterProvider(baseURL, apiKey, headers)
 	case azure.Name:
-		provider = c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
+		return c.buildAzureProvider(baseURL, apiKey, headers, providerCfg.ExtraParams)
 	case google.Name:
-		provider = c.buildGoogleProvider(baseURL, apiKey, headers)
-	// this is not in fantasy since its just the google provider with extra stuff
+		return c.buildGoogleProvider(baseURL, apiKey, headers)
 	case "vertexai":
-		provider = c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
+		return c.buildGoogleVertexProvider(headers, providerCfg.ExtraParams)
 	case openaicompat.Name:
-		provider = c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
+		return c.buildOpenaiCompatProvider(baseURL, apiKey, headers)
 	default:
 		return nil, errors.New("provider type not supported")
 	}
-	return provider, nil
 }
 
 func (c *coordinator) Cancel(sessionID string) {
@@ -621,9 +620,9 @@ func (c *coordinator) Model() Model {
 	return c.currentAgent.Model()
 }
 
-func (c *coordinator) UpdateModels() error {
+func (c *coordinator) UpdateModels(ctx context.Context) error {
 	// build the models again so we make sure we get the latest config
-	large, small, err := c.buildAgentModels()
+	large, small, err := c.buildAgentModels(ctx)
 	if err != nil {
 		return err
 	}
@@ -634,7 +633,7 @@ func (c *coordinator) UpdateModels() error {
 		return errors.New("coder agent not configured")
 	}
 
-	tools, err := c.buildTools(agentCfg)
+	tools, err := c.buildTools(ctx, agentCfg)
 	if err != nil {
 		return err
 	}

internal/app/app.go 🔗

@@ -87,7 +87,7 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
 
 	// TODO: remove the concept of agent config, most likely.
 	if cfg.IsConfigured() {
-		if err := app.InitCoderAgent(); err != nil {
+		if err := app.InitCoderAgent(ctx); err != nil {
 			return nil, fmt.Errorf("failed to initialize coder agent: %w", err)
 		}
 	} else {
@@ -207,8 +207,8 @@ func (app *App) RunNonInteractive(ctx context.Context, prompt string, quiet bool
 	}
 }
 
-func (app *App) UpdateAgentModel() error {
-	return app.AgentCoordinator.UpdateModels()
+func (app *App) UpdateAgentModel(ctx context.Context) error {
+	return app.AgentCoordinator.UpdateModels(ctx)
 }
 
 func (app *App) setupEvents() {
@@ -262,13 +262,14 @@ func setupSubscriber[T any](
 	})
 }
 
-func (app *App) InitCoderAgent() error {
+func (app *App) InitCoderAgent(ctx context.Context) error {
 	coderAgentCfg := app.config.Agents[config.AgentCoder]
 	if coderAgentCfg.ID == "" {
 		return fmt.Errorf("coder agent configuration is missing")
 	}
 	var err error
 	app.AgentCoordinator, err = agent.NewCoordinator(
+		ctx,
 		app.config,
 		app.Sessions,
 		app.Messages,

internal/tui/page/chat/chat.go 🔗

@@ -346,7 +346,7 @@ func (p *chatPage) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 			p.splashFullScreen = true
 			return p, p.SetSize(p.width, p.height)
 		}
-		err := p.app.InitCoderAgent()
+		err := p.app.InitCoderAgent(context.TODO())
 		if err != nil {
 			return p, util.ReportError(err)
 		}
@@ -538,7 +538,7 @@ func (p *chatPage) toggleThinking() tea.Cmd {
 		cfg.Models[agentCfg.Model] = currentModel
 
 		// Update the agent with the new configuration
-		if err := p.app.UpdateAgentModel(); err != nil {
+		if err := p.app.UpdateAgentModel(context.TODO()); err != nil {
 			return util.InfoMsg{
 				Type: util.InfoTypeError,
 				Msg:  "Failed to update thinking mode: " + err.Error(),
@@ -589,7 +589,7 @@ func (p *chatPage) handleReasoningEffortSelected(effort string) tea.Cmd {
 		}
 
 		// Update the agent with the new configuration
-		if err := p.app.UpdateAgentModel(); err != nil {
+		if err := p.app.UpdateAgentModel(context.TODO()); err != nil {
 			return util.InfoMsg{
 				Type: util.InfoTypeError,
 				Msg:  "Failed to update reasoning effort: " + err.Error(),

internal/tui/tui.go 🔗

@@ -204,7 +204,7 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		config.Get().UpdatePreferredModel(msg.ModelType, msg.Model)
 
 		// Update the agent with the new model/provider configuration
-		if err := a.app.UpdateAgentModel(); err != nil {
+		if err := a.app.UpdateAgentModel(context.TODO()); err != nil {
 			return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.Model, err))
 		}