@@ -53,18 +53,22 @@ pub async fn run_format_prompt(
let prompt = format_zeta_prompt(prompt_inputs, zeta_format);
let prefill = zeta_prompt::get_prefill(prompt_inputs, zeta_format);
- let (expected_patch, expected_cursor_offset) = example
+ let expected_output = example
.spec
.expected_patches_with_cursor_positions()
.into_iter()
.next()
- .context("expected patches is empty")?;
- let expected_output = zeta2_output_for_patch(
- prompt_inputs,
- &expected_patch,
- expected_cursor_offset,
- zeta_format,
- )?;
+ .and_then(|(expected_patch, expected_cursor_offset)| {
+ zeta2_output_for_patch(
+ prompt_inputs,
+ &expected_patch,
+ expected_cursor_offset,
+ zeta_format,
+ )
+ .ok()
+ })
+ .unwrap_or_default();
+
let rejected_output = example.spec.rejected_patch.as_ref().and_then(|patch| {
zeta2_output_for_patch(prompt_inputs, patch, None, zeta_format).ok()
});
@@ -358,6 +358,7 @@ enum PredictionProvider {
Mercury,
Zeta1,
Zeta2(ZetaFormat),
+ Baseten(ZetaFormat),
Teacher(TeacherBackend),
TeacherNonBatching(TeacherBackend),
Repair,
@@ -376,6 +377,7 @@ impl std::fmt::Display for PredictionProvider {
PredictionProvider::Mercury => write!(f, "mercury"),
PredictionProvider::Zeta1 => write!(f, "zeta1"),
PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
+ PredictionProvider::Baseten(format) => write!(f, "baseten:{format}"),
PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
PredictionProvider::TeacherNonBatching(backend) => {
write!(f, "teacher-non-batching:{backend}")
@@ -415,6 +417,13 @@ impl std::str::FromStr for PredictionProvider {
Ok(PredictionProvider::TeacherNonBatching(backend))
}
"repair" => Ok(PredictionProvider::Repair),
+ "baseten" => {
+ let format = arg
+ .map(ZetaFormat::parse)
+ .transpose()?
+ .unwrap_or(ZetaFormat::default());
+ Ok(PredictionProvider::Baseten(format))
+ }
_ => {
anyhow::bail!(
"unknown provider `{provider}`. Valid options: sweep, mercury, zeta1, zeta2, zeta2:<version>, teacher, teacher:<backend>, teacher-non-batching, repair\n\
@@ -6,14 +6,18 @@ use crate::{
headless::EpAppState,
load_project::run_load_project,
openai_client::OpenAiClient,
+ parse_output::parse_prediction_output,
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
- progress::{ExampleProgress, InfoStyle, Step},
+ progress::{ExampleProgress, InfoStyle, Step, StepProgress},
retrieve_context::run_context_retrieval,
};
use anyhow::Context as _;
+use cloud_llm_client::predict_edits_v3::{RawCompletionRequest, RawCompletionResponse};
use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
-use futures::{FutureExt as _, StreamExt as _, future::Shared};
+use futures::{AsyncReadExt as _, FutureExt as _, StreamExt as _, future::Shared};
use gpui::{AppContext as _, AsyncApp, Task};
+use http_client::{AsyncBody, HttpClient, Method};
+use reqwest_client::ReqwestClient;
use std::{
fs,
sync::{
@@ -79,6 +83,22 @@ pub async fn run_prediction(
.await;
}
+ if let PredictionProvider::Baseten(format) = provider {
+ run_format_prompt(
+ example,
+ &FormatPromptArgs {
+ provider: PredictionProvider::Zeta2(format),
+ },
+ app_state.clone(),
+ example_progress,
+ cx,
+ )
+ .await?;
+
+ let step_progress = example_progress.start(Step::Predict);
+ return predict_baseten(example, format, &step_progress).await;
+ }
+
run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
@@ -116,7 +136,8 @@ pub async fn run_prediction(
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher(..)
| PredictionProvider::TeacherNonBatching(..)
- | PredictionProvider::Repair => {
+ | PredictionProvider::Repair
+ | PredictionProvider::Baseten(_) => {
unreachable!()
}
};
@@ -480,6 +501,89 @@ async fn predict_openai(
Ok(())
}
+pub async fn predict_baseten(
+ example: &mut Example,
+ format: ZetaFormat,
+ step_progress: &StepProgress,
+) -> anyhow::Result<()> {
+ let model_id =
+ std::env::var("ZED_ZETA_MODEL").context("ZED_ZETA_MODEL environment variable required")?;
+
+ let api_key =
+ std::env::var("BASETEN_API_KEY").context("BASETEN_API_KEY environment variable not set")?;
+
+ let prompt = example.prompt.as_ref().context("Prompt is required")?;
+ let prompt_text = prompt.input.clone();
+ let prefill = prompt.prefill.clone().unwrap_or_default();
+
+ step_progress.set_substatus("running prediction via baseten");
+
+ let environment: String = <&'static str>::from(&format).to_lowercase();
+ let url = format!(
+ "https://model-{model_id}.api.baseten.co/environments/{environment}/sync/v1/completions"
+ );
+
+ let request_body = RawCompletionRequest {
+ model: model_id,
+ prompt: prompt_text.clone(),
+ max_tokens: Some(2048),
+ temperature: Some(0.),
+ stop: vec![],
+ environment: None,
+ };
+
+ let body_bytes =
+ serde_json::to_vec(&request_body).context("Failed to serialize request body")?;
+
+ let http_client: Arc<dyn HttpClient> = Arc::new(ReqwestClient::new());
+ let request = http_client::Request::builder()
+ .method(Method::POST)
+ .uri(&url)
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Api-Key {api_key}"))
+ .body(AsyncBody::from(body_bytes))?;
+
+ let mut response = http_client.send(request).await?;
+ let status = response.status();
+
+ let mut body = String::new();
+ response
+ .body_mut()
+ .read_to_string(&mut body)
+ .await
+ .context("Failed to read Baseten response body")?;
+
+ if !status.is_success() {
+ anyhow::bail!("Baseten API returned {status}: {body}");
+ }
+
+ let completion: RawCompletionResponse =
+ serde_json::from_str(&body).context("Failed to parse Baseten response")?;
+
+ let actual_output = completion
+ .choices
+ .into_iter()
+ .next()
+ .map(|choice| choice.text)
+ .unwrap_or_default();
+
+ let actual_output = format!("{prefill}{actual_output}");
+
+ let (actual_patch, actual_cursor) =
+ parse_prediction_output(example, &actual_output, PredictionProvider::Zeta2(format))?;
+
+ let prediction = ExamplePrediction {
+ actual_patch: Some(actual_patch),
+ actual_output,
+ actual_cursor,
+ error: None,
+ provider: PredictionProvider::Baseten(format),
+ };
+
+ example.predictions.push(prediction);
+ Ok(())
+}
+
pub async fn sync_batches(provider: Option<&PredictionProvider>) -> anyhow::Result<()> {
match provider {
Some(PredictionProvider::Teacher(backend)) => match backend {