crates/ai/Cargo.toml 🔗
@@ -8,6 +8,9 @@ publish = false
path = "src/ai.rs"
doctest = false
+[features]
+test-support = []
+
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
Kyle Caverly created
Small reorganization for the AI crates. I seperated out the base traits
and providers, to get closer to an abstraction for AI completions as
opposed to OpenAI specific functionality.
crates/ai/Cargo.toml | 3
crates/ai/src/ai.rs | 6
crates/ai/src/auth.rs | 15
crates/ai/src/completion.rs | 215 ----------
crates/ai/src/embedding.rs | 322 ----------------
crates/ai/src/models.rs | 70 ---
crates/ai/src/prompts/base.rs | 56 --
crates/ai/src/prompts/file_context.rs | 16
crates/ai/src/prompts/generate.rs | 8
crates/ai/src/prompts/mod.rs | 0
crates/ai/src/prompts/preamble.rs | 2
crates/ai/src/prompts/repository_context.rs | 2
crates/ai/src/providers/mod.rs | 1
crates/ai/src/providers/open_ai/completion.rs | 298 +++++++++++++++
crates/ai/src/providers/open_ai/embedding.rs | 306 ++++++++++++++++
crates/ai/src/providers/open_ai/mod.rs | 9
crates/ai/src/providers/open_ai/model.rs | 57 +++
crates/ai/src/providers/open_ai/new.rs | 11
crates/ai/src/test.rs | 191 ++++++++++
crates/assistant/Cargo.toml | 1
crates/assistant/src/assistant.rs | 2
crates/assistant/src/assistant_panel.rs | 280 ++++++-------
crates/assistant/src/codegen.rs | 80 +--
crates/assistant/src/prompts.rs | 13
crates/semantic_index/Cargo.toml | 1
crates/semantic_index/src/embedding_queue.rs | 15
crates/semantic_index/src/parsing.rs | 33 +
crates/semantic_index/src/semantic_index.rs | 55 +-
crates/semantic_index/src/semantic_index_tests.rs | 93 ----
crates/zed/examples/semantic_index_eval.rs | 4
30 files changed, 1,205 insertions(+), 960 deletions(-)
@@ -8,6 +8,9 @@ publish = false
path = "src/ai.rs"
doctest = false
+[features]
+test-support = []
+
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
@@ -1,4 +1,8 @@
+pub mod auth;
pub mod completion;
pub mod embedding;
pub mod models;
-pub mod templates;
+pub mod prompts;
+pub mod providers;
+#[cfg(any(test, feature = "test-support"))]
+pub mod test;
@@ -0,0 +1,15 @@
+use gpui::AppContext;
+
+#[derive(Clone, Debug)]
+pub enum ProviderCredential {
+ Credentials { api_key: String },
+ NoCredentials,
+ NotNeeded,
+}
+
+pub trait CredentialProvider: Send + Sync {
+ fn has_credentials(&self) -> bool;
+ fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential;
+ fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential);
+ fn delete_credentials(&self, cx: &AppContext);
+}
@@ -1,214 +1,23 @@
-use anyhow::{anyhow, Result};
-use futures::{
- future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
- Stream, StreamExt,
-};
-use gpui::executor::Background;
-use isahc::{http::StatusCode, Request, RequestExt};
-use serde::{Deserialize, Serialize};
-use std::{
- fmt::{self, Display},
- io,
- sync::Arc,
-};
+use anyhow::Result;
+use futures::{future::BoxFuture, stream::BoxStream};
-pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
+use crate::{auth::CredentialProvider, models::LanguageModel};
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-pub enum Role {
- User,
- Assistant,
- System,
+pub trait CompletionRequest: Send + Sync {
+ fn data(&self) -> serde_json::Result<String>;
}
-impl Role {
- pub fn cycle(&mut self) {
- *self = match self {
- Role::User => Role::Assistant,
- Role::Assistant => Role::System,
- Role::System => Role::User,
- }
- }
-}
-
-impl Display for Role {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- Role::User => write!(f, "User"),
- Role::Assistant => write!(f, "Assistant"),
- Role::System => write!(f, "System"),
- }
- }
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct RequestMessage {
- pub role: Role,
- pub content: String,
-}
-
-#[derive(Debug, Default, Serialize)]
-pub struct OpenAIRequest {
- pub model: String,
- pub messages: Vec<RequestMessage>,
- pub stream: bool,
- pub stop: Vec<String>,
- pub temperature: f32,
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct ResponseMessage {
- pub role: Option<Role>,
- pub content: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAIUsage {
- pub prompt_tokens: u32,
- pub completion_tokens: u32,
- pub total_tokens: u32,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct ChatChoiceDelta {
- pub index: u32,
- pub delta: ResponseMessage,
- pub finish_reason: Option<String>,
-}
-
-#[derive(Deserialize, Debug)]
-pub struct OpenAIResponseStreamEvent {
- pub id: Option<String>,
- pub object: String,
- pub created: u32,
- pub model: String,
- pub choices: Vec<ChatChoiceDelta>,
- pub usage: Option<OpenAIUsage>,
-}
-
-pub async fn stream_completion(
- api_key: String,
- executor: Arc<Background>,
- mut request: OpenAIRequest,
-) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
- request.stream = true;
-
- let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
-
- let json_data = serde_json::to_string(&request)?;
- let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key))
- .body(json_data)?
- .send_async()
- .await?;
-
- let status = response.status();
- if status == StatusCode::OK {
- executor
- .spawn(async move {
- let mut lines = BufReader::new(response.body_mut()).lines();
-
- fn parse_line(
- line: Result<String, io::Error>,
- ) -> Result<Option<OpenAIResponseStreamEvent>> {
- if let Some(data) = line?.strip_prefix("data: ") {
- let event = serde_json::from_str(&data)?;
- Ok(Some(event))
- } else {
- Ok(None)
- }
- }
-
- while let Some(line) = lines.next().await {
- if let Some(event) = parse_line(line).transpose() {
- let done = event.as_ref().map_or(false, |event| {
- event
- .choices
- .last()
- .map_or(false, |choice| choice.finish_reason.is_some())
- });
- if tx.unbounded_send(event).is_err() {
- break;
- }
-
- if done {
- break;
- }
- }
- }
-
- anyhow::Ok(())
- })
- .detach();
-
- Ok(rx)
- } else {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- #[derive(Deserialize)]
- struct OpenAIResponse {
- error: OpenAIError,
- }
-
- #[derive(Deserialize)]
- struct OpenAIError {
- message: String,
- }
-
- match serde_json::from_str::<OpenAIResponse>(&body) {
- Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
- "Failed to connect to OpenAI API: {}",
- response.error.message,
- )),
-
- _ => Err(anyhow!(
- "Failed to connect to OpenAI API: {} {}",
- response.status(),
- body,
- )),
- }
- }
-}
-
-pub trait CompletionProvider {
+pub trait CompletionProvider: CredentialProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel>;
fn complete(
&self,
- prompt: OpenAIRequest,
+ prompt: Box<dyn CompletionRequest>,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
+ fn box_clone(&self) -> Box<dyn CompletionProvider>;
}
-pub struct OpenAICompletionProvider {
- api_key: String,
- executor: Arc<Background>,
-}
-
-impl OpenAICompletionProvider {
- pub fn new(api_key: String, executor: Arc<Background>) -> Self {
- Self { api_key, executor }
- }
-}
-
-impl CompletionProvider for OpenAICompletionProvider {
- fn complete(
- &self,
- prompt: OpenAIRequest,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
- async move {
- let response = request.await?;
- let stream = response
- .filter_map(|response| async move {
- match response {
- Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
- Err(error) => Some(Err(error)),
- }
- })
- .boxed();
- Ok(stream)
- }
- .boxed()
+impl Clone for Box<dyn CompletionProvider> {
+ fn clone(&self) -> Box<dyn CompletionProvider> {
+ self.box_clone()
}
}
@@ -1,32 +1,13 @@
-use anyhow::{anyhow, Result};
+use std::time::Instant;
+
+use anyhow::Result;
use async_trait::async_trait;
-use futures::AsyncReadExt;
-use gpui::executor::Background;
-use gpui::{serde_json, AppContext};
-use isahc::http::StatusCode;
-use isahc::prelude::Configurable;
-use isahc::{AsyncBody, Response};
-use lazy_static::lazy_static;
use ordered_float::OrderedFloat;
-use parking_lot::Mutex;
-use parse_duration::parse;
-use postage::watch;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
-use serde::{Deserialize, Serialize};
-use std::env;
-use std::ops::Add;
-use std::sync::Arc;
-use std::time::{Duration, Instant};
-use tiktoken_rs::{cl100k_base, CoreBPE};
-use util::http::{HttpClient, Request};
-use util::ResultExt;
-
-use crate::completion::OPENAI_API_URL;
-lazy_static! {
- static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
-}
+use crate::auth::CredentialProvider;
+use crate::models::LanguageModel;
#[derive(Debug, PartialEq, Clone)]
pub struct Embedding(pub Vec<f32>);
@@ -87,301 +68,14 @@ impl Embedding {
}
}
-#[derive(Clone)]
-pub struct OpenAIEmbeddings {
- pub client: Arc<dyn HttpClient>,
- pub executor: Arc<Background>,
- rate_limit_count_rx: watch::Receiver<Option<Instant>>,
- rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
-}
-
-#[derive(Serialize)]
-struct OpenAIEmbeddingRequest<'a> {
- model: &'static str,
- input: Vec<&'a str>,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingResponse {
- data: Vec<OpenAIEmbedding>,
- usage: OpenAIEmbeddingUsage,
-}
-
-#[derive(Debug, Deserialize)]
-struct OpenAIEmbedding {
- embedding: Vec<f32>,
- index: usize,
- object: String,
-}
-
-#[derive(Deserialize)]
-struct OpenAIEmbeddingUsage {
- prompt_tokens: usize,
- total_tokens: usize,
-}
-
#[async_trait]
-pub trait EmbeddingProvider: Sync + Send {
- fn retrieve_credentials(&self, cx: &AppContext) -> Option<String>;
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- api_key: Option<String>,
- ) -> Result<Vec<Embedding>>;
+pub trait EmbeddingProvider: CredentialProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel>;
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
- fn truncate(&self, span: &str) -> (String, usize);
fn rate_limit_expiration(&self) -> Option<Instant>;
}
-pub struct DummyEmbeddings {}
-
-#[async_trait]
-impl EmbeddingProvider for DummyEmbeddings {
- fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
- Some("Dummy API KEY".to_string())
- }
- fn rate_limit_expiration(&self) -> Option<Instant> {
- None
- }
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- _api_key: Option<String>,
- ) -> Result<Vec<Embedding>> {
- // 1024 is the OpenAI Embeddings size for ada models.
- // the model we will likely be starting with.
- let dummy_vec = Embedding::from(vec![0.32 as f32; 1536]);
- return Ok(vec![dummy_vec; spans.len()]);
- }
-
- fn max_tokens_per_batch(&self) -> usize {
- OPENAI_INPUT_LIMIT
- }
-
- fn truncate(&self, span: &str) -> (String, usize) {
- let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- let token_count = tokens.len();
- let output = if token_count > OPENAI_INPUT_LIMIT {
- tokens.truncate(OPENAI_INPUT_LIMIT);
- let new_input = OPENAI_BPE_TOKENIZER.decode(tokens.clone());
- new_input.ok().unwrap_or_else(|| span.to_string())
- } else {
- span.to_string()
- };
-
- (output, tokens.len())
- }
-}
-
-const OPENAI_INPUT_LIMIT: usize = 8190;
-
-impl OpenAIEmbeddings {
- pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
- let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
- let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
-
- OpenAIEmbeddings {
- client,
- executor,
- rate_limit_count_rx,
- rate_limit_count_tx,
- }
- }
-
- fn resolve_rate_limit(&self) {
- let reset_time = *self.rate_limit_count_tx.lock().borrow();
-
- if let Some(reset_time) = reset_time {
- if Instant::now() >= reset_time {
- *self.rate_limit_count_tx.lock().borrow_mut() = None
- }
- }
-
- log::trace!(
- "resolving reset time: {:?}",
- *self.rate_limit_count_tx.lock().borrow()
- );
- }
-
- fn update_reset_time(&self, reset_time: Instant) {
- let original_time = *self.rate_limit_count_tx.lock().borrow();
-
- let updated_time = if let Some(original_time) = original_time {
- if reset_time < original_time {
- Some(reset_time)
- } else {
- Some(original_time)
- }
- } else {
- Some(reset_time)
- };
-
- log::trace!("updating rate limit time: {:?}", updated_time);
-
- *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
- }
- async fn send_request(
- &self,
- api_key: &str,
- spans: Vec<&str>,
- request_timeout: u64,
- ) -> Result<Response<AsyncBody>> {
- let request = Request::post("https://api.openai.com/v1/embeddings")
- .redirect_policy(isahc::config::RedirectPolicy::Follow)
- .timeout(Duration::from_secs(request_timeout))
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", api_key))
- .body(
- serde_json::to_string(&OpenAIEmbeddingRequest {
- input: spans.clone(),
- model: "text-embedding-ada-002",
- })
- .unwrap()
- .into(),
- )?;
-
- Ok(self.client.send(request).await?)
- }
-}
-
-#[async_trait]
-impl EmbeddingProvider for OpenAIEmbeddings {
- fn retrieve_credentials(&self, cx: &AppContext) -> Option<String> {
- if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- Some(api_key)
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
- {
- String::from_utf8(api_key).log_err()
- } else {
- None
- }
- }
-
- fn max_tokens_per_batch(&self) -> usize {
- 50000
- }
-
- fn rate_limit_expiration(&self) -> Option<Instant> {
- *self.rate_limit_count_rx.borrow()
- }
- fn truncate(&self, span: &str) -> (String, usize) {
- let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
- let output = if tokens.len() > OPENAI_INPUT_LIMIT {
- tokens.truncate(OPENAI_INPUT_LIMIT);
- OPENAI_BPE_TOKENIZER
- .decode(tokens.clone())
- .ok()
- .unwrap_or_else(|| span.to_string())
- } else {
- span.to_string()
- };
-
- (output, tokens.len())
- }
-
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- api_key: Option<String>,
- ) -> Result<Vec<Embedding>> {
- const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
- const MAX_RETRIES: usize = 4;
-
- let Some(api_key) = api_key else {
- return Err(anyhow!("no open ai key provided"));
- };
-
- let mut request_number = 0;
- let mut rate_limiting = false;
- let mut request_timeout: u64 = 15;
- let mut response: Response<AsyncBody>;
- while request_number < MAX_RETRIES {
- response = self
- .send_request(
- &api_key,
- spans.iter().map(|x| &**x).collect(),
- request_timeout,
- )
- .await?;
-
- request_number += 1;
-
- match response.status() {
- StatusCode::REQUEST_TIMEOUT => {
- request_timeout += 5;
- }
- StatusCode::OK => {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
-
- log::trace!(
- "openai embedding completed. tokens: {:?}",
- response.usage.total_tokens
- );
-
- // If we complete a request successfully that was previously rate_limited
- // resolve the rate limit
- if rate_limiting {
- self.resolve_rate_limit()
- }
-
- return Ok(response
- .data
- .into_iter()
- .map(|embedding| Embedding::from(embedding.embedding))
- .collect());
- }
- StatusCode::TOO_MANY_REQUESTS => {
- rate_limiting = true;
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
-
- let delay_duration = {
- let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
- if let Some(time_to_reset) =
- response.headers().get("x-ratelimit-reset-tokens")
- {
- if let Ok(time_str) = time_to_reset.to_str() {
- parse(time_str).unwrap_or(delay)
- } else {
- delay
- }
- } else {
- delay
- }
- };
-
- // If we've previously rate limited, increment the duration but not the count
- let reset_time = Instant::now().add(delay_duration);
- self.update_reset_time(reset_time);
-
- log::trace!(
- "openai rate limiting: waiting {:?} until lifted",
- &delay_duration
- );
-
- self.executor.timer(delay_duration).await;
- }
- _ => {
- let mut body = String::new();
- response.body_mut().read_to_string(&mut body).await?;
- return Err(anyhow!(
- "open ai bad request: {:?} {:?}",
- &response.status(),
- body
- ));
- }
- }
- }
- Err(anyhow!("openai max retries"))
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;
@@ -1,66 +1,16 @@
-use anyhow::anyhow;
-use tiktoken_rs::CoreBPE;
-use util::ResultExt;
+pub enum TruncationDirection {
+ Start,
+ End,
+}
pub trait LanguageModel {
fn name(&self) -> String;
fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
- fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
- fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String>;
fn capacity(&self) -> anyhow::Result<usize>;
}
-
-pub struct OpenAILanguageModel {
- name: String,
- bpe: Option<CoreBPE>,
-}
-
-impl OpenAILanguageModel {
- pub fn load(model_name: &str) -> Self {
- let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
- OpenAILanguageModel {
- name: model_name.to_string(),
- bpe,
- }
- }
-}
-
-impl LanguageModel for OpenAILanguageModel {
- fn name(&self) -> String {
- self.name.clone()
- }
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
- if let Some(bpe) = &self.bpe {
- anyhow::Ok(bpe.encode_with_special_tokens(content).len())
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
- if let Some(bpe) = &self.bpe {
- let tokens = bpe.encode_with_special_tokens(content);
- if tokens.len() > length {
- bpe.decode(tokens[..length].to_vec())
- } else {
- bpe.decode(tokens)
- }
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
- if let Some(bpe) = &self.bpe {
- let tokens = bpe.encode_with_special_tokens(content);
- if tokens.len() > length {
- bpe.decode(tokens[length..].to_vec())
- } else {
- bpe.decode(tokens)
- }
- } else {
- Err(anyhow!("bpe for open ai model was not retrieved"))
- }
- }
- fn capacity(&self) -> anyhow::Result<usize> {
- anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
- }
-}
@@ -6,7 +6,7 @@ use language::BufferSnapshot;
use util::ResultExt;
use crate::models::LanguageModel;
-use crate::templates::repository_context::PromptCodeSnippet;
+use crate::prompts::repository_context::PromptCodeSnippet;
pub(crate) enum PromptFileType {
Text,
@@ -125,6 +125,9 @@ impl PromptChain {
#[cfg(test)]
pub(crate) mod tests {
+ use crate::models::TruncationDirection;
+ use crate::test::FakeLanguageModel;
+
use super::*;
#[test]
@@ -141,7 +144,11 @@ pub(crate) mod tests {
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
- content = args.model.truncate(&content, max_token_length)?;
+ content = args.model.truncate(
+ &content,
+ max_token_length,
+ TruncationDirection::End,
+ )?;
token_count = max_token_length;
}
}
@@ -162,7 +169,11 @@ pub(crate) mod tests {
let mut token_count = args.model.count_tokens(&content)?;
if let Some(max_token_length) = max_token_length {
if token_count > max_token_length {
- content = args.model.truncate(&content, max_token_length)?;
+ content = args.model.truncate(
+ &content,
+ max_token_length,
+ TruncationDirection::End,
+ )?;
token_count = max_token_length;
}
}
@@ -171,38 +182,7 @@ pub(crate) mod tests {
}
}
- #[derive(Clone)]
- struct DummyLanguageModel {
- capacity: usize,
- }
-
- impl LanguageModel for DummyLanguageModel {
- fn name(&self) -> String {
- "dummy".to_string()
- }
- fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
- anyhow::Ok(content.chars().collect::<Vec<char>>().len())
- }
- fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
- anyhow::Ok(
- content.chars().collect::<Vec<char>>()[..length]
- .into_iter()
- .collect::<String>(),
- )
- }
- fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
- anyhow::Ok(
- content.chars().collect::<Vec<char>>()[length..]
- .into_iter()
- .collect::<String>(),
- )
- }
- fn capacity(&self) -> anyhow::Result<usize> {
- anyhow::Ok(self.capacity)
- }
- }
-
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 100 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -238,7 +218,7 @@ pub(crate) mod tests {
// Testing with Truncation Off
// Should ignore capacity and return all prompts
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity: 20 });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -275,7 +255,7 @@ pub(crate) mod tests {
// Testing with Truncation Off
// Should ignore capacity and return all prompts
let capacity = 20;
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -311,7 +291,7 @@ pub(crate) mod tests {
// Change Ordering of Prompts Based on Priority
let capacity = 120;
let reserved_tokens = 10;
- let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+ let model: Arc<dyn LanguageModel> = Arc::new(FakeLanguageModel { capacity });
let args = PromptArguments {
model: model.clone(),
language_name: None,
@@ -3,8 +3,9 @@ use language::BufferSnapshot;
use language::ToOffset;
use crate::models::LanguageModel;
-use crate::templates::base::PromptArguments;
-use crate::templates::base::PromptTemplate;
+use crate::models::TruncationDirection;
+use crate::prompts::base::PromptArguments;
+use crate::prompts::base::PromptTemplate;
use std::fmt::Write;
use std::ops::Range;
use std::sync::Arc;
@@ -70,8 +71,9 @@ fn retrieve_context(
};
let truncated_start_window =
- model.truncate_start(&start_window, start_goal_tokens)?;
- let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
+ model.truncate(&start_window, start_goal_tokens, TruncationDirection::Start)?;
+ let truncated_end_window =
+ model.truncate(&end_window, end_goal_tokens, TruncationDirection::End)?;
writeln!(
prompt,
"{truncated_start_window}{selected_window}{truncated_end_window}"
@@ -89,7 +91,7 @@ fn retrieve_context(
if let Some(max_token_count) = max_token_count {
if model.count_tokens(&prompt)? > max_token_count {
truncated = true;
- prompt = model.truncate(&prompt, max_token_count)?;
+ prompt = model.truncate(&prompt, max_token_count, TruncationDirection::End)?;
}
}
}
@@ -148,7 +150,9 @@ impl PromptTemplate for FileContext {
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
- prompt = args.model.truncate(&prompt, max_tokens)?;
+ prompt = args
+ .model
+ .truncate(&prompt, max_tokens, TruncationDirection::End)?;
}
let token_count = args.model.count_tokens(&prompt)?;
@@ -1,4 +1,4 @@
-use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use anyhow::anyhow;
use std::fmt::Write;
@@ -85,7 +85,11 @@ impl PromptTemplate for GenerateInlineContent {
// Really dumb truncation strategy
if let Some(max_tokens) = max_token_length {
- prompt = args.model.truncate(&prompt, max_tokens)?;
+ prompt = args.model.truncate(
+ &prompt,
+ max_tokens,
+ crate::models::TruncationDirection::End,
+ )?;
}
let token_count = args.model.count_tokens(&prompt)?;
@@ -1,4 +1,4 @@
-use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use crate::prompts::base::{PromptArguments, PromptFileType, PromptTemplate};
use std::fmt::Write;
pub struct EngineerPreamble {}
@@ -1,4 +1,4 @@
-use crate::templates::base::{PromptArguments, PromptTemplate};
+use crate::prompts::base::{PromptArguments, PromptTemplate};
use std::fmt::Write;
use std::{ops::Range, path::PathBuf};
@@ -0,0 +1 @@
+pub mod open_ai;
@@ -0,0 +1,298 @@
+use anyhow::{anyhow, Result};
+use futures::{
+ future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
+ Stream, StreamExt,
+};
+use gpui::{executor::Background, AppContext};
+use isahc::{http::StatusCode, Request, RequestExt};
+use parking_lot::RwLock;
+use serde::{Deserialize, Serialize};
+use std::{
+ env,
+ fmt::{self, Display},
+ io,
+ sync::Arc,
+};
+use util::ResultExt;
+
+use crate::{
+ auth::{CredentialProvider, ProviderCredential},
+ completion::{CompletionProvider, CompletionRequest},
+ models::LanguageModel,
+};
+
+use crate::providers::open_ai::{OpenAILanguageModel, OPENAI_API_URL};
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+ System,
+}
+
+impl Role {
+ pub fn cycle(&mut self) {
+ *self = match self {
+ Role::User => Role::Assistant,
+ Role::Assistant => Role::System,
+ Role::System => Role::User,
+ }
+ }
+}
+
+impl Display for Role {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Role::User => write!(f, "User"),
+ Role::Assistant => write!(f, "Assistant"),
+ Role::System => write!(f, "System"),
+ }
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+ pub role: Role,
+ pub content: String,
+}
+
+#[derive(Debug, Default, Serialize)]
+pub struct OpenAIRequest {
+ pub model: String,
+ pub messages: Vec<RequestMessage>,
+ pub stream: bool,
+ pub stop: Vec<String>,
+ pub temperature: f32,
+}
+
+impl CompletionRequest for OpenAIRequest {
+ fn data(&self) -> serde_json::Result<String> {
+ serde_json::to_string(self)
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct ResponseMessage {
+ pub role: Option<Role>,
+ pub content: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIUsage {
+ pub prompt_tokens: u32,
+ pub completion_tokens: u32,
+ pub total_tokens: u32,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ChatChoiceDelta {
+ pub index: u32,
+ pub delta: ResponseMessage,
+ pub finish_reason: Option<String>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct OpenAIResponseStreamEvent {
+ pub id: Option<String>,
+ pub object: String,
+ pub created: u32,
+ pub model: String,
+ pub choices: Vec<ChatChoiceDelta>,
+ pub usage: Option<OpenAIUsage>,
+}
+
+pub async fn stream_completion(
+ credential: ProviderCredential,
+ executor: Arc<Background>,
+ request: Box<dyn CompletionRequest>,
+) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
+ let api_key = match credential {
+ ProviderCredential::Credentials { api_key } => api_key,
+ _ => {
+ return Err(anyhow!("no credentials provider for completion"));
+ }
+ };
+
+ let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
+
+ let json_data = request.data()?;
+ let mut response = Request::post(format!("{OPENAI_API_URL}/chat/completions"))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(json_data)?
+ .send_async()
+ .await?;
+
+ let status = response.status();
+ if status == StatusCode::OK {
+ executor
+ .spawn(async move {
+ let mut lines = BufReader::new(response.body_mut()).lines();
+
+ fn parse_line(
+ line: Result<String, io::Error>,
+ ) -> Result<Option<OpenAIResponseStreamEvent>> {
+ if let Some(data) = line?.strip_prefix("data: ") {
+ let event = serde_json::from_str(&data)?;
+ Ok(Some(event))
+ } else {
+ Ok(None)
+ }
+ }
+
+ while let Some(line) = lines.next().await {
+ if let Some(event) = parse_line(line).transpose() {
+ let done = event.as_ref().map_or(false, |event| {
+ event
+ .choices
+ .last()
+ .map_or(false, |choice| choice.finish_reason.is_some())
+ });
+ if tx.unbounded_send(event).is_err() {
+ break;
+ }
+
+ if done {
+ break;
+ }
+ }
+ }
+
+ anyhow::Ok(())
+ })
+ .detach();
+
+ Ok(rx)
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ #[derive(Deserialize)]
+ struct OpenAIResponse {
+ error: OpenAIError,
+ }
+
+ #[derive(Deserialize)]
+ struct OpenAIError {
+ message: String,
+ }
+
+ match serde_json::from_str::<OpenAIResponse>(&body) {
+ Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
+ "Failed to connect to OpenAI API: {}",
+ response.error.message,
+ )),
+
+ _ => Err(anyhow!(
+ "Failed to connect to OpenAI API: {} {}",
+ response.status(),
+ body,
+ )),
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct OpenAICompletionProvider {
+ model: OpenAILanguageModel,
+ credential: Arc<RwLock<ProviderCredential>>,
+ executor: Arc<Background>,
+}
+
+impl OpenAICompletionProvider {
+ pub fn new(model_name: &str, executor: Arc<Background>) -> Self {
+ let model = OpenAILanguageModel::load(model_name);
+ let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+ Self {
+ model,
+ credential,
+ executor,
+ }
+ }
+}
+
+impl CredentialProvider for OpenAICompletionProvider {
+ fn has_credentials(&self) -> bool {
+ match *self.credential.read() {
+ ProviderCredential::Credentials { .. } => true,
+ _ => false,
+ }
+ }
+ fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+ let mut credential = self.credential.write();
+ match *credential {
+ ProviderCredential::Credentials { .. } => {
+ return credential.clone();
+ }
+ _ => {
+ if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ *credential = ProviderCredential::Credentials { api_key };
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ if let Some(api_key) = String::from_utf8(api_key).log_err() {
+ *credential = ProviderCredential::Credentials { api_key };
+ }
+ } else {
+ };
+ }
+ }
+
+ credential.clone()
+ }
+
+ fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+ match credential.clone() {
+ ProviderCredential::Credentials { api_key } => {
+ cx.platform()
+ .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ .log_err();
+ }
+ _ => {}
+ }
+
+ *self.credential.write() = credential;
+ }
+ fn delete_credentials(&self, cx: &AppContext) {
+ cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+ *self.credential.write() = ProviderCredential::NoCredentials;
+ }
+}
+
+impl CompletionProvider for OpenAICompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+ model
+ }
+ fn complete(
+ &self,
+ prompt: Box<dyn CompletionRequest>,
+ ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
+ // Currently the CompletionRequest for OpenAI, includes a 'model' parameter
+ // This means that the model is determined by the CompletionRequest and not the CompletionProvider,
+ // which is currently model based, due to the langauge model.
+ // At some point in the future we should rectify this.
+ let credential = self.credential.read().clone();
+ let request = stream_completion(credential, self.executor.clone(), prompt);
+ async move {
+ let response = request.await?;
+ let stream = response
+ .filter_map(|response| async move {
+ match response {
+ Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
+ Err(error) => Some(Err(error)),
+ }
+ })
+ .boxed();
+ Ok(stream)
+ }
+ .boxed()
+ }
+ fn box_clone(&self) -> Box<dyn CompletionProvider> {
+ Box::new((*self).clone())
+ }
+}
@@ -0,0 +1,306 @@
+use anyhow::{anyhow, Result};
+use async_trait::async_trait;
+use futures::AsyncReadExt;
+use gpui::executor::Background;
+use gpui::{serde_json, AppContext};
+use isahc::http::StatusCode;
+use isahc::prelude::Configurable;
+use isahc::{AsyncBody, Response};
+use lazy_static::lazy_static;
+use parking_lot::{Mutex, RwLock};
+use parse_duration::parse;
+use postage::watch;
+use serde::{Deserialize, Serialize};
+use std::env;
+use std::ops::Add;
+use std::sync::Arc;
+use std::time::{Duration, Instant};
+use tiktoken_rs::{cl100k_base, CoreBPE};
+use util::http::{HttpClient, Request};
+use util::ResultExt;
+
+use crate::auth::{CredentialProvider, ProviderCredential};
+use crate::embedding::{Embedding, EmbeddingProvider};
+use crate::models::LanguageModel;
+use crate::providers::open_ai::OpenAILanguageModel;
+
+use crate::providers::open_ai::OPENAI_API_URL;
+
+lazy_static! {
+ static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
+}
+
+#[derive(Clone)]
+pub struct OpenAIEmbeddingProvider {
+ model: OpenAILanguageModel,
+ credential: Arc<RwLock<ProviderCredential>>,
+ pub client: Arc<dyn HttpClient>,
+ pub executor: Arc<Background>,
+ rate_limit_count_rx: watch::Receiver<Option<Instant>>,
+ rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
+}
+
+#[derive(Serialize)]
+struct OpenAIEmbeddingRequest<'a> {
+ model: &'static str,
+ input: Vec<&'a str>,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingResponse {
+ data: Vec<OpenAIEmbedding>,
+ usage: OpenAIEmbeddingUsage,
+}
+
+#[derive(Debug, Deserialize)]
+struct OpenAIEmbedding {
+ embedding: Vec<f32>,
+ index: usize,
+ object: String,
+}
+
+#[derive(Deserialize)]
+struct OpenAIEmbeddingUsage {
+ prompt_tokens: usize,
+ total_tokens: usize,
+}
+
+impl OpenAIEmbeddingProvider {
+ pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+ let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
+ let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
+
+ let model = OpenAILanguageModel::load("text-embedding-ada-002");
+ let credential = Arc::new(RwLock::new(ProviderCredential::NoCredentials));
+
+ OpenAIEmbeddingProvider {
+ model,
+ credential,
+ client,
+ executor,
+ rate_limit_count_rx,
+ rate_limit_count_tx,
+ }
+ }
+
+ fn get_api_key(&self) -> Result<String> {
+ match self.credential.read().clone() {
+ ProviderCredential::Credentials { api_key } => Ok(api_key),
+ _ => Err(anyhow!("api credentials not provided")),
+ }
+ }
+
+ fn resolve_rate_limit(&self) {
+ let reset_time = *self.rate_limit_count_tx.lock().borrow();
+
+ if let Some(reset_time) = reset_time {
+ if Instant::now() >= reset_time {
+ *self.rate_limit_count_tx.lock().borrow_mut() = None
+ }
+ }
+
+ log::trace!(
+ "resolving reset time: {:?}",
+ *self.rate_limit_count_tx.lock().borrow()
+ );
+ }
+
+ fn update_reset_time(&self, reset_time: Instant) {
+ let original_time = *self.rate_limit_count_tx.lock().borrow();
+
+ let updated_time = if let Some(original_time) = original_time {
+ if reset_time < original_time {
+ Some(reset_time)
+ } else {
+ Some(original_time)
+ }
+ } else {
+ Some(reset_time)
+ };
+
+ log::trace!("updating rate limit time: {:?}", updated_time);
+
+ *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
+ }
+ async fn send_request(
+ &self,
+ api_key: &str,
+ spans: Vec<&str>,
+ request_timeout: u64,
+ ) -> Result<Response<AsyncBody>> {
+ let request = Request::post("https://api.openai.com/v1/embeddings")
+ .redirect_policy(isahc::config::RedirectPolicy::Follow)
+ .timeout(Duration::from_secs(request_timeout))
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", api_key))
+ .body(
+ serde_json::to_string(&OpenAIEmbeddingRequest {
+ input: spans.clone(),
+ model: "text-embedding-ada-002",
+ })
+ .unwrap()
+ .into(),
+ )?;
+
+ Ok(self.client.send(request).await?)
+ }
+}
+
+impl CredentialProvider for OpenAIEmbeddingProvider {
+ fn has_credentials(&self) -> bool {
+ match *self.credential.read() {
+ ProviderCredential::Credentials { .. } => true,
+ _ => false,
+ }
+ }
+ fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential {
+ let mut credential = self.credential.write();
+ match *credential {
+ ProviderCredential::Credentials { .. } => {
+ return credential.clone();
+ }
+ _ => {
+ if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ *credential = ProviderCredential::Credentials { api_key };
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ if let Some(api_key) = String::from_utf8(api_key).log_err() {
+ *credential = ProviderCredential::Credentials { api_key };
+ }
+ } else {
+ };
+ }
+ }
+
+ credential.clone()
+ }
+
+ fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {
+ match credential.clone() {
+ ProviderCredential::Credentials { api_key } => {
+ cx.platform()
+ .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
+ .log_err();
+ }
+ _ => {}
+ }
+
+ *self.credential.write() = credential;
+ }
+ fn delete_credentials(&self, cx: &AppContext) {
+ cx.platform().delete_credentials(OPENAI_API_URL).log_err();
+ *self.credential.write() = ProviderCredential::NoCredentials;
+ }
+}
+
+#[async_trait]
+impl EmbeddingProvider for OpenAIEmbeddingProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(self.model.clone());
+ model
+ }
+
+ fn max_tokens_per_batch(&self) -> usize {
+ 50000
+ }
+
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ *self.rate_limit_count_rx.borrow()
+ }
+
+ async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
+ const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
+ const MAX_RETRIES: usize = 4;
+
+ let api_key = self.get_api_key()?;
+
+ let mut request_number = 0;
+ let mut rate_limiting = false;
+ let mut request_timeout: u64 = 15;
+ let mut response: Response<AsyncBody>;
+ while request_number < MAX_RETRIES {
+ response = self
+ .send_request(
+ &api_key,
+ spans.iter().map(|x| &**x).collect(),
+ request_timeout,
+ )
+ .await?;
+
+ request_number += 1;
+
+ match response.status() {
+ StatusCode::REQUEST_TIMEOUT => {
+ request_timeout += 5;
+ }
+ StatusCode::OK => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ let response: OpenAIEmbeddingResponse = serde_json::from_str(&body)?;
+
+ log::trace!(
+ "openai embedding completed. tokens: {:?}",
+ response.usage.total_tokens
+ );
+
+ // If we complete a request successfully that was previously rate_limited
+ // resolve the rate limit
+ if rate_limiting {
+ self.resolve_rate_limit()
+ }
+
+ return Ok(response
+ .data
+ .into_iter()
+ .map(|embedding| Embedding::from(embedding.embedding))
+ .collect());
+ }
+ StatusCode::TOO_MANY_REQUESTS => {
+ rate_limiting = true;
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+
+ let delay_duration = {
+ let delay = Duration::from_secs(BACKOFF_SECONDS[request_number - 1] as u64);
+ if let Some(time_to_reset) =
+ response.headers().get("x-ratelimit-reset-tokens")
+ {
+ if let Ok(time_str) = time_to_reset.to_str() {
+ parse(time_str).unwrap_or(delay)
+ } else {
+ delay
+ }
+ } else {
+ delay
+ }
+ };
+
+ // If we've previously rate limited, increment the duration but not the count
+ let reset_time = Instant::now().add(delay_duration);
+ self.update_reset_time(reset_time);
+
+ log::trace!(
+ "openai rate limiting: waiting {:?} until lifted",
+ &delay_duration
+ );
+
+ self.executor.timer(delay_duration).await;
+ }
+ _ => {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ return Err(anyhow!(
+ "open ai bad request: {:?} {:?}",
+ &response.status(),
+ body
+ ));
+ }
+ }
+ }
+ Err(anyhow!("openai max retries"))
+ }
+}
@@ -0,0 +1,9 @@
+pub mod completion;
+pub mod embedding;
+pub mod model;
+
+pub use completion::*;
+pub use embedding::*;
+pub use model::OpenAILanguageModel;
+
+pub const OPENAI_API_URL: &'static str = "https://api.openai.com/v1";
@@ -0,0 +1,57 @@
+use anyhow::anyhow;
+use tiktoken_rs::CoreBPE;
+use util::ResultExt;
+
+use crate::models::{LanguageModel, TruncationDirection};
+
+#[derive(Clone)]
+pub struct OpenAILanguageModel {
+ name: String,
+ bpe: Option<CoreBPE>,
+}
+
+impl OpenAILanguageModel {
+ pub fn load(model_name: &str) -> Self {
+ let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
+ OpenAILanguageModel {
+ name: model_name.to_string(),
+ bpe,
+ }
+ }
+}
+
+impl LanguageModel for OpenAILanguageModel {
+ fn name(&self) -> String {
+ self.name.clone()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ if let Some(bpe) = &self.bpe {
+ anyhow::Ok(bpe.encode_with_special_tokens(content).len())
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
+ if let Some(bpe) = &self.bpe {
+ let tokens = bpe.encode_with_special_tokens(content);
+ if tokens.len() > length {
+ match direction {
+ TruncationDirection::End => bpe.decode(tokens[..length].to_vec()),
+ TruncationDirection::Start => bpe.decode(tokens[length..].to_vec()),
+ }
+ } else {
+ bpe.decode(tokens)
+ }
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
+ }
+}
@@ -0,0 +1,11 @@
+pub trait LanguageModel {
+ fn name(&self) -> String;
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String>;
+ fn capacity(&self) -> anyhow::Result<usize>;
+}
@@ -0,0 +1,191 @@
+use std::{
+ sync::atomic::{self, AtomicUsize, Ordering},
+ time::Instant,
+};
+
+use async_trait::async_trait;
+use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
+use gpui::AppContext;
+use parking_lot::Mutex;
+
+use crate::{
+ auth::{CredentialProvider, ProviderCredential},
+ completion::{CompletionProvider, CompletionRequest},
+ embedding::{Embedding, EmbeddingProvider},
+ models::{LanguageModel, TruncationDirection},
+};
+
+#[derive(Clone)]
+pub struct FakeLanguageModel {
+ pub capacity: usize,
+}
+
+impl LanguageModel for FakeLanguageModel {
+ fn name(&self) -> String {
+ "dummy".to_string()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+ }
+ fn truncate(
+ &self,
+ content: &str,
+ length: usize,
+ direction: TruncationDirection,
+ ) -> anyhow::Result<String> {
+ println!("TRYING TO TRUNCATE: {:?}", length.clone());
+
+ if length > self.count_tokens(content)? {
+ println!("NOT TRUNCATING");
+ return anyhow::Ok(content.to_string());
+ }
+
+ anyhow::Ok(match direction {
+ TruncationDirection::End => content.chars().collect::<Vec<char>>()[..length]
+ .into_iter()
+ .collect::<String>(),
+ TruncationDirection::Start => content.chars().collect::<Vec<char>>()[length..]
+ .into_iter()
+ .collect::<String>(),
+ })
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(self.capacity)
+ }
+}
+
+pub struct FakeEmbeddingProvider {
+ pub embedding_count: AtomicUsize,
+}
+
+impl Clone for FakeEmbeddingProvider {
+ fn clone(&self) -> Self {
+ FakeEmbeddingProvider {
+ embedding_count: AtomicUsize::new(self.embedding_count.load(Ordering::SeqCst)),
+ }
+ }
+}
+
+impl Default for FakeEmbeddingProvider {
+ fn default() -> Self {
+ FakeEmbeddingProvider {
+ embedding_count: AtomicUsize::default(),
+ }
+ }
+}
+
+impl FakeEmbeddingProvider {
+ pub fn embedding_count(&self) -> usize {
+ self.embedding_count.load(atomic::Ordering::SeqCst)
+ }
+
+ pub fn embed_sync(&self, span: &str) -> Embedding {
+ let mut result = vec![1.0; 26];
+ for letter in span.chars() {
+ let letter = letter.to_ascii_lowercase();
+ if letter as u32 >= 'a' as u32 {
+ let ix = (letter as u32) - ('a' as u32);
+ if ix < 26 {
+ result[ix as usize] += 1.0;
+ }
+ }
+ }
+
+ let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
+ for x in &mut result {
+ *x /= norm;
+ }
+
+ result.into()
+ }
+}
+
+impl CredentialProvider for FakeEmbeddingProvider {
+ fn has_credentials(&self) -> bool {
+ true
+ }
+ fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+ ProviderCredential::NotNeeded
+ }
+ fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
+ fn delete_credentials(&self, _cx: &AppContext) {}
+}
+
+#[async_trait]
+impl EmbeddingProvider for FakeEmbeddingProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ Box::new(FakeLanguageModel { capacity: 1000 })
+ }
+ fn max_tokens_per_batch(&self) -> usize {
+ 1000
+ }
+
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ None
+ }
+
+ async fn embed_batch(&self, spans: Vec<String>) -> anyhow::Result<Vec<Embedding>> {
+ self.embedding_count
+ .fetch_add(spans.len(), atomic::Ordering::SeqCst);
+
+ anyhow::Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
+ }
+}
+
+pub struct FakeCompletionProvider {
+ last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
+}
+
+impl Clone for FakeCompletionProvider {
+ fn clone(&self) -> Self {
+ Self {
+ last_completion_tx: Mutex::new(None),
+ }
+ }
+}
+
+impl FakeCompletionProvider {
+ pub fn new() -> Self {
+ Self {
+ last_completion_tx: Mutex::new(None),
+ }
+ }
+
+ pub fn send_completion(&self, completion: impl Into<String>) {
+ let mut tx = self.last_completion_tx.lock();
+ tx.as_mut().unwrap().try_send(completion.into()).unwrap();
+ }
+
+ pub fn finish_completion(&self) {
+ self.last_completion_tx.lock().take().unwrap();
+ }
+}
+
+impl CredentialProvider for FakeCompletionProvider {
+ fn has_credentials(&self) -> bool {
+ true
+ }
+ fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential {
+ ProviderCredential::NotNeeded
+ }
+ fn save_credentials(&self, _cx: &AppContext, _credential: ProviderCredential) {}
+ fn delete_credentials(&self, _cx: &AppContext) {}
+}
+
+impl CompletionProvider for FakeCompletionProvider {
+ fn base_model(&self) -> Box<dyn LanguageModel> {
+ let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
+ model
+ }
+ fn complete(
+ &self,
+ _prompt: Box<dyn CompletionRequest>,
+ ) -> BoxFuture<'static, anyhow::Result<BoxStream<'static, anyhow::Result<String>>>> {
+ let (tx, rx) = mpsc::channel(1);
+ *self.last_completion_tx.lock() = Some(tx);
+ async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
+ }
+ fn box_clone(&self) -> Box<dyn CompletionProvider> {
+ Box::new((*self).clone())
+ }
+}
@@ -45,6 +45,7 @@ tiktoken-rs = "0.5"
[dev-dependencies]
editor = { path = "../editor", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
+ai = { path = "../ai", features = ["test-support"]}
ctor.workspace = true
env_logger.workspace = true
@@ -4,7 +4,7 @@ mod codegen;
mod prompts;
mod streaming_diff;
-use ai::completion::Role;
+use ai::providers::open_ai::Role;
use anyhow::Result;
pub use assistant_panel::AssistantPanel;
use assistant_settings::OpenAIModel;
@@ -5,12 +5,14 @@ use crate::{
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
+
use ai::{
- completion::{
- stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
- },
- templates::repository_context::PromptCodeSnippet,
+ auth::ProviderCredential,
+ completion::{CompletionProvider, CompletionRequest},
+ providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage},
};
+
+use ai::prompts::repository_context::PromptCodeSnippet;
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
use client::{telemetry::AssistantKind, ClickhouseEvent, TelemetrySettings};
@@ -43,8 +45,8 @@ use search::BufferSearchBar;
use semantic_index::{SemanticIndex, SemanticIndexStatus};
use settings::SettingsStore;
use std::{
- cell::{Cell, RefCell},
- cmp, env,
+ cell::Cell,
+ cmp,
fmt::Write,
iter,
ops::Range,
@@ -97,8 +99,8 @@ pub fn init(cx: &mut AppContext) {
cx.capture_action(ConversationEditor::copy);
cx.add_action(ConversationEditor::split);
cx.capture_action(ConversationEditor::cycle_message_role);
- cx.add_action(AssistantPanel::save_api_key);
- cx.add_action(AssistantPanel::reset_api_key);
+ cx.add_action(AssistantPanel::save_credentials);
+ cx.add_action(AssistantPanel::reset_credentials);
cx.add_action(AssistantPanel::toggle_zoom);
cx.add_action(AssistantPanel::deploy);
cx.add_action(AssistantPanel::select_next_match);
@@ -140,9 +142,8 @@ pub struct AssistantPanel {
zoomed: bool,
has_focus: bool,
toolbar: ViewHandle<Toolbar>,
- api_key: Rc<RefCell<Option<String>>>,
+ completion_provider: Box<dyn CompletionProvider>,
api_key_editor: Option<ViewHandle<Editor>>,
- has_read_credentials: bool,
languages: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
subscriptions: Vec<Subscription>,
@@ -202,6 +203,11 @@ impl AssistantPanel {
});
let semantic_index = SemanticIndex::global(cx);
+ // Defaulting currently to GPT4, allow for this to be set via config.
+ let completion_provider = Box::new(OpenAICompletionProvider::new(
+ "gpt-4",
+ cx.background().clone(),
+ ));
let mut this = Self {
workspace: workspace_handle,
@@ -213,9 +219,8 @@ impl AssistantPanel {
zoomed: false,
has_focus: false,
toolbar,
- api_key: Rc::new(RefCell::new(None)),
+ completion_provider,
api_key_editor: None,
- has_read_credentials: false,
languages: workspace.app_state().languages.clone(),
fs: workspace.app_state().fs.clone(),
width: None,
@@ -254,10 +259,7 @@ impl AssistantPanel {
cx: &mut ViewContext<Workspace>,
) {
let this = if let Some(this) = workspace.panel::<AssistantPanel>(cx) {
- if this
- .update(cx, |assistant, cx| assistant.load_api_key(cx))
- .is_some()
- {
+ if this.update(cx, |assistant, _| assistant.has_credentials()) {
this
} else {
workspace.focus_panel::<AssistantPanel>(cx);
@@ -289,12 +291,6 @@ impl AssistantPanel {
cx: &mut ViewContext<Self>,
project: &ModelHandle<Project>,
) {
- let api_key = if let Some(api_key) = self.api_key.borrow().clone() {
- api_key
- } else {
- return;
- };
-
let selection = editor.read(cx).selections.newest_anchor().clone();
if selection.start.excerpt_id != selection.end.excerpt_id {
return;
@@ -325,10 +321,13 @@ impl AssistantPanel {
let inline_assist_id = post_inc(&mut self.next_inline_assist_id);
let provider = Arc::new(OpenAICompletionProvider::new(
- api_key,
+ "gpt-4",
cx.background().clone(),
));
+ // Retrieve Credentials Authenticates the Provider
+ // provider.retrieve_credentials(cx);
+
let codegen = cx.add_model(|cx| {
Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx)
});
@@ -745,13 +744,14 @@ impl AssistantPanel {
content: prompt,
});
- let request = OpenAIRequest {
+ let request = Box::new(OpenAIRequest {
model: model.full_name().into(),
messages,
stream: true,
stop: vec!["|END|>".to_string()],
temperature,
- };
+ });
+
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
anyhow::Ok(())
})
@@ -811,7 +811,7 @@ impl AssistantPanel {
fn new_conversation(&mut self, cx: &mut ViewContext<Self>) -> ViewHandle<ConversationEditor> {
let editor = cx.add_view(|cx| {
ConversationEditor::new(
- self.api_key.clone(),
+ self.completion_provider.clone(),
self.languages.clone(),
self.fs.clone(),
self.workspace.clone(),
@@ -870,17 +870,19 @@ impl AssistantPanel {
}
}
- fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+ fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
if let Some(api_key) = self
.api_key_editor
.as_ref()
.map(|editor| editor.read(cx).text(cx))
{
if !api_key.is_empty() {
- cx.platform()
- .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
- .log_err();
- *self.api_key.borrow_mut() = Some(api_key);
+ let credential = ProviderCredential::Credentials {
+ api_key: api_key.clone(),
+ };
+
+ self.completion_provider.save_credentials(cx, credential);
+
self.api_key_editor.take();
cx.focus_self();
cx.notify();
@@ -890,9 +892,8 @@ impl AssistantPanel {
}
}
- fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
- cx.platform().delete_credentials(OPENAI_API_URL).log_err();
- self.api_key.take();
+ fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext<Self>) {
+ self.completion_provider.delete_credentials(cx);
self.api_key_editor = Some(build_api_key_editor(cx));
cx.focus_self();
cx.notify();
@@ -1151,13 +1152,12 @@ impl AssistantPanel {
let fs = self.fs.clone();
let workspace = self.workspace.clone();
- let api_key = self.api_key.clone();
let languages = self.languages.clone();
cx.spawn(|this, mut cx| async move {
let saved_conversation = fs.load(&path).await?;
let saved_conversation = serde_json::from_str(&saved_conversation)?;
let conversation = cx.add_model(|cx| {
- Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx)
+ Conversation::deserialize(saved_conversation, path.clone(), languages, cx)
});
this.update(&mut cx, |this, cx| {
// If, by the time we've loaded the conversation, the user has already opened
@@ -1181,30 +1181,12 @@ impl AssistantPanel {
.position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path))
}
- fn load_api_key(&mut self, cx: &mut ViewContext<Self>) -> Option<String> {
- if self.api_key.borrow().is_none() && !self.has_read_credentials {
- self.has_read_credentials = true;
- let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
- Some(api_key)
- } else if let Some((_, api_key)) = cx
- .platform()
- .read_credentials(OPENAI_API_URL)
- .log_err()
- .flatten()
- {
- String::from_utf8(api_key).log_err()
- } else {
- None
- };
- if let Some(api_key) = api_key {
- *self.api_key.borrow_mut() = Some(api_key);
- } else if self.api_key_editor.is_none() {
- self.api_key_editor = Some(build_api_key_editor(cx));
- cx.notify();
- }
- }
+ fn has_credentials(&mut self) -> bool {
+ self.completion_provider.has_credentials()
+ }
- self.api_key.borrow().clone()
+ fn load_credentials(&mut self, cx: &mut ViewContext<Self>) {
+ self.completion_provider.retrieve_credentials(cx);
}
}
@@ -1389,7 +1371,7 @@ impl Panel for AssistantPanel {
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
if active {
- self.load_api_key(cx);
+ self.load_credentials(cx);
if self.editors.is_empty() {
self.new_conversation(cx);
@@ -1454,10 +1436,10 @@ struct Conversation {
token_count: Option<usize>,
max_token_count: usize,
pending_token_count: Task<Option<()>>,
- api_key: Rc<RefCell<Option<String>>>,
pending_save: Task<Result<()>>,
path: Option<PathBuf>,
_subscriptions: Vec<Subscription>,
+ completion_provider: Box<dyn CompletionProvider>,
}
impl Entity for Conversation {
@@ -1466,9 +1448,9 @@ impl Entity for Conversation {
impl Conversation {
fn new(
- api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
+ completion_provider: Box<dyn CompletionProvider>,
) -> Self {
let markdown = language_registry.language_for_name("Markdown");
let buffer = cx.add_model(|cx| {
@@ -1507,8 +1489,8 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: None,
- api_key,
buffer,
+ completion_provider,
};
let message = MessageAnchor {
id: MessageId(post_inc(&mut this.next_message_id.0)),
@@ -1554,7 +1536,6 @@ impl Conversation {
fn deserialize(
saved_conversation: SavedConversation,
path: PathBuf,
- api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
@@ -1563,6 +1544,10 @@ impl Conversation {
None => Some(Uuid::new_v4().to_string()),
};
let model = saved_conversation.model;
+ let completion_provider: Box<dyn CompletionProvider> = Box::new(
+ OpenAICompletionProvider::new(model.full_name(), cx.background().clone()),
+ );
+ completion_provider.retrieve_credentials(cx);
let markdown = language_registry.language_for_name("Markdown");
let mut message_anchors = Vec::new();
let mut next_message_id = MessageId(0);
@@ -1609,8 +1594,8 @@ impl Conversation {
_subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
pending_save: Task::ready(Ok(())),
path: Some(path),
- api_key,
buffer,
+ completion_provider,
};
this.count_remaining_tokens(cx);
this
@@ -1731,11 +1716,11 @@ impl Conversation {
}
if should_assist {
- let Some(api_key) = self.api_key.borrow().clone() else {
+ if !self.completion_provider.has_credentials() {
return Default::default();
- };
+ }
- let request = OpenAIRequest {
+ let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
model: self.model.full_name().to_string(),
messages: self
.messages(cx)
@@ -1745,9 +1730,9 @@ impl Conversation {
stream: true,
stop: vec![],
temperature: 1.0,
- };
+ });
- let stream = stream_completion(api_key, cx.background().clone(), request);
+ let stream = self.completion_provider.complete(request);
let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap();
@@ -1765,33 +1750,28 @@ impl Conversation {
let mut messages = stream.await?;
while let Some(message) = messages.next().await {
- let mut message = message?;
- if let Some(choice) = message.choices.pop() {
- this.upgrade(&cx)
- .ok_or_else(|| anyhow!("conversation was dropped"))?
- .update(&mut cx, |this, cx| {
- let text: Arc<str> = choice.delta.content?.into();
- let message_ix =
- this.message_anchors.iter().position(|message| {
- message.id == assistant_message_id
- })?;
- this.buffer.update(cx, |buffer, cx| {
- let offset = this.message_anchors[message_ix + 1..]
- .iter()
- .find(|message| message.start.is_valid(buffer))
- .map_or(buffer.len(), |message| {
- message
- .start
- .to_offset(buffer)
- .saturating_sub(1)
- });
- buffer.edit([(offset..offset, text)], None, cx);
- });
- cx.emit(ConversationEvent::StreamedCompletion);
-
- Some(())
+ let text = message?;
+
+ this.upgrade(&cx)
+ .ok_or_else(|| anyhow!("conversation was dropped"))?
+ .update(&mut cx, |this, cx| {
+ let message_ix = this
+ .message_anchors
+ .iter()
+ .position(|message| message.id == assistant_message_id)?;
+ this.buffer.update(cx, |buffer, cx| {
+ let offset = this.message_anchors[message_ix + 1..]
+ .iter()
+ .find(|message| message.start.is_valid(buffer))
+ .map_or(buffer.len(), |message| {
+ message.start.to_offset(buffer).saturating_sub(1)
+ });
+ buffer.edit([(offset..offset, text)], None, cx);
});
- }
+ cx.emit(ConversationEvent::StreamedCompletion);
+
+ Some(())
+ });
smol::future::yield_now().await;
}
@@ -2013,57 +1993,54 @@ impl Conversation {
fn summarize(&mut self, cx: &mut ModelContext<Self>) {
if self.message_anchors.len() >= 2 && self.summary.is_none() {
- let api_key = self.api_key.borrow().clone();
- if let Some(api_key) = api_key {
- let messages = self
- .messages(cx)
- .take(2)
- .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
- .chain(Some(RequestMessage {
- role: Role::User,
- content:
- "Summarize the conversation into a short title without punctuation"
- .into(),
- }));
- let request = OpenAIRequest {
- model: self.model.full_name().to_string(),
- messages: messages.collect(),
- stream: true,
- stop: vec![],
- temperature: 1.0,
- };
+ if !self.completion_provider.has_credentials() {
+ return;
+ }
- let stream = stream_completion(api_key, cx.background().clone(), request);
- self.pending_summary = cx.spawn(|this, mut cx| {
- async move {
- let mut messages = stream.await?;
+ let messages = self
+ .messages(cx)
+ .take(2)
+ .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
+ .chain(Some(RequestMessage {
+ role: Role::User,
+ content: "Summarize the conversation into a short title without punctuation"
+ .into(),
+ }));
+ let request: Box<dyn CompletionRequest> = Box::new(OpenAIRequest {
+ model: self.model.full_name().to_string(),
+ messages: messages.collect(),
+ stream: true,
+ stop: vec![],
+ temperature: 1.0,
+ });
- while let Some(message) = messages.next().await {
- let mut message = message?;
- if let Some(choice) = message.choices.pop() {
- let text = choice.delta.content.unwrap_or_default();
- this.update(&mut cx, |this, cx| {
- this.summary
- .get_or_insert(Default::default())
- .text
- .push_str(&text);
- cx.emit(ConversationEvent::SummaryChanged);
- });
- }
- }
+ let stream = self.completion_provider.complete(request);
+ self.pending_summary = cx.spawn(|this, mut cx| {
+ async move {
+ let mut messages = stream.await?;
+ while let Some(message) = messages.next().await {
+ let text = message?;
this.update(&mut cx, |this, cx| {
- if let Some(summary) = this.summary.as_mut() {
- summary.done = true;
- cx.emit(ConversationEvent::SummaryChanged);
- }
+ this.summary
+ .get_or_insert(Default::default())
+ .text
+ .push_str(&text);
+ cx.emit(ConversationEvent::SummaryChanged);
});
-
- anyhow::Ok(())
}
- .log_err()
- });
- }
+
+ this.update(&mut cx, |this, cx| {
+ if let Some(summary) = this.summary.as_mut() {
+ summary.done = true;
+ cx.emit(ConversationEvent::SummaryChanged);
+ }
+ });
+
+ anyhow::Ok(())
+ }
+ .log_err()
+ });
}
}
@@ -2224,13 +2201,14 @@ struct ConversationEditor {
impl ConversationEditor {
fn new(
- api_key: Rc<RefCell<Option<String>>>,
+ completion_provider: Box<dyn CompletionProvider>,
language_registry: Arc<LanguageRegistry>,
fs: Arc<dyn Fs>,
workspace: WeakViewHandle<Workspace>,
cx: &mut ViewContext<Self>,
) -> Self {
- let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx));
+ let conversation =
+ cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider));
Self::for_conversation(conversation, fs, workspace, cx)
}
@@ -3419,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
mod tests {
use super::*;
use crate::MessageId;
+ use ai::test::FakeCompletionProvider;
use gpui::AppContext;
#[gpui::test]
@@ -3426,7 +3405,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
- let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+
+ let completion_provider = Box::new(FakeCompletionProvider::new());
+ let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3554,7 +3535,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
- let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+ let completion_provider = Box::new(FakeCompletionProvider::new());
+
+ let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3650,7 +3633,8 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
- let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx));
+ let completion_provider = Box::new(FakeCompletionProvider::new());
+ let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_1 = conversation.read(cx).message_anchors[0].clone();
@@ -3732,8 +3716,9 @@ mod tests {
cx.set_global(SettingsStore::test(cx));
init(cx);
let registry = Arc::new(LanguageRegistry::test());
+ let completion_provider = Box::new(FakeCompletionProvider::new());
let conversation =
- cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx));
+ cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider));
let buffer = conversation.read(cx).buffer.clone();
let message_0 = conversation.read(cx).message_anchors[0].id;
let message_1 = conversation.update(cx, |conversation, cx| {
@@ -3770,7 +3755,6 @@ mod tests {
Conversation::deserialize(
conversation.read(cx).serialize(cx),
Default::default(),
- Default::default(),
registry.clone(),
cx,
)
@@ -1,5 +1,5 @@
use crate::streaming_diff::{Hunk, StreamingDiff};
-use ai::completion::{CompletionProvider, OpenAIRequest};
+use ai::completion::{CompletionProvider, CompletionRequest};
use anyhow::Result;
use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
@@ -96,7 +96,7 @@ impl Codegen {
self.error.as_ref()
}
- pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
+ pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
let range = self.range();
let snapshot = self.snapshot.clone();
let selected_text = snapshot
@@ -336,17 +336,25 @@ fn strip_markdown_codeblock(
#[cfg(test)]
mod tests {
use super::*;
- use futures::{
- future::BoxFuture,
- stream::{self, BoxStream},
- };
+ use ai::test::FakeCompletionProvider;
+ use futures::stream::{self};
use gpui::{executor::Deterministic, TestAppContext};
use indoc::indoc;
use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
- use parking_lot::Mutex;
use rand::prelude::*;
+ use serde::Serialize;
use settings::SettingsStore;
- use smol::future::FutureExt;
+
+ #[derive(Serialize)]
+ pub struct DummyCompletionRequest {
+ pub name: String,
+ }
+
+ impl CompletionRequest for DummyCompletionRequest {
+ fn data(&self) -> serde_json::Result<String> {
+ serde_json::to_string(self)
+ }
+ }
#[gpui::test(iterations = 10)]
async fn test_transform_autoindent(
@@ -372,7 +380,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
});
- let provider = Arc::new(TestCompletionProvider::new());
+ let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@@ -381,7 +389,11 @@ mod tests {
cx,
)
});
- codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+ let request = Box::new(DummyCompletionRequest {
+ name: "test".to_string(),
+ });
+ codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
" let mut x = 0;\n",
@@ -434,7 +446,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 6))
});
- let provider = Arc::new(TestCompletionProvider::new());
+ let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@@ -443,7 +455,11 @@ mod tests {
cx,
)
});
- codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+ let request = Box::new(DummyCompletionRequest {
+ name: "test".to_string(),
+ });
+ codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"t mut x = 0;\n",
@@ -496,7 +512,7 @@ mod tests {
let snapshot = buffer.snapshot(cx);
snapshot.anchor_before(Point::new(1, 2))
});
- let provider = Arc::new(TestCompletionProvider::new());
+ let provider = Arc::new(FakeCompletionProvider::new());
let codegen = cx.add_model(|cx| {
Codegen::new(
buffer.clone(),
@@ -505,7 +521,11 @@ mod tests {
cx,
)
});
- codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
+
+ let request = Box::new(DummyCompletionRequest {
+ name: "test".to_string(),
+ });
+ codegen.update(cx, |codegen, cx| codegen.start(request, cx));
let mut new_text = concat!(
"let mut x = 0;\n",
@@ -593,38 +613,6 @@ mod tests {
}
}
- struct TestCompletionProvider {
- last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
- }
-
- impl TestCompletionProvider {
- fn new() -> Self {
- Self {
- last_completion_tx: Mutex::new(None),
- }
- }
-
- fn send_completion(&self, completion: impl Into<String>) {
- let mut tx = self.last_completion_tx.lock();
- tx.as_mut().unwrap().try_send(completion.into()).unwrap();
- }
-
- fn finish_completion(&self) {
- self.last_completion_tx.lock().take().unwrap();
- }
- }
-
- impl CompletionProvider for TestCompletionProvider {
- fn complete(
- &self,
- _prompt: OpenAIRequest,
- ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
- let (tx, rx) = mpsc::channel(1);
- *self.last_completion_tx.lock() = Some(tx);
- async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
- }
- }
-
fn rust_lang() -> Language {
Language::new(
LanguageConfig {
@@ -1,9 +1,10 @@
-use ai::models::{LanguageModel, OpenAILanguageModel};
-use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
-use ai::templates::file_context::FileContext;
-use ai::templates::generate::GenerateInlineContent;
-use ai::templates::preamble::EngineerPreamble;
-use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
+use ai::models::LanguageModel;
+use ai::prompts::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
+use ai::prompts::file_context::FileContext;
+use ai::prompts::generate::GenerateInlineContent;
+use ai::prompts::preamble::EngineerPreamble;
+use ai::prompts::repository_context::{PromptCodeSnippet, RepositoryContext};
+use ai::providers::open_ai::OpenAILanguageModel;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp::{self, Reverse};
use std::ops::Range;
@@ -42,6 +42,7 @@ sha1 = "0.10.5"
ndarray = { version = "0.15.0" }
[dev-dependencies]
+ai = { path = "../ai", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
language = { path = "../language", features = ["test-support"] }
@@ -41,7 +41,6 @@ pub struct EmbeddingQueue {
pending_batch_token_count: usize,
finished_files_tx: channel::Sender<FileToEmbed>,
finished_files_rx: channel::Receiver<FileToEmbed>,
- api_key: Option<String>,
}
#[derive(Clone)]
@@ -51,11 +50,7 @@ pub struct FileFragmentToEmbed {
}
impl EmbeddingQueue {
- pub fn new(
- embedding_provider: Arc<dyn EmbeddingProvider>,
- executor: Arc<Background>,
- api_key: Option<String>,
- ) -> Self {
+ pub fn new(embedding_provider: Arc<dyn EmbeddingProvider>, executor: Arc<Background>) -> Self {
let (finished_files_tx, finished_files_rx) = channel::unbounded();
Self {
embedding_provider,
@@ -64,14 +59,9 @@ impl EmbeddingQueue {
pending_batch_token_count: 0,
finished_files_tx,
finished_files_rx,
- api_key,
}
}
- pub fn set_api_key(&mut self, api_key: Option<String>) {
- self.api_key = api_key
- }
-
pub fn push(&mut self, file: FileToEmbed) {
if file.spans.is_empty() {
self.finished_files_tx.try_send(file).unwrap();
@@ -118,7 +108,6 @@ impl EmbeddingQueue {
let finished_files_tx = self.finished_files_tx.clone();
let embedding_provider = self.embedding_provider.clone();
- let api_key = self.api_key.clone();
self.executor
.spawn(async move {
@@ -143,7 +132,7 @@ impl EmbeddingQueue {
return;
};
- match embedding_provider.embed_batch(spans, api_key).await {
+ match embedding_provider.embed_batch(spans).await {
Ok(embeddings) => {
let mut embeddings = embeddings.into_iter();
for fragment in batch {
@@ -1,4 +1,7 @@
-use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::{
+ embedding::{Embedding, EmbeddingProvider},
+ models::TruncationDirection,
+};
use anyhow::{anyhow, Result};
use language::{Grammar, Language};
use rusqlite::{
@@ -108,7 +111,14 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
- let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
+ let model = self.embedding_provider.base_model();
+ let document_span = model.truncate(
+ &document_span,
+ model.capacity()?,
+ ai::models::TruncationDirection::End,
+ )?;
+ let token_count = model.count_tokens(&document_span)?;
+
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
@@ -131,7 +141,15 @@ impl CodeContextRetriever {
)
.replace("<item>", &content);
let digest = SpanDigest::from(document_span.as_str());
- let (document_span, token_count) = self.embedding_provider.truncate(&document_span);
+
+ let model = self.embedding_provider.base_model();
+ let document_span = model.truncate(
+ &document_span,
+ model.capacity()?,
+ ai::models::TruncationDirection::End,
+ )?;
+ let token_count = model.count_tokens(&document_span)?;
+
Ok(vec![Span {
range: 0..content.len(),
content: document_span,
@@ -222,8 +240,13 @@ impl CodeContextRetriever {
.replace("<language>", language_name.as_ref())
.replace("item", &span.content);
- let (document_content, token_count) =
- self.embedding_provider.truncate(&document_content);
+ let model = self.embedding_provider.base_model();
+ let document_content = model.truncate(
+ &document_content,
+ model.capacity()?,
+ TruncationDirection::End,
+ )?;
+ let token_count = model.count_tokens(&document_content)?;
span.content = document_content;
span.token_count = token_count;
@@ -7,7 +7,8 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
+use ai::embedding::{Embedding, EmbeddingProvider};
+use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
@@ -88,7 +89,7 @@ pub fn init(
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
- Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+ Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
language_registry,
cx.clone(),
)
@@ -123,8 +124,6 @@ pub struct SemanticIndex {
_embedding_task: Task<()>,
_parsing_files_tasks: Vec<Task<()>>,
projects: HashMap<WeakModelHandle<Project>, ProjectState>,
- api_key: Option<String>,
- embedding_queue: Arc<Mutex<EmbeddingQueue>>,
}
struct ProjectState {
@@ -278,18 +277,18 @@ impl SemanticIndex {
}
}
- pub fn authenticate(&mut self, cx: &AppContext) {
- if self.api_key.is_none() {
- self.api_key = self.embedding_provider.retrieve_credentials(cx);
-
- self.embedding_queue
- .lock()
- .set_api_key(self.api_key.clone());
+ pub fn authenticate(&mut self, cx: &AppContext) -> bool {
+ if !self.embedding_provider.has_credentials() {
+ self.embedding_provider.retrieve_credentials(cx);
+ } else {
+ return true;
}
+
+ self.embedding_provider.has_credentials()
}
pub fn is_authenticated(&self) -> bool {
- self.api_key.is_some()
+ self.embedding_provider.has_credentials()
}
pub fn enabled(cx: &AppContext) -> bool {
@@ -339,7 +338,7 @@ impl SemanticIndex {
Ok(cx.add_model(|cx| {
let t0 = Instant::now();
let embedding_queue =
- EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone(), None);
+ EmbeddingQueue::new(embedding_provider.clone(), cx.background().clone());
let _embedding_task = cx.background().spawn({
let embedded_files = embedding_queue.finished_files();
let db = db.clone();
@@ -404,8 +403,6 @@ impl SemanticIndex {
_embedding_task,
_parsing_files_tasks,
projects: Default::default(),
- api_key: None,
- embedding_queue
}
}))
}
@@ -720,13 +717,13 @@ impl SemanticIndex {
let index = self.index_project(project.clone(), cx);
let embedding_provider = self.embedding_provider.clone();
- let api_key = self.api_key.clone();
cx.spawn(|this, mut cx| async move {
index.await?;
let t0 = Instant::now();
+
let query = embedding_provider
- .embed_batch(vec![query], api_key)
+ .embed_batch(vec![query])
.await?
.pop()
.ok_or_else(|| anyhow!("could not embed query"))?;
@@ -944,7 +941,6 @@ impl SemanticIndex {
let fs = self.fs.clone();
let db_path = self.db.path().clone();
let background = cx.background().clone();
- let api_key = self.api_key.clone();
cx.background().spawn(async move {
let db = VectorDatabase::new(fs, db_path.clone(), background).await?;
let mut results = Vec::<SearchResult>::new();
@@ -959,15 +955,10 @@ impl SemanticIndex {
.parse_file_with_template(None, &snapshot.text(), language)
.log_err()
.unwrap_or_default();
- if Self::embed_spans(
- &mut spans,
- embedding_provider.as_ref(),
- &db,
- api_key.clone(),
- )
- .await
- .log_err()
- .is_some()
+ if Self::embed_spans(&mut spans, embedding_provider.as_ref(), &db)
+ .await
+ .log_err()
+ .is_some()
{
for span in spans {
let similarity = span.embedding.unwrap().similarity(&query);
@@ -1007,9 +998,8 @@ impl SemanticIndex {
project: ModelHandle<Project>,
cx: &mut ModelContext<Self>,
) -> Task<Result<()>> {
- if self.api_key.is_none() {
- self.authenticate(cx);
- if self.api_key.is_none() {
+ if !self.is_authenticated() {
+ if !self.authenticate(cx) {
return Task::ready(Err(anyhow!("user is not authenticated")));
}
}
@@ -1192,7 +1182,6 @@ impl SemanticIndex {
spans: &mut [Span],
embedding_provider: &dyn EmbeddingProvider,
db: &VectorDatabase,
- api_key: Option<String>,
) -> Result<()> {
let mut batch = Vec::new();
let mut batch_tokens = 0;
@@ -1215,7 +1204,7 @@ impl SemanticIndex {
if batch_tokens + span.token_count > embedding_provider.max_tokens_per_batch() {
let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch), api_key.clone())
+ .embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
batch_tokens = 0;
@@ -1227,7 +1216,7 @@ impl SemanticIndex {
if !batch.is_empty() {
let batch_embeddings = embedding_provider
- .embed_batch(mem::take(&mut batch), api_key)
+ .embed_batch(mem::take(&mut batch))
.await?;
embeddings.extend(batch_embeddings);
@@ -4,10 +4,9 @@ use crate::{
semantic_index_settings::SemanticIndexSettings,
FileToEmbed, JobHandle, SearchResult, SemanticIndex, EMBEDDING_QUEUE_FLUSH_TIMEOUT,
};
-use ai::embedding::{DummyEmbeddings, Embedding, EmbeddingProvider};
-use anyhow::Result;
-use async_trait::async_trait;
-use gpui::{executor::Deterministic, AppContext, Task, TestAppContext};
+use ai::test::FakeEmbeddingProvider;
+
+use gpui::{executor::Deterministic, Task, TestAppContext};
use language::{Language, LanguageConfig, LanguageRegistry, ToOffset};
use parking_lot::Mutex;
use pretty_assertions::assert_eq;
@@ -15,14 +14,7 @@ use project::{project_settings::ProjectSettings, search::PathMatcher, FakeFs, Fs
use rand::{rngs::StdRng, Rng};
use serde_json::json;
use settings::SettingsStore;
-use std::{
- path::Path,
- sync::{
- atomic::{self, AtomicUsize},
- Arc,
- },
- time::{Instant, SystemTime},
-};
+use std::{path::Path, sync::Arc, time::SystemTime};
use unindent::Unindent;
use util::RandomCharIter;
@@ -228,7 +220,7 @@ async fn test_embedding_batching(cx: &mut TestAppContext, mut rng: StdRng) {
let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
- let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background(), None);
+ let mut queue = EmbeddingQueue::new(embedding_provider.clone(), cx.background());
for file in &files {
queue.push(file.clone());
}
@@ -280,7 +272,7 @@ fn assert_search_results(
#[gpui::test]
async fn test_code_context_retrieval_rust() {
let language = rust_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -382,7 +374,7 @@ async fn test_code_context_retrieval_rust() {
#[gpui::test]
async fn test_code_context_retrieval_json() {
let language = json_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -466,7 +458,7 @@ fn assert_documents_eq(
#[gpui::test]
async fn test_code_context_retrieval_javascript() {
let language = js_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -565,7 +557,7 @@ async fn test_code_context_retrieval_javascript() {
#[gpui::test]
async fn test_code_context_retrieval_lua() {
let language = lua_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -639,7 +631,7 @@ async fn test_code_context_retrieval_lua() {
#[gpui::test]
async fn test_code_context_retrieval_elixir() {
let language = elixir_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -756,7 +748,7 @@ async fn test_code_context_retrieval_elixir() {
#[gpui::test]
async fn test_code_context_retrieval_cpp() {
let language = cpp_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = "
@@ -909,7 +901,7 @@ async fn test_code_context_retrieval_cpp() {
#[gpui::test]
async fn test_code_context_retrieval_ruby() {
let language = ruby_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -1100,7 +1092,7 @@ async fn test_code_context_retrieval_ruby() {
#[gpui::test]
async fn test_code_context_retrieval_php() {
let language = php_lang();
- let embedding_provider = Arc::new(DummyEmbeddings {});
+ let embedding_provider = Arc::new(FakeEmbeddingProvider::default());
let mut retriever = CodeContextRetriever::new(embedding_provider);
let text = r#"
@@ -1248,65 +1240,6 @@ async fn test_code_context_retrieval_php() {
);
}
-#[derive(Default)]
-struct FakeEmbeddingProvider {
- embedding_count: AtomicUsize,
-}
-
-impl FakeEmbeddingProvider {
- fn embedding_count(&self) -> usize {
- self.embedding_count.load(atomic::Ordering::SeqCst)
- }
-
- fn embed_sync(&self, span: &str) -> Embedding {
- let mut result = vec![1.0; 26];
- for letter in span.chars() {
- let letter = letter.to_ascii_lowercase();
- if letter as u32 >= 'a' as u32 {
- let ix = (letter as u32) - ('a' as u32);
- if ix < 26 {
- result[ix as usize] += 1.0;
- }
- }
- }
-
- let norm = result.iter().map(|x| x * x).sum::<f32>().sqrt();
- for x in &mut result {
- *x /= norm;
- }
-
- result.into()
- }
-}
-
-#[async_trait]
-impl EmbeddingProvider for FakeEmbeddingProvider {
- fn retrieve_credentials(&self, _cx: &AppContext) -> Option<String> {
- Some("Fake Credentials".to_string())
- }
- fn truncate(&self, span: &str) -> (String, usize) {
- (span.to_string(), 1)
- }
-
- fn max_tokens_per_batch(&self) -> usize {
- 200
- }
-
- fn rate_limit_expiration(&self) -> Option<Instant> {
- None
- }
-
- async fn embed_batch(
- &self,
- spans: Vec<String>,
- _api_key: Option<String>,
- ) -> Result<Vec<Embedding>> {
- self.embedding_count
- .fetch_add(spans.len(), atomic::Ordering::SeqCst);
- Ok(spans.iter().map(|span| self.embed_sync(span)).collect())
- }
-}
-
fn js_lang() -> Arc<Language> {
Arc::new(
Language::new(
@@ -1,4 +1,4 @@
-use ai::embedding::OpenAIEmbeddings;
+use ai::providers::open_ai::OpenAIEmbeddingProvider;
use anyhow::{anyhow, Result};
use client::{self, UserStore};
use gpui::{AsyncAppContext, ModelHandle, Task};
@@ -475,7 +475,7 @@ fn main() {
let semantic_index = SemanticIndex::new(
fs.clone(),
db_file_path,
- Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+ Arc::new(OpenAIEmbeddingProvider::new(http_client, cx.background())),
languages.clone(),
cx.clone(),
)