ep: Predict by querying Baseten directly (#50626)

Oleksiy Syvokon created

This can be used like `ep predict --provider
baseten:V0131GitMergeMarkersPrefix`. Since it doesn't require
load_project, it can be used with captured requests.


Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/format_prompt.rs |  20 ++-
crates/edit_prediction_cli/src/main.rs          |   9 +
crates/edit_prediction_cli/src/predict.rs       | 110 ++++++++++++++++++
3 files changed, 128 insertions(+), 11 deletions(-)

Detailed changes

crates/edit_prediction_cli/src/format_prompt.rs 🔗

@@ -53,18 +53,22 @@ pub async fn run_format_prompt(
 
             let prompt = format_zeta_prompt(prompt_inputs, zeta_format);
             let prefill = zeta_prompt::get_prefill(prompt_inputs, zeta_format);
-            let (expected_patch, expected_cursor_offset) = example
+            let expected_output = example
                 .spec
                 .expected_patches_with_cursor_positions()
                 .into_iter()
                 .next()
-                .context("expected patches is empty")?;
-            let expected_output = zeta2_output_for_patch(
-                prompt_inputs,
-                &expected_patch,
-                expected_cursor_offset,
-                zeta_format,
-            )?;
+                .and_then(|(expected_patch, expected_cursor_offset)| {
+                    zeta2_output_for_patch(
+                        prompt_inputs,
+                        &expected_patch,
+                        expected_cursor_offset,
+                        zeta_format,
+                    )
+                    .ok()
+                })
+                .unwrap_or_default();
+
             let rejected_output = example.spec.rejected_patch.as_ref().and_then(|patch| {
                 zeta2_output_for_patch(prompt_inputs, patch, None, zeta_format).ok()
             });

crates/edit_prediction_cli/src/main.rs 🔗

@@ -358,6 +358,7 @@ enum PredictionProvider {
     Mercury,
     Zeta1,
     Zeta2(ZetaFormat),
+    Baseten(ZetaFormat),
     Teacher(TeacherBackend),
     TeacherNonBatching(TeacherBackend),
     Repair,
@@ -376,6 +377,7 @@ impl std::fmt::Display for PredictionProvider {
             PredictionProvider::Mercury => write!(f, "mercury"),
             PredictionProvider::Zeta1 => write!(f, "zeta1"),
             PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
+            PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
             PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
             PredictionProvider::TeacherNonBatching(backend) => {
                 write!(f, "teacher-non-batching:{backend}")
@@ -415,6 +417,13 @@ impl std::str::FromStr for PredictionProvider {
                 Ok(PredictionProvider::TeacherNonBatching(backend))
             }
             "repair" => Ok(PredictionProvider::Repair),
+            "baseten" => {
+                let format = arg
+                    .map(ZetaFormat::parse)
+                    .transpose()?
+                    .unwrap_or(ZetaFormat::default());
+                Ok(PredictionProvider::Baseten(format))
+            }
             _ => {
                 anyhow::bail!(
                     "unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -6,14 +6,18 @@ use crate::{
     headless::EpAppState,
     load_project::run_load_project,
     openai_client::OpenAiClient,
+    parse_output::parse_prediction_output,
     paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
-    progress::{ExampleProgress, InfoStyle, Step},
+    progress::{ExampleProgress, InfoStyle, Step, StepProgress},
     retrieve_context::run_context_retrieval,
 };
 use anyhow::Context as _;
+use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
 use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
-use futures::{FutureExt as _, StreamExt as _, future::Shared};
+use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
 use gpui::{AppContext as _, AsyncApp, Task};
+use http_client::{AsyncBody, HttpClient, Method};
+use reqwest_client::ReqwestClient;
 use std::{
     fs,
     sync::{
@@ -79,6 +83,22 @@ pub async fn run_prediction(
         .await;
     }
 
+    if let PredictionProvider::Baseten(format) = provider {
+        run_format_prompt(
+            example,
+            &FormatPromptArgs {
+                provider: PredictionProvider::Zeta2(format),
+            },
+            app_state.clone(),
+            example_progress,
+            cx,
+        )
+        .await?;
+
+        let step_progress = example_progress.start(Step::Predict);
+        return predict_baseten(example, format, &step_progress).await;
+    }
+
     run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
     run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 
@@ -116,7 +136,8 @@ pub async fn run_prediction(
             PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
             PredictionProvider::Teacher(..)
             | PredictionProvider::TeacherNonBatching(..)
-            | PredictionProvider::Repair => {
+            | PredictionProvider::Repair
+            | PredictionProvider::Baseten(_) => {
                 unreachable!()
             }
         };
@@ -480,6 +501,89 @@ async fn predict_openai(
     Ok(())
 }
 
+pub async fn predict_baseten(
+    example: &mut Example,
+    format: ZetaFormat,
+    step_progress: &StepProgress,
+) -> anyhow::Result<()> {
+    let model_id =
+        std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
+
+    let api_key =
+        std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
+
+    let prompt = example.prompt.as_ref().context("Prompt is required")?;
+    let prompt_text = prompt.input.clone();
+    let prefill = prompt.prefill.clone().unwrap_or_default();
+
+    step_progress.set_substatus("running prediction via baseten");
+
+    let environment: String = <&'static str>::from(&format).to_lowercase();
+    let url = format!(
+        "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
+    );
+
+    let request_body = RawCompletionRequest {
+        model: model_id,
+        prompt: prompt_text.clone(),
+        max_tokens: Some(2048),
+        temperature: Some(0.),
+        stop: vec![],
+        environment: None,
+    };
+
+    let body_bytes =
+        serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
+
+    let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
+    let request = http_client::Request::builder()
+        .method(Method::POST)
+        .uri(&url)
+        .header("Content-Type", "application/json")
+        .header("Authorization", format!("Api-Key {api_key}"))
+        .body(AsyncBody::from(body_bytes))?;
+
+    let mut response = http_client.send(request).await?;
+    let status = response.status();
+
+    let mut body = String::new();
+    response
+        .body_mut()
+        .read_to_string(&mut body)
+        .await
+        .context("Failed to read Baseten response body")?;
+
+    if !status.is_success() {
+        anyhow::bail!("Baseten API returned {status}: {body}");
+    }
+
+    let completion: RawCompletionResponse =
+        serde_json::from_str(&body).context("Failed to parse Baseten response")?;
+
+    let actual_output = completion
+        .choices
+        .into_iter()
+        .next()
+        .map(|choice| choice.text)
+        .unwrap_or_default();
+
+    let actual_output = format!("{prefill}{actual_output}");
+
+    let (actual_patch, actual_cursor) =
+        parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
+
+    let prediction = ExamplePrediction {
+        actual_patch: Some(actual_patch),
+        actual_output,
+        actual_cursor,
+        error: None,
+        provider: PredictionProvider::Baseten(format),
+    };
+
+    example.predictions.push(prediction);
+    Ok(())
+}
+
 pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
     match provider {
         Some(PredictionProvider::Teacher(backend)) => match backend {