inline assistant: Fix model picker (#29136)

Agus Zubiaga created

Release Notes:

- inline assistant: Fixed a bug where the default model would be used
even when a specific inline assistant model was configured

Change summary

crates/agent/src/assistant_model_selector.rs                  | 17 -
crates/agent/src/buffer_codegen.rs                            | 14 
crates/agent/src/inline_assistant.rs                          |  9 +
crates/agent/src/inline_prompt_editor.rs                      |  4 
crates/assistant/src/inline_assistant.rs                      |  3 
crates/assistant/src/terminal_inline_assistant.rs             |  3 
crates/assistant_context_editor/src/context_editor.rs         |  3 
crates/language_model_selector/src/language_model_selector.rs | 39 ++++
8 files changed, 61 insertions(+), 31 deletions(-)

Detailed changes

crates/agent/src/assistant_model_selector.rs 🔗

@@ -1,7 +1,7 @@
 use assistant_settings::AssistantSettings;
 use fs::Fs;
 use gpui::{Entity, FocusHandle, SharedString};
-use language_model::LanguageModelRegistry;
+
 use language_model_selector::{
     LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
 };
@@ -9,17 +9,12 @@ use settings::update_settings_file;
 use std::sync::Arc;
 use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
 
-#[derive(Clone, Copy)]
-pub enum ModelType {
-    Default,
-    InlineAssistant,
-}
+pub use language_model_selector::ModelType;
 
 pub struct AssistantModelSelector {
     selector: Entity<LanguageModelSelector>,
     menu_handle: PopoverMenuHandle<LanguageModelSelector>,
     focus_handle: FocusHandle,
-    model_type: ModelType,
 }
 
 impl AssistantModelSelector {
@@ -63,13 +58,13 @@ impl AssistantModelSelector {
                             }
                         }
                     },
+                    model_type,
                     window,
                     cx,
                 )
             }),
             menu_handle,
             focus_handle,
-            model_type,
         }
     }
 
@@ -82,11 +77,7 @@ impl Render for AssistantModelSelector {
     fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
         let focus_handle = self.focus_handle.clone();
 
-        let model_registry = LanguageModelRegistry::read_global(cx);
-        let model = match self.model_type {
-            ModelType::Default => model_registry.default_model(),
-            ModelType::InlineAssistant => model_registry.inline_assistant_model(),
-        };
+        let model = self.selector.read(cx).active_model(cx);
         let (model_name, model_icon) = match model {
             Some(model) => (model.model.name().0, Some(model.provider.icon())),
             _ => (SharedString::from("No model selected"), None),

crates/agent/src/buffer_codegen.rs 🔗

@@ -1,7 +1,7 @@
 use crate::context::attach_context_to_message;
 use crate::context_store::ContextStore;
 use crate::inline_prompt_editor::CodegenStatus;
-use anyhow::{Context as _, Result};
+use anyhow::Result;
 use client::telemetry::Telemetry;
 use collections::HashSet;
 use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
@@ -131,7 +131,12 @@ impl BufferCodegen {
         cx.notify();
     }
 
-    pub fn start(&mut self, user_prompt: String, cx: &mut Context<Self>) -> Result<()> {
+    pub fn start(
+        &mut self,
+        primary_model: Arc<dyn LanguageModel>,
+        user_prompt: String,
+        cx: &mut Context<Self>,
+    ) -> Result<()> {
         let alternative_models = LanguageModelRegistry::read_global(cx)
             .inline_alternative_models()
             .to_vec();
@@ -155,11 +160,6 @@ impl BufferCodegen {
             }));
         }
 
-        let primary_model = LanguageModelRegistry::read_global(cx)
-            .default_model()
-            .context("no active model")?
-            .model;
-
         for (model, alternative) in iter::once(primary_model)
             .chain(alternative_models)
             .zip(&self.alternatives)

crates/agent/src/inline_assistant.rs 🔗

@@ -24,6 +24,7 @@ use gpui::{
     WeakEntity, Window, point,
 };
 use language::{Buffer, Point, Selection, TransactionId};
+use language_model::ConfiguredModel;
 use language_model::{LanguageModelRegistry, report_assistant_event};
 use multi_buffer::MultiBufferRow;
 use parking_lot::Mutex;
@@ -1221,9 +1222,15 @@ impl InlineAssistant {
             self.prompt_history.pop_front();
         }
 
+        let Some(ConfiguredModel { model, .. }) =
+            LanguageModelRegistry::read_global(cx).inline_assistant_model()
+        else {
+            return;
+        };
+
         assist
             .codegen
-            .update(cx, |codegen, cx| codegen.start(user_prompt, cx))
+            .update(cx, |codegen, cx| codegen.start(model, user_prompt, cx))
             .log_err();
     }
 

