Send paths and ranges in zeta2 requests + add debug_info

Michael Sloan created

Change summary

Cargo.lock                                                    |  2 
crates/cloud_llm_client/Cargo.toml                            |  1 
crates/cloud_llm_client/src/predict_edits_v3.rs               | 31 +++
crates/edit_prediction_context/src/declaration.rs             |  7 
crates/edit_prediction_context/src/edit_prediction_context.rs |  7 
crates/zeta2/Cargo.toml                                       |  1 
crates/zeta2/src/zeta2.rs                                     | 41 ++++
7 files changed, 85 insertions(+), 5 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -3213,6 +3213,7 @@ name = "cloud_llm_client"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "chrono",
  "pretty_assertions",
  "serde",
  "serde_json",
@@ -21656,6 +21657,7 @@ dependencies = [
  "util",
  "workspace",
  "workspace-hack",
+ "worktree",
 ]
 
 [[package]]

crates/cloud_llm_client/Cargo.toml 🔗

@@ -13,6 +13,7 @@ path = "src/cloud_llm_client.rs"
 
 [dependencies]
 anyhow.workspace = true
+chrono.workspace = true
 serde = { workspace = true, features = ["derive", "rc"] }
 serde_json.workspace = true
 strum = { workspace = true, features = ["derive"] }

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -1,3 +1,4 @@
+use chrono::Duration;
 use serde::{Deserialize, Serialize};
 use std::{ops::Range, path::PathBuf};
 use uuid::Uuid;
