zeta2: Add prompt planner and provide access via zeta_cli (#38691)

Michael Sloan created

Release Notes:

- N/A

Change summary

Cargo.lock                                                |  15 
Cargo.toml                                                |   2 
crates/cloud_llm_client/src/predict_edits_v3.rs           |   5 
crates/cloud_zeta2_prompt/Cargo.toml                      |  20 
crates/cloud_zeta2_prompt/LICENSE-GPL                     |   1 
crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs       | 396 +++++++++
crates/edit_prediction_context/src/declaration.rs         |  58 
crates/edit_prediction_context/src/declaration_scoring.rs |   5 
crates/edit_prediction_context/src/syntax_index.rs        |   4 
crates/edit_prediction_context/src/wip_requests.rs        |  35 
crates/zeta2/src/zeta2.rs                                 |  77 +
crates/zeta_cli/Cargo.toml                                |   5 
crates/zeta_cli/src/main.rs                               | 143 ++
13 files changed, 667 insertions(+), 99 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3225,6 +3225,18 @@ dependencies = [
  "workspace-hack",
 ]
 
+[[package]]
+name = "cloud_zeta2_prompt"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "cloud_llm_client",
+ "ordered-float 2.10.1",
+ "rustc-hash 2.1.1",
+ "strum 0.27.1",
+ "workspace-hack",
+]
+
 [[package]]
 name = "clru"
 version = "0.6.2"
