agent: Refine language model selector (#28597)

Bennet Bo Fenner , Danilo Leal , and Danilo Leal created

Release Notes:

- agent: Show recommended models in the agent model selector and display
the provider in the model selector's trigger.

---------

Co-authored-by: Danilo Leal <daniloleal09@gmail.com>
Co-authored-by: Danilo Leal <67129314+danilo-leal@users.noreply.github.com>

Change summary

Cargo.lock                                                    |   1 
assets/icons/ai_anthropic_hosted.svg                          |  12 
crates/agent/src/assistant_model_selector.rs                  |  19 
crates/file_finder/src/file_finder_tests.rs                   |  14 
crates/icons/src/icons.rs                                     |   1 
crates/language_model/src/language_model.rs                   |   7 
crates/language_model/src/model/cloud_model.rs                |   8 
crates/language_model_selector/Cargo.toml                     |   1 
crates/language_model_selector/src/language_model_selector.rs | 351 ++--
crates/language_models/src/provider/anthropic.rs              |  30 
crates/language_models/src/provider/cloud.rs                  |  39 
crates/picker/src/picker.rs                                   |  83 
crates/prompt_library/src/prompt_library.rs                   |   2 
13 files changed, 350 insertions(+), 218 deletions(-)

Detailed changes

Cargo.lock πŸ”—

@@ -7657,6 +7657,7 @@ dependencies = [
 name = "language_model_selector"
 version = "0.1.0"
 dependencies = [
+ "collections",
  "feature_flags",
  "gpui",
  "language_model",

assets/icons/ai_anthropic_hosted.svg πŸ”—

@@ -1,12 +0,0 @@
-<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
-<rect width="16" height="16" rx="2" fill="black" fill-opacity="0.2"/>
-<g clip-path="url(#clip0_1916_18)">
-<path d="M10.652 3.79999H8.816L12.164 12.2H14L10.652 3.79999Z" fill="#1F1F1E"/>
-<path d="M5.348 3.79999L2 12.2H3.872L4.55672 10.436H8.05927L8.744 12.2H10.616L7.268 3.79999H5.348ZM5.16224 8.87599L6.308 5.92399L7.45374 8.87599H5.16224Z" fill="#1F1F1E"/>
-</g>
-<defs>
-<clipPath id="clip0_1916_18">
-<rect width="12" height="8.4" fill="white" transform="translate(2 3.79999)"/>
-</clipPath>
-</defs>
-</svg>

crates/agent/src/assistant_model_selector.rs πŸ”—

@@ -80,17 +80,16 @@ impl AssistantModelSelector {
 
 impl Render for AssistantModelSelector {
     fn render(&mut self, _window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
-        let model_registry = LanguageModelRegistry::read_global(cx);
+        let focus_handle = self.focus_handle.clone();
 
+        let model_registry = LanguageModelRegistry::read_global(cx);
         let model = match self.model_type {
             ModelType::Default => model_registry.default_model(),
             ModelType::InlineAssistant => model_registry.inline_assistant_model(),
         };
-
-        let focus_handle = self.focus_handle.clone();
-        let model_name = match model {
-            Some(model) => model.model.name().0,
-            _ => SharedString::from("No model selected"),
+        let (model_name, model_icon) = match model {
+            Some(model) => (model.model.name().0, Some(model.provider.icon())),
+            _ => (SharedString::from("No model selected"), None),
         };
 
         LanguageModelSelectorPopoverMenu::new(
@@ -100,10 +99,16 @@ impl Render for AssistantModelSelector {
                 .child(
                     h_flex()
                         .gap_0p5()
+                        .children(
+                            model_icon.map(|icon| {
+                                Icon::new(icon).color(Color::Muted).size(IconSize::Small)
+                            }),
+                        )
                         .child(
                             Label::new(model_name)
                                 .size(LabelSize::Small)
-                                .color(Color::Muted),
+                                .color(Color::Muted)
+                                .ml_1(),
                         )
                         .child(
                             Icon::new(IconName::ChevronDown)

crates/file_finder/src/file_finder_tests.rs πŸ”—

@@ -2133,18 +2133,28 @@ async fn test_repeat_toggle_action(cx: &mut gpui::TestAppContext) {
 
     cx.dispatch_action(ToggleFileFinder::default());
     let picker = active_file_picker(&workspace, cx);
+
+    picker.update_in(cx, |picker, window, cx| {
+        picker.update_matches(".txt".to_string(), window, cx)
+    });
+
+    cx.run_until_parked();
+
     picker.update(cx, |picker, _| {
+        assert_eq!(picker.delegate.matches.len(), 6);
         assert_eq!(picker.delegate.selected_index, 0);
-        assert_eq!(picker.logical_scroll_top_index(), 0);
     });
 
     // When toggling repeatedly, the picker scrolls to reveal the selected item.
     cx.dispatch_action(ToggleFileFinder::default());
     cx.dispatch_action(ToggleFileFinder::default());
     cx.dispatch_action(ToggleFileFinder::default());
+
+    cx.run_until_parked();
+
     picker.update(cx, |picker, _| {
+        assert_eq!(picker.delegate.matches.len(), 6);
         assert_eq!(picker.delegate.selected_index, 3);
-        assert_eq!(picker.logical_scroll_top_index(), 3);
     });
 }
 

crates/icons/src/icons.rs πŸ”—

@@ -10,7 +10,6 @@ use strum::{EnumIter, EnumString, IntoStaticStr};
 pub enum IconName {
     Ai,
     AiAnthropic,
-    AiAnthropicHosted,
     AiBedrock,
     AiDeepSeek,
     AiEdit,

crates/language_model/src/language_model.rs πŸ”—

@@ -174,10 +174,6 @@ impl Default for LanguageModelTextStream {
 pub trait LanguageModel: Send + Sync {
     fn id(&self) -> LanguageModelId;
     fn name(&self) -> LanguageModelName;
-    /// If None, falls back to [LanguageModelProvider::icon]
-    fn icon(&self) -> Option<IconName> {
-        None
-    }
     fn provider_id(&self) -> LanguageModelProviderId;
     fn provider_name(&self) -> LanguageModelProviderName;
     fn telemetry_id(&self) -> String;
@@ -304,6 +300,9 @@ pub trait LanguageModelProvider: 'static {
     }
     fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
     fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
+    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
+        Vec::new()
+    }
     fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &App) {}
     fn is_authenticated(&self, cx: &App) -> bool;
     fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>>;

crates/language_model/src/model/cloud_model.rs πŸ”—

@@ -6,7 +6,6 @@ use client::Client;
 use gpui::{
     App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _,
 };
-use icons::IconName;
 use proto::{Plan, TypedEnvelope};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
@@ -53,13 +52,6 @@ impl CloudModel {
         }
     }
 
-    pub fn icon(&self) -> Option<IconName> {
-        match self {
-            Self::Anthropic(_) => Some(IconName::AiAnthropicHosted),
-            _ => None,
-        }
-    }
-
     pub fn max_token_count(&self) -> usize {
         match self {
             Self::Anthropic(model) => model.max_token_count(),

crates/language_model_selector/Cargo.toml πŸ”—

@@ -12,6 +12,7 @@ workspace = true
 path = "src/language_model_selector.rs"
 
 [dependencies]
+collections.workspace = true
 feature_flags.workspace = true
 gpui.workspace = true
 language_model.workspace = true

crates/language_model_selector/src/language_model_selector.rs πŸ”—

@@ -1,12 +1,13 @@
 use std::sync::Arc;
 
+use collections::{HashSet, IndexMap};
 use feature_flags::{Assistant2FeatureFlag, ZedPro};
 use gpui::{
     Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
     Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
 };
 use language_model::{
-    AuthenticateError, LanguageModel, LanguageModelAvailability, LanguageModelRegistry,
+    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 };
 use picker::{Picker, PickerDelegate};
 use proto::Plan;
@@ -24,9 +25,6 @@ type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 
 pub struct LanguageModelSelector {
     picker: Entity<Picker<LanguageModelPickerDelegate>>,
-    /// The task used to update the picker's matches when there is a change to
-    /// the language model registry.
-    update_matches_task: Option<Task<()>>,
     _authenticate_all_providers_task: Task<()>,
     _subscriptions: Vec<Subscription>,
 }
@@ -40,16 +38,18 @@ impl LanguageModelSelector {
         let on_model_changed = Arc::new(on_model_changed);
 
         let all_models = Self::all_models(cx);
+        let entries = all_models.entries();
+
         let delegate = LanguageModelPickerDelegate {
             language_model_selector: cx.entity().downgrade(),
             on_model_changed: on_model_changed.clone(),
-            all_models: all_models.clone(),
-            filtered_models: all_models,
-            selected_index: Self::get_active_model_index(cx),
+            all_models: Arc::new(all_models),
+            selected_index: Self::get_active_model_index(&entries, cx),
+            filtered_entries: entries,
         };
 
         let picker = cx.new(|cx| {
-            Picker::uniform_list(delegate, window, cx)
+            Picker::list(delegate, window, cx)
                 .show_scrollbar(true)
                 .width(rems(20.))
                 .max_height(Some(rems(20.).into()))
@@ -59,7 +59,6 @@ impl LanguageModelSelector {
 
         LanguageModelSelector {
             picker,
-            update_matches_task: None,
             _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
             _subscriptions: vec![
                 cx.subscribe_in(
@@ -83,12 +82,13 @@ impl LanguageModelSelector {
             language_model::Event::ProviderStateChanged
             | language_model::Event::AddedProvider(_)
             | language_model::Event::RemovedProvider(_) => {
-                let task = self.picker.update(cx, |this, cx| {
+                self.picker.update(cx, |this, cx| {
                     let query = this.query(cx);
-                    this.delegate.all_models = Self::all_models(cx);
-                    this.delegate.update_matches(query, window, cx)
+                    this.delegate.all_models = Arc::new(Self::all_models(cx));
+                    // Update matches will automatically drop the previous task
+                    // if we get a provider event again
+                    this.update_matches(query, window, cx)
                 });
-                self.update_matches_task = Some(task);
             }
             _ => {}
         }
@@ -144,34 +144,72 @@ impl LanguageModelSelector {
         })
     }
 
-    fn all_models(cx: &App) -> Vec<ModelInfo> {
-        LanguageModelRegistry::global(cx)
+    fn all_models(cx: &App) -> GroupedModels {
+        let mut recommended = Vec::new();
+        let mut recommended_set = HashSet::default();
+        for provider in LanguageModelRegistry::global(cx)
             .read(cx)
             .providers()
             .iter()
-            .flat_map(|provider| {
-                let icon = provider.icon();
-
-                provider.provided_models(cx).into_iter().map(move |model| {
-                    let model = model.clone();
-                    let icon = model.icon().unwrap_or(icon);
-
-                    ModelInfo {
+        {
+            let models = provider.recommended_models(cx);
+            recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
+            recommended.extend(
+                provider
+                    .recommended_models(cx)
+                    .into_iter()
+                    .map(move |model| ModelInfo {
                         model: model.clone(),
-                        icon,
-                        availability: model.availability(),
-                    }
-                })
+                        icon: provider.icon(),
+                    }),
+            );
+        }
+
+        let other_models = LanguageModelRegistry::global(cx)
+            .read(cx)
+            .providers()
+            .iter()
+            .map(|provider| {
+                (
+                    provider.id(),
+                    provider
+                        .provided_models(cx)
+                        .into_iter()
+                        .filter_map(|model| {
+                            let not_included =
+                                !recommended_set.contains(&(model.provider_id(), model.id()));
+                            not_included.then(|| ModelInfo {
+                                model: model.clone(),
+                                icon: provider.icon(),
+                            })
+                        })
+                        .collect::<Vec<_>>(),
+                )
             })
-            .collect::<Vec<_>>()
+            .collect::<IndexMap<_, _>>();
+
+        GroupedModels {
+            recommended,
+            other: other_models,
+        }
     }
 
-    fn get_active_model_index(cx: &App) -> usize {
+    fn get_active_model_index(entries: &[LanguageModelPickerEntry], cx: &App) -> usize {
         let active_model = LanguageModelRegistry::read_global(cx).default_model();
-        Self::all_models(cx)
+        entries
             .iter()
-            .position(|model_info| {
-                Some(model_info.model.id()) == active_model.as_ref().map(|model| model.model.id())
+            .position(|entry| {
+                if let LanguageModelPickerEntry::Model(model) = entry {
+                    active_model
+                        .as_ref()
+                        .map(|active_model| {
+                            active_model.model.id() == model.model.id()
+                                && active_model.model.provider_id() == model.model.provider_id()
+                        })
+                        .unwrap_or_default()
+                } else {
+                    false
+                }
             })
             .unwrap_or(0)
     }
@@ -254,22 +292,61 @@ where
 struct ModelInfo {
     model: Arc<dyn LanguageModel>,
     icon: IconName,
-    availability: LanguageModelAvailability,
 }
 
 pub struct LanguageModelPickerDelegate {
     language_model_selector: WeakEntity<LanguageModelSelector>,
     on_model_changed: OnModelChanged,
-    all_models: Vec<ModelInfo>,
-    filtered_models: Vec<ModelInfo>,
+    all_models: Arc<GroupedModels>,
+    filtered_entries: Vec<LanguageModelPickerEntry>,
     selected_index: usize,
 }
 
+struct GroupedModels {
+    recommended: Vec<ModelInfo>,
+    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
+}
+
+impl GroupedModels {
+    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
+        let mut entries = Vec::new();
+
+        if !self.recommended.is_empty() {
+            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
+            entries.extend(
+                self.recommended
+                    .iter()
+                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
+            );
+        }
+
+        for models in self.other.values() {
+            if models.is_empty() {
+                continue;
+            }
+            entries.push(LanguageModelPickerEntry::Separator(
+                models[0].model.provider_name().0,
+            ));
+            entries.extend(
+                models
+                    .iter()
+                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
+            );
+        }
+        entries
+    }
+}
+
+enum LanguageModelPickerEntry {
+    Model(ModelInfo),
+    Separator(SharedString),
+}
+
 impl PickerDelegate for LanguageModelPickerDelegate {
-    type ListItem = ListItem;
+    type ListItem = AnyElement;
 
     fn match_count(&self) -> usize {
-        self.filtered_models.len()
+        self.filtered_entries.len()
     }
 
     fn selected_index(&self) -> usize {
@@ -277,12 +354,24 @@ impl PickerDelegate for LanguageModelPickerDelegate {
     }
 
     fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
-        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
+        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
         cx.notify();
     }
 
+    fn can_select(
+        &mut self,
+        ix: usize,
+        _window: &mut Window,
+        _cx: &mut Context<Picker<Self>>,
+    ) -> bool {
+        match self.filtered_entries.get(ix) {
+            Some(LanguageModelPickerEntry::Model(_)) => true,
+            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
+        }
+    }
+
     fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
-        "Select a model...".into()
+        "Select a model…".into()
     }
 
     fn update_matches(
@@ -307,22 +396,9 @@ impl PickerDelegate for LanguageModelPickerDelegate {
         cx.spawn_in(window, async move |this, cx| {
             let filtered_models = cx
                 .background_spawn(async move {
-                    let displayed_models = if configured_providers.is_empty() {
-                        all_models
-                    } else {
-                        all_models
-                            .into_iter()
-                            .filter(|model_info| {
-                                configured_providers.contains(&model_info.model.provider_id())
-                            })
-                            .collect::<Vec<_>>()
-                    };
-
-                    if query.is_empty() {
-                        displayed_models
-                    } else {
-                        displayed_models
-                            .into_iter()
+                    let filter_models = |model_infos: &[ModelInfo]| {
+                        model_infos
+                            .iter()
                             .filter(|model_info| {
                                 model_info
                                     .model
@@ -331,20 +407,33 @@ impl PickerDelegate for LanguageModelPickerDelegate {
                                     .to_lowercase()
                                     .contains(&query.to_lowercase())
                             })
-                            .collect()
+                            .cloned()
+                            .collect::<Vec<_>>()
+                    };
+
+                    let recommended_models = filter_models(&all_models.recommended);
+                    let mut other_models = IndexMap::default();
+                    for (provider_id, models) in &all_models.other {
+                        if configured_providers.contains(&provider_id) {
+                            other_models.insert(provider_id.clone(), filter_models(models));
+                        }
+                    }
+                    GroupedModels {
+                        recommended: recommended_models,
+                        other: other_models,
                     }
                 })
                 .await;
 
             this.update_in(cx, |this, window, cx| {
-                this.delegate.filtered_models = filtered_models;
+                this.delegate.filtered_entries = filtered_models.entries();
                 // Preserve selection focus
-                let new_index = if current_index >= this.delegate.filtered_models.len() {
+                let new_index = if current_index >= this.delegate.filtered_entries.len() {
                     0
                 } else {
                     current_index
                 };
-                this.delegate.set_selected_index(new_index, window, cx);
+                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
                 cx.notify();
             })
             .ok();
@@ -352,7 +441,9 @@ impl PickerDelegate for LanguageModelPickerDelegate {
     }
 
     fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
-        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
+        if let Some(LanguageModelPickerEntry::Model(model_info)) =
+            self.filtered_entries.get(self.selected_index)
+        {
             let model = model_info.model.clone();
             (self.on_model_changed)(model.clone(), cx);
 
@@ -369,29 +460,6 @@ impl PickerDelegate for LanguageModelPickerDelegate {
             .ok();
     }
 
-    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
-        let configured_models_count = LanguageModelRegistry::global(cx)
-            .read(cx)
-            .providers()
-            .iter()
-            .filter(|provider| provider.is_authenticated(cx))
-            .count();
-
-        if configured_models_count > 0 {
-            Some(
-                Label::new("Configured Models")
-                    .size(LabelSize::Small)
-                    .color(Color::Muted)
-                    .mt_1()
-                    .mb_0p5()
-                    .ml_2()
-                    .into_any_element(),
-            )
-        } else {
-            None
-        }
-    }
-
     fn render_match(
         &self,
         ix: usize,
@@ -399,77 +467,68 @@ impl PickerDelegate for LanguageModelPickerDelegate {
         _: &mut Window,
         cx: &mut Context<Picker<Self>>,
     ) -> Option<Self::ListItem> {
-        use feature_flags::FeatureFlagAppExt;
-        let show_badges = cx.has_flag::<ZedPro>();
-
-        let model_info = self.filtered_models.get(ix)?;
-        let provider_name: String = model_info.model.provider_name().0.clone().into();
-
-        let active_model = LanguageModelRegistry::read_global(cx).default_model();
+        match self.filtered_entries.get(ix)? {
+            LanguageModelPickerEntry::Separator(title) => Some(
+                div()
+                    .px_2()
+                    .pb_1()
+                    .when(ix > 1, |this| {
+                        this.mt_1()
+                            .pt_2()
+                            .border_t_1()
+                            .border_color(cx.theme().colors().border_variant)
+                    })
+                    .child(
+                        Label::new(title)
+                            .size(LabelSize::XSmall)
+                            .color(Color::Muted),
+                    )
+                    .into_any_element(),
+            ),
+            LanguageModelPickerEntry::Model(model_info) => {
+                let active_model = LanguageModelRegistry::read_global(cx).default_model();
 
-        let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
-        let active_model_id = active_model.map(|m| m.model.id());
+                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
+                let active_model_id = active_model.map(|m| m.model.id());
 
-        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
-            && Some(model_info.model.id()) == active_model_id;
+                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
+                    && Some(model_info.model.id()) == active_model_id;
 
-        let model_icon_color = if is_selected {
-            Color::Accent
-        } else {
-            Color::Muted
-        };
+                let model_icon_color = if is_selected {
+                    Color::Accent
+                } else {
+                    Color::Muted
+                };
 
-        Some(
-            ListItem::new(ix)
-                .inset(true)
-                .spacing(ListItemSpacing::Sparse)
-                .toggle_state(selected)
-                .start_slot(
-                    Icon::new(model_info.icon)
-                        .color(model_icon_color)
-                        .size(IconSize::Small),
-                )
-                .child(
-                    h_flex()
-                        .w_full()
-                        .items_center()
-                        .gap_1p5()
-                        .pl_0p5()
-                        .w(px(240.))
-                        .child(
-                            div()
-                                .max_w_40()
-                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
+                Some(
+                    ListItem::new(ix)
+                        .inset(true)
+                        .spacing(ListItemSpacing::Sparse)
+                        .toggle_state(selected)
+                        .start_slot(
+                            Icon::new(model_info.icon)
+                                .color(model_icon_color)
+                                .size(IconSize::Small),
                         )
                         .child(
                             h_flex()
-                                .gap_0p5()
-                                .child(
-                                    Label::new(provider_name)
-                                        .size(LabelSize::XSmall)
-                                        .color(Color::Muted),
-                                )
-                                .children(match model_info.availability {
-                                    LanguageModelAvailability::Public => None,
-                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
-                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
-                                        show_badges.then(|| {
-                                            Label::new("Pro")
-                                                .size(LabelSize::XSmall)
-                                                .color(Color::Muted)
-                                        })
-                                    }
-                                }),
-                        ),
+                                .w_full()
+                                .pl_0p5()
+                                .gap_1p5()
+                                .w(px(240.))
+                                .child(Label::new(model_info.model.name().0.clone()).truncate()),
+                        )
+                        .end_slot(div().pr_3().when(is_selected, |this| {
+                            this.child(
+                                Icon::new(IconName::Check)
+                                    .color(Color::Accent)
+                                    .size(IconSize::Small),
+                            )
+                        }))
+                        .into_any_element(),
                 )
-                .end_slot(div().pr_3().when(is_selected, |this| {
-                    this.child(
-                        Icon::new(IconName::Check)
-                            .color(Color::Accent)
-                            .size(IconSize::Small),
-                    )
-                })),
-        )
+            }
+        }
     }
 
     fn render_footer(

crates/language_models/src/provider/anthropic.rs πŸ”—

@@ -192,6 +192,16 @@ impl AnthropicLanguageModelProvider {
 
         Self { http_client, state }
     }
+
+    fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
+        Arc::new(AnthropicModel {
+            id: LanguageModelId::from(model.id().to_string()),
+            model,
+            state: self.state.clone(),
+            http_client: self.http_client.clone(),
+            request_limiter: RateLimiter::new(4),
+        }) as Arc<dyn LanguageModel>
+    }
 }
 
 impl LanguageModelProviderState for AnthropicLanguageModelProvider {
@@ -226,6 +236,16 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
         }))
     }
 
+    fn recommended_models(&self, _cx: &App) -> Vec<Arc<dyn LanguageModel>> {
+        [
+            anthropic::Model::Claude3_7Sonnet,
+            anthropic::Model::Claude3_7SonnetThinking,
+        ]
+        .into_iter()
+        .map(|model| self.create_language_model(model))
+        .collect()
+    }
+
     fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
         let mut models = BTreeMap::default();
 
@@ -266,15 +286,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
 
         models
             .into_values()
-            .map(|model| {
-                Arc::new(AnthropicModel {
-                    id: LanguageModelId::from(model.id().to_string()),
-                    model,
-                    state: self.state.clone(),
-                    http_client: self.http_client.clone(),
-                    request_limiter: RateLimiter::new(4),
-                }) as Arc<dyn LanguageModel>
-            })
+            .map(|model| self.create_language_model(model))
             .collect()
     }
 

crates/language_models/src/provider/cloud.rs πŸ”—

@@ -225,6 +225,20 @@ impl CloudLanguageModelProvider {
             _maintain_client_status: maintain_client_status,
         }
     }
+
+    fn create_language_model(
+        &self,
+        model: CloudModel,
+        llm_api_token: LlmApiToken,
+    ) -> Arc<dyn LanguageModel> {
+        Arc::new(CloudLanguageModel {
+            id: LanguageModelId::from(model.id().to_string()),
+            model,
+            llm_api_token: llm_api_token.clone(),
+            client: self.client.clone(),
+            request_limiter: RateLimiter::new(4),
+        }) as Arc<dyn LanguageModel>
+    }
 }
 
 impl LanguageModelProviderState for CloudLanguageModelProvider {
@@ -260,6 +274,17 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
         }))
     }
 
+    fn recommended_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
+        let llm_api_token = self.state.read(cx).llm_api_token.clone();
+        [
+            CloudModel::Anthropic(anthropic::Model::Claude3_7Sonnet),
+            CloudModel::Anthropic(anthropic::Model::Claude3_7SonnetThinking),
+        ]
+        .into_iter()
+        .map(|model| self.create_language_model(model, llm_api_token.clone()))
+        .collect()
+    }
+
     fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
         let mut models = BTreeMap::default();
 
@@ -345,15 +370,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
         let llm_api_token = self.state.read(cx).llm_api_token.clone();
         models
             .into_values()
-            .map(|model| {
-                Arc::new(CloudLanguageModel {
-                    id: LanguageModelId::from(model.id().to_string()),
-                    model,
-                    llm_api_token: llm_api_token.clone(),
-                    client: self.client.clone(),
-                    request_limiter: RateLimiter::new(4),
-                }) as Arc<dyn LanguageModel>
-            })
+            .map(|model| self.create_language_model(model, llm_api_token.clone()))
             .collect()
     }
 
@@ -575,10 +592,6 @@ impl LanguageModel for CloudLanguageModel {
         LanguageModelName::from(self.model.display_name().to_string())
     }
 
-    fn icon(&self) -> Option<IconName> {
-        self.model.icon()
-    }
-
     fn provider_id(&self) -> LanguageModelProviderId {
         LanguageModelProviderId(ZED_CLOUD_PROVIDER_ID.into())
     }

crates/picker/src/picker.rs πŸ”—

@@ -3,8 +3,8 @@ use editor::{Editor, scroll::Autoscroll};
 use gpui::{
     AnyElement, App, ClickEvent, Context, DismissEvent, Entity, EventEmitter, FocusHandle,
     Focusable, Length, ListSizingBehavior, ListState, MouseButton, MouseUpEvent, Render,
-    ScrollHandle, ScrollStrategy, Stateful, Task, UniformListScrollHandle, Window, actions, div,
-    impl_actions, list, prelude::*, uniform_list,
+    ScrollStrategy, Stateful, Task, UniformListScrollHandle, Window, actions, div, impl_actions,
+    list, prelude::*, uniform_list,
 };
 use head::Head;
 use schemars::JsonSchema;
@@ -24,6 +24,11 @@ enum ElementContainer {
     UniformList(UniformListScrollHandle),
 }
 
+pub enum Direction {
+    Up,
+    Down,
+}
+
 actions!(picker, [ConfirmCompletion]);
 
 /// ConfirmInput is an alternative editor action which - instead of selecting active picker entry - treats pickers editor input literally,
@@ -86,6 +91,15 @@ pub trait PickerDelegate: Sized + 'static {
         window: &mut Window,
         cx: &mut Context<Picker<Self>>,
     );
+    fn can_select(
+        &mut self,
+        _ix: usize,
+        _window: &mut Window,
+        _cx: &mut Context<Picker<Self>>,
+    ) -> bool {
+        true
+    }
+
     // Allows binding some optional effect to when the selection changes.
     fn selected_index_changed(
         &self,
@@ -271,10 +285,7 @@ impl<D: PickerDelegate> Picker<D> {
             ElementContainer::UniformList(scroll_handle) => {
                 ScrollbarState::new(scroll_handle.clone())
             }
-            ElementContainer::List(_) => {
-                // todo smit: implement for list
-                ScrollbarState::new(ScrollHandle::new())
-            }
+            ElementContainer::List(state) => ScrollbarState::new(state.clone()),
         };
         let focus_handle = cx.focus_handle();
         let mut this = Self {
@@ -359,16 +370,58 @@ impl<D: PickerDelegate> Picker<D> {
     }
 
     /// Handles the selecting an index, and passing the change to the delegate.
-    /// If `scroll_to_index` is true, the new selected index will be scrolled into view.
+    /// If `fallback_direction` is set to `None`, the index will not be selected
+    /// if the element at that index cannot be selected.
+    /// If `fallback_direction` is set to
+    /// `Some(..)`, the next selectable element will be selected in the
+    /// specified direction (Down or Up), cycling through all elements until
+    /// finding one that can be selected or returning if there are no selectable elements.
+    /// If `scroll_to_index` is true, the new selected index will be scrolled into
+    /// view.
     ///
     /// If some effect is bound to `selected_index_changed`, it will be executed.
     pub fn set_selected_index(
         &mut self,
-        ix: usize,
+        mut ix: usize,
+        fallback_direction: Option<Direction>,
         scroll_to_index: bool,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
+        let match_count = self.delegate.match_count();
+        if match_count == 0 {
+            return;
+        }
+
+        if let Some(bias) = fallback_direction {
+            let mut curr_ix = ix;
+            while !self.delegate.can_select(curr_ix, window, cx) {
+                curr_ix = match bias {
+                    Direction::Down => {
+                        if curr_ix == match_count - 1 {
+                            0
+                        } else {
+                            curr_ix + 1
+                        }
+                    }
+                    Direction::Up => {
+                        if curr_ix == 0 {
+                            match_count - 1
+                        } else {
+                            curr_ix - 1
+                        }
+                    }
+                };
+                // There is no item that can be selected
+                if ix == curr_ix {
+                    return;
+                }
+            }
+            ix = curr_ix;
+        } else if !self.delegate.can_select(ix, window, cx) {
+            return;
+        }
+
         let previous_index = self.delegate.selected_index();
         self.delegate.set_selected_index(ix, window, cx);
         let current_index = self.delegate.selected_index();
@@ -393,7 +446,7 @@ impl<D: PickerDelegate> Picker<D> {
         if count > 0 {
             let index = self.delegate.selected_index();
             let ix = if index == count - 1 { 0 } else { index + 1 };
-            self.set_selected_index(ix, true, window, cx);
+            self.set_selected_index(ix, Some(Direction::Down), true, window, cx);
             cx.notify();
         }
     }
@@ -408,7 +461,7 @@ impl<D: PickerDelegate> Picker<D> {
         if count > 0 {
             let index = self.delegate.selected_index();
             let ix = if index == 0 { count - 1 } else { index - 1 };
-            self.set_selected_index(ix, true, window, cx);
+            self.set_selected_index(ix, Some(Direction::Up), true, window, cx);
             cx.notify();
         }
     }
@@ -416,7 +469,7 @@ impl<D: PickerDelegate> Picker<D> {
     fn select_first(&mut self, _: &menu::SelectFirst, window: &mut Window, cx: &mut Context<Self>) {
         let count = self.delegate.match_count();
         if count > 0 {
-            self.set_selected_index(0, true, window, cx);
+            self.set_selected_index(0, Some(Direction::Down), true, window, cx);
             cx.notify();
         }
     }
@@ -424,7 +477,7 @@ impl<D: PickerDelegate> Picker<D> {
     fn select_last(&mut self, _: &menu::SelectLast, window: &mut Window, cx: &mut Context<Self>) {
         let count = self.delegate.match_count();
         if count > 0 {
-            self.set_selected_index(count - 1, true, window, cx);
+            self.set_selected_index(count - 1, Some(Direction::Up), true, window, cx);
             cx.notify();
         }
     }
@@ -433,7 +486,7 @@ impl<D: PickerDelegate> Picker<D> {
         let count = self.delegate.match_count();
         let index = self.delegate.selected_index();
         let new_index = if index + 1 == count { 0 } else { index + 1 };
-        self.set_selected_index(new_index, true, window, cx);
+        self.set_selected_index(new_index, Some(Direction::Down), true, window, cx);
         cx.notify();
     }
 
@@ -506,14 +559,14 @@ impl<D: PickerDelegate> Picker<D> {
     ) {
         cx.stop_propagation();
         window.prevent_default();
-        self.set_selected_index(ix, false, window, cx);
+        self.set_selected_index(ix, None, false, window, cx);
         self.do_confirm(secondary, window, cx)
     }
 
     fn do_confirm(&mut self, secondary: bool, window: &mut Window, cx: &mut Context<Self>) {
         if let Some(update_query) = self.delegate.confirm_update_query(window, cx) {
             self.set_query(update_query, window, cx);
-            self.delegate.set_selected_index(0, window, cx);
+            self.set_selected_index(0, Some(Direction::Down), false, window, cx);
         } else {
             self.delegate.confirm(secondary, window, cx)
         }

crates/prompt_library/src/prompt_library.rs πŸ”—

@@ -657,7 +657,7 @@ impl PromptLibrary {
                         .iter()
                         .position(|mat| mat.id == prompt_id)
                     {
-                        picker.set_selected_index(ix, true, window, cx);
+                        picker.set_selected_index(ix, None, true, window, cx);
                     }
                 }
             } else {