language_model_selector: Refresh the models when the providers change (#22624)

Marshall Bowers created

This PR fixes an issue introduced in #21939 where the list of models in
the language model selector could be outdated.

Since we're no longer recreating the picker each render, we now need to
make sure we are updating the list of models accordingly when there are
changes to the language model providers.

I noticed it specifically in Assistant1.

Release Notes:

- Fixed a staleness issue with the language model selector.

Change summary

crates/language_model_selector/src/language_model_selector.rs | 77 +++-
1 file changed, 56 insertions(+), 21 deletions(-)

Detailed changes

crates/language_model_selector/src/language_model_selector.rs 🔗

@@ -2,8 +2,8 @@ use std::sync::Arc;
 
 use feature_flags::ZedPro;
 use gpui::{
-    Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Task,
-    View, WeakView,
+    Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Model,
+    Subscription, Task, View, WeakView,
 };
 use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
 use picker::{Picker, PickerDelegate};
@@ -17,6 +17,10 @@ type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &AppContext) + 'static>
 
 pub struct LanguageModelSelector {
     picker: View<Picker<LanguageModelPickerDelegate>>,
+    /// The task used to update the picker's matches when there is a change to
+    /// the language model registry.
+    update_matches_task: Option<Task<()>>,
+    _subscriptions: Vec<Subscription>,
 }
 
 impl LanguageModelSelector {
@@ -26,7 +30,51 @@ impl LanguageModelSelector {
     ) -> Self {
         let on_model_changed = Arc::new(on_model_changed);
 
-        let all_models = LanguageModelRegistry::global(cx)
+        let all_models = Self::all_models(cx);
+        let delegate = LanguageModelPickerDelegate {
+            language_model_selector: cx.view().downgrade(),
+            on_model_changed: on_model_changed.clone(),
+            all_models: all_models.clone(),
+            filtered_models: all_models,
+            selected_index: 0,
+        };
+
+        let picker =
+            cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
+
+        LanguageModelSelector {
+            picker,
+            update_matches_task: None,
+            _subscriptions: vec![cx.subscribe(
+                &LanguageModelRegistry::global(cx),
+                Self::handle_language_model_registry_event,
+            )],
+        }
+    }
+
+    fn handle_language_model_registry_event(
+        &mut self,
+        _registry: Model<LanguageModelRegistry>,
+        event: &language_model::Event,
+        cx: &mut ViewContext<Self>,
+    ) {
+        match event {
+            language_model::Event::ProviderStateChanged
+            | language_model::Event::AddedProvider(_)
+            | language_model::Event::RemovedProvider(_) => {
+                let task = self.picker.update(cx, |this, cx| {
+                    let query = this.query(cx);
+                    this.delegate.all_models = Self::all_models(cx);
+                    this.delegate.update_matches(query, cx)
+                });
+                self.update_matches_task = Some(task);
+            }
+            _ => {}
+        }
+    }
+
+    fn all_models(cx: &AppContext) -> Vec<ModelInfo> {
+        LanguageModelRegistry::global(cx)
             .read(cx)
             .providers()
             .iter()
@@ -44,20 +92,7 @@ impl LanguageModelSelector {
                     }
                 })
             })
-            .collect::<Vec<_>>();
-
-        let delegate = LanguageModelPickerDelegate {
-            language_model_selector: cx.view().downgrade(),
-            on_model_changed: on_model_changed.clone(),
-            all_models: all_models.clone(),
-            filtered_models: all_models,
-            selected_index: 0,
-        };
-
-        let picker =
-            cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
-
-        LanguageModelSelector { picker }
+            .collect::<Vec<_>>()
     }
 }
 
@@ -152,25 +187,25 @@ impl PickerDelegate for LanguageModelPickerDelegate {
 
         let llm_registry = LanguageModelRegistry::global(cx);
 
-        let configured_models: Vec<_> = llm_registry
+        let configured_providers = llm_registry
             .read(cx)
             .providers()
             .iter()
             .filter(|provider| provider.is_authenticated(cx))
             .map(|provider| provider.id())
-            .collect();
+            .collect::<Vec<_>>();
 
         cx.spawn(|this, mut cx| async move {
             let filtered_models = cx
                 .background_executor()
                 .spawn(async move {
-                    let displayed_models = if configured_models.is_empty() {
+                    let displayed_models = if configured_providers.is_empty() {
                         all_models
                     } else {
                         all_models
                             .into_iter()
                             .filter(|model_info| {
-                                configured_models.contains(&model_info.model.provider_id())
+                                configured_providers.contains(&model_info.model.provider_id())
                             })
                             .collect::<Vec<_>>()
                     };