@@ -46,53 +46,35 @@ pub fn language_model_selector(
}
fn all_models(cx: &App) -> GroupedModels {
- let mut recommended = Vec::new();
- let mut recommended_set = HashSet::default();
- for provider in LanguageModelRegistry::global(cx)
- .read(cx)
- .providers()
+ let providers = LanguageModelRegistry::global(cx).read(cx).providers();
+
+ let recommended = providers
.iter()
- {
- let models = provider.recommended_models(cx);
- recommended_set.extend(models.iter().map(|model| (model.provider_id(), model.id())));
- recommended.extend(
+ .flat_map(|provider| {
provider
.recommended_models(cx)
.into_iter()
- .map(move |model| ModelInfo {
- model: model.clone(),
+ .map(|model| ModelInfo {
+ model,
icon: provider.icon(),
- }),
- );
- }
+ })
+ })
+ .collect();
- let other_models = LanguageModelRegistry::global(cx)
- .read(cx)
- .providers()
+ let other = providers
.iter()
- .map(|provider| {
- (
- provider.id(),
- provider
- .provided_models(cx)
- .into_iter()
- .filter_map(|model| {
- let not_included =
- !recommended_set.contains(&(model.provider_id(), model.id()));
- not_included.then(|| ModelInfo {
- model: model.clone(),
- icon: provider.icon(),
- })
- })
- .collect::<Vec<_>>(),
- )
+ .flat_map(|provider| {
+ provider
+ .provided_models(cx)
+ .into_iter()
+ .map(|model| ModelInfo {
+ model,
+ icon: provider.icon(),
+ })
})
- .collect::<IndexMap<_, _>>();
+ .collect();
- GroupedModels {
- recommended,
- other: other_models,
- }
+ GroupedModels::new(other, recommended)
}
#[derive(Clone)]
@@ -234,11 +216,14 @@ struct GroupedModels {
impl GroupedModels {
pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
- let recommended_ids: HashSet<_> = recommended.iter().map(|info| info.model.id()).collect();
+ let recommended_ids = recommended
+ .iter()
+ .map(|info| (info.model.provider_id(), info.model.id()))
+ .collect::<HashSet<_>>();
let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
for model in other {
- if recommended_ids.contains(&model.model.id()) {
+ if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
continue;
}
@@ -823,4 +808,26 @@ mod tests {
// Recommended models should not appear in "other"
assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
}
+
+ #[gpui::test]
+ fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
+ let recommended_models = create_models(vec![("zed", "claude")]);
+ let all_models = create_models(vec![
+ ("zed", "claude"), // Should be filtered out from "other"
+ ("zed", "gemini"),
+ ("copilot", "claude"), // Should not be filtered out from "other"
+ ]);
+
+ let grouped_models = GroupedModels::new(all_models, recommended_models);
+
+ let actual_other_models = grouped_models
+ .other
+ .values()
+ .flatten()
+ .cloned()
+ .collect::<Vec<_>>();
+
+ // Recommended models should not appear in "other"
+ assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
+ }
}