Detailed changes
@@ -163,7 +163,7 @@ impl LanguageModelRequestMessage {
}
}
-#[derive(Debug, Default, Serialize)]
+#[derive(Debug, Default, Serialize, Deserialize)]
pub struct LanguageModelRequest {
pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>,
@@ -1409,7 +1409,7 @@ impl Context {
}
let request = self.to_completion_request(cx);
- let stream = CompletionProvider::global(cx).complete(request);
+ let response = CompletionProvider::global(cx).complete(request, cx);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@@ -1422,11 +1422,12 @@ impl Context {
let task = cx.spawn({
|this, mut cx| async move {
+ let response = response.await;
let assistant_message_id = assistant_message.id;
let mut response_latency = None;
let stream_completion = async {
let request_start = Instant::now();
- let mut messages = stream.await?;
+ let mut messages = response.inner.await?;
while let Some(message) = messages.next().await {
if response_latency.is_none() {
@@ -1718,10 +1719,11 @@ impl Context {
temperature: 1.0,
};
- let stream = CompletionProvider::global(cx).complete(request);
+ let response = CompletionProvider::global(cx).complete(request, cx);
self.pending_summary = cx.spawn(|this, mut cx| {
async move {
- let mut messages = stream.await?;
+ let response = response.await;
+ let mut messages = response.inner.await?;
while let Some(message) = messages.next().await {
let text = message?;
@@ -3642,7 +3644,7 @@ mod tests {
#[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
- cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
+ FakeCompletionProvider::setup_test(cx);
cx.set_global(settings_store);
init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@@ -3774,7 +3776,7 @@ mod tests {
fn test_message_splitting(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
- cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
+ FakeCompletionProvider::setup_test(cx);
init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@@ -3867,7 +3869,7 @@ mod tests {
#[gpui::test]
fn test_messages_for_offsets(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx);
- cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
+ FakeCompletionProvider::setup_test(cx);
cx.set_global(settings_store);
init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@@ -3952,7 +3954,8 @@ mod tests {
async fn test_slash_commands(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
- cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
+ cx.update(|cx| FakeCompletionProvider::setup_test(cx));
+
cx.update(Project::init_settings);
cx.update(init);
let fs = FakeFs::new(cx.background_executor.clone());
@@ -4147,7 +4150,7 @@ mod tests {
async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store);
- cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
+ cx.update(FakeCompletionProvider::setup_test);
cx.update(init);
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context =
@@ -1,5 +1,6 @@
use std::fmt;
+use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
pub use anthropic::Model as AnthropicModel;
use gpui::Pixels;
pub use ollama::Model as OllamaModel;
@@ -15,8 +16,6 @@ use serde::{
use settings::{Settings, SettingsSources};
use strum::{EnumIter, IntoEnumIterator};
-use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
-
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum CloudModel {
Gpt3Point5Turbo,
@@ -11,6 +11,8 @@ pub use cloud::*;
pub use fake::*;
pub use ollama::*;
pub use open_ai::*;
+use parking_lot::RwLock;
+use smol::lock::{Semaphore, SemaphoreGuardArc};
use crate::{
assistant_settings::{AssistantProvider, AssistantSettings},
@@ -21,8 +23,8 @@ use client::Client;
use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
use settings::{Settings, SettingsStore};
-use std::sync::Arc;
use std::time::Duration;
+use std::{any::Any, sync::Arc};
/// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model.
@@ -39,272 +41,91 @@ fn choose_openai_model(
}
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
- let mut settings_version = 0;
- let provider = match &AssistantSettings::get_global(cx).provider {
- AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
- CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
- ),
- AssistantProvider::OpenAi {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- available_models,
- } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
- choose_openai_model(model, available_models),
- api_url.clone(),
- client.http_client(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- )),
- AssistantProvider::Anthropic {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
- model.clone(),
- api_url.clone(),
- client.http_client(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- )),
- AssistantProvider::Ollama {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- } => CompletionProvider::Ollama(OllamaCompletionProvider::new(
- model.clone(),
- api_url.clone(),
- client.http_client(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- cx,
- )),
- };
- cx.set_global(provider);
+ let provider = create_provider_from_settings(client.clone(), 0, cx);
+ cx.set_global(CompletionProvider::new(provider, Some(client)));
+ let mut settings_version = 0;
cx.observe_global::<SettingsStore>(move |cx| {
settings_version += 1;
cx.update_global::<CompletionProvider, _>(|provider, cx| {
- match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
- (
- CompletionProvider::OpenAi(provider),
- AssistantProvider::OpenAi {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- available_models,
- },
- ) => {
- provider.update(
- choose_openai_model(model, available_models),
- api_url.clone(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- );
- }
- (
- CompletionProvider::Anthropic(provider),
- AssistantProvider::Anthropic {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- },
- ) => {
- provider.update(
- model.clone(),
- api_url.clone(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- );
- }
-
- (
- CompletionProvider::Ollama(provider),
- AssistantProvider::Ollama {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- },
- ) => {
- provider.update(
- model.clone(),
- api_url.clone(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- cx,
- );
- }
-
- (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
- provider.update(model.clone(), settings_version);
- }
- (_, AssistantProvider::ZedDotDev { model }) => {
- *provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
- model.clone(),
- client.clone(),
- settings_version,
- cx,
- ));
- }
- (
- _,
- AssistantProvider::OpenAi {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- available_models,
- },
- ) => {
- *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
- choose_openai_model(model, available_models),
- api_url.clone(),
- client.http_client(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- ));
- }
- (
- _,
- AssistantProvider::Anthropic {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- },
- ) => {
- *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
- model.clone(),
- api_url.clone(),
- client.http_client(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- ));
- }
- (
- _,
- AssistantProvider::Ollama {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- },
- ) => {
- *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new(
- model.clone(),
- api_url.clone(),
- client.http_client(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- cx,
- ));
- }
- }
+ provider.update_settings(settings_version, cx);
})
})
.detach();
}
-pub enum CompletionProvider {
- OpenAi(OpenAiCompletionProvider),
- Anthropic(AnthropicCompletionProvider),
- Cloud(CloudCompletionProvider),
- #[cfg(test)]
- Fake(FakeCompletionProvider),
- Ollama(OllamaCompletionProvider),
+pub struct CompletionResponse {
+ pub inner: BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>,
+ _lock: SemaphoreGuardArc,
}
-impl gpui::Global for CompletionProvider {}
+pub trait LanguageModelCompletionProvider: Send + Sync {
+ fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
+ fn settings_version(&self) -> usize;
+ fn is_authenticated(&self) -> bool;
+ fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
+ fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
+ fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
+ fn model(&self) -> LanguageModel;
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ cx: &AppContext,
+ ) -> BoxFuture<'static, Result<usize>>;
+ fn complete(
+ &self,
+ request: LanguageModelRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+
+ fn as_any_mut(&mut self) -> &mut dyn Any;
+}
+
+const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
+
+pub struct CompletionProvider {
+ provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
+ client: Option<Arc<Client>>,
+ request_limiter: Arc<Semaphore>,
+}
impl CompletionProvider {
- pub fn global(cx: &AppContext) -> &Self {
- cx.global::<Self>()
+ pub fn new(
+ provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
+ client: Option<Arc<Client>>,
+ ) -> Self {
+ Self {
+ provider,
+ client,
+ request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
+ }
}
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
- match self {
- CompletionProvider::OpenAi(provider) => provider
- .available_models(cx)
- .map(LanguageModel::OpenAi)
- .collect(),
- CompletionProvider::Anthropic(provider) => provider
- .available_models()
- .map(LanguageModel::Anthropic)
- .collect(),
- CompletionProvider::Cloud(provider) => provider
- .available_models()
- .map(LanguageModel::Cloud)
- .collect(),
- CompletionProvider::Ollama(provider) => provider
- .available_models()
- .map(|model| LanguageModel::Ollama(model.clone()))
- .collect(),
- #[cfg(test)]
- CompletionProvider::Fake(_) => unimplemented!(),
- }
+ self.provider.read().available_models(cx)
}
pub fn settings_version(&self) -> usize {
- match self {
- CompletionProvider::OpenAi(provider) => provider.settings_version(),
- CompletionProvider::Anthropic(provider) => provider.settings_version(),
- CompletionProvider::Cloud(provider) => provider.settings_version(),
- CompletionProvider::Ollama(provider) => provider.settings_version(),
- #[cfg(test)]
- CompletionProvider::Fake(_) => unimplemented!(),
- }
+ self.provider.read().settings_version()
}
pub fn is_authenticated(&self) -> bool {
- match self {
- CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
- CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
- CompletionProvider::Cloud(provider) => provider.is_authenticated(),
- CompletionProvider::Ollama(provider) => provider.is_authenticated(),
- #[cfg(test)]
- CompletionProvider::Fake(_) => true,
- }
+ self.provider.read().is_authenticated()
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
- match self {
- CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
- CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
- CompletionProvider::Cloud(provider) => provider.authenticate(cx),
- CompletionProvider::Ollama(provider) => provider.authenticate(cx),
- #[cfg(test)]
- CompletionProvider::Fake(_) => Task::ready(Ok(())),
- }
+ self.provider.read().authenticate(cx)
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
- match self {
- CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
- CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
- CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
- CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx),
- #[cfg(test)]
- CompletionProvider::Fake(_) => unimplemented!(),
- }
+ self.provider.read().authentication_prompt(cx)
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
- match self {
- CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
- CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
- CompletionProvider::Cloud(_) => Task::ready(Ok(())),
- CompletionProvider::Ollama(provider) => provider.reset_credentials(cx),
- #[cfg(test)]
- CompletionProvider::Fake(_) => Task::ready(Ok(())),
- }
+ self.provider.read().reset_credentials(cx)
}
pub fn model(&self) -> LanguageModel {
- match self {
- CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
- CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
- CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
- CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()),
- #[cfg(test)]
- CompletionProvider::Fake(_) => LanguageModel::default(),
- }
+ self.provider.read().model()
}
pub fn count_tokens(
@@ -312,27 +133,241 @@ impl CompletionProvider {
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
- match self {
- CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
- CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
- CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
- CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx),
- #[cfg(test)]
- CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
- }
+ self.provider.read().count_tokens(request, cx)
}
pub fn complete(
&self,
request: LanguageModelRequest,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- match self {
- CompletionProvider::OpenAi(provider) => provider.complete(request),
- CompletionProvider::Anthropic(provider) => provider.complete(request),
- CompletionProvider::Cloud(provider) => provider.complete(request),
- CompletionProvider::Ollama(provider) => provider.complete(request),
- #[cfg(test)]
- CompletionProvider::Fake(provider) => provider.complete(),
+ cx: &AppContext,
+ ) -> Task<CompletionResponse> {
+ let rate_limiter = self.request_limiter.clone();
+ let provider = self.provider.clone();
+ cx.background_executor().spawn(async move {
+ let lock = rate_limiter.acquire_arc().await;
+ let response = provider.read().complete(request);
+ CompletionResponse {
+ inner: response,
+ _lock: lock,
+ }
+ })
+ }
+}
+
+impl gpui::Global for CompletionProvider {}
+
+impl CompletionProvider {
+ pub fn global(cx: &AppContext) -> &Self {
+ cx.global::<Self>()
+ }
+
+ pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
+ &mut self,
+ update: impl FnOnce(&mut T) -> R,
+ ) -> Option<R> {
+ let mut provider = self.provider.write();
+ if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
+ Some(update(provider))
+ } else {
+ None
+ }
+ }
+
+ pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
+ let updated = match &AssistantSettings::get_global(cx).provider {
+ AssistantProvider::ZedDotDev { model } => self
+ .update_current_as::<_, CloudCompletionProvider>(|provider| {
+ provider.update(model.clone(), version);
+ }),
+ AssistantProvider::OpenAi {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ available_models,
+ } => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
+ provider.update(
+ choose_openai_model(&model, &available_models),
+ api_url.clone(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ version,
+ );
+ }),
+ AssistantProvider::Anthropic {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
+ provider.update(
+ model.clone(),
+ api_url.clone(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ version,
+ );
+ }),
+ AssistantProvider::Ollama {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
+ provider.update(
+ model.clone(),
+ api_url.clone(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ version,
+ cx,
+ );
+ }),
+ };
+
+ // Previously configured provider was changed to another one
+ if updated.is_none() {
+ if let Some(client) = self.client.clone() {
+ self.provider = create_provider_from_settings(client, version, cx);
+ } else {
+ log::warn!("completion provider cannot be created because client is not set");
+ }
}
}
}
+
+fn create_provider_from_settings(
+ client: Arc<Client>,
+ settings_version: usize,
+ cx: &mut AppContext,
+) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
+ match &AssistantSettings::get_global(cx).provider {
+ AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
+ CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
+ )),
+ AssistantProvider::OpenAi {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ available_models,
+ } => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
+ choose_openai_model(&model, &available_models),
+ api_url.clone(),
+ client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ ))),
+ AssistantProvider::Anthropic {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
+ model.clone(),
+ api_url.clone(),
+ client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ ))),
+ AssistantProvider::Ollama {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => Arc::new(RwLock::new(OllamaCompletionProvider::new(
+ model.clone(),
+ api_url.clone(),
+ client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ cx,
+ ))),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::sync::Arc;
+
+ use gpui::AppContext;
+ use parking_lot::RwLock;
+ use settings::SettingsStore;
+ use smol::stream::StreamExt;
+
+ use crate::{
+ completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
+ FakeCompletionProvider, LanguageModelRequest,
+ };
+
+ #[gpui::test]
+ fn test_rate_limiting(cx: &mut AppContext) {
+ SettingsStore::test(cx);
+ let fake_provider = FakeCompletionProvider::setup_test(cx);
+
+ let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
+
+ // Enqueue some requests
+ for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
+ let response = provider.complete(
+ LanguageModelRequest {
+ temperature: i as f32 / 10.0,
+ ..Default::default()
+ },
+ cx,
+ );
+ cx.background_executor()
+ .spawn(async move {
+ let response = response.await;
+ let mut stream = response.inner.await.unwrap();
+ while let Some(message) = stream.next().await {
+ message.unwrap();
+ }
+ })
+ .detach();
+ }
+ cx.background_executor().run_until_parked();
+
+ assert_eq!(
+ fake_provider.completion_count(),
+ MAX_CONCURRENT_COMPLETION_REQUESTS
+ );
+
+ // Get the first completion request that is in flight and mark it as completed.
+ let completion = fake_provider
+ .running_completions()
+ .into_iter()
+ .next()
+ .unwrap();
+ fake_provider.finish_completion(&completion);
+
+ // Ensure that the number of in-flight completion requests is reduced.
+ assert_eq!(
+ fake_provider.completion_count(),
+ MAX_CONCURRENT_COMPLETION_REQUESTS - 1
+ );
+
+ cx.background_executor().run_until_parked();
+
+ // Ensure that another completion request was allowed to acquire the lock.
+ assert_eq!(
+ fake_provider.completion_count(),
+ MAX_CONCURRENT_COMPLETION_REQUESTS
+ );
+
+ // Mark all completion requests as finished that are in flight.
+ for request in fake_provider.running_completions() {
+ fake_provider.finish_completion(&request);
+ }
+
+ assert_eq!(fake_provider.completion_count(), 0);
+
+ // Wait until the background tasks acquire the lock again.
+ cx.background_executor().run_until_parked();
+
+ assert_eq!(
+ fake_provider.completion_count(),
+ MAX_CONCURRENT_COMPLETION_REQUESTS - 1
+ );
+
+ // Finish all remaining completion requests.
+ for request in fake_provider.running_completions() {
+ fake_provider.finish_completion(&request);
+ }
+
+ cx.background_executor().run_until_parked();
+
+ assert_eq!(fake_provider.completion_count(), 0);
+ }
+}
@@ -2,7 +2,7 @@ use crate::{
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
Role,
};
-use crate::{count_open_ai_tokens, LanguageModelRequestMessage};
+use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
use anthropic::{stream_completion, Request, RequestMessage};
use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle};
@@ -26,50 +26,22 @@ pub struct AnthropicCompletionProvider {
settings_version: usize,
}
-impl AnthropicCompletionProvider {
- pub fn new(
- model: AnthropicModel,
- api_url: String,
- http_client: Arc<dyn HttpClient>,
- low_speed_timeout: Option<Duration>,
- settings_version: usize,
- ) -> Self {
- Self {
- api_key: None,
- api_url,
- model,
- http_client,
- low_speed_timeout,
- settings_version,
- }
- }
-
- pub fn update(
- &mut self,
- model: AnthropicModel,
- api_url: String,
- low_speed_timeout: Option<Duration>,
- settings_version: usize,
- ) {
- self.model = model;
- self.api_url = api_url;
- self.low_speed_timeout = low_speed_timeout;
- self.settings_version = settings_version;
- }
-
- pub fn available_models(&self) -> impl Iterator<Item = AnthropicModel> {
+impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
+ fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
AnthropicModel::iter()
+ .map(LanguageModel::Anthropic)
+ .collect()
}
- pub fn settings_version(&self) -> usize {
+ fn settings_version(&self) -> usize {
self.settings_version
}
- pub fn is_authenticated(&self) -> bool {
+ fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
- pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+ fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
@@ -85,36 +57,36 @@ impl AnthropicCompletionProvider {
String::from_utf8(api_key)?
};
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
- if let CompletionProvider::Anthropic(provider) = provider {
+ provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = Some(api_key);
- }
+ });
})
})
}
}
- pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+ fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
- if let CompletionProvider::Anthropic(provider) = provider {
+ provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = None;
- }
+ });
})
})
}
- pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+ fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
.into()
}
- pub fn model(&self) -> AnthropicModel {
- self.model.clone()
+ fn model(&self) -> LanguageModel {
+ LanguageModel::Anthropic(self.model.clone())
}
- pub fn count_tokens(
+ fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
@@ -122,7 +94,7 @@ impl AnthropicCompletionProvider {
count_open_ai_tokens(request, cx.background_executor())
}
- pub fn complete(
+ fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
@@ -167,12 +139,48 @@ impl AnthropicCompletionProvider {
.boxed()
}
+ fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
+ self
+ }
+}
+
+impl AnthropicCompletionProvider {
+ pub fn new(
+ model: AnthropicModel,
+ api_url: String,
+ http_client: Arc<dyn HttpClient>,
+ low_speed_timeout: Option<Duration>,
+ settings_version: usize,
+ ) -> Self {
+ Self {
+ api_key: None,
+ api_url,
+ model,
+ http_client,
+ low_speed_timeout,
+ settings_version,
+ }
+ }
+
+ pub fn update(
+ &mut self,
+ model: AnthropicModel,
+ api_url: String,
+ low_speed_timeout: Option<Duration>,
+ settings_version: usize,
+ ) {
+ self.model = model;
+ self.api_url = api_url;
+ self.low_speed_timeout = low_speed_timeout;
+ self.settings_version = settings_version;
+ }
+
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
preprocess_anthropic_request(&mut request);
let model = match request.model {
LanguageModel::Anthropic(model) => model,
- _ => self.model(),
+ _ => self.model.clone(),
};
let mut system_message = String::new();
@@ -278,9 +286,9 @@ impl AuthenticationPrompt {
cx.spawn(|_, mut cx| async move {
write_credentials.await?;
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
- if let CompletionProvider::Anthropic(provider) = provider {
+ provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = Some(api_key);
- }
+ });
})
})
.detach_and_log_err(cx);
@@ -1,6 +1,6 @@
use crate::{
assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
- LanguageModelRequest,
+ LanguageModelCompletionProvider, LanguageModelRequest,
};
use anyhow::{anyhow, Result};
use client::{proto, Client};
@@ -30,11 +30,9 @@ impl CloudCompletionProvider {
let maintain_client_status = cx.spawn(|mut cx| async move {
while let Some(status) = status_rx.next().await {
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
- if let CompletionProvider::Cloud(provider) = provider {
+ provider.update_current_as::<_, Self>(|provider| {
provider.status = status;
- } else {
- unreachable!()
- }
+ });
});
}
});
@@ -51,44 +49,53 @@ impl CloudCompletionProvider {
self.model = model;
self.settings_version = settings_version;
}
+}
- pub fn available_models(&self) -> impl Iterator<Item = CloudModel> {
+impl LanguageModelCompletionProvider for CloudCompletionProvider {
+ fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
Some(custom_model)
} else {
None
};
- CloudModel::iter().filter_map(move |model| {
- if let CloudModel::Custom(_) = model {
- Some(CloudModel::Custom(custom_model.take()?))
- } else {
- Some(model)
- }
- })
+ CloudModel::iter()
+ .filter_map(move |model| {
+ if let CloudModel::Custom(_) = model {
+ Some(CloudModel::Custom(custom_model.take()?))
+ } else {
+ Some(model)
+ }
+ })
+ .map(LanguageModel::Cloud)
+ .collect()
}
- pub fn settings_version(&self) -> usize {
+ fn settings_version(&self) -> usize {
self.settings_version
}
- pub fn model(&self) -> CloudModel {
- self.model.clone()
- }
-
- pub fn is_authenticated(&self) -> bool {
+ fn is_authenticated(&self) -> bool {
self.status.is_connected()
}
- pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+ fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
let client = self.client.clone();
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
}
- pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+ fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|_cx| AuthenticationPrompt).into()
}
- pub fn count_tokens(
+ fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
+ Task::ready(Ok(()))
+ }
+
+ fn model(&self) -> LanguageModel {
+ LanguageModel::Cloud(self.model.clone())
+ }
+
+ fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
@@ -128,7 +135,7 @@ impl CloudCompletionProvider {
}
}
- pub fn complete(
+ fn complete(
&self,
mut request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
@@ -161,6 +168,10 @@ impl CloudCompletionProvider {
})
.boxed()
}
+
+ fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
+ self
+ }
}
struct AuthenticationPrompt;
@@ -1,29 +1,107 @@
use anyhow::Result;
+use collections::HashMap;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui::{AnyView, AppContext, Task};
use std::sync::Arc;
+use ui::WindowContext;
+
+use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
#[derive(Clone, Default)]
pub struct FakeCompletionProvider {
- current_completion_tx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedSender<String>>>>,
+ current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
}
impl FakeCompletionProvider {
- pub fn complete(&self) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let (tx, rx) = mpsc::unbounded();
- *self.current_completion_tx.lock() = Some(tx);
- async move { Ok(rx.map(Ok).boxed()) }.boxed()
+ #[cfg(test)]
+ pub fn setup_test(cx: &mut AppContext) -> Self {
+ use crate::CompletionProvider;
+ use parking_lot::RwLock;
+
+ let this = Self::default();
+ let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
+ cx.set_global(provider);
+ this
+ }
+
+ pub fn running_completions(&self) -> Vec<LanguageModelRequest> {
+ self.current_completion_txs
+ .lock()
+ .keys()
+ .map(|k| serde_json::from_str(k).unwrap())
+ .collect()
+ }
+
+ pub fn completion_count(&self) -> usize {
+ self.current_completion_txs.lock().len()
}
- pub fn send_completion(&self, chunk: String) {
- self.current_completion_tx
+ pub fn send_completion(&self, request: &LanguageModelRequest, chunk: String) {
+ let json = serde_json::to_string(request).unwrap();
+ self.current_completion_txs
.lock()
- .as_ref()
+ .get(&json)
.unwrap()
.unbounded_send(chunk)
.unwrap();
}
- pub fn finish_completion(&self) {
- self.current_completion_tx.lock().take();
+ pub fn finish_completion(&self, request: &LanguageModelRequest) {
+ self.current_completion_txs
+ .lock()
+ .remove(&serde_json::to_string(request).unwrap());
+ }
+}
+
+impl LanguageModelCompletionProvider for FakeCompletionProvider {
+ fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
+ vec![LanguageModel::default()]
+ }
+
+ fn settings_version(&self) -> usize {
+ 0
+ }
+
+ fn is_authenticated(&self) -> bool {
+ true
+ }
+
+ fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
+ Task::ready(Ok(()))
+ }
+
+ fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
+ unimplemented!()
+ }
+
+ fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
+ Task::ready(Ok(()))
+ }
+
+ fn model(&self) -> LanguageModel {
+ LanguageModel::default()
+ }
+
+ fn count_tokens(
+ &self,
+ _request: LanguageModelRequest,
+ _cx: &AppContext,
+ ) -> BoxFuture<'static, Result<usize>> {
+ futures::future::ready(Ok(0)).boxed()
+ }
+
+ fn complete(
+ &self,
+ _request: LanguageModelRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let (tx, rx) = mpsc::unbounded();
+ self.current_completion_txs
+ .lock()
+ .insert(serde_json::to_string(&_request).unwrap(), tx);
+ async move { Ok(rx.map(Ok).boxed()) }.boxed()
+ }
+
+ fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
+ self
}
}
@@ -1,3 +1,4 @@
+use crate::LanguageModelCompletionProvider;
use crate::{
assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
@@ -26,6 +27,108 @@ pub struct OllamaCompletionProvider {
available_models: Vec<OllamaModel>,
}
+impl LanguageModelCompletionProvider for OllamaCompletionProvider {
+ fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
+ self.available_models
+ .iter()
+ .map(|m| LanguageModel::Ollama(m.clone()))
+ .collect()
+ }
+
+ fn settings_version(&self) -> usize {
+ self.settings_version
+ }
+
+ fn is_authenticated(&self) -> bool {
+ !self.available_models.is_empty()
+ }
+
+ fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+ if self.is_authenticated() {
+ Task::ready(Ok(()))
+ } else {
+ self.fetch_models(cx)
+ }
+ }
+
+ fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+ let fetch_models = Box::new(move |cx: &mut WindowContext| {
+ cx.update_global::<CompletionProvider, _>(|provider, cx| {
+ provider
+ .update_current_as::<_, OllamaCompletionProvider>(|provider| {
+ provider.fetch_models(cx)
+ })
+ .unwrap_or_else(|| Task::ready(Ok(())))
+ })
+ });
+
+ cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
+ .into()
+ }
+
+ fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+ self.fetch_models(cx)
+ }
+
+ fn model(&self) -> LanguageModel {
+ LanguageModel::Ollama(self.model.clone())
+ }
+
+ fn count_tokens(
+ &self,
+ request: LanguageModelRequest,
+ _cx: &AppContext,
+ ) -> BoxFuture<'static, Result<usize>> {
+ // There is no endpoint for this _yet_ in Ollama
+ // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
+ let token_count = request
+ .messages
+ .iter()
+ .map(|msg| msg.content.chars().count())
+ .sum::<usize>()
+ / 4;
+
+ async move { Ok(token_count) }.boxed()
+ }
+
+ fn complete(
+ &self,
+ request: LanguageModelRequest,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ let request = self.to_ollama_request(request);
+
+ let http_client = self.http_client.clone();
+ let api_url = self.api_url.clone();
+ let low_speed_timeout = self.low_speed_timeout;
+ async move {
+ let request =
+ stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
+ let response = request.await?;
+ let stream = response
+ .filter_map(|response| async move {
+ match response {
+ Ok(delta) => {
+ let content = match delta.message {
+ ChatMessage::User { content } => content,
+ ChatMessage::Assistant { content } => content,
+ ChatMessage::System { content } => content,
+ };
+ Some(Ok(content))
+ }
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed();
+ Ok(stream)
+ }
+ .boxed()
+ }
+
+ fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
+ self
+ }
+}
+
impl OllamaCompletionProvider {
pub fn new(
model: OllamaModel,
@@ -87,36 +190,12 @@ impl OllamaCompletionProvider {
self.settings_version = settings_version;
}
- pub fn available_models(&self) -> impl Iterator<Item = &OllamaModel> {
- self.available_models.iter()
- }
-
pub fn select_first_available_model(&mut self) {
if let Some(model) = self.available_models.first() {
self.model = model.clone();
}
}
- pub fn settings_version(&self) -> usize {
- self.settings_version
- }
-
- pub fn is_authenticated(&self) -> bool {
- !self.available_models.is_empty()
- }
-
- pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
- if self.is_authenticated() {
- Task::ready(Ok(()))
- } else {
- self.fetch_models(cx)
- }
- }
-
- pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
- self.fetch_models(cx)
- }
-
pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let http_client = self.http_client.clone();
let api_url = self.api_url.clone();
@@ -137,90 +216,21 @@ impl OllamaCompletionProvider {
models.sort_by(|a, b| a.name.cmp(&b.name));
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
- if let CompletionProvider::Ollama(provider) = provider {
+ provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.available_models = models;
if !provider.available_models.is_empty() && provider.model.name.is_empty() {
provider.select_first_available_model()
}
- }
+ });
})
})
}
- pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
- let fetch_models = Box::new(move |cx: &mut WindowContext| {
- cx.update_global::<CompletionProvider, _>(|provider, cx| {
- if let CompletionProvider::Ollama(provider) = provider {
- provider.fetch_models(cx)
- } else {
- Task::ready(Ok(()))
- }
- })
- });
-
- cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
- .into()
- }
-
- pub fn model(&self) -> OllamaModel {
- self.model.clone()
- }
-
- pub fn count_tokens(
- &self,
- request: LanguageModelRequest,
- _cx: &AppContext,
- ) -> BoxFuture<'static, Result<usize>> {
- // There is no endpoint for this _yet_ in Ollama
- // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
- let token_count = request
- .messages
- .iter()
- .map(|msg| msg.content.chars().count())
- .sum::<usize>()
- / 4;
-
- async move { Ok(token_count) }.boxed()
- }
-
- pub fn complete(
- &self,
- request: LanguageModelRequest,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let request = self.to_ollama_request(request);
-
- let http_client = self.http_client.clone();
- let api_url = self.api_url.clone();
- let low_speed_timeout = self.low_speed_timeout;
- async move {
- let request =
- stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
- let response = request.await?;
- let stream = response
- .filter_map(|response| async move {
- match response {
- Ok(delta) => {
- let content = match delta.message {
- ChatMessage::User { content } => content,
- ChatMessage::Assistant { content } => content,
- ChatMessage::System { content } => content,
- };
- Some(Ok(content))
- }
- Err(error) => Some(Err(error)),
- }
- })
- .boxed();
- Ok(stream)
- }
- .boxed()
- }
-
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
let model = match request.model {
LanguageModel::Ollama(model) => model,
- _ => self.model(),
+ _ => self.model.clone(),
};
ChatRequest {
@@ -1,5 +1,6 @@
use crate::assistant_settings::CloudModel;
use crate::assistant_settings::{AssistantProvider, AssistantSettings};
+use crate::LanguageModelCompletionProvider;
use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
@@ -57,37 +58,75 @@ impl OpenAiCompletionProvider {
self.settings_version = settings_version;
}
- pub fn available_models(&self, cx: &AppContext) -> impl Iterator<Item = OpenAiModel> {
+ fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
+ let model = match request.model {
+ LanguageModel::OpenAi(model) => model,
+ _ => self.model.clone(),
+ };
+
+ Request {
+ model,
+ messages: request
+ .messages
+ .into_iter()
+ .map(|msg| match msg.role {
+ Role::User => RequestMessage::User {
+ content: msg.content,
+ },
+ Role::Assistant => RequestMessage::Assistant {
+ content: Some(msg.content),
+ tool_calls: Vec::new(),
+ },
+ Role::System => RequestMessage::System {
+ content: msg.content,
+ },
+ })
+ .collect(),
+ stream: true,
+ stop: request.stop,
+ temperature: request.temperature,
+ tools: Vec::new(),
+ tool_choice: None,
+ }
+ }
+}
+
+impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
+ fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
if let AssistantProvider::OpenAi {
available_models, ..
} = &AssistantSettings::get_global(cx).provider
{
if !available_models.is_empty() {
- // available_models is set, just return it
- return available_models.clone().into_iter();
+ return available_models
+ .iter()
+ .cloned()
+ .map(LanguageModel::OpenAi)
+ .collect();
}
}
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
- // available_models is not set but the default model is set to custom, only show custom
vec![self.model.clone()]
} else {
- // default case, use all models except custom
OpenAiModel::iter()
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
.collect()
};
- available_models.into_iter()
+ available_models
+ .into_iter()
+ .map(LanguageModel::OpenAi)
+ .collect()
}
- pub fn settings_version(&self) -> usize {
+ fn settings_version(&self) -> usize {
self.settings_version
}
- pub fn is_authenticated(&self) -> bool {
+ fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
- pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
+ fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
@@ -103,36 +142,36 @@ impl OpenAiCompletionProvider {
String::from_utf8(api_key)?
};
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
- if let CompletionProvider::OpenAi(provider) = provider {
+ provider.update_current_as::<_, Self>(|provider| {
provider.api_key = Some(api_key);
- }
+ });
})
})
}
}
- pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
+ fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
- if let CompletionProvider::OpenAi(provider) = provider {
+ provider.update_current_as::<_, Self>(|provider| {
provider.api_key = None;
- }
+ });
})
})
}
- pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
+ fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
.into()
}
- pub fn model(&self) -> OpenAiModel {
- self.model.clone()
+ fn model(&self) -> LanguageModel {
+ LanguageModel::OpenAi(self.model.clone())
}
- pub fn count_tokens(
+ fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
@@ -140,7 +179,7 @@ impl OpenAiCompletionProvider {
count_open_ai_tokens(request, cx.background_executor())
}
- pub fn complete(
+ fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
@@ -173,36 +212,8 @@ impl OpenAiCompletionProvider {
.boxed()
}
- fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
- let model = match request.model {
- LanguageModel::OpenAi(model) => model,
- _ => self.model(),
- };
-
- Request {
- model,
- messages: request
- .messages
- .into_iter()
- .map(|msg| match msg.role {
- Role::User => RequestMessage::User {
- content: msg.content,
- },
- Role::Assistant => RequestMessage::Assistant {
- content: Some(msg.content),
- tool_calls: Vec::new(),
- },
- Role::System => RequestMessage::System {
- content: msg.content,
- },
- })
- .collect(),
- stream: true,
- stop: request.stop,
- temperature: request.temperature,
- tools: Vec::new(),
- tool_choice: None,
- }
+ fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
+ self
}
}
@@ -284,9 +295,9 @@ impl AuthenticationPrompt {
cx.spawn(|_, mut cx| async move {
write_credentials.await?;
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
- if let CompletionProvider::OpenAi(provider) = provider {
+ provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.api_key = Some(api_key);
- }
+ });
})
})
.detach_and_log_err(cx);
@@ -1986,13 +1986,14 @@ impl Codegen {
.unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
let model_telemetry_id = prompt.model.telemetry_id();
- let response = CompletionProvider::global(cx).complete(prompt);
+ let response = CompletionProvider::global(cx).complete(prompt, cx);
let telemetry = self.telemetry.clone();
self.edit_position = range.start;
self.diff = Diff::default();
self.status = CodegenStatus::Pending;
self.generation = cx.spawn(|this, mut cx| {
async move {
+ let response = response.await;
let generate = async {
let mut edit_start = range.start.to_offset(&snapshot);
@@ -2002,7 +2003,7 @@ impl Codegen {
let mut response_latency = None;
let request_start = Instant::now();
let diff = async {
- let chunks = StripInvalidSpans::new(response.await?);
+ let chunks = StripInvalidSpans::new(response.inner.await?);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
@@ -2473,9 +2474,8 @@ mod tests {
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
- let provider = FakeCompletionProvider::default();
cx.set_global(cx.update(SettingsStore::test));
- cx.set_global(CompletionProvider::Fake(provider.clone()));
+ let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.update(language_settings::init);
let text = indoc! {"
@@ -2495,8 +2495,11 @@ mod tests {
});
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx));
- let request = LanguageModelRequest::default();
- codegen.update(cx, |codegen, cx| codegen.start(request, cx));
+ codegen.update(cx, |codegen, cx| {
+ codegen.start(LanguageModelRequest::default(), cx)
+ });
+
+ cx.background_executor.run_until_parked();
let mut new_text = concat!(
" let mut x = 0;\n",
@@ -2508,11 +2511,11 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
- provider.send_completion(chunk.into());
+ provider.send_completion(&LanguageModelRequest::default(), chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}
- provider.finish_completion();
+ provider.finish_completion(&LanguageModelRequest::default());
cx.background_executor.run_until_parked();
assert_eq!(
@@ -2533,8 +2536,7 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
- let provider = FakeCompletionProvider::default();
- cx.set_global(CompletionProvider::Fake(provider.clone()));
+ let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
@@ -2555,6 +2557,8 @@ mod tests {
let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
+ cx.background_executor.run_until_parked();
+
let mut new_text = concat!(
"t mut x = 0;\n",
"while x < 10 {\n",
@@ -2565,11 +2569,11 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
- provider.send_completion(chunk.into());
+ provider.send_completion(&LanguageModelRequest::default(), chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}
- provider.finish_completion();
+ provider.finish_completion(&LanguageModelRequest::default());
cx.background_executor.run_until_parked();
assert_eq!(
@@ -2590,8 +2594,7 @@ mod tests {
cx: &mut TestAppContext,
mut rng: StdRng,
) {
- let provider = FakeCompletionProvider::default();
- cx.set_global(CompletionProvider::Fake(provider.clone()));
+ let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init);
@@ -2612,6 +2615,8 @@ mod tests {
let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
+ cx.background_executor.run_until_parked();
+
let mut new_text = concat!(
"let mut x = 0;\n",
"while x < 10 {\n",
@@ -2622,11 +2627,11 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len);
- provider.send_completion(chunk.into());
+ provider.send_completion(&LanguageModelRequest::default(), chunk.into());
new_text = suffix;
cx.background_executor.run_until_parked();
}
- provider.finish_completion();
+ provider.finish_completion(&LanguageModelRequest::default());
cx.background_executor.run_until_parked();
assert_eq!(
@@ -1026,9 +1026,10 @@ impl Codegen {
let telemetry = self.telemetry.clone();
let model_telemetry_id = prompt.model.telemetry_id();
- let response = CompletionProvider::global(cx).complete(prompt);
+ let response = CompletionProvider::global(cx).complete(prompt, cx);
self.generation = cx.spawn(|this, mut cx| async move {
+ let response = response.await;
let generate = async {
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
@@ -1036,7 +1037,7 @@ impl Codegen {
let mut response_latency = None;
let request_start = Instant::now();
let task = async {
- let mut response = response.await?;
+ let mut response = response.inner.await?;
while let Some(chunk) = response.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());