assistant: Limit amount of concurrent completion requests (#13856)

Bennet Bo Fenner created

This PR refactors the completion providers to only process a maximum
amount of completion requests at a time.

Also started refactoring language model providers to use traits, so it's
easier to allow specifying multiple providers in the future.

Release Notes:

- N/A

Change summary

crates/assistant/src/assistant.rs                     |   2 
crates/assistant/src/assistant_panel.rs               |  21 
crates/assistant/src/assistant_settings.rs            |   3 
crates/assistant/src/completion_provider.rs           | 533 ++++++------
crates/assistant/src/completion_provider/anthropic.rs | 106 +-
crates/assistant/src/completion_provider/cloud.rs     |  57 
crates/assistant/src/completion_provider/fake.rs      |  98 ++
crates/assistant/src/completion_provider/ollama.rs    | 202 ++--
crates/assistant/src/completion_provider/open_ai.rs   | 113 +-
crates/assistant/src/inline_assistant.rs              |  37 
crates/assistant/src/terminal_inline_assistant.rs     |   5 
11 files changed, 669 insertions(+), 508 deletions(-)

Detailed changes

crates/assistant/src/assistant.rs 🔗

@@ -163,7 +163,7 @@ impl LanguageModelRequestMessage {
     }
 }
 
-#[derive(Debug, Default, Serialize)]
+#[derive(Debug, Default, Serialize, Deserialize)]
 pub struct LanguageModelRequest {
     pub model: LanguageModel,
     pub messages: Vec<LanguageModelRequestMessage>,

crates/assistant/src/assistant_panel.rs 🔗

@@ -1409,7 +1409,7 @@ impl Context {
             }
 
             let request = self.to_completion_request(cx);
-            let stream = CompletionProvider::global(cx).complete(request);
+            let response = CompletionProvider::global(cx).complete(request, cx);
             let assistant_message = self
                 .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
                 .unwrap();
@@ -1422,11 +1422,12 @@ impl Context {
 
             let task = cx.spawn({
                 |this, mut cx| async move {
+                    let response = response.await;
                     let assistant_message_id = assistant_message.id;
                     let mut response_latency = None;
                     let stream_completion = async {
                         let request_start = Instant::now();
-                        let mut messages = stream.await?;
+                        let mut messages = response.inner.await?;
 
                         while let Some(message) = messages.next().await {
                             if response_latency.is_none() {
@@ -1718,10 +1719,11 @@ impl Context {
                 temperature: 1.0,
             };
 
-            let stream = CompletionProvider::global(cx).complete(request);
+            let response = CompletionProvider::global(cx).complete(request, cx);
             self.pending_summary = cx.spawn(|this, mut cx| {
                 async move {
-                    let mut messages = stream.await?;
+                    let response = response.await;
+                    let mut messages = response.inner.await?;
 
                     while let Some(message) = messages.next().await {
                         let text = message?;
@@ -3642,7 +3644,7 @@ mod tests {
     #[gpui::test]
     fn test_inserting_and_removing_messages(cx: &mut AppContext) {
         let settings_store = SettingsStore::test(cx);
-        cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
+        FakeCompletionProvider::setup_test(cx);
         cx.set_global(settings_store);
         init(cx);
         let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@@ -3774,7 +3776,7 @@ mod tests {
     fn test_message_splitting(cx: &mut AppContext) {
         let settings_store = SettingsStore::test(cx);
         cx.set_global(settings_store);
-        cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
+        FakeCompletionProvider::setup_test(cx);
         init(cx);
         let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
 
@@ -3867,7 +3869,7 @@ mod tests {
     #[gpui::test]
     fn test_messages_for_offsets(cx: &mut AppContext) {
         let settings_store = SettingsStore::test(cx);
-        cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
+        FakeCompletionProvider::setup_test(cx);
         cx.set_global(settings_store);
         init(cx);
         let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@@ -3952,7 +3954,8 @@ mod tests {
     async fn test_slash_commands(cx: &mut TestAppContext) {
         let settings_store = cx.update(SettingsStore::test);
         cx.set_global(settings_store);
-        cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
+        cx.update(|cx| FakeCompletionProvider::setup_test(cx));
+
         cx.update(Project::init_settings);
         cx.update(init);
         let fs = FakeFs::new(cx.background_executor.clone());
@@ -4147,7 +4150,7 @@ mod tests {
     async fn test_serialization(cx: &mut TestAppContext) {
         let settings_store = cx.update(SettingsStore::test);
         cx.set_global(settings_store);
-        cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
+        cx.update(FakeCompletionProvider::setup_test);
         cx.update(init);
         let registry = Arc::new(LanguageRegistry::test(cx.executor()));
         let context =

crates/assistant/src/assistant_settings.rs 🔗

@@ -1,5 +1,6 @@
 use std::fmt;
 
+use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
 pub use anthropic::Model as AnthropicModel;
 use gpui::Pixels;
 pub use ollama::Model as OllamaModel;
@@ -15,8 +16,6 @@ use serde::{
 use settings::{Settings, SettingsSources};
 use strum::{EnumIter, IntoEnumIterator};
 
-use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
-
 #[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 pub enum CloudModel {
     Gpt3Point5Turbo,

crates/assistant/src/completion_provider.rs 🔗

@@ -11,6 +11,8 @@ pub use cloud::*;
 pub use fake::*;
 pub use ollama::*;
 pub use open_ai::*;
+use parking_lot::RwLock;
+use smol::lock::{Semaphore, SemaphoreGuardArc};
 
 use crate::{
     assistant_settings::{AssistantProvider, AssistantSettings},
@@ -21,8 +23,8 @@ use client::Client;
 use futures::{future::BoxFuture, stream::BoxStream};
 use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
 use settings::{Settings, SettingsStore};
-use std::sync::Arc;
 use std::time::Duration;
+use std::{any::Any, sync::Arc};
 
 /// Choose which model to use for openai provider.
 /// If the model is not available, try to use the first available model, or fallback to the original model.
@@ -39,272 +41,91 @@ fn choose_openai_model(
 }
 
 pub fn init(client: Arc<Client>, cx: &mut AppContext) {
-    let mut settings_version = 0;
-    let provider = match &AssistantSettings::get_global(cx).provider {
-        AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
-            CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
-        ),
-        AssistantProvider::OpenAi {
-            model,
-            api_url,
-            low_speed_timeout_in_seconds,
-            available_models,
-        } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
-            choose_openai_model(model, available_models),
-            api_url.clone(),
-            client.http_client(),
-            low_speed_timeout_in_seconds.map(Duration::from_secs),
-            settings_version,
-        )),
-        AssistantProvider::Anthropic {
-            model,
-            api_url,
-            low_speed_timeout_in_seconds,
-        } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
-            model.clone(),
-            api_url.clone(),
-            client.http_client(),
-            low_speed_timeout_in_seconds.map(Duration::from_secs),
-            settings_version,
-        )),
-        AssistantProvider::Ollama {
-            model,
-            api_url,
-            low_speed_timeout_in_seconds,
-        } => CompletionProvider::Ollama(OllamaCompletionProvider::new(
-            model.clone(),
-            api_url.clone(),
-            client.http_client(),
-            low_speed_timeout_in_seconds.map(Duration::from_secs),
-            settings_version,
-            cx,
-        )),
-    };
-    cx.set_global(provider);
+    let provider = create_provider_from_settings(client.clone(), 0, cx);
+    cx.set_global(CompletionProvider::new(provider, Some(client)));
 
+    let mut settings_version = 0;
     cx.observe_global::<SettingsStore>(move |cx| {
         settings_version += 1;
         cx.update_global::<CompletionProvider, _>(|provider, cx| {
-            match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
-                (
-                    CompletionProvider::OpenAi(provider),
-                    AssistantProvider::OpenAi {
-                        model,
-                        api_url,
-                        low_speed_timeout_in_seconds,
-                        available_models,
-                    },
-                ) => {
-                    provider.update(
-                        choose_openai_model(model, available_models),
-                        api_url.clone(),
-                        low_speed_timeout_in_seconds.map(Duration::from_secs),
-                        settings_version,
-                    );
-                }
-                (
-                    CompletionProvider::Anthropic(provider),
-                    AssistantProvider::Anthropic {
-                        model,
-                        api_url,
-                        low_speed_timeout_in_seconds,
-                    },
-                ) => {
-                    provider.update(
-                        model.clone(),
-                        api_url.clone(),
-                        low_speed_timeout_in_seconds.map(Duration::from_secs),
-                        settings_version,
-                    );
-                }
-
-                (
-                    CompletionProvider::Ollama(provider),
-                    AssistantProvider::Ollama {
-                        model,
-                        api_url,
-                        low_speed_timeout_in_seconds,
-                    },
-                ) => {
-                    provider.update(
-                        model.clone(),
-                        api_url.clone(),
-                        low_speed_timeout_in_seconds.map(Duration::from_secs),
-                        settings_version,
-                        cx,
-                    );
-                }
-
-                (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
-                    provider.update(model.clone(), settings_version);
-                }
-                (_, AssistantProvider::ZedDotDev { model }) => {
-                    *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
-                        model.clone(),
-                        client.clone(),
-                        settings_version,
-                        cx,
-                    ));
-                }
-                (
-                    _,
-                    AssistantProvider::OpenAi {
-                        model,
-                        api_url,
-                        low_speed_timeout_in_seconds,
-                        available_models,
-                    },
-                ) => {
-                    *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
-                        choose_openai_model(model, available_models),
-                        api_url.clone(),
-                        client.http_client(),
-                        low_speed_timeout_in_seconds.map(Duration::from_secs),
-                        settings_version,
-                    ));
-                }
-                (
-                    _,
-                    AssistantProvider::Anthropic {
-                        model,
-                        api_url,
-                        low_speed_timeout_in_seconds,
-                    },
-                ) => {
-                    *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
-                        model.clone(),
-                        api_url.clone(),
-                        client.http_client(),
-                        low_speed_timeout_in_seconds.map(Duration::from_secs),
-                        settings_version,
-                    ));
-                }
-                (
-                    _,
-                    AssistantProvider::Ollama {
-                        model,
-                        api_url,
-                        low_speed_timeout_in_seconds,
-                    },
-                ) => {
-                    *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new(
-                        model.clone(),
-                        api_url.clone(),
-                        client.http_client(),
-                        low_speed_timeout_in_seconds.map(Duration::from_secs),
-                        settings_version,
-                        cx,
-                    ));
-                }
-            }
+            provider.update_settings(settings_version, cx);
         })
     })
     .detach();
 }
 
-pub enum CompletionProvider {
-    OpenAi(OpenAiCompletionProvider),
-    Anthropic(AnthropicCompletionProvider),
-    Cloud(CloudCompletionProvider),
-    #[cfg(test)]
-    Fake(FakeCompletionProvider),
-    Ollama(OllamaCompletionProvider),
+pub struct CompletionResponse {
+    pub inner: BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>,
+    _lock: SemaphoreGuardArc,
 }
 
-impl gpui::Global for CompletionProvider {}
+pub trait LanguageModelCompletionProvider: Send + Sync {
+    fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
+    fn settings_version(&self) -> usize;
+    fn is_authenticated(&self) -> bool;
+    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
+    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
+    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
+    fn model(&self) -> LanguageModel;
+    fn count_tokens(
+        &self,
+        request: LanguageModelRequest,
+        cx: &AppContext,
+    ) -> BoxFuture<'static, Result<usize>>;
+    fn complete(
+        &self,
+        request: LanguageModelRequest,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+
+    fn as_any_mut(&mut self) -> &mut dyn Any;
+}
+
+const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
+
+pub struct CompletionProvider {
+    provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
+    client: Option<Arc<Client>>,
+    request_limiter: Arc<Semaphore>,
+}
 
 impl CompletionProvider {
-    pub fn global(cx: &AppContext) -> &Self {
-        cx.global::<Self>()
+    pub fn new(
+        provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
+        client: Option<Arc<Client>>,
+    ) -> Self {
+        Self {
+            provider,
+            client,
+            request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
+        }
     }
 
     pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
-        match self {
-            CompletionProvider::OpenAi(provider) => provider
-                .available_models(cx)
-                .map(LanguageModel::OpenAi)
-                .collect(),
-            CompletionProvider::Anthropic(provider) => provider
-                .available_models()
-                .map(LanguageModel::Anthropic)
-                .collect(),
-            CompletionProvider::Cloud(provider) => provider
-                .available_models()
-                .map(LanguageModel::Cloud)
-                .collect(),
-            CompletionProvider::Ollama(provider) => provider
-                .available_models()
-                .map(|model| LanguageModel::Ollama(model.clone()))
-                .collect(),
-            #[cfg(test)]
-            CompletionProvider::Fake(_) => unimplemented!(),
-        }
+        self.provider.read().available_models(cx)
     }
 
     pub fn settings_version(&self) -> usize {
-        match self {
-            CompletionProvider::OpenAi(provider) => provider.settings_version(),
-            CompletionProvider::Anthropic(provider) => provider.settings_version(),
-            CompletionProvider::Cloud(provider) => provider.settings_version(),
-            CompletionProvider::Ollama(provider) => provider.settings_version(),
-            #[cfg(test)]
-            CompletionProvider::Fake(_) => unimplemented!(),
-        }
+        self.provider.read().settings_version()
     }
 
     pub fn is_authenticated(&self) -> bool {
-        match self {
-            CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
-            CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
-            CompletionProvider::Cloud(provider) => provider.is_authenticated(),
-            CompletionProvider::Ollama(provider) => provider.is_authenticated(),
-            #[cfg(test)]
-            CompletionProvider::Fake(_) => true,
-        }
+        self.provider.read().is_authenticated()
     }
 
     pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
-        match self {
-            CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
-            CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
-            CompletionProvider::Cloud(provider) => provider.authenticate(cx),
-            CompletionProvider::Ollama(provider) => provider.authenticate(cx),
-            #[cfg(test)]
-            CompletionProvider::Fake(_) => Task::ready(Ok(())),
-        }
+        self.provider.read().authenticate(cx)
     }
 
     pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
-        match self {
-            CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
-            CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
-            CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
-            CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx),
-            #[cfg(test)]
-            CompletionProvider::Fake(_) => unimplemented!(),
-        }
+        self.provider.read().authentication_prompt(cx)
     }
 
     pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
-        match self {
-            CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
-            CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
-            CompletionProvider::Cloud(_) => Task::ready(Ok(())),
-            CompletionProvider::Ollama(provider) => provider.reset_credentials(cx),
-            #[cfg(test)]
-            CompletionProvider::Fake(_) => Task::ready(Ok(())),
-        }
+        self.provider.read().reset_credentials(cx)
     }
 
     pub fn model(&self) -> LanguageModel {
-        match self {
-            CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
-            CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
-            CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
-            CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()),
-            #[cfg(test)]
-            CompletionProvider::Fake(_) => LanguageModel::default(),
-        }
+        self.provider.read().model()
     }
 
     pub fn count_tokens(
@@ -312,27 +133,241 @@ impl CompletionProvider {
         request: LanguageModelRequest,
         cx: &AppContext,
     ) -> BoxFuture<'static, Result<usize>> {
-        match self {
-            CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
-            CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
-            CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
-            CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx),
-            #[cfg(test)]
-            CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
-        }
+        self.provider.read().count_tokens(request, cx)
     }
 
     pub fn complete(
         &self,
         request: LanguageModelRequest,
-    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
-        match self {
-            CompletionProvider::OpenAi(provider) => provider.complete(request),
-            CompletionProvider::Anthropic(provider) => provider.complete(request),
-            CompletionProvider::Cloud(provider) => provider.complete(request),
-            CompletionProvider::Ollama(provider) => provider.complete(request),
-            #[cfg(test)]
-            CompletionProvider::Fake(provider) => provider.complete(),
+        cx: &AppContext,
+    ) -> Task<CompletionResponse> {
+        let rate_limiter = self.request_limiter.clone();
+        let provider = self.provider.clone();
+        cx.background_executor().spawn(async move {
+            let lock = rate_limiter.acquire_arc().await;
+            let response = provider.read().complete(request);
+            CompletionResponse {
+                inner: response,
+                _lock: lock,
+            }
+        })
+    }
+}
+
+impl gpui::Global for CompletionProvider {}
+
+impl CompletionProvider {
+    pub fn global(cx: &AppContext) -> &Self {
+        cx.global::<Self>()
+    }
+
+    pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
+        &mut self,
+        update: impl FnOnce(&mut T) -> R,
+    ) -> Option<R> {
+        let mut provider = self.provider.write();
+        if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
+            Some(update(provider))
+        } else {
+            None
+        }
+    }
+
+    pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
+        let updated = match &AssistantSettings::get_global(cx).provider {
+            AssistantProvider::ZedDotDev { model } => self
+                .update_current_as::<_, CloudCompletionProvider>(|provider| {
+                    provider.update(model.clone(), version);
+                }),
+            AssistantProvider::OpenAi {
+                model,
+                api_url,
+                low_speed_timeout_in_seconds,
+                available_models,
+            } => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
+                provider.update(
+                    choose_openai_model(&model, &available_models),
+                    api_url.clone(),
+                    low_speed_timeout_in_seconds.map(Duration::from_secs),
+                    version,
+                );
+            }),
+            AssistantProvider::Anthropic {
+                model,
+                api_url,
+                low_speed_timeout_in_seconds,
+            } => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
+                provider.update(
+                    model.clone(),
+                    api_url.clone(),
+                    low_speed_timeout_in_seconds.map(Duration::from_secs),
+                    version,
+                );
+            }),
+            AssistantProvider::Ollama {
+                model,
+                api_url,
+                low_speed_timeout_in_seconds,
+            } => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
+                provider.update(
+                    model.clone(),
+                    api_url.clone(),
+                    low_speed_timeout_in_seconds.map(Duration::from_secs),
+                    version,
+                    cx,
+                );
+            }),
+        };
+
+        // Previously configured provider was changed to another one
+        if updated.is_none() {
+            if let Some(client) = self.client.clone() {
+                self.provider = create_provider_from_settings(client, version, cx);
+            } else {
+                log::warn!("completion provider cannot be created because client is not set");
+            }
         }
     }
 }
