@@ -2801,21 +2801,19 @@ pub struct ContextEditorToolbarItem {
fs: Arc<dyn Fs>,
workspace: WeakView<Workspace>,
active_context_editor: Option<WeakView<ContextEditor>>,
- model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
model_summary_editor: View<Editor>,
}
impl ContextEditorToolbarItem {
pub fn new(
workspace: &Workspace,
- model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
+ _model_selector_menu_handle: PopoverMenuHandle<ContextMenu>,
model_summary_editor: View<Editor>,
) -> Self {
Self {
fs: workspace.app_state().fs.clone(),
workspace: workspace.weak_handle(),
active_context_editor: None,
- model_selector_menu_handle,
model_summary_editor,
}
}
@@ -2946,49 +2944,46 @@ impl Render for ContextEditorToolbarItem {
});
let right_side = h_flex()
.gap_2()
- .child(
- ModelSelector::new(
- self.fs.clone(),
- ButtonLike::new("active-model")
- .style(ButtonStyle::Subtle)
- .child(
- h_flex()
- .w_full()
- .gap_0p5()
- .child(
- div()
- .overflow_x_hidden()
- .flex_grow()
- .whitespace_nowrap()
- .child(
- Label::new(
- LanguageModelRegistry::read_global(cx)
- .active_model()
- .map(|model| {
- format!(
- "{}: {}",
- model.provider_name().0,
- model.name().0
- )
- })
- .unwrap_or_else(|| "No model selected".into()),
- )
- .size(LabelSize::Small)
- .color(Color::Muted),
- ),
- )
- .child(
- Icon::new(IconName::ChevronDown)
- .color(Color::Muted)
- .size(IconSize::XSmall),
- ),
- )
- .tooltip(move |cx| {
- Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
- }),
- )
- .with_handle(self.model_selector_menu_handle.clone()),
- )
+ .child(ModelSelector::new(
+ self.fs.clone(),
+ ButtonLike::new("active-model")
+ .style(ButtonStyle::Subtle)
+ .child(
+ h_flex()
+ .w_full()
+ .gap_0p5()
+ .child(
+ div()
+ .overflow_x_hidden()
+ .flex_grow()
+ .whitespace_nowrap()
+ .child(
+ Label::new(
+ LanguageModelRegistry::read_global(cx)
+ .active_model()
+ .map(|model| {
+ format!(
+ "{}: {}",
+ model.provider_name().0,
+ model.name().0
+ )
+ })
+ .unwrap_or_else(|| "No model selected".into()),
+ )
+ .size(LabelSize::Small)
+ .color(Color::Muted),
+ ),
+ )
+ .child(
+ Icon::new(IconName::ChevronDown)
+ .color(Color::Muted)
+ .size(IconSize::XSmall),
+ ),
+ )
+ .tooltip(move |cx| {
+ Tooltip::for_action("Change Model", &ToggleModelSelector, cx)
+ }),
+ ))
.children(self.render_remaining_tokens(cx))
.child(self.render_inject_context_menu(cx));
@@ -1,21 +1,45 @@
+use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
+use proto::Plan;
+
use std::sync::Arc;
+use ui::ListItemSpacing;
-use crate::{assistant_settings::AssistantSettings, ShowConfiguration};
+use crate::assistant_settings::AssistantSettings;
+use crate::ShowConfiguration;
use fs::Fs;
-use gpui::{Action, SharedString};
-use language_model::{LanguageModelAvailability, LanguageModelRegistry};
-use proto::Plan;
+use gpui::Action;
+use gpui::SharedString;
+use gpui::Task;
+use picker::{Picker, PickerDelegate};
use settings::update_settings_file;
-use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
+use ui::{prelude::*, ListItem, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
+
+const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
#[derive(IntoElement)]
pub struct ModelSelector<T: PopoverTrigger> {
- handle: Option<PopoverMenuHandle<ContextMenu>>,
+ handle: Option<PopoverMenuHandle<Picker<ModelPickerDelegate>>>,
fs: Arc<dyn Fs>,
trigger: T,
info_text: Option<SharedString>,
}
+pub struct ModelPickerDelegate {
+ fs: Arc<dyn Fs>,
+ all_models: Vec<ModelInfo>,
+ filtered_models: Vec<ModelInfo>,
+ selected_index: usize,
+}
+
+#[derive(Clone)]
+struct ModelInfo {
+ model: Arc<dyn LanguageModel>,
+ _provider_name: SharedString,
+ provider_icon: IconName,
+ availability: LanguageModelAvailability,
+ is_selected: bool,
+}
+
impl<T: PopoverTrigger> ModelSelector<T> {
pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
ModelSelector {
@@ -26,7 +50,7 @@ impl<T: PopoverTrigger> ModelSelector<T> {
}
}
- pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
+ pub fn with_handle(mut self, handle: PopoverMenuHandle<Picker<ModelPickerDelegate>>) -> Self {
self.handle = Some(handle);
self
}
@@ -37,148 +61,228 @@ impl<T: PopoverTrigger> ModelSelector<T> {
}
}
-impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
- fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
- let mut menu = PopoverMenu::new("model-switcher");
- if let Some(handle) = self.handle {
- menu = menu.with_handle(handle);
- }
+impl PickerDelegate for ModelPickerDelegate {
+ type ListItem = ListItem;
- let info_text = self.info_text.clone();
-
- menu.menu(move |cx| {
- ContextMenu::build(cx, |mut menu, cx| {
- if let Some(info_text) = info_text.clone() {
- menu = menu
- .custom_row(move |_cx| {
- Label::new(info_text.clone())
- .color(Color::Muted)
- .into_any_element()
- })
- .separator();
- }
-
- for (index, provider) in LanguageModelRegistry::global(cx)
- .read(cx)
- .providers()
- .into_iter()
- .enumerate()
- {
- let provider_icon = provider.icon();
- let provider_name = provider.name().0.clone();
-
- if index > 0 {
- menu = menu.separator();
+ fn match_count(&self) -> usize {
+ self.filtered_models.len()
+ }
+
+ fn selected_index(&self) -> usize {
+ self.selected_index
+ }
+
+ fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
+ self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
+ cx.notify();
+ }
+
+ fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
+ "Select a model...".into()
+ }
+
+ fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
+ let all_models = self.all_models.clone();
+ cx.spawn(|this, mut cx| async move {
+ let filtered_models = cx
+ .background_executor()
+ .spawn(async move {
+ if query.is_empty() {
+ all_models
+ } else {
+ all_models
+ .into_iter()
+ .filter(|model_info| {
+ model_info
+ .model
+ .name()
+ .0
+ .to_lowercase()
+ .contains(&query.to_lowercase())
+ })
+ .collect()
}
- menu = menu.custom_row(move |_| {
- h_flex()
- .pb_1()
- .gap_1p5()
- .w_full()
- .child(
- Icon::new(provider_icon)
- .color(Color::Muted)
+ })
+ .await;
+
+ this.update(&mut cx, |this, cx| {
+ this.delegate.filtered_models = filtered_models;
+ this.delegate.set_selected_index(0, cx);
+ cx.notify();
+ })
+ .ok();
+ })
+ }
+
+ fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
+ if let Some(model_info) = self.filtered_models.get(self.selected_index) {
+ let model = model_info.model.clone();
+ update_settings_file::<AssistantSettings>(self.fs.clone(), cx, move |settings, _| {
+ settings.set_model(model.clone())
+ });
+
+ // Update the selection status
+ let selected_model_id = model_info.model.id();
+ for model in &mut self.all_models {
+ model.is_selected = model.model.id() == selected_model_id;
+ }
+ for model in &mut self.filtered_models {
+ model.is_selected = model.model.id() == selected_model_id;
+ }
+ }
+ }
+
+ fn dismissed(&mut self, _cx: &mut ViewContext<Picker<Self>>) {}
+
+ fn render_match(
+ &self,
+ ix: usize,
+ selected: bool,
+ cx: &mut ViewContext<Picker<Self>>,
+ ) -> Option<Self::ListItem> {
+ let model_info = self.filtered_models.get(ix)?;
+
+ Some(
+ ListItem::new(ix)
+ .inset(true)
+ .spacing(ListItemSpacing::Sparse)
+ .selected(selected)
+ .start_slot(
+ div().pr_1().child(
+ Icon::new(model_info.provider_icon)
+ .color(Color::Muted)
+ .size(IconSize::XSmall),
+ ),
+ )
+ .child(
+ h_flex()
+ .w_full()
+ .justify_between()
+ .font_buffer(cx)
+ .min_w(px(200.))
+ .child(
+ h_flex()
+ .gap_2()
+ .child(Label::new(model_info.model.name().0.clone()))
+ .children(match model_info.availability {
+ LanguageModelAvailability::Public => None,
+ LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
+ LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => Some(
+ Label::new("Pro")
+ .size(LabelSize::XSmall)
+ .color(Color::Muted),
+ ),
+ }),
+ )
+ .child(div().when(model_info.is_selected, |this| {
+ this.child(
+ Icon::new(IconName::Check)
+ .color(Color::Accent)
.size(IconSize::Small),
)
- .child(Label::new(provider_name.clone()))
- .into_any_element()
- });
-
- let available_models = provider.provided_models(cx);
- if available_models.is_empty() {
- menu = menu.custom_entry(
- {
- move |_| {
- h_flex()
- .w_full()
- .gap_1()
- .child(Icon::new(IconName::Settings))
- .child(Label::new("Configure"))
- .into_any()
- }
- },
- {
- |cx| {
- cx.dispatch_action(ShowConfiguration.boxed_clone());
- }
- },
- );
- }
+ })),
+ ),
+ )
+ }
+
+ fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
+ let plan = proto::Plan::ZedPro;
+ let is_trial = false;
- let selected_provider = LanguageModelRegistry::read_global(cx)
- .active_provider()
- .map(|m| m.id());
- let selected_model = LanguageModelRegistry::read_global(cx)
- .active_model()
- .map(|m| m.id());
-
- for available_model in available_models {
- menu = menu.custom_entry(
- {
- let id = available_model.id();
- let provider_id = available_model.provider_id();
- let model_name = available_model.name().0.clone();
- let availability = available_model.availability();
- let selected_model = selected_model.clone();
- let selected_provider = selected_provider.clone();
- move |cx| {
- h_flex()
- .w_full()
- .justify_between()
- .font_buffer(cx)
- .min_w(px(260.))
- .child(
- h_flex()
- .gap_2()
- .child(Label::new(model_name.clone()))
- .children(match availability {
- LanguageModelAvailability::Public => None,
- LanguageModelAvailability::RequiresPlan(
- Plan::Free,
- ) => None,
- LanguageModelAvailability::RequiresPlan(
- Plan::ZedPro,
- ) => Some(
- Label::new("Pro")
- .size(LabelSize::XSmall)
- .color(Color::Muted),
- ),
- }),
- )
- .child(div().when(
- selected_model.as_ref() == Some(&id)
- && selected_provider.as_ref() == Some(&provider_id),
- |this| {
- this.child(
- Icon::new(IconName::Check)
- .color(Color::Accent)
- .size(IconSize::Small),
- )
- },
- ))
- .into_any()
- }
- },
- {
- let fs = self.fs.clone();
- let model = available_model.clone();
- move |cx| {
- let model = model.clone();
- update_settings_file::<AssistantSettings>(
- fs.clone(),
- cx,
- move |settings, _| settings.set_model(model),
- );
- }
- },
- );
+ Some(
+ h_flex()
+ .w_full()
+ .border_t_1()
+ .border_color(cx.theme().colors().border)
+ .p_1()
+ .gap_4()
+ .justify_between()
+ .child(match plan {
+ // Already a zed pro subscriber
+ Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
+ .icon(IconName::ZedAssistant)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .icon_position(IconPosition::Start)
+ .on_click(|_, cx| {
+ cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
+ }),
+ // Free user
+ Plan::Free => Button::new(
+ "try-pro",
+ if is_trial {
+ "Upgrade to Pro"
+ } else {
+ "Try Pro"
+ },
+ )
+ .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
+ })
+ .child(
+ Button::new("configure", "Configure")
+ .icon(IconName::Settings)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .icon_position(IconPosition::Start)
+ .on_click(|_, cx| {
+ cx.dispatch_action(ShowConfiguration.boxed_clone());
+ }),
+ )
+ .into_any(),
+ )
+ }
+}
+
+impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
+ fn render(self, cx: &mut WindowContext) -> impl IntoElement {
+ let selected_provider = LanguageModelRegistry::read_global(cx)
+ .active_provider()
+ .map(|m| m.id());
+ let selected_model = LanguageModelRegistry::read_global(cx)
+ .active_model()
+ .map(|m| m.id());
+
+ let all_models = LanguageModelRegistry::global(cx)
+ .read(cx)
+ .providers()
+ .iter()
+ .flat_map(|provider| {
+ let provider_name = provider.name().0.clone();
+ let provider_icon = provider.icon();
+ let provider_id = provider.id();
+ let selected_model = selected_model.clone();
+ let selected_provider = selected_provider.clone();
+
+ provider.provided_models(cx).into_iter().map(move |model| {
+ let model = model.clone();
+
+ ModelInfo {
+ model: model.clone(),
+ _provider_name: provider_name.clone(),
+ provider_icon,
+ availability: model.availability(),
+ is_selected: selected_model.as_ref() == Some(&model.id())
+ && selected_provider.as_ref() == Some(&provider_id),
}
- }
- menu
+ })
})
- .into()
- })
- .trigger(self.trigger)
- .attach(gpui::AnchorCorner::BottomLeft)
+ .collect::<Vec<_>>();
+
+ let delegate = ModelPickerDelegate {
+ fs: self.fs.clone(),
+ all_models: all_models.clone(),
+ filtered_models: all_models,
+ selected_index: 0,
+ };
+
+ let picker_view = cx.new_view(|cx| {
+ let picker = Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into()));
+ picker
+ });
+
+ PopoverMenu::new("model-switcher")
+ .menu(move |_cx| Some(picker_view.clone()))
+ .trigger(self.trigger)
+ .attach(gpui::AnchorCorner::BottomLeft)
}
}
@@ -0,0 +1,184 @@
+use std::sync::Arc;
+
+use crate::{assistant_settings::AssistantSettings, ShowConfiguration};
+use fs::Fs;
+use gpui::{Action, SharedString};
+use language_model::{LanguageModelAvailability, LanguageModelRegistry};
+use proto::Plan;
+use settings::update_settings_file;
+use ui::{prelude::*, ContextMenu, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
+
+#[derive(IntoElement)]
+pub struct ModelSelector<T: PopoverTrigger> {
+ handle: Option<PopoverMenuHandle<ContextMenu>>,
+ fs: Arc<dyn Fs>,
+ trigger: T,
+ info_text: Option<SharedString>,
+}
+
+impl<T: PopoverTrigger> ModelSelector<T> {
+ pub fn new(fs: Arc<dyn Fs>, trigger: T) -> Self {
+ ModelSelector {
+ handle: None,
+ fs,
+ trigger,
+ info_text: None,
+ }
+ }
+
+ pub fn with_handle(mut self, handle: PopoverMenuHandle<ContextMenu>) -> Self {
+ self.handle = Some(handle);
+ self
+ }
+
+ pub fn with_info_text(mut self, text: impl Into<SharedString>) -> Self {
+ self.info_text = Some(text.into());
+ self
+ }
+}
+
+impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
+ fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
+ let mut menu = PopoverMenu::new("model-switcher");
+ if let Some(handle) = self.handle {
+ menu = menu.with_handle(handle);
+ }
+
+ let info_text = self.info_text.clone();
+
+ menu.menu(move |cx| {
+ ContextMenu::build(cx, |mut menu, cx| {
+ if let Some(info_text) = info_text.clone() {
+ menu = menu
+ .custom_row(move |_cx| {
+ Label::new(info_text.clone())
+ .color(Color::Muted)
+ .into_any_element()
+ })
+ .separator();
+ }
+
+ for (index, provider) in LanguageModelRegistry::global(cx)
+ .read(cx)
+ .providers()
+ .into_iter()
+ .enumerate()
+ {
+ let provider_icon = provider.icon();
+ let provider_name = provider.name().0.clone();
+
+ if index > 0 {
+ menu = menu.separator();
+ }
+ menu = menu.custom_row(move |_| {
+ h_flex()
+ .pb_1()
+ .gap_1p5()
+ .w_full()
+ .child(
+ Icon::new(provider_icon)
+ .color(Color::Muted)
+ .size(IconSize::Small),
+ )
+ .child(Label::new(provider_name.clone()))
+ .into_any_element()
+ });
+
+ let available_models = provider.provided_models(cx);
+ if available_models.is_empty() {
+ menu = menu.custom_entry(
+ {
+ move |_| {
+ h_flex()
+ .w_full()
+ .gap_1()
+ .child(Icon::new(IconName::Settings))
+ .child(Label::new("Configure"))
+ .into_any()
+ }
+ },
+ {
+ |cx| {
+ cx.dispatch_action(ShowConfiguration.boxed_clone());
+ }
+ },
+ );
+ }
+
+ let selected_provider = LanguageModelRegistry::read_global(cx)
+ .active_provider()
+ .map(|m| m.id());
+ let selected_model = LanguageModelRegistry::read_global(cx)
+ .active_model()
+ .map(|m| m.id());
+
+ for available_model in available_models {
+ menu = menu.custom_entry(
+ {
+ let id = available_model.id();
+ let provider_id = available_model.provider_id();
+ let model_name = available_model.name().0.clone();
+ let availability = available_model.availability();
+ let selected_model = selected_model.clone();
+ let selected_provider = selected_provider.clone();
+ move |cx| {
+ h_flex()
+ .w_full()
+ .justify_between()
+ .font_buffer(cx)
+ .min_w(px(260.))
+ .child(
+ h_flex()
+ .gap_2()
+ .child(Label::new(model_name.clone()))
+ .children(match availability {
+ LanguageModelAvailability::Public => None,
+ LanguageModelAvailability::RequiresPlan(
+ Plan::Free,
+ ) => None,
+ LanguageModelAvailability::RequiresPlan(
+ Plan::ZedPro,
+ ) => Some(
+ Label::new("Pro")
+ .size(LabelSize::XSmall)
+ .color(Color::Muted),
+ ),
+ }),
+ )
+ .child(div().when(
+ selected_model.as_ref() == Some(&id)
+ && selected_provider.as_ref() == Some(&provider_id),
+ |this| {
+ this.child(
+ Icon::new(IconName::Check)
+ .color(Color::Accent)
+ .size(IconSize::Small),
+ )
+ },
+ ))
+ .into_any()
+ }
+ },
+ {
+ let fs = self.fs.clone();
+ let model = available_model.clone();
+ move |cx| {
+ let model = model.clone();
+ update_settings_file::<AssistantSettings>(
+ fs.clone(),
+ cx,
+ move |settings, _| settings.set_model(model),
+ );
+ }
+ },
+ );
+ }
+ }
+ menu
+ })
+ .into()
+ })
+ .trigger(self.trigger)
+ .attach(gpui::AnchorCorner::BottomLeft)
+ }
+}