language_model_selector.rs

  1use std::{cmp::Reverse, sync::Arc};
  2
  3use collections::IndexMap;
  4use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
  5use gpui::{
  6    Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
  7};
  8use language_model::{
  9    AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
 10    LanguageModelRegistry,
 11};
 12use ordered_float::OrderedFloat;
 13use picker::{Picker, PickerDelegate};
 14use ui::prelude::*;
 15use zed_actions::agent::OpenSettings;
 16
 17use crate::ui::{ModelSelectorFooter, ModelSelectorHeader, ModelSelectorListItem};
 18
 19type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
 20type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
 21
 22pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
 23
 24pub fn language_model_selector(
 25    get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
 26    on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
 27    popover_styles: bool,
 28    focus_handle: FocusHandle,
 29    window: &mut Window,
 30    cx: &mut Context<LanguageModelSelector>,
 31) -> LanguageModelSelector {
 32    let delegate = LanguageModelPickerDelegate::new(
 33        get_active_model,
 34        on_model_changed,
 35        popover_styles,
 36        focus_handle,
 37        window,
 38        cx,
 39    );
 40
 41    if popover_styles {
 42        Picker::list(delegate, window, cx)
 43            .show_scrollbar(true)
 44            .width(rems(20.))
 45            .max_height(Some(rems(20.).into()))
 46    } else {
 47        Picker::list(delegate, window, cx).show_scrollbar(true)
 48    }
 49}
 50
 51fn all_models(cx: &App) -> GroupedModels {
 52    let providers = LanguageModelRegistry::global(cx).read(cx).providers();
 53
 54    let recommended = providers
 55        .iter()
 56        .flat_map(|provider| {
 57            provider
 58                .recommended_models(cx)
 59                .into_iter()
 60                .map(|model| ModelInfo {
 61                    model,
 62                    icon: provider.icon(),
 63                })
 64        })
 65        .collect();
 66
 67    let all = providers
 68        .iter()
 69        .flat_map(|provider| {
 70            provider
 71                .provided_models(cx)
 72                .into_iter()
 73                .map(|model| ModelInfo {
 74                    model,
 75                    icon: provider.icon(),
 76                })
 77        })
 78        .collect();
 79
 80    GroupedModels::new(all, recommended)
 81}
 82
 83#[derive(Clone)]
 84struct ModelInfo {
 85    model: Arc<dyn LanguageModel>,
 86    icon: IconName,
 87}
 88
 89pub struct LanguageModelPickerDelegate {
 90    on_model_changed: OnModelChanged,
 91    get_active_model: GetActiveModel,
 92    all_models: Arc<GroupedModels>,
 93    filtered_entries: Vec<LanguageModelPickerEntry>,
 94    selected_index: usize,
 95    _authenticate_all_providers_task: Task<()>,
 96    _subscriptions: Vec<Subscription>,
 97    popover_styles: bool,
 98    focus_handle: FocusHandle,
 99}
