Detailed changes
@@ -25,8 +25,8 @@ use gpui::{
};
use language::{Buffer, LanguageRegistry};
use language_model::{
- LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent,
- RequestUsage, Role, StopReason,
+ LanguageModelRequestMessage, LanguageModelToolUseId, MessageContent, RequestUsage, Role,
+ StopReason,
};
use markdown::parser::{CodeBlockKind, CodeBlockMetadata};
use markdown::{HeadingLevelStyles, Markdown, MarkdownElement, MarkdownStyle, ParsedMarkdown};
@@ -1252,7 +1252,7 @@ impl ActiveThread {
cx.emit(ActiveThreadEvent::EditingMessageTokenCountChanged);
state._update_token_count_task.take();
- let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
+ let Some(configured_model) = self.thread.read(cx).configured_model() else {
state.last_estimated_token_count.take();
return;
};
@@ -1305,7 +1305,7 @@ impl ActiveThread {
temperature: None,
};
- Some(default_model.model.count_tokens(request, cx))
+ Some(configured_model.model.count_tokens(request, cx))
})? {
task.await?
} else {
@@ -1338,7 +1338,7 @@ impl ActiveThread {
return;
};
let edited_text = state.editor.read(cx).text(cx);
- self.thread.update(cx, |thread, cx| {
+ let thread_model = self.thread.update(cx, |thread, cx| {
thread.edit_message(
message_id,
Role::User,
@@ -1348,9 +1348,10 @@ impl ActiveThread {
for message_id in self.messages_after(message_id) {
thread.delete_message(*message_id, cx);
}
+ thread.get_or_init_configured_model(cx)
});
- let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
+ let Some(model) = thread_model else {
return;
};
@@ -951,6 +951,7 @@ mod tests {
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);
+ language_model::init_settings(cx);
});
let fs = FakeFs::new(cx.executor());
@@ -2,6 +2,8 @@ use assistant_settings::AssistantSettings;
use fs::Fs;
use gpui::{Entity, FocusHandle, SharedString};
+use crate::Thread;
+use language_model::{ConfiguredModel, LanguageModelRegistry};
use language_model_selector::{
LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
};
@@ -9,7 +11,11 @@ use settings::update_settings_file;
use std::sync::Arc;
use ui::{ButtonLike, PopoverMenuHandle, Tooltip, prelude::*};
-pub use language_model_selector::ModelType;
+#[derive(Clone)]
+pub enum ModelType {
+ Default(Entity<Thread>),
+ InlineAssistant,
+}
pub struct AssistantModelSelector {
selector: Entity<LanguageModelSelector>,
@@ -24,18 +30,39 @@ impl AssistantModelSelector {
focus_handle: FocusHandle,
model_type: ModelType,
window: &mut Window,
- cx: &mut App,
+ cx: &mut Context<Self>,
) -> Self {
Self {
- selector: cx.new(|cx| {
+ selector: cx.new(move |cx| {
let fs = fs.clone();
LanguageModelSelector::new(
+ {
+ let model_type = model_type.clone();
+ move |cx| match &model_type {
+ ModelType::Default(thread) => thread.read(cx).configured_model(),
+ ModelType::InlineAssistant => {
+ LanguageModelRegistry::read_global(cx).inline_assistant_model()
+ }
+ }
+ },
move |model, cx| {
let provider = model.provider_id().0.to_string();
let model_id = model.id().0.to_string();
-
- match model_type {
- ModelType::Default => {
+ match &model_type {
+ ModelType::Default(thread) => {
+ thread.update(cx, |thread, cx| {
+ let registry = LanguageModelRegistry::read_global(cx);
+ if let Some(provider) = registry.provider(&model.provider_id())
+ {
+ thread.set_configured_model(
+ Some(ConfiguredModel {
+ provider,
+ model: model.clone(),
+ }),
+ cx,
+ );
+ }
+ });
update_settings_file::<AssistantSettings>(
fs.clone(),
cx,
@@ -58,7 +85,6 @@ impl AssistantModelSelector {
}
}
},
- model_type,
window,
cx,
)
@@ -1274,12 +1274,12 @@ impl AssistantPanel {
let is_generating = thread.is_generating();
let message_editor = self.message_editor.read(cx);
- let conversation_token_usage = thread.total_token_usage(cx);
+ let conversation_token_usage = thread.total_token_usage();
let (total_token_usage, is_estimating) = if let Some((editing_message_id, unsent_tokens)) =
self.thread.read(cx).editing_message_id()
{
let combined = thread
- .token_usage_up_to_message(editing_message_id, cx)
+ .token_usage_up_to_message(editing_message_id)
.add(unsent_tokens);
(combined, unsent_tokens > 0)
@@ -1,4 +1,4 @@
-use crate::assistant_model_selector::AssistantModelSelector;
+use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::buffer_codegen::BufferCodegen;
use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
@@ -20,7 +20,7 @@ use gpui::{
Focusable, FontWeight, Subscription, TextStyle, WeakEntity, Window, anchored, deferred, point,
};
use language_model::{LanguageModel, LanguageModelRegistry};
-use language_model_selector::{ModelType, ToggleModelSelector};
+use language_model_selector::ToggleModelSelector;
use parking_lot::Mutex;
use settings::Settings;
use std::cmp;
@@ -1,7 +1,7 @@
use std::collections::BTreeMap;
use std::sync::Arc;
-use crate::assistant_model_selector::ModelType;
+use crate::assistant_model_selector::{AssistantModelSelector, ModelType};
use crate::context::{ContextLoadResult, load_context};
use crate::tool_compatibility::{IncompatibleToolsState, IncompatibleToolsTooltip};
use buffer_diff::BufferDiff;
@@ -21,9 +21,7 @@ use gpui::{
Task, TextStyle, WeakEntity, linear_color_stop, linear_gradient, point, pulsating_between,
};
use language::{Buffer, Language};
-use language_model::{
- ConfiguredModel, LanguageModelRegistry, LanguageModelRequestMessage, MessageContent,
-};
+use language_model::{ConfiguredModel, LanguageModelRequestMessage, MessageContent};
use language_model_selector::ToggleModelSelector;
use multi_buffer;
use project::Project;
@@ -36,7 +34,6 @@ use util::ResultExt as _;
use workspace::Workspace;
use zed_llm_client::CompletionMode;
-use crate::assistant_model_selector::AssistantModelSelector;
use crate::context_picker::{ContextPicker, ContextPickerCompletionProvider};
use crate::context_store::ContextStore;
use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind};
@@ -153,6 +150,17 @@ impl MessageEditor {
}),
];
+ let model_selector = cx.new(|cx| {
+ AssistantModelSelector::new(
+ fs.clone(),
+ model_selector_menu_handle,
+ editor.focus_handle(cx),
+ ModelType::Default(thread.clone()),
+ window,
+ cx,
+ )
+ });
+
Self {
editor: editor.clone(),
project: thread.read(cx).project().clone(),
@@ -165,16 +173,7 @@ impl MessageEditor {
context_picker_menu_handle,
load_context_task: None,
last_loaded_context: None,
- model_selector: cx.new(|cx| {
- AssistantModelSelector::new(
- fs.clone(),
- model_selector_menu_handle,
- editor.focus_handle(cx),
- ModelType::Default,
- window,
- cx,
- )
- }),
+ model_selector,
edits_expanded: false,
editor_is_expanded: false,
profile_selector: cx
@@ -263,15 +262,11 @@ impl MessageEditor {
self.editor.read(cx).text(cx).trim().is_empty()
}
- fn is_model_selected(&self, cx: &App) -> bool {
- LanguageModelRegistry::read_global(cx)
- .default_model()
- .is_some()
- }
-
fn send_to_model(&mut self, window: &mut Window, cx: &mut Context<Self>) {
- let model_registry = LanguageModelRegistry::read_global(cx);
- let Some(ConfiguredModel { model, provider }) = model_registry.default_model() else {
+ let Some(ConfiguredModel { model, provider }) = self
+ .thread
+ .update(cx, |thread, cx| thread.get_or_init_configured_model(cx))
+ else {
return;
};
@@ -408,14 +403,13 @@ impl MessageEditor {
return None;
}
- let model = LanguageModelRegistry::read_global(cx)
- .default_model()
- .map(|default| default.model.clone())?;
- if !model.supports_max_mode() {
+ let thread = self.thread.read(cx);
+ let model = thread.configured_model();
+ if !model?.model.supports_max_mode() {
return None;
}
- let active_completion_mode = self.thread.read(cx).completion_mode();
+ let active_completion_mode = thread.completion_mode();
Some(
IconButton::new("max-mode", IconName::SquarePlus)
@@ -442,24 +436,21 @@ impl MessageEditor {
cx: &mut Context<Self>,
) -> Div {
let thread = self.thread.read(cx);
+ let model = thread.configured_model();
let editor_bg_color = cx.theme().colors().editor_background;
let is_generating = thread.is_generating();
let focus_handle = self.editor.focus_handle(cx);
- let is_model_selected = self.is_model_selected(cx);
+ let is_model_selected = model.is_some();
let is_editor_empty = self.is_editor_empty(cx);
- let model = LanguageModelRegistry::read_global(cx)
- .default_model()
- .map(|default| default.model.clone());
-
let incompatible_tools = model
.as_ref()
.map(|model| {
self.incompatible_tools_state.update(cx, |state, cx| {
state
- .incompatible_tools(model, cx)
+ .incompatible_tools(&model.model, cx)
.iter()
.cloned()
.collect::<Vec<_>>()
@@ -1058,7 +1049,7 @@ impl MessageEditor {
cx.emit(MessageEditorEvent::Changed);
self.update_token_count_task.take();
- let Some(default_model) = LanguageModelRegistry::read_global(cx).default_model() else {
+ let Some(model) = self.thread.read(cx).configured_model() else {
self.last_estimated_token_count.take();
return;
};
@@ -1111,7 +1102,7 @@ impl MessageEditor {
temperature: None,
};
- Some(default_model.model.count_tokens(request, cx))
+ Some(model.model.count_tokens(request, cx))
})? {
task.await?
} else {
@@ -1143,7 +1134,7 @@ impl Focusable for MessageEditor {
impl Render for MessageEditor {
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
let thread = self.thread.read(cx);
- let total_token_usage = thread.total_token_usage(cx);
+ let total_token_usage = thread.total_token_usage();
let token_usage_ratio = total_token_usage.ratio();
let action_log = self.thread.read(cx).action_log();
@@ -22,8 +22,8 @@ use language_model::{
LanguageModelId, LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
- ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
- TokenUsage,
+ ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, SelectedModel,
+ StopReason, TokenUsage,
};
use postage::stream::Stream as _;
use project::Project;
@@ -41,8 +41,8 @@ use zed_llm_client::CompletionMode;
use crate::ThreadStore;
use crate::context::{AgentContext, ContextLoadResult, LoadedContext};
use crate::thread_store::{
- SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
- SerializedToolUse, SharedProjectContext,
+ SerializedLanguageModel, SerializedMessage, SerializedMessageSegment, SerializedThread,
+ SerializedToolResult, SerializedToolUse, SharedProjectContext,
};
use crate::tool_use::{PendingToolUse, ToolUse, ToolUseMetadata, ToolUseState};
@@ -332,6 +332,7 @@ pub struct Thread {
Box<dyn FnMut(&LanguageModelRequest, &[Result<LanguageModelCompletionEvent, String>])>,
>,
remaining_turns: u32,
+ configured_model: Option<ConfiguredModel>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -351,6 +352,8 @@ impl Thread {
cx: &mut Context<Self>,
) -> Self {
let (detailed_summary_tx, detailed_summary_rx) = postage::watch::channel();
+ let configured_model = LanguageModelRegistry::read_global(cx).default_model();
+
Self {
id: ThreadId::new(),
updated_at: Utc::now(),
@@ -388,6 +391,7 @@ impl Thread {
last_auto_capture_at: None,
request_callback: None,
remaining_turns: u32::MAX,
+ configured_model,
}
}
@@ -411,6 +415,19 @@ impl Thread {
let (detailed_summary_tx, detailed_summary_rx) =
postage::watch::channel_with(serialized.detailed_summary_state);
+ let configured_model = LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+ serialized
+ .model
+ .and_then(|model| {
+ let model = SelectedModel {
+ provider: model.provider.clone().into(),
+ model: model.model.clone().into(),
+ };
+ registry.select_model(&model, cx)
+ })
+ .or_else(|| registry.default_model())
+ });
+
Self {
id,
updated_at: serialized.updated_at,
@@ -468,6 +485,7 @@ impl Thread {
last_auto_capture_at: None,
request_callback: None,
remaining_turns: u32::MAX,
+ configured_model,
}
}
@@ -507,6 +525,22 @@ impl Thread {
self.project_context.clone()
}
+ pub fn get_or_init_configured_model(&mut self, cx: &App) -> Option<ConfiguredModel> {
+ if self.configured_model.is_none() {
+ self.configured_model = LanguageModelRegistry::read_global(cx).default_model();
+ }
+ self.configured_model.clone()
+ }
+
+ pub fn configured_model(&self) -> Option<ConfiguredModel> {
+ self.configured_model.clone()
+ }
+
+ pub fn set_configured_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
+ self.configured_model = model;
+ cx.notify();
+ }
+
pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
pub fn summary_or_default(&self) -> SharedString {
@@ -952,6 +986,13 @@ impl Thread {
request_token_usage: this.request_token_usage.clone(),
detailed_summary_state: this.detailed_summary_rx.borrow().clone(),
exceeded_window_error: this.exceeded_window_error.clone(),
+ model: this
+ .configured_model
+ .as_ref()
+ .map(|model| SerializedLanguageModel {
+ provider: model.provider.id().0.to_string(),
+ model: model.model.id().0.to_string(),
+ }),
})
})
}
@@ -1733,7 +1774,7 @@ impl Thread {
tool_use_id.clone(),
tool_name,
Err(anyhow!("Error parsing input JSON: {error}")),
- cx,
+ self.configured_model.as_ref(),
);
let ui_text = if let Some(pending_tool_use) = &pending_tool_use {
pending_tool_use.ui_text.clone()
@@ -1808,7 +1849,7 @@ impl Thread {
tool_use_id.clone(),
tool_name,
output,
- cx,
+ thread.configured_model.as_ref(),
);
thread.tool_finished(tool_use_id, pending_tool_use, false, window, cx);
})
@@ -1826,10 +1867,9 @@ impl Thread {
cx: &mut Context<Self>,
) {
if self.all_tools_finished() {
- let model_registry = LanguageModelRegistry::read_global(cx);
- if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
+ if let Some(ConfiguredModel { model, .. }) = self.configured_model.as_ref() {
if !canceled {
- self.send_to_model(model, window, cx);
+ self.send_to_model(model.clone(), window, cx);
}
self.auto_capture_telemetry(cx);
}
@@ -2254,8 +2294,8 @@ impl Thread {
self.cumulative_token_usage
}
- pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
- let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
+ pub fn token_usage_up_to_message(&self, message_id: MessageId) -> TotalTokenUsage {
+ let Some(model) = self.configured_model.as_ref() else {
return TotalTokenUsage::default();
};
@@ -2283,9 +2323,8 @@ impl Thread {
}
}
- pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
- let model_registry = LanguageModelRegistry::read_global(cx);
- let Some(model) = model_registry.default_model() else {
+ pub fn total_token_usage(&self) -> TotalTokenUsage {
+ let Some(model) = self.configured_model.as_ref() else {
return TotalTokenUsage::default();
};
@@ -2336,8 +2375,12 @@ impl Thread {
"Permission to run tool action denied by user"
));
- self.tool_use
- .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
+ self.tool_use.insert_tool_output(
+ tool_use_id.clone(),
+ tool_name,
+ err,
+ self.configured_model.as_ref(),
+ );
self.tool_finished(tool_use_id.clone(), None, true, window, cx);
}
}
@@ -2769,6 +2812,7 @@ fn main() {{
prompt_store::init(cx);
thread_store::init(cx);
workspace::init_settings(cx);
+ language_model::init_settings(cx);
ThemeSettings::register(cx);
ContextServerSettings::register(cx);
EditorSettings::register(cx);
@@ -640,6 +640,14 @@ pub struct SerializedThread {
pub detailed_summary_state: DetailedSummaryState,
#[serde(default)]
pub exceeded_window_error: Option<ExceededWindowError>,
+ #[serde(default)]
+ pub model: Option<SerializedLanguageModel>,
+}
+
+#[derive(Serialize, Deserialize, Debug)]
+pub struct SerializedLanguageModel {
+ pub provider: String,
+ pub model: String,
}
impl SerializedThread {
@@ -774,6 +782,7 @@ impl LegacySerializedThread {
request_token_usage: Vec::new(),
detailed_summary_state: DetailedSummaryState::default(),
exceeded_window_error: None,
+ model: None,
}
}
}
@@ -7,7 +7,7 @@ use futures::FutureExt as _;
use futures::future::Shared;
use gpui::{App, Entity, SharedString, Task};
use language_model::{
- LanguageModel, LanguageModelRegistry, LanguageModelRequestMessage, LanguageModelToolResult,
+ ConfiguredModel, LanguageModel, LanguageModelRequestMessage, LanguageModelToolResult,
LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
};
use ui::IconName;
@@ -353,7 +353,7 @@ impl ToolUseState {
tool_use_id: LanguageModelToolUseId,
tool_name: Arc<str>,
output: Result<String>,
- cx: &App,
+ configured_model: Option<&ConfiguredModel>,
) -> Option<PendingToolUse> {
let metadata = self.tool_use_metadata_by_id.remove(&tool_use_id);
@@ -373,13 +373,10 @@ impl ToolUseState {
match output {
Ok(tool_result) => {
- let model_registry = LanguageModelRegistry::read_global(cx);
-
const BYTES_PER_TOKEN_ESTIMATE: usize = 3;
// Protect from clearly large output
- let tool_output_limit = model_registry
- .default_model()
+ let tool_output_limit = configured_model
.map(|model| model.model.max_token_count() * BYTES_PER_TOKEN_ESTIMATE)
.unwrap_or(usize::MAX);
@@ -37,7 +37,7 @@ use language_model::{
ConfiguredModel, LanguageModel, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelTextStream, Role, report_assistant_event,
};
-use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
+use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use project::{CodeAction, LspAction, ProjectTransaction};
@@ -1759,6 +1759,7 @@ impl PromptEditor {
language_model_selector: cx.new(|cx| {
let fs = fs.clone();
LanguageModelSelector::new(
+ |cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| {
update_settings_file::<AssistantSettings>(
fs.clone(),
@@ -1766,7 +1767,6 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
- ModelType::Default,
window,
cx,
)
@@ -19,7 +19,7 @@ use language_model::{
ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
Role, report_assistant_event,
};
-use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType};
+use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use prompt_store::PromptBuilder;
use settings::{Settings, update_settings_file};
use std::{
@@ -749,6 +749,7 @@ impl PromptEditor {
language_model_selector: cx.new(|cx| {
let fs = fs.clone();
LanguageModelSelector::new(
+ |cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| {
update_settings_file::<AssistantSettings>(
fs.clone(),
@@ -756,7 +757,6 @@ impl PromptEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
- ModelType::Default,
window,
cx,
)
@@ -39,7 +39,7 @@ use language_model::{
Role,
};
use language_model_selector::{
- LanguageModelSelector, LanguageModelSelectorPopoverMenu, ModelType, ToggleModelSelector,
+ LanguageModelSelector, LanguageModelSelectorPopoverMenu, ToggleModelSelector,
};
use multi_buffer::MultiBufferRow;
use picker::Picker;
@@ -291,6 +291,7 @@ impl ContextEditor {
dragged_file_worktrees: Vec::new(),
language_model_selector: cx.new(|cx| {
LanguageModelSelector::new(
+ |cx| LanguageModelRegistry::read_global(cx).default_model(),
move |model, cx| {
update_settings_file::<AssistantSettings>(
fs.clone(),
@@ -298,7 +299,6 @@ impl ContextEditor {
move |settings, _| settings.set_model(model.clone()),
);
},
- ModelType::Default,
window,
cx,
)
@@ -39,10 +39,14 @@ pub use crate::telemetry::*;
pub const ZED_CLOUD_PROVIDER_ID: &str = "zed.dev";
pub fn init(client: Arc<Client>, cx: &mut App) {
- registry::init(cx);
+ init_settings(cx);
RefreshLlmTokenListener::register(client.clone(), cx);
}
+pub fn init_settings(cx: &mut App) {
+ registry::init(cx);
+}
+
/// The availability of a [`LanguageModel`].
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub enum LanguageModelAvailability {
@@ -188,7 +188,7 @@ impl LanguageModelRegistry {
.collect::<Vec<_>>();
}
- fn select_model(
+ pub fn select_model(
&mut self,
selected_model: &SelectedModel,
cx: &mut Context<Self>,
@@ -22,7 +22,8 @@ action_with_deprecated_aliases!(
const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
-type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
+type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
+type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
pub struct LanguageModelSelector {
picker: Entity<Picker<LanguageModelPickerDelegate>>,
@@ -30,16 +31,10 @@ pub struct LanguageModelSelector {
_subscriptions: Vec<Subscription>,
}
-#[derive(Clone, Copy)]
-pub enum ModelType {
- Default,
- InlineAssistant,
-}
-
impl LanguageModelSelector {
pub fn new(
- on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
- model_type: ModelType,
+ get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
+ on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
window: &mut Window,
cx: &mut Context<Self>,
) -> Self {
@@ -52,9 +47,9 @@ impl LanguageModelSelector {
language_model_selector: cx.entity().downgrade(),
on_model_changed: on_model_changed.clone(),
all_models: Arc::new(all_models),
- selected_index: Self::get_active_model_index(&entries, model_type, cx),
+ selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
filtered_entries: entries,
- model_type,
+ get_active_model: Arc::new(get_active_model),
};
let picker = cx.new(|cx| {
@@ -204,26 +199,13 @@ impl LanguageModelSelector {
}
pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
- let model_type = self.picker.read(cx).delegate.model_type;
- Self::active_model_by_type(model_type, cx)
- }
-
- fn active_model_by_type(model_type: ModelType, cx: &App) -> Option<ConfiguredModel> {
- match model_type {
- ModelType::Default => LanguageModelRegistry::read_global(cx).default_model(),
- ModelType::InlineAssistant => {
- LanguageModelRegistry::read_global(cx).inline_assistant_model()
- }
- }
+ (self.picker.read(cx).delegate.get_active_model)(cx)
}
fn get_active_model_index(
entries: &[LanguageModelPickerEntry],
- model_type: ModelType,
- cx: &App,
+ active_model: Option<ConfiguredModel>,
) -> usize {
- let active_model = Self::active_model_by_type(model_type, cx);
-
entries
.iter()
.position(|entry| {
@@ -232,7 +214,7 @@ impl LanguageModelSelector {
.as_ref()
.map(|active_model| {
active_model.model.id() == model.model.id()
- && active_model.model.provider_id() == model.model.provider_id()
+ && active_model.provider.id() == model.model.provider_id()
})
.unwrap_or_default()
} else {
@@ -325,10 +307,10 @@ struct ModelInfo {
pub struct LanguageModelPickerDelegate {
language_model_selector: WeakEntity<LanguageModelSelector>,
on_model_changed: OnModelChanged,
+ get_active_model: GetActiveModel,
all_models: Arc<GroupedModels>,
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
- model_type: ModelType,
}
struct GroupedModels {
@@ -522,8 +504,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.into_any_element(),
),
LanguageModelPickerEntry::Model(model_info) => {
- let active_model = LanguageModelSelector::active_model_by_type(self.model_type, cx);
-
+ let active_model = (self.get_active_model)(cx);
let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
let active_model_id = active_model.map(|m| m.model.id());