Simplify language model registry + only emit change events on change (#29086)

Michael Sloan created

* Now only does default fallback logic in the registry

* Only emits change events when there is actually a change

Release Notes:

- N/A

Change summary

Cargo.lock                            |   1 
crates/assistant/src/assistant.rs     |  73 ++-------
crates/eval/src/eval.rs               |  41 ++---
crates/language_model/Cargo.toml      |   1 
crates/language_model/src/registry.rs | 197 +++++++++++-----------------
5 files changed, 119 insertions(+), 194 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -7651,7 +7651,6 @@ dependencies = [
  "http_client",
  "icons",
  "image",
- "log",
  "open_ai",
  "parking_lot",
  "proto",

crates/assistant/src/assistant.rs 🔗

@@ -8,7 +8,7 @@ mod terminal_inline_assistant;
 
 use std::sync::Arc;
 
-use assistant_settings::AssistantSettings;
+use assistant_settings::{AssistantSettings, LanguageModelSelection};
 use assistant_slash_command::SlashCommandRegistry;
 use client::Client;
 use command_palette_hooks::CommandPaletteFilter;
@@ -161,71 +161,38 @@ fn init_language_model_settings(cx: &mut App) {
 
 fn update_active_language_model_from_settings(cx: &mut App) {
     let settings = AssistantSettings::get_global(cx);
-    // Default model - used as fallback
-    let active_model_provider_name =
-        LanguageModelProviderId::from(settings.default_model.provider.clone());
-    let active_model_id = LanguageModelId::from(settings.default_model.model.clone());
 
-    // Inline assistant model
-    let inline_assistant_model = settings
+    fn to_selected_model(selection: &LanguageModelSelection) -> language_model::SelectedModel {
+        language_model::SelectedModel {
+            provider: LanguageModelProviderId::from(selection.provider.clone()),
+            model: LanguageModelId::from(selection.model.clone()),
+        }
+    }
+
+    let default = to_selected_model(&settings.default_model);
+    let inline_assistant = settings
         .inline_assistant_model
         .as_ref()
-        .unwrap_or(&settings.default_model);
-    let inline_assistant_provider_name =
-        LanguageModelProviderId::from(inline_assistant_model.provider.clone());
-    let inline_assistant_model_id = LanguageModelId::from(inline_assistant_model.model.clone());
-
-    // Commit message model
-    let commit_message_model = settings
+        .map(to_selected_model);
+    let commit_message = settings
         .commit_message_model
         .as_ref()
-        .unwrap_or(&settings.default_model);
-    let commit_message_provider_name =
-        LanguageModelProviderId::from(commit_message_model.provider.clone());
-    let commit_message_model_id = LanguageModelId::from(commit_message_model.model.clone());
-
-    // Thread summary model
-    let thread_summary_model = settings
+        .map(to_selected_model);
+    let thread_summary = settings
         .thread_summary_model
         .as_ref()
-        .unwrap_or(&settings.default_model);
-    let thread_summary_provider_name =
-        LanguageModelProviderId::from(thread_summary_model.provider.clone());
-    let thread_summary_model_id = LanguageModelId::from(thread_summary_model.model.clone());
-
+        .map(to_selected_model);
     let inline_alternatives = settings
         .inline_alternatives
         .iter()
-        .map(|alternative| {
-            (
-                LanguageModelProviderId::from(alternative.provider.clone()),
-                LanguageModelId::from(alternative.model.clone()),
-            )
-        })
+        .map(to_selected_model)
         .collect::<Vec<_>>();
 
     LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
-        // Set the default model
-        registry.select_default_model(&active_model_provider_name, &active_model_id, cx);
-
-        // Set the specific models
-        registry.select_inline_assistant_model(
-            &inline_assistant_provider_name,
-            &inline_assistant_model_id,
-            cx,
-        );
-        registry.select_commit_message_model(
-            &commit_message_provider_name,
-            &commit_message_model_id,
-            cx,
-        );
-        registry.select_thread_summary_model(
-            &thread_summary_provider_name,
-            &thread_summary_model_id,
-            cx,
-        );
-
-        // Set the alternatives
+        registry.select_default_model(Some(&default), cx);
+        registry.select_inline_assistant_model(inline_assistant.as_ref(), cx);
+        registry.select_commit_message_model(commit_message.as_ref(), cx);
+        registry.select_thread_summary_model(thread_summary.as_ref(), cx);
         registry.select_inline_alternative_models(inline_alternatives, cx);
     });
 }

crates/eval/src/eval.rs 🔗

@@ -11,12 +11,10 @@ use clap::Parser;
 use extension::ExtensionHostProxy;
 use futures::{StreamExt, future};
 use gpui::http_client::{Uri, read_proxy_from_env};
-use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, Task, UpdateGlobal};
+use gpui::{App, AppContext, Application, AsyncApp, Entity, SemanticVersion, UpdateGlobal};
 use gpui_tokio::Tokio;
 use language::LanguageRegistry;