+
+fn create_provider_from_settings(
+    client: Arc<Client>,
+    settings_version: usize,
+    cx: &mut AppContext,
+) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
+    match &AssistantSettings::get_global(cx).provider {
+        AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
+            CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
+        )),
+        AssistantProvider::OpenAi {
+            model,
+            api_url,
+            low_speed_timeout_in_seconds,
+            available_models,
+        } => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
+            choose_openai_model(&model, &available_models),
+            api_url.clone(),
+            client.http_client(),
+            low_speed_timeout_in_seconds.map(Duration::from_secs),
+            settings_version,
+        ))),
+        AssistantProvider::Anthropic {
+            model,
+            api_url,
+            low_speed_timeout_in_seconds,
+        } => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
+            model.clone(),
+            api_url.clone(),
+            client.http_client(),
+            low_speed_timeout_in_seconds.map(Duration::from_secs),
+            settings_version,
+        ))),
+        AssistantProvider::Ollama {
+            model,
+            api_url,
+            low_speed_timeout_in_seconds,
+        } => Arc::new(RwLock::new(OllamaCompletionProvider::new(
+            model.clone(),
+            api_url.clone(),
+            client.http_client(),
+            low_speed_timeout_in_seconds.map(Duration::from_secs),
+            settings_version,
+            cx,
+        ))),
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::sync::Arc;
+
+    use gpui::AppContext;
+    use parking_lot::RwLock;
+    use settings::SettingsStore;
+    use smol::stream::StreamExt;
+
+    use crate::{
+        completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
+        FakeCompletionProvider, LanguageModelRequest,
+    };
+
+    #[gpui::test]
+    fn test_rate_limiting(cx: &mut AppContext) {
+        SettingsStore::test(cx);
+        let fake_provider = FakeCompletionProvider::setup_test(cx);
+
+        let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
+
+        // Enqueue some requests
+        for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
+            let response = provider.complete(
+                LanguageModelRequest {
+                    temperature: i as f32 / 10.0,
+                    ..Default::default()
+                },
+                cx,
+            );
+            cx.background_executor()
+                .spawn(async move {
+                    let response = response.await;
+                    let mut stream = response.inner.await.unwrap();
+                    while let Some(message) = stream.next().await {
+                        message.unwrap();
+                    }
+                })
+                .detach();
+        }
+        cx.background_executor().run_until_parked();
+
+        assert_eq!(
+            fake_provider.completion_count(),
+            MAX_CONCURRENT_COMPLETION_REQUESTS
+        );
+
+        // Get the first completion request that is in flight and mark it as completed.
+        let completion = fake_provider
+            .running_completions()
+            .into_iter()
+            .next()
+            .unwrap();
+        fake_provider.finish_completion(&completion);
+
+        // Ensure that the number of in-flight completion requests is reduced.
+        assert_eq!(
+            fake_provider.completion_count(),
+            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
+        );
+
+        cx.background_executor().run_until_parked();
+
+        // Ensure that another completion request was allowed to acquire the lock.
+        assert_eq!(
+            fake_provider.completion_count(),
+            MAX_CONCURRENT_COMPLETION_REQUESTS
+        );
+
+        // Mark all completion requests as finished that are in flight.
+        for request in fake_provider.running_completions() {
+            fake_provider.finish_completion(&request);
+        }
+
+        assert_eq!(fake_provider.completion_count(), 0);
+
+        // Wait until the background tasks acquire the lock again.
+        cx.background_executor().run_until_parked();
+
+        assert_eq!(
+            fake_provider.completion_count(),
+            MAX_CONCURRENT_COMPLETION_REQUESTS - 1
+        );
+
+        // Finish all remaining completion requests.
+        for request in fake_provider.running_completions() {
+            fake_provider.finish_completion(&request);
+        }
+
+        cx.background_executor().run_until_parked();
+
+        assert_eq!(fake_provider.completion_count(), 0);
+    }
+}

