1use std::{cmp::Reverse, sync::Arc};
2
3use cloud_llm_client::Plan;
4use collections::{HashSet, IndexMap};
5use feature_flags::ZedProFeatureFlag;
6use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
7use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
8use language_model::{
9 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
10 LanguageModelRegistry,
11};
12use ordered_float::OrderedFloat;
13use picker::{Picker, PickerDelegate};
14use ui::{ListItem, ListItemSpacing, prelude::*};
15
16const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
17
18type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
19type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
20
21pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
22
23pub fn language_model_selector(
24 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
25 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
26 window: &mut Window,
27 cx: &mut Context<LanguageModelSelector>,
28) -> LanguageModelSelector {
29 let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
30 Picker::list(delegate, window, cx)
31 .show_scrollbar(true)
32 .width(rems(20.))
33 .max_height(Some(rems(20.).into()))
34}
35
36fn all_models(cx: &App) -> GroupedModels {
37 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
38
39 let recommended = providers
40 .iter()
41 .flat_map(|provider| {
42 provider
43 .recommended_models(cx)
44 .into_iter()
45 .map(|model| ModelInfo {
46 model,
47 icon: provider.icon(),
48 })
49 })
50 .collect();
51
52 let other = providers
53 .iter()
54 .flat_map(|provider| {
55 provider
56 .provided_models(cx)
57 .into_iter()
58 .map(|model| ModelInfo {
59 model,
60 icon: provider.icon(),
61 })
62 })
63 .collect();
64
65 GroupedModels::new(other, recommended)
66}
67
68#[derive(Clone)]
69struct ModelInfo {
70 model: Arc<dyn LanguageModel>,
71 icon: IconName,
72}
73
74pub struct LanguageModelPickerDelegate {
75 on_model_changed: OnModelChanged,
76 get_active_model: GetActiveModel,
77 all_models: Arc<GroupedModels>,
78 filtered_entries: Vec<LanguageModelPickerEntry>,
79 selected_index: usize,
80 _authenticate_all_providers_task: Task<()>,
81 _subscriptions: Vec<Subscription>,
82}
83
84impl LanguageModelPickerDelegate {
85 fn new(
86 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
87 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
88 window: &mut Window,
89 cx: &mut Context<Picker<Self>>,
90 ) -> Self {
91 let on_model_changed = Arc::new(on_model_changed);
92 let models = all_models(cx);
93 let entries = models.entries();
94
95 Self {
96 on_model_changed,
97 all_models: Arc::new(models),
98 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
99 filtered_entries: entries,
100 get_active_model: Arc::new(get_active_model),
101 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
102 _subscriptions: vec![cx.subscribe_in(
103 &LanguageModelRegistry::global(cx),
104 window,
105 |picker, _, event, window, cx| {
106 match event {
107 language_model::Event::ProviderStateChanged(_)
108 | language_model::Event::AddedProvider(_)
109 | language_model::Event::RemovedProvider(_) => {
110 let query = picker.query(cx);
111 picker.delegate.all_models = Arc::new(all_models(cx));
112 // Update matches will automatically drop the previous task
113 // if we get a provider event again
114 picker.update_matches(query, window, cx)
115 }
116 _ => {}
117 }
118 },
119 )],
120 }
121 }
122
123 fn get_active_model_index(
124 entries: &[LanguageModelPickerEntry],
125 active_model: Option<ConfiguredModel>,
126 ) -> usize {
127 entries
128 .iter()
129 .position(|entry| {
130 if let LanguageModelPickerEntry::Model(model) = entry {
131 active_model
132 .as_ref()
133 .map(|active_model| {
134 active_model.model.id() == model.model.id()
135 && active_model.provider.id() == model.model.provider_id()
136 })
137 .unwrap_or_default()
138 } else {
139 false
140 }
141 })
142 .unwrap_or(0)
143 }
144
145 /// Authenticates all providers in the [`LanguageModelRegistry`].
146 ///
147 /// We do this so that we can populate the language selector with all of the
148 /// models from the configured providers.
149 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
150 let authenticate_all_providers = LanguageModelRegistry::global(cx)
151 .read(cx)
152 .providers()
153 .iter()
154 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
155 .collect::<Vec<_>>();
156
157 cx.spawn(async move |_cx| {
158 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
159 if let Err(err) = authenticate_task.await {
160 if matches!(err, AuthenticateError::CredentialsNotFound) {
161 // Since we're authenticating these providers in the
162 // background for the purposes of populating the
163 // language selector, we don't care about providers
164 // where the credentials are not found.
165 } else {
166 // Some providers have noisy failure states that we
167 // don't want to spam the logs with every time the
168 // language model selector is initialized.
169 //
170 // Ideally these should have more clear failure modes
171 // that we know are safe to ignore here, like what we do
172 // with `CredentialsNotFound` above.
173 match provider_id.0.as_ref() {
174 "lmstudio" | "ollama" => {
175 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
176 //
177 // These fail noisily, so we don't log them.
178 }
179 "copilot_chat" => {
180 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
181 }
182 _ => {
183 log::error!(
184 "Failed to authenticate provider: {}: {err}",
185 provider_name.0
186 );
187 }
188 }
189 }
190 }
191 }
192 })
193 }
194
195 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
196 (self.get_active_model)(cx)
197 }
198}
199
200struct GroupedModels {
201 recommended: Vec<ModelInfo>,
202 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
203}
204
205impl GroupedModels {
206 pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
207 let recommended_ids = recommended
208 .iter()
209 .map(|info| (info.model.provider_id(), info.model.id()))
210 .collect::<HashSet<_>>();
211
212 let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
213 for model in other {
214 if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
215 continue;
216 }
217
218 let provider = model.model.provider_id();
219 if let Some(models) = other_by_provider.get_mut(&provider) {
220 models.push(model);
221 } else {
222 other_by_provider.insert(provider, vec![model]);
223 }
224 }
225
226 Self {
227 recommended,
228 other: other_by_provider,
229 }
230 }
231
232 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
233 let mut entries = Vec::new();
234
235 if !self.recommended.is_empty() {
236 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
237 entries.extend(
238 self.recommended
239 .iter()
240 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
241 );
242 }
243
244 for models in self.other.values() {
245 if models.is_empty() {
246 continue;
247 }
248 entries.push(LanguageModelPickerEntry::Separator(
249 models[0].model.provider_name().0,
250 ));
251 entries.extend(
252 models
253 .iter()
254 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
255 );
256 }
257 entries
258 }
259
260 fn model_infos(&self) -> Vec<ModelInfo> {
261 let other = self
262 .other
263 .values()
264 .flat_map(|model| model.iter())
265 .cloned()
266 .collect::<Vec<_>>();
267 self.recommended
268 .iter()
269 .chain(&other)
270 .cloned()
271 .collect::<Vec<_>>()
272 }
273}
274
275enum LanguageModelPickerEntry {
276 Model(ModelInfo),
277 Separator(SharedString),
278}
279
280struct ModelMatcher {
281 models: Vec<ModelInfo>,
282 bg_executor: BackgroundExecutor,
283 candidates: Vec<StringMatchCandidate>,
284}
285
286impl ModelMatcher {
287 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
288 let candidates = Self::make_match_candidates(&models);
289 Self {
290 models,
291 bg_executor,
292 candidates,
293 }
294 }
295
296 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
297 let mut matches = self.bg_executor.block(match_strings(
298 &self.candidates,
299 query,
300 false,
301 true,
302 100,
303 &Default::default(),
304 self.bg_executor.clone(),
305 ));
306
307 let sorting_key = |mat: &StringMatch| {
308 let candidate = &self.candidates[mat.candidate_id];
309 (Reverse(OrderedFloat(mat.score)), candidate.id)
310 };
311 matches.sort_unstable_by_key(sorting_key);
312
313 let matched_models: Vec<_> = matches
314 .into_iter()
315 .map(|mat| self.models[mat.candidate_id].clone())
316 .collect();
317
318 matched_models
319 }
320
321 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
322 self.models
323 .iter()
324 .filter(|m| {
325 m.model
326 .name()
327 .0
328 .to_lowercase()
329 .contains(&query.to_lowercase())
330 })
331 .cloned()
332 .collect::<Vec<_>>()
333 }
334
335 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
336 model_infos
337 .iter()
338 .enumerate()
339 .map(|(index, model)| {
340 StringMatchCandidate::new(
341 index,
342 &format!(
343 "{}/{}",
344 &model.model.provider_name().0,
345 &model.model.name().0
346 ),
347 )
348 })
349 .collect::<Vec<_>>()
350 }
351}
352
353impl PickerDelegate for LanguageModelPickerDelegate {
354 type ListItem = AnyElement;
355
356 fn match_count(&self) -> usize {
357 self.filtered_entries.len()
358 }
359
360 fn selected_index(&self) -> usize {
361 self.selected_index
362 }
363
364 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
365 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
366 cx.notify();
367 }
368
369 fn can_select(
370 &mut self,
371 ix: usize,
372 _window: &mut Window,
373 _cx: &mut Context<Picker<Self>>,
374 ) -> bool {
375 match self.filtered_entries.get(ix) {
376 Some(LanguageModelPickerEntry::Model(_)) => true,
377 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
378 }
379 }
380
381 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
382 "Select a model…".into()
383 }
384
385 fn update_matches(
386 &mut self,
387 query: String,
388 window: &mut Window,
389 cx: &mut Context<Picker<Self>>,
390 ) -> Task<()> {
391 let all_models = self.all_models.clone();
392 let active_model = (self.get_active_model)(cx);
393 let bg_executor = cx.background_executor();
394
395 let language_model_registry = LanguageModelRegistry::global(cx);
396
397 let configured_providers = language_model_registry
398 .read(cx)
399 .providers()
400 .into_iter()
401 .filter(|provider| provider.is_authenticated(cx))
402 .collect::<Vec<_>>();
403
404 let configured_provider_ids = configured_providers
405 .iter()
406 .map(|provider| provider.id())
407 .collect::<Vec<_>>();
408
409 let recommended_models = all_models
410 .recommended
411 .iter()
412 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
413 .cloned()
414 .collect::<Vec<_>>();
415
416 let available_models = all_models
417 .model_infos()
418 .iter()
419 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
420 .cloned()
421 .collect::<Vec<_>>();
422
423 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
424 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
425
426 let recommended = matcher_rec.exact_search(&query);
427 let all = matcher_all.fuzzy_search(&query);
428
429 let filtered_models = GroupedModels::new(all, recommended);
430
431 cx.spawn_in(window, async move |this, cx| {
432 this.update_in(cx, |this, window, cx| {
433 this.delegate.filtered_entries = filtered_models.entries();
434 // Finds the currently selected model in the list
435 let new_index =
436 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
437 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
438 cx.notify();
439 })
440 .ok();
441 })
442 }
443
444 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
445 if let Some(LanguageModelPickerEntry::Model(model_info)) =
446 self.filtered_entries.get(self.selected_index)
447 {
448 let model = model_info.model.clone();
449 (self.on_model_changed)(model.clone(), cx);
450
451 let current_index = self.selected_index;
452 self.set_selected_index(current_index, window, cx);
453
454 cx.emit(DismissEvent);
455 }
456 }
457
458 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
459 cx.emit(DismissEvent);
460 }
461
462 fn render_match(
463 &self,
464 ix: usize,
465 selected: bool,
466 _: &mut Window,
467 cx: &mut Context<Picker<Self>>,
468 ) -> Option<Self::ListItem> {
469 match self.filtered_entries.get(ix)? {
470 LanguageModelPickerEntry::Separator(title) => Some(
471 div()
472 .px_2()
473 .pb_1()
474 .when(ix > 1, |this| {
475 this.mt_1()
476 .pt_2()
477 .border_t_1()
478 .border_color(cx.theme().colors().border_variant)
479 })
480 .child(
481 Label::new(title)
482 .size(LabelSize::XSmall)
483 .color(Color::Muted),
484 )
485 .into_any_element(),
486 ),
487 LanguageModelPickerEntry::Model(model_info) => {
488 let active_model = (self.get_active_model)(cx);
489 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
490 let active_model_id = active_model.map(|m| m.model.id());
491
492 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
493 && Some(model_info.model.id()) == active_model_id;
494
495 let model_icon_color = if is_selected {
496 Color::Accent
497 } else {
498 Color::Muted
499 };
500
501 Some(
502 ListItem::new(ix)
503 .inset(true)
504 .spacing(ListItemSpacing::Sparse)
505 .toggle_state(selected)
506 .start_slot(
507 Icon::new(model_info.icon)
508 .color(model_icon_color)
509 .size(IconSize::Small),
510 )
511 .child(
512 h_flex()
513 .w_full()
514 .pl_0p5()
515 .gap_1p5()
516 .w(px(240.))
517 .child(Label::new(model_info.model.name().0).truncate()),
518 )
519 .end_slot(div().pr_3().when(is_selected, |this| {
520 this.child(
521 Icon::new(IconName::Check)
522 .color(Color::Accent)
523 .size(IconSize::Small),
524 )
525 }))
526 .into_any_element(),
527 )
528 }
529 }
530 }
531
532 fn render_footer(
533 &self,
534 _: &mut Window,
535 cx: &mut Context<Picker<Self>>,
536 ) -> Option<gpui::AnyElement> {
537 use feature_flags::FeatureFlagAppExt;
538
539 let plan = Plan::ZedPro;
540
541 Some(
542 h_flex()
543 .w_full()
544 .border_t_1()
545 .border_color(cx.theme().colors().border_variant)
546 .p_1()
547 .gap_4()
548 .justify_between()
549 .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
550 this.child(match plan {
551 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
552 .icon(IconName::ZedAssistant)
553 .icon_size(IconSize::Small)
554 .icon_color(Color::Muted)
555 .icon_position(IconPosition::Start)
556 .on_click(|_, window, cx| {
557 window
558 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
559 }),
560 Plan::ZedFree | Plan::ZedProTrial => Button::new(
561 "try-pro",
562 if plan == Plan::ZedProTrial {
563 "Upgrade to Pro"
564 } else {
565 "Try Pro"
566 },
567 )
568 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
569 })
570 })
571 .child(
572 Button::new("configure", "Configure")
573 .icon(IconName::Settings)
574 .icon_size(IconSize::Small)
575 .icon_color(Color::Muted)
576 .icon_position(IconPosition::Start)
577 .on_click(|_, window, cx| {
578 window.dispatch_action(
579 zed_actions::agent::OpenSettings.boxed_clone(),
580 cx,
581 );
582 }),
583 )
584 .into_any(),
585 )
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592 use futures::{future::BoxFuture, stream::BoxStream};
593 use gpui::{AsyncApp, TestAppContext, http_client};
594 use language_model::{
595 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
596 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
597 LanguageModelRequest, LanguageModelToolChoice,
598 };
599 use ui::IconName;
600
601 #[derive(Clone)]
602 struct TestLanguageModel {
603 name: LanguageModelName,
604 id: LanguageModelId,
605 provider_id: LanguageModelProviderId,
606 provider_name: LanguageModelProviderName,
607 }
608
609 impl TestLanguageModel {
610 fn new(name: &str, provider: &str) -> Self {
611 Self {
612 name: LanguageModelName::from(name.to_string()),
613 id: LanguageModelId::from(name.to_string()),
614 provider_id: LanguageModelProviderId::from(provider.to_string()),
615 provider_name: LanguageModelProviderName::from(provider.to_string()),
616 }
617 }
618 }
619
620 impl LanguageModel for TestLanguageModel {
621 fn id(&self) -> LanguageModelId {
622 self.id.clone()
623 }
624
625 fn name(&self) -> LanguageModelName {
626 self.name.clone()
627 }
628
629 fn provider_id(&self) -> LanguageModelProviderId {
630 self.provider_id.clone()
631 }
632
633 fn provider_name(&self) -> LanguageModelProviderName {
634 self.provider_name.clone()
635 }
636
637 fn supports_tools(&self) -> bool {
638 false
639 }
640
641 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
642 false
643 }
644
645 fn supports_images(&self) -> bool {
646 false
647 }
648
649 fn telemetry_id(&self) -> String {
650 format!("{}/{}", self.provider_id.0, self.name.0)
651 }
652
653 fn max_token_count(&self) -> u64 {
654 1000
655 }
656
657 fn count_tokens(
658 &self,
659 _: LanguageModelRequest,
660 _: &App,
661 ) -> BoxFuture<'static, http_client::Result<u64>> {
662 unimplemented!()
663 }
664
665 fn stream_completion(
666 &self,
667 _: LanguageModelRequest,
668 _: &AsyncApp,
669 ) -> BoxFuture<
670 'static,
671 Result<
672 BoxStream<
673 'static,
674 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
675 >,
676 LanguageModelCompletionError,
677 >,
678 > {
679 unimplemented!()
680 }
681 }
682
683 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
684 model_specs
685 .into_iter()
686 .map(|(provider, name)| ModelInfo {
687 model: Arc::new(TestLanguageModel::new(name, provider)),
688 icon: IconName::Ai,
689 })
690 .collect()
691 }
692
693 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
694 assert_eq!(
695 result.len(),
696 expected.len(),
697 "Number of models doesn't match"
698 );
699
700 for (i, expected_name) in expected.iter().enumerate() {
701 assert_eq!(
702 result[i].model.telemetry_id(),
703 *expected_name,
704 "Model at position {} doesn't match expected model",
705 i
706 );
707 }
708 }
709
710 #[gpui::test]
711 fn test_exact_match(cx: &mut TestAppContext) {
712 let models = create_models(vec![
713 ("zed", "Claude 3.7 Sonnet"),
714 ("zed", "Claude 3.7 Sonnet Thinking"),
715 ("zed", "gpt-4.1"),
716 ("zed", "gpt-4.1-nano"),
717 ("openai", "gpt-3.5-turbo"),
718 ("openai", "gpt-4.1"),
719 ("openai", "gpt-4.1-nano"),
720 ("ollama", "mistral"),
721 ("ollama", "deepseek"),
722 ]);
723 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
724
725 // The order of models should be maintained, case doesn't matter
726 let results = matcher.exact_search("GPT-4.1");
727 assert_models_eq(
728 results,
729 vec![
730 "zed/gpt-4.1",
731 "zed/gpt-4.1-nano",
732 "openai/gpt-4.1",
733 "openai/gpt-4.1-nano",
734 ],
735 );
736 }
737
738 #[gpui::test]
739 fn test_fuzzy_match(cx: &mut TestAppContext) {
740 let models = create_models(vec![
741 ("zed", "Claude 3.7 Sonnet"),
742 ("zed", "Claude 3.7 Sonnet Thinking"),
743 ("zed", "gpt-4.1"),
744 ("zed", "gpt-4.1-nano"),
745 ("openai", "gpt-3.5-turbo"),
746 ("openai", "gpt-4.1"),
747 ("openai", "gpt-4.1-nano"),
748 ("ollama", "mistral"),
749 ("ollama", "deepseek"),
750 ]);
751 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
752
753 // Results should preserve models order whenever possible.
754 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
755 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
756 // so it should appear first in the results.
757 let results = matcher.fuzzy_search("41");
758 assert_models_eq(
759 results,
760 vec![
761 "zed/gpt-4.1",
762 "openai/gpt-4.1",
763 "zed/gpt-4.1-nano",
764 "openai/gpt-4.1-nano",
765 ],
766 );
767
768 // Model provider should be searchable as well
769 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
770 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
771
772 // Fuzzy search
773 let results = matcher.fuzzy_search("z4n");
774 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
775 }
776
777 #[gpui::test]
778 fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
779 let recommended_models = create_models(vec![("zed", "claude")]);
780 let all_models = create_models(vec![
781 ("zed", "claude"), // Should be filtered out from "other"
782 ("zed", "gemini"),
783 ("copilot", "o3"),
784 ]);
785
786 let grouped_models = GroupedModels::new(all_models, recommended_models);
787
788 let actual_other_models = grouped_models
789 .other
790 .values()
791 .flatten()
792 .cloned()
793 .collect::<Vec<_>>();
794
795 // Recommended models should not appear in "other"
796 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
797 }
798
799 #[gpui::test]
800 fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
801 let recommended_models = create_models(vec![("zed", "claude")]);
802 let all_models = create_models(vec![
803 ("zed", "claude"), // Should be filtered out from "other"
804 ("zed", "gemini"),
805 ("copilot", "claude"), // Should not be filtered out from "other"
806 ]);
807
808 let grouped_models = GroupedModels::new(all_models, recommended_models);
809
810 let actual_other_models = grouped_models
811 .other
812 .values()
813 .flatten()
814 .cloned()
815 .collect::<Vec<_>>();
816
817 // Recommended models should not appear in "other"
818 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
819 }
820}