port ai to zed2 (#3186)

Kyle Caverly created

port ai to zed2

Change summary

Cargo.lock                                        |  28 +
crates/Cargo.toml                                 |  38 +
crates/ai/Cargo.toml                              |   3 
crates/ai/src/ai.rs                               |   6 
crates/ai/src/auth.rs                             |  15 
crates/ai/src/completion.rs                       | 215 ----------
crates/ai/src/embedding.rs                        | 322 ----------------
crates/ai/src/models.rs                           |  70 ---
crates/ai/src/prompts/base.rs                     |  56 -
crates/ai/src/prompts/file_context.rs             |  16 
crates/ai/src/prompts/generate.rs                 |   8 
crates/ai/src/prompts/mod.rs                      |   0 
crates/ai/src/prompts/preamble.rs                 |   2 
crates/ai/src/prompts/repository_context.rs       |   2 
crates/ai/src/providers/mod.rs                    |   1 
crates/ai/src/providers/open_ai/completion.rs     | 298 +++++++++++++++
crates/ai/src/providers/open_ai/embedding.rs      | 306 +++++++++++++++
crates/ai/src/providers/open_ai/mod.rs            |   9 
crates/ai/src/providers/open_ai/model.rs          |  57 ++
crates/ai/src/providers/open_ai/new.rs            |  11 
crates/ai/src/test.rs                             | 191 +++++++++
crates/ai2/Cargo.toml                             |  38 +
crates/ai2/src/ai2.rs                             |   8 
crates/ai2/src/auth.rs                            |  17 
crates/ai2/src/completion.rs                      |  23 +
crates/ai2/src/embedding.rs                       | 123 ++++++
crates/ai2/src/models.rs                          |  16 
crates/ai2/src/prompts/base.rs                    | 330 +++++++++++++++++
crates/ai2/src/prompts/file_context.rs            | 164 ++++++++
crates/ai2/src/prompts/generate.rs                |  99 +++++
crates/ai2/src/prompts/mod.rs                     |   5 
crates/ai2/src/prompts/preamble.rs                |  52 ++
crates/ai2/src/prompts/repository_context.rs      |  98 +++++
crates/ai2/src/providers/mod.rs                   |   1 
crates/ai2/src/providers/open_ai/completion.rs    | 306 +++++++++++++++
crates/ai2/src/providers/open_ai/embedding.rs     | 313 ++++++++++++++++
crates/ai2/src/providers/open_ai/mod.rs           |   9 
crates/ai2/src/providers/open_ai/model.rs         |  57 ++
crates/ai2/src/providers/open_ai/new.rs           |  11 
crates/ai2/src/test.rs                            | 193 +++++++++
crates/assistant/Cargo.toml                       |   1 
crates/assistant/src/assistant.rs                 |   2 
crates/assistant/src/assistant_panel.rs           | 280 ++++++-------
crates/assistant/src/codegen.rs                   |  80 +--
crates/assistant/src/prompts.rs                   |  13 
crates/semantic_index/Cargo.toml                  |   1 
crates/semantic_index/src/embedding_queue.rs      |  15 
crates/semantic_index/src/parsing.rs              |  33 +
crates/semantic_index/src/semantic_index.rs       |  55 +-
crates/semantic_index/src/semantic_index_tests.rs |  93 ----
crates/ui2/src/elements/icon.rs                   |   2 
crates/zed/examples/semantic_index_eval.rs        |   4 
crates/zed2/Cargo.toml                            |   1 
53 files changed, 3,136 insertions(+), 961 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -108,6 +108,33 @@ dependencies = [
  "util",
 ]
 
+[[package]]
+name = "ai2"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "async-trait",
+ "bincode",
+ "futures 0.3.28",
+ "gpui2",
+ "isahc",
+ "language2",
+ "lazy_static",
+ "log",
+ "matrixmultiply",
+ "ordered-float 2.10.0",
+ "parking_lot 0.11.2",
+ "parse_duration",
+ "postage",
+ "rand 0.8.5",
+ "regex",
+ "rusqlite",
+ "serde",
+ "serde_json",
+ "tiktoken-rs",
+ "util",
+]
+
 [[package]]
 name = "alacritty_config"
 version = "0.1.2-dev"
