Detailed changes
@@ -108,6 +108,33 @@ dependencies = [
"util",
]
+[[package]]
+name = "ai2"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "async-trait",
+ "bincode",
+ "futures 0.3.28",
+ "gpui2",
+ "isahc",
+ "language2",
+ "lazy_static",
+ "log",
+ "matrixmultiply",
+ "ordered-float 2.10.0",
+ "parking_lot 0.11.2",
+ "parse_duration",
+ "postage",
+ "rand 0.8.5",
+ "regex",
+ "rusqlite",
+ "serde",
+ "serde_json",
+ "tiktoken-rs",
+ "util",
+]
+
[[package]]
name = "alacritty_config"
version = "0.1.2-dev"
@@ -10903,6 +10930,7 @@ dependencies = [
name = "zed2"
version = "0.109.0"
dependencies = [
+ "ai2",
"anyhow",
"async-compression",
"async-recursion 0.3.2",
@@ -0,0 +1,38 @@
+[package]
+name = "ai"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/ai.rs"
+doctest = false
+
+[features]
+test-support = []
+
+[dependencies]
+gpui = { path = "../gpui" }
+util = { path = "../util" }
+language = { path = "../language" }
+async-trait.workspace = true
+anyhow.workspace = true
+futures.workspace = true
+lazy_static.workspace = true
+ordered-float.workspace = true
+parking_lot.workspace = true
+isahc.workspace = true
+regex.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+postage.workspace = true
+rand.workspace = true
+log.workspace = true
+parse_duration = "2.1.1"
+tiktoken-rs = "0.5.0"
+matrixmultiply = "0.3.7"
+rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
+bincode = "1.3.3"
+
+[dev-dependencies]
+gpui = { path = "../gpui", features = ["test-support"] }
@@ -8,6 +8,9 @@ publish = false
path = "src/ai.rs"
doctest = false
+[features]
+test-support = []
+
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
@@ -1,4 +1,8 @@
+pub mod auth;
pub mod completion;
pub mod embedding;
pub mod models;
-pub mod templates;
+pub mod prompts;
+pub mod providers;
+#[cfg(any(test, feature = "test-support"))]
+pub mod test;
@@ -0,0 +1,15 @@
+use gpui::AppContext;
+
+#[derive(Clone, Debug)]
+pub enum ProviderCredential {
+ Credentials { api_key: String },
+ NoCredentials,
+ NotNeeded,
+}
+
+pub trait CredentialProvider: Send + Sync {
+ fn has_credentials(&self) -> bool;
+ fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
+ fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
+ fn delete_credentials(&self, cx: &AppContext);
+}
@@ -1,214 +1,23 @@
-use anyhow::{anyhow, Result};
-use futures::{
- future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
- Stream, StreamExt,
-};
-use gpui::executor::Background;
-use isahc::{http::StatusCode, Request, RequestExt};
-use serde::{Deserialize, Serialize};
-use std::{
- fmt::{self, Display},
- io,
- sync::Arc,
-};
+use anyhow::Result;
+use futures::{future::BoxFuture, stream::BoxStream};
-pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
+use crate::{auth::CredentialProvider, models::LanguageModel};
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-pub enum Role {
- User,
- Assistant,
- System,
+pub trait CompletionRequest: Send + Sync {
+ fn data(&self) -> serde_json::Result<String>;
}
-impl Role {
- pub fn cycle(&mut self) {
- *self = match self {
- Role::User => Role::Assistant,
- Role::Assistant => Role::System,
- Role::System => Role::User,
- }
- }
-}
-
-impl Display for Role {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- Role::User => write!(f, "User"),
- Role::Assistant => write!(f, "Assistant"),
- Role::System => write!(f, "System"),
- }
- }
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct RequestMessage {
- pub role: Role,
- pub content: String,
-}
-
-#[derive(Debug, Default, Serialize)]
-pub struct OpenAIRequest {
- pub model: String,
- pub messages: Vec<RequestMessage>,
- pub stream: bool,
- pub stop: Vec<String>,
- pub temperature: f32,
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct ResponseMessage {
- pub role: Option<Role>,
- pub content: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAIUsage {
- pub prompt_tokens: u32,
- pub completion_tokens: u32,
- pub total_tokens: u32,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct ChatChoiceDelta {
- pub index: u32,
- pub delta: ResponseMessage,
- pub finish_reason: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAIResponseStreamEvent {
- pub id: Option<String>,
- pub object: String,
- pub created: u32,
- pub model: String,
- pub choices: Vec<ChatChoiceDelta>,
- pub usage: Option<OpenAIUsage>,
-}
-
-pub async fn stream_completion(
- api_key: String,
- executor: Arc<Background>,
- mut request: OpenAIRequest,
-) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
- request.stream = true;
-
- let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
-
- let json_data = serde_json::to_string(&request)?;
- let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key))
- .body(json_data)?
- .send_async()
- .await?;
-
- let status = response.status();
- if status == StatusCode::OK {
- executor
- .spawn(async move {
- let mut lines = BufReader::new(response.body_mut()).lines();
-
- fn parse_line(
- line: Result<String, io::Error>,
- ) -> Result<Option<OpenAIResponseStreamEvent>> {
- if let Some(data) = line?.strip_prefix("data: ") {
- let event = serde_json::from_str(&data)?;
- Ok(Some(event))
- } else {
- Ok(None)
- }
- }
-
- while let Some(line) = lines.next().await {
- if let Some(event) = parse_line(line).transpose() {
- let done = event.as_ref().map_or(false, |event| {
- event
- .choices
- .last()
- .map_or(false, |choice| choice.finish_reason.is_some())
- });
- if tx.unbounded_send(event).is_err() {
- break;
- }
-
- if done {
- break;
- }
- }
- }
-
- anyhow::Ok(())
- })
- .detach();
-
- Ok(rx)
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- #[derive(Deserialize)]
- struct OpenAIResponse {
- error: OpenAIError,
- }
-
- #[derive(Deserialize)]
- struct OpenAIError {
- message: String,
- }
-
- match serde_json::from_str::<OpenAIResponse>(&body) {
- Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
- "Failed to connect to OpenAI API: {}",
- response.error.message,
- )),
-
- _ => Err(anyhow!(
- "Failed to connect to OpenAI API: {} {}",
- response.status(),
- body,
- )),
- }
- }
-}
-
-pub trait CompletionProvider {
+pub trait CompletionProvider: CredentialProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete(
&self,
- prompt: OpenAIRequest,
+ prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+ fn box_clone(&self) -> Box<dyn CompletionProvider>;
}
-pub struct OpenAICompletionProvider {
- api_key: String,
- executor: Arc<Background>,
-}
-
-impl OpenAICompletionProvider {
- pub fn new(api_key: String, executor: Arc<Background>) -> Self {
- Self { api_key, executor }
- }
-}
-
-impl CompletionProvider for OpenAICompletionProvider {
- fn complete(
- &self,
- prompt: OpenAIRequest,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
- async move {
- let response = request.await?;
- let stream = response
- .filter_map(|response| async move {
- match response {
- Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
- Err(error) => Some(Err(error)),
- }
- })
- .boxed();
- Ok(stream)
- }
- .boxed()
+impl Clone for Box<dyn CompletionProvider> {
+ fn clone(&self) -> Box<dyn CompletionProvider> {
+ self.box_clone()
}
}
@@ -1,32 +1,13 @@
-use anyhow::{anyhow, Result};
+use std::time::Instant;
+
+use anyhow::Result;
use async_trait::async_trait;
-use futures::AsyncReadExt;
-use gpui::executor::Background;
-use gpui::{serde_json, AppContext};
-use isahc::http::StatusCode;
-use isahc::prelude::Configurable;
-use isahc::{AsyncBody, Response};
-use lazy_static::lazy_static;
use ordered_float::OrderedFloat;
-use parking_lot::Mutex;
-use parse_duration::parse;
-use postage::watch;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
-use serde::{Deserialize, Serialize};
-use std::env;
-use std::ops::Add;
-use std::sync::Arc;
-use std::time::{Duration, Instant};
-use tiktoken_rs::{cl100k_base, CoreBPE};
-use util::http::{HttpClient, Request};
-use util::ResultExt;
-
-use crate::completion::OPENAI_API_URL;
-lazy_static! {
- static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
-}
+use crate::auth::CredentialProvider;
+use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>);
@@ -87,301 +68,14 @@ impl Embedding {
}
}
-#[derive(Clone)]
-pub struct OpenAIEmbeddings {
- pub client: Arc<dyn HttpClient>,
- pub executor: Arc<Background>,
- rate_limit_count_rx: watch::Receiver<Option<Instant>>,
- rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
-}
-
-#[derive(Serialize)]
-struct OpenAIEmbeddingRequest<'a> {
- model: &'static str,
- input: Vec<&'a str>,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingResponse {
- data: Vec<OpenAIEmbedding>,
- usage: OpenAIEmbeddingUsage,
-}
-
-#[derive(Debug, Deserialize)]
-struct OpenAIEmbedding {
- embedding: Vec<f32>,
- index: usize,
- object: String,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingUsage {
- prompt_tokens: usize,
- total_tokens: usize,
-}
-
#[async_trait]
-pub trait EmbeddingProvider: Sync + Send {
- fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- api_key: Option<String>,
- ) -> Result<Vec<Embedding>>;
+pub trait EmbeddingProvider: CredentialProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel>;
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
- fn truncate(&self, span: &str) -> (String, usize);
fn rate_limit_expiration(&self) -> Option<Instant>;
}
-pub struct DummyEmbeddings {}
-
-#[async_trait]
-impl EmbeddingProvider for DummyEmbeddings {
- fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
- Some("Dummy API KEY".to_string())
- }
- fn rate_limit_expiration(&self) -> Option<Instant> {
- None
- }
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- _api_key: Option<String>,
- ) -> Result<Vec<Embedding>> {
- // 1024 is the OpenAI Embeddings size for ada models.
- // the model we will likely be starting with.
- let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
- return Ok(vec![dummy_vec; spans.len()]);
- }
-
- fn max_tokens_per_batch(&self) -> usize {
- OPENAI_INPUT_LIMIT
- }
-
- fn truncate(&self, span: &str) -> (String, usize) {
- let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- let token_count = tokens.len();
- let output = if token_count > OPENAI_INPUT_LIMIT {
- tokens.truncate(OPENAI_INPUT_LIMIT);
- let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
- new_input.ok().unwrap_or_else(|| span.to_string())
- } else {
- span.to_string()
- };
-
- (output, tokens.len())
- }
-}
-
-const OPENAI_INPUT_LIMIT: usize = 8190;
-
-impl OpenAIEmbeddings {
- pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
- let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
- let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
-
- OpenAIEmbeddings {
- client,
- executor,
- rate_limit_count_rx,
- rate_limit_count_tx,
- }
- }
-
- fn resolve_rate_limit(&self) {
- let reset_time = *self.rate_limit_count_tx.lock().borrow();
-
- if let Some(reset_time) = reset_time {
- if Instant::now() >= reset_time {
- *self.rate_limit_count_tx.lock().borrow_mut() = None
- }
- }
-
- log::trace!(
- "resolving reset time: {:?}",
- *self.rate_limit_count_tx.lock().borrow()
- );
- }
-
- fn update_reset_time(&self, reset_time: Instant) {
- let original_time = *self.rate_limit_count_tx.lock().borrow();
-
- let updated_time = if let Some(original_time) = original_time {
- if reset_time < original_time {
- Some(reset_time)
- } else {
- Some(original_time)
- }
- } else {
- Some(reset_time)
- };
-
- log::trace!("updating rate limit time: {:?}", updated_time);
-
- *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
- }
- async fn send_request(
- &self,
- api_key: &str,
- spans: Vec<&str>,
- request_timeout: u64,
- ) -> Result<Response<AsyncBody>> {
- let request = Request::post("https://api.openai.com/v1/embeddings")
- .redirect_policy(isahc::config::RedirectPolicy::Follow)
- .timeout(Duration::from_secs(request_timeout))
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key))
- .body(
- serde_json::to_string(&OpenAIEmbeddingRequest {
- input: spans.clone(),
- model: "text-embedding-ada-002",
- })
- .unwrap()
- .into(),
- )?;
-
- Ok(self.client.send(request).await?)
- }
-}
-
-#[async_trait]
-impl EmbeddingProvider for OpenAIEmbeddings {
- fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
- if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- Some(api_key)
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
- {
- String::from_utf8(api_key).log_err()
- } else {
- None
- }
- }
-
- fn max_tokens_per_batch(&self) -> usize {
- 50000
- }
-
- fn rate_limit_expiration(&self) -> Option<Instant> {
- *self.rate_limit_count_rx.borrow()
- }
- fn truncate(&self, span: &str) -> (String, usize) {
- let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- let output = if tokens.len() > OPENAI_INPUT_LIMIT {
- tokens.truncate(OPENAI_INPUT_LIMIT);
- OPENAI_BPE_TOKENIZER
- .decode(tokens.clone())
- .ok()
- .unwrap_or_else(|| span.to_string())
- } else {
- span.to_string()
- };
-
- (output, tokens.len())
- }
-
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- api_key: Option<String>,
- ) -> Result<Vec<Embedding>> {
- const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
- const MAX_RETRIES: usize = 4;
-
- let Some(api_key) = api_key else {
- return Err(anyhow!("no open ai key provided"));
- };
-
- let mut request_number = 0;
- let mut rate_limiting = false;
- let mut request_timeout: u64 = 15;
- let mut response: Response<AsyncBody>;
- while request_number < MAX_RETRIES {
- response = self
- .send_request(
- &api_key,
- spans.iter().map(|x| &**x).collect(),
- request_timeout,
- )
- .await?;
-
- request_number += 1;
-
- match response.status() {
- StatusCode::REQUEST_TIMEOUT => {
- request_timeout += 5;
- }
- StatusCode::OK => {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
-
- log::trace!(
- "openai embedding completed. tokens: {:?}",
- response.usage.total_tokens
- );
-
- // If we complete a request successfully that was previously rate_limited
- // resolve the rate limit
- if rate_limiting {
- self.resolve_rate_limit()
- }
-
- return Ok(response
- .data
- .into_iter()
- .map(|embedding| Embedding::from(embedding.embedding))
- .collect());
- }
- StatusCode::TOO_MANY_REQUESTS => {
- rate_limiting = true;
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- let delay_duration = {
- let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
- if let Some(time_to_reset) =
- response.headers().get("x-ratelimit-reset-tokens")
- {
- if let Ok(time_str) = time_to_reset.to_str() {
- parse(time_str).unwrap_or(delay)
- } else {
- delay
- }
- } else {
- delay
- }
- };
-
- // If we've previously rate limited, increment the duration but not the count
- let reset_time = Instant::now().add(delay_duration);
- self.update_reset_time(reset_time);
-
- log::trace!(
- "openai rate limiting: waiting {:?} until lifted",
- &delay_duration
- );
-
- self.executor.timer(delay_duration).await;
- }
- _ => {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- return Err(anyhow!(
- "open ai bad request: {:?} {:?}",
- &response.status(),
- body
- ));
- }
- }
- }
- Err(anyhow!("openai max retries"))
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
@@ -1,66 +1,16 @@
-use anyhow::anyhow;
-use tiktoken_rs::CoreBPE;
-use util::ResultExt;
+pub enum TruncationDirection {
+ Start,
+ End,
+}
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
- fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
- fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}
-
-pub struct OpenAILanguageModel {
- name: String,
- bpe: Option<CoreBPE>,
-}
-
-impl OpenAILanguageModel {
- pub fn load(model_name: &str) -> Self {
- let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
- OpenAILanguageModel {
- name: model_name.to_string(),
- bpe,
- }
- }
-}
-
-impl LanguageModel for OpenAILanguageModel {
- fn name(&self) -> String {
- self.name.clone()
- }
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
- if let Some(bpe) = &self.bpe {
- anyhow::Ok(bpe.encode_with_special_tokens(content).len())
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
- if let Some(bpe) = &self.bpe {
- let tokens = bpe.encode_with_special_tokens(content);
- if tokens.len() > length {
- bpe.decode(tokens[..length].to_vec())
- } else {
- bpe.decode(tokens)
- }
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
- if let Some(bpe) = &self.bpe {
- let tokens = bpe.encode_with_special_tokens(content);
- if tokens.len() > length {
- bpe.decode(tokens[length..].to_vec())
- } else {
- bpe.decode(tokens)
- }
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn capacity(&self) -> anyhow::Result<usize> {
- anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
- }
-}
@@ -6,7 +6,7 @@ use language::BufferSnapshot;
use util::ResultExt;
use crate::models::LanguageModel;
-use crate::templates::repository_context::PromptCodeSnippet;
+use crate::prompts::repository_context::PromptCodeSnippet;
pub(crate) enum PromptFileType {
Text,
@@ -125,6 +125,9 @@ impl PromptChain {
#[cfg(test)]
pub(crate) mod tests {
+ use crate::models::TruncationDirection;
+ use crate::test::FakeLanguageModel;
+
use super::*;
#[test]
@@ -141,7 +144,11 @@ pub(crate) mod tests {
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
- content = args.model.truncate(&content, max_token_length)?;
+ content = args.model.truncate(
+ &content,
+ max_token_length,
+ TruncationDirection::End,
+ )?;
token_count = max_token_length;
}
}
@@ -162,7 +169,11 @@ pub(crate) mod tests {
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
- content = args.model.truncate(&content, max_token_length)?;
+ content = args.model.truncate(
+ &content,
+ max_token_length,
+ TruncationDirection::End,
+ )?;
token_count = max_token_length;
}
}
@@ -171,38 +182,7 @@ pub(crate) mod tests {
}
}
- #[derive(Clone)]
- struct DummyLanguageModel {
- capacity: usize,
- }
-
- impl LanguageModel for DummyLanguageModel {
- fn name(&self) -> String {
- "dummy".to_string()
- }
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
- anyhow::Ok(content.chars().collect::<Vec<char>>().len())
- }
- fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
- anyhow::Ok(
- content.chars().collect::<Vec<char>>()[..length]
- .into_iter()
- .collect::<String>(),
- )
- }
- fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
- anyhow::Ok(
- content.chars().collect::<Vec<char>>()[length..]
- .into_iter()
- .collect::<String>(),
- )
- }
- fn capacity(&self) -> anyhow::Result<usize> {
- anyhow::Ok(self.capacity)
- }
- }
-
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -238,7 +218,7 @@ pub(crate) mod tests {
// Testing with Truncation Off
// Should ignore capacity and return all prompts
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -275,7 +255,7 @@ pub(crate) mod tests {
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let capacity = 20;
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -311,7 +291,7 @@ pub(crate) mod tests {
// Change Ordering of Prompts Based on Priority
let capacity = 120;
let reserved_tokens = 10;
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -3,8 +3,9 @@ use language::BufferSnapshot;
use language::ToOffset;
use crate::models::LanguageModel;
-use crate::templates::base::PromptArguments;
-use crate::templates::base::PromptTemplate;
+use crate::models::TruncationDirection;
+use crate::prompts::base::PromptArguments;
+use crate::prompts::base::PromptTemplate;
use std::fmt::Write;
use std::ops::Range;
use std::sync::Arc;
@@ -70,8 +71,9 @@ fn retrieve_context(
};
let truncated_start_window =
- model.truncate_start(&start_window, start_goal_tokens)?;
- let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
+ model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
+ let truncated_end_window =
+ model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
@@ -89,7 +91,7 @@ fn retrieve_context(
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
- prompt = model.truncate(&prompt, max_token_count)?;
+ prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
}
}
}
@@ -148,7 +150,9 @@ impl PromptTemplate for FileContext {
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
- prompt = args.model.truncate(&prompt, max_tokens)?;
+ prompt = args
+ .model
+ .truncate(&prompt, max_tokens, TruncationDirection::End)?;
}
let token_count = args.model.count_tokens(&prompt)?;
@@ -1,4 +1,4 @@
-use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use anyhow::anyhow;
use std::fmt::Write;
@@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent {
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
- prompt = args.model.truncate(&prompt, max_tokens)?;
+ prompt = args.model.truncate(
+ &prompt,
+ max_tokens,
+ crate::models::TruncationDirection::End,
+ )?;
}
let token_count = args.model.count_tokens(&prompt)?;
@@ -1,4 +1,4 @@
-use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write;
pub struct EngineerPreamble {}
@@ -1,4 +1,4 @@
-use crate::templates::base::{PromptArguments, PromptTemplate};
+use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
@@ -0,0 +1 @@
+pub mod open_ai;
@@ -0,0 +1,298 @@
+use anyhow::{anyhow, Result};
+use futures::{
+ future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
+ Stream, StreamExt,
+};
+use gpui::{executor::Background, AppContext};
+use isahc::{http::StatusCode, Request, RequestExt};
+use parking_lot::RwLock;
+use serde::{Deserialize, Serialize};
+use std::{
+ env,
+ fmt::{self, Display},
+ io,
+ sync::Arc,
+};
+use util::ResultExt;
+
+use crate::{
+ auth::{CredentialProvider, ProviderCredential},
+ completion::{CompletionProvider, CompletionRequest},
+ models::LanguageModel,
+};
+
+use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+ System,
+}
+
+impl Role {
+ pub fn cycle(&mut self) {
+ *self = match self {
+ Role::User => Role::Assistant,
+ Role::Assistant => Role::System,
+ Role::System => Role::User,
+ }
+ }
+}
+
+impl Display for Role {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Role::User => write!(f, "User"),
+ Role::Assistant => write!(f, "Assistant"),
+ Role::System => write!(f, "System"),
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+ pub role: Role,
+ pub content: String,
+}
+
+#[derive(Debug, Default, Serialize)]
+pub struct OpenAIRequest {
+ pub model: String,
+ pub messages: Vec<RequestMessage>,
+ pub stream: bool,
+ pub stop: Vec<String>,
+ pub temperature: f32,
+}
+
+impl CompletionRequest for OpenAIRequest {
+ fn data(&self) -> serde_json::Result<String> {
+ serde_json::to_string(self)
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessage {
+ pub role: Option<Role>,
+ pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIUsage {
+ pub prompt_tokens: u32,
+ pub completion_tokens: u32,
+ pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ChatChoiceDelta {
+ pub index: u32,
+ pub delta: ResponseMessage,
+ pub finish_reason: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIResponseStreamEvent {
+ pub id: Option<String>,
+ pub object: String,
+ pub created: u32,
+ pub model: String,
+ pub choices: Vec<ChatChoiceDelta>,
+ pub usage: Option<OpenAIUsage>,
+}
+
+pub async fn stream_completion(
+ credential: ProviderCredential,
+ executor: Arc<Background>,
+ request: Box<dyn CompletionRequest>,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+ let api_key = match credential {
+ ProviderCredential::Credentials { api_key } => api_key,
+ _ => {
+ return Err(anyhow!("no credentials provider for completion"));
+ }
+ };
+
+ let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+ let json_data = request.data()?;
+ let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(json_data)?
+ .send_async()
+ .await?;
+
+ let status = response.status();
+ if status == StatusCode::OK {
+ executor
+ .spawn(async move {
+ let mut lines = BufReader::new(response.body_mut()).lines();
+
+ fn parse_line(
+ line: Result<String, io::Error>,
+ ) -> Result<Option<OpenAIResponseStreamEvent>> {
+ if let Some(data) = line?.strip_prefix("data: ") {
+ let event = serde_json::from_str(&data)?;
+ Ok(Some(event))
+ } else {
+ Ok(None)
+ }
+ }
+
+ while let Some(line) = lines.next().await {
+ if let Some(event) = parse_line(line).transpose() {
+ let done = event.as_ref().map_or(false, |event| {
+ event
+ .choices
+ .last()
+ .map_or(false, |choice| choice.finish_reason.is_some())
+ });
+ if tx.unbounded_send(event).is_err() {
+ break;
+ }
+
+ if done {
+ break;
+ }
+ }
+ }
+
+ anyhow::Ok(())
+ })
+ .detach();
+
+ Ok(rx)
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ #[derive(Deserialize)]
+ struct OpenAIResponse {
+ error: OpenAIError,
+ }
+
+ #[derive(Deserialize)]
+ struct OpenAIError {
+ message: String,
+ }
+
+ match serde_json::from_str::<OpenAIResponse>(&body) {
+ Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+ "Failed to connect to OpenAI API: {}",
+ response.error.message,
+ )),
+
+ _ => Err(anyhow!(
+ "Failed to connect to OpenAI API: {} {}",
+ response.status(),
+ body,
+ )),
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct OpenAICompletionProvider {
+ model: OpenAILanguageModel,
+ credential: Arc<RwLock<ProviderCredential>>,
+ executor: Arc<Background>,
+}
+
+impl OpenAICompletionProvider {
+ pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
+ let model = OpenAILanguageModel::load(model_name);
+ let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+ Self {
+ model,
+ credential,
+ executor,
+ }
+ }
+}
+
+impl CredentialProvider for OpenAICompletionProvider {
+ fn has_credentials(&self) -> bool {
+ match *self.credential.read() {
+ ProviderCredential::Credentials { .. } => true,
+ _ => false,
+ }
+ }
+ fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+ let mut credential = self.credential.write();
+ match *credential {
+ ProviderCredential::Credentials { .. } => {
+ return credential.clone();
+ }
+ _ => {
+ if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ *credential = ProviderCredential::Credentials { api_key };
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ if let Some(api_key) = String::from_utf8(api_key).log_err() {
+ *credential = ProviderCredential::Credentials { api_key };
+ }
+ } else {
+ };
+ }
+ }
+
+ credential.clone()
+ }
+
+ fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+ match credential.clone() {
+ ProviderCredential::Credentials { api_key } => {
+ cx.platform()
+ .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ .log_err();
+ }
+ _ => {}
+ }
+
+ *self.credential.write() = credential;
+ }
+ fn delete_credentials(&self, cx: &AppContext) {
+ cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+ *self.credential.write() = ProviderCredential::NoCredentials;
+ }
+}
+
+impl CompletionProvider for OpenAICompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+ model
+ }
+ fn complete(
+ &self,
+ prompt: Box<dyn CompletionRequest>,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
+ // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
+ // which is currently model based, due to the langauge model.
+ // At some point in the future we should rectify this.
+ let credential = self.credential.read().clone();
+ let request = stream_completion(credential, self.executor.clone(), prompt);
+ async move {
+ let response = request.await?;
+ let stream = response
+ .filter_map(|response| async move {
+ match response {
+ Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed();
+ Ok(stream)
+ }
+ .boxed()
+ }
+ fn box_clone(&self) -> Box<dyn CompletionProvider> {
+ Box::new((*self).clone())
+ }
+}
@@ -0,0 +1,306 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::AsyncReadExt;
+use gpui::executor::Background;
+use gpui::{serde_json, AppContext};
+use isahc::http::StatusCode;
+use isahc::prelude::Configurable;
+use isahc::{AsyncBody, Response};
+use lazy_static::lazy_static;
+use parking_lot::{Mutex, RwLock};
+use parse_duration::parse;
+use postage::watch;
+use serde::{Deserialize, Serialize};
+use std::env;
+use std::ops::Add;
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use tiktoken_rs::{cl100k_base, CoreBPE};
+use util::http::{HttpClient, Request};
+use util::ResultExt;
+
+use crate::auth::{CredentialProvider, ProviderCredential};
+use crate::embedding::{Embedding, EmbeddingProvider};
+use crate::models::LanguageModel;
+use crate::providers::open_ai::OpenAILanguageModel;
+
+use crate::providers::open_ai::OPENAI_API_URL;
+
+lazy_static! {
+ static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
+}
+
+#[derive(Clone)]
+pub struct OpenAIEmbeddingProvider {
+ model: OpenAILanguageModel,
+ credential: Arc<RwLock<ProviderCredential>>,
+ pub client: Arc<dyn HttpClient>,
+ pub executor: Arc<Background>,
+ rate_limit_count_rx: watch::Receiver<Option<Instant>>,
+ rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
+}
+
+#[derive(Serialize)]
+struct OpenAIEmbeddingRequest<'a> {
+ model: &'static str,
+ input: Vec<&'a str>,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingResponse {
+ data: Vec<OpenAIEmbedding>,
+ usage: OpenAIEmbeddingUsage,
+}
+
+#[derive(Debug, Deserialize)]
+struct OpenAIEmbedding {
+ embedding: Vec<f32>,
+ index: usize,
+ object: String,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingUsage {
+ prompt_tokens: usize,
+ total_tokens: usize,
+}
+
+impl OpenAIEmbeddingProvider {
+ pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+ let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
+ let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
+
+ let model = OpenAILanguageModel::load("text-embedding-ada-002");
+ let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+
+ OpenAIEmbeddingProvider {
+ model,
+ credential,
+ client,
+ executor,
+ rate_limit_count_rx,
+ rate_limit_count_tx,
+ }
+ }
+
+ fn get_api_key(&self) -> Result<String> {
+ match self.credential.read().clone() {
+ ProviderCredential::Credentials { api_key } => Ok(api_key),
+ _ => Err(anyhow!("api credentials not provided")),
+ }
+ }
+
+ fn resolve_rate_limit(&self) {
+ let reset_time = *self.rate_limit_count_tx.lock().borrow();
+
+ if let Some(reset_time) = reset_time {
+ if Instant::now() >= reset_time {
+ *self.rate_limit_count_tx.lock().borrow_mut() = None
+ }
+ }
+
+ log::trace!(
+ "resolving reset time: {:?}",
+ *self.rate_limit_count_tx.lock().borrow()
+ );
+ }
+
+ fn update_reset_time(&self, reset_time: Instant) {
+ let original_time = *self.rate_limit_count_tx.lock().borrow();
+
+ let updated_time = if let Some(original_time) = original_time {
+ if reset_time < original_time {
+ Some(reset_time)
+ } else {
+ Some(original_time)
+ }
+ } else {
+ Some(reset_time)
+ };
+
+ log::trace!("updating rate limit time: {:?}", updated_time);
+
+ *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
+ }
+ async fn send_request(
+ &self,
+ api_key: &str,
+ spans: Vec<&str>,
+ request_timeout: u64,
+ ) -> Result<Response<AsyncBody>> {
+ let request = Request::post("https://api.openai.com/v1/embeddings")
+ .redirect_policy(isahc::config::RedirectPolicy::Follow)
+ .timeout(Duration::from_secs(request_timeout))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(
+ serde_json::to_string(&OpenAIEmbeddingRequest {
+ input: spans.clone(),
+ model: "text-embedding-ada-002",
+ })
+ .unwrap()
+ .into(),
+ )?;
+
+ Ok(self.client.send(request).await?)
+ }
+}
+
+impl CredentialProvider for OpenAIEmbeddingProvider {
+ fn has_credentials(&self) -> bool {
+ match *self.credential.read() {
+ ProviderCredential::Credentials { .. } => true,
+ _ => false,
+ }
+ }
+ fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+ let mut credential = self.credential.write();
+ match *credential {
+ ProviderCredential::Credentials { .. } => {
+ return credential.clone();
+ }
+ _ => {
+ if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ *credential = ProviderCredential::Credentials { api_key };
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ if let Some(api_key) = String::from_utf8(api_key).log_err() {
+ *credential = ProviderCredential::Credentials { api_key };
+ }
+ } else {
+ };
+ }
+ }
+
+ credential.clone()
+ }
+
+ fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+ match credential.clone() {
+ ProviderCredential::Credentials { api_key } => {
+ cx.platform()
+ .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ .log_err();
+ }
+ _ => {}
+ }
+
+ *self.credential.write() = credential;
+ }
+ fn delete_credentials(&self, cx: &AppContext) {
+ cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+ *self.credential.write() = ProviderCredential::NoCredentials;
+ }
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddingProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+ model
+ }
+
+ fn max_tokens_per_batch(&self) -> usize {
+ 50000
+ }
+
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ *self.rate_limit_count_rx.borrow()
+ }
+
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+ const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
+ const MAX_RETRIES: usize = 4;
+
+ let api_key = self.get_api_key()?;
+
+ let mut request_number = 0;
+ let mut rate_limiting = false;
+ let mut request_timeout: u64 = 15;
+ let mut response: Response<AsyncBody>;
+ while request_number < MAX_RETRIES {
+ response = self
+ .send_request(
+ &api_key,
+ spans.iter().map(|x| &**x).collect(),
+ request_timeout,
+ )
+ .await?;
+
+ request_number += 1;
+
+ match response.status() {
+ StatusCode::REQUEST_TIMEOUT => {
+ request_timeout += 5;
+ }
+ StatusCode::OK => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+ log::trace!(
+ "openai embedding completed. tokens: {:?}",
+ response.usage.total_tokens
+ );
+
+ // If we complete a request successfully that was previously rate_limited
+ // resolve the rate limit
+ if rate_limiting {
+ self.resolve_rate_limit()
+ }
+
+ return Ok(response
+ .data
+ .into_iter()
+ .map(|embedding| Embedding::from(embedding.embedding))
+ .collect());
+ }
+ StatusCode::TOO_MANY_REQUESTS => {
+ rate_limiting = true;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ let delay_duration = {
+ let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+ if let Some(time_to_reset) =
+ response.headers().get("x-ratelimit-reset-tokens")
+ {
+ if let Ok(time_str) = time_to_reset.to_str() {
+ parse(time_str).unwrap_or(delay)
+ } else {
+ delay
+ }
+ } else {
+ delay
+ }
+ };
+
+ // If we've previously rate limited, increment the duration but not the count
+ let reset_time = Instant::now().add(delay_duration);
+ self.update_reset_time(reset_time);
+
+ log::trace!(
+ "openai rate limiting: waiting {:?} until lifted",
+ &delay_duration
+ );
+
+ self.executor.timer(delay_duration).await;
+ }
+ _ => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow!(
+ "open ai bad request: {:?} {:?}",
+ &response.status(),
+ body
+ ));
+ }
+ }
+ }
+ Err(anyhow!("openai max retries"))
+ }
+}
@@ -0,0 +1,9 @@
+pub mod completion;
+pub mod embedding;
+pub mod model;
+
+pub use completion::*;
+pub use embedding::*;
+pub use model::OpenAILanguageModel;
+
+pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
@@ -0,0 +1,57 @@
+use anyhow::anyhow;
+use tiktoken_rs::CoreBPE;
+use util::ResultExt;
+
+use crate::models::{LanguageModel, TruncationDirection};
+
+#[derive(Clone)]
+pub struct OpenAILanguageModel {
+ name: String,
+ bpe: Option<CoreBPE>,
+}
+
+impl OpenAILanguageModel {
+ pub fn load(model_name: &str) -> Self {
+ let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
+ OpenAILanguageModel {
+ name: model_name.to_string(),
+ bpe,
+ }
+ }
+}
+
+impl LanguageModel for OpenAILanguageModel {
+ fn name(&self) -> String {
+ self.name.clone()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ if let Some(bpe) = &self.bpe {
+ anyhow::Ok(bpe.encode_with_special_tokens(content).len())
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
+ if let Some(bpe) = &self.bpe {
+ let tokens = bpe.encode_with_special_tokens(content);
+ if tokens.len() > length {
+ match direction {
+ TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
+ TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
+ }
+ } else {
+ bpe.decode(tokens)
+ }
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
+ }
+}
@@ -0,0 +1,11 @@
+pub trait LanguageModel {
+ fn name(&self) -> String;
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String>;
+ fn capacity(&self) -> anyhow::Result<usize>;
+}
@@ -0,0 +1,191 @@
+use std::{
+ sync::atomic::{self, AtomicUsize, Ordering},
+ time::Instant,
+};
+
+use async_trait::async_trait;
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui::AppContext;
+use parking_lot::Mutex;
+
+use crate::{
+ auth::{CredentialProvider, ProviderCredential},
+ completion::{CompletionProvider, CompletionRequest},
+ embedding::{Embedding, EmbeddingProvider},
+ models::{LanguageModel, TruncationDirection},
+};
+
+#[derive(Clone)]
+pub struct FakeLanguageModel {
+ pub capacity: usize,
+}
+
+impl LanguageModel for FakeLanguageModel {
+ fn name(&self) -> String {
+ "dummy".to_string()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+ }
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
+ println!("TRYING TO TRUNCATE: {:?}", length.clone());
+
+ if length > self.count_tokens(content)? {
+ println!("NOT TRUNCATING");
+ return anyhow::Ok(content.to_string());
+ }
+
+ anyhow::Ok(match direction {
+ TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
+ .into_iter()
+ .collect::<String>(),
+ TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
+ .into_iter()
+ .collect::<String>(),
+ })
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(self.capacity)
+ }
+}
+
+pub struct FakeEmbeddingProvider {
+ pub embedding_count: AtomicUsize,
+}
+
+impl Clone for FakeEmbeddingProvider {
+ fn clone(&self) -> Self {
+ FakeEmbeddingProvider {
+ embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
+ }
+ }
+}
+
+impl Default for FakeEmbeddingProvider {
+ fn default() -> Self {
+ FakeEmbeddingProvider {
+ embedding_count: AtomicUsize::default(),
+ }
+ }
+}
+
+impl FakeEmbeddingProvider {
+ pub fn embedding_count(&self) -> usize {
+ self.embedding_count.load(atomic::Ordering::SeqCst)
+ }
+
+ pub fn embed_sync(&self, span: &str) -> Embedding {
+ let mut result = vec![1.0; 26];
+ for letter in span.chars() {
+ let letter = letter.to_ascii_lowercase();
+ if letter as u32 >= 'a' as u32 {
+ let ix = (letter as u32) - ('a' as u32);
+ if ix < 26 {
+ result[ix as usize] += 1.0;
+ }
+ }
+ }
+
+ let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+ for x in &mut result {
+ *x /= norm;
+ }
+
+ result.into()
+ }
+}
+
+impl CredentialProvider for FakeEmbeddingProvider {
+ fn has_credentials(&self) -> bool {
+ true
+ }
+ fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+ ProviderCredential::NotNeeded
+ }
+ fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
+ fn delete_credentials(&self, _cx: &AppContext) {}
+}
+
+#[async_trait]
+impl EmbeddingProvider for FakeEmbeddingProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ Box::new(FakeLanguageModel { capacity: 1000 })
+ }
+ fn max_tokens_per_batch(&self) -> usize {
+ 1000
+ }
+
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ None
+ }
+
+ async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
+ self.embedding_count
+ .fetch_add(spans.len(), atomic::Ordering::SeqCst);
+
+ anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
+ }
+}
+
+pub struct FakeCompletionProvider {
+ last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
+}
+
+impl Clone for FakeCompletionProvider {
+ fn clone(&self) -> Self {
+ Self {
+ last_completion_tx: Mutex::new(None),
+ }
+ }
+}
+
+impl FakeCompletionProvider {
+ pub fn new() -> Self {
+ Self {
+ last_completion_tx: Mutex::new(None),
+ }
+ }
+
+ pub fn send_completion(&self, completion: impl Into<String>) {
+ let mut tx = self.last_completion_tx.lock();
+ tx.as_mut().unwrap().try_send(completion.into()).unwrap();
+ }
+
+ pub fn finish_completion(&self) {
+ self.last_completion_tx.lock().take().unwrap();
+ }
+}
+
+impl CredentialProvider for FakeCompletionProvider {
+ fn has_credentials(&self) -> bool {
+ true
+ }
+ fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+ ProviderCredential::NotNeeded
+ }
+ fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
+ fn delete_credentials(&self, _cx: &AppContext) {}
+}
+
+impl CompletionProvider for FakeCompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
+ model
+ }
+ fn complete(
+ &self,
+ _prompt: Box<dyn CompletionRequest>,
+ ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
+ let (tx, rx) = mpsc::channel(1);
+ *self.last_completion_tx.lock() = Some(tx);
+ async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
+ }
+ fn box_clone(&self) -> Box<dyn CompletionProvider> {
+ Box::new((*self).clone())
+ }
+}
@@ -0,0 +1,38 @@
+[package]
+name = "ai2"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/ai2.rs"
+doctest = false
+
+[features]
+test-support = []
+
+[dependencies]
+gpui2 = { path = "../gpui2" }
+util = { path = "../util" }
+language2 = { path = "../language2" }
+async-trait.workspace = true
+anyhow.workspace = true
+futures.workspace = true
+lazy_static.workspace = true
+ordered-float.workspace = true
+parking_lot.workspace = true
+isahc.workspace = true
+regex.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+postage.workspace = true
+rand.workspace = true
+log.workspace = true
+parse_duration = "2.1.1"
+tiktoken-rs = "0.5.0"
+matrixmultiply = "0.3.7"
+rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
+bincode = "1.3.3"
+
+[dev-dependencies]
+gpui2 = { path = "../gpui2", features = ["test-support"] }
@@ -0,0 +1,8 @@
+pub mod auth;
+pub mod completion;
+pub mod embedding;
+pub mod models;
+pub mod prompts;
+pub mod providers;
+#[cfg(any(test, feature = "test-support"))]
+pub mod test;
@@ -0,0 +1,17 @@
+use async_trait::async_trait;
+use gpui2::AppContext;
+
+#[derive(Clone, Debug)]
+pub enum ProviderCredential {
+ Credentials { api_key: String },
+ NoCredentials,
+ NotNeeded,
+}
+
+#[async_trait]
+pub trait CredentialProvider: Send + Sync {
+ fn has_credentials(&self) -> bool;
+ async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential;
+ async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential);
+ async fn delete_credentials(&self, cx: &mut AppContext);
+}
@@ -0,0 +1,23 @@
+use anyhow::Result;
+use futures::{future::BoxFuture, stream::BoxStream};
+
+use crate::{auth::CredentialProvider, models::LanguageModel};
+
+pub trait CompletionRequest: Send + Sync {
+ fn data(&self) -> serde_json::Result<String>;
+}
+
+pub trait CompletionProvider: CredentialProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel>;
+ fn complete(
+ &self,
+ prompt: Box<dyn CompletionRequest>,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+ fn box_clone(&self) -> Box<dyn CompletionProvider>;
+}
+
+impl Clone for Box<dyn CompletionProvider> {
+ fn clone(&self) -> Box<dyn CompletionProvider> {
+ self.box_clone()
+ }
+}
@@ -0,0 +1,123 @@
+use std::time::Instant;
+
+use anyhow::Result;
+use async_trait::async_trait;
+use ordered_float::OrderedFloat;
+use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
+use rusqlite::ToSql;
+
+use crate::auth::CredentialProvider;
+use crate::models::LanguageModel;
+
+#[derive(Debug, PartialEq, Clone)]
+pub struct Embedding(pub Vec<f32>);
+
+// This is needed for semantic index functionality
+// Unfortunately it has to live wherever the "Embedding" struct is created.
+// Keeping this in here though, introduces a 'rusqlite' dependency into AI
+// which is less than ideal
+impl FromSql for Embedding {
+ fn column_result(value: ValueRef) -> FromSqlResult<Self> {
+ let bytes = value.as_blob()?;
+ let embedding: Result<Vec<f32>, Box<bincode::ErrorKind>> = bincode::deserialize(bytes);
+ if embedding.is_err() {
+ return Err(rusqlite::types::FromSqlError::Other(embedding.unwrap_err()));
+ }
+ Ok(Embedding(embedding.unwrap()))
+ }
+}
+
+impl ToSql for Embedding {
+ fn to_sql(&self) -> rusqlite::Result<ToSqlOutput> {
+ let bytes = bincode::serialize(&self.0)
+ .map_err(|err| rusqlite::Error::ToSqlConversionFailure(Box::new(err)))?;
+ Ok(ToSqlOutput::Owned(rusqlite::types::Value::Blob(bytes)))
+ }
+}
+impl From<Vec<f32>> for Embedding {
+ fn from(value: Vec<f32>) -> Self {
+ Embedding(value)
+ }
+}
+
+impl Embedding {
+ pub fn similarity(&self, other: &Self) -> OrderedFloat<f32> {
+ let len = self.0.len();
+ assert_eq!(len, other.0.len());
+
+ let mut result = 0.0;
+ unsafe {
+ matrixmultiply::sgemm(
+ 1,
+ len,
+ 1,
+ 1.0,
+ self.0.as_ptr(),
+ len as isize,
+ 1,
+ other.0.as_ptr(),
+ 1,
+ len as isize,
+ 0.0,
+ &mut result as *mut f32,
+ 1,
+ 1,
+ );
+ }
+ OrderedFloat(result)
+ }
+}
+
+#[async_trait]
+pub trait EmbeddingProvider: CredentialProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel>;
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
+ fn max_tokens_per_batch(&self) -> usize;
+ fn rate_limit_expiration(&self) -> Option<Instant>;
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use rand::prelude::*;
+
+ #[gpui2::test]
+ fn test_similarity(mut rng: StdRng) {
+ assert_eq!(
+ Embedding::from(vec![1., 0., 0., 0., 0.])
+ .similarity(&Embedding::from(vec![0., 1., 0., 0., 0.])),
+ 0.
+ );
+ assert_eq!(
+ Embedding::from(vec![2., 0., 0., 0., 0.])
+ .similarity(&Embedding::from(vec![3., 1., 0., 0., 0.])),
+ 6.
+ );
+
+ for _ in 0..100 {
+ let size = 1536;
+ let mut a = vec![0.; size];
+ let mut b = vec![0.; size];
+ for (a, b) in a.iter_mut().zip(b.iter_mut()) {
+ *a = rng.gen();
+ *b = rng.gen();
+ }
+ let a = Embedding::from(a);
+ let b = Embedding::from(b);
+
+ assert_eq!(
+ round_to_decimals(a.similarity(&b), 1),
+ round_to_decimals(reference_dot(&a.0, &b.0), 1)
+ );
+ }
+
+ fn round_to_decimals(n: OrderedFloat<f32>, decimal_places: i32) -> f32 {
+ let factor = (10.0 as f32).powi(decimal_places);
+ (n * factor).round() / factor
+ }
+
+ fn reference_dot(a: &[f32], b: &[f32]) -> OrderedFloat<f32> {
+ OrderedFloat(a.iter().zip(b.iter()).map(|(a, b)| a * b).sum())
+ }
+ }
+}
@@ -0,0 +1,16 @@
+pub enum TruncationDirection {
+ Start,
+ End,
+}
+
+pub trait LanguageModel {
+ fn name(&self) -> String;
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String>;
+ fn capacity(&self) -> anyhow::Result<usize>;
+}
@@ -0,0 +1,330 @@
+use std::cmp::Reverse;
+use std::ops::Range;
+use std::sync::Arc;
+
+use language2::BufferSnapshot;
+use util::ResultExt;
+
+use crate::models::LanguageModel;
+use crate::prompts::repository_context::PromptCodeSnippet;
+
+pub(crate) enum PromptFileType {
+ Text,
+ Code,
+}
+
+// TODO: Set this up to manage for defaults well
+pub struct PromptArguments {
+ pub model: Arc<dyn LanguageModel>,
+ pub user_prompt: Option<String>,
+ pub language_name: Option<String>,
+ pub project_name: Option<String>,
+ pub snippets: Vec<PromptCodeSnippet>,
+ pub reserved_tokens: usize,
+ pub buffer: Option<BufferSnapshot>,
+ pub selected_range: Option<Range<usize>>,
+}
+
+impl PromptArguments {
+ pub(crate) fn get_file_type(&self) -> PromptFileType {
+ if self
+ .language_name
+ .as_ref()
+ .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
+ .unwrap_or(true)
+ {
+ PromptFileType::Code
+ } else {
+ PromptFileType::Text
+ }
+ }
+}
+
+pub trait PromptTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)>;
+}
+
+#[repr(i8)]
+#[derive(PartialEq, Eq, Ord)]
+pub enum PromptPriority {
+ Mandatory, // Ignores truncation
+ Ordered { order: usize }, // Truncates based on priority
+}
+
+impl PartialOrd for PromptPriority {
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+ match (self, other) {
+ (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
+ (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
+ (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
+ (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
+ }
+ }
+}
+
+pub struct PromptChain {
+ args: PromptArguments,
+ templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+}
+
+impl PromptChain {
+ pub fn new(
+ args: PromptArguments,
+ templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+ ) -> Self {
+ PromptChain { args, templates }
+ }
+
+ pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
+ // Argsort based on Prompt Priority
+ let seperator = "\n";
+ let seperator_tokens = self.args.model.count_tokens(seperator)?;
+ let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
+ sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
+
+ // If Truncate
+ let mut tokens_outstanding = if truncate {
+ Some(self.args.model.capacity()? - self.args.reserved_tokens)
+ } else {
+ None
+ };
+
+ let mut prompts = vec!["".to_string(); sorted_indices.len()];
+ for idx in sorted_indices {
+ let (_, template) = &self.templates[idx];
+
+ if let Some((template_prompt, prompt_token_count)) =
+ template.generate(&self.args, tokens_outstanding).log_err()
+ {
+ if template_prompt != "" {
+ prompts[idx] = template_prompt;
+
+ if let Some(remaining_tokens) = tokens_outstanding {
+ let new_tokens = prompt_token_count + seperator_tokens;
+ tokens_outstanding = if remaining_tokens > new_tokens {
+ Some(remaining_tokens - new_tokens)
+ } else {
+ Some(0)
+ };
+ }
+ }
+ }
+ }
+
+ prompts.retain(|x| x != "");
+
+ let full_prompt = prompts.join(seperator);
+ let total_token_count = self.args.model.count_tokens(&full_prompt)?;
+ anyhow::Ok((prompts.join(seperator), total_token_count))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod tests {
+ use crate::models::TruncationDirection;
+ use crate::test::FakeLanguageModel;
+
+ use super::*;
+
+ #[test]
+ pub fn test_prompt_chain() {
+ struct TestPromptTemplate {}
+ impl PromptTemplate for TestPromptTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut content = "This is a test prompt template".to_string();
+
+ let mut token_count = args.model.count_tokens(&content)?;
+ if let Some(max_token_length) = max_token_length {
+ if token_count > max_token_length {
+ content = args.model.truncate(
+ &content,
+ max_token_length,
+ TruncationDirection::End,
+ )?;
+ token_count = max_token_length;
+ }
+ }
+
+ anyhow::Ok((content, token_count))
+ }
+ }
+
+ struct TestLowPriorityTemplate {}
+ impl PromptTemplate for TestLowPriorityTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut content = "This is a low priority test prompt template".to_string();
+
+ let mut token_count = args.model.count_tokens(&content)?;
+ if let Some(max_token_length) = max_token_length {
+ if token_count > max_token_length {
+ content = args.model.truncate(
+ &content,
+ max_token_length,
+ TruncationDirection::End,
+ )?;
+ token_count = max_token_length;
+ }
+ }
+
+ anyhow::Ok((content, token_count))
+ }
+ }
+
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(false).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a test prompt template\nThis is a low priority test prompt template"
+ .to_string()
+ );
+
+ assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+ // Testing with Truncation Off
+ // Should ignore capacity and return all prompts
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(false).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a test prompt template\nThis is a low priority test prompt template"
+ .to_string()
+ );
+
+ assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+ // Testing with Truncation Off
+ // Should ignore capacity and return all prompts
+ let capacity = 20;
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 2 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(true).unwrap();
+
+ assert_eq!(prompt, "This is a test promp".to_string());
+ assert_eq!(token_count, capacity);
+
+ // Change Ordering of Prompts Based on Priority
+ let capacity = 120;
+ let reserved_tokens = 10;
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Mandatory,
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(true).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
+ .to_string()
+ );
+ assert_eq!(token_count, capacity - reserved_tokens);
+ }
+}
@@ -0,0 +1,164 @@
+use anyhow::anyhow;
+use language2::BufferSnapshot;
+use language2::ToOffset;
+
+use crate::models::LanguageModel;
+use crate::models::TruncationDirection;
+use crate::prompts::base::PromptArguments;
+use crate::prompts::base::PromptTemplate;
+use std::fmt::Write;
+use std::ops::Range;
+use std::sync::Arc;
+
+fn retrieve_context(
+ buffer: &BufferSnapshot,
+ selected_range: &Option<Range<usize>>,
+ model: Arc<dyn LanguageModel>,
+ max_token_count: Option<usize>,
+) -> anyhow::Result<(String, usize, bool)> {
+ let mut prompt = String::new();
+ let mut truncated = false;
+ if let Some(selected_range) = selected_range {
+ let start = selected_range.start.to_offset(buffer);
+ let end = selected_range.end.to_offset(buffer);
+
+ let start_window = buffer.text_for_range(0..start).collect::<String>();
+
+ let mut selected_window = String::new();
+ if start == end {
+ write!(selected_window, "<|START|>").unwrap();
+ } else {
+ write!(selected_window, "<|START|").unwrap();
+ }
+
+ write!(
+ selected_window,
+ "{}",
+ buffer.text_for_range(start..end).collect::<String>()
+ )
+ .unwrap();
+
+ if start != end {
+ write!(selected_window, "|END|>").unwrap();
+ }
+
+ let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
+
+ if let Some(max_token_count) = max_token_count {
+ let selected_tokens = model.count_tokens(&selected_window)?;
+ if selected_tokens > max_token_count {
+ return Err(anyhow!(
+ "selected range is greater than model context window, truncation not possible"
+ ));
+ };
+
+ let mut remaining_tokens = max_token_count - selected_tokens;
+ let start_window_tokens = model.count_tokens(&start_window)?;
+ let end_window_tokens = model.count_tokens(&end_window)?;
+ let outside_tokens = start_window_tokens + end_window_tokens;
+ if outside_tokens > remaining_tokens {
+ let (start_goal_tokens, end_goal_tokens) =
+ if start_window_tokens < end_window_tokens {
+ let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
+ remaining_tokens -= start_goal_tokens;
+ let end_goal_tokens = remaining_tokens.min(end_window_tokens);
+ (start_goal_tokens, end_goal_tokens)
+ } else {
+ let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
+ remaining_tokens -= end_goal_tokens;
+ let start_goal_tokens = remaining_tokens.min(start_window_tokens);
+ (start_goal_tokens, end_goal_tokens)
+ };
+
+ let truncated_start_window =
+ model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
+ let truncated_end_window =
+ model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
+ writeln!(
+ prompt,
+ "{truncated_start_window}{selected_window}{truncated_end_window}"
+ )
+ .unwrap();
+ truncated = true;
+ } else {
+ writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
+ }
+ } else {
+ // If we dont have a selected range, include entire file.
+ writeln!(prompt, "{}", &buffer.text()).unwrap();
+
+ // Dumb truncation strategy
+ if let Some(max_token_count) = max_token_count {
+ if model.count_tokens(&prompt)? > max_token_count {
+ truncated = true;
+ prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
+ }
+ }
+ }
+ }
+
+ let token_count = model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count, truncated))
+}
+
+pub struct FileContext {}
+
+impl PromptTemplate for FileContext {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ if let Some(buffer) = &args.buffer {
+ let mut prompt = String::new();
+ // Add Initial Preamble
+ // TODO: Do we want to add the path in here?
+ writeln!(
+ prompt,
+ "The file you are currently working on has the following content:"
+ )
+ .unwrap();
+
+ let language_name = args
+ .language_name
+ .clone()
+ .unwrap_or("".to_string())
+ .to_lowercase();
+
+ let (context, _, truncated) = retrieve_context(
+ buffer,
+ &args.selected_range,
+ args.model.clone(),
+ max_token_length,
+ )?;
+ writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
+
+ if truncated {
+ writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
+ }
+
+ if let Some(selected_range) = &args.selected_range {
+ let start = selected_range.start.to_offset(buffer);
+ let end = selected_range.end.to_offset(buffer);
+
+ if start == end {
+ writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
+ } else {
+ writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
+ }
+ }
+
+ // Really dumb truncation strategy
+ if let Some(max_tokens) = max_token_length {
+ prompt = args
+ .model
+ .truncate(&prompt, max_tokens, TruncationDirection::End)?;
+ }
+
+ let token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count))
+ } else {
+ Err(anyhow!("no buffer provided to retrieve file context from"))
+ }
+ }
+}
@@ -0,0 +1,99 @@
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
+use anyhow::anyhow;
+use std::fmt::Write;
+
+pub fn capitalize(s: &str) -> String {
+ let mut c = s.chars();
+ match c.next() {
+ None => String::new(),
+ Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
+ }
+}
+
+pub struct GenerateInlineContent {}
+
+impl PromptTemplate for GenerateInlineContent {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let Some(user_prompt) = &args.user_prompt else {
+ return Err(anyhow!("user prompt not provided"));
+ };
+
+ let file_type = args.get_file_type();
+ let content_type = match &file_type {
+ PromptFileType::Code => "code",
+ PromptFileType::Text => "text",
+ };
+
+ let mut prompt = String::new();
+
+ if let Some(selected_range) = &args.selected_range {
+ if selected_range.start == selected_range.end {
+ writeln!(
+ prompt,
+ "Assume the cursor is located where the `<|START|>` span is."
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "{} can't be replaced, so assume your answer will be inserted at the cursor.",
+ capitalize(content_type)
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "Generate {content_type} based on the users prompt: {user_prompt}",
+ )
+ .unwrap();
+ } else {
+ writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
+ writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
+ writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
+ }
+ } else {
+ writeln!(
+ prompt,
+ "Generate {content_type} based on the users prompt: {user_prompt}"
+ )
+ .unwrap();
+ }
+
+ if let Some(language_name) = &args.language_name {
+ writeln!(
+ prompt,
+ "Your answer MUST always and only be valid {}.",
+ language_name
+ )
+ .unwrap();
+ }
+ writeln!(prompt, "Never make remarks about the output.").unwrap();
+ writeln!(
+ prompt,
+ "Do not return anything else, except the generated {content_type}."
+ )
+ .unwrap();
+
+ match file_type {
+ PromptFileType::Code => {
+ // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
+ }
+ _ => {}
+ }
+
+ // Really dumb truncation strategy
+ if let Some(max_tokens) = max_token_length {
+ prompt = args.model.truncate(
+ &prompt,
+ max_tokens,
+ crate::models::TruncationDirection::End,
+ )?;
+ }
+
+ let token_count = args.model.count_tokens(&prompt)?;
+
+ anyhow::Ok((prompt, token_count))
+ }
+}
@@ -0,0 +1,5 @@
+pub mod base;
+pub mod file_context;
+pub mod generate;
+pub mod preamble;
+pub mod repository_context;
@@ -0,0 +1,52 @@
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
+use std::fmt::Write;
+
+pub struct EngineerPreamble {}
+
+impl PromptTemplate for EngineerPreamble {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut prompts = Vec::new();
+
+ match args.get_file_type() {
+ PromptFileType::Code => {
+ prompts.push(format!(
+ "You are an expert {}engineer.",
+ args.language_name.clone().unwrap_or("".to_string()) + " "
+ ));
+ }
+ PromptFileType::Text => {
+ prompts.push("You are an expert engineer.".to_string());
+ }
+ }
+
+ if let Some(project_name) = args.project_name.clone() {
+ prompts.push(format!(
+ "You are currently working inside the '{project_name}' project in code editor Zed."
+ ));
+ }
+
+ if let Some(mut remaining_tokens) = max_token_length {
+ let mut prompt = String::new();
+ let mut total_count = 0;
+ for prompt_piece in prompts {
+ let prompt_token_count =
+ args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
+ if remaining_tokens > prompt_token_count {
+ writeln!(prompt, "{prompt_piece}").unwrap();
+ remaining_tokens -= prompt_token_count;
+ total_count += prompt_token_count;
+ }
+ }
+
+ anyhow::Ok((prompt, total_count))
+ } else {
+ let prompt = prompts.join("\n");
+ let token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count))
+ }
+ }
+}
@@ -0,0 +1,98 @@
+use crate::prompts::base::{PromptArguments, PromptTemplate};
+use std::fmt::Write;
+use std::{ops::Range, path::PathBuf};
+
+use gpui2::{AsyncAppContext, Model};
+use language2::{Anchor, Buffer};
+
+#[derive(Clone)]
+pub struct PromptCodeSnippet {
+ path: Option<PathBuf>,
+ language_name: Option<String>,
+ content: String,
+}
+
+impl PromptCodeSnippet {
+ pub fn new(
+ buffer: Model<Buffer>,
+ range: Range<Anchor>,
+ cx: &mut AsyncAppContext,
+ ) -> anyhow::Result<Self> {
+ let (content, language_name, file_path) = buffer.update(cx, |buffer, _| {
+ let snapshot = buffer.snapshot();
+ let content = snapshot.text_for_range(range.clone()).collect::<String>();
+
+ let language_name = buffer
+ .language()
+ .and_then(|language| Some(language.name().to_string().to_lowercase()));
+
+ let file_path = buffer
+ .file()
+ .and_then(|file| Some(file.path().to_path_buf()));
+
+ (content, language_name, file_path)
+ })?;
+
+ anyhow::Ok(PromptCodeSnippet {
+ path: file_path,
+ language_name,
+ content,
+ })
+ }
+}
+
+impl ToString for PromptCodeSnippet {
+ fn to_string(&self) -> String {
+ let path = self
+ .path
+ .as_ref()
+ .and_then(|path| Some(path.to_string_lossy().to_string()))
+ .unwrap_or("".to_string());
+ let language_name = self.language_name.clone().unwrap_or("".to_string());
+ let content = self.content.clone();
+
+ format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+ }
+}
+
+pub struct RepositoryContext {}
+
+impl PromptTemplate for RepositoryContext {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+ let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
+ let mut prompt = String::new();
+
+ let mut remaining_tokens = max_token_length.clone();
+ let seperator_token_length = args.model.count_tokens("\n")?;
+ for snippet in &args.snippets {
+ let mut snippet_prompt = template.to_string();
+ let content = snippet.to_string();
+ writeln!(snippet_prompt, "{content}").unwrap();
+
+ let token_count = args.model.count_tokens(&snippet_prompt)?;
+ if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
+ if let Some(tokens_left) = remaining_tokens {
+ if tokens_left >= token_count {
+ writeln!(prompt, "{snippet_prompt}").unwrap();
+ remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
+ {
+ Some(tokens_left - token_count - seperator_token_length)
+ } else {
+ Some(0)
+ };
+ }
+ } else {
+ writeln!(prompt, "{snippet_prompt}").unwrap();
+ }
+ }
+ }
+
+ let total_token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, total_token_count))
+ }
+}
@@ -0,0 +1 @@
+pub mod open_ai;
@@ -0,0 +1,306 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::{
+ future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
+ Stream, StreamExt,
+};
+use gpui2::{AppContext, Executor};
+use isahc::{http::StatusCode, Request, RequestExt};
+use parking_lot::RwLock;
+use serde::{Deserialize, Serialize};
+use std::{
+ env,
+ fmt::{self, Display},
+ io,
+ sync::Arc,
+};
+use util::ResultExt;
+
+use crate::{
+ auth::{CredentialProvider, ProviderCredential},
+ completion::{CompletionProvider, CompletionRequest},
+ models::LanguageModel,
+};
+
+use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+ System,
+}
+
+impl Role {
+ pub fn cycle(&mut self) {
+ *self = match self {
+ Role::User => Role::Assistant,
+ Role::Assistant => Role::System,
+ Role::System => Role::User,
+ }
+ }
+}
+
+impl Display for Role {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Role::User => write!(f, "User"),
+ Role::Assistant => write!(f, "Assistant"),
+ Role::System => write!(f, "System"),
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+ pub role: Role,
+ pub content: String,
+}
+
+#[derive(Debug, Default, Serialize)]
+pub struct OpenAIRequest {
+ pub model: String,
+ pub messages: Vec<RequestMessage>,
+ pub stream: bool,
+ pub stop: Vec<String>,
+ pub temperature: f32,
+}
+
+impl CompletionRequest for OpenAIRequest {
+ fn data(&self) -> serde_json::Result<String> {
+ serde_json::to_string(self)
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessage {
+ pub role: Option<Role>,
+ pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIUsage {
+ pub prompt_tokens: u32,
+ pub completion_tokens: u32,
+ pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ChatChoiceDelta {
+ pub index: u32,
+ pub delta: ResponseMessage,
+ pub finish_reason: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIResponseStreamEvent {
+ pub id: Option<String>,
+ pub object: String,
+ pub created: u32,
+ pub model: String,
+ pub choices: Vec<ChatChoiceDelta>,
+ pub usage: Option<OpenAIUsage>,
+}
+
+pub async fn stream_completion(
+ credential: ProviderCredential,
+ executor: Arc<Executor>,
+ request: Box<dyn CompletionRequest>,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+ let api_key = match credential {
+ ProviderCredential::Credentials { api_key } => api_key,
+ _ => {
+ return Err(anyhow!("no credentials provider for completion"));
+ }
+ };
+
+ let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+ let json_data = request.data()?;
+ let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(json_data)?
+ .send_async()
+ .await?;
+
+ let status = response.status();
+ if status == StatusCode::OK {
+ executor
+ .spawn(async move {
+ let mut lines = BufReader::new(response.body_mut()).lines();
+
+ fn parse_line(
+ line: Result<String, io::Error>,
+ ) -> Result<Option<OpenAIResponseStreamEvent>> {
+ if let Some(data) = line?.strip_prefix("data: ") {
+ let event = serde_json::from_str(&data)?;
+ Ok(Some(event))
+ } else {
+ Ok(None)
+ }
+ }
+
+ while let Some(line) = lines.next().await {
+ if let Some(event) = parse_line(line).transpose() {
+ let done = event.as_ref().map_or(false, |event| {
+ event
+ .choices
+ .last()
+ .map_or(false, |choice| choice.finish_reason.is_some())
+ });
+ if tx.unbounded_send(event).is_err() {
+ break;
+ }
+
+ if done {
+ break;
+ }
+ }
+ }
+
+ anyhow::Ok(())
+ })
+ .detach();
+
+ Ok(rx)
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ #[derive(Deserialize)]
+ struct OpenAIResponse {
+ error: OpenAIError,
+ }
+
+ #[derive(Deserialize)]
+ struct OpenAIError {
+ message: String,
+ }
+
+ match serde_json::from_str::<OpenAIResponse>(&body) {
+ Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+ "Failed to connect to OpenAI API: {}",
+ response.error.message,
+ )),
+
+ _ => Err(anyhow!(
+ "Failed to connect to OpenAI API: {} {}",
+ response.status(),
+ body,
+ )),
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct OpenAICompletionProvider {
+ model: OpenAILanguageModel,
+ credential: Arc<RwLock<ProviderCredential>>,
+ executor: Arc<Executor>,
+}
+
+impl OpenAICompletionProvider {
+ pub fn new(model_name: &str, executor: Arc<Executor>) -> Self {
+ let model = OpenAILanguageModel::load(model_name);
+ let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+ Self {
+ model,
+ credential,
+ executor,
+ }
+ }
+}
+
+#[async_trait]
+impl CredentialProvider for OpenAICompletionProvider {
+ fn has_credentials(&self) -> bool {
+ match *self.credential.read() {
+ ProviderCredential::Credentials { .. } => true,
+ _ => false,
+ }
+ }
+ async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
+ let existing_credential = self.credential.read().clone();
+
+ let retrieved_credential = cx
+ .run_on_main(move |cx| match existing_credential {
+ ProviderCredential::Credentials { .. } => {
+ return existing_credential.clone();
+ }
+ _ => {
+ if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
+ return ProviderCredential::Credentials { api_key };
+ }
+
+ if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
+ {
+ if let Some(api_key) = String::from_utf8(api_key).log_err() {
+ return ProviderCredential::Credentials { api_key };
+ } else {
+ return ProviderCredential::NoCredentials;
+ }
+ } else {
+ return ProviderCredential::NoCredentials;
+ }
+ }
+ })
+ .await;
+
+ *self.credential.write() = retrieved_credential.clone();
+ retrieved_credential
+ }
+
+ async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
+ *self.credential.write() = credential.clone();
+ let credential = credential.clone();
+ cx.run_on_main(move |cx| match credential {
+ ProviderCredential::Credentials { api_key } => {
+ cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ .log_err();
+ }
+ _ => {}
+ })
+ .await;
+ }
+ async fn delete_credentials(&self, cx: &mut AppContext) {
+ cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
+ .await;
+ *self.credential.write() = ProviderCredential::NoCredentials;
+ }
+}
+
+impl CompletionProvider for OpenAICompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+ model
+ }
+ fn complete(
+ &self,
+ prompt: Box<dyn CompletionRequest>,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
+ // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
+ // which is currently model based, due to the langauge model.
+ // At some point in the future we should rectify this.
+ let credential = self.credential.read().clone();
+ let request = stream_completion(credential, self.executor.clone(), prompt);
+ async move {
+ let response = request.await?;
+ let stream = response
+ .filter_map(|response| async move {
+ match response {
+ Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed();
+ Ok(stream)
+ }
+ .boxed()
+ }
+ fn box_clone(&self) -> Box<dyn CompletionProvider> {
+ Box::new((*self).clone())
+ }
+}
@@ -0,0 +1,313 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::AsyncReadExt;
+use gpui2::Executor;
+use gpui2::{serde_json, AppContext};
+use isahc::http::StatusCode;
+use isahc::prelude::Configurable;
+use isahc::{AsyncBody, Response};
+use lazy_static::lazy_static;
+use parking_lot::{Mutex, RwLock};
+use parse_duration::parse;
+use postage::watch;
+use serde::{Deserialize, Serialize};
+use std::env;
+use std::ops::Add;
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use tiktoken_rs::{cl100k_base, CoreBPE};
+use util::http::{HttpClient, Request};
+use util::ResultExt;
+
+use crate::auth::{CredentialProvider, ProviderCredential};
+use crate::embedding::{Embedding, EmbeddingProvider};
+use crate::models::LanguageModel;
+use crate::providers::open_ai::OpenAILanguageModel;
+
+use crate::providers::open_ai::OPENAI_API_URL;
+
+lazy_static! {
+ static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
+}
+
+#[derive(Clone)]
+pub struct OpenAIEmbeddingProvider {
+ model: OpenAILanguageModel,
+ credential: Arc<RwLock<ProviderCredential>>,
+ pub client: Arc<dyn HttpClient>,
+ pub executor: Arc<Executor>,
+ rate_limit_count_rx: watch::Receiver<Option<Instant>>,
+ rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
+}
+
+#[derive(Serialize)]
+struct OpenAIEmbeddingRequest<'a> {
+ model: &'static str,
+ input: Vec<&'a str>,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingResponse {
+ data: Vec<OpenAIEmbedding>,
+ usage: OpenAIEmbeddingUsage,
+}
+
+#[derive(Debug, Deserialize)]
+struct OpenAIEmbedding {
+ embedding: Vec<f32>,
+ index: usize,
+ object: String,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingUsage {
+ prompt_tokens: usize,
+ total_tokens: usize,
+}
+
+impl OpenAIEmbeddingProvider {
+ pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Executor>) -> Self {
+ let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
+ let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
+
+ let model = OpenAILanguageModel::load("text-embedding-ada-002");
+ let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+
+ OpenAIEmbeddingProvider {
+ model,
+ credential,
+ client,
+ executor,
+ rate_limit_count_rx,
+ rate_limit_count_tx,
+ }
+ }
+
+ fn get_api_key(&self) -> Result<String> {
+ match self.credential.read().clone() {
+ ProviderCredential::Credentials { api_key } => Ok(api_key),
+ _ => Err(anyhow!("api credentials not provided")),
+ }
+ }
+
+ fn resolve_rate_limit(&self) {
+ let reset_time = *self.rate_limit_count_tx.lock().borrow();
+
+ if let Some(reset_time) = reset_time {
+ if Instant::now() >= reset_time {
+ *self.rate_limit_count_tx.lock().borrow_mut() = None
+ }
+ }
+
+ log::trace!(
+ "resolving reset time: {:?}",
+ *self.rate_limit_count_tx.lock().borrow()
+ );
+ }
+
+ fn update_reset_time(&self, reset_time: Instant) {
+ let original_time = *self.rate_limit_count_tx.lock().borrow();
+
+ let updated_time = if let Some(original_time) = original_time {
+ if reset_time < original_time {
+ Some(reset_time)
+ } else {
+ Some(original_time)
+ }
+ } else {
+ Some(reset_time)
+ };
+
+ log::trace!("updating rate limit time: {:?}", updated_time);
+
+ *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
+ }
+ async fn send_request(
+ &self,
+ api_key: &str,
+ spans: Vec<&str>,
+ request_timeout: u64,
+ ) -> Result<Response<AsyncBody>> {
+ let request = Request::post("https://api.openai.com/v1/embeddings")
+ .redirect_policy(isahc::config::RedirectPolicy::Follow)
+ .timeout(Duration::from_secs(request_timeout))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(
+ serde_json::to_string(&OpenAIEmbeddingRequest {
+ input: spans.clone(),
+ model: "text-embedding-ada-002",
+ })
+ .unwrap()
+ .into(),
+ )?;
+
+ Ok(self.client.send(request).await?)
+ }
+}
+
+#[async_trait]
+impl CredentialProvider for OpenAIEmbeddingProvider {
+ fn has_credentials(&self) -> bool {
+ match *self.credential.read() {
+ ProviderCredential::Credentials { .. } => true,
+ _ => false,
+ }
+ }
+ async fn retrieve_credentials(&self, cx: &mut AppContext) -> ProviderCredential {
+ let existing_credential = self.credential.read().clone();
+
+ let retrieved_credential = cx
+ .run_on_main(move |cx| match existing_credential {
+ ProviderCredential::Credentials { .. } => {
+ return existing_credential.clone();
+ }
+ _ => {
+ if let Some(api_key) = env::var("OPENAI_API_KEY").log_err() {
+ return ProviderCredential::Credentials { api_key };
+ }
+
+ if let Some(Some((_, api_key))) = cx.read_credentials(OPENAI_API_URL).log_err()
+ {
+ if let Some(api_key) = String::from_utf8(api_key).log_err() {
+ return ProviderCredential::Credentials { api_key };
+ } else {
+ return ProviderCredential::NoCredentials;
+ }
+ } else {
+ return ProviderCredential::NoCredentials;
+ }
+ }
+ })
+ .await;
+
+ *self.credential.write() = retrieved_credential.clone();
+ retrieved_credential
+ }
+
+ async fn save_credentials(&self, cx: &mut AppContext, credential: ProviderCredential) {
+ *self.credential.write() = credential.clone();
+ let credential = credential.clone();
+ cx.run_on_main(move |cx| match credential {
+ ProviderCredential::Credentials { api_key } => {
+ cx.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ .log_err();
+ }
+ _ => {}
+ })
+ .await;
+ }
+ async fn delete_credentials(&self, cx: &mut AppContext) {
+ cx.run_on_main(move |cx| cx.delete_credentials(OPENAI_API_URL).log_err())
+ .await;
+ *self.credential.write() = ProviderCredential::NoCredentials;
+ }
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddingProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+ model
+ }
+
+ fn max_tokens_per_batch(&self) -> usize {
+ 50000
+ }
+
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ *self.rate_limit_count_rx.borrow()
+ }
+
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+ const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
+ const MAX_RETRIES: usize = 4;
+
+ let api_key = self.get_api_key()?;
+
+ let mut request_number = 0;
+ let mut rate_limiting = false;
+ let mut request_timeout: u64 = 15;
+ let mut response: Response<AsyncBody>;
+ while request_number < MAX_RETRIES {
+ response = self
+ .send_request(
+ &api_key,
+ spans.iter().map(|x| &**x).collect(),
+ request_timeout,
+ )
+ .await?;
+
+ request_number += 1;
+
+ match response.status() {
+ StatusCode::REQUEST_TIMEOUT => {
+ request_timeout += 5;
+ }
+ StatusCode::OK => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+ log::trace!(
+ "openai embedding completed. tokens: {:?}",
+ response.usage.total_tokens
+ );
+
+ // If we complete a request successfully that was previously rate_limited
+ // resolve the rate limit
+ if rate_limiting {
+ self.resolve_rate_limit()
+ }
+
+ return Ok(response
+ .data
+ .into_iter()
+ .map(|embedding| Embedding::from(embedding.embedding))
+ .collect());
+ }
+ StatusCode::TOO_MANY_REQUESTS => {
+ rate_limiting = true;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ let delay_duration = {
+ let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+ if let Some(time_to_reset) =
+ response.headers().get("x-ratelimit-reset-tokens")
+ {
+ if let Ok(time_str) = time_to_reset.to_str() {
+ parse(time_str).unwrap_or(delay)
+ } else {
+ delay
+ }
+ } else {
+ delay
+ }
+ };
+
+ // If we've previously rate limited, increment the duration but not the count
+ let reset_time = Instant::now().add(delay_duration);
+ self.update_reset_time(reset_time);
+
+ log::trace!(
+ "openai rate limiting: waiting {:?} until lifted",
+ &delay_duration
+ );
+
+ self.executor.timer(delay_duration).await;
+ }
+ _ => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow!(
+ "open ai bad request: {:?} {:?}",
+ &response.status(),
+ body
+ ));
+ }
+ }
+ }
+ Err(anyhow!("openai max retries"))
+ }
+}
@@ -0,0 +1,9 @@
+pub mod completion;
+pub mod embedding;
+pub mod model;
+
+pub use completion::*;
+pub use embedding::*;
+pub use model::OpenAILanguageModel;
+
+pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
@@ -0,0 +1,57 @@
+use anyhow::anyhow;
+use tiktoken_rs::CoreBPE;
+use util::ResultExt;
+
+use crate::models::{LanguageModel, TruncationDirection};
+
+#[derive(Clone)]
+pub struct OpenAILanguageModel {
+ name: String,
+ bpe: Option<CoreBPE>,
+}
+
+impl OpenAILanguageModel {
+ pub fn load(model_name: &str) -> Self {
+ let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
+ OpenAILanguageModel {
+ name: model_name.to_string(),
+ bpe,
+ }
+ }
+}
+
+impl LanguageModel for OpenAILanguageModel {
+ fn name(&self) -> String {
+ self.name.clone()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ if let Some(bpe) = &self.bpe {
+ anyhow::Ok(bpe.encode_with_special_tokens(content).len())
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
+ if let Some(bpe) = &self.bpe {
+ let tokens = bpe.encode_with_special_tokens(content);
+ if tokens.len() > length {
+ match direction {
+ TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
+ TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
+ }
+ } else {
+ bpe.decode(tokens)
+ }
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
+ }
+}
@@ -0,0 +1,11 @@
+pub trait LanguageModel {
+ fn name(&self) -> String;
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String>;
+ fn capacity(&self) -> anyhow::Result<usize>;
+}
@@ -0,0 +1,193 @@
+use std::{
+ sync::atomic::{self, AtomicUsize, Ordering},
+ time::Instant,
+};
+
+use async_trait::async_trait;
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui2::AppContext;
+use parking_lot::Mutex;
+
+use crate::{
+ auth::{CredentialProvider, ProviderCredential},
+ completion::{CompletionProvider, CompletionRequest},
+ embedding::{Embedding, EmbeddingProvider},
+ models::{LanguageModel, TruncationDirection},
+};
+
+#[derive(Clone)]
+pub struct FakeLanguageModel {
+ pub capacity: usize,
+}
+
+impl LanguageModel for FakeLanguageModel {
+ fn name(&self) -> String {
+ "dummy".to_string()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+ }
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
+ println!("TRYING TO TRUNCATE: {:?}", length.clone());
+
+ if length > self.count_tokens(content)? {
+ println!("NOT TRUNCATING");
+ return anyhow::Ok(content.to_string());
+ }
+
+ anyhow::Ok(match direction {
+ TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
+ .into_iter()
+ .collect::<String>(),
+ TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
+ .into_iter()
+ .collect::<String>(),
+ })
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(self.capacity)
+ }
+}
+
+pub struct FakeEmbeddingProvider {
+ pub embedding_count: AtomicUsize,
+}
+
+impl Clone for FakeEmbeddingProvider {
+ fn clone(&self) -> Self {
+ FakeEmbeddingProvider {
+ embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
+ }
+ }
+}
+
+impl Default for FakeEmbeddingProvider {
+ fn default() -> Self {
+ FakeEmbeddingProvider {
+ embedding_count: AtomicUsize::default(),
+ }
+ }
+}
+
+impl FakeEmbeddingProvider {
+ pub fn embedding_count(&self) -> usize {
+ self.embedding_count.load(atomic::Ordering::SeqCst)
+ }
+
+ pub fn embed_sync(&self, span: &str) -> Embedding {
+ let mut result = vec![1.0; 26];
+ for letter in span.chars() {
+ let letter = letter.to_ascii_lowercase();
+ if letter as u32 >= 'a' as u32 {
+ let ix = (letter as u32) - ('a' as u32);
+ if ix < 26 {
+ result[ix as usize] += 1.0;
+ }
+ }
+ }
+
+ let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+ for x in &mut result {
+ *x /= norm;
+ }
+
+ result.into()
+ }
+}
+
+#[async_trait]
+impl CredentialProvider for FakeEmbeddingProvider {
+ fn has_credentials(&self) -> bool {
+ true
+ }
+ async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
+ ProviderCredential::NotNeeded
+ }
+ async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
+ async fn delete_credentials(&self, _cx: &mut AppContext) {}
+}
+
+#[async_trait]
+impl EmbeddingProvider for FakeEmbeddingProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ Box::new(FakeLanguageModel { capacity: 1000 })
+ }
+ fn max_tokens_per_batch(&self) -> usize {
+ 1000
+ }
+
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ None
+ }
+
+ async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
+ self.embedding_count
+ .fetch_add(spans.len(), atomic::Ordering::SeqCst);
+
+ anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
+ }
+}
+
+pub struct FakeCompletionProvider {
+ last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
+}
+
+impl Clone for FakeCompletionProvider {
+ fn clone(&self) -> Self {
+ Self {
+ last_completion_tx: Mutex::new(None),
+ }
+ }
+}
+
+impl FakeCompletionProvider {
+ pub fn new() -> Self {
+ Self {
+ last_completion_tx: Mutex::new(None),
+ }
+ }
+
+ pub fn send_completion(&self, completion: impl Into<String>) {
+ let mut tx = self.last_completion_tx.lock();
+ tx.as_mut().unwrap().try_send(completion.into()).unwrap();
+ }
+
+ pub fn finish_completion(&self) {
+ self.last_completion_tx.lock().take().unwrap();
+ }
+}
+
+#[async_trait]
+impl CredentialProvider for FakeCompletionProvider {
+ fn has_credentials(&self) -> bool {
+ true
+ }
+ async fn retrieve_credentials(&self, _cx: &mut AppContext) -> ProviderCredential {
+ ProviderCredential::NotNeeded
+ }
+ async fn save_credentials(&self, _cx: &mut AppContext, _credential: ProviderCredential) {}
+ async fn delete_credentials(&self, _cx: &mut AppContext) {}
+}
+
+impl CompletionProvider for FakeCompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
+ model
+ }
+ fn complete(
+ &self,
+ _prompt: Box<dyn CompletionRequest>,
+ ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
+ let (tx, rx) = mpsc::channel(1);
+ *self.last_completion_tx.lock() = Some(tx);
+ async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
+ }
+ fn box_clone(&self) -> Box<dyn CompletionProvider> {
+ Box::new((*self).clone())
+ }
+}
@@ -45,6 +45,7 @@ tiktoken-rs = "0.5"
[dev-dependencies]
editor = { path = "../editor", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
+ai = { path = "../ai", features = ["test-support"]}
ctor.workspace = true
env_logger.workspace = true
@@ -4,7 +4,7 @@ mod codegen;
mod prompts;
mod streaming_diff;
-use ai::completion::Role;
+use ai::providers::open_ai::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel;
@@ -5,12 +5,14 @@ use crate::{
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
+
use ai::{
- completion::{
- stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
- },
- templates::repository_context::PromptCodeSnippet,
+ auth::ProviderCredential,
+ completion::{CompletionProvider, CompletionRequest},
+ providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage},
};
+
+use ai::prompts::repository_context::PromptCodeSnippet;
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
@@ -43,8 +45,8 @@ use search::BufferSearchBar;
use semantic_index::{SemanticIndex, SemanticIndexStatus};
use settings::SettingsStore;
use std::{
- cell::{Cell, RefCell},
- cmp, env,
+ cell::Cell,
+ cmp,
fmt::Write,
iter,
ops::Range,
@@ -97,8 +99,8 @@ pub fn init(cx: &mut AppContext) {
cx.capture_action(ConversationEditor::copy);
cx.add_action(ConversationEditor::split);
cx.capture_action(ConversationEditor::cycle_message_role);
- cx.add_action(AssistantPanel::save_api_key);
- cx.add_action(AssistantPanel::reset_api_key);
+ cx.add_action(AssistantPanel::save_credentials);
+ cx.add_action(AssistantPanel::reset_credentials);
cx.add_action(AssistantPanel::toggle_zoom);
cx.add_action(AssistantPanel::deploy);
cx.add_action(AssistantPanel::select_next_match);
@@ -140,9 +142,8 @@ pub struct AssistantPanel {
zoomed: bool,
has_focus: bool,
toolbar: ViewHandle<Toolbar>,
- api_key: Rc<RefCell<Option<String>>>,
+ completion_provider: Box<dyn CompletionProvider>,
api_key_editor: Option<ViewHandle<Editor>>,
- has_read_credentials: bool,
languages: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
subscriptions: Vec<Subscription>,
@@ -202,6 +203,11 @@ impl AssistantPanel {
});
let semantic_index = SemanticIndex::global(cx);
+ // Defaulting currently to GPT4, allow for this to be set via config.
+ let completion_provider = Box::new(OpenAICompletionProvider::new(
+ "gpt-4",
+ cx.background().clone(),
+ ));
let mut this = Self {
workspace: workspace_handle,
@@ -213,9 +219,8 @@ impl AssistantPanel {
zoomed: false,
has_focus: false,
toolbar,
- api_key: Rc::new(RefCell::new(None)),
+ completion_provider,
api_key_editor: None,
- has_read_credentials: false,
languages: workspace.app_state().languages.clone(),
fs: workspace.app_state().fs.clone(),
width: None,
@@ -254,10 +259,7 @@ impl AssistantPanel {
cx: &mut ViewContext<Workspace>,
) {
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
- if this
- .update(cx, |assistant, cx| assistant.load_api_key(cx))
- .is_some()
- {
+ if this.update(cx, |assistant, _| assistant.has_credentials()) {
this
} else {
workspace.focus_panel::<AssistantPanel>(cx);
@@ -289,12 +291,6 @@ impl AssistantPanel {
cx: &mut ViewContext<Self>,
project: &ModelHandle<Project>,
) {
- let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
- api_key
- } else {
- return;
- };
-
let selection = editor.read(cx).selections.newest_anchor().clone();
if selection.start.excerpt_id != selection.end.excerpt_id {
return;
@@ -325,10 +321,13 @@ impl AssistantPanel {
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let provider = Arc::new(OpenAICompletionProvider::new(
- api_key,
+ "gpt-4",
cx.background().clone(),
));
+ // Retrieve Credentials Authenticates the Provider
+ // provider.retrieve_credentials(cx);
+
let codegen = cx.add_model(|cx| {
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
});
@@ -745,13 +744,14 @@ impl AssistantPanel {
content: prompt,
});
- let request = OpenAIRequest {
+ let request = Box::new(OpenAIRequest {
model: model.full_name().into(),
messages,
stream: true,
stop: vec!["|END|>".to_string()],
temperature,
- };
+ });
+
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
anyhow::Ok(())
})
@@ -811,7 +811,7 @@ impl AssistantPanel {
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
let editor = cx.add_view(|cx| {
ConversationEditor::new(
- self.api_key.clone(),
+ self.completion_provider.clone(),
self.languages.clone(),
self.fs.clone(),
self.workspace.clone(),
@@ -870,17 +870,19 @@ impl AssistantPanel {
}
}
- fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+ fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
if let Some(api_key) = self
.api_key_editor
.as_ref()
.map(|editor| editor.read(cx).text(cx))
{
if !api_key.is_empty() {
- cx.platform()
- .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
- .log_err();
- *self.api_key.borrow_mut() = Some(api_key);
+ let credential = ProviderCredential::Credentials {
+ api_key: api_key.clone(),
+ };
+
+ self.completion_provider.save_credentials(cx, credential);
+
self.api_key_editor.take();
cx.focus_self();
cx.notify();
@@ -890,9 +892,8 @@ impl AssistantPanel {
}
}
- fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
- cx.platform().delete_credentials(OPENAI_API_URL).log_err();
- self.api_key.take();
+ fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
+ self.completion_provider.delete_credentials(cx);
self.api_key_editor = Some(build_api_key_editor(cx));
cx.focus_self();
cx.notify();
@@ -1151,13 +1152,12 @@ impl AssistantPanel {
let fs = self.fs.clone();
let workspace = self.workspace.clone();
- let api_key = self.api_key.clone();
let languages = self.languages.clone();
cx.spawn(|this, mut cx| async move {
let saved_conversation = fs.load(&path).await?;
let saved_conversation = serde_json::from_str(&saved_conversation)?;
let conversation = cx.add_model(|cx| {
- Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
+ Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
});
this.update(&mut cx, |this, cx| {
// If, by the time we've loaded the conversation, the user has already opened
@@ -1181,30 +1181,12 @@ impl AssistantPanel {
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
}
- fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
- if self.api_key.borrow().is_none() && !self.has_read_credentials {
- self.has_read_credentials = true;
- let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- Some(api_key)
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
- {
- String::from_utf8(api_key).log_err()
- } else {
- None
- };
- if let Some(api_key) = api_key {
- *self.api_key.borrow_mut() = Some(api_key);
- } else if self.api_key_editor.is_none() {
- self.api_key_editor = Some(build_api_key_editor(cx));
- cx.notify();
- }
- }
+ fn has_credentials(&mut self) -> bool {
+ self.completion_provider.has_credentials()
+ }
- self.api_key.borrow().clone()
+ fn load_credentials(&mut self, cx: &mut ViewContext<Self>) {
+ self.completion_provider.retrieve_credentials(cx);
}
}
@@ -1389,7 +1371,7 @@ impl Panel for AssistantPanel {
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
if active {
- self.load_api_key(cx);
+ self.load_credentials(cx);
if self.editors.is_empty() {
self.new_conversation(cx);
@@ -1454,10 +1436,10 @@ struct Conversation {
token_count: Option<usize>,
max_token_count: usize,
pending_token_count: Task<Option<()>>,
- api_key: Rc<RefCell<Option<String>>>,
pending_save: Task<Result<()>>,
path: Option<PathBuf>,
_subscriptions: Vec<Subscription>,
+ completion_provider: Box<dyn CompletionProvider>,
}
impl Entity for Conversation {
@@ -1466,9 +1448,9 @@ impl Entity for Conversation {
impl Conversation {
fn new(
- api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
+ completion_provider: Box<dyn CompletionProvider>,
) -> Self {
let markdown = language_registry.language_for_name("Markdown");
let buffer = cx.add_model(|cx| {
@@ -1507,8 +1489,8 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: None,
- api_key,
buffer,
+ completion_provider,
};
let message = MessageAnchor {
id: MessageId(post_inc(&mut this.next_message_id.0)),
@@ -1554,7 +1536,6 @@ impl Conversation {
fn deserialize(
saved_conversation: SavedConversation,
path: PathBuf,
- api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
@@ -1563,6 +1544,10 @@ impl Conversation {
None => Some(Uuid::new_v4().to_string()),
};
let model = saved_conversation.model;
+ let completion_provider: Box<dyn CompletionProvider> = Box::new(
+ OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
+ );
+ completion_provider.retrieve_credentials(cx);
let markdown = language_registry.language_for_name("Markdown");
let mut message_anchors = Vec::new();
let mut next_message_id = MessageId(0);
@@ -1609,8 +1594,8 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: Some(path),
- api_key,
buffer,
+ completion_provider,
};
this.count_remaining_tokens(cx);
this
@@ -1731,11 +1716,11 @@ impl Conversation {
}
if should_assist {
- let Some(api_key) = self.api_key.borrow().clone() else {
+ if !self.completion_provider.has_credentials() {
return Default::default();
- };
+ }
- let request = OpenAIRequest {
+ let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
model: self.model.full_name().to_string(),
messages: self
.messages(cx)
@@ -1745,9 +1730,9 @@ impl Conversation {
stream: true,
stop: vec![],
temperature: 1.0,
- };
+ });
- let stream = stream_completion(api_key, cx.background().clone(), request);
+ let stream = self.completion_provider.complete(request);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@@ -1765,33 +1750,28 @@ impl Conversation {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
- let mut message = message?;
- if let Some(choice) = message.choices.pop() {
- this.upgrade(&cx)
- .ok_or_else(|| anyhow!("conversation was dropped"))?
- .update(&mut cx, |this, cx| {
- let text: Arc<str> = choice.delta.content?.into();
- let message_ix =
- this.message_anchors.iter().position(|message| {
- message.id == assistant_message_id
- })?;
- this.buffer.update(cx, |buffer, cx| {
- let offset = this.message_anchors[message_ix + 1..]
- .iter()
- .find(|message| message.start.is_valid(buffer))
- .map_or(buffer.len(), |message| {
- message
- .start
- .to_offset(buffer)
- .saturating_sub(1)
- });
- buffer.edit([(offset..offset, text)], None, cx);
- });
- cx.emit(ConversationEvent::StreamedCompletion);
-
- Some(())
+ let text = message?;
+
+ this.upgrade(&cx)
+ .ok_or_else(|| anyhow!("conversation was dropped"))?
+ .update(&mut cx, |this, cx| {
+ let message_ix = this
+ .message_anchors
+ .iter()
+ .position(|message| message.id == assistant_message_id)?;
+ this.buffer.update(cx, |buffer, cx| {
+ let offset = this.message_anchors[message_ix + 1..]
+ .iter()
+ .find(|message| message.start.is_valid(buffer))
+ .map_or(buffer.len(), |message| {
+ message.start.to_offset(buffer).saturating_sub(1)
+ });
+ buffer.edit([(offset..offset, text)], None, cx);
});
- }
+ cx.emit(ConversationEvent::StreamedCompletion);
+
+ Some(())
+ });
smol::future::yield_now().await;
}
@@ -2013,57 +1993,54 @@ impl Conversation {
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
if self.message_anchors.len() >= 2 && self.summary.is_none() {
- let api_key = self.api_key.borrow().clone();
- if let Some(api_key) = api_key {
- let messages = self
- .messages(cx)
- .take(2)
- .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
- .chain(Some(RequestMessage {
- role: Role::User,
- content:
- "Summarize the conversation into a short title without punctuation"
- .into(),
- }));
- let request = OpenAIRequest {
- model: self.model.full_name().to_string(),
- messages: messages.collect(),
- stream: true,
- stop: vec![],
- temperature: 1.0,
- };
+ if !self.completion_provider.has_credentials() {
+ return;
+ }
- let stream = stream_completion(api_key, cx.background().clone(), request);
- self.pending_summary = cx.spawn(|this, mut cx| {
- async move {
- let mut messages = stream.await?;
+ let messages = self
+ .messages(cx)
+ .take(2)
+ .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
+ .chain(Some(RequestMessage {
+ role: Role::User,
+ content: "Summarize the conversation into a short title without punctuation"
+ .into(),
+ }));
+ let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
+ model: self.model.full_name().to_string(),
+ messages: messages.collect(),
+ stream: true,
+ stop: vec![],
+ temperature: 1.0,
+ });
- while let Some(message) = messages.next().await {
- let mut message = message?;
- if let Some(choice) = message.choices.pop() {
- let text = choice.delta.content.unwrap_or_default();
- this.update(&mut cx, |this, cx| {
- this.summary
- .get_or_insert(Default::default())
- .text
- .push_str(&text);
- cx.emit(ConversationEvent::SummaryChanged);
- });
- }
- }
+ let stream = self.completion_provider.complete(request);
+ self.pending_summary = cx.spawn(|this, mut cx| {
+ async move {
+ let mut messages = stream.await?;
+ while let Some(message) = messages.next().await {
+ let text = message?;
this.update(&mut cx, |this, cx| {
- if let Some(summary) = this.summary.as_mut() {
- summary.done = true;
- cx.emit(ConversationEvent::SummaryChanged);
- }
+ this.summary
+ .get_or_insert(Default::default())
+ .text
+ .push_str(&text);
+ cx.emit(ConversationEvent::SummaryChanged);
});
-
- anyhow::Ok(())
}
- .log_err()
- });
- }
+
+ this.update(&mut cx, |this, cx| {
+ if let Some(summary) = this.summary.as_mut() {
+ summary.done = true;
+ cx.emit(ConversationEvent::SummaryChanged);
+ }
+ });
+
+ anyhow::Ok(())
+ }
+ .log_err()
+ });
}
}
@@ -2224,13 +2201,14 @@ struct ConversationEditor {
impl ConversationEditor {
fn new(
- api_key: Rc<RefCell<Option<String>>>,
+ completion_provider: Box<dyn CompletionProvider>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
workspace: WeakViewHandle<Workspace>,
cx: &mut ViewContext<Self>,
) -> Self {
- let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
+ let conversation =
+ cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider));
Self::for_conversation(conversation, fs, workspace, cx)
}
@@ -3419,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
mod tests {
use super::*;
use crate::MessageId;
+ use ai::test::FakeCompletionProvider;
use gpui::AppContext;
#[gpui::test]
@@ -3426,7 +3405,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
- let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+
+ let completion_provider = Box::new(FakeCompletionProvider::new());
+ let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3554,7 +3535,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
- let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+ let completion_provider = Box::new(FakeCompletionProvider::new());
+
+ let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3650,7 +3633,8 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
- let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+ let completion_provider = Box::new(FakeCompletionProvider::new());
+ let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3732,8 +3716,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
+ let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation =
- cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
+ cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_0 = conversation.read(cx).message_anchors[0].id;
let message_1 = conversation.update(cx, |conversation, cx| {
@@ -3770,7 +3755,6 @@ mod tests {
Conversation::deserialize(
conversation.read(cx).serialize(cx),
Default::default(),
- Default::default(),
registry.clone(),
cx,
)
@@ -1,5 +1,5 @@
use crate::streaming_diff::{Hunk, StreamingDiff};
-use ai::completion::{CompletionProvider, OpenAIRequest};
+use ai::completion::{CompletionProvider, CompletionRequest};
use anyhow::Result;
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
@@ -96,7 +96,7 @@ impl Codegen {
self.error.as_ref()
}
- pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
+ pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
@@ -336,17 +336,25 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
- use futures::{
- future::BoxFuture,
- stream::{self, BoxStream},
- };
+ use ai::test::FakeCompletionProvider;
+ use futures::stream::{self};
use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
- use parking_lot::Mutex;
use rand::prelude::*;
+ use serde::Serialize;
use settings::SettingsStore;
- use smol::future::FutureExt;
+
+ #[derive(Serialize)]
+ pub struct DummyCompletionRequest {
+ pub name: String,
+ }
+
+ impl CompletionRequest for DummyCompletionRequest {
+ fn data(&self) -> serde_json::Result<String> {
+ serde_json::to_string(self)
+ }
+ }
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(
@@ -372,7 +380,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
- let provider = Arc::new(TestCompletionProvider::new());
+ let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@@ -381,7 +389,11 @@ mod tests {
cx,
)
});
- codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+ let request = Box::new(DummyCompletionRequest {
+ name: "test".to_string(),
+ });
+ codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
" let mut x = 0;\n",
@@ -434,7 +446,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
- let provider = Arc::new(TestCompletionProvider::new());
+ let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@@ -443,7 +455,11 @@ mod tests {
cx,
)
});
- codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+ let request = Box::new(DummyCompletionRequest {
+ name: "test".to_string(),
+ });
+ codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"t mut x = 0;\n",
@@ -496,7 +512,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
- let provider = Arc::new(TestCompletionProvider::new());
+ let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@@ -505,7 +521,11 @@ mod tests {
cx,
)
});
- codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+ let request = Box::new(DummyCompletionRequest {
+ name: "test".to_string(),
+ });
+ codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"let mut x = 0;\n",
@@ -593,38 +613,6 @@ mod tests {
}
}
- struct TestCompletionProvider {
- last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
- }
-
- impl TestCompletionProvider {
- fn new() -> Self {
- Self {
- last_completion_tx: Mutex::new(None),
- }
- }
-
- fn send_completion(&self, completion: impl Into<String>) {
- let mut tx = self.last_completion_tx.lock();
- tx.as_mut().unwrap().try_send(completion.into()).unwrap();
- }
-
- fn finish_completion(&self) {
- self.last_completion_tx.lock().take().unwrap();
- }
- }
-
- impl CompletionProvider for TestCompletionProvider {
- fn complete(
- &self,
- _prompt: OpenAIRequest,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let (tx, rx) = mpsc::channel(1);
- *self.last_completion_tx.lock() = Some(tx);
- async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
- }
- }
-
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
@@ -1,9 +1,10 @@
-use ai::models::{LanguageModel, OpenAILanguageModel};
-use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
-use ai::templates::file_context::FileContext;
-use ai::templates::generate::GenerateInlineContent;
-use ai::templates::preamble::EngineerPreamble;
-use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
+use ai::models::LanguageModel;
+use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
+use ai::prompts::file_context::FileContext;
+use ai::prompts::generate::GenerateInlineContent;
+use ai::prompts::preamble::EngineerPreamble;
+use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
+use ai::providers::open_ai::OpenAILanguageModel;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp::{self, Reverse};
use std::ops::Range;
@@ -13,7 +13,7 @@ use collections::HashSet;
use futures::{future::Shared, FutureExt};
use gpui2::{
AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Subscription, Task,
- WeakHandle,
+ WeakModel,
};
use postage::watch;
use project2::Project;
@@ -42,7 +42,7 @@ pub struct IncomingCall {
pub struct ActiveCall {
room: Option<(Model<Room>, Vec<Subscription>)>,
pending_room_creation: Option<Shared<Task<Result<Model<Room>, Arc<anyhow::Error>>>>>,
- location: Option<WeakHandle<Project>>,
+ location: Option<WeakModel<Project>>,
pending_invites: HashSet<u64>,
incoming_call: (
watch::Sender<Option<IncomingCall>>,
@@ -347,7 +347,7 @@ impl ActiveCall {
}
}
- pub fn location(&self) -> Option<&WeakHandle<Project>> {
+ pub fn location(&self) -> Option<&WeakModel<Project>> {
self.location.as_ref()
}
@@ -1,7 +1,7 @@
use anyhow::{anyhow, Result};
use client2::ParticipantIndex;
use client2::{proto, User};
-use gpui2::WeakHandle;
+use gpui2::WeakModel;
pub use live_kit_client::Frame;
use project2::Project;
use std::{fmt, sync::Arc};
@@ -33,7 +33,7 @@ impl ParticipantLocation {
#[derive(Clone, Default)]
pub struct LocalParticipant {
pub projects: Vec<proto::ParticipantProject>,
- pub active_project: Option<WeakHandle<Project>>,
+ pub active_project: Option<WeakModel<Project>>,
}
#[derive(Clone, Debug)]
@@ -16,7 +16,7 @@ use collections::{BTreeMap, HashMap, HashSet};
use fs::Fs;
use futures::{FutureExt, StreamExt};
use gpui2::{
- AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakHandle,
+ AppContext, AsyncAppContext, Context, EventEmitter, Model, ModelContext, Task, WeakModel,
};
use language2::LanguageRegistry;
use live_kit_client::{LocalTrackPublication, RemoteAudioTrackUpdate, RemoteVideoTrackUpdate};
@@ -61,8 +61,8 @@ pub struct Room {
channel_id: Option<u64>,
// live_kit: Option<LiveKitRoom>,
status: RoomStatus,
- shared_projects: HashSet<WeakHandle<Project>>,
- joined_projects: HashSet<WeakHandle<Project>>,
+ shared_projects: HashSet<WeakModel<Project>>,
+ joined_projects: HashSet<WeakModel<Project>>,
local_participant: LocalParticipant,
remote_participants: BTreeMap<u64, RemoteParticipant>,
pending_participants: Vec<Arc<User>>,
@@ -424,7 +424,7 @@ impl Room {
}
async fn maintain_connection(
- this: WeakHandle<Self>,
+ this: WeakModel<Self>,
client: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
@@ -14,8 +14,8 @@ use futures::{
future::BoxFuture, AsyncReadExt, FutureExt, SinkExt, StreamExt, TryFutureExt as _, TryStreamExt,
};
use gpui2::{
- serde_json, AnyHandle, AnyWeakHandle, AppContext, AsyncAppContext, Model, SemanticVersion,
- Task, WeakHandle,
+ serde_json, AnyModel, AnyWeakModel, AppContext, AsyncAppContext, Model, SemanticVersion, Task,
+ WeakModel,
};
use lazy_static::lazy_static;
use parking_lot::RwLock;
@@ -227,7 +227,7 @@ struct ClientState {
_reconnect_task: Option<Task<()>>,
reconnect_interval: Duration,
entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>,
- models_by_message_type: HashMap<TypeId, AnyWeakHandle>,
+ models_by_message_type: HashMap<TypeId, AnyWeakModel>,
entity_types_by_message_type: HashMap<TypeId, TypeId>,
#[allow(clippy::type_complexity)]
message_handlers: HashMap<
@@ -236,7 +236,7 @@ struct ClientState {
dyn Send
+ Sync
+ Fn(
- AnyHandle,
+ AnyModel,
Box<dyn AnyTypedEnvelope>,
&Arc<Client>,
AsyncAppContext,
@@ -246,7 +246,7 @@ struct ClientState {
}
enum WeakSubscriber {
- Entity { handle: AnyWeakHandle },
+ Entity { handle: AnyWeakModel },
Pending(Vec<Box<dyn AnyTypedEnvelope>>),
}
@@ -552,7 +552,7 @@ impl Client {
#[track_caller]
pub fn add_message_handler<M, E, H, F>(
self: &Arc<Self>,
- entity: WeakHandle<E>,
+ entity: WeakModel<E>,
handler: H,
) -> Subscription
where
@@ -594,7 +594,7 @@ impl Client {
pub fn add_request_handler<M, E, H, F>(
self: &Arc<Self>,
- model: WeakHandle<E>,
+ model: WeakModel<E>,
handler: H,
) -> Subscription
where
@@ -628,7 +628,7 @@ impl Client {
where
M: EntityMessage,
E: 'static + Send,
- H: 'static + Send + Sync + Fn(AnyHandle, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+ H: 'static + Send + Sync + Fn(AnyModel, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
F: 'static + Future<Output = Result<()>> + Send,
{
let model_type_id = TypeId::of::<E>();
@@ -8,7 +8,7 @@ use collections::{HashMap, HashSet};
use futures::{channel::oneshot, future::Shared, Future, FutureExt, TryFutureExt};
use gpui2::{
AppContext, AsyncAppContext, Context, EntityId, EventEmitter, Model, ModelContext, Task,
- WeakHandle,
+ WeakModel,
};
use language2::{
language_settings::{all_language_settings, language_settings},
@@ -278,7 +278,7 @@ pub struct Copilot {
http: Arc<dyn HttpClient>,
node_runtime: Arc<dyn NodeRuntime>,
server: CopilotServer,
- buffers: HashSet<WeakHandle<Buffer>>,
+ buffers: HashSet<WeakModel<Buffer>>,
server_id: LanguageServerId,
_subscription: gpui2::Subscription,
}
@@ -383,7 +383,7 @@ impl Copilot {
new_server_id: LanguageServerId,
http: Arc<dyn HttpClient>,
node_runtime: Arc<dyn NodeRuntime>,
- this: WeakHandle<Self>,
+ this: WeakModel<Self>,
mut cx: AsyncAppContext,
) -> impl Future<Output = ()> {
async move {
@@ -706,7 +706,7 @@ impl Copilot {
Ok(())
}
- fn unregister_buffer(&mut self, buffer: &WeakHandle<Buffer>) {
+ fn unregister_buffer(&mut self, buffer: &WeakModel<Buffer>) {
if let Ok(server) = self.server.as_running() {
if let Some(buffer) = server.registered_buffers.remove(&buffer.entity_id()) {
server
@@ -711,7 +711,7 @@ impl Context for AppContext {
type Result<T> = T;
/// Build an entity that is owned by the application. The given function will be invoked with
- /// a `ModelContext` and must return an object representing the entity. A `Handle` will be returned
+ /// a `ModelContext` and must return an object representing the entity. A `Model` will be returned
/// which can be used to access the entity in a context.
fn build_model<T: 'static + Send>(
&mut self,
@@ -724,18 +724,18 @@ impl Context for AppContext {
})
}
- /// Update the entity referenced by the given handle. The function is passed a mutable reference to the
+ /// Update the entity referenced by the given model. The function is passed a mutable reference to the
/// entity along with a `ModelContext` for the entity.
fn update_entity<T: 'static, R>(
&mut self,
- handle: &Model<T>,
+ model: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> R {
self.update(|cx| {
- let mut entity = cx.entities.lease(handle);
+ let mut entity = cx.entities.lease(model);
let result = update(
&mut entity,
- &mut ModelContext::mutable(cx, handle.downgrade()),
+ &mut ModelContext::mutable(cx, model.downgrade()),
);
cx.entities.end_lease(entity);
result
@@ -61,21 +61,21 @@ impl EntityMap {
where
T: 'static + Send,
{
- let handle = slot.0;
- self.entities.insert(handle.entity_id, Box::new(entity));
- handle
+ let model = slot.0;
+ self.entities.insert(model.entity_id, Box::new(entity));
+ model
}
/// Move an entity to the stack.
- pub fn lease<'a, T>(&mut self, handle: &'a Model<T>) -> Lease<'a, T> {
- self.assert_valid_context(handle);
+ pub fn lease<'a, T>(&mut self, model: &'a Model<T>) -> Lease<'a, T> {
+ self.assert_valid_context(model);
let entity = Some(
self.entities
- .remove(handle.entity_id)
+ .remove(model.entity_id)
.expect("Circular entity lease. Is the entity already being updated?"),
);
Lease {
- handle,
+ model,
entity,
entity_type: PhantomData,
}
@@ -84,18 +84,18 @@ impl EntityMap {
/// Return an entity after moving it to the stack.
pub fn end_lease<T>(&mut self, mut lease: Lease<T>) {
self.entities
- .insert(lease.handle.entity_id, lease.entity.take().unwrap());
+ .insert(lease.model.entity_id, lease.entity.take().unwrap());
}
- pub fn read<T: 'static>(&self, handle: &Model<T>) -> &T {
- self.assert_valid_context(handle);
- self.entities[handle.entity_id].downcast_ref().unwrap()
+ pub fn read<T: 'static>(&self, model: &Model<T>) -> &T {
+ self.assert_valid_context(model);
+ self.entities[model.entity_id].downcast_ref().unwrap()
}
- fn assert_valid_context(&self, handle: &AnyHandle) {
+ fn assert_valid_context(&self, model: &AnyModel) {
debug_assert!(
- Weak::ptr_eq(&handle.entity_map, &Arc::downgrade(&self.ref_counts)),
- "used a handle with the wrong context"
+ Weak::ptr_eq(&model.entity_map, &Arc::downgrade(&self.ref_counts)),
+ "used a model with the wrong context"
);
}
@@ -115,7 +115,7 @@ impl EntityMap {
pub struct Lease<'a, T> {
entity: Option<AnyBox>,
- pub handle: &'a Model<T>,
+ pub model: &'a Model<T>,
entity_type: PhantomData<T>,
}
@@ -145,13 +145,13 @@ impl<'a, T> Drop for Lease<'a, T> {
#[derive(Deref, DerefMut)]
pub struct Slot<T>(Model<T>);
-pub struct AnyHandle {
+pub struct AnyModel {
pub(crate) entity_id: EntityId,
entity_type: TypeId,
entity_map: Weak<RwLock<EntityRefCounts>>,
}
-impl AnyHandle {
+impl AnyModel {
fn new(id: EntityId, entity_type: TypeId, entity_map: Weak<RwLock<EntityRefCounts>>) -> Self {
Self {
entity_id: id,
@@ -164,8 +164,8 @@ impl AnyHandle {
self.entity_id
}
- pub fn downgrade(&self) -> AnyWeakHandle {
- AnyWeakHandle {
+ pub fn downgrade(&self) -> AnyWeakModel {
+ AnyWeakModel {
entity_id: self.entity_id,
entity_type: self.entity_type,
entity_ref_counts: self.entity_map.clone(),
@@ -175,7 +175,7 @@ impl AnyHandle {
pub fn downcast<T: 'static>(&self) -> Option<Model<T>> {
if TypeId::of::<T>() == self.entity_type {
Some(Model {
- any_handle: self.clone(),
+ any_model: self.clone(),
entity_type: PhantomData,
})
} else {
@@ -184,16 +184,16 @@ impl AnyHandle {
}
}
-impl Clone for AnyHandle {
+impl Clone for AnyModel {
fn clone(&self) -> Self {
if let Some(entity_map) = self.entity_map.upgrade() {
let entity_map = entity_map.read();
let count = entity_map
.counts
.get(self.entity_id)
- .expect("detected over-release of a handle");
+ .expect("detected over-release of a model");
let prev_count = count.fetch_add(1, SeqCst);
- assert_ne!(prev_count, 0, "Detected over-release of a handle.");
+ assert_ne!(prev_count, 0, "Detected over-release of a model.");
}
Self {
@@ -204,16 +204,16 @@ impl Clone for AnyHandle {
}
}
-impl Drop for AnyHandle {
+impl Drop for AnyModel {
fn drop(&mut self) {
if let Some(entity_map) = self.entity_map.upgrade() {
let entity_map = entity_map.upgradable_read();
let count = entity_map
.counts
.get(self.entity_id)
- .expect("Detected over-release of a handle.");
+ .expect("Detected over-release of a model.");
let prev_count = count.fetch_sub(1, SeqCst);
- assert_ne!(prev_count, 0, "Detected over-release of a handle.");
+ assert_ne!(prev_count, 0, "Detected over-release of a model.");
if prev_count == 1 {
// We were the last reference to this entity, so we can remove it.
let mut entity_map = RwLockUpgradableReadGuard::upgrade(entity_map);
@@ -223,31 +223,31 @@ impl Drop for AnyHandle {
}
}
-impl<T> From<Model<T>> for AnyHandle {
- fn from(handle: Model<T>) -> Self {
- handle.any_handle
+impl<T> From<Model<T>> for AnyModel {
+ fn from(model: Model<T>) -> Self {
+ model.any_model
}
}
-impl Hash for AnyHandle {
+impl Hash for AnyModel {
fn hash<H: Hasher>(&self, state: &mut H) {
self.entity_id.hash(state);
}
}
-impl PartialEq for AnyHandle {
+impl PartialEq for AnyModel {
fn eq(&self, other: &Self) -> bool {
self.entity_id == other.entity_id
}
}
-impl Eq for AnyHandle {}
+impl Eq for AnyModel {}
#[derive(Deref, DerefMut)]
pub struct Model<T> {
#[deref]
#[deref_mut]
- any_handle: AnyHandle,
+ any_model: AnyModel,
entity_type: PhantomData<T>,
}
@@ -260,14 +260,14 @@ impl<T: 'static> Model<T> {
T: 'static,
{
Self {
- any_handle: AnyHandle::new(id, TypeId::of::<T>(), entity_map),
+ any_model: AnyModel::new(id, TypeId::of::<T>(), entity_map),
entity_type: PhantomData,
}
}
- pub fn downgrade(&self) -> WeakHandle<T> {
- WeakHandle {
- any_handle: self.any_handle.downgrade(),
+ pub fn downgrade(&self) -> WeakModel<T> {
+ WeakModel {
+ any_model: self.any_model.downgrade(),
entity_type: self.entity_type,
}
}
@@ -276,7 +276,7 @@ impl<T: 'static> Model<T> {
cx.entities.read(self)
}
- /// Update the entity referenced by this handle with the given function.
+ /// Update the entity referenced by this model with the given function.
///
/// The update function receives a context appropriate for its environment.
/// When updating in an `AppContext`, it receives a `ModelContext`.
@@ -296,7 +296,7 @@ impl<T: 'static> Model<T> {
impl<T> Clone for Model<T> {
fn clone(&self) -> Self {
Self {
- any_handle: self.any_handle.clone(),
+ any_model: self.any_model.clone(),
entity_type: self.entity_type,
}
}
@@ -306,8 +306,8 @@ impl<T> std::fmt::Debug for Model<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
- "Handle {{ entity_id: {:?}, entity_type: {:?} }}",
- self.any_handle.entity_id,
+ "Model {{ entity_id: {:?}, entity_type: {:?} }}",
+ self.any_model.entity_id,
type_name::<T>()
)
}
@@ -315,32 +315,32 @@ impl<T> std::fmt::Debug for Model<T> {
impl<T> Hash for Model<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
- self.any_handle.hash(state);
+ self.any_model.hash(state);
}
}
impl<T> PartialEq for Model<T> {
fn eq(&self, other: &Self) -> bool {
- self.any_handle == other.any_handle
+ self.any_model == other.any_model
}
}
impl<T> Eq for Model<T> {}
-impl<T> PartialEq<WeakHandle<T>> for Model<T> {
- fn eq(&self, other: &WeakHandle<T>) -> bool {
+impl<T> PartialEq<WeakModel<T>> for Model<T> {
+ fn eq(&self, other: &WeakModel<T>) -> bool {
self.entity_id() == other.entity_id()
}
}
#[derive(Clone)]
-pub struct AnyWeakHandle {
+pub struct AnyWeakModel {
pub(crate) entity_id: EntityId,
entity_type: TypeId,
entity_ref_counts: Weak<RwLock<EntityRefCounts>>,
}
-impl AnyWeakHandle {
+impl AnyWeakModel {
pub fn entity_id(&self) -> EntityId {
self.entity_id
}
@@ -354,14 +354,14 @@ impl AnyWeakHandle {
ref_count > 0
}
- pub fn upgrade(&self) -> Option<AnyHandle> {
+ pub fn upgrade(&self) -> Option<AnyModel> {
let entity_map = self.entity_ref_counts.upgrade()?;
entity_map
.read()
.counts
.get(self.entity_id)?
.fetch_add(1, SeqCst);
- Some(AnyHandle {
+ Some(AnyModel {
entity_id: self.entity_id,
entity_type: self.entity_type,
entity_map: self.entity_ref_counts.clone(),
@@ -369,55 +369,55 @@ impl AnyWeakHandle {
}
}
-impl<T> From<WeakHandle<T>> for AnyWeakHandle {
- fn from(handle: WeakHandle<T>) -> Self {
- handle.any_handle
+impl<T> From<WeakModel<T>> for AnyWeakModel {
+ fn from(model: WeakModel<T>) -> Self {
+ model.any_model
}
}
-impl Hash for AnyWeakHandle {
+impl Hash for AnyWeakModel {
fn hash<H: Hasher>(&self, state: &mut H) {
self.entity_id.hash(state);
}
}
-impl PartialEq for AnyWeakHandle {
+impl PartialEq for AnyWeakModel {
fn eq(&self, other: &Self) -> bool {
self.entity_id == other.entity_id
}
}
-impl Eq for AnyWeakHandle {}
+impl Eq for AnyWeakModel {}
#[derive(Deref, DerefMut)]
-pub struct WeakHandle<T> {
+pub struct WeakModel<T> {
#[deref]
#[deref_mut]
- any_handle: AnyWeakHandle,
+ any_model: AnyWeakModel,
entity_type: PhantomData<T>,
}
-unsafe impl<T> Send for WeakHandle<T> {}
-unsafe impl<T> Sync for WeakHandle<T> {}
+unsafe impl<T> Send for WeakModel<T> {}
+unsafe impl<T> Sync for WeakModel<T> {}
-impl<T> Clone for WeakHandle<T> {
+impl<T> Clone for WeakModel<T> {
fn clone(&self) -> Self {
Self {
- any_handle: self.any_handle.clone(),
+ any_model: self.any_model.clone(),
entity_type: self.entity_type,
}
}
}
-impl<T: 'static> WeakHandle<T> {
+impl<T: 'static> WeakModel<T> {
pub fn upgrade(&self) -> Option<Model<T>> {
Some(Model {
- any_handle: self.any_handle.upgrade()?,
+ any_model: self.any_model.upgrade()?,
entity_type: self.entity_type,
})
}
- /// Update the entity referenced by this handle with the given function if
+ /// Update the entity referenced by this model with the given function if
/// the referenced entity still exists. Returns an error if the entity has
/// been released.
///
@@ -441,21 +441,21 @@ impl<T: 'static> WeakHandle<T> {
}
}
-impl<T> Hash for WeakHandle<T> {
+impl<T> Hash for WeakModel<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
- self.any_handle.hash(state);
+ self.any_model.hash(state);
}
}
-impl<T> PartialEq for WeakHandle<T> {
+impl<T> PartialEq for WeakModel<T> {
fn eq(&self, other: &Self) -> bool {
- self.any_handle == other.any_handle
+ self.any_model == other.any_model
}
}
-impl<T> Eq for WeakHandle<T> {}
+impl<T> Eq for WeakModel<T> {}
-impl<T> PartialEq<Model<T>> for WeakHandle<T> {
+impl<T> PartialEq<Model<T>> for WeakModel<T> {
fn eq(&self, other: &Model<T>) -> bool {
self.entity_id() == other.entity_id()
}
@@ -1,6 +1,6 @@
use crate::{
AppContext, AsyncAppContext, Context, Effect, EntityId, EventEmitter, MainThread, Model,
- Reference, Subscription, Task, WeakHandle,
+ Reference, Subscription, Task, WeakModel,
};
use derive_more::{Deref, DerefMut};
use futures::FutureExt;
@@ -15,11 +15,11 @@ pub struct ModelContext<'a, T> {
#[deref]
#[deref_mut]
app: Reference<'a, AppContext>,
- model_state: WeakHandle<T>,
+ model_state: WeakModel<T>,
}
impl<'a, T: 'static> ModelContext<'a, T> {
- pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakHandle<T>) -> Self {
+ pub(crate) fn mutable(app: &'a mut AppContext, model_state: WeakModel<T>) -> Self {
Self {
app: Reference::Mutable(app),
model_state,
@@ -36,7 +36,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
.expect("The entity must be alive if we have a model context")
}
- pub fn weak_handle(&self) -> WeakHandle<T> {
+ pub fn weak_handle(&self) -> WeakModel<T> {
self.model_state.clone()
}
@@ -184,7 +184,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
pub fn spawn<Fut, R>(
&self,
- f: impl FnOnce(WeakHandle<T>, AsyncAppContext) -> Fut + Send + 'static,
+ f: impl FnOnce(WeakModel<T>, AsyncAppContext) -> Fut + Send + 'static,
) -> Task<R>
where
T: 'static,
@@ -197,7 +197,7 @@ impl<'a, T: 'static> ModelContext<'a, T> {
pub fn spawn_on_main<Fut, R>(
&self,
- f: impl FnOnce(WeakHandle<T>, MainThread<AsyncAppContext>) -> Fut + Send + 'static,
+ f: impl FnOnce(WeakModel<T>, MainThread<AsyncAppContext>) -> Fut + Send + 'static,
) -> Task<R>
where
Fut: Future<Output = R> + 'static,
@@ -333,7 +333,7 @@ pub trait StatefulInteractive<V: 'static>: StatelessInteractive<V> {
Some(Box::new(move |view_state, cursor_offset, cx| {
let drag = listener(view_state, cx);
let drag_handle_view = Some(
- View::for_handle(cx.handle().upgrade().unwrap(), move |view_state, cx| {
+ View::for_handle(cx.model().upgrade().unwrap(), move |view_state, cx| {
(drag.render_drag_handle)(view_state, cx)
})
.into_any(),
@@ -1,6 +1,6 @@
use crate::{
AnyBox, AnyElement, AvailableSpace, BorrowWindow, Bounds, Component, Element, ElementId,
- EntityId, LayoutId, Model, Pixels, Size, ViewContext, VisualContext, WeakHandle, WindowContext,
+ EntityId, LayoutId, Model, Pixels, Size, ViewContext, VisualContext, WeakModel, WindowContext,
};
use anyhow::{Context, Result};
use parking_lot::Mutex;
@@ -116,7 +116,7 @@ impl<V: 'static> Element<()> for View<V> {
}
pub struct WeakView<V> {
- pub(crate) state: WeakHandle<V>,
+ pub(crate) state: WeakModel<V>,
render: Weak<Mutex<dyn Fn(&mut V, &mut ViewContext<V>) -> AnyElement<V> + Send + 'static>>,
}
@@ -7,7 +7,7 @@ use crate::{
MouseButton, MouseDownEvent, MouseMoveEvent, MouseUpEvent, Path, Pixels, PlatformAtlas,
PlatformWindow, Point, PolychromeSprite, Quad, Reference, RenderGlyphParams, RenderImageParams,
RenderSvgParams, ScaledPixels, SceneBuilder, Shadow, SharedString, Size, Style, Subscription,
- TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext, WeakHandle, WeakView,
+ TaffyLayoutEngine, Task, Underline, UnderlineStyle, View, VisualContext, WeakModel, WeakView,
WindowOptions, SUBPIXEL_VARIANTS,
};
use anyhow::Result;
@@ -1257,13 +1257,13 @@ impl Context for WindowContext<'_, '_> {
fn update_entity<T: 'static, R>(
&mut self,
- handle: &Model<T>,
+ model: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> R {
- let mut entity = self.entities.lease(handle);
+ let mut entity = self.entities.lease(model);
let result = update(
&mut *entity,
- &mut ModelContext::mutable(&mut *self.app, handle.downgrade()),
+ &mut ModelContext::mutable(&mut *self.app, model.downgrade()),
);
self.entities.end_lease(entity);
result
@@ -1555,7 +1555,7 @@ impl<'a, 'w, V: 'static> ViewContext<'a, 'w, V> {
self.view.clone()
}
- pub fn handle(&self) -> WeakHandle<V> {
+ pub fn model(&self) -> WeakModel<V> {
self.view.state.clone()
}
@@ -1872,10 +1872,10 @@ impl<'a, 'w, V> Context for ViewContext<'a, 'w, V> {
fn update_entity<T: 'static, R>(
&mut self,
- handle: &Model<T>,
+ model: &Model<T>,
update: impl FnOnce(&mut T, &mut Self::ModelContext<'_, T>) -> R,
) -> R {
- self.window_cx.update_entity(handle, update)
+ self.window_cx.update_entity(model, update)
}
}
@@ -26,8 +26,8 @@ use futures::{
};
use globset::{Glob, GlobSet, GlobSetBuilder};
use gpui2::{
- AnyHandle, AppContext, AsyncAppContext, Context, EventEmitter, Executor, Model, ModelContext,
- Task, WeakHandle,
+ AnyModel, AppContext, AsyncAppContext, Context, EventEmitter, Executor, Model, ModelContext,
+ Task, WeakModel,
};
use itertools::Itertools;
use language2::{
@@ -153,7 +153,7 @@ pub struct Project {
incomplete_remote_buffers: HashMap<u64, Option<Model<Buffer>>>,
buffer_snapshots: HashMap<u64, HashMap<LanguageServerId, Vec<LspBufferSnapshot>>>, // buffer_id -> server_id -> vec of snapshots
buffers_being_formatted: HashSet<u64>,
- buffers_needing_diff: HashSet<WeakHandle<Buffer>>,
+ buffers_needing_diff: HashSet<WeakModel<Buffer>>,
git_diff_debouncer: DelayedDebounced,
nonce: u128,
_maintain_buffer_languages: Task<()>,
@@ -245,14 +245,14 @@ enum LocalProjectUpdate {
enum OpenBuffer {
Strong(Model<Buffer>),
- Weak(WeakHandle<Buffer>),
+ Weak(WeakModel<Buffer>),
Operations(Vec<Operation>),
}
#[derive(Clone)]
enum WorktreeHandle {
Strong(Model<Worktree>),
- Weak(WeakHandle<Worktree>),
+ Weak(WeakModel<Worktree>),
}
enum ProjectClientState {
@@ -1671,7 +1671,7 @@ impl Project {
&mut self,
path: impl Into<ProjectPath>,
cx: &mut ModelContext<Self>,
- ) -> Task<Result<(ProjectEntryId, AnyHandle)>> {
+ ) -> Task<Result<(ProjectEntryId, AnyModel)>> {
let task = self.open_buffer(path, cx);
cx.spawn(move |_, mut cx| async move {
let buffer = task.await?;
@@ -1681,7 +1681,7 @@ impl Project {
})?
.ok_or_else(|| anyhow!("no project entry"))?;
- let buffer: &AnyHandle = &buffer;
+ let buffer: &AnyModel = &buffer;
Ok((project_entry_id, buffer.clone()))
})
}
@@ -2158,7 +2158,7 @@ impl Project {
}
async fn send_buffer_ordered_messages(
- this: WeakHandle<Self>,
+ this: WeakModel<Self>,
rx: UnboundedReceiver<BufferOrderedMessage>,
mut cx: AsyncAppContext,
) -> Result<()> {
@@ -2166,7 +2166,7 @@ impl Project {
let mut operations_by_buffer_id = HashMap::default();
async fn flush_operations(
- this: &WeakHandle<Project>,
+ this: &WeakModel<Project>,
operations_by_buffer_id: &mut HashMap<u64, Vec<proto::Operation>>,
needs_resync_with_host: &mut bool,
is_local: bool,
@@ -2931,7 +2931,7 @@ impl Project {
}
async fn setup_and_insert_language_server(
- this: WeakHandle<Self>,
+ this: WeakModel<Self>,
initialization_options: Option<serde_json::Value>,
pending_server: PendingLanguageServer,
adapter: Arc<CachedLspAdapter>,
@@ -2970,7 +2970,7 @@ impl Project {
}
async fn setup_pending_language_server(
- this: WeakHandle<Self>,
+ this: WeakModel<Self>,
initialization_options: Option<serde_json::Value>,
pending_server: PendingLanguageServer,
adapter: Arc<CachedLspAdapter>,
@@ -3748,7 +3748,7 @@ impl Project {
}
async fn on_lsp_workspace_edit(
- this: WeakHandle<Self>,
+ this: WeakModel<Self>,
params: lsp2::ApplyWorkspaceEditParams,
server_id: LanguageServerId,
adapter: Arc<CachedLspAdapter>,
@@ -4360,7 +4360,7 @@ impl Project {
}
async fn format_via_lsp(
- this: &WeakHandle<Self>,
+ this: &WeakModel<Self>,
buffer: &Model<Buffer>,
abs_path: &Path,
language_server: &Arc<LanguageServer>,
@@ -1,5 +1,5 @@
use crate::Project;
-use gpui2::{AnyWindowHandle, Context, Model, ModelContext, WeakHandle};
+use gpui2::{AnyWindowHandle, Context, Model, ModelContext, WeakModel};
use settings2::Settings;
use std::path::{Path, PathBuf};
use terminal2::{
@@ -11,7 +11,7 @@ use terminal2::{
use std::os::unix::ffi::OsStrExt;
pub struct Terminals {
- pub(crate) local_handles: Vec<WeakHandle<terminal2::Terminal>>,
+ pub(crate) local_handles: Vec<WeakModel<terminal2::Terminal>>,
}
impl Project {
@@ -121,7 +121,7 @@ impl Project {
}
}
- pub fn local_terminal_handles(&self) -> &Vec<WeakHandle<terminal2::Terminal>> {
+ pub fn local_terminal_handles(&self) -> &Vec<WeakModel<terminal2::Terminal>> {
&self.terminals.local_handles
}
}
@@ -42,6 +42,7 @@ sha1 = "0.10.5"
ndarray = { version = "0.15.0" }
[dev-dependencies]
+ai = { path = "../ai", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
@@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
- api_key: Option<String>,
}
#[derive(Clone)]
@@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
}
impl EmbeddingQueue {
- pub fn new(
- embedding_provider: Arc<dyn EmbeddingProvider>,
- executor: Arc<Background>,
- api_key: Option<String>,
- ) -> Self {
+ pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
@@ -64,14 +59,9 @@ impl EmbeddingQueue {
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
- api_key,
}
}
- pub fn set_api_key(&mut self, api_key: Option<String>) {
- self.api_key = api_key
- }
-
pub fn push(&mut self, file: FileToEmbed) {
if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
@@ -118,7 +108,6 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
- let api_key = self.api_key.clone();
self.executor
.spawn(async move {
@@ -143,7 +132,7 @@ impl EmbeddingQueue {
return;
};
- match embedding_provider.embed_batch(spans, api_key).await {
+ match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
@@ -1,4 +1,7 @@
-use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::{
+ embedding::{Embedding, EmbeddingProvider},
+ models::TruncationDirection,
+};
use anyhow::{anyhow, Result};
use language::{Grammar, Language};
use rusqlite::{
@@ -108,7 +111,14 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
- let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
+ let model = self.embedding_provider.base_model();
+ let document_span = model.truncate(
+ &document_span,
+ model.capacity()?,
+ ai::models::TruncationDirection::End,
+ )?;
+ let token_count = model.count_tokens(&document_span)?;
+
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
@@ -131,7 +141,15 @@ impl CodeContextRetriever {
)
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
- let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
+
+ let model = self.embedding_provider.base_model();
+ let document_span = model.truncate(
+ &document_span,
+ model.capacity()?,
+ ai::models::TruncationDirection::End,
+ )?;
+ let token_count = model.count_tokens(&document_span)?;
+
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
@@ -222,8 +240,13 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("item", &span.content);
- let (document_content, token_count) =
- self.embedding_provider.truncate(&document_content);
+ let model = self.embedding_provider.base_model();
+ let document_content = model.truncate(
+ &document_content,
+ model.capacity()?,
+ TruncationDirection::End,
+ )?;
+ let token_count = model.count_tokens(&document_content)?;
span.content = document_content;
span.token_count = token_count;
@@ -7,7 +7,8 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
+use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
@@ -88,7 +89,7 @@ pub fn init(
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
- Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+ Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
language_registry,
cx.clone(),
)
@@ -123,8 +124,6 @@ pub struct SemanticIndex {
_embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
- api_key: Option<String>,
- embedding_queue: Arc<Mutex<EmbeddingQueue>>,
}
struct ProjectState {
@@ -278,18 +277,18 @@ impl SemanticIndex {
}
}
- pub fn authenticate(&mut self, cx: &AppContext) {
- if self.api_key.is_none() {
- self.api_key = self.embedding_provider.retrieve_credentials(cx);
-
- self.embedding_queue
- .lock()
- .set_api_key(self.api_key.clone());
+ pub fn authenticate(&mut self, cx: &AppContext) -> bool {
+ if !self.embedding_provider.has_credentials() {
+ self.embedding_provider.retrieve_credentials(cx);
+ } else {
+ return true;
}
+
+ self.embedding_provider.has_credentials()
}
pub fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
+ self.embedding_provider.has_credentials()
}
pub fn enabled(cx: &AppContext) -> bool {
@@ -339,7 +338,7 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
let embedding_queue =
- EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
+ EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files();
let db = db.clone();
@@ -404,8 +403,6 @@ impl SemanticIndex {
_embedding_task,
_parsing_files_tasks,
projects: Default::default(),
- api_key: None,
- embedding_queue
}
}))
}
@@ -720,13 +717,13 @@ impl SemanticIndex {
let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone();
- let api_key = self.api_key.clone();
cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
+
let query = embedding_provider
- .embed_batch(vec![query], api_key)
+ .embed_batch(vec![query])
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
@@ -944,7 +941,6 @@ impl SemanticIndex {
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
- let api_key = self.api_key.clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
@@ -959,15 +955,10 @@ impl SemanticIndex {
.parse_file_with_template(None, &snapshot.text(), language)
.log_err()
.unwrap_or_default();
- if Self::embed_spans(
- &mut spans,
- embedding_provider.as_ref(),
- &db,
- api_key.clone(),
- )
- .await
- .log_err()
- .is_some()
+ if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
+ .await
+ .log_err()
+ .is_some()
{
for span in spans {
let similarity = span.embedding.unwrap().similarity(&query);
@@ -1007,9 +998,8 @@ impl SemanticIndex {
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
- if self.api_key.is_none() {
- self.authenticate(cx);
- if self.api_key.is_none() {
+ if !self.is_authenticated() {
+ if !self.authenticate(cx) {
return Task::ready(Err(anyhow!("user is not authenticated")));
}
}
@@ -1192,7 +1182,6 @@ impl SemanticIndex {
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
- api_key: Option<String>,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
@@ -1215,7 +1204,7 @@ impl SemanticIndex {
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch), api_key.clone())
+ .embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
@@ -1227,7 +1216,7 @@ impl SemanticIndex {
if !batch.is_empty() {
let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch), api_key)
+ .embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
@@ -4,10 +4,9 @@ use crate::{
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
-use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
-use anyhow::Result;
-use async_trait::async_trait;
-use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
+use ai::test::FakeEmbeddingProvider;
+
+use gpui::{executor::Deterministic, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
@@ -15,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
use rand::{rngs::StdRng, Rng};
use serde_json::json;
use settings::SettingsStore;
-use std::{
- path::Path,
- sync::{
- atomic::{self, AtomicUsize},
- Arc,
- },
- time::{Instant, SystemTime},
-};
+use std::{path::Path, sync::Arc, time::SystemTime};
use unindent::Unindent;
use util::RandomCharIter;
@@ -228,7 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
+ let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
for file in &files {
queue.push(file.clone());
}
@@ -280,7 +272,7 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -382,7 +374,7 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -466,7 +458,7 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -565,7 +557,7 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -639,7 +631,7 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -756,7 +748,7 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -909,7 +901,7 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -1100,7 +1092,7 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -1248,65 +1240,6 @@ async fn test_code_context_retrieval_php() {
);
}
-#[derive(Default)]
-struct FakeEmbeddingProvider {
- embedding_count: AtomicUsize,
-}
-
-impl FakeEmbeddingProvider {
- fn embedding_count(&self) -> usize {
- self.embedding_count.load(atomic::Ordering::SeqCst)
- }
-
- fn embed_sync(&self, span: &str) -> Embedding {
- let mut result = vec![1.0; 26];
- for letter in span.chars() {
- let letter = letter.to_ascii_lowercase();
- if letter as u32 >= 'a' as u32 {
- let ix = (letter as u32) - ('a' as u32);
- if ix < 26 {
- result[ix as usize] += 1.0;
- }
- }
- }
-
- let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
- for x in &mut result {
- *x /= norm;
- }
-
- result.into()
- }
-}
-
-#[async_trait]
-impl EmbeddingProvider for FakeEmbeddingProvider {
- fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
- Some("Fake Credentials".to_string())
- }
- fn truncate(&self, span: &str) -> (String, usize) {
- (span.to_string(), 1)
- }
-
- fn max_tokens_per_batch(&self) -> usize {
- 200
- }
-
- fn rate_limit_expiration(&self) -> Option<Instant> {
- None
- }
-
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- _api_key: Option<String>,
- ) -> Result<Vec<Embedding>> {
- self.embedding_count
- .fetch_add(spans.len(), atomic::Ordering::SeqCst);
- Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
- }
-}
-
fn js_lang() -> Arc<Language> {
Arc::new(
Language::new(
@@ -1,4 +1,4 @@
-use ai::embedding::OpenAIEmbeddings;
+use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
use client::{self, UserStore};
use gpui::{AsyncAppContext, ModelHandle, Task};
@@ -475,7 +475,7 @@ fn main() {
let semantic_index = SemanticIndex::new(
fs.clone(),
db_file_path,
- Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+ Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
languages.clone(),
cx.clone(),
)
@@ -15,6 +15,7 @@ name = "Zed"
path = "src/main.rs"
[dependencies]
+ai2 = { path = "../ai2"}
# audio = { path = "../audio" }
# activity_indicator = { path = "../activity_indicator" }
# auto_update = { path = "../auto_update" }