add concept of LanguageModel to CompletionProvider

KCaverly created

Change summary

crates/ai/src/completion.rs                   |  3 +++
crates/ai/src/providers/open_ai/completion.rs | 21 ++++++++++++++++++---
crates/ai/src/providers/open_ai/embedding.rs  |  1 -
crates/assistant/src/assistant_panel.rs       |  1 +
crates/assistant/src/codegen.rs               |  5 +++++
5 files changed, 27 insertions(+), 4 deletions(-)

Detailed changes

crates/ai/src/completion.rs 🔗

@@ -1,11 +1,14 @@
 use anyhow::Result;
 use futures::{future::BoxFuture, stream::BoxStream};
 
+use crate::models::LanguageModel;
+
 pub trait CompletionRequest: Send + Sync {
     fn data(&self) -> serde_json::Result<String>;
 }
 
 pub trait CompletionProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel>;
     fn complete(
         &self,
         prompt: Box<dyn CompletionRequest>,

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -12,7 +12,12 @@ use std::{
     sync::Arc,
 };
 
-use crate::completion::{CompletionProvider, CompletionRequest};
+use crate::{
+    completion::{CompletionProvider, CompletionRequest},
+    models::LanguageModel,
+};
+
+use super::OpenAILanguageModel;
 
 pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
 
@@ -180,17 +185,27 @@ pub async fn stream_completion(
 }
 
 pub struct OpenAICompletionProvider {
+    model: OpenAILanguageModel,
     api_key: String,
     executor: Arc<Background>,
 }
 
 impl OpenAICompletionProvider {
-    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
-        Self { api_key, executor }
+    pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
+        let model = OpenAILanguageModel::load(model_name);
+        Self {
+            model,
+            api_key,
+            executor,
+        }
     }
 }
 
 impl CompletionProvider for OpenAICompletionProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+        model
+    }
     fn complete(
         &self,
         prompt: Box<dyn CompletionRequest>,

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -26,7 +26,6 @@ use crate::providers::open_ai::OpenAILanguageModel;
 use crate::providers::open_ai::auth::OpenAICredentialProvider;
 
 lazy_static! {
-    static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
     static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
 }
 

crates/assistant/src/assistant_panel.rs 🔗

@@ -328,6 +328,7 @@ impl AssistantPanel {
 
         let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
         let provider = Arc::new(OpenAICompletionProvider::new(
+            "gpt-4",
             api_key,
             cx.background().clone(),
         ));

crates/assistant/src/codegen.rs 🔗

@@ -335,6 +335,7 @@ fn strip_markdown_codeblock(
 #[cfg(test)]
 mod tests {
     use super::*;
+    use ai::{models::LanguageModel, test::FakeLanguageModel};
     use futures::{
         future::BoxFuture,
         stream::{self, BoxStream},
@@ -638,6 +639,10 @@ mod tests {
     }
 
     impl CompletionProvider for TestCompletionProvider {
+        fn base_model(&self) -> Box<dyn LanguageModel> {
+            let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
+            model
+        }
         fn complete(
             &self,
             _prompt: Box<dyn CompletionRequest>,