From 8fc7bd9ae8a940abbcf708846a4ecd70d3b236fe Mon Sep 17 00:00:00 2001 From: Michael Sloan Date: Wed, 24 Sep 2025 18:07:43 -0600 Subject: [PATCH] zeta2: Add labeled sections prompt format (#38828) Release Notes: - N/A Co-authored-by: Agus --- Cargo.lock | 1 + .../cloud_llm_client/src/predict_edits_v3.rs | 22 +- .../src/cloud_zeta2_prompt.rs | 251 +++++++++++------- crates/zeta2/src/prediction.rs | 2 +- crates/zeta2/src/zeta2.rs | 20 +- crates/zeta2_tools/src/zeta2_tools.rs | 4 +- crates/zeta_cli/Cargo.toml | 1 + crates/zeta_cli/src/main.rs | 22 +- 8 files changed, 201 insertions(+), 122 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a2ab191e59160fb257fcd70bbf6c3cbe0c628fbb..fa7b50381da8dbe68ebb3149192c1f7738c07d37 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20669,6 +20669,7 @@ dependencies = [ "anyhow", "clap", "client", + "cloud_llm_client", "cloud_zeta2_prompt", "debug_adapter_extension", "edit_prediction_context", diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index eeca7ed4e24d594d4a7e9555b3d9fbfd6f3706d2..9c5123fdb8e7aaddbda3bd7cd5d36b112de7538d 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -1,6 +1,10 @@ use chrono::Duration; use serde::{Deserialize, Serialize}; -use std::{ops::Range, path::PathBuf}; +use std::{ + ops::Range, + path::{Path, PathBuf}, + sync::Arc, +}; use uuid::Uuid; use crate::PredictEditsGitInfo; @@ -10,7 +14,7 @@ use crate::PredictEditsGitInfo; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PredictEditsRequest { pub excerpt: String, - pub excerpt_path: PathBuf, + pub excerpt_path: Arc, /// Within file pub excerpt_range: Range, /// Within `excerpt` @@ -32,7 +36,17 @@ pub struct PredictEditsRequest { // Only available to staff #[serde(default)] pub debug_info: bool, + #[serde(skip_serializing_if = "Option::is_none", default)] pub prompt_max_bytes: Option, + #[serde(default)] + pub prompt_format: PromptFormat, +} + +#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize, PartialEq)] +pub enum PromptFormat { + #[default] + MarkedExcerpt, + LabeledSections, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -59,7 +73,7 @@ pub struct Signature { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ReferencedDeclaration { - pub path: PathBuf, + pub path: Arc, pub text: String, pub text_is_truncated: bool, /// Range of `text` within file, possibly truncated according to `text_is_truncated` @@ -117,7 +131,7 @@ pub struct DebugInfo { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Edit { - pub path: PathBuf, + pub path: Arc, pub range: Range, pub content: String, } diff --git a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs index 89cfb4c41f2d05c521c81c6e136d7bc46861d7bc..cc5c8cb8b287e620e38910a6bc4408f67a5722aa 100644 --- a/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs +++ b/crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs @@ -1,29 +1,47 @@ //! Zeta2 prompt planning and generation code shared with cloud. -use anyhow::{Result, anyhow}; -use cloud_llm_client::predict_edits_v3::{self, Event, ReferencedDeclaration}; +use anyhow::{Context as _, Result, anyhow}; +use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration}; use indoc::indoc; use ordered_float::OrderedFloat; use rustc_hash::{FxHashMap, FxHashSet}; use std::fmt::Write; +use std::sync::Arc; use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path}; use strum::{EnumIter, IntoEnumIterator}; pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024; -pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>"; +pub const CURSOR_MARKER: &str = "<|cursor_position|>"; /// NOTE: Differs from zed version of constant - includes a newline pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n"; /// NOTE: Differs from zed version of constant - includes a newline pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n"; // TODO: use constants for markers? -pub const SYSTEM_PROMPT: &str = indoc! {" +const MARKED_EXCERPT_SYSTEM_PROMPT: &str = indoc! {" You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. - The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor_is_here|>. Please respond with edited code for that region. + The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|cursor_position|>. Please respond with edited code for that region. + + Other code is provided for context, and `…` indicates when code has been skipped. "}; +const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#" + You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code. + + Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`). + + The cursor position is marked with `<|cursor_position|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it. + + Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example: + + <|current_section|> + for i in 0..16 { + println!("{i}"); + } +"#}; + pub struct PlannedPrompt<'a> { request: &'a predict_edits_v3::PredictEditsRequest, /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in @@ -32,13 +50,16 @@ pub struct PlannedPrompt<'a> { budget_used: usize, } -pub struct PlanOptions { - pub max_bytes: usize, +pub fn system_prompt(format: PromptFormat) -> &'static str { + match format { + PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT, + PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT, + } } #[derive(Clone, Debug)] pub struct PlannedSnippet<'a> { - path: &'a Path, + path: Arc, range: Range, text: &'a str, // TODO: Indicate this in the output @@ -52,6 +73,12 @@ pub enum SnippetStyle { Declaration, } +#[derive(Clone, Debug)] +pub struct SectionLabels { + pub excerpt_index: usize, + pub section_ranges: Vec<(Arc, Range)>, +} + impl<'a> PlannedPrompt<'a> { /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following: /// @@ -74,10 +101,7 @@ impl<'a> PlannedPrompt<'a> { /// signatures may be shared by multiple snippets. /// /// * Does not include file paths / other text when considering max_bytes. - pub fn populate( - request: &'a predict_edits_v3::PredictEditsRequest, - options: &PlanOptions, - ) -> Result { + pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result { let mut this = PlannedPrompt { request, snippets: Vec::new(), @@ -91,11 +115,13 @@ impl<'a> PlannedPrompt<'a> { )?; this.add_parents(&mut included_parents, additional_parents); - if this.budget_used > options.max_bytes { + let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES); + + if this.budget_used > max_bytes { return Err(anyhow!( "Excerpt + signatures size of {} already exceeds budget of {}", this.budget_used, - options.max_bytes + max_bytes )); } @@ -138,7 +164,7 @@ impl<'a> PlannedPrompt<'a> { }; let mut additional_bytes = declaration_size(declaration, queue_entry.style); - if this.budget_used + additional_bytes > options.max_bytes { + if this.budget_used + additional_bytes > max_bytes { continue; } @@ -151,7 +177,7 @@ impl<'a> PlannedPrompt<'a> { .iter() .map(|(_, snippet)| snippet.text.len()) .sum::(); - if this.budget_used + additional_bytes > options.max_bytes { + if this.budget_used + additional_bytes > max_bytes { continue; } @@ -168,7 +194,7 @@ impl<'a> PlannedPrompt<'a> { )); }; PlannedSnippet { - path: &declaration.path, + path: declaration.path.clone(), range: (declaration.signature_range.start + declaration.range.start) ..(declaration.signature_range.end + declaration.range.start), text, @@ -176,7 +202,7 @@ impl<'a> PlannedPrompt<'a> { } } SnippetStyle::Declaration => PlannedSnippet { - path: &declaration.path, + path: declaration.path.clone(), range: declaration.range.clone(), text: &declaration.text, text_is_truncated: declaration.text_is_truncated, @@ -220,7 +246,7 @@ impl<'a> PlannedPrompt<'a> { fn additional_parent_signatures( &self, - path: &'a Path, + path: &Arc, parent_index: Option, included_parents: &FxHashSet, ) -> Result)>> { @@ -231,7 +257,7 @@ impl<'a> PlannedPrompt<'a> { fn additional_parent_signatures_impl( &self, - path: &'a Path, + path: &Arc, parent_index: Option, included_parents: &FxHashSet, results: &mut Vec<(usize, PlannedSnippet<'a>)>, @@ -248,7 +274,7 @@ impl<'a> PlannedPrompt<'a> { results.push(( parent_index, PlannedSnippet { - path, + path: path.clone(), range: parent_signature.range.clone(), text: &parent_signature.text, text_is_truncated: parent_signature.text_is_truncated, @@ -265,7 +291,7 @@ impl<'a> PlannedPrompt<'a> { /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive /// chunks. - pub fn to_prompt_string(&self) -> String { + pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> { let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> = FxHashMap::default(); for snippet in &self.snippets { @@ -279,14 +305,14 @@ impl<'a> PlannedPrompt<'a> { let mut file_snippets = Vec::new(); let mut excerpt_file_snippets = Vec::new(); for (file_path, snippets) in file_to_snippets { - if file_path == &self.request.excerpt_path { + if file_path == self.request.excerpt_path.as_ref() { excerpt_file_snippets = snippets; } else { file_snippets.push((file_path, snippets, false)); } } let excerpt_snippet = PlannedSnippet { - path: &self.request.excerpt_path, + path: self.request.excerpt_path.clone(), range: self.request.excerpt_range.clone(), text: &self.request.excerpt, text_is_truncated: false, @@ -294,32 +320,39 @@ impl<'a> PlannedPrompt<'a> { excerpt_file_snippets.push(&excerpt_snippet); file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true)); - let mut excerpt_file_insertions = vec![ - ( - self.request.excerpt_range.start, - EDITABLE_REGION_START_MARKER_WITH_NEWLINE, - ), - ( + let mut excerpt_file_insertions = match self.request.prompt_format { + PromptFormat::MarkedExcerpt => vec![ + ( + self.request.excerpt_range.start, + EDITABLE_REGION_START_MARKER_WITH_NEWLINE, + ), + ( + self.request.excerpt_range.start + self.request.cursor_offset, + CURSOR_MARKER, + ), + ( + self.request + .excerpt_range + .end + .saturating_sub(0) + .max(self.request.excerpt_range.start), + EDITABLE_REGION_END_MARKER_WITH_NEWLINE, + ), + ], + PromptFormat::LabeledSections => vec![( self.request.excerpt_range.start + self.request.cursor_offset, CURSOR_MARKER, - ), - ( - self.request - .excerpt_range - .end - .saturating_sub(0) - .max(self.request.excerpt_range.start), - EDITABLE_REGION_END_MARKER_WITH_NEWLINE, - ), - ]; - - let mut output = String::new(); - output.push_str("## User Edits\n\n"); - Self::push_events(&mut output, &self.request.events); - - output.push_str("\n## Code\n\n"); - Self::push_file_snippets(&mut output, &mut excerpt_file_insertions, file_snippets); - output + )], + }; + + let mut prompt = String::new(); + prompt.push_str("## User Edits\n\n"); + Self::push_events(&mut prompt, &self.request.events); + + prompt.push_str("\n## Code\n\n"); + let section_labels = + self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?; + Ok((prompt, section_labels)) } fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) { @@ -366,79 +399,93 @@ impl<'a> PlannedPrompt<'a> { } fn push_file_snippets( + &self, output: &mut String, excerpt_file_insertions: &mut Vec<(usize, &'static str)>, - file_snippets: Vec<(&Path, Vec<&PlannedSnippet>, bool)>, - ) { - fn push_excerpt_file_range( - range: Range, - text: &str, - excerpt_file_insertions: &mut Vec<(usize, &'static str)>, - output: &mut String, - ) { - let mut last_offset = range.start; - let mut i = 0; - while i < excerpt_file_insertions.len() { - let (offset, insertion) = &excerpt_file_insertions[i]; - let found = *offset >= range.start && *offset <= range.end; - if found { - output.push_str(&text[last_offset - range.start..offset - range.start]); - output.push_str(insertion); - last_offset = *offset; - excerpt_file_insertions.remove(i); - continue; - } - i += 1; - } - output.push_str(&text[last_offset - range.start..]); - } + file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>, + ) -> Result { + let mut section_ranges = Vec::new(); + let mut excerpt_index = None; for (file_path, mut snippets, is_excerpt_file) in file_snippets { - output.push_str(&format!("```{}\n", file_path.display())); - - let mut last_included_range: Option> = None; snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end))); + + // TODO: What if the snippets get expanded too large to be editable? + let mut current_snippet: Option<(&PlannedSnippet, Range)> = None; + let mut disjoint_snippets: Vec<(&PlannedSnippet, Range)> = Vec::new(); for snippet in snippets { - if let Some(last_range) = &last_included_range - && snippet.range.start < last_range.end + if let Some((_, current_snippet_range)) = current_snippet.as_mut() + && snippet.range.start < current_snippet_range.end { - if snippet.range.end <= last_range.end { - continue; - } - // TODO: Should probably also handle case where there is just one char (newline) - // between snippets - assume it's a newline. - let text = &snippet.text[last_range.end - snippet.range.start..]; - if is_excerpt_file { - push_excerpt_file_range( - last_range.end..snippet.range.end, - text, - excerpt_file_insertions, - output, - ); - } else { - output.push_str(text); + if snippet.range.end > current_snippet_range.end { + current_snippet_range.end = snippet.range.end; } - last_included_range = Some(last_range.start..snippet.range.end); continue; } - if last_included_range.is_some() { - output.push_str("…\n"); + if let Some(current_snippet) = current_snippet.take() { + disjoint_snippets.push(current_snippet); + } + current_snippet = Some((snippet, snippet.range.clone())); + } + if let Some(current_snippet) = current_snippet.take() { + disjoint_snippets.push(current_snippet); + } + + writeln!(output, "```{}", file_path.display()).ok(); + for (snippet, range) in disjoint_snippets { + let section_index = section_ranges.len(); + + match self.request.prompt_format { + PromptFormat::MarkedExcerpt => { + if range.start > 0 { + output.push_str("…\n"); + } + } + PromptFormat::LabeledSections => { + if is_excerpt_file + && range.start <= self.request.excerpt_range.start + && range.end >= self.request.excerpt_range.end + { + writeln!(output, "<|current_section|>").ok(); + } else { + writeln!(output, "<|section_{}|>", section_index).ok(); + } + } } + if is_excerpt_file { - push_excerpt_file_range( - snippet.range.clone(), - snippet.text, - excerpt_file_insertions, - output, - ); + excerpt_index = Some(section_index); + let mut last_offset = range.start; + let mut i = 0; + while i < excerpt_file_insertions.len() { + let (offset, insertion) = &excerpt_file_insertions[i]; + let found = *offset >= range.start && *offset <= range.end; + if found { + output.push_str( + &snippet.text[last_offset - range.start..offset - range.start], + ); + output.push_str(insertion); + last_offset = *offset; + excerpt_file_insertions.remove(i); + continue; + } + i += 1; + } + output.push_str(&snippet.text[last_offset - range.start..]); } else { output.push_str(snippet.text); } - last_included_range = Some(snippet.range.clone()); + + section_ranges.push((snippet.path.clone(), range)); } output.push_str("```\n\n"); } + + Ok(SectionLabels { + excerpt_index: excerpt_index.context("bug: no snippet found for excerpt")?, + section_ranges, + }) } } diff --git a/crates/zeta2/src/prediction.rs b/crates/zeta2/src/prediction.rs index d7b3c584a0324869921bff2868838a7ba09585ac..cca41efb7c62224e1601001a55d5c6c4c50ff47a 100644 --- a/crates/zeta2/src/prediction.rs +++ b/crates/zeta2/src/prediction.rs @@ -182,7 +182,7 @@ mod tests { // TODO cover more cases when multi-file is supported let big_edits = vec![predict_edits_v3::Edit { - path: PathBuf::from("test.txt"), + path: PathBuf::from("test.txt").into(), range: 0..old.len(), content: new.into(), }]; diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index db461f29b824edb43a2ba416a29e7eaecb276627..5f621d6acf11b1f42e5c2334b8cf03f8e1176d0a 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -1,7 +1,7 @@ use anyhow::{Context as _, Result, anyhow}; use chrono::TimeDelta; use client::{Client, EditPredictionUsage, UserStore}; -use cloud_llm_client::predict_edits_v3::{self, Signature}; +use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature}; use cloud_llm_client::{ EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME, }; @@ -23,7 +23,7 @@ use language_model::{LlmApiToken, RefreshLlmTokenListener}; use project::Project; use release_channel::AppVersion; use std::collections::{HashMap, VecDeque, hash_map}; -use std::path::PathBuf; +use std::path::Path; use std::str::FromStr as _; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -53,6 +53,7 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { excerpt: DEFAULT_EXCERPT_OPTIONS, max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES, max_diagnostic_bytes: 2048, + prompt_format: PromptFormat::MarkedExcerpt, }; #[derive(Clone)] @@ -76,6 +77,7 @@ pub struct ZetaOptions { pub excerpt: EditPredictionExcerptOptions, pub max_prompt_bytes: usize, pub max_diagnostic_bytes: usize, + pub prompt_format: predict_edits_v3::PromptFormat, } pub struct PredictionDebugInfo { @@ -319,7 +321,7 @@ impl Zeta { }); let options = self.options.clone(); let snapshot = buffer.read(cx).snapshot(); - let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else { + let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else { return Task::ready(Err(anyhow!("No file path for excerpt"))); }; let client = self.client.clone(); @@ -412,7 +414,7 @@ impl Zeta { ); let request = make_cloud_request( - excerpt_path.clone(), + excerpt_path, context, events, // TODO data collection @@ -424,6 +426,7 @@ impl Zeta { &worktree_snapshots, index_state.as_deref(), Some(options.max_prompt_bytes), + options.prompt_format, ); let retrieval_time = chrono::Utc::now() - before_retrieval; @@ -686,7 +689,7 @@ impl Zeta { .context("Failed to select excerpt") .map(|context| { make_cloud_request( - excerpt_path.clone(), + excerpt_path.into(), context, // TODO pass everything Vec::new(), @@ -698,6 +701,7 @@ impl Zeta { &worktree_snapshots, index_state.as_deref(), Some(options.max_prompt_bytes), + options.prompt_format, ) }) }) @@ -713,7 +717,7 @@ pub struct ZedUpdateRequiredError { } fn make_cloud_request( - excerpt_path: PathBuf, + excerpt_path: Arc, context: EditPredictionContext, events: Vec, can_collect_data: bool, @@ -724,6 +728,7 @@ fn make_cloud_request( worktrees: &Vec, index_state: Option<&SyntaxIndexState>, prompt_max_bytes: Option, + prompt_format: PromptFormat, ) -> predict_edits_v3::PredictEditsRequest { let mut signatures = Vec::new(); let mut declaration_to_signature_index = HashMap::default(); @@ -755,7 +760,7 @@ fn make_cloud_request( let (text, text_is_truncated) = snippet.declaration.item_text(); referenced_declarations.push(predict_edits_v3::ReferencedDeclaration { - path: path.as_std_path().to_path_buf(), + path: path.as_std_path().into(), text: text.into(), range: snippet.declaration.item_range(), text_is_truncated, @@ -797,6 +802,7 @@ fn make_cloud_request( git_info, debug_info, prompt_max_bytes, + prompt_format, } } diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index 4b0d35936076c9f16a811d5c3a2a35643a768020..3913677915b3535b1e2be3993606457033edbf33 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -244,11 +244,13 @@ impl Zeta2Inspector { ), }; + let zeta_options = this.zeta.read(cx).options(); this.set_options( ZetaOptions { excerpt: excerpt_options, max_prompt_bytes: number_input_value(&this.max_prompt_bytes_input, cx), - max_diagnostic_bytes: this.zeta.read(cx).options().max_diagnostic_bytes, + max_diagnostic_bytes: zeta_options.max_diagnostic_bytes, + prompt_format: zeta_options.prompt_format, }, cx, ); diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml index 38b85d7c3ac583b25f72240bfde6109a04e30c10..7132340a4d884d4e85c8e67330ca01fb6315b514 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/zeta_cli/Cargo.toml @@ -16,6 +16,7 @@ path = "src/main.rs" anyhow.workspace = true clap.workspace = true client.workspace = true +cloud_llm_client.workspace= true cloud_zeta2_prompt.workspace= true debug_adapter_extension.workspace = true edit_prediction_context.workspace = true diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 4a6efbefb0407a1f835da95413b616a2dd0328ef..4460d660055bb5647ea5ef8f87d049d9c115b308 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -2,6 +2,7 @@ mod headless; use anyhow::{Result, anyhow}; use clap::{Args, Parser, Subcommand}; +use cloud_llm_client::predict_edits_v3::PromptFormat; use edit_prediction_context::EditPredictionExcerptOptions; use futures::channel::mpsc; use futures::{FutureExt as _, StreamExt as _}; @@ -74,6 +75,16 @@ struct Zeta2Args { target_before_cursor_over_total_bytes: f32, #[arg(long, default_value_t = 1024)] max_diagnostic_bytes: usize, + #[arg(long, value_parser = parse_format)] + format: PromptFormat, +} + +fn parse_format(s: &str) -> Result { + match s { + "marked_excerpt" => Ok(PromptFormat::MarkedExcerpt), + "labeled_sections" => Ok(PromptFormat::LabeledSections), + _ => Err(anyhow!("Invalid format: {}", s)), + } } #[derive(Debug, Clone)] @@ -228,6 +239,7 @@ async fn get_context( }, max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes, max_prompt_bytes: zeta2_args.max_prompt_bytes, + prompt_format: zeta2_args.format, }) }); // TODO: Actually wait for indexing. @@ -240,13 +252,9 @@ async fn get_context( zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx) })? .await?; - let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate( - &request, - &cloud_zeta2_prompt::PlanOptions { - max_bytes: zeta2_args.max_prompt_bytes, - }, - )?; - anyhow::Ok(planned_prompt.to_prompt_string()) + let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(&request)?; + // TODO: Output the section label ranges + anyhow::Ok(planned_prompt.to_prompt_string()?.0) }) })? .await?,