Detailed changes
@@ -108,6 +108,33 @@ dependencies = [
"util",
]
+[[package]]
+name = "ai2"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "async-trait",
+ "bincode",
+ "futures 0.3.28",
+ "gpui2",
+ "isahc",
+ "language2",
+ "lazy_static",
+ "log",
+ "matrixmultiply",
+ "ordered-float 2.10.0",
+ "parking_lot 0.11.2",
+ "parse_duration",
+ "postage",
+ "rand 0.8.5",
+ "regex",
+ "rusqlite",
+ "serde",
+ "serde_json",
+ "tiktoken-rs",
+ "util",
+]
+
[[package]]
name = "alacritty_config"
version = "0.1.2-dev"
@@ -10903,6 +10930,7 @@ dependencies = [
name = "zed2"
version = "0.109.0"
dependencies = [
+ "ai2",
"anyhow",
"async-compression",
"async-recursion 0.3.2",
@@ -0,0 +1,38 @@
+[package]
+name = "ai"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/ai.rs"
+doctest = false
+
+[features]
+test-support = []
+
+[dependencies]
+gpui = { path = "../gpui" }
+util = { path = "../util" }
+language = { path = "../language" }
+async-trait.workspace = true
+anyhow.workspace = true
+futures.workspace = true
+lazy_static.workspace = true
+ordered-float.workspace = true
+parking_lot.workspace = true
+isahc.workspace = true
+regex.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+postage.workspace = true
+rand.workspace = true
+log.workspace = true
+parse_duration = "2.1.1"
+tiktoken-rs = "0.5.0"
+matrixmultiply = "0.3.7"
+rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
+bincode = "1.3.3"
+
+[dev-dependencies]
+gpui = { path = "../gpui", features = ["test-support"] }
@@ -0,0 +1,38 @@
+[package]
+name = "ai2"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/ai2.rs"
+doctest = false
+
+[features]
+test-support = []
+
+[dependencies]
+gpui2 = { path = "../gpui2" }
+util = { path = "../util" }
+language2 = { path = "../language2" }
+async-trait.workspace = true
+anyhow.workspace = true
+futures.workspace = true
+lazy_static.workspace = true
+ordered-float.workspace = true
+parking_lot.workspace = true
+isahc.workspace = true
+regex.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+postage.workspace = true
+rand.workspace = true
+log.workspace = true
+parse_duration = "2.1.1"
+tiktoken-rs = "0.5.0"
+matrixmultiply = "0.3.7"
+rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
+bincode = "1.3.3"
+
+[dev-dependencies]
+gpui2 = { path = "../gpui2", features = ["test-support"] }
@@ -0,0 +1,8 @@
+pub mod auth;
+pub mod completion;
+pub mod embedding;
+pub mod models;
+pub mod prompts;
+pub mod providers;
+#[cfg(any(test, feature = "test-support"))]
+pub mod test;
@@ -0,0 +1,17 @@
+use async_trait::async_trait;
+use gpui2::AppContext;
+
+#[derive(Clone, Debug)]
+pub enum ProviderCredential {
+ Credentials { api_key: String },
+ NoCredentials,
+ NotNeeded,
+}
+
+#[async_trait]
+pub trait CredentialProvider: Send + Sync {
+ fn has_credentials(&self) -> bool;
+ async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
+ async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
+ async fn delete_credentials(&self, cx: &mut AppContext);
+}
@@ -0,0 +1,23 @@
+use anyhow::Result;
+use futures::{future::BoxFuture, stream::BoxStream};
+
+use crate::{auth::CredentialProvider, models::LanguageModel};
+
+pub trait CompletionRequest: Send + Sync {
+ fn data(&self) -> serde_json::Result<String>;
+}
+
+pub trait CompletionProvider: CredentialProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel>;
+ fn complete(
+ &self,
+ prompt: Box<dyn CompletionRequest>,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+ fn box_clone(&self) -> Box<dyn CompletionProvider>;
+}
+
+impl Clone for Box<dyn CompletionProvider> {
+ fn clone(&self) -> Box<dyn CompletionProvider> {
+ self.box_clone()
+ }
+}
@@ -0,0 +1,123 @@
+use std::time::Instant;
+
+use anyhow::Result;
+use async_trait::async_trait;
+use ordered_float::OrderedFloat;
+use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
+use rusqlite::ToSql;
+
+use crate::auth::CredentialProvider;
+use crate::models::LanguageModel;
+
+#[derive(Debug, PartialEq, Clone)]
+pub struct Embedding(pub Vec<f32>);
+
+// This is needed for semantic index functionality
+// Unfortunately it has to live wherever the "Embedding" struct is created.
+// Keeping this in here though, introduces a 'rusqlite' dependency into AI
+// which is less than ideal
+impl FromSql for Embedding {
+ fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+ let bytes = value.as_blob()?;
+ let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
+ if embedding.is_err() {
+ return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
+ }
+ Ok(Embedding(embedding.unwrap()))
+ }
+}
+
+impl ToSql for Embedding {
+ fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+ let bytes = bincode::serialize(&self.0)
+ .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
+ Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
+ }
+}
+impl From<Vec<f32>> for Embedding {
+ fn from(value: Vec<f32>) -> Self {
+ Embedding(value)
+ }
+}
+
+impl Embedding {
+ pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
+ let len = self.0.len();
+ assert_eq!(len, other.0.len());
+
+ let mut result = 0.0;
+ unsafe {
+ matrixmultiply::sgemm(
+ 1,
+ len,
+ 1,
+ 1.0,
+ self.0.as_ptr(),
+ len as isize,
+ 1,
+ other.0.as_ptr(),
+ 1,
+ len as isize,
+ 0.0,
+ &mut result as *mut f32,
+ 1,
+ 1,
+ );
+ }
+ OrderedFloat(result)
+ }
+}
+
+#[async_trait]
+pub trait EmbeddingProvider: CredentialProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel>;
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
+ fn max_tokens_per_batch(&self) -> usize;
+ fn rate_limit_expiration(&self) -> Option<Instant>;
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use rand::prelude::*;
+
+ #[gpui2::test]
+ fn test_similarity(mut rng: StdRng) {
+ assert_eq!(
+ Embedding::from(vec![1., 0., 0., 0., 0.])
+ .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
+ 0.
+ );
+ assert_eq!(
+ Embedding::from(vec![2., 0., 0., 0., 0.])
+ .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
+ 6.
+ );
+
+ for _ in 0..100 {
+ let size = 1536;
+ let mut a = vec![0.; size];
+ let mut b = vec![0.; size];
+ for (a, b) in a.iter_mut().zip(b.iter_mut()) {
+ *a = rng.gen();
+ *b = rng.gen();
+ }
+ let a = Embedding::from(a);
+ let b = Embedding::from(b);
+
+ assert_eq!(
+ round_to_decimals(a.similarity(&b), 1),
+ round_to_decimals(reference_dot(&a.0, &b.0), 1)
+ );
+ }
+
+ fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
+ let factor = (10.0 as f32).powi(decimal_places);
+ (n * factor).round() / factor
+ }
+
+ fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
+ OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
+ }
+ }
+}
@@ -0,0 +1,16 @@
+pub enum TruncationDirection {
+ Start,
+ End,
+}
+
+pub trait LanguageModel {
+ fn name(&self) -> String;
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String>;
+ fn capacity(&self) -> anyhow::Result<usize>;
+}
@@ -0,0 +1,330 @@
+use std::cmp::Reverse;
+use std::ops::Range;
+use std::sync::Arc;
+
+use language2::BufferSnapshot;
+use util::ResultExt;
+
+use crate::models::LanguageModel;
+use crate::prompts::repository_context::PromptCodeSnippet;
+
+pub(crate) enum PromptFileType {
+ Text,
+ Code,
+}
+
+// TODO: Set this up to manage for defaults well
+pub struct PromptArguments {
+ pub model: Arc<dyn LanguageModel>,
+ pub user_prompt: Option<String>,
+ pub language_name: Option<String>,
+ pub project_name: Option<String>,
+ pub snippets: Vec<PromptCodeSnippet>,
+ pub reserved_tokens: usize,
+ pub buffer: Option<BufferSnapshot>,
+ pub selected_range: Option<Range<usize>>,
+}
+
+impl PromptArguments {
+ pub(crate) fn get_file_type(&self) -> PromptFileType {
+ if self
+ .language_name
+ .as_ref()
+ .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
+ .unwrap_or(true)
+ {
+ PromptFileType::Code
+ } else {
+ PromptFileType::Text
+ }
+ }
+}
+
+pub trait PromptTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)>;
+}
+
+#[repr(i8)]
+#[derive(PartialEq, Eq, Ord)]
+pub enum PromptPriority {
+ Mandatory, // Ignores truncation
+ Ordered { order: usize }, // Truncates based on priority
+}
+
+impl PartialOrd for PromptPriority {
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+ match (self, other) {
+ (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
+ (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
+ (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
+ (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
+ }
+ }
+}
+
+pub struct PromptChain {
+ args: PromptArguments,
+ templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+}
+
+impl PromptChain {
+ pub fn new(
+ args: PromptArguments,
+ templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+ ) -> Self {
+ PromptChain { args, templates }
+ }
+
+ pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
+ // Argsort based on Prompt Priority
+ let seperator = "\n";
+ let seperator_tokens = self.args.model.count_tokens(seperator)?;
+ let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
+ sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
+
+ // If Truncate
+ let mut tokens_outstanding = if truncate {
+ Some(self.args.model.capacity()? - self.args.reserved_tokens)
+ } else {
+ None
+ };
+
+ let mut prompts = vec!["".to_string(); sorted_indices.len()];
+ for idx in sorted_indices {
+ let (_, template) = &self.templates[idx];
+
+ if let Some((template_prompt, prompt_token_count)) =
+ template.generate(&self.args, tokens_outstanding).log_err()
+ {
+ if template_prompt != "" {
+ prompts[idx] = template_prompt;
+
+ if let Some(remaining_tokens) = tokens_outstanding {
+ let new_tokens = prompt_token_count + seperator_tokens;
+ tokens_outstanding = if remaining_tokens > new_tokens {
+ Some(remaining_tokens - new_tokens)
+ } else {
+ Some(0)
+ };
+ }
+ }
+ }
+ }
+
+ prompts.retain(|x| x != "");
+
+ let full_prompt = prompts.join(seperator);
+ let total_token_count = self.args.model.count_tokens(&full_prompt)?;
+ anyhow::Ok((prompts.join(seperator), total_token_count))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod tests {
+ use crate::models::TruncationDirection;
+ use crate::test::FakeLanguageModel;
+
+ use super::*;
+
+ #[test]
+ pub fn test_prompt_chain() {
+ struct TestPromptTemplate {}
+ impl PromptTemplate for TestPromptTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut content = "This is a test prompt template".to_string();
+
+ let mut token_count = args.model.count_tokens(&content)?;
+ if let Some(max_token_length) = max_token_length {
+ if token_count > max_token_length {
+ content = args.model.truncate(
+ &content,
+ max_token_length,
+ TruncationDirection::End,
+ )?;
+ token_count = max_token_length;
+ }
+ }
+
+ anyhow::Ok((content, token_count))
+ }
+ }
+
+ struct TestLowPriorityTemplate {}
+ impl PromptTemplate for TestLowPriorityTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut content = "This is a low priority test prompt template".to_string();
+
+ let mut token_count = args.model.count_tokens(&content)?;
+ if let Some(max_token_length) = max_token_length {
+ if token_count > max_token_length {
+ content = args.model.truncate(
+ &content,
+ max_token_length,
+ TruncationDirection::End,
+ )?;
+ token_count = max_token_length;
+ }
+ }
+
+ anyhow::Ok((content, token_count))
+ }
+ }
+
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(false).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a test prompt template\nThis is a low priority test prompt template"
+ .to_string()
+ );
+
+ assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+ // Testing with Truncation Off
+ // Should ignore capacity and return all prompts
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(false).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a test prompt template\nThis is a low priority test prompt template"
+ .to_string()
+ );
+
+ assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+ // Testing with Truncation Off
+ // Should ignore capacity and return all prompts
+ let capacity = 20;
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 2 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(true).unwrap();
+
+ assert_eq!(prompt, "This is a test promp".to_string());
+ assert_eq!(token_count, capacity);
+
+ // Change Ordering of Prompts Based on Priority
+ let capacity = 120;
+ let reserved_tokens = 10;
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Mandatory,
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(true).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
+ .to_string()
+ );
+ assert_eq!(token_count, capacity - reserved_tokens);
+ }
+}
@@ -0,0 +1,164 @@
+use anyhow::anyhow;
+use language2::BufferSnapshot;
+use language2::ToOffset;
+
+use crate::models::LanguageModel;
+use crate::models::TruncationDirection;
+use crate::prompts::base::PromptArguments;
+use crate::prompts::base::PromptTemplate;
+use std::fmt::Write;
+use std::ops::Range;
+use std::sync::Arc;
+
+fn retrieve_context(
+ buffer: &BufferSnapshot,
+ selected_range: &Option<Range<usize>>,
+ model: Arc<dyn LanguageModel>,
+ max_token_count: Option<usize>,
+) -> anyhow::Result<(String, usize, bool)> {
+ let mut prompt = String::new();
+ let mut truncated = false;
+ if let Some(selected_range) = selected_range {
+ let start = selected_range.start.to_offset(buffer);
+ let end = selected_range.end.to_offset(buffer);
+
+ let start_window = buffer.text_for_range(0..start).collect::<String>();
+
+ let mut selected_window = String::new();
+ if start == end {
+ write!(selected_window, "<|START|>").unwrap();
+ } else {
+ write!(selected_window, "<|START|").unwrap();
+ }
+
+ write!(
+ selected_window,
+ "{}",
+ buffer.text_for_range(start..end).collect::<String>()
+ )
+ .unwrap();
+
+ if start != end {
+ write!(selected_window, "|END|>").unwrap();
+ }
+
+ let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
+
+ if let Some(max_token_count) = max_token_count {
+ let selected_tokens = model.count_tokens(&selected_window)?;
+ if selected_tokens > max_token_count {
+ return Err(anyhow!(
+ "selected range is greater than model context window, truncation not possible"
+ ));
+ };
+
+ let mut remaining_tokens = max_token_count - selected_tokens;
+ let start_window_tokens = model.count_tokens(&start_window)?;
+ let end_window_tokens = model.count_tokens(&end_window)?;
+ let outside_tokens = start_window_tokens + end_window_tokens;
+ if outside_tokens > remaining_tokens {
+ let (start_goal_tokens, end_goal_tokens) =
+ if start_window_tokens < end_window_tokens {
+ let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
+ remaining_tokens -= start_goal_tokens;
+ let end_goal_tokens = remaining_tokens.min(end_window_tokens);
+ (start_goal_tokens, end_goal_tokens)
+ } else {
+ let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
+ remaining_tokens -= end_goal_tokens;
+ let start_goal_tokens = remaining_tokens.min(start_window_tokens);
+ (start_goal_tokens, end_goal_tokens)
+ };
+
+ let truncated_start_window =
+ model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
+ let truncated_end_window =
+ model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
+ writeln!(
+ prompt,
+ "{truncated_start_window}{selected_window}{truncated_end_window}"
+ )
+ .unwrap();
+ truncated = true;
+ } else {
+ writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
+ }
+ } else {
+ // If we dont have a selected range, include entire file.
+ writeln!(prompt, "{}", &buffer.text()).unwrap();
+
+ // Dumb truncation strategy
+ if let Some(max_token_count) = max_token_count {
+ if model.count_tokens(&prompt)? > max_token_count {
+ truncated = true;
+ prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
+ }
+ }
+ }
+ }
+
+ let token_count = model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count, truncated))
+}
+
+pub struct FileContext {}
+
+impl PromptTemplate for FileContext {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ if let Some(buffer) = &args.buffer {
+ let mut prompt = String::new();
+ // Add Initial Preamble
+ // TODO: Do we want to add the path in here?
+ writeln!(
+ prompt,
+ "The file you are currently working on has the following content:"
+ )
+ .unwrap();
+
+ let language_name = args
+ .language_name
+ .clone()
+ .unwrap_or("".to_string())
+ .to_lowercase();
+
+ let (context, _, truncated) = retrieve_context(
+ buffer,
+ &args.selected_range,
+ args.model.clone(),
+ max_token_length,
+ )?;
+ writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
+
+ if truncated {
+ writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
+ }
+
+ if let Some(selected_range) = &args.selected_range {
+ let start = selected_range.start.to_offset(buffer);
+ let end = selected_range.end.to_offset(buffer);
+
+ if start == end {
+ writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
+ } else {
+ writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
+ }
+ }
+
+ // Really dumb truncation strategy
+ if let Some(max_tokens) = max_token_length {
+ prompt = args
+ .model
+ .truncate(&prompt, max_tokens, TruncationDirection::End)?;
+ }
+
+ let token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count))
+ } else {
+ Err(anyhow!("no buffer provided to retrieve file context from"))
+ }
+ }
+}
@@ -0,0 +1,99 @@
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
+use anyhow::anyhow;
+use std::fmt::Write;
+
+pub fn capitalize(s: &str) -> String {
+ let mut c = s.chars();
+ match c.next() {
+ None => String::new(),
+ Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
+ }
+}
+
+pub struct GenerateInlineContent {}
+
+impl PromptTemplate for GenerateInlineContent {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let Some(user_prompt) = &args.user_prompt else {
+ return Err(anyhow!("user prompt not provided"));
+ };
+
+ let file_type = args.get_file_type();
+ let content_type = match &file_type {
+ PromptFileType::Code => "code",
+ PromptFileType::Text => "text",
+ };
+
+ let mut prompt = String::new();
+
+ if let Some(selected_range) = &args.selected_range {
+ if selected_range.start == selected_range.end {
+ writeln!(
+ prompt,
+ "Assume the cursor is located where the `<|START|>` span is."
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "{} can't be replaced, so assume your answer will be inserted at the cursor.",
+ capitalize(content_type)
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "Generate {content_type} based on the users prompt: {user_prompt}",
+ )
+ .unwrap();
+ } else {
+ writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
+ writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
+ writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
+ }
+ } else {
+ writeln!(
+ prompt,
+ "Generate {content_type} based on the users prompt: {user_prompt}"
+ )
+ .unwrap();
+ }
+
+ if let Some(language_name) = &args.language_name {
+ writeln!(
+ prompt,
+ "Your answer MUST always and only be valid {}.",
+ language_name
+ )
+ .unwrap();
+ }
+ writeln!(prompt, "Never make remarks about the output.").unwrap();
+ writeln!(
+ prompt,
+ "Do not return anything else, except the generated {content_type}."
+ )
+ .unwrap();
+
+ match file_type {
+ PromptFileType::Code => {
+ // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
+ }
+ _ => {}
+ }
+
+ // Really dumb truncation strategy
+ if let Some(max_tokens) = max_token_length {
+ prompt = args.model.truncate(
+ &prompt,
+ max_tokens,
+ crate::models::TruncationDirection::End,
+ )?;
+ }
+
+ let token_count = args.model.count_tokens(&prompt)?;
+
+ anyhow::Ok((prompt, token_count))
+ }
+}
@@ -0,0 +1,5 @@
+pub mod base;
+pub mod file_context;
+pub mod generate;
+pub mod preamble;
+pub mod repository_context;
@@ -0,0 +1,52 @@
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
+use std::fmt::Write;
+
+pub struct EngineerPreamble {}
+
+impl PromptTemplate for EngineerPreamble {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut prompts = Vec::new();
+
+ match args.get_file_type() {
+ PromptFileType::Code => {
+ prompts.push(format!(
+ "You are an expert {}engineer.",
+ args.language_name.clone().unwrap_or("".to_string()) + " "
+ ));
+ }
+ PromptFileType::Text => {
+ prompts.push("You are an expert engineer.".to_string());
+ }
+ }
+
+ if let Some(project_name) = args.project_name.clone() {
+ prompts.push(format!(
+ "You are currently working inside the '{project_name}' project in code editor Zed."
+ ));
+ }
+
+ if let Some(mut remaining_tokens) = max_token_length {
+ let mut prompt = String::new();
+ let mut total_count = 0;
+ for prompt_piece in prompts {
+ let prompt_token_count =
+ args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
+ if remaining_tokens > prompt_token_count {
+ writeln!(prompt, "{prompt_piece}").unwrap();
+ remaining_tokens -= prompt_token_count;
+ total_count += prompt_token_count;
+ }
+ }
+
+ anyhow::Ok((prompt, total_count))
+ } else {
+ let prompt = prompts.join("\n");
+ let token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count))
+ }
+ }
+}
@@ -0,0 +1,98 @@
+use crate::prompts::base::{PromptArguments, PromptTemplate};
+use std::fmt::Write;
+use std::{ops::Range, path::PathBuf};
+
+use gpui2::{AsyncAppContext, Handle};
+use language2::{Anchor, Buffer};
+
+#[derive(Clone)]
+pub struct PromptCodeSnippet {
+ path: Option<PathBuf>,
+ language_name: Option<String>,
+ content: String,
+}
+
+impl PromptCodeSnippet {
+ pub fn new(
+ buffer: Handle<Buffer>,
+ range: Range<Anchor>,
+ cx: &mut AsyncAppContext,
+ ) -> anyhow::Result<Self> {
+ let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
+ let snapshot = buffer.snapshot();
+ let content = snapshot.text_for_range(range.clone()).collect::<String>();
+
+ let language_name = buffer
+ .language()
+ .and_then(|language| Some(language.name().to_string().to_lowercase()));
+
+ let file_path = buffer
+ .file()
+ .and_then(|file| Some(file.path().to_path_buf()));
+
+ (content, language_name, file_path)
+ })?;
+
+ anyhow::Ok(PromptCodeSnippet {
+ path: file_path,
+ language_name,
+ content,
+ })
+ }
+}
+
+impl ToString for PromptCodeSnippet {
+ fn to_string(&self) -> String {
+ let path = self
+ .path
+ .as_ref()
+ .and_then(|path| Some(path.to_string_lossy().to_string()))
+ .unwrap_or("".to_string());
+ let language_name = self.language_name.clone().unwrap_or("".to_string());
+ let content = self.content.clone();
+
+ format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+ }
+}
+
+pub struct RepositoryContext {}
+
+impl PromptTemplate for RepositoryContext {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+ let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
+ let mut prompt = String::new();
+
+ let mut remaining_tokens = max_token_length.clone();
+ let seperator_token_length = args.model.count_tokens("\n")?;
+ for snippet in &args.snippets {
+ let mut snippet_prompt = template.to_string();
+ let content = snippet.to_string();
+ writeln!(snippet_prompt, "{content}").unwrap();
+
+ let token_count = args.model.count_tokens(&snippet_prompt)?;
+ if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
+ if let Some(tokens_left) = remaining_tokens {
+ if tokens_left >= token_count {
+ writeln!(prompt, "{snippet_prompt}").unwrap();
+ remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
+ {
+ Some(tokens_left - token_count - seperator_token_length)
+ } else {
+ Some(0)
+ };
+ }
+ } else {
+ writeln!(prompt, "{snippet_prompt}").unwrap();
+ }
+ }
+ }
+
+ let total_token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, total_token_count))
+ }
+}
@@ -0,0 +1 @@
+pub mod open_ai;
@@ -0,0 +1,306 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::{
+ future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
+ Stream, StreamExt,
+};
+use gpui2::{AppContext, Executor};
+use isahc::{http::StatusCode, Request, RequestExt};
+use parking_lot::RwLock;
+use serde::{Deserialize, Serialize};
+use std::{
+ env,
+ fmt::{self, Display},
+ io,
+ sync::Arc,
+};
+use util::ResultExt;
+
+use crate::{
+ auth::{CredentialProvider, ProviderCredential},
+ completion::{CompletionProvider, CompletionRequest},
+ models::LanguageModel,
+};
+
+use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+ System,
+}
+
+impl Role {
+ pub fn cycle(&mut self) {
+ *self = match self {
+ Role::User => Role::Assistant,
+ Role::Assistant => Role::System,
+ Role::System => Role::User,
+ }
+ }
+}
+
+impl Display for Role {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Role::User => write!(f, "User"),
+ Role::Assistant => write!(f, "Assistant"),
+ Role::System => write!(f, "System"),
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+ pub role: Role,
+ pub content: String,
+}
+
+#[derive(Debug, Default, Serialize)]
+pub struct OpenAIRequest {
+ pub model: String,
+ pub messages: Vec<RequestMessage>,
+ pub stream: bool,
+ pub stop: Vec<String>,
+ pub temperature: f32,
+}
+
+impl CompletionRequest for OpenAIRequest {
+ fn data(&self) -> serde_json::Result<String> {
+ serde_json::to_string(self)
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessage {
+ pub role: Option<Role>,
+ pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIUsage {
+ pub prompt_tokens: u32,
+ pub completion_tokens: u32,
+ pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ChatChoiceDelta {
+ pub index: u32,
+ pub delta: ResponseMessage,
+ pub finish_reason: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIResponseStreamEvent {
+ pub id: Option<String>,
+ pub object: String,
+ pub created: u32,
+ pub model: String,
+ pub choices: Vec<ChatChoiceDelta>,
+ pub usage: Option<OpenAIUsage>,
+}
+
+pub async fn stream_completion(
+ credential: ProviderCredential,
+ executor: Arc<Executor>,
+ request: Box<dyn CompletionRequest>,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+ let api_key = match credential {
+ ProviderCredential::Credentials { api_key } => api_key,
+ _ => {
+ return Err(anyhow!("no credentials provider for completion"));
+ }
+ };
+
+ let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+ let json_data = request.data()?;
+ let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(json_data)?
+ .send_async()
+ .await?;
+
+ let status = response.status();
+ if status == StatusCode::OK {
+ executor
+ .spawn(async move {
+ let mut lines = BufReader::new(response.body_mut()).lines();
+
+ fn parse_line(
+ line: Result<String, io::Error>,
+ ) -> Result<Option<OpenAIResponseStreamEvent>> {
+ if let Some(data) = line?.strip_prefix("data: ") {
+ let event = serde_json::from_str(&data)?;
+ Ok(Some(event))
+ } else {
+ Ok(None)
+ }
+ }
+
+ while let Some(line) = lines.next().await {
+ if let Some(event) = parse_line(line).transpose() {
+ let done = event.as_ref().map_or(false, |event| {
+ event
+ .choices
+ .last()
+ .map_or(false, |choice| choice.finish_reason.is_some())
+ });
+ if tx.unbounded_send(event).is_err() {
+ break;
+ }
+
+ if done {
+ break;
+ }
+ }
+ }
+
+ anyhow::Ok(())
+ })
+ .detach();
+
+ Ok(rx)
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ #[derive(Deserialize)]
+ struct OpenAIResponse {
+ error: OpenAIError,
+ }
+
+ #[derive(Deserialize)]
+ struct OpenAIError {
+ message: String,
+ }
+
+ match serde_json::from_str::<OpenAIResponse>(&body) {
+ Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+ "Failed to connect to OpenAI API: {}",
+ response.error.message,
+ )),
+
+ _ => Err(anyhow!(
+ "Failed to connect to OpenAI API: {} {}",
+ response.status(),
+ body,
+ )),
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct OpenAICompletionProvider {
+ model: OpenAILanguageModel,
+ credential: Arc<RwLock<ProviderCredential>>,
+ executor: Arc<Executor>,
+}
+
+impl OpenAICompletionProvider {
+ pub fn new(model_name: &str, executor: Arc<Executor>) -> Self {
+ let model = OpenAILanguageModel::load(model_name);
+ let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+ Self {
+ model,
+ credential,
+ executor,
+ }
+ }
+}
+
+#[async_trait]
+impl CredentialProvider for OpenAICompletionProvider {
+ fn has_credentials(&self) -> bool {
+ match *self.credential.read() {
+ ProviderCredential::Credentials { .. } => true,
+ _ => false,
+ }
+ }
+ async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
+ let existing_credential = self.credential.read().clone();
+
+ let retrieved_credential = cx
+ .run_on_main(move |cx| match existing_credential {
+ ProviderCredential::Credentials { .. } => {
+ return existing_credential.clone();
+ }
+ _ => {
+ if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
+ return ProviderCredential::Credentials { api_key };
+ }
+
+ if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
+ {
+ if let Some(api_key) = String::from_utf8(api_key).log_err() {
+ return ProviderCredential::Credentials { api_key };
+ } else {
+ return ProviderCredential::NoCredentials;
+ }
+ } else {
+ return ProviderCredential::NoCredentials;
+ }
+ }
+ })
+ .await;
+
+ *self.credential.write() = retrieved_credential.clone();
+ retrieved_credential
+ }
+
+ async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
+ *self.credential.write() = credential.clone();
+ let credential = credential.clone();
+ cx.run_on_main(move |cx| match credential {
+ ProviderCredential::Credentials { api_key } => {
+ cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ .log_err();
+ }
+ _ => {}
+ })
+ .await;
+ }
+ async fn delete_credentials(&self, cx: &mut AppContext) {
+ cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
+ .await;
+ *self.credential.write() = ProviderCredential::NoCredentials;
+ }
+}
+
+impl CompletionProvider for OpenAICompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+ model
+ }
+ fn complete(
+ &self,
+ prompt: Box<dyn CompletionRequest>,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
+ // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
+ // which is currently model based, due to the langauge model.
+ // At some point in the future we should rectify this.
+ let credential = self.credential.read().clone();
+ let request = stream_completion(credential, self.executor.clone(), prompt);
+ async move {
+ let response = request.await?;
+ let stream = response
+ .filter_map(|response| async move {
+ match response {
+ Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed();
+ Ok(stream)
+ }
+ .boxed()
+ }
+ fn box_clone(&self) -> Box<dyn CompletionProvider> {
+ Box::new((*self).clone())
+ }
+}
@@ -0,0 +1,313 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::AsyncReadExt;
+use gpui2::Executor;
+use gpui2::{serde_json, AppContext};
+use isahc::http::StatusCode;
+use isahc::prelude::Configurable;
+use isahc::{AsyncBody, Response};
+use lazy_static::lazy_static;
+use parking_lot::{Mutex, RwLock};
+use parse_duration::parse;
+use postage::watch;
+use serde::{Deserialize, Serialize};
+use std::env;
+use std::ops::Add;
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use tiktoken_rs::{cl100k_base, CoreBPE};
+use util::http::{HttpClient, Request};
+use util::ResultExt;
+
+use crate::auth::{CredentialProvider, ProviderCredential};
+use crate::embedding::{Embedding, EmbeddingProvider};
+use crate::models::LanguageModel;
+use crate::providers::open_ai::OpenAILanguageModel;
+
+use crate::providers::open_ai::OPENAI_API_URL;
+
+lazy_static! {
+ static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
+}
+
+#[derive(Clone)]
+pub struct OpenAIEmbeddingProvider {
+ model: OpenAILanguageModel,
+ credential: Arc<RwLock<ProviderCredential>>,
+ pub client: Arc<dyn HttpClient>,
+ pub executor: Arc<Executor>,
+ rate_limit_count_rx: watch::Receiver<Option<Instant>>,
+ rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
+}
+
+#[derive(Serialize)]
+struct OpenAIEmbeddingRequest<'a> {
+ model: &'static str,
+ input: Vec<&'a str>,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingResponse {
+ data: Vec<OpenAIEmbedding>,
+ usage: OpenAIEmbeddingUsage,
+}
+
+#[derive(Debug, Deserialize)]
+struct OpenAIEmbedding {
+ embedding: Vec<f32>,
+ index: usize,
+ object: String,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingUsage {
+ prompt_tokens: usize,
+ total_tokens: usize,
+}
+
+impl OpenAIEmbeddingProvider {
+ pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Executor>) -> Self {
+ let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
+ let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
+
+ let model = OpenAILanguageModel::load("text-embedding-ada-002");
+ let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+
+ OpenAIEmbeddingProvider {
+ model,
+ credential,
+ client,
+ executor,
+ rate_limit_count_rx,
+ rate_limit_count_tx,
+ }
+ }
+
+ fn get_api_key(&self) -> Result<String> {
+ match self.credential.read().clone() {
+ ProviderCredential::Credentials { api_key } => Ok(api_key),
+ _ => Err(anyhow!("api credentials not provided")),
+ }
+ }
+
+ fn resolve_rate_limit(&self) {
+ let reset_time = *self.rate_limit_count_tx.lock().borrow();
+
+ if let Some(reset_time) = reset_time {
+ if Instant::now() >= reset_time {
+ *self.rate_limit_count_tx.lock().borrow_mut() = None
+ }
+ }
+
+ log::trace!(
+ "resolving reset time: {:?}",
+ *self.rate_limit_count_tx.lock().borrow()
+ );
+ }
+
+ fn update_reset_time(&self, reset_time: Instant) {
+ let original_time = *self.rate_limit_count_tx.lock().borrow();
+
+ let updated_time = if let Some(original_time) = original_time {
+ if reset_time < original_time {
+ Some(reset_time)
+ } else {
+ Some(original_time)
+ }
+ } else {
+ Some(reset_time)
+ };
+
+ log::trace!("updating rate limit time: {:?}", updated_time);
+
+ *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
+ }
+ async fn send_request(
+ &self,
+ api_key: &str,
+ spans: Vec<&str>,
+ request_timeout: u64,
+ ) -> Result<Response<AsyncBody>> {
+ let request = Request::post("https://api.openai.com/v1/embeddings")
+ .redirect_policy(isahc::config::RedirectPolicy::Follow)
+ .timeout(Duration::from_secs(request_timeout))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(
+ serde_json::to_string(&OpenAIEmbeddingRequest {
+ input: spans.clone(),
+ model: "text-embedding-ada-002",
+ })
+ .unwrap()
+ .into(),
+ )?;
+
+ Ok(self.client.send(request).await?)
+ }
+}
+
+#[async_trait]
+impl CredentialProvider for OpenAIEmbeddingProvider {
+ fn has_credentials(&self) -> bool {
+ match *self.credential.read() {
+ ProviderCredential::Credentials { .. } => true,
+ _ => false,
+ }
+ }
+ async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
+ let existing_credential = self.credential.read().clone();
+
+ let retrieved_credential = cx
+ .run_on_main(move |cx| match existing_credential {
+ ProviderCredential::Credentials { .. } => {
+ return existing_credential.clone();
+ }
+ _ => {
+ if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
+ return ProviderCredential::Credentials { api_key };
+ }
+
+ if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
+ {
+ if let Some(api_key) = String::from_utf8(api_key).log_err() {
+ return ProviderCredential::Credentials { api_key };
+ } else {
+ return ProviderCredential::NoCredentials;
+ }
+ } else {
+ return ProviderCredential::NoCredentials;
+ }
+ }
+ })
+ .await;
+
+ *self.credential.write() = retrieved_credential.clone();
+ retrieved_credential
+ }
+
+ async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
+ *self.credential.write() = credential.clone();
+ let credential = credential.clone();
+ cx.run_on_main(move |cx| match credential {
+ ProviderCredential::Credentials { api_key } => {
+ cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ .log_err();
+ }
+ _ => {}
+ })
+ .await;
+ }
+ async fn delete_credentials(&self, cx: &mut AppContext) {
+ cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
+ .await;
+ *self.credential.write() = ProviderCredential::NoCredentials;
+ }
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddingProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+ model
+ }
+
+ fn max_tokens_per_batch(&self) -> usize {
+ 50000
+ }
+
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ *self.rate_limit_count_rx.borrow()
+ }
+
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+ const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
+ const MAX_RETRIES: usize = 4;
+
+ let api_key = self.get_api_key()?;
+
+ let mut request_number = 0;
+ let mut rate_limiting = false;
+ let mut request_timeout: u64 = 15;
+ let mut response: Response<AsyncBody>;
+ while request_number < MAX_RETRIES {
+ response = self
+ .send_request(
+ &api_key,
+ spans.iter().map(|x| &**x).collect(),
+ request_timeout,
+ )
+ .await?;
+
+ request_number += 1;
+
+ match response.status() {
+ StatusCode::REQUEST_TIMEOUT => {
+ request_timeout += 5;
+ }
+ StatusCode::OK => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+ log::trace!(
+ "openai embedding completed. tokens: {:?}",
+ response.usage.total_tokens
+ );
+
+ // If we complete a request successfully that was previously rate_limited
+ // resolve the rate limit
+ if rate_limiting {
+ self.resolve_rate_limit()
+ }
+
+ return Ok(response
+ .data
+ .into_iter()
+ .map(|embedding| Embedding::from(embedding.embedding))
+ .collect());
+ }
+ StatusCode::TOO_MANY_REQUESTS => {
+ rate_limiting = true;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ let delay_duration = {
+ let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+ if let Some(time_to_reset) =
+ response.headers().get("x-ratelimit-reset-tokens")
+ {
+ if let Ok(time_str) = time_to_reset.to_str() {
+ parse(time_str).unwrap_or(delay)
+ } else {
+ delay
+ }
+ } else {
+ delay
+ }
+ };
+
+ // If we've previously rate limited, increment the duration but not the count
+ let reset_time = Instant::now().add(delay_duration);
+ self.update_reset_time(reset_time);
+
+ log::trace!(
+ "openai rate limiting: waiting {:?} until lifted",
+ &delay_duration
+ );
+
+ self.executor.timer(delay_duration).await;
+ }
+ _ => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow!(
+ "open ai bad request: {:?} {:?}",
+ &response.status(),
+ body
+ ));
+ }
+ }
+ }
+ Err(anyhow!("openai max retries"))
+ }
+}
@@ -0,0 +1,9 @@
+pub mod completion;
+pub mod embedding;
+pub mod model;
+
+pub use completion::*;
+pub use embedding::*;
+pub use model::OpenAILanguageModel;
+
+pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
@@ -0,0 +1,57 @@
+use anyhow::anyhow;
+use tiktoken_rs::CoreBPE;
+use util::ResultExt;
+
+use crate::models::{LanguageModel, TruncationDirection};
+
+#[derive(Clone)]
+pub struct OpenAILanguageModel {
+ name: String,
+ bpe: Option<CoreBPE>,
+}
+
+impl OpenAILanguageModel {
+ pub fn load(model_name: &str) -> Self {
+ let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
+ OpenAILanguageModel {
+ name: model_name.to_string(),
+ bpe,
+ }
+ }
+}
+
+impl LanguageModel for OpenAILanguageModel {
+ fn name(&self) -> String {
+ self.name.clone()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ if let Some(bpe) = &self.bpe {
+ anyhow::Ok(bpe.encode_with_special_tokens(content).len())
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
+ if let Some(bpe) = &self.bpe {
+ let tokens = bpe.encode_with_special_tokens(content);
+ if tokens.len() > length {
+ match direction {
+ TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
+ TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
+ }
+ } else {
+ bpe.decode(tokens)
+ }
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
+ }
+}
@@ -0,0 +1,11 @@
+pub trait LanguageModel {
+ fn name(&self) -> String;
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String>;
+ fn capacity(&self) -> anyhow::Result<usize>;
+}
@@ -0,0 +1,193 @@
+use std::{
+ sync::atomic::{self, AtomicUsize, Ordering},
+ time::Instant,
+};
+
+use async_trait::async_trait;
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui2::AppContext;
+use parking_lot::Mutex;
+
+use crate::{
+ auth::{CredentialProvider, ProviderCredential},
+ completion::{CompletionProvider, CompletionRequest},
+ embedding::{Embedding, EmbeddingProvider},
+ models::{LanguageModel, TruncationDirection},
+};
+
+#[derive(Clone)]
+pub struct FakeLanguageModel {
+ pub capacity: usize,
+}
+
+impl LanguageModel for FakeLanguageModel {
+ fn name(&self) -> String {
+ "dummy".to_string()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+ }
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
+ println!("TRYING TO TRUNCATE: {:?}", length.clone());
+
+ if length > self.count_tokens(content)? {
+ println!("NOT TRUNCATING");
+ return anyhow::Ok(content.to_string());
+ }
+
+ anyhow::Ok(match direction {
+ TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
+ .into_iter()
+ .collect::<String>(),
+ TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
+ .into_iter()
+ .collect::<String>(),
+ })
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(self.capacity)
+ }
+}
+
+pub struct FakeEmbeddingProvider {
+ pub embedding_count: AtomicUsize,
+}
+
+impl Clone for FakeEmbeddingProvider {
+ fn clone(&self) -> Self {
+ FakeEmbeddingProvider {
+ embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
+ }
+ }
+}
+
+impl Default for FakeEmbeddingProvider {
+ fn default() -> Self {
+ FakeEmbeddingProvider {
+ embedding_count: AtomicUsize::default(),
+ }
+ }
+}
+
+impl FakeEmbeddingProvider {
+ pub fn embedding_count(&self) -> usize {
+ self.embedding_count.load(atomic::Ordering::SeqCst)
+ }
+
+ pub fn embed_sync(&self, span: &str) -> Embedding {
+ let mut result = vec![1.0; 26];
+ for letter in span.chars() {
+ let letter = letter.to_ascii_lowercase();
+ if letter as u32 >= 'a' as u32 {
+ let ix = (letter as u32) - ('a' as u32);
+ if ix < 26 {
+ result[ix as usize] += 1.0;
+ }
+ }
+ }
+
+ let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+ for x in &mut result {
+ *x /= norm;
+ }
+
+ result.into()
+ }
+}
+
+#[async_trait]
+impl CredentialProvider for FakeEmbeddingProvider {
+ fn has_credentials(&self) -> bool {
+ true
+ }
+ async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
+ ProviderCredential::NotNeeded
+ }
+ async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
+ async fn delete_credentials(&self, _cx: &mut AppContext) {}
+}
+
+#[async_trait]
+impl EmbeddingProvider for FakeEmbeddingProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ Box::new(FakeLanguageModel { capacity: 1000 })
+ }
+ fn max_tokens_per_batch(&self) -> usize {
+ 1000
+ }
+
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ None
+ }
+
+ async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
+ self.embedding_count
+ .fetch_add(spans.len(), atomic::Ordering::SeqCst);
+
+ anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
+ }
+}
+
+pub struct FakeCompletionProvider {
+ last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
+}
+
+impl Clone for FakeCompletionProvider {
+ fn clone(&self) -> Self {
+ Self {
+ last_completion_tx: Mutex::new(None),
+ }
+ }
+}
+
+impl FakeCompletionProvider {
+ pub fn new() -> Self {
+ Self {
+ last_completion_tx: Mutex::new(None),
+ }
+ }
+
+ pub fn send_completion(&self, completion: impl Into<String>) {
+ let mut tx = self.last_completion_tx.lock();
+ tx.as_mut().unwrap().try_send(completion.into()).unwrap();
+ }
+
+ pub fn finish_completion(&self) {
+ self.last_completion_tx.lock().take().unwrap();
+ }
+}
+
+#[async_trait]
+impl CredentialProvider for FakeCompletionProvider {
+ fn has_credentials(&self) -> bool {
+ true
+ }
+ async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
+ ProviderCredential::NotNeeded
+ }
+ async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
+ async fn delete_credentials(&self, _cx: &mut AppContext) {}
+}
+
+impl CompletionProvider for FakeCompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
+ model
+ }
+ fn complete(
+ &self,
+ _prompt: Box<dyn CompletionRequest>,
+ ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
+ let (tx, rx) = mpsc::channel(1);
+ *self.last_completion_tx.lock() = Some(tx);
+ async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
+ }
+ fn box_clone(&self) -> Box<dyn CompletionProvider> {
+ Box::new((*self).clone())
+ }
+}
@@ -15,6 +15,7 @@ name = "Zed"
path = "src/main.rs"
[dependencies]
+ai2 = { path = "../ai2"}
# audio = { path = "../audio" }
# activity_indicator = { path = "../activity_indicator" }
# auto_update = { path = "../auto_update" }