@@ -10903,6 +10930,7 @@ dependencies = [
 name = "zed2"
 version = "0.109.0"
 dependencies = [
+ "ai2",
  "anyhow",
  "async-compression",
  "async-recursion 0.3.2",

crates/Cargo.toml 🔗

@@ -0,0 +1,38 @@
+[package]
+name = "ai"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/ai.rs"
+doctest = false
+
+[features]
+test-support = []
+
+[dependencies]
+gpui = { path = "../gpui" }
+util = { path = "../util" }
+language = { path = "../language" }
+async-trait.workspace = true
+anyhow.workspace = true
+futures.workspace = true
+lazy_static.workspace = true
+ordered-float.workspace = true
+parking_lot.workspace = true
+isahc.workspace = true
+regex.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+postage.workspace = true
+rand.workspace = true
+log.workspace = true
+parse_duration = "2.1.1"
+tiktoken-rs = "0.5.0"
+matrixmultiply = "0.3.7"
+rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
+bincode = "1.3.3"
+
+[dev-dependencies]
+gpui = { path = "../gpui", features = ["test-support"] }

crates/ai/Cargo.toml 🔗

@@ -8,6 +8,9 @@ publish = false
 path = "src/ai.rs"
 doctest = false
 
+[features]
+test-support = []
+
 [dependencies]
 gpui = { path = "../gpui" }
 util = { path = "../util" }

crates/ai/src/ai.rs 🔗

@@ -1,4 +1,8 @@
+pub mod auth;
 pub mod completion;
 pub mod embedding;
 pub mod models;
-pub mod templates;
+pub mod prompts;
+pub mod providers;
+#[cfg(any(test, feature = "test-support"))]
+pub mod test;

crates/ai/src/auth.rs 🔗

@@ -0,0 +1,15 @@
+use gpui::AppContext;
+
+#[derive(Clone, Debug)]
+pub enum ProviderCredential {
+    Credentials { api_key: String },
+    NoCredentials,
+    NotNeeded,
+}
+
+pub trait CredentialProvider: Send + Sync {
+    fn has_credentials(&self) -> bool;
+    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
+    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
+    fn delete_credentials(&self, cx: &AppContext);
+}

crates/ai/src/completion.rs 🔗

@@ -1,214 +1,23 @@
-use anyhow::{anyhow, Result};
-use futures::{
-    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
-    Stream, StreamExt,
-};
-use gpui::executor::Background;
-use isahc::{http::StatusCode, Request, RequestExt};
-use serde::{Deserialize, Serialize};
-use std::{
-    fmt::{self, Display},
-    io,
-    sync::Arc,
-};
+use anyhow::Result;
+use futures::{future::BoxFuture, stream::BoxStream};
 
-pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
+use crate::{auth::CredentialProvider, models::LanguageModel};
 
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-pub enum Role {
-    User,
-    Assistant,
-    System,
+pub trait CompletionRequest: Send + Sync {
+    fn data(&self) -> serde_json::Result<String>;
 }
 
-impl Role {
-    pub fn cycle(&mut self) {
-        *self = match self {
-            Role::User => Role::Assistant,
-            Role::Assistant => Role::System,
-            Role::System => Role::User,
-        }
-    }
-}
-
-impl Display for Role {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
-        match self {
-            Role::User => write!(f, "User"),
-            Role::Assistant => write!(f, "Assistant"),
-            Role::System => write!(f, "System"),
-        }
-    }
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct RequestMessage {
-    pub role: Role,
-    pub content: String,
-}
-
-#[derive(Debug, Default, Serialize)]
-pub struct OpenAIRequest {
-    pub model: String,
-    pub messages: Vec<RequestMessage>,
-    pub stream: bool,
-    pub stop: Vec<String>,
-    pub temperature: f32,
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct ResponseMessage {
-    pub role: Option<Role>,
-    pub content: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAIUsage {
-    pub prompt_tokens: u32,
-    pub completion_tokens: u32,
-    pub total_tokens: u32,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct ChatChoiceDelta {
-    pub index: u32,
-    pub delta: ResponseMessage,
-    pub finish_reason: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAIResponseStreamEvent {
-    pub id: Option<String>,
-    pub object: String,
-    pub created: u32,
-    pub model: String,
-    pub choices: Vec<ChatChoiceDelta>,
-    pub usage: Option<OpenAIUsage>,
-}
-
-pub async fn stream_completion(
-    api_key: String,
-    executor: Arc<Background>,
-    mut request: OpenAIRequest,
-) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
-    request.stream = true;
-
-    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
-
-    let json_data = serde_json::to_string(&request)?;
-    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
-        .header("Content-Type", "application/json")
-        .header("Authorization", format!("Bearer {}", api_key))
-        .body(json_data)?
-        .send_async()
-        .await?;
-
-    let status = response.status();
-    if status == StatusCode::OK {
-        executor
-            .spawn(async move {
-                let mut lines = BufReader::new(response.body_mut()).lines();
-
-                fn parse_line(
-                    line: Result<String, io::Error>,
-                ) -> Result<Option<OpenAIResponseStreamEvent>> {
-                    if let Some(data) = line?.strip_prefix("data: ") {
-                        let event = serde_json::from_str(&data)?;
-                        Ok(Some(event))
-                    } else {
-                        Ok(None)
-                    }
-                }
-
-                while let Some(line) = lines.next().await {
-                    if let Some(event) = parse_line(line).transpose() {
-                        let done = event.as_ref().map_or(false, |event| {
-                            event
-                                .choices
-                                .last()
-                                .map_or(false, |choice| choice.finish_reason.is_some())
-                        });
-                        if tx.unbounded_send(event).is_err() {
-                            break;
-                        }
-
-                        if done {
-                            break;
-                        }
-                    }
-                }
-
-                anyhow::Ok(())
-            })
-            .detach();
-
-        Ok(rx)
-    } else {
-        let mut body = String::new();
-        response.body_mut().read_to_string(&mut body).await?;
-
-        #[derive(Deserialize)]
-        struct OpenAIResponse {
-            error: OpenAIError,
-        }
-
-        #[derive(Deserialize)]
-        struct OpenAIError {
-            message: String,
-        }
-
-        match serde_json::from_str::<OpenAIResponse>(&body) {
-            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
-                "Failed to connect to OpenAI API: {}",
-                response.error.message,
-            )),
-
-            _ => Err(anyhow!(
-                "Failed to connect to OpenAI API: {} {}",
-                response.status(),
-                body,
-            )),
-        }
-    }
-}
-
-pub trait CompletionProvider {
+pub trait CompletionProvider: CredentialProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel>;
     fn complete(
         &self,
-        prompt: OpenAIRequest,
+        prompt: Box<dyn CompletionRequest>,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+    fn box_clone(&self) -> Box<dyn CompletionProvider>;
 }
 
-pub struct OpenAICompletionProvider {
-    api_key: String,
-    executor: Arc<Background>,
-}
-
-impl OpenAICompletionProvider {
-    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
-        Self { api_key, executor }
-    }
-}
-
-impl CompletionProvider for OpenAICompletionProvider {
-    fn complete(
-        &self,
-        prompt: OpenAIRequest,
-    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
-        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
-        async move {
-            let response = request.await?;
-            let stream = response
-                .filter_map(|response| async move {
-                    match response {
-                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
-                        Err(error) => Some(Err(error)),
-                    }
-                })
-                .boxed();
-            Ok(stream)
-        }
-        .boxed()
+impl Clone for Box<dyn CompletionProvider> {
+    fn clone(&self) -> Box<dyn CompletionProvider> {
+        self.box_clone()
     }
 }

crates/ai/src/embedding.rs 🔗

@@ -1,32 +1,13 @@
-use anyhow::{anyhow, Result};
+use std::time::Instant;
+
+use anyhow::Result;
 use async_trait::async_trait;
-use futures::AsyncReadExt;
-use gpui::executor::Background;
-use gpui::{serde_json, AppContext};
-use isahc::http::StatusCode;
-use isahc::prelude::Configurable;
-use isahc::{AsyncBody, Response};
-use lazy_static::lazy_static;
 use ordered_float::OrderedFloat;
-use parking_lot::Mutex;
-use parse_duration::parse;
-use postage::watch;
 use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 use rusqlite::ToSql;
-use serde::{Deserialize, Serialize};
-use std::env;
-use std::ops::Add;
-use std::sync::Arc;
-use std::time::{Duration, Instant};
-use tiktoken_rs::{cl100k_base, CoreBPE};
-use util::http::{HttpClient, Request};
-use util::ResultExt;
-
-use crate::completion::OPENAI_API_URL;
 
-lazy_static! {
-    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
-}
+use crate::auth::CredentialProvider;
+use crate::models::LanguageModel;
 
 #[derive(Debug, PartialEq, Clone)]
 pub struct Embedding(pub Vec<f32>);
@@ -87,301 +68,14 @@ impl Embedding {
     }
 }
 
-#[derive(Clone)]
-pub struct OpenAIEmbeddings {
-    pub client: Arc<dyn HttpClient>,
-    pub executor: Arc<Background>,
-    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
-    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
-}
-
-#[derive(Serialize)]
-struct OpenAIEmbeddingRequest<'a> {
-    model: &'static str,
-    input: Vec<&'a str>,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingResponse {
-    data: Vec<OpenAIEmbedding>,
-    usage: OpenAIEmbeddingUsage,
-}
-
-#[derive(Debug, Deserialize)]
-struct OpenAIEmbedding {
-    embedding: Vec<f32>,
-    index: usize,
-    object: String,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingUsage {
-    prompt_tokens: usize,
-    total_tokens: usize,
-}
-
 #[async_trait]
-pub trait EmbeddingProvider: Sync + Send {
-    fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
-    async fn embed_batch(
-        &self,
-        spans: Vec<String>,
-        api_key: Option<String>,
-    ) -> Result<Vec<Embedding>>;
+pub trait EmbeddingProvider: CredentialProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel>;
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
     fn max_tokens_per_batch(&self) -> usize;
-    fn truncate(&self, span: &str) -> (String, usize);
     fn rate_limit_expiration(&self) -> Option<Instant>;
 }
 
-pub struct DummyEmbeddings {}
-
-#[async_trait]
-impl EmbeddingProvider for DummyEmbeddings {
-    fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
-        Some("Dummy API KEY".to_string())
-    }
-    fn rate_limit_expiration(&self) -> Option<Instant> {
-        None
-    }
-    async fn embed_batch(
-        &self,
-        spans: Vec<String>,
-        _api_key: Option<String>,
-    ) -> Result<Vec<Embedding>> {
-        // 1024 is the OpenAI Embeddings size for ada models.
-        // the model we will likely be starting with.
-        let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
-        return Ok(vec![dummy_vec; spans.len()]);
-    }
-
-    fn max_tokens_per_batch(&self) -> usize {
-        OPENAI_INPUT_LIMIT
-    }
-
-    fn truncate(&self, span: &str) -> (String, usize) {
-        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
-        let token_count = tokens.len();
-        let output = if token_count > OPENAI_INPUT_LIMIT {
-            tokens.truncate(OPENAI_INPUT_LIMIT);
-            let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
-            new_input.ok().unwrap_or_else(|| span.to_string())
-        } else {
-            span.to_string()
-        };
-
-        (output, tokens.len())
-    }
-}
-
-const OPENAI_INPUT_LIMIT: usize = 8190;
-
-impl OpenAIEmbeddings {
-    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
-        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
-        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
-
-        OpenAIEmbeddings {
-            client,
-            executor,
-            rate_limit_count_rx,
-            rate_limit_count_tx,
-        }
-    }
-
-    fn resolve_rate_limit(&self) {
-        let reset_time = *self.rate_limit_count_tx.lock().borrow();
-
-        if let Some(reset_time) = reset_time {
-            if Instant::now() >= reset_time {
-                *self.rate_limit_count_tx.lock().borrow_mut() = None
-            }
-        }
-
-        log::trace!(
-            "resolving reset time: {:?}",
-            *self.rate_limit_count_tx.lock().borrow()
-        );
-    }
-
-    fn update_reset_time(&self, reset_time: Instant) {
-        let original_time = *self.rate_limit_count_tx.lock().borrow();
-
-        let updated_time = if let Some(original_time) = original_time {
-            if reset_time < original_time {
-                Some(reset_time)
-            } else {
-                Some(original_time)
-            }
-        } else {
-            Some(reset_time)
-        };
-
-        log::trace!("updating rate limit time: {:?}", updated_time);
-
-        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
-    }
-    async fn send_request(
-        &self,
-        api_key: &str,
-        spans: Vec<&str>,
-        request_timeout: u64,
-    ) -> Result<Response<AsyncBody>> {
-        let request = Request::post("https://api.openai.com/v1/embeddings")
-            .redirect_policy(isahc::config::RedirectPolicy::Follow)
-            .timeout(Duration::from_secs(request_timeout))
-            .header("Content-Type", "application/json")
-            .header("Authorization", format!("Bearer {}", api_key))
-            .body(
-                serde_json::to_string(&OpenAIEmbeddingRequest {
-                    input: spans.clone(),
-                    model: "text-embedding-ada-002",
-                })
-                .unwrap()
-                .into(),
-            )?;
-
-        Ok(self.client.send(request).await?)
-    }
-}
-
-#[async_trait]
-impl EmbeddingProvider for OpenAIEmbeddings {
-    fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
-        if let Ok(api_key) = env::var("OPENAI_API_KEY") {
-            Some(api_key)
-        } else if let Some((_, api_key)) = cx
-            .platform()
-            .read_credentials(OPENAI_API_URL)
-            .log_err()
-            .flatten()
-        {
-            String::from_utf8(api_key).log_err()
-        } else {
-            None
-        }
-    }
-
-    fn max_tokens_per_batch(&self) -> usize {
-        50000
-    }
-
-    fn rate_limit_expiration(&self) -> Option<Instant> {
-        *self.rate_limit_count_rx.borrow()
-    }
-    fn truncate(&self, span: &str) -> (String, usize) {
-        let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
-        let output = if tokens.len() > OPENAI_INPUT_LIMIT {
-            tokens.truncate(OPENAI_INPUT_LIMIT);
-            OPENAI_BPE_TOKENIZER
-                .decode(tokens.clone())
-                .ok()
-                .unwrap_or_else(|| span.to_string())
-        } else {
-            span.to_string()
-        };
-
-        (output, tokens.len())
-    }
-
-    async fn embed_batch(
-        &self,
-        spans: Vec<String>,
-        api_key: Option<String>,
-    ) -> Result<Vec<Embedding>> {
-        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
-        const MAX_RETRIES: usize = 4;
-
-        let Some(api_key) = api_key else {
-            return Err(anyhow!("no open ai key provided"));
-        };
-
-        let mut request_number = 0;
-        let mut rate_limiting = false;
-        let mut request_timeout: u64 = 15;
-        let mut response: Response<AsyncBody>;
-        while request_number < MAX_RETRIES {
-            response = self
-                .send_request(
-                    &api_key,
-                    spans.iter().map(|x| &**x).collect(),
-                    request_timeout,
-                )
-                .await?;
-
-            request_number += 1;
-
-            match response.status() {
-                StatusCode::REQUEST_TIMEOUT => {
-                    request_timeout += 5;
-                }
-                StatusCode::OK => {
-                    let mut body = String::new();
-                    response.body_mut().read_to_string(&mut body).await?;
-                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
-
-                    log::trace!(
-                        "openai embedding completed. tokens: {:?}",
-                        response.usage.total_tokens
-                    );
-
-                    // If we complete a request successfully that was previously rate_limited
-                    // resolve the rate limit
-                    if rate_limiting {
-                        self.resolve_rate_limit()
-                    }
-
-                    return Ok(response
-                        .data
-                        .into_iter()
-                        .map(|embedding| Embedding::from(embedding.embedding))
-                        .collect());
-                }
-                StatusCode::TOO_MANY_REQUESTS => {
-                    rate_limiting = true;
-                    let mut body = String::new();
-                    response.body_mut().read_to_string(&mut body).await?;
-
-                    let delay_duration = {
-                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
-                        if let Some(time_to_reset) =
-                            response.headers().get("x-ratelimit-reset-tokens")
-                        {
-                            if let Ok(time_str) = time_to_reset.to_str() {
-                                parse(time_str).unwrap_or(delay)
-                            } else {
-                                delay
-                            }
-                        } else {
-                            delay
-                        }
-                    };
-
-                    // If we've previously rate limited, increment the duration but not the count
-                    let reset_time = Instant::now().add(delay_duration);
-                    self.update_reset_time(reset_time);
-
-                    log::trace!(
-                        "openai rate limiting: waiting {:?} until lifted",
-                        &delay_duration
-                    );
-
-                    self.executor.timer(delay_duration).await;
-                }
-                _ => {
-                    let mut body = String::new();
-                    response.body_mut().read_to_string(&mut body).await?;
-                    return Err(anyhow!(
-                        "open ai bad request: {:?} {:?}",
-                        &response.status(),
-                        body
-                    ));
-                }
-            }
-        }
-        Err(anyhow!("openai max retries"))
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;

crates/ai/src/models.rs 🔗

@@ -1,66 +1,16 @@
-use anyhow::anyhow;
-use tiktoken_rs::CoreBPE;
-use util::ResultExt;
+pub enum TruncationDirection {
+    Start,
+    End,
+}
 
 pub trait LanguageModel {
     fn name(&self) -> String;
     fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
-    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
-    fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String>;
     fn capacity(&self) -> anyhow::Result<usize>;
 }
-
-pub struct OpenAILanguageModel {
-    name: String,
-    bpe: Option<CoreBPE>,
-}
-
-impl OpenAILanguageModel {
-    pub fn load(model_name: &str) -> Self {
-        let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
-        OpenAILanguageModel {
-            name: model_name.to_string(),
-            bpe,
-        }
-    }
-}
-
-impl LanguageModel for OpenAILanguageModel {
-    fn name(&self) -> String {
-        self.name.clone()
-    }
-    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
-        if let Some(bpe) = &self.bpe {
-            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
-        } else {
-            Err(anyhow!("bpe for open ai model was not retrieved"))
-        }
-    }
-    fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
-        if let Some(bpe) = &self.bpe {
-            let tokens = bpe.encode_with_special_tokens(content);
-            if tokens.len() > length {
-                bpe.decode(tokens[..length].to_vec())
-            } else {
-                bpe.decode(tokens)
-            }
-        } else {
-            Err(anyhow!("bpe for open ai model was not retrieved"))
-        }
-    }
-    fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
-        if let Some(bpe) = &self.bpe {
-            let tokens = bpe.encode_with_special_tokens(content);
-            if tokens.len() > length {
-                bpe.decode(tokens[length..].to_vec())
-            } else {
-                bpe.decode(tokens)
-            }
-        } else {
-            Err(anyhow!("bpe for open ai model was not retrieved"))
-        }
-    }
-    fn capacity(&self) -> anyhow::Result<usize> {
-        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
-    }
-}

crates/ai/src/templates/base.rs → crates/ai/src/prompts/base.rs 🔗

@@ -6,7 +6,7 @@ use language::BufferSnapshot;
 use util::ResultExt;
 
 use crate::models::LanguageModel;
-use crate::templates::repository_context::PromptCodeSnippet;
+use crate::prompts::repository_context::PromptCodeSnippet;
 
 pub(crate) enum PromptFileType {
     Text,
@@ -125,6 +125,9 @@ impl PromptChain {
 
 #[cfg(test)]
 pub(crate) mod tests {
+    use crate::models::TruncationDirection;
+    use crate::test::FakeLanguageModel;
+
     use super::*;
 
     #[test]
@@ -141,7 +144,11 @@ pub(crate) mod tests {
                 let mut token_count = args.model.count_tokens(&content)?;
                 if let Some(max_token_length) = max_token_length {
                     if token_count > max_token_length {
-                        content = args.model.truncate(&content, max_token_length)?;
+                        content = args.model.truncate(
+                            &content,
+                            max_token_length,
+                            TruncationDirection::End,
+                        )?;
                         token_count = max_token_length;
                     }
                 }
@@ -162,7 +169,11 @@ pub(crate) mod tests {
                 let mut token_count = args.model.count_tokens(&content)?;
                 if let Some(max_token_length) = max_token_length {
                     if token_count > max_token_length {
-                        content = args.model.truncate(&content, max_token_length)?;
+                        content = args.model.truncate(
+                            &content,
+                            max_token_length,
+                            TruncationDirection::End,
+                        )?;
                         token_count = max_token_length;
                     }
                 }
@@ -171,38 +182,7 @@ pub(crate) mod tests {
             }
         }
 
-        #[derive(Clone)]
-        struct DummyLanguageModel {
-            capacity: usize,
-        }
-
-        impl LanguageModel for DummyLanguageModel {
-            fn name(&self) -> String {
-                "dummy".to_string()
-            }
-            fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
-                anyhow::Ok(content.chars().collect::<Vec<char>>().len())
-            }
-            fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
-                anyhow::Ok(
-                    content.chars().collect::<Vec<char>>()[..length]
-                        .into_iter()
-                        .collect::<String>(),
-                )
-            }
-            fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
-                anyhow::Ok(
-                    content.chars().collect::<Vec<char>>()[length..]
-                        .into_iter()
-                        .collect::<String>(),
-                )
-            }
-            fn capacity(&self) -> anyhow::Result<usize> {
-                anyhow::Ok(self.capacity)
-            }
-        }
-
-        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
         let args = PromptArguments {
             model: model.clone(),
             language_name: None,
@@ -238,7 +218,7 @@ pub(crate) mod tests {
 
         // Testing with Truncation Off
         // Should ignore capacity and return all prompts
-        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
         let args = PromptArguments {
             model: model.clone(),
             language_name: None,
@@ -275,7 +255,7 @@ pub(crate) mod tests {
         // Testing with Truncation Off
         // Should ignore capacity and return all prompts
         let capacity = 20;
-        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
         let args = PromptArguments {
             model: model.clone(),
             language_name: None,
@@ -311,7 +291,7 @@ pub(crate) mod tests {
         // Change Ordering of Prompts Based on Priority
         let capacity = 120;
         let reserved_tokens = 10;
-        let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
         let args = PromptArguments {
             model: model.clone(),
             language_name: None,

crates/ai/src/templates/file_context.rs → crates/ai/src/prompts/file_context.rs 🔗

@@ -3,8 +3,9 @@ use language::BufferSnapshot;
 use language::ToOffset;
 
 use crate::models::LanguageModel;
-use crate::templates::base::PromptArguments;
-use crate::templates::base::PromptTemplate;
+use crate::models::TruncationDirection;
+use crate::prompts::base::PromptArguments;
+use crate::prompts::base::PromptTemplate;
 use std::fmt::Write;
 use std::ops::Range;
 use std::sync::Arc;
@@ -70,8 +71,9 @@ fn retrieve_context(
                     };
 
                 let truncated_start_window =
-                    model.truncate_start(&start_window, start_goal_tokens)?;
-                let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
+                    model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
+                let truncated_end_window =
+                    model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
                 writeln!(
                     prompt,
                     "{truncated_start_window}{selected_window}{truncated_end_window}"
@@ -89,7 +91,7 @@ fn retrieve_context(
             if let Some(max_token_count) = max_token_count {
                 if model.count_tokens(&prompt)? > max_token_count {
                     truncated = true;
-                    prompt = model.truncate(&prompt, max_token_count)?;
+                    prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
                 }
             }
         }
@@ -148,7 +150,9 @@ impl PromptTemplate for FileContext {
 
             // Really dumb truncation strategy
             if let Some(max_tokens) = max_token_length {
-                prompt = args.model.truncate(&prompt, max_tokens)?;
+                prompt = args
+                    .model
+                    .truncate(&prompt, max_tokens, TruncationDirection::End)?;
             }
 
             let token_count = args.model.count_tokens(&prompt)?;

crates/ai/src/templates/generate.rs → crates/ai/src/prompts/generate.rs 🔗

@@ -1,4 +1,4 @@
-use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
 use anyhow::anyhow;
 use std::fmt::Write;
 
@@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent {
 
         // Really dumb truncation strategy
         if let Some(max_tokens) = max_token_length {
-            prompt = args.model.truncate(&prompt, max_tokens)?;
+            prompt = args.model.truncate(
+                &prompt,
+                max_tokens,
+                crate::models::TruncationDirection::End,
+            )?;
         }
 
         let token_count = args.model.count_tokens(&prompt)?;

crates/ai/src/templates/preamble.rs → crates/ai/src/prompts/preamble.rs 🔗

@@ -1,4 +1,4 @@
-use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
 use std::fmt::Write;
 
 pub struct EngineerPreamble {}

crates/ai/src/templates/repository_context.rs → crates/ai/src/prompts/repository_context.rs 🔗

@@ -1,4 +1,4 @@
-use crate::templates::base::{PromptArguments, PromptTemplate};
+use crate::prompts::base::{PromptArguments, PromptTemplate};
 use std::fmt::Write;
 use std::{ops::Range, path::PathBuf};
 

crates/ai/src/providers/open_ai/completion.rs 🔗

@@ -0,0 +1,298 @@
+use anyhow::{anyhow, Result};
+use futures::{
+    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
+    Stream, StreamExt,
+};
+use gpui::{executor::Background, AppContext};
+use isahc::{http::StatusCode, Request, RequestExt};
+use parking_lot::RwLock;
+use serde::{Deserialize, Serialize};
+use std::{
+    env,
+    fmt::{self, Display},
+    io,
+    sync::Arc,
+};
+use util::ResultExt;
+
+use crate::{
+    auth::{CredentialProvider, ProviderCredential},
+    completion::{CompletionProvider, CompletionRequest},
+    models::LanguageModel,
+};
+
+use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+    User,
+    Assistant,
+    System,
+}
+
+impl Role {
+    pub fn cycle(&mut self) {
+        *self = match self {
+            Role::User => Role::Assistant,
+            Role::Assistant => Role::System,
+            Role::System => Role::User,
+        }
+    }
+}
+
+impl Display for Role {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Role::User => write!(f, "User"),
+            Role::Assistant => write!(f, "Assistant"),
+            Role::System => write!(f, "System"),
+        }
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+    pub role: Role,
+    pub content: String,
+}
+
+#[derive(Debug, Default, Serialize)]
+pub struct OpenAIRequest {
+    pub model: String,
+    pub messages: Vec<RequestMessage>,
+    pub stream: bool,
+    pub stop: Vec<String>,
+    pub temperature: f32,
+}
+
+impl CompletionRequest for OpenAIRequest {
+    fn data(&self) -> serde_json::Result<String> {
+        serde_json::to_string(self)
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessage {
+    pub role: Option<Role>,
+    pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIUsage {
+    pub prompt_tokens: u32,
+    pub completion_tokens: u32,
+    pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ChatChoiceDelta {
+    pub index: u32,
+    pub delta: ResponseMessage,
+    pub finish_reason: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIResponseStreamEvent {
+    pub id: Option<String>,
+    pub object: String,
+    pub created: u32,
+    pub model: String,
+    pub choices: Vec<ChatChoiceDelta>,
+    pub usage: Option<OpenAIUsage>,
+}
+
+pub async fn stream_completion(
+    credential: ProviderCredential,
+    executor: Arc<Background>,
+    request: Box<dyn CompletionRequest>,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+    let api_key = match credential {
+        ProviderCredential::Credentials { api_key } => api_key,
+        _ => {
+            return Err(anyhow!("no credentials provider for completion"));
+        }
+    };
+
+    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+    let json_data = request.data()?;
+    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("Bearer {}", api_key))
+        .body(json_data)?
+        .send_async()
+        .await?;
+
+    let status = response.status();
+    if status == StatusCode::OK {
+        executor
+            .spawn(async move {
+                let mut lines = BufReader::new(response.body_mut()).lines();
+
+                fn parse_line(
+                    line: Result<String, io::Error>,
+                ) -> Result<Option<OpenAIResponseStreamEvent>> {
+                    if let Some(data) = line?.strip_prefix("data: ") {
+                        let event = serde_json::from_str(&data)?;
+                        Ok(Some(event))
+                    } else {
+                        Ok(None)
+                    }
+                }
+
+                while let Some(line) = lines.next().await {
+                    if let Some(event) = parse_line(line).transpose() {
+                        let done = event.as_ref().map_or(false, |event| {
+                            event
+                                .choices
+                                .last()
+                                .map_or(false, |choice| choice.finish_reason.is_some())
+                        });
+                        if tx.unbounded_send(event).is_err() {
+                            break;
+                        }
+
+                        if done {
+                            break;
+                        }
+                    }
+                }
+
+                anyhow::Ok(())
+            })
+            .detach();
+
+        Ok(rx)
+    } else {
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        #[derive(Deserialize)]
+        struct OpenAIResponse {
+            error: OpenAIError,
+        }
+
+        #[derive(Deserialize)]
+        struct OpenAIError {
+            message: String,
+        }
+
+        match serde_json::from_str::<OpenAIResponse>(&body) {
+            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+                "Failed to connect to OpenAI API: {}",
+                response.error.message,
+            )),
+
+            _ => Err(anyhow!(
+                "Failed to connect to OpenAI API: {} {}",
+                response.status(),
+                body,
+            )),
+        }
+    }
+}
+
+#[derive(Clone)]
+pub struct OpenAICompletionProvider {
+    model: OpenAILanguageModel,
+    credential: Arc<RwLock<ProviderCredential>>,
+    executor: Arc<Background>,
+}
+
+impl OpenAICompletionProvider {
+    pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
+        let model = OpenAILanguageModel::load(model_name);
+        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+        Self {
+            model,
+            credential,
+            executor,
+        }
+    }
+}
+
+impl CredentialProvider for OpenAICompletionProvider {
+    fn has_credentials(&self) -> bool {
+        match *self.credential.read() {
+            ProviderCredential::Credentials { .. } => true,
+            _ => false,
+        }
+    }
+    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+        let mut credential = self.credential.write();
+        match *credential {
+            ProviderCredential::Credentials { .. } => {
+                return credential.clone();
+            }
+            _ => {
+                if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+                    *credential = ProviderCredential::Credentials { api_key };
+                } else if let Some((_, api_key)) = cx
+                    .platform()
+                    .read_credentials(OPENAI_API_URL)
+                    .log_err()
+                    .flatten()
+                {
+                    if let Some(api_key) = String::from_utf8(api_key).log_err() {
+                        *credential = ProviderCredential::Credentials { api_key };
+                    }
+                } else {
+                };
+            }
+        }
+
+        credential.clone()
+    }
+
+    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+        match credential.clone() {
+            ProviderCredential::Credentials { api_key } => {
+                cx.platform()
+                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+                    .log_err();
+            }
+            _ => {}
+        }
+
+        *self.credential.write() = credential;
+    }
+    fn delete_credentials(&self, cx: &AppContext) {
+        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+        *self.credential.write() = ProviderCredential::NoCredentials;
+    }
+}
+
+impl CompletionProvider for OpenAICompletionProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+        model
+    }
+    fn complete(
+        &self,
+        prompt: Box<dyn CompletionRequest>,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
+        // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
+        // which is currently model based, due to the langauge model.
+        // At some point in the future we should rectify this.
+        let credential = self.credential.read().clone();
+        let request = stream_completion(credential, self.executor.clone(), prompt);
+        async move {
+            let response = request.await?;
+            let stream = response
+                .filter_map(|response| async move {
+                    match response {
+                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+                        Err(error) => Some(Err(error)),
+                    }
+                })
+                .boxed();
+            Ok(stream)
+        }
+        .boxed()
+    }
+    fn box_clone(&self) -> Box<dyn CompletionProvider> {
+        Box::new((*self).clone())
+    }
+}

crates/ai/src/providers/open_ai/embedding.rs 🔗

@@ -0,0 +1,306 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::AsyncReadExt;
+use gpui::executor::Background;
+use gpui::{serde_json, AppContext};
+use isahc::http::StatusCode;
+use isahc::prelude::Configurable;
+use isahc::{AsyncBody, Response};
+use lazy_static::lazy_static;
+use parking_lot::{Mutex, RwLock};
+use parse_duration::parse;
+use postage::watch;
+use serde::{Deserialize, Serialize};
+use std::env;
+use std::ops::Add;
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use tiktoken_rs::{cl100k_base, CoreBPE};
+use util::http::{HttpClient, Request};
+use util::ResultExt;
+
+use crate::auth::{CredentialProvider, ProviderCredential};
+use crate::embedding::{Embedding, EmbeddingProvider};
+use crate::models::LanguageModel;
+use crate::providers::open_ai::OpenAILanguageModel;
+
+use crate::providers::open_ai::OPENAI_API_URL;
+
+lazy_static! {
+    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
+}
+
+#[derive(Clone)]
+pub struct OpenAIEmbeddingProvider {
+    model: OpenAILanguageModel,
+    credential: Arc<RwLock<ProviderCredential>>,
+    pub client: Arc<dyn HttpClient>,
+    pub executor: Arc<Background>,
+    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
+    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
+}
+
+#[derive(Serialize)]
+struct OpenAIEmbeddingRequest<'a> {
+    model: &'static str,
+    input: Vec<&'a str>,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingResponse {
+    data: Vec<OpenAIEmbedding>,
+    usage: OpenAIEmbeddingUsage,
+}
+
+#[derive(Debug, Deserialize)]
+struct OpenAIEmbedding {
+    embedding: Vec<f32>,
+    index: usize,
+    object: String,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingUsage {
+    prompt_tokens: usize,
+    total_tokens: usize,
+}
+
+impl OpenAIEmbeddingProvider {
+    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
+        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
+
+        let model = OpenAILanguageModel::load("text-embedding-ada-002");
+        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+
+        OpenAIEmbeddingProvider {
+            model,
+            credential,
+            client,
+            executor,
+            rate_limit_count_rx,
+            rate_limit_count_tx,
+        }
+    }
+
+    fn get_api_key(&self) -> Result<String> {
+        match self.credential.read().clone() {
+            ProviderCredential::Credentials { api_key } => Ok(api_key),
+            _ => Err(anyhow!("api credentials not provided")),
+        }
+    }
+
+    fn resolve_rate_limit(&self) {
+        let reset_time = *self.rate_limit_count_tx.lock().borrow();
+
+        if let Some(reset_time) = reset_time {
+            if Instant::now() >= reset_time {
+                *self.rate_limit_count_tx.lock().borrow_mut() = None
+            }
+        }
+
+        log::trace!(
+            "resolving reset time: {:?}",
+            *self.rate_limit_count_tx.lock().borrow()
+        );
+    }
+
+    fn update_reset_time(&self, reset_time: Instant) {
+        let original_time = *self.rate_limit_count_tx.lock().borrow();
+
+        let updated_time = if let Some(original_time) = original_time {
+            if reset_time < original_time {
+                Some(reset_time)
+            } else {
+                Some(original_time)
+            }
+        } else {
+            Some(reset_time)
+        };
+
+        log::trace!("updating rate limit time: {:?}", updated_time);
+
+        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
+    }
+    async fn send_request(
+        &self,
+        api_key: &str,
+        spans: Vec<&str>,
+        request_timeout: u64,
+    ) -> Result<Response<AsyncBody>> {
+        let request = Request::post("https://api.openai.com/v1/embeddings")
+            .redirect_policy(isahc::config::RedirectPolicy::Follow)
+            .timeout(Duration::from_secs(request_timeout))
+            .header("Content-Type", "application/json")
+            .header("Authorization", format!("Bearer {}", api_key))
+            .body(
+                serde_json::to_string(&OpenAIEmbeddingRequest {
+                    input: spans.clone(),
+                    model: "text-embedding-ada-002",
+                })
+                .unwrap()
+                .into(),
+            )?;
+
+        Ok(self.client.send(request).await?)
+    }
+}
+
+impl CredentialProvider for OpenAIEmbeddingProvider {
+    fn has_credentials(&self) -> bool {
+        match *self.credential.read() {
+            ProviderCredential::Credentials { .. } => true,
+            _ => false,
+        }
+    }
+    fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+        let mut credential = self.credential.write();
+        match *credential {
+            ProviderCredential::Credentials { .. } => {
+                return credential.clone();
+            }
+            _ => {
+                if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+                    *credential = ProviderCredential::Credentials { api_key };
+                } else if let Some((_, api_key)) = cx
+                    .platform()
+                    .read_credentials(OPENAI_API_URL)
+                    .log_err()
+                    .flatten()
+                {
+                    if let Some(api_key) = String::from_utf8(api_key).log_err() {
+                        *credential = ProviderCredential::Credentials { api_key };
+                    }
+                } else {
+                };
+            }
+        }
+
+        credential.clone()
+    }
+
+    fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+        match credential.clone() {
+            ProviderCredential::Credentials { api_key } => {
+                cx.platform()
+                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+                    .log_err();
+            }
+            _ => {}
+        }
+
+        *self.credential.write() = credential;
+    }
+    fn delete_credentials(&self, cx: &AppContext) {
+        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+        *self.credential.write() = ProviderCredential::NoCredentials;
+    }
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddingProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+        model
+    }
+
+    fn max_tokens_per_batch(&self) -> usize {
+        50000
+    }
+
+    fn rate_limit_expiration(&self) -> Option<Instant> {
+        *self.rate_limit_count_rx.borrow()
+    }
+
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
+        const MAX_RETRIES: usize = 4;
+
+        let api_key = self.get_api_key()?;
+
+        let mut request_number = 0;
+        let mut rate_limiting = false;
+        let mut request_timeout: u64 = 15;
+        let mut response: Response<AsyncBody>;
+        while request_number < MAX_RETRIES {
+            response = self
+                .send_request(
+                    &api_key,
+                    spans.iter().map(|x| &**x).collect(),
+                    request_timeout,
+                )
+                .await?;
+
+            request_number += 1;
+
+            match response.status() {
+                StatusCode::REQUEST_TIMEOUT => {
+                    request_timeout += 5;
+                }
+                StatusCode::OK => {
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+                    log::trace!(
+                        "openai embedding completed. tokens: {:?}",
+                        response.usage.total_tokens
+                    );
+
+                    // If we complete a request successfully that was previously rate_limited
+                    // resolve the rate limit
+                    if rate_limiting {
+                        self.resolve_rate_limit()
+                    }
+
+                    return Ok(response
+                        .data
+                        .into_iter()
+                        .map(|embedding| Embedding::from(embedding.embedding))
+                        .collect());
+                }
+                StatusCode::TOO_MANY_REQUESTS => {
+                    rate_limiting = true;
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+
+                    let delay_duration = {
+                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+                        if let Some(time_to_reset) =
+                            response.headers().get("x-ratelimit-reset-tokens")
+                        {
+                            if let Ok(time_str) = time_to_reset.to_str() {
+                                parse(time_str).unwrap_or(delay)
+                            } else {
+                                delay
+                            }
+                        } else {
+                            delay
+                        }
+                    };
+
+                    // If we've previously rate limited, increment the duration but not the count
+                    let reset_time = Instant::now().add(delay_duration);
+                    self.update_reset_time(reset_time);
+
+                    log::trace!(
+                        "openai rate limiting: waiting {:?} until lifted",
+                        &delay_duration
+                    );
+
+                    self.executor.timer(delay_duration).await;
+                }
+                _ => {
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+                    return Err(anyhow!(
+                        "open ai bad request: {:?} {:?}",
+                        &response.status(),
+                        body
+                    ));
+                }
+            }
+        }
+        Err(anyhow!("openai max retries"))
+    }
+}

