added credential provider to completion provider

KCaverly created

Change summary

crates/ai/src/completion.rs                   | 10 +++++++++-
crates/ai/src/providers/open_ai/completion.rs | 10 +++++++++-
crates/ai/src/providers/open_ai/embedding.rs  |  1 -
crates/ai/src/test.rs                         |  3 +++
crates/assistant/src/codegen.rs               |  7 +------
5 files changed, 22 insertions(+), 9 deletions(-)

Detailed changes

crates/ai/src/completion.rs 🔗

@@ -1,7 +1,11 @@
 use anyhow::Result;
 use futures::{future::BoxFuture, stream::BoxStream};
+use gpui::AppContext;
 
-use crate::models::LanguageModel;
+use crate::{
+    auth::{CredentialProvider, ProviderCredential},
+    models::LanguageModel,
+};
 
 pub trait CompletionRequest: Send + Sync {
     fn data(&self) -> serde_json::Result<String>;
@@ -9,6 +13,10 @@ pub trait CompletionRequest: Send + Sync {
 
 pub trait CompletionProvider {
     fn base_model(&self) -> Box<dyn LanguageModel>;
+    fn credential_provider(&self) -> Box<dyn CredentialProvider>;
+    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+        self.credential_provider().retrieve_credentials(cx)
+    }
     fn complete(
         &self,
         prompt: Box<dyn CompletionRequest>,

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -13,11 +13,12 @@ use std::{
 };
 
 use crate::{
+    auth::CredentialProvider,
     completion::{CompletionProvider, CompletionRequest},
     models::LanguageModel,
 };
 
-use super::OpenAILanguageModel;
+use super::{auth::OpenAICredentialProvider, OpenAILanguageModel};
 
 pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 
@@ -186,6 +187,7 @@ pub async fn stream_completion(
 
 pub struct OpenAICompletionProvider {
     model: OpenAILanguageModel,
+    credential_provider: OpenAICredentialProvider,
     api_key: String,
     executor: Arc<Background>,
 }
@@ -193,8 +195,10 @@ pub struct OpenAICompletionProvider {
 impl OpenAICompletionProvider {
     pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
         let model = OpenAILanguageModel::load(model_name);
+        let credential_provider = OpenAICredentialProvider {};
         Self {
             model,
+            credential_provider,
             api_key,
             executor,
         }
@@ -206,6 +210,10 @@ impl CompletionProvider for OpenAICompletionProvider {
         let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
         model
     }
+    fn credential_provider(&self) -> Box<dyn CredentialProvider> {
+        let provider: Box<dyn CredentialProvider> = Box::new(self.credential_provider.clone());
+        provider
+    }
     fn complete(
         &self,
         prompt: Box<dyn CompletionRequest>,

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -11,7 +11,6 @@ use parking_lot::Mutex;
 use parse_duration::parse;
 use postage::watch;
 use serde::{Deserialize, Serialize};
-use std::env;
 use std::ops::Add;
 use std::sync::Arc;
 use std::time::{Duration, Instant};

crates/ai/src/test.rs 🔗

@@ -155,6 +155,9 @@ impl CompletionProvider for TestCompletionProvider {
         let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
         model
     }
+    fn credential_provider(&self) -> Box<dyn CredentialProvider> {
+        Box::new(NullCredentialProvider {})
+    }
     fn complete(
         &self,
         _prompt: Box<dyn CompletionRequest>,

crates/assistant/src/codegen.rs 🔗

@@ -336,18 +336,13 @@ fn strip_markdown_codeblock(
 mod tests {
     use super::*;
     use ai::test::TestCompletionProvider;
-    use futures::{
-        future::BoxFuture,
-        stream::{self, BoxStream},
-    };
+    use futures::stream::{self};
     use gpui::{executor::Deterministic, TestAppContext};
     use indoc::indoc;
     use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
-    use parking_lot::Mutex;
     use rand::prelude::*;
     use serde::Serialize;
     use settings::SettingsStore;
-    use smol::future::FutureExt;
 
     #[derive(Serialize)]
     pub struct DummyCompletionRequest {