collab: Remove code for embeddings (#29310)

Marshall Bowers created

This PR removes the embeddings-related code from collab and the
protocol, as we weren't using it anywhere.

Release Notes:

- N/A

Change summary

Cargo.lock                                   |   1 
crates/collab/Cargo.toml                     |   1 
crates/collab/src/rpc.rs                     | 137 ----------------------
crates/proto/proto/ai.proto                  |  23 ---
crates/proto/proto/zed.proto                 |   5 
crates/proto/src/proto.rs                    |   6 
crates/semantic_index/src/embedding.rs       |   2 
crates/semantic_index/src/embedding/cloud.rs |  93 --------------
8 files changed, 1 insertion(+), 267 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3017,7 +3017,6 @@ dependencies = [
  "nanoid",
  "node_runtime",
  "notifications",
- "open_ai",
  "parking_lot",
  "pretty_assertions",
  "project",

crates/collab/Cargo.toml 🔗

@@ -41,7 +41,6 @@ jsonwebtoken.workspace = true
 livekit_api.workspace = true
 log.workspace = true
 nanoid.workspace = true
-open_ai.workspace = true
 parking_lot.workspace = true
 prometheus = "0.14"
 prost.workspace = true

crates/collab/src/rpc.rs 🔗

@@ -34,10 +34,8 @@ use collections::{HashMap, HashSet};
 pub use connection_pool::{ConnectionPool, ZedVersion};
 use core::fmt::{self, Debug, Formatter};
 use http_client::HttpClient;
-use open_ai::{OPEN_AI_API_URL, OpenAiEmbeddingModel};
 use reqwest_client::ReqwestClient;
 use rpc::proto::split_repository_update;
-use sha2::Digest;
 use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
 
 use futures::{
@@ -437,18 +435,6 @@ impl Server {
                             .await
                     }
                 }
-            })
-            .add_request_handler(get_cached_embeddings)
-            .add_request_handler({
-                let app_state = app_state.clone();
-                move |request, response, session| {
-                    compute_embeddings(
-                        request,
-                        response,
-                        session,
-                        app_state.config.openai_api_key.clone(),
-                    )
-                }
             });
 
         Arc::new(server)
@@ -3780,129 +3766,6 @@ impl RateLimit for FreeCountLanguageModelTokensRateLimit {
     }
 }
 
-struct ZedProComputeEmbeddingsRateLimit;
-
-impl RateLimit for ZedProComputeEmbeddingsRateLimit {
-    fn capacity(&self) -> usize {
-        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
-            .ok()
-            .and_then(|v| v.parse().ok())
-            .unwrap_or(5000) // Picked arbitrarily
-    }
-
-    fn refill_duration(&self) -> chrono::Duration {
-        chrono::Duration::hours(1)
-    }
-
-    fn db_name(&self) -> &'static str {
-        "zed-pro:compute-embeddings"
-    }
-}
-
-struct FreeComputeEmbeddingsRateLimit;
-
-impl RateLimit for FreeComputeEmbeddingsRateLimit {
-    fn capacity(&self) -> usize {
-        std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
-            .ok()
-            .and_then(|v| v.parse().ok())
-            .unwrap_or(5000 / 10) // Picked arbitrarily
-    }
-
-    fn refill_duration(&self) -> chrono::Duration {
-        chrono::Duration::hours(1)
-    }
-
-    fn db_name(&self) -> &'static str {
-        "free:compute-embeddings"
-    }
-}
-
-async fn compute_embeddings(
-    request: proto::ComputeEmbeddings,
-    response: Response<proto::ComputeEmbeddings>,
-    session: Session,
-    api_key: Option<Arc<str>>,
-) -> Result<()> {
-    let api_key = api_key.context("no OpenAI API key configured on the server")?;
-    authorize_access_to_legacy_llm_endpoints(&session).await?;
-
-    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
-        proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
-        proto::Plan::Free | proto::Plan::ZedProTrial => Box::new(FreeComputeEmbeddingsRateLimit),
-    };
-
-    session
-        .app_state
-        .rate_limiter
-        .check(&*rate_limit, session.user_id())
-        .await?;
-
-    let embeddings = match request.model.as_str() {
-        "openai/text-embedding-3-small" => {
-            open_ai::embed(
-                session.http_client.as_ref(),
-                OPEN_AI_API_URL,
-                &api_key,
-                OpenAiEmbeddingModel::TextEmbedding3Small,
-                request.texts.iter().map(|text| text.as_str()),
-            )
-            .await?
-        }
-        provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
-    };
-
-    let embeddings = request
-        .texts
-        .iter()
-        .map(|text| {
-            let mut hasher = sha2::Sha256::new();
-            hasher.update(text.as_bytes());
-            let result = hasher.finalize();
-            result.to_vec()
-        })
-        .zip(
-            embeddings
-                .data
-                .into_iter()
-                .map(|embedding| embedding.embedding),
-        )
-        .collect::<HashMap<_, _>>();
-
-    let db = session.db().await;
-    db.save_embeddings(&request.model, &embeddings)
-        .await
-        .context("failed to save embeddings")
-        .trace_err();
-
-    response.send(proto::ComputeEmbeddingsResponse {
-        embeddings: embeddings
-            .into_iter()
-            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
-            .collect(),
-    })?;
-    Ok(())
-}
-
-async fn get_cached_embeddings(
-    request: proto::GetCachedEmbeddings,
-    response: Response<proto::GetCachedEmbeddings>,
-    session: Session,
-) -> Result<()> {
-    authorize_access_to_legacy_llm_endpoints(&session).await?;
-
-    let db = session.db().await;
-    let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
-
-    response.send(proto::GetCachedEmbeddingsResponse {
-        embeddings: embeddings
-            .into_iter()
-            .map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
-            .collect(),
-    })?;
-    Ok(())
-}
-
 /// This is leftover from before the LLM service.
 ///
 /// The endpoints protected by this check will be moved there eventually.