crates/ai/src/providers/open_ai/mod.rs 🔗

@@ -0,0 +1,9 @@
+pub mod completion;
+pub mod embedding;
+pub mod model;
+
+pub use completion::*;
+pub use embedding::*;
+pub use model::OpenAILanguageModel;
+
+pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";

crates/ai/src/providers/open_ai/model.rs 🔗

@@ -0,0 +1,57 @@
+use anyhow::anyhow;
+use tiktoken_rs::CoreBPE;
+use util::ResultExt;
+
+use crate::models::{LanguageModel, TruncationDirection};
+
+#[derive(Clone)]
+pub struct OpenAILanguageModel {
+    name: String,
+    bpe: Option<CoreBPE>,
+}
+
+impl OpenAILanguageModel {
+    pub fn load(model_name: &str) -> Self {
+        let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
+        OpenAILanguageModel {
+            name: model_name.to_string(),
+            bpe,
+        }
+    }
+}
+
+impl LanguageModel for OpenAILanguageModel {
+    fn name(&self) -> String {
+        self.name.clone()
+    }
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+        if let Some(bpe) = &self.bpe {
+            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
+        } else {
+            Err(anyhow!("bpe for open ai model was not retrieved"))
+        }
+    }
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String> {
+        if let Some(bpe) = &self.bpe {
+            let tokens = bpe.encode_with_special_tokens(content);
+            if tokens.len() > length {
+                match direction {
+                    TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
+                    TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
+                }
+            } else {
+                bpe.decode(tokens)
+            }
+        } else {
+            Err(anyhow!("bpe for open ai model was not retrieved"))
+        }
+    }
+    fn capacity(&self) -> anyhow::Result<usize> {
+        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
+    }
+}

crates/ai/src/providers/open_ai/new.rs 🔗

@@ -0,0 +1,11 @@
+pub trait LanguageModel {
+    fn name(&self) -> String;
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String>;
+    fn capacity(&self) -> anyhow::Result<usize>;
+}

crates/ai/src/test.rs 🔗

@@ -0,0 +1,191 @@
+use std::{
+    sync::atomic::{self, AtomicUsize, Ordering},
+    time::Instant,
+};
+
+use async_trait::async_trait;
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui::AppContext;
+use parking_lot::Mutex;
+
+use crate::{
+    auth::{CredentialProvider, ProviderCredential},
+    completion::{CompletionProvider, CompletionRequest},
+    embedding::{Embedding, EmbeddingProvider},
+    models::{LanguageModel, TruncationDirection},
+};
+
+#[derive(Clone)]
+pub struct FakeLanguageModel {
+    pub capacity: usize,
+}
+
+impl LanguageModel for FakeLanguageModel {
+    fn name(&self) -> String {
+        "dummy".to_string()
+    }
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+    }
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String> {
+        println!("TRYING TO TRUNCATE: {:?}", length.clone());
+
+        if length > self.count_tokens(content)? {
+            println!("NOT TRUNCATING");
+            return anyhow::Ok(content.to_string());
+        }
+
+        anyhow::Ok(match direction {
+            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
+                .into_iter()
+                .collect::<String>(),
+            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
+                .into_iter()
+                .collect::<String>(),
+        })
+    }
+    fn capacity(&self) -> anyhow::Result<usize> {
+        anyhow::Ok(self.capacity)
+    }
+}
+
+pub struct FakeEmbeddingProvider {
+    pub embedding_count: AtomicUsize,
+}
+
+impl Clone for FakeEmbeddingProvider {
+    fn clone(&self) -> Self {
+        FakeEmbeddingProvider {
+            embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
+        }
+    }
+}
+
+impl Default for FakeEmbeddingProvider {
+    fn default() -> Self {
+        FakeEmbeddingProvider {
+            embedding_count: AtomicUsize::default(),
+        }
+    }
+}
+
+impl FakeEmbeddingProvider {
+    pub fn embedding_count(&self) -> usize {
+        self.embedding_count.load(atomic::Ordering::SeqCst)
+    }
+
+    pub fn embed_sync(&self, span: &str) -> Embedding {
+        let mut result = vec![1.0; 26];
+        for letter in span.chars() {
+            let letter = letter.to_ascii_lowercase();
+            if letter as u32 >= 'a' as u32 {
+                let ix = (letter as u32) - ('a' as u32);
+                if ix < 26 {
+                    result[ix as usize] += 1.0;
+                }
+            }
+        }
+
+        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+        for x in &mut result {
+            *x /= norm;
+        }
+
+        result.into()
+    }
+}
+
+impl CredentialProvider for FakeEmbeddingProvider {
+    fn has_credentials(&self) -> bool {
+        true
+    }
+    fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+        ProviderCredential::NotNeeded
+    }
+    fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
+    fn delete_credentials(&self, _cx: &AppContext) {}
+}
+
+#[async_trait]
+impl EmbeddingProvider for FakeEmbeddingProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        Box::new(FakeLanguageModel { capacity: 1000 })
+    }
+    fn max_tokens_per_batch(&self) -> usize {
+        1000
+    }
+
+    fn rate_limit_expiration(&self) -> Option<Instant> {
+        None
+    }
+
+    async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
+        self.embedding_count
+            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
+
+        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
+    }
+}
+
+pub struct FakeCompletionProvider {
+    last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
+}
+
+impl Clone for FakeCompletionProvider {
+    fn clone(&self) -> Self {
+        Self {
+            last_completion_tx: Mutex::new(None),
+        }
+    }
+}
+
+impl FakeCompletionProvider {
+    pub fn new() -> Self {
+        Self {
+            last_completion_tx: Mutex::new(None),
+        }
+    }
+
+    pub fn send_completion(&self, completion: impl Into<String>) {
+        let mut tx = self.last_completion_tx.lock();
+        tx.as_mut().unwrap().try_send(completion.into()).unwrap();
+    }
+
+    pub fn finish_completion(&self) {
+        self.last_completion_tx.lock().take().unwrap();
+    }
+}
+
+impl CredentialProvider for FakeCompletionProvider {
+    fn has_credentials(&self) -> bool {
+        true
+    }
+    fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+        ProviderCredential::NotNeeded
+    }
+    fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
+    fn delete_credentials(&self, _cx: &AppContext) {}
+}
+
+impl CompletionProvider for FakeCompletionProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
+        model
+    }
+    fn complete(
+        &self,
+        _prompt: Box<dyn CompletionRequest>,
+    ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
+        let (tx, rx) = mpsc::channel(1);
+        *self.last_completion_tx.lock() = Some(tx);
+        async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
+    }
+    fn box_clone(&self) -> Box<dyn CompletionProvider> {
+        Box::new((*self).clone())
+    }
+}

crates/ai2/Cargo.toml 🔗

