From 09e178d1e19d19350a1c98785bb97b903fa309af Mon Sep 17 00:00:00 2001 From: Oleksiy Syvokon Date: Tue, 3 Mar 2026 21:09:26 +0200 Subject: [PATCH] ep: Predict by querying Baseten directly (#50626) This can be used like `ep predict --provider baseten:V0131GitMergeMarkersPrefix`. Since it doesn't require load_project, it can be used with captured requests. Release Notes: - N/A --- .../edit_prediction_cli/src/format_prompt.rs | 20 ++-- crates/edit_prediction_cli/src/main.rs | 9 ++ crates/edit_prediction_cli/src/predict.rs | 110 +++++++++++++++++- 3 files changed, 128 insertions(+), 11 deletions(-) diff --git a/crates/edit_prediction_cli/src/format_prompt.rs b/crates/edit_prediction_cli/src/format_prompt.rs index ecacd963023d7d113ea5ad77b61fd1d88306fc95..bee79ae8160eeb815a3739b53a5441f6063fb622 100644 --- a/crates/edit_prediction_cli/src/format_prompt.rs +++ b/crates/edit_prediction_cli/src/format_prompt.rs @@ -53,18 +53,22 @@ pub async fn run_format_prompt( let prompt = format_zeta_prompt(prompt_inputs, zeta_format); let prefill = zeta_prompt::get_prefill(prompt_inputs, zeta_format); - let (expected_patch, expected_cursor_offset) = example + let expected_output = example .spec .expected_patches_with_cursor_positions() .into_iter() .next() - .context("expected patches is empty")?; - let expected_output = zeta2_output_for_patch( - prompt_inputs, - &expected_patch, - expected_cursor_offset, - zeta_format, - )?; + .and_then(|(expected_patch, expected_cursor_offset)| { + zeta2_output_for_patch( + prompt_inputs, + &expected_patch, + expected_cursor_offset, + zeta_format, + ) + .ok() + }) + .unwrap_or_default(); + let rejected_output = example.spec.rejected_patch.as_ref().and_then(|patch| { zeta2_output_for_patch(prompt_inputs, patch, None, zeta_format).ok() }); diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 207a69328fb07277c39463c0c6a460862c95fe42..8bb4b2a8e2f50d448fc314a70e2fc94cfa2c3d71 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -358,6 +358,7 @@ enum PredictionProvider { Mercury, Zeta1, Zeta2(ZetaFormat), + Baseten(ZetaFormat), Teacher(TeacherBackend), TeacherNonBatching(TeacherBackend), Repair, @@ -376,6 +377,7 @@ impl std::fmt::Display for PredictionProvider { PredictionProvider::Mercury => write!(f, "mercury"), PredictionProvider::Zeta1 => write!(f, "zeta1"), PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"), + PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"), PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"), PredictionProvider::TeacherNonBatching(backend) => { write!(f, "teacher-non-batching:{backend}") @@ -415,6 +417,13 @@ impl std::str::FromStr for PredictionProvider { Ok(PredictionProvider::TeacherNonBatching(backend)) } "repair" => Ok(PredictionProvider::Repair), + "baseten" => { + let format = arg + .map(ZetaFormat::parse) + .transpose()? + .unwrap_or(ZetaFormat::default()); + Ok(PredictionProvider::Baseten(format)) + } _ => { anyhow::bail!( "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:, teacher, teacher:, teacher-non-batching, repair\n\ diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 02ba24b8a4f2627b9542254e3d118981737f8318..8f537dc0817a9cb0b4fd74348ae5e43d4f63beb9 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -6,14 +6,18 @@ use crate::{ headless::EpAppState, load_project::run_load_project, openai_client::OpenAiClient, + parse_output::parse_prediction_output, paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR}, - progress::{ExampleProgress, InfoStyle, Step}, + progress::{ExampleProgress, InfoStyle, Step, StepProgress}, retrieve_context::run_context_retrieval, }; use anyhow::Context as _; +use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse}; use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig}; -use futures::{FutureExt as _, StreamExt as _, future::Shared}; +use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared}; use gpui::{AppContext as _, AsyncApp, Task}; +use http_client::{AsyncBody, HttpClient, Method}; +use reqwest_client::ReqwestClient; use std::{ fs, sync::{ @@ -79,6 +83,22 @@ pub async fn run_prediction( .await; } + if let PredictionProvider::Baseten(format) = provider { + run_format_prompt( + example, + &FormatPromptArgs { + provider: PredictionProvider::Zeta2(format), + }, + app_state.clone(), + example_progress, + cx, + ) + .await?; + + let step_progress = example_progress.start(Step::Predict); + return predict_baseten(example, format, &step_progress).await; + } + run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?; run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?; @@ -116,7 +136,8 @@ pub async fn run_prediction( PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury, PredictionProvider::Teacher(..) | PredictionProvider::TeacherNonBatching(..) - | PredictionProvider::Repair => { + | PredictionProvider::Repair + | PredictionProvider::Baseten(_) => { unreachable!() } }; @@ -480,6 +501,89 @@ async fn predict_openai( Ok(()) } +pub async fn predict_baseten( + example: &mut Example, + format: ZetaFormat, + step_progress: &StepProgress, +) -> anyhow::Result<()> { + let model_id = + std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?; + + let api_key = + std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?; + + let prompt = example.prompt.as_ref().context("Prompt is required")?; + let prompt_text = prompt.input.clone(); + let prefill = prompt.prefill.clone().unwrap_or_default(); + + step_progress.set_substatus("running prediction via baseten"); + + let environment: String = <&'static str>::from(&format).to_lowercase(); + let url = format!( + "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions" + ); + + let request_body = RawCompletionRequest { + model: model_id, + prompt: prompt_text.clone(), + max_tokens: Some(2048), + temperature: Some(0.), + stop: vec![], + environment: None, + }; + + let body_bytes = + serde_json::to_vec(&request_body).context("Failed to serialize request body")?; + + let http_client: Arc = Arc::new(ReqwestClient::new()); + let request = http_client::Request::builder() + .method(Method::POST) + .uri(&url) + .header("Content-Type", "application/json") + .header("Authorization", format!("Api-Key {api_key}")) + .body(AsyncBody::from(body_bytes))?; + + let mut response = http_client.send(request).await?; + let status = response.status(); + + let mut body = String::new(); + response + .body_mut() + .read_to_string(&mut body) + .await + .context("Failed to read Baseten response body")?; + + if !status.is_success() { + anyhow::bail!("Baseten API returned {status}: {body}"); + } + + let completion: RawCompletionResponse = + serde_json::from_str(&body).context("Failed to parse Baseten response")?; + + let actual_output = completion + .choices + .into_iter() + .next() + .map(|choice| choice.text) + .unwrap_or_default(); + + let actual_output = format!("{prefill}{actual_output}"); + + let (actual_patch, actual_cursor) = + parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?; + + let prediction = ExamplePrediction { + actual_patch: Some(actual_patch), + actual_output, + actual_cursor, + error: None, + provider: PredictionProvider::Baseten(format), + }; + + example.predictions.push(prediction); + Ok(()) +} + pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> { match provider { Some(PredictionProvider::Teacher(backend)) => match backend {