1use std::{cmp::Reverse, sync::Arc};
2
3use collections::IndexMap;
4use futures::{StreamExt, channel::mpsc};
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
7use language_model::{
8 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
9 LanguageModelProviderId, LanguageModelRegistry,
10};
11use ordered_float::OrderedFloat;
12use picker::{Picker, PickerDelegate};
13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
14use zed_actions::agent::OpenSettings;
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
18
19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
20
21pub fn language_model_selector(
22 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
23 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
24 popover_styles: bool,
25 focus_handle: FocusHandle,
26 window: &mut Window,
27 cx: &mut Context<LanguageModelSelector>,
28) -> LanguageModelSelector {
29 let delegate = LanguageModelPickerDelegate::new(
30 get_active_model,
31 on_model_changed,
32 popover_styles,
33 focus_handle,
34 window,
35 cx,
36 );
37
38 if popover_styles {
39 Picker::list(delegate, window, cx)
40 .show_scrollbar(true)
41 .width(rems(20.))
42 .max_height(Some(rems(20.).into()))
43 } else {
44 Picker::list(delegate, window, cx).show_scrollbar(true)
45 }
46}
47
48fn all_models(cx: &App) -> GroupedModels {
49 let providers = LanguageModelRegistry::global(cx)
50 .read(cx)
51 .visible_providers();
52
53 let recommended = providers
54 .iter()
55 .flat_map(|provider| {
56 provider
57 .recommended_models(cx)
58 .into_iter()
59 .map(|model| ModelInfo {
60 model,
61 icon: ProviderIcon::from_provider(provider.as_ref()),
62 })
63 })
64 .collect();
65
66 let all: Vec<ModelInfo> = providers
67 .iter()
68 .flat_map(|provider| {
69 provider
70 .provided_models(cx)
71 .into_iter()
72 .map(|model| ModelInfo {
73 model,
74 icon: ProviderIcon::from_provider(provider.as_ref()),
75 })
76 })
77 .collect();
78
79 GroupedModels::new(all, recommended)
80}
81
82#[derive(Clone)]
83enum ProviderIcon {
84 Name(IconName),
85 Path(SharedString),
86}
87
88impl ProviderIcon {
89 fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
90 if let Some(path) = provider.icon_path() {
91 Self::Path(path)
92 } else {
93 Self::Name(provider.icon())
94 }
95 }
96}
97
98#[derive(Clone)]
99struct ModelInfo {
100 model: Arc<dyn LanguageModel>,
101 icon: ProviderIcon,
102}
103
104pub struct LanguageModelPickerDelegate {
105 on_model_changed: OnModelChanged,
106 get_active_model: GetActiveModel,
107 all_models: Arc<GroupedModels>,
108 filtered_entries: Vec<LanguageModelPickerEntry>,
109 selected_index: usize,
110 _authenticate_all_providers_task: Task<()>,
111 _refresh_models_task: Task<()>,
112 popover_styles: bool,
113 focus_handle: FocusHandle,
114}
115
116impl LanguageModelPickerDelegate {
117 fn new(
118 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
119 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
120 popover_styles: bool,
121 focus_handle: FocusHandle,
122 window: &mut Window,
123 cx: &mut Context<Picker<Self>>,
124 ) -> Self {
125 let on_model_changed = Arc::new(on_model_changed);
126 let models = all_models(cx);
127 let entries = models.entries();
128
129 Self {
130 on_model_changed,
131 all_models: Arc::new(models),
132 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
133 filtered_entries: entries,
134 get_active_model: Arc::new(get_active_model),
135 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
136 _refresh_models_task: {
137 // Create a channel to signal when models need refreshing
138 let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
139
140 // Subscribe to registry events and send refresh signals through the channel
141 let registry = LanguageModelRegistry::global(cx);
142 cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
143 language_model::Event::ProviderStateChanged(_) => {
144 refresh_tx.unbounded_send(()).ok();
145 }
146 language_model::Event::AddedProvider(_) => {
147 refresh_tx.unbounded_send(()).ok();
148 }
149 language_model::Event::RemovedProvider(_) => {
150 refresh_tx.unbounded_send(()).ok();
151 }
152 _ => {}
153 })
154 .detach();
155
156 // Spawn a task that listens for refresh signals and updates the picker
157 cx.spawn_in(window, async move |this, cx| {
158 while let Some(()) = refresh_rx.next().await {
159 let result = this.update_in(cx, |picker, window, cx| {
160 picker.delegate.all_models = Arc::new(all_models(cx));
161 picker.refresh(window, cx);
162 });
163 if result.is_err() {
164 // Picker was dropped, exit the loop
165 break;
166 }
167 }
168 })
169 },
170 popover_styles,
171 focus_handle,
172 }
173 }
174
175 fn get_active_model_index(
176 entries: &[LanguageModelPickerEntry],
177 active_model: Option<ConfiguredModel>,
178 ) -> usize {
179 entries
180 .iter()
181 .position(|entry| {
182 if let LanguageModelPickerEntry::Model(model) = entry {
183 active_model
184 .as_ref()
185 .map(|active_model| {
186 active_model.model.id() == model.model.id()
187 && active_model.provider.id() == model.model.provider_id()
188 })
189 .unwrap_or_default()
190 } else {
191 false
192 }
193 })
194 .unwrap_or(0)
195 }
196
197 /// Authenticates all providers in the [`LanguageModelRegistry`].
198 ///
199 /// We do this so that we can populate the language selector with all of the
200 /// models from the configured providers.
201 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
202 let authenticate_all_providers = LanguageModelRegistry::global(cx)
203 .read(cx)
204 .providers()
205 .iter()
206 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
207 .collect::<Vec<_>>();
208
209 cx.spawn(async move |_cx| {
210 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
211 if let Err(err) = authenticate_task.await {
212 if matches!(err, AuthenticateError::CredentialsNotFound) {
213 // Since we're authenticating these providers in the
214 // background for the purposes of populating the
215 // language selector, we don't care about providers
216 // where the credentials are not found.
217 } else {
218 // Some providers have noisy failure states that we
219 // don't want to spam the logs with every time the
220 // language model selector is initialized.
221 //
222 // Ideally these should have more clear failure modes
223 // that we know are safe to ignore here, like what we do
224 // with `CredentialsNotFound` above.
225 match provider_id.0.as_ref() {
226 "lmstudio" | "ollama" => {
227 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
228 //
229 // These fail noisily, so we don't log them.
230 }
231 "copilot_chat" => {
232 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
233 }
234 _ => {
235 log::error!(
236 "Failed to authenticate provider: {}: {err:#}",
237 provider_name.0
238 );
239 }
240 }
241 }
242 }
243 }
244 })
245 }
246
247 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
248 (self.get_active_model)(cx)
249 }
250}
251
252struct GroupedModels {
253 recommended: Vec<ModelInfo>,
254 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
255}
256
257impl GroupedModels {
258 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
259 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
260 for model in all {
261 let provider = model.model.provider_id();
262 if let Some(models) = all_by_provider.get_mut(&provider) {
263 models.push(model);
264 } else {
265 all_by_provider.insert(provider, vec![model]);
266 }
267 }
268
269 Self {
270 recommended,
271 all: all_by_provider,
272 }
273 }
274
275 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
276 let mut entries = Vec::new();
277
278 if !self.recommended.is_empty() {
279 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
280 entries.extend(
281 self.recommended
282 .iter()
283 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
284 );
285 }
286
287 for models in self.all.values() {
288 if models.is_empty() {
289 continue;
290 }
291 entries.push(LanguageModelPickerEntry::Separator(
292 models[0].model.provider_name().0,
293 ));
294 entries.extend(
295 models
296 .iter()
297 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
298 );
299 }
300 entries
301 }
302}
303
304enum LanguageModelPickerEntry {
305 Model(ModelInfo),
306 Separator(SharedString),
307}
308
309struct ModelMatcher {
310 models: Vec<ModelInfo>,
311 bg_executor: BackgroundExecutor,
312 candidates: Vec<StringMatchCandidate>,
313}
314
315impl ModelMatcher {
316 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
317 let candidates = Self::make_match_candidates(&models);
318 Self {
319 models,
320 bg_executor,
321 candidates,
322 }
323 }
324
325 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
326 let mut matches = self.bg_executor.block(match_strings(
327 &self.candidates,
328 query,
329 false,
330 true,
331 100,
332 &Default::default(),
333 self.bg_executor.clone(),
334 ));
335
336 let sorting_key = |mat: &StringMatch| {
337 let candidate = &self.candidates[mat.candidate_id];
338 (Reverse(OrderedFloat(mat.score)), candidate.id)
339 };
340 matches.sort_unstable_by_key(sorting_key);
341
342 let matched_models: Vec<_> = matches
343 .into_iter()
344 .map(|mat| self.models[mat.candidate_id].clone())
345 .collect();
346
347 matched_models
348 }
349
350 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
351 self.models
352 .iter()
353 .filter(|m| {
354 m.model
355 .name()
356 .0
357 .to_lowercase()
358 .contains(&query.to_lowercase())
359 })
360 .cloned()
361 .collect::<Vec<_>>()
362 }
363
364 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
365 model_infos
366 .iter()
367 .enumerate()
368 .map(|(index, model)| {
369 StringMatchCandidate::new(
370 index,
371 &format!(
372 "{}/{}",
373 &model.model.provider_name().0,
374 &model.model.name().0
375 ),
376 )
377 })
378 .collect::<Vec<_>>()
379 }
380}
381
382impl PickerDelegate for LanguageModelPickerDelegate {
383 type ListItem = AnyElement;
384
385 fn match_count(&self) -> usize {
386 self.filtered_entries.len()
387 }
388
389 fn selected_index(&self) -> usize {
390 self.selected_index
391 }
392
393 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
394 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
395 cx.notify();
396 }
397
398 fn can_select(
399 &mut self,
400 ix: usize,
401 _window: &mut Window,
402 _cx: &mut Context<Picker<Self>>,
403 ) -> bool {
404 match self.filtered_entries.get(ix) {
405 Some(LanguageModelPickerEntry::Model(_)) => true,
406 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
407 }
408 }
409
410 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
411 "Select a model…".into()
412 }
413
414 fn update_matches(
415 &mut self,
416 query: String,
417 window: &mut Window,
418 cx: &mut Context<Picker<Self>>,
419 ) -> Task<()> {
420 let all_models = self.all_models.clone();
421 let active_model = (self.get_active_model)(cx);
422 let bg_executor = cx.background_executor();
423
424 let language_model_registry = LanguageModelRegistry::global(cx);
425
426 let configured_providers = language_model_registry
427 .read(cx)
428 .providers()
429 .into_iter()
430 .filter(|provider| provider.is_authenticated(cx))
431 .collect::<Vec<_>>();
432
433 let configured_provider_ids = configured_providers
434 .iter()
435 .map(|provider| provider.id())
436 .collect::<Vec<_>>();
437
438 let recommended_models = all_models
439 .recommended
440 .iter()
441 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
442 .cloned()
443 .collect::<Vec<_>>();
444
445 let available_models = all_models
446 .all
447 .values()
448 .flat_map(|models| models.iter())
449 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
450 .cloned()
451 .collect::<Vec<_>>();
452
453 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
454 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
455
456 let recommended = matcher_rec.exact_search(&query);
457 let all = matcher_all.fuzzy_search(&query);
458
459 let filtered_models = GroupedModels::new(all, recommended);
460
461 cx.spawn_in(window, async move |this, cx| {
462 this.update_in(cx, |this, window, cx| {
463 this.delegate.filtered_entries = filtered_models.entries();
464 // Finds the currently selected model in the list
465 let new_index =
466 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
467 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
468 cx.notify();
469 })
470 .ok();
471 })
472 }
473
474 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
475 if let Some(LanguageModelPickerEntry::Model(model_info)) =
476 self.filtered_entries.get(self.selected_index)
477 {
478 let model = model_info.model.clone();
479 (self.on_model_changed)(model.clone(), cx);
480
481 let current_index = self.selected_index;
482 self.set_selected_index(current_index, window, cx);
483
484 cx.emit(DismissEvent);
485 }
486 }
487
488 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
489 cx.emit(DismissEvent);
490 }
491
492 fn render_match(
493 &self,
494 ix: usize,
495 selected: bool,
496 _: &mut Window,
497 cx: &mut Context<Picker<Self>>,
498 ) -> Option<Self::ListItem> {
499 match self.filtered_entries.get(ix)? {
500 LanguageModelPickerEntry::Separator(title) => Some(
501 div()
502 .px_2()
503 .pb_1()
504 .when(ix > 1, |this| {
505 this.mt_1()
506 .pt_2()
507 .border_t_1()
508 .border_color(cx.theme().colors().border_variant)
509 })
510 .child(
511 Label::new(title)
512 .size(LabelSize::XSmall)
513 .color(Color::Muted),
514 )
515 .into_any_element(),
516 ),
517 LanguageModelPickerEntry::Model(model_info) => {
518 let active_model = (self.get_active_model)(cx);
519 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
520 let active_model_id = active_model.map(|m| m.model.id());
521
522 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
523 && Some(model_info.model.id()) == active_model_id;
524
525 let model_icon_color = if is_selected {
526 Color::Accent
527 } else {
528 Color::Muted
529 };
530
531 Some(
532 ListItem::new(ix)
533 .inset(true)
534 .spacing(ListItemSpacing::Sparse)
535 .toggle_state(selected)
536 .child(
537 h_flex()
538 .w_full()
539 .gap_1p5()
540 .child(match &model_info.icon {
541 ProviderIcon::Name(icon_name) => {
542 log::info!("ICON_DEBUG model_selector using Icon::new for {:?}", icon_name);
543 Icon::new(*icon_name)
544 .color(model_icon_color)
545 .size(IconSize::Small)
546 }
547 ProviderIcon::Path(icon_path) => {
548 log::info!("ICON_DEBUG model_selector using from_external_svg path={}", icon_path);
549 Icon::from_external_svg(icon_path.clone())
550 .color(model_icon_color)
551 .size(IconSize::Small)
552 }
553 })
554 .child(Label::new(model_info.model.name().0).truncate()),
555 )
556 .end_slot(div().pr_3().when(is_selected, |this| {
557 this.child(
558 Icon::new(IconName::Check)
559 .color(Color::Accent)
560 .size(IconSize::Small),
561 )
562 }))
563 .into_any_element(),
564 )
565 }
566 }
567 }
568
569 fn render_footer(
570 &self,
571 _window: &mut Window,
572 cx: &mut Context<Picker<Self>>,
573 ) -> Option<gpui::AnyElement> {
574 let focus_handle = self.focus_handle.clone();
575
576 if !self.popover_styles {
577 return None;
578 }
579
580 Some(
581 h_flex()
582 .w_full()
583 .p_1p5()
584 .border_t_1()
585 .border_color(cx.theme().colors().border_variant)
586 .child(
587 Button::new("configure", "Configure")
588 .full_width()
589 .style(ButtonStyle::Outlined)
590 .key_binding(
591 KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
592 .map(|kb| kb.size(rems_from_px(12.))),
593 )
594 .on_click(|_, window, cx| {
595 window.dispatch_action(OpenSettings.boxed_clone(), cx);
596 }),
597 )
598 .into_any(),
599 )
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606 use futures::{future::BoxFuture, stream::BoxStream};
607 use gpui::{AsyncApp, TestAppContext, http_client};
608 use language_model::{
609 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
610 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
611 LanguageModelRequest, LanguageModelToolChoice,
612 };
613 use ui::IconName;
614
615 #[derive(Clone)]
616 struct TestLanguageModel {
617 name: LanguageModelName,
618 id: LanguageModelId,
619 provider_id: LanguageModelProviderId,
620 provider_name: LanguageModelProviderName,
621 }
622
623 impl TestLanguageModel {
624 fn new(name: &str, provider: &str) -> Self {
625 Self {
626 name: LanguageModelName::from(name.to_string()),
627 id: LanguageModelId::from(name.to_string()),
628 provider_id: LanguageModelProviderId::from(provider.to_string()),
629 provider_name: LanguageModelProviderName::from(provider.to_string()),
630 }
631 }
632 }
633
634 impl LanguageModel for TestLanguageModel {
635 fn id(&self) -> LanguageModelId {
636 self.id.clone()
637 }
638
639 fn name(&self) -> LanguageModelName {
640 self.name.clone()
641 }
642
643 fn provider_id(&self) -> LanguageModelProviderId {
644 self.provider_id.clone()
645 }
646
647 fn provider_name(&self) -> LanguageModelProviderName {
648 self.provider_name.clone()
649 }
650
651 fn supports_tools(&self) -> bool {
652 false
653 }
654
655 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
656 false
657 }
658
659 fn supports_images(&self) -> bool {
660 false
661 }
662
663 fn telemetry_id(&self) -> String {
664 format!("{}/{}", self.provider_id.0, self.name.0)
665 }
666
667 fn max_token_count(&self) -> u64 {
668 1000
669 }
670
671 fn count_tokens(
672 &self,
673 _: LanguageModelRequest,
674 _: &App,
675 ) -> BoxFuture<'static, http_client::Result<u64>> {
676 unimplemented!()
677 }
678
679 fn stream_completion(
680 &self,
681 _: LanguageModelRequest,
682 _: &AsyncApp,
683 ) -> BoxFuture<
684 'static,
685 Result<
686 BoxStream<
687 'static,
688 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
689 >,
690 LanguageModelCompletionError,
691 >,
692 > {
693 unimplemented!()
694 }
695 }
696
697 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
698 model_specs
699 .into_iter()
700 .map(|(provider, name)| ModelInfo {
701 model: Arc::new(TestLanguageModel::new(name, provider)),
702 icon: ProviderIcon::Name(IconName::Ai),
703 })
704 .collect()
705 }
706
707 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
708 assert_eq!(
709 result.len(),
710 expected.len(),
711 "Number of models doesn't match"
712 );
713
714 for (i, expected_name) in expected.iter().enumerate() {
715 assert_eq!(
716 result[i].model.telemetry_id(),
717 *expected_name,
718 "Model at position {} doesn't match expected model",
719 i
720 );
721 }
722 }
723
724 #[gpui::test]
725 fn test_exact_match(cx: &mut TestAppContext) {
726 let models = create_models(vec![
727 ("zed", "Claude 3.7 Sonnet"),
728 ("zed", "Claude 3.7 Sonnet Thinking"),
729 ("zed", "gpt-4.1"),
730 ("zed", "gpt-4.1-nano"),
731 ("openai", "gpt-3.5-turbo"),
732 ("openai", "gpt-4.1"),
733 ("openai", "gpt-4.1-nano"),
734 ("ollama", "mistral"),
735 ("ollama", "deepseek"),
736 ]);
737 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
738
739 // The order of models should be maintained, case doesn't matter
740 let results = matcher.exact_search("GPT-4.1");
741 assert_models_eq(
742 results,
743 vec![
744 "zed/gpt-4.1",
745 "zed/gpt-4.1-nano",
746 "openai/gpt-4.1",
747 "openai/gpt-4.1-nano",
748 ],
749 );
750 }
751
752 #[gpui::test]
753 fn test_fuzzy_match(cx: &mut TestAppContext) {
754 let models = create_models(vec![
755 ("zed", "Claude 3.7 Sonnet"),
756 ("zed", "Claude 3.7 Sonnet Thinking"),
757 ("zed", "gpt-4.1"),
758 ("zed", "gpt-4.1-nano"),
759 ("openai", "gpt-3.5-turbo"),
760 ("openai", "gpt-4.1"),
761 ("openai", "gpt-4.1-nano"),
762 ("ollama", "mistral"),
763 ("ollama", "deepseek"),
764 ]);
765 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
766
767 // Results should preserve models order whenever possible.
768 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
769 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
770 // so it should appear first in the results.
771 let results = matcher.fuzzy_search("41");
772 assert_models_eq(
773 results,
774 vec![
775 "zed/gpt-4.1",
776 "openai/gpt-4.1",
777 "zed/gpt-4.1-nano",
778 "openai/gpt-4.1-nano",
779 ],
780 );
781
782 // Model provider should be searchable as well
783 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
784 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
785
786 // Fuzzy search
787 let results = matcher.fuzzy_search("z4n");
788 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
789 }
790
791 #[gpui::test]
792 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
793 let recommended_models = create_models(vec![("zed", "claude")]);
794 let all_models = create_models(vec![
795 ("zed", "claude"), // Should also appear in "other"
796 ("zed", "gemini"),
797 ("copilot", "o3"),
798 ]);
799
800 let grouped_models = GroupedModels::new(all_models, recommended_models);
801
802 let actual_all_models = grouped_models
803 .all
804 .values()
805 .flatten()
806 .cloned()
807 .collect::<Vec<_>>();
808
809 // Recommended models should also appear in "all"
810 assert_models_eq(
811 actual_all_models,
812 vec!["zed/claude", "zed/gemini", "copilot/o3"],
813 );
814 }
815
816 #[gpui::test]
817 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
818 let recommended_models = create_models(vec![("zed", "claude")]);
819 let all_models = create_models(vec![
820 ("zed", "claude"), // Should also appear in "other"
821 ("zed", "gemini"),
822 ("copilot", "claude"), // Different provider, should appear in "other"
823 ]);
824
825 let grouped_models = GroupedModels::new(all_models, recommended_models);
826
827 let actual_all_models = grouped_models
828 .all
829 .values()
830 .flatten()
831 .cloned()
832 .collect::<Vec<_>>();
833
834 // All models should appear in "all" regardless of recommended status
835 assert_models_eq(
836 actual_all_models,
837 vec!["zed/claude", "zed/gemini", "copilot/claude"],
838 );
839 }
840}