From 74c82e1e1a85b24023232bd62831ceb118e83a0c Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Fri, 3 Jan 2025 15:33:08 -0500 Subject: [PATCH] language_model_selector: Refresh the models when the providers change (cherry-pick #22624) (#22626) Cherry-picked language_model_selector: Refresh the models when the providers change (#22624) This PR fixes an issue introduced in #21939 where the list of models in the language model selector could be outdated. Since we're no longer recreating the picker each render, we now need to make sure we are updating the list of models accordingly when there are changes to the language model providers. I noticed it specifically in Assistant1. Release Notes: - Fixed a staleness issue with the language model selector. Co-authored-by: Marshall Bowers --- .../src/language_model_selector.rs | 77 ++++++++++++++----- 1 file changed, 56 insertions(+), 21 deletions(-) diff --git a/crates/language_model_selector/src/language_model_selector.rs b/crates/language_model_selector/src/language_model_selector.rs index 7d77bfd6ea2f91ca40d99bc2bb68661509a9e734..cf8260c73bb0aaef2d15ef9c2bb750a4cdd6c8e4 100644 --- a/crates/language_model_selector/src/language_model_selector.rs +++ b/crates/language_model_selector/src/language_model_selector.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use feature_flags::ZedPro; use gpui::{ - Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Task, - View, WeakView, + Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Model, + Subscription, Task, View, WeakView, }; use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry}; use picker::{Picker, PickerDelegate}; @@ -17,6 +17,10 @@ type OnModelChanged = Arc, &AppContext) + 'static> pub struct LanguageModelSelector { picker: View>, + /// The task used to update the picker's matches when there is a change to + /// the language model registry. + update_matches_task: Option>, + _subscriptions: Vec, } impl LanguageModelSelector { @@ -26,7 +30,51 @@ impl LanguageModelSelector { ) -> Self { let on_model_changed = Arc::new(on_model_changed); - let all_models = LanguageModelRegistry::global(cx) + let all_models = Self::all_models(cx); + let delegate = LanguageModelPickerDelegate { + language_model_selector: cx.view().downgrade(), + on_model_changed: on_model_changed.clone(), + all_models: all_models.clone(), + filtered_models: all_models, + selected_index: 0, + }; + + let picker = + cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()))); + + LanguageModelSelector { + picker, + update_matches_task: None, + _subscriptions: vec![cx.subscribe( + &LanguageModelRegistry::global(cx), + Self::handle_language_model_registry_event, + )], + } + } + + fn handle_language_model_registry_event( + &mut self, + _registry: Model, + event: &language_model::Event, + cx: &mut ViewContext, + ) { + match event { + language_model::Event::ProviderStateChanged + | language_model::Event::AddedProvider(_) + | language_model::Event::RemovedProvider(_) => { + let task = self.picker.update(cx, |this, cx| { + let query = this.query(cx); + this.delegate.all_models = Self::all_models(cx); + this.delegate.update_matches(query, cx) + }); + self.update_matches_task = Some(task); + } + _ => {} + } + } + + fn all_models(cx: &AppContext) -> Vec { + LanguageModelRegistry::global(cx) .read(cx) .providers() .iter() @@ -44,20 +92,7 @@ impl LanguageModelSelector { } }) }) - .collect::>(); - - let delegate = LanguageModelPickerDelegate { - language_model_selector: cx.view().downgrade(), - on_model_changed: on_model_changed.clone(), - all_models: all_models.clone(), - filtered_models: all_models, - selected_index: 0, - }; - - let picker = - cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()))); - - LanguageModelSelector { picker } + .collect::>() } } @@ -152,25 +187,25 @@ impl PickerDelegate for LanguageModelPickerDelegate { let llm_registry = LanguageModelRegistry::global(cx); - let configured_models: Vec<_> = llm_registry + let configured_providers = llm_registry .read(cx) .providers() .iter() .filter(|provider| provider.is_authenticated(cx)) .map(|provider| provider.id()) - .collect(); + .collect::>(); cx.spawn(|this, mut cx| async move { let filtered_models = cx .background_executor() .spawn(async move { - let displayed_models = if configured_models.is_empty() { + let displayed_models = if configured_providers.is_empty() { all_models } else { all_models .into_iter() .filter(|model_info| { - configured_models.contains(&model_info.model.provider_id()) + configured_providers.contains(&model_info.model.provider_id()) }) .collect::>() };