@@ -2,8 +2,8 @@ use std::sync::Arc;
use feature_flags::ZedPro;
use gpui::{
- Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Task,
- View, WeakView,
+ Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Model,
+ Subscription, Task, View, WeakView,
};
use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
use picker::{Picker, PickerDelegate};
@@ -17,6 +17,10 @@ type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &AppContext) + 'static>
pub struct LanguageModelSelector {
picker: View<Picker<LanguageModelPickerDelegate>>,
+ /// The task used to update the picker's matches when there is a change to
+ /// the language model registry.
+ update_matches_task: Option<Task<()>>,
+ _subscriptions: Vec<Subscription>,
}
impl LanguageModelSelector {
@@ -26,7 +30,51 @@ impl LanguageModelSelector {
) -> Self {
let on_model_changed = Arc::new(on_model_changed);
- let all_models = LanguageModelRegistry::global(cx)
+ let all_models = Self::all_models(cx);
+ let delegate = LanguageModelPickerDelegate {
+ language_model_selector: cx.view().downgrade(),
+ on_model_changed: on_model_changed.clone(),
+ all_models: all_models.clone(),
+ filtered_models: all_models,
+ selected_index: 0,
+ };
+
+ let picker =
+ cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
+
+ LanguageModelSelector {
+ picker,
+ update_matches_task: None,
+ _subscriptions: vec![cx.subscribe(
+ &LanguageModelRegistry::global(cx),
+ Self::handle_language_model_registry_event,
+ )],
+ }
+ }
+
+ fn handle_language_model_registry_event(
+ &mut self,
+ _registry: Model<LanguageModelRegistry>,
+ event: &language_model::Event,
+ cx: &mut ViewContext<Self>,
+ ) {
+ match event {
+ language_model::Event::ProviderStateChanged
+ | language_model::Event::AddedProvider(_)
+ | language_model::Event::RemovedProvider(_) => {
+ let task = self.picker.update(cx, |this, cx| {
+ let query = this.query(cx);
+ this.delegate.all_models = Self::all_models(cx);
+ this.delegate.update_matches(query, cx)
+ });
+ self.update_matches_task = Some(task);
+ }
+ _ => {}
+ }
+ }
+
+ fn all_models(cx: &AppContext) -> Vec<ModelInfo> {
+ LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.iter()
@@ -44,20 +92,7 @@ impl LanguageModelSelector {
}
})
})
- .collect::<Vec<_>>();
-
- let delegate = LanguageModelPickerDelegate {
- language_model_selector: cx.view().downgrade(),
- on_model_changed: on_model_changed.clone(),
- all_models: all_models.clone(),
- filtered_models: all_models,
- selected_index: 0,
- };
-
- let picker =
- cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
-
- LanguageModelSelector { picker }
+ .collect::<Vec<_>>()
}
}
@@ -152,25 +187,25 @@ impl PickerDelegate for LanguageModelPickerDelegate {
let llm_registry = LanguageModelRegistry::global(cx);
- let configured_models: Vec<_> = llm_registry
+ let configured_providers = llm_registry
.read(cx)
.providers()
.iter()
.filter(|provider| provider.is_authenticated(cx))
.map(|provider| provider.id())
- .collect();
+ .collect::<Vec<_>>();
cx.spawn(|this, mut cx| async move {
let filtered_models = cx
.background_executor()
.spawn(async move {
- let displayed_models = if configured_models.is_empty() {
+ let displayed_models = if configured_providers.is_empty() {
all_models
} else {
all_models
.into_iter()
.filter(|model_info| {
- configured_models.contains(&model_info.model.provider_id())
+ configured_providers.contains(&model_info.model.provider_id())
})
.collect::<Vec<_>>()
};