agent: Fuzzy search in model selector (#30281)

Oleksiy Syvokon created

This change enables fuzzy search on model providers and names. For
example, the query "z41" will match "zed/gpt-4.1".

Release Notes:

- Agent: Improved model selection with fuzzy search support

Change summary

Cargo.lock                                                    |   3 
crates/language_model_selector/Cargo.toml                     |  12 
crates/language_model_selector/src/language_model_selector.rs | 363 ++++
3 files changed, 336 insertions(+), 42 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -7813,9 +7813,12 @@ version = "0.1.0"
 dependencies = [
  "collections",
  "feature_flags",
+ "futures 0.3.31",
+ "fuzzy",
  "gpui",
  "language_model",
  "log",
+ "ordered-float 2.10.1",
  "picker",
  "proto",
  "ui",

crates/language_model_selector/Cargo.toml 🔗

@@ -11,14 +11,26 @@ workspace = true
 [lib]
 path = "src/language_model_selector.rs"
 
+[features]
+test-support = [
+    "gpui/test-support",
+]
+
 [dependencies]
 collections.workspace = true
 feature_flags.workspace = true
+futures.workspace = true
+fuzzy.workspace = true
 gpui.workspace = true
 language_model.workspace = true
 log.workspace = true
+ordered-float.workspace = true
 picker.workspace = true
 proto.workspace = true
 ui.workspace = true
 workspace-hack.workspace = true
 zed_actions.workspace = true
+
+[dev-dependencies]
+gpui = { workspace = true, "features" = ["test-support"] }
+language_model = { workspace = true, "features" = ["test-support"] }

crates/language_model_selector/src/language_model_selector.rs 🔗

@@ -1,15 +1,18 @@
-use std::sync::Arc;
+use std::{cmp::Reverse, sync::Arc};
 
 use collections::{HashSet, IndexMap};
 use feature_flags::ZedProFeatureFlag;
+use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
 use gpui::{
-    Action, AnyElement, AnyView, App, Corner, DismissEvent, Entity, EventEmitter, FocusHandle,
-    Focusable, Subscription, Task, WeakEntity, action_with_deprecated_aliases,
+    Action, AnyElement, AnyView, App, BackgroundExecutor, Corner, DismissEvent, Entity,
+    EventEmitter, FocusHandle, Focusable, Subscription, Task, WeakEntity,
+    action_with_deprecated_aliases,
 };
 use language_model::{
     AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
     LanguageModelRegistry,
 };
+use ordered_float::OrderedFloat;
 use picker::{Picker, PickerDelegate};
 use proto::Plan;
 use ui::{ListItem, ListItemSpacing, PopoverMenu, PopoverMenuHandle, PopoverTrigger, prelude::*};
@@ -322,6 +325,23 @@ struct GroupedModels {
 }
 
 impl GroupedModels {
+    pub fn new(other: Vec<ModelInfo>, recommended: Vec<ModelInfo>) -> Self {
+        let mut other_by_provider: IndexMap<_, Vec<ModelInfo>> = IndexMap::default();
+        for model in other {
+            let provider = model.model.provider_id();
+            if let Some(models) = other_by_provider.get_mut(&provider) {
+                models.push(model);
+            } else {
+                other_by_provider.insert(provider, vec![model]);
+            }
+        }
+
+        Self {
+            recommended,
+            other: other_by_provider,
+        }
+    }
+
     fn entries(&self) -> Vec<LanguageModelPickerEntry> {
         let mut entries = Vec::new();
 
@@ -349,6 +369,20 @@ impl GroupedModels {
         }
         entries
     }
+
+    fn model_infos(&self) -> Vec<ModelInfo> {
+        let other = self
+            .other
+            .values()
+            .flat_map(|model| model.iter())
+            .cloned()
+            .collect::<Vec<_>>();
+        self.recommended
+            .iter()
+            .chain(&other)
+            .cloned()
+            .collect::<Vec<_>>()
+    }
 }
 
 enum LanguageModelPickerEntry {
@@ -356,6 +390,78 @@ enum LanguageModelPickerEntry {
     Separator(SharedString),
 }
 
+struct ModelMatcher {
+    models: Vec<ModelInfo>,
+    bg_executor: BackgroundExecutor,
+    candidates: Vec<StringMatchCandidate>,
+}
+
+impl ModelMatcher {
+    fn new(models: Vec<ModelInfo>, bg_executor: BackgroundExecutor) -> ModelMatcher {
+        let candidates = Self::make_match_candidates(&models);
+        Self {
+            models,
+            bg_executor,
+            candidates,
+        }
+    }
+
+    pub fn fuzzy_search(&self, query: &str) -> Vec<ModelInfo> {
+        let mut matches = self.bg_executor.block(match_strings(
+            &self.candidates,
+            &query,
+            false,
+            100,
+            &Default::default(),
+            self.bg_executor.clone(),
+        ));
+
+        let sorting_key = |mat: &StringMatch| {
+            let candidate = &self.candidates[mat.candidate_id];
+            (Reverse(OrderedFloat(mat.score)), candidate.id)
+        };
+        matches.sort_unstable_by_key(sorting_key);
+
+        let matched_models: Vec<_> = matches
+            .into_iter()
+            .map(|mat| self.models[mat.candidate_id].clone())
+            .collect();
+
+        matched_models
+    }
+
+    pub fn exact_search(&self, query: &str) -> Vec<ModelInfo> {
+        self.models
+            .iter()
+            .filter(|m| {
+                m.model
+                    .name()
+                    .0
+                    .to_lowercase()
+                    .contains(&query.to_lowercase())
+            })
+            .cloned()
+            .collect::<Vec<_>>()
+    }
+
+    fn make_match_candidates(model_infos: &Vec<ModelInfo>) -> Vec<StringMatchCandidate> {
+        model_infos
+            .iter()
+            .enumerate()
+            .map(|(index, model)| {
+                StringMatchCandidate::new(
+                    index,
+                    &format!(
+                        "{}/{}",
+                        &model.model.provider_name().0,
+                        &model.model.name().0
+                    ),
+                )
+            })
+            .collect::<Vec<_>>()
+    }
+}
+
 impl PickerDelegate for LanguageModelPickerDelegate {
     type ListItem = AnyElement;
 
@@ -396,56 +502,45 @@ impl PickerDelegate for LanguageModelPickerDelegate {
     ) -> Task<()> {
         let all_models = self.all_models.clone();
         let current_index = self.selected_index;
+        let bg_executor = cx.background_executor();
 
         let language_model_registry = LanguageModelRegistry::global(cx);
 
         let configured_providers = language_model_registry
             .read(cx)
             .providers()
-            .iter()
+            .into_iter()
             .filter(|provider| provider.is_authenticated(cx))
+            .collect::<Vec<_>>();
+
+        let configured_provider_ids = configured_providers
+            .iter()
             .map(|provider| provider.id())
             .collect::<Vec<_>>();
 
-        cx.spawn_in(window, async move |this, cx| {
-            let filtered_models = cx
-                .background_spawn(async move {
-                    let matches = |info: &ModelInfo| {
-                        info.model
-                            .name()
-                            .0
-                            .to_lowercase()
-                            .contains(&query.to_lowercase())
-                    };
-
-                    let recommended_models = all_models
-                        .recommended
-                        .iter()
-                        .filter(|r| {
-                            configured_providers.contains(&r.model.provider_id()) && matches(r)
-                        })
-                        .cloned()
-                        .collect();
-                    let mut other_models = IndexMap::default();
-                    for (provider_id, models) in &all_models.other {
-                        if configured_providers.contains(&provider_id) {
-                            other_models.insert(
-                                provider_id.clone(),
-                                models
-                                    .iter()
-                                    .filter(|m| matches(m))
-                                    .cloned()
-                                    .collect::<Vec<_>>(),
-                            );
-                        }
-                    }
-                    GroupedModels {
-                        recommended: recommended_models,
-                        other: other_models,
-                    }
-                })
-                .await;
+        let recommended_models = all_models
+            .recommended
+            .iter()
+            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
+            .cloned()
+            .collect::<Vec<_>>();
+
+        let available_models = all_models
+            .model_infos()
+            .iter()
+            .filter(|m| configured_provider_ids.contains(&m.model.provider_id()))
+            .cloned()
+            .collect::<Vec<_>>();
+
+        let matcher_rec = ModelMatcher::new(recommended_models, bg_executor.clone());
+        let matcher_all = ModelMatcher::new(available_models, bg_executor.clone());
+
+        let recommended = matcher_rec.exact_search(&query);
+        let all = matcher_all.fuzzy_search(&query);
 
+        let filtered_models = GroupedModels::new(all, recommended);
+
+        cx.spawn_in(window, async move |this, cx| {
             this.update_in(cx, |this, window, cx| {
                 this.delegate.filtered_entries = filtered_models.entries();
                 // Preserve selection focus
@@ -607,3 +702,187 @@ impl PickerDelegate for LanguageModelPickerDelegate {
         )
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use futures::{future::BoxFuture, stream::BoxStream};
+    use gpui::{AsyncApp, TestAppContext, http_client};
+    use language_model::{
+        LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
+        LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
+        LanguageModelRequest, LanguageModelToolChoice,
+    };
+    use ui::IconName;
+
+    #[derive(Clone)]
+    struct TestLanguageModel {
+        name: LanguageModelName,
+        id: LanguageModelId,
+        provider_id: LanguageModelProviderId,
+        provider_name: LanguageModelProviderName,
+    }
+
+    impl TestLanguageModel {
+        fn new(name: &str, provider: &str) -> Self {
+            Self {
+                name: LanguageModelName::from(name.to_string()),
+                id: LanguageModelId::from(name.to_string()),
+                provider_id: LanguageModelProviderId::from(provider.to_string()),
+                provider_name: LanguageModelProviderName::from(provider.to_string()),
+            }
+        }
+    }
+
+    impl LanguageModel for TestLanguageModel {
+        fn id(&self) -> LanguageModelId {
+            self.id.clone()
+        }
+
+        fn name(&self) -> LanguageModelName {
+            self.name.clone()
+        }
+
+        fn provider_id(&self) -> LanguageModelProviderId {
+            self.provider_id.clone()
+        }
+
+        fn provider_name(&self) -> LanguageModelProviderName {
+            self.provider_name.clone()
+        }
+
+        fn supports_tools(&self) -> bool {
+            false
+        }
+
+        fn supports_tool_choice(&self, _choice: LanguageModelToolChoice) -> bool {
+            false
+        }
+
+        fn telemetry_id(&self) -> String {
+            format!("{}/{}", self.provider_id.0, self.name.0)
+        }
+
+        fn max_token_count(&self) -> usize {
+            1000
+        }
+
+        fn count_tokens(
+            &self,
+            _: LanguageModelRequest,
+            _: &App,
+        ) -> BoxFuture<'static, http_client::Result<usize>> {
+            unimplemented!()
+        }
+
+        fn stream_completion(
+            &self,
+            _: LanguageModelRequest,
+            _: &AsyncApp,
+        ) -> BoxFuture<
+            'static,
+            http_client::Result<
+                BoxStream<
+                    'static,
+                    http_client::Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
+                >,
+            >,
+        > {
+            unimplemented!()
+        }
+    }
+
+    fn create_models(model_specs: Vec<(&str, &str)>) -> Vec<ModelInfo> {
+        model_specs
+            .into_iter()
+            .map(|(provider, name)| ModelInfo {
+                model: Arc::new(TestLanguageModel::new(name, provider)),
+                icon: IconName::Ai,
+            })
+            .collect()
+    }
+
+    fn assert_models_eq(result: Vec<ModelInfo>, expected: Vec<&str>) {
+        assert_eq!(
+            result.len(),
+            expected.len(),
+            "Number of models doesn't match"
+        );
+
+        for (i, expected_name) in expected.iter().enumerate() {
+            assert_eq!(
+                result[i].model.telemetry_id(),
+                *expected_name,
+                "Model at position {} doesn't match expected model",
+                i
+            );
+        }
+    }
+
+    #[gpui::test]
+    fn test_exact_match(cx: &mut TestAppContext) {
+        let models = create_models(vec![
+            ("zed", "Claude 3.7 Sonnet"),
+            ("zed", "Claude 3.7 Sonnet Thinking"),
+            ("zed", "gpt-4.1"),
+            ("zed", "gpt-4.1-nano"),
+            ("openai", "gpt-3.5-turbo"),
+            ("openai", "gpt-4.1"),
+            ("openai", "gpt-4.1-nano"),
+            ("ollama", "mistral"),
+            ("ollama", "deepseek"),
+        ]);
+        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
+
+        // The order of models should be maintained, case doesn't matter
+        let results = matcher.exact_search("GPT-4.1");
+        assert_models_eq(
+            results,
+            vec![
+                "zed/gpt-4.1",
+                "zed/gpt-4.1-nano",
+                "openai/gpt-4.1",
+                "openai/gpt-4.1-nano",
+            ],
+        );
+    }
+
+    #[gpui::test]
+    fn test_fuzzy_match(cx: &mut TestAppContext) {
+        let models = create_models(vec![
+            ("zed", "Claude 3.7 Sonnet"),
+            ("zed", "Claude 3.7 Sonnet Thinking"),
+            ("zed", "gpt-4.1"),
+            ("zed", "gpt-4.1-nano"),
+            ("openai", "gpt-3.5-turbo"),
+            ("openai", "gpt-4.1"),
+            ("openai", "gpt-4.1-nano"),
+            ("ollama", "mistral"),
+            ("ollama", "deepseek"),
+        ]);
+        let matcher = ModelMatcher::new(models, cx.background_executor.clone());
+
+        // Results should preserve models order whenever possible.
+        // In the case below, `zed/gpt-4.1` and `openai/gpt-4.1` have identical
+        // similarity scores, but `zed/gpt-4.1` was higher in the models list,
+        // so it should appear first in the results.
+        let results = matcher.fuzzy_search("41");
+        assert_models_eq(
+            results,
+            vec![
+                "zed/gpt-4.1",
+                "openai/gpt-4.1",
+                "zed/gpt-4.1-nano",
+                "openai/gpt-4.1-nano",
+            ],
+        );
+
+        // Model provider should be searchable as well
+        let results = matcher.fuzzy_search("ol"); // meaning "ollama"
+        assert_models_eq(results, vec!["ollama/mistral", "ollama/deepseek"]);
+
+        // Fuzzy search
+        let results = matcher.fuzzy_search("z4n");
+        assert_models_eq(results, vec!["zed/gpt-4.1-nano"]);
+    }
+}