Cargo.lock 🔗
@@ -3017,7 +3017,6 @@ dependencies = [
"nanoid",
"node_runtime",
"notifications",
- "open_ai",
"parking_lot",
"pretty_assertions",
"project",
Marshall Bowers created
This PR removes the embeddings-related code from collab and the
protocol, as we weren't using it anywhere.
Release Notes:
- N/A
Cargo.lock | 1
crates/collab/Cargo.toml | 1
crates/collab/src/rpc.rs | 137 ----------------------
crates/proto/proto/ai.proto | 23 ---
crates/proto/proto/zed.proto | 5
crates/proto/src/proto.rs | 6
crates/semantic_index/src/embedding.rs | 2
crates/semantic_index/src/embedding/cloud.rs | 93 --------------
8 files changed, 1 insertion(+), 267 deletions(-)
@@ -3017,7 +3017,6 @@ dependencies = [
"nanoid",
"node_runtime",
"notifications",
- "open_ai",
"parking_lot",
"pretty_assertions",
"project",
@@ -41,7 +41,6 @@ jsonwebtoken.workspace = true
livekit_api.workspace = true
log.workspace = true
nanoid.workspace = true
-open_ai.workspace = true
parking_lot.workspace = true
prometheus = "0.14"
prost.workspace = true
@@ -34,10 +34,8 @@ use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
use http_client::HttpClient;
-use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
use reqwest_client::ReqwestClient;
use rpc::proto::split_repository_update;
-use sha2::Digest;
use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
use futures::{
@@ -437,18 +435,6 @@ impl Server {
.await
}
}
- })
- .add_request_handler(get_cached_embeddings)
- .add_request_handler({
- let app_state = app_state.clone();
- move |request, response, session| {
- compute_embeddings(
- request,
- response,
- session,
- app_state.config.openai_api_key.clone(),
- )
- }
});
Arc::new(server)
@@ -3780,129 +3766,6 @@ impl RateLimit for FreeCountLanguageModelTokensRateLimit {
}
}
-struct ZedProComputeEmbeddingsRateLimit;
-
-impl RateLimit for ZedProComputeEmbeddingsRateLimit {
- fn capacity(&self) -> usize {
- std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
- .ok()
- .and_then(|v| v.parse().ok())
- .unwrap_or(5000) // Picked arbitrarily
- }
-
- fn refill_duration(&self) -> chrono::Duration {
- chrono::Duration::hours(1)
- }
-
- fn db_name(&self) -> &'static str {
- "zed-pro:compute-embeddings"
- }
-}
-
-struct FreeComputeEmbeddingsRateLimit;
-
-impl RateLimit for FreeComputeEmbeddingsRateLimit {
- fn capacity(&self) -> usize {
- std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
- .ok()
- .and_then(|v| v.parse().ok())
- .unwrap_or(5000 / 10) // Picked arbitrarily
- }
-
- fn refill_duration(&self) -> chrono::Duration {
- chrono::Duration::hours(1)
- }
-
- fn db_name(&self) -> &'static str {
- "free:compute-embeddings"
- }
-}
-
-async fn compute_embeddings(
- request: proto::ComputeEmbeddings,
- response: Response<proto::ComputeEmbeddings>,
- session: Session,
- api_key: Option<Arc<str>>,
-) -> Result<()> {
- let api_key = api_key.context("no OpenAI API key configured on the server")?;
- authorize_access_to_legacy_llm_endpoints(&session).await?;
-
- let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
- proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
- proto::Plan::Free | proto::Plan::ZedProTrial => Box::new(FreeComputeEmbeddingsRateLimit),
- };
-
- session
- .app_state
- .rate_limiter
- .check(&*rate_limit, session.user_id())
- .await?;
-
- let embeddings = match request.model.as_str() {
- "openai/text-embedding-3-small" => {
- open_ai::embed(
- session.http_client.as_ref(),
- OPEN_AI_API_URL,
- &api_key,
- OpenAiEmbeddingModel::TextEmbedding3Small,
- request.texts.iter().map(|text| text.as_str()),
- )
- .await?
- }
- provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
- };
-
- let embeddings = request
- .texts
- .iter()
- .map(|text| {
- let mut hasher = sha2::Sha256::new();
- hasher.update(text.as_bytes());
- let result = hasher.finalize();
- result.to_vec()
- })
- .zip(
- embeddings
- .data
- .into_iter()
- .map(|embedding| embedding.embedding),
- )
- .collect::<HashMap<_, _>>();
-
- let db = session.db().await;
- db.save_embeddings(&request.model, &embeddings)
- .await
- .context("failed to save embeddings")
- .trace_err();
-
- response.send(proto::ComputeEmbeddingsResponse {
- embeddings: embeddings
- .into_iter()
- .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
- .collect(),
- })?;
- Ok(())
-}
-
-async fn get_cached_embeddings(
- request: proto::GetCachedEmbeddings,
- response: Response<proto::GetCachedEmbeddings>,
- session: Session,
-) -> Result<()> {
- authorize_access_to_legacy_llm_endpoints(&session).await?;
-
- let db = session.db().await;
- let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
-
- response.send(proto::GetCachedEmbeddingsResponse {
- embeddings: embeddings
- .into_iter()
- .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
- .collect(),
- })?;
- Ok(())
-}
-
/// This is leftover from before the LLM service.
///
/// The endpoints protected by this check will be moved there eventually.
@@ -188,26 +188,3 @@ enum LanguageModelProvider {
Google = 2;
Zed = 3;
}
-
-message GetCachedEmbeddings {
- string model = 1;
- repeated bytes digests = 2;
-}
-
-message GetCachedEmbeddingsResponse {
- repeated Embedding embeddings = 1;
-}
-
-message ComputeEmbeddings {
- string model = 1;
- repeated string texts = 2;
-}
-
-message ComputeEmbeddingsResponse {
- repeated Embedding embeddings = 1;
-}
-
-message Embedding {
- bytes digest = 1;
- repeated float dimensions = 2;
-}
@@ -208,10 +208,6 @@ message Envelope {
CountLanguageModelTokens count_language_model_tokens = 230;
CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
- GetCachedEmbeddings get_cached_embeddings = 189;
- GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
- ComputeEmbeddings compute_embeddings = 191;
- ComputeEmbeddingsResponse compute_embeddings_response = 192;
UpdateChannelMessage update_channel_message = 170;
ChannelMessageUpdate channel_message_update = 171;
@@ -394,6 +390,7 @@ message Envelope {
reserved 166 to 169;
reserved 177 to 185;
reserved 188;
+ reserved 189 to 192;
reserved 193 to 195;
reserved 197;
reserved 200 to 202;
@@ -49,8 +49,6 @@ messages!(
(ChannelMessageUpdate, Foreground),
(CloseBuffer, Foreground),
(Commit, Background),
- (ComputeEmbeddings, Background),
- (ComputeEmbeddingsResponse, Background),
(CopyProjectEntry, Foreground),
(CountLanguageModelTokens, Background),
(CountLanguageModelTokensResponse, Background),
@@ -82,8 +80,6 @@ messages!(
(FormatBuffers, Foreground),
(FormatBuffersResponse, Foreground),
(FuzzySearchUsers, Foreground),
- (GetCachedEmbeddings, Background),
- (GetCachedEmbeddingsResponse, Background),
(GetChannelMembers, Foreground),
(GetChannelMembersResponse, Foreground),
(GetChannelMessages, Background),
@@ -319,7 +315,6 @@ request_messages!(
(CancelCall, Ack),
(Commit, Ack),
(CopyProjectEntry, ProjectEntryResponse),
- (ComputeEmbeddings, ComputeEmbeddingsResponse),
(CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse),
(CreateRoom, CreateRoomResponse),
@@ -332,7 +327,6 @@ request_messages!(
(ApplyCodeActionKind, ApplyCodeActionKindResponse),
(FormatBuffers, FormatBuffersResponse),
(FuzzySearchUsers, UsersResponse),
- (GetCachedEmbeddings, GetCachedEmbeddingsResponse),
(GetChannelMembers, GetChannelMembersResponse),
(GetChannelMessages, GetChannelMessagesResponse),
(GetChannelMessagesById, GetChannelMessagesResponse),
@@ -1,9 +1,7 @@
-mod cloud;
mod lmstudio;
mod ollama;
mod open_ai;
-pub use cloud::*;
pub use lmstudio::*;
pub use ollama::*;
pub use open_ai::*;
@@ -1,93 +0,0 @@
-use crate::{Embedding, EmbeddingProvider, TextToEmbed};
-use anyhow::{Context as _, Result, anyhow};
-use client::{Client, proto};
-use collections::HashMap;
-use futures::{FutureExt, future::BoxFuture};
-use std::sync::Arc;
-
-pub struct CloudEmbeddingProvider {
- model: String,
- client: Arc<Client>,
-}
-
-impl CloudEmbeddingProvider {
- pub fn new(client: Arc<Client>) -> Self {
- Self {
- model: "openai/text-embedding-3-small".into(),
- client,
- }
- }
-}
-
-impl EmbeddingProvider for CloudEmbeddingProvider {
- fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
- // First, fetch any embeddings that are cached based on the requested texts' digests
- // Then compute any embeddings that are missing.
- async move {
- if !self.client.status().borrow().is_connected() {
- return Err(anyhow!("sign in required"));
- }
-
- let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
- model: self.model.clone(),
- digests: texts
- .iter()
- .map(|to_embed| to_embed.digest.to_vec())
- .collect(),
- });
- let mut embeddings = cached_embeddings
- .await
- .context("failed to fetch cached embeddings via cloud model")?
- .embeddings
- .into_iter()
- .map(|embedding| {
- let digest: [u8; 32] = embedding
- .digest
- .try_into()
- .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
- Ok((digest, embedding.dimensions))
- })
- .collect::<Result<HashMap<_, _>>>()?;
-
- let compute_embeddings_request = proto::ComputeEmbeddings {
- model: self.model.clone(),
- texts: texts
- .iter()
- .filter_map(|to_embed| {
- if embeddings.contains_key(&to_embed.digest) {
- None
- } else {
- Some(to_embed.text.to_string())
- }
- })
- .collect(),
- };
- if !compute_embeddings_request.texts.is_empty() {
- let missing_embeddings = self.client.request(compute_embeddings_request).await?;
- for embedding in missing_embeddings.embeddings {
- let digest: [u8; 32] = embedding
- .digest
- .try_into()
- .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
- embeddings.insert(digest, embedding.dimensions);
- }
- }
-
- texts
- .iter()
- .map(|to_embed| {
- let embedding =
- embeddings.get(&to_embed.digest).cloned().with_context(|| {
- format!("server did not return an embedding for {:?}", to_embed)
- })?;
- Ok(Embedding::new(embedding))
- })
- .collect()
- }
- .boxed()
- }
-
- fn batch_size(&self) -> usize {
- 2048
- }
-}