@@ -21683,7 +21695,9 @@ dependencies = [
  "anyhow",
  "clap",
  "client",
+ "cloud_zeta2_prompt",
  "debug_adapter_extension",
+ "edit_prediction_context",
  "extension",
  "fs",
  "futures 0.3.31",
@@ -21710,6 +21724,7 @@ dependencies = [
  "watch",
  "workspace-hack",
  "zeta",
+ "zeta2",
 ]
 
 [[package]]

Cargo.toml 🔗

@@ -35,6 +35,7 @@ members = [
     "crates/cloud_api_client",
     "crates/cloud_api_types",
     "crates/cloud_llm_client",
+    "crates/cloud_zeta2_prompt",
     "crates/collab",
     "crates/collab_ui",
     "crates/collections",
@@ -271,6 +272,7 @@ clock = { path = "crates/clock" }
 cloud_api_client = { path = "crates/cloud_api_client" }
 cloud_api_types = { path = "crates/cloud_api_types" }
 cloud_llm_client = { path = "crates/cloud_llm_client" }
+cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
 collab = { path = "crates/collab" }
 collab_ui = { path = "crates/collab_ui" }
 collections = { path = "crates/collections" }

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -48,6 +48,9 @@ pub struct Signature {
     pub text_is_truncated: bool,
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub parent_index: Option<usize>,
+    /// Range of `text` within the file, possibly truncated according to `text_is_truncated`. The
+    /// file is implicitly the file that contains the descendant declaration or excerpt.
+    pub range: Range<usize>,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
@@ -55,7 +58,7 @@ pub struct ReferencedDeclaration {
     pub path: PathBuf,
     pub text: String,
     pub text_is_truncated: bool,
-    /// Range of `text` within file, potentially truncated according to `text_is_truncated`
+    /// Range of `text` within file, possibly truncated according to `text_is_truncated`
     pub range: Range<usize>,
     /// Range within `text`
     pub signature_range: Range<usize>,

crates/cloud_zeta2_prompt/Cargo.toml 🔗

@@ -0,0 +1,20 @@
+[package]
+name = "cloud_zeta2_prompt"
+version = "0.1.0"
+publish.workspace = true
+edition.workspace = true
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/cloud_zeta2_prompt.rs"
+
+[dependencies]
+anyhow.workspace = true
+cloud_llm_client.workspace = true
+ordered-float.workspace = true
+rustc-hash.workspace = true
+strum.workspace = true
+workspace-hack.workspace = true

crates/cloud_zeta2_prompt/src/cloud_zeta2_prompt.rs 🔗

@@ -0,0 +1,396 @@
+//! Zeta2 prompt planning and generation code shared with cloud.
+
+use anyhow::{Result, anyhow};
+use cloud_llm_client::predict_edits_v3::{self, ReferencedDeclaration};
+use ordered_float::OrderedFloat;
+use rustc_hash::{FxHashMap, FxHashSet};
+use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
+use strum::{EnumIter, IntoEnumIterator};
+
+pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
+/// NOTE: Differs from zed version of constant - includes a newline
+pub const EDITABLE_REGION_START_MARKER: &str = "<|editable_region_start|>\n";
+/// NOTE: Differs from zed version of constant - includes a newline
+pub const EDITABLE_REGION_END_MARKER: &str = "<|editable_region_end|>\n";
+
+pub struct PlannedPrompt<'a> {
+    request: &'a predict_edits_v3::PredictEditsRequest,
+    /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
+    /// `to_prompt_string`.
+    snippets: Vec<PlannedSnippet<'a>>,
+    budget_used: usize,
+}
+
+pub struct PlanOptions {
+    pub max_bytes: usize,
+}
+
+#[derive(Clone, Debug)]
+pub struct PlannedSnippet<'a> {
+    path: &'a Path,
+    range: Range<usize>,
+    text: &'a str,
+    // TODO: Indicate this in the output
+    #[allow(dead_code)]
+    text_is_truncated: bool,
+}
+
+#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
+pub enum SnippetStyle {
+    Signature,
+    Declaration,
+}
+
+impl<'a> PlannedPrompt<'a> {
+    /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
+    ///
+    /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
+    /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
+    /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
+    /// upgrade.
+    ///
+    /// TODO: Implement an early halting condition. One option might be to have another priority
+    /// queue where the score is the size, and update it accordingly. Another option might be to
+    /// have some simpler heuristic like bailing after N failed insertions, or based on how much
+    /// budget is left.
+    ///
+    /// TODO: Has the current known sources of imprecision:
+    ///
+    /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
+    /// plan even though the containing struct is already included.
+    ///
+    /// * Does not consider cost of signatures when ranking snippets - this is tricky since
+    /// signatures may be shared by multiple snippets.
+    ///
+    /// * Does not include file paths / other text when considering max_bytes.
+    pub fn populate(
+        request: &'a predict_edits_v3::PredictEditsRequest,
+        options: &PlanOptions,
+    ) -> Result<Self> {
+        let mut this = PlannedPrompt {
+            request,
+            snippets: Vec::new(),
+            budget_used: request.excerpt.len(),
+        };
+        let mut included_parents = FxHashSet::default();
+        let additional_parents = this.additional_parent_signatures(
+            &request.excerpt_path,
+            request.excerpt_parent,
+            &included_parents,
+        )?;
+        this.add_parents(&mut included_parents, additional_parents);
+
+        if this.budget_used > options.max_bytes {
+            return Err(anyhow!(
+                "Excerpt + signatures size of {} already exceeds budget of {}",
+                this.budget_used,
+                options.max_bytes
+            ));
+        }
+
+        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
+        struct QueueEntry {
+            score_density: OrderedFloat<f32>,
+            declaration_index: usize,
+            style: SnippetStyle,
+        }
+
+        // Initialize priority queue with the best score for each snippet.
+        let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
+        for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
+            let (style, score_density) = SnippetStyle::iter()
+                .map(|style| {
+                    (
+                        style,
+                        OrderedFloat(declaration_score_density(&declaration, style)),
+                    )
+                })
+                .max_by_key(|(_, score_density)| *score_density)
+                .unwrap();
+            queue.push(QueueEntry {
+                score_density,
+                declaration_index,
+                style,
+            });
+        }
+
+        // Knapsack selection loop
+        while let Some(queue_entry) = queue.pop() {
+            let Some(declaration) = request
+                .referenced_declarations
+                .get(queue_entry.declaration_index)
+            else {
+                return Err(anyhow!(
+                    "Invalid declaration index {}",
+                    queue_entry.declaration_index
+                ));
+            };
+
+            let mut additional_bytes = declaration_size(declaration, queue_entry.style);
+            if this.budget_used + additional_bytes > options.max_bytes {
+                continue;
+            }
+
+            let additional_parents = this.additional_parent_signatures(
+                &declaration.path,
+                declaration.parent_index,
+                &mut included_parents,
+            )?;
+            additional_bytes += additional_parents
+                .iter()
+                .map(|(_, snippet)| snippet.text.len())
+                .sum::<usize>();
+            if this.budget_used + additional_bytes > options.max_bytes {
+                continue;
+            }
+
+            this.budget_used += additional_bytes;
+            this.add_parents(&mut included_parents, additional_parents);
+            let planned_snippet = match queue_entry.style {
+                SnippetStyle::Signature => {
+                    let Some(text) = declaration.text.get(declaration.signature_range.clone())
+                    else {
+                        return Err(anyhow!(
+                            "Invalid declaration signature_range {:?} with text.len() = {}",
+                            declaration.signature_range,
+                            declaration.text.len()
+                        ));
+                    };
+                    PlannedSnippet {
+                        path: &declaration.path,
+                        range: (declaration.signature_range.start + declaration.range.start)
+                            ..(declaration.signature_range.end + declaration.range.start),
+                        text,
+                        text_is_truncated: declaration.text_is_truncated,
+                    }
+                }
+                SnippetStyle::Declaration => PlannedSnippet {
+                    path: &declaration.path,
+                    range: declaration.range.clone(),
+                    text: &declaration.text,
+                    text_is_truncated: declaration.text_is_truncated,
+                },
+            };
+            this.snippets.push(planned_snippet);
+
+            // When a Signature is consumed, insert an entry for Definition style.
+            if queue_entry.style == SnippetStyle::Signature {
+                let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
+                let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
+                let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
+                let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
+
+                let score_diff = declaration_score - signature_score;
+                let size_diff = declaration_size.saturating_sub(signature_size);
+                if score_diff > 0.0001 && size_diff > 0 {
+                    queue.push(QueueEntry {
+                        declaration_index: queue_entry.declaration_index,
+                        score_density: OrderedFloat(score_diff / (size_diff as f32)),
+                        style: SnippetStyle::Declaration,
+                    });
+                }
+            }
+        }
+
+        anyhow::Ok(this)
+    }
+
+    fn add_parents(
+        &mut self,
+        included_parents: &mut FxHashSet<usize>,
+        snippets: Vec<(usize, PlannedSnippet<'a>)>,
+    ) {
+        for (parent_index, snippet) in snippets {
+            included_parents.insert(parent_index);
+            self.budget_used += snippet.text.len();
+            self.snippets.push(snippet);
+        }
+    }
+
+    fn additional_parent_signatures(
+        &self,
+        path: &'a Path,
+        parent_index: Option<usize>,
+        included_parents: &FxHashSet<usize>,
+    ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
+        let mut results = Vec::new();
+        self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
+        Ok(results)
+    }
+
+    fn additional_parent_signatures_impl(
+        &self,
+        path: &'a Path,
+        parent_index: Option<usize>,
+        included_parents: &FxHashSet<usize>,
+        results: &mut Vec<(usize, PlannedSnippet<'a>)>,
+    ) -> Result<()> {
+        let Some(parent_index) = parent_index else {
+            return Ok(());
+        };
+        if included_parents.contains(&parent_index) {
+            return Ok(());
+        }
+        let Some(parent_signature) = self.request.signatures.get(parent_index) else {
+            return Err(anyhow!("Invalid parent index {}", parent_index));
+        };
+        results.push((
+            parent_index,
+            PlannedSnippet {
+                path,
+                range: parent_signature.range.clone(),
+                text: &parent_signature.text,
+                text_is_truncated: parent_signature.text_is_truncated,
+            },
+        ));
+        self.additional_parent_signatures_impl(
+            path,
+            parent_signature.parent_index,
+            included_parents,
+            results,
+        )
+    }
+
+    /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
+    /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
+    /// chunks.
+    pub fn to_prompt_string(&self) -> String {
+        let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
+            FxHashMap::default();
+        for snippet in &self.snippets {
+            file_to_snippets
+                .entry(&snippet.path)
+                .or_default()
+                .push(snippet);
+        }
+
+        // Reorder so that file with cursor comes last
+        let mut file_snippets = Vec::new();
+        let mut excerpt_file_snippets = Vec::new();
+        for (file_path, snippets) in file_to_snippets {
+            if file_path == &self.request.excerpt_path {
+                excerpt_file_snippets = snippets;
+            } else {
+                file_snippets.push((file_path, snippets, false));
+            }
+        }
+        let excerpt_snippet = PlannedSnippet {
+            path: &self.request.excerpt_path,
+            range: self.request.excerpt_range.clone(),
+            text: &self.request.excerpt,
+            text_is_truncated: false,
+        };
+        excerpt_file_snippets.push(&excerpt_snippet);
+        file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
+
+        let mut excerpt_file_insertions = vec![
+            (
+                self.request.excerpt_range.start,
+                EDITABLE_REGION_START_MARKER,
+            ),
+            (
+                self.request.excerpt_range.start + self.request.cursor_offset,
+                CURSOR_MARKER,
+            ),
+            (
+                self.request
+                    .excerpt_range
+                    .end
+                    .saturating_sub(0)
+                    .max(self.request.excerpt_range.start),
+                EDITABLE_REGION_END_MARKER,
+            ),
+        ];
+
+        fn push_excerpt_file_range(
+            range: Range<usize>,
+            text: &str,
+            excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
+            output: &mut String,
+        ) {
+            let mut last_offset = range.start;
+            let mut i = 0;
+            while i < excerpt_file_insertions.len() {
+                let (offset, insertion) = &excerpt_file_insertions[i];
+                let found = *offset >= range.start && *offset <= range.end;
+                if found {
+                    output.push_str(&text[last_offset - range.start..offset - range.start]);
+                    output.push_str(insertion);
+                    last_offset = *offset;
+                    excerpt_file_insertions.remove(i);
+                    continue;
+                }
+                i += 1;
+            }
+            output.push_str(&text[last_offset - range.start..]);
+        }
+
+        let mut output = String::new();
+        for (file_path, mut snippets, is_excerpt_file) in file_snippets {
+            output.push_str(&format!("```{}\n", file_path.display()));
+
+            let mut last_included_range: Option<Range<usize>> = None;
+            snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
+            for snippet in snippets {
+                if let Some(last_range) = &last_included_range
+                    && snippet.range.start < last_range.end
+                {
+                    if snippet.range.end <= last_range.end {
+                        continue;
+                    }
+                    // TODO: Should probably also handle case where there is just one char (newline)
+                    // between snippets - assume it's a newline.
+                    let text = &snippet.text[last_range.end - snippet.range.start..];
+                    if is_excerpt_file {
+                        push_excerpt_file_range(
+                            last_range.end..snippet.range.end,
+                            text,
+                            &mut excerpt_file_insertions,
+                            &mut output,
+                        );
+                    } else {
+                        output.push_str(text);
+                    }
+                    last_included_range = Some(last_range.start..snippet.range.end);
+                    continue;
+                }
+                if last_included_range.is_some() {
+                    output.push_str("…\n");
+                }
+                if is_excerpt_file {
+                    push_excerpt_file_range(
+                        snippet.range.clone(),
+                        snippet.text,
+                        &mut excerpt_file_insertions,
+                        &mut output,
+                    );
+                } else {
+                    output.push_str(snippet.text);
+                }
+                last_included_range = Some(snippet.range.clone());
+            }
+
+            output.push_str("```\n\n");
+        }
+
+        output
+    }
+}
+
+fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
+    declaration_score(declaration, style) / declaration_size(declaration, style) as f32
+}
+
+fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
+    match style {
+        SnippetStyle::Signature => declaration.signature_score,
+        SnippetStyle::Declaration => declaration.declaration_score,
+    }
+}
+
+fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
+    match style {
+        SnippetStyle::Signature => declaration.signature_range.len(),
+        SnippetStyle::Declaration => declaration.text.len(),
+    }
+}

