1use std::{cmp::Reverse, sync::Arc};
2
3use collections::IndexMap;
4use futures::{StreamExt, channel::mpsc};
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
7use language_model::{
8 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
9 LanguageModelProviderId, LanguageModelRegistry,
10};
11use ordered_float::OrderedFloat;
12use picker::{Picker, PickerDelegate};
13use ui::{KeyBinding, ListItem, ListItemSpacing, prelude::*};
14use zed_actions::agent::OpenSettings;
15
16type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
17type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
18
19pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
20
21pub fn language_model_selector(
22 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
23 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
24 popover_styles: bool,
25 focus_handle: FocusHandle,
26 window: &mut Window,
27 cx: &mut Context<LanguageModelSelector>,
28) -> LanguageModelSelector {
29 let delegate = LanguageModelPickerDelegate::new(
30 get_active_model,
31 on_model_changed,
32 popover_styles,
33 focus_handle,
34 window,
35 cx,
36 );
37
38 if popover_styles {
39 Picker::list(delegate, window, cx)
40 .show_scrollbar(true)
41 .width(rems(20.))
42 .max_height(Some(rems(20.).into()))
43 } else {
44 Picker::list(delegate, window, cx).show_scrollbar(true)
45 }
46}
47
48fn all_models(cx: &App) -> GroupedModels {
49 let providers = LanguageModelRegistry::global(cx)
50 .read(cx)
51 .visible_providers();
52
53 let recommended = providers
54 .iter()
55 .flat_map(|provider| {
56 provider
57 .recommended_models(cx)
58 .into_iter()
59 .map(|model| ModelInfo {
60 model,
61 icon: ProviderIcon::from_provider(provider.as_ref()),
62 })
63 })
64 .collect();
65
66 let all: Vec<ModelInfo> = providers
67 .iter()
68 .flat_map(|provider| {
69 provider
70 .provided_models(cx)
71 .into_iter()
72 .map(|model| ModelInfo {
73 model,
74 icon: ProviderIcon::from_provider(provider.as_ref()),
75 })
76 })
77 .collect();
78
79 GroupedModels::new(all, recommended)
80}
81
82#[derive(Clone)]
83enum ProviderIcon {
84 Name(IconName),
85 Path(SharedString),
86}
87
88impl ProviderIcon {
89 fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
90 if let Some(path) = provider.icon_path() {
91 Self::Path(path)
92 } else {
93 Self::Name(provider.icon())
94 }
95 }
96}
97
98#[derive(Clone)]
99struct ModelInfo {
100 model: Arc<dyn LanguageModel>,
101 icon: ProviderIcon,
102}
103
104pub struct LanguageModelPickerDelegate {
105 on_model_changed: OnModelChanged,
106 get_active_model: GetActiveModel,
107 all_models: Arc<GroupedModels>,
108 filtered_entries: Vec<LanguageModelPickerEntry>,
109 selected_index: usize,
110 _authenticate_all_providers_task: Task<()>,
111 _refresh_models_task: Task<()>,
112 popover_styles: bool,
113 focus_handle: FocusHandle,
114}
115
116impl LanguageModelPickerDelegate {
117 fn new(
118 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
119 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
120 popover_styles: bool,
121 focus_handle: FocusHandle,
122 window: &mut Window,
123 cx: &mut Context<Picker<Self>>,
124 ) -> Self {
125 let on_model_changed = Arc::new(on_model_changed);
126 let models = all_models(cx);
127 let entries = models.entries();
128
129 Self {
130 on_model_changed,
131 all_models: Arc::new(models),
132 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
133 filtered_entries: entries,
134 get_active_model: Arc::new(get_active_model),
135 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
136 _refresh_models_task: {
137 // Create a channel to signal when models need refreshing
138 let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
139
140 // Subscribe to registry events and send refresh signals through the channel
141 let registry = LanguageModelRegistry::global(cx);
142 cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
143 language_model::Event::ProviderStateChanged(_)
144 | language_model::Event::AddedProvider(_)
145 | language_model::Event::RemovedProvider(_)
146 | language_model::Event::ProvidersChanged => {
147 refresh_tx.unbounded_send(()).ok();
148 }
149 language_model::Event::DefaultModelChanged
150 | language_model::Event::InlineAssistantModelChanged
151 | language_model::Event::CommitMessageModelChanged
152 | language_model::Event::ThreadSummaryModelChanged => {}
153 })
154 .detach();
155
156 // Spawn a task that listens for refresh signals and updates the picker
157 cx.spawn_in(window, async move |this, cx| {
158 while let Some(()) = refresh_rx.next().await {
159 if this
160 .update_in(cx, |picker, window, cx| {
161 picker.delegate.all_models = Arc::new(all_models(cx));
162 picker.refresh(window, cx);
163 })
164 .is_err()
165 {
166 // Picker was dropped, exit the loop
167 break;
168 }
169 }
170 })
171 },
172 popover_styles,
173 focus_handle,
174 }
175 }
176
177 fn get_active_model_index(
178 entries: &[LanguageModelPickerEntry],
179 active_model: Option<ConfiguredModel>,
180 ) -> usize {
181 entries
182 .iter()
183 .position(|entry| {
184 if let LanguageModelPickerEntry::Model(model) = entry {
185 active_model
186 .as_ref()
187 .map(|active_model| {
188 active_model.model.id() == model.model.id()
189 && active_model.provider.id() == model.model.provider_id()
190 })
191 .unwrap_or_default()
192 } else {
193 false
194 }
195 })
196 .unwrap_or(0)
197 }
198
199 /// Authenticates all providers in the [`LanguageModelRegistry`].
200 ///
201 /// We do this so that we can populate the language selector with all of the
202 /// models from the configured providers.
203 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
204 let authenticate_all_providers = LanguageModelRegistry::global(cx)
205 .read(cx)
206 .providers()
207 .iter()
208 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
209 .collect::<Vec<_>>();
210
211 cx.spawn(async move |_cx| {
212 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
213 if let Err(err) = authenticate_task.await {
214 if matches!(err, AuthenticateError::CredentialsNotFound) {
215 // Since we're authenticating these providers in the
216 // background for the purposes of populating the
217 // language selector, we don't care about providers
218 // where the credentials are not found.
219 } else {
220 // Some providers have noisy failure states that we
221 // don't want to spam the logs with every time the
222 // language model selector is initialized.
223 //
224 // Ideally these should have more clear failure modes
225 // that we know are safe to ignore here, like what we do
226 // with `CredentialsNotFound` above.
227 match provider_id.0.as_ref() {
228 "lmstudio" | "ollama" => {
229 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
230 //
231 // These fail noisily, so we don't log them.
232 }
233 "copilot_chat" => {
234 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
235 }
236 _ => {
237 log::error!(
238 "Failed to authenticate provider: {}: {err:#}",
239 provider_name.0
240 );
241 }
242 }
243 }
244 }
245 }
246 })
247 }
248
249 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
250 (self.get_active_model)(cx)
251 }
252}
253
254struct GroupedModels {
255 recommended: Vec<ModelInfo>,
256 all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
257}
258
259impl GroupedModels {
260 pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
261 let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
262 for model in all {
263 let provider = model.model.provider_id();
264 if let Some(models) = all_by_provider.get_mut(&provider) {
265 models.push(model);
266 } else {
267 all_by_provider.insert(provider, vec![model]);
268 }
269 }
270
271 Self {
272 recommended,
273 all: all_by_provider,
274 }
275 }
276
277 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
278 let mut entries = Vec::new();
279
280 if !self.recommended.is_empty() {
281 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
282 entries.extend(
283 self.recommended
284 .iter()
285 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
286 );
287 }
288
289 for models in self.all.values() {
290 if models.is_empty() {
291 continue;
292 }
293 entries.push(LanguageModelPickerEntry::Separator(
294 models[0].model.provider_name().0,
295 ));
296 entries.extend(
297 models
298 .iter()
299 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
300 );
301 }
302 entries
303 }
304}
305
306enum LanguageModelPickerEntry {
307 Model(ModelInfo),
308 Separator(SharedString),
309}
310
311struct ModelMatcher {
312 models: Vec<ModelInfo>,
313 bg_executor: BackgroundExecutor,
314 candidates: Vec<StringMatchCandidate>,
315}
316
317impl ModelMatcher {
318 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
319 let candidates = Self::make_match_candidates(&models);
320 Self {
321 models,
322 bg_executor,
323 candidates,
324 }
325 }
326
327 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
328 let mut matches = self.bg_executor.block(match_strings(
329 &self.candidates,
330 query,
331 false,
332 true,
333 100,
334 &Default::default(),
335 self.bg_executor.clone(),
336 ));
337
338 let sorting_key = |mat: &StringMatch| {
339 let candidate = &self.candidates[mat.candidate_id];
340 (Reverse(OrderedFloat(mat.score)), candidate.id)
341 };
342 matches.sort_unstable_by_key(sorting_key);
343
344 let matched_models: Vec<_> = matches
345 .into_iter()
346 .map(|mat| self.models[mat.candidate_id].clone())
347 .collect();
348
349 matched_models
350 }
351
352 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
353 self.models
354 .iter()
355 .filter(|m| {
356 m.model
357 .name()
358 .0
359 .to_lowercase()
360 .contains(&query.to_lowercase())
361 })
362 .cloned()
363 .collect::<Vec<_>>()
364 }
365
366 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
367 model_infos
368 .iter()
369 .enumerate()
370 .map(|(index, model)| {
371 StringMatchCandidate::new(
372 index,
373 &format!(
374 "{}/{}",
375 &model.model.provider_name().0,
376 &model.model.name().0
377 ),
378 )
379 })
380 .collect::<Vec<_>>()
381 }
382}
383
384impl PickerDelegate for LanguageModelPickerDelegate {
385 type ListItem = AnyElement;
386
387 fn match_count(&self) -> usize {
388 self.filtered_entries.len()
389 }
390
391 fn selected_index(&self) -> usize {
392 self.selected_index
393 }
394
395 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
396 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
397 cx.notify();
398 }
399
400 fn can_select(
401 &mut self,
402 ix: usize,
403 _window: &mut Window,
404 _cx: &mut Context<Picker<Self>>,
405 ) -> bool {
406 match self.filtered_entries.get(ix) {
407 Some(LanguageModelPickerEntry::Model(_)) => true,
408 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
409 }
410 }
411
412 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
413 "Select a model…".into()
414 }
415
416 fn update_matches(
417 &mut self,
418 query: String,
419 window: &mut Window,
420 cx: &mut Context<Picker<Self>>,
421 ) -> Task<()> {
422 let all_models = self.all_models.clone();
423 let active_model = (self.get_active_model)(cx);
424 let bg_executor = cx.background_executor();
425
426 let language_model_registry = LanguageModelRegistry::global(cx);
427
428 let configured_providers = language_model_registry
429 .read(cx)
430 .visible_providers()
431 .into_iter()
432 .filter(|provider| provider.is_authenticated(cx))
433 .collect::<Vec<_>>();
434
435 let configured_provider_ids = configured_providers
436 .iter()
437 .map(|provider| provider.id())
438 .collect::<Vec<_>>();
439
440 let recommended_models = all_models
441 .recommended
442 .iter()
443 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
444 .cloned()
445 .collect::<Vec<_>>();
446
447 let available_models = all_models
448 .all
449 .values()
450 .flat_map(|models| models.iter())
451 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
452 .cloned()
453 .collect::<Vec<_>>();
454
455 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
456 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
457
458 let recommended = matcher_rec.exact_search(&query);
459 let all = matcher_all.fuzzy_search(&query);
460
461 let filtered_models = GroupedModels::new(all, recommended);
462
463 cx.spawn_in(window, async move |this, cx| {
464 this.update_in(cx, |this, window, cx| {
465 this.delegate.filtered_entries = filtered_models.entries();
466 // Finds the currently selected model in the list
467 let new_index =
468 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
469 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
470 cx.notify();
471 })
472 .ok();
473 })
474 }
475
476 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
477 if let Some(LanguageModelPickerEntry::Model(model_info)) =
478 self.filtered_entries.get(self.selected_index)
479 {
480 let model = model_info.model.clone();
481 (self.on_model_changed)(model.clone(), cx);
482
483 let current_index = self.selected_index;
484 self.set_selected_index(current_index, window, cx);
485
486 cx.emit(DismissEvent);
487 }
488 }
489
490 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
491 cx.emit(DismissEvent);
492 }
493
494 fn render_match(
495 &self,
496 ix: usize,
497 selected: bool,
498 _: &mut Window,
499 cx: &mut Context<Picker<Self>>,
500 ) -> Option<Self::ListItem> {
501 match self.filtered_entries.get(ix)? {
502 LanguageModelPickerEntry::Separator(title) => Some(
503 div()
504 .px_2()
505 .pb_1()
506 .when(ix > 1, |this| {
507 this.mt_1()
508 .pt_2()
509 .border_t_1()
510 .border_color(cx.theme().colors().border_variant)
511 })
512 .child(
513 Label::new(title)
514 .size(LabelSize::XSmall)
515 .color(Color::Muted),
516 )
517 .into_any_element(),
518 ),
519 LanguageModelPickerEntry::Model(model_info) => {
520 let active_model = (self.get_active_model)(cx);
521 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
522 let active_model_id = active_model.map(|m| m.model.id());
523
524 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
525 && Some(model_info.model.id()) == active_model_id;
526
527 let model_icon_color = if is_selected {
528 Color::Accent
529 } else {
530 Color::Muted
531 };
532
533 Some(
534 ListItem::new(ix)
535 .inset(true)
536 .spacing(ListItemSpacing::Sparse)
537 .toggle_state(selected)
538 .child(
539 h_flex()
540 .w_full()
541 .gap_1p5()
542 .child(match &model_info.icon {
543 ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
544 .color(model_icon_color)
545 .size(IconSize::Small),
546 ProviderIcon::Path(icon_path) => {
547 Icon::from_external_svg(icon_path.clone())
548 .color(model_icon_color)
549 .size(IconSize::Small)
550 }
551 })
552 .child(Label::new(model_info.model.name().0).truncate()),
553 )
554 .end_slot(div().pr_3().when(is_selected, |this| {
555 this.child(
556 Icon::new(IconName::Check)
557 .color(Color::Accent)
558 .size(IconSize::Small),
559 )
560 }))
561 .into_any_element(),
562 )
563 }
564 }
565 }
566
567 fn render_footer(
568 &self,
569 _window: &mut Window,
570 cx: &mut Context<Picker<Self>>,
571 ) -> Option<gpui::AnyElement> {
572 let focus_handle = self.focus_handle.clone();
573
574 if !self.popover_styles {
575 return None;
576 }
577
578 Some(
579 h_flex()
580 .w_full()
581 .p_1p5()
582 .border_t_1()
583 .border_color(cx.theme().colors().border_variant)
584 .child(
585 Button::new("configure", "Configure")
586 .full_width()
587 .style(ButtonStyle::Outlined)
588 .key_binding(
589 KeyBinding::for_action_in(&OpenSettings, &focus_handle, cx)
590 .map(|kb| kb.size(rems_from_px(12.))),
591 )
592 .on_click(|_, window, cx| {
593 window.dispatch_action(OpenSettings.boxed_clone(), cx);
594 }),
595 )
596 .into_any(),
597 )
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604 use futures::{future::BoxFuture, stream::BoxStream};
605 use gpui::{AsyncApp, TestAppContext, http_client};
606 use language_model::{
607 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
608 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
609 LanguageModelRequest, LanguageModelToolChoice,
610 };
611 use ui::IconName;
612
613 #[derive(Clone)]
614 struct TestLanguageModel {
615 name: LanguageModelName,
616 id: LanguageModelId,
617 provider_id: LanguageModelProviderId,
618 provider_name: LanguageModelProviderName,
619 }
620
621 impl TestLanguageModel {
622 fn new(name: &str, provider: &str) -> Self {
623 Self {
624 name: LanguageModelName::from(name.to_string()),
625 id: LanguageModelId::from(name.to_string()),
626 provider_id: LanguageModelProviderId::from(provider.to_string()),
627 provider_name: LanguageModelProviderName::from(provider.to_string()),
628 }
629 }
630 }
631
632 impl LanguageModel for TestLanguageModel {
633 fn id(&self) -> LanguageModelId {
634 self.id.clone()
635 }
636
637 fn name(&self) -> LanguageModelName {
638 self.name.clone()
639 }
640
641 fn provider_id(&self) -> LanguageModelProviderId {
642 self.provider_id.clone()
643 }
644
645 fn provider_name(&self) -> LanguageModelProviderName {
646 self.provider_name.clone()
647 }
648
649 fn supports_tools(&self) -> bool {
650 false
651 }
652
653 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
654 false
655 }
656
657 fn supports_images(&self) -> bool {
658 false
659 }
660
661 fn telemetry_id(&self) -> String {
662 format!("{}/{}", self.provider_id.0, self.name.0)
663 }
664
665 fn max_token_count(&self) -> u64 {
666 1000
667 }
668
669 fn count_tokens(
670 &self,
671 _: LanguageModelRequest,
672 _: &App,
673 ) -> BoxFuture<'static, http_client::Result<u64>> {
674 unimplemented!()
675 }
676
677 fn stream_completion(
678 &self,
679 _: LanguageModelRequest,
680 _: &AsyncApp,
681 ) -> BoxFuture<
682 'static,
683 Result<
684 BoxStream<
685 'static,
686 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
687 >,
688 LanguageModelCompletionError,
689 >,
690 > {
691 unimplemented!()
692 }
693 }
694
695 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
696 model_specs
697 .into_iter()
698 .map(|(provider, name)| ModelInfo {
699 model: Arc::new(TestLanguageModel::new(name, provider)),
700 icon: ProviderIcon::Name(IconName::Ai),
701 })
702 .collect()
703 }
704
705 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
706 assert_eq!(
707 result.len(),
708 expected.len(),
709 "Number of models doesn't match"
710 );
711
712 for (i, expected_name) in expected.iter().enumerate() {
713 assert_eq!(
714 result[i].model.telemetry_id(),
715 *expected_name,
716 "Model at position {} doesn't match expected model",
717 i
718 );
719 }
720 }
721
722 #[gpui::test]
723 fn test_exact_match(cx: &mut TestAppContext) {
724 let models = create_models(vec![
725 ("zed", "Claude 3.7 Sonnet"),
726 ("zed", "Claude 3.7 Sonnet Thinking"),
727 ("zed", "gpt-4.1"),
728 ("zed", "gpt-4.1-nano"),
729 ("openai", "gpt-3.5-turbo"),
730 ("openai", "gpt-4.1"),
731 ("openai", "gpt-4.1-nano"),
732 ("ollama", "mistral"),
733 ("ollama", "deepseek"),
734 ]);
735 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
736
737 // The order of models should be maintained, case doesn't matter
738 let results = matcher.exact_search("GPT-4.1");
739 assert_models_eq(
740 results,
741 vec![
742 "zed/gpt-4.1",
743 "zed/gpt-4.1-nano",
744 "openai/gpt-4.1",
745 "openai/gpt-4.1-nano",
746 ],
747 );
748 }
749
750 #[gpui::test]
751 fn test_fuzzy_match(cx: &mut TestAppContext) {
752 let models = create_models(vec![
753 ("zed", "Claude 3.7 Sonnet"),
754 ("zed", "Claude 3.7 Sonnet Thinking"),
755 ("zed", "gpt-4.1"),
756 ("zed", "gpt-4.1-nano"),
757 ("openai", "gpt-3.5-turbo"),
758 ("openai", "gpt-4.1"),
759 ("openai", "gpt-4.1-nano"),
760 ("ollama", "mistral"),
761 ("ollama", "deepseek"),
762 ]);
763 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
764
765 // Results should preserve models order whenever possible.
766 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
767 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
768 // so it should appear first in the results.
769 let results = matcher.fuzzy_search("41");
770 assert_models_eq(
771 results,
772 vec![
773 "zed/gpt-4.1",
774 "openai/gpt-4.1",
775 "zed/gpt-4.1-nano",
776 "openai/gpt-4.1-nano",
777 ],
778 );
779
780 // Model provider should be searchable as well
781 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
782 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
783
784 // Fuzzy search
785 let results = matcher.fuzzy_search("z4n");
786 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
787 }
788
789 #[gpui::test]
790 fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
791 let recommended_models = create_models(vec![("zed", "claude")]);
792 let all_models = create_models(vec![
793 ("zed", "claude"), // Should also appear in "other"
794 ("zed", "gemini"),
795 ("copilot", "o3"),
796 ]);
797
798 let grouped_models = GroupedModels::new(all_models, recommended_models);
799
800 let actual_all_models = grouped_models
801 .all
802 .values()
803 .flatten()
804 .cloned()
805 .collect::<Vec<_>>();
806
807 // Recommended models should also appear in "all"
808 assert_models_eq(
809 actual_all_models,
810 vec!["zed/claude", "zed/gemini", "copilot/o3"],
811 );
812 }
813
814 #[gpui::test]
815 fn test_models_from_different_providers(_cx: &mut TestAppContext) {
816 let recommended_models = create_models(vec![("zed", "claude")]);
817 let all_models = create_models(vec![
818 ("zed", "claude"), // Should also appear in "other"
819 ("zed", "gemini"),
820 ("copilot", "claude"), // Different provider, should appear in "other"
821 ]);
822
823 let grouped_models = GroupedModels::new(all_models, recommended_models);
824
825 let actual_all_models = grouped_models
826 .all
827 .values()
828 .flatten()
829 .cloned()
830 .collect::<Vec<_>>();
831
832 // All models should appear in "all" regardless of recommended status
833 assert_models_eq(
834 actual_all_models,
835 vec!["zed/claude", "zed/gemini", "copilot/claude"],
836 );
837 }
838}