@@ -0,0 +1,38 @@
+[package]
+name = "ai2"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/ai2.rs"
+doctest = false
+
+[features]
+test-support = []
+
+[dependencies]
+gpui2 = { path = "../gpui2" }
+util = { path = "../util" }
+language2 = { path = "../language2" }
+async-trait.workspace = true
+anyhow.workspace = true
+futures.workspace = true
+lazy_static.workspace = true
+ordered-float.workspace = true
+parking_lot.workspace = true
+isahc.workspace = true
+regex.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+postage.workspace = true
+rand.workspace = true
+log.workspace = true
+parse_duration = "2.1.1"
+tiktoken-rs = "0.5.0"
+matrixmultiply = "0.3.7"
+rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
+bincode = "1.3.3"
+
+[dev-dependencies]
+gpui2 = { path = "../gpui2", features = ["test-support"] }

crates/ai2/src/ai2.rs 🔗

@@ -0,0 +1,8 @@
+pub mod auth;
+pub mod completion;
+pub mod embedding;
+pub mod models;
+pub mod prompts;
+pub mod providers;
+#[cfg(any(test, feature = "test-support"))]
+pub mod test;

crates/ai2/src/auth.rs 🔗

@@ -0,0 +1,17 @@
+use async_trait::async_trait;
+use gpui2::AppContext;
+
+#[derive(Clone, Debug)]
+pub enum ProviderCredential {
+    Credentials { api_key: String },
+    NoCredentials,
+    NotNeeded,
+}
+
+#[async_trait]
+pub trait CredentialProvider: Send + Sync {
+    fn has_credentials(&self) -> bool;
+    async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
+    async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
+    async fn delete_credentials(&self, cx: &mut AppContext);
+}

crates/ai2/src/completion.rs 🔗

@@ -0,0 +1,23 @@
+use anyhow::Result;
+use futures::{future::BoxFuture, stream::BoxStream};
+
+use crate::{auth::CredentialProvider, models::LanguageModel};
+
+pub trait CompletionRequest: Send + Sync {
+    fn data(&self) -> serde_json::Result<String>;
+}
+
+pub trait CompletionProvider: CredentialProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel>;
+    fn complete(
+        &self,
+        prompt: Box<dyn CompletionRequest>,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+    fn box_clone(&self) -> Box<dyn CompletionProvider>;
+}
+
+impl Clone for Box<dyn CompletionProvider> {
+    fn clone(&self) -> Box<dyn CompletionProvider> {
+        self.box_clone()
+    }
+}

crates/ai2/src/embedding.rs 🔗

@@ -0,0 +1,123 @@
+use std::time::Instant;
+
+use anyhow::Result;
+use async_trait::async_trait;
+use ordered_float::OrderedFloat;
+use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
+use rusqlite::ToSql;
+
+use crate::auth::CredentialProvider;
+use crate::models::LanguageModel;
+
+#[derive(Debug, PartialEq, Clone)]
+pub struct Embedding(pub Vec<f32>);
+
+// This is needed for semantic index functionality
+// Unfortunately it has to live wherever the "Embedding" struct is created.
+// Keeping this in here though, introduces a 'rusqlite' dependency into AI
+// which is less than ideal
+impl FromSql for Embedding {
+    fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+        let bytes = value.as_blob()?;
+        let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
+        if embedding.is_err() {
+            return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
+        }
+        Ok(Embedding(embedding.unwrap()))
+    }
+}
+
+impl ToSql for Embedding {
+    fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+        let bytes = bincode::serialize(&self.0)
+            .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
+        Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
+    }
+}
+impl From<Vec<f32>> for Embedding {
+    fn from(value: Vec<f32>) -> Self {
+        Embedding(value)
+    }
+}
+
+impl Embedding {
+    pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
+        let len = self.0.len();
+        assert_eq!(len, other.0.len());
+
+        let mut result = 0.0;
+        unsafe {
+            matrixmultiply::sgemm(
+                1,
+                len,
+                1,
+                1.0,
+                self.0.as_ptr(),
+                len as isize,
+                1,
+                other.0.as_ptr(),
+                1,
+                len as isize,
+                0.0,
+                &mut result as *mut f32,
+                1,
+                1,
+            );
+        }
+        OrderedFloat(result)
+    }
+}
+
+#[async_trait]
+pub trait EmbeddingProvider: CredentialProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel>;
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
+    fn max_tokens_per_batch(&self) -> usize;
+    fn rate_limit_expiration(&self) -> Option<Instant>;
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use rand::prelude::*;
+
+    #[gpui2::test]
+    fn test_similarity(mut rng: StdRng) {
+        assert_eq!(
+            Embedding::from(vec![1., 0., 0., 0., 0.])
+                .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
+            0.
+        );
+        assert_eq!(
+            Embedding::from(vec![2., 0., 0., 0., 0.])
+                .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
+            6.
+        );
+
+        for _ in 0..100 {
+            let size = 1536;
+            let mut a = vec![0.; size];
+            let mut b = vec![0.; size];
+            for (a, b) in a.iter_mut().zip(b.iter_mut()) {
+                *a = rng.gen();
+                *b = rng.gen();
+            }
+            let a = Embedding::from(a);
+            let b = Embedding::from(b);
+
+            assert_eq!(
+                round_to_decimals(a.similarity(&b), 1),
+                round_to_decimals(reference_dot(&a.0, &b.0), 1)
+            );
+        }
+
+        fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
+            let factor = (10.0 as f32).powi(decimal_places);
+            (n * factor).round() / factor
+        }
+
+        fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
+            OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
+        }
+    }
+}

crates/ai2/src/models.rs 🔗

@@ -0,0 +1,16 @@
+pub enum TruncationDirection {
+    Start,
+    End,
+}
+
+pub trait LanguageModel {
+    fn name(&self) -> String;
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String>;
+    fn capacity(&self) -> anyhow::Result<usize>;
+}

crates/ai2/src/prompts/base.rs 🔗

@@ -0,0 +1,330 @@
+use std::cmp::Reverse;
+use std::ops::Range;
+use std::sync::Arc;
+
+use language2::BufferSnapshot;
+use util::ResultExt;
+
+use crate::models::LanguageModel;
+use crate::prompts::repository_context::PromptCodeSnippet;
+
+pub(crate) enum PromptFileType {
+    Text,
+    Code,
+}
+
+// TODO: Set this up to manage for defaults well
+pub struct PromptArguments {
+    pub model: Arc<dyn LanguageModel>,
+    pub user_prompt: Option<String>,
+    pub language_name: Option<String>,
+    pub project_name: Option<String>,
+    pub snippets: Vec<PromptCodeSnippet>,
+    pub reserved_tokens: usize,
+    pub buffer: Option<BufferSnapshot>,
+    pub selected_range: Option<Range<usize>>,
+}
+
+impl PromptArguments {
+    pub(crate) fn get_file_type(&self) -> PromptFileType {
+        if self
+            .language_name
+            .as_ref()
+            .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
+            .unwrap_or(true)
+        {
+            PromptFileType::Code
+        } else {
+            PromptFileType::Text
+        }
+    }
+}
+
+pub trait PromptTemplate {
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)>;
+}
+
+#[repr(i8)]
+#[derive(PartialEq, Eq, Ord)]
+pub enum PromptPriority {
+    Mandatory,                // Ignores truncation
+    Ordered { order: usize }, // Truncates based on priority
+}
+
+impl PartialOrd for PromptPriority {
+    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+        match (self, other) {
+            (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
+            (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
+            (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
+            (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
+        }
+    }
+}
+
+pub struct PromptChain {
+    args: PromptArguments,
+    templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+}
+
+impl PromptChain {
+    pub fn new(
+        args: PromptArguments,
+        templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+    ) -> Self {
+        PromptChain { args, templates }
+    }
+
+    pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
+        // Argsort based on Prompt Priority
+        let seperator = "\n";
+        let seperator_tokens = self.args.model.count_tokens(seperator)?;
+        let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
+        sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
+
+        // If Truncate
+        let mut tokens_outstanding = if truncate {
+            Some(self.args.model.capacity()? - self.args.reserved_tokens)
+        } else {
+            None
+        };
+
+        let mut prompts = vec!["".to_string(); sorted_indices.len()];
+        for idx in sorted_indices {
+            let (_, template) = &self.templates[idx];
+
+            if let Some((template_prompt, prompt_token_count)) =
+                template.generate(&self.args, tokens_outstanding).log_err()
+            {
+                if template_prompt != "" {
+                    prompts[idx] = template_prompt;
+
+                    if let Some(remaining_tokens) = tokens_outstanding {
+                        let new_tokens = prompt_token_count + seperator_tokens;
+                        tokens_outstanding = if remaining_tokens > new_tokens {
+                            Some(remaining_tokens - new_tokens)
+                        } else {
+                            Some(0)
+                        };
+                    }
+                }
+            }
+        }
+
+        prompts.retain(|x| x != "");
+
+        let full_prompt = prompts.join(seperator);
+        let total_token_count = self.args.model.count_tokens(&full_prompt)?;
+        anyhow::Ok((prompts.join(seperator), total_token_count))
+    }
+}
+
+#[cfg(test)]
+pub(crate) mod tests {
+    use crate::models::TruncationDirection;
+    use crate::test::FakeLanguageModel;
+
+    use super::*;
+
+    #[test]
+    pub fn test_prompt_chain() {
+        struct TestPromptTemplate {}
+        impl PromptTemplate for TestPromptTemplate {
+            fn generate(
+                &self,
+                args: &PromptArguments,
+                max_token_length: Option<usize>,
+            ) -> anyhow::Result<(String, usize)> {
+                let mut content = "This is a test prompt template".to_string();
+
+                let mut token_count = args.model.count_tokens(&content)?;
+                if let Some(max_token_length) = max_token_length {
+                    if token_count > max_token_length {
+                        content = args.model.truncate(
+                            &content,
+                            max_token_length,
+                            TruncationDirection::End,
+                        )?;
+                        token_count = max_token_length;
+                    }
+                }
+
+                anyhow::Ok((content, token_count))
+            }
+        }
+
+        struct TestLowPriorityTemplate {}
+        impl PromptTemplate for TestLowPriorityTemplate {
+            fn generate(
+                &self,
+                args: &PromptArguments,
+                max_token_length: Option<usize>,
+            ) -> anyhow::Result<(String, usize)> {
+                let mut content = "This is a low priority test prompt template".to_string();
+
+                let mut token_count = args.model.count_tokens(&content)?;
+                if let Some(max_token_length) = max_token_length {
+                    if token_count > max_token_length {
+                        content = args.model.truncate(
+                            &content,
+                            max_token_length,
+                            TruncationDirection::End,
+                        )?;
+                        token_count = max_token_length;
+                    }
+                }
+
+                anyhow::Ok((content, token_count))
+            }
+        }
+
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
+        let args = PromptArguments {
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens: 0,
+            buffer: None,
+            selected_range: None,
+            user_prompt: None,
+        };
+
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (
+                PromptPriority::Ordered { order: 0 },
+                Box::new(TestPromptTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 1 },
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let (prompt, token_count) = chain.generate(false).unwrap();
+
+        assert_eq!(
+            prompt,
+            "This is a test prompt template\nThis is a low priority test prompt template"
+                .to_string()
+        );
+
+        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+        // Testing with Truncation Off
+        // Should ignore capacity and return all prompts
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
+        let args = PromptArguments {
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens: 0,
+            buffer: None,
+            selected_range: None,
+            user_prompt: None,
+        };
+
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (
+                PromptPriority::Ordered { order: 0 },
+                Box::new(TestPromptTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 1 },
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let (prompt, token_count) = chain.generate(false).unwrap();
+
+        assert_eq!(
+            prompt,
+            "This is a test prompt template\nThis is a low priority test prompt template"
+                .to_string()
+        );
+
+        assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+        // Testing with Truncation Off
+        // Should ignore capacity and return all prompts
+        let capacity = 20;
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
+        let args = PromptArguments {
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens: 0,
+            buffer: None,
+            selected_range: None,
+            user_prompt: None,
+        };
+
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (
+                PromptPriority::Ordered { order: 0 },
+                Box::new(TestPromptTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 1 },
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 2 },
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let (prompt, token_count) = chain.generate(true).unwrap();
+
+        assert_eq!(prompt, "This is a test promp".to_string());
+        assert_eq!(token_count, capacity);
+
+        // Change Ordering of Prompts Based on Priority
+        let capacity = 120;
+        let reserved_tokens = 10;
+        let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
+        let args = PromptArguments {
+            model: model.clone(),
+            language_name: None,
+            project_name: None,
+            snippets: Vec::new(),
+            reserved_tokens,
+            buffer: None,
+            selected_range: None,
+            user_prompt: None,
+        };
+        let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+            (
+                PromptPriority::Mandatory,
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 0 },
+                Box::new(TestPromptTemplate {}),
+            ),
+            (
+                PromptPriority::Ordered { order: 1 },
+                Box::new(TestLowPriorityTemplate {}),
+            ),
+        ];
+        let chain = PromptChain::new(args, templates);
+
+        let (prompt, token_count) = chain.generate(true).unwrap();
+
+        assert_eq!(
+            prompt,
+            "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
+                .to_string()
+        );
+        assert_eq!(token_count, capacity - reserved_tokens);
+    }
+}

crates/ai2/src/prompts/file_context.rs 🔗

@@ -0,0 +1,164 @@
+use anyhow::anyhow;
+use language2::BufferSnapshot;
+use language2::ToOffset;
+
+use crate::models::LanguageModel;
+use crate::models::TruncationDirection;
+use crate::prompts::base::PromptArguments;
+use crate::prompts::base::PromptTemplate;
+use std::fmt::Write;
+use std::ops::Range;
+use std::sync::Arc;
+
+fn retrieve_context(
+    buffer: &BufferSnapshot,
+    selected_range: &Option<Range<usize>>,
+    model: Arc<dyn LanguageModel>,
+    max_token_count: Option<usize>,
+) -> anyhow::Result<(String, usize, bool)> {
+    let mut prompt = String::new();
+    let mut truncated = false;
+    if let Some(selected_range) = selected_range {
+        let start = selected_range.start.to_offset(buffer);
+        let end = selected_range.end.to_offset(buffer);
+
+        let start_window = buffer.text_for_range(0..start).collect::<String>();
+
+        let mut selected_window = String::new();
+        if start == end {
+            write!(selected_window, "<|START|>").unwrap();
+        } else {
+            write!(selected_window, "<|START|").unwrap();
+        }
+
+        write!(
+            selected_window,
+            "{}",
+            buffer.text_for_range(start..end).collect::<String>()
+        )
+        .unwrap();
+
+        if start != end {
+            write!(selected_window, "|END|>").unwrap();
+        }
+
+        let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
+
+        if let Some(max_token_count) = max_token_count {
+            let selected_tokens = model.count_tokens(&selected_window)?;
+            if selected_tokens > max_token_count {
+                return Err(anyhow!(
+                    "selected range is greater than model context window, truncation not possible"
+                ));
+            };
+
+            let mut remaining_tokens = max_token_count - selected_tokens;
+            let start_window_tokens = model.count_tokens(&start_window)?;
+            let end_window_tokens = model.count_tokens(&end_window)?;
+            let outside_tokens = start_window_tokens + end_window_tokens;
+            if outside_tokens > remaining_tokens {
+                let (start_goal_tokens, end_goal_tokens) =
+                    if start_window_tokens < end_window_tokens {
+                        let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
+                        remaining_tokens -= start_goal_tokens;
+                        let end_goal_tokens = remaining_tokens.min(end_window_tokens);
+                        (start_goal_tokens, end_goal_tokens)
+                    } else {
+                        let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
+                        remaining_tokens -= end_goal_tokens;
+                        let start_goal_tokens = remaining_tokens.min(start_window_tokens);
+                        (start_goal_tokens, end_goal_tokens)
+                    };
+
+                let truncated_start_window =
+                    model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
+                let truncated_end_window =
+                    model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
+                writeln!(
+                    prompt,
+                    "{truncated_start_window}{selected_window}{truncated_end_window}"
+                )
+                .unwrap();
+                truncated = true;
+            } else {
+                writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
+            }
+        } else {
+            // If we dont have a selected range, include entire file.
+            writeln!(prompt, "{}", &buffer.text()).unwrap();
+
+            // Dumb truncation strategy
+            if let Some(max_token_count) = max_token_count {
+                if model.count_tokens(&prompt)? > max_token_count {
+                    truncated = true;
+                    prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
+                }
+            }
+        }
+    }
+
+    let token_count = model.count_tokens(&prompt)?;
+    anyhow::Ok((prompt, token_count, truncated))
+}
+
+pub struct FileContext {}
+
+impl PromptTemplate for FileContext {
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)> {
+        if let Some(buffer) = &args.buffer {
+            let mut prompt = String::new();
+            // Add Initial Preamble
+            // TODO: Do we want to add the path in here?
+            writeln!(
+                prompt,
+                "The file you are currently working on has the following content:"
+            )
+            .unwrap();
+
+            let language_name = args
+                .language_name
+                .clone()
+                .unwrap_or("".to_string())
+                .to_lowercase();
+
+            let (context, _, truncated) = retrieve_context(
+                buffer,
+                &args.selected_range,
+                args.model.clone(),
+                max_token_length,
+            )?;
+            writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
+
+            if truncated {
+                writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
+            }
+
+            if let Some(selected_range) = &args.selected_range {
+                let start = selected_range.start.to_offset(buffer);
+                let end = selected_range.end.to_offset(buffer);
+
+                if start == end {
+                    writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
+                } else {
+                    writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
+                }
+            }
+
+            // Really dumb truncation strategy
+            if let Some(max_tokens) = max_token_length {
+                prompt = args
+                    .model
+                    .truncate(&prompt, max_tokens, TruncationDirection::End)?;
+            }
+
+            let token_count = args.model.count_tokens(&prompt)?;
+            anyhow::Ok((prompt, token_count))
+        } else {
+            Err(anyhow!("no buffer provided to retrieve file context from"))
+        }
+    }
+}

crates/ai2/src/prompts/generate.rs 🔗

@@ -0,0 +1,99 @@
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
+use anyhow::anyhow;
+use std::fmt::Write;
+
+pub fn capitalize(s: &str) -> String {
+    let mut c = s.chars();
+    match c.next() {
+        None => String::new(),
+        Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
+    }
+}
+
+pub struct GenerateInlineContent {}
+
+impl PromptTemplate for GenerateInlineContent {
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)> {
+        let Some(user_prompt) = &args.user_prompt else {
+            return Err(anyhow!("user prompt not provided"));
+        };
+
+        let file_type = args.get_file_type();
+        let content_type = match &file_type {
+            PromptFileType::Code => "code",
+            PromptFileType::Text => "text",
+        };
+
+        let mut prompt = String::new();
+
+        if let Some(selected_range) = &args.selected_range {
+            if selected_range.start == selected_range.end {
+                writeln!(
+                    prompt,
+                    "Assume the cursor is located where the `<|START|>` span is."
+                )
+                .unwrap();
+                writeln!(
+                    prompt,
+                    "{} can't be replaced, so assume your answer will be inserted at the cursor.",
+                    capitalize(content_type)
+                )
+                .unwrap();
+                writeln!(
+                    prompt,
+                    "Generate {content_type} based on the users prompt: {user_prompt}",
+                )
+                .unwrap();
+            } else {
+                writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
+                writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
+                writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
+            }
+        } else {
+            writeln!(
+                prompt,
+                "Generate {content_type} based on the users prompt: {user_prompt}"
+            )
+            .unwrap();
+        }
+
+        if let Some(language_name) = &args.language_name {
+            writeln!(
+                prompt,
+                "Your answer MUST always and only be valid {}.",
+                language_name
+            )
+            .unwrap();
+        }
+        writeln!(prompt, "Never make remarks about the output.").unwrap();
+        writeln!(
+            prompt,
+            "Do not return anything else, except the generated {content_type}."
+        )
+        .unwrap();
+
+        match file_type {
+            PromptFileType::Code => {
+                // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
+            }
+            _ => {}
+        }
+
+        // Really dumb truncation strategy
+        if let Some(max_tokens) = max_token_length {
+            prompt = args.model.truncate(
+                &prompt,
+                max_tokens,
+                crate::models::TruncationDirection::End,
+            )?;
+        }
+
+        let token_count = args.model.count_tokens(&prompt)?;
+
+        anyhow::Ok((prompt, token_count))
+    }
+}

