1use std::{cmp::Reverse, sync::Arc};
2
3use collections::{HashSet, IndexMap};
4use feature_flags::ZedProFeatureFlag;
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{
7 Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task,
8 action_with_deprecated_aliases,
9};
10use language_model::{
11 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
12 LanguageModelRegistry,
13};
14use ordered_float::OrderedFloat;
15use picker::{Picker, PickerDelegate};
16use proto::Plan;
17use ui::{ListItem, ListItemSpacing, prelude::*};
18
19action_with_deprecated_aliases!(
20 agent,
21 ToggleModelSelector,
22 [
23 "assistant::ToggleModelSelector",
24 "assistant2::ToggleModelSelector"
25 ]
26);
27
28const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
29
30type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
31type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
32
33pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
34
35pub fn language_model_selector(
36 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
37 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
38 window: &mut Window,
39 cx: &mut Context<LanguageModelSelector>,
40) -> LanguageModelSelector {
41 let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
42 Picker::list(delegate, window, cx)
43 .show_scrollbar(true)
44 .width(rems(20.))
45 .max_height(Some(rems(20.).into()))
46}
47
48fn all_models(cx: &App) -> GroupedModels {
49 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
50
51 let recommended = providers
52 .iter()
53 .flat_map(|provider| {
54 provider
55 .recommended_models(cx)
56 .into_iter()
57 .map(|model| ModelInfo {
58 model,
59 icon: provider.icon(),
60 })
61 })
62 .collect();
63
64 let other = providers
65 .iter()
66 .flat_map(|provider| {
67 provider
68 .provided_models(cx)
69 .into_iter()
70 .map(|model| ModelInfo {
71 model,
72 icon: provider.icon(),
73 })
74 })
75 .collect();
76
77 GroupedModels::new(other, recommended)
78}
79
80#[derive(Clone)]
81struct ModelInfo {
82 model: Arc<dyn LanguageModel>,
83 icon: IconName,
84}
85
86pub struct LanguageModelPickerDelegate {
87 on_model_changed: OnModelChanged,
88 get_active_model: GetActiveModel,
89 all_models: Arc<GroupedModels>,
90 filtered_entries: Vec<LanguageModelPickerEntry>,
91 selected_index: usize,
92 _authenticate_all_providers_task: Task<()>,
93 _subscriptions: Vec<Subscription>,
94}
95
96impl LanguageModelPickerDelegate {
97 fn new(
98 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
99 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
100 window: &mut Window,
101 cx: &mut Context<Picker<Self>>,
102 ) -> Self {
103 let on_model_changed = Arc::new(on_model_changed);
104 let models = all_models(cx);
105 let entries = models.entries();
106
107 Self {
108 on_model_changed: on_model_changed.clone(),
109 all_models: Arc::new(models),
110 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
111 filtered_entries: entries,
112 get_active_model: Arc::new(get_active_model),
113 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
114 _subscriptions: vec![cx.subscribe_in(
115 &LanguageModelRegistry::global(cx),
116 window,
117 |picker, _, event, window, cx| {
118 match event {
119 language_model::Event::ProviderStateChanged
120 | language_model::Event::AddedProvider(_)
121 | language_model::Event::RemovedProvider(_) => {
122 let query = picker.query(cx);
123 picker.delegate.all_models = Arc::new(all_models(cx));
124 // Update matches will automatically drop the previous task
125 // if we get a provider event again
126 picker.update_matches(query, window, cx)
127 }
128 _ => {}
129 }
130 },
131 )],
132 }
133 }
134
135 fn get_active_model_index(
136 entries: &[LanguageModelPickerEntry],
137 active_model: Option<ConfiguredModel>,
138 ) -> usize {
139 entries
140 .iter()
141 .position(|entry| {
142 if let LanguageModelPickerEntry::Model(model) = entry {
143 active_model
144 .as_ref()
145 .map(|active_model| {
146 active_model.model.id() == model.model.id()
147 && active_model.provider.id() == model.model.provider_id()
148 })
149 .unwrap_or_default()
150 } else {
151 false
152 }
153 })
154 .unwrap_or(0)
155 }
156
157 /// Authenticates all providers in the [`LanguageModelRegistry`].
158 ///
159 /// We do this so that we can populate the language selector with all of the
160 /// models from the configured providers.
161 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
162 let authenticate_all_providers = LanguageModelRegistry::global(cx)
163 .read(cx)
164 .providers()
165 .iter()
166 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
167 .collect::<Vec<_>>();
168
169 cx.spawn(async move |_cx| {
170 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
171 if let Err(err) = authenticate_task.await {
172 if matches!(err, AuthenticateError::CredentialsNotFound) {
173 // Since we're authenticating these providers in the
174 // background for the purposes of populating the
175 // language selector, we don't care about providers
176 // where the credentials are not found.
177 } else {
178 // Some providers have noisy failure states that we
179 // don't want to spam the logs with every time the
180 // language model selector is initialized.
181 //
182 // Ideally these should have more clear failure modes
183 // that we know are safe to ignore here, like what we do
184 // with `CredentialsNotFound` above.
185 match provider_id.0.as_ref() {
186 "lmstudio" | "ollama" => {
187 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
188 //
189 // These fail noisily, so we don't log them.
190 }
191 "copilot_chat" => {
192 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
193 }
194 _ => {
195 log::error!(
196 "Failed to authenticate provider: {}: {err}",
197 provider_name.0
198 );
199 }
200 }
201 }
202 }
203 }
204 })
205 }
206
207 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
208 (self.get_active_model)(cx)
209 }
210}
211
212struct GroupedModels {
213 recommended: Vec<ModelInfo>,
214 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
215}
216
217impl GroupedModels {
218 pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
219 let recommended_ids = recommended
220 .iter()
221 .map(|info| (info.model.provider_id(), info.model.id()))
222 .collect::<HashSet<_>>();
223
224 let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
225 for model in other {
226 if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
227 continue;
228 }
229
230 let provider = model.model.provider_id();
231 if let Some(models) = other_by_provider.get_mut(&provider) {
232 models.push(model);
233 } else {
234 other_by_provider.insert(provider, vec![model]);
235 }
236 }
237
238 Self {
239 recommended,
240 other: other_by_provider,
241 }
242 }
243
244 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
245 let mut entries = Vec::new();
246
247 if !self.recommended.is_empty() {
248 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
249 entries.extend(
250 self.recommended
251 .iter()
252 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
253 );
254 }
255
256 for models in self.other.values() {
257 if models.is_empty() {
258 continue;
259 }
260 entries.push(LanguageModelPickerEntry::Separator(
261 models[0].model.provider_name().0,
262 ));
263 entries.extend(
264 models
265 .iter()
266 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
267 );
268 }
269 entries
270 }
271
272 fn model_infos(&self) -> Vec<ModelInfo> {
273 let other = self
274 .other
275 .values()
276 .flat_map(|model| model.iter())
277 .cloned()
278 .collect::<Vec<_>>();
279 self.recommended
280 .iter()
281 .chain(&other)
282 .cloned()
283 .collect::<Vec<_>>()
284 }
285}
286
287enum LanguageModelPickerEntry {
288 Model(ModelInfo),
289 Separator(SharedString),
290}
291
292struct ModelMatcher {
293 models: Vec<ModelInfo>,
294 bg_executor: BackgroundExecutor,
295 candidates: Vec<StringMatchCandidate>,
296}
297
298impl ModelMatcher {
299 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
300 let candidates = Self::make_match_candidates(&models);
301 Self {
302 models,
303 bg_executor,
304 candidates,
305 }
306 }
307
308 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
309 let mut matches = self.bg_executor.block(match_strings(
310 &self.candidates,
311 &query,
312 false,
313 true,
314 100,
315 &Default::default(),
316 self.bg_executor.clone(),
317 ));
318
319 let sorting_key = |mat: &StringMatch| {
320 let candidate = &self.candidates[mat.candidate_id];
321 (Reverse(OrderedFloat(mat.score)), candidate.id)
322 };
323 matches.sort_unstable_by_key(sorting_key);
324
325 let matched_models: Vec<_> = matches
326 .into_iter()
327 .map(|mat| self.models[mat.candidate_id].clone())
328 .collect();
329
330 matched_models
331 }
332
333 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
334 self.models
335 .iter()
336 .filter(|m| {
337 m.model
338 .name()
339 .0
340 .to_lowercase()
341 .contains(&query.to_lowercase())
342 })
343 .cloned()
344 .collect::<Vec<_>>()
345 }
346
347 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
348 model_infos
349 .iter()
350 .enumerate()
351 .map(|(index, model)| {
352 StringMatchCandidate::new(
353 index,
354 &format!(
355 "{}/{}",
356 &model.model.provider_name().0,
357 &model.model.name().0
358 ),
359 )
360 })
361 .collect::<Vec<_>>()
362 }
363}
364
365impl PickerDelegate for LanguageModelPickerDelegate {
366 type ListItem = AnyElement;
367
368 fn match_count(&self) -> usize {
369 self.filtered_entries.len()
370 }
371
372 fn selected_index(&self) -> usize {
373 self.selected_index
374 }
375
376 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
377 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
378 cx.notify();
379 }
380
381 fn can_select(
382 &mut self,
383 ix: usize,
384 _window: &mut Window,
385 _cx: &mut Context<Picker<Self>>,
386 ) -> bool {
387 match self.filtered_entries.get(ix) {
388 Some(LanguageModelPickerEntry::Model(_)) => true,
389 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
390 }
391 }
392
393 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
394 "Select a model…".into()
395 }
396
397 fn update_matches(
398 &mut self,
399 query: String,
400 window: &mut Window,
401 cx: &mut Context<Picker<Self>>,
402 ) -> Task<()> {
403 let all_models = self.all_models.clone();
404 let current_index = self.selected_index;
405 let bg_executor = cx.background_executor();
406
407 let language_model_registry = LanguageModelRegistry::global(cx);
408
409 let configured_providers = language_model_registry
410 .read(cx)
411 .providers()
412 .into_iter()
413 .filter(|provider| provider.is_authenticated(cx))
414 .collect::<Vec<_>>();
415
416 let configured_provider_ids = configured_providers
417 .iter()
418 .map(|provider| provider.id())
419 .collect::<Vec<_>>();
420
421 let recommended_models = all_models
422 .recommended
423 .iter()
424 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
425 .cloned()
426 .collect::<Vec<_>>();
427
428 let available_models = all_models
429 .model_infos()
430 .iter()
431 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
432 .cloned()
433 .collect::<Vec<_>>();
434
435 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
436 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
437
438 let recommended = matcher_rec.exact_search(&query);
439 let all = matcher_all.fuzzy_search(&query);
440
441 let filtered_models = GroupedModels::new(all, recommended);
442
443 cx.spawn_in(window, async move |this, cx| {
444 this.update_in(cx, |this, window, cx| {
445 this.delegate.filtered_entries = filtered_models.entries();
446 // Preserve selection focus
447 let new_index = if current_index >= this.delegate.filtered_entries.len() {
448 0
449 } else {
450 current_index
451 };
452 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
453 cx.notify();
454 })
455 .ok();
456 })
457 }
458
459 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
460 if let Some(LanguageModelPickerEntry::Model(model_info)) =
461 self.filtered_entries.get(self.selected_index)
462 {
463 let model = model_info.model.clone();
464 (self.on_model_changed)(model.clone(), cx);
465
466 let current_index = self.selected_index;
467 self.set_selected_index(current_index, window, cx);
468
469 cx.emit(DismissEvent);
470 }
471 }
472
473 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
474 cx.emit(DismissEvent);
475 }
476
477 fn render_match(
478 &self,
479 ix: usize,
480 selected: bool,
481 _: &mut Window,
482 cx: &mut Context<Picker<Self>>,
483 ) -> Option<Self::ListItem> {
484 match self.filtered_entries.get(ix)? {
485 LanguageModelPickerEntry::Separator(title) => Some(
486 div()
487 .px_2()
488 .pb_1()
489 .when(ix > 1, |this| {
490 this.mt_1()
491 .pt_2()
492 .border_t_1()
493 .border_color(cx.theme().colors().border_variant)
494 })
495 .child(
496 Label::new(title)
497 .size(LabelSize::XSmall)
498 .color(Color::Muted),
499 )
500 .into_any_element(),
501 ),
502 LanguageModelPickerEntry::Model(model_info) => {
503 let active_model = (self.get_active_model)(cx);
504 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
505 let active_model_id = active_model.map(|m| m.model.id());
506
507 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
508 && Some(model_info.model.id()) == active_model_id;
509
510 let model_icon_color = if is_selected {
511 Color::Accent
512 } else {
513 Color::Muted
514 };
515
516 Some(
517 ListItem::new(ix)
518 .inset(true)
519 .spacing(ListItemSpacing::Sparse)
520 .toggle_state(selected)
521 .start_slot(
522 Icon::new(model_info.icon)
523 .color(model_icon_color)
524 .size(IconSize::Small),
525 )
526 .child(
527 h_flex()
528 .w_full()
529 .pl_0p5()
530 .gap_1p5()
531 .w(px(240.))
532 .child(Label::new(model_info.model.name().0.clone()).truncate()),
533 )
534 .end_slot(div().pr_3().when(is_selected, |this| {
535 this.child(
536 Icon::new(IconName::Check)
537 .color(Color::Accent)
538 .size(IconSize::Small),
539 )
540 }))
541 .into_any_element(),
542 )
543 }
544 }
545 }
546
547 fn render_footer(
548 &self,
549 _: &mut Window,
550 cx: &mut Context<Picker<Self>>,
551 ) -> Option<gpui::AnyElement> {
552 use feature_flags::FeatureFlagAppExt;
553
554 let plan = proto::Plan::ZedPro;
555
556 Some(
557 h_flex()
558 .w_full()
559 .border_t_1()
560 .border_color(cx.theme().colors().border_variant)
561 .p_1()
562 .gap_4()
563 .justify_between()
564 .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
565 this.child(match plan {
566 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
567 .icon(IconName::ZedAssistant)
568 .icon_size(IconSize::Small)
569 .icon_color(Color::Muted)
570 .icon_position(IconPosition::Start)
571 .on_click(|_, window, cx| {
572 window
573 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
574 }),
575 Plan::Free | Plan::ZedProTrial => Button::new(
576 "try-pro",
577 if plan == Plan::ZedProTrial {
578 "Upgrade to Pro"
579 } else {
580 "Try Pro"
581 },
582 )
583 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
584 })
585 })
586 .child(
587 Button::new("configure", "Configure")
588 .icon(IconName::Settings)
589 .icon_size(IconSize::Small)
590 .icon_color(Color::Muted)
591 .icon_position(IconPosition::Start)
592 .on_click(|_, window, cx| {
593 window.dispatch_action(
594 zed_actions::agent::OpenConfiguration.boxed_clone(),
595 cx,
596 );
597 }),
598 )
599 .into_any(),
600 )
601 }
602}
603
604#[cfg(test)]
605mod tests {
606 use super::*;
607 use futures::{future::BoxFuture, stream::BoxStream};
608 use gpui::{AsyncApp, TestAppContext, http_client};
609 use language_model::{
610 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
611 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
612 LanguageModelRequest, LanguageModelToolChoice,
613 };
614 use ui::IconName;
615
616 #[derive(Clone)]
617 struct TestLanguageModel {
618 name: LanguageModelName,
619 id: LanguageModelId,
620 provider_id: LanguageModelProviderId,
621 provider_name: LanguageModelProviderName,
622 }
623
624 impl TestLanguageModel {
625 fn new(name: &str, provider: &str) -> Self {
626 Self {
627 name: LanguageModelName::from(name.to_string()),
628 id: LanguageModelId::from(name.to_string()),
629 provider_id: LanguageModelProviderId::from(provider.to_string()),
630 provider_name: LanguageModelProviderName::from(provider.to_string()),
631 }
632 }
633 }
634
635 impl LanguageModel for TestLanguageModel {
636 fn id(&self) -> LanguageModelId {
637 self.id.clone()
638 }
639
640 fn name(&self) -> LanguageModelName {
641 self.name.clone()
642 }
643
644 fn provider_id(&self) -> LanguageModelProviderId {
645 self.provider_id.clone()
646 }
647
648 fn provider_name(&self) -> LanguageModelProviderName {
649 self.provider_name.clone()
650 }
651
652 fn supports_tools(&self) -> bool {
653 false
654 }
655
656 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
657 false
658 }
659
660 fn supports_images(&self) -> bool {
661 false
662 }
663
664 fn telemetry_id(&self) -> String {
665 format!("{}/{}", self.provider_id.0, self.name.0)
666 }
667
668 fn max_token_count(&self) -> u64 {
669 1000
670 }
671
672 fn count_tokens(
673 &self,
674 _: LanguageModelRequest,
675 _: &App,
676 ) -> BoxFuture<'static, http_client::Result<u64>> {
677 unimplemented!()
678 }
679
680 fn stream_completion(
681 &self,
682 _: LanguageModelRequest,
683 _: &AsyncApp,
684 ) -> BoxFuture<
685 'static,
686 Result<
687 BoxStream<
688 'static,
689 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
690 >,
691 LanguageModelCompletionError,
692 >,
693 > {
694 unimplemented!()
695 }
696 }
697
698 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
699 model_specs
700 .into_iter()
701 .map(|(provider, name)| ModelInfo {
702 model: Arc::new(TestLanguageModel::new(name, provider)),
703 icon: IconName::Ai,
704 })
705 .collect()
706 }
707
708 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
709 assert_eq!(
710 result.len(),
711 expected.len(),
712 "Number of models doesn't match"
713 );
714
715 for (i, expected_name) in expected.iter().enumerate() {
716 assert_eq!(
717 result[i].model.telemetry_id(),
718 *expected_name,
719 "Model at position {} doesn't match expected model",
720 i
721 );
722 }
723 }
724
725 #[gpui::test]
726 fn test_exact_match(cx: &mut TestAppContext) {
727 let models = create_models(vec![
728 ("zed", "Claude 3.7 Sonnet"),
729 ("zed", "Claude 3.7 Sonnet Thinking"),
730 ("zed", "gpt-4.1"),
731 ("zed", "gpt-4.1-nano"),
732 ("openai", "gpt-3.5-turbo"),
733 ("openai", "gpt-4.1"),
734 ("openai", "gpt-4.1-nano"),
735 ("ollama", "mistral"),
736 ("ollama", "deepseek"),
737 ]);
738 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
739
740 // The order of models should be maintained, case doesn't matter
741 let results = matcher.exact_search("GPT-4.1");
742 assert_models_eq(
743 results,
744 vec![
745 "zed/gpt-4.1",
746 "zed/gpt-4.1-nano",
747 "openai/gpt-4.1",
748 "openai/gpt-4.1-nano",
749 ],
750 );
751 }
752
753 #[gpui::test]
754 fn test_fuzzy_match(cx: &mut TestAppContext) {
755 let models = create_models(vec![
756 ("zed", "Claude 3.7 Sonnet"),
757 ("zed", "Claude 3.7 Sonnet Thinking"),
758 ("zed", "gpt-4.1"),
759 ("zed", "gpt-4.1-nano"),
760 ("openai", "gpt-3.5-turbo"),
761 ("openai", "gpt-4.1"),
762 ("openai", "gpt-4.1-nano"),
763 ("ollama", "mistral"),
764 ("ollama", "deepseek"),
765 ]);
766 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
767
768 // Results should preserve models order whenever possible.
769 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
770 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
771 // so it should appear first in the results.
772 let results = matcher.fuzzy_search("41");
773 assert_models_eq(
774 results,
775 vec![
776 "zed/gpt-4.1",
777 "openai/gpt-4.1",
778 "zed/gpt-4.1-nano",
779 "openai/gpt-4.1-nano",
780 ],
781 );
782
783 // Model provider should be searchable as well
784 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
785 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
786
787 // Fuzzy search
788 let results = matcher.fuzzy_search("z4n");
789 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
790 }
791
792 #[gpui::test]
793 fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
794 let recommended_models = create_models(vec![("zed", "claude")]);
795 let all_models = create_models(vec![
796 ("zed", "claude"), // Should be filtered out from "other"
797 ("zed", "gemini"),
798 ("copilot", "o3"),
799 ]);
800
801 let grouped_models = GroupedModels::new(all_models, recommended_models);
802
803 let actual_other_models = grouped_models
804 .other
805 .values()
806 .flatten()
807 .cloned()
808 .collect::<Vec<_>>();
809
810 // Recommended models should not appear in "other"
811 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
812 }
813
814 #[gpui::test]
815 fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
816 let recommended_models = create_models(vec![("zed", "claude")]);
817 let all_models = create_models(vec![
818 ("zed", "claude"), // Should be filtered out from "other"
819 ("zed", "gemini"),
820 ("copilot", "claude"), // Should not be filtered out from "other"
821 ]);
822
823 let grouped_models = GroupedModels::new(all_models, recommended_models);
824
825 let actual_other_models = grouped_models
826 .other
827 .values()
828 .flatten()
829 .cloned()
830 .collect::<Vec<_>>();
831
832 // Recommended models should not appear in "other"
833 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
834 }
835}