crates/agent/src/inline_prompt_editor.rs 🔗

@@ -1,4 +1,4 @@
-use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
+use crate::assistant_model_selector::AssistantModelSelector;
 use crate::buffer_codegen::BufferCodegen;
 use crate::context_picker::ContextPicker;
 use crate::context_store::ContextStore;
@@ -20,7 +20,7 @@ use gpui::{
     Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point,
 };
 use language_model::{LanguageModel, LanguageModelRegistry};
-use language_model_selector::ToggleModelSelector;
+use language_model_selector::{ModelType, ToggleModelSelector};
 use parking_lot::Mutex;
 use settings::Settings;
 use std::cmp;

crates/assistant/src/inline_assistant.rs 🔗

@@ -37,7 +37,7 @@ use language_model::{
     ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
     LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
 };
-use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
+use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
 use multi_buffer::MultiBufferRow;
 use parking_lot::Mutex;
 use project::{CodeAction, LspAction, ProjectTransaction};
@@ -1766,6 +1766,7 @@ impl PromptEditor {
                             move |settings, _| settings.set_model(model.clone()),
                         );
                     },
+                    ModelType::Default,
                     window,
                     cx,
                 )

crates/assistant/src/terminal_inline_assistant.rs 🔗

@@ -19,7 +19,7 @@ use language_model::{
     ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
     Role, report_assistant_event,
 };
-use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
+use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
 use prompt_store::PromptBuilder;
 use settings::{Settings, update_settings_file};
 use std::{
@@ -755,6 +755,7 @@ impl PromptEditor {
                             move |settings, _| settings.set_model(model.clone()),
                         );
                     },
+                    ModelType::Default,
                     window,
                     cx,
                 )

crates/assistant_context_editor/src/context_editor.rs 🔗

@@ -39,7 +39,7 @@ use language_model::{
     Role,
 };
 use language_model_selector::{
-    LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
+    LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector,
 };
 use multi_buffer::MultiBufferRow;
 use picker::Picker;
@@ -298,6 +298,7 @@ impl ContextEditor {
                             move |settings, _| settings.set_model(model.clone()),
                         );
                     },
+                    ModelType::Default,
                     window,
                     cx,
                 )

crates/language_model_selector/src/language_model_selector.rs 🔗

@@ -7,7 +7,8 @@ use gpui::{
     Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
 };
 use language_model::{
-    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
+    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
+    LanguageModelRegistry,
 };
 use picker::{Picker, PickerDelegate};
 use proto::Plan;
@@ -29,9 +30,16 @@ pub struct LanguageModelSelector {
     _subscriptions: Vec<Subscription>,
 }
 
+#[derive(Clone, Copy)]
+pub enum ModelType {
+    Default,
+    InlineAssistant,
+}
+
 impl LanguageModelSelector {
     pub fn new(
         on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
+        model_type: ModelType,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Self {
@@ -44,8 +52,9 @@ impl LanguageModelSelector {
             language_model_selector: cx.entity().downgrade(),
             on_model_changed: on_model_changed.clone(),
             all_models: Arc::new(all_models),
-            selected_index: Self::get_active_model_index(&entries, cx),
+            selected_index: Self::get_active_model_index(&entries, model_type, cx),
             filtered_entries: entries,
+            model_type,
         };
 
         let picker = cx.new(|cx| {
@@ -194,8 +203,27 @@ impl LanguageModelSelector {
         }
     }
 
-    fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize {
-        let active_model = LanguageModelRegistry::read_global(cx).default_model();
+    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
+        let model_type = self.picker.read(cx).delegate.model_type;
+        Self::active_model_by_type(model_type, cx)
+    }
+
+    fn active_model_by_type(model_type: ModelType, cx: &App) -> Option<ConfiguredModel> {
+        match model_type {
+            ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(),
+            ModelType::InlineAssistant => {
+                LanguageModelRegistry::read_global(cx).inline_assistant_model()
+            }
+        }
+    }
+
+    fn get_active_model_index(
+        entries: &[LanguageModelPickerEntry],
+        model_type: ModelType,
+        cx: &App,
+    ) -> usize {
+        let active_model = Self::active_model_by_type(model_type, cx);
+
         entries
             .iter()
             .position(|entry| {
@@ -300,6 +328,7 @@ pub struct LanguageModelPickerDelegate {
     all_models: Arc<GroupedModels>,
     filtered_entries: Vec<LanguageModelPickerEntry>,
     selected_index: usize,
+    model_type: ModelType,
 }
 
 struct GroupedModels {
@@ -493,7 +522,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
                     .into_any_element(),
             ),
             LanguageModelPickerEntry::Model(model_info) => {
-                let active_model = LanguageModelRegistry::read_global(cx).default_model();
+                let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx);
 
                 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
                 let active_model_id = active_model.map(|m| m.model.id());