crates/ai2/src/prompts/preamble.rs 🔗

@@ -0,0 +1,52 @@
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
+use std::fmt::Write;
+
+pub struct EngineerPreamble {}
+
+impl PromptTemplate for EngineerPreamble {
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)> {
+        let mut prompts = Vec::new();
+
+        match args.get_file_type() {
+            PromptFileType::Code => {
+                prompts.push(format!(
+                    "You are an expert {}engineer.",
+                    args.language_name.clone().unwrap_or("".to_string()) + " "
+                ));
+            }
+            PromptFileType::Text => {
+                prompts.push("You are an expert engineer.".to_string());
+            }
+        }
+
+        if let Some(project_name) = args.project_name.clone() {
+            prompts.push(format!(
+                "You are currently working inside the '{project_name}' project in code editor Zed."
+            ));
+        }
+
+        if let Some(mut remaining_tokens) = max_token_length {
+            let mut prompt = String::new();
+            let mut total_count = 0;
+            for prompt_piece in prompts {
+                let prompt_token_count =
+                    args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
+                if remaining_tokens > prompt_token_count {
+                    writeln!(prompt, "{prompt_piece}").unwrap();
+                    remaining_tokens -= prompt_token_count;
+                    total_count += prompt_token_count;
+                }
+            }
+
+            anyhow::Ok((prompt, total_count))
+        } else {
+            let prompt = prompts.join("\n");
+            let token_count = args.model.count_tokens(&prompt)?;
+            anyhow::Ok((prompt, token_count))
+        }
+    }
+}

crates/ai2/src/prompts/repository_context.rs 🔗

@@ -0,0 +1,98 @@
+use crate::prompts::base::{PromptArguments, PromptTemplate};
+use std::fmt::Write;
+use std::{ops::Range, path::PathBuf};
+
+use gpui2::{AsyncAppContext, Model};
+use language2::{Anchor, Buffer};
+
+#[derive(Clone)]
+pub struct PromptCodeSnippet {
+    path: Option<PathBuf>,
+    language_name: Option<String>,
+    content: String,
+}
+
+impl PromptCodeSnippet {
+    pub fn new(
+        buffer: Model<Buffer>,
+        range: Range<Anchor>,
+        cx: &mut AsyncAppContext,
+    ) -> anyhow::Result<Self> {
+        let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
+            let snapshot = buffer.snapshot();
+            let content = snapshot.text_for_range(range.clone()).collect::<String>();
+
+            let language_name = buffer
+                .language()
+                .and_then(|language| Some(language.name().to_string().to_lowercase()));
+
+            let file_path = buffer
+                .file()
+                .and_then(|file| Some(file.path().to_path_buf()));
+
+            (content, language_name, file_path)
+        })?;
+
+        anyhow::Ok(PromptCodeSnippet {
+            path: file_path,
+            language_name,
+            content,
+        })
+    }
+}
+
+impl ToString for PromptCodeSnippet {
+    fn to_string(&self) -> String {
+        let path = self
+            .path
+            .as_ref()
+            .and_then(|path| Some(path.to_string_lossy().to_string()))
+            .unwrap_or("".to_string());
+        let language_name = self.language_name.clone().unwrap_or("".to_string());
+        let content = self.content.clone();
+
+        format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+    }
+}
+
+pub struct RepositoryContext {}
+
+impl PromptTemplate for RepositoryContext {
+    fn generate(
+        &self,
+        args: &PromptArguments,
+        max_token_length: Option<usize>,
+    ) -> anyhow::Result<(String, usize)> {
+        const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+        let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
+        let mut prompt = String::new();
+
+        let mut remaining_tokens = max_token_length.clone();
+        let seperator_token_length = args.model.count_tokens("\n")?;
+        for snippet in &args.snippets {
+            let mut snippet_prompt = template.to_string();
+            let content = snippet.to_string();
+            writeln!(snippet_prompt, "{content}").unwrap();
+
+            let token_count = args.model.count_tokens(&snippet_prompt)?;
+            if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
+                if let Some(tokens_left) = remaining_tokens {
+                    if tokens_left >= token_count {
+                        writeln!(prompt, "{snippet_prompt}").unwrap();
+                        remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
+                        {
+                            Some(tokens_left - token_count - seperator_token_length)
+                        } else {
+                            Some(0)
+                        };
+                    }
+                } else {
+                    writeln!(prompt, "{snippet_prompt}").unwrap();
+                }
+            }
+        }
+
+        let total_token_count = args.model.count_tokens(&prompt)?;
+        anyhow::Ok((prompt, total_token_count))
+    }
+}

crates/ai2/src/providers/open_ai/completion.rs 🔗

@@ -0,0 +1,306 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::{
+    future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
+    Stream, StreamExt,
+};
+use gpui2::{AppContext, Executor};
+use isahc::{http::StatusCode, Request, RequestExt};
+use parking_lot::RwLock;
+use serde::{Deserialize, Serialize};
+use std::{
+    env,
+    fmt::{self, Display},
+    io,
+    sync::Arc,
+};
+use util::ResultExt;
+
+use crate::{
+    auth::{CredentialProvider, ProviderCredential},
+    completion::{CompletionProvider, CompletionRequest},
+    models::LanguageModel,
+};
+
+use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+    User,
+    Assistant,
+    System,
+}
+
+impl Role {
+    pub fn cycle(&mut self) {
+        *self = match self {
+            Role::User => Role::Assistant,
+            Role::Assistant => Role::System,
+            Role::System => Role::User,
+        }
+    }
+}
+
+impl Display for Role {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+        match self {
+            Role::User => write!(f, "User"),
+            Role::Assistant => write!(f, "Assistant"),
+            Role::System => write!(f, "System"),
+        }
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+    pub role: Role,
+    pub content: String,
+}
+
+#[derive(Debug, Default, Serialize)]
+pub struct OpenAIRequest {
+    pub model: String,
+    pub messages: Vec<RequestMessage>,
+    pub stream: bool,
+    pub stop: Vec<String>,
+    pub temperature: f32,
+}
+
+impl CompletionRequest for OpenAIRequest {
+    fn data(&self) -> serde_json::Result<String> {
+        serde_json::to_string(self)
+    }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessage {
+    pub role: Option<Role>,
+    pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIUsage {
+    pub prompt_tokens: u32,
+    pub completion_tokens: u32,
+    pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ChatChoiceDelta {
+    pub index: u32,
+    pub delta: ResponseMessage,
+    pub finish_reason: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIResponseStreamEvent {
+    pub id: Option<String>,
+    pub object: String,
+    pub created: u32,
+    pub model: String,
+    pub choices: Vec<ChatChoiceDelta>,
+    pub usage: Option<OpenAIUsage>,
+}
+
+pub async fn stream_completion(
+    credential: ProviderCredential,
+    executor: Arc<Executor>,
+    request: Box<dyn CompletionRequest>,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+    let api_key = match credential {
+        ProviderCredential::Credentials { api_key } => api_key,
+        _ => {
+            return Err(anyhow!("no credentials provider for completion"));
+        }
+    };
+
+    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+    let json_data = request.data()?;
+    let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("Bearer {}", api_key))
+        .body(json_data)?
+        .send_async()
+        .await?;
+
+    let status = response.status();
+    if status == StatusCode::OK {
+        executor
+            .spawn(async move {
+                let mut lines = BufReader::new(response.body_mut()).lines();
+
+                fn parse_line(
+                    line: Result<String, io::Error>,
+                ) -> Result<Option<OpenAIResponseStreamEvent>> {
+                    if let Some(data) = line?.strip_prefix("data: ") {
+                        let event = serde_json::from_str(&data)?;
+                        Ok(Some(event))
+                    } else {
+                        Ok(None)
+                    }
+                }
+
+                while let Some(line) = lines.next().await {
+                    if let Some(event) = parse_line(line).transpose() {
+                        let done = event.as_ref().map_or(false, |event| {
+                            event
+                                .choices
+                                .last()
+                                .map_or(false, |choice| choice.finish_reason.is_some())
+                        });
+                        if tx.unbounded_send(event).is_err() {
+                            break;
+                        }
+
+                        if done {
+                            break;
+                        }
+                    }
+                }
+
+                anyhow::Ok(())
+            })
+            .detach();
+
+        Ok(rx)
+    } else {
+        let mut body = String::new();
+        response.body_mut().read_to_string(&mut body).await?;
+
+        #[derive(Deserialize)]
+        struct OpenAIResponse {
+            error: OpenAIError,
+        }
+
+        #[derive(Deserialize)]
+        struct OpenAIError {
+            message: String,
+        }
+
+        match serde_json::from_str::<OpenAIResponse>(&body) {
+            Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+                "Failed to connect to OpenAI API: {}",
+                response.error.message,
+            )),
+
+            _ => Err(anyhow!(
+                "Failed to connect to OpenAI API: {} {}",
+                response.status(),
+                body,
+            )),
+        }
+    }
+}
+
+#[derive(Clone)]
+pub struct OpenAICompletionProvider {
+    model: OpenAILanguageModel,
+    credential: Arc<RwLock<ProviderCredential>>,
+    executor: Arc<Executor>,
+}
+
+impl OpenAICompletionProvider {
+    pub fn new(model_name: &str, executor: Arc<Executor>) -> Self {
+        let model = OpenAILanguageModel::load(model_name);
+        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+        Self {
+            model,
+            credential,
+            executor,
+        }
+    }
+}
+
+#[async_trait]
+impl CredentialProvider for OpenAICompletionProvider {
+    fn has_credentials(&self) -> bool {
+        match *self.credential.read() {
+            ProviderCredential::Credentials { .. } => true,
+            _ => false,
+        }
+    }
+    async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
+        let existing_credential = self.credential.read().clone();
+
+        let retrieved_credential = cx
+            .run_on_main(move |cx| match existing_credential {
+                ProviderCredential::Credentials { .. } => {
+                    return existing_credential.clone();
+                }
+                _ => {
+                    if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
+                        return ProviderCredential::Credentials { api_key };
+                    }
+
+                    if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
+                    {
+                        if let Some(api_key) = String::from_utf8(api_key).log_err() {
+                            return ProviderCredential::Credentials { api_key };
+                        } else {
+                            return ProviderCredential::NoCredentials;
+                        }
+                    } else {
+                        return ProviderCredential::NoCredentials;
+                    }
+                }
+            })
+            .await;
+
+        *self.credential.write() = retrieved_credential.clone();
+        retrieved_credential
+    }
+
+    async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
+        *self.credential.write() = credential.clone();
+        let credential = credential.clone();
+        cx.run_on_main(move |cx| match credential {
+            ProviderCredential::Credentials { api_key } => {
+                cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+                    .log_err();
+            }
+            _ => {}
+        })
+        .await;
+    }
+    async fn delete_credentials(&self, cx: &mut AppContext) {
+        cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
+            .await;
+        *self.credential.write() = ProviderCredential::NoCredentials;
+    }
+}
+
+impl CompletionProvider for OpenAICompletionProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+        model
+    }
+    fn complete(
+        &self,
+        prompt: Box<dyn CompletionRequest>,
+    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+        // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
+        // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
+        // which is currently model based, due to the langauge model.
+        // At some point in the future we should rectify this.
+        let credential = self.credential.read().clone();
+        let request = stream_completion(credential, self.executor.clone(), prompt);
+        async move {
+            let response = request.await?;
+            let stream = response
+                .filter_map(|response| async move {
+                    match response {
+                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+                        Err(error) => Some(Err(error)),
+                    }
+                })
+                .boxed();
+            Ok(stream)
+        }
+        .boxed()
+    }
+    fn box_clone(&self) -> Box<dyn CompletionProvider> {
+        Box::new((*self).clone())
+    }
+}

crates/ai2/src/providers/open_ai/embedding.rs 🔗

