internal/app/app.go 🔗
@@ -168,3 +168,7 @@ func (app *App) Shutdown() {
}
app.CoderAgent.CancelAll()
}
+
+func (app *App) UpdateAgentModel() error {
+ return app.CoderAgent.UpdateModel()
+}
Kujtim Hoxha created
internal/app/app.go | 4 +
internal/llm/agent/agent.go | 98 +++++++++++++++++++++++++++++++++++++++
internal/tui/tui.go | 7 ++
3 files changed, 109 insertions(+)
@@ -168,3 +168,7 @@ func (app *App) Shutdown() {
}
app.CoderAgent.CancelAll()
}
+
+func (app *App) UpdateAgentModel() error {
+ return app.CoderAgent.UpdateModel()
+}
@@ -56,6 +56,7 @@ type Service interface {
IsSessionBusy(sessionID string) bool
IsBusy() bool
Summarize(ctx context.Context, sessionID string) error
+ UpdateModel() error
}
type agent struct {
@@ -805,3 +806,100 @@ func (a *agent) CancelAll() {
return true
})
}
+
+func (a *agent) UpdateModel() error {
+ cfg := config.Get()
+
+ // Get current provider configuration
+ currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
+ if currentProviderCfg.ID == "" {
+ return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
+ }
+
+ // Check if provider has changed
+ if string(currentProviderCfg.ID) != a.providerID {
+ // Provider changed, need to recreate the main provider
+ model := config.GetAgentModel(a.agentCfg.ID)
+ if model.ID == "" {
+ return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
+ }
+
+ promptID := agentPromptMap[a.agentCfg.ID]
+ if promptID == "" {
+ promptID = prompt.PromptDefault
+ }
+
+ opts := []provider.ProviderClientOption{
+ provider.WithModel(a.agentCfg.Model),
+ provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
+ provider.WithMaxTokens(model.DefaultMaxTokens),
+ }
+
+ newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...)
+ if err != nil {
+ return fmt.Errorf("failed to create new provider: %w", err)
+ }
+
+ // Update the provider and provider ID
+ a.provider = newProvider
+ a.providerID = string(currentProviderCfg.ID)
+ }
+
+ // Check if small model provider has changed (affects title and summarize providers)
+ smallModelCfg := cfg.Models.Small
+ var smallModelProviderCfg config.ProviderConfig
+
+ for _, p := range cfg.Providers {
+ if p.ID == smallModelCfg.Provider {
+ smallModelProviderCfg = p
+ break
+ }
+ }
+
+ if smallModelProviderCfg.ID == "" {
+ return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
+ }
+
+ // Check if summarize provider has changed
+ if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
+ var smallModel config.Model
+ for _, m := range smallModelProviderCfg.Models {
+ if m.ID == smallModelCfg.ModelID {
+ smallModel = m
+ break
+ }
+ }
+ if smallModel.ID == "" {
+ return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
+ }
+
+ // Recreate title provider
+ titleOpts := []provider.ProviderClientOption{
+ provider.WithModel(config.SmallModel),
+ provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
+ provider.WithMaxTokens(40),
+ }
+ newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
+ if err != nil {
+ return fmt.Errorf("failed to create new title provider: %w", err)
+ }
+
+ // Recreate summarize provider
+ summarizeOpts := []provider.ProviderClientOption{
+ provider.WithModel(config.SmallModel),
+ provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
+ provider.WithMaxTokens(smallModel.DefaultMaxTokens),
+ }
+ newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
+ if err != nil {
+ return fmt.Errorf("failed to create new summarize provider: %w", err)
+ }
+
+ // Update the providers and provider ID
+ a.titleProvider = newTitleProvider
+ a.summarizeProvider = newSummarizeProvider
+ a.summarizeProviderID = string(smallModelProviderCfg.ID)
+ }
+
+ return nil
+}
@@ -173,6 +173,13 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
// Model Switch
case models.ModelSelectedMsg:
config.UpdatePreferredModel(config.LargeModel, msg.Model)
+
+ // Update the agent with the new model/provider configuration
+ if err := a.app.UpdateAgentModel(); err != nil {
+ logging.ErrorPersist(fmt.Sprintf("Failed to update agent model: %v", err))
+ return a, util.ReportError(fmt.Errorf("model changed to %s but failed to update agent: %v", msg.Model.ModelID, err))
+ }
+
return a, util.ReportInfo(fmt.Sprintf("Model changed to %s", msg.Model.ModelID))
// File Picker