crates/assistant/src/completion_provider/anthropic.rs 🔗

@@ -2,7 +2,7 @@ use crate::{
     assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
     Role,
 };
-use crate::{count_open_ai_tokens, LanguageModelRequestMessage};
+use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
 use anthropic::{stream_completion, Request, RequestMessage};
 use anyhow::{anyhow, Result};
 use editor::{Editor, EditorElement, EditorStyle};
@@ -26,50 +26,22 @@ pub struct AnthropicCompletionProvider {
     settings_version: usize,
 }
 
-impl AnthropicCompletionProvider {
-    pub fn new(
-        model: AnthropicModel,
-        api_url: String,
-        http_client: Arc<dyn HttpClient>,
-        low_speed_timeout: Option<Duration>,
-        settings_version: usize,
-    ) -> Self {
-        Self {
-            api_key: None,
-            api_url,
-            model,
-            http_client,
-            low_speed_timeout,
-            settings_version,
-        }
-    }
-
-    pub fn update(
-        &mut self,
-        model: AnthropicModel,
-        api_url: String,
-        low_speed_timeout: Option<Duration>,
-        settings_version: usize,
-    ) {
-        self.model = model;
-        self.api_url = api_url;
-        self.low_speed_timeout = low_speed_timeout;
-        self.settings_version = settings_version;
-    }
-
-    pub fn available_models(&self) -> impl Iterator<Item = AnthropicModel> {
+impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
+    fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
         AnthropicModel::iter()
+            .map(LanguageModel::Anthropic)
+            .collect()
     }
 
-    pub fn settings_version(&self) -> usize {
+    fn settings_version(&self) -> usize {
         self.settings_version
     }
 
-    pub fn is_authenticated(&self) -> bool {
+    fn is_authenticated(&self) -> bool {
         self.api_key.is_some()
     }
 
-    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
         if self.is_authenticated() {
             Task::ready(Ok(()))
         } else {
@@ -85,36 +57,36 @@ impl AnthropicCompletionProvider {
                     String::from_utf8(api_key)?
                 };
                 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
-                    if let CompletionProvider::Anthropic(provider) = provider {
+                    provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
                         provider.api_key = Some(api_key);
-                    }
+                    });
                 })
             })
         }
     }
 