-use language_model::{
-    AuthenticateError, LanguageModel, LanguageModelProviderId, LanguageModelRegistry,
-};
+use language_model::{ConfiguredModel, LanguageModel, LanguageModelRegistry};
 use node_runtime::{NodeBinaryOptions, NodeRuntime};
 use project::Project;
 use project::project_settings::ProjectSettings;
@@ -94,18 +92,25 @@ fn main() {
             .telemetry()
             .start(system_id, installation_id, session_id, cx);
 
-        let model = find_model("claude-3-7-sonnet-latest", cx).unwrap();
+        let model_registry = LanguageModelRegistry::read_global(cx);
+        let model = find_model("claude-3-7-sonnet-latest", model_registry, cx).unwrap();
+        let model_provider_id = model.provider_id();
+        let model_provider = model_registry.provider(&model_provider_id).unwrap();
 
         LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
-            registry.set_default_model(Some(model.clone()), cx);
+            registry.set_default_model(
+                Some(ConfiguredModel {
+                    provider: model_provider.clone(),
+                    model: model.clone(),
+                }),
+                cx,
+            );
         });
 
-        let model_provider_id = model.provider_id();
-
-        let authenticate = authenticate_model_provider(model_provider_id.clone(), cx);
+        let authenticate_task = model_provider.authenticate(cx);
 
         cx.spawn(async move |cx| {
-            authenticate.await.unwrap();
+            authenticate_task.await.unwrap();
 
             std::fs::create_dir_all(REPOS_DIR)?;
             std::fs::create_dir_all(WORKTREES_DIR)?;
@@ -498,8 +503,11 @@ pub fn init(cx: &mut App) -> Arc<AgentAppState> {
     })
 }
 
-pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn LanguageModel>> {
-    let model_registry = LanguageModelRegistry::read_global(cx);
+pub fn find_model(
+    model_name: &str,
+    model_registry: &LanguageModelRegistry,
+    cx: &App,
+) -> anyhow::Result<Arc<dyn LanguageModel>> {
     let model = model_registry
         .available_models(cx)
         .find(|model| model.id().0 == model_name);
@@ -519,15 +527,6 @@ pub fn find_model(model_name: &str, cx: &App) -> anyhow::Result<Arc<dyn Language
     Ok(model)
 }
 
-pub fn authenticate_model_provider(
-    provider_id: LanguageModelProviderId,
-    cx: &mut App,
-) -> Task<std::result::Result<(), AuthenticateError>> {
-    let model_registry = LanguageModelRegistry::read_global(cx);
-    let model_provider = model_registry.provider(&provider_id).unwrap();
-    model_provider.authenticate(cx)
-}
-
 pub async fn get_current_commit_id(repo_path: &Path) -> Option<String> {
     (run_git(repo_path, &["rev-parse", "HEAD"]).await).ok()
 }

crates/language_model/Cargo.toml 🔗

@@ -27,7 +27,6 @@ gpui.workspace = true
 http_client.workspace = true
 icons.workspace = true
 image.workspace = true
-log.workspace = true
 open_ai = { workspace = true, features = ["schemars"] }
 parking_lot.workspace = true
 proto.workspace = true

crates/language_model/src/registry.rs 🔗

@@ -25,12 +25,23 @@ pub struct LanguageModelRegistry {
     inline_alternatives: Vec<Arc<dyn LanguageModel>>,
 }
 
+pub struct SelectedModel {
+    pub provider: LanguageModelProviderId,
+    pub model: LanguageModelId,
+}
+
 #[derive(Clone)]
 pub struct ConfiguredModel {
     pub provider: Arc<dyn LanguageModelProvider>,
     pub model: Arc<dyn LanguageModel>,
 }
 
+impl ConfiguredModel {
+    pub fn is_same_as(&self, other: &ConfiguredModel) -> bool {
+        self.model.id() == other.model.id() && self.provider.id() == other.provider.id()
+    }
+}
+
 pub enum Event {
     DefaultModelChanged,
     InlineAssistantModelChanged,
@@ -59,7 +70,11 @@ impl LanguageModelRegistry {
             let mut registry = Self::default();
             registry.register_provider(fake_provider.clone(), cx);
             let model = fake_provider.provided_models(cx)[0].clone();
-            registry.set_default_model(Some(model), cx);
+            let configured_model = ConfiguredModel {
+                provider: Arc::new(fake_provider.clone()),
+                model,
+            };
+            registry.set_default_model(Some(configured_model), cx);
             registry
         });
         cx.set_global(GlobalLanguageModelRegistry(registry));
@@ -119,144 +134,114 @@ impl LanguageModelRegistry {
         self.providers.get(id).cloned()
     }
 
-    pub fn select_default_model(
-        &mut self,
-        provider: &LanguageModelProviderId,
-        model_id: &LanguageModelId,
-        cx: &mut Context<Self>,
-    ) {
-        let Some(provider) = self.provider(provider) else {
-            return;
-        };
-
-        let models = provider.provided_models(cx);
-        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
-            self.set_default_model(Some(model), cx);
-        }
+    pub fn select_default_model(&mut self, model: Option<&SelectedModel>, cx: &mut Context<Self>) {
+        let configured_model = model.and_then(|model| self.select_model(model, cx));
+        self.set_default_model(configured_model, cx);
     }
 
     pub fn select_inline_assistant_model(
         &mut self,
-        provider: &LanguageModelProviderId,
-        model_id: &LanguageModelId,
+        model: Option<&SelectedModel>,
         cx: &mut Context<Self>,
     ) {
-        let Some(provider) = self.provider(provider) else {
-            return;
-        };
-
-        let models = provider.provided_models(cx);
-        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
-            self.set_inline_assistant_model(Some(model), cx);
-        }
+        let configured_model = model.and_then(|model| self.select_model(model, cx));
+        self.set_inline_assistant_model(configured_model, cx);
     }
 
     pub fn select_commit_message_model(
         &mut self,
-        provider: &LanguageModelProviderId,
-        model_id: &LanguageModelId,
+        model: Option<&SelectedModel>,
         cx: &mut Context<Self>,
     ) {
-        let Some(provider) = self.provider(provider) else {
-            return;
-        };
-
-        let models = provider.provided_models(cx);
-        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
-            self.set_commit_message_model(Some(model), cx);
-        }
+        let configured_model = model.and_then(|model| self.select_model(model, cx));
+        self.set_commit_message_model(configured_model, cx);
     }
 
     pub fn select_thread_summary_model(
         &mut self,
-        provider: &LanguageModelProviderId,
-        model_id: &LanguageModelId,
+        model: Option<&SelectedModel>,
         cx: &mut Context<Self>,
     ) {
-        let Some(provider) = self.provider(provider) else {
-            return;
-        };
-
-        let models = provider.provided_models(cx);
-        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
-            self.set_thread_summary_model(Some(model), cx);
-        }
+        let configured_model = model.and_then(|model| self.select_model(model, cx));
+        self.set_thread_summary_model(configured_model, cx);
     }
 