@@ -0,0 +1,313 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::AsyncReadExt;
+use gpui2::Executor;
+use gpui2::{serde_json, AppContext};
+use isahc::http::StatusCode;
+use isahc::prelude::Configurable;
+use isahc::{AsyncBody, Response};
+use lazy_static::lazy_static;
+use parking_lot::{Mutex, RwLock};
+use parse_duration::parse;
+use postage::watch;
+use serde::{Deserialize, Serialize};
+use std::env;
+use std::ops::Add;
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use tiktoken_rs::{cl100k_base, CoreBPE};
+use util::http::{HttpClient, Request};
+use util::ResultExt;
+
+use crate::auth::{CredentialProvider, ProviderCredential};
+use crate::embedding::{Embedding, EmbeddingProvider};
+use crate::models::LanguageModel;
+use crate::providers::open_ai::OpenAILanguageModel;
+
+use crate::providers::open_ai::OPENAI_API_URL;
+
+lazy_static! {
+    static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
+}
+
+#[derive(Clone)]
+pub struct OpenAIEmbeddingProvider {
+    model: OpenAILanguageModel,
+    credential: Arc<RwLock<ProviderCredential>>,
+    pub client: Arc<dyn HttpClient>,
+    pub executor: Arc<Executor>,
+    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
+    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
+}
+
+#[derive(Serialize)]
+struct OpenAIEmbeddingRequest<'a> {
+    model: &'static str,
+    input: Vec<&'a str>,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingResponse {
+    data: Vec<OpenAIEmbedding>,
+    usage: OpenAIEmbeddingUsage,
+}
+
+#[derive(Debug, Deserialize)]
+struct OpenAIEmbedding {
+    embedding: Vec<f32>,
+    index: usize,
+    object: String,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingUsage {
+    prompt_tokens: usize,
+    total_tokens: usize,
+}
+
+impl OpenAIEmbeddingProvider {
+    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Executor>) -> Self {
+        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
+        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
+
+        let model = OpenAILanguageModel::load("text-embedding-ada-002");
+        let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+
+        OpenAIEmbeddingProvider {
+            model,
+            credential,
+            client,
+            executor,
+            rate_limit_count_rx,
+            rate_limit_count_tx,
+        }
+    }
+
+    fn get_api_key(&self) -> Result<String> {
+        match self.credential.read().clone() {
+            ProviderCredential::Credentials { api_key } => Ok(api_key),
+            _ => Err(anyhow!("api credentials not provided")),
+        }
+    }
+
+    fn resolve_rate_limit(&self) {
+        let reset_time = *self.rate_limit_count_tx.lock().borrow();
+
+        if let Some(reset_time) = reset_time {
+            if Instant::now() >= reset_time {
+                *self.rate_limit_count_tx.lock().borrow_mut() = None
+            }
+        }
+
+        log::trace!(
+            "resolving reset time: {:?}",
+            *self.rate_limit_count_tx.lock().borrow()
+        );
+    }
+
+    fn update_reset_time(&self, reset_time: Instant) {
+        let original_time = *self.rate_limit_count_tx.lock().borrow();
+
+        let updated_time = if let Some(original_time) = original_time {
+            if reset_time < original_time {
+                Some(reset_time)
+            } else {
+                Some(original_time)
+            }
+        } else {
+            Some(reset_time)
+        };
+
+        log::trace!("updating rate limit time: {:?}", updated_time);
+
+        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
+    }
+    async fn send_request(
+        &self,
+        api_key: &str,
+        spans: Vec<&str>,
+        request_timeout: u64,
+    ) -> Result<Response<AsyncBody>> {
+        let request = Request::post("https://api.openai.com/v1/embeddings")
+            .redirect_policy(isahc::config::RedirectPolicy::Follow)
+            .timeout(Duration::from_secs(request_timeout))
+            .header("Content-Type", "application/json")
+            .header("Authorization", format!("Bearer {}", api_key))
+            .body(
+                serde_json::to_string(&OpenAIEmbeddingRequest {
+                    input: spans.clone(),
+                    model: "text-embedding-ada-002",
+                })
+                .unwrap()
+                .into(),
+            )?;
+
+        Ok(self.client.send(request).await?)
+    }
+}
+
+#[async_trait]
+impl CredentialProvider for OpenAIEmbeddingProvider {
+    fn has_credentials(&self) -> bool {
+        match *self.credential.read() {
+            ProviderCredential::Credentials { .. } => true,
+            _ => false,
+        }
+    }
+    async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
+        let existing_credential = self.credential.read().clone();
+
+        let retrieved_credential = cx
+            .run_on_main(move |cx| match existing_credential {
+                ProviderCredential::Credentials { .. } => {
+                    return existing_credential.clone();
+                }
+                _ => {
+                    if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
+                        return ProviderCredential::Credentials { api_key };
+                    }
+
+                    if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
+                    {
+                        if let Some(api_key) = String::from_utf8(api_key).log_err() {
+                            return ProviderCredential::Credentials { api_key };
+                        } else {
+                            return ProviderCredential::NoCredentials;
+                        }
+                    } else {
+                        return ProviderCredential::NoCredentials;
+                    }
+                }
+            })
+            .await;
+
+        *self.credential.write() = retrieved_credential.clone();
+        retrieved_credential
+    }
+
+    async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
+        *self.credential.write() = credential.clone();
+        let credential = credential.clone();
+        cx.run_on_main(move |cx| match credential {
+            ProviderCredential::Credentials { api_key } => {
+                cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+                    .log_err();
+            }
+            _ => {}
+        })
+        .await;
+    }
+    async fn delete_credentials(&self, cx: &mut AppContext) {
+        cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
+            .await;
+        *self.credential.write() = ProviderCredential::NoCredentials;
+    }
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddingProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+        model
+    }
+
+    fn max_tokens_per_batch(&self) -> usize {
+        50000
+    }
+
+    fn rate_limit_expiration(&self) -> Option<Instant> {
+        *self.rate_limit_count_rx.borrow()
+    }
+
+    async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+        const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
+        const MAX_RETRIES: usize = 4;
+
+        let api_key = self.get_api_key()?;
+
+        let mut request_number = 0;
+        let mut rate_limiting = false;
+        let mut request_timeout: u64 = 15;
+        let mut response: Response<AsyncBody>;
+        while request_number < MAX_RETRIES {
+            response = self
+                .send_request(
+                    &api_key,
+                    spans.iter().map(|x| &**x).collect(),
+                    request_timeout,
+                )
+                .await?;
+
+            request_number += 1;
+
+            match response.status() {
+                StatusCode::REQUEST_TIMEOUT => {
+                    request_timeout += 5;
+                }
+                StatusCode::OK => {
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+                    let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+                    log::trace!(
+                        "openai embedding completed. tokens: {:?}",
+                        response.usage.total_tokens
+                    );
+
+                    // If we complete a request successfully that was previously rate_limited
+                    // resolve the rate limit
+                    if rate_limiting {
+                        self.resolve_rate_limit()
+                    }
+
+                    return Ok(response
+                        .data
+                        .into_iter()
+                        .map(|embedding| Embedding::from(embedding.embedding))
+                        .collect());
+                }
+                StatusCode::TOO_MANY_REQUESTS => {
+                    rate_limiting = true;
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+
+                    let delay_duration = {
+                        let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+                        if let Some(time_to_reset) =
+                            response.headers().get("x-ratelimit-reset-tokens")
+                        {
+                            if let Ok(time_str) = time_to_reset.to_str() {
+                                parse(time_str).unwrap_or(delay)
+                            } else {
+                                delay
+                            }
+                        } else {
+                            delay
+                        }
+                    };
+
+                    // If we've previously rate limited, increment the duration but not the count
+                    let reset_time = Instant::now().add(delay_duration);
+                    self.update_reset_time(reset_time);
+
+                    log::trace!(
+                        "openai rate limiting: waiting {:?} until lifted",
+                        &delay_duration
+                    );
+
+                    self.executor.timer(delay_duration).await;
+                }
+                _ => {
+                    let mut body = String::new();
+                    response.body_mut().read_to_string(&mut body).await?;
+                    return Err(anyhow!(
+                        "open ai bad request: {:?} {:?}",
+                        &response.status(),
+                        body
+                    ));
+                }
+            }
+        }
+        Err(anyhow!("openai max retries"))
+    }
+}

crates/ai2/src/providers/open_ai/mod.rs 🔗

@@ -0,0 +1,9 @@
+pub mod completion;
+pub mod embedding;
+pub mod model;
+
+pub use completion::*;
+pub use embedding::*;
+pub use model::OpenAILanguageModel;
+
+pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";

crates/ai2/src/providers/open_ai/model.rs 🔗

@@ -0,0 +1,57 @@
+use anyhow::anyhow;
+use tiktoken_rs::CoreBPE;
+use util::ResultExt;
+
+use crate::models::{LanguageModel, TruncationDirection};
+
+#[derive(Clone)]
+pub struct OpenAILanguageModel {
+    name: String,
+    bpe: Option<CoreBPE>,
+}
+
+impl OpenAILanguageModel {
+    pub fn load(model_name: &str) -> Self {
+        let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
+        OpenAILanguageModel {
+            name: model_name.to_string(),
+            bpe,
+        }
+    }
+}
+
+impl LanguageModel for OpenAILanguageModel {
+    fn name(&self) -> String {
+        self.name.clone()
+    }
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+        if let Some(bpe) = &self.bpe {
+            anyhow::Ok(bpe.encode_with_special_tokens(content).len())
+        } else {
+            Err(anyhow!("bpe for open ai model was not retrieved"))
+        }
+    }
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String> {
+        if let Some(bpe) = &self.bpe {
+            let tokens = bpe.encode_with_special_tokens(content);
+            if tokens.len() > length {
+                match direction {
+                    TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
+                    TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
+                }
+            } else {
+                bpe.decode(tokens)
+            }
+        } else {
+            Err(anyhow!("bpe for open ai model was not retrieved"))
+        }
+    }
+    fn capacity(&self) -> anyhow::Result<usize> {
+        anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
+    }
+}

crates/ai2/src/providers/open_ai/new.rs 🔗

@@ -0,0 +1,11 @@
+pub trait LanguageModel {
+    fn name(&self) -> String;
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String>;
+    fn capacity(&self) -> anyhow::Result<usize>;
+}

crates/ai2/src/test.rs 🔗

@@ -0,0 +1,193 @@
+use std::{
+    sync::atomic::{self, AtomicUsize, Ordering},
+    time::Instant,
+};
+
+use async_trait::async_trait;
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui2::AppContext;
+use parking_lot::Mutex;
+
+use crate::{
+    auth::{CredentialProvider, ProviderCredential},
+    completion::{CompletionProvider, CompletionRequest},
+    embedding::{Embedding, EmbeddingProvider},
+    models::{LanguageModel, TruncationDirection},
+};
+
+#[derive(Clone)]
+pub struct FakeLanguageModel {
+    pub capacity: usize,
+}
+
+impl LanguageModel for FakeLanguageModel {
+    fn name(&self) -> String {
+        "dummy".to_string()
+    }
+    fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+        anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+    }
+    fn truncate(
+        &self,
+        content: &str,
+        length: usize,
+        direction: TruncationDirection,
+    ) -> anyhow::Result<String> {
+        println!("TRYING TO TRUNCATE: {:?}", length.clone());
+
+        if length > self.count_tokens(content)? {
+            println!("NOT TRUNCATING");
+            return anyhow::Ok(content.to_string());
+        }
+
+        anyhow::Ok(match direction {
+            TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
+                .into_iter()
+                .collect::<String>(),
+            TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
+                .into_iter()
+                .collect::<String>(),
+        })
+    }
+    fn capacity(&self) -> anyhow::Result<usize> {
+        anyhow::Ok(self.capacity)
+    }
+}
+
+pub struct FakeEmbeddingProvider {
+    pub embedding_count: AtomicUsize,
+}
+
+impl Clone for FakeEmbeddingProvider {
+    fn clone(&self) -> Self {
+        FakeEmbeddingProvider {
+            embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
+        }
+    }
+}
+
+impl Default for FakeEmbeddingProvider {
+    fn default() -> Self {
+        FakeEmbeddingProvider {
+            embedding_count: AtomicUsize::default(),
+        }
+    }
+}
+
+impl FakeEmbeddingProvider {
+    pub fn embedding_count(&self) -> usize {
+        self.embedding_count.load(atomic::Ordering::SeqCst)
+    }
+
+    pub fn embed_sync(&self, span: &str) -> Embedding {
+        let mut result = vec![1.0; 26];
+        for letter in span.chars() {
+            let letter = letter.to_ascii_lowercase();
+            if letter as u32 >= 'a' as u32 {
+                let ix = (letter as u32) - ('a' as u32);
+                if ix < 26 {
+                    result[ix as usize] += 1.0;
+                }
+            }
+        }
+
+        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+        for x in &mut result {
+            *x /= norm;
+        }
+
+        result.into()
+    }
+}
+
+#[async_trait]
+impl CredentialProvider for FakeEmbeddingProvider {
+    fn has_credentials(&self) -> bool {
+        true
+    }
+    async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
+        ProviderCredential::NotNeeded
+    }
+    async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
+    async fn delete_credentials(&self, _cx: &mut AppContext) {}
+}
+
+#[async_trait]
+impl EmbeddingProvider for FakeEmbeddingProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        Box::new(FakeLanguageModel { capacity: 1000 })
+    }
+    fn max_tokens_per_batch(&self) -> usize {
+        1000
+    }
+
+    fn rate_limit_expiration(&self) -> Option<Instant> {
+        None
+    }
+
+    async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
+        self.embedding_count
+            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
+
+        anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
+    }
+}
+
+pub struct FakeCompletionProvider {
+    last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
+}
+
+impl Clone for FakeCompletionProvider {
+    fn clone(&self) -> Self {
+        Self {
+            last_completion_tx: Mutex::new(None),
+        }
+    }
+}
+
+impl FakeCompletionProvider {
+    pub fn new() -> Self {
+        Self {
+            last_completion_tx: Mutex::new(None),
+        }
+    }
+
+    pub fn send_completion(&self, completion: impl Into<String>) {
+        let mut tx = self.last_completion_tx.lock();
+        tx.as_mut().unwrap().try_send(completion.into()).unwrap();
+    }
+
+    pub fn finish_completion(&self) {
+        self.last_completion_tx.lock().take().unwrap();
+    }
+}
+
+#[async_trait]
+impl CredentialProvider for FakeCompletionProvider {
+    fn has_credentials(&self) -> bool {
+        true
+    }
+    async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
+        ProviderCredential::NotNeeded
+    }
+    async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
+    async fn delete_credentials(&self, _cx: &mut AppContext) {}
+}
+
+impl CompletionProvider for FakeCompletionProvider {
+    fn base_model(&self) -> Box<dyn LanguageModel> {
+        let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
+        model
+    }
+    fn complete(
+        &self,
+        _prompt: Box<dyn CompletionRequest>,
+    ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
+        let (tx, rx) = mpsc::channel(1);
+        *self.last_completion_tx.lock() = Some(tx);
+        async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
+    }
+    fn box_clone(&self) -> Box<dyn CompletionProvider> {
+        Box::new((*self).clone())
+    }
+}

crates/assistant/Cargo.toml 🔗

@@ -45,6 +45,7 @@ tiktoken-rs = "0.5"
 [dev-dependencies]
 editor = { path = "../editor", features = ["test-support"] }
 project = { path = "../project", features = ["test-support"] }
+ai = { path = "../ai", features = ["test-support"]}
 
 ctor.workspace = true
 env_logger.workspace = true

crates/assistant/src/assistant.rs 🔗

@@ -4,7 +4,7 @@ mod codegen;
 mod prompts;
 mod streaming_diff;
 
-use ai::completion::Role;
+use ai::providers::open_ai::Role;
 use anyhow::Result;
 pub use assistant_panel::AssistantPanel;
 use assistant_settings::OpenAIModel;

crates/assistant/src/assistant_panel.rs 🔗

@@ -5,12 +5,14 @@ use crate::{
     MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
     SavedMessage,
 };
+
 use ai::{
-    completion::{
-        stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
-    },
-    templates::repository_context::PromptCodeSnippet,
+    auth::ProviderCredential,
+    completion::{CompletionProvider, CompletionRequest},
+    providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage},
 };
+
+use ai::prompts::repository_context::PromptCodeSnippet;
 use anyhow::{anyhow, Result};
 use chrono::{DateTime, Local};
 use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
