Fix model deduplication to use provider ID and model ID (#31751)

Ben Brandt created

Replicates https://github.com/zed-industries/zed/pull/31750 for v0.188.x
branch

Release Notes:

- Fix to make sure all provider models are shown in the model picker

Change summary

crates/assistant_context_editor/src/language_model_selector.rs | 87 ++-
1 file changed, 47 insertions(+), 40 deletions(-)

Detailed changes

crates/assistant_context_editor/src/language_model_selector.rs 🔗

@@ -155,53 +155,35 @@ impl LanguageModelSelector {
     }
 
     fn all_models(cx: &App) -> GroupedModels {
-        let mut recommended = Vec::new();
-        let mut recommended_set = HashSet::default();
-        for provider in LanguageModelRegistry::global(cx)
-            .read(cx)
-            .providers()
+        let providers = LanguageModelRegistry::global(cx).read(cx).providers();
+
+        let recommended = providers
             .iter()
-        {
-            let models = provider.recommended_models(cx);
-            recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
-            recommended.extend(
+            .flat_map(|provider| {
                 provider
                     .recommended_models(cx)
                     .into_iter()
-                    .map(move |model| ModelInfo {
-                        model: model.clone(),
+                    .map(|model| ModelInfo {
+                        model,
                         icon: provider.icon(),
-                    }),
-            );
-        }
+                    })
+            })
+            .collect();
 
-        let other_models = LanguageModelRegistry::global(cx)
-            .read(cx)
-            .providers()
+        let other = providers
             .iter()
-            .map(|provider| {
-                (
-                    provider.id(),
-                    provider
-                        .provided_models(cx)
-                        .into_iter()
-                        .filter_map(|model| {
-                            let not_included =
-                                !recommended_set.contains(&(model.provider_id(), model.id()));
-                            not_included.then(|| ModelInfo {
-                                model: model.clone(),
-                                icon: provider.icon(),
-                            })
-                        })
-                        .collect::<Vec<_>>(),
-                )
+            .flat_map(|provider| {
+                provider
+                    .provided_models(cx)
+                    .into_iter()
+                    .map(|model| ModelInfo {
+                        model,
+                        icon: provider.icon(),
+                    })
             })
-            .collect::<IndexMap<_, _>>();
+            .collect();
 
-        GroupedModels {
-            recommended,
-            other: other_models,
-        }
+        GroupedModels::new(other, recommended)
     }
 
     pub fn active_model(&self, cx: &App) -> Option<ConfiguredModel> {
@@ -326,11 +308,14 @@ struct GroupedModels {
 
 impl GroupedModels {
     pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
-        let recommended_ids: HashSet<_> = recommended.iter().map(|info| info.model.id()).collect();
+        let recommended_ids = recommended
+            .iter()
+            .map(|info| (info.model.provider_id(), info.model.id()))
+            .collect::<HashSet<_>>();
 
         let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
         for model in other {
-            if recommended_ids.contains(&model.model.id()) {
+            if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
                 continue;
             }
 
@@ -917,4 +902,26 @@ mod tests {
         // Recommended models should not appear in "other"
         assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
     }
+
+    #[gpui::test]
+    fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
+        let recommended_models = create_models(vec![("zed", "claude")]);
+        let all_models = create_models(vec![
+            ("zed", "claude"), // Should be filtered out from "other"
+            ("zed", "gemini"),
+            ("copilot", "claude"), // Should not be filtered out from "other"
+        ]);
+
+        let grouped_models = GroupedModels::new(all_models, recommended_models);
+
+        let actual_other_models = grouped_models
+            .other
+            .values()
+            .flatten()
+            .cloned()
+            .collect::<Vec<_>>();
+
+        // Recommended models should not appear in "other"
+        assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
+    }
 }