crates/proto/proto/ai.proto 🔗

@@ -188,26 +188,3 @@ enum LanguageModelProvider {
     Google = 2;
     Zed = 3;
 }
-
-message GetCachedEmbeddings {
-    string model = 1;
-    repeated bytes digests = 2;
-}
-
-message GetCachedEmbeddingsResponse {
-    repeated Embedding embeddings = 1;
-}
-
-message ComputeEmbeddings {
-    string model = 1;
-    repeated string texts = 2;
-}
-
-message ComputeEmbeddingsResponse {
-    repeated Embedding embeddings = 1;
-}
-
-message Embedding {
-    bytes digest = 1;
-    repeated float dimensions = 2;
-}

crates/proto/proto/zed.proto 🔗

@@ -208,10 +208,6 @@ message Envelope {
 
         CountLanguageModelTokens count_language_model_tokens = 230;
         CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
-        GetCachedEmbeddings get_cached_embeddings = 189;
-        GetCachedEmbeddingsResponse get_cached_embeddings_response = 190;
-        ComputeEmbeddings compute_embeddings = 191;
-        ComputeEmbeddingsResponse compute_embeddings_response = 192;
 
         UpdateChannelMessage update_channel_message = 170;
         ChannelMessageUpdate channel_message_update = 171;
@@ -394,6 +390,7 @@ message Envelope {
     reserved 166 to 169;
     reserved 177 to 185;
     reserved 188;
+    reserved 189 to 192;
     reserved 193 to 195;
     reserved 197;
     reserved 200 to 202;

crates/proto/src/proto.rs 🔗

@@ -49,8 +49,6 @@ messages!(
     (ChannelMessageUpdate, Foreground),
     (CloseBuffer, Foreground),
     (Commit, Background),
-    (ComputeEmbeddings, Background),
-    (ComputeEmbeddingsResponse, Background),
     (CopyProjectEntry, Foreground),
     (CountLanguageModelTokens, Background),
     (CountLanguageModelTokensResponse, Background),
@@ -82,8 +80,6 @@ messages!(
     (FormatBuffers, Foreground),
     (FormatBuffersResponse, Foreground),
     (FuzzySearchUsers, Foreground),
-    (GetCachedEmbeddings, Background),
-    (GetCachedEmbeddingsResponse, Background),
     (GetChannelMembers, Foreground),
     (GetChannelMembersResponse, Foreground),
     (GetChannelMessages, Background),
@@ -319,7 +315,6 @@ request_messages!(
     (CancelCall, Ack),
     (Commit, Ack),
     (CopyProjectEntry, ProjectEntryResponse),
-    (ComputeEmbeddings, ComputeEmbeddingsResponse),
     (CreateChannel, CreateChannelResponse),
     (CreateProjectEntry, ProjectEntryResponse),
     (CreateRoom, CreateRoomResponse),
@@ -332,7 +327,6 @@ request_messages!(
     (ApplyCodeActionKind, ApplyCodeActionKindResponse),
     (FormatBuffers, FormatBuffersResponse),
     (FuzzySearchUsers, UsersResponse),
-    (GetCachedEmbeddings, GetCachedEmbeddingsResponse),
     (GetChannelMembers, GetChannelMembersResponse),
     (GetChannelMessages, GetChannelMessagesResponse),
     (GetChannelMessagesById, GetChannelMessagesResponse),

crates/semantic_index/src/embedding/cloud.rs 🔗

@@ -1,93 +0,0 @@
-use crate::{Embedding, EmbeddingProvider, TextToEmbed};
-use anyhow::{Context as _, Result, anyhow};
-use client::{Client, proto};
-use collections::HashMap;
-use futures::{FutureExt, future::BoxFuture};
-use std::sync::Arc;
-
-pub struct CloudEmbeddingProvider {
-    model: String,
-    client: Arc<Client>,
-}
-
-impl CloudEmbeddingProvider {
-    pub fn new(client: Arc<Client>) -> Self {
-        Self {
-            model: "openai/text-embedding-3-small".into(),
-            client,
-        }
-    }
-}
-
-impl EmbeddingProvider for CloudEmbeddingProvider {
-    fn embed<'a>(&'a self, texts: &'a [TextToEmbed<'a>]) -> BoxFuture<'a, Result<Vec<Embedding>>> {
-        // First, fetch any embeddings that are cached based on the requested texts' digests
-        // Then compute any embeddings that are missing.
-        async move {
-            if !self.client.status().borrow().is_connected() {
-                return Err(anyhow!("sign in required"));
-            }
-
-            let cached_embeddings = self.client.request(proto::GetCachedEmbeddings {
-                model: self.model.clone(),
-                digests: texts
-                    .iter()
-                    .map(|to_embed| to_embed.digest.to_vec())
-                    .collect(),
-            });
-            let mut embeddings = cached_embeddings
-                .await
-                .context("failed to fetch cached embeddings via cloud model")?
-                .embeddings
-                .into_iter()
-                .map(|embedding| {
-                    let digest: [u8; 32] = embedding
-                        .digest
-                        .try_into()
-                        .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
-                    Ok((digest, embedding.dimensions))
-                })
-                .collect::<Result<HashMap<_, _>>>()?;
-
-            let compute_embeddings_request = proto::ComputeEmbeddings {
-                model: self.model.clone(),
-                texts: texts
-                    .iter()
-                    .filter_map(|to_embed| {
-                        if embeddings.contains_key(&to_embed.digest) {
-                            None
-                        } else {
-                            Some(to_embed.text.to_string())
-                        }
-                    })
-                    .collect(),
-            };
-            if !compute_embeddings_request.texts.is_empty() {
-                let missing_embeddings = self.client.request(compute_embeddings_request).await?;
-                for embedding in missing_embeddings.embeddings {
-                    let digest: [u8; 32] = embedding
-                        .digest
-                        .try_into()
-                        .map_err(|_| anyhow!("invalid digest for cached embedding"))?;
-                    embeddings.insert(digest, embedding.dimensions);
-                }
-            }
-
-            texts
-                .iter()
-                .map(|to_embed| {
-                    let embedding =
-                        embeddings.get(&to_embed.digest).cloned().with_context(|| {
-                            format!("server did not return an embedding for {:?}", to_embed)
-                        })?;
-                    Ok(Embedding::new(embedding))
-                })
-                .collect()
-        }
-        .boxed()
-    }
-
-    fn batch_size(&self) -> usize {
-        2048
-    }
-}