1use std::{cmp::Reverse, sync::Arc};
2
3use collections::{HashSet, IndexMap};
4use feature_flags::ZedProFeatureFlag;
5use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
6use gpui::{
7 Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task, actions,
8};
9use language_model::{
10 AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
11 LanguageModelRegistry,
12};
13use ordered_float::OrderedFloat;
14use picker::{Picker, PickerDelegate};
15use proto::Plan;
16use ui::{ListItem, ListItemSpacing, prelude::*};
17
18actions!(
19 agent,
20 [
21 #[action(deprecated_aliases = ["assistant::ToggleModelSelector", "assistant2::ToggleModelSelector"])]
22 ToggleModelSelector
23 ]
24);
25
26const TRY_ZED_PRO_URL: &str = "https://zed.dev/pro";
27
28type OnModelChanged = Arc<dyn Fn(Arc<dyn LanguageModel>, &mut App) + 'static>;
29type GetActiveModel = Arc<dyn Fn(&App) -> Option<ConfiguredModel> + 'static>;
30
31pub type LanguageModelSelector = Picker<LanguageModelPickerDelegate>;
32
33pub fn language_model_selector(
34 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
35 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
36 window: &mut Window,
37 cx: &mut Context<LanguageModelSelector>,
38) -> LanguageModelSelector {
39 let delegate = LanguageModelPickerDelegate::new(get_active_model, on_model_changed, window, cx);
40 Picker::list(delegate, window, cx)
41 .show_scrollbar(true)
42 .width(rems(20.))
43 .max_height(Some(rems(20.).into()))
44}
45
46fn all_models(cx: &App) -> GroupedModels {
47 let providers = LanguageModelRegistry::global(cx).read(cx).providers();
48
49 let recommended = providers
50 .iter()
51 .flat_map(|provider| {
52 provider
53 .recommended_models(cx)
54 .into_iter()
55 .map(|model| ModelInfo {
56 model,
57 icon: provider.icon(),
58 })
59 })
60 .collect();
61
62 let other = providers
63 .iter()
64 .flat_map(|provider| {
65 provider
66 .provided_models(cx)
67 .into_iter()
68 .map(|model| ModelInfo {
69 model,
70 icon: provider.icon(),
71 })
72 })
73 .collect();
74
75 GroupedModels::new(other, recommended)
76}
77
78#[derive(Clone)]
79struct ModelInfo {
80 model: Arc<dyn LanguageModel>,
81 icon: IconName,
82}
83
84pub struct LanguageModelPickerDelegate {
85 on_model_changed: OnModelChanged,
86 get_active_model: GetActiveModel,
87 all_models: Arc<GroupedModels>,
88 filtered_entries: Vec<LanguageModelPickerEntry>,
89 selected_index: usize,
90 _authenticate_all_providers_task: Task<()>,
91 _subscriptions: Vec<Subscription>,
92}
93
94impl LanguageModelPickerDelegate {
95 fn new(
96 get_active_model: impl Fn(&App) -> Option<ConfiguredModel> + 'static,
97 on_model_changed: impl Fn(Arc<dyn LanguageModel>, &mut App) + 'static,
98 window: &mut Window,
99 cx: &mut Context<Picker<Self>>,
100 ) -> Self {
101 let on_model_changed = Arc::new(on_model_changed);
102 let models = all_models(cx);
103 let entries = models.entries();
104
105 Self {
106 on_model_changed: on_model_changed.clone(),
107 all_models: Arc::new(models),
108 selected_index: Self::get_active_model_index(&entries, get_active_model(cx)),
109 filtered_entries: entries,
110 get_active_model: Arc::new(get_active_model),
111 _authenticate_all_providers_task: Self::authenticate_all_providers(cx),
112 _subscriptions: vec![cx.subscribe_in(
113 &LanguageModelRegistry::global(cx),
114 window,
115 |picker, _, event, window, cx| {
116 match event {
117 language_model::Event::ProviderStateChanged
118 | language_model::Event::AddedProvider(_)
119 | language_model::Event::RemovedProvider(_) => {
120 let query = picker.query(cx);
121 picker.delegate.all_models = Arc::new(all_models(cx));
122 // Update matches will automatically drop the previous task
123 // if we get a provider event again
124 picker.update_matches(query, window, cx)
125 }
126 _ => {}
127 }
128 },
129 )],
130 }
131 }
132
133 fn get_active_model_index(
134 entries: &[LanguageModelPickerEntry],
135 active_model: Option<ConfiguredModel>,
136 ) -> usize {
137 entries
138 .iter()
139 .position(|entry| {
140 if let LanguageModelPickerEntry::Model(model) = entry {
141 active_model
142 .as_ref()
143 .map(|active_model| {
144 active_model.model.id() == model.model.id()
145 && active_model.provider.id() == model.model.provider_id()
146 })
147 .unwrap_or_default()
148 } else {
149 false
150 }
151 })
152 .unwrap_or(0)
153 }
154
155 /// Authenticates all providers in the [`LanguageModelRegistry`].
156 ///
157 /// We do this so that we can populate the language selector with all of the
158 /// models from the configured providers.
159 fn authenticate_all_providers(cx: &mut App) -> Task<()> {
160 let authenticate_all_providers = LanguageModelRegistry::global(cx)
161 .read(cx)
162 .providers()
163 .iter()
164 .map(|provider| (provider.id(), provider.name(), provider.authenticate(cx)))
165 .collect::<Vec<_>>();
166
167 cx.spawn(async move |_cx| {
168 for (provider_id, provider_name, authenticate_task) in authenticate_all_providers {
169 if let Err(err) = authenticate_task.await {
170 if matches!(err, AuthenticateError::CredentialsNotFound) {
171 // Since we're authenticating these providers in the
172 // background for the purposes of populating the
173 // language selector, we don't care about providers
174 // where the credentials are not found.
175 } else {
176 // Some providers have noisy failure states that we
177 // don't want to spam the logs with every time the
178 // language model selector is initialized.
179 //
180 // Ideally these should have more clear failure modes
181 // that we know are safe to ignore here, like what we do
182 // with `CredentialsNotFound` above.
183 match provider_id.0.as_ref() {
184 "lmstudio" | "ollama" => {
185 // LM Studio and Ollama both make fetch requests to the local APIs to determine if they are "authenticated".
186 //
187 // These fail noisily, so we don't log them.
188 }
189 "copilot_chat" => {
190 // Copilot Chat returns an error if Copilot is not enabled, so we don't log those errors.
191 }
192 _ => {
193 log::error!(
194 "Failed to authenticate provider: {}: {err}",
195 provider_name.0
196 );
197 }
198 }
199 }
200 }
201 }
202 })
203 }
204
205 pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
206 (self.get_active_model)(cx)
207 }
208}
209
210struct GroupedModels {
211 recommended: Vec<ModelInfo>,
212 other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
213}
214
215impl GroupedModels {
216 pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
217 let recommended_ids = recommended
218 .iter()
219 .map(|info| (info.model.provider_id(), info.model.id()))
220 .collect::<HashSet<_>>();
221
222 let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
223 for model in other {
224 if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
225 continue;
226 }
227
228 let provider = model.model.provider_id();
229 if let Some(models) = other_by_provider.get_mut(&provider) {
230 models.push(model);
231 } else {
232 other_by_provider.insert(provider, vec![model]);
233 }
234 }
235
236 Self {
237 recommended,
238 other: other_by_provider,
239 }
240 }
241
242 fn entries(&self) -> Vec<LanguageModelPickerEntry> {
243 let mut entries = Vec::new();
244
245 if !self.recommended.is_empty() {
246 entries.push(LanguageModelPickerEntry::Separator("Recommended".into()));
247 entries.extend(
248 self.recommended
249 .iter()
250 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
251 );
252 }
253
254 for models in self.other.values() {
255 if models.is_empty() {
256 continue;
257 }
258 entries.push(LanguageModelPickerEntry::Separator(
259 models[0].model.provider_name().0,
260 ));
261 entries.extend(
262 models
263 .iter()
264 .map(|info| LanguageModelPickerEntry::Model(info.clone())),
265 );
266 }
267 entries
268 }
269
270 fn model_infos(&self) -> Vec<ModelInfo> {
271 let other = self
272 .other
273 .values()
274 .flat_map(|model| model.iter())
275 .cloned()
276 .collect::<Vec<_>>();
277 self.recommended
278 .iter()
279 .chain(&other)
280 .cloned()
281 .collect::<Vec<_>>()
282 }
283}
284
285enum LanguageModelPickerEntry {
286 Model(ModelInfo),
287 Separator(SharedString),
288}
289
290struct ModelMatcher {
291 models: Vec<ModelInfo>,
292 bg_executor: BackgroundExecutor,
293 candidates: Vec<StringMatchCandidate>,
294}
295
296impl ModelMatcher {
297 fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
298 let candidates = Self::make_match_candidates(&models);
299 Self {
300 models,
301 bg_executor,
302 candidates,
303 }
304 }
305
306 pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
307 let mut matches = self.bg_executor.block(match_strings(
308 &self.candidates,
309 &query,
310 false,
311 true,
312 100,
313 &Default::default(),
314 self.bg_executor.clone(),
315 ));
316
317 let sorting_key = |mat: &StringMatch| {
318 let candidate = &self.candidates[mat.candidate_id];
319 (Reverse(OrderedFloat(mat.score)), candidate.id)
320 };
321 matches.sort_unstable_by_key(sorting_key);
322
323 let matched_models: Vec<_> = matches
324 .into_iter()
325 .map(|mat| self.models[mat.candidate_id].clone())
326 .collect();
327
328 matched_models
329 }
330
331 pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
332 self.models
333 .iter()
334 .filter(|m| {
335 m.model
336 .name()
337 .0
338 .to_lowercase()
339 .contains(&query.to_lowercase())
340 })
341 .cloned()
342 .collect::<Vec<_>>()
343 }
344
345 fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
346 model_infos
347 .iter()
348 .enumerate()
349 .map(|(index, model)| {
350 StringMatchCandidate::new(
351 index,
352 &format!(
353 "{}/{}",
354 &model.model.provider_name().0,
355 &model.model.name().0
356 ),
357 )
358 })
359 .collect::<Vec<_>>()
360 }
361}
362
363impl PickerDelegate for LanguageModelPickerDelegate {
364 type ListItem = AnyElement;
365
366 fn match_count(&self) -> usize {
367 self.filtered_entries.len()
368 }
369
370 fn selected_index(&self) -> usize {
371 self.selected_index
372 }
373
374 fn set_selected_index(&mut self, ix: usize, _: &mut Window, cx: &mut Context<Picker<Self>>) {
375 self.selected_index = ix.min(self.filtered_entries.len().saturating_sub(1));
376 cx.notify();
377 }
378
379 fn can_select(
380 &mut self,
381 ix: usize,
382 _window: &mut Window,
383 _cx: &mut Context<Picker<Self>>,
384 ) -> bool {
385 match self.filtered_entries.get(ix) {
386 Some(LanguageModelPickerEntry::Model(_)) => true,
387 Some(LanguageModelPickerEntry::Separator(_)) | None => false,
388 }
389 }
390
391 fn placeholder_text(&self, _window: &mut Window, _cx: &mut App) -> Arc<str> {
392 "Select a model…".into()
393 }
394
395 fn update_matches(
396 &mut self,
397 query: String,
398 window: &mut Window,
399 cx: &mut Context<Picker<Self>>,
400 ) -> Task<()> {
401 let all_models = self.all_models.clone();
402 let current_index = self.selected_index;
403 let bg_executor = cx.background_executor();
404
405 let language_model_registry = LanguageModelRegistry::global(cx);
406
407 let configured_providers = language_model_registry
408 .read(cx)
409 .providers()
410 .into_iter()
411 .filter(|provider| provider.is_authenticated(cx))
412 .collect::<Vec<_>>();
413
414 let configured_provider_ids = configured_providers
415 .iter()
416 .map(|provider| provider.id())
417 .collect::<Vec<_>>();
418
419 let recommended_models = all_models
420 .recommended
421 .iter()
422 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
423 .cloned()
424 .collect::<Vec<_>>();
425
426 let available_models = all_models
427 .model_infos()
428 .iter()
429 .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
430 .cloned()
431 .collect::<Vec<_>>();
432
433 let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
434 let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
435
436 let recommended = matcher_rec.exact_search(&query);
437 let all = matcher_all.fuzzy_search(&query);
438
439 let filtered_models = GroupedModels::new(all, recommended);
440
441 cx.spawn_in(window, async move |this, cx| {
442 this.update_in(cx, |this, window, cx| {
443 this.delegate.filtered_entries = filtered_models.entries();
444 // Preserve selection focus
445 let new_index = if current_index >= this.delegate.filtered_entries.len() {
446 0
447 } else {
448 current_index
449 };
450 this.set_selected_index(new_index, Some(picker::Direction::Down), true, window, cx);
451 cx.notify();
452 })
453 .ok();
454 })
455 }
456
457 fn confirm(&mut self, _secondary: bool, window: &mut Window, cx: &mut Context<Picker<Self>>) {
458 if let Some(LanguageModelPickerEntry::Model(model_info)) =
459 self.filtered_entries.get(self.selected_index)
460 {
461 let model = model_info.model.clone();
462 (self.on_model_changed)(model.clone(), cx);
463
464 let current_index = self.selected_index;
465 self.set_selected_index(current_index, window, cx);
466
467 cx.emit(DismissEvent);
468 }
469 }
470
471 fn dismissed(&mut self, _: &mut Window, cx: &mut Context<Picker<Self>>) {
472 cx.emit(DismissEvent);
473 }
474
475 fn render_match(
476 &self,
477 ix: usize,
478 selected: bool,
479 _: &mut Window,
480 cx: &mut Context<Picker<Self>>,
481 ) -> Option<Self::ListItem> {
482 match self.filtered_entries.get(ix)? {
483 LanguageModelPickerEntry::Separator(title) => Some(
484 div()
485 .px_2()
486 .pb_1()
487 .when(ix > 1, |this| {
488 this.mt_1()
489 .pt_2()
490 .border_t_1()
491 .border_color(cx.theme().colors().border_variant)
492 })
493 .child(
494 Label::new(title)
495 .size(LabelSize::XSmall)
496 .color(Color::Muted),
497 )
498 .into_any_element(),
499 ),
500 LanguageModelPickerEntry::Model(model_info) => {
501 let active_model = (self.get_active_model)(cx);
502 let active_provider_id = active_model.as_ref().map(|m| m.provider.id());
503 let active_model_id = active_model.map(|m| m.model.id());
504
505 let is_selected = Some(model_info.model.provider_id()) == active_provider_id
506 && Some(model_info.model.id()) == active_model_id;
507
508 let model_icon_color = if is_selected {
509 Color::Accent
510 } else {
511 Color::Muted
512 };
513
514 Some(
515 ListItem::new(ix)
516 .inset(true)
517 .spacing(ListItemSpacing::Sparse)
518 .toggle_state(selected)
519 .start_slot(
520 Icon::new(model_info.icon)
521 .color(model_icon_color)
522 .size(IconSize::Small),
523 )
524 .child(
525 h_flex()
526 .w_full()
527 .pl_0p5()
528 .gap_1p5()
529 .w(px(240.))
530 .child(Label::new(model_info.model.name().0.clone()).truncate()),
531 )
532 .end_slot(div().pr_3().when(is_selected, |this| {
533 this.child(
534 Icon::new(IconName::Check)
535 .color(Color::Accent)
536 .size(IconSize::Small),
537 )
538 }))
539 .into_any_element(),
540 )
541 }
542 }
543 }
544
545 fn render_footer(
546 &self,
547 _: &mut Window,
548 cx: &mut Context<Picker<Self>>,
549 ) -> Option<gpui::AnyElement> {
550 use feature_flags::FeatureFlagAppExt;
551
552 let plan = proto::Plan::ZedPro;
553
554 Some(
555 h_flex()
556 .w_full()
557 .border_t_1()
558 .border_color(cx.theme().colors().border_variant)
559 .p_1()
560 .gap_4()
561 .justify_between()
562 .when(cx.has_flag::<ZedProFeatureFlag>(), |this| {
563 this.child(match plan {
564 Plan::ZedPro => Button::new("zed-pro", "Zed Pro")
565 .icon(IconName::ZedAssistant)
566 .icon_size(IconSize::Small)
567 .icon_color(Color::Muted)
568 .icon_position(IconPosition::Start)
569 .on_click(|_, window, cx| {
570 window
571 .dispatch_action(Box::new(zed_actions::OpenAccountSettings), cx)
572 }),
573 Plan::Free | Plan::ZedProTrial => Button::new(
574 "try-pro",
575 if plan == Plan::ZedProTrial {
576 "Upgrade to Pro"
577 } else {
578 "Try Pro"
579 },
580 )
581 .on_click(|_, _, cx| cx.open_url(TRY_ZED_PRO_URL)),
582 })
583 })
584 .child(
585 Button::new("configure", "Configure")
586 .icon(IconName::Settings)
587 .icon_size(IconSize::Small)
588 .icon_color(Color::Muted)
589 .icon_position(IconPosition::Start)
590 .on_click(|_, window, cx| {
591 window.dispatch_action(
592 zed_actions::agent::OpenConfiguration.boxed_clone(),
593 cx,
594 );
595 }),
596 )
597 .into_any(),
598 )
599 }
600}
601
602#[cfg(test)]
603mod tests {
604 use super::*;
605 use futures::{future::BoxFuture, stream::BoxStream};
606 use gpui::{AsyncApp, TestAppContext, http_client};
607 use language_model::{
608 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
609 LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
610 LanguageModelRequest, LanguageModelToolChoice,
611 };
612 use ui::IconName;
613
614 #[derive(Clone)]
615 struct TestLanguageModel {
616 name: LanguageModelName,
617 id: LanguageModelId,
618 provider_id: LanguageModelProviderId,
619 provider_name: LanguageModelProviderName,
620 }
621
622 impl TestLanguageModel {
623 fn new(name: &str, provider: &str) -> Self {
624 Self {
625 name: LanguageModelName::from(name.to_string()),
626 id: LanguageModelId::from(name.to_string()),
627 provider_id: LanguageModelProviderId::from(provider.to_string()),
628 provider_name: LanguageModelProviderName::from(provider.to_string()),
629 }
630 }
631 }
632
633 impl LanguageModel for TestLanguageModel {
634 fn id(&self) -> LanguageModelId {
635 self.id.clone()
636 }
637
638 fn name(&self) -> LanguageModelName {
639 self.name.clone()
640 }
641
642 fn provider_id(&self) -> LanguageModelProviderId {
643 self.provider_id.clone()
644 }
645
646 fn provider_name(&self) -> LanguageModelProviderName {
647 self.provider_name.clone()
648 }
649
650 fn supports_tools(&self) -> bool {
651 false
652 }
653
654 fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
655 false
656 }
657
658 fn supports_images(&self) -> bool {
659 false
660 }
661
662 fn telemetry_id(&self) -> String {
663 format!("{}/{}", self.provider_id.0, self.name.0)
664 }
665
666 fn max_token_count(&self) -> u64 {
667 1000
668 }
669
670 fn count_tokens(
671 &self,
672 _: LanguageModelRequest,
673 _: &App,
674 ) -> BoxFuture<'static, http_client::Result<u64>> {
675 unimplemented!()
676 }
677
678 fn stream_completion(
679 &self,
680 _: LanguageModelRequest,
681 _: &AsyncApp,
682 ) -> BoxFuture<
683 'static,
684 Result<
685 BoxStream<
686 'static,
687 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
688 >,
689 LanguageModelCompletionError,
690 >,
691 > {
692 unimplemented!()
693 }
694 }
695
696 fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
697 model_specs
698 .into_iter()
699 .map(|(provider, name)| ModelInfo {
700 model: Arc::new(TestLanguageModel::new(name, provider)),
701 icon: IconName::Ai,
702 })
703 .collect()
704 }
705
706 fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
707 assert_eq!(
708 result.len(),
709 expected.len(),
710 "Number of models doesn't match"
711 );
712
713 for (i, expected_name) in expected.iter().enumerate() {
714 assert_eq!(
715 result[i].model.telemetry_id(),
716 *expected_name,
717 "Model at position {} doesn't match expected model",
718 i
719 );
720 }
721 }
722
723 #[gpui::test]
724 fn test_exact_match(cx: &mut TestAppContext) {
725 let models = create_models(vec![
726 ("zed", "Claude 3.7 Sonnet"),
727 ("zed", "Claude 3.7 Sonnet Thinking"),
728 ("zed", "gpt-4.1"),
729 ("zed", "gpt-4.1-nano"),
730 ("openai", "gpt-3.5-turbo"),
731 ("openai", "gpt-4.1"),
732 ("openai", "gpt-4.1-nano"),
733 ("ollama", "mistral"),
734 ("ollama", "deepseek"),
735 ]);
736 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
737
738 // The order of models should be maintained, case doesn't matter
739 let results = matcher.exact_search("GPT-4.1");
740 assert_models_eq(
741 results,
742 vec![
743 "zed/gpt-4.1",
744 "zed/gpt-4.1-nano",
745 "openai/gpt-4.1",
746 "openai/gpt-4.1-nano",
747 ],
748 );
749 }
750
751 #[gpui::test]
752 fn test_fuzzy_match(cx: &mut TestAppContext) {
753 let models = create_models(vec![
754 ("zed", "Claude 3.7 Sonnet"),
755 ("zed", "Claude 3.7 Sonnet Thinking"),
756 ("zed", "gpt-4.1"),
757 ("zed", "gpt-4.1-nano"),
758 ("openai", "gpt-3.5-turbo"),
759 ("openai", "gpt-4.1"),
760 ("openai", "gpt-4.1-nano"),
761 ("ollama", "mistral"),
762 ("ollama", "deepseek"),
763 ]);
764 let matcher = ModelMatcher::new(models, cx.background_executor.clone());
765
766 // Results should preserve models order whenever possible.
767 // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
768 // similarity scores, but `zed/gpt-4.1` was higher in the models list,
769 // so it should appear first in the results.
770 let results = matcher.fuzzy_search("41");
771 assert_models_eq(
772 results,
773 vec![
774 "zed/gpt-4.1",
775 "openai/gpt-4.1",
776 "zed/gpt-4.1-nano",
777 "openai/gpt-4.1-nano",
778 ],
779 );
780
781 // Model provider should be searchable as well
782 let results = matcher.fuzzy_search("ol"); // meaning "ollama"
783 assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
784
785 // Fuzzy search
786 let results = matcher.fuzzy_search("z4n");
787 assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
788 }
789
790 #[gpui::test]
791 fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
792 let recommended_models = create_models(vec![("zed", "claude")]);
793 let all_models = create_models(vec![
794 ("zed", "claude"), // Should be filtered out from "other"
795 ("zed", "gemini"),
796 ("copilot", "o3"),
797 ]);
798
799 let grouped_models = GroupedModels::new(all_models, recommended_models);
800
801 let actual_other_models = grouped_models
802 .other
803 .values()
804 .flatten()
805 .cloned()
806 .collect::<Vec<_>>();
807
808 // Recommended models should not appear in "other"
809 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
810 }
811
812 #[gpui::test]
813 fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
814 let recommended_models = create_models(vec![("zed", "claude")]);
815 let all_models = create_models(vec![
816 ("zed", "claude"), // Should be filtered out from "other"
817 ("zed", "gemini"),
818 ("copilot", "claude"), // Should not be filtered out from "other"
819 ]);
820
821 let grouped_models = GroupedModels::new(all_models, recommended_models);
822
823 let actual_other_models = grouped_models
824 .other
825 .values()
826 .flatten()
827 .cloned()
828 .collect::<Vec<_>>();
829
830 // Recommended models should not appear in "other"
831 assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
832 }
833}