-    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
         let delete_credentials = cx.delete_credentials(&self.api_url);
         cx.spawn(|mut cx| async move {
             delete_credentials.await.log_err();
             cx.update_global::<CompletionProvider, _>(|provider, _cx| {
-                if let CompletionProvider::Anthropic(provider) = provider {
+                provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
                     provider.api_key = None;
-                }
+                });
             })
         })
     }
 
-    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
         cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
             .into()
     }
 
-    pub fn model(&self) -> AnthropicModel {
-        self.model.clone()
+    fn model(&self) -> LanguageModel {
+        LanguageModel::Anthropic(self.model.clone())
     }
 
-    pub fn count_tokens(
+    fn count_tokens(
         &self,
         request: LanguageModelRequest,
         cx: &AppContext,
@@ -122,7 +94,7 @@ impl AnthropicCompletionProvider {
         count_open_ai_tokens(request, cx.background_executor())
     }
 
-    pub fn complete(
+    fn complete(
         &self,
         request: LanguageModelRequest,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
@@ -167,12 +139,48 @@ impl AnthropicCompletionProvider {
         .boxed()
     }
 
+    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
+        self
+    }
+}
+
+impl AnthropicCompletionProvider {
+    pub fn new(
+        model: AnthropicModel,
+        api_url: String,
+        http_client: Arc<dyn HttpClient>,
+        low_speed_timeout: Option<Duration>,
+        settings_version: usize,
+    ) -> Self {
+        Self {
+            api_key: None,
+            api_url,
+            model,
+            http_client,
+            low_speed_timeout,
+            settings_version,
+        }
+    }
+
+    pub fn update(
+        &mut self,
+        model: AnthropicModel,
+        api_url: String,
+        low_speed_timeout: Option<Duration>,
+        settings_version: usize,
+    ) {
+        self.model = model;
+        self.api_url = api_url;
+        self.low_speed_timeout = low_speed_timeout;
+        self.settings_version = settings_version;
+    }
+
     fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
         preprocess_anthropic_request(&mut request);
 
         let model = match request.model {
             LanguageModel::Anthropic(model) => model,
-            _ => self.model(),
+            _ => self.model.clone(),
         };
 
         let mut system_message = String::new();
@@ -278,9 +286,9 @@ impl AuthenticationPrompt {
         cx.spawn(|_, mut cx| async move {
             write_credentials.await?;
             cx.update_global::<CompletionProvider, _>(|provider, _cx| {
-                if let CompletionProvider::Anthropic(provider) = provider {
+                provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
                     provider.api_key = Some(api_key);
-                }
+                });
             })
         })
         .detach_and_log_err(cx);

crates/assistant/src/completion_provider/cloud.rs 🔗

@@ -1,6 +1,6 @@
 use crate::{
     assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
-    LanguageModelRequest,
+    LanguageModelCompletionProvider, LanguageModelRequest,
 };
 use anyhow::{anyhow, Result};
 use client::{proto, Client};
@@ -30,11 +30,9 @@ impl CloudCompletionProvider {
         let maintain_client_status = cx.spawn(|mut cx| async move {
             while let Some(status) = status_rx.next().await {
                 let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
-                    if let CompletionProvider::Cloud(provider) = provider {
+                    provider.update_current_as::<_, Self>(|provider| {
                         provider.status = status;
-                    } else {
-                        unreachable!()
-                    }
+                    });
                 });
             }
         });
@@ -51,44 +49,53 @@ impl CloudCompletionProvider {
         self.model = model;
         self.settings_version = settings_version;
     }
+}
 
-    pub fn available_models(&self) -> impl Iterator<Item = CloudModel> {
+impl LanguageModelCompletionProvider for CloudCompletionProvider {
+    fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
         let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
             Some(custom_model)
         } else {
             None
         };
-        CloudModel::iter().filter_map(move |model| {
-            if let CloudModel::Custom(_) = model {
-                Some(CloudModel::Custom(custom_model.take()?))
-            } else {
-                Some(model)
-            }
-        })
+        CloudModel::iter()
+            .filter_map(move |model| {
+                if let CloudModel::Custom(_) = model {
+                    Some(CloudModel::Custom(custom_model.take()?))
+                } else {
+                    Some(model)
+                }
+            })
+            .map(LanguageModel::Cloud)
+            .collect()
     }
 
-    pub fn settings_version(&self) -> usize {
+    fn settings_version(&self) -> usize {
         self.settings_version
     }
 
-    pub fn model(&self) -> CloudModel {
-        self.model.clone()
-    }
-
-    pub fn is_authenticated(&self) -> bool {
+    fn is_authenticated(&self) -> bool {
         self.status.is_connected()
     }
 
-    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
         let client = self.client.clone();
         cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
     }
 
-    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
         cx.new_view(|_cx| AuthenticationPrompt).into()
     }
 
-    pub fn count_tokens(
+    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
+        Task::ready(Ok(()))
+    }
+
+    fn model(&self) -> LanguageModel {
+        LanguageModel::Cloud(self.model.clone())
+    }
+
+    fn count_tokens(
         &self,
         request: LanguageModelRequest,
         cx: &AppContext,
@@ -128,7 +135,7 @@ impl CloudCompletionProvider {
         }
     }
 
-    pub fn complete(
+    fn complete(
         &self,
         mut request: LanguageModelRequest,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
@@ -161,6 +168,10 @@ impl CloudCompletionProvider {
             })
             .boxed()
     }
+
+    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
+        self
+    }
 }
 
 struct AuthenticationPrompt;

