Cargo.lock 🔗
@@ -5398,6 +5398,7 @@ dependencies = [
"smol",
"sqlez",
"sqlez_macros",
+ "strum 0.27.2",
"telemetry_events",
"tempfile",
"terminal_view",
Max Brunsfeld created
This allows us to switch the prompt format without client-side changes.
If we want to experiment with prompt formats or models other than the
currently-deployed one, we can use the raw endpoint, and do prompt
construction and output processing on the client.
This also adds an optional environment parameter to the raw endpoint, so
that we can use that endpoint in the new scheme where we're deploying to
separate environments for different zeta prompt versions.
Release Notes:
- N/A
Cargo.lock | 1
crates/cloud_llm_client/src/predict_edits_v3.rs | 6
crates/edit_prediction/src/edit_prediction.rs | 100 ++---
crates/edit_prediction/src/edit_prediction_tests.rs | 126 -------
crates/edit_prediction/src/zed_edit_prediction_delegate.rs | 2
crates/edit_prediction/src/zeta1.rs | 19
crates/edit_prediction/src/zeta2.rs | 72 ++--
crates/edit_prediction_cli/Cargo.toml | 1
crates/edit_prediction_cli/src/format_prompt.rs | 8
crates/edit_prediction_cli/src/main.rs | 29 +
crates/edit_prediction_cli/src/parse_output.rs | 32 +-
crates/edit_prediction_cli/src/predict.rs | 16
crates/zed/src/zed/edit_prediction_registry.rs | 4
crates/zeta_prompt/src/zeta_prompt.rs | 51 ++-
14 files changed, 193 insertions(+), 274 deletions(-)
@@ -5398,6 +5398,7 @@ dependencies = [
"smol",
"sqlez",
"sqlez_macros",
+ "strum 0.27.2",
"telemetry_events",
"tempfile",
"terminal_view",
@@ -11,16 +11,14 @@ pub struct RawCompletionRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
pub stop: Vec<Cow<'static, str>>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub environment: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PredictEditsV3Request {
#[serde(flatten)]
pub input: zeta_prompt::ZetaPromptInput,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub model: Option<String>,
- #[serde(default)]
- pub prompt_version: zeta_prompt::ZetaVersion,
#[serde(default)]
pub trigger: PredictEditsRequestTrigger,
}
@@ -36,18 +36,18 @@ use semver::Version;
use serde::de::DeserializeOwned;
use settings::{EditPredictionProvider, Settings as _, update_settings_file};
use std::collections::{VecDeque, hash_map};
+use std::env;
use text::Edit;
use workspace::Workspace;
-use zeta_prompt::ZetaPromptInput;
-use zeta_prompt::ZetaVersion;
+use zeta_prompt::{ZetaFormat, ZetaPromptInput};
+use std::mem;
use std::ops::Range;
use std::path::Path;
use std::rc::Rc;
use std::str::FromStr as _;
-use std::sync::{Arc, LazyLock};
+use std::sync::Arc;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
-use std::{env, mem};
use thiserror::Error;
use util::{RangeExt as _, ResultExt as _};
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
@@ -105,9 +105,6 @@ const LAST_CHANGE_GROUPING_TIME: Duration = Duration::from_secs(1);
const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_choice";
const REJECT_REQUEST_DEBOUNCE: Duration = Duration::from_secs(15);
-static EDIT_PREDICTIONS_MODEL_ID: LazyLock<Option<String>> =
- LazyLock::new(|| env::var("ZED_ZETA_MODEL").ok());
-
pub struct Zeta2FeatureFlag;
impl FeatureFlag for Zeta2FeatureFlag {
@@ -133,6 +130,15 @@ struct EditPredictionStoreGlobal(Entity<EditPredictionStore>);
impl Global for EditPredictionStoreGlobal {}
+/// Configuration for using the raw Zeta2 endpoint.
+/// When set, the client uses the raw endpoint and constructs the prompt itself.
+/// The version is also used as the Baseten environment name (lowercased).
+#[derive(Clone)]
+pub struct Zeta2RawConfig {
+ pub model_id: Option<String>,
+ pub format: ZetaFormat,
+}
+
pub struct EditPredictionStore {
client: Arc<Client>,
user_store: Entity<UserStore>,
@@ -141,6 +147,7 @@ pub struct EditPredictionStore {
projects: HashMap<EntityId, ProjectState>,
update_required: bool,
edit_prediction_model: EditPredictionModel,
+ zeta2_raw_config: Option<Zeta2RawConfig>,
pub sweep_ai: SweepAi,
pub mercury: Mercury,
pub ollama: Ollama,
@@ -148,16 +155,13 @@ pub struct EditPredictionStore {
reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
shown_predictions: VecDeque<EditPrediction>,
rated_predictions: HashSet<EditPredictionId>,
- custom_predict_edits_url: Option<Arc<Url>>,
}
#[derive(Copy, Clone, Default, PartialEq, Eq)]
pub enum EditPredictionModel {
#[default]
Zeta1,
- Zeta2 {
- version: ZetaVersion,
- },
+ Zeta2,
Sweep,
Mercury,
Ollama,
@@ -631,9 +635,8 @@ impl EditPredictionStore {
},
),
update_required: false,
- edit_prediction_model: EditPredictionModel::Zeta2 {
- version: Default::default(),
- },
+ edit_prediction_model: EditPredictionModel::Zeta2,
+ zeta2_raw_config: Self::zeta2_raw_config_from_env(),
sweep_ai: SweepAi::new(cx),
mercury: Mercury::new(cx),
ollama: Ollama::new(),
@@ -642,24 +645,30 @@ impl EditPredictionStore {
reject_predictions_tx: reject_tx,
rated_predictions: Default::default(),
shown_predictions: Default::default(),
- custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
- Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
- Err(_) => None,
- },
};
this
}
- #[cfg(test)]
- pub fn set_custom_predict_edits_url(&mut self, url: Url) {
- self.custom_predict_edits_url = Some(url.into());
+ fn zeta2_raw_config_from_env() -> Option<Zeta2RawConfig> {
+ let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
+ let format = ZetaFormat::parse(&version_str).ok()?;
+ let model_id = env::var("ZED_ZETA_MODEL").ok();
+ Some(Zeta2RawConfig { model_id, format })
}
pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
self.edit_prediction_model = model;
}
+ pub fn set_zeta2_raw_config(&mut self, config: Zeta2RawConfig) {
+ self.zeta2_raw_config = Some(config);
+ }
+
+ pub fn zeta2_raw_config(&self) -> Option<&Zeta2RawConfig> {
+ self.zeta2_raw_config.as_ref()
+ }
+
pub fn icons(&self) -> edit_prediction_types::EditPredictionIconSet {
use ui::IconName;
match self.edit_prediction_model {
@@ -673,7 +682,7 @@ impl EditPredictionStore {
EditPredictionModel::Mercury => {
edit_prediction_types::EditPredictionIconSet::new(IconName::Inception)
}
- EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
edit_prediction_types::EditPredictionIconSet::new(IconName::ZedPredict)
.with_disabled(IconName::ZedPredictDisabled)
.with_up(IconName::ZedPredictUp)
@@ -796,10 +805,7 @@ impl EditPredictionStore {
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
- if matches!(
- self.edit_prediction_model,
- EditPredictionModel::Zeta2 { .. }
- ) {
+ if matches!(self.edit_prediction_model, EditPredictionModel::Zeta2) {
self.user_store.read(cx).edit_prediction_usage()
} else {
None
@@ -1223,7 +1229,7 @@ impl EditPredictionStore {
);
}
EditPredictionModel::Ollama => {}
- EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
zeta2::edit_prediction_accepted(self, current_prediction, cx)
}
}
@@ -1359,16 +1365,14 @@ impl EditPredictionStore {
cx: &App,
) {
match self.edit_prediction_model {
- EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
- if self.custom_predict_edits_url.is_none() {
- self.reject_predictions_tx
- .unbounded_send(EditPredictionRejection {
- request_id: prediction_id.to_string(),
- reason,
- was_shown,
- })
- .log_err();
- }
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
+ self.reject_predictions_tx
+ .unbounded_send(EditPredictionRejection {
+ request_id: prediction_id.to_string(),
+ reason,
+ was_shown,
+ })
+ .log_err();
}
EditPredictionModel::Sweep | EditPredictionModel::Ollama => {}
EditPredictionModel::Mercury => {
@@ -1805,24 +1809,16 @@ impl EditPredictionStore {
.detach_and_log_err(cx);
}
}
- let task = match self.edit_prediction_model {
+ let task = match &self.edit_prediction_model {
EditPredictionModel::Zeta1 => {
if should_send_testing_zeta2_request() {
let mut zeta2_inputs = inputs.clone();
zeta2_inputs.trigger = PredictEditsRequestTrigger::Testing;
- zeta2::request_prediction_with_zeta2(
- self,
- zeta2_inputs,
- Default::default(),
- cx,
- )
- .detach();
+ zeta2::request_prediction_with_zeta2(self, zeta2_inputs, cx).detach();
}
zeta1::request_prediction_with_zeta1(self, inputs, cx)
}
- EditPredictionModel::Zeta2 { version } => {
- zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
- }
+ EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
EditPredictionModel::Ollama => self.ollama.request_prediction(inputs, cx),
@@ -1976,7 +1972,6 @@ impl EditPredictionStore {
pub(crate) async fn send_v3_request(
input: ZetaPromptInput,
- prompt_version: ZetaVersion,
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: Version,
@@ -1986,12 +1981,7 @@ impl EditPredictionStore {
.http_client()
.build_zed_llm_url("/predict_edits/v3", &[])?;
- let request = PredictEditsV3Request {
- input,
- model: EDIT_PREDICTIONS_MODEL_ID.clone(),
- prompt_version,
- trigger,
- };
+ let request = PredictEditsV3Request { input, trigger };
Self::send_api_request(
|builder| {
@@ -1343,7 +1343,7 @@ fn model_response(request: &PredictEditsV3Request, diff_to_apply: &str) -> Predi
}
fn prompt_from_request(request: &PredictEditsV3Request) -> String {
- zeta_prompt::format_zeta_prompt(&request.input, request.prompt_version)
+ zeta_prompt::format_zeta_prompt(&request.input, zeta_prompt::ZetaFormat::default())
}
struct RequestChannels {
@@ -2073,6 +2073,20 @@ async fn make_test_ep_store(
)
.unwrap())
}
+ (&Method::POST, "/predict_edits/v3") => {
+ next_request_id += 1;
+ Ok(http_client::Response::builder()
+ .status(200)
+ .body(
+ serde_json::to_string(&PredictEditsV3Response {
+ request_id: format!("request-{next_request_id}"),
+ output: "hello world".to_string(),
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap())
+ }
_ => Ok(http_client::Response::builder()
.status(404)
.body("Not Found".into())
@@ -2200,116 +2214,6 @@ async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut
);
}
-#[gpui::test]
-async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
- init_test(cx);
-
- let fs = FakeFs::new(cx.executor());
- fs.insert_tree(
- "/project",
- serde_json::json!({
- "main.rs": "fn main() {\n \n}\n"
- }),
- )
- .await;
-
- let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
-
- let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
- let predict_called_clone = predict_called.clone();
-
- let http_client = FakeHttpClient::create({
- move |req| {
- let uri = req.uri().path().to_string();
- let predict_called = predict_called_clone.clone();
- async move {
- if uri.contains("predict") {
- predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
- Ok(gpui::http_client::Response::builder()
- .body(
- serde_json::to_string(&open_ai::Response {
- id: "test-123".to_string(),
- object: "chat.completion".to_string(),
- created: 0,
- model: "test".to_string(),
- usage: open_ai::Usage {
- prompt_tokens: 0,
- completion_tokens: 0,
- total_tokens: 0,
- },
- choices: vec![open_ai::Choice {
- index: 0,
- message: open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(
- indoc! {"
- ```main.rs
- <|start_of_file|>
- <|editable_region_start|>
- fn main() {
- println!(\"Hello, world!\");
- }
- <|editable_region_end|>
- ```
- "}
- .to_string(),
- )),
- tool_calls: vec![],
- },
- finish_reason: Some("stop".to_string()),
- }],
- })
- .unwrap()
- .into(),
- )
- .unwrap())
- } else {
- Ok(gpui::http_client::Response::builder()
- .status(401)
- .body("Unauthorized".into())
- .unwrap())
- }
- }
- }
- });
-
- let client =
- cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
- cx.update(|cx| {
- language_model::RefreshLlmTokenListener::register(client.clone(), cx);
- });
-
- let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
-
- let buffer = project
- .update(cx, |project, cx| {
- let path = project
- .find_project_path(path!("/project/main.rs"), cx)
- .unwrap();
- project.open_buffer(path, cx)
- })
- .await
- .unwrap();
-
- let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
- ep_store.update(cx, |ep_store, cx| {
- ep_store.register_buffer(&buffer, &project, cx)
- });
- cx.background_executor.run_until_parked();
-
- let completion_task = ep_store.update(cx, |ep_store, cx| {
- ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
- ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
- ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
- });
-
- let _ = completion_task.await;
-
- assert!(
- predict_called.load(std::sync::atomic::Ordering::SeqCst),
- "With custom URL, predict endpoint should be called even without authentication"
- );
-}
-
#[gpui::test]
fn test_compute_diff_between_snapshots(cx: &mut TestAppContext) {
let buffer = cx.new(|cx| {
@@ -70,7 +70,7 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
.with_down(IconName::SweepAiDown)
.with_error(IconName::SweepAiError),
EditPredictionModel::Mercury => EditPredictionIconSet::new(IconName::Inception),
- EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
EditPredictionIconSet::new(IconName::ZedPredict)
.with_disabled(IconName::ZedPredictDisabled)
.with_up(IconName::ZedPredictUp)
@@ -81,17 +81,12 @@ pub(crate) fn request_prediction_with_zeta1(
cx,
);
- let (uri, require_auth) = match &store.custom_predict_edits_url {
- Some(custom_url) => (custom_url.clone(), false),
- None => {
- match client
- .http_client()
- .build_zed_llm_url("/predict_edits/v2", &[])
- {
- Ok(url) => (url.into(), true),
- Err(err) => return Task::ready(Err(err)),
- }
- }
+ let uri = match client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/v2", &[])
+ {
+ Ok(url) => Arc::from(url),
+ Err(err) => return Task::ready(Err(err)),
};
cx.spawn(async move |this, cx| {
@@ -127,7 +122,7 @@ pub(crate) fn request_prediction_with_zeta1(
client,
llm_token,
app_version,
- require_auth,
+ true,
)
.await;
@@ -1,9 +1,8 @@
use crate::prediction::EditPredictionResult;
use crate::zeta1::compute_edits_and_cursor_position;
use crate::{
- CurrentEditPrediction, DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent,
- EditPredictionId, EditPredictionModelInput, EditPredictionStartedDebugEvent,
- EditPredictionStore,
+ CurrentEditPrediction, DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId,
+ EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
};
use anyhow::{Result, anyhow};
use cloud_llm_client::predict_edits_v3::RawCompletionRequest;
@@ -14,17 +13,16 @@ use release_channel::AppVersion;
use std::env;
use std::{path::Path, sync::Arc, time::Instant};
-use zeta_prompt::format_zeta_prompt;
-use zeta_prompt::{CURSOR_MARKER, ZetaVersion, v0120_git_merge_markers};
+use zeta_prompt::{CURSOR_MARKER, ZetaFormat, clean_zeta2_model_output, format_zeta_prompt};
pub const MAX_CONTEXT_TOKENS: usize = 350;
-pub fn max_editable_tokens(version: ZetaVersion) -> usize {
- match version {
- ZetaVersion::V0112MiddleAtEnd | ZetaVersion::V0113Ordered => 150,
- ZetaVersion::V0114180EditableRegion => 180,
- ZetaVersion::V0120GitMergeMarkers => 180,
- ZetaVersion::V0131GitMergeMarkersPrefix => 180,
+pub fn max_editable_tokens(format: ZetaFormat) -> usize {
+ match format {
+ ZetaFormat::V0112MiddleAtEnd | ZetaFormat::V0113Ordered => 150,
+ ZetaFormat::V0114180EditableRegion => 180,
+ ZetaFormat::V0120GitMergeMarkers => 180,
+ ZetaFormat::V0131GitMergeMarkersPrefix => 180,
}
}
@@ -40,11 +38,10 @@ pub fn request_prediction_with_zeta2(
trigger,
..
}: EditPredictionModelInput,
- zeta_version: ZetaVersion,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let buffer_snapshotted_at = Instant::now();
- let custom_url = store.custom_predict_edits_url.clone();
+ let raw_config = store.zeta2_raw_config().cloned();
let Some(excerpt_path) = snapshot
.file()
@@ -59,6 +56,11 @@ pub fn request_prediction_with_zeta2(
let request_task = cx.background_spawn({
async move {
+ let zeta_version = raw_config
+ .as_ref()
+ .map(|config| config.format)
+ .unwrap_or(ZetaFormat::default());
+
let cursor_offset = position.to_offset(&snapshot);
let (editable_offset_range, prompt_input) = zeta2_prompt_input(
&snapshot,
@@ -84,33 +86,36 @@ pub fn request_prediction_with_zeta2(
log::trace!("Sending edit prediction request");
- let (request_id, output_text, usage) = if let Some(custom_url) = custom_url {
- // Use raw endpoint with custom URL
- let prompt = format_zeta_prompt(&prompt_input, zeta_version);
+ let (request_id, output_text, usage) = if let Some(config) = &raw_config {
+ let prompt = format_zeta_prompt(&prompt_input, config.format);
let request = RawCompletionRequest {
- model: EDIT_PREDICTIONS_MODEL_ID.clone().unwrap_or_default(),
+ model: config.model_id.clone().unwrap_or_default(),
prompt,
temperature: None,
stop: vec![],
max_tokens: Some(2048),
+ environment: Some(config.format.to_string().to_lowercase()),
};
let (mut response, usage) = EditPredictionStore::send_raw_llm_request(
request,
client,
- Some(custom_url),
+ None,
llm_token,
app_version,
)
.await?;
let request_id = EditPredictionId(response.id.clone().into());
- let output_text = response.choices.pop().map(|choice| choice.text);
+ let output_text = response.choices.pop().map(|choice| {
+ clean_zeta2_model_output(&choice.text, config.format).to_string()
+ });
+
(request_id, output_text, usage)
} else {
+ // Use V3 endpoint - server handles model/version selection and suffix stripping
let (response, usage) = EditPredictionStore::send_v3_request(
prompt_input.clone(),
- zeta_version,
client,
llm_token,
app_version,
@@ -135,6 +140,13 @@ pub fn request_prediction_with_zeta2(
return Ok((Some((request_id, None)), usage));
};
+ // Client-side cursor marker processing (applies to both raw and v3 responses)
+ let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
+ if let Some(offset) = cursor_offset_in_output {
+ log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
+ output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
+ }
+
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
@@ -147,20 +159,6 @@ pub fn request_prediction_with_zeta2(
.ok();
}
- let cursor_offset_in_output = output_text.find(CURSOR_MARKER);
- if let Some(offset) = cursor_offset_in_output {
- log::trace!("Stripping out {CURSOR_MARKER} from response at offset {offset}");
- output_text.replace_range(offset..offset + CURSOR_MARKER.len(), "");
- }
-
- if zeta_version == ZetaVersion::V0120GitMergeMarkers {
- if let Some(stripped) =
- output_text.strip_suffix(v0120_git_merge_markers::END_MARKER)
- {
- output_text = stripped.to_string();
- }
- }
-
let mut old_text = snapshot
.text_for_range(editable_offset_range.clone())
.collect::<String>();
@@ -242,7 +240,7 @@ pub fn zeta2_prompt_input(
events: Vec<Arc<zeta_prompt::Event>>,
excerpt_path: Arc<Path>,
cursor_offset: usize,
- zeta_version: ZetaVersion,
+ zeta_format: ZetaFormat,
) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
let cursor_point = cursor_offset.to_point(snapshot);
@@ -250,7 +248,7 @@ pub fn zeta2_prompt_input(
crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
cursor_point,
snapshot,
- max_editable_tokens(zeta_version),
+ max_editable_tokens(zeta_format),
MAX_CONTEXT_TOKENS,
);
@@ -288,7 +286,7 @@ pub(crate) fn edit_prediction_accepted(
cx: &App,
) {
let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
- if store.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
+ if store.zeta2_raw_config().is_some() && custom_accept_url.is_none() {
return;
}
@@ -50,6 +50,7 @@ settings.workspace = true
shellexpand.workspace = true
smol.workspace = true
sqlez.workspace = true
+strum.workspace = true
sqlez_macros.workspace = true
terminal_view.workspace = true
util.workspace = true
@@ -12,7 +12,7 @@ use language::{Buffer, OffsetRangeExt, Point};
use similar::DiffableStr;
use std::sync::Arc;
use std::{fmt::Write as _, ops::Range};
-use zeta_prompt::ZetaVersion;
+use zeta_prompt::ZetaFormat;
use zeta_prompt::format_zeta_prompt;
pub async fn run_format_prompt(
@@ -54,7 +54,7 @@ pub async fn run_format_prompt(
let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
cursor_point,
&snapshot,
- edit_prediction::zeta2::max_editable_tokens(ZetaVersion::default()),
+ edit_prediction::zeta2::max_editable_tokens(ZetaFormat::default()),
edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
);
let editable_range = editable_range.to_offset(&snapshot);
@@ -126,7 +126,7 @@ pub fn zeta2_output_for_patch(
input: &zeta_prompt::ZetaPromptInput,
patch: &str,
cursor_offset: Option<usize>,
- version: ZetaVersion,
+ version: ZetaFormat,
) -> Result<String> {
let mut old_editable_region =
input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
@@ -155,7 +155,7 @@ pub fn zeta2_output_for_patch(
}
match version {
- ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => {
+ ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => {
if !result.ends_with('\n') {
result.push('\n');
}
@@ -31,7 +31,7 @@ use edit_prediction::EditPredictionStore;
use futures::channel::mpsc;
use futures::{SinkExt as _, StreamExt as _};
use gpui::{AppContext as _, Application, BackgroundExecutor, Task};
-use zeta_prompt::ZetaVersion;
+use zeta_prompt::ZetaFormat;
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
@@ -207,6 +207,8 @@ enum Command {
Qa(qa::QaArgs),
/// Repair predictions that received poor QA scores by generating improved predictions
Repair(repair::RepairArgs),
+ /// Print all valid zeta formats (lowercase, one per line)
+ PrintZetaFormats,
}
impl Display for Command {
@@ -249,6 +251,9 @@ impl Display for Command {
Command::Repair(_) => {
write!(f, "repair")
}
+ Command::PrintZetaFormats => {
+ write!(f, "print-zeta-formats")
+ }
}
}
}
@@ -321,7 +326,7 @@ enum PredictionProvider {
Sweep,
Mercury,
Zeta1,
- Zeta2(ZetaVersion),
+ Zeta2(ZetaFormat),
Teacher(TeacherBackend),
TeacherNonBatching(TeacherBackend),
Repair,
@@ -329,7 +334,7 @@ enum PredictionProvider {
impl Default for PredictionProvider {
fn default() -> Self {
- PredictionProvider::Zeta2(ZetaVersion::default())
+ PredictionProvider::Zeta2(ZetaFormat::default())
}
}
@@ -339,7 +344,7 @@ impl std::fmt::Display for PredictionProvider {
PredictionProvider::Sweep => write!(f, "sweep"),
PredictionProvider::Mercury => write!(f, "mercury"),
PredictionProvider::Zeta1 => write!(f, "zeta1"),
- PredictionProvider::Zeta2(version) => write!(f, "zeta2:{version}"),
+ PredictionProvider::Zeta2(format) => write!(f, "zeta2:{format}"),
PredictionProvider::Teacher(backend) => write!(f, "teacher:{backend}"),
PredictionProvider::TeacherNonBatching(backend) => {
write!(f, "teacher-non-batching:{backend}")
@@ -361,8 +366,8 @@ impl std::str::FromStr for PredictionProvider {
"mercury" => Ok(PredictionProvider::Mercury),
"zeta1" => Ok(PredictionProvider::Zeta1),
"zeta2" => {
- let version = arg.map(ZetaVersion::parse).transpose()?.unwrap_or_default();
- Ok(PredictionProvider::Zeta2(version))
+ let format = arg.map(ZetaFormat::parse).transpose()?.unwrap_or_default();
+ Ok(PredictionProvider::Zeta2(format))
}
"teacher" => {
let backend = arg
@@ -385,7 +390,7 @@ impl std::str::FromStr for PredictionProvider {
For zeta2, you can optionally specify a version like `zeta2:ordered` or `zeta2:V0113_Ordered`.\n\
For teacher, you can specify a backend like `teacher:sonnet45` or `teacher:gpt52`.\n\
Available zeta versions:\n{}",
- ZetaVersion::options_as_string()
+ ZetaFormat::options_as_string()
)
}
}
@@ -719,6 +724,13 @@ fn main() {
std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
return;
}
+ Command::PrintZetaFormats => {
+ use strum::IntoEnumIterator as _;
+ for format in ZetaFormat::iter() {
+ println!("{}", format.to_string().to_lowercase());
+ }
+ return;
+ }
Command::Synthesize(synth_args) => {
let Some(output_dir) = args.output else {
panic!("output dir is required");
@@ -953,7 +965,8 @@ fn main() {
| Command::Split(_)
| Command::TruncatePatch(_)
| Command::FilterLanguages(_)
- | Command::ImportBatch(_) => {
+ | Command::ImportBatch(_)
+ | Command::PrintZetaFormats => {
unreachable!()
}
}
@@ -5,7 +5,7 @@ use crate::{
repair,
};
use anyhow::{Context as _, Result};
-use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
+use zeta_prompt::{CURSOR_MARKER, ZetaFormat};
pub fn run_parse_output(example: &mut Example) -> Result<()> {
example
@@ -49,13 +49,13 @@ pub fn parse_prediction_output(
}
}
-fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result<String> {
- let (current_marker, end_marker) = match version {
- ZetaVersion::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
- ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
+fn extract_zeta2_current_region(prompt: &str, format: ZetaFormat) -> Result<String> {
+ let (current_marker, end_marker) = match format {
+ ZetaFormat::V0112MiddleAtEnd => ("<|fim_middle|>current\n", "<|fim_middle|>updated"),
+ ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
("<|fim_middle|>current\n", "<|fim_suffix|>")
}
- ZetaVersion::V0120GitMergeMarkers | ZetaVersion::V0131GitMergeMarkersPrefix => (
+ ZetaFormat::V0120GitMergeMarkers | ZetaFormat::V0131GitMergeMarkersPrefix => (
zeta_prompt::v0120_git_merge_markers::START_MARKER,
zeta_prompt::v0120_git_merge_markers::SEPARATOR,
),
@@ -82,7 +82,7 @@ fn extract_zeta2_current_region(prompt: &str, version: ZetaVersion) -> Result<St
fn parse_zeta2_output(
example: &Example,
actual_output: &str,
- version: ZetaVersion,
+ format: ZetaFormat,
) -> Result<(String, Option<ActualCursor>)> {
let prompt = &example.prompt.as_ref().context("prompt required")?.input;
let prompt_inputs = example
@@ -90,7 +90,7 @@ fn parse_zeta2_output(
.as_ref()
.context("prompt_inputs required")?;
- let old_text = extract_zeta2_current_region(prompt, version)?;
+ let old_text = extract_zeta2_current_region(prompt, format)?;
let mut new_text = actual_output.to_string();
let cursor_offset = if let Some(offset) = new_text.find(CURSOR_MARKER) {
@@ -100,11 +100,11 @@ fn parse_zeta2_output(
None
};
- let suffix = match version {
- ZetaVersion::V0131GitMergeMarkersPrefix => {
+ let suffix = match format {
+ ZetaFormat::V0131GitMergeMarkersPrefix => {
zeta_prompt::v0131_git_merge_markers_prefix::END_MARKER
}
- ZetaVersion::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
+ ZetaFormat::V0120GitMergeMarkers => zeta_prompt::v0120_git_merge_markers::END_MARKER,
_ => "",
};
if !suffix.is_empty() {
@@ -184,7 +184,7 @@ mod tests {
<|fim_middle|>updated
"};
- let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
+ let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
assert_eq!(region, "println!(\"hello\");\n");
}
@@ -201,7 +201,7 @@ mod tests {
<|fim_middle|>updated
"};
- let region = extract_zeta2_current_region(prompt, ZetaVersion::V0112MiddleAtEnd).unwrap();
+ let region = extract_zeta2_current_region(prompt, ZetaFormat::V0112MiddleAtEnd).unwrap();
assert_eq!(region, "println!(\"hello\");\n");
}
@@ -218,7 +218,7 @@ mod tests {
<|fim_middle|>updated
"};
- let region = extract_zeta2_current_region(prompt, ZetaVersion::V0113Ordered).unwrap();
+ let region = extract_zeta2_current_region(prompt, ZetaFormat::V0113Ordered).unwrap();
assert_eq!(region, "println!(\"hello\");\n");
}
@@ -236,7 +236,7 @@ mod tests {
"};
let region =
- extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
+ extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
assert_eq!(region, "println!(\"hello\");\n");
}
@@ -254,7 +254,7 @@ mod tests {
"};
let region =
- extract_zeta2_current_region(prompt, ZetaVersion::V0120GitMergeMarkers).unwrap();
+ extract_zeta2_current_region(prompt, ZetaFormat::V0120GitMergeMarkers).unwrap();
assert_eq!(region, "println!(\"hello\");\n");
}
}
@@ -11,7 +11,7 @@ use crate::{
retrieve_context::run_context_retrieval,
};
use anyhow::Context as _;
-use edit_prediction::{DebugEvent, EditPredictionStore};
+use edit_prediction::{DebugEvent, EditPredictionStore, Zeta2RawConfig};
use futures::{FutureExt as _, StreamExt as _, future::Shared};
use gpui::{AppContext as _, AsyncApp, Task};
use std::{
@@ -21,6 +21,7 @@ use std::{
atomic::{AtomicUsize, Ordering::SeqCst},
},
};
+use zeta_prompt::ZetaFormat;
static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
static OPENAI_CLIENT: OnceLock<OpenAiClient> = OnceLock::new();
@@ -103,9 +104,7 @@ pub async fn run_prediction(
ep_store.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
- PredictionProvider::Zeta2(version) => {
- edit_prediction::EditPredictionModel::Zeta2 { version }
- }
+ PredictionProvider::Zeta2(_) => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher(..)
@@ -115,6 +114,15 @@ pub async fn run_prediction(
}
};
store.set_edit_prediction_model(model);
+
+ // If user specified a non-default Zeta2 version, configure raw endpoint.
+ // ZED_ZETA_MODEL env var is optional.
+ if let PredictionProvider::Zeta2(format) = provider {
+ if format != ZetaFormat::default() {
+ let model_id = std::env::var("ZED_ZETA_MODEL").ok();
+ store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
+ }
+ }
});
step_progress.set_substatus("configuring model");
let state = example.state.as_ref().context("state must be set")?;
@@ -217,9 +217,7 @@ fn assign_edit_prediction_provider(
if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<Zeta2FeatureFlag>() =>
{
- edit_prediction::EditPredictionModel::Zeta2 {
- version: Default::default(),
- }
+ edit_prediction::EditPredictionModel::Zeta2
}
EditPredictionProvider::Zed
if user_store.read(cx).current_user().is_some() =>
@@ -39,7 +39,7 @@ pub struct ZetaPromptInput {
Deserialize,
)]
#[allow(non_camel_case_types)]
-pub enum ZetaVersion {
+pub enum ZetaFormat {
V0112MiddleAtEnd,
V0113Ordered,
#[default]
@@ -48,28 +48,28 @@ pub enum ZetaVersion {
V0131GitMergeMarkersPrefix,
}
-impl std::fmt::Display for ZetaVersion {
+impl std::fmt::Display for ZetaFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", <&'static str>::from(self))
}
}
-impl ZetaVersion {
- pub fn parse(version_string: &str) -> Result<Self> {
- let mut results = ZetaVersion::iter().filter(|version| {
+impl ZetaFormat {
+ pub fn parse(format_name: &str) -> Result<Self> {
+ let mut results = ZetaFormat::iter().filter(|version| {
<&'static str>::from(version)
.to_lowercase()
- .contains(&version_string.to_lowercase())
+ .contains(&format_name.to_lowercase())
});
let Some(result) = results.next() else {
anyhow::bail!(
- "`{version_string}` did not match any of:\n{}",
+ "`{format_name}` did not match any of:\n{}",
Self::options_as_string()
);
};
if results.next().is_some() {
anyhow::bail!(
- "`{version_string}` matched more than one of:\n{}",
+ "`{format_name}` matched more than one of:\n{}",
Self::options_as_string()
);
}
@@ -77,8 +77,8 @@ impl ZetaVersion {
}
pub fn options_as_string() -> String {
- ZetaVersion::iter()
- .map(|version| format!("- {}\n", <&'static str>::from(version)))
+ ZetaFormat::iter()
+ .map(|format| format!("- {}\n", <&'static str>::from(format)))
.collect::<Vec<_>>()
.concat()
}
@@ -137,27 +137,40 @@ pub struct RelatedExcerpt {
pub text: Arc<str>,
}
-pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
- format_zeta_prompt_with_budget(input, version, MAX_PROMPT_TOKENS)
+pub fn format_zeta_prompt(input: &ZetaPromptInput, format: ZetaFormat) -> String {
+ format_zeta_prompt_with_budget(input, format, MAX_PROMPT_TOKENS)
+}
+
+/// Post-processes model output for the given zeta format by stripping format-specific suffixes.
+pub fn clean_zeta2_model_output(output: &str, format: ZetaFormat) -> &str {
+ match format {
+ ZetaFormat::V0120GitMergeMarkers => output
+ .strip_suffix(v0120_git_merge_markers::END_MARKER)
+ .unwrap_or(output),
+ ZetaFormat::V0131GitMergeMarkersPrefix => output
+ .strip_suffix(v0131_git_merge_markers_prefix::END_MARKER)
+ .unwrap_or(output),
+ _ => output,
+ }
}
fn format_zeta_prompt_with_budget(
input: &ZetaPromptInput,
- version: ZetaVersion,
+ format: ZetaFormat,
max_tokens: usize,
) -> String {
let mut cursor_section = String::new();
- match version {
- ZetaVersion::V0112MiddleAtEnd => {
+ match format {
+ ZetaFormat::V0112MiddleAtEnd => {
v0112_middle_at_end::write_cursor_excerpt_section(&mut cursor_section, input);
}
- ZetaVersion::V0113Ordered | ZetaVersion::V0114180EditableRegion => {
+ ZetaFormat::V0113Ordered | ZetaFormat::V0114180EditableRegion => {
v0113_ordered::write_cursor_excerpt_section(&mut cursor_section, input)
}
- ZetaVersion::V0120GitMergeMarkers => {
+ ZetaFormat::V0120GitMergeMarkers => {
v0120_git_merge_markers::write_cursor_excerpt_section(&mut cursor_section, input)
}
- ZetaVersion::V0131GitMergeMarkersPrefix => {
+ ZetaFormat::V0131GitMergeMarkersPrefix => {
v0131_git_merge_markers_prefix::write_cursor_excerpt_section(&mut cursor_section, input)
}
}
@@ -563,7 +576,7 @@ mod tests {
}
fn format_with_budget(input: &ZetaPromptInput, max_tokens: usize) -> String {
- format_zeta_prompt_with_budget(input, ZetaVersion::V0114180EditableRegion, max_tokens)
+ format_zeta_prompt_with_budget(input, ZetaFormat::V0114180EditableRegion, max_tokens)
}
#[test]