-    pub fn set_default_model(
+    /// Selects and sets the inline alternatives for language models based on
+    /// provider name and id.
+    pub fn select_inline_alternative_models(
         &mut self,
-        model: Option<Arc<dyn LanguageModel>>,
+        alternatives: impl IntoIterator<Item = SelectedModel>,
         cx: &mut Context<Self>,
     ) {
-        if let Some(model) = model {
-            let provider_id = model.provider_id();
-            if let Some(provider) = self.providers.get(&provider_id).cloned() {
-                self.default_model = Some(ConfiguredModel { provider, model });
-                cx.emit(Event::DefaultModelChanged);
-            } else {
-                log::warn!("Active model's provider not found in registry");
-            }
-        } else {
-            self.default_model = None;
-            cx.emit(Event::DefaultModelChanged);
+        self.inline_alternatives = alternatives
+            .into_iter()
+            .flat_map(|alternative| {
+                self.select_model(&alternative, cx)
+                    .map(|configured_model| configured_model.model)
+            })
+            .collect::<Vec<_>>();
+    }
+
+    fn select_model(
+        &mut self,
+        selected_model: &SelectedModel,
+        cx: &mut Context<Self>,
+    ) -> Option<ConfiguredModel> {
+        let provider = self.provider(&selected_model.provider)?;
+        let model = provider
+            .provided_models(cx)
+            .iter()
+            .find(|model| model.id() == selected_model.model)?
+            .clone();
+        Some(ConfiguredModel { provider, model })
+    }
+
+    pub fn set_default_model(&mut self, model: Option<ConfiguredModel>, cx: &mut Context<Self>) {
+        match (self.default_model.as_ref(), model.as_ref()) {
+            (Some(old), Some(new)) if old.is_same_as(new) => {}
+            (None, None) => {}
+            _ => cx.emit(Event::DefaultModelChanged),
         }
+        self.default_model = model;
     }
 
     pub fn set_inline_assistant_model(
         &mut self,
-        model: Option<Arc<dyn LanguageModel>>,
+        model: Option<ConfiguredModel>,
         cx: &mut Context<Self>,
     ) {
-        if let Some(model) = model {
-            let provider_id = model.provider_id();
-            if let Some(provider) = self.providers.get(&provider_id).cloned() {
-                self.inline_assistant_model = Some(ConfiguredModel { provider, model });
-                cx.emit(Event::InlineAssistantModelChanged);
-            } else {
-                log::warn!("Inline assistant model's provider not found in registry");
-            }
-        } else {
-            self.inline_assistant_model = None;
-            cx.emit(Event::InlineAssistantModelChanged);
+        match (self.inline_assistant_model.as_ref(), model.as_ref()) {
+            (Some(old), Some(new)) if old.is_same_as(new) => {}
+            (None, None) => {}
+            _ => cx.emit(Event::InlineAssistantModelChanged),
         }
+        self.inline_assistant_model = model;
     }
 
     pub fn set_commit_message_model(
         &mut self,
-        model: Option<Arc<dyn LanguageModel>>,
+        model: Option<ConfiguredModel>,
         cx: &mut Context<Self>,
     ) {
-        if let Some(model) = model {
-            let provider_id = model.provider_id();
-            if let Some(provider) = self.providers.get(&provider_id).cloned() {
-                self.commit_message_model = Some(ConfiguredModel { provider, model });
-                cx.emit(Event::CommitMessageModelChanged);
-            } else {
-                log::warn!("Commit message model's provider not found in registry");
-            }
-        } else {
-            self.commit_message_model = None;
-            cx.emit(Event::CommitMessageModelChanged);
+        match (self.commit_message_model.as_ref(), model.as_ref()) {
+            (Some(old), Some(new)) if old.is_same_as(new) => {}
+            (None, None) => {}
+            _ => cx.emit(Event::CommitMessageModelChanged),
         }
+        self.commit_message_model = model;
     }
 
     pub fn set_thread_summary_model(
         &mut self,
-        model: Option<Arc<dyn LanguageModel>>,
+        model: Option<ConfiguredModel>,
         cx: &mut Context<Self>,
     ) {
-        if let Some(model) = model {
-            let provider_id = model.provider_id();
-            if let Some(provider) = self.providers.get(&provider_id).cloned() {
-                self.thread_summary_model = Some(ConfiguredModel { provider, model });
-                cx.emit(Event::ThreadSummaryModelChanged);
-            } else {
-                log::warn!("Thread summary model's provider not found in registry");
-            }
-        } else {
-            self.thread_summary_model = None;
-            cx.emit(Event::ThreadSummaryModelChanged);
+        match (self.thread_summary_model.as_ref(), model.as_ref()) {
+            (Some(old), Some(new)) if old.is_same_as(new) => {}
+            (None, None) => {}
+            _ => cx.emit(Event::ThreadSummaryModelChanged),
         }
+        self.thread_summary_model = model;
     }
 
     pub fn default_model(&self) -> Option<ConfiguredModel> {
@@ -286,30 +271,6 @@ impl LanguageModelRegistry {
             .or_else(|| self.default_model())
     }
 
-    /// Selects and sets the inline alternatives for language models based on
-    /// provider name and id.
-    pub fn select_inline_alternative_models(
-        &mut self,
-        alternatives: impl IntoIterator<Item = (LanguageModelProviderId, LanguageModelId)>,
-        cx: &mut Context<Self>,
-    ) {
-        let mut selected_alternatives = Vec::new();
-
-        for (provider_id, model_id) in alternatives {
-            if let Some(provider) = self.providers.get(&provider_id) {
-                if let Some(model) = provider
-                    .provided_models(cx)
-                    .iter()
-                    .find(|m| m.id() == model_id)
-                {
-                    selected_alternatives.push(model.clone());
-                }
-            }
-        }
-
-        self.inline_alternatives = selected_alternatives;
-    }
-
     /// The models to use for inline assists. Returns the union of the active
     /// model and all inline alternatives. When there are multiple models, the
     /// user will be able to cycle through results.