From 8603a908c1e97109498c16a7ed60a674be85e352 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 29 Jan 2025 13:22:16 -0500 Subject: [PATCH] zeta: Send staff edit predictions through Cloudflare Workers (#23847) This PR makes it so staff edit predictions now go through Cloudflare Workers instead of going to the LLM service. This will allow us to dogfood the new LLM worker to make sure it is working as expected. Release Notes: - N/A --- crates/zeta/src/zeta.rs | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 35867f995089769b06eef6c8bafd4a6614d3dc4f..2f297d97e9a2d405bc5e40431c2f6cc576a91686 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -8,6 +8,7 @@ use anyhow::{anyhow, Context as _, Result}; use arrayvec::ArrayVec; use client::{Client, UserStore}; use collections::{HashMap, HashSet, VecDeque}; +use feature_flags::FeatureFlagAppExt as _; use futures::AsyncReadExt; use gpui::{ actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task, @@ -298,7 +299,7 @@ impl Zeta { perform_predict_edits: F, ) -> Task>> where - F: FnOnce(Arc, LlmApiToken, PredictEditsParams) -> R + 'static, + F: FnOnce(Arc, LlmApiToken, bool, PredictEditsParams) -> R + 'static, R: Future> + Send + 'static, { let snapshot = self.report_changes_for_buffer(buffer, cx); @@ -313,6 +314,7 @@ impl Zeta { let client = self.client.clone(); let llm_token = self.llm_token.clone(); + let is_staff = cx.is_staff(); cx.spawn(|_, cx| async move { let request_sent_at = Instant::now(); @@ -348,7 +350,7 @@ impl Zeta { outline: Some(input_outline.clone()), }; - let response = perform_predict_edits(client, llm_token, body).await?; + let response = perform_predict_edits(client, llm_token, is_staff, body).await?; let output_excerpt = response.output_excerpt; log::debug!("completion response: {}", output_excerpt); @@ -515,7 +517,7 @@ and then another ) -> Task>> { use std::future::ready; - self.request_completion_impl(buffer, position, cx, |_, _, _| ready(Ok(response))) + self.request_completion_impl(buffer, position, cx, |_, _, _, _| ready(Ok(response))) } pub fn request_completion( @@ -530,6 +532,7 @@ and then another fn perform_predict_edits( client: Arc, llm_token: LlmApiToken, + is_staff: bool, body: PredictEditsParams, ) -> impl Future> { async move { @@ -538,14 +541,19 @@ and then another let mut did_retry = false; loop { - let request_builder = http_client::Request::builder(); - let request = request_builder - .method(Method::POST) - .uri( + let request_builder = http_client::Request::builder().method(Method::POST); + let request_builder = if is_staff { + request_builder.uri( + "https://llm-worker-production.zed-industries.workers.dev/predict_edits", + ) + } else { + request_builder.uri( http_client .build_zed_llm_url("/predict_edits", &[])? .as_ref(), ) + }; + let request = request_builder .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", token)) .body(serde_json::to_string(&body)?.into())?;