100
101impl LanguageModelPickerDelegate {
102    fn new(
103        get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
104        on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
105        popover_styles: bool,
106        focus_handle: FocusHandle,
107        window: &mut Window,
108        cx: &mut Context<Picker<Self>>,
109    ) -> Self {
110        let on_model_changed = Arc::new(on_model_changed);
111        let models = all_models(cx);
112        let entries = models.entries();
113
114        Self {
115            on_model_changed,
116            all_models: Arc::new(models),
117            selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
118            filtered_entries: entries,
119            get_active_model: Arc::new(get_active_model),
120            _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
121            _subscriptions: vec![cx.subscribe_in(
122                &LanguageModelRegistry::global(cx),
123                window,
124                |picker, _, event, window, cx| {
125                    match event {
126                        language_model::Event::ProviderStateChanged(_)
127                        | language_model::Event::AddedProvider(_)
128                        | language_model::Event::RemovedProvider(_) => {
129                            let query = picker.query(cx);
130                            picker.delegate.all_models = Arc::new(all_models(cx));
131                            // Update matches will automatically drop the previous task
132                            // if we get a provider event again
133                            picker.update_matches(query, window, cx)
134                        }
135                        _ => {}
136                    }
137                },
138            )],
139            popover_styles,
140            focus_handle,
141        }
142    }
143
144    fn get_active_model_index(
145        entries: &[LanguageModelPickerEntry],
146        active_model: Option<ConfiguredModel>,
147    ) -> usize {
148        entries
149            .iter()
150            .position(|entry| {
151                if let LanguageModelPickerEntry::Model(model) = entry {
152                    active_model
153                        .as_ref()
154                        .map(|active_model| {
155                            active_model.model.id() == model.model.id()
156                                && active_model.provider.id() == model.model.provider_id()
157                        })
158                        .unwrap_or_default()
159                } else {
160                    false
161                }
162            })
163            .unwrap_or(0)
164    }
165
166    /// Authenticates all providers in the [`LanguageModelRegistry`].
167    ///
168    /// We do this so that we can populate the language selector with all of the
169    /// models from the configured providers.
170    fn authenticate_all_providers(cx: &mut App) -> Task<()> {
171        let authenticate_all_providers = LanguageModelRegistry::global(cx)
172            .read(cx)
173            .providers()
174            .iter()
175            .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
176            .collect::<Vec<_>>();
177
178        cx.spawn(async move |_cx| {
179            for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
180                if let Err(err) = authenticate_task.await {
181                    if matches!(err, AuthenticateError::CredentialsNotFound) {
182                        // Since we're authenticating these providers in the
183                        // background for the purposes of populating the
184                        // language selector, we don't care about providers
185                        // where the credentials are not found.
186                    } else {
187                        // Some providers have noisy failure states that we
188                        // don't want to spam the logs with every time the
189                        // language model selector is initialized.
190                        //
191                        // Ideally these should have more clear failure modes
192                        // that we know are safe to ignore here, like what we do
193                        // with `CredentialsNotFound` above.
194                        match provider_id.0.as_ref() {
195                            "lmstudio" | "ollama" => {
196                                // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
197                                //
198                                // These fail noisily, so we don't log them.
199                            }
200                            "copilot_chat" => {
201                                // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
202                            }
203                            _ => {
204                                log::error!(
205                                    "Failed to authenticate provider: {}: {err:#}",
206                                    provider_name.0
207                                );
208                            }
209                        }
210                    }
211                }
212            }
213        })
214    }
215
216    pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
217        (self.get_active_model)(cx)
218    }
219}
220
221struct GroupedModels {
222    recommended: Vec<ModelInfo>,
223    all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
224}
225
226impl GroupedModels {
227    pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
228        let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
229        for model in all {
230            let provider = model.model.provider_id();
231            if let Some(models) = all_by_provider.get_mut(&provider) {
232                models.push(model);
233            } else {
234                all_by_provider.insert(provider, vec![model]);
235            }
236        }
237
238        Self {
239            recommended,
240            all: all_by_provider,
241        }
242    }
243
244    fn entries(&self) -> Vec<LanguageModelPickerEntry> {
245        let mut entries = Vec::new();
246
247        if !self.recommended.is_empty() {
248            entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
249            entries.extend(
250                self.recommended
251                    .iter()
252                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
253            );
254        }
255
256        for models in self.all.values() {
257            if models.is_empty() {
258                continue;
259            }
260            entries.push(LanguageModelPickerEntry::Separator(
261                models[0].model.provider_name().0,
262            ));
263            entries.extend(
264                models
265                    .iter()
266                    .map(|info| LanguageModelPickerEntry::Model(info.clone())),
267            );
268        }
269        entries
270    }
271}
272
273enum LanguageModelPickerEntry {
274    Model(ModelInfo),
275    Separator(SharedString),
276}
277
278struct ModelMatcher {
279    models: Vec<ModelInfo>,
280    bg_executor: BackgroundExecutor,
281    candidates: Vec<StringMatchCandidate>,
282}
283
284impl ModelMatcher {
285    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
286        let candidates = Self::make_match_candidates(&models);
287        Self {
288            models,
289            bg_executor,
290            candidates,
291        }
292    }
293
294    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
295        let mut matches = self.bg_executor.block(match_strings(
296            &self.candidates,
297            query,
298            false,
299            true,
300            100,
301            &Default::default(),
302            self.bg_executor.clone(),
303        ));
304
305        let sorting_key = |mat: &StringMatch| {
306            let candidate = &self.candidates[mat.candidate_id];
307            (Reverse(OrderedFloat(mat.score)), candidate.id)
308        };
309        matches.sort_unstable_by_key(sorting_key);
310
311        let matched_models: Vec<_> = matches
312            .into_iter()
313            .map(|mat| self.models[mat.candidate_id].clone())
314            .collect();
315
316        matched_models
317    }
318
319    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
320        self.models
321            .iter()
322            .filter(|m| {
323                m.model
324                    .name()
325                    .0
326                    .to_lowercase()
327                    .contains(&query.to_lowercase())
328            })
329            .cloned()
330            .collect::<Vec<_>>()
331    }
332
333    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
334        model_infos
335            .iter()
336            .enumerate()
337            .map(|(index, model)| {
338                StringMatchCandidate::new(
339                    index,
340                    &format!(
341                        "{}/{}",
342                        &model.model.provider_name().0,
343                        &model.model.name().0
344                    ),
345                )
346            })
347            .collect::<Vec<_>>()
348    }
349}
350
351impl PickerDelegate for LanguageModelPickerDelegate {
352    type ListItem = AnyElement;
353
354    fn match_count(&self) -> usize {
355        self.filtered_entries.len()
356    }
357
358    fn selected_index(&self) -> usize {
359        self.selected_index
360    }
361
362    fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
363        self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
364        cx.notify();
365    }
366
367    fn can_select(
368        &mut self,
369        ix: usize,
370        _window: &mut Window,
371        _cx: &mut Context<Picker<Self>>,
372    ) -> bool {
373        match self.filtered_entries.get(ix) {
374            Some(LanguageModelPickerEntry::Model(_)) => true,
375            Some(LanguageModelPickerEntry::Separator(_)) | None => false,
376        }
377    }
378
379    fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
380        "Select a model…".into()
381    }
382
383    fn update_matches(
384        &mut self,
385        query: String,
386        window: &mut Window,
387        cx: &mut Context<Picker<Self>>,
388    ) -> Task<()> {
389        let all_models = self.all_models.clone();
390        let active_model = (self.get_active_model)(cx);
391        let bg_executor = cx.background_executor();
392
393        let language_model_registry = LanguageModelRegistry::global(cx);
394
395        let configured_providers = language_model_registry
396            .read(cx)
397            .providers()
398            .into_iter()
399            .filter(|provider| provider.is_authenticated(cx))
400            .collect::<Vec<_>>();
401
402        let configured_provider_ids = configured_providers
403            .iter()
404            .map(|provider| provider.id())
405            .collect::<Vec<_>>();
406
407        let recommended_models = all_models
408            .recommended
409            .iter()
410            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
411            .cloned()
412            .collect::<Vec<_>>();
413
414        let available_models = all_models
415            .all
416            .values()
417            .flat_map(|models| models.iter())
418            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
419            .cloned()
420            .collect::<Vec<_>>();
421
422        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
423        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
424
425        let recommended = matcher_rec.exact_search(&query);
426        let all = matcher_all.fuzzy_search(&query);
427
428        let filtered_models = GroupedModels::new(all, recommended);
429
430        cx.spawn_in(window, async move |this, cx| {
431            this.update_in(cx, |this, window, cx| {
432                this.delegate.filtered_entries = filtered_models.entries();
433                // Finds the currently selected model in the list
434                let new_index =
435                    Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
436                this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
437                cx.notify();
438            })
439            .ok();
440        })
441    }
442
443    fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
444        if let Some(LanguageModelPickerEntry::Model(model_info)) =
445            self.filtered_entries.get(self.selected_index)
446        {
447            let model = model_info.model.clone();
448            (self.on_model_changed)(model.clone(), cx);
449
450            let current_index = self.selected_index;
451            self.set_selected_index(current_index, window, cx);
452
453            cx.emit(DismissEvent);
454        }
455    }
456
457    fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
458        cx.emit(DismissEvent);
459    }
460
461    fn render_match(
462        &self,
463        ix: usize,
464        is_focused: bool,
465        _: &mut Window,
466        cx: &mut Context<Picker<Self>>,
467    ) -> Option<Self::ListItem> {
468        match self.filtered_entries.get(ix)? {
469            LanguageModelPickerEntry::Separator(title) => {
470                Some(ModelSelectorHeader::new(title, ix > 1).into_any_element())
471            }
472            LanguageModelPickerEntry::Model(model_info) => {
473                let active_model = (self.get_active_model)(cx);
474                let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
475                let active_model_id = active_model.map(|m| m.model.id());
476
477                let is_selected = Some(model_info.model.provider_id()) == active_provider_id
478                    && Some(model_info.model.id()) == active_model_id;
479
480                Some(
481                    ModelSelectorListItem::new(ix, model_info.model.name().0)
482                        .is_focused(is_focused)
483                        .is_selected(is_selected)
484                        .icon(model_info.icon)
485                        .into_any_element(),
486                )
487            }
488        }
489    }
490
491    fn render_footer(
492        &self,
493        _window: &mut Window,
494        _cx: &mut Context<Picker<Self>>,
495    ) -> Option<gpui::AnyElement> {
496        if !self.popover_styles {
497            return None;
498        }
499
500        let focus_handle = self.focus_handle.clone();
501
502        Some(ModelSelectorFooter::new(OpenSettings.boxed_clone(), focus_handle).into_any_element())
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509    use futures::{future::BoxFuture, stream::BoxStream};
510    use gpui::{AsyncApp, TestAppContext, http_client};
511    use language_model::{
512        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
513        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
514        LanguageModelRequest, LanguageModelToolChoice,
515    };
516    use ui::IconName;
517
518    #[derive(Clone)]
519    struct TestLanguageModel {
520        name: LanguageModelName,
521        id: LanguageModelId,
522        provider_id: LanguageModelProviderId,
523        provider_name: LanguageModelProviderName,
524    }
525
526    impl TestLanguageModel {
527        fn new(name: &str, provider: &str) -> Self {
528            Self {
529                name: LanguageModelName::from(name.to_string()),
530                id: LanguageModelId::from(name.to_string()),
531                provider_id: LanguageModelProviderId::from(provider.to_string()),
532                provider_name: LanguageModelProviderName::from(provider.to_string()),
533            }
534        }
535    }
536
537    impl LanguageModel for TestLanguageModel {
538        fn id(&self) -> LanguageModelId {
539            self.id.clone()
540        }
541
542        fn name(&self) -> LanguageModelName {
543            self.name.clone()
544        }
545
546        fn provider_id(&self) -> LanguageModelProviderId {
547            self.provider_id.clone()
548        }
549
550        fn provider_name(&self) -> LanguageModelProviderName {
551            self.provider_name.clone()
552        }
553
554        fn supports_tools(&self) -> bool {
555            false
556        }
557
558        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
559            false
560        }
561
562        fn supports_images(&self) -> bool {
563            false
564        }
565
566        fn telemetry_id(&self) -> String {
567            format!("{}/{}", self.provider_id.0, self.name.0)
568        }
569
570        fn max_token_count(&self) -> u64 {
571            1000
572        }
573
574        fn count_tokens(
575            &self,
576            _: LanguageModelRequest,
577            _: &App,
578        ) -> BoxFuture<'static, http_client::Result<u64>> {
579            unimplemented!()
580        }
581
582        fn stream_completion(
583            &self,
584            _: LanguageModelRequest,
585            _: &AsyncApp,
586        ) -> BoxFuture<
587            'static,
588            Result<
589                BoxStream<
590                    'static,
591                    Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
592                >,
593                LanguageModelCompletionError,
594            >,
595        > {
596            unimplemented!()
597        }
598    }
599
600    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
601        model_specs
602            .into_iter()
603            .map(|(provider, name)| ModelInfo {
604                model: Arc::new(TestLanguageModel::new(name, provider)),
605                icon: IconName::Ai,
606            })
607            .collect()
608    }
609
610    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
611        assert_eq!(
612            result.len(),
613            expected.len(),
614            "Number of models doesn't match"
615        );
616
617        for (i, expected_name) in expected.iter().enumerate() {
618            assert_eq!(
619                result[i].model.telemetry_id(),
620                *expected_name,
621                "Model at position {} doesn't match expected model",
622                i
623            );
624        }
625    }
626
627    #[gpui::test]
628    fn test_exact_match(cx: &mut TestAppContext) {
629        let models = create_models(vec![
630            ("zed", "Claude 3.7 Sonnet"),
631            ("zed", "Claude 3.7 Sonnet Thinking"),
632            ("zed", "gpt-4.1"),
633            ("zed", "gpt-4.1-nano"),
634            ("openai", "gpt-3.5-turbo"),
635            ("openai", "gpt-4.1"),
636            ("openai", "gpt-4.1-nano"),
637            ("ollama", "mistral"),
638            ("ollama", "deepseek"),
639        ]);
640        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
641
642        // The order of models should be maintained, case doesn't matter
643        let results = matcher.exact_search("GPT-4.1");
644        assert_models_eq(
645            results,
646            vec![
647                "zed/gpt-4.1",
648                "zed/gpt-4.1-nano",
649                "openai/gpt-4.1",
650                "openai/gpt-4.1-nano",
651            ],
652        );
653    }
654
655    #[gpui::test]
656    fn test_fuzzy_match(cx: &mut TestAppContext) {
657        let models = create_models(vec![
658            ("zed", "Claude 3.7 Sonnet"),
659            ("zed", "Claude 3.7 Sonnet Thinking"),
660            ("zed", "gpt-4.1"),
661            ("zed", "gpt-4.1-nano"),
662            ("openai", "gpt-3.5-turbo"),
663            ("openai", "gpt-4.1"),
664            ("openai", "gpt-4.1-nano"),
665            ("ollama", "mistral"),
666            ("ollama", "deepseek"),
667        ]);
668        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
669
670        // Results should preserve models order whenever possible.
671        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
672        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
673        // so it should appear first in the results.
674        let results = matcher.fuzzy_search("41");
675        assert_models_eq(
676            results,
677            vec![
678                "zed/gpt-4.1",
679                "openai/gpt-4.1",
680                "zed/gpt-4.1-nano",
681                "openai/gpt-4.1-nano",
682            ],
683        );
684
685        // Model provider should be searchable as well
686        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
687        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
688
689        // Fuzzy search
690        let results = matcher.fuzzy_search("z4n");
691        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
692    }
693
694    #[gpui::test]
695    fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
696        let recommended_models = create_models(vec![("zed", "claude")]);
697        let all_models = create_models(vec![
698            ("zed", "claude"), // Should also appear in "other"
699            ("zed", "gemini"),
700            ("copilot", "o3"),
701        ]);
702
703        let grouped_models = GroupedModels::new(all_models, recommended_models);
704
705        let actual_all_models = grouped_models
706            .all
707            .values()
708            .flatten()
709            .cloned()
710            .collect::<Vec<_>>();
711
712        // Recommended models should also appear in "all"
713        assert_models_eq(
714            actual_all_models,
715            vec!["zed/claude", "zed/gemini", "copilot/o3"],
716        );
717    }
718
719    #[gpui::test]
720    fn test_models_from_different_providers(_cx: &mut TestAppContext) {
721        let recommended_models = create_models(vec![("zed", "claude")]);
722        let all_models = create_models(vec![
723            ("zed", "claude"), // Should also appear in "other"
724            ("zed", "gemini"),
725            ("copilot", "claude"), // Different provider, should appear in "other"
726        ]);
727
728        let grouped_models = GroupedModels::new(all_models, recommended_models);
729
730        let actual_all_models = grouped_models
731            .all
732            .values()
733            .flatten()
734            .cloned()
735            .collect::<Vec<_>>();
736
737        // All models should appear in "all" regardless of recommended status
738        assert_models_eq(
739            actual_all_models,
740            vec!["zed/claude", "zed/gemini", "copilot/claude"],
741        );
742    }
743}