crates/assistant/src/completion_provider/fake.rs 🔗

@@ -1,29 +1,107 @@
 use anyhow::Result;
+use collections::HashMap;
 use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui::{AnyView, AppContext, Task};
 use std::sync::Arc;
+use ui::WindowContext;
+
+use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
 
 #[derive(Clone, Default)]
 pub struct FakeCompletionProvider {
-    current_completion_tx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedSender<String>>>>,
+    current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
 }
 
 impl FakeCompletionProvider {
-    pub fn complete(&self) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
-        let (tx, rx) = mpsc::unbounded();
-        *self.current_completion_tx.lock() = Some(tx);
-        async move { Ok(rx.map(Ok).boxed()) }.boxed()
+    #[cfg(test)]
+    pub fn setup_test(cx: &mut AppContext) -> Self {
+        use crate::CompletionProvider;
+        use parking_lot::RwLock;
+
+        let this = Self::default();
+        let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
+        cx.set_global(provider);
+        this
+    }
+
+    pub fn running_completions(&self) -> Vec<LanguageModelRequest> {
+        self.current_completion_txs
+            .lock()
+            .keys()
+            .map(|k| serde_json::from_str(k).unwrap())
+            .collect()
+    }
+
+    pub fn completion_count(&self) -> usize {
+        self.current_completion_txs.lock().len()
     }
 
-    pub fn send_completion(&self, chunk: String) {
-        self.current_completion_tx
+    pub fn send_completion(&self, request: &LanguageModelRequest, chunk: String) {
+        let json = serde_json::to_string(request).unwrap();
+        self.current_completion_txs
             .lock()
-            .as_ref()
+            .get(&json)
             .unwrap()
             .unbounded_send(chunk)
             .unwrap();
     }
 
-    pub fn finish_completion(&self) {
-        self.current_completion_tx.lock().take();
+    pub fn finish_completion(&self, request: &LanguageModelRequest) {
+        self.current_completion_txs
+            .lock()
+            .remove(&serde_json::to_string(request).unwrap());
+    }
+}
+
+impl LanguageModelCompletionProvider for FakeCompletionProvider {
+    fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
+        vec![LanguageModel::default()]
+    }
+
+    fn settings_version(&self) -> usize {
+        0
+    }
+
+    fn is_authenticated(&self) -> bool {
+        true
+    }
+
+    fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
+        Task::ready(Ok(()))
+    }
+
+    fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
+        unimplemented!()
+    }
+
+    fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
+        Task::ready(Ok(()))
+    }
+
+    fn model(&self) -> LanguageModel {
+        LanguageModel::default()
+    }
+
+    fn count_tokens(
+        &self,
+        _request: LanguageModelRequest,
+        _cx: &AppContext,
+    ) -> BoxFuture<'static, Result<usize>> {
+        futures::future::ready(Ok(0)).boxed()
+    }
+
+    fn complete(
+        &self,
+        _request: LanguageModelRequest,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        let (tx, rx) = mpsc::unbounded();
+        self.current_completion_txs
+            .lock()
+            .insert(serde_json::to_string(&_request).unwrap(), tx);
+        async move { Ok(rx.map(Ok).boxed()) }.boxed()
+    }
+
+    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
+        self
     }
 }