crates/edit_prediction_context/src/declaration.rs 🔗

@@ -68,7 +68,7 @@ impl Declaration {
 
     pub fn item_range(&self) -> Range<usize> {
         match self {
-            Declaration::File { declaration, .. } => declaration.item_range_in_file.clone(),
+            Declaration::File { declaration, .. } => declaration.item_range.clone(),
             Declaration::Buffer { declaration, .. } => declaration.item_range.clone(),
         }
     }
@@ -92,7 +92,7 @@ impl Declaration {
     pub fn signature_text(&self) -> (Cow<'_, str>, bool) {
         match self {
             Declaration::File { declaration, .. } => (
-                declaration.text[declaration.signature_range_in_text.clone()].into(),
+                declaration.text[self.signature_range_in_item_text()].into(),
                 declaration.signature_is_truncated,
             ),
             Declaration::Buffer {
@@ -105,15 +105,19 @@ impl Declaration {
         }
     }
 
-    pub fn signature_range_in_item_text(&self) -> Range<usize> {
+    pub fn signature_range(&self) -> Range<usize> {
         match self {
-            Declaration::File { declaration, .. } => declaration.signature_range_in_text.clone(),
-            Declaration::Buffer { declaration, .. } => {
-                declaration.signature_range.start - declaration.item_range.start
-                    ..declaration.signature_range.end - declaration.item_range.start
-            }
+            Declaration::File { declaration, .. } => declaration.signature_range.clone(),
+            Declaration::Buffer { declaration, .. } => declaration.signature_range.clone(),
         }
     }
+
+    pub fn signature_range_in_item_text(&self) -> Range<usize> {
+        let signature_range = self.signature_range();
+        let item_range = self.item_range();
+        signature_range.start.saturating_sub(item_range.start)
+            ..(signature_range.end.saturating_sub(item_range.start)).min(item_range.len())
+    }
 }
 
 fn expand_range_to_line_boundaries_and_truncate(
@@ -141,13 +145,13 @@ pub struct FileDeclaration {
     pub parent: Option<DeclarationId>,
     pub identifier: Identifier,
     /// offset range of the declaration in the file, expanded to line boundaries and truncated
-    pub item_range_in_file: Range<usize>,
-    /// text of `item_range_in_file`
+    pub item_range: Range<usize>,
+    /// text of `item_range`
     pub text: Arc<str>,
     /// whether `text` was truncated
     pub text_is_truncated: bool,
-    /// offset range of the signature within `text`
-    pub signature_range_in_text: Range<usize>,
+    /// offset range of the signature in the file, expanded to line boundaries and truncated
+    pub signature_range: Range<usize>,
     /// whether `signature` was truncated
     pub signature_is_truncated: bool,
 }
@@ -160,31 +164,33 @@ impl FileDeclaration {
             rope,
         );
 
-        // TODO: consider logging if unexpected
-        let signature_start = declaration
-            .signature_range
-            .start
-            .saturating_sub(item_range_in_file.start);
-        let mut signature_end = declaration
-            .signature_range
-            .end
-            .saturating_sub(item_range_in_file.start);
-        let signature_is_truncated = signature_end > item_range_in_file.len();
-        if signature_is_truncated {
-            signature_end = item_range_in_file.len();
+        let (mut signature_range_in_file, mut signature_is_truncated) =
+            expand_range_to_line_boundaries_and_truncate(
+                &declaration.signature_range,
+                ITEM_TEXT_TRUNCATION_LENGTH,
+                rope,
+            );
+
+        if signature_range_in_file.start < item_range_in_file.start {
+            signature_range_in_file.start = item_range_in_file.start;
+            signature_is_truncated = true;
+        }
+        if signature_range_in_file.end > item_range_in_file.end {
+            signature_range_in_file.end = item_range_in_file.end;
+            signature_is_truncated = true;
         }
 
         FileDeclaration {
             parent: None,
             identifier: declaration.identifier,
-            signature_range_in_text: signature_start..signature_end,
+            signature_range: signature_range_in_file,
             signature_is_truncated,
             text: rope
                 .chunks_in_range(item_range_in_file.clone())
                 .collect::<String>()
                 .into(),
             text_is_truncated,
-            item_range_in_file,
+            item_range: item_range_in_file,
         }
     }
 }

crates/edit_prediction_context/src/declaration_scoring.rs 🔗

@@ -40,10 +40,9 @@ impl ScoredSnippet {
     }
 
     pub fn size(&self, style: SnippetStyle) -> usize {
-        // TODO: how to handle truncation?
         match &self.declaration {
             Declaration::File { declaration, .. } => match style {
-                SnippetStyle::Signature => declaration.signature_range_in_text.len(),
+                SnippetStyle::Signature => declaration.signature_range.len(),
                 SnippetStyle::Declaration => declaration.text.len(),
             },
             Declaration::Buffer { declaration, .. } => match style {
@@ -276,6 +275,8 @@ pub struct Scores {
 
 impl Scores {
     fn score(components: &ScoreComponents) -> Scores {
+        // TODO: handle truncation
+
         // Score related to how likely this is the correct declaration, range 0 to 1
         let accuracy_score = if components.is_same_file {
             // TODO: use declaration_line_distance_rank

crates/edit_prediction_context/src/syntax_index.rs 🔗

@@ -578,11 +578,11 @@ mod tests {
 
             let decl = expect_file_decl("c.rs", &decls[0].1, &project, cx);
             assert_eq!(decl.identifier, main.clone());
-            assert_eq!(decl.item_range_in_file, 32..280);
+            assert_eq!(decl.item_range, 32..280);
 
             let decl = expect_file_decl("a.rs", &decls[1].1, &project, cx);
             assert_eq!(decl.identifier, main);
-            assert_eq!(decl.item_range_in_file, 0..98);
+            assert_eq!(decl.item_range, 0..98);
         });
     }
 

crates/edit_prediction_context/src/wip_requests.rs 🔗

@@ -1,35 +0,0 @@
-// To discuss: What to send to the new endpoint? Thinking it'd make sense to put `prompt.rs` from
-// `zeta_context.rs` in cloud.
-//
-// * Run excerpt selection at several different sizes, send the largest size with offsets within for
-// the smaller sizes.
-//
-// * Longer event history.
-//
-// * Many more snippets than could fit in model context - allows ranking experimentation.
-
-pub struct Zeta2Request {
-    pub event_history: Vec<Event>,
-    pub excerpt: String,
-    pub excerpt_subsets: Vec<Zeta2ExcerptSubset>,
-    /// Within `excerpt`
-    pub cursor_position: usize,
-    pub signatures: Vec<String>,
-    pub retrieved_declarations: Vec<ReferencedDeclaration>,
-}
-
-pub struct Zeta2ExcerptSubset {
-    /// Within `excerpt` text.
-    pub excerpt_range: Range<usize>,
-    /// Within `signatures`.
-    pub parent_signatures: Vec<usize>,
-}
-
-pub struct ReferencedDeclaration {
-    pub text: Arc<str>,
-    /// Range within `text`
-    pub signature_range: Range<usize>,
-    /// Indices within `signatures`.
-    pub parent_signatures: Vec<usize>,
-    // A bunch of score metrics
-}

crates/zeta2/src/zeta2.rs 🔗

@@ -48,7 +48,7 @@ pub struct Zeta {
     llm_token: LlmApiToken,
     _llm_token_subscription: Subscription,
     projects: HashMap<EntityId, ZetaProject>,
-    excerpt_options: EditPredictionExcerptOptions,
+    pub excerpt_options: EditPredictionExcerptOptions,
     update_required: bool,
 }
 
@@ -87,7 +87,7 @@ impl Zeta {
             })
     }
 
-    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
+    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
         let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 
         Self {
@@ -478,6 +478,66 @@ impl Zeta {
             }
         }
     }
+
+    // TODO: Dedupe with similar code in request_prediction?
+    pub fn cloud_request_for_zeta_cli(
+        &mut self,
+        project: &Entity<Project>,
+        buffer: &Entity<Buffer>,
+        position: language::Anchor,
+        cx: &mut Context<Self>,
+    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
+        let project_state = self.projects.get(&project.entity_id());
+
+        let index_state = project_state.map(|state| {
+            state
+                .syntax_index
+                .read_with(cx, |index, _cx| index.state().clone())
+        });
+        let excerpt_options = self.excerpt_options.clone();
+        let snapshot = buffer.read(cx).snapshot();
+        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
+            return Task::ready(Err(anyhow!("No file path for excerpt")));
+        };
+        let worktree_snapshots = project
+            .read(cx)
+            .worktrees(cx)
+            .map(|worktree| worktree.read(cx).snapshot())
+            .collect::<Vec<_>>();
+
+        cx.background_spawn(async move {
+            let index_state = if let Some(index_state) = index_state {
+                Some(index_state.lock_owned().await)
+            } else {
+                None
+            };
+
+            let cursor_point = position.to_point(&snapshot);
+
+            let debug_info = true;
+            EditPredictionContext::gather_context(
+                cursor_point,
+                &snapshot,
+                &excerpt_options,
+                index_state.as_deref(),
+            )
+            .context("Failed to select excerpt")
+            .map(|context| {
+                make_cloud_request(
+                    excerpt_path.clone(),
+                    context,
+                    // TODO pass everything
+                    Vec::new(),
+                    false,
+                    Vec::new(),
+                    None,
+                    debug_info,
+                    &worktree_snapshots,
+                    index_state.as_deref(),
+                )
+            })
+        })
+    }
 }
 
 #[derive(Error, Debug)]
@@ -840,13 +900,13 @@ fn make_cloud_request(
 
     for snippet in context.snippets {
         let project_entry_id = snippet.declaration.project_entry_id();
-        // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
-        // Note that currently full_path is currently being used for excerpt_path.
         let Some(path) = worktrees.iter().find_map(|worktree| {
-            let abs_path = worktree.abs_path();
-            worktree
-                .entry_for_id(project_entry_id)
-                .map(|e| abs_path.join(&e.path))
+            worktree.entry_for_id(project_entry_id).map(|entry| {
+                let mut full_path = PathBuf::new();
+                full_path.push(worktree.root_name());
+                full_path.push(&entry.path);
+                full_path
+            })
         }) else {
             continue;
         };
@@ -929,6 +989,7 @@ fn add_signature(
         text: text.into(),
         text_is_truncated,
         parent_index,
+        range: parent_declaration.signature_range(),
     });
     declaration_to_signature_index.insert(declaration_id, signature_index);
     Some(signature_index)

crates/zeta_cli/Cargo.toml 🔗

@@ -16,7 +16,9 @@ path = "src/main.rs"
 anyhow.workspace = true
 clap.workspace = true
 client.workspace = true
+cloud_zeta2_prompt.workspace= true
 debug_adapter_extension.workspace = true
+edit_prediction_context.workspace = true
 extension.workspace = true
 fs.workspace = true
 futures.workspace = true
@@ -37,9 +39,10 @@ serde.workspace = true
 serde_json.workspace = true
 settings.workspace = true
 shellexpand.workspace = true
+smol.workspace = true
 terminal_view.workspace = true
 util.workspace = true
 watch.workspace = true
 workspace-hack.workspace = true
 zeta.workspace = true
-smol.workspace = true
+zeta2.workspace = true

crates/zeta_cli/src/main.rs 🔗

@@ -2,6 +2,7 @@ mod headless;
 
 use anyhow::{Result, anyhow};
 use clap::{Args, Parser, Subcommand};
+use edit_prediction_context::EditPredictionExcerptOptions;
 use futures::channel::mpsc;
 use futures::{FutureExt as _, StreamExt as _};
 use gpui::{AppContext, Application, AsyncApp};
@@ -18,7 +19,7 @@ use std::process::exit;
 use std::str::FromStr;
 use std::sync::Arc;
 use std::time::Duration;
-use zeta::{GatherContextOutput, PerformPredictEditsParams, Zeta, gather_context};
+use zeta::{PerformPredictEditsParams, Zeta};
 
 use crate::headless::ZetaCliAppState;
 
@@ -32,6 +33,12 @@ struct ZetaCliArgs {
 #[derive(Subcommand, Debug)]
 enum Commands {
     Context(ContextArgs),
+    Zeta2Context {
+        #[clap(flatten)]
+        zeta2_args: Zeta2Args,
+        #[clap(flatten)]
+        context_args: ContextArgs,
+    },
     Predict {
         #[arg(long)]
         predict_edits_body: Option<FileOrStdin>,
@@ -53,6 +60,18 @@ struct ContextArgs {
     events: Option<FileOrStdin>,
 }
 
+#[derive(Debug, Args)]
+struct Zeta2Args {
+    #[arg(long, default_value_t = 8192)]
+    prompt_max_bytes: usize,
+    #[arg(long, default_value_t = 2048)]
+    excerpt_max_bytes: usize,
+    #[arg(long, default_value_t = 1024)]
+    excerpt_min_bytes: usize,
+    #[arg(long, default_value_t = 0.66)]
+    target_before_cursor_over_total_bytes: f32,
+}
+
 #[derive(Debug, Clone)]
 enum FileOrStdin {
     File(PathBuf),
@@ -112,11 +131,17 @@ impl FromStr for CursorPosition {
     }
 }
 
+enum GetContextOutput {
+    Zeta1(zeta::GatherContextOutput),
+    Zeta2(String),
+}
+
 async fn get_context(
+    zeta2_args: Option<Zeta2Args>,
     args: ContextArgs,
     app_state: &Arc<ZetaCliAppState>,
     cx: &mut AsyncApp,
-) -> Result<GatherContextOutput> {
+) -> Result<GetContextOutput> {
     let ContextArgs {
         worktree: worktree_path,
         cursor,
@@ -152,9 +177,7 @@ async fn get_context(
             open_buffer_with_language_server(&project, &worktree, &cursor.path, cx).await?;
         (Some(lsp_open_handle), buffer)
     } else {
-        let abs_path = worktree_path.join(&cursor.path);
-        let content = smol::fs::read_to_string(&abs_path).await?;
-        let buffer = cx.new(|cx| Buffer::local(content, cx))?;
+        let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?;
         (None, buffer)
     };
 
@@ -189,33 +212,83 @@ async fn get_context(
         Some(events) => events.read_to_string().await?,
         None => String::new(),
     };
-    let prompt_for_events = move || (events, 0);
-    cx.update(|cx| {
-        gather_context(
-            full_path_str,
-            &snapshot,
-            clipped_cursor,
-            prompt_for_events,
-            cx,
-        )
-    })?
-    .await
+
+    if let Some(zeta2_args) = zeta2_args {
+        Ok(GetContextOutput::Zeta2(
+            cx.update(|cx| {
+                let zeta = cx.new(|cx| {
+                    zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx)
+                });
+                zeta.update(cx, |zeta, cx| {
+                    zeta.register_buffer(&buffer, &project, cx);
+                    zeta.excerpt_options = EditPredictionExcerptOptions {
+                        max_bytes: zeta2_args.excerpt_max_bytes,
+                        min_bytes: zeta2_args.excerpt_min_bytes,
+                        target_before_cursor_over_total_bytes: zeta2_args
+                            .target_before_cursor_over_total_bytes,
+                    }
+                });
+                // TODO: Actually wait for indexing.
+                let timer = cx.background_executor().timer(Duration::from_secs(5));
+                cx.spawn(async move |cx| {
+                    timer.await;
+                    let request = zeta
+                        .update(cx, |zeta, cx| {
+                            let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
+                            zeta.cloud_request_for_zeta_cli(&project, &buffer, cursor, cx)
+                        })?
+                        .await?;
+                    let planned_prompt = cloud_zeta2_prompt::PlannedPrompt::populate(
+                        &request,
+                        &cloud_zeta2_prompt::PlanOptions {
+                            max_bytes: zeta2_args.prompt_max_bytes,
+                        },
+                    )?;
+                    anyhow::Ok(planned_prompt.to_prompt_string())
+                })
+            })?
+            .await?,
+        ))
+    } else {
+        let prompt_for_events = move || (events, 0);
+        Ok(GetContextOutput::Zeta1(
+            cx.update(|cx| {
+                zeta::gather_context(
+                    full_path_str,
+                    &snapshot,
+                    clipped_cursor,
+                    prompt_for_events,
+                    cx,
+                )
+            })?
+            .await?,
+        ))
+    }
 }
 
-pub async fn open_buffer_with_language_server(
+pub async fn open_buffer(
     project: &Entity<Project>,
     worktree: &Entity<Worktree>,
     path: &Path,
     cx: &mut AsyncApp,
-) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
+) -> Result<Entity<Buffer>> {
     let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
         worktree_id: worktree.id(),
         path: path.to_path_buf().into(),
     })?;
 
-    let buffer = project
+    project
         .update(cx, |project, cx| project.open_buffer(project_path, cx))?
-        .await?;
+        .await
+}
+
+pub async fn open_buffer_with_language_server(
+    project: &Entity<Project>,
+    worktree: &Entity<Worktree>,
+    path: &Path,
+    cx: &mut AsyncApp,
+) -> Result<(Entity<Entity<Buffer>>, Entity<Buffer>)> {
+    let buffer = open_buffer(project, worktree, path, cx).await?;
 
     let lsp_open_handle = project.update(cx, |project, cx| {
         project.register_buffer_with_language_servers(&buffer, cx)
@@ -319,11 +392,26 @@ fn main() {
 
     app.run(move |cx| {
         let app_state = Arc::new(headless::init(cx));
+        let is_zeta2_context_command = matches!(args.command, Commands::Zeta2Context { .. });
         cx.spawn(async move |cx| {
             let result = match args.command {
-                Commands::Context(context_args) => get_context(context_args, &app_state, cx)
-                    .await
-                    .map(|output| serde_json::to_string_pretty(&output.body).unwrap()),
+                Commands::Zeta2Context {
+                    zeta2_args,
+                    context_args,
+                } => match get_context(Some(zeta2_args), context_args, &app_state, cx).await {
+                    Ok(GetContextOutput::Zeta1 { .. }) => unreachable!(),
+                    Ok(GetContextOutput::Zeta2(output)) => Ok(output),
+                    Err(err) => Err(err),
+                },
+                Commands::Context(context_args) => {
+                    match get_context(None, context_args, &app_state, cx).await {
+                        Ok(GetContextOutput::Zeta1(output)) => {
+                            Ok(serde_json::to_string_pretty(&output.body).unwrap())
+                        }
+                        Ok(GetContextOutput::Zeta2 { .. }) => unreachable!(),
+                        Err(err) => Err(err),
+                    }
+                }
                 Commands::Predict {
                     predict_edits_body,
                     context_args,
@@ -338,7 +426,10 @@ fn main() {
                             if let Some(predict_edits_body) = predict_edits_body {
                                 serde_json::from_str(&predict_edits_body.read_to_string().await?)?
                             } else if let Some(context_args) = context_args {
-                                get_context(context_args, &app_state, cx).await?.body
+                                match get_context(None, context_args, &app_state, cx).await? {
+                                    GetContextOutput::Zeta1(output) => output.body,
+                                    GetContextOutput::Zeta2 { .. } => unreachable!(),
+                                }
                             } else {
                                 return Err(anyhow!(
                                     "Expected either --predict-edits-body-file \
@@ -363,6 +454,10 @@ fn main() {
             match result {
                 Ok(output) => {
                     println!("{}", output);
+                    // TODO: Remove this once the 5 second delay is properly replaced.
+                    if is_zeta2_context_command {
+                        eprintln!("Note that zeta2-context doesn't yet wait for indexing, instead waits 5 seconds.");
+                    }
                     let _ = cx.update(|cx| cx.quit());
                 }
                 Err(e) => {