@@ -9,6 +10,11 @@ use crate::PredictEditsGitInfo;
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct PredictEditsRequest {
     pub excerpt: String,
+    pub excerpt_path: PathBuf,
+    /// Within file
+    pub excerpt_range: Range<usize>,
+    /// Within `excerpt`
+    pub cursor_offset: usize,
     /// Within `signatures`
     pub excerpt_parent: Option<usize>,
     pub signatures: Vec<Signature>,
@@ -21,10 +27,20 @@ pub struct PredictEditsRequest {
     /// Info about the git repository state, only present when can_collect_data is true.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub git_info: Option<PredictEditsGitInfo>,
+    #[serde(default)]
+    pub debug_info: bool,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
-pub enum Event {}
+#[serde(tag = "event")]
+pub enum Event {
+    BufferChange {
+        path: Option<PathBuf>,
+        old_path: Option<PathBuf>,
+        diff: String,
+        predicted: bool,
+    },
+}
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct Signature {
@@ -36,8 +52,11 @@ pub struct Signature {
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct ReferencedDeclaration {
+    pub path: PathBuf,
     pub text: String,
     pub text_is_truncated: bool,
+    /// Range of `text` within file, potentially truncated according to `text_is_truncated`
+    pub range: Range<usize>,
     /// Range within `text`
     pub signature_range: Range<usize>,
     /// Index within `signatures`.
@@ -79,6 +98,16 @@ pub struct DiagnosticGroup {
 pub struct PredictEditsResponse {
     pub request_id: Uuid,
     pub edits: Vec<Edit>,
+    pub debug_info: Option<DebugInfo>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DebugInfo {
+    pub prompt: String,
+    pub prompt_planning_time: Duration,
+    pub model_response: String,
+    pub inference_time: Duration,
+    pub parsing_time: Duration,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]

crates/edit_prediction_context/src/declaration.rs 🔗

@@ -66,6 +66,13 @@ impl Declaration {
         }
     }
 
+    pub fn item_range(&self) -> Range<usize> {
+        match self {
+            Declaration::File { declaration, .. } => declaration.item_range_in_file.clone(),
+            Declaration::Buffer { declaration, .. } => declaration.item_range.clone(),
+        }
+    }
+
     pub fn item_text(&self) -> (Cow<'_, str>, bool) {
         match self {
             Declaration::File { declaration, .. } => (

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -20,6 +20,7 @@ pub use syntax_index::*;
 pub struct EditPredictionContext {
     pub excerpt: EditPredictionExcerpt,
     pub excerpt_text: EditPredictionExcerptText,
+    pub cursor_offset_in_excerpt: usize,
     pub snippets: Vec<ScoredSnippet>,
 }
 
@@ -57,17 +58,18 @@ impl EditPredictionContext {
             index_state,
         )?;
         let excerpt_text = excerpt.text(buffer);
+        let cursor_offset_in_file = cursor_point.to_offset(buffer);
+        let cursor_offset_in_excerpt = cursor_offset_in_file - excerpt.range.start;
 
         let snippets = if let Some(index_state) = index_state {
             let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
-            let cursor_offset = cursor_point.to_offset(buffer);
 
             scored_snippets(
                 &index_state,
                 &excerpt,
                 &excerpt_text,
                 references,
-                cursor_offset,
+                cursor_offset_in_file,
                 buffer,
             )
         } else {
@@ -77,6 +79,7 @@ impl EditPredictionContext {
         Some(Self {
             excerpt,
             excerpt_text,
+            cursor_offset_in_excerpt,
             snippets,
         })
     }

crates/zeta2/Cargo.toml 🔗

@@ -30,3 +30,4 @@ thiserror.workspace = true
 util.workspace = true
 workspace.workspace = true
 workspace-hack.workspace = true
+worktree.workspace = true

crates/zeta2/src/zeta2.rs 🔗

@@ -1,4 +1,4 @@
-use anyhow::{Context as _, Result};
+use anyhow::{Context as _, Result, anyhow};
 use arrayvec::ArrayVec;
 use client::{Client, EditPredictionUsage, UserStore};
 use cloud_llm_client::predict_edits_v3::{self, Signature};
@@ -22,6 +22,7 @@ use language_model::{LlmApiToken, RefreshLlmTokenListener};
 use project::Project;
 use release_channel::AppVersion;
 use std::collections::HashMap;
+use std::path::PathBuf;
 use std::str::FromStr as _;
 use std::time::{Duration, Instant};
 use std::{ops::Range, sync::Arc};
@@ -120,9 +121,17 @@ impl Zeta {
         });
         let excerpt_options = self.excerpt_options.clone();
         let snapshot = buffer.read(cx).snapshot();
+        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
+            return Task::ready(Err(anyhow!("No file path for excerpt")));
+        };
         let client = self.client.clone();
         let llm_token = self.llm_token.clone();
         let app_version = AppVersion::global(cx);
+        let worktree_snapshots = project
+            .read(cx)
+            .worktrees(cx)
+            .map(|worktree| worktree.read(cx).snapshot())
+            .collect::<Vec<_>>();
 
         let request_task = cx.background_spawn({
             let snapshot = snapshot.clone();
@@ -135,6 +144,9 @@ impl Zeta {
 
                 let cursor_point = position.to_point(&snapshot);
 
+                // TODO: make this only true if debug view is open
+                let debug_info = true;
+
                 let Some(request) = EditPredictionContext::gather_context(
                     cursor_point,
                     &snapshot,
@@ -143,12 +155,15 @@ impl Zeta {
                 )
                 .map(|context| {
                     make_cloud_request(
+                        excerpt_path.clone(),
                         context,
                         // TODO pass everything
                         Vec::new(),
                         false,
                         Vec::new(),
                         None,
+                        debug_info,
+                        &worktree_snapshots,
                         index_state.as_deref(),
                     )
                 }) else {
@@ -263,7 +278,7 @@ impl Zeta {
                 } else {
                     request_builder.uri(
                         http_client
-                            .build_zed_llm_url("/predict_edits/v2", &[])?
+                            .build_zed_llm_url("/predict_edits/v3", &[])?
                             .as_ref(),
                     )
                 };
@@ -585,11 +600,14 @@ impl EditPredictionProvider for ZetaEditPredictionProvider {
 }
 
 fn make_cloud_request(
+    excerpt_path: PathBuf,
     context: EditPredictionContext,
     events: Vec<predict_edits_v3::Event>,
     can_collect_data: bool,
     diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
     git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
+    debug_info: bool,
+    worktrees: &Vec<worktree::Snapshot>,
     index_state: Option<&SyntaxIndexState>,
 ) -> predict_edits_v3::PredictEditsRequest {
     let mut signatures = Vec::new();
@@ -597,6 +615,18 @@ fn make_cloud_request(
     let mut referenced_declarations = Vec::new();
 
     for snippet in context.snippets {
+        let project_entry_id = snippet.declaration.project_entry_id();
+        // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
+        // Note that currently full_path is currently being used for excerpt_path.
+        let Some(path) = worktrees.iter().find_map(|worktree| {
+            let abs_path = worktree.abs_path();
+            worktree
+                .entry_for_id(project_entry_id)
+                .map(|e| abs_path.join(&e.path))
+        }) else {
+            continue;
+        };
+
         let parent_index = index_state.and_then(|index_state| {
             snippet.declaration.parent().and_then(|parent| {
                 add_signature(
@@ -607,9 +637,12 @@ fn make_cloud_request(
                 )
             })
         });
+
         let (text, text_is_truncated) = snippet.declaration.item_text();
         referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
+            path,
             text: text.into(),
+            range: snippet.declaration.item_range(),
             text_is_truncated,
             signature_range: snippet.declaration.signature_range_in_item_text(),
             parent_index,
@@ -635,7 +668,10 @@ fn make_cloud_request(
     });
 
     predict_edits_v3::PredictEditsRequest {
+        excerpt_path,
         excerpt: context.excerpt_text.body,
+        excerpt_range: context.excerpt.range,
+        cursor_offset: context.cursor_offset_in_excerpt,
         referenced_declarations,
         signatures,
         excerpt_parent,
@@ -644,6 +680,7 @@ fn make_cloud_request(
         can_collect_data,
         diagnostic_groups,
         git_info,
+        debug_info,
     }
 }