crates/assistant/src/completion_provider/ollama.rs 🔗

@@ -1,3 +1,4 @@
+use crate::LanguageModelCompletionProvider;
 use crate::{
     assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
 };
@@ -26,6 +27,108 @@ pub struct OllamaCompletionProvider {
     available_models: Vec<OllamaModel>,
 }
 
+impl LanguageModelCompletionProvider for OllamaCompletionProvider {
+    fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
+        self.available_models
+            .iter()
+            .map(|m| LanguageModel::Ollama(m.clone()))
+            .collect()
+    }
+
+    fn settings_version(&self) -> usize {
+        self.settings_version
+    }
+
+    fn is_authenticated(&self) -> bool {
+        !self.available_models.is_empty()
+    }
+
+    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+        if self.is_authenticated() {
+            Task::ready(Ok(()))
+        } else {
+            self.fetch_models(cx)
+        }
+    }
+
+    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+        let fetch_models = Box::new(move |cx: &mut WindowContext| {
+            cx.update_global::<CompletionProvider, _>(|provider, cx| {
+                provider
+                    .update_current_as::<_, OllamaCompletionProvider>(|provider| {
+                        provider.fetch_models(cx)
+                    })
+                    .unwrap_or_else(|| Task::ready(Ok(())))
+            })
+        });
+
+        cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
+            .into()
+    }
+
+    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+        self.fetch_models(cx)
+    }
+
+    fn model(&self) -> LanguageModel {
+        LanguageModel::Ollama(self.model.clone())
+    }
+
+    fn count_tokens(
+        &self,
+        request: LanguageModelRequest,
+        _cx: &AppContext,
+    ) -> BoxFuture<'static, Result<usize>> {
+        // There is no endpoint for this _yet_ in Ollama
+        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
+        let token_count = request
+            .messages
+            .iter()
+            .map(|msg| msg.content.chars().count())
+            .sum::<usize>()
+            / 4;
+
+        async move { Ok(token_count) }.boxed()
+    }
+
+    fn complete(
+        &self,
+        request: LanguageModelRequest,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        let request = self.to_ollama_request(request);
+
+        let http_client = self.http_client.clone();
+        let api_url = self.api_url.clone();
+        let low_speed_timeout = self.low_speed_timeout;
+        async move {
+            let request =
+                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
+            let response = request.await?;
+            let stream = response
+                .filter_map(|response| async move {
+                    match response {
+                        Ok(delta) => {
+                            let content = match delta.message {
+                                ChatMessage::User { content } => content,
+                                ChatMessage::Assistant { content } => content,
+                                ChatMessage::System { content } => content,
+                            };
+                            Some(Ok(content))
+                        }
+                        Err(error) => Some(Err(error)),
+                    }
+                })
+                .boxed();
+            Ok(stream)
+        }
+        .boxed()
+    }
+
+    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
+        self
+    }
+}
+
 impl OllamaCompletionProvider {
     pub fn new(
         model: OllamaModel,
@@ -87,36 +190,12 @@ impl OllamaCompletionProvider {
         self.settings_version = settings_version;
     }
 
-    pub fn available_models(&self) -> impl Iterator<Item = &OllamaModel> {
-        self.available_models.iter()
-    }
-
     pub fn select_first_available_model(&mut self) {
         if let Some(model) = self.available_models.first() {
             self.model = model.clone();
         }
     }
 
-    pub fn settings_version(&self) -> usize {
-        self.settings_version
-    }
-
-    pub fn is_authenticated(&self) -> bool {
-        !self.available_models.is_empty()
-    }
-
-    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
-        if self.is_authenticated() {
-            Task::ready(Ok(()))
-        } else {
-            self.fetch_models(cx)
-        }
-    }
-
-    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
-        self.fetch_models(cx)
-    }
-
     pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
         let http_client = self.http_client.clone();
         let api_url = self.api_url.clone();
@@ -137,90 +216,21 @@ impl OllamaCompletionProvider {
             models.sort_by(|a, b| a.name.cmp(&b.name));
 
             cx.update_global::<CompletionProvider, _>(|provider, _cx| {
-                if let CompletionProvider::Ollama(provider) = provider {
+                provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
                     provider.available_models = models;
 
                     if !provider.available_models.is_empty() && provider.model.name.is_empty() {
                         provider.select_first_available_model()
                     }
-                }
+                });
             })
         })
     }
 