@@ -43,8 +45,8 @@ use search::BufferSearchBar;
 use semantic_index::{SemanticIndex, SemanticIndexStatus};
 use settings::SettingsStore;
 use std::{
-    cell::{Cell, RefCell},
-    cmp, env,
+    cell::Cell,
+    cmp,
     fmt::Write,
     iter,
     ops::Range,
@@ -97,8 +99,8 @@ pub fn init(cx: &mut AppContext) {
     cx.capture_action(ConversationEditor::copy);
     cx.add_action(ConversationEditor::split);
     cx.capture_action(ConversationEditor::cycle_message_role);
-    cx.add_action(AssistantPanel::save_api_key);
-    cx.add_action(AssistantPanel::reset_api_key);
+    cx.add_action(AssistantPanel::save_credentials);
+    cx.add_action(AssistantPanel::reset_credentials);
     cx.add_action(AssistantPanel::toggle_zoom);
     cx.add_action(AssistantPanel::deploy);
     cx.add_action(AssistantPanel::select_next_match);
@@ -140,9 +142,8 @@ pub struct AssistantPanel {
     zoomed: bool,
     has_focus: bool,
     toolbar: ViewHandle<Toolbar>,
-    api_key: Rc<RefCell<Option<String>>>,
+    completion_provider: Box<dyn CompletionProvider>,
     api_key_editor: Option<ViewHandle<Editor>>,
-    has_read_credentials: bool,
     languages: Arc<LanguageRegistry>,
     fs: Arc<dyn Fs>,
     subscriptions: Vec<Subscription>,
@@ -202,6 +203,11 @@ impl AssistantPanel {
                     });
 
                     let semantic_index = SemanticIndex::global(cx);
+                    // Defaulting currently to GPT4, allow for this to be set via config.
+                    let completion_provider = Box::new(OpenAICompletionProvider::new(
+                        "gpt-4",
+                        cx.background().clone(),
+                    ));
 
                     let mut this = Self {
                         workspace: workspace_handle,
@@ -213,9 +219,8 @@ impl AssistantPanel {
                         zoomed: false,
                         has_focus: false,
                         toolbar,
-                        api_key: Rc::new(RefCell::new(None)),
+                        completion_provider,
                         api_key_editor: None,
-                        has_read_credentials: false,
                         languages: workspace.app_state().languages.clone(),
                         fs: workspace.app_state().fs.clone(),
                         width: None,
@@ -254,10 +259,7 @@ impl AssistantPanel {
         cx: &mut ViewContext<Workspace>,
     ) {
         let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
-            if this
-                .update(cx, |assistant, cx| assistant.load_api_key(cx))
-                .is_some()
-            {
+            if this.update(cx, |assistant, _| assistant.has_credentials()) {
                 this
             } else {
                 workspace.focus_panel::<AssistantPanel>(cx);
@@ -289,12 +291,6 @@ impl AssistantPanel {
         cx: &mut ViewContext<Self>,
         project: &ModelHandle<Project>,
     ) {
-        let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
-            api_key
-        } else {
-            return;
-        };
-
         let selection = editor.read(cx).selections.newest_anchor().clone();
         if selection.start.excerpt_id != selection.end.excerpt_id {
             return;
@@ -325,10 +321,13 @@ impl AssistantPanel {
 
         let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
         let provider = Arc::new(OpenAICompletionProvider::new(
-            api_key,
+            "gpt-4",
             cx.background().clone(),
         ));
 
+        // Retrieve Credentials Authenticates the Provider
+        // provider.retrieve_credentials(cx);
+
         let codegen = cx.add_model(|cx| {
             Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
         });
@@ -745,13 +744,14 @@ impl AssistantPanel {
                 content: prompt,
             });
 
-            let request = OpenAIRequest {
+            let request = Box::new(OpenAIRequest {
                 model: model.full_name().into(),
                 messages,
                 stream: true,
                 stop: vec!["|END|>".to_string()],
                 temperature,
-            };
+            });
+
             codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
             anyhow::Ok(())
         })
@@ -811,7 +811,7 @@ impl AssistantPanel {
     fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
         let editor = cx.add_view(|cx| {
             ConversationEditor::new(
-                self.api_key.clone(),
+                self.completion_provider.clone(),
                 self.languages.clone(),
                 self.fs.clone(),
                 self.workspace.clone(),
@@ -870,17 +870,19 @@ impl AssistantPanel {
         }
     }
 
-    fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+    fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
         if let Some(api_key) = self
             .api_key_editor
             .as_ref()
             .map(|editor| editor.read(cx).text(cx))
         {
             if !api_key.is_empty() {
-                cx.platform()
-                    .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
-                    .log_err();
-                *self.api_key.borrow_mut() = Some(api_key);
+                let credential = ProviderCredential::Credentials {
+                    api_key: api_key.clone(),
+                };
+
+                self.completion_provider.save_credentials(cx, credential);
+
                 self.api_key_editor.take();
                 cx.focus_self();
                 cx.notify();
@@ -890,9 +892,8 @@ impl AssistantPanel {
         }
     }
 
-    fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
-        cx.platform().delete_credentials(OPENAI_API_URL).log_err();
-        self.api_key.take();
+    fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
+        self.completion_provider.delete_credentials(cx);
         self.api_key_editor = Some(build_api_key_editor(cx));
         cx.focus_self();
         cx.notify();
@@ -1151,13 +1152,12 @@ impl AssistantPanel {
 
         let fs = self.fs.clone();
         let workspace = self.workspace.clone();
-        let api_key = self.api_key.clone();
         let languages = self.languages.clone();
         cx.spawn(|this, mut cx| async move {
             let saved_conversation = fs.load(&path).await?;
             let saved_conversation = serde_json::from_str(&saved_conversation)?;
             let conversation = cx.add_model(|cx| {
-                Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
+                Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
             });
             this.update(&mut cx, |this, cx| {
                 // If, by the time we've loaded the conversation, the user has already opened
@@ -1181,30 +1181,12 @@ impl AssistantPanel {
             .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
     }
 
-    fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
-        if self.api_key.borrow().is_none() && !self.has_read_credentials {
-            self.has_read_credentials = true;
-            let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
-                Some(api_key)
-            } else if let Some((_, api_key)) = cx
-                .platform()
-                .read_credentials(OPENAI_API_URL)
-                .log_err()
-                .flatten()
-            {
-                String::from_utf8(api_key).log_err()
-            } else {
-                None
-            };
-            if let Some(api_key) = api_key {
-                *self.api_key.borrow_mut() = Some(api_key);
-            } else if self.api_key_editor.is_none() {
-                self.api_key_editor = Some(build_api_key_editor(cx));
-                cx.notify();
-            }
-        }
+    fn has_credentials(&mut self) -> bool {
+        self.completion_provider.has_credentials()
+    }
 
-        self.api_key.borrow().clone()
+    fn load_credentials(&mut self, cx: &mut ViewContext<Self>) {
+        self.completion_provider.retrieve_credentials(cx);
     }
 }
 
@@ -1389,7 +1371,7 @@ impl Panel for AssistantPanel {
 
     fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
         if active {
-            self.load_api_key(cx);
+            self.load_credentials(cx);
 
             if self.editors.is_empty() {
                 self.new_conversation(cx);
@@ -1454,10 +1436,10 @@ struct Conversation {
     token_count: Option<usize>,
     max_token_count: usize,
     pending_token_count: Task<Option<()>>,
-    api_key: Rc<RefCell<Option<String>>>,
     pending_save: Task<Result<()>>,
     path: Option<PathBuf>,
     _subscriptions: Vec<Subscription>,
+    completion_provider: Box<dyn CompletionProvider>,
 }
 
 impl Entity for Conversation {
@@ -1466,9 +1448,9 @@ impl Entity for Conversation {
 
 impl Conversation {
     fn new(
-        api_key: Rc<RefCell<Option<String>>>,
         language_registry: Arc<LanguageRegistry>,
         cx: &mut ModelContext<Self>,
+        completion_provider: Box<dyn CompletionProvider>,
     ) -> Self {
         let markdown = language_registry.language_for_name("Markdown");
         let buffer = cx.add_model(|cx| {
@@ -1507,8 +1489,8 @@ impl Conversation {
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
             pending_save: Task::ready(Ok(())),
             path: None,
-            api_key,
             buffer,
+            completion_provider,
         };
         let message = MessageAnchor {
             id: MessageId(post_inc(&mut this.next_message_id.0)),
@@ -1554,7 +1536,6 @@ impl Conversation {
     fn deserialize(
         saved_conversation: SavedConversation,
         path: PathBuf,
-        api_key: Rc<RefCell<Option<String>>>,
         language_registry: Arc<LanguageRegistry>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
@@ -1563,6 +1544,10 @@ impl Conversation {
             None => Some(Uuid::new_v4().to_string()),
         };
         let model = saved_conversation.model;
+        let completion_provider: Box<dyn CompletionProvider> = Box::new(
+            OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
+        );
+        completion_provider.retrieve_credentials(cx);
         let markdown = language_registry.language_for_name("Markdown");
         let mut message_anchors = Vec::new();
         let mut next_message_id = MessageId(0);
@@ -1609,8 +1594,8 @@ impl Conversation {
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
             pending_save: Task::ready(Ok(())),
             path: Some(path),
-            api_key,
             buffer,
+            completion_provider,
         };
         this.count_remaining_tokens(cx);
         this
@@ -1731,11 +1716,11 @@ impl Conversation {
         }
 
         if should_assist {
-            let Some(api_key) = self.api_key.borrow().clone() else {
+            if !self.completion_provider.has_credentials() {
                 return Default::default();
-            };
+            }
 
-            let request = OpenAIRequest {
+            let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
                 model: self.model.full_name().to_string(),
                 messages: self
                     .messages(cx)
@@ -1745,9 +1730,9 @@ impl Conversation {
                 stream: true,
                 stop: vec![],
                 temperature: 1.0,
-            };
+            });
 
-            let stream = stream_completion(api_key, cx.background().clone(), request);
+            let stream = self.completion_provider.complete(request);
             let assistant_message = self
                 .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
                 .unwrap();
@@ -1765,33 +1750,28 @@ impl Conversation {
                         let mut messages = stream.await?;
 
                         while let Some(message) = messages.next().await {
-                            let mut message = message?;
-                            if let Some(choice) = message.choices.pop() {
-                                this.upgrade(&cx)
-                                    .ok_or_else(|| anyhow!("conversation was dropped"))?
-                                    .update(&mut cx, |this, cx| {
-                                        let text: Arc<str> = choice.delta.content?.into();
-                                        let message_ix =
-                                            this.message_anchors.iter().position(|message| {
-                                                message.id == assistant_message_id
-                                            })?;
-                                        this.buffer.update(cx, |buffer, cx| {
-                                            let offset = this.message_anchors[message_ix + 1..]
-                                                .iter()
-                                                .find(|message| message.start.is_valid(buffer))
-                                                .map_or(buffer.len(), |message| {
-                                                    message
-                                                        .start
-                                                        .to_offset(buffer)
-                                                        .saturating_sub(1)
-                                                });
-                                            buffer.edit([(offset..offset, text)], None, cx);
-                                        });
-                                        cx.emit(ConversationEvent::StreamedCompletion);
-
-                                        Some(())
+                            let text = message?;
+
+                            this.upgrade(&cx)
+                                .ok_or_else(|| anyhow!("conversation was dropped"))?
+                                .update(&mut cx, |this, cx| {
+                                    let message_ix = this
+                                        .message_anchors
+                                        .iter()
+                                        .position(|message| message.id == assistant_message_id)?;
+                                    this.buffer.update(cx, |buffer, cx| {
+                                        let offset = this.message_anchors[message_ix + 1..]
+                                            .iter()
+                                            .find(|message| message.start.is_valid(buffer))
+                                            .map_or(buffer.len(), |message| {
+                                                message.start.to_offset(buffer).saturating_sub(1)
+                                            });
+                                        buffer.edit([(offset..offset, text)], None, cx);
                                     });
-                            }
+                                    cx.emit(ConversationEvent::StreamedCompletion);
+
+                                    Some(())
+                                });
                             smol::future::yield_now().await;
                         }
 
@@ -2013,57 +1993,54 @@ impl Conversation {
 
     fn summarize(&mut self, cx: &mut ModelContext<Self>) {
         if self.message_anchors.len() >= 2 && self.summary.is_none() {
-            let api_key = self.api_key.borrow().clone();
-            if let Some(api_key) = api_key {
-                let messages = self
-                    .messages(cx)
-                    .take(2)
-                    .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
-                    .chain(Some(RequestMessage {
-                        role: Role::User,
-                        content:
-                            "Summarize the conversation into a short title without punctuation"
-                                .into(),
-                    }));
-                let request = OpenAIRequest {
-                    model: self.model.full_name().to_string(),
-                    messages: messages.collect(),
-                    stream: true,
-                    stop: vec![],
-                    temperature: 1.0,
-                };
+            if !self.completion_provider.has_credentials() {
+                return;
+            }
 
-                let stream = stream_completion(api_key, cx.background().clone(), request);
-                self.pending_summary = cx.spawn(|this, mut cx| {
-                    async move {
-                        let mut messages = stream.await?;
+            let messages = self
+                .messages(cx)
+                .take(2)
+                .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
+                .chain(Some(RequestMessage {
+                    role: Role::User,
+                    content: "Summarize the conversation into a short title without punctuation"
+                        .into(),
+                }));
+            let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
+                model: self.model.full_name().to_string(),
+                messages: messages.collect(),
+                stream: true,
+                stop: vec![],
+                temperature: 1.0,
+            });
 
-                        while let Some(message) = messages.next().await {
-                            let mut message = message?;
-                            if let Some(choice) = message.choices.pop() {
-                                let text = choice.delta.content.unwrap_or_default();
-                                this.update(&mut cx, |this, cx| {
-                                    this.summary
-                                        .get_or_insert(Default::default())
-                                        .text
-                                        .push_str(&text);
-                                    cx.emit(ConversationEvent::SummaryChanged);
-                                });
-                            }
-                        }
+            let stream = self.completion_provider.complete(request);
+            self.pending_summary = cx.spawn(|this, mut cx| {
+                async move {
+                    let mut messages = stream.await?;
 
+                    while let Some(message) = messages.next().await {
+                        let text = message?;
                         this.update(&mut cx, |this, cx| {
-                            if let Some(summary) = this.summary.as_mut() {
-                                summary.done = true;
-                                cx.emit(ConversationEvent::SummaryChanged);
-                            }
+                            this.summary
+                                .get_or_insert(Default::default())
+                                .text
+                                .push_str(&text);
+                            cx.emit(ConversationEvent::SummaryChanged);
                         });
-
-                        anyhow::Ok(())
                     }
-                    .log_err()
-                });
-            }
+
+                    this.update(&mut cx, |this, cx| {
+                        if let Some(summary) = this.summary.as_mut() {
+                            summary.done = true;
+                            cx.emit(ConversationEvent::SummaryChanged);
+                        }
+                    });
+
+                    anyhow::Ok(())
+                }
+                .log_err()
+            });
         }
     }
 
@@ -2224,13 +2201,14 @@ struct ConversationEditor {
 
 impl ConversationEditor {
     fn new(
-        api_key: Rc<RefCell<Option<String>>>,
+        completion_provider: Box<dyn CompletionProvider>,
         language_registry: Arc<LanguageRegistry>,
         fs: Arc<dyn Fs>,
         workspace: WeakViewHandle<Workspace>,
         cx: &mut ViewContext<Self>,
     ) -> Self {
-        let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
+        let conversation =
+            cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider));
         Self::for_conversation(conversation, fs, workspace, cx)
     }
 
@@ -3419,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
 mod tests {
     use super::*;
     use crate::MessageId;
+    use ai::test::FakeCompletionProvider;
     use gpui::AppContext;
 
     #[gpui::test]
@@ -3426,7 +3405,9 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+
+        let completion_provider = Box::new(FakeCompletionProvider::new());
+        let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3554,7 +3535,9 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+        let completion_provider = Box::new(FakeCompletionProvider::new());
+
+        let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3650,7 +3633,8 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
-        let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+        let completion_provider = Box::new(FakeCompletionProvider::new());
+        let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();
 
         let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3732,8 +3716,9 @@ mod tests {
         cx.set_global(SettingsStore::test(cx));
         init(cx);
         let registry = Arc::new(LanguageRegistry::test());
+        let completion_provider = Box::new(FakeCompletionProvider::new());
         let conversation =
-            cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
+            cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
         let buffer = conversation.read(cx).buffer.clone();
         let message_0 = conversation.read(cx).message_anchors[0].id;
         let message_1 = conversation.update(cx, |conversation, cx| {
@@ -3770,7 +3755,6 @@ mod tests {
             Conversation::deserialize(
                 conversation.read(cx).serialize(cx),
                 Default::default(),
-                Default::default(),
                 registry.clone(),
                 cx,
             )

crates/assistant/src/codegen.rs 🔗

@@ -1,5 +1,5 @@
 use crate::streaming_diff::{Hunk, StreamingDiff};
-use ai::completion::{CompletionProvider, OpenAIRequest};
+use ai::completion::{CompletionProvider, CompletionRequest};
 use anyhow::Result;
 use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
 use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
@@ -96,7 +96,7 @@ impl Codegen {
         self.error.as_ref()
     }
 
-    pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
+    pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
         let range = self.range();
         let snapshot = self.snapshot.clone();
         let selected_text = snapshot
@@ -336,17 +336,25 @@ fn strip_markdown_codeblock(
 #[cfg(test)]
 mod tests {
     use super::*;
-    use futures::{
-        future::BoxFuture,
-        stream::{self, BoxStream},
-    };
+    use ai::test::FakeCompletionProvider;
+    use futures::stream::{self};
     use gpui::{executor::Deterministic, TestAppContext};
     use indoc::indoc;
     use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
-    use parking_lot::Mutex;
     use rand::prelude::*;
+    use serde::Serialize;
     use settings::SettingsStore;
-    use smol::future::FutureExt;
+
+    #[derive(Serialize)]
+    pub struct DummyCompletionRequest {
+        pub name: String,
+    }
+
+    impl CompletionRequest for DummyCompletionRequest {
+        fn data(&self) -> serde_json::Result<String> {
+            serde_json::to_string(self)
+        }
+    }
 
     #[gpui::test(iterations = 10)]
     async fn test_transform_autoindent(
@@ -372,7 +380,7 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
         });
-        let provider = Arc::new(TestCompletionProvider::new());
+        let provider = Arc::new(FakeCompletionProvider::new());
         let codegen = cx.add_model(|cx| {
             Codegen::new(
                 buffer.clone(),
@@ -381,7 +389,11 @@ mod tests {
                 cx,
             )
         });
-        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+        let request = Box::new(DummyCompletionRequest {
+            name: "test".to_string(),
+        });
+        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
 
         let mut new_text = concat!(
             "       let mut x = 0;\n",
@@ -434,7 +446,7 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 6))
         });
-        let provider = Arc::new(TestCompletionProvider::new());
+        let provider = Arc::new(FakeCompletionProvider::new());
         let codegen = cx.add_model(|cx| {
             Codegen::new(
                 buffer.clone(),
@@ -443,7 +455,11 @@ mod tests {
                 cx,
             )
         });
-        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+        let request = Box::new(DummyCompletionRequest {
+            name: "test".to_string(),
+        });
+        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
 
         let mut new_text = concat!(
             "t mut x = 0;\n",
@@ -496,7 +512,7 @@ mod tests {
             let snapshot = buffer.snapshot(cx);
             snapshot.anchor_before(Point::new(1, 2))
         });
-        let provider = Arc::new(TestCompletionProvider::new());
+        let provider = Arc::new(FakeCompletionProvider::new());
         let codegen = cx.add_model(|cx| {
             Codegen::new(
                 buffer.clone(),
@@ -505,7 +521,11 @@ mod tests {
                 cx,
             )
         });
-        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+        let request = Box::new(DummyCompletionRequest {
+            name: "test".to_string(),
+        });
+        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
 
         let mut new_text = concat!(
             "let mut x = 0;\n",
@@ -593,38 +613,6 @@ mod tests {
         }
     }
 
-    struct TestCompletionProvider {
-        last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
-    }
-
-    impl TestCompletionProvider {
-        fn new() -> Self {
-            Self {
-                last_completion_tx: Mutex::new(None),
-            }
-        }
-
-        fn send_completion(&self, completion: impl Into<String>) {
-            let mut tx = self.last_completion_tx.lock();
-            tx.as_mut().unwrap().try_send(completion.into()).unwrap();
-        }
-
-        fn finish_completion(&self) {
-            self.last_completion_tx.lock().take().unwrap();
-        }
-    }
-
-    impl CompletionProvider for TestCompletionProvider {
-        fn complete(
-            &self,
-            _prompt: OpenAIRequest,
-        ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
-            let (tx, rx) = mpsc::channel(1);
-            *self.last_completion_tx.lock() = Some(tx);
-            async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
-        }
-    }
-
     fn rust_lang() -> Language {
         Language::new(
             LanguageConfig {

crates/assistant/src/prompts.rs 🔗

@@ -1,9 +1,10 @@
-use ai::models::{LanguageModel, OpenAILanguageModel};
-use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
-use ai::templates::file_context::FileContext;
-use ai::templates::generate::GenerateInlineContent;
-use ai::templates::preamble::EngineerPreamble;
-use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
+use ai::models::LanguageModel;
+use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
+use ai::prompts::file_context::FileContext;
+use ai::prompts::generate::GenerateInlineContent;
+use ai::prompts::preamble::EngineerPreamble;
+use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
+use ai::providers::open_ai::OpenAILanguageModel;
 use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
 use std::cmp::{self, Reverse};
 use std::ops::Range;

crates/semantic_index/Cargo.toml 🔗

@@ -42,6 +42,7 @@ sha1 = "0.10.5"
 ndarray = { version = "0.15.0" }
 
 [dev-dependencies]
+ai = { path = "../ai", features = ["test-support"] }
 collections = { path = "../collections", features = ["test-support"] }
 gpui = { path = "../gpui", features = ["test-support"] }
 language = { path = "../language", features = ["test-support"] }

crates/semantic_index/src/embedding_queue.rs 🔗

@@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
     pending_batch_token_count: usize,
     finished_files_tx: channel::Sender<FileToEmbed>,
     finished_files_rx: channel::Receiver<FileToEmbed>,
-    api_key: Option<String>,
 }
 
 #[derive(Clone)]
@@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
 }
 
 impl EmbeddingQueue {
-    pub fn new(
-        embedding_provider: Arc<dyn EmbeddingProvider>,
-        executor: Arc<Background>,
-        api_key: Option<String>,
-    ) -> Self {
+    pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
         let (finished_files_tx, finished_files_rx) = channel::unbounded();
         Self {
             embedding_provider,
@@ -64,14 +59,9 @@ impl EmbeddingQueue {
             pending_batch_token_count: 0,
             finished_files_tx,
             finished_files_rx,
-            api_key,
         }
     }
 
-    pub fn set_api_key(&mut self, api_key: Option<String>) {
-        self.api_key = api_key
-    }
-
     pub fn push(&mut self, file: FileToEmbed) {
         if file.spans.is_empty() {
             self.finished_files_tx.try_send(file).unwrap();
@@ -118,7 +108,6 @@ impl EmbeddingQueue {
 
         let finished_files_tx = self.finished_files_tx.clone();
         let embedding_provider = self.embedding_provider.clone();
-        let api_key = self.api_key.clone();
 
         self.executor
             .spawn(async move {
@@ -143,7 +132,7 @@ impl EmbeddingQueue {
                     return;
                 };
 
-                match embedding_provider.embed_batch(spans, api_key).await {
+                match embedding_provider.embed_batch(spans).await {
                     Ok(embeddings) => {
                         let mut embeddings = embeddings.into_iter();
                         for fragment in batch {

crates/semantic_index/src/parsing.rs 🔗

@@ -1,4 +1,7 @@
-use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::{
+    embedding::{Embedding, EmbeddingProvider},
+    models::TruncationDirection,
+};
 use anyhow::{anyhow, Result};
 use language::{Grammar, Language};
 use rusqlite::{
@@ -108,7 +111,14 @@ impl CodeContextRetriever {
             .replace("<language>", language_name.as_ref())
             .replace("<item>", &content);
         let digest = SpanDigest::from(document_span.as_str());
-        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
+        let model = self.embedding_provider.base_model();
+        let document_span = model.truncate(
+            &document_span,
+            model.capacity()?,
+            ai::models::TruncationDirection::End,
+        )?;
+        let token_count = model.count_tokens(&document_span)?;
+
         Ok(vec![Span {
             range: 0..content.len(),
             content: document_span,
@@ -131,7 +141,15 @@ impl CodeContextRetriever {
             )
             .replace("<item>", &content);
         let digest = SpanDigest::from(document_span.as_str());
-        let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
+
+        let model = self.embedding_provider.base_model();
+        let document_span = model.truncate(
+            &document_span,
+            model.capacity()?,
+            ai::models::TruncationDirection::End,
+        )?;
+        let token_count = model.count_tokens(&document_span)?;
+
         Ok(vec![Span {
             range: 0..content.len(),
             content: document_span,
@@ -222,8 +240,13 @@ impl CodeContextRetriever {
                 .replace("<language>", language_name.as_ref())
                 .replace("item", &span.content);
 
-            let (document_content, token_count) =
-                self.embedding_provider.truncate(&document_content);
+            let model = self.embedding_provider.base_model();
+            let document_content = model.truncate(
+                &document_content,
+                model.capacity()?,
+                TruncationDirection::End,
+            )?;
+            let token_count = model.count_tokens(&document_content)?;
 
             span.content = document_content;
             span.token_count = token_count;

crates/semantic_index/src/semantic_index.rs 🔗

@@ -7,7 +7,8 @@ pub mod semantic_index_settings;
 mod semantic_index_tests;
 
 use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
+use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::providers::open_ai::OpenAIEmbeddingProvider;
 use anyhow::{anyhow, Result};
 use collections::{BTreeMap, HashMap, HashSet};
 use db::VectorDatabase;
@@ -88,7 +89,7 @@ pub fn init(
         let semantic_index = SemanticIndex::new(
             fs,
             db_file_path,
-            Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+            Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
             language_registry,
             cx.clone(),
         )
@@ -123,8 +124,6 @@ pub struct SemanticIndex {
     _embedding_task: Task<()>,
     _parsing_files_tasks: Vec<Task<()>>,
     projects: HashMap<WeakModelHandle<Project>, ProjectState>,
-    api_key: Option<String>,
-    embedding_queue: Arc<Mutex<EmbeddingQueue>>,
 }
 
 struct ProjectState {
@@ -278,18 +277,18 @@ impl SemanticIndex {
         }
     }
 
-    pub fn authenticate(&mut self, cx: &AppContext) {
-        if self.api_key.is_none() {
-            self.api_key = self.embedding_provider.retrieve_credentials(cx);
-
-            self.embedding_queue
-                .lock()
-                .set_api_key(self.api_key.clone());
+    pub fn authenticate(&mut self, cx: &AppContext) -> bool {
+        if !self.embedding_provider.has_credentials() {
+            self.embedding_provider.retrieve_credentials(cx);
+        } else {
+            return true;
         }
+
+        self.embedding_provider.has_credentials()
     }
 
     pub fn is_authenticated(&self) -> bool {
-        self.api_key.is_some()
+        self.embedding_provider.has_credentials()
     }
 
     pub fn enabled(cx: &AppContext) -> bool {
@@ -339,7 +338,7 @@ impl SemanticIndex {
         Ok(cx.add_model(|cx| {
             let t0 = Instant::now();
             let embedding_queue =
-                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
+                EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
             let _embedding_task = cx.background().spawn({
                 let embedded_files = embedding_queue.finished_files();
                 let db = db.clone();
@@ -404,8 +403,6 @@ impl SemanticIndex {
                 _embedding_task,
                 _parsing_files_tasks,
                 projects: Default::default(),
-                api_key: None,
-                embedding_queue
             }
         }))
     }
@@ -720,13 +717,13 @@ impl SemanticIndex {
 
         let index = self.index_project(project.clone(), cx);
         let embedding_provider = self.embedding_provider.clone();
-        let api_key = self.api_key.clone();
 
         cx.spawn(|this, mut cx| async move {
             index.await?;
             let t0 = Instant::now();
+
             let query = embedding_provider
-                .embed_batch(vec![query], api_key)
+                .embed_batch(vec![query])
                 .await?
                 .pop()
                 .ok_or_else(|| anyhow!("could not embed query"))?;
@@ -944,7 +941,6 @@ impl SemanticIndex {
         let fs = self.fs.clone();
         let db_path = self.db.path().clone();
         let background = cx.background().clone();
-        let api_key = self.api_key.clone();
         cx.background().spawn(async move {
             let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
             let mut results = Vec::<SearchResult>::new();
@@ -959,15 +955,10 @@ impl SemanticIndex {
                     .parse_file_with_template(None, &snapshot.text(), language)
                     .log_err()
                     .unwrap_or_default();
-                if Self::embed_spans(
-                    &mut spans,
-                    embedding_provider.as_ref(),
-                    &db,
-                    api_key.clone(),
-                )
-                .await
-                .log_err()
-                .is_some()
+                if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
+                    .await
+                    .log_err()
+                    .is_some()
                 {
                     for span in spans {
                         let similarity = span.embedding.unwrap().similarity(&query);
@@ -1007,9 +998,8 @@ impl SemanticIndex {
         project: ModelHandle<Project>,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<()>> {
-        if self.api_key.is_none() {
-            self.authenticate(cx);
-            if self.api_key.is_none() {
+        if !self.is_authenticated() {
+            if !self.authenticate(cx) {
                 return Task::ready(Err(anyhow!("user is not authenticated")));
             }
         }
@@ -1192,7 +1182,6 @@ impl SemanticIndex {
         spans: &mut [Span],
         embedding_provider: &dyn EmbeddingProvider,
         db: &VectorDatabase,
-        api_key: Option<String>,
     ) -> Result<()> {
         let mut batch = Vec::new();
         let mut batch_tokens = 0;
@@ -1215,7 +1204,7 @@ impl SemanticIndex {
 
             if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
                 let batch_embeddings = embedding_provider
-                    .embed_batch(mem::take(&mut batch), api_key.clone())
+                    .embed_batch(mem::take(&mut batch))
                     .await?;
                 embeddings.extend(batch_embeddings);
                 batch_tokens = 0;
@@ -1227,7 +1216,7 @@ impl SemanticIndex {
 
         if !batch.is_empty() {
             let batch_embeddings = embedding_provider
-                .embed_batch(mem::take(&mut batch), api_key)
+                .embed_batch(mem::take(&mut batch))
                 .await?;
 
             embeddings.extend(batch_embeddings);

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -4,10 +4,9 @@ use crate::{
     semantic_index_settings::SemanticIndexSettings,
     FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
 };
-use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
-use anyhow::Result;
-use async_trait::async_trait;
-use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
+use ai::test::FakeEmbeddingProvider;
+
+use gpui::{executor::Deterministic, Task, TestAppContext};
 use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
 use parking_lot::Mutex;
 use pretty_assertions::assert_eq;
@@ -15,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
 use rand::{rngs::StdRng, Rng};
 use serde_json::json;
 use settings::SettingsStore;
-use std::{
-    path::Path,
-    sync::{
-        atomic::{self, AtomicUsize},
-        Arc,
-    },
-    time::{Instant, SystemTime},
-};
+use std::{path::Path, sync::Arc, time::SystemTime};
 use unindent::Unindent;
 use util::RandomCharIter;
 
@@ -228,7 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
 
     let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
 
-    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
+    let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
     for file in &files {
         queue.push(file.clone());
     }
@@ -280,7 +272,7 @@ fn assert_search_results(
 #[gpui::test]
 async fn test_code_context_retrieval_rust() {
     let language = rust_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
@@ -382,7 +374,7 @@ async fn test_code_context_retrieval_rust() {
 #[gpui::test]
 async fn test_code_context_retrieval_json() {
     let language = json_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -466,7 +458,7 @@ fn assert_documents_eq(
 #[gpui::test]
 async fn test_code_context_retrieval_javascript() {
     let language = js_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
@@ -565,7 +557,7 @@ async fn test_code_context_retrieval_javascript() {
 #[gpui::test]
 async fn test_code_context_retrieval_lua() {
     let language = lua_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -639,7 +631,7 @@ async fn test_code_context_retrieval_lua() {
 #[gpui::test]
 async fn test_code_context_retrieval_elixir() {
     let language = elixir_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -756,7 +748,7 @@ async fn test_code_context_retrieval_elixir() {
 #[gpui::test]
 async fn test_code_context_retrieval_cpp() {
     let language = cpp_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = "
@@ -909,7 +901,7 @@ async fn test_code_context_retrieval_cpp() {
 #[gpui::test]
 async fn test_code_context_retrieval_ruby() {
     let language = ruby_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -1100,7 +1092,7 @@ async fn test_code_context_retrieval_ruby() {
 #[gpui::test]
 async fn test_code_context_retrieval_php() {
     let language = php_lang();
-    let embedding_provider = Arc::new(DummyEmbeddings {});
+    let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
     let mut retriever = CodeContextRetriever::new(embedding_provider);
 
     let text = r#"
@@ -1248,65 +1240,6 @@ async fn test_code_context_retrieval_php() {
     );
 }
 
-#[derive(Default)]
-struct FakeEmbeddingProvider {
-    embedding_count: AtomicUsize,
-}
-
-impl FakeEmbeddingProvider {
-    fn embedding_count(&self) -> usize {
-        self.embedding_count.load(atomic::Ordering::SeqCst)
-    }
-
-    fn embed_sync(&self, span: &str) -> Embedding {
-        let mut result = vec![1.0; 26];
-        for letter in span.chars() {
-            let letter = letter.to_ascii_lowercase();
-            if letter as u32 >= 'a' as u32 {
-                let ix = (letter as u32) - ('a' as u32);
-                if ix < 26 {
-                    result[ix as usize] += 1.0;
-                }
-            }
-        }
-
-        let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
-        for x in &mut result {
-            *x /= norm;
-        }
-
-        result.into()
-    }
-}
-
-#[async_trait]
-impl EmbeddingProvider for FakeEmbeddingProvider {
-    fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
-        Some("Fake Credentials".to_string())
-    }
-    fn truncate(&self, span: &str) -> (String, usize) {
-        (span.to_string(), 1)
-    }
-
-    fn max_tokens_per_batch(&self) -> usize {
-        200
-    }
-
-    fn rate_limit_expiration(&self) -> Option<Instant> {
-        None
-    }
-
-    async fn embed_batch(
-        &self,
-        spans: Vec<String>,
-        _api_key: Option<String>,
-    ) -> Result<Vec<Embedding>> {
-        self.embedding_count
-            .fetch_add(spans.len(), atomic::Ordering::SeqCst);
-        Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
-    }
-}
-
 fn js_lang() -> Arc<Language> {
     Arc::new(
         Language::new(

crates/ui2/src/elements/icon.rs 🔗

@@ -36,7 +36,7 @@ impl IconColor {
             IconColor::Error => gpui2::red(),
             IconColor::Warning => gpui2::red(),
             IconColor::Success => gpui2::red(),
-            IconColor::Info => gpui2::red()
+            IconColor::Info => gpui2::red(),
         }
     }
 }

crates/zed/examples/semantic_index_eval.rs 🔗

@@ -1,4 +1,4 @@
-use ai::embedding::OpenAIEmbeddings;
+use ai::providers::open_ai::OpenAIEmbeddingProvider;
 use anyhow::{anyhow, Result};
 use client::{self, UserStore};
 use gpui::{AsyncAppContext, ModelHandle, Task};
@@ -475,7 +475,7 @@ fn main() {
             let semantic_index = SemanticIndex::new(
                 fs.clone(),
                 db_file_path,
-                Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+                Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
                 languages.clone(),
                 cx.clone(),
             )

crates/zed2/Cargo.toml 🔗

@@ -15,6 +15,7 @@ name = "Zed"
 path = "src/main.rs"
 
 [dependencies]
+ai2 = { path = "../ai2"}
 # audio = { path = "../audio" }
 # activity_indicator = { path = "../activity_indicator" }
 # auto_update = { path = "../auto_update" }