language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    Action, AnyElement, AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Model,
  6    Subscription, Task, View, WeakView,
  7};
  8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  9use picker::{Picker, PickerDelegate};
 10use proto::Plan;
 11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 12use workspace::ShowConfiguration;
 13
 14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &AppContext) + 'static>;
 17
 18pub struct LanguageModelSelector {
 19    picker: View<Picker<LanguageModelPickerDelegate>>,
 20    /// The task used to update the picker's matches when there is a change to
 21    /// the language model registry.
 22    update_matches_task: Option<Task<()>>,
 23    _subscriptions: Vec<Subscription>,
 24}
 25
 26impl LanguageModelSelector {
 27    pub fn new(
 28        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &AppContext) + 'static,
 29        cx: &mut ViewContext<Self>,
 30    ) -> Self {
 31        let on_model_changed = Arc::new(on_model_changed);
 32
 33        let all_models = Self::all_models(cx);
 34        let delegate = LanguageModelPickerDelegate {
 35            language_model_selector: cx.view().downgrade(),
 36            on_model_changed: on_model_changed.clone(),
 37            all_models: all_models.clone(),
 38            filtered_models: all_models,
 39            selected_index: 0,
 40        };
 41
 42        let picker =
 43            cx.new_view(|cx| Picker::uniform_list(delegate, cx).max_height(Some(rems(20.).into())));
 44
 45        LanguageModelSelector {
 46            picker,
 47            update_matches_task: None,
 48            _subscriptions: vec![cx.subscribe(
 49                &LanguageModelRegistry::global(cx),
 50                Self::handle_language_model_registry_event,
 51            )],
 52        }
 53    }
 54
 55    fn handle_language_model_registry_event(
 56        &mut self,
 57        _registry: Model<LanguageModelRegistry>,
 58        event: &language_model::Event,
 59        cx: &mut ViewContext<Self>,
 60    ) {
 61        match event {
 62            language_model::Event::ProviderStateChanged
 63            | language_model::Event::AddedProvider(_)
 64            | language_model::Event::RemovedProvider(_) => {
 65                let task = self.picker.update(cx, |this, cx| {
 66                    let query = this.query(cx);
 67                    this.delegate.all_models = Self::all_models(cx);
 68                    this.delegate.update_matches(query, cx)
 69                });
 70                self.update_matches_task = Some(task);
 71            }
 72            _ => {}
 73        }
 74    }
 75
 76    fn all_models(cx: &AppContext) -> Vec<ModelInfo> {
 77        LanguageModelRegistry::global(cx)
 78            .read(cx)
 79            .providers()
 80            .iter()
 81            .flat_map(|provider| {
 82                let icon = provider.icon();
 83
 84                provider.provided_models(cx).into_iter().map(move |model| {
 85                    let model = model.clone();
 86                    let icon = model.icon().unwrap_or(icon);
 87
 88                    ModelInfo {
 89                        model: model.clone(),
 90                        icon,
 91                        availability: model.availability(),
 92                    }
 93                })
 94            })
 95            .collect::<Vec<_>>()
 96    }
 97}
 98
 99impl EventEmitter<DismissEvent> for LanguageModelSelector {}
