diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 35867f995089769b06eef6c8bafd4a6614d3dc4f..2f297d97e9a2d405bc5e40431c2f6cc576a91686 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -8,6 +8,7 @@ use anyhow::{anyhow, Context as _, Result}; use arrayvec::ArrayVec; use client::{Client, UserStore}; use collections::{HashMap, HashSet, VecDeque}; +use feature_flags::FeatureFlagAppExt as _; use futures::AsyncReadExt; use gpui::{ actions, App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, Subscription, Task, @@ -298,7 +299,7 @@ impl Zeta { perform_predict_edits: F, ) -> Task>> where - F: FnOnce(Arc, LlmApiToken, PredictEditsParams) -> R + 'static, + F: FnOnce(Arc, LlmApiToken, bool, PredictEditsParams) -> R + 'static, R: Future> + Send + 'static, { let snapshot = self.report_changes_for_buffer(buffer, cx); @@ -313,6 +314,7 @@ impl Zeta { let client = self.client.clone(); let llm_token = self.llm_token.clone(); + let is_staff = cx.is_staff(); cx.spawn(|_, cx| async move { let request_sent_at = Instant::now(); @@ -348,7 +350,7 @@ impl Zeta { outline: Some(input_outline.clone()), }; - let response = perform_predict_edits(client, llm_token, body).await?; + let response = perform_predict_edits(client, llm_token, is_staff, body).await?; let output_excerpt = response.output_excerpt; log::debug!("completion response: {}", output_excerpt); @@ -515,7 +517,7 @@ and then another ) -> Task>> { use std::future::ready; - self.request_completion_impl(buffer, position, cx, |_, _, _| ready(Ok(response))) + self.request_completion_impl(buffer, position, cx, |_, _, _, _| ready(Ok(response))) } pub fn request_completion( @@ -530,6 +532,7 @@ and then another fn perform_predict_edits( client: Arc, llm_token: LlmApiToken, + is_staff: bool, body: PredictEditsParams, ) -> impl Future> { async move { @@ -538,14 +541,19 @@ and then another let mut did_retry = false; loop { - let request_builder = http_client::Request::builder(); - let request = request_builder - .method(Method::POST) - .uri( + let request_builder = http_client::Request::builder().method(Method::POST); + let request_builder = if is_staff { + request_builder.uri( + "https://llm-worker-production.zed-industries.workers.dev/predict_edits", + ) + } else { + request_builder.uri( http_client .build_zed_llm_url("/predict_edits", &[])? .as_ref(), ) + }; + let request = request_builder .header("Content-Type", "application/json") .header("Authorization", format!("Bearer {}", token)) .body(serde_json::to_string(&body)?.into())?;