Associate each thread with a model (#29573)

Max Brunsfeld and Richard Feldman created

This PR makes it possible to use different LLM models in the agent
panels of two different projects, simultaneously. It also properly
restores a thread's original model when restoring it from the history,
rather than having it use the default model. As before, newly-created
threads will use the current default model.

Release Notes:

- Enabled different project windows to use different models in the agent
panel
- Enhanced the agent panel so that when revisiting old threads, their
original model will be used.

---------

Co-authored-by: Richard Feldman <oss@rtfeldman.com>

Change summary

crates/agent/src/active_thread.rs                             | 13 
crates/agent/src/agent_diff.rs                                |  1 
crates/agent/src/assistant_model_selector.rs                  | 40 ++
crates/agent/src/assistant_panel.rs                           |  4 
crates/agent/src/inline_prompt_editor.rs                      |  4 
crates/agent/src/message_editor.rs                            | 65 +--
crates/agent/src/thread.rs                                    | 76 +++-
crates/agent/src/thread_store.rs                              |  9 
crates/agent/src/tool_use.rs                                  |  9 
crates/assistant/src/inline_assistant.rs                      |  4 
crates/assistant/src/terminal_inline_assistant.rs             |  4 
crates/assistant_context_editor/src/context_editor.rs         |  4 
crates/language_model/src/language_model.rs                   |  6 
crates/language_model/src/registry.rs                         |  2 
crates/language_model_selector/src/language_model_selector.rs | 41 -
15 files changed, 168 insertions(+), 114 deletions(-)

Detailed changes

crates/agent/src/active_thread.rs 🔗

@@ -25,8 +25,8 @@ use gpui::{
 };
 use language::{Buffer, LanguageRegistry};
 use language_model::{
-    LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent,
-    RequestUsage, Role, StopReason,
+    LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, RequestUsage, Role,
+    StopReason,
 };
 use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
 use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
@@ -1252,7 +1252,7 @@ impl ActiveThread {
         cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
         state._update_token_count_task.take();
 
-        let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
+        let Some(configured_model) = self.thread.read(cx).configured_model() else {
             state.last_estimated_token_count.take();
             return;
         };
@@ -1305,7 +1305,7 @@ impl ActiveThread {
                     temperature: None,
                 };
 
-                Some(default_model.model.count_tokens(request, cx))
+                Some(configured_model.model.count_tokens(request, cx))
             })? {
                 task.await?
             } else {
@@ -1338,7 +1338,7 @@ impl ActiveThread {
             return;
         };
         let edited_text = state.editor.read(cx).text(cx);
-        self.thread.update(cx, |thread, cx| {
+        let thread_model = self.thread.update(cx, |thread, cx| {
             thread.edit_message(
                 message_id,
                 Role::User,
@@ -1348,9 +1348,10 @@ impl ActiveThread {
             for message_id in self.messages_after(message_id) {
                 thread.delete_message(*message_id, cx);
             }
+            thread.get_or_init_configured_model(cx)
         });
 
-        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
+        let Some(model) = thread_model else {
             return;
         };
 

crates/agent/src/agent_diff.rs 🔗

@@ -951,6 +951,7 @@ mod tests {
             ThemeSettings::register(cx);
             ContextServerSettings::register(cx);
             EditorSettings::register(cx);
+            language_model::init_settings(cx);
         });
 
         let fs = FakeFs::new(cx.executor());

crates/agent/src/assistant_model_selector.rs 🔗

@@ -2,6 +2,8 @@ use assistant_settings::AssistantSettings;
 use fs::Fs;
 use gpui::{Entity, FocusHandle, SharedString};
 
+use crate::Thread;
+use language_model::{ConfiguredModel, LanguageModelRegistry};
 use language_model_selector::{
     LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
 };
@@ -9,7 +11,11 @@ use settings::update_settings_file;
 use std::sync::Arc;
 use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
 
-pub use language_model_selector::ModelType;
+#[derive(Clone)]
+pub enum ModelType {
+    Default(Entity<Thread>),
+    InlineAssistant,
+}
 
 pub struct AssistantModelSelector {
     selector: Entity<LanguageModelSelector>,
@@ -24,18 +30,39 @@ impl AssistantModelSelector {
         focus_handle: FocusHandle,
         model_type: ModelType,
         window: &mut Window,
-        cx: &mut App,
+        cx: &mut Context<Self>,
     ) -> Self {
         Self {
-            selector: cx.new(|cx| {
+            selector: cx.new(move |cx| {
                 let fs = fs.clone();
                 LanguageModelSelector::new(
+                    {
+                        let model_type = model_type.clone();
+                        move |cx| match &model_type {
+                            ModelType::Default(thread) => thread.read(cx).configured_model(),
+                            ModelType::InlineAssistant => {
+                                LanguageModelRegistry::read_global(cx).inline_assistant_model()
+                            }
+                        }
+                    },
                     move |model, cx| {
                         let provider = model.provider_id().0.to_string();
                         let model_id = model.id().0.to_string();
-
-                        match model_type {
-                            ModelType::Default => {
+                        match &model_type {
+                            ModelType::Default(thread) => {
+                                thread.update(cx, |thread, cx| {
+                                    let registry = LanguageModelRegistry::read_global(cx);
+                                    if let Some(provider) = registry.provider(&model.provider_id())
+                                    {
+                                        thread.set_configured_model(
+                                            Some(ConfiguredModel {
+                                                provider,
+                                                model: model.clone(),
+                                            }),
+                                            cx,
+                                        );
+                                    }
+                                });
                                 update_settings_file::<AssistantSettings>(
                                     fs.clone(),
                                     cx,
@@ -58,7 +85,6 @@ impl AssistantModelSelector {
                             }
                         }
                     },
-                    model_type,
                     window,
                     cx,
                 )

crates/agent/src/assistant_panel.rs 🔗

@@ -1274,12 +1274,12 @@ impl AssistantPanel {
         let is_generating = thread.is_generating();
         let message_editor = self.message_editor.read(cx);
 
-        let conversation_token_usage = thread.total_token_usage(cx);
+        let conversation_token_usage = thread.total_token_usage();
         let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) =
             self.thread.read(cx).editing_message_id()
         {
             let combined = thread
-                .token_usage_up_to_message(editing_message_id, cx)
+                .token_usage_up_to_message(editing_message_id)
                 .add(unsent_tokens);
 
             (combined, unsent_tokens > 0)

crates/agent/src/inline_prompt_editor.rs 🔗

@@ -1,4 +1,4 @@
-use crate::assistant_model_selector::AssistantModelSelector;
+use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
 use crate::buffer_codegen::BufferCodegen;
 use crate::context_picker::ContextPicker;
 use crate::context_store::ContextStore;
@@ -20,7 +20,7 @@ use gpui::{
     Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point,
 };
 use language_model::{LanguageModel, LanguageModelRegistry};
-use language_model_selector::{ModelType, ToggleModelSelector};
+use language_model_selector::ToggleModelSelector;
 use parking_lot::Mutex;
 use settings::Settings;
 use std::cmp;

crates/agent/src/message_editor.rs 🔗

@@ -1,7 +1,7 @@
 use std::collections::BTreeMap;
 use std::sync::Arc;
 
-use crate::assistant_model_selector::ModelType;
+use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
 use crate::context::{ContextLoadResult, load_context};
 use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
 use buffer_diff::BufferDiff;
@@ -21,9 +21,7 @@ use gpui::{
     Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
 };
 use language::{Buffer, Language};
-use language_model::{
-    ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage, MessageContent,
-};
+use language_model::{ConfiguredModel, LanguageModelRequestMessage, MessageContent};
 use language_model_selector::ToggleModelSelector;
 use multi_buffer;
 use project::Project;
@@ -36,7 +34,6 @@ use util::ResultExt as _;
 use workspace::Workspace;
 use zed_llm_client::CompletionMode;
 
-use crate::assistant_model_selector::AssistantModelSelector;
 use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
 use crate::context_store::ContextStore;
 use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
@@ -153,6 +150,17 @@ impl MessageEditor {
             }),
         ];
 
+        let model_selector = cx.new(|cx| {
+            AssistantModelSelector::new(
+                fs.clone(),
+                model_selector_menu_handle,
+                editor.focus_handle(cx),
+                ModelType::Default(thread.clone()),
+                window,
+                cx,
+            )
+        });
+
         Self {
             editor: editor.clone(),
             project: thread.read(cx).project().clone(),
@@ -165,16 +173,7 @@ impl MessageEditor {
             context_picker_menu_handle,
             load_context_task: None,
             last_loaded_context: None,
-            model_selector: cx.new(|cx| {
-                AssistantModelSelector::new(
-                    fs.clone(),
-                    model_selector_menu_handle,
-                    editor.focus_handle(cx),
-                    ModelType::Default,
-                    window,
-                    cx,
-                )
-            }),
+            model_selector,
             edits_expanded: false,
             editor_is_expanded: false,
             profile_selector: cx
@@ -263,15 +262,11 @@ impl MessageEditor {
         self.editor.read(cx).text(cx).trim().is_empty()
     }
 
-    fn is_model_selected(&self, cx: &App) -> bool {
-        LanguageModelRegistry::read_global(cx)
-            .default_model()
-            .is_some()
-    }
-
     fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) {
-        let model_registry = LanguageModelRegistry::read_global(cx);
-        let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
+        let Some(ConfiguredModel { model, provider }) = self
+            .thread
+            .update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
+        else {
             return;
         };
 
@@ -408,14 +403,13 @@ impl MessageEditor {
             return None;
         }
 
-        let model = LanguageModelRegistry::read_global(cx)
-            .default_model()
-            .map(|default| default.model.clone())?;
-        if !model.supports_max_mode() {
+        let thread = self.thread.read(cx);
+        let model = thread.configured_model();
+        if !model?.model.supports_max_mode() {
             return None;
         }
 
-        let active_completion_mode = self.thread.read(cx).completion_mode();
+        let active_completion_mode = thread.completion_mode();
 
         Some(
             IconButton::new("max-mode", IconName::SquarePlus)
@@ -442,24 +436,21 @@ impl MessageEditor {
         cx: &mut Context<Self>,
     ) -> Div {
         let thread = self.thread.read(cx);
+        let model = thread.configured_model();
 
         let editor_bg_color = cx.theme().colors().editor_background;
         let is_generating = thread.is_generating();
         let focus_handle = self.editor.focus_handle(cx);
 
-        let is_model_selected = self.is_model_selected(cx);
+        let is_model_selected = model.is_some();
         let is_editor_empty = self.is_editor_empty(cx);
 
-        let model = LanguageModelRegistry::read_global(cx)
-            .default_model()
-            .map(|default| default.model.clone());
-
         let incompatible_tools = model
             .as_ref()
             .map(|model| {
                 self.incompatible_tools_state.update(cx, |state, cx| {
                     state
-                        .incompatible_tools(model, cx)
+                        .incompatible_tools(&model.model, cx)
                         .iter()
                         .cloned()
                         .collect::<Vec<_>>()
@@ -1058,7 +1049,7 @@ impl MessageEditor {
         cx.emit(MessageEditorEvent::Changed);
         self.update_token_count_task.take();
 
-        let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
+        let Some(model) = self.thread.read(cx).configured_model() else {
             self.last_estimated_token_count.take();
             return;
         };
@@ -1111,7 +1102,7 @@ impl MessageEditor {
                     temperature: None,
                 };
 
-                Some(default_model.model.count_tokens(request, cx))
+                Some(model.model.count_tokens(request, cx))
             })? {
                 task.await?
             } else {
@@ -1143,7 +1134,7 @@ impl Focusable for MessageEditor {
 impl Render for MessageEditor {
     fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
         let thread = self.thread.read(cx);
-        let total_token_usage = thread.total_token_usage(cx);
+        let total_token_usage = thread.total_token_usage();
         let token_usage_ratio = total_token_usage.ratio();
 
         let action_log = self.thread.read(cx).action_log();

crates/agent/src/thread.rs 🔗

@@ -22,8 +22,8 @@ use language_model::{
     LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
     LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
     LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
-    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
-    TokenUsage,
+    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
+    StopReason, TokenUsage,
 };
 use postage::stream::Stream as _;
 use project::Project;
@@ -41,8 +41,8 @@ use zed_llm_client::CompletionMode;
 use crate::ThreadStore;
 use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
 use crate::thread_store::{
-    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
-    SerializedToolUse, SharedProjectContext,
+    SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
+    SerializedToolResult, SerializedToolUse, SharedProjectContext,
 };
 use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
 
@@ -332,6 +332,7 @@ pub struct Thread {
         Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
     >,
     remaining_turns: u32,
+    configured_model: Option<ConfiguredModel>,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -351,6 +352,8 @@ impl Thread {
         cx: &mut Context<Self>,
     ) -> Self {
         let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
+        let configured_model = LanguageModelRegistry::read_global(cx).default_model();
+
         Self {
             id: ThreadId::new(),
             updated_at: Utc::now(),
@@ -388,6 +391,7 @@ impl Thread {
             last_auto_capture_at: None,
             request_callback: None,
             remaining_turns: u32::MAX,
+            configured_model,
         }
     }
 
@@ -411,6 +415,19 @@ impl Thread {
         let (detailed_summary_tx, detailed_summary_rx) =
             postage::watch::channel_with(serialized.detailed_summary_state);
 
+        let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+            serialized
+                .model
+                .and_then(|model| {
+                    let model = SelectedModel {
+                        provider: model.provider.clone().into(),
+                        model: model.model.clone().into(),
+                    };
+                    registry.select_model(&model, cx)
+                })
+                .or_else(|| registry.default_model())
+        });
+
         Self {
             id,
             updated_at: serialized.updated_at,
@@ -468,6 +485,7 @@ impl Thread {
             last_auto_capture_at: None,
             request_callback: None,
             remaining_turns: u32::MAX,
+            configured_model,
         }
     }
 
@@ -507,6 +525,22 @@ impl Thread {
         self.project_context.clone()
     }
 
+    pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
+        if self.configured_model.is_none() {
+            self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
+        }
+        self.configured_model.clone()
+    }
+
+    pub fn configured_model(&self) -> Option<ConfiguredModel> {
+        self.configured_model.clone()
+    }
+
+    pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
+        self.configured_model = model;
+        cx.notify();
+    }
+
     pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 
     pub fn summary_or_default(&self) -> SharedString {
@@ -952,6 +986,13 @@ impl Thread {
                 request_token_usage: this.request_token_usage.clone(),
                 detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
                 exceeded_window_error: this.exceeded_window_error.clone(),
+                model: this
+                    .configured_model
+                    .as_ref()
+                    .map(|model| SerializedLanguageModel {
+                        provider: model.provider.id().0.to_string(),
+                        model: model.model.id().0.to_string(),
+                    }),
             })
         })
     }
@@ -1733,7 +1774,7 @@ impl Thread {
             tool_use_id.clone(),
             tool_name,
             Err(anyhow!("Error parsing input JSON: {error}")),
-            cx,
+            self.configured_model.as_ref(),
         );
         let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
             pending_tool_use.ui_text.clone()
@@ -1808,7 +1849,7 @@ impl Thread {
                             tool_use_id.clone(),
                             tool_name,
                             output,
-                            cx,
+                            thread.configured_model.as_ref(),
                         );
                         thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
                     })
@@ -1826,10 +1867,9 @@ impl Thread {
         cx: &mut Context<Self>,
     ) {
         if self.all_tools_finished() {
-            let model_registry = LanguageModelRegistry::read_global(cx);
-            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
+            if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
                 if !canceled {
-                    self.send_to_model(model, window, cx);
+                    self.send_to_model(model.clone(), window, cx);
                 }
                 self.auto_capture_telemetry(cx);
             }
@@ -2254,8 +2294,8 @@ impl Thread {
         self.cumulative_token_usage
     }
 
-    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
-        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
+    pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
+        let Some(model) = self.configured_model.as_ref() else {
             return TotalTokenUsage::default();
         };
 
@@ -2283,9 +2323,8 @@ impl Thread {
         }
     }
 
-    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
-        let model_registry = LanguageModelRegistry::read_global(cx);
-        let Some(model) = model_registry.default_model() else {
+    pub fn total_token_usage(&self) -> TotalTokenUsage {
+        let Some(model) = self.configured_model.as_ref() else {
             return TotalTokenUsage::default();
         };
 
@@ -2336,8 +2375,12 @@ impl Thread {
             "Permission to run tool action denied by user"
         ));
 
-        self.tool_use
-            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
+        self.tool_use.insert_tool_output(
+            tool_use_id.clone(),
+            tool_name,
+            err,
+            self.configured_model.as_ref(),
+        );
         self.tool_finished(tool_use_id.clone(), None, true, window, cx);
     }
 }
@@ -2769,6 +2812,7 @@ fn main() {{
             prompt_store::init(cx);
             thread_store::init(cx);
             workspace::init_settings(cx);
+            language_model::init_settings(cx);
             ThemeSettings::register(cx);
             ContextServerSettings::register(cx);
             EditorSettings::register(cx);

crates/agent/src/thread_store.rs 🔗

@@ -640,6 +640,14 @@ pub struct SerializedThread {
     pub detailed_summary_state: DetailedSummaryState,
     #[serde(default)]
     pub exceeded_window_error: Option<ExceededWindowError>,
+    #[serde(default)]
+    pub model: Option<SerializedLanguageModel>,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct SerializedLanguageModel {
+    pub provider: String,
+    pub model: String,
 }
 
 impl SerializedThread {
@@ -774,6 +782,7 @@ impl LegacySerializedThread {
             request_token_usage: Vec::new(),
             detailed_summary_state: DetailedSummaryState::default(),
             exceeded_window_error: None,
+            model: None,
         }
     }
 }

crates/agent/src/tool_use.rs 🔗

@@ -7,7 +7,7 @@ use futures::FutureExt as _;
 use futures::future::Shared;
 use gpui::{App, Entity, SharedString, Task};
 use language_model::{
-    LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
+    ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
     LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 };
 use ui::IconName;
@@ -353,7 +353,7 @@ impl ToolUseState {
         tool_use_id: LanguageModelToolUseId,
         tool_name: Arc<str>,
         output: Result<String>,
-        cx: &App,
+        configured_model: Option<&ConfiguredModel>,
     ) -> Option<PendingToolUse> {
         let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
 
@@ -373,13 +373,10 @@ impl ToolUseState {
 
         match output {
             Ok(tool_result) => {
-                let model_registry = LanguageModelRegistry::read_global(cx);
-
                 const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
 
                 // Protect from clearly large output
-                let tool_output_limit = model_registry
-                    .default_model()
+                let tool_output_limit = configured_model
                     .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
                     .unwrap_or(usize::MAX);
 

crates/assistant/src/inline_assistant.rs 🔗

@@ -37,7 +37,7 @@ use language_model::{
     ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
     LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
 };
-use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
+use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
 use multi_buffer::MultiBufferRow;
 use parking_lot::Mutex;
 use project::{CodeAction, LspAction, ProjectTransaction};
@@ -1759,6 +1759,7 @@ impl PromptEditor {
             language_model_selector: cx.new(|cx| {
                 let fs = fs.clone();
                 LanguageModelSelector::new(
+                    |cx| LanguageModelRegistry::read_global(cx).default_model(),
                     move |model, cx| {
                         update_settings_file::<AssistantSettings>(
                             fs.clone(),
@@ -1766,7 +1767,6 @@ impl PromptEditor {
                             move |settings, _| settings.set_model(model.clone()),
                         );
                     },
-                    ModelType::Default,
                     window,
                     cx,
                 )

crates/assistant/src/terminal_inline_assistant.rs 🔗

@@ -19,7 +19,7 @@ use language_model::{
     ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
     Role, report_assistant_event,
 };
-use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
+use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
 use prompt_store::PromptBuilder;
 use settings::{Settings, update_settings_file};
 use std::{
@@ -749,6 +749,7 @@ impl PromptEditor {
             language_model_selector: cx.new(|cx| {
                 let fs = fs.clone();
                 LanguageModelSelector::new(
+                    |cx| LanguageModelRegistry::read_global(cx).default_model(),
                     move |model, cx| {
                         update_settings_file::<AssistantSettings>(
                             fs.clone(),
@@ -756,7 +757,6 @@ impl PromptEditor {
                             move |settings, _| settings.set_model(model.clone()),
                         );
                     },
-                    ModelType::Default,
                     window,
                     cx,
                 )

crates/assistant_context_editor/src/context_editor.rs 🔗

@@ -39,7 +39,7 @@ use language_model::{
     Role,
 };
 use language_model_selector::{
-    LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector,
+    LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
 };
 use multi_buffer::MultiBufferRow;
 use picker::Picker;
@@ -291,6 +291,7 @@ impl ContextEditor {
             dragged_file_worktrees: Vec::new(),
             language_model_selector: cx.new(|cx| {
                 LanguageModelSelector::new(
+                    |cx| LanguageModelRegistry::read_global(cx).default_model(),
                     move |model, cx| {
                         update_settings_file::<AssistantSettings>(
                             fs.clone(),
@@ -298,7 +299,6 @@ impl ContextEditor {
                             move |settings, _| settings.set_model(model.clone()),
                         );
                     },
-                    ModelType::Default,
                     window,
                     cx,
                 )

crates/language_model/src/language_model.rs 🔗

@@ -39,10 +39,14 @@ pub use crate::telemetry::*;
 pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
 
 pub fn init(client: Arc<Client>, cx: &mut App) {
-    registry::init(cx);
+    init_settings(cx);
     RefreshLlmTokenListener::register(client.clone(), cx);
 }
 
+pub fn init_settings(cx: &mut App) {
+    registry::init(cx);
+}
+
 /// The availability of a [`LanguageModel`].
 #[derive(Debug, PartialEq, Eq, Clone, Copy)]
 pub enum LanguageModelAvailability {

crates/language_model/src/registry.rs 🔗

@@ -188,7 +188,7 @@ impl LanguageModelRegistry {
             .collect::<Vec<_>>();
     }
 
-    fn select_model(
+    pub fn select_model(
         &mut self,
         selected_model: &SelectedModel,
         cx: &mut Context<Self>,

crates/language_model_selector/src/language_model_selector.rs 🔗

@@ -22,7 +22,8 @@ action_with_deprecated_aliases!(
 
 const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 
-type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
+type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
+type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 
 pub struct LanguageModelSelector {
     picker: Entity<Picker<LanguageModelPickerDelegate>>,
@@ -30,16 +31,10 @@ pub struct LanguageModelSelector {
     _subscriptions: Vec<Subscription>,
 }
 
-#[derive(Clone, Copy)]
-pub enum ModelType {
-    Default,
-    InlineAssistant,
-}
-
 impl LanguageModelSelector {
     pub fn new(
-        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
-        model_type: ModelType,
+        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
+        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) -> Self {
@@ -52,9 +47,9 @@ impl LanguageModelSelector {
             language_model_selector: cx.entity().downgrade(),
             on_model_changed: on_model_changed.clone(),
             all_models: Arc::new(all_models),
-            selected_index: Self::get_active_model_index(&entries, model_type, cx),
+            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
             filtered_entries: entries,
-            model_type,
+            get_active_model: Arc::new(get_active_model),
         };
 
         let picker = cx.new(|cx| {
@@ -204,26 +199,13 @@ impl LanguageModelSelector {
     }
 
     pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
-        let model_type = self.picker.read(cx).delegate.model_type;
-        Self::active_model_by_type(model_type, cx)
-    }
-
-    fn active_model_by_type(model_type: ModelType, cx: &App) -> Option<ConfiguredModel> {
-        match model_type {
-            ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(),
-            ModelType::InlineAssistant => {
-                LanguageModelRegistry::read_global(cx).inline_assistant_model()
-            }
-        }
+        (self.picker.read(cx).delegate.get_active_model)(cx)
     }
 
     fn get_active_model_index(
         entries: &[LanguageModelPickerEntry],
-        model_type: ModelType,
-        cx: &App,
+        active_model: Option<ConfiguredModel>,
     ) -> usize {
-        let active_model = Self::active_model_by_type(model_type, cx);
-
         entries
             .iter()
             .position(|entry| {
@@ -232,7 +214,7 @@ impl LanguageModelSelector {
                         .as_ref()
                         .map(|active_model| {
                             active_model.model.id() == model.model.id()
-                                && active_model.model.provider_id() == model.model.provider_id()
+                                && active_model.provider.id() == model.model.provider_id()
                         })
                         .unwrap_or_default()
                 } else {
@@ -325,10 +307,10 @@ struct ModelInfo {
 pub struct LanguageModelPickerDelegate {
     language_model_selector: WeakEntity<LanguageModelSelector>,
     on_model_changed: OnModelChanged,
+    get_active_model: GetActiveModel,
     all_models: Arc<GroupedModels>,
     filtered_entries: Vec<LanguageModelPickerEntry>,
     selected_index: usize,
-    model_type: ModelType,
 }
 
 struct GroupedModels {
@@ -522,8 +504,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
                     .into_any_element(),
             ),
             LanguageModelPickerEntry::Model(model_info) => {
-                let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx);
-
+                let active_model = (self.get_active_model)(cx);
                 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
                 let active_model_id = active_model.map(|m| m.model.id());