Cargo.lock 🔗
@@ -7813,9 +7813,12 @@ version = "0.1.0"
dependencies = [
"collections",
"feature_flags",
+ "futures 0.3.31",
+ "fuzzy",
"gpui",
"language_model",
"log",
+ "ordered-float 2.10.1",
"picker",
"proto",
"ui",
Oleksiy Syvokon created
This change enables fuzzy search on model providers and names. For
example, the query "z41" will match "zed/gpt-4.1".
Release Notes:
- Agent: Improved model selection with fuzzy search support
Cargo.lock | 3
crates/language_model_selector/Cargo.toml | 12
crates/language_model_selector/src/language_model_selector.rs | 363 ++++
3 files changed, 336 insertions(+), 42 deletions(-)
@@ -7813,9 +7813,12 @@ version = "0.1.0"
dependencies = [
"collections",
"feature_flags",
+ "futures 0.3.31",
+ "fuzzy",
"gpui",
"language_model",
"log",
+ "ordered-float 2.10.1",
"picker",
"proto",
"ui",
@@ -11,14 +11,26 @@ workspace = true
[lib]
path = "src/language_model_selector.rs"
+[features]
+test-support = [
+ "gpui/test-support",
+]
+
[dependencies]
collections.workspace = true
feature_flags.workspace = true
+futures.workspace = true
+fuzzy.workspace = true
gpui.workspace = true
language_model.workspace = true
log.workspace = true
+ordered-float.workspace = true
picker.workspace = true
proto.workspace = true
ui.workspace = true
workspace-hack.workspace = true
zed_actions.workspace = true
+
+[dev-dependencies]
+gpui = { workspace = true, "features" = ["test-support"] }
+language_model = { workspace = true, "features" = ["test-support"] }
@@ -1,15 +1,18 @@
-use std::sync::Arc;
+use std::{cmp::Reverse, sync::Arc};
use collections::{HashSet, IndexMap};
use feature_flags::ZedProFeatureFlag;
+use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{
- Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
- Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
+ Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
+ EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
+ action_with_deprecated_aliases,
};
use language_model::{
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
LanguageModelRegistry,
};
+use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
use proto::Plan;
use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
@@ -322,6 +325,23 @@ struct GroupedModels {
}
impl GroupedModels {
+ pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
+ let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
+ for model in other {
+ let provider = model.model.provider_id();
+ if let Some(models) = other_by_provider.get_mut(&provider) {
+ models.push(model);
+ } else {
+ other_by_provider.insert(provider, vec![model]);
+ }
+ }
+
+ Self {
+ recommended,
+ other: other_by_provider,
+ }
+ }
+
fn entries(&self) -> Vec<LanguageModelPickerEntry> {
let mut entries = Vec::new();
@@ -349,6 +369,20 @@ impl GroupedModels {
}
entries
}
+
+ fn model_infos(&self) -> Vec<ModelInfo> {
+ let other = self
+ .other
+ .values()
+ .flat_map(|model| model.iter())
+ .cloned()
+ .collect::<Vec<_>>();
+ self.recommended
+ .iter()
+ .chain(&other)
+ .cloned()
+ .collect::<Vec<_>>()
+ }
}
enum LanguageModelPickerEntry {
@@ -356,6 +390,78 @@ enum LanguageModelPickerEntry {
Separator(SharedString),
}
+struct ModelMatcher {
+ models: Vec<ModelInfo>,
+ bg_executor: BackgroundExecutor,
+ candidates: Vec<StringMatchCandidate>,
+}
+
+impl ModelMatcher {
+ fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
+ let candidates = Self::make_match_candidates(&models);
+ Self {
+ models,
+ bg_executor,
+ candidates,
+ }
+ }
+
+ pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
+ let mut matches = self.bg_executor.block(match_strings(
+ &self.candidates,
+ &query,
+ false,
+ 100,
+ &Default::default(),
+ self.bg_executor.clone(),
+ ));
+
+ let sorting_key = |mat: &StringMatch| {
+ let candidate = &self.candidates[mat.candidate_id];
+ (Reverse(OrderedFloat(mat.score)), candidate.id)
+ };
+ matches.sort_unstable_by_key(sorting_key);
+
+ let matched_models: Vec<_> = matches
+ .into_iter()
+ .map(|mat| self.models[mat.candidate_id].clone())
+ .collect();
+
+ matched_models
+ }
+
+ pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
+ self.models
+ .iter()
+ .filter(|m| {
+ m.model
+ .name()
+ .0
+ .to_lowercase()
+ .contains(&query.to_lowercase())
+ })
+ .cloned()
+ .collect::<Vec<_>>()
+ }
+
+ fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
+ model_infos
+ .iter()
+ .enumerate()
+ .map(|(index, model)| {
+ StringMatchCandidate::new(
+ index,
+ &format!(
+ "{}/{}",
+ &model.model.provider_name().0,
+ &model.model.name().0
+ ),
+ )
+ })
+ .collect::<Vec<_>>()
+ }
+}
+
impl PickerDelegate for LanguageModelPickerDelegate {
type ListItem = AnyElement;
@@ -396,56 +502,45 @@ impl PickerDelegate for LanguageModelPickerDelegate {
) -> Task<()> {
let all_models = self.all_models.clone();
let current_index = self.selected_index;
+ let bg_executor = cx.background_executor();
let language_model_registry = LanguageModelRegistry::global(cx);
let configured_providers = language_model_registry
.read(cx)
.providers()
- .iter()
+ .into_iter()
.filter(|provider| provider.is_authenticated(cx))
+ .collect::<Vec<_>>();
+
+ let configured_provider_ids = configured_providers
+ .iter()
.map(|provider| provider.id())
.collect::<Vec<_>>();
- cx.spawn_in(window, async move |this, cx| {
- let filtered_models = cx
- .background_spawn(async move {
- let matches = |info: &ModelInfo| {
- info.model
- .name()
- .0
- .to_lowercase()
- .contains(&query.to_lowercase())
- };
-
- let recommended_models = all_models
- .recommended
- .iter()
- .filter(|r| {
- configured_providers.contains(&r.model.provider_id()) && matches(r)
- })
- .cloned()
- .collect();
- let mut other_models = IndexMap::default();
- for (provider_id, models) in &all_models.other {
- if configured_providers.contains(&provider_id) {
- other_models.insert(
- provider_id.clone(),
- models
- .iter()
- .filter(|m| matches(m))
- .cloned()
- .collect::<Vec<_>>(),
- );
- }
- }
- GroupedModels {
- recommended: recommended_models,
- other: other_models,
- }
- })
- .await;
+ let recommended_models = all_models
+ .recommended
+ .iter()
+ .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
+ .cloned()
+ .collect::<Vec<_>>();
+
+ let available_models = all_models
+ .model_infos()
+ .iter()
+ .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
+ .cloned()
+ .collect::<Vec<_>>();
+
+ let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
+ let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
+
+ let recommended = matcher_rec.exact_search(&query);
+ let all = matcher_all.fuzzy_search(&query);
+ let filtered_models = GroupedModels::new(all, recommended);
+
+ cx.spawn_in(window, async move |this, cx| {
this.update_in(cx, |this, window, cx| {
this.delegate.filtered_entries = filtered_models.entries();
// Preserve selection focus
@@ -607,3 +702,187 @@ impl PickerDelegate for LanguageModelPickerDelegate {
)
}
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use futures::{future::BoxFuture, stream::BoxStream};
+ use gpui::{AsyncApp, TestAppContext, http_client};
+ use language_model::{
+ LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
+ LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelRequest, LanguageModelToolChoice,
+ };
+ use ui::IconName;
+
+ #[derive(Clone)]
+ struct TestLanguageModel {
+ name: LanguageModelName,
+ id: LanguageModelId,
+ provider_id: LanguageModelProviderId,
+ provider_name: LanguageModelProviderName,
+ }
+
+ impl TestLanguageModel {
+ fn new(name: &str, provider: &str) -> Self {
+ Self {
+ name: LanguageModelName::from(name.to_string()),
+ id: LanguageModelId::from(name.to_string()),
+ provider_id: LanguageModelProviderId::from(provider.to_string()),
+ provider_name: LanguageModelProviderName::from(provider.to_string()),
+ }
+ }
+ }
+
+ impl LanguageModel for TestLanguageModel {
+ fn id(&self) -> LanguageModelId {
+ self.id.clone()
+ }
+
+ fn name(&self) -> LanguageModelName {
+ self.name.clone()
+ }
+
+ fn provider_id(&self) -> LanguageModelProviderId {
+ self.provider_id.clone()
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ self.provider_name.clone()
+ }
+
+ fn supports_tools(&self) -> bool {
+ false
+ }
+
+ fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
+ false
+ }
+
+ fn telemetry_id(&self) -> String {
+ format!("{}/{}", self.provider_id.0, self.name.0)
+ }
+
+ fn max_token_count(&self) -> usize {
+ 1000
+ }
+
+ fn count_tokens(
+ &self,
+ _: LanguageModelRequest,
+ _: &App,
+ ) -> BoxFuture<'static, http_client::Result<usize>> {
+ unimplemented!()
+ }
+
+ fn stream_completion(
+ &self,
+ _: LanguageModelRequest,
+ _: &AsyncApp,
+ ) -> BoxFuture<
+ 'static,
+ http_client::Result<
+ BoxStream<
+ 'static,
+ http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+ >,
+ >,
+ > {
+ unimplemented!()
+ }
+ }
+
+ fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
+ model_specs
+ .into_iter()
+ .map(|(provider, name)| ModelInfo {
+ model: Arc::new(TestLanguageModel::new(name, provider)),
+ icon: IconName::Ai,
+ })
+ .collect()
+ }
+
+ fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
+ assert_eq!(
+ result.len(),
+ expected.len(),
+ "Number of models doesn't match"
+ );
+
+ for (i, expected_name) in expected.iter().enumerate() {
+ assert_eq!(
+ result[i].model.telemetry_id(),
+ *expected_name,
+ "Model at position {} doesn't match expected model",
+ i
+ );
+ }
+ }
+
+ #[gpui::test]
+ fn test_exact_match(cx: &mut TestAppContext) {
+ let models = create_models(vec![
+ ("zed", "Claude 3.7 Sonnet"),
+ ("zed", "Claude 3.7 Sonnet Thinking"),
+ ("zed", "gpt-4.1"),
+ ("zed", "gpt-4.1-nano"),
+ ("openai", "gpt-3.5-turbo"),
+ ("openai", "gpt-4.1"),
+ ("openai", "gpt-4.1-nano"),
+ ("ollama", "mistral"),
+ ("ollama", "deepseek"),
+ ]);
+ let matcher = ModelMatcher::new(models, cx.background_executor.clone());
+
+ // The order of models should be maintained, case doesn't matter
+ let results = matcher.exact_search("GPT-4.1");
+ assert_models_eq(
+ results,
+ vec![
+ "zed/gpt-4.1",
+ "zed/gpt-4.1-nano",
+ "openai/gpt-4.1",
+ "openai/gpt-4.1-nano",
+ ],
+ );
+ }
+
+ #[gpui::test]
+ fn test_fuzzy_match(cx: &mut TestAppContext) {
+ let models = create_models(vec![
+ ("zed", "Claude 3.7 Sonnet"),
+ ("zed", "Claude 3.7 Sonnet Thinking"),
+ ("zed", "gpt-4.1"),
+ ("zed", "gpt-4.1-nano"),
+ ("openai", "gpt-3.5-turbo"),
+ ("openai", "gpt-4.1"),
+ ("openai", "gpt-4.1-nano"),
+ ("ollama", "mistral"),
+ ("ollama", "deepseek"),
+ ]);
+ let matcher = ModelMatcher::new(models, cx.background_executor.clone());
+
+ // Results should preserve models order whenever possible.
+ // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
+ // similarity scores, but `zed/gpt-4.1` was higher in the models list,
+ // so it should appear first in the results.
+ let results = matcher.fuzzy_search("41");
+ assert_models_eq(
+ results,
+ vec![
+ "zed/gpt-4.1",
+ "openai/gpt-4.1",
+ "zed/gpt-4.1-nano",
+ "openai/gpt-4.1-nano",
+ ],
+ );
+
+ // Model provider should be searchable as well
+ let results = matcher.fuzzy_search("ol"); // meaning "ollama"
+ assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
+
+ // Fuzzy search
+ let results = matcher.fuzzy_search("z4n");
+ assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
+ }
+}