Detailed changes
@@ -3225,6 +3225,18 @@ dependencies = [
"workspace-hack",
]
+[[package]]
+name = "cloud_zeta2_prompt"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "cloud_llm_client",
+ "ordered-float 2.10.1",
+ "rustc-hash 2.1.1",
+ "strum 0.27.1",
+ "workspace-hack",
+]
+
[[package]]
name = "clru"
version = "0.6.2"
@@ -21683,7 +21695,9 @@ dependencies = [
"anyhow",
"clap",
"client",
+ "cloud_zeta2_prompt",
"debug_adapter_extension",
+ "edit_prediction_context",
"extension",
"fs",
"futures 0.3.31",
@@ -21710,6 +21724,7 @@ dependencies = [
"watch",
"workspace-hack",
"zeta",
+ "zeta2",
]
[[package]]
@@ -35,6 +35,7 @@ members = [
"crates/cloud_api_client",
"crates/cloud_api_types",
"crates/cloud_llm_client",
+ "crates/cloud_zeta2_prompt",
"crates/collab",
"crates/collab_ui",
"crates/collections",
@@ -271,6 +272,7 @@ clock = { path = "crates/clock" }
cloud_api_client = { path = "crates/cloud_api_client" }
cloud_api_types = { path = "crates/cloud_api_types" }
cloud_llm_client = { path = "crates/cloud_llm_client" }
+cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
collab = { path = "crates/collab" }
collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" }
@@ -48,6 +48,9 @@ pub struct Signature {
pub text_is_truncated: bool,
#[serde(skip_serializing_if = "Option::is_none", default)]
pub parent_index: Option<usize>,
+ /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
+ /// file is implicitly the file that contains the descendant declaration or excerpt.
+ pub range: Range<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -55,7 +58,7 @@ pub struct ReferencedDeclaration {
pub path: PathBuf,
pub text: String,
pub text_is_truncated: bool,
- /// Range of `text` within file, potentially truncated according to `text_is_truncated`
+ /// Range of `text` within file, possibly truncated according to `text_is_truncated`
pub range: Range<usize>,
/// Range within `text`
pub signature_range: Range<usize>,
@@ -0,0 +1,20 @@
+[package]
+name = "cloud_zeta2_prompt"
+version = "0.1.0"
+publish.workspace = true
+edition.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/cloud_zeta2_prompt.rs"
+
+[dependencies]
+anyhow.workspace = true
+cloud_llm_client.workspace = true
+ordered-float.workspace = true
+rustc-hash.workspace = true
+strum.workspace = true
+workspace-hack.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,396 @@
+//! Zeta2 prompt planning and generation code shared with cloud.
+
+use anyhow::{Result, anyhow};
+use cloud_llm_client::predict_edits_v3::{self, ReferencedDeclaration};
+use ordered_float::OrderedFloat;
+use rustc_hash::{FxHashMap, FxHashSet};
+use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
+use strum::{EnumIter, IntoEnumIterator};
+
+pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
+/// NOTE: Differs from zed version of constant - includes a newline
+pub const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>\n";
+/// NOTE: Differs from zed version of constant - includes a newline
+pub const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>\n";
+
+pub struct PlannedPrompt<'a> {
+ request: &'a predict_edits_v3::PredictEditsRequest,
+ /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
+ /// `to_prompt_string`.
+ snippets: Vec<PlannedSnippet<'a>>,
+ budget_used: usize,
+}
+
+pub struct PlanOptions {
+ pub max_bytes: usize,
+}
+
+#[derive(Clone, Debug)]
+pub struct PlannedSnippet<'a> {
+ path: &'a Path,
+ range: Range<usize>,
+ text: &'a str,
+ // TODO: Indicate this in the output
+ #[allow(dead_code)]
+ text_is_truncated: bool,
+}
+
+#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
+pub enum SnippetStyle {
+ Signature,
+ Declaration,
+}
+
+impl<'a> PlannedPrompt<'a> {
+ /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
+ ///
+ /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
+ /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
+ /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
+ /// upgrade.
+ ///
+ /// TODO: Implement an early halting condition. One option might be to have another priority
+ /// queue where the score is the size, and update it accordingly. Another option might be to
+ /// have some simpler heuristic like bailing after N failed insertions, or based on how much
+ /// budget is left.
+ ///
+ /// TODO: Has the current known sources of imprecision:
+ ///
+ /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
+ /// plan even though the containing struct is already included.
+ ///
+ /// * Does not consider cost of signatures when ranking snippets - this is tricky since
+ /// signatures may be shared by multiple snippets.
+ ///
+ /// * Does not include file paths / other text when considering max_bytes.
+ pub fn populate(
+ request: &'a predict_edits_v3::PredictEditsRequest,
+ options: &PlanOptions,
+ ) -> Result<Self> {
+ let mut this = PlannedPrompt {
+ request,
+ snippets: Vec::new(),
+ budget_used: request.excerpt.len(),
+ };
+ let mut included_parents = FxHashSet::default();
+ let additional_parents = this.additional_parent_signatures(
+ &request.excerpt_path,
+ request.excerpt_parent,
+ &included_parents,
+ )?;
+ this.add_parents(&mut included_parents, additional_parents);
+
+ if this.budget_used > options.max_bytes {
+ return Err(anyhow!(
+ "Excerpt + signatures size of {} already exceeds budget of {}",
+ this.budget_used,
+ options.max_bytes
+ ));
+ }
+
+ #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
+ struct QueueEntry {
+ score_density: OrderedFloat<f32>,
+ declaration_index: usize,
+ style: SnippetStyle,
+ }
+
+ // Initialize priority queue with the best score for each snippet.
+ let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
+ for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
+ let (style, score_density) = SnippetStyle::iter()
+ .map(|style| {
+ (
+ style,
+ OrderedFloat(declaration_score_density(&declaration, style)),
+ )
+ })
+ .max_by_key(|(_, score_density)| *score_density)
+ .unwrap();
+ queue.push(QueueEntry {
+ score_density,
+ declaration_index,
+ style,
+ });
+ }
+
+ // Knapsack selection loop
+ while let Some(queue_entry) = queue.pop() {
+ let Some(declaration) = request
+ .referenced_declarations
+ .get(queue_entry.declaration_index)
+ else {
+ return Err(anyhow!(
+ "Invalid declaration index {}",
+ queue_entry.declaration_index
+ ));
+ };
+
+ let mut additional_bytes = declaration_size(declaration, queue_entry.style);
+ if this.budget_used + additional_bytes > options.max_bytes {
+ continue;
+ }
+
+ let additional_parents = this.additional_parent_signatures(
+ &declaration.path,
+ declaration.parent_index,
+ &mut included_parents,
+ )?;
+ additional_bytes += additional_parents
+ .iter()
+ .map(|(_, snippet)| snippet.text.len())
+ .sum::<usize>();
+ if this.budget_used + additional_bytes > options.max_bytes {
+ continue;
+ }
+
+ this.budget_used += additional_bytes;
+ this.add_parents(&mut included_parents, additional_parents);
+ let planned_snippet = match queue_entry.style {
+ SnippetStyle::Signature => {
+ let Some(text) = declaration.text.get(declaration.signature_range.clone())
+ else {
+ return Err(anyhow!(
+ "Invalid declaration signature_range {:?} with text.len() = {}",
+ declaration.signature_range,
+ declaration.text.len()
+ ));
+ };
+ PlannedSnippet {
+ path: &declaration.path,
+ range: (declaration.signature_range.start + declaration.range.start)
+ ..(declaration.signature_range.end + declaration.range.start),
+ text,
+ text_is_truncated: declaration.text_is_truncated,
+ }
+ }
+ SnippetStyle::Declaration => PlannedSnippet {
+ path: &declaration.path,
+ range: declaration.range.clone(),
+ text: &declaration.text,
+ text_is_truncated: declaration.text_is_truncated,
+ },
+ };
+ this.snippets.push(planned_snippet);
+
+ // When a Signature is consumed, insert an entry for Definition style.
+ if queue_entry.style == SnippetStyle::Signature {
+ let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
+ let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
+ let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
+ let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
+
+ let score_diff = declaration_score - signature_score;
+ let size_diff = declaration_size.saturating_sub(signature_size);
+ if score_diff > 0.0001 && size_diff > 0 {
+ queue.push(QueueEntry {
+ declaration_index: queue_entry.declaration_index,
+ score_density: OrderedFloat(score_diff / (size_diff as f32)),
+ style: SnippetStyle::Declaration,
+ });
+ }
+ }
+ }
+
+ anyhow::Ok(this)
+ }
+
+ fn add_parents(
+ &mut self,
+ included_parents: &mut FxHashSet<usize>,
+ snippets: Vec<(usize, PlannedSnippet<'a>)>,
+ ) {
+ for (parent_index, snippet) in snippets {
+ included_parents.insert(parent_index);
+ self.budget_used += snippet.text.len();
+ self.snippets.push(snippet);
+ }
+ }
+
+ fn additional_parent_signatures(
+ &self,
+ path: &'a Path,
+ parent_index: Option<usize>,
+ included_parents: &FxHashSet<usize>,
+ ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
+ let mut results = Vec::new();
+ self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
+ Ok(results)
+ }
+
+ fn additional_parent_signatures_impl(
+ &self,
+ path: &'a Path,
+ parent_index: Option<usize>,
+ included_parents: &FxHashSet<usize>,
+ results: &mut Vec<(usize, PlannedSnippet<'a>)>,
+ ) -> Result<()> {
+ let Some(parent_index) = parent_index else {
+ return Ok(());
+ };
+ if included_parents.contains(&parent_index) {
+ return Ok(());
+ }
+ let Some(parent_signature) = self.request.signatures.get(parent_index) else {
+ return Err(anyhow!("Invalid parent index {}", parent_index));
+ };
+ results.push((
+ parent_index,
+ PlannedSnippet {
+ path,
+ range: parent_signature.range.clone(),
+ text: &parent_signature.text,
+ text_is_truncated: parent_signature.text_is_truncated,
+ },
+ ));
+ self.additional_parent_signatures_impl(
+ path,
+ parent_signature.parent_index,
+ included_parents,
+ results,
+ )
+ }
+
+ /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
+ /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
+ /// chunks.
+ pub fn to_prompt_string(&self) -> String {
+ let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
+ FxHashMap::default();
+ for snippet in &self.snippets {
+ file_to_snippets
+ .entry(&snippet.path)
+ .or_default()
+ .push(snippet);
+ }
+
+ // Reorder so that file with cursor comes last
+ let mut file_snippets = Vec::new();
+ let mut excerpt_file_snippets = Vec::new();
+ for (file_path, snippets) in file_to_snippets {
+ if file_path == &self.request.excerpt_path {
+ excerpt_file_snippets = snippets;
+ } else {
+ file_snippets.push((file_path, snippets, false));
+ }
+ }
+ let excerpt_snippet = PlannedSnippet {
+ path: &self.request.excerpt_path,
+ range: self.request.excerpt_range.clone(),
+ text: &self.request.excerpt,
+ text_is_truncated: false,
+ };
+ excerpt_file_snippets.push(&excerpt_snippet);
+ file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
+
+ let mut excerpt_file_insertions = vec![
+ (
+ self.request.excerpt_range.start,
+ EDITABLE_REGION_START_MARKER,
+ ),
+ (
+ self.request.excerpt_range.start + self.request.cursor_offset,
+ CURSOR_MARKER,
+ ),
+ (
+ self.request
+ .excerpt_range
+ .end
+ .saturating_sub(0)
+ .max(self.request.excerpt_range.start),
+ EDITABLE_REGION_END_MARKER,
+ ),
+ ];
+
+ fn push_excerpt_file_range(
+ range: Range<usize>,
+ text: &str,
+ excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
+ output: &mut String,
+ ) {
+ let mut last_offset = range.start;
+ let mut i = 0;
+ while i < excerpt_file_insertions.len() {
+ let (offset, insertion) = &excerpt_file_insertions[i];
+ let found = *offset >= range.start && *offset <= range.end;
+ if found {
+ output.push_str(&text[last_offset - range.start..offset - range.start]);
+ output.push_str(insertion);
+ last_offset = *offset;
+ excerpt_file_insertions.remove(i);
+ continue;
+ }
+ i += 1;
+ }
+ output.push_str(&text[last_offset - range.start..]);
+ }
+
+ let mut output = String::new();
+ for (file_path, mut snippets, is_excerpt_file) in file_snippets {
+ output.push_str(&format!("```{}\n", file_path.display()));
+
+ let mut last_included_range: Option<Range<usize>> = None;
+ snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
+ for snippet in snippets {
+ if let Some(last_range) = &last_included_range
+ && snippet.range.start < last_range.end
+ {
+ if snippet.range.end <= last_range.end {
+ continue;
+ }
+ // TODO: Should probably also handle case where there is just one char (newline)
+ // between snippets - assume it's a newline.
+ let text = &snippet.text[last_range.end - snippet.range.start..];
+ if is_excerpt_file {
+ push_excerpt_file_range(
+ last_range.end..snippet.range.end,
+ text,
+ &mut excerpt_file_insertions,
+ &mut output,
+ );
+ } else {
+ output.push_str(text);
+ }
+ last_included_range = Some(last_range.start..snippet.range.end);
+ continue;
+ }
+ if last_included_range.is_some() {
+ output.push_str("…\n");
+ }
+ if is_excerpt_file {
+ push_excerpt_file_range(
+ snippet.range.clone(),
+ snippet.text,
+ &mut excerpt_file_insertions,
+ &mut output,
+ );
+ } else {
+ output.push_str(snippet.text);
+ }
+ last_included_range = Some(snippet.range.clone());
+ }
+
+ output.push_str("```\n\n");
+ }
+
+ output
+ }
+}
+
+fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
+ declaration_score(declaration, style) / declaration_size(declaration, style) as f32
+}
+
+fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
+ match style {
+ SnippetStyle::Signature => declaration.signature_score,
+ SnippetStyle::Declaration => declaration.declaration_score,
+ }
+}
+
+fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
+ match style {
+ SnippetStyle::Signature => declaration.signature_range.len(),
+ SnippetStyle::Declaration => declaration.text.len(),
+ }
+}
@@ -68,7 +68,7 @@ impl Declaration {
pub fn item_range(&self) -> Range<usize> {
match self {
- Declaration::File { declaration, .. } => declaration.item_range_in_file.clone(),
+ Declaration::File { declaration, .. } => declaration.item_range.clone(),
Declaration::Buffer { declaration, .. } => declaration.item_range.clone(),
}
}
@@ -92,7 +92,7 @@ impl Declaration {
pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
match self {
Declaration::File { declaration, .. } => (
- declaration.text[declaration.signature_range_in_text.clone()].into(),
+ declaration.text[self.signature_range_in_item_text()].into(),
declaration.signature_is_truncated,
),
Declaration::Buffer {
@@ -105,15 +105,19 @@ impl Declaration {
}
}
- pub fn signature_range_in_item_text(&self) -> Range<usize> {
+ pub fn signature_range(&self) -> Range<usize> {
match self {
- Declaration::File { declaration, .. } => declaration.signature_range_in_text.clone(),
- Declaration::Buffer { declaration, .. } => {
- declaration.signature_range.start - declaration.item_range.start
- ..declaration.signature_range.end - declaration.item_range.start
- }
+ Declaration::File { declaration, .. } => declaration.signature_range.clone(),
+ Declaration::Buffer { declaration, .. } => declaration.signature_range.clone(),
}
}
+
+ pub fn signature_range_in_item_text(&self) -> Range<usize> {
+ let signature_range = self.signature_range();
+ let item_range = self.item_range();
+ signature_range.start.saturating_sub(item_range.start)
+ ..(signature_range.end.saturating_sub(item_range.start)).min(item_range.len())
+ }
}
fn expand_range_to_line_boundaries_and_truncate(
@@ -141,13 +145,13 @@ pub struct FileDeclaration {
pub parent: Option<DeclarationId>,
pub identifier: Identifier,
/// offset range of the declaration in the file, expanded to line boundaries and truncated
- pub item_range_in_file: Range<usize>,
- /// text of `item_range_in_file`
+ pub item_range: Range<usize>,
+ /// text of `item_range`
pub text: Arc<str>,
/// whether `text` was truncated
pub text_is_truncated: bool,
- /// offset range of the signature within `text`
- pub signature_range_in_text: Range<usize>,
+ /// offset range of the signature in the file, expanded to line boundaries and truncated
+ pub signature_range: Range<usize>,
/// whether `signature` was truncated
pub signature_is_truncated: bool,
}
@@ -160,31 +164,33 @@ impl FileDeclaration {
rope,
);
- // TODO: consider logging if unexpected
- let signature_start = declaration
- .signature_range
- .start
- .saturating_sub(item_range_in_file.start);
- let mut signature_end = declaration
- .signature_range
- .end
- .saturating_sub(item_range_in_file.start);
- let signature_is_truncated = signature_end > item_range_in_file.len();
- if signature_is_truncated {
- signature_end = item_range_in_file.len();
+ let (mut signature_range_in_file, mut signature_is_truncated) =
+ expand_range_to_line_boundaries_and_truncate(
+ &declaration.signature_range,
+ ITEM_TEXT_TRUNCATION_LENGTH,
+ rope,
+ );
+
+ if signature_range_in_file.start < item_range_in_file.start {
+ signature_range_in_file.start = item_range_in_file.start;
+ signature_is_truncated = true;
+ }
+ if signature_range_in_file.end > item_range_in_file.end {
+ signature_range_in_file.end = item_range_in_file.end;
+ signature_is_truncated = true;
}
FileDeclaration {
parent: None,
identifier: declaration.identifier,
- signature_range_in_text: signature_start..signature_end,
+ signature_range: signature_range_in_file,
signature_is_truncated,
text: rope
.chunks_in_range(item_range_in_file.clone())
.collect::<String>()
.into(),
text_is_truncated,
- item_range_in_file,
+ item_range: item_range_in_file,
}
}
}
@@ -40,10 +40,9 @@ impl ScoredSnippet {
}
pub fn size(&self, style: SnippetStyle) -> usize {
- // TODO: how to handle truncation?
match &self.declaration {
Declaration::File { declaration, .. } => match style {
- SnippetStyle::Signature => declaration.signature_range_in_text.len(),
+ SnippetStyle::Signature => declaration.signature_range.len(),
SnippetStyle::Declaration => declaration.text.len(),
},
Declaration::Buffer { declaration, .. } => match style {
@@ -276,6 +275,8 @@ pub struct Scores {
impl Scores {
fn score(components: &ScoreComponents) -> Scores {
+ // TODO: handle truncation
+
// Score related to how likely this is the correct declaration, range 0 to 1
let accuracy_score = if components.is_same_file {
// TODO: use declaration_line_distance_rank
@@ -578,11 +578,11 @@ mod tests {
let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
assert_eq!(decl.identifier, main.clone());
- assert_eq!(decl.item_range_in_file, 32..280);
+ assert_eq!(decl.item_range, 32..280);
let decl = expect_file_decl("a.rs", &decls[1].1, &project, cx);
assert_eq!(decl.identifier, main);
- assert_eq!(decl.item_range_in_file, 0..98);
+ assert_eq!(decl.item_range, 0..98);
});
}
@@ -1,35 +0,0 @@
-// To discuss: What to send to the new endpoint? Thinking it'd make sense to put `prompt.rs` from
-// `zeta_context.rs` in cloud.
-//
-// * Run excerpt selection at several different sizes, send the largest size with offsets within for
-// the smaller sizes.
-//
-// * Longer event history.
-//
-// * Many more snippets than could fit in model context - allows ranking experimentation.
-
-pub struct Zeta2Request {
- pub event_history: Vec<Event>,
- pub excerpt: String,
- pub excerpt_subsets: Vec<Zeta2ExcerptSubset>,
- /// Within `excerpt`
- pub cursor_position: usize,
- pub signatures: Vec<String>,
- pub retrieved_declarations: Vec<ReferencedDeclaration>,
-}
-
-pub struct Zeta2ExcerptSubset {
- /// Within `excerpt` text.
- pub excerpt_range: Range<usize>,
- /// Within `signatures`.
- pub parent_signatures: Vec<usize>,
-}
-
-pub struct ReferencedDeclaration {
- pub text: Arc<str>,
- /// Range within `text`
- pub signature_range: Range<usize>,
- /// Indices within `signatures`.
- pub parent_signatures: Vec<usize>,
- // A bunch of score metrics
-}
@@ -48,7 +48,7 @@ pub struct Zeta {
llm_token: LlmApiToken,
_llm_token_subscription: Subscription,
projects: HashMap<EntityId, ZetaProject>,
- excerpt_options: EditPredictionExcerptOptions,
+ pub excerpt_options: EditPredictionExcerptOptions,
update_required: bool,
}
@@ -87,7 +87,7 @@ impl Zeta {
})
}
- fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
+ pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
Self {
@@ -478,6 +478,66 @@ impl Zeta {
}
}
}
+
+ // TODO: Dedupe with similar code in request_prediction?
+ pub fn cloud_request_for_zeta_cli(
+ &mut self,
+ project: &Entity<Project>,
+ buffer: &Entity<Buffer>,
+ position: language::Anchor,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
+ let project_state = self.projects.get(&project.entity_id());
+
+ let index_state = project_state.map(|state| {
+ state
+ .syntax_index
+ .read_with(cx, |index, _cx| index.state().clone())
+ });
+ let excerpt_options = self.excerpt_options.clone();
+ let snapshot = buffer.read(cx).snapshot();
+ let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
+ return Task::ready(Err(anyhow!("No file path for excerpt")));
+ };
+ let worktree_snapshots = project
+ .read(cx)
+ .worktrees(cx)
+ .map(|worktree| worktree.read(cx).snapshot())
+ .collect::<Vec<_>>();
+
+ cx.background_spawn(async move {
+ let index_state = if let Some(index_state) = index_state {
+ Some(index_state.lock_owned().await)
+ } else {
+ None
+ };
+
+ let cursor_point = position.to_point(&snapshot);
+
+ let debug_info = true;
+ EditPredictionContext::gather_context(
+ cursor_point,
+ &snapshot,
+ &excerpt_options,
+ index_state.as_deref(),
+ )
+ .context("Failed to select excerpt")
+ .map(|context| {
+ make_cloud_request(
+ excerpt_path.clone(),
+ context,
+ // TODO pass everything
+ Vec::new(),
+ false,
+ Vec::new(),
+ None,
+ debug_info,
+ &worktree_snapshots,
+ index_state.as_deref(),
+ )
+ })
+ })
+ }
}
#[derive(Error, Debug)]
@@ -840,13 +900,13 @@ fn make_cloud_request(
for snippet in context.snippets {
let project_entry_id = snippet.declaration.project_entry_id();
- // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
- // Note that currently full_path is currently being used for excerpt_path.
let Some(path) = worktrees.iter().find_map(|worktree| {
- let abs_path = worktree.abs_path();
- worktree
- .entry_for_id(project_entry_id)
- .map(|e| abs_path.join(&e.path))
+ worktree.entry_for_id(project_entry_id).map(|entry| {
+ let mut full_path = PathBuf::new();
+ full_path.push(worktree.root_name());
+ full_path.push(&entry.path);
+ full_path
+ })
}) else {
continue;
};
@@ -929,6 +989,7 @@ fn add_signature(
text: text.into(),
text_is_truncated,
parent_index,
+ range: parent_declaration.signature_range(),
});
declaration_to_signature_index.insert(declaration_id, signature_index);
Some(signature_index)
@@ -16,7 +16,9 @@ path = "src/main.rs"
anyhow.workspace = true
clap.workspace = true
client.workspace = true
+cloud_zeta2_prompt.workspace= true
debug_adapter_extension.workspace = true
+edit_prediction_context.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true
@@ -37,9 +39,10 @@ serde.workspace = true
serde_json.workspace = true
settings.workspace = true
shellexpand.workspace = true
+smol.workspace = true
terminal_view.workspace = true
util.workspace = true
watch.workspace = true
workspace-hack.workspace = true
zeta.workspace = true
-smol.workspace = true
+zeta2.workspace = true
@@ -2,6 +2,7 @@ mod headless;
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand};
+use edit_prediction_context::EditPredictionExcerptOptions;
use futures::channel::mpsc;
use futures::{FutureExt as _, StreamExt as _};
use gpui::{AppContext, Application, AsyncApp};
@@ -18,7 +19,7 @@ use std::process::exit;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
-use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
+use zeta::{PerformPredictEditsParams, Zeta};
use crate::headless::ZetaCliAppState;
@@ -32,6 +33,12 @@ struct ZetaCliArgs {
#[derive(Subcommand, Debug)]
enum Commands {
Context(ContextArgs),
+ Zeta2Context {
+ #[clap(flatten)]
+ zeta2_args: Zeta2Args,
+ #[clap(flatten)]
+ context_args: ContextArgs,
+ },
Predict {
#[arg(long)]
predict_edits_body: Option<FileOrStdin>,
@@ -53,6 +60,18 @@ struct ContextArgs {
events: Option<FileOrStdin>,
}
+#[derive(Debug, Args)]
+struct Zeta2Args {
+ #[arg(long, default_value_t = 8192)]
+ prompt_max_bytes: usize,
+ #[arg(long, default_value_t = 2048)]
+ excerpt_max_bytes: usize,
+ #[arg(long, default_value_t = 1024)]
+ excerpt_min_bytes: usize,
+ #[arg(long, default_value_t = 0.66)]
+ target_before_cursor_over_total_bytes: f32,
+}
+
#[derive(Debug, Clone)]
enum FileOrStdin {
File(PathBuf),
@@ -112,11 +131,17 @@ impl FromStr for CursorPosition {
}
}
+enum GetContextOutput {
+ Zeta1(zeta::GatherContextOutput),
+ Zeta2(String),
+}
+
async fn get_context(
+ zeta2_args: Option<Zeta2Args>,
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
-) -> Result<GatherContextOutput> {
+) -> Result<GetContextOutput> {
let ContextArgs {
worktree: worktree_path,
cursor,
@@ -152,9 +177,7 @@ async fn get_context(
open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
(Some(lsp_open_handle), buffer)
} else {
- let abs_path = worktree_path.join(&cursor.path);
- let content = smol::fs::read_to_string(&abs_path).await?;
- let buffer = cx.new(|cx| Buffer::local(content, cx))?;
+ let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
(None, buffer)
};
@@ -189,33 +212,83 @@ async fn get_context(
Some(events) => events.read_to_string().await?,
None => String::new(),
};
- let prompt_for_events = move || (events, 0);
- cx.update(|cx| {
- gather_context(
- full_path_str,
- &snapshot,
- clipped_cursor,
- prompt_for_events,
- cx,
- )
- })?
- .await
+
+ if let Some(zeta2_args) = zeta2_args {
+ Ok(GetContextOutput::Zeta2(
+ cx.update(|cx| {
+ let zeta = cx.new(|cx| {
+ zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
+ });
+ zeta.update(cx, |zeta, cx| {
+ zeta.register_buffer(&buffer, &project, cx);
+ zeta.excerpt_options = EditPredictionExcerptOptions {
+ max_bytes: zeta2_args.excerpt_max_bytes,
+ min_bytes: zeta2_args.excerpt_min_bytes,
+ target_before_cursor_over_total_bytes: zeta2_args
+ .target_before_cursor_over_total_bytes,
+ }
+ });
+ // TODO: Actually wait for indexing.
+ let timer = cx.background_executor().timer(Duration::from_secs(5));
+ cx.spawn(async move |cx| {
+ timer.await;
+ let request = zeta
+ .update(cx, |zeta, cx| {
+ let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
+ zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
+ })?
+ .await?;
+ let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(
+ &request,
+ &cloud_zeta2_prompt::PlanOptions {
+ max_bytes: zeta2_args.prompt_max_bytes,
+ },
+ )?;
+ anyhow::Ok(planned_prompt.to_prompt_string())
+ })
+ })?
+ .await?,
+ ))
+ } else {
+ let prompt_for_events = move || (events, 0);
+ Ok(GetContextOutput::Zeta1(
+ cx.update(|cx| {
+ zeta::gather_context(
+ full_path_str,
+ &snapshot,
+ clipped_cursor,
+ prompt_for_events,
+ cx,
+ )
+ })?
+ .await?,
+ ))
+ }
}
-pub async fn open_buffer_with_language_server(
+pub async fn open_buffer(
project: &Entity<Project>,
worktree: &Entity<Worktree>,
path: &Path,
cx: &mut AsyncApp,
-) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
+) -> Result<Entity<Buffer>> {
let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
worktree_id: worktree.id(),
path: path.to_path_buf().into(),
})?;
- let buffer = project
+ project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
- .await?;
+ .await
+}
+
+pub async fn open_buffer_with_language_server(
+ project: &Entity<Project>,
+ worktree: &Entity<Worktree>,
+ path: &Path,
+ cx: &mut AsyncApp,
+) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
+ let buffer = open_buffer(project, worktree, path, cx).await?;
let lsp_open_handle = project.update(cx, |project, cx| {
project.register_buffer_with_language_servers(&buffer, cx)
@@ -319,11 +392,26 @@ fn main() {
app.run(move |cx| {
let app_state = Arc::new(headless::init(cx));
+ let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
cx.spawn(async move |cx| {
let result = match args.command {
- Commands::Context(context_args) => get_context(context_args, &app_state, cx)
- .await
- .map(|output| serde_json::to_string_pretty(&output.body).unwrap()),
+ Commands::Zeta2Context {
+ zeta2_args,
+ context_args,
+ } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
+ Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
+ Ok(GetContextOutput::Zeta2(output)) => Ok(output),
+ Err(err) => Err(err),
+ },
+ Commands::Context(context_args) => {
+ match get_context(None, context_args, &app_state, cx).await {
+ Ok(GetContextOutput::Zeta1(output)) => {
+ Ok(serde_json::to_string_pretty(&output.body).unwrap())
+ }
+ Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
+ Err(err) => Err(err),
+ }
+ }
Commands::Predict {
predict_edits_body,
context_args,
@@ -338,7 +426,10 @@ fn main() {
if let Some(predict_edits_body) = predict_edits_body {
serde_json::from_str(&predict_edits_body.read_to_string().await?)?
} else if let Some(context_args) = context_args {
- get_context(context_args, &app_state, cx).await?.body
+ match get_context(None, context_args, &app_state, cx).await? {
+ GetContextOutput::Zeta1(output) => output.body,
+ GetContextOutput::Zeta2 { .. } => unreachable!(),
+ }
} else {
return Err(anyhow!(
"Expected either --predict-edits-body-file \
@@ -363,6 +454,10 @@ fn main() {
match result {
Ok(output) => {
println!("{}", output);
+ // TODO: Remove this once the 5 second delay is properly replaced.
+ if is_zeta2_context_command {
+ eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
+ }
let _ = cx.update(|cx| cx.quit());
}
Err(e) => {