Update model selector (#15665)

Nate Butler and Marshall Bowers created

Release Notes:

- N/A

---------

Co-authored-by: Marshall Bowers <elliott.codes@gmail.com>

Change summary

assets/icons/ai_anthropic.svg                      |  4 +
assets/icons/ai_google.svg                         |  3 
assets/icons/ai_ollama.svg                         |  5 +
assets/icons/ai_open_ai.svg                        |  1 
assets/icons/ai_zed.svg                            | 10 ++
crates/assistant/src/model_selector.rs             | 59 +++++++++++++--
crates/language_model/src/language_model.rs        |  4 +
crates/language_model/src/provider/anthropic.rs    |  4 +
crates/language_model/src/provider/cloud.rs        |  6 +
crates/language_model/src/provider/copilot_chat.rs |  4 +
crates/language_model/src/provider/google.rs       |  4 +
crates/language_model/src/provider/ollama.rs       |  4 +
crates/language_model/src/provider/open_ai.rs      |  4 +
crates/ui/src/components/icon.rs                   | 10 ++
crates/ui/src/styles/typography.rs                 | 16 ++++
15 files changed, 128 insertions(+), 10 deletions(-)

Detailed changes

assets/icons/ai_anthropic.svg 🔗

@@ -0,0 +1,4 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M3.43331 10.1846L6.66616 2.33334L9.89902 10.1846M3.43331 10.1846L1.9995 13.6667M3.43331 10.1846H9.89902M11.3328 13.6667L9.89902 10.1846" stroke="black" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
+<path d="M14.0613 13.647L9.34721 2.33334" stroke="black" stroke-width="1.25" stroke-linecap="round" stroke-linejoin="round"/>
+</svg>

assets/icons/ai_google.svg 🔗

@@ -0,0 +1,3 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M14.8695 8.16639C14.8695 12.2258 12.0896 15.1147 7.98425 15.1147C4.04818 15.1147 0.869492 11.9361 0.869492 7.99999C0.869492 4.06393 4.04818 0.885239 7.98425 0.885239C9.90064 0.885239 11.5129 1.58811 12.7551 2.74712L10.8187 4.60901C8.28547 2.16475 3.57482 4.00081 3.57482 7.99999C3.57482 10.4816 5.5572 12.4926 7.98425 12.4926C10.8015 12.4926 11.8572 10.4729 12.0236 9.42581H7.98425V6.97868H14.7576C14.8236 7.34303 14.8695 7.69303 14.8695 8.16639Z" fill="black"/>
+</svg>

assets/icons/ai_ollama.svg 🔗

@@ -0,0 +1,5 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<path d="M2.84221 15C2.80954 14.65 2.89261 13.8016 3.48621 13.208C2.85154 12.48 1.93501 10.6712 3.34621 9.26C2.17021 7.496 2.84221 4.808 5.30621 4.668C5.41821 4.14533 6.24388 3.06077 8.05828 3.06077C9.87268 3.06077 10.5818 4.14533 10.6938 4.668C13.1578 4.808 13.8298 7.496 12.6538 9.26C14.065 10.6712 13.1485 12.48 12.5138 13.208C13.1074 13.8016 13.1905 14.65 13.1578 15M3.93421 4.864C3.74755 3.51066 3.6549 1 4.83021 1C5.64221 0.999997 5.83821 2.68 6.14621 3.632M12.0658 4.864C12.2525 3.51066 12.3451 1 11.1698 1C10.3578 0.999997 10.1618 2.68 9.85379 3.632M8.05828 7.6965C7.48995 7.72052 6.28918 7.74527 6.28918 9.10927C6.28918 10.4733 7.54678 10.5621 8.05828 10.5621C8.56978 10.5621 9.71082 10.4733 9.71082 9.10927C9.71082 7.74527 8.62661 7.72052 8.05828 7.6965Z" stroke="black" stroke-width="1.25"/>
+<circle cx="4.98426" cy="7.76129" r="0.666553" fill="black"/>
+<circle cx="0.666553" cy="0.666553" r="0.666553" transform="matrix(-1 0 0 1 11.6823 7.09473)" fill="black"/>
+</svg>

assets/icons/ai_zed.svg 🔗

@@ -0,0 +1,10 @@
+<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
+<g clip-path="url(#clip0_1882_101)">
+<path fill-rule="evenodd" clip-rule="evenodd" d="M2.3125 1.875C2.07088 1.875 1.875 2.07088 1.875 2.3125V11.9375H1V2.3125C1 1.58763 1.58763 1 2.3125 1H14.0344C14.6191 1 14.9118 1.70688 14.4984 2.12029L7.27887 9.33984H9.3125V8.4375H10.1875V9.55859C10.1875 9.92103 9.89369 10.2148 9.53125 10.2148H6.40387L4.89996 11.7187H11.7187V6.25H12.5937V11.7187C12.5937 12.202 12.202 12.5937 11.7187 12.5937H4.02496L2.49371 14.125H13.6875C13.9291 14.125 14.125 13.9291 14.125 13.6875V4.0625H15V13.6875C15 14.4124 14.4124 15 13.6875 15H1.96561C1.38095 15 1.08816 14.2931 1.50157 13.8797L8.69379 6.6875H6.6875V7.5625H5.8125V6.46875C5.8125 6.10631 6.10631 5.8125 6.46875 5.8125H9.56879L11.1 4.28125H4.28125V9.75H3.40625V4.28125C3.40625 3.798 3.798 3.40625 4.28125 3.40625H11.975L13.5063 1.875H2.3125Z" fill="black"/>
+</g>
+<defs>
+<clipPath id="clip0_1882_101">
+<rect width="14" height="14" fill="white" transform="translate(1 1)"/>
+</clipPath>
+</defs>
+</svg>

crates/assistant/src/model_selector.rs 🔗

@@ -3,7 +3,8 @@ use std::sync::Arc;
 use crate::assistant_settings::AssistantSettings;
 use fs::Fs;
 use gpui::SharedString;
-use language_model::LanguageModelRegistry;
+use language_model::{LanguageModelAvailability, LanguageModelRegistry};
+use proto::Plan;
 use settings::update_settings_file;
 use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 
@@ -37,7 +38,7 @@ impl<T: PopoverTrigger> ModelSelector<T> {
 }
 
 impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
-    fn render(self, _: &mut WindowContext) -> impl IntoElement {
+    fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
         let mut menu = PopoverMenu::new("model-switcher");
         if let Some(handle) = self.handle {
             menu = menu.with_handle(handle);
@@ -63,10 +64,25 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
                     .into_iter()
                     .enumerate()
                 {
+                    let provider_icon = provider.icon();
+                    let provider_name = provider.name().0.clone();
+
                     if index > 0 {
                         menu = menu.separator();
                     }
-                    menu = menu.header(provider.name().0);
+                    menu = menu.custom_row(move |_| {
+                        h_flex()
+                            .pb_1()
+                            .gap_1p5()
+                            .w_full()
+                            .child(
+                                Icon::new(provider_icon)
+                                    .color(Color::Muted)
+                                    .size(IconSize::Small),
+                            )
+                            .child(Label::new(provider_name.clone()))
+                            .into_any_element()
+                    });
 
                     let available_models = provider.provided_models(cx);
                     if available_models.is_empty() {
@@ -109,19 +125,44 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
                                 let id = available_model.id();
                                 let provider_id = available_model.provider_id();
                                 let model_name = available_model.name().0.clone();
-                                let _availability = available_model.availability();
+                                let availability = available_model.availability();
                                 let selected_model = selected_model.clone();
                                 let selected_provider = selected_provider.clone();
-                                move |_| {
+                                move |cx| {
                                     h_flex()
                                         .w_full()
                                         .justify_between()
-                                        .child(Label::new(model_name.clone()))
-                                        .when(
+                                        .font_buffer(cx)
+                                        .min_w(px(260.))
+                                        .child(
+                                            h_flex()
+                                                .gap_2()
+                                                .child(Label::new(model_name.clone()))
+                                                .children(match availability {
+                                                    LanguageModelAvailability::Public => None,
+                                                    LanguageModelAvailability::RequiresPlan(
+                                                        Plan::Free,
+                                                    ) => None,
+                                                    LanguageModelAvailability::RequiresPlan(
+                                                        Plan::ZedPro,
+                                                    ) => Some(
+                                                        Label::new("Pro")
+                                                            .size(LabelSize::XSmall)
+                                                            .color(Color::Muted),
+                                                    ),
+                                                }),
+                                        )
+                                        .child(div().when(
                                             selected_model.as_ref() == Some(&id)
                                                 && selected_provider.as_ref() == Some(&provider_id),
-                                            |this| this.child(Icon::new(IconName::Check)),
-                                        )
+                                            |this| {
+                                                this.child(
+                                                    Icon::new(IconName::Check)
+                                                        .color(Color::Accent)
+                                                        .size(IconSize::Small),
+                                                )
+                                            },
+                                        ))
                                         .into_any()
                                 }
                             },

crates/language_model/src/language_model.rs 🔗

@@ -22,6 +22,7 @@ pub use role::*;
 use schemars::JsonSchema;
 use serde::de::DeserializeOwned;
 use std::{future::Future, sync::Arc};
+use ui::IconName;
 
 pub fn init(
     user_store: Model<UserStore>,
@@ -102,6 +103,9 @@ pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
 pub trait LanguageModelProvider: 'static {
     fn id(&self) -> LanguageModelProviderId;
     fn name(&self) -> LanguageModelProviderName;
+    fn icon(&self) -> IconName {
+        IconName::ZedAssistant
+    }
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
     fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
     fn is_authenticated(&self, cx: &AppContext) -> bool;

crates/language_model/src/provider/anthropic.rs 🔗

@@ -115,6 +115,10 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
 
+    fn icon(&self) -> IconName {
+        IconName::AiAnthropic
+    }
+
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
         let mut models = BTreeMap::default();
 

crates/language_model/src/provider/cloud.rs 🔗

@@ -23,7 +23,7 @@ use crate::{LanguageModelAvailability, LanguageModelProvider};
 use super::anthropic::count_anthropic_tokens;
 
 pub const PROVIDER_ID: &str = "zed.dev";
-pub const PROVIDER_NAME: &str = "Zed AI";
+pub const PROVIDER_NAME: &str = "Zed";
 
 #[derive(Default, Clone, Debug, PartialEq)]
 pub struct ZedDotDevSettings {
@@ -128,6 +128,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
 
+    fn icon(&self) -> IconName {
+        IconName::AiZed
+    }
+
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
         let mut models = BTreeMap::default();
 

crates/language_model/src/provider/copilot_chat.rs 🔗

@@ -91,6 +91,10 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
 
+    fn icon(&self) -> IconName {
+        IconName::Copilot
+    }
+
     fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
         CopilotChatModel::iter()
             .map(|model| {

crates/language_model/src/provider/google.rs 🔗

@@ -97,6 +97,10 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
 
+    fn icon(&self) -> IconName {
+        IconName::AiGoogle
+    }
+
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
         let mut models = BTreeMap::default();
 

crates/language_model/src/provider/ollama.rs 🔗

@@ -108,6 +108,10 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
 
+    fn icon(&self) -> IconName {
+        IconName::AiOllama
+    }
+
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
         self.state
             .read(cx)

crates/language_model/src/provider/open_ai.rs 🔗

@@ -98,6 +98,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
 
+    fn icon(&self) -> IconName {
+        IconName::AiOpenAi
+    }
+
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
         let mut models = BTreeMap::default();
 

crates/ui/src/components/icon.rs 🔗

@@ -106,6 +106,11 @@ impl IconSize {
 )]
 pub enum IconName {
     Ai,
+    AiAnthropic,
+    AiOpenAi,
+    AiGoogle,
+    AiOllama,
+    AiZed,
     ArrowCircle,
     ArrowDown,
     ArrowDownFromLine,
@@ -262,6 +267,11 @@ impl IconName {
     pub fn path(self) -> &'static str {
         match self {
             IconName::Ai => "icons/ai.svg",
+            IconName::AiAnthropic => "icons/ai_anthropic.svg",
+            IconName::AiOpenAi => "icons/ai_open_ai.svg",
+            IconName::AiGoogle => "icons/ai_google.svg",
+            IconName::AiOllama => "icons/ai_ollama.svg",
+            IconName::AiZed => "icons/ai_zed.svg",
             IconName::ArrowCircle => "icons/arrow_circle.svg",
             IconName::ArrowDown => "icons/arrow_down.svg",
             IconName::ArrowDownFromLine => "icons/arrow_down_from_line.svg",

crates/ui/src/styles/typography.rs 🔗

@@ -8,6 +8,22 @@ use crate::{rems_from_px, Color};
 
 /// Extends [`gpui::Styled`] with typography-related styling methods.
 pub trait StyledTypography: Styled + Sized {
+    /// Sets the font family to the buffer font.
+    fn font_buffer(self, cx: &WindowContext) -> Self {
+        let settings = ThemeSettings::get_global(cx);
+        let buffer_font_family = settings.buffer_font.family.clone();
+
+        self.font_family(buffer_font_family)
+    }
+
+    /// Sets the font family to the UI font.
+    fn font_ui(self, cx: &WindowContext) -> Self {
+        let settings = ThemeSettings::get_global(cx);
+        let ui_font_family = settings.ui_font.family.clone();
+
+        self.font_family(ui_font_family)
+    }
+
     /// Sets the text size using a [`UiTextSize`].
     fn text_ui_size(self, size: TextSize, cx: &WindowContext) -> Self {
         self.text_size(size.rems(cx))