1use std::{cmp::Reverse, sync::Arc};
2
3use collections::IndexMap;
4use futures::{StreamExt, channel::mpsc};
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
7use language_model::{
8 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
9 LanguageModelProviderId, LanguageModelRegistry,
10};
11use ordered_float::OrderedFloat;
12use picker::{Picker, PickerDelegate};
13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
14use zed_actions::agent::OpenSettings;
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
18
19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
20
21pub fn language_model_selector(
22 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
23 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
24 popover_styles: bool,
25 focus_handle: FocusHandle,
26 window: &mut Window,
27 cx: &mut Context<LanguageModelSelector>,
28) -> LanguageModelSelector {
29 let delegate = LanguageModelPickerDelegate::new(
30 get_active_model,
31 on_model_changed,
32 popover_styles,
33 focus_handle,
34 window,
35 cx,
36 );
37
38 if popover_styles {
39 Picker::list(delegate, window, cx)
40 .show_scrollbar(true)
41 .width(rems(20.))
42 .max_height(Some(rems(20.).into()))
43 } else {
44 Picker::list(delegate, window, cx).show_scrollbar(true)
45 }
46}
47
48fn all_models(cx: &App) -> GroupedModels {
49 let providers = LanguageModelRegistry::global(cx)
50 .read(cx)
51 .visible_providers();
52
53 let recommended = providers
54 .iter()
55 .flat_map(|provider| {
56 provider
57 .recommended_models(cx)
58 .into_iter()
59 .map(|model| ModelInfo {
60 model,
61 icon: ProviderIcon::from_provider(provider.as_ref()),
62 })
63 })
64 .collect();
65
66 let all: Vec<ModelInfo> = providers
67 .iter()
68 .flat_map(|provider| {
69 provider
70 .provided_models(cx)
71 .into_iter()
72 .map(|model| ModelInfo {
73 model,
74 icon: ProviderIcon::from_provider(provider.as_ref()),
75 })
76 })
77 .collect();
78
79 GroupedModels::new(all, recommended)
80}
81
82#[derive(Clone)]
83enum ProviderIcon {
84 Name(IconName),
85 Path(SharedString),
86}
87
88impl ProviderIcon {
89 fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
90 if let Some(path) = provider.icon_path() {
91 Self::Path(path)
92 } else {
93 Self::Name(provider.icon())
94 }
95 }
96}
97
98#[derive(Clone)]
99struct ModelInfo {
100 model: Arc<dyn LanguageModel>,
101 icon: ProviderIcon,
102}
103
104pub struct LanguageModelPickerDelegate {
105 on_model_changed: OnModelChanged,
106 get_active_model: GetActiveModel,
107 all_models: Arc<GroupedModels>,
108 filtered_entries: Vec<LanguageModelPickerEntry>,
109 selected_index: usize,
110 _authenticate_all_providers_task: Task<()>,
111 _refresh_models_task: Task<()>,
112 popover_styles: bool,
113 focus_handle: FocusHandle,
114}
115
116impl LanguageModelPickerDelegate {
117 fn new(
118 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
119 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
120 popover_styles: bool,
121 focus_handle: FocusHandle,
122 window: &mut Window,
123 cx: &mut Context<Picker<Self>>,
124 ) -> Self {
125 let on_model_changed = Arc::new(on_model_changed);
126 let models = all_models(cx);
127 let entries = models.entries();
128
129 Self {
130 on_model_changed,
131 all_models: Arc::new(models),
132 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
133 filtered_entries: entries,
134 get_active_model: Arc::new(get_active_model),
135 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
136 _refresh_models_task: {
137 // Create a channel to signal when models need refreshing
138 let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
139
140 // Subscribe to registry events and send refresh signals through the channel
141 let registry = LanguageModelRegistry::global(cx);
142 cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
143 language_model::Event::ProviderStateChanged(_) => {
144 refresh_tx.unbounded_send(()).ok();
145 }
146 language_model::Event::AddedProvider(_) => {
147 refresh_tx.unbounded_send(()).ok();
148 }
149 language_model::Event::RemovedProvider(_) => {
150 refresh_tx.unbounded_send(()).ok();
151 }
152 language_model::Event::ProvidersChanged => {
153 refresh_tx.unbounded_send(()).ok();
154 }
155 _ => {}
156 })
157 .detach();
158
159 // Spawn a task that listens for refresh signals and updates the picker
160 cx.spawn_in(window, async move |this, cx| {
161 while let Some(()) = refresh_rx.next().await {
162 let result = this.update_in(cx, |picker, window, cx| {
163 picker.delegate.all_models = Arc::new(all_models(cx));
164 picker.refresh(window, cx);
165 });
166 if result.is_err() {
167 // Picker was dropped, exit the loop
168 break;
169 }
170 }
171 })
172 },
173 popover_styles,
174 focus_handle,
175 }
176 }
177
178 fn get_active_model_index(
179 entries: &[LanguageModelPickerEntry],
180 active_model: Option<ConfiguredModel>,
181 ) -> usize {
182 entries
183 .iter()
184 .position(|entry| {
185 if let LanguageModelPickerEntry::Model(model) = entry {
186 active_model
187 .as_ref()
188 .map(|active_model| {
189 active_model.model.id() == model.model.id()
190 && active_model.provider.id() == model.model.provider_id()
191 })
192 .unwrap_or_default()
193 } else {
194 false
195 }
196 })
197 .unwrap_or(0)
198 }
199
200 /// Authenticates all providers in the [`LanguageModelRegistry`].
201 ///
202 /// We do this so that we can populate the language selector with all of the
203 /// models from the configured providers.
204 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
205 let authenticate_all_providers = LanguageModelRegistry::global(cx)
206 .read(cx)
207 .providers()
208 .iter()
209 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
210 .collect::<Vec<_>>();
211
212 cx.spawn(async move |_cx| {
213 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
214 if let Err(err) = authenticate_task.await {
215 if matches!(err, AuthenticateError::CredentialsNotFound) {
216 // Since we're authenticating these providers in the
217 // background for the purposes of populating the
218 // language selector, we don't care about providers
219 // where the credentials are not found.
220 } else {
221 // Some providers have noisy failure states that we
222 // don't want to spam the logs with every time the
223 // language model selector is initialized.
224 //
225 // Ideally these should have more clear failure modes
226 // that we know are safe to ignore here, like what we do
227 // with `CredentialsNotFound` above.
228 match provider_id.0.as_ref() {
229 "lmstudio" | "ollama" => {
230 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
231 //
232 // These fail noisily, so we don't log them.
233 }
234 "copilot_chat" => {
235 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
236 }
237 _ => {
238 log::error!(
239 "Failed to authenticate provider: {}: {err:#}",
240 provider_name.0
241 );
242 }
243 }
244 }
245 }
246 }
247 })
248 }
249
250 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
251 (self.get_active_model)(cx)
252 }
253}
254
255struct GroupedModels {
256 recommended: Vec<ModelInfo>,
257 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
258}
259
260impl GroupedModels {
261 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
262 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
263 for model in all {
264 let provider = model.model.provider_id();
265 if let Some(models) = all_by_provider.get_mut(&provider) {
266 models.push(model);
267 } else {
268 all_by_provider.insert(provider, vec![model]);
269 }
270 }
271
272 Self {
273 recommended,
274 all: all_by_provider,
275 }
276 }
277
278 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
279 let mut entries = Vec::new();
280
281 if !self.recommended.is_empty() {
282 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
283 entries.extend(
284 self.recommended
285 .iter()
286 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
287 );
288 }
289
290 for models in self.all.values() {
291 if models.is_empty() {
292 continue;
293 }
294 entries.push(LanguageModelPickerEntry::Separator(
295 models[0].model.provider_name().0,
296 ));
297 entries.extend(
298 models
299 .iter()
300 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
301 );
302 }
303 entries
304 }
305}
306
307enum LanguageModelPickerEntry {
308 Model(ModelInfo),
309 Separator(SharedString),
310}
311
312struct ModelMatcher {
313 models: Vec<ModelInfo>,
314 bg_executor: BackgroundExecutor,
315 candidates: Vec<StringMatchCandidate>,
316}
317
318impl ModelMatcher {
319 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
320 let candidates = Self::make_match_candidates(&models);
321 Self {
322 models,
323 bg_executor,
324 candidates,
325 }
326 }
327
328 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
329 let mut matches = self.bg_executor.block(match_strings(
330 &self.candidates,
331 query,
332 false,
333 true,
334 100,
335 &Default::default(),
336 self.bg_executor.clone(),
337 ));
338
339 let sorting_key = |mat: &StringMatch| {
340 let candidate = &self.candidates[mat.candidate_id];
341 (Reverse(OrderedFloat(mat.score)), candidate.id)
342 };
343 matches.sort_unstable_by_key(sorting_key);
344
345 let matched_models: Vec<_> = matches
346 .into_iter()
347 .map(|mat| self.models[mat.candidate_id].clone())
348 .collect();
349
350 matched_models
351 }
352
353 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
354 self.models
355 .iter()
356 .filter(|m| {
357 m.model
358 .name()
359 .0
360 .to_lowercase()
361 .contains(&query.to_lowercase())
362 })
363 .cloned()
364 .collect::<Vec<_>>()
365 }
366
367 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
368 model_infos
369 .iter()
370 .enumerate()
371 .map(|(index, model)| {
372 StringMatchCandidate::new(
373 index,
374 &format!(
375 "{}/{}",
376 &model.model.provider_name().0,
377 &model.model.name().0
378 ),
379 )
380 })
381 .collect::<Vec<_>>()
382 }
383}
384
385impl PickerDelegate for LanguageModelPickerDelegate {
386 type ListItem = AnyElement;
387
388 fn match_count(&self) -> usize {
389 self.filtered_entries.len()
390 }
391
392 fn selected_index(&self) -> usize {
393 self.selected_index
394 }
395
396 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
397 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
398 cx.notify();
399 }
400
401 fn can_select(
402 &mut self,
403 ix: usize,
404 _window: &mut Window,
405 _cx: &mut Context<Picker<Self>>,
406 ) -> bool {
407 match self.filtered_entries.get(ix) {
408 Some(LanguageModelPickerEntry::Model(_)) => true,
409 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
410 }
411 }
412
413 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
414 "Select a model…".into()
415 }
416
417 fn update_matches(
418 &mut self,
419 query: String,
420 window: &mut Window,
421 cx: &mut Context<Picker<Self>>,
422 ) -> Task<()> {
423 let all_models = self.all_models.clone();
424 let active_model = (self.get_active_model)(cx);
425 let bg_executor = cx.background_executor();
426
427 let language_model_registry = LanguageModelRegistry::global(cx);
428
429 let configured_providers = language_model_registry
430 .read(cx)
431 .visible_providers()
432 .into_iter()
433 .filter(|provider| provider.is_authenticated(cx))
434 .collect::<Vec<_>>();
435
436 let configured_provider_ids = configured_providers
437 .iter()
438 .map(|provider| provider.id())
439 .collect::<Vec<_>>();
440
441 let recommended_models = all_models
442 .recommended
443 .iter()
444 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
445 .cloned()
446 .collect::<Vec<_>>();
447
448 let available_models = all_models
449 .all
450 .values()
451 .flat_map(|models| models.iter())
452 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
453 .cloned()
454 .collect::<Vec<_>>();
455
456 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
457 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
458
459 let recommended = matcher_rec.exact_search(&query);
460 let all = matcher_all.fuzzy_search(&query);
461
462 let filtered_models = GroupedModels::new(all, recommended);
463
464 cx.spawn_in(window, async move |this, cx| {
465 this.update_in(cx, |this, window, cx| {
466 this.delegate.filtered_entries = filtered_models.entries();
467 // Finds the currently selected model in the list
468 let new_index =
469 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
470 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
471 cx.notify();
472 })
473 .ok();
474 })
475 }
476
477 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
478 if let Some(LanguageModelPickerEntry::Model(model_info)) =
479 self.filtered_entries.get(self.selected_index)
480 {
481 let model = model_info.model.clone();
482 (self.on_model_changed)(model.clone(), cx);
483
484 let current_index = self.selected_index;
485 self.set_selected_index(current_index, window, cx);
486
487 cx.emit(DismissEvent);
488 }
489 }
490
491 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
492 cx.emit(DismissEvent);
493 }
494
495 fn render_match(
496 &self,
497 ix: usize,
498 selected: bool,
499 _: &mut Window,
500 cx: &mut Context<Picker<Self>>,
501 ) -> Option<Self::ListItem> {
502 match self.filtered_entries.get(ix)? {
503 LanguageModelPickerEntry::Separator(title) => Some(
504 div()
505 .px_2()
506 .pb_1()
507 .when(ix > 1, |this| {
508 this.mt_1()
509 .pt_2()
510 .border_t_1()
511 .border_color(cx.theme().colors().border_variant)
512 })
513 .child(
514 Label::new(title)
515 .size(LabelSize::XSmall)
516 .color(Color::Muted),
517 )
518 .into_any_element(),
519 ),
520 LanguageModelPickerEntry::Model(model_info) => {
521 let active_model = (self.get_active_model)(cx);
522 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
523 let active_model_id = active_model.map(|m| m.model.id());
524
525 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
526 && Some(model_info.model.id()) == active_model_id;
527
528 let model_icon_color = if is_selected {
529 Color::Accent
530 } else {
531 Color::Muted
532 };
533
534 Some(
535 ListItem::new(ix)
536 .inset(true)
537 .spacing(ListItemSpacing::Sparse)
538 .toggle_state(selected)
539 .child(
540 h_flex()
541 .w_full()
542 .gap_1p5()
543 .child(match &model_info.icon {
544 ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
545 .color(model_icon_color)
546 .size(IconSize::Small),
547 ProviderIcon::Path(icon_path) => {
548 Icon::from_external_svg(icon_path.clone())
549 .color(model_icon_color)
550 .size(IconSize::Small)
551 }
552 })
553 .child(Label::new(model_info.model.name().0).truncate()),
554 )
555 .end_slot(div().pr_3().when(is_selected, |this| {
556 this.child(
557 Icon::new(IconName::Check)
558 .color(Color::Accent)
559 .size(IconSize::Small),
560 )
561 }))
562 .into_any_element(),
563 )
564 }
565 }
566 }
567
568 fn render_footer(
569 &self,
570 _window: &mut Window,
571 cx: &mut Context<Picker<Self>>,
572 ) -> Option<gpui::AnyElement> {
573 let focus_handle = self.focus_handle.clone();
574
575 if !self.popover_styles {
576 return None;
577 }
578
579 Some(
580 h_flex()
581 .w_full()
582 .p_1p5()
583 .border_t_1()
584 .border_color(cx.theme().colors().border_variant)
585 .child(
586 Button::new("configure", "Configure")
587 .full_width()
588 .style(ButtonStyle::Outlined)
589 .key_binding(
590 KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
591 .map(|kb| kb.size(rems_from_px(12.))),
592 )
593 .on_click(|_, window, cx| {
594 window.dispatch_action(OpenSettings.boxed_clone(), cx);
595 }),
596 )
597 .into_any(),
598 )
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605 use futures::{future::BoxFuture, stream::BoxStream};
606 use gpui::{AsyncApp, TestAppContext, http_client};
607 use language_model::{
608 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
609 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
610 LanguageModelRequest, LanguageModelToolChoice,
611 };
612 use ui::IconName;
613
614 #[derive(Clone)]
615 struct TestLanguageModel {
616 name: LanguageModelName,
617 id: LanguageModelId,
618 provider_id: LanguageModelProviderId,
619 provider_name: LanguageModelProviderName,
620 }
621
622 impl TestLanguageModel {
623 fn new(name: &str, provider: &str) -> Self {
624 Self {
625 name: LanguageModelName::from(name.to_string()),
626 id: LanguageModelId::from(name.to_string()),
627 provider_id: LanguageModelProviderId::from(provider.to_string()),
628 provider_name: LanguageModelProviderName::from(provider.to_string()),
629 }
630 }
631 }
632
633 impl LanguageModel for TestLanguageModel {
634 fn id(&self) -> LanguageModelId {
635 self.id.clone()
636 }
637
638 fn name(&self) -> LanguageModelName {
639 self.name.clone()
640 }
641
642 fn provider_id(&self) -> LanguageModelProviderId {
643 self.provider_id.clone()
644 }
645
646 fn provider_name(&self) -> LanguageModelProviderName {
647 self.provider_name.clone()
648 }
649
650 fn supports_tools(&self) -> bool {
651 false
652 }
653
654 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
655 false
656 }
657
658 fn supports_images(&self) -> bool {
659 false
660 }
661
662 fn telemetry_id(&self) -> String {
663 format!("{}/{}", self.provider_id.0, self.name.0)
664 }
665
666 fn max_token_count(&self) -> u64 {
667 1000
668 }
669
670 fn count_tokens(
671 &self,
672 _: LanguageModelRequest,
673 _: &App,
674 ) -> BoxFuture<'static, http_client::Result<u64>> {
675 unimplemented!()
676 }
677
678 fn stream_completion(
679 &self,
680 _: LanguageModelRequest,
681 _: &AsyncApp,
682 ) -> BoxFuture<
683 'static,
684 Result<
685 BoxStream<
686 'static,
687 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
688 >,
689 LanguageModelCompletionError,
690 >,
691 > {
692 unimplemented!()
693 }
694 }
695
696 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
697 model_specs
698 .into_iter()
699 .map(|(provider, name)| ModelInfo {
700 model: Arc::new(TestLanguageModel::new(name, provider)),
701 icon: ProviderIcon::Name(IconName::Ai),
702 })
703 .collect()
704 }
705
706 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
707 assert_eq!(
708 result.len(),
709 expected.len(),
710 "Number of models doesn't match"
711 );
712
713 for (i, expected_name) in expected.iter().enumerate() {
714 assert_eq!(
715 result[i].model.telemetry_id(),
716 *expected_name,
717 "Model at position {} doesn't match expected model",
718 i
719 );
720 }
721 }
722
723 #[gpui::test]
724 fn test_exact_match(cx: &mut TestAppContext) {
725 let models = create_models(vec![
726 ("zed", "Claude 3.7 Sonnet"),
727 ("zed", "Claude 3.7 Sonnet Thinking"),
728 ("zed", "gpt-4.1"),
729 ("zed", "gpt-4.1-nano"),
730 ("openai", "gpt-3.5-turbo"),
731 ("openai", "gpt-4.1"),
732 ("openai", "gpt-4.1-nano"),
733 ("ollama", "mistral"),
734 ("ollama", "deepseek"),
735 ]);
736 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
737
738 // The order of models should be maintained, case doesn't matter
739 let results = matcher.exact_search("GPT-4.1");
740 assert_models_eq(
741 results,
742 vec![
743 "zed/gpt-4.1",
744 "zed/gpt-4.1-nano",
745 "openai/gpt-4.1",
746 "openai/gpt-4.1-nano",
747 ],
748 );
749 }
750
751 #[gpui::test]
752 fn test_fuzzy_match(cx: &mut TestAppContext) {
753 let models = create_models(vec![
754 ("zed", "Claude 3.7 Sonnet"),
755 ("zed", "Claude 3.7 Sonnet Thinking"),
756 ("zed", "gpt-4.1"),
757 ("zed", "gpt-4.1-nano"),
758 ("openai", "gpt-3.5-turbo"),
759 ("openai", "gpt-4.1"),
760 ("openai", "gpt-4.1-nano"),
761 ("ollama", "mistral"),
762 ("ollama", "deepseek"),
763 ]);
764 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
765
766 // Results should preserve models order whenever possible.
767 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
768 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
769 // so it should appear first in the results.
770 let results = matcher.fuzzy_search("41");
771 assert_models_eq(
772 results,
773 vec![
774 "zed/gpt-4.1",
775 "openai/gpt-4.1",
776 "zed/gpt-4.1-nano",
777 "openai/gpt-4.1-nano",
778 ],
779 );
780
781 // Model provider should be searchable as well
782 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
783 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
784
785 // Fuzzy search
786 let results = matcher.fuzzy_search("z4n");
787 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
788 }
789
790 #[gpui::test]
791 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
792 let recommended_models = create_models(vec![("zed", "claude")]);
793 let all_models = create_models(vec![
794 ("zed", "claude"), // Should also appear in "other"
795 ("zed", "gemini"),
796 ("copilot", "o3"),
797 ]);
798
799 let grouped_models = GroupedModels::new(all_models, recommended_models);
800
801 let actual_all_models = grouped_models
802 .all
803 .values()
804 .flatten()
805 .cloned()
806 .collect::<Vec<_>>();
807
808 // Recommended models should also appear in "all"
809 assert_models_eq(
810 actual_all_models,
811 vec!["zed/claude", "zed/gemini", "copilot/o3"],
812 );
813 }
814
815 #[gpui::test]
816 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
817 let recommended_models = create_models(vec![("zed", "claude")]);
818 let all_models = create_models(vec![
819 ("zed", "claude"), // Should also appear in "other"
820 ("zed", "gemini"),
821 ("copilot", "claude"), // Different provider, should appear in "other"
822 ]);
823
824 let grouped_models = GroupedModels::new(all_models, recommended_models);
825
826 let actual_all_models = grouped_models
827 .all
828 .values()
829 .flatten()
830 .cloned()
831 .collect::<Vec<_>>();
832
833 // All models should appear in "all" regardless of recommended status
834 assert_models_eq(
835 actual_all_models,
836 vec!["zed/claude", "zed/gemini", "copilot/claude"],
837 );
838 }
839}