From 17903a09990ec5eae0f4f39aa806e3d106e22342 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 28 Apr 2025 16:43:16 -0700 Subject: [PATCH] Associate each thread with a model (#29573) This PR makes it possible to use different LLM models in the agent panels of two different projects, simultaneously. It also properly restores a thread's original model when restoring it from the history, rather than having it use the default model. As before, newly-created threads will use the current default model. Release Notes: - Enabled different project windows to use different models in the agent panel - Enhanced the agent panel so that when revisiting old threads, their original model will be used. --------- Co-authored-by: Richard Feldman --- crates/agent/src/active_thread.rs | 13 ++-- crates/agent/src/agent_diff.rs | 1 + crates/agent/src/assistant_model_selector.rs | 40 ++++++++-- crates/agent/src/assistant_panel.rs | 4 +- crates/agent/src/inline_prompt_editor.rs | 4 +- crates/agent/src/message_editor.rs | 65 +++++++--------- crates/agent/src/thread.rs | 76 +++++++++++++++---- crates/agent/src/thread_store.rs | 9 +++ crates/agent/src/tool_use.rs | 9 +-- crates/assistant/src/inline_assistant.rs | 4 +- .../src/terminal_inline_assistant.rs | 4 +- .../src/context_editor.rs | 4 +- crates/language_model/src/language_model.rs | 6 +- crates/language_model/src/registry.rs | 2 +- .../src/language_model_selector.rs | 41 +++------- 15 files changed, 168 insertions(+), 114 deletions(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index e8e5351c1aa7c7df3cd66e9d56055f55046c003c..e4c528e0046c597b8d8aa4c0f8f427cca3126294 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -25,8 +25,8 @@ use gpui::{ }; use language::{Buffer, LanguageRegistry}; use language_model::{ - LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, - RequestUsage, Role, StopReason, + LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, RequestUsage, Role, + StopReason, }; use markdown::parser::{CodeBlockKind, CodeBlockMetadata}; use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown}; @@ -1252,7 +1252,7 @@ impl ActiveThread { cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged); state._update_token_count_task.take(); - let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else { + let Some(configured_model) = self.thread.read(cx).configured_model() else { state.last_estimated_token_count.take(); return; }; @@ -1305,7 +1305,7 @@ impl ActiveThread { temperature: None, }; - Some(default_model.model.count_tokens(request, cx)) + Some(configured_model.model.count_tokens(request, cx)) })? { task.await? } else { @@ -1338,7 +1338,7 @@ impl ActiveThread { return; }; let edited_text = state.editor.read(cx).text(cx); - self.thread.update(cx, |thread, cx| { + let thread_model = self.thread.update(cx, |thread, cx| { thread.edit_message( message_id, Role::User, @@ -1348,9 +1348,10 @@ impl ActiveThread { for message_id in self.messages_after(message_id) { thread.delete_message(*message_id, cx); } + thread.get_or_init_configured_model(cx) }); - let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { + let Some(model) = thread_model else { return; }; diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index b09c0015c5c357d49b75795d949a431a93ee2a5c..2d7ab8df4a33b5dd29ac87c8d361a16cb9b95c4c 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -951,6 +951,7 @@ mod tests { ThemeSettings::register(cx); ContextServerSettings::register(cx); EditorSettings::register(cx); + language_model::init_settings(cx); }); let fs = FakeFs::new(cx.executor()); diff --git a/crates/agent/src/assistant_model_selector.rs b/crates/agent/src/assistant_model_selector.rs index f63eb588bc104b5982bf354658b7f33aee5ac48f..fc8d035293f63b2d36e3f8fe8511ae159bbb81bf 100644 --- a/crates/agent/src/assistant_model_selector.rs +++ b/crates/agent/src/assistant_model_selector.rs @@ -2,6 +2,8 @@ use assistant_settings::AssistantSettings; use fs::Fs; use gpui::{Entity, FocusHandle, SharedString}; +use crate::Thread; +use language_model::{ConfiguredModel, LanguageModelRegistry}; use language_model_selector::{ LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector, }; @@ -9,7 +11,11 @@ use settings::update_settings_file; use std::sync::Arc; use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*}; -pub use language_model_selector::ModelType; +#[derive(Clone)] +pub enum ModelType { + Default(Entity), + InlineAssistant, +} pub struct AssistantModelSelector { selector: Entity, @@ -24,18 +30,39 @@ impl AssistantModelSelector { focus_handle: FocusHandle, model_type: ModelType, window: &mut Window, - cx: &mut App, + cx: &mut Context, ) -> Self { Self { - selector: cx.new(|cx| { + selector: cx.new(move |cx| { let fs = fs.clone(); LanguageModelSelector::new( + { + let model_type = model_type.clone(); + move |cx| match &model_type { + ModelType::Default(thread) => thread.read(cx).configured_model(), + ModelType::InlineAssistant => { + LanguageModelRegistry::read_global(cx).inline_assistant_model() + } + } + }, move |model, cx| { let provider = model.provider_id().0.to_string(); let model_id = model.id().0.to_string(); - - match model_type { - ModelType::Default => { + match &model_type { + ModelType::Default(thread) => { + thread.update(cx, |thread, cx| { + let registry = LanguageModelRegistry::read_global(cx); + if let Some(provider) = registry.provider(&model.provider_id()) + { + thread.set_configured_model( + Some(ConfiguredModel { + provider, + model: model.clone(), + }), + cx, + ); + } + }); update_settings_file::( fs.clone(), cx, @@ -58,7 +85,6 @@ impl AssistantModelSelector { } } }, - model_type, window, cx, ) diff --git a/crates/agent/src/assistant_panel.rs b/crates/agent/src/assistant_panel.rs index ce96388dfffecad0e19dac9261f75adee2a08f07..3308f87a45a5495e92d43192dc6009fcc2664f74 100644 --- a/crates/agent/src/assistant_panel.rs +++ b/crates/agent/src/assistant_panel.rs @@ -1274,12 +1274,12 @@ impl AssistantPanel { let is_generating = thread.is_generating(); let message_editor = self.message_editor.read(cx); - let conversation_token_usage = thread.total_token_usage(cx); + let conversation_token_usage = thread.total_token_usage(); let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) = self.thread.read(cx).editing_message_id() { let combined = thread - .token_usage_up_to_message(editing_message_id, cx) + .token_usage_up_to_message(editing_message_id) .add(unsent_tokens); (combined, unsent_tokens > 0) diff --git a/crates/agent/src/inline_prompt_editor.rs b/crates/agent/src/inline_prompt_editor.rs index 3c25ffb46ddcebd5a0c97e93930d3d70e82c3597..66ff1d94ebbb39aded1c7993a18c111a50008bde 100644 --- a/crates/agent/src/inline_prompt_editor.rs +++ b/crates/agent/src/inline_prompt_editor.rs @@ -1,4 +1,4 @@ -use crate::assistant_model_selector::AssistantModelSelector; +use crate::assistant_model_selector::{AssistantModelSelector, ModelType}; use crate::buffer_codegen::BufferCodegen; use crate::context_picker::ContextPicker; use crate::context_store::ContextStore; @@ -20,7 +20,7 @@ use gpui::{ Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point, }; use language_model::{LanguageModel, LanguageModelRegistry}; -use language_model_selector::{ModelType, ToggleModelSelector}; +use language_model_selector::ToggleModelSelector; use parking_lot::Mutex; use settings::Settings; use std::cmp; diff --git a/crates/agent/src/message_editor.rs b/crates/agent/src/message_editor.rs index 65ae2db0b57c0c08b1a409489b8417f1d80ab2f9..4b3a98c90c4e51c72bc77e0d9cd67a1db1715676 100644 --- a/crates/agent/src/message_editor.rs +++ b/crates/agent/src/message_editor.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use std::sync::Arc; -use crate::assistant_model_selector::ModelType; +use crate::assistant_model_selector::{AssistantModelSelector, ModelType}; use crate::context::{ContextLoadResult, load_context}; use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip}; use buffer_diff::BufferDiff; @@ -21,9 +21,7 @@ use gpui::{ Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between, }; use language::{Buffer, Language}; -use language_model::{ - ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage, MessageContent, -}; +use language_model::{ConfiguredModel, LanguageModelRequestMessage, MessageContent}; use language_model_selector::ToggleModelSelector; use multi_buffer; use project::Project; @@ -36,7 +34,6 @@ use util::ResultExt as _; use workspace::Workspace; use zed_llm_client::CompletionMode; -use crate::assistant_model_selector::AssistantModelSelector; use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider}; use crate::context_store::ContextStore; use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; @@ -153,6 +150,17 @@ impl MessageEditor { }), ]; + let model_selector = cx.new(|cx| { + AssistantModelSelector::new( + fs.clone(), + model_selector_menu_handle, + editor.focus_handle(cx), + ModelType::Default(thread.clone()), + window, + cx, + ) + }); + Self { editor: editor.clone(), project: thread.read(cx).project().clone(), @@ -165,16 +173,7 @@ impl MessageEditor { context_picker_menu_handle, load_context_task: None, last_loaded_context: None, - model_selector: cx.new(|cx| { - AssistantModelSelector::new( - fs.clone(), - model_selector_menu_handle, - editor.focus_handle(cx), - ModelType::Default, - window, - cx, - ) - }), + model_selector, edits_expanded: false, editor_is_expanded: false, profile_selector: cx @@ -263,15 +262,11 @@ impl MessageEditor { self.editor.read(cx).text(cx).trim().is_empty() } - fn is_model_selected(&self, cx: &App) -> bool { - LanguageModelRegistry::read_global(cx) - .default_model() - .is_some() - } - fn send_to_model(&mut self, window: &mut Window, cx: &mut Context) { - let model_registry = LanguageModelRegistry::read_global(cx); - let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else { + let Some(ConfiguredModel { model, provider }) = self + .thread + .update(cx, |thread, cx| thread.get_or_init_configured_model(cx)) + else { return; }; @@ -408,14 +403,13 @@ impl MessageEditor { return None; } - let model = LanguageModelRegistry::read_global(cx) - .default_model() - .map(|default| default.model.clone())?; - if !model.supports_max_mode() { + let thread = self.thread.read(cx); + let model = thread.configured_model(); + if !model?.model.supports_max_mode() { return None; } - let active_completion_mode = self.thread.read(cx).completion_mode(); + let active_completion_mode = thread.completion_mode(); Some( IconButton::new("max-mode", IconName::SquarePlus) @@ -442,24 +436,21 @@ impl MessageEditor { cx: &mut Context, ) -> Div { let thread = self.thread.read(cx); + let model = thread.configured_model(); let editor_bg_color = cx.theme().colors().editor_background; let is_generating = thread.is_generating(); let focus_handle = self.editor.focus_handle(cx); - let is_model_selected = self.is_model_selected(cx); + let is_model_selected = model.is_some(); let is_editor_empty = self.is_editor_empty(cx); - let model = LanguageModelRegistry::read_global(cx) - .default_model() - .map(|default| default.model.clone()); - let incompatible_tools = model .as_ref() .map(|model| { self.incompatible_tools_state.update(cx, |state, cx| { state - .incompatible_tools(model, cx) + .incompatible_tools(&model.model, cx) .iter() .cloned() .collect::>() @@ -1058,7 +1049,7 @@ impl MessageEditor { cx.emit(MessageEditorEvent::Changed); self.update_token_count_task.take(); - let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else { + let Some(model) = self.thread.read(cx).configured_model() else { self.last_estimated_token_count.take(); return; }; @@ -1111,7 +1102,7 @@ impl MessageEditor { temperature: None, }; - Some(default_model.model.count_tokens(request, cx)) + Some(model.model.count_tokens(request, cx)) })? { task.await? } else { @@ -1143,7 +1134,7 @@ impl Focusable for MessageEditor { impl Render for MessageEditor { fn render(&mut self, window: &mut Window, cx: &mut Context) -> impl IntoElement { let thread = self.thread.read(cx); - let total_token_usage = thread.total_token_usage(cx); + let total_token_usage = thread.total_token_usage(); let token_usage_ratio = total_token_usage.ratio(); let action_log = self.thread.read(cx).action_log(); diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index a723d082c1f5901d6bf0413ece10b14c4cbd39fe..d8008915357c9da1d0e0976ab6473540da051000 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -22,8 +22,8 @@ use language_model::{ LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, - ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason, - TokenUsage, + ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel, + StopReason, TokenUsage, }; use postage::stream::Stream as _; use project::Project; @@ -41,8 +41,8 @@ use zed_llm_client::CompletionMode; use crate::ThreadStore; use crate::context::{AgentContext, ContextLoadResult, LoadedContext}; use crate::thread_store::{ - SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult, - SerializedToolUse, SharedProjectContext, + SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread, + SerializedToolResult, SerializedToolUse, SharedProjectContext, }; use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState}; @@ -332,6 +332,7 @@ pub struct Thread { Box])>, >, remaining_turns: u32, + configured_model: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -351,6 +352,8 @@ impl Thread { cx: &mut Context, ) -> Self { let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel(); + let configured_model = LanguageModelRegistry::read_global(cx).default_model(); + Self { id: ThreadId::new(), updated_at: Utc::now(), @@ -388,6 +391,7 @@ impl Thread { last_auto_capture_at: None, request_callback: None, remaining_turns: u32::MAX, + configured_model, } } @@ -411,6 +415,19 @@ impl Thread { let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel_with(serialized.detailed_summary_state); + let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + serialized + .model + .and_then(|model| { + let model = SelectedModel { + provider: model.provider.clone().into(), + model: model.model.clone().into(), + }; + registry.select_model(&model, cx) + }) + .or_else(|| registry.default_model()) + }); + Self { id, updated_at: serialized.updated_at, @@ -468,6 +485,7 @@ impl Thread { last_auto_capture_at: None, request_callback: None, remaining_turns: u32::MAX, + configured_model, } } @@ -507,6 +525,22 @@ impl Thread { self.project_context.clone() } + pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option { + if self.configured_model.is_none() { + self.configured_model = LanguageModelRegistry::read_global(cx).default_model(); + } + self.configured_model.clone() + } + + pub fn configured_model(&self) -> Option { + self.configured_model.clone() + } + + pub fn set_configured_model(&mut self, model: Option, cx: &mut Context) { + self.configured_model = model; + cx.notify(); + } + pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread"); pub fn summary_or_default(&self) -> SharedString { @@ -952,6 +986,13 @@ impl Thread { request_token_usage: this.request_token_usage.clone(), detailed_summary_state: this.detailed_summary_rx.borrow().clone(), exceeded_window_error: this.exceeded_window_error.clone(), + model: this + .configured_model + .as_ref() + .map(|model| SerializedLanguageModel { + provider: model.provider.id().0.to_string(), + model: model.model.id().0.to_string(), + }), }) }) } @@ -1733,7 +1774,7 @@ impl Thread { tool_use_id.clone(), tool_name, Err(anyhow!("Error parsing input JSON: {error}")), - cx, + self.configured_model.as_ref(), ); let ui_text = if let Some(pending_tool_use) = &pending_tool_use { pending_tool_use.ui_text.clone() @@ -1808,7 +1849,7 @@ impl Thread { tool_use_id.clone(), tool_name, output, - cx, + thread.configured_model.as_ref(), ); thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx); }) @@ -1826,10 +1867,9 @@ impl Thread { cx: &mut Context, ) { if self.all_tools_finished() { - let model_registry = LanguageModelRegistry::read_global(cx); - if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() { + if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() { if !canceled { - self.send_to_model(model, window, cx); + self.send_to_model(model.clone(), window, cx); } self.auto_capture_telemetry(cx); } @@ -2254,8 +2294,8 @@ impl Thread { self.cumulative_token_usage } - pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage { - let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else { + pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage { + let Some(model) = self.configured_model.as_ref() else { return TotalTokenUsage::default(); }; @@ -2283,9 +2323,8 @@ impl Thread { } } - pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage { - let model_registry = LanguageModelRegistry::read_global(cx); - let Some(model) = model_registry.default_model() else { + pub fn total_token_usage(&self) -> TotalTokenUsage { + let Some(model) = self.configured_model.as_ref() else { return TotalTokenUsage::default(); }; @@ -2336,8 +2375,12 @@ impl Thread { "Permission to run tool action denied by user" )); - self.tool_use - .insert_tool_output(tool_use_id.clone(), tool_name, err, cx); + self.tool_use.insert_tool_output( + tool_use_id.clone(), + tool_name, + err, + self.configured_model.as_ref(), + ); self.tool_finished(tool_use_id.clone(), None, true, window, cx); } } @@ -2769,6 +2812,7 @@ fn main() {{ prompt_store::init(cx); thread_store::init(cx); workspace::init_settings(cx); + language_model::init_settings(cx); ThemeSettings::register(cx); ContextServerSettings::register(cx); EditorSettings::register(cx); diff --git a/crates/agent/src/thread_store.rs b/crates/agent/src/thread_store.rs index 9b42737be67ddd66761ad51e7d29a5955c211950..8465144b3913d87a5c3e3239966c33ce512dd7f5 100644 --- a/crates/agent/src/thread_store.rs +++ b/crates/agent/src/thread_store.rs @@ -640,6 +640,14 @@ pub struct SerializedThread { pub detailed_summary_state: DetailedSummaryState, #[serde(default)] pub exceeded_window_error: Option, + #[serde(default)] + pub model: Option, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct SerializedLanguageModel { + pub provider: String, + pub model: String, } impl SerializedThread { @@ -774,6 +782,7 @@ impl LegacySerializedThread { request_token_usage: Vec::new(), detailed_summary_state: DetailedSummaryState::default(), exceeded_window_error: None, + model: None, } } } diff --git a/crates/agent/src/tool_use.rs b/crates/agent/src/tool_use.rs index 9b5b2f02d9c24d9cec7e87bcdafb7e9f2b225235..b7a1747726ed0680ba860b6e10b6fcf66ecf8dd6 100644 --- a/crates/agent/src/tool_use.rs +++ b/crates/agent/src/tool_use.rs @@ -7,7 +7,7 @@ use futures::FutureExt as _; use futures::future::Shared; use gpui::{App, Entity, SharedString, Task}; use language_model::{ - LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult, + ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role, }; use ui::IconName; @@ -353,7 +353,7 @@ impl ToolUseState { tool_use_id: LanguageModelToolUseId, tool_name: Arc, output: Result, - cx: &App, + configured_model: Option<&ConfiguredModel>, ) -> Option { let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id); @@ -373,13 +373,10 @@ impl ToolUseState { match output { Ok(tool_result) => { - let model_registry = LanguageModelRegistry::read_global(cx); - const BYTES_PER_TOKEN_ESTIMATE: usize = 3; // Protect from clearly large output - let tool_output_limit = model_registry - .default_model() + let tool_output_limit = configured_model .map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE) .unwrap_or(usize::MAX); diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 6495bea21d4de758aabbfe47f00add5d82d93154..852f5b1c7c60a26686924aee0d10afcd375b8d86 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -37,7 +37,7 @@ use language_model::{ ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event, }; -use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType}; +use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use project::{CodeAction, LspAction, ProjectTransaction}; @@ -1759,6 +1759,7 @@ impl PromptEditor { language_model_selector: cx.new(|cx| { let fs = fs.clone(); LanguageModelSelector::new( + |cx| LanguageModelRegistry::read_global(cx).default_model(), move |model, cx| { update_settings_file::( fs.clone(), @@ -1766,7 +1767,6 @@ impl PromptEditor { move |settings, _| settings.set_model(model.clone()), ); }, - ModelType::Default, window, cx, ) diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index fb47173acd0cf67e3ec7cc0a05b48152e38a1805..12733ba038b5546237e471cb59a36a7ee836b472 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -19,7 +19,7 @@ use language_model::{ ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role, report_assistant_event, }; -use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType}; +use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu}; use prompt_store::PromptBuilder; use settings::{Settings, update_settings_file}; use std::{ @@ -749,6 +749,7 @@ impl PromptEditor { language_model_selector: cx.new(|cx| { let fs = fs.clone(); LanguageModelSelector::new( + |cx| LanguageModelRegistry::read_global(cx).default_model(), move |model, cx| { update_settings_file::( fs.clone(), @@ -756,7 +757,6 @@ impl PromptEditor { move |settings, _| settings.set_model(model.clone()), ); }, - ModelType::Default, window, cx, ) diff --git a/crates/assistant_context_editor/src/context_editor.rs b/crates/assistant_context_editor/src/context_editor.rs index b2dd4c95a10708def0127080c4ca6b6884734319..975789f01bcd727ea2599eba85eab61806f96e7e 100644 --- a/crates/assistant_context_editor/src/context_editor.rs +++ b/crates/assistant_context_editor/src/context_editor.rs @@ -39,7 +39,7 @@ use language_model::{ Role, }; use language_model_selector::{ - LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector, + LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector, }; use multi_buffer::MultiBufferRow; use picker::Picker; @@ -291,6 +291,7 @@ impl ContextEditor { dragged_file_worktrees: Vec::new(), language_model_selector: cx.new(|cx| { LanguageModelSelector::new( + |cx| LanguageModelRegistry::read_global(cx).default_model(), move |model, cx| { update_settings_file::( fs.clone(), @@ -298,7 +299,6 @@ impl ContextEditor { move |settings, _| settings.set_model(model.clone()), ); }, - ModelType::Default, window, cx, ) diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index 81a03c33059fb2659e87d901c52f13bef7ea0e9c..4c9e918756c64320022b3c65c885efebe3dd6732 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -39,10 +39,14 @@ pub use crate::telemetry::*; pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev"; pub fn init(client: Arc, cx: &mut App) { - registry::init(cx); + init_settings(cx); RefreshLlmTokenListener::register(client.clone(), cx); } +pub fn init_settings(cx: &mut App) { + registry::init(cx); +} + /// The availability of a [`LanguageModel`]. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum LanguageModelAvailability { diff --git a/crates/language_model/src/registry.rs b/crates/language_model/src/registry.rs index 1f17a6e822acdbdee22e4567b9caf70d1a22f7f7..46b0bc56fd8183be7addf1b25e0b7e3d93d30b12 100644 --- a/crates/language_model/src/registry.rs +++ b/crates/language_model/src/registry.rs @@ -188,7 +188,7 @@ impl LanguageModelRegistry { .collect::>(); } - fn select_model( + pub fn select_model( &mut self, selected_model: &SelectedModel, cx: &mut Context, diff --git a/crates/language_model_selector/src/language_model_selector.rs b/crates/language_model_selector/src/language_model_selector.rs index e677c3e3da4e0afdc8a3b2d76b09695d05600a2b..7ca093e1ea8144df9378f0636343c300033de61b 100644 --- a/crates/language_model_selector/src/language_model_selector.rs +++ b/crates/language_model_selector/src/language_model_selector.rs @@ -22,7 +22,8 @@ action_with_deprecated_aliases!( const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro"; -type OnModelChanged = Arc, &App) + 'static>; +type OnModelChanged = Arc, &mut App) + 'static>; +type GetActiveModel = Arc Option + 'static>; pub struct LanguageModelSelector { picker: Entity>, @@ -30,16 +31,10 @@ pub struct LanguageModelSelector { _subscriptions: Vec, } -#[derive(Clone, Copy)] -pub enum ModelType { - Default, - InlineAssistant, -} - impl LanguageModelSelector { pub fn new( - on_model_changed: impl Fn(Arc, &App) + 'static, - model_type: ModelType, + get_active_model: impl Fn(&App) -> Option + 'static, + on_model_changed: impl Fn(Arc, &mut App) + 'static, window: &mut Window, cx: &mut Context, ) -> Self { @@ -52,9 +47,9 @@ impl LanguageModelSelector { language_model_selector: cx.entity().downgrade(), on_model_changed: on_model_changed.clone(), all_models: Arc::new(all_models), - selected_index: Self::get_active_model_index(&entries, model_type, cx), + selected_index: Self::get_active_model_index(&entries, get_active_model(cx)), filtered_entries: entries, - model_type, + get_active_model: Arc::new(get_active_model), }; let picker = cx.new(|cx| { @@ -204,26 +199,13 @@ impl LanguageModelSelector { } pub fn active_model(&self, cx: &App) -> Option { - let model_type = self.picker.read(cx).delegate.model_type; - Self::active_model_by_type(model_type, cx) - } - - fn active_model_by_type(model_type: ModelType, cx: &App) -> Option { - match model_type { - ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(), - ModelType::InlineAssistant => { - LanguageModelRegistry::read_global(cx).inline_assistant_model() - } - } + (self.picker.read(cx).delegate.get_active_model)(cx) } fn get_active_model_index( entries: &[LanguageModelPickerEntry], - model_type: ModelType, - cx: &App, + active_model: Option, ) -> usize { - let active_model = Self::active_model_by_type(model_type, cx); - entries .iter() .position(|entry| { @@ -232,7 +214,7 @@ impl LanguageModelSelector { .as_ref() .map(|active_model| { active_model.model.id() == model.model.id() - && active_model.model.provider_id() == model.model.provider_id() + && active_model.provider.id() == model.model.provider_id() }) .unwrap_or_default() } else { @@ -325,10 +307,10 @@ struct ModelInfo { pub struct LanguageModelPickerDelegate { language_model_selector: WeakEntity, on_model_changed: OnModelChanged, + get_active_model: GetActiveModel, all_models: Arc, filtered_entries: Vec, selected_index: usize, - model_type: ModelType, } struct GroupedModels { @@ -522,8 +504,7 @@ impl PickerDelegate for LanguageModelPickerDelegate { .into_any_element(), ), LanguageModelPickerEntry::Model(model_info) => { - let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx); - + let active_model = (self.get_active_model)(cx); let active_provider_id = active_model.as_ref().map(|m| m.provider.id()); let active_model_id = active_model.map(|m| m.model.id());