language_model_selector.rs

  1use std::sync::Arc;
  2
  3use feature_flags::ZedPro;
  4use gpui::{
  5    Action, AnyElement, App, DismissEvent, Entity, EventEmitter, FocusHandle, Focusable,
  6    Subscription, Task, WeakEntity,
  7};
  8use language_model::{LanguageModel, LanguageModelAvailability, LanguageModelRegistry};
  9use picker::{Picker, PickerDelegate};
 10use proto::Plan;
 11use ui::{prelude::*, ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger};
 12use workspace::ShowConfiguration;
 13
 14const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 15
 16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &App) + 'static>;
 17
 18pub struct LanguageModelSelector {
 19    picker: Entity<Picker<LanguageModelPickerDelegate>>,
 20    /// The task used to update the picker's matches when there is a change to
 21    /// the language model registry.
 22    update_matches_task: Option<Task<()>>,
 23    _subscriptions: Vec<Subscription>,
 24}
 25
 26impl LanguageModelSelector {
 27    pub fn new(
 28        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &App) + 'static,
 29        window: &mut Window,
 30        cx: &mut Context<Self>,
 31    ) -> Self {
 32        let on_model_changed = Arc::new(on_model_changed);
 33
 34        let all_models = Self::all_models(cx);
 35        let delegate = LanguageModelPickerDelegate {
 36            language_model_selector: cx.model().downgrade(),
 37            on_model_changed: on_model_changed.clone(),
 38            all_models: all_models.clone(),
 39            filtered_models: all_models,
 40            selected_index: 0,
 41        };
 42
 43        let picker = cx.new(|cx| {
 44            Picker::uniform_list(delegate, window, cx).max_height(Some(rems(20.).into()))
 45        });
 46
 47        LanguageModelSelector {
 48            picker,
 49            update_matches_task: None,
 50            _subscriptions: vec![cx.subscribe_in(
 51                &LanguageModelRegistry::global(cx),
 52                window,
 53                Self::handle_language_model_registry_event,
 54            )],
 55        }
 56    }
 57
 58    fn handle_language_model_registry_event(
 59        &mut self,
 60        _registry: &Entity<LanguageModelRegistry>,
 61        event: &language_model::Event,
 62        window: &mut Window,
 63        cx: &mut Context<Self>,
 64    ) {
 65        match event {
 66            language_model::Event::ProviderStateChanged
 67            | language_model::Event::AddedProvider(_)
 68            | language_model::Event::RemovedProvider(_) => {
 69                let task = self.picker.update(cx, |this, cx| {
 70                    let query = this.query(cx);
 71                    this.delegate.all_models = Self::all_models(cx);
 72                    this.delegate.update_matches(query, window, cx)
 73                });
 74                self.update_matches_task = Some(task);
 75            }
 76            _ => {}
 77        }
 78    }
 79
 80    fn all_models(cx: &App) -> Vec<ModelInfo> {
 81        LanguageModelRegistry::global(cx)
 82            .read(cx)
 83            .providers()
 84            .iter()
 85            .flat_map(|provider| {
 86                let icon = provider.icon();
 87
 88                provider.provided_models(cx).into_iter().map(move |model| {
 89                    let model = model.clone();
 90                    let icon = model.icon().unwrap_or(icon);
 91
 92                    ModelInfo {
 93                        model: model.clone(),
 94                        icon,
 95                        availability: model.availability(),
 96                    }
 97                })
 98            })
 99            .collect::<Vec<_>>()
100    }
101}
102
103impl EventEmitter<DismissEvent> for LanguageModelSelector {}
104
105impl Focusable for LanguageModelSelector {
106    fn focus_handle(&self, cx: &App) -> FocusHandle {
107        self.picker.focus_handle(cx)
108    }
109}
110
111impl Render for LanguageModelSelector {
112    fn render(&mut self, _window: &mut Window, _cx: &mut Context<Self>) -> impl IntoElement {
113        self.picker.clone()
114    }
115}
116
117#[derive(IntoElement)]
118pub struct LanguageModelSelectorPopoverMenu<T>
119where
120    T: PopoverTrigger,
121{
122    language_model_selector: Entity<LanguageModelSelector>,
123    trigger: T,
124    handle: Option<PopoverMenuHandle<LanguageModelSelector>>,
125}
126
127impl<T: PopoverTrigger> LanguageModelSelectorPopoverMenu<T> {
128    pub fn new(language_model_selector: Entity<LanguageModelSelector>, trigger: T) -> Self {
129        Self {
130            language_model_selector,
131            trigger,
132            handle: None,
133        }
134    }
135
136    pub fn with_handle(mut self, handle: PopoverMenuHandle<LanguageModelSelector>) -> Self {
137        self.handle = Some(handle);
138        self
139    }
140}
141
142impl<T: PopoverTrigger> RenderOnce for LanguageModelSelectorPopoverMenu<T> {
143    fn render(self, _window: &mut Window, _cx: &mut App) -> impl IntoElement {
144        let language_model_selector = self.language_model_selector.clone();
145
146        PopoverMenu::new("model-switcher")
147            .menu(move |_window, _cx| Some(language_model_selector.clone()))
148            .trigger(self.trigger)
149            .attach(gpui::Corner::BottomLeft)
150            .when_some(self.handle.clone(), |menu, handle| menu.with_handle(handle))
151    }
152}
153
154#[derive(Clone)]
155struct ModelInfo {
156    model: Arc<dyn LanguageModel>,
157    icon: IconName,
158    availability: LanguageModelAvailability,
159}
160
161pub struct LanguageModelPickerDelegate {
162    language_model_selector: WeakEntity<LanguageModelSelector>,
163    on_model_changed: OnModelChanged,
164    all_models: Vec<ModelInfo>,
165    filtered_models: Vec<ModelInfo>,
166    selected_index: usize,
167}
168
169impl PickerDelegate for LanguageModelPickerDelegate {
170    type ListItem = ListItem;
171
172    fn match_count(&self) -> usize {
173        self.filtered_models.len()
174    }
175
176    fn selected_index(&self) -> usize {
177        self.selected_index
178    }
179
180    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
181        self.selected_index = ix.min(self.filtered_models.len().saturating_sub(1));
182        cx.notify();
183    }
184
185    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
186        "Select a model...".into()
187    }
188
189    fn update_matches(
190        &mut self,
191        query: String,
192        window: &mut Window,
193        cx: &mut Context<Picker<Self>>,
194    ) -> Task<()> {
195        let all_models = self.all_models.clone();
196
197        let llm_registry = LanguageModelRegistry::global(cx);
198
199        let configured_providers = llm_registry
200            .read(cx)
201            .providers()
202            .iter()
203            .filter(|provider| provider.is_authenticated(cx))
204            .map(|provider| provider.id())
205            .collect::<Vec<_>>();
206
207        cx.spawn_in(window, |this, mut cx| async move {
208            let filtered_models = cx
209                .background_executor()
210                .spawn(async move {
211                    let displayed_models = if configured_providers.is_empty() {
212                        all_models
213                    } else {
214                        all_models
215                            .into_iter()
216                            .filter(|model_info| {
217                                configured_providers.contains(&model_info.model.provider_id())
218                            })
219                            .collect::<Vec<_>>()
220                    };
221
222                    if query.is_empty() {
223                        displayed_models
224                    } else {
225                        displayed_models
226                            .into_iter()
227                            .filter(|model_info| {
228                                model_info
229                                    .model
230                                    .name()
231                                    .0
232                                    .to_lowercase()
233                                    .contains(&query.to_lowercase())
234                            })
235                            .collect()
236                    }
237                })
238                .await;
239
240            this.update_in(&mut cx, |this, window, cx| {
241                this.delegate.filtered_models = filtered_models;
242                this.delegate.set_selected_index(0, window, cx);
243                cx.notify();
244            })
245            .ok();
246        })
247    }
248
249    fn confirm(&mut self, _secondary: bool, _: &mut Window, cx: &mut Context<Picker<Self>>) {
250        if let Some(model_info) = self.filtered_models.get(self.selected_index) {
251            let model = model_info.model.clone();
252            (self.on_model_changed)(model.clone(), cx);
253
254            cx.emit(DismissEvent);
255        }
256    }
257
258    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
259        self.language_model_selector
260            .update(cx, |_this, cx| cx.emit(DismissEvent))
261            .ok();
262    }
263
264    fn render_header(&self, _: &mut Window, cx: &mut Context<Picker<Self>>) -> Option<AnyElement> {
265        let configured_models_count = LanguageModelRegistry::global(cx)
266            .read(cx)
267            .providers()
268            .iter()
269            .filter(|provider| provider.is_authenticated(cx))
270            .count();
271
272        if configured_models_count > 0 {
273            Some(
274                Label::new("Configured Models")
275                    .size(LabelSize::Small)
276                    .color(Color::Muted)
277                    .mt_1()
278                    .mb_0p5()
279                    .ml_3()
280                    .into_any_element(),
281            )
282        } else {
283            None
284        }
285    }
286
287    fn render_match(
288        &self,
289        ix: usize,
290        selected: bool,
291        _: &mut Window,
292        cx: &mut Context<Picker<Self>>,
293    ) -> Option<Self::ListItem> {
294        use feature_flags::FeatureFlagAppExt;
295        let show_badges = cx.has_flag::<ZedPro>();
296
297        let model_info = self.filtered_models.get(ix)?;
298        let provider_name: String = model_info.model.provider_name().0.clone().into();
299
300        let active_provider_id = LanguageModelRegistry::read_global(cx)
301            .active_provider()
302            .map(|m| m.id());
303
304        let active_model_id = LanguageModelRegistry::read_global(cx)
305            .active_model()
306            .map(|m| m.id());
307
308        let is_selected = Some(model_info.model.provider_id()) == active_provider_id
309            && Some(model_info.model.id()) == active_model_id;
310
311        Some(
312            ListItem::new(ix)
313                .inset(true)
314                .spacing(ListItemSpacing::Sparse)
315                .toggle_state(selected)
316                .start_slot(
317                    div().pr_0p5().child(
318                        Icon::new(model_info.icon)
319                            .color(Color::Muted)
320                            .size(IconSize::Medium),
321                    ),
322                )
323                .child(
324                    h_flex()
325                        .w_full()
326                        .items_center()
327                        .gap_1p5()
328                        .min_w(px(200.))
329                        .child(Label::new(model_info.model.name().0.clone()))
330                        .child(
331                            h_flex()
332                                .gap_0p5()
333                                .child(
334                                    Label::new(provider_name)
335                                        .size(LabelSize::XSmall)
336                                        .color(Color::Muted),
337                                )
338                                .children(match model_info.availability {
339                                    LanguageModelAvailability::Public => None,
340                                    LanguageModelAvailability::RequiresPlan(Plan::Free) => None,
341                                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro) => {
342                                        show_badges.then(|| {
343                                            Label::new("Pro")
344                                                .size(LabelSize::XSmall)
345                                                .color(Color::Muted)
346                                        })
347                                    }
348                                }),
349                        ),
350                )
351                .end_slot(div().when(is_selected, |this| {
352                    this.child(
353                        Icon::new(IconName::Check)
354                            .color(Color::Accent)
355                            .size(IconSize::Small),
356                    )
357                })),
358        )
359    }
360
361    fn render_footer(
362        &self,
363        _: &mut Window,
364        cx: &mut Context<Picker<Self>>,
365    ) -> Option<gpui::AnyElement> {
366        use feature_flags::FeatureFlagAppExt;
367
368        let plan = proto::Plan::ZedPro;
369        let is_trial = false;
370
371        Some(
372            h_flex()
373                .w_full()
374                .border_t_1()
375                .border_color(cx.theme().colors().border_variant)
376                .p_1()
377                .gap_4()
378                .justify_between()
379                .when(cx.has_flag::<ZedPro>(), |this| {
380                    this.child(match plan {
381                        // Already a Zed Pro subscriber
382                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
383                            .icon(IconName::ZedAssistant)
384                            .icon_size(IconSize::Small)
385                            .icon_color(Color::Muted)
386                            .icon_position(IconPosition::Start)
387                            .on_click(|_, window, cx| {
388                                window
389                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
390                            }),
391                        // Free user
392                        Plan::Free => Button::new(
393                            "try-pro",
394                            if is_trial {
395                                "Upgrade to Pro"
396                            } else {
397                                "Try Pro"
398                            },
399                        )
400                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
401                    })
402                })
403                .child(
404                    Button::new("configure", "Configure")
405                        .icon(IconName::Settings)
406                        .icon_size(IconSize::Small)
407                        .icon_color(Color::Muted)
408                        .icon_position(IconPosition::Start)
409                        .on_click(|_, window, cx| {
410                            window.dispatch_action(ShowConfiguration.boxed_clone(), cx);
411                        }),
412                )
413                .into_any(),
414        )
415    }
416}