100
101impl FocusableView for LanguageModelSelector {
102    fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
103        self.picker.focus_handle(cx)
104    }
105}
106
107impl Render for LanguageModelSelector {
108    fn render(&mut self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
109        self.picker.clone()
110    }
111}
112
113#[derive(IntoElement)]
114pub struct LanguageModelSelectorPopoverMenu<T>
115where
116    T: PopoverTrigger,
117{
118    language_model_selector: View<LanguageModelSelector>,
119    trigger: T,
120    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
121}
122
123impl<T: PopoverTrigger> LanguageModelSelectorPopoverMenu<T> {
124    pub fn new(language_model_selector: View<LanguageModelSelector>, trigger: T) -> Self {
125        Self {
126            language_model_selector,
127            trigger,
128            handle: None,
129        }
130    }
131
132    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
133        self.handle = Some(handle);
134        self
135    }
136}
137
138impl<T: PopoverTrigger> RenderOnce for LanguageModelSelectorPopoverMenu<T> {
139    fn render(self, _cx: &mut WindowContext) -> impl IntoElement {
140        let language_model_selector = self.language_model_selector.clone();
141
142        PopoverMenu::new("model-switcher")
143            .menu(move |_cx| Some(language_model_selector.clone()))
144            .trigger(self.trigger)
145            .attach(gpui::Corner::BottomLeft)
146            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
147    }
148}
149
150#[derive(Clone)]
151struct ModelInfo {
152    model: Arc<dyn LanguageModel>,
153    icon: IconName,
154    availability: LanguageModelAvailability,
155}
156
157pub struct LanguageModelPickerDelegate {
158    language_model_selector: WeakView<LanguageModelSelector>,
159    on_model_changed: OnModelChanged,
160    all_models: Vec<ModelInfo>,
161    filtered_models: Vec<ModelInfo>,
162    selected_index: usize,
163}
164
165impl PickerDelegate for LanguageModelPickerDelegate {
166    type ListItem = ListItem;
167
168    fn match_count(&self) -> usize {
169        self.filtered_models.len()
170    }
171
172    fn selected_index(&self) -> usize {
173        self.selected_index
174    }
175
176    fn set_selected_index(&mut self, ix: usize, cx: &mut ViewContext<Picker<Self>>) {
177        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
178        cx.notify();
179    }
180
181    fn placeholder_text(&self, _cx: &mut WindowContext) -> Arc<str> {
182        "Select a model...".into()
183    }
184
185    fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
186        let all_models = self.all_models.clone();
187
188        let llm_registry = LanguageModelRegistry::global(cx);
189
190        let configured_providers = llm_registry
191            .read(cx)
192            .providers()
193            .iter()
194            .filter(|provider| provider.is_authenticated(cx))
195            .map(|provider| provider.id())
196            .collect::<Vec<_>>();
197
198        cx.spawn(|this, mut cx| async move {
199            let filtered_models = cx
200                .background_executor()
201                .spawn(async move {
202                    let displayed_models = if configured_providers.is_empty() {
203                        all_models
204                    } else {
205                        all_models
206                            .into_iter()
207                            .filter(|model_info| {
208                                configured_providers.contains(&model_info.model.provider_id())
209                            })
210                            .collect::<Vec<_>>()
211                    };
212
213                    if query.is_empty() {
214                        displayed_models
215                    } else {
216                        displayed_models
217                            .into_iter()
218                            .filter(|model_info| {
219                                model_info
220                                    .model
221                                    .name()
222                                    .0
223                                    .to_lowercase()
224                                    .contains(&query.to_lowercase())
225                            })
226                            .collect()
227                    }
228                })
229                .await;
230
231            this.update(&mut cx, |this, cx| {
232                this.delegate.filtered_models = filtered_models;
233                this.delegate.set_selected_index(0, cx);
234                cx.notify();
235            })
236            .ok();
237        })
238    }
239
240    fn confirm(&mut self, _secondary: bool, cx: &mut ViewContext<Picker<Self>>) {
241        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
242            let model = model_info.model.clone();
243            (self.on_model_changed)(model.clone(), cx);
244
245            cx.emit(DismissEvent);
246        }
247    }
248
249    fn dismissed(&mut self, cx: &mut ViewContext<Picker<Self>>) {
250        self.language_model_selector
251            .update(cx, |_this, cx| cx.emit(DismissEvent))
252            .ok();
253    }
254
255    fn render_header(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<AnyElement> {
256        let configured_models_count = LanguageModelRegistry::global(cx)
257            .read(cx)
258            .providers()
259            .iter()
260            .filter(|provider| provider.is_authenticated(cx))
261            .count();
262
263        if configured_models_count > 0 {
264            Some(
265                Label::new("Configured Models")
266                    .size(LabelSize::Small)
267                    .color(Color::Muted)
268                    .mt_1()
269                    .mb_0p5()
270                    .ml_3()
271                    .into_any_element(),
272            )
273        } else {
274            None
275        }
276    }
277
278    fn render_match(
279        &self,
280        ix: usize,
281        selected: bool,
282        cx: &mut ViewContext<Picker<Self>>,
283    ) -> Option<Self::ListItem> {
284        use feature_flags::FeatureFlagAppExt;
285        let show_badges = cx.has_flag::<ZedPro>();
286
287        let model_info = self.filtered_models.get(ix)?;
288        let provider_name: String = model_info.model.provider_name().0.clone().into();
289
290        let active_provider_id = LanguageModelRegistry::read_global(cx)
291            .active_provider()
292            .map(|m| m.id());
293
294        let active_model_id = LanguageModelRegistry::read_global(cx)
295            .active_model()
296            .map(|m| m.id());
297
298        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
299            && Some(model_info.model.id()) == active_model_id;
300
301        Some(
302            ListItem::new(ix)
303                .inset(true)
304                .spacing(ListItemSpacing::Sparse)
305                .toggle_state(selected)
306                .start_slot(
307                    div().pr_0p5().child(
308                        Icon::new(model_info.icon)
309                            .color(Color::Muted)
310                            .size(IconSize::Medium),
311                    ),
312                )
313                .child(
314                    h_flex()
315                        .w_full()
316                        .items_center()
317                        .gap_1p5()
318                        .min_w(px(200.))
319                        .child(Label::new(model_info.model.name().0.clone()))
320                        .child(
321                            h_flex()
322                                .gap_0p5()
323                                .child(
324                                    Label::new(provider_name)
325                                        .size(LabelSize::XSmall)
326                                        .color(Color::Muted),
327                                )
328                                .children(match model_info.availability {
329                                    LanguageModelAvailability::Public => None,
330                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
331                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
332                                        show_badges.then(|| {
333                                            Label::new("Pro")
334                                                .size(LabelSize::XSmall)
335                                                .color(Color::Muted)
336                                        })
337                                    }
338                                }),
339                        ),
340                )
341                .end_slot(div().when(is_selected, |this| {
342                    this.child(
343                        Icon::new(IconName::Check)
344                            .color(Color::Accent)
345                            .size(IconSize::Small),
346                    )
347                })),
348        )
349    }
350
351    fn render_footer(&self, cx: &mut ViewContext<Picker<Self>>) -> Option<gpui::AnyElement> {
352        use feature_flags::FeatureFlagAppExt;
353
354        let plan = proto::Plan::ZedPro;
355        let is_trial = false;
356
357        Some(
358            h_flex()
359                .w_full()
360                .border_t_1()
361                .border_color(cx.theme().colors().border_variant)
362                .p_1()
363                .gap_4()
364                .justify_between()
365                .when(cx.has_flag::<ZedPro>(), |this| {
366                    this.child(match plan {
367                        // Already a Zed Pro subscriber
368                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
369                            .icon(IconName::ZedAssistant)
370                            .icon_size(IconSize::Small)
371                            .icon_color(Color::Muted)
372                            .icon_position(IconPosition::Start)
373                            .on_click(|_, cx| {
374                                cx.dispatch_action(Box::new(zed_actions::OpenAccountSettings))
375                            }),
376                        // Free user
377                        Plan::Free => Button::new(
378                            "try-pro",
379                            if is_trial {
380                                "Upgrade to Pro"
381                            } else {
382                                "Try Pro"
383                            },
384                        )
385                        .on_click(|_, cx| cx.open_url(TRY_ZED_PRO_URL)),
386                    })
387                })
388                .child(
389                    Button::new("configure", "Configure")
390                        .icon(IconName::Settings)
391                        .icon_size(IconSize::Small)
392                        .icon_color(Color::Muted)
393                        .icon_position(IconPosition::Start)
394                        .on_click(|_, cx| {
395                            cx.dispatch_action(ShowConfiguration.boxed_clone());
396                        }),
397                )
398                .into_any(),
399        )
400    }
401}