collab: Remove LLM completions over RPC (#16114)

Marshall Bowers created

This PR removes the LLM completion messages from the RPC protocol, as
these now go through the LLM service as of #16113.

Release Notes:

- N/A

Change summary

crates/collab/src/rpc.rs     | 267 --------------------------------------
crates/proto/proto/zed.proto |  24 ---
crates/proto/src/proto.rs    |   9 -
3 files changed, 1 insertion(+), 299 deletions(-)

Detailed changes

crates/collab/src/rpc.rs 🔗

@@ -105,18 +105,6 @@ impl<R: RequestMessage> Response<R> {
     }
 }
 
-struct StreamingResponse<R: RequestMessage> {
-    peer: Arc<Peer>,
-    receipt: Receipt<R>,
-}
-
-impl<R: RequestMessage> StreamingResponse<R> {
-    fn send(&self, payload: R::Response) -> Result<()> {
-        self.peer.respond(self.receipt, payload)?;
-        Ok(())
-    }
-}
-
 #[derive(Clone, Debug)]
 pub enum Principal {
     User(User),
@@ -630,31 +618,6 @@ impl Server {
             ))
             .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
             .add_message_handler(update_context)
-            .add_request_handler({
-                let app_state = app_state.clone();
-                move |request, response, session| {
-                    let app_state = app_state.clone();
-                    async move {
-                        complete_with_language_model(request, response, session, &app_state.config)
-                            .await
-                    }
-                }
-            })
-            .add_streaming_request_handler({
-                let app_state = app_state.clone();
-                move |request, response, session| {
-                    let app_state = app_state.clone();
-                    async move {
-                        stream_complete_with_language_model(
-                            request,
-                            response,
-                            session,
-                            &app_state.config,
-                        )
-                        .await
-                    }
-                }
-            })
             .add_request_handler({
                 let app_state = app_state.clone();
                 move |request, response, session| {
@@ -948,40 +911,6 @@ impl Server {
         })
     }
 
-    fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
-    where
-        F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
-        Fut: Send + Future<Output = Result<()>>,
-        M: RequestMessage,
-    {
-        let handler = Arc::new(handler);
-        self.add_handler(move |envelope, session| {
-            let receipt = envelope.receipt();
-            let handler = handler.clone();
-            async move {
-                let peer = session.peer.clone();
-                let response = StreamingResponse {
-                    peer: peer.clone(),
-                    receipt,
-                };
-                match (handler)(envelope.payload, response, session).await {
-                    Ok(()) => {
-                        peer.end_stream(receipt)?;
-                        Ok(())
-                    }
-                    Err(error) => {
-                        let proto_err = match &error {
-                            Error::Internal(err) => err.to_proto(),
-                            _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
-                        };
-                        peer.respond_with_error(receipt, proto_err)?;
-                        Err(error)
-                    }
-                }
-            }
-        })
-    }
-
     #[allow(clippy::too_many_arguments)]
     pub fn handle_connection(
         self: &Arc<Self>,
@@ -4561,202 +4490,6 @@ async fn acknowledge_buffer_version(
     Ok(())
 }
 
-struct ZedProCompleteWithLanguageModelRateLimit;
-
-impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
-    fn capacity(&self) -> usize {
-        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
-            .ok()
-            .and_then(|v| v.parse().ok())
-            .unwrap_or(120) // Picked arbitrarily
-    }
-
-    fn refill_duration(&self) -> chrono::Duration {
-        chrono::Duration::hours(1)
-    }
-
-    fn db_name(&self) -> &'static str {
-        "zed-pro:complete-with-language-model"
-    }
-}
-
-struct FreeCompleteWithLanguageModelRateLimit;
-
-impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
-    fn capacity(&self) -> usize {
-        std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
-            .ok()
-            .and_then(|v| v.parse().ok())
-            .unwrap_or(120 / 10) // Picked arbitrarily
-    }
-
-    fn refill_duration(&self) -> chrono::Duration {
-        chrono::Duration::hours(1)
-    }
-
-    fn db_name(&self) -> &'static str {
-        "free:complete-with-language-model"
-    }
-}
-
-async fn complete_with_language_model(
-    request: proto::CompleteWithLanguageModel,
-    response: Response<proto::CompleteWithLanguageModel>,
-    session: Session,
-    config: &Config,
-) -> Result<()> {
-    let Some(session) = session.for_user() else {
-        return Err(anyhow!("user not found"))?;
-    };
-    authorize_access_to_language_models(&session).await?;
-
-    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
-        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
-        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
-    };
-
-    session
-        .app_state
-        .rate_limiter
-        .check(&*rate_limit, session.user_id())
-        .await?;
-
-    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
-        Some(proto::LanguageModelProvider::Anthropic) => {
-            let api_key = config
-                .anthropic_api_key
-                .as_ref()
-                .context("no Anthropic AI API key configured on the server")?;
-            anthropic::complete(
-                session.http_client.as_ref(),
-                anthropic::ANTHROPIC_API_URL,
-                api_key,
-                serde_json::from_str(&request.request)?,
-            )
-            .await?
-        }
-        _ => return Err(anyhow!("unsupported provider"))?,
-    };
-
-    response.send(proto::CompleteWithLanguageModelResponse {
-        completion: serde_json::to_string(&result)?,
-    })?;
-
-    Ok(())
-}
-
-async fn stream_complete_with_language_model(
-    request: proto::StreamCompleteWithLanguageModel,
-    response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
-    session: Session,
-    config: &Config,
-) -> Result<()> {
-    let Some(session) = session.for_user() else {
-        return Err(anyhow!("user not found"))?;
-    };
-    authorize_access_to_language_models(&session).await?;
-
-    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
-        proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
-        proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
-    };
-
-    session
-        .app_state
-        .rate_limiter
-        .check(&*rate_limit, session.user_id())
-        .await?;
-
-    match proto::LanguageModelProvider::from_i32(request.provider) {
-        Some(proto::LanguageModelProvider::Anthropic) => {
-            let api_key = config
-                .anthropic_api_key
-                .as_ref()
-                .context("no Anthropic AI API key configured on the server")?;
-            let mut chunks = anthropic::stream_completion(
-                session.http_client.as_ref(),
-                anthropic::ANTHROPIC_API_URL,
-                api_key,
-                serde_json::from_str(&request.request)?,
-                None,
-            )
-            .await?;
-            while let Some(event) = chunks.next().await {
-                let chunk = event?;
-                response.send(proto::StreamCompleteWithLanguageModelResponse {
-                    event: serde_json::to_string(&chunk)?,
-                })?;
-            }
-        }
-        Some(proto::LanguageModelProvider::OpenAi) => {
-            let api_key = config
-                .openai_api_key
-                .as_ref()
-                .context("no OpenAI API key configured on the server")?;
-            let mut events = open_ai::stream_completion(
-                session.http_client.as_ref(),
-                open_ai::OPEN_AI_API_URL,
-                api_key,
-                serde_json::from_str(&request.request)?,
-                None,
-            )
-            .await?;
-            while let Some(event) = events.next().await {
-                let event = event?;
-                response.send(proto::StreamCompleteWithLanguageModelResponse {
-                    event: serde_json::to_string(&event)?,
-                })?;
-            }
-        }
-        Some(proto::LanguageModelProvider::Google) => {
-            let api_key = config
-                .google_ai_api_key
-                .as_ref()
-                .context("no Google AI API key configured on the server")?;
-            let mut events = google_ai::stream_generate_content(
-                session.http_client.as_ref(),
-                google_ai::API_URL,
-                api_key,
-                serde_json::from_str(&request.request)?,
-            )
-            .await?;
-            while let Some(event) = events.next().await {
-                let event = event?;
-                response.send(proto::StreamCompleteWithLanguageModelResponse {
-                    event: serde_json::to_string(&event)?,
-                })?;
-            }
-        }
-        Some(proto::LanguageModelProvider::Zed) => {
-            let api_key = config
-                .qwen2_7b_api_key
-                .as_ref()
-                .context("no Qwen2-7B API key configured on the server")?;
-            let api_url = config
-                .qwen2_7b_api_url
-                .as_ref()
-                .context("no Qwen2-7B URL configured on the server")?;
-            let mut events = open_ai::stream_completion(
-                session.http_client.as_ref(),
-                &api_url,
-                api_key,
-                serde_json::from_str(&request.request)?,
-                None,
-            )
-            .await?;
-            while let Some(event) = events.next().await {
-                let event = event?;
-                response.send(proto::StreamCompleteWithLanguageModelResponse {
-                    event: serde_json::to_string(&event)?,
-                })?;
-            }
-        }
-        None => return Err(anyhow!("unknown provider"))?,
-    }
-
-    Ok(())
-}
-
 async fn count_language_model_tokens(
     request: proto::CountLanguageModelTokens,
     response: Response<proto::CountLanguageModelTokens>,

crates/proto/proto/zed.proto 🔗

@@ -197,10 +197,6 @@ message Envelope {
 
         JoinHostedProject join_hosted_project = 164;
 
-        CompleteWithLanguageModel complete_with_language_model = 226;
-        CompleteWithLanguageModelResponse complete_with_language_model_response = 227;
-        StreamCompleteWithLanguageModel stream_complete_with_language_model = 228;
-        StreamCompleteWithLanguageModelResponse stream_complete_with_language_model_response = 229;
         CountLanguageModelTokens count_language_model_tokens = 230;
         CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
         GetCachedEmbeddings get_cached_embeddings = 189;
@@ -279,7 +275,7 @@ message Envelope {
 
     reserved 158 to 161;
     reserved 166 to 169;
-    reserved 224 to 225;
+    reserved 224 to 229;
 }
 
 // Messages
@@ -2084,24 +2080,6 @@ enum LanguageModelRole {
     reserved 3;
 }
 
-message CompleteWithLanguageModel {
-    LanguageModelProvider provider = 1;
-    string request = 2;
-}
-
-message CompleteWithLanguageModelResponse {
-    string completion = 1;
-}
-
-message StreamCompleteWithLanguageModel {
-    LanguageModelProvider provider = 1;
-    string request = 2;
-}
-
-message StreamCompleteWithLanguageModelResponse {
-    string event = 1;
-}
-
 message CountLanguageModelTokens {
     LanguageModelProvider provider = 1;
     string request = 2;

crates/proto/src/proto.rs 🔗

@@ -298,10 +298,6 @@ messages!(
     (PrepareRename, Background),
     (PrepareRenameResponse, Background),
     (ProjectEntryResponse, Foreground),
-    (CompleteWithLanguageModel, Background),
-    (CompleteWithLanguageModelResponse, Background),
-    (StreamCompleteWithLanguageModel, Background),
-    (StreamCompleteWithLanguageModelResponse, Background),
     (CountLanguageModelTokens, Background),
     (CountLanguageModelTokensResponse, Background),
     (RefreshInlayHints, Foreground),
@@ -476,11 +472,6 @@ request_messages!(
     (PerformRename, PerformRenameResponse),
     (Ping, Ack),
     (PrepareRename, PrepareRenameResponse),
-    (CompleteWithLanguageModel, CompleteWithLanguageModelResponse),
-    (
-        StreamCompleteWithLanguageModel,
-        StreamCompleteWithLanguageModelResponse
-    ),
     (CountLanguageModelTokens, CountLanguageModelTokensResponse),
     (RefreshInlayHints, Ack),
     (RejoinChannelBuffers, RejoinChannelBuffersResponse),