Detailed changes
@@ -1,11 +1,14 @@
use anyhow::Result;
use futures::{future::BoxFuture, stream::BoxStream};
+use crate::models::LanguageModel;
+
pub trait CompletionRequest: Send + Sync {
fn data(&self) -> serde_json::Result<String>;
}
pub trait CompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
@@ -12,7 +12,12 @@ use std::{
sync::Arc,
};
-use crate::completion::{CompletionProvider, CompletionRequest};
+use crate::{
+ completion::{CompletionProvider, CompletionRequest},
+ models::LanguageModel,
+};
+
+use super::OpenAILanguageModel;
pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
@@ -180,17 +185,27 @@ pub async fn stream_completion(
}
pub struct OpenAICompletionProvider {
+ model: OpenAILanguageModel,
api_key: String,
executor: Arc<Background>,
}
impl OpenAICompletionProvider {
- pub fn new(api_key: String, executor: Arc<Background>) -> Self {
- Self { api_key, executor }
+ pub fn new(model_name: &str, api_key: String, executor: Arc<Background>) -> Self {
+ let model = OpenAILanguageModel::load(model_name);
+ Self {
+ model,
+ api_key,
+ executor,
+ }
}
}
impl CompletionProvider for OpenAICompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+ model
+ }
fn complete(
&self,
prompt: Box<dyn CompletionRequest>,
@@ -26,7 +26,6 @@ use crate::providers::open_ai::OpenAILanguageModel;
use crate::providers::open_ai::auth::OpenAICredentialProvider;
lazy_static! {
- static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
@@ -328,6 +328,7 @@ impl AssistantPanel {
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let provider = Arc::new(OpenAICompletionProvider::new(
+ "gpt-4",
api_key,
cx.background().clone(),
));
@@ -335,6 +335,7 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
+ use ai::{models::LanguageModel, test::FakeLanguageModel};
use futures::{
future::BoxFuture,
stream::{self, BoxStream},
@@ -638,6 +639,10 @@ mod tests {
}
impl CompletionProvider for TestCompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
+ model
+ }
fn complete(
&self,
_prompt: Box<dyn CompletionRequest>,