Allow customization of the model used for tool calling (#15479)

Antonio Scandurra and Nathan created

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>

Change summary

Cargo.lock                                         |  28 -
Cargo.toml                                         |   2 
crates/anthropic/src/anthropic.rs                  |  19 
crates/assistant/Cargo.toml                        |   2 
crates/assistant/src/assistant.rs                  |  23 
crates/assistant/src/assistant_panel.rs            |  34 +
crates/assistant/src/assistant_settings.rs         |  24 
crates/assistant/src/context.rs                    |  61 +-
crates/assistant/src/inline_assistant.rs           |  34 
crates/assistant/src/model_selector.rs             |  16 
crates/assistant/src/prompt_library.rs             |  19 
crates/assistant/src/terminal_inline_assistant.rs  |  36 
crates/collab/Cargo.toml                           |   1 
crates/collab/src/tests/test_server.rs             |   1 
crates/completion/Cargo.toml                       |  45 --
crates/completion/LICENSE-GPL                      |   1 
crates/completion/src/completion.rs                | 312 ----------------
crates/copilot/src/copilot_chat.rs                 |   6 
crates/language_model/Cargo.toml                   |   2 
crates/language_model/src/language_model.rs        |  28 +
crates/language_model/src/provider/anthropic.rs    |  55 +-
crates/language_model/src/provider/cloud.rs        |  93 ++--
crates/language_model/src/provider/copilot_chat.rs |  59 +-
crates/language_model/src/provider/fake.rs         |   6 
crates/language_model/src/provider/google.rs       |  16 
crates/language_model/src/provider/ollama.rs       |  68 +--
crates/language_model/src/provider/open_ai.rs      |  17 
crates/language_model/src/rate_limiter.rs          |  70 +++
crates/language_model/src/registry.rs              |  75 +++
crates/language_model/src/settings.rs              |  12 
crates/semantic_index/Cargo.toml                   |   1 
crates/semantic_index/src/semantic_index.rs        |   3 
32 files changed, 478 insertions(+), 691 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -406,7 +406,6 @@ dependencies = [
  "clock",
  "collections",
  "command_palette_hooks",
- "completion",
  "ctor",
  "editor",
  "env_logger",
@@ -2470,7 +2469,6 @@ dependencies = [
  "clock",
  "collab_ui",
  "collections",
- "completion",
  "ctor",
  "dashmap 6.0.1",
  "dev_server_projects",
@@ -2655,30 +2653,6 @@ dependencies = [
  "gpui",
 ]
 
-[[package]]
-name = "completion"
-version = "0.1.0"
-dependencies = [
- "anyhow",
- "ctor",
- "editor",
- "env_logger",
- "futures 0.3.28",
- "gpui",
- "language",
- "language_model",
- "project",
- "rand 0.8.5",
- "schemars",
- "serde",
- "serde_json",
- "settings",
- "smol",
- "text",
- "ui",
- "unindent",
-]
-
 [[package]]
 name = "concurrent-queue"
 version = "2.2.0"
@@ -6048,6 +6022,7 @@ dependencies = [
  "serde",
  "serde_json",
  "settings",
+ "smol",
  "strum",
  "text",
  "theme",
@@ -9506,7 +9481,6 @@ dependencies = [
  "client",
  "clock",
  "collections",
- "completion",
  "env_logger",
  "fs",
  "futures 0.3.28",

Cargo.toml 🔗

@@ -19,7 +19,6 @@ members = [
     "crates/collections",
     "crates/command_palette",
     "crates/command_palette_hooks",
-    "crates/completion",
     "crates/copilot",
     "crates/db",
     "crates/dev_server_projects",
@@ -190,7 +189,6 @@ collab_ui = { path = "crates/collab_ui" }
 collections = { path = "crates/collections" }
 command_palette = { path = "crates/command_palette" }
 command_palette_hooks = { path = "crates/command_palette_hooks" }
-completion = { path = "crates/completion" }
 copilot = { path = "crates/copilot" }
 db = { path = "crates/db" }
 dev_server_projects = { path = "crates/dev_server_projects" }

crates/anthropic/src/anthropic.rs 🔗

@@ -21,7 +21,12 @@ pub enum Model {
     #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
     Claude3Haiku,
     #[serde(rename = "custom")]
-    Custom { name: String, max_tokens: usize },
+    Custom {
+        name: String,
+        max_tokens: usize,
+        /// Override this model with a different Anthropic model for tool calls.
+        tool_override: Option<String>,
+    },
 }
 
 impl Model {
@@ -68,6 +73,18 @@ impl Model {
             Self::Custom { max_tokens, .. } => *max_tokens,
         }
     }
+
+    pub fn tool_model_id(&self) -> &str {
+        if let Self::Custom {
+            tool_override: Some(tool_override),
+            ..
+        } = self
+        {
+            tool_override
+        } else {
+            self.id()
+        }
+    }
 }
 
 pub async fn complete(

crates/assistant/Cargo.toml 🔗

@@ -32,7 +32,6 @@ client.workspace = true
 clock.workspace = true
 collections.workspace = true
 command_palette_hooks.workspace = true
-completion.workspace = true
 editor.workspace = true
 fs.workspace = true
 futures.workspace = true
@@ -77,7 +76,6 @@ workspace.workspace = true
 picker.workspace = true
 
 [dev-dependencies]
-completion = { workspace = true, features = ["test-support"] }
 ctor.workspace = true
 editor = { workspace = true, features = ["test-support"] }
 env_logger.workspace = true

crates/assistant/src/assistant.rs 🔗

@@ -15,7 +15,6 @@ use assistant_settings::AssistantSettings;
 use assistant_slash_command::SlashCommandRegistry;
 use client::{proto, Client};
 use command_palette_hooks::CommandPaletteFilter;
-use completion::LanguageModelCompletionProvider;
 pub use context::*;
 pub use context_store::*;
 use fs::Fs;
@@ -192,7 +191,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
 
     context_store::init(&client);
     prompt_library::init(cx);
-    init_completion_provider(cx);
+    init_language_model_settings(cx);
     assistant_slash_command::init(cx);
     register_slash_commands(cx);
     assistant_panel::init(cx);
@@ -217,8 +216,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
     .detach();
 }
 
-fn init_completion_provider(cx: &mut AppContext) {
-    completion::init(cx);
+fn init_language_model_settings(cx: &mut AppContext) {
     update_active_language_model_from_settings(cx);
 
     cx.observe_global::<SettingsStore>(update_active_language_model_from_settings)
@@ -233,20 +231,9 @@ fn update_active_language_model_from_settings(cx: &mut AppContext) {
     let settings = AssistantSettings::get_global(cx);
     let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
     let model_id = LanguageModelId::from(settings.default_model.model.clone());
-
-    let Some(provider) = LanguageModelRegistry::global(cx)
-        .read(cx)
-        .provider(&provider_name)
-    else {
-        return;
-    };
-
-    let models = provider.provided_models(cx);
-    if let Some(model) = models.iter().find(|model| model.id() == model_id).cloned() {
-        LanguageModelCompletionProvider::global(cx).update(cx, |completion_provider, cx| {
-            completion_provider.set_active_model(model, cx);
-        });
-    }
+    LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
+        registry.select_active_model(&provider_name, &model_id, cx);
+    });
 }
 
 fn register_slash_commands(cx: &mut AppContext) {

crates/assistant/src/assistant_panel.rs 🔗

@@ -19,7 +19,6 @@ use anyhow::{anyhow, Result};
 use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
 use client::proto;
 use collections::{BTreeSet, HashMap, HashSet};
-use completion::LanguageModelCompletionProvider;
 use editor::{
     actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
     display_map::{
@@ -43,7 +42,7 @@ use language::{
     language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
     ToOffset,
 };
-use language_model::{LanguageModelProviderId, Role};
+use language_model::{LanguageModelProviderId, LanguageModelRegistry, Role};
 use multi_buffer::MultiBufferRow;
 use picker::{Picker, PickerDelegate};
 use project::{Project, ProjectLspAdapterDelegate};
@@ -392,9 +391,9 @@ impl AssistantPanel {
             cx.subscribe(&context_editor_toolbar, Self::handle_toolbar_event),
             cx.subscribe(&model_summary_editor, Self::handle_summary_editor_event),
             cx.subscribe(&context_store, Self::handle_context_store_event),
-            cx.observe(
-                &LanguageModelCompletionProvider::global(cx),
-                |this, _, cx| {
+            cx.subscribe(
+                &LanguageModelRegistry::global(cx),
+                |this, _, _: &language_model::ActiveModelChanged, cx| {
                     this.completion_provider_changed(cx);
                 },
             ),
@@ -560,7 +559,7 @@ impl AssistantPanel {
             })
         }
 
-        let Some(new_provider_id) = LanguageModelCompletionProvider::read_global(cx)
+        let Some(new_provider_id) = LanguageModelRegistry::read_global(cx)
             .active_provider()
             .map(|p| p.id())
         else {
@@ -599,7 +598,7 @@ impl AssistantPanel {
     }
 
     fn authentication_prompt(cx: &mut WindowContext) -> Option<AnyView> {
-        if let Some(provider) = LanguageModelCompletionProvider::read_global(cx).active_provider() {
+        if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
             if !provider.is_authenticated(cx) {
                 return Some(provider.authentication_prompt(cx));
             }
@@ -904,9 +903,9 @@ impl AssistantPanel {
     }
 
     fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
-        LanguageModelCompletionProvider::read_global(cx)
-            .reset_credentials(cx)
-            .detach_and_log_err(cx);
+        if let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() {
+            provider.reset_credentials(cx).detach_and_log_err(cx);
+        }
     }
 
     fn toggle_model_selector(&mut self, _: &ToggleModelSelector, cx: &mut ViewContext<Self>) {
@@ -1041,11 +1040,18 @@ impl AssistantPanel {
     }
 
     fn is_authenticated(&mut self, cx: &mut ViewContext<Self>) -> bool {
-        LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx)
+        LanguageModelRegistry::read_global(cx)
+            .active_provider()
+            .map_or(false, |provider| provider.is_authenticated(cx))
     }
 
     fn authenticate(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
-        LanguageModelCompletionProvider::read_global(cx).authenticate(cx)
+        LanguageModelRegistry::read_global(cx)
+            .active_provider()
+            .map_or(
+                Task::ready(Err(anyhow!("no active language model provider"))),
+                |provider| provider.authenticate(cx),
+            )
     }
 
     fn render_signed_in(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
@@ -2707,7 +2713,7 @@ impl ContextEditorToolbarItem {
     }
 
     fn render_remaining_tokens(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
-        let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
+        let model = LanguageModelRegistry::read_global(cx).active_model()?;
         let context = &self
             .active_context_editor
             .as_ref()?
@@ -2779,7 +2785,7 @@ impl Render for ContextEditorToolbarItem {
                                         .whitespace_nowrap()
                                         .child(
                                             Label::new(
-                                                LanguageModelCompletionProvider::read_global(cx)
+                                                LanguageModelRegistry::read_global(cx)
                                                     .active_model()
                                                     .map(|model| model.name().0)
                                                     .unwrap_or_else(|| "No model selected".into()),

crates/assistant/src/assistant_settings.rs 🔗

@@ -52,7 +52,7 @@ pub struct AssistantSettings {
     pub dock: AssistantDockPosition,
     pub default_width: Pixels,
     pub default_height: Pixels,
-    pub default_model: AssistantDefaultModel,
+    pub default_model: LanguageModelSelection,
     pub using_outdated_settings_version: bool,
 }
 
@@ -198,25 +198,25 @@ impl AssistantSettingsContent {
                         .clone()
                         .and_then(|provider| match provider {
                             AssistantProviderContentV1::ZedDotDev { default_model } => {
-                                default_model.map(|model| AssistantDefaultModel {
+                                default_model.map(|model| LanguageModelSelection {
                                     provider: "zed.dev".to_string(),
                                     model: model.id().to_string(),
                                 })
                             }
                             AssistantProviderContentV1::OpenAi { default_model, .. } => {
-                                default_model.map(|model| AssistantDefaultModel {
+                                default_model.map(|model| LanguageModelSelection {
                                     provider: "openai".to_string(),
                                     model: model.id().to_string(),
                                 })
                             }
                             AssistantProviderContentV1::Anthropic { default_model, .. } => {
-                                default_model.map(|model| AssistantDefaultModel {
+                                default_model.map(|model| LanguageModelSelection {
                                     provider: "anthropic".to_string(),
                                     model: model.id().to_string(),
                                 })
                             }
                             AssistantProviderContentV1::Ollama { default_model, .. } => {
-                                default_model.map(|model| AssistantDefaultModel {
+                                default_model.map(|model| LanguageModelSelection {
                                     provider: "ollama".to_string(),
                                     model: model.id().to_string(),
                                 })
@@ -231,7 +231,7 @@ impl AssistantSettingsContent {
                 dock: settings.dock,
                 default_width: settings.default_width,
                 default_height: settings.default_height,
-                default_model: Some(AssistantDefaultModel {
+                default_model: Some(LanguageModelSelection {
                     provider: "openai".to_string(),
                     model: settings
                         .default_open_ai_model
@@ -325,7 +325,7 @@ impl AssistantSettingsContent {
                     _ => {}
                 },
                 VersionedAssistantSettingsContent::V2(settings) => {
-                    settings.default_model = Some(AssistantDefaultModel { provider, model });
+                    settings.default_model = Some(LanguageModelSelection { provider, model });
                 }
             },
             AssistantSettingsContent::Legacy(settings) => {
@@ -382,11 +382,11 @@ pub struct AssistantSettingsContentV2 {
     /// Default: 320
     default_height: Option<f32>,
     /// The default model to use when creating new contexts.
-    default_model: Option<AssistantDefaultModel>,
+    default_model: Option<LanguageModelSelection>,
 }
 
 #[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
-pub struct AssistantDefaultModel {
+pub struct LanguageModelSelection {
     #[schemars(schema_with = "providers_schema")]
     pub provider: String,
     pub model: String,
@@ -407,7 +407,7 @@ fn providers_schema(_: &mut schemars::gen::SchemaGenerator) -> schemars::schema:
     .into()
 }
 
-impl Default for AssistantDefaultModel {
+impl Default for LanguageModelSelection {
     fn default() -> Self {
         Self {
             provider: "openai".to_string(),
@@ -542,7 +542,7 @@ mod tests {
             assert!(!AssistantSettings::get_global(cx).using_outdated_settings_version);
             assert_eq!(
                 AssistantSettings::get_global(cx).default_model,
-                AssistantDefaultModel {
+                LanguageModelSelection {
                     provider: "openai".into(),
                     model: "gpt-4o".into(),
                 }
@@ -555,7 +555,7 @@ mod tests {
                 |settings, _| {
                     *settings = AssistantSettingsContent::Versioned(
                         VersionedAssistantSettingsContent::V2(AssistantSettingsContentV2 {
-                            default_model: Some(AssistantDefaultModel {
+                            default_model: Some(LanguageModelSelection {
                                 provider: "test-provider".into(),
                                 model: "gpt-99".into(),
                             }),

crates/assistant/src/context.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{
-    prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
-    LanguageModelCompletionProvider, MessageId, MessageStatus,
+    prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion, MessageId,
+    MessageStatus,
 };
 use anyhow::{anyhow, Context as _, Result};
 use assistant_slash_command::{
@@ -18,7 +18,10 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
 use language::{
     AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
 };
-use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
+use language_model::{
+    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool,
+    Role,
+};
 use open_ai::Model as OpenAiModel;
 use paths::contexts_dir;
 use project::Project;
@@ -1180,17 +1183,16 @@ impl Context {
 
     pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
         let request = self.to_completion_request(cx);
+        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+            return;
+        };
         self.pending_token_count = cx.spawn(|this, mut cx| {
             async move {
                 cx.background_executor()
                     .timer(Duration::from_millis(200))
                     .await;
 
-                let token_count = cx
-                    .update(|cx| {
-                        LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
-                    })?
-                    .await?;
+                let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
                 this.update(&mut cx, |this, cx| {
                     this.token_count = Some(token_count);
                     cx.notify()
@@ -1368,6 +1370,10 @@ impl Context {
             }
         }
 
+        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+            return Task::ready(Err(anyhow!("no active model")).log_err());
+        };
+
         let mut request = self.to_completion_request(cx);
         let edit_step_range = edit_step.source_range.clone();
         let step_text = self
@@ -1388,12 +1394,7 @@ impl Context {
                     content: prompt,
                 });
 
-                let tool_use = cx
-                    .update(|cx| {
-                        LanguageModelCompletionProvider::read_global(cx)
-                            .use_tool::<EditTool>(request, cx)
-                    })?
-                    .await?;
+                let tool_use = model.use_tool::<EditTool>(request, &cx).await?;
 
                 this.update(&mut cx, |this, cx| {
                     let step_index = this
@@ -1568,6 +1569,8 @@ impl Context {
     }
 
     pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
+        let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
+        let model = LanguageModelRegistry::read_global(cx).active_model()?;
         let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
             message
                 .start
@@ -1575,14 +1578,12 @@ impl Context {
                 .then_some(message.id)
         })?;
 
-        if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
+        if !provider.is_authenticated(cx) {
             log::info!("completion provider has no credentials");
             return None;
         }
 
         let request = self.to_completion_request(cx);
-        let stream =
-            LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
         let assistant_message = self
             .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
             .unwrap();
@@ -1594,6 +1595,7 @@ impl Context {
 
         let task = cx.spawn({
             |this, mut cx| async move {
+                let stream = model.stream_completion(request, &cx);
                 let assistant_message_id = assistant_message.id;
                 let mut response_latency = None;
                 let stream_completion = async {
@@ -1662,14 +1664,10 @@ impl Context {
                     });
 
                     if let Some(telemetry) = this.telemetry.as_ref() {
-                        let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
-                            .active_model()
-                            .map(|m| m.telemetry_id())
-                            .unwrap_or_default();
                         telemetry.report_assistant_event(
                             Some(this.id.0.clone()),
                             AssistantKind::Panel,
-                            model_telemetry_id,
+                            model.telemetry_id(),
                             response_latency,
                             error_message,
                         );
@@ -1935,8 +1933,15 @@ impl Context {
     }
 
     pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
+        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
+            return;
+        };
+        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+            return;
+        };
+
         if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
-            if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
+            if !provider.is_authenticated(cx) {
                 return;
             }
 
@@ -1953,10 +1958,9 @@ impl Context {
                 temperature: 1.0,
             };
 
-            let stream =
-                LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
             self.pending_summary = cx.spawn(|this, mut cx| {
                 async move {
+                    let stream = model.stream_completion(request, &cx);
                     let mut messages = stream.await?;
 
                     let mut replaced = !replace_old;
@@ -2490,7 +2494,6 @@ mod tests {
     fn test_inserting_and_removing_messages(cx: &mut AppContext) {
         let settings_store = SettingsStore::test(cx);
         language_model::LanguageModelRegistry::test(cx);
-        completion::LanguageModelCompletionProvider::test(cx);
         cx.set_global(settings_store);
         assistant_panel::init(cx);
         let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@@ -2623,7 +2626,6 @@ mod tests {
         let settings_store = SettingsStore::test(cx);
         cx.set_global(settings_store);
         language_model::LanguageModelRegistry::test(cx);
-        completion::LanguageModelCompletionProvider::test(cx);
         assistant_panel::init(cx);
         let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 
@@ -2717,7 +2719,6 @@ mod tests {
     fn test_messages_for_offsets(cx: &mut AppContext) {
         let settings_store = SettingsStore::test(cx);
         language_model::LanguageModelRegistry::test(cx);
-        completion::LanguageModelCompletionProvider::test(cx);
         cx.set_global(settings_store);
         assistant_panel::init(cx);
         let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@@ -2803,7 +2804,6 @@ mod tests {
         let settings_store = cx.update(SettingsStore::test);
         cx.set_global(settings_store);
         cx.update(language_model::LanguageModelRegistry::test);
-        cx.update(completion::LanguageModelCompletionProvider::test);
         cx.update(Project::init_settings);
         cx.update(assistant_panel::init);
         let fs = FakeFs::new(cx.background_executor.clone());
@@ -2930,7 +2930,6 @@ mod tests {
         cx.set_global(settings_store);
 
         let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
-        cx.update(completion::LanguageModelCompletionProvider::test);
 
         let fake_model = fake_provider.test_model();
         cx.update(assistant_panel::init);
@@ -3032,7 +3031,6 @@ mod tests {
         let settings_store = cx.update(SettingsStore::test);
         cx.set_global(settings_store);
         cx.update(language_model::LanguageModelRegistry::test);
-        cx.update(completion::LanguageModelCompletionProvider::test);
         cx.update(assistant_panel::init);
         let registry = Arc::new(LanguageRegistry::test(cx.executor()));
         let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
@@ -3109,7 +3107,6 @@ mod tests {
         let settings_store = cx.update(SettingsStore::test);
         cx.set_global(settings_store);
         cx.update(language_model::LanguageModelRegistry::test);
-        cx.update(completion::LanguageModelCompletionProvider::test);
 
         cx.update(assistant_panel::init);
         let slash_commands = cx.update(SlashCommandRegistry::default_global);

crates/assistant/src/inline_assistant.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{
     humanize_token_count, prompts::generate_content_prompt, AssistantPanel, AssistantPanelEvent,
-    Hunk, LanguageModelCompletionProvider, ModelSelector, StreamingDiff,
+    Hunk, ModelSelector, StreamingDiff,
 };
 use anyhow::{anyhow, Context as _, Result};
 use client::telemetry::Telemetry;
@@ -27,7 +27,9 @@ use gpui::{
     WindowContext,
 };
 use language::{Buffer, IndentKind, Point, Selection, TransactionId};
-use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
+use language_model::{
+    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
+};
 use multi_buffer::MultiBufferRow;
 use parking_lot::Mutex;
 use rope::Rope;
@@ -1328,7 +1330,7 @@ impl Render for PromptEditor {
                                     Tooltip::with_meta(
                                         format!(
                                             "Using {}",
-                                            LanguageModelCompletionProvider::read_global(cx)
+                                            LanguageModelRegistry::read_global(cx)
                                                 .active_model()
                                                 .map(|model| model.name().0)
                                                 .unwrap_or_else(|| "No model selected".into()),
@@ -1662,7 +1664,7 @@ impl PromptEditor {
     }
 
     fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
-        let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
+        let model = LanguageModelRegistry::read_global(cx).active_model()?;
         let token_count = self.token_count?;
         let max_token_count = model.max_token_count();
 
@@ -2013,8 +2015,12 @@ impl Codegen {
         assistant_panel_context: Option<LanguageModelRequest>,
         cx: &AppContext,
     ) -> BoxFuture<'static, Result<usize>> {
-        let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
-        LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
+        if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
+            let request = self.build_request(user_prompt, assistant_panel_context, edit_range, cx);
+            model.count_tokens(request, cx)
+        } else {
+            future::ready(Err(anyhow!("no active model"))).boxed()
+        }
     }
 
     pub fn start(
@@ -2024,6 +2030,10 @@ impl Codegen {
         assistant_panel_context: Option<LanguageModelRequest>,
         cx: &mut ModelContext<Self>,
     ) -> Result<()> {
+        let model = LanguageModelRegistry::read_global(cx)
+            .active_model()
+            .context("no active model")?;
+
         self.undo(cx);
 
         // Handle initial insertion
@@ -2053,10 +2063,7 @@ impl Codegen {
             None
         };
 
-        let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
-            .active_model_telemetry_id()
-            .context("no active model")?;
-
+        let telemetry_id = model.telemetry_id();
         let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
             .trim()
             .to_lowercase()
@@ -2067,10 +2074,10 @@ impl Codegen {
             let request =
                 self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx);
             let chunks =
-                LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
+                cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
             async move { Ok(chunks.await?.boxed()) }.boxed_local()
         };
-        self.handle_stream(model_telemetry_id, edit_range, chunks, cx);
+        self.handle_stream(telemetry_id, edit_range, chunks, cx);
         Ok(())
     }
 
@@ -2657,7 +2664,6 @@ mod tests {
     async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
         cx.set_global(cx.update(SettingsStore::test));
         cx.update(language_model::LanguageModelRegistry::test);
-        cx.update(completion::LanguageModelCompletionProvider::test);
         cx.update(language_settings::init);
 
         let text = indoc! {"
@@ -2789,7 +2795,6 @@ mod tests {
         mut rng: StdRng,
     ) {
         cx.update(LanguageModelRegistry::test);
-        cx.update(completion::LanguageModelCompletionProvider::test);
         cx.set_global(cx.update(SettingsStore::test));
         cx.update(language_settings::init);
 
@@ -2853,7 +2858,6 @@ mod tests {
     #[gpui::test(iterations = 10)]
     async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
         cx.update(LanguageModelRegistry::test);
-        cx.update(completion::LanguageModelCompletionProvider::test);
         cx.set_global(cx.update(SettingsStore::test));
         cx.update(language_settings::init);
 

crates/assistant/src/model_selector.rs 🔗

@@ -1,6 +1,6 @@
 use std::sync::Arc;
 
-use crate::{assistant_settings::AssistantSettings, LanguageModelCompletionProvider};
+use crate::assistant_settings::AssistantSettings;
 use fs::Fs;
 use gpui::SharedString;
 use language_model::LanguageModelRegistry;
@@ -81,13 +81,13 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
                                 }
                             },
                             {
-                                let provider = provider.id();
+                                let provider = provider.clone();
                                 move |cx| {
-                                    LanguageModelCompletionProvider::global(cx).update(
+                                    LanguageModelRegistry::global(cx).update(
                                         cx,
                                         |completion_provider, cx| {
                                             completion_provider
-                                                .set_active_provider(provider.clone(), cx)
+                                                .set_active_provider(Some(provider.clone()), cx);
                                         },
                                     );
                                 }
@@ -95,12 +95,12 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
                         );
                     }
 
-                    let selected_model = LanguageModelCompletionProvider::read_global(cx)
-                        .active_model()
-                        .map(|m| m.id());
-                    let selected_provider = LanguageModelCompletionProvider::read_global(cx)
+                    let selected_provider = LanguageModelRegistry::read_global(cx)
                         .active_provider()
                         .map(|m| m.id());
+                    let selected_model = LanguageModelRegistry::read_global(cx)
+                        .active_model()
+                        .map(|m| m.id());
 
                     for available_model in available_models {
                         menu = menu.custom_entry(

crates/assistant/src/prompt_library.rs 🔗

@@ -1,6 +1,5 @@
 use crate::{
     slash_command::SlashCommandCompletionProvider, AssistantPanel, InlineAssist, InlineAssistant,
-    LanguageModelCompletionProvider,
 };
 use anyhow::{anyhow, Result};
 use assets::Assets;
@@ -19,7 +18,9 @@ use gpui::{
 };
 use heed::{types::SerdeBincode, Database, RoTxn};
 use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
-use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
+use language_model::{
+    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
+};
 use parking_lot::RwLock;
 use picker::{Picker, PickerDelegate};
 use rope::Rope;
@@ -636,7 +637,10 @@ impl PromptLibrary {
         };
 
         let prompt_editor = &self.prompt_editors[&active_prompt_id].body_editor;
-        let provider = LanguageModelCompletionProvider::read_global(cx);
+        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
+            return;
+        };
+
         let initial_prompt = action.prompt.clone();
         if provider.is_authenticated(cx) {
             InlineAssistant::update_global(cx, |assistant, cx| {
@@ -725,6 +729,9 @@ impl PromptLibrary {
     }
 
     fn count_tokens(&mut self, prompt_id: PromptId, cx: &mut ViewContext<Self>) {
+        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+            return;
+        };
         if let Some(prompt) = self.prompt_editors.get_mut(&prompt_id) {
             let editor = &prompt.body_editor.read(cx);
             let buffer = &editor.buffer().read(cx).as_singleton().unwrap().read(cx);
@@ -736,7 +743,7 @@ impl PromptLibrary {
                     cx.background_executor().timer(DEBOUNCE_TIMEOUT).await;
                     let token_count = cx
                         .update(|cx| {
-                            LanguageModelCompletionProvider::read_global(cx).count_tokens(
+                            model.count_tokens(
                                 LanguageModelRequest {
                                     messages: vec![LanguageModelRequestMessage {
                                         role: Role::System,
@@ -804,7 +811,7 @@ impl PromptLibrary {
                 let prompt_metadata = self.store.metadata(prompt_id)?;
                 let prompt_editor = &self.prompt_editors[&prompt_id];
                 let focus_handle = prompt_editor.body_editor.focus_handle(cx);
-                let current_model = LanguageModelCompletionProvider::read_global(cx).active_model();
+                let model = LanguageModelRegistry::read_global(cx).active_model();
                 let settings = ThemeSettings::get_global(cx);
 
                 Some(
@@ -914,7 +921,7 @@ impl PromptLibrary {
                                                                     None,
                                                                     format!(
                                                                         "Model: {}",
-                                                                        current_model
+                                                                        model
                                                                             .as_ref()
                                                                             .map(|model| model
                                                                                 .name()

crates/assistant/src/terminal_inline_assistant.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{
     humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel,
-    AssistantPanelEvent, LanguageModelCompletionProvider, ModelSelector,
+    AssistantPanelEvent, ModelSelector,
 };
 use anyhow::{Context as _, Result};
 use client::telemetry::Telemetry;
@@ -16,7 +16,9 @@ use gpui::{
     Subscription, Task, TextStyle, UpdateGlobal, View, WeakView,
 };
 use language::Buffer;
-use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
+use language_model::{
+    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
+};
 use settings::Settings;
 use std::{
     cmp,
@@ -556,7 +558,7 @@ impl Render for PromptEditor {
                                 Tooltip::with_meta(
                                     format!(
                                         "Using {}",
-                                        LanguageModelCompletionProvider::read_global(cx)
+                                        LanguageModelRegistry::read_global(cx)
                                             .active_model()
                                             .map(|model| model.name().0)
                                             .unwrap_or_else(|| "No model selected".into()),
@@ -700,6 +702,9 @@ impl PromptEditor {
 
     fn count_tokens(&mut self, cx: &mut ViewContext<Self>) {
         let assist_id = self.id;
+        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+            return;
+        };
         self.pending_token_count = cx.spawn(|this, mut cx| async move {
             cx.background_executor().timer(Duration::from_secs(1)).await;
             let request =
@@ -707,11 +712,7 @@ impl PromptEditor {
                     inline_assistant.request_for_inline_assist(assist_id, cx)
                 })??;
 
-            let token_count = cx
-                .update(|cx| {
-                    LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
-                })?
-                .await?;
+            let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
             this.update(&mut cx, |this, cx| {
                 this.token_count = Some(token_count);
                 cx.notify();
@@ -840,7 +841,7 @@ impl PromptEditor {
     }
 
     fn render_token_count(&self, cx: &mut ViewContext<Self>) -> Option<impl IntoElement> {
-        let model = LanguageModelCompletionProvider::read_global(cx).active_model()?;
+        let model = LanguageModelRegistry::read_global(cx).active_model()?;
         let token_count = self.token_count?;
         let max_token_count = model.max_token_count();
 
@@ -982,19 +983,16 @@ impl Codegen {
     }
 
     pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
-        self.status = CodegenStatus::Pending;
-        self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
+        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+            return;
+        };
 
         let telemetry = self.telemetry.clone();
-        let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
-            .active_model()
-            .map(|m| m.telemetry_id())
-            .unwrap_or_default();
-        let response =
-            LanguageModelCompletionProvider::read_global(cx).stream_completion(prompt, cx);
-
+        self.status = CodegenStatus::Pending;
+        self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
         self.generation = cx.spawn(|this, mut cx| async move {
-            let response = response.await;
+            let model_telemetry_id = model.telemetry_id();
+            let response = model.stream_completion(prompt, &cx).await;
             let generate = async {
                 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 

crates/collab/Cargo.toml 🔗

@@ -80,7 +80,6 @@ channel.workspace = true
 client = { workspace = true, features = ["test-support"] }
 collab_ui = { workspace = true, features = ["test-support"] }
 collections = { workspace = true, features = ["test-support"] }
-completion = { workspace = true, features = ["test-support"] }
 ctor.workspace = true
 editor = { workspace = true, features = ["test-support"] }
 env_logger.workspace = true

crates/collab/src/tests/test_server.rs 🔗

@@ -300,7 +300,6 @@ impl TestServer {
             dev_server_projects::init(client.clone(), cx);
             settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
             language_model::LanguageModelRegistry::test(cx);
-            completion::init(cx);
             assistant::context_store::init(&client);
         });
 

crates/completion/Cargo.toml 🔗

@@ -1,45 +0,0 @@
-[package]
-name = "completion"
-version = "0.1.0"
-edition = "2021"
-publish = false
-license = "GPL-3.0-or-later"
-
-[lints]
-workspace = true
-
-[lib]
-path = "src/completion.rs"
-doctest = false
-
-[features]
-test-support = [
-    "editor/test-support",
-    "language/test-support",
-    "language_model/test-support",
-    "project/test-support",
-    "text/test-support",
-]
-
-[dependencies]
-anyhow.workspace = true
-futures.workspace = true
-gpui.workspace = true
-language_model.workspace = true
-schemars.workspace = true
-serde.workspace = true
-serde_json.workspace = true
-settings.workspace = true
-smol.workspace = true
-ui.workspace = true
-
-[dev-dependencies]
-ctor.workspace = true
-editor = { workspace = true, features = ["test-support"] }
-env_logger.workspace = true
-language = { workspace = true, features = ["test-support"] }
-project = { workspace = true, features = ["test-support"] }
-language_model = { workspace = true, features = ["test-support"] }
-rand.workspace = true
-text = { workspace = true, features = ["test-support"] }
-unindent.workspace = true

crates/completion/src/completion.rs 🔗

@@ -1,312 +0,0 @@
-use anyhow::{anyhow, Result};
-use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
-use gpui::{AppContext, Global, Model, ModelContext, Task};
-use language_model::{
-    LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
-    LanguageModelRequest, LanguageModelTool,
-};
-use smol::{
-    future::FutureExt,
-    lock::{Semaphore, SemaphoreGuardArc},
-};
-use std::{future, pin::Pin, sync::Arc, task::Poll};
-use ui::Context;
-
-pub fn init(cx: &mut AppContext) {
-    let completion_provider = cx.new_model(|cx| LanguageModelCompletionProvider::new(cx));
-    cx.set_global(GlobalLanguageModelCompletionProvider(completion_provider));
-}
-
-struct GlobalLanguageModelCompletionProvider(Model<LanguageModelCompletionProvider>);
-
-impl Global for GlobalLanguageModelCompletionProvider {}
-
-pub struct LanguageModelCompletionProvider {
-    active_provider: Option<Arc<dyn LanguageModelProvider>>,
-    active_model: Option<Arc<dyn LanguageModel>>,
-    request_limiter: Arc<Semaphore>,
-}
-
-const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
-
-pub struct LanguageModelCompletionResponse {
-    inner: BoxStream<'static, Result<String>>,
-    _lock: SemaphoreGuardArc,
-}
-
-impl futures::Stream for LanguageModelCompletionResponse {
-    type Item = Result<String>;
-
-    fn poll_next(
-        mut self: Pin<&mut Self>,
-        cx: &mut std::task::Context<'_>,
-    ) -> Poll<Option<Self::Item>> {
-        Pin::new(&mut self.inner).poll_next(cx)
-    }
-}
-
-impl LanguageModelCompletionProvider {
-    pub fn global(cx: &AppContext) -> Model<Self> {
-        cx.global::<GlobalLanguageModelCompletionProvider>()
-            .0
-            .clone()
-    }
-
-    pub fn read_global(cx: &AppContext) -> &Self {
-        cx.global::<GlobalLanguageModelCompletionProvider>()
-            .0
-            .read(cx)
-    }
-
-    #[cfg(any(test, feature = "test-support"))]
-    pub fn test(cx: &mut AppContext) {
-        let provider = cx.new_model(|cx| {
-            let mut this = Self::new(cx);
-            let available_model = LanguageModelRegistry::read_global(cx)
-                .available_models(cx)
-                .first()
-                .unwrap()
-                .clone();
-            this.set_active_model(available_model, cx);
-            this
-        });
-        cx.set_global(GlobalLanguageModelCompletionProvider(provider));
-    }
-
-    pub fn new(cx: &mut ModelContext<Self>) -> Self {
-        cx.observe(&LanguageModelRegistry::global(cx), |_, _, cx| {
-            cx.notify();
-        })
-        .detach();
-
-        Self {
-            active_provider: None,
-            active_model: None,
-            request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
-        }
-    }
-
-    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
-        self.active_provider.clone()
-    }
-
-    pub fn set_active_provider(
-        &mut self,
-        provider_id: LanguageModelProviderId,
-        cx: &mut ModelContext<Self>,
-    ) {
-        self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_id);
-        self.active_model = None;
-        cx.notify();
-    }
-
-    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
-        self.active_model.clone()
-    }
-
-    pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
-        if self.active_model.as_ref().map_or(false, |m| {
-            m.id() == model.id() && m.provider_id() == model.provider_id()
-        }) {
-            return;
-        }
-
-        self.active_provider =
-            LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
-        self.active_model = Some(model.clone());
-
-        if let Some(provider) = self.active_provider.as_ref() {
-            provider.load_model(model, cx);
-        }
-
-        cx.notify();
-    }
-
-    pub fn is_authenticated(&self, cx: &AppContext) -> bool {
-        self.active_provider
-            .as_ref()
-            .map_or(false, |provider| provider.is_authenticated(cx))
-    }
-
-    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
-        self.active_provider
-            .as_ref()
-            .map_or(Task::ready(Ok(())), |provider| provider.authenticate(cx))
-    }
-
-    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
-        self.active_provider
-            .as_ref()
-            .map_or(Task::ready(Ok(())), |provider| {
-                provider.reset_credentials(cx)
-            })
-    }
-
-    pub fn count_tokens(
-        &self,
-        request: LanguageModelRequest,
-        cx: &AppContext,
-    ) -> BoxFuture<'static, Result<usize>> {
-        if let Some(model) = self.active_model() {
-            model.count_tokens(request, cx)
-        } else {
-            future::ready(Err(anyhow!("no active model"))).boxed()
-        }
-    }
-
-    pub fn stream_completion(
-        &self,
-        request: LanguageModelRequest,
-        cx: &AppContext,
-    ) -> Task<Result<LanguageModelCompletionResponse>> {
-        if let Some(language_model) = self.active_model() {
-            let rate_limiter = self.request_limiter.clone();
-            cx.spawn(|cx| async move {
-                let lock = rate_limiter.acquire_arc().await;
-                let response = language_model.stream_completion(request, &cx).await?;
-                Ok(LanguageModelCompletionResponse {
-                    inner: response,
-                    _lock: lock,
-                })
-            })
-        } else {
-            Task::ready(Err(anyhow!("No active model set")))
-        }
-    }
-
-    pub fn complete(&self, request: LanguageModelRequest, cx: &AppContext) -> Task<Result<String>> {
-        let response = self.stream_completion(request, cx);
-        cx.foreground_executor().spawn(async move {
-            let mut chunks = response.await?;
-            let mut completion = String::new();
-            while let Some(chunk) = chunks.next().await {
-                let chunk = chunk?;
-                completion.push_str(&chunk);
-            }
-            Ok(completion)
-        })
-    }
-
-    pub fn use_tool<T: LanguageModelTool>(
-        &self,
-        request: LanguageModelRequest,
-        cx: &AppContext,
-    ) -> Task<Result<T>> {
-        if let Some(language_model) = self.active_model() {
-            cx.spawn(|cx| async move {
-                let schema = schemars::schema_for!(T);
-                let schema_json = serde_json::to_value(&schema).unwrap();
-                let request =
-                    language_model.use_tool(request, T::name(), T::description(), schema_json, &cx);
-                let response = request.await?;
-                Ok(serde_json::from_value(response)?)
-            })
-        } else {
-            Task::ready(Err(anyhow!("No active model set")))
-        }
-    }
-
-    pub fn active_model_telemetry_id(&self) -> Option<String> {
-        self.active_model.as_ref().map(|m| m.telemetry_id())
-    }
-}
-
-#[cfg(test)]
-mod tests {
-    use futures::StreamExt;
-    use gpui::AppContext;
-    use settings::SettingsStore;
-    use ui::Context;
-
-    use crate::{
-        LanguageModelCompletionProvider, LanguageModelRequest, MAX_CONCURRENT_COMPLETION_REQUESTS,
-    };
-
-    use language_model::LanguageModelRegistry;
-
-    #[gpui::test]
-    fn test_rate_limiting(cx: &mut AppContext) {
-        SettingsStore::test(cx);
-        let fake_provider = LanguageModelRegistry::test(cx);
-
-        let model = LanguageModelRegistry::read_global(cx)
-            .available_models(cx)
-            .first()
-            .cloned()
-            .unwrap();
-
-        let provider = cx.new_model(|cx| {
-            let mut provider = LanguageModelCompletionProvider::new(cx);
-            provider.set_active_model(model.clone(), cx);
-            provider
-        });
-
-        let fake_model = fake_provider.test_model();
-
-        // Enqueue some requests
-        for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
-            let response = provider.read(cx).stream_completion(
-                LanguageModelRequest {
-                    temperature: i as f32 / 10.0,
-                    ..Default::default()
-                },
-                cx,
-            );
-            cx.background_executor()
-                .spawn(async move {
-                    let mut stream = response.await.unwrap();
-                    while let Some(message) = stream.next().await {
-                        message.unwrap();
-                    }
-                })
-                .detach();
-        }
-        cx.background_executor().run_until_parked();
-        assert_eq!(
-            fake_model.completion_count(),
-            MAX_CONCURRENT_COMPLETION_REQUESTS
-        );
-
-        // Get the first completion request that is in flight and mark it as completed.
-        let completion = fake_model.pending_completions().into_iter().next().unwrap();
-        fake_model.finish_completion(&completion);
-
-        // Ensure that the number of in-flight completion requests is reduced.
-        assert_eq!(
-            fake_model.completion_count(),
-            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
-        );
-
-        cx.background_executor().run_until_parked();
-
-        // Ensure that another completion request was allowed to acquire the lock.
-        assert_eq!(
-            fake_model.completion_count(),
-            MAX_CONCURRENT_COMPLETION_REQUESTS
-        );
-
-        // Mark all completion requests as finished that are in flight.
-        for request in fake_model.pending_completions() {
-            fake_model.finish_completion(&request);
-        }
-
-        assert_eq!(fake_model.completion_count(), 0);
-
-        // Wait until the background tasks acquire the lock again.
-        cx.background_executor().run_until_parked();
-
-        assert_eq!(
-            fake_model.completion_count(),
-            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
-        );
-
-        // Finish all remaining completion requests.
-        for request in fake_model.pending_completions() {
-            fake_model.finish_completion(&request);
-        }
-
-        cx.background_executor().run_until_parked();
-
-        assert_eq!(fake_model.completion_count(), 0);
-    }
-}

crates/copilot/src/copilot_chat.rs 🔗

@@ -208,13 +208,13 @@ impl CopilotChat {
     pub async fn stream_completion(
         request: Request,
         low_speed_timeout: Option<Duration>,
-        cx: &mut AsyncAppContext,
+        mut cx: AsyncAppContext,
     ) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
         let Some(this) = cx.update(|cx| Self::global(cx)).ok().flatten() else {
             return Err(anyhow!("Copilot chat is not enabled"));
         };
 
-        let (oauth_token, api_token, client) = this.read_with(cx, |this, _| {
+        let (oauth_token, api_token, client) = this.read_with(&cx, |this, _| {
             (
                 this.oauth_token.clone(),
                 this.api_token.clone(),
@@ -229,7 +229,7 @@ impl CopilotChat {
             _ => {
                 let token =
                     request_api_token(&oauth_token, client.clone(), low_speed_timeout).await?;
-                this.update(cx, |this, cx| {
+                this.update(&mut cx, |this, cx| {
                     this.api_token = Some(token.clone());
                     cx.notify();
                 })?;

crates/language_model/Cargo.toml 🔗

@@ -33,6 +33,7 @@ google_ai = { workspace = true, features = ["schemars"] }
 gpui.workspace = true
 http_client.workspace = true
 inline_completion_button.workspace = true
+log.workspace = true
 menu.workspace = true
 ollama = { workspace = true, features = ["schemars"] }
 open_ai = { workspace = true, features = ["schemars"] }
@@ -42,6 +43,7 @@ schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true
+smol.workspace = true
 strum.workspace = true
 theme.workspace = true
 tiktoken-rs.workspace = true

crates/language_model/src/language_model.rs 🔗

@@ -1,24 +1,24 @@
 mod model;
 pub mod provider;
+mod rate_limiter;
 mod registry;
 mod request;
 mod role;
 pub mod settings;
 
-use std::sync::Arc;
-
 use anyhow::Result;
 use client::Client;
 use futures::{future::BoxFuture, stream::BoxStream};
 use gpui::{AnyView, AppContext, AsyncAppContext, SharedString, Task, WindowContext};
-
 pub use model::*;
 use project::Fs;
+pub(crate) use rate_limiter::*;
 pub use registry::*;
 pub use request::*;
 pub use role::*;
 use schemars::JsonSchema;
 use serde::de::DeserializeOwned;
+use std::{future::Future, sync::Arc};
 
 pub fn init(client: Arc<Client>, fs: Arc<dyn Fs>, cx: &mut AppContext) {
     settings::init(fs, cx);
@@ -46,7 +46,7 @@ pub trait LanguageModel: Send + Sync {
         cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 
-    fn use_tool(
+    fn use_any_tool(
         &self,
         request: LanguageModelRequest,
         name: String,
@@ -56,6 +56,22 @@ pub trait LanguageModel: Send + Sync {
     ) -> BoxFuture<'static, Result<serde_json::Value>>;
 }
 
+impl dyn LanguageModel {
+    pub fn use_tool<T: LanguageModelTool>(
+        &self,
+        request: LanguageModelRequest,
+        cx: &AsyncAppContext,
+    ) -> impl 'static + Future<Output = Result<T>> {
+        let schema = schemars::schema_for!(T);
+        let schema_json = serde_json::to_value(&schema).unwrap();
+        let request = self.use_any_tool(request, T::name(), T::description(), schema_json, cx);
+        async move {
+            let response = request.await?;
+            Ok(serde_json::from_value(response)?)
+        }
+    }
+}
+
 pub trait LanguageModelTool: 'static + DeserializeOwned + JsonSchema {
     fn name() -> String;
     fn description() -> String;
@@ -67,9 +83,9 @@ pub trait LanguageModelProvider: 'static {
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
     fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
     fn is_authenticated(&self, cx: &AppContext) -> bool;
-    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
+    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
     fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
-    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
+    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
 }
 
 pub trait LanguageModelProviderState: 'static {

crates/language_model/src/provider/anthropic.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{
     settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
     LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
-    LanguageModelProviderState, LanguageModelRequest, Role,
+    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 };
 use anyhow::{anyhow, Context as _, Result};
 use collections::BTreeMap;
@@ -36,6 +36,7 @@ pub struct AnthropicSettings {
 pub struct AvailableModel {
     pub name: String,
     pub max_tokens: usize,
+    pub tool_override: Option<String>,
 }
 
 pub struct AnthropicLanguageModelProvider {
@@ -98,6 +99,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
                 anthropic::Model::Custom {
                     name: model.name.clone(),
                     max_tokens: model.max_tokens,
+                    tool_override: model.tool_override.clone(),
                 },
             );
         }
@@ -110,6 +112,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
                     model,
                     state: self.state.clone(),
                     http_client: self.http_client.clone(),
+                    request_limiter: RateLimiter::new(4),
                 }) as Arc<dyn LanguageModel>
             })
             .collect()
@@ -119,7 +122,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
         self.state.read(cx).api_key.is_some()
     }
 
-    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
         if self.is_authenticated(cx) {
             Task::ready(Ok(()))
         } else {
@@ -152,7 +155,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
             .into()
     }
 
-    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
         let state = self.state.clone();
         let delete_credentials =
             cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
@@ -171,6 +174,7 @@ pub struct AnthropicModel {
     model: anthropic::Model,
     state: gpui::Model<State>,
     http_client: Arc<dyn HttpClient>,
+    request_limiter: RateLimiter,
 }
 
 pub fn count_anthropic_tokens(
@@ -296,14 +300,14 @@ impl LanguageModel for AnthropicModel {
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
         let request = request.into_anthropic(self.model.id().into());
         let request = self.stream_completion(request, cx);
-        async move {
+        let future = self.request_limiter.stream(async move {
             let response = request.await?;
-            Ok(anthropic::extract_text_from_events(response).boxed())
-        }
-        .boxed()
+            Ok(anthropic::extract_text_from_events(response))
+        });
+        async move { Ok(future.await?.boxed()) }.boxed()
     }
 
-    fn use_tool(
+    fn use_any_tool(
         &self,
         request: LanguageModelRequest,
         tool_name: String,
@@ -311,7 +315,7 @@ impl LanguageModel for AnthropicModel {
         input_schema: serde_json::Value,
         cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<serde_json::Value>> {
-        let mut request = request.into_anthropic(self.model.id().into());
+        let mut request = request.into_anthropic(self.model.tool_model_id().into());
         request.tool_choice = Some(anthropic::ToolChoice::Tool {
             name: tool_name.clone(),
         });
@@ -322,25 +326,26 @@ impl LanguageModel for AnthropicModel {
         }];
 
         let response = self.request_completion(request, cx);
-        async move {
-            let response = response.await?;
-            response
-                .content
-                .into_iter()
-                .find_map(|content| {
-                    if let anthropic::Content::ToolUse { name, input, .. } = content {
-                        if name == tool_name {
-                            Some(input)
+        self.request_limiter
+            .run(async move {
+                let response = response.await?;
+                response
+                    .content
+                    .into_iter()
+                    .find_map(|content| {
+                        if let anthropic::Content::ToolUse { name, input, .. } = content {
+                            if name == tool_name {
+                                Some(input)
+                            } else {
+                                None
+                            }
                         } else {
                             None
                         }
-                    } else {
-                        None
-                    }
-                })
-                .context("tool not used")
-        }
-        .boxed()
+                    })
+                    .context("tool not used")
+            })
+            .boxed()
     }
 }
 

crates/language_model/src/provider/cloud.rs 🔗

@@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
 use crate::{
     settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
     LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
-    LanguageModelProviderState, LanguageModelRequest,
+    LanguageModelProviderState, LanguageModelRequest, RateLimiter,
 };
 use anyhow::{anyhow, Context as _, Result};
 use client::Client;
@@ -41,6 +41,7 @@ pub struct AvailableModel {
     provider: AvailableProvider,
     name: String,
     max_tokens: usize,
+    tool_override: Option<String>,
 }
 
 pub struct CloudLanguageModelProvider {
@@ -56,7 +57,7 @@ struct State {
 }
 
 impl State {
-    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
         let client = self.client.clone();
         cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
     }
@@ -142,6 +143,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
                 AvailableProvider::Anthropic => CloudModel::Anthropic(anthropic::Model::Custom {
                     name: model.name.clone(),
                     max_tokens: model.max_tokens,
+                    tool_override: model.tool_override.clone(),
                 }),
                 AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
                     name: model.name.clone(),
@@ -162,6 +164,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
                     id: LanguageModelId::from(model.id().to_string()),
                     model,
                     client: self.client.clone(),
+                    request_limiter: RateLimiter::new(4),
                 }) as Arc<dyn LanguageModel>
             })
             .collect()
@@ -171,8 +174,8 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
         self.state.read(cx).status.is_connected()
     }
 
-    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
-        self.state.read(cx).authenticate(cx)
+    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
+        self.state.update(cx, |state, cx| state.authenticate(cx))
     }
 
     fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
@@ -182,7 +185,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
         .into()
     }
 
-    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
+    fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
         Task::ready(Ok(()))
     }
 }
@@ -191,6 +194,7 @@ pub struct CloudLanguageModel {
     id: LanguageModelId,
     model: CloudModel,
     client: Arc<Client>,
+    request_limiter: RateLimiter,
 }
 
 impl LanguageModel for CloudLanguageModel {
@@ -256,7 +260,7 @@ impl LanguageModel for CloudLanguageModel {
             CloudModel::Anthropic(model) => {
                 let client = self.client.clone();
                 let request = request.into_anthropic(model.id().into());
-                async move {
+                let future = self.request_limiter.stream(async move {
                     let request = serde_json::to_string(&request)?;
                     let stream = client
                         .request_stream(proto::StreamCompleteWithLanguageModel {
@@ -266,15 +270,14 @@ impl LanguageModel for CloudLanguageModel {
                         .await?;
                     Ok(anthropic::extract_text_from_events(
                         stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
-                    )
-                    .boxed())
-                }
-                .boxed()
+                    ))
+                });
+                async move { Ok(future.await?.boxed()) }.boxed()
             }
             CloudModel::OpenAi(model) => {
                 let client = self.client.clone();
                 let request = request.into_open_ai(model.id().into());
-                async move {
+                let future = self.request_limiter.stream(async move {
                     let request = serde_json::to_string(&request)?;
                     let stream = client
                         .request_stream(proto::StreamCompleteWithLanguageModel {
@@ -284,15 +287,14 @@ impl LanguageModel for CloudLanguageModel {
                         .await?;
                     Ok(open_ai::extract_text_from_events(
                         stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
-                    )
-                    .boxed())
-                }
-                .boxed()
+                    ))
+                });
+                async move { Ok(future.await?.boxed()) }.boxed()
             }
             CloudModel::Google(model) => {
                 let client = self.client.clone();
                 let request = request.into_google(model.id().into());
-                async move {
+                let future = self.request_limiter.stream(async move {
                     let request = serde_json::to_string(&request)?;
                     let stream = client
                         .request_stream(proto::StreamCompleteWithLanguageModel {
@@ -302,15 +304,14 @@ impl LanguageModel for CloudLanguageModel {
                         .await?;
                     Ok(google_ai::extract_text_from_events(
                         stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
-                    )
-                    .boxed())
-                }
-                .boxed()
+                    ))
+                });
+                async move { Ok(future.await?.boxed()) }.boxed()
             }
         }
     }
 
-    fn use_tool(
+    fn use_any_tool(
         &self,
         request: LanguageModelRequest,
         tool_name: String,
@@ -321,7 +322,7 @@ impl LanguageModel for CloudLanguageModel {
         match &self.model {
             CloudModel::Anthropic(model) => {
                 let client = self.client.clone();
-                let mut request = request.into_anthropic(model.id().into());
+                let mut request = request.into_anthropic(model.tool_model_id().into());
                 request.tool_choice = Some(anthropic::ToolChoice::Tool {
                     name: tool_name.clone(),
                 });
@@ -331,32 +332,34 @@ impl LanguageModel for CloudLanguageModel {
                     input_schema,
                 }];
 
-                async move {
-                    let request = serde_json::to_string(&request)?;
-                    let response = client
-                        .request(proto::CompleteWithLanguageModel {
-                            provider: proto::LanguageModelProvider::Anthropic as i32,
-                            request,
-                        })
-                        .await?;
-                    let response: anthropic::Response = serde_json::from_str(&response.completion)?;
-                    response
-                        .content
-                        .into_iter()
-                        .find_map(|content| {
-                            if let anthropic::Content::ToolUse { name, input, .. } = content {
-                                if name == tool_name {
-                                    Some(input)
+                self.request_limiter
+                    .run(async move {
+                        let request = serde_json::to_string(&request)?;
+                        let response = client
+                            .request(proto::CompleteWithLanguageModel {
+                                provider: proto::LanguageModelProvider::Anthropic as i32,
+                                request,
+                            })
+                            .await?;
+                        let response: anthropic::Response =
+                            serde_json::from_str(&response.completion)?;
+                        response
+                            .content
+                            .into_iter()
+                            .find_map(|content| {
+                                if let anthropic::Content::ToolUse { name, input, .. } = content {
+                                    if name == tool_name {
+                                        Some(input)
+                                    } else {
+                                        None
+                                    }
                                 } else {
                                     None
                                 }
-                            } else {
-                                None
-                            }
-                        })
-                        .context("tool not used")
-                }
-                .boxed()
+                            })
+                            .context("tool not used")
+                    })
+                    .boxed()
             }
             CloudModel::OpenAi(_) => {
                 future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()

crates/language_model/src/provider/copilot_chat.rs 🔗

@@ -27,7 +27,7 @@ use crate::settings::AllLanguageModelSettings;
 use crate::LanguageModelProviderState;
 use crate::{
     LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
-    LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role,
+    LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
 };
 
 use super::open_ai::count_open_ai_tokens;
@@ -85,7 +85,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
 
     fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
         CopilotChatModel::iter()
-            .map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc<dyn LanguageModel>)
+            .map(|model| {
+                Arc::new(CopilotChatLanguageModel {
+                    model,
+                    request_limiter: RateLimiter::new(4),
+                }) as Arc<dyn LanguageModel>
+            })
             .collect()
     }
 
@@ -95,7 +100,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
             .unwrap_or(false)
     }
 
-    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
         let result = if self.is_authenticated(cx) {
             Ok(())
         } else if let Some(copilot) = Copilot::global(cx) {
@@ -121,7 +126,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
         cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
     }
 
-    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
         let Some(copilot) = Copilot::global(cx) else {
             return Task::ready(Err(anyhow::anyhow!(
                 "Copilot is not available. Please ensure Copilot is enabled and running and try again."
@@ -145,6 +150,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
 
 pub struct CopilotChatLanguageModel {
     model: CopilotChatModel,
+    request_limiter: RateLimiter,
 }
 
 impl LanguageModel for CopilotChatLanguageModel {
@@ -215,30 +221,35 @@ impl LanguageModel for CopilotChatLanguageModel {
             return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
         };
 
-        cx.spawn(|mut cx| async move {
-            let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?;
-            let stream = response
-                .filter_map(|response| async move {
-                    match response {
-                        Ok(result) => {
-                            let choice = result.choices.first();
-                            match choice {
-                                Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
-                                None => Some(Err(anyhow::anyhow!(
-                                    "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
-                                ))),
+        let request_limiter = self.request_limiter.clone();
+        let future = cx.spawn(|cx| async move {
+            let response = CopilotChat::stream_completion(request, low_speed_timeout, cx);
+            request_limiter.stream(async move {
+                let response = response.await?;
+                let stream = response
+                    .filter_map(|response| async move {
+                        match response {
+                            Ok(result) => {
+                                let choice = result.choices.first();
+                                match choice {
+                                    Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
+                                    None => Some(Err(anyhow::anyhow!(
+                                        "The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
+                                    ))),
+                                }
                             }
+                            Err(err) => Some(Err(err)),
                         }
-                        Err(err) => Some(Err(err)),
-                    }
-                })
-                .boxed();
-            Ok(stream)
-        })
-        .boxed()
+                    })
+                    .boxed();
+                Ok(stream)
+            }).await
+        });
+
+        async move { Ok(future.await?.boxed()) }.boxed()
     }
 
-    fn use_tool(
+    fn use_any_tool(
         &self,
         _request: LanguageModelRequest,
         _name: String,

crates/language_model/src/provider/fake.rs 🔗

@@ -60,7 +60,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
         true
     }
 
-    fn authenticate(&self, _: &AppContext) -> Task<Result<()>> {
+    fn authenticate(&self, _: &mut AppContext) -> Task<Result<()>> {
         Task::ready(Ok(()))
     }
 
@@ -68,7 +68,7 @@ impl LanguageModelProvider for FakeLanguageModelProvider {
         unimplemented!()
     }
 
-    fn reset_credentials(&self, _: &AppContext) -> Task<Result<()>> {
+    fn reset_credentials(&self, _: &mut AppContext) -> Task<Result<()>> {
         Task::ready(Ok(()))
     }
 }
@@ -173,7 +173,7 @@ impl LanguageModel for FakeLanguageModel {
         async move { Ok(rx.map(Ok).boxed()) }.boxed()
     }
 
-    fn use_tool(
+    fn use_any_tool(
         &self,
         _request: LanguageModelRequest,
         _name: String,

crates/language_model/src/provider/google.rs 🔗

@@ -20,7 +20,7 @@ use util::ResultExt;
 use crate::{
     settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
     LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
-    LanguageModelProviderState, LanguageModelRequest,
+    LanguageModelProviderState, LanguageModelRequest, RateLimiter,
 };
 
 const PROVIDER_ID: &str = "google";
@@ -111,6 +111,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
                     model,
                     state: self.state.clone(),
                     http_client: self.http_client.clone(),
+                    rate_limiter: RateLimiter::new(4),
                 }) as Arc<dyn LanguageModel>
             })
             .collect()
@@ -120,7 +121,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
         self.state.read(cx).api_key.is_some()
     }
 
-    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
         if self.is_authenticated(cx) {
             Task::ready(Ok(()))
         } else {
@@ -153,7 +154,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
             .into()
     }
 
-    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
         let state = self.state.clone();
         let delete_credentials =
             cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
@@ -172,6 +173,7 @@ pub struct GoogleLanguageModel {
     model: google_ai::Model,
     state: gpui::Model<State>,
     http_client: Arc<dyn HttpClient>,
+    rate_limiter: RateLimiter,
 }
 
 impl LanguageModel for GoogleLanguageModel {
@@ -243,17 +245,17 @@ impl LanguageModel for GoogleLanguageModel {
             return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
         };
 
-        async move {
+        let future = self.rate_limiter.stream(async move {
             let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
             let response =
                 stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
             let events = response.await?;
             Ok(google_ai::extract_text_from_events(events).boxed())
-        }
-        .boxed()
+        });
+        async move { Ok(future.await?.boxed()) }.boxed()
     }
 
-    fn use_tool(
+    fn use_any_tool(
         &self,
         _request: LanguageModelRequest,
         _name: String,

crates/language_model/src/provider/ollama.rs 🔗

@@ -12,7 +12,7 @@ use ui::{prelude::*, ButtonLike, ElevationIndex};
 use crate::{
     settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
     LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
-    LanguageModelProviderState, LanguageModelRequest, Role,
+    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 };
 
 const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
@@ -39,7 +39,7 @@ struct State {
 }
 
 impl State {
-    fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
+    fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
         let settings = &AllLanguageModelSettings::get_global(cx).ollama;
         let http_client = self.http_client.clone();
         let api_url = settings.api_url.clone();
@@ -80,37 +80,10 @@ impl OllamaLanguageModelProvider {
                 }),
             }),
         };
-        this.fetch_models(cx).detach();
+        this.state
+            .update(cx, |state, cx| state.fetch_models(cx).detach());
         this
     }
-
-    fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
-        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
-        let http_client = self.http_client.clone();
-        let api_url = settings.api_url.clone();
-
-        let state = self.state.clone();
-        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
-        cx.spawn(|mut cx| async move {
-            let models = get_models(http_client.as_ref(), &api_url, None).await?;
-
-            let mut models: Vec<ollama::Model> = models
-                .into_iter()
-                // Since there is no metadata from the Ollama API
-                // indicating which models are embedding models,
-                // simply filter out models with "-embed" in their name
-                .filter(|model| !model.name.contains("-embed"))
-                .map(|model| ollama::Model::new(&model.name))
-                .collect();
-
-            models.sort_by(|a, b| a.name.cmp(&b.name));
-
-            state.update(&mut cx, |this, cx| {
-                this.available_models = models;
-                cx.notify();
-            })
-        })
-    }
 }
 
 impl LanguageModelProviderState for OllamaLanguageModelProvider {
@@ -140,6 +113,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
                     id: LanguageModelId::from(model.name.clone()),
                     model: model.clone(),
                     http_client: self.http_client.clone(),
+                    request_limiter: RateLimiter::new(4),
                 }) as Arc<dyn LanguageModel>
             })
             .collect()
@@ -158,11 +132,11 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
         !self.state.read(cx).available_models.is_empty()
     }
 
-    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
         if self.is_authenticated(cx) {
             Task::ready(Ok(()))
         } else {
-            self.fetch_models(cx)
+            self.state.update(cx, |state, cx| state.fetch_models(cx))
         }
     }
 
@@ -176,8 +150,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
             .into()
     }
 
-    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
-        self.fetch_models(cx)
+    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
+        self.state.update(cx, |state, cx| state.fetch_models(cx))
     }
 }
 
@@ -185,6 +159,7 @@ pub struct OllamaLanguageModel {
     id: LanguageModelId,
     model: ollama::Model,
     http_client: Arc<dyn HttpClient>,
+    request_limiter: RateLimiter,
 }
 
 impl OllamaLanguageModel {
@@ -235,14 +210,14 @@ impl LanguageModel for OllamaLanguageModel {
         LanguageModelProviderName(PROVIDER_NAME.into())
     }
 
-    fn max_token_count(&self) -> usize {
-        self.model.max_token_count()
-    }
-
     fn telemetry_id(&self) -> String {
         format!("ollama/{}", self.model.id())
     }
 
+    fn max_token_count(&self) -> usize {
+        self.model.max_token_count()
+    }
+
     fn count_tokens(
         &self,
         request: LanguageModelRequest,
@@ -275,10 +250,10 @@ impl LanguageModel for OllamaLanguageModel {
             return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
         };
 
-        async move {
-            let request =
-                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
-            let response = request.await?;
+        let future = self.request_limiter.stream(async move {
+            let response =
+                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
+                    .await?;
             let stream = response
                 .filter_map(|response| async move {
                     match response {
@@ -295,11 +270,12 @@ impl LanguageModel for OllamaLanguageModel {
                 })
                 .boxed();
             Ok(stream)
-        }
-        .boxed()
+        });
+
+        async move { Ok(future.await?.boxed()) }.boxed()
     }
 
-    fn use_tool(
+    fn use_any_tool(
         &self,
         _request: LanguageModelRequest,
         _name: String,

crates/language_model/src/provider/open_ai.rs 🔗

@@ -20,7 +20,7 @@ use util::ResultExt;
 use crate::{
     settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
     LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
-    LanguageModelProviderState, LanguageModelRequest, Role,
+    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 };
 
 const PROVIDER_ID: &str = "openai";
@@ -112,6 +112,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
                     model,
                     state: self.state.clone(),
                     http_client: self.http_client.clone(),
+                    request_limiter: RateLimiter::new(4),
                 }) as Arc<dyn LanguageModel>
             })
             .collect()
@@ -121,7 +122,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
         self.state.read(cx).api_key.is_some()
     }
 
-    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
         if self.is_authenticated(cx) {
             Task::ready(Ok(()))
         } else {
@@ -153,7 +154,7 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
             .into()
     }
 
-    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
         let settings = &AllLanguageModelSettings::get_global(cx).openai;
         let delete_credentials = cx.delete_credentials(&settings.api_url);
         let state = self.state.clone();
@@ -172,6 +173,7 @@ pub struct OpenAiLanguageModel {
     model: open_ai::Model,
     state: gpui::Model<State>,
     http_client: Arc<dyn HttpClient>,
+    request_limiter: RateLimiter,
 }
 
 impl LanguageModel for OpenAiLanguageModel {
@@ -226,7 +228,7 @@ impl LanguageModel for OpenAiLanguageModel {
             return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
         };
 
-        async move {
+        let future = self.request_limiter.stream(async move {
             let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
             let request = stream_completion(
                 http_client.as_ref(),
@@ -237,11 +239,12 @@ impl LanguageModel for OpenAiLanguageModel {
             );
             let response = request.await?;
             Ok(open_ai::extract_text_from_events(response).boxed())
-        }
-        .boxed()
+        });
+
+        async move { Ok(future.await?.boxed()) }.boxed()
     }
 
-    fn use_tool(
+    fn use_any_tool(
         &self,
         _request: LanguageModelRequest,
         _name: String,

crates/language_model/src/rate_limiter.rs 🔗

@@ -0,0 +1,70 @@
+use anyhow::Result;
+use futures::Stream;
+use smol::lock::{Semaphore, SemaphoreGuardArc};
+use std::{
+    future::Future,
+    pin::Pin,
+    sync::Arc,
+    task::{Context, Poll},
+};
+
+#[derive(Clone)]
+pub struct RateLimiter {
+    semaphore: Arc<Semaphore>,
+}
+
+pub struct RateLimitGuard<T> {
+    inner: T,
+    _guard: SemaphoreGuardArc,
+}
+
+impl<T> Stream for RateLimitGuard<T>
+where
+    T: Stream,
+{
+    type Item = T::Item;
+
+    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
+        unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
+    }
+}
+
+impl RateLimiter {
+    pub fn new(limit: usize) -> Self {
+        Self {
+            semaphore: Arc::new(Semaphore::new(limit)),
+        }
+    }
+
+    pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
+    where
+        Fut: 'a + Future<Output = Result<T>>,
+    {
+        let guard = self.semaphore.acquire_arc();
+        async move {
+            let guard = guard.await;
+            let result = future.await?;
+            drop(guard);
+            Ok(result)
+        }
+    }
+
+    pub fn stream<'a, Fut, T>(
+        &self,
+        future: Fut,
+    ) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item>>>
+    where
+        Fut: 'a + Future<Output = Result<T>>,
+        T: Stream,
+    {
+        let guard = self.semaphore.acquire_arc();
+        async move {
+            let guard = guard.await;
+            let inner = future.await?;
+            Ok(RateLimitGuard {
+                inner,
+                _guard: guard,
+            })
+        }
+    }
+}

crates/language_model/src/registry.rs 🔗

@@ -4,11 +4,12 @@ use crate::{
         copilot_chat::CopilotChatLanguageModelProvider, google::GoogleLanguageModelProvider,
         ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
     },
-    LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
+    LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
+    LanguageModelProviderState,
 };
 use client::Client;
 use collections::BTreeMap;
-use gpui::{AppContext, Global, Model, ModelContext};
+use gpui::{AppContext, EventEmitter, Global, Model, ModelContext};
 use std::sync::Arc;
 use ui::Context;
 
@@ -70,9 +71,19 @@ impl Global for GlobalLanguageModelRegistry {}
 
 #[derive(Default)]
 pub struct LanguageModelRegistry {
+    active_model: Option<ActiveModel>,
     providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
 }
 
+pub struct ActiveModel {
+    provider: Arc<dyn LanguageModelProvider>,
+    model: Option<Arc<dyn LanguageModel>>,
+}
+
+pub struct ActiveModelChanged;
+
+impl EventEmitter<ActiveModelChanged> for LanguageModelRegistry {}
+
 impl LanguageModelRegistry {
     pub fn global(cx: &AppContext) -> Model<Self> {
         cx.global::<GlobalLanguageModelRegistry>().0.clone()
@@ -88,6 +99,8 @@ impl LanguageModelRegistry {
         let registry = cx.new_model(|cx| {
             let mut registry = Self::default();
             registry.register_provider(fake_provider.clone(), cx);
+            let model = fake_provider.provided_models(cx)[0].clone();
+            registry.set_active_model(Some(model), cx);
             registry
         });
         cx.set_global(GlobalLanguageModelRegistry(registry));
@@ -136,6 +149,64 @@ impl LanguageModelRegistry {
     ) -> Option<Arc<dyn LanguageModelProvider>> {
         self.providers.get(name).cloned()
     }
+
+    pub fn select_active_model(
+        &mut self,
+        provider: &LanguageModelProviderId,
+        model_id: &LanguageModelId,
+        cx: &mut ModelContext<Self>,
+    ) {
+        let Some(provider) = self.provider(&provider) else {
+            return;
+        };
+
+        let models = provider.provided_models(cx);
+        if let Some(model) = models.iter().find(|model| &model.id() == model_id).cloned() {
+            self.set_active_model(Some(model), cx);
+        }
+    }
+
+    pub fn set_active_provider(
+        &mut self,
+        provider: Option<Arc<dyn LanguageModelProvider>>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        self.active_model = provider.map(|provider| ActiveModel {
+            provider,
+            model: None,
+        });
+        cx.emit(ActiveModelChanged);
+    }
+
+    pub fn set_active_model(
+        &mut self,
+        model: Option<Arc<dyn LanguageModel>>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        if let Some(model) = model {
+            let provider_id = model.provider_id();
+            if let Some(provider) = self.providers.get(&provider_id).cloned() {
+                self.active_model = Some(ActiveModel {
+                    provider,
+                    model: Some(model),
+                });
+                cx.emit(ActiveModelChanged);
+            } else {
+                log::warn!("Active model's provider not found in registry");
+            }
+        } else {
+            self.active_model = None;
+            cx.emit(ActiveModelChanged);
+        }
+    }
+
+    pub fn active_provider(&self) -> Option<Arc<dyn LanguageModelProvider>> {
+        Some(self.active_model.as_ref()?.provider.clone())
+    }
+
+    pub fn active_model(&self) -> Option<Arc<dyn LanguageModel>> {
+        self.active_model.as_ref()?.model.clone()
+    }
 }
 
 #[cfg(test)]

crates/language_model/src/settings.rs 🔗

@@ -89,9 +89,15 @@ impl AnthropicSettingsContent {
                         models
                             .into_iter()
                             .filter_map(|model| match model {
-                                anthropic::Model::Custom { name, max_tokens } => {
-                                    Some(provider::anthropic::AvailableModel { name, max_tokens })
-                                }
+                                anthropic::Model::Custom {
+                                    name,
+                                    max_tokens,
+                                    tool_override,
+                                } => Some(provider::anthropic::AvailableModel {
+                                    name,
+                                    max_tokens,
+                                    tool_override,
+                                }),
                                 _ => None,
                             })
                             .collect()

crates/semantic_index/Cargo.toml 🔗

@@ -22,7 +22,6 @@ anyhow.workspace = true
 client.workspace = true
 clock.workspace = true
 collections.workspace = true
-completion.workspace = true
 fs.workspace = true
 futures.workspace = true
 futures-batch.workspace = true

crates/semantic_index/src/semantic_index.rs 🔗

@@ -1261,6 +1261,3 @@ mod tests {
         );
     }
 }
-
-// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed.
-type _TODO = completion::LanguageModelCompletionProvider;