language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use cloud_llm_client::Plan;
  4use collections::{HashSet, IndexMap};
  5use feature_flags::ZedProFeatureFlag;
  6use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  7use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
  8use language_model::{
  9    ConfiguredModel, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
 10};
 11use ordered_float::OrderedFloat;
 12use picker::{Picker, PickerDelegate};
 13use ui::{ListItem, ListItemSpacing, prelude::*};
 14
 15const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
 16
 17type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 18type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 19
 20pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 21
 22pub fn language_model_selector(
 23    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 24    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 25    window: &mut Window,
 26    cx: &mut Context<LanguageModelSelector>,
 27) -> LanguageModelSelector {
 28    let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
 29    Picker::list(delegate, window, cx)
 30        .show_scrollbar(true)
 31        .width(rems(20.))
 32        .max_height(Some(rems(20.).into()))
 33}
 34
 35fn all_models(cx: &App) -> GroupedModels {
 36    let providers = LanguageModelRegistry::global(cx).read(cx).providers();
 37
 38    let recommended = providers
 39        .iter()
 40        .flat_map(|provider| {
 41            provider
 42                .recommended_models(cx)
 43                .into_iter()
 44                .map(|model| ModelInfo {
 45                    model,
 46                    icon: provider.icon(),
 47                })
 48        })
 49        .collect();
 50
 51    let other = providers
 52        .iter()
 53        .flat_map(|provider| {
 54            provider
 55                .provided_models(cx)
 56                .into_iter()
 57                .map(|model| ModelInfo {
 58                    model,
 59                    icon: provider.icon(),
 60                })
 61        })
 62        .collect();
 63
 64    GroupedModels::new(other, recommended)
 65}
 66
 67#[derive(Clone)]
 68struct ModelInfo {
 69    model: Arc<dyn LanguageModel>,
 70    icon: IconName,
 71}
 72
 73pub struct LanguageModelPickerDelegate {
 74    on_model_changed: OnModelChanged,
 75    get_active_model: GetActiveModel,
 76    all_models: Arc<GroupedModels>,
 77    filtered_entries: Vec<LanguageModelPickerEntry>,
 78    selected_index: usize,
 79    _subscriptions: Vec<Subscription>,
 80}
 81
 82impl LanguageModelPickerDelegate {
 83    fn new(
 84        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 85        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 86        window: &mut Window,
 87        cx: &mut Context<Picker<Self>>,
 88    ) -> Self {
 89        let on_model_changed = Arc::new(on_model_changed);
 90        let models = all_models(cx);
 91        let entries = models.entries();
 92
 93        Self {
 94            on_model_changed,
 95            all_models: Arc::new(models),
 96            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
 97            filtered_entries: entries,
 98            get_active_model: Arc::new(get_active_model),
 99            _subscriptions: vec![cx.subscribe_in(
100                &LanguageModelRegistry::global(cx),
101                window,
102                |picker, _, event, window, cx| {
103                    match event {
104                        language_model::Event::ProviderStateChanged(_)
105                        | language_model::Event::AddedProvider(_)
106                        | language_model::Event::RemovedProvider(_) => {
107                            let query = picker.query(cx);
108                            picker.delegate.all_models = Arc::new(all_models(cx));
109                            // Update matches will automatically drop the previous task
110                            // if we get a provider event again
111                            picker.update_matches(query, window, cx)
112                        }
113                        _ => {}
114                    }
115                },
116            )],
117        }
118    }
119
120    fn get_active_model_index(
121        entries: &[LanguageModelPickerEntry],
122        active_model: Option<ConfiguredModel>,
123    ) -> usize {
124        entries
125            .iter()
126            .position(|entry| {
127                if let LanguageModelPickerEntry::Model(model) = entry {
128                    active_model
129                        .as_ref()
130                        .map(|active_model| {
131                            active_model.model.id() == model.model.id()
132                                && active_model.provider.id() == model.model.provider_id()
133                        })
134                        .unwrap_or_default()
135                } else {
136                    false
137                }
138            })
139            .unwrap_or(0)
140    }
141
142    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
143        (self.get_active_model)(cx)
144    }
145}
146
147struct GroupedModels {
148    recommended: Vec<ModelInfo>,
149    other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
150}
151
152impl GroupedModels {
153    pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
154        let recommended_ids = recommended
155            .iter()
156            .map(|info| (info.model.provider_id(), info.model.id()))
157            .collect::<HashSet<_>>();
158
159        let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
160        for model in other {
161            if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
162                continue;
163            }
164
165            let provider = model.model.provider_id();
166            if let Some(models) = other_by_provider.get_mut(&provider) {
167                models.push(model);
168            } else {
169                other_by_provider.insert(provider, vec![model]);
170            }
171        }
172
173        Self {
174            recommended,
175            other: other_by_provider,
176        }
177    }
178
179    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
180        let mut entries = Vec::new();
181
182        if !self.recommended.is_empty() {
183            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
184            entries.extend(
185                self.recommended
186                    .iter()
187                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
188            );
189        }
190
191        for models in self.other.values() {
192            if models.is_empty() {
193                continue;
194            }
195            entries.push(LanguageModelPickerEntry::Separator(
196                models[0].model.provider_name().0,
197            ));
198            entries.extend(
199                models
200                    .iter()
201                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
202            );
203        }
204        entries
205    }
206
207    fn model_infos(&self) -> Vec<ModelInfo> {
208        let other = self
209            .other
210            .values()
211            .flat_map(|model| model.iter())
212            .cloned()
213            .collect::<Vec<_>>();
214        self.recommended
215            .iter()
216            .chain(&other)
217            .cloned()
218            .collect::<Vec<_>>()
219    }
220}
221
222enum LanguageModelPickerEntry {
223    Model(ModelInfo),
224    Separator(SharedString),
225}
226
227struct ModelMatcher {
228    models: Vec<ModelInfo>,
229    bg_executor: BackgroundExecutor,
230    candidates: Vec<StringMatchCandidate>,
231}
232
233impl ModelMatcher {
234    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
235        let candidates = Self::make_match_candidates(&models);
236        Self {
237            models,
238            bg_executor,
239            candidates,
240        }
241    }
242
243    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
244        let mut matches = self.bg_executor.block(match_strings(
245            &self.candidates,
246            query,
247            false,
248            true,
249            100,
250            &Default::default(),
251            self.bg_executor.clone(),
252        ));
253
254        let sorting_key = |mat: &StringMatch| {
255            let candidate = &self.candidates[mat.candidate_id];
256            (Reverse(OrderedFloat(mat.score)), candidate.id)
257        };
258        matches.sort_unstable_by_key(sorting_key);
259
260        let matched_models: Vec<_> = matches
261            .into_iter()
262            .map(|mat| self.models[mat.candidate_id].clone())
263            .collect();
264
265        matched_models
266    }
267
268    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
269        self.models
270            .iter()
271            .filter(|m| {
272                m.model
273                    .name()
274                    .0
275                    .to_lowercase()
276                    .contains(&query.to_lowercase())
277            })
278            .cloned()
279            .collect::<Vec<_>>()
280    }
281
282    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
283        model_infos
284            .iter()
285            .enumerate()
286            .map(|(index, model)| {
287                StringMatchCandidate::new(
288                    index,
289                    &format!(
290                        "{}/{}",
291                        &model.model.provider_name().0,
292                        &model.model.name().0
293                    ),
294                )
295            })
296            .collect::<Vec<_>>()
297    }
298}
299
300impl PickerDelegate for LanguageModelPickerDelegate {
301    type ListItem = AnyElement;
302
303    fn match_count(&self) -> usize {
304        self.filtered_entries.len()
305    }
306
307    fn selected_index(&self) -> usize {
308        self.selected_index
309    }
310
311    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
312        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
313        cx.notify();
314    }
315
316    fn can_select(
317        &mut self,
318        ix: usize,
319        _window: &mut Window,
320        _cx: &mut Context<Picker<Self>>,
321    ) -> bool {
322        match self.filtered_entries.get(ix) {
323            Some(LanguageModelPickerEntry::Model(_)) => true,
324            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
325        }
326    }
327
328    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
329        "Select a model…".into()
330    }
331
332    fn update_matches(
333        &mut self,
334        query: String,
335        window: &mut Window,
336        cx: &mut Context<Picker<Self>>,
337    ) -> Task<()> {
338        let all_models = self.all_models.clone();
339        let active_model = (self.get_active_model)(cx);
340        let bg_executor = cx.background_executor();
341
342        let language_model_registry = LanguageModelRegistry::global(cx);
343
344        let configured_providers = language_model_registry
345            .read(cx)
346            .providers()
347            .into_iter()
348            .filter(|provider| provider.is_authenticated(cx))
349            .collect::<Vec<_>>();
350
351        let configured_provider_ids = configured_providers
352            .iter()
353            .map(|provider| provider.id())
354            .collect::<Vec<_>>();
355
356        let recommended_models = all_models
357            .recommended
358            .iter()
359            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
360            .cloned()
361            .collect::<Vec<_>>();
362
363        let available_models = all_models
364            .model_infos()
365            .iter()
366            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
367            .cloned()
368            .collect::<Vec<_>>();
369
370        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
371        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
372
373        let recommended = matcher_rec.exact_search(&query);
374        let all = matcher_all.fuzzy_search(&query);
375
376        let filtered_models = GroupedModels::new(all, recommended);
377
378        cx.spawn_in(window, async move |this, cx| {
379            this.update_in(cx, |this, window, cx| {
380                this.delegate.filtered_entries = filtered_models.entries();
381                // Finds the currently selected model in the list
382                let new_index =
383                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
384                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
385                cx.notify();
386            })
387            .ok();
388        })
389    }
390
391    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
392        if let Some(LanguageModelPickerEntry::Model(model_info)) =
393            self.filtered_entries.get(self.selected_index)
394        {
395            let model = model_info.model.clone();
396            (self.on_model_changed)(model.clone(), cx);
397
398            let current_index = self.selected_index;
399            self.set_selected_index(current_index, window, cx);
400
401            cx.emit(DismissEvent);
402        }
403    }
404
405    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
406        cx.emit(DismissEvent);
407    }
408
409    fn render_match(
410        &self,
411        ix: usize,
412        selected: bool,
413        _: &mut Window,
414        cx: &mut Context<Picker<Self>>,
415    ) -> Option<Self::ListItem> {
416        match self.filtered_entries.get(ix)? {
417            LanguageModelPickerEntry::Separator(title) => Some(
418                div()
419                    .px_2()
420                    .pb_1()
421                    .when(ix > 1, |this| {
422                        this.mt_1()
423                            .pt_2()
424                            .border_t_1()
425                            .border_color(cx.theme().colors().border_variant)
426                    })
427                    .child(
428                        Label::new(title)
429                            .size(LabelSize::XSmall)
430                            .color(Color::Muted),
431                    )
432                    .into_any_element(),
433            ),
434            LanguageModelPickerEntry::Model(model_info) => {
435                let active_model = (self.get_active_model)(cx);
436                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
437                let active_model_id = active_model.map(|m| m.model.id());
438
439                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
440                    && Some(model_info.model.id()) == active_model_id;
441
442                let model_icon_color = if is_selected {
443                    Color::Accent
444                } else {
445                    Color::Muted
446                };
447
448                Some(
449                    ListItem::new(ix)
450                        .inset(true)
451                        .spacing(ListItemSpacing::Sparse)
452                        .toggle_state(selected)
453                        .start_slot(
454                            Icon::new(model_info.icon)
455                                .color(model_icon_color)
456                                .size(IconSize::Small),
457                        )
458                        .child(
459                            h_flex()
460                                .w_full()
461                                .pl_0p5()
462                                .gap_1p5()
463                                .w(px(240.))
464                                .child(Label::new(model_info.model.name().0).truncate()),
465                        )
466                        .end_slot(div().pr_3().when(is_selected, |this| {
467                            this.child(
468                                Icon::new(IconName::Check)
469                                    .color(Color::Accent)
470                                    .size(IconSize::Small),
471                            )
472                        }))
473                        .into_any_element(),
474                )
475            }
476        }
477    }
478
479    fn render_footer(
480        &self,
481        _: &mut Window,
482        cx: &mut Context<Picker<Self>>,
483    ) -> Option<gpui::AnyElement> {
484        use feature_flags::FeatureFlagAppExt;
485
486        let plan = Plan::ZedPro;
487
488        Some(
489            h_flex()
490                .w_full()
491                .border_t_1()
492                .border_color(cx.theme().colors().border_variant)
493                .p_1()
494                .gap_4()
495                .justify_between()
496                .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
497                    this.child(match plan {
498                        Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
499                            .icon(IconName::ZedAssistant)
500                            .icon_size(IconSize::Small)
501                            .icon_color(Color::Muted)
502                            .icon_position(IconPosition::Start)
503                            .on_click(|_, window, cx| {
504                                window
505                                    .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
506                            }),
507                        Plan::ZedFree | Plan::ZedProTrial => Button::new(
508                            "try-pro",
509                            if plan == Plan::ZedProTrial {
510                                "Upgrade to Pro"
511                            } else {
512                                "Try Pro"
513                            },
514                        )
515                        .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
516                    })
517                })
518                .child(
519                    Button::new("configure", "Configure")
520                        .icon(IconName::Settings)
521                        .icon_size(IconSize::Small)
522                        .icon_color(Color::Muted)
523                        .icon_position(IconPosition::Start)
524                        .on_click(|_, window, cx| {
525                            window.dispatch_action(
526                                zed_actions::agent::OpenSettings.boxed_clone(),
527                                cx,
528                            );
529                        }),
530                )
531                .into_any(),
532        )
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539    use futures::{future::BoxFuture, stream::BoxStream};
540    use gpui::{AsyncApp, TestAppContext, http_client};
541    use language_model::{
542        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
543        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
544        LanguageModelRequest, LanguageModelToolChoice,
545    };
546    use ui::IconName;
547
548    #[derive(Clone)]
549    struct TestLanguageModel {
550        name: LanguageModelName,
551        id: LanguageModelId,
552        provider_id: LanguageModelProviderId,
553        provider_name: LanguageModelProviderName,
554    }
555
556    impl TestLanguageModel {
557        fn new(name: &str, provider: &str) -> Self {
558            Self {
559                name: LanguageModelName::from(name.to_string()),
560                id: LanguageModelId::from(name.to_string()),
561                provider_id: LanguageModelProviderId::from(provider.to_string()),
562                provider_name: LanguageModelProviderName::from(provider.to_string()),
563            }
564        }
565    }
566
567    impl LanguageModel for TestLanguageModel {
568        fn id(&self) -> LanguageModelId {
569            self.id.clone()
570        }
571
572        fn name(&self) -> LanguageModelName {
573            self.name.clone()
574        }
575
576        fn provider_id(&self) -> LanguageModelProviderId {
577            self.provider_id.clone()
578        }
579
580        fn provider_name(&self) -> LanguageModelProviderName {
581            self.provider_name.clone()
582        }
583
584        fn supports_tools(&self) -> bool {
585            false
586        }
587
588        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
589            false
590        }
591
592        fn supports_images(&self) -> bool {
593            false
594        }
595
596        fn telemetry_id(&self) -> String {
597            format!("{}/{}", self.provider_id.0, self.name.0)
598        }
599
600        fn max_token_count(&self) -> u64 {
601            1000
602        }
603
604        fn count_tokens(
605            &self,
606            _: LanguageModelRequest,
607            _: &App,
608        ) -> BoxFuture<'static, http_client::Result<u64>> {
609            unimplemented!()
610        }
611
612        fn stream_completion(
613            &self,
614            _: LanguageModelRequest,
615            _: &AsyncApp,
616        ) -> BoxFuture<
617            'static,
618            Result<
619                BoxStream<
620                    'static,
621                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
622                >,
623                LanguageModelCompletionError,
624            >,
625        > {
626            unimplemented!()
627        }
628    }
629
630    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
631        model_specs
632            .into_iter()
633            .map(|(provider, name)| ModelInfo {
634                model: Arc::new(TestLanguageModel::new(name, provider)),
635                icon: IconName::Ai,
636            })
637            .collect()
638    }
639
640    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
641        assert_eq!(
642            result.len(),
643            expected.len(),
644            "Number of models doesn't match"
645        );
646
647        for (i, expected_name) in expected.iter().enumerate() {
648            assert_eq!(
649                result[i].model.telemetry_id(),
650                *expected_name,
651                "Model at position {} doesn't match expected model",
652                i
653            );
654        }
655    }
656
657    #[gpui::test]
658    fn test_exact_match(cx: &mut TestAppContext) {
659        let models = create_models(vec![
660            ("zed", "Claude 3.7 Sonnet"),
661            ("zed", "Claude 3.7 Sonnet Thinking"),
662            ("zed", "gpt-4.1"),
663            ("zed", "gpt-4.1-nano"),
664            ("openai", "gpt-3.5-turbo"),
665            ("openai", "gpt-4.1"),
666            ("openai", "gpt-4.1-nano"),
667            ("ollama", "mistral"),
668            ("ollama", "deepseek"),
669        ]);
670        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
671
672        // The order of models should be maintained, case doesn't matter
673        let results = matcher.exact_search("GPT-4.1");
674        assert_models_eq(
675            results,
676            vec![
677                "zed/gpt-4.1",
678                "zed/gpt-4.1-nano",
679                "openai/gpt-4.1",
680                "openai/gpt-4.1-nano",
681            ],
682        );
683    }
684
685    #[gpui::test]
686    fn test_fuzzy_match(cx: &mut TestAppContext) {
687        let models = create_models(vec![
688            ("zed", "Claude 3.7 Sonnet"),
689            ("zed", "Claude 3.7 Sonnet Thinking"),
690            ("zed", "gpt-4.1"),
691            ("zed", "gpt-4.1-nano"),
692            ("openai", "gpt-3.5-turbo"),
693            ("openai", "gpt-4.1"),
694            ("openai", "gpt-4.1-nano"),
695            ("ollama", "mistral"),
696            ("ollama", "deepseek"),
697        ]);
698        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
699
700        // Results should preserve models order whenever possible.
701        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
702        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
703        // so it should appear first in the results.
704        let results = matcher.fuzzy_search("41");
705        assert_models_eq(
706            results,
707            vec![
708                "zed/gpt-4.1",
709                "openai/gpt-4.1",
710                "zed/gpt-4.1-nano",
711                "openai/gpt-4.1-nano",
712            ],
713        );
714
715        // Model provider should be searchable as well
716        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
717        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
718
719        // Fuzzy search
720        let results = matcher.fuzzy_search("z4n");
721        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
722    }
723
724    #[gpui::test]
725    fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
726        let recommended_models = create_models(vec![("zed", "claude")]);
727        let all_models = create_models(vec![
728            ("zed", "claude"), // Should be filtered out from "other"
729            ("zed", "gemini"),
730            ("copilot", "o3"),
731        ]);
732
733        let grouped_models = GroupedModels::new(all_models, recommended_models);
734
735        let actual_other_models = grouped_models
736            .other
737            .values()
738            .flatten()
739            .cloned()
740            .collect::<Vec<_>>();
741
742        // Recommended models should not appear in "other"
743        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
744    }
745
746    #[gpui::test]
747    fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
748        let recommended_models = create_models(vec![("zed", "claude")]);
749        let all_models = create_models(vec![
750            ("zed", "claude"), // Should be filtered out from "other"
751            ("zed", "gemini"),
752            ("copilot", "claude"), // Should not be filtered out from "other"
753        ]);
754
755        let grouped_models = GroupedModels::new(all_models, recommended_models);
756
757        let actual_other_models = grouped_models
758            .other
759            .values()
760            .flatten()
761            .cloned()
762            .collect::<Vec<_>>();
763
764        // Recommended models should not appear in "other"
765        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
766    }
767}