From 135aceb72f4f4de1694c22f653d95fd6c5015745 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Fri, 27 Jun 2025 17:46:51 +0200 Subject: [PATCH] feat: add dynamic model switching with agent provider updates --- internal/app/app.go | 4 ++ internal/llm/agent/agent.go | 98 +++++++++++++++++++++++++++++++++++++ internal/tui/tui.go | 7 +++ 3 files changed, 109 insertions(+) diff --git a/internal/app/app.go b/internal/app/app.go index b096c1b4f5612901a1cedeaa2ee758b666cda517..6dd1b9916d593c6f0e053aaef6714723f8fd5c60 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -168,3 +168,7 @@ func (app *App) Shutdown() { } app.CoderAgent.CancelAll() } + +func (app *App) UpdateAgentModel() error { + return app.CoderAgent.UpdateModel() +} diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 8c6faf8c4a06bbef5da279847cd14ce2314648cd..8312b0f8965a5d02f7ce049abff50953cc56e422 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -56,6 +56,7 @@ type Service interface { IsSessionBusy(sessionID string) bool IsBusy() bool Summarize(ctx context.Context, sessionID string) error + UpdateModel() error } type agent struct { @@ -805,3 +806,100 @@ func (a *agent) CancelAll() { return true }) } + +func (a *agent) UpdateModel() error { + cfg := config.Get() + + // Get current provider configuration + currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID) + if currentProviderCfg.ID == "" { + return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name) + } + + // Check if provider has changed + if string(currentProviderCfg.ID) != a.providerID { + // Provider changed, need to recreate the main provider + model := config.GetAgentModel(a.agentCfg.ID) + if model.ID == "" { + return fmt.Errorf("model not found for agent %s", a.agentCfg.Name) + } + + promptID := agentPromptMap[a.agentCfg.ID] + if promptID == "" { + promptID = prompt.PromptDefault + } + + opts := []provider.ProviderClientOption{ + provider.WithModel(a.agentCfg.Model), + provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)), + provider.WithMaxTokens(model.DefaultMaxTokens), + } + + newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...) + if err != nil { + return fmt.Errorf("failed to create new provider: %w", err) + } + + // Update the provider and provider ID + a.provider = newProvider + a.providerID = string(currentProviderCfg.ID) + } + + // Check if small model provider has changed (affects title and summarize providers) + smallModelCfg := cfg.Models.Small + var smallModelProviderCfg config.ProviderConfig + + for _, p := range cfg.Providers { + if p.ID == smallModelCfg.Provider { + smallModelProviderCfg = p + break + } + } + + if smallModelProviderCfg.ID == "" { + return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider) + } + + // Check if summarize provider has changed + if string(smallModelProviderCfg.ID) != a.summarizeProviderID { + var smallModel config.Model + for _, m := range smallModelProviderCfg.Models { + if m.ID == smallModelCfg.ModelID { + smallModel = m + break + } + } + if smallModel.ID == "" { + return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID) + } + + // Recreate title provider + titleOpts := []provider.ProviderClientOption{ + provider.WithModel(config.SmallModel), + provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), + provider.WithMaxTokens(40), + } + newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...) + if err != nil { + return fmt.Errorf("failed to create new title provider: %w", err) + } + + // Recreate summarize provider + summarizeOpts := []provider.ProviderClientOption{ + provider.WithModel(config.SmallModel), + provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)), + provider.WithMaxTokens(smallModel.DefaultMaxTokens), + } + newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...) + if err != nil { + return fmt.Errorf("failed to create new summarize provider: %w", err) + } + + // Update the providers and provider ID + a.titleProvider = newTitleProvider + a.summarizeProvider = newSummarizeProvider + a.summarizeProviderID = string(smallModelProviderCfg.ID) + } + + return nil +} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 032b481eeaad75531debe7dc453efe19b866dd8d..e3c974ca002529ce1ac90f420afcc5eedf2a45fd 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -173,6 +173,13 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Model Switch case models.ModelSelectedMsg: config.UpdatePreferredModel(config.LargeModel, msg.Model) + + // Update the agent with the new model/provider configuration + if err := a.app.UpdateAgentModel(); err != nil { + logging.ErrorPersist(fmt.Sprintf("Failed to update agent model: %v", err)) + return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.ModelID, err)) + } + return a, util.ReportInfo(fmt.Sprintf("Model changed to %s", msg.Model.ModelID)) // File Picker