-    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
-        let fetch_models = Box::new(move |cx: &mut WindowContext| {
-            cx.update_global::<CompletionProvider, _>(|provider, cx| {
-                if let CompletionProvider::Ollama(provider) = provider {
-                    provider.fetch_models(cx)
-                } else {
-                    Task::ready(Ok(()))
-                }
-            })
-        });
-
-        cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
-            .into()
-    }
-
-    pub fn model(&self) -> OllamaModel {
-        self.model.clone()
-    }
-
-    pub fn count_tokens(
-        &self,
-        request: LanguageModelRequest,
-        _cx: &AppContext,
-    ) -> BoxFuture<'static, Result<usize>> {
-        // There is no endpoint for this _yet_ in Ollama
-        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
-        let token_count = request
-            .messages
-            .iter()
-            .map(|msg| msg.content.chars().count())
-            .sum::<usize>()
-            / 4;
-
-        async move { Ok(token_count) }.boxed()
-    }
-
-    pub fn complete(
-        &self,
-        request: LanguageModelRequest,
-    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
-        let request = self.to_ollama_request(request);
-
-        let http_client = self.http_client.clone();
-        let api_url = self.api_url.clone();
-        let low_speed_timeout = self.low_speed_timeout;
-        async move {
-            let request =
-                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
-            let response = request.await?;
-            let stream = response
-                .filter_map(|response| async move {
-                    match response {
-                        Ok(delta) => {
-                            let content = match delta.message {
-                                ChatMessage::User { content } => content,
-                                ChatMessage::Assistant { content } => content,
-                                ChatMessage::System { content } => content,
-                            };
-                            Some(Ok(content))
-                        }
-                        Err(error) => Some(Err(error)),
-                    }
-                })
-                .boxed();
-            Ok(stream)
-        }
-        .boxed()
-    }
-
     fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
         let model = match request.model {
             LanguageModel::Ollama(model) => model,
-            _ => self.model(),
+            _ => self.model.clone(),
         };
 
         ChatRequest {

crates/assistant/src/completion_provider/open_ai.rs 🔗

@@ -1,5 +1,6 @@
 use crate::assistant_settings::CloudModel;
 use crate::assistant_settings::{AssistantProvider, AssistantSettings};
+use crate::LanguageModelCompletionProvider;
 use crate::{
     assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
 };
@@ -57,37 +58,75 @@ impl OpenAiCompletionProvider {
         self.settings_version = settings_version;
     }
 
-    pub fn available_models(&self, cx: &AppContext) -> impl Iterator<Item = OpenAiModel> {
+    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
+        let model = match request.model {
+            LanguageModel::OpenAi(model) => model,
+            _ => self.model.clone(),
+        };
+
+        Request {
+            model,
+            messages: request
+                .messages
+                .into_iter()
+                .map(|msg| match msg.role {
+                    Role::User => RequestMessage::User {
+                        content: msg.content,
+                    },
+                    Role::Assistant => RequestMessage::Assistant {
+                        content: Some(msg.content),
+                        tool_calls: Vec::new(),
+                    },
+                    Role::System => RequestMessage::System {
+                        content: msg.content,
+                    },
+                })
+                .collect(),
+            stream: true,
+            stop: request.stop,
+            temperature: request.temperature,
+            tools: Vec::new(),
+            tool_choice: None,
+        }
+    }
+}
+
+impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
+    fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
         if let AssistantProvider::OpenAi {
             available_models, ..
         } = &AssistantSettings::get_global(cx).provider
         {
             if !available_models.is_empty() {
-                // available_models is set, just return it
-                return available_models.clone().into_iter();
+                return available_models
+                    .iter()
+                    .cloned()
+                    .map(LanguageModel::OpenAi)
+                    .collect();
             }
         }
         let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
-            // available_models is not set but the default model is set to custom, only show custom
             vec![self.model.clone()]
         } else {
-            // default case, use all models except custom
             OpenAiModel::iter()
                 .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
                 .collect()
         };
-        available_models.into_iter()
+        available_models
+            .into_iter()
+            .map(LanguageModel::OpenAi)
+            .collect()
     }
 
-    pub fn settings_version(&self) -> usize {
+    fn settings_version(&self) -> usize {
         self.settings_version
     }
 
-    pub fn is_authenticated(&self) -> bool {
+    fn is_authenticated(&self) -> bool {
         self.api_key.is_some()
     }
 
-    pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
         if self.is_authenticated() {
             Task::ready(Ok(()))
         } else {
@@ -103,36 +142,36 @@ impl OpenAiCompletionProvider {
                     String::from_utf8(api_key)?
                 };
                 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
-                    if let CompletionProvider::OpenAi(provider) = provider {
+                    provider.update_current_as::<_, Self>(|provider| {
                         provider.api_key = Some(api_key);
-                    }
+                    });
                 })
             })
         }
     }
 
-    pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
         let delete_credentials = cx.delete_credentials(&self.api_url);
         cx.spawn(|mut cx| async move {
             delete_credentials.await.log_err();
             cx.update_global::<CompletionProvider, _>(|provider, _cx| {
-                if let CompletionProvider::OpenAi(provider) = provider {
+                provider.update_current_as::<_, Self>(|provider| {
                     provider.api_key = None;
-                }
+                });
             })
         })
     }
 
-    pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
         cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
             .into()
     }
 
-    pub fn model(&self) -> OpenAiModel {
-        self.model.clone()
+    fn model(&self) -> LanguageModel {
+        LanguageModel::OpenAi(self.model.clone())
     }
 
-    pub fn count_tokens(
+    fn count_tokens(
         &self,
         request: LanguageModelRequest,
         cx: &AppContext,
@@ -140,7 +179,7 @@ impl OpenAiCompletionProvider {
         count_open_ai_tokens(request, cx.background_executor())
     }
 
-    pub fn complete(
+    fn complete(
         &self,
         request: LanguageModelRequest,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
@@ -173,36 +212,8 @@ impl OpenAiCompletionProvider {
         .boxed()
     }
 
-    fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
-        let model = match request.model {
-            LanguageModel::OpenAi(model) => model,
-            _ => self.model(),
-        };
-
-        Request {
-            model,
-            messages: request
-                .messages
-                .into_iter()
-                .map(|msg| match msg.role {
-                    Role::User => RequestMessage::User {
-                        content: msg.content,
-                    },
-                    Role::Assistant => RequestMessage::Assistant {
-                        content: Some(msg.content),
-                        tool_calls: Vec::new(),
-                    },
-                    Role::System => RequestMessage::System {
-                        content: msg.content,
-                    },
-                })
-                .collect(),
-            stream: true,
-            stop: request.stop,
-            temperature: request.temperature,
-            tools: Vec::new(),
-            tool_choice: None,
-        }
+    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
+        self
     }
 }
 
@@ -284,9 +295,9 @@ impl AuthenticationPrompt {
         cx.spawn(|_, mut cx| async move {
             write_credentials.await?;
             cx.update_global::<CompletionProvider, _>(|provider, _cx| {
-                if let CompletionProvider::OpenAi(provider) = provider {
+                provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
                     provider.api_key = Some(api_key);
-                }
+                });
             })
         })
         .detach_and_log_err(cx);

crates/assistant/src/inline_assistant.rs 🔗

@@ -1986,13 +1986,14 @@ impl Codegen {
             .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
 
         let model_telemetry_id = prompt.model.telemetry_id();
-        let response = CompletionProvider::global(cx).complete(prompt);
+        let response = CompletionProvider::global(cx).complete(prompt, cx);
         let telemetry = self.telemetry.clone();
         self.edit_position = range.start;
         self.diff = Diff::default();
         self.status = CodegenStatus::Pending;
         self.generation = cx.spawn(|this, mut cx| {
             async move {
+                let response = response.await;
                 let generate = async {
                     let mut edit_start = range.start.to_offset(&snapshot);
 
@@ -2002,7 +2003,7 @@ impl Codegen {
                             let mut response_latency = None;
                             let request_start = Instant::now();
                             let diff = async {
-                                let chunks = StripInvalidSpans::new(response.await?);
+                                let chunks = StripInvalidSpans::new(response.inner.await?);
                                 futures::pin_mut!(chunks);
                                 let mut diff = StreamingDiff::new(selected_text.to_string());
 
@@ -2473,9 +2474,8 @@ mod tests {
 
     #[gpui::test(iterations = 10)]
     async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
-        let provider = FakeCompletionProvider::default();
         cx.set_global(cx.update(SettingsStore::test));
-        cx.set_global(CompletionProvider::Fake(provider.clone()));
+        let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
         cx.update(language_settings::init);
 
         let text = indoc! {"
@@ -2495,8 +2495,11 @@ mod tests {
         });
         let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx));
 
-        let request = LanguageModelRequest::default();
-        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
+        codegen.update(cx, |codegen, cx| {
+            codegen.start(LanguageModelRequest::default(), cx)
+        });
+
+        cx.background_executor.run_until_parked();
 
         let mut new_text = concat!(
             "       let mut x = 0;\n",
@@ -2508,11 +2511,11 @@ mod tests {
             let max_len = cmp::min(new_text.len(), 10);
             let len = rng.gen_range(1..=max_len);
             let (chunk, suffix) = new_text.split_at(len);
-            provider.send_completion(chunk.into());
+            provider.send_completion(&LanguageModelRequest::default(), chunk.into());
             new_text = suffix;
             cx.background_executor.run_until_parked();
         }
-        provider.finish_completion();
+        provider.finish_completion(&LanguageModelRequest::default());
         cx.background_executor.run_until_parked();
 
         assert_eq!(
@@ -2533,8 +2536,7 @@ mod tests {
         cx: &mut TestAppContext,
         mut rng: StdRng,
     ) {
-        let provider = FakeCompletionProvider::default();
-        cx.set_global(CompletionProvider::Fake(provider.clone()));
+        let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
         cx.set_global(cx.update(SettingsStore::test));
         cx.update(language_settings::init);
 
@@ -2555,6 +2557,8 @@ mod tests {
         let request = LanguageModelRequest::default();
         codegen.update(cx, |codegen, cx| codegen.start(request, cx));
 
+        cx.background_executor.run_until_parked();
+
         let mut new_text = concat!(
             "t mut x = 0;\n",
             "while x < 10 {\n",
@@ -2565,11 +2569,11 @@ mod tests {
             let max_len = cmp::min(new_text.len(), 10);
             let len = rng.gen_range(1..=max_len);
             let (chunk, suffix) = new_text.split_at(len);
-            provider.send_completion(chunk.into());
+            provider.send_completion(&LanguageModelRequest::default(), chunk.into());
             new_text = suffix;
             cx.background_executor.run_until_parked();
         }
-        provider.finish_completion();
+        provider.finish_completion(&LanguageModelRequest::default());
         cx.background_executor.run_until_parked();
 
         assert_eq!(
@@ -2590,8 +2594,7 @@ mod tests {
         cx: &mut TestAppContext,
         mut rng: StdRng,
     ) {
-        let provider = FakeCompletionProvider::default();
-        cx.set_global(CompletionProvider::Fake(provider.clone()));
+        let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
         cx.set_global(cx.update(SettingsStore::test));
         cx.update(language_settings::init);
 
@@ -2612,6 +2615,8 @@ mod tests {
         let request = LanguageModelRequest::default();
         codegen.update(cx, |codegen, cx| codegen.start(request, cx));
 
+        cx.background_executor.run_until_parked();
+
         let mut new_text = concat!(
             "let mut x = 0;\n",
             "while x < 10 {\n",
@@ -2622,11 +2627,11 @@ mod tests {
             let max_len = cmp::min(new_text.len(), 10);
             let len = rng.gen_range(1..=max_len);
             let (chunk, suffix) = new_text.split_at(len);
-            provider.send_completion(chunk.into());
+            provider.send_completion(&LanguageModelRequest::default(), chunk.into());
             new_text = suffix;
             cx.background_executor.run_until_parked();
         }
-        provider.finish_completion();
+        provider.finish_completion(&LanguageModelRequest::default());
         cx.background_executor.run_until_parked();
 
         assert_eq!(

crates/assistant/src/terminal_inline_assistant.rs 🔗

@@ -1026,9 +1026,10 @@ impl Codegen {
 
         let telemetry = self.telemetry.clone();
         let model_telemetry_id = prompt.model.telemetry_id();
-        let response = CompletionProvider::global(cx).complete(prompt);
+        let response = CompletionProvider::global(cx).complete(prompt, cx);
 
         self.generation = cx.spawn(|this, mut cx| async move {
+            let response = response.await;
             let generate = async {
                 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 
@@ -1036,7 +1037,7 @@ impl Codegen {
                     let mut response_latency = None;
                     let request_start = Instant::now();
                     let task = async {
-                        let mut response = response.await?;
+                        let mut response = response.inner.await?;
                         while let Some(chunk) = response.next().await {
                             if response_latency.is_none() {
                                 response_latency = Some(request_start.elapsed());