assistant: Ensure that zed provider is listed as first option (#15496)

Bennet Bo Fenner created

Release Notes:

- N/A

Change summary

crates/assistant/src/model_selector.rs |  1 +
crates/language_model/src/registry.rs  | 20 ++++++++++++++++----
2 files changed, 17 insertions(+), 4 deletions(-)

Detailed changes

crates/assistant/src/model_selector.rs 🔗

@@ -60,6 +60,7 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
                 for (index, provider) in LanguageModelRegistry::global(cx)
                     .read(cx)
                     .providers()
+                    .into_iter()
                     .enumerate()
                 {
                     if index > 0 {

crates/language_model/src/registry.rs 🔗

@@ -132,8 +132,20 @@ impl LanguageModelRegistry {
         }
     }
 
-    pub fn providers(&self) -> impl Iterator<Item = &Arc<dyn LanguageModelProvider>> {
-        self.providers.values()
+    pub fn providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
+        let zed_provider_id = LanguageModelProviderId(crate::provider::cloud::PROVIDER_ID.into());
+        let mut providers = Vec::with_capacity(self.providers.len());
+        if let Some(provider) = self.providers.get(&zed_provider_id) {
+            providers.push(provider.clone());
+        }
+        providers.extend(self.providers.values().filter_map(|p| {
+            if p.id() != zed_provider_id {
+                Some(p.clone())
+            } else {
+                None
+            }
+        }));
+        providers
     }
 
     pub fn available_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
@@ -222,7 +234,7 @@ mod tests {
             registry.register_provider(FakeLanguageModelProvider::default(), cx);
         });
 
-        let providers = registry.read(cx).providers().collect::<Vec<_>>();
+        let providers = registry.read(cx).providers();
         assert_eq!(providers.len(), 1);
         assert_eq!(providers[0].id(), crate::provider::fake::provider_id());
 
@@ -230,7 +242,7 @@ mod tests {
             registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
         });
 
-        let providers = registry.read(cx).providers().collect::<Vec<_>>();
+        let providers = registry.read(cx).providers();
         assert!(providers.is_empty());
     }
 }