1use std::{cmp::Reverse, sync::Arc};
2
3use collections::IndexMap;
4use futures::{StreamExt, channel::mpsc};
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
7use language_model::{
8 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
9 LanguageModelProviderId, LanguageModelRegistry,
10};
11use ordered_float::OrderedFloat;
12use picker::{Picker, PickerDelegate};
13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
14use zed_actions::agent::OpenSettings;
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
18
19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
20
21pub fn language_model_selector(
22 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
23 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
24 popover_styles: bool,
25 focus_handle: FocusHandle,
26 window: &mut Window,
27 cx: &mut Context<LanguageModelSelector>,
28) -> LanguageModelSelector {
29 let delegate = LanguageModelPickerDelegate::new(
30 get_active_model,
31 on_model_changed,
32 popover_styles,
33 focus_handle,
34 window,
35 cx,
36 );
37
38 if popover_styles {
39 Picker::list(delegate, window, cx)
40 .show_scrollbar(true)
41 .width(rems(20.))
42 .max_height(Some(rems(20.).into()))
43 } else {
44 Picker::list(delegate, window, cx).show_scrollbar(true)
45 }
46}
47
48fn all_models(cx: &App) -> GroupedModels {
49 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
50
51 let recommended = providers
52 .iter()
53 .flat_map(|provider| {
54 provider
55 .recommended_models(cx)
56 .into_iter()
57 .map(|model| ModelInfo {
58 model,
59 icon: ProviderIcon::from_provider(provider.as_ref()),
60 })
61 })
62 .collect();
63
64 let all: Vec<ModelInfo> = providers
65 .iter()
66 .flat_map(|provider| {
67 provider
68 .provided_models(cx)
69 .into_iter()
70 .map(|model| ModelInfo {
71 model,
72 icon: ProviderIcon::from_provider(provider.as_ref()),
73 })
74 })
75 .collect();
76
77 GroupedModels::new(all, recommended)
78}
79
80#[derive(Clone)]
81enum ProviderIcon {
82 Name(IconName),
83 Path(SharedString),
84}
85
86impl ProviderIcon {
87 fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
88 if let Some(path) = provider.icon_path() {
89 Self::Path(path)
90 } else {
91 Self::Name(provider.icon())
92 }
93 }
94}
95
96#[derive(Clone)]
97struct ModelInfo {
98 model: Arc<dyn LanguageModel>,
99 icon: ProviderIcon,
100}
101
102pub struct LanguageModelPickerDelegate {
103 on_model_changed: OnModelChanged,
104 get_active_model: GetActiveModel,
105 all_models: Arc<GroupedModels>,
106 filtered_entries: Vec<LanguageModelPickerEntry>,
107 selected_index: usize,
108 _authenticate_all_providers_task: Task<()>,
109 _refresh_models_task: Task<()>,
110 popover_styles: bool,
111 focus_handle: FocusHandle,
112}
113
114impl LanguageModelPickerDelegate {
115 fn new(
116 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
117 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
118 popover_styles: bool,
119 focus_handle: FocusHandle,
120 window: &mut Window,
121 cx: &mut Context<Picker<Self>>,
122 ) -> Self {
123 let on_model_changed = Arc::new(on_model_changed);
124 let models = all_models(cx);
125 let entries = models.entries();
126
127 Self {
128 on_model_changed,
129 all_models: Arc::new(models),
130 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
131 filtered_entries: entries,
132 get_active_model: Arc::new(get_active_model),
133 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
134 _refresh_models_task: {
135 // Create a channel to signal when models need refreshing
136 let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
137
138 // Subscribe to registry events and send refresh signals through the channel
139 let registry = LanguageModelRegistry::global(cx);
140 cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
141 language_model::Event::ProviderStateChanged(_) => {
142 refresh_tx.unbounded_send(()).ok();
143 }
144 language_model::Event::AddedProvider(_) => {
145 refresh_tx.unbounded_send(()).ok();
146 }
147 language_model::Event::RemovedProvider(_) => {
148 refresh_tx.unbounded_send(()).ok();
149 }
150 _ => {}
151 })
152 .detach();
153
154 // Spawn a task that listens for refresh signals and updates the picker
155 cx.spawn_in(window, async move |this, cx| {
156 while let Some(()) = refresh_rx.next().await {
157 let result = this.update_in(cx, |picker, window, cx| {
158 picker.delegate.all_models = Arc::new(all_models(cx));
159 picker.refresh(window, cx);
160 });
161 if result.is_err() {
162 // Picker was dropped, exit the loop
163 break;
164 }
165 }
166 })
167 },
168 popover_styles,
169 focus_handle,
170 }
171 }
172
173 fn get_active_model_index(
174 entries: &[LanguageModelPickerEntry],
175 active_model: Option<ConfiguredModel>,
176 ) -> usize {
177 entries
178 .iter()
179 .position(|entry| {
180 if let LanguageModelPickerEntry::Model(model) = entry {
181 active_model
182 .as_ref()
183 .map(|active_model| {
184 active_model.model.id() == model.model.id()
185 && active_model.provider.id() == model.model.provider_id()
186 })
187 .unwrap_or_default()
188 } else {
189 false
190 }
191 })
192 .unwrap_or(0)
193 }
194
195 /// Authenticates all providers in the [`LanguageModelRegistry`].
196 ///
197 /// We do this so that we can populate the language selector with all of the
198 /// models from the configured providers.
199 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
200 let authenticate_all_providers = LanguageModelRegistry::global(cx)
201 .read(cx)
202 .providers()
203 .iter()
204 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
205 .collect::<Vec<_>>();
206
207 cx.spawn(async move |_cx| {
208 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
209 if let Err(err) = authenticate_task.await {
210 if matches!(err, AuthenticateError::CredentialsNotFound) {
211 // Since we're authenticating these providers in the
212 // background for the purposes of populating the
213 // language selector, we don't care about providers
214 // where the credentials are not found.
215 } else {
216 // Some providers have noisy failure states that we
217 // don't want to spam the logs with every time the
218 // language model selector is initialized.
219 //
220 // Ideally these should have more clear failure modes
221 // that we know are safe to ignore here, like what we do
222 // with `CredentialsNotFound` above.
223 match provider_id.0.as_ref() {
224 "lmstudio" | "ollama" => {
225 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
226 //
227 // These fail noisily, so we don't log them.
228 }
229 "copilot_chat" => {
230 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
231 }
232 _ => {
233 log::error!(
234 "Failed to authenticate provider: {}: {err:#}",
235 provider_name.0
236 );
237 }
238 }
239 }
240 }
241 }
242 })
243 }
244
245 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
246 (self.get_active_model)(cx)
247 }
248}
249
250struct GroupedModels {
251 recommended: Vec<ModelInfo>,
252 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
253}
254
255impl GroupedModels {
256 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
257 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
258 for model in all {
259 let provider = model.model.provider_id();
260 if let Some(models) = all_by_provider.get_mut(&provider) {
261 models.push(model);
262 } else {
263 all_by_provider.insert(provider, vec![model]);
264 }
265 }
266
267 Self {
268 recommended,
269 all: all_by_provider,
270 }
271 }
272
273 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
274 let mut entries = Vec::new();
275
276 if !self.recommended.is_empty() {
277 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
278 entries.extend(
279 self.recommended
280 .iter()
281 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
282 );
283 }
284
285 for models in self.all.values() {
286 if models.is_empty() {
287 continue;
288 }
289 entries.push(LanguageModelPickerEntry::Separator(
290 models[0].model.provider_name().0,
291 ));
292 entries.extend(
293 models
294 .iter()
295 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
296 );
297 }
298 entries
299 }
300}
301
302enum LanguageModelPickerEntry {
303 Model(ModelInfo),
304 Separator(SharedString),
305}
306
307struct ModelMatcher {
308 models: Vec<ModelInfo>,
309 bg_executor: BackgroundExecutor,
310 candidates: Vec<StringMatchCandidate>,
311}
312
313impl ModelMatcher {
314 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
315 let candidates = Self::make_match_candidates(&models);
316 Self {
317 models,
318 bg_executor,
319 candidates,
320 }
321 }
322
323 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
324 let mut matches = self.bg_executor.block(match_strings(
325 &self.candidates,
326 query,
327 false,
328 true,
329 100,
330 &Default::default(),
331 self.bg_executor.clone(),
332 ));
333
334 let sorting_key = |mat: &StringMatch| {
335 let candidate = &self.candidates[mat.candidate_id];
336 (Reverse(OrderedFloat(mat.score)), candidate.id)
337 };
338 matches.sort_unstable_by_key(sorting_key);
339
340 let matched_models: Vec<_> = matches
341 .into_iter()
342 .map(|mat| self.models[mat.candidate_id].clone())
343 .collect();
344
345 matched_models
346 }
347
348 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
349 self.models
350 .iter()
351 .filter(|m| {
352 m.model
353 .name()
354 .0
355 .to_lowercase()
356 .contains(&query.to_lowercase())
357 })
358 .cloned()
359 .collect::<Vec<_>>()
360 }
361
362 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
363 model_infos
364 .iter()
365 .enumerate()
366 .map(|(index, model)| {
367 StringMatchCandidate::new(
368 index,
369 &format!(
370 "{}/{}",
371 &model.model.provider_name().0,
372 &model.model.name().0
373 ),
374 )
375 })
376 .collect::<Vec<_>>()
377 }
378}
379
380impl PickerDelegate for LanguageModelPickerDelegate {
381 type ListItem = AnyElement;
382
383 fn match_count(&self) -> usize {
384 self.filtered_entries.len()
385 }
386
387 fn selected_index(&self) -> usize {
388 self.selected_index
389 }
390
391 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
392 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
393 cx.notify();
394 }
395
396 fn can_select(
397 &mut self,
398 ix: usize,
399 _window: &mut Window,
400 _cx: &mut Context<Picker<Self>>,
401 ) -> bool {
402 match self.filtered_entries.get(ix) {
403 Some(LanguageModelPickerEntry::Model(_)) => true,
404 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
405 }
406 }
407
408 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
409 "Select a model…".into()
410 }
411
412 fn update_matches(
413 &mut self,
414 query: String,
415 window: &mut Window,
416 cx: &mut Context<Picker<Self>>,
417 ) -> Task<()> {
418 let all_models = self.all_models.clone();
419 let active_model = (self.get_active_model)(cx);
420 let bg_executor = cx.background_executor();
421
422 let language_model_registry = LanguageModelRegistry::global(cx);
423
424 let configured_providers = language_model_registry
425 .read(cx)
426 .providers()
427 .into_iter()
428 .filter(|provider| provider.is_authenticated(cx))
429 .collect::<Vec<_>>();
430
431 let configured_provider_ids = configured_providers
432 .iter()
433 .map(|provider| provider.id())
434 .collect::<Vec<_>>();
435
436 let recommended_models = all_models
437 .recommended
438 .iter()
439 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
440 .cloned()
441 .collect::<Vec<_>>();
442
443 let available_models = all_models
444 .all
445 .values()
446 .flat_map(|models| models.iter())
447 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
448 .cloned()
449 .collect::<Vec<_>>();
450
451 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
452 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
453
454 let recommended = matcher_rec.exact_search(&query);
455 let all = matcher_all.fuzzy_search(&query);
456
457 let filtered_models = GroupedModels::new(all, recommended);
458
459 cx.spawn_in(window, async move |this, cx| {
460 this.update_in(cx, |this, window, cx| {
461 this.delegate.filtered_entries = filtered_models.entries();
462 // Finds the currently selected model in the list
463 let new_index =
464 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
465 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
466 cx.notify();
467 })
468 .ok();
469 })
470 }
471
472 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
473 if let Some(LanguageModelPickerEntry::Model(model_info)) =
474 self.filtered_entries.get(self.selected_index)
475 {
476 let model = model_info.model.clone();
477 (self.on_model_changed)(model.clone(), cx);
478
479 let current_index = self.selected_index;
480 self.set_selected_index(current_index, window, cx);
481
482 cx.emit(DismissEvent);
483 }
484 }
485
486 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
487 cx.emit(DismissEvent);
488 }
489
490 fn render_match(
491 &self,
492 ix: usize,
493 selected: bool,
494 _: &mut Window,
495 cx: &mut Context<Picker<Self>>,
496 ) -> Option<Self::ListItem> {
497 match self.filtered_entries.get(ix)? {
498 LanguageModelPickerEntry::Separator(title) => Some(
499 div()
500 .px_2()
501 .pb_1()
502 .when(ix > 1, |this| {
503 this.mt_1()
504 .pt_2()
505 .border_t_1()
506 .border_color(cx.theme().colors().border_variant)
507 })
508 .child(
509 Label::new(title)
510 .size(LabelSize::XSmall)
511 .color(Color::Muted),
512 )
513 .into_any_element(),
514 ),
515 LanguageModelPickerEntry::Model(model_info) => {
516 let active_model = (self.get_active_model)(cx);
517 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
518 let active_model_id = active_model.map(|m| m.model.id());
519
520 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
521 && Some(model_info.model.id()) == active_model_id;
522
523 let model_icon_color = if is_selected {
524 Color::Accent
525 } else {
526 Color::Muted
527 };
528
529 Some(
530 ListItem::new(ix)
531 .inset(true)
532 .spacing(ListItemSpacing::Sparse)
533 .toggle_state(selected)
534 .child(
535 h_flex()
536 .w_full()
537 .gap_1p5()
538 .child(match &model_info.icon {
539 ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
540 .color(model_icon_color)
541 .size(IconSize::Small),
542 ProviderIcon::Path(icon_path) => {
543 Icon::from_external_svg(icon_path.clone())
544 .color(model_icon_color)
545 .size(IconSize::Small)
546 }
547 })
548 .child(Label::new(model_info.model.name().0).truncate()),
549 )
550 .end_slot(div().pr_3().when(is_selected, |this| {
551 this.child(
552 Icon::new(IconName::Check)
553 .color(Color::Accent)
554 .size(IconSize::Small),
555 )
556 }))
557 .into_any_element(),
558 )
559 }
560 }
561 }
562
563 fn render_footer(
564 &self,
565 _window: &mut Window,
566 cx: &mut Context<Picker<Self>>,
567 ) -> Option<gpui::AnyElement> {
568 let focus_handle = self.focus_handle.clone();
569
570 if !self.popover_styles {
571 return None;
572 }
573
574 Some(
575 h_flex()
576 .w_full()
577 .p_1p5()
578 .border_t_1()
579 .border_color(cx.theme().colors().border_variant)
580 .child(
581 Button::new("configure", "Configure")
582 .full_width()
583 .style(ButtonStyle::Outlined)
584 .key_binding(
585 KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
586 .map(|kb| kb.size(rems_from_px(12.))),
587 )
588 .on_click(|_, window, cx| {
589 window.dispatch_action(OpenSettings.boxed_clone(), cx);
590 }),
591 )
592 .into_any(),
593 )
594 }
595}
596
597#[cfg(test)]
598mod tests {
599 use super::*;
600 use futures::{future::BoxFuture, stream::BoxStream};
601 use gpui::{AsyncApp, TestAppContext, http_client};
602 use language_model::{
603 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
604 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
605 LanguageModelRequest, LanguageModelToolChoice,
606 };
607 use ui::IconName;
608
609 #[derive(Clone)]
610 struct TestLanguageModel {
611 name: LanguageModelName,
612 id: LanguageModelId,
613 provider_id: LanguageModelProviderId,
614 provider_name: LanguageModelProviderName,
615 }
616
617 impl TestLanguageModel {
618 fn new(name: &str, provider: &str) -> Self {
619 Self {
620 name: LanguageModelName::from(name.to_string()),
621 id: LanguageModelId::from(name.to_string()),
622 provider_id: LanguageModelProviderId::from(provider.to_string()),
623 provider_name: LanguageModelProviderName::from(provider.to_string()),
624 }
625 }
626 }
627
628 impl LanguageModel for TestLanguageModel {
629 fn id(&self) -> LanguageModelId {
630 self.id.clone()
631 }
632
633 fn name(&self) -> LanguageModelName {
634 self.name.clone()
635 }
636
637 fn provider_id(&self) -> LanguageModelProviderId {
638 self.provider_id.clone()
639 }
640
641 fn provider_name(&self) -> LanguageModelProviderName {
642 self.provider_name.clone()
643 }
644
645 fn supports_tools(&self) -> bool {
646 false
647 }
648
649 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
650 false
651 }
652
653 fn supports_images(&self) -> bool {
654 false
655 }
656
657 fn telemetry_id(&self) -> String {
658 format!("{}/{}", self.provider_id.0, self.name.0)
659 }
660
661 fn max_token_count(&self) -> u64 {
662 1000
663 }
664
665 fn count_tokens(
666 &self,
667 _: LanguageModelRequest,
668 _: &App,
669 ) -> BoxFuture<'static, http_client::Result<u64>> {
670 unimplemented!()
671 }
672
673 fn stream_completion(
674 &self,
675 _: LanguageModelRequest,
676 _: &AsyncApp,
677 ) -> BoxFuture<
678 'static,
679 Result<
680 BoxStream<
681 'static,
682 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
683 >,
684 LanguageModelCompletionError,
685 >,
686 > {
687 unimplemented!()
688 }
689 }
690
691 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
692 model_specs
693 .into_iter()
694 .map(|(provider, name)| ModelInfo {
695 model: Arc::new(TestLanguageModel::new(name, provider)),
696 icon: ProviderIcon::Name(IconName::Ai),
697 })
698 .collect()
699 }
700
701 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
702 assert_eq!(
703 result.len(),
704 expected.len(),
705 "Number of models doesn't match"
706 );
707
708 for (i, expected_name) in expected.iter().enumerate() {
709 assert_eq!(
710 result[i].model.telemetry_id(),
711 *expected_name,
712 "Model at position {} doesn't match expected model",
713 i
714 );
715 }
716 }
717
718 #[gpui::test]
719 fn test_exact_match(cx: &mut TestAppContext) {
720 let models = create_models(vec![
721 ("zed", "Claude 3.7 Sonnet"),
722 ("zed", "Claude 3.7 Sonnet Thinking"),
723 ("zed", "gpt-4.1"),
724 ("zed", "gpt-4.1-nano"),
725 ("openai", "gpt-3.5-turbo"),
726 ("openai", "gpt-4.1"),
727 ("openai", "gpt-4.1-nano"),
728 ("ollama", "mistral"),
729 ("ollama", "deepseek"),
730 ]);
731 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
732
733 // The order of models should be maintained, case doesn't matter
734 let results = matcher.exact_search("GPT-4.1");
735 assert_models_eq(
736 results,
737 vec![
738 "zed/gpt-4.1",
739 "zed/gpt-4.1-nano",
740 "openai/gpt-4.1",
741 "openai/gpt-4.1-nano",
742 ],
743 );
744 }
745
746 #[gpui::test]
747 fn test_fuzzy_match(cx: &mut TestAppContext) {
748 let models = create_models(vec![
749 ("zed", "Claude 3.7 Sonnet"),
750 ("zed", "Claude 3.7 Sonnet Thinking"),
751 ("zed", "gpt-4.1"),
752 ("zed", "gpt-4.1-nano"),
753 ("openai", "gpt-3.5-turbo"),
754 ("openai", "gpt-4.1"),
755 ("openai", "gpt-4.1-nano"),
756 ("ollama", "mistral"),
757 ("ollama", "deepseek"),
758 ]);
759 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
760
761 // Results should preserve models order whenever possible.
762 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
763 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
764 // so it should appear first in the results.
765 let results = matcher.fuzzy_search("41");
766 assert_models_eq(
767 results,
768 vec![
769 "zed/gpt-4.1",
770 "openai/gpt-4.1",
771 "zed/gpt-4.1-nano",
772 "openai/gpt-4.1-nano",
773 ],
774 );
775
776 // Model provider should be searchable as well
777 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
778 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
779
780 // Fuzzy search
781 let results = matcher.fuzzy_search("z4n");
782 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
783 }
784
785 #[gpui::test]
786 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
787 let recommended_models = create_models(vec![("zed", "claude")]);
788 let all_models = create_models(vec![
789 ("zed", "claude"), // Should also appear in "other"
790 ("zed", "gemini"),
791 ("copilot", "o3"),
792 ]);
793
794 let grouped_models = GroupedModels::new(all_models, recommended_models);
795
796 let actual_all_models = grouped_models
797 .all
798 .values()
799 .flatten()
800 .cloned()
801 .collect::<Vec<_>>();
802
803 // Recommended models should also appear in "all"
804 assert_models_eq(
805 actual_all_models,
806 vec!["zed/claude", "zed/gemini", "copilot/o3"],
807 );
808 }
809
810 #[gpui::test]
811 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
812 let recommended_models = create_models(vec![("zed", "claude")]);
813 let all_models = create_models(vec![
814 ("zed", "claude"), // Should also appear in "other"
815 ("zed", "gemini"),
816 ("copilot", "claude"), // Different provider, should appear in "other"
817 ]);
818
819 let grouped_models = GroupedModels::new(all_models, recommended_models);
820
821 let actual_all_models = grouped_models
822 .all
823 .values()
824 .flatten()
825 .cloned()
826 .collect::<Vec<_>>();
827
828 // All models should appear in "all" regardless of recommended status
829 assert_models_eq(
830 actual_all_models,
831 vec!["zed/claude", "zed/gemini", "copilot/claude"],
832 );
833 }
834}