@@ -1,6 +1,6 @@
use std::{cmp::Reverse, sync::Arc};
-use collections::{HashSet, IndexMap};
+use collections::IndexMap;
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, Subscription, Task};
use language_model::{
@@ -57,7 +57,7 @@ fn all_models(cx: &App) -> GroupedModels {
})
.collect();
- let other = providers
+ let all = providers
.iter()
.flat_map(|provider| {
provider
@@ -70,7 +70,7 @@ fn all_models(cx: &App) -> GroupedModels {
})
.collect();
- GroupedModels::new(other, recommended)
+ GroupedModels::new(all, recommended)
}
#[derive(Clone)]
@@ -210,33 +210,24 @@ impl LanguageModelPickerDelegate {
struct GroupedModels {
recommended: Vec<ModelInfo>,
- other: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
+ all: IndexMap<LanguageModelProviderId, Vec<ModelInfo>>,
}
impl GroupedModels {
- pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
- let recommended_ids = recommended
- .iter()
- .map(|info| (info.model.provider_id(), info.model.id()))
- .collect::<HashSet<_>>();
-
- let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
- for model in other {
- if recommended_ids.contains(&(model.model.provider_id(), model.model.id())) {
- continue;
- }
-
+ pub fn new(all: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
+ let mut all_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
+ for model in all {
let provider = model.model.provider_id();
- if let Some(models) = other_by_provider.get_mut(&provider) {
+ if let Some(models) = all_by_provider.get_mut(&provider) {
models.push(model);
} else {
- other_by_provider.insert(provider, vec![model]);
+ all_by_provider.insert(provider, vec![model]);
}
}
Self {
recommended,
- other: other_by_provider,
+ all: all_by_provider,
}
}
@@ -252,7 +243,7 @@ impl GroupedModels {
);
}
- for models in self.other.values() {
+ for models in self.all.values() {
if models.is_empty() {
continue;
}
@@ -267,20 +258,6 @@ impl GroupedModels {
}
entries
}
-
- fn model_infos(&self) -> Vec<ModelInfo> {
- let other = self
- .other
- .values()
- .flat_map(|model| model.iter())
- .cloned()
- .collect::<Vec<_>>();
- self.recommended
- .iter()
- .chain(&other)
- .cloned()
- .collect::<Vec<_>>()
- }
}
enum LanguageModelPickerEntry {
@@ -425,8 +402,9 @@ impl PickerDelegate for LanguageModelPickerDelegate {
.collect::<Vec<_>>();
let available_models = all_models
- .model_infos()
- .iter()
+ .all
+ .values()
+ .flat_map(|models| models.iter())
.filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
.cloned()
.collect::<Vec<_>>();
@@ -764,46 +742,52 @@ mod tests {
}
#[gpui::test]
- fn test_exclude_recommended_models(_cx: &mut TestAppContext) {
+ fn test_recommended_models_also_appear_in_other(_cx: &mut TestAppContext) {
let recommended_models = create_models(vec![("zed", "claude")]);
let all_models = create_models(vec![
- ("zed", "claude"), // Should be filtered out from "other"
+ ("zed", "claude"), // Should also appear in "other"
("zed", "gemini"),
("copilot", "o3"),
]);
let grouped_models = GroupedModels::new(all_models, recommended_models);
- let actual_other_models = grouped_models
- .other
+ let actual_all_models = grouped_models
+ .all
.values()
.flatten()
.cloned()
.collect::<Vec<_>>();
- // Recommended models should not appear in "other"
- assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/o3"]);
+ // Recommended models should also appear in "all"
+ assert_models_eq(
+ actual_all_models,
+ vec!["zed/claude", "zed/gemini", "copilot/o3"],
+ );
}
#[gpui::test]
- fn test_dont_exclude_models_from_other_providers(_cx: &mut TestAppContext) {
+ fn test_models_from_different_providers(_cx: &mut TestAppContext) {
let recommended_models = create_models(vec![("zed", "claude")]);
let all_models = create_models(vec![
- ("zed", "claude"), // Should be filtered out from "other"
+ ("zed", "claude"), // Should also appear in "other"
("zed", "gemini"),
- ("copilot", "claude"), // Should not be filtered out from "other"
+ ("copilot", "claude"), // Different provider, should appear in "other"
]);
let grouped_models = GroupedModels::new(all_models, recommended_models);
- let actual_other_models = grouped_models
- .other
+ let actual_all_models = grouped_models
+ .all
.values()
.flatten()
.cloned()
.collect::<Vec<_>>();
- // Recommended models should not appear in "other"
- assert_models_eq(actual_other_models, vec!["zed/gemini", "copilot/claude"]);
+ // All models should appear in "all" regardless of recommended status
+ assert_models_eq(
+ actual_all_models,
+ vec!["zed/claude", "zed/gemini", "copilot/claude"],
+ );
}
}