diff --git a/crates/language_model_selector/src/language_model_selector.rs b/crates/language_model_selector/src/language_model_selector.rs index 7d77bfd6ea2f91ca40d99bc2bb68661509a9e734..cf8260c73bb0aaef2d15ef9c2bb750a4cdd6c8e4 100644 --- a/crates/language_model_selector/src/language_model_selector.rs +++ b/crates/language_model_selector/src/language_model_selector.rs @@ -2,8 +2,8 @@ use std::sync::Arc; use feature_flags::ZedPro; use gpui::{ - Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Task, - View, WeakView, + Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Model, + Subscription, Task, View, WeakView, }; use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry}; use picker::{Picker, PickerDelegate}; @@ -17,6 +17,10 @@ type OnModelChanged = Arc, &AppContext) + 'static> pub struct LanguageModelSelector { picker: View>, + /// The task used to update the picker's matches when there is a change to + /// the language model registry. + update_matches_task: Option>, + _subscriptions: Vec, } impl LanguageModelSelector { @@ -26,7 +30,51 @@ impl LanguageModelSelector { ) -> Self { let on_model_changed = Arc::new(on_model_changed); - let all_models = LanguageModelRegistry::global(cx) + let all_models = Self::all_models(cx); + let delegate = LanguageModelPickerDelegate { + language_model_selector: cx.view().downgrade(), + on_model_changed: on_model_changed.clone(), + all_models: all_models.clone(), + filtered_models: all_models, + selected_index: 0, + }; + + let picker = + cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()))); + + LanguageModelSelector { + picker, + update_matches_task: None, + _subscriptions: vec![cx.subscribe( + &LanguageModelRegistry::global(cx), + Self::handle_language_model_registry_event, + )], + } + } + + fn handle_language_model_registry_event( + &mut self, + _registry: Model, + event: &language_model::Event, + cx: &mut ViewContext, + ) { + match event { + language_model::Event::ProviderStateChanged + | language_model::Event::AddedProvider(_) + | language_model::Event::RemovedProvider(_) => { + let task = self.picker.update(cx, |this, cx| { + let query = this.query(cx); + this.delegate.all_models = Self::all_models(cx); + this.delegate.update_matches(query, cx) + }); + self.update_matches_task = Some(task); + } + _ => {} + } + } + + fn all_models(cx: &AppContext) -> Vec { + LanguageModelRegistry::global(cx) .read(cx) .providers() .iter() @@ -44,20 +92,7 @@ impl LanguageModelSelector { } }) }) - .collect::>(); - - let delegate = LanguageModelPickerDelegate { - language_model_selector: cx.view().downgrade(), - on_model_changed: on_model_changed.clone(), - all_models: all_models.clone(), - filtered_models: all_models, - selected_index: 0, - }; - - let picker = - cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()))); - - LanguageModelSelector { picker } + .collect::>() } } @@ -152,25 +187,25 @@ impl PickerDelegate for LanguageModelPickerDelegate { let llm_registry = LanguageModelRegistry::global(cx); - let configured_models: Vec<_> = llm_registry + let configured_providers = llm_registry .read(cx) .providers() .iter() .filter(|provider| provider.is_authenticated(cx)) .map(|provider| provider.id()) - .collect(); + .collect::>(); cx.spawn(|this, mut cx| async move { let filtered_models = cx .background_executor() .spawn(async move { - let displayed_models = if configured_models.is_empty() { + let displayed_models = if configured_providers.is_empty() { all_models } else { all_models .into_iter() .filter(|model_info| { - configured_models.contains(&model_info.model.provider_id()) + configured_providers.contains(&model_info.model.provider_id()) }) .collect::>() };