Cargo.lock 🔗
@@ -21179,7 +21179,9 @@ dependencies = [
name = "zeta_prompt"
version = "0.1.0"
dependencies = [
+ "anyhow",
"serde",
+ "strum 0.27.2",
]
[[package]]
Max Brunsfeld , Agus Zubiaga , and Ben Kunkle created
This one does `fim_prefix`, `fim_middle`, and `fim_suffix` in that
order, in the prompt, instead of putting the current middle last.
Release Notes:
- N/A
---------
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>
Cargo.lock | 2
crates/edit_prediction/src/edit_prediction.rs | 22 +
crates/edit_prediction/src/edit_prediction_tests.rs | 86 ++++++++
crates/edit_prediction/src/zeta2.rs | 15 +
crates/edit_prediction_cli/src/example.rs | 26 -
crates/edit_prediction_cli/src/format_prompt.rs | 122 +++++------
crates/edit_prediction_cli/src/load_project.rs | 23 +
crates/edit_prediction_cli/src/main.rs | 35 ++-
crates/edit_prediction_cli/src/predict.rs | 26 +
crates/edit_prediction_cli/src/pull_examples.rs | 3
crates/edit_prediction_cli/src/retrieve_context.rs | 14
crates/edit_prediction_cli/src/score.rs | 11
crates/zed/src/zed/edit_prediction_registry.rs | 4
crates/zeta_prompt/Cargo.toml | 4
crates/zeta_prompt/src/zeta_prompt.rs | 149 ++++++++++++--
15 files changed, 382 insertions(+), 160 deletions(-)
@@ -21179,7 +21179,9 @@ dependencies = [
name = "zeta_prompt"
version = "0.1.0"
dependencies = [
+ "anyhow",
"serde",
+ "strum 0.27.2",
]
[[package]]
@@ -38,6 +38,7 @@ use settings::{EditPredictionProvider, SettingsStore, update_settings_file};
use std::collections::{VecDeque, hash_map};
use text::Edit;
use workspace::Workspace;
+use zeta_prompt::ZetaVersion;
use std::ops::Range;
use std::path::Path;
@@ -183,7 +184,9 @@ pub struct EditPredictionStore {
pub enum EditPredictionModel {
#[default]
Zeta1,
- Zeta2,
+ Zeta2 {
+ version: ZetaVersion,
+ },
Sweep,
Mercury,
}
@@ -654,7 +657,9 @@ impl EditPredictionStore {
update_required: false,
#[cfg(feature = "cli-support")]
eval_cache: None,
- edit_prediction_model: EditPredictionModel::Zeta2,
+ edit_prediction_model: EditPredictionModel::Zeta2 {
+ version: Default::default(),
+ },
sweep_ai: SweepAi::new(cx),
mercury: Mercury::new(cx),
data_collection_choice,
@@ -794,7 +799,10 @@ impl EditPredictionStore {
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
- if self.edit_prediction_model == EditPredictionModel::Zeta2 {
+ if matches!(
+ self.edit_prediction_model,
+ EditPredictionModel::Zeta2 { .. }
+ ) {
self.user_store.read(cx).edit_prediction_usage()
} else {
None
@@ -1204,7 +1212,7 @@ impl EditPredictionStore {
sweep_ai::edit_prediction_accepted(self, current_prediction, cx)
}
EditPredictionModel::Mercury => {}
- EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
zeta2::edit_prediction_accepted(self, current_prediction, cx)
}
}
@@ -1338,7 +1346,7 @@ impl EditPredictionStore {
was_shown: bool,
) {
match self.edit_prediction_model {
- EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 { .. } => {
if self.custom_predict_edits_url.is_some() {
return;
}
@@ -1773,7 +1781,9 @@ impl EditPredictionStore {
}
let task = match self.edit_prediction_model {
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
- EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
+ EditPredictionModel::Zeta2 { version } => {
+ zeta2::request_prediction_with_zeta2(self, inputs, version, cx)
+ }
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
};
@@ -1332,12 +1332,20 @@ fn model_response(request: RawCompletionRequest, diff_to_apply: &str) -> RawComp
let current_marker = "<|fim_middle|>current\n";
let updated_marker = "<|fim_middle|>updated\n";
+ let suffix_marker = "<|fim_suffix|>\n";
let cursor = "<|user_cursor|>";
let start_ix = current_marker.len() + prompt.find(current_marker).unwrap();
let end_ix = start_ix + &prompt[start_ix..].find(updated_marker).unwrap();
let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
- let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
+ // In v0113_ordered format, the excerpt contains <|fim_suffix|> and suffix content.
+ // Strip that out to get just the editable region.
+ let excerpt = if let Some(suffix_pos) = excerpt.find(suffix_marker) {
+ &excerpt[..suffix_pos]
+ } else {
+ &excerpt
+ };
+ let new_excerpt = apply_diff_to_string(diff_to_apply, excerpt).unwrap();
RawCompletionResponse {
id: Uuid::new_v4().to_string(),
@@ -1629,6 +1637,82 @@ async fn test_edit_prediction_end_of_buffer(cx: &mut TestAppContext) {
);
}
+#[gpui::test]
+async fn test_edit_prediction_no_spurious_trailing_newline(cx: &mut TestAppContext) {
+ // Test that zeta2's newline normalization logic doesn't insert spurious newlines.
+ // When the buffer ends without a trailing newline, but the model returns output
+ // with a trailing newline, zeta2 should normalize both sides before diffing
+ // so no spurious newline is inserted.
+ let (ep_store, mut requests) = init_test_with_fake_client(cx);
+ let fs = FakeFs::new(cx.executor());
+
+ // Single line buffer with no trailing newline
+ fs.insert_tree(
+ "/root",
+ json!({
+ "foo.txt": "hello"
+ }),
+ )
+ .await;
+ let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project
+ .find_project_path(path!("root/foo.txt"), cx)
+ .unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
+ let position = snapshot.anchor_before(language::Point::new(0, 5));
+
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
+ });
+
+ let (_request, respond_tx) = requests.predict.next().await.unwrap();
+
+ // Model returns output WITH a trailing newline, even though the buffer doesn't have one.
+ // Zeta2 should normalize both sides before diffing, so no spurious newline is inserted.
+ let response = RawCompletionResponse {
+ id: Uuid::new_v4().to_string(),
+ object: "text_completion".into(),
+ created: 0,
+ model: "model".into(),
+ choices: vec![RawCompletionChoice {
+ text: "hello world\n".to_string(),
+ finish_reason: None,
+ }],
+ usage: RawCompletionUsage {
+ prompt_tokens: 0,
+ completion_tokens: 0,
+ total_tokens: 0,
+ },
+ };
+ respond_tx.send(response).unwrap();
+
+ cx.run_until_parked();
+
+ // The prediction should insert " world" without adding a newline
+ ep_store.update(cx, |ep_store, cx| {
+ let prediction = ep_store
+ .prediction_at(&buffer, None, &project, cx)
+ .expect("should have prediction");
+ let edits: Vec<_> = prediction
+ .edits
+ .iter()
+ .map(|(range, text)| {
+ let snapshot = buffer.read(cx).snapshot();
+ (range.to_offset(&snapshot), text.clone())
+ })
+ .collect();
+ assert_eq!(edits, vec![(5..5, " world".into())]);
+ });
+}
+
#[gpui::test]
async fn test_can_collect_data(cx: &mut TestAppContext) {
init_test(cx);
@@ -15,8 +15,8 @@ use release_channel::AppVersion;
use std::env;
use std::{path::Path, sync::Arc, time::Instant};
-use zeta_prompt::CURSOR_MARKER;
use zeta_prompt::format_zeta_prompt;
+use zeta_prompt::{CURSOR_MARKER, ZetaVersion};
pub const MAX_CONTEXT_TOKENS: usize = 350;
pub const MAX_EDITABLE_TOKENS: usize = 150;
@@ -32,6 +32,7 @@ pub fn request_prediction_with_zeta2(
debug_tx,
..
}: EditPredictionModelInput,
+ zeta_version: ZetaVersion,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let buffer_snapshotted_at = Instant::now();
@@ -62,7 +63,7 @@ pub fn request_prediction_with_zeta2(
cursor_offset,
);
- let prompt = format_zeta_prompt(&prompt_input);
+ let prompt = format_zeta_prompt(&prompt_input, zeta_version);
if let Some(debug_tx) = &debug_tx {
debug_tx
@@ -125,9 +126,17 @@ pub fn request_prediction_with_zeta2(
output_text = output_text.replace(CURSOR_MARKER, "");
}
- let old_text = snapshot
+ let mut old_text = snapshot
.text_for_range(editable_offset_range.clone())
.collect::<String>();
+
+ if !output_text.is_empty() && !output_text.ends_with('\n') {
+ output_text.push('\n');
+ }
+ if !old_text.is_empty() && !old_text.ends_with('\n') {
+ old_text.push('\n');
+ }
+
let edits: Vec<_> = language::text_diff(&old_text, &output_text)
.into_iter()
.map(|(range, text)| {
@@ -1,5 +1,5 @@
+use crate::PredictionProvider;
use crate::paths::WORKTREES_DIR;
-use crate::{PredictionProvider, PromptFormat};
use anyhow::{Context as _, Result};
use collections::HashMap;
use edit_prediction::example_spec::ExampleSpec;
@@ -9,11 +9,12 @@ use http_client::Url;
use language::{Anchor, Buffer};
use project::Project;
use serde::{Deserialize, Serialize};
-use std::ops::Range;
use std::{
borrow::Cow,
io::Read,
+ ops::Range,
path::{Path, PathBuf},
+ sync::Arc,
};
use zeta_prompt::RelatedFile;
@@ -25,12 +26,7 @@ pub struct Example {
/// The full content of the file where an edit is being predicted, and the
/// actual cursor offset.
#[serde(skip_serializing_if = "Option::is_none")]
- pub buffer: Option<ExampleBuffer>,
-
- /// The context retrieved for the prediction. This requires the worktree to
- /// be loaded and the language server to be started.
- #[serde(skip_serializing_if = "Option::is_none")]
- pub context: Option<ExampleContext>,
+ pub prompt_inputs: Option<ExamplePromptInputs>,
/// The input and expected output from the edit prediction model.
#[serde(skip_serializing_if = "Option::is_none")]
@@ -59,25 +55,22 @@ pub struct ExampleState {
}
#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct ExampleContext {
- pub files: Vec<RelatedFile>,
-}
-
-#[derive(Clone, Debug, Serialize, Deserialize)]
-pub struct ExampleBuffer {
+pub struct ExamplePromptInputs {
pub content: String,
pub cursor_row: u32,
pub cursor_column: u32,
pub cursor_offset: usize,
pub context_range: Range<usize>,
pub editable_range: Range<usize>,
+ pub edit_history: Vec<Arc<zeta_prompt::Event>>,
+ pub related_files: Option<Vec<RelatedFile>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExamplePrompt {
pub input: String,
pub expected_output: String,
- pub format: PromptFormat,
+ pub provider: PredictionProvider,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
@@ -239,8 +232,7 @@ fn parse_markdown_example(input: &str) -> Result<Example> {
let spec = ExampleSpec::from_markdown(input)?;
Ok(Example {
spec,
- buffer: None,
- context: None,
+ prompt_inputs: None,
prompt: None,
predictions: Vec::new(),
score: Vec::new(),
@@ -1,14 +1,12 @@
use crate::{
- PromptFormat,
+ FormatPromptArgs, PredictionProvider,
example::{Example, ExamplePrompt},
headless::EpAppState,
- load_project::run_load_project,
progress::{Progress, Step},
retrieve_context::run_context_retrieval,
};
use anyhow::{Context as _, Result};
-use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
-use gpui::{AsyncApp, Entity};
+use gpui::AsyncApp;
use similar::DiffableStr;
use std::fmt::Write as _;
use std::sync::Arc;
@@ -16,16 +14,21 @@ use zeta_prompt::format_zeta_prompt;
pub async fn run_format_prompt(
example: &mut Example,
- prompt_format: PromptFormat,
+ args: &FormatPromptArgs,
app_state: Arc<EpAppState>,
- mut cx: AsyncApp,
+ cx: AsyncApp,
) -> Result<()> {
- run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
+ run_context_retrieval(example, app_state, cx).await?;
let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
- match prompt_format {
- PromptFormat::Teacher => {
+ let prompt_inputs = example
+ .prompt_inputs
+ .as_ref()
+ .context("prompt_inputs must be set after context retrieval")?;
+
+ match args.provider {
+ PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
step_progress.set_substatus("formatting teacher prompt");
let prompt = TeacherPrompt::format_prompt(example);
example.prompt = Some(ExamplePrompt {
@@ -36,47 +39,27 @@ pub async fn run_format_prompt(
.first()
.cloned()
.unwrap_or_default(),
- format: prompt_format,
+ provider: args.provider,
});
}
- PromptFormat::Zeta2 => {
- step_progress.set_substatus("loading project");
- run_load_project(example, app_state, cx.clone()).await?;
-
+ PredictionProvider::Zeta2 => {
step_progress.set_substatus("formatting zeta2 prompt");
- let ep_store: Entity<EditPredictionStore> = cx.update(|cx| {
- EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
- })?;
-
- let state = example.state.as_ref().context("state must be set")?;
- let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot());
- let project = state.project.clone();
- let (_, input) =
- ep_store.update(&mut cx, |ep_store: &mut EditPredictionStore, cx| {
- let events = ep_store
- .edit_history_for_project(&project, cx)
- .into_iter()
- .map(|e| e.event)
- .collect();
- anyhow::Ok(zeta2_prompt_input(
- &snapshot,
- example
- .context
- .as_ref()
- .context("context must be set")?
- .files
- .clone(),
- events,
- example.spec.cursor_path.clone(),
- example
- .buffer
- .as_ref()
- .context("buffer must be set")?
- .cursor_offset,
- ))
- })?;
- let prompt = format_zeta_prompt(&input);
+ let context_start = prompt_inputs.context_range.start;
+ let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
+ let editable_range_in_excerpt = (prompt_inputs.editable_range.start - context_start)
+ ..(prompt_inputs.editable_range.end - context_start);
+ let input = zeta_prompt::ZetaPromptInput {
+ cursor_path: example.spec.cursor_path.clone(),
+ cursor_excerpt: prompt_inputs.content[prompt_inputs.context_range.clone()]
+ .to_string()
+ .into(),
+ editable_range_in_excerpt,
+ cursor_offset_in_excerpt,
+ events: prompt_inputs.edit_history.clone(),
+ related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
+ };
+ let prompt = format_zeta_prompt(&input, args.version);
let expected_output = zeta2_output_for_patch(
&input,
&example
@@ -89,9 +72,12 @@ pub async fn run_format_prompt(
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output,
- format: prompt_format,
+ provider: args.provider,
});
}
+ _ => {
+ panic!("Cannot format prompt for {:?}", args.provider);
+ }
};
Ok(())
}
@@ -144,10 +130,10 @@ impl TeacherPrompt {
// 2. Context retriever just didn't include cursor line.
//
// In that case, fallback to using `cursor_position` as excerpt.
- let example_buffer = example
- .buffer
+ let prompt_inputs = example
+ .prompt_inputs
.as_ref()
- .context("`buffer` should be filled in in the context collection step")?;
+ .context("`prompt_inputs` should be filled in in the context collection step")?;
// Extract updated (new) editable region from the model response.
// The model may include editable region markers in its output, so we need to strip them.
@@ -155,7 +141,7 @@ impl TeacherPrompt {
let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
let old_editable_region =
- example_buffer.content[example_buffer.editable_range.clone()].to_string();
+ prompt_inputs.content[prompt_inputs.editable_range.clone()].to_string();
// Normalize leading newlines: if old starts with newline but new doesn't,
// prepend newline to new to preserve whitespace structure.
@@ -164,8 +150,8 @@ impl TeacherPrompt {
new_editable_region.insert(0, '\n');
}
- let editable_region_start_line = example_buffer.content
- [..example_buffer.editable_range.start]
+ let editable_region_start_line = prompt_inputs.content
+ [..prompt_inputs.editable_range.start]
.matches('\n')
.count();
@@ -208,17 +194,21 @@ impl TeacherPrompt {
}
fn format_context(example: &Example) -> String {
- let context = example
- .context
+ let related_files = example
+ .prompt_inputs
.as_ref()
- .expect("Missing context retriever step");
+ .and_then(|pi| pi.related_files.as_ref());
+
+ let Some(related_files) = related_files else {
+ return "(No context)".to_string();
+ };
- if context.files.is_empty() {
+ if related_files.is_empty() {
return "(No context)".to_string();
}
let mut prompt = String::new();
- for file in context.files.iter() {
+ for file in related_files {
let path_str = file.path.to_string_lossy();
writeln!(&mut prompt, "`````{path_str}").ok();
let mut prev_row = 0;
@@ -242,28 +232,26 @@ impl TeacherPrompt {
fn format_cursor_excerpt(example: &Example) -> String {
let mut result = String::new();
- let example_buffer = example.buffer.as_ref().unwrap();
+ let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
let path_str = example.spec.cursor_path.to_string_lossy();
result.push_str(&format!("`````{path_str}\n"));
result.push_str(
- &example_buffer.content
- [example_buffer.context_range.start..example_buffer.editable_range.start],
+ &prompt_inputs.content
+ [prompt_inputs.context_range.start..prompt_inputs.editable_range.start],
);
result.push_str(Self::EDITABLE_REGION_START);
result.push_str(
- &example_buffer.content
- [example_buffer.editable_range.start..example_buffer.cursor_offset],
+ &prompt_inputs.content[prompt_inputs.editable_range.start..prompt_inputs.cursor_offset],
);
result.push_str(Self::USER_CURSOR_MARKER);
result.push_str(
- &example_buffer.content
- [example_buffer.cursor_offset..example_buffer.editable_range.end],
+ &prompt_inputs.content[prompt_inputs.cursor_offset..prompt_inputs.editable_range.end],
);
result.push_str(Self::EDITABLE_REGION_END);
result.push_str(
- &example_buffer.content
- [example_buffer.editable_range.end..example_buffer.context_range.end],
+ &prompt_inputs.content
+ [prompt_inputs.editable_range.end..prompt_inputs.context_range.end],
);
result.push_str("\n`````");
@@ -1,5 +1,5 @@
use crate::{
- example::{Example, ExampleBuffer, ExampleState},
+ example::{Example, ExamplePromptInputs, ExampleState},
git,
headless::EpAppState,
progress::{InfoStyle, Progress, Step, StepProgress},
@@ -38,7 +38,20 @@ pub async fn run_load_project(
buffer
.read_with(&cx, |buffer, _| buffer.parsing_idle())
.await;
- let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
+
+ let ep_store = cx
+ .update(|cx| EditPredictionStore::try_global(cx))
+ .context("EditPredictionStore not initialized")?;
+
+ let edit_history = ep_store.update(&mut cx, |store, cx| {
+ store
+ .edit_history_for_project(&project, cx)
+ .into_iter()
+ .map(|e| e.event)
+ .collect()
+ });
+
+ let (prompt_inputs, language_name) = buffer.read_with(&cx, |buffer, _cx| {
let cursor_point = cursor_position.to_point(&buffer);
let snapshot = buffer.snapshot();
let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
@@ -54,13 +67,15 @@ pub async fn run_load_project(
.map(|l| l.name().to_string())
.unwrap_or_else(|| "Unknown".to_string());
(
- ExampleBuffer {
+ ExamplePromptInputs {
content: buffer.text(),
cursor_row: cursor_point.row,
cursor_column: cursor_point.column,
cursor_offset: cursor_position.to_offset(&buffer),
context_range,
editable_range,
+ edit_history,
+ related_files: None,
},
language_name,
)
@@ -68,7 +83,7 @@ pub async fn run_load_project(
progress.set_info(language_name, InfoStyle::Normal);
- example.buffer = Some(example_buffer);
+ example.prompt_inputs = Some(prompt_inputs);
example.state = Some(ExampleState {
buffer,
project,
@@ -22,6 +22,7 @@ use edit_prediction::EditPredictionStore;
use futures::channel::mpsc;
use futures::{SinkExt as _, StreamExt as _};
use gpui::{AppContext as _, Application};
+use zeta_prompt::ZetaVersion;
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Serialize};
@@ -155,7 +156,7 @@ impl Display for Command {
f,
"format-prompt --prompt-format={}",
format_prompt_args
- .prompt_format
+ .provider
.to_possible_value()
.unwrap()
.get_name()
@@ -204,22 +205,31 @@ impl Display for Command {
#[derive(Debug, Args, Clone)]
struct FormatPromptArgs {
- #[clap(long, short('p'))]
- prompt_format: PromptFormat,
-}
-
-#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
-enum PromptFormat {
- Teacher,
- Zeta2,
+ #[clap(long, short)]
+ provider: PredictionProvider,
+ #[clap(
+ long,
+ short,
+ help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
+ value_parser = ZetaVersion::parse,
+ default_value_t = ZetaVersion::default(),
+ )]
+ version: ZetaVersion,
}
#[derive(Debug, Args, Clone)]
struct PredictArgs {
- #[clap(long)]
+ #[clap(long, short)]
provider: PredictionProvider,
#[clap(long, default_value_t = 1)]
repetitions: usize,
+ #[clap(
+ long,
+ short,
+ help = "(only for --provider zeta2) A substring of a zeta_prompt::ZetaVersion variant to use",
+ value_parser = ZetaVersion::parse,
+ )]
+ version: ZetaVersion,
}
#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)]
@@ -514,7 +524,7 @@ fn main() {
Command::FormatPrompt(args) => {
run_format_prompt(
example,
- args.prompt_format,
+ args,
app_state.clone(),
cx.clone(),
)
@@ -523,8 +533,7 @@ fn main() {
Command::Predict(args) => {
run_prediction(
example,
- Some(args.provider),
- args.repetitions,
+ args,
app_state.clone(),
cx.clone(),
)
@@ -1,5 +1,5 @@
use crate::{
- PredictionProvider, PromptFormat,
+ FormatPromptArgs, PredictArgs, PredictionProvider,
anthropic_client::AnthropicClient,
example::{Example, ExamplePrediction, ExamplePrompt},
format_prompt::{TeacherPrompt, run_format_prompt},
@@ -25,12 +25,13 @@ static ANTHROPIC_CLIENT: OnceLock<AnthropicClient> = OnceLock::new();
pub async fn run_prediction(
example: &mut Example,
- provider: Option<PredictionProvider>,
- repetition_count: usize,
+ args: &PredictArgs,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) -> anyhow::Result<()> {
- let provider = provider.context("provider is required")?;
+ let provider = args.provider;
+ let repetition_count = args.repetitions;
+ let zeta_version = args.version;
if let Some(existing_prediction) = example.predictions.first() {
if existing_prediction.provider == provider {
@@ -48,7 +49,16 @@ pub async fn run_prediction(
) {
let _step_progress = Progress::global().start(Step::Predict, &example.spec.name);
- run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
+ run_format_prompt(
+ example,
+ &FormatPromptArgs {
+ provider,
+ version: args.version,
+ },
+ app_state.clone(),
+ cx,
+ )
+ .await?;
let batched = matches!(provider, PredictionProvider::Teacher);
return predict_anthropic(example, repetition_count, batched).await;
@@ -85,7 +95,9 @@ pub async fn run_prediction(
ep_store.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
- PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
+ PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2 {
+ version: zeta_version,
+ },
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
@@ -127,7 +139,7 @@ pub async fn run_prediction(
updated_example.prompt.get_or_insert(ExamplePrompt {
input: prompt,
expected_output: String::new(),
- format: PromptFormat::Zeta2,
+ provider,
});
}
}
@@ -149,8 +149,7 @@ fn examples_from_response(
match parse_result {
Ok(spec) => Some(Example {
spec,
- buffer: None,
- context: None,
+ prompt_inputs: None,
prompt: None,
predictions: Vec::new(),
score: Vec::new(),
@@ -1,5 +1,5 @@
use crate::{
- example::{Example, ExampleContext},
+ example::Example,
headless::EpAppState,
load_project::run_load_project,
progress::{InfoStyle, Progress, Step, StepProgress},
@@ -19,7 +19,11 @@ pub async fn run_context_retrieval(
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) -> anyhow::Result<()> {
- if example.context.is_some() {
+ if example
+ .prompt_inputs
+ .as_ref()
+ .is_some_and(|inputs| inputs.related_files.is_some())
+ {
return Ok(());
}
@@ -63,9 +67,9 @@ pub async fn run_context_retrieval(
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
- example.context = Some(ExampleContext {
- files: context_files,
- });
+ if let Some(prompt_inputs) = example.prompt_inputs.as_mut() {
+ prompt_inputs.related_files = Some(context_files);
+ }
Ok(())
}
@@ -17,19 +17,12 @@ pub async fn run_scoring(
app_state: Arc<EpAppState>,
cx: AsyncApp,
) -> anyhow::Result<()> {
- run_prediction(
- example,
- Some(args.provider),
- args.repetitions,
- app_state,
- cx,
- )
- .await?;
+ run_prediction(example, args, app_state, cx).await?;
let progress = Progress::global().start(Step::Score, &example.spec.name);
progress.set_substatus("applying patches");
- let original_text = &example.buffer.as_ref().unwrap().content;
+ let original_text = &example.prompt_inputs.as_ref().unwrap().content;
let expected_texts: Vec<String> = example
.spec
.expected_patches
@@ -204,7 +204,9 @@ fn assign_edit_prediction_provider(
} else if name == EXPERIMENTAL_ZETA2_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<Zeta2FeatureFlag>()
{
- edit_prediction::EditPredictionModel::Zeta2
+ edit_prediction::EditPredictionModel::Zeta2 {
+ version: Default::default(),
+ }
} else if name == EXPERIMENTAL_MERCURY_EDIT_PREDICTION_PROVIDER_NAME
&& cx.has_flag::<MercuryFeatureFlag>()
{
@@ -12,4 +12,6 @@ workspace = true
path = "src/zeta_prompt.rs"
[dependencies]
-serde.workspace = true
+anyhow.workspace = true
+serde.workspace = true
+strum.workspace = true
@@ -1,8 +1,10 @@
+use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::ops::Range;
use std::path::Path;
use std::sync::Arc;
+use strum::{EnumIter, IntoEnumIterator as _, IntoStaticStr};
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
@@ -16,6 +18,54 @@ pub struct ZetaPromptInput {
pub related_files: Vec<RelatedFile>,
}
+#[derive(Default, Clone, Copy, Debug, PartialEq, Eq, EnumIter, IntoStaticStr)]
+#[allow(non_camel_case_types)]
+pub enum ZetaVersion {
+ V0112_MiddleAtEnd,
+ #[default]
+ V0113_Ordered,
+}
+
+impl std::fmt::Display for ZetaVersion {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}", <&'static str>::from(self))
+ }
+}
+
+impl ZetaVersion {
+ pub fn parse(version_string: &str) -> Result<Self> {
+ let mut results = ZetaVersion::iter().filter(|version| {
+ <&'static str>::from(version)
+ .to_lowercase()
+ .contains(&version_string.to_lowercase())
+ });
+ let Some(result) = results.next() else {
+ anyhow::bail!(
+ "`{version_string}` did not match any of:\n{}",
+ Self::options_as_string()
+ );
+ };
+ if results.next().is_some() {
+ anyhow::bail!(
+ "`{version_string}` matched more than one of:\n{}",
+ Self::options_as_string()
+ );
+ }
+ Ok(result)
+ }
+
+ fn options_as_string() -> String {
+ ZetaVersion::iter()
+ .map(|version| format!("- {}\n", <&'static str>::from(version)))
+ .collect::<Vec<_>>()
+ .concat()
+ }
+
+ pub fn default_as_string() -> String {
+ <&'static str>::from(Self::default()).to_string()
+ }
+}
+
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "event")]
pub enum Event {
@@ -69,11 +119,20 @@ pub struct RelatedExcerpt {
pub text: Arc<str>,
}
-pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String {
+pub fn format_zeta_prompt(input: &ZetaPromptInput, version: ZetaVersion) -> String {
let mut prompt = String::new();
write_related_files(&mut prompt, &input.related_files);
write_edit_history_section(&mut prompt, input);
- write_cursor_excerpt_section(&mut prompt, input);
+
+ match version {
+ ZetaVersion::V0112_MiddleAtEnd => {
+ v0112_middle_at_end::write_cursor_excerpt_section(&mut prompt, input);
+ }
+ ZetaVersion::V0113_Ordered => {
+ v0113_ordered::write_cursor_excerpt_section(&mut prompt, input)
+ }
+ }
+
prompt
}
@@ -100,31 +159,73 @@ fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
}
}
-fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
- let path_str = input.cursor_path.to_string_lossy();
- write!(prompt, "<|file_sep|>{}\n", path_str).ok();
+mod v0112_middle_at_end {
+ use super::*;
- prompt.push_str("<|fim_prefix|>\n");
- prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+ pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
+ let path_str = input.cursor_path.to_string_lossy();
+ write!(prompt, "<|file_sep|>{}\n", path_str).ok();
- prompt.push_str("<|fim_suffix|>\n");
- prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
- if !prompt.ends_with('\n') {
- prompt.push('\n');
- }
+ prompt.push_str("<|fim_prefix|>\n");
+ prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
- prompt.push_str("<|fim_middle|>current\n");
- prompt.push_str(
- &input.cursor_excerpt
- [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
- );
- prompt.push_str(CURSOR_MARKER);
- prompt.push_str(
- &input.cursor_excerpt[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
- );
- if !prompt.ends_with('\n') {
- prompt.push('\n');
+ prompt.push_str("<|fim_suffix|>\n");
+ prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+ if !prompt.ends_with('\n') {
+ prompt.push('\n');
+ }
+
+ prompt.push_str("<|fim_middle|>current\n");
+ prompt.push_str(
+ &input.cursor_excerpt
+ [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
+ );
+ prompt.push_str(CURSOR_MARKER);
+ prompt.push_str(
+ &input.cursor_excerpt
+ [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
+ );
+ if !prompt.ends_with('\n') {
+ prompt.push('\n');
+ }
+
+ prompt.push_str("<|fim_middle|>updated\n");
}
+}
+
+mod v0113_ordered {
+ use super::*;
+
+ pub fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
+ let path_str = input.cursor_path.to_string_lossy();
+ write!(prompt, "<|file_sep|>{}\n", path_str).ok();
- prompt.push_str("<|fim_middle|>updated\n");
+ prompt.push_str("<|fim_prefix|>\n");
+ prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
+ if !prompt.ends_with('\n') {
+ prompt.push('\n');
+ }
+
+ prompt.push_str("<|fim_middle|>current\n");
+ prompt.push_str(
+ &input.cursor_excerpt
+ [input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
+ );
+ prompt.push_str(CURSOR_MARKER);
+ prompt.push_str(
+ &input.cursor_excerpt
+ [input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
+ );
+ if !prompt.ends_with('\n') {
+ prompt.push('\n');
+ }
+
+ prompt.push_str("<|fim_suffix|>\n");
+ prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
+ if !prompt.ends_with('\n') {
+ prompt.push('\n');
+ }
+
+ prompt.push_str("<|fim_middle|>updated\n");
+ }
}