1use std::{cmp::Reverse, sync::Arc};
2
3use collections::{HashSet, IndexMap};
4use feature_flags::ZedProFeatureFlag;
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{
7 Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task, actions,
8};
9use language_model::{
10 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
11 LanguageModelRegistry,
12};
13use ordered_float::OrderedFloat;
14use picker::{Picker, PickerDelegate};
15use proto::Plan;
16use ui::{ListItem, ListItemSpacing, prelude::*};
17
18actions!(
19 agent,
20 [
21 /// Toggles the language model selector dropdown.
22 #[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])]
23 ToggleModelSelector
24 ]
25);
26
27const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
28
29type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
30type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
31
32pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
33
34pub fn language_model_selector(
35 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
36 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
37 window: &mut Window,
38 cx: &mut Context<LanguageModelSelector>,
39) -> LanguageModelSelector {
40 let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
41 Picker::list(delegate, window, cx)
42 .show_scrollbar(true)
43 .width(rems(20.))
44 .max_height(Some(rems(20.).into()))
45}
46
47fn all_models(cx: &App) -> GroupedModels {
48 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
49
50 let recommended = providers
51 .iter()
52 .flat_map(|provider| {
53 provider
54 .recommended_models(cx)
55 .into_iter()
56 .map(|model| ModelInfo {
57 model,
58 icon: provider.icon(),
59 })
60 })
61 .collect();
62
63 let other = providers
64 .iter()
65 .flat_map(|provider| {
66 provider
67 .provided_models(cx)
68 .into_iter()
69 .map(|model| ModelInfo {
70 model,
71 icon: provider.icon(),
72 })
73 })
74 .collect();
75
76 GroupedModels::new(other, recommended)
77}
78
79#[derive(Clone)]
80struct ModelInfo {
81 model: Arc<dyn LanguageModel>,
82 icon: IconName,
83}
84
85pub struct LanguageModelPickerDelegate {
86 on_model_changed: OnModelChanged,
87 get_active_model: GetActiveModel,
88 all_models: Arc<GroupedModels>,
89 filtered_entries: Vec<LanguageModelPickerEntry>,
90 selected_index: usize,
91 _authenticate_all_providers_task: Task<()>,
92 _subscriptions: Vec<Subscription>,
93}
94
95impl LanguageModelPickerDelegate {
96 fn new(
97 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
98 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
99 window: &mut Window,
100 cx: &mut Context<Picker<Self>>,
101 ) -> Self {
102 let on_model_changed = Arc::new(on_model_changed);
103 let models = all_models(cx);
104 let entries = models.entries();
105
106 Self {
107 on_model_changed: on_model_changed.clone(),
108 all_models: Arc::new(models),
109 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
110 filtered_entries: entries,
111 get_active_model: Arc::new(get_active_model),
112 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
113 _subscriptions: vec![cx.subscribe_in(
114 &LanguageModelRegistry::global(cx),
115 window,
116 |picker, _, event, window, cx| {
117 match event {
118 language_model::Event::ProviderStateChanged
119 | language_model::Event::AddedProvider(_)
120 | language_model::Event::RemovedProvider(_) => {
121 let query = picker.query(cx);
122 picker.delegate.all_models = Arc::new(all_models(cx));
123 // Update matches will automatically drop the previous task
124 // if we get a provider event again
125 picker.update_matches(query, window, cx)
126 }
127 _ => {}
128 }
129 },
130 )],
131 }
132 }
133
134 fn get_active_model_index(
135 entries: &[LanguageModelPickerEntry],
136 active_model: Option<ConfiguredModel>,
137 ) -> usize {
138 entries
139 .iter()
140 .position(|entry| {
141 if let LanguageModelPickerEntry::Model(model) = entry {
142 active_model
143 .as_ref()
144 .map(|active_model| {
145 active_model.model.id() == model.model.id()
146 && active_model.provider.id() == model.model.provider_id()
147 })
148 .unwrap_or_default()
149 } else {
150 false
151 }
152 })
153 .unwrap_or(0)
154 }
155
156 /// Authenticates all providers in the [`LanguageModelRegistry`].
157 ///
158 /// We do this so that we can populate the language selector with all of the
159 /// models from the configured providers.
160 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
161 let authenticate_all_providers = LanguageModelRegistry::global(cx)
162 .read(cx)
163 .providers()
164 .iter()
165 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
166 .collect::<Vec<_>>();
167
168 cx.spawn(async move |_cx| {
169 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
170 if let Err(err) = authenticate_task.await {
171 if matches!(err, AuthenticateError::CredentialsNotFound) {
172 // Since we're authenticating these providers in the
173 // background for the purposes of populating the
174 // language selector, we don't care about providers
175 // where the credentials are not found.
176 } else {
177 // Some providers have noisy failure states that we
178 // don't want to spam the logs with every time the
179 // language model selector is initialized.
180 //
181 // Ideally these should have more clear failure modes
182 // that we know are safe to ignore here, like what we do
183 // with `CredentialsNotFound` above.
184 match provider_id.0.as_ref() {
185 "lmstudio" | "ollama" => {
186 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
187 //
188 // These fail noisily, so we don't log them.
189 }
190 "copilot_chat" => {
191 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
192 }
193 _ => {
194 log::error!(
195 "Failed to authenticate provider: {}: {err}",
196 provider_name.0
197 );
198 }
199 }
200 }
201 }
202 }
203 })
204 }
205
206 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
207 (self.get_active_model)(cx)
208 }
209}
210
211struct GroupedModels {
212 recommended: Vec<ModelInfo>,
213 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
214}
215
216impl GroupedModels {
217 pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
218 let recommended_ids = recommended
219 .iter()
220 .map(|info| (info.model.provider_id(), info.model.id()))
221 .collect::<HashSet<_>>();
222
223 let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
224 for model in other {
225 if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
226 continue;
227 }
228
229 let provider = model.model.provider_id();
230 if let Some(models) = other_by_provider.get_mut(&provider) {
231 models.push(model);
232 } else {
233 other_by_provider.insert(provider, vec![model]);
234 }
235 }
236
237 Self {
238 recommended,
239 other: other_by_provider,
240 }
241 }
242
243 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
244 let mut entries = Vec::new();
245
246 if !self.recommended.is_empty() {
247 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
248 entries.extend(
249 self.recommended
250 .iter()
251 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
252 );
253 }
254
255 for models in self.other.values() {
256 if models.is_empty() {
257 continue;
258 }
259 entries.push(LanguageModelPickerEntry::Separator(
260 models[0].model.provider_name().0,
261 ));
262 entries.extend(
263 models
264 .iter()
265 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
266 );
267 }
268 entries
269 }
270
271 fn model_infos(&self) -> Vec<ModelInfo> {
272 let other = self
273 .other
274 .values()
275 .flat_map(|model| model.iter())
276 .cloned()
277 .collect::<Vec<_>>();
278 self.recommended
279 .iter()
280 .chain(&other)
281 .cloned()
282 .collect::<Vec<_>>()
283 }
284}
285
286enum LanguageModelPickerEntry {
287 Model(ModelInfo),
288 Separator(SharedString),
289}
290
291struct ModelMatcher {
292 models: Vec<ModelInfo>,
293 bg_executor: BackgroundExecutor,
294 candidates: Vec<StringMatchCandidate>,
295}
296
297impl ModelMatcher {
298 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
299 let candidates = Self::make_match_candidates(&models);
300 Self {
301 models,
302 bg_executor,
303 candidates,
304 }
305 }
306
307 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
308 let mut matches = self.bg_executor.block(match_strings(
309 &self.candidates,
310 &query,
311 false,
312 true,
313 100,
314 &Default::default(),
315 self.bg_executor.clone(),
316 ));
317
318 let sorting_key = |mat: &StringMatch| {
319 let candidate = &self.candidates[mat.candidate_id];
320 (Reverse(OrderedFloat(mat.score)), candidate.id)
321 };
322 matches.sort_unstable_by_key(sorting_key);
323
324 let matched_models: Vec<_> = matches
325 .into_iter()
326 .map(|mat| self.models[mat.candidate_id].clone())
327 .collect();
328
329 matched_models
330 }
331
332 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
333 self.models
334 .iter()
335 .filter(|m| {
336 m.model
337 .name()
338 .0
339 .to_lowercase()
340 .contains(&query.to_lowercase())
341 })
342 .cloned()
343 .collect::<Vec<_>>()
344 }
345
346 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
347 model_infos
348 .iter()
349 .enumerate()
350 .map(|(index, model)| {
351 StringMatchCandidate::new(
352 index,
353 &format!(
354 "{}/{}",
355 &model.model.provider_name().0,
356 &model.model.name().0
357 ),
358 )
359 })
360 .collect::<Vec<_>>()
361 }
362}
363
364impl PickerDelegate for LanguageModelPickerDelegate {
365 type ListItem = AnyElement;
366
367 fn match_count(&self) -> usize {
368 self.filtered_entries.len()
369 }
370
371 fn selected_index(&self) -> usize {
372 self.selected_index
373 }
374
375 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
376 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
377 cx.notify();
378 }
379
380 fn can_select(
381 &mut self,
382 ix: usize,
383 _window: &mut Window,
384 _cx: &mut Context<Picker<Self>>,
385 ) -> bool {
386 match self.filtered_entries.get(ix) {
387 Some(LanguageModelPickerEntry::Model(_)) => true,
388 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
389 }
390 }
391
392 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
393 "Select a model…".into()
394 }
395
396 fn update_matches(
397 &mut self,
398 query: String,
399 window: &mut Window,
400 cx: &mut Context<Picker<Self>>,
401 ) -> Task<()> {
402 let all_models = self.all_models.clone();
403 let active_model = (self.get_active_model)(cx);
404 let bg_executor = cx.background_executor();
405
406 let language_model_registry = LanguageModelRegistry::global(cx);
407
408 let configured_providers = language_model_registry
409 .read(cx)
410 .providers()
411 .into_iter()
412 .filter(|provider| provider.is_authenticated(cx))
413 .collect::<Vec<_>>();
414
415 let configured_provider_ids = configured_providers
416 .iter()
417 .map(|provider| provider.id())
418 .collect::<Vec<_>>();
419
420 let recommended_models = all_models
421 .recommended
422 .iter()
423 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
424 .cloned()
425 .collect::<Vec<_>>();
426
427 let available_models = all_models
428 .model_infos()
429 .iter()
430 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
431 .cloned()
432 .collect::<Vec<_>>();
433
434 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
435 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
436
437 let recommended = matcher_rec.exact_search(&query);
438 let all = matcher_all.fuzzy_search(&query);
439
440 let filtered_models = GroupedModels::new(all, recommended);
441
442 cx.spawn_in(window, async move |this, cx| {
443 this.update_in(cx, |this, window, cx| {
444 this.delegate.filtered_entries = filtered_models.entries();
445 // Finds the currently selected model in the list
446 let new_index =
447 Self::get_active_model_index(&this.delegate.filtered_entries, active_model);
448 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
449 cx.notify();
450 })
451 .ok();
452 })
453 }
454
455 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
456 if let Some(LanguageModelPickerEntry::Model(model_info)) =
457 self.filtered_entries.get(self.selected_index)
458 {
459 let model = model_info.model.clone();
460 (self.on_model_changed)(model.clone(), cx);
461
462 let current_index = self.selected_index;
463 self.set_selected_index(current_index, window, cx);
464
465 cx.emit(DismissEvent);
466 }
467 }
468
469 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
470 cx.emit(DismissEvent);
471 }
472
473 fn render_match(
474 &self,
475 ix: usize,
476 selected: bool,
477 _: &mut Window,
478 cx: &mut Context<Picker<Self>>,
479 ) -> Option<Self::ListItem> {
480 match self.filtered_entries.get(ix)? {
481 LanguageModelPickerEntry::Separator(title) => Some(
482 div()
483 .px_2()
484 .pb_1()
485 .when(ix > 1, |this| {
486 this.mt_1()
487 .pt_2()
488 .border_t_1()
489 .border_color(cx.theme().colors().border_variant)
490 })
491 .child(
492 Label::new(title)
493 .size(LabelSize::XSmall)
494 .color(Color::Muted),
495 )
496 .into_any_element(),
497 ),
498 LanguageModelPickerEntry::Model(model_info) => {
499 let active_model = (self.get_active_model)(cx);
500 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
501 let active_model_id = active_model.map(|m| m.model.id());
502
503 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
504 && Some(model_info.model.id()) == active_model_id;
505
506 let model_icon_color = if is_selected {
507 Color::Accent
508 } else {
509 Color::Muted
510 };
511
512 Some(
513 ListItem::new(ix)
514 .inset(true)
515 .spacing(ListItemSpacing::Sparse)
516 .toggle_state(selected)
517 .start_slot(
518 Icon::new(model_info.icon)
519 .color(model_icon_color)
520 .size(IconSize::Small),
521 )
522 .child(
523 h_flex()
524 .w_full()
525 .pl_0p5()
526 .gap_1p5()
527 .w(px(240.))
528 .child(Label::new(model_info.model.name().0.clone()).truncate()),
529 )
530 .end_slot(div().pr_3().when(is_selected, |this| {
531 this.child(
532 Icon::new(IconName::Check)
533 .color(Color::Accent)
534 .size(IconSize::Small),
535 )
536 }))
537 .into_any_element(),
538 )
539 }
540 }
541 }
542
543 fn render_footer(
544 &self,
545 _: &mut Window,
546 cx: &mut Context<Picker<Self>>,
547 ) -> Option<gpui::AnyElement> {
548 use feature_flags::FeatureFlagAppExt;
549
550 let plan = proto::Plan::ZedPro;
551
552 Some(
553 h_flex()
554 .w_full()
555 .border_t_1()
556 .border_color(cx.theme().colors().border_variant)
557 .p_1()
558 .gap_4()
559 .justify_between()
560 .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
561 this.child(match plan {
562 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
563 .icon(IconName::ZedAssistant)
564 .icon_size(IconSize::Small)
565 .icon_color(Color::Muted)
566 .icon_position(IconPosition::Start)
567 .on_click(|_, window, cx| {
568 window
569 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
570 }),
571 Plan::Free | Plan::ZedProTrial => Button::new(
572 "try-pro",
573 if plan == Plan::ZedProTrial {
574 "Upgrade to Pro"
575 } else {
576 "Try Pro"
577 },
578 )
579 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
580 })
581 })
582 .child(
583 Button::new("configure", "Configure")
584 .icon(IconName::Settings)
585 .icon_size(IconSize::Small)
586 .icon_color(Color::Muted)
587 .icon_position(IconPosition::Start)
588 .on_click(|_, window, cx| {
589 window.dispatch_action(
590 zed_actions::agent::OpenConfiguration.boxed_clone(),
591 cx,
592 );
593 }),
594 )
595 .into_any(),
596 )
597 }
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603 use futures::{future::BoxFuture, stream::BoxStream};
604 use gpui::{AsyncApp, TestAppContext, http_client};
605 use language_model::{
606 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
607 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
608 LanguageModelRequest, LanguageModelToolChoice,
609 };
610 use ui::IconName;
611
612 #[derive(Clone)]
613 struct TestLanguageModel {
614 name: LanguageModelName,
615 id: LanguageModelId,
616 provider_id: LanguageModelProviderId,
617 provider_name: LanguageModelProviderName,
618 }
619
620 impl TestLanguageModel {
621 fn new(name: &str, provider: &str) -> Self {
622 Self {
623 name: LanguageModelName::from(name.to_string()),
624 id: LanguageModelId::from(name.to_string()),
625 provider_id: LanguageModelProviderId::from(provider.to_string()),
626 provider_name: LanguageModelProviderName::from(provider.to_string()),
627 }
628 }
629 }
630
631 impl LanguageModel for TestLanguageModel {
632 fn id(&self) -> LanguageModelId {
633 self.id.clone()
634 }
635
636 fn name(&self) -> LanguageModelName {
637 self.name.clone()
638 }
639
640 fn provider_id(&self) -> LanguageModelProviderId {
641 self.provider_id.clone()
642 }
643
644 fn provider_name(&self) -> LanguageModelProviderName {
645 self.provider_name.clone()
646 }
647
648 fn supports_tools(&self) -> bool {
649 false
650 }
651
652 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
653 false
654 }
655
656 fn supports_images(&self) -> bool {
657 false
658 }
659
660 fn telemetry_id(&self) -> String {
661 format!("{}/{}", self.provider_id.0, self.name.0)
662 }
663
664 fn max_token_count(&self) -> u64 {
665 1000
666 }
667
668 fn count_tokens(
669 &self,
670 _: LanguageModelRequest,
671 _: &App,
672 ) -> BoxFuture<'static, http_client::Result<u64>> {
673 unimplemented!()
674 }
675
676 fn stream_completion(
677 &self,
678 _: LanguageModelRequest,
679 _: &AsyncApp,
680 ) -> BoxFuture<
681 'static,
682 Result<
683 BoxStream<
684 'static,
685 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
686 >,
687 LanguageModelCompletionError,
688 >,
689 > {
690 unimplemented!()
691 }
692 }
693
694 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
695 model_specs
696 .into_iter()
697 .map(|(provider, name)| ModelInfo {
698 model: Arc::new(TestLanguageModel::new(name, provider)),
699 icon: IconName::Ai,
700 })
701 .collect()
702 }
703
704 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
705 assert_eq!(
706 result.len(),
707 expected.len(),
708 "Number of models doesn't match"
709 );
710
711 for (i, expected_name) in expected.iter().enumerate() {
712 assert_eq!(
713 result[i].model.telemetry_id(),
714 *expected_name,
715 "Model at position {} doesn't match expected model",
716 i
717 );
718 }
719 }
720
721 #[gpui::test]
722 fn test_exact_match(cx: &mut TestAppContext) {
723 let models = create_models(vec![
724 ("zed", "Claude 3.7 Sonnet"),
725 ("zed", "Claude 3.7 Sonnet Thinking"),
726 ("zed", "gpt-4.1"),
727 ("zed", "gpt-4.1-nano"),
728 ("openai", "gpt-3.5-turbo"),
729 ("openai", "gpt-4.1"),
730 ("openai", "gpt-4.1-nano"),
731 ("ollama", "mistral"),
732 ("ollama", "deepseek"),
733 ]);
734 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
735
736 // The order of models should be maintained, case doesn't matter
737 let results = matcher.exact_search("GPT-4.1");
738 assert_models_eq(
739 results,
740 vec![
741 "zed/gpt-4.1",
742 "zed/gpt-4.1-nano",
743 "openai/gpt-4.1",
744 "openai/gpt-4.1-nano",
745 ],
746 );
747 }
748
749 #[gpui::test]
750 fn test_fuzzy_match(cx: &mut TestAppContext) {
751 let models = create_models(vec![
752 ("zed", "Claude 3.7 Sonnet"),
753 ("zed", "Claude 3.7 Sonnet Thinking"),
754 ("zed", "gpt-4.1"),
755 ("zed", "gpt-4.1-nano"),
756 ("openai", "gpt-3.5-turbo"),
757 ("openai", "gpt-4.1"),
758 ("openai", "gpt-4.1-nano"),
759 ("ollama", "mistral"),
760 ("ollama", "deepseek"),
761 ]);
762 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
763
764 // Results should preserve models order whenever possible.
765 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
766 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
767 // so it should appear first in the results.
768 let results = matcher.fuzzy_search("41");
769 assert_models_eq(
770 results,
771 vec![
772 "zed/gpt-4.1",
773 "openai/gpt-4.1",
774 "zed/gpt-4.1-nano",
775 "openai/gpt-4.1-nano",
776 ],
777 );
778
779 // Model provider should be searchable as well
780 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
781 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
782
783 // Fuzzy search
784 let results = matcher.fuzzy_search("z4n");
785 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
786 }
787
788 #[gpui::test]
789 fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
790 let recommended_models = create_models(vec![("zed", "claude")]);
791 let all_models = create_models(vec![
792 ("zed", "claude"), // Should be filtered out from "other"
793 ("zed", "gemini"),
794 ("copilot", "o3"),
795 ]);
796
797 let grouped_models = GroupedModels::new(all_models, recommended_models);
798
799 let actual_other_models = grouped_models
800 .other
801 .values()
802 .flatten()
803 .cloned()
804 .collect::<Vec<_>>();
805
806 // Recommended models should not appear in "other"
807 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
808 }
809
810 #[gpui::test]
811 fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
812 let recommended_models = create_models(vec![("zed", "claude")]);
813 let all_models = create_models(vec![
814 ("zed", "claude"), // Should be filtered out from "other"
815 ("zed", "gemini"),
816 ("copilot", "claude"), // Should not be filtered out from "other"
817 ]);
818
819 let grouped_models = GroupedModels::new(all_models, recommended_models);
820
821 let actual_other_models = grouped_models
822 .other
823 .values()
824 .flatten()
825 .cloned()
826 .collect::<Vec<_>>();
827
828 // Recommended models should not appear in "other"
829 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
830 }
831}