assistant: Show only configured models in the model picker (#20392)

Danilo Leal created

Closes https://github.com/zed-industries/zed/issues/16568

This PR introduces some changes to how we display models in the model
selector within the assistant panel. Basically, it comes down to this:

- If you don't have any provider configured, you should see _all_
available models in the picker
- But, once you've configured some, you should _only_ see models from
them in the picker

Visually, nothing's changed much aside from the added "Configured
Models" label at the top to ensure the understanding that that's a list
of, well, configured models only. 😬

<img width="700" alt="Screenshot 2024-11-07 at 23 42 41"
src="https://github.com/user-attachments/assets/219ed386-2318-43a6-abea-1de0cda8dc53">

Release Notes:

- Change model selector in the assistant panel to only show configured
models

Change summary

crates/assistant/src/assistant_panel.rs     |   2 
crates/assistant/src/model_selector.rs      | 110 ++++++++++++++++------
crates/language_model/src/provider/cloud.rs |  29 +----
3 files changed, 88 insertions(+), 53 deletions(-)

Detailed changes

crates/assistant/src/assistant_panel.rs 🔗

@@ -4507,7 +4507,6 @@ impl Render for ContextEditorToolbarItem {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
         let left_side = h_flex()
             .group("chat-title-group")
-            .pl_0p5()
             .gap_1()
             .items_center()
             .flex_grow()
@@ -4598,6 +4597,7 @@ impl Render for ContextEditorToolbarItem {
             .children(self.render_remaining_tokens(cx));
 
         h_flex()
+            .px_0p5()
             .size_full()
             .gap_2()
             .justify_between()

crates/assistant/src/model_selector.rs 🔗

@@ -1,21 +1,17 @@
 use feature_flags::ZedPro;
-use gpui::Action;
-use gpui::DismissEvent;
 
 use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
 use proto::Plan;
 use workspace::ShowConfiguration;
 
 use std::sync::Arc;
-use ui::ListItemSpacing;
 
 use crate::assistant_settings::AssistantSettings;
 use fs::Fs;
-use gpui::SharedString;
-use gpui::Task;
+use gpui::{Action, AnyElement, DismissEvent, SharedString, Task};
 use picker::{Picker, PickerDelegate};
 use settings::update_settings_file;
-use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
+use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 
 const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 
@@ -85,14 +81,36 @@ impl PickerDelegate for ModelPickerDelegate {
 
     fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
         let all_models = self.all_models.clone();
+
+        let llm_registry = LanguageModelRegistry::global(cx);
+
+        let configured_models: Vec<_> = llm_registry
+            .read(cx)
+            .providers()
+            .iter()
+            .filter(|provider| provider.is_authenticated(cx))
+            .map(|provider| provider.id())
+            .collect();
+
         cx.spawn(|this, mut cx| async move {
             let filtered_models = cx
                 .background_executor()
                 .spawn(async move {
-                    if query.is_empty() {
+                    let displayed_models = if configured_models.is_empty() {
                         all_models
                     } else {
                         all_models
+                            .into_iter()
+                            .filter(|model_info| {
+                                configured_models.contains(&model_info.model.provider_id())
+                            })
+                            .collect::<Vec<_>>()
+                    };
+
+                    if query.is_empty() {
+                        displayed_models
+                    } else {
+                        displayed_models
                             .into_iter()
                             .filter(|model_info| {
                                 model_info
@@ -141,6 +159,29 @@ impl PickerDelegate for ModelPickerDelegate {
 
     fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
 
+    fn render_header(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<AnyElement> {
+        let configured_models_count = LanguageModelRegistry::global(cx)
+            .read(cx)
+            .providers()
+            .iter()
+            .filter(|provider| provider.is_authenticated(cx))
+            .count();
+
+        if configured_models_count > 0 {
+            Some(
+                Label::new("Configured Models")
+                    .size(LabelSize::Small)
+                    .color(Color::Muted)
+                    .mt_1()
+                    .mb_0p5()
+                    .ml_3()
+                    .into_any_element(),
+            )
+        } else {
+            None
+        }
+    }
+
     fn render_match(
         &self,
         ix: usize,
@@ -148,9 +189,10 @@ impl PickerDelegate for ModelPickerDelegate {
         cx: &mut ViewContext<Picker<Self>>,
     ) -> Option<Self::ListItem> {
         use feature_flags::FeatureFlagAppExt;
-        let model_info = self.filtered_models.get(ix)?;
         let show_badges = cx.has_flag::<ZedPro>();
-        let provider_name: String = model_info.model.provider_name().0.into();
+
+        let model_info = self.filtered_models.get(ix)?;
+        let provider_name: String = model_info.model.provider_name().0.clone().into();
 
         Some(
             ListItem::new(ix)
@@ -165,27 +207,32 @@ impl PickerDelegate for ModelPickerDelegate {
                     ),
                 )
                 .child(
-                    h_flex().w_full().justify_between().min_w(px(200.)).child(
-                        h_flex()
-                            .gap_1p5()
-                            .child(Label::new(model_info.model.name().0.clone()))
-                            .child(
-                                Label::new(provider_name)
-                                    .size(LabelSize::XSmall)
-                                    .color(Color::Muted),
-                            )
-                            .children(match model_info.availability {
-                                LanguageModelAvailability::Public => None,
-                                LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
-                                LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
-                                    show_badges.then(|| {
-                                        Label::new("Pro")
-                                            .size(LabelSize::XSmall)
-                                            .color(Color::Muted)
-                                    })
-                                }
-                            }),
-                    ),
+                    h_flex()
+                        .w_full()
+                        .items_center()
+                        .gap_1p5()
+                        .min_w(px(200.))
+                        .child(Label::new(model_info.model.name().0.clone()))
+                        .child(
+                            h_flex()
+                                .gap_0p5()
+                                .child(
+                                    Label::new(provider_name)
+                                        .size(LabelSize::XSmall)
+                                        .color(Color::Muted),
+                                )
+                                .children(match model_info.availability {
+                                    LanguageModelAvailability::Public => None,
+                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
+                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
+                                        show_badges.then(|| {
+                                            Label::new("Pro")
+                                                .size(LabelSize::XSmall)
+                                                .color(Color::Muted)
+                                        })
+                                    }
+                                }),
+                        ),
                 )
                 .end_slot(div().when(model_info.is_selected, |this| {
                     this.child(
@@ -213,7 +260,7 @@ impl PickerDelegate for ModelPickerDelegate {
                 .justify_between()
                 .when(cx.has_flag::<ZedPro>(), |this| {
                     this.child(match plan {
-                        // Already a zed pro subscriber
+                        // Already a Zed Pro subscriber
                         Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
                             .icon(IconName::ZedAssistant)
                             .icon_size(IconSize::Small)
@@ -254,6 +301,7 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
         let selected_provider = LanguageModelRegistry::read_global(cx)
             .active_provider()
             .map(|m| m.id());
+
         let selected_model = LanguageModelRegistry::read_global(cx)
             .active_model()
             .map(|m| m.id());

crates/language_model/src/provider/cloud.rs 🔗

@@ -912,7 +912,7 @@ impl Render for ConfigurationView {
 
         let is_pro = plan == Some(proto::Plan::ZedPro);
         let subscription_text = Label::new(if is_pro {
-            "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
+            "You have full access to Zed's hosted LLMs, which include models from Anthropic, OpenAI, and Google. They come with faster speeds and higher limits through Zed Pro."
         } else {
             "You have basic access to models from Anthropic through the Zed AI Free plan."
         });
@@ -957,27 +957,14 @@ impl Render for ConfigurationView {
                 })
         } else {
             v_flex()
-                .gap_6()
-                .child(Label::new("Use the zed.dev to access language models."))
+                .gap_2()
+                .child(Label::new("Use Zed AI to access hosted language models."))
                 .child(
-                    v_flex()
-                        .gap_2()
-                        .child(
-                            Button::new("sign_in", "Sign in")
-                                .icon_color(Color::Muted)
-                                .icon(IconName::Github)
-                                .icon_position(IconPosition::Start)
-                                .style(ButtonStyle::Filled)
-                                .full_width()
-                                .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
-                        )
-                        .child(
-                            div().flex().w_full().items_center().child(
-                                Label::new("Sign in to enable collaboration.")
-                                    .color(Color::Muted)
-                                    .size(LabelSize::Small),
-                            ),
-                        ),
+                    Button::new("sign_in", "Sign In")
+                        .icon_color(Color::Muted)
+                        .icon(IconName::Github)
+                        .icon_position(IconPosition::Start)
+                        .on_click(cx.listener(move |this, _, cx| this.authenticate(cx))),
                 )
         }
     }