Send additional context for edit predictions data via a telemetry event

Michael Sloan created

Also:

* Removes old PredictEditsBody fields that don't have anticipated future use

* Sorts diagnostics by proximity to cursor and truncates based on json byte count

* Brings back cursor_offset

Change summary

crates/cloud_llm_client/src/cloud_llm_client.rs |  37 +
crates/language/src/buffer.rs                   |   2 
crates/zeta/src/zeta.rs                         | 291 +++++++++++-------
crates/zeta_cli/src/main.rs                     |  31 -
4 files changed, 206 insertions(+), 155 deletions(-)

Detailed changes

crates/cloud_llm_client/src/cloud_llm_client.rs 🔗

@@ -138,30 +138,44 @@ pub enum LanguageModelProvider {
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct PredictEditsBody {
-    #[serde(skip_serializing_if = "Option::is_none", default)]
-    pub outline: Option<String>,
     pub input_events: String,
     pub input_excerpt: String,
-    #[serde(skip_serializing_if = "Option::is_none", default)]
-    pub speculated_output: Option<String>,
     /// Whether the user provided consent for sampling this interaction.
     #[serde(default, alias = "data_collection_permission")]
     pub can_collect_data: bool,
+    /// Note that this is no longer sent, in favor of `PredictEditsAdditionalContext`.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
-    /// Info about the git repository state, only present when can_collect_data is true.
+    /// Info about the git repository state, only present when can_collect_data is true. Note that
+    /// this is no longer sent, in favor of `PredictEditsAdditionalContext`.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub git_info: Option<PredictEditsGitInfo>,
 }
 
+/// Additional context only stored when can_collect_data is true for the corresponding edit
+/// predictions request.
 #[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct PredictEditsGitInfo {
+pub struct PredictEditsAdditionalContext {
     /// Path to the file in the repository that contains the input excerpt.
-    #[serde(skip_serializing_if = "Option::is_none", default)]
-    pub input_path: Option<String>,
+    pub input_path: String,
     /// Cursor position within the file that contains the input excerpt.
-    #[serde(skip_serializing_if = "Option::is_none", default)]
-    pub cursor_point: Option<Point>,
+    pub cursor_point: Point,
+    /// Cursor offset in bytes within the file that contains the input excerpt.
+    pub cursor_offset: usize,
+    #[serde(flatten)]
+    pub git_info: PredictEditsGitInfo,
+    /// Diagnostic near the cursor position.
+    #[serde(skip_serializing_if = "Vec::is_empty", default)]
+    pub diagnostic_groups: Vec<(String, Box<serde_json::value::RawValue>)>,
+    /// True if the diagnostics were truncated.
+    pub diagnostic_groups_truncated: bool,
+    /// Recently active files that may be within this repository.
+    #[serde(skip_serializing_if = "Vec::is_empty", default)]
+    pub recent_files: Vec<PredictEditsRecentFile>,
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct PredictEditsGitInfo {
     /// SHA of git HEAD commit at time of prediction.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub head_sha: Option<String>,
@@ -171,9 +185,6 @@ pub struct PredictEditsGitInfo {
     /// URL of the remote called `upstream`.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub remote_upstream_url: Option<String>,
-    /// Recently active files that may be within this repository.
-    #[serde(skip_serializing_if = "Option::is_none", default)]
-    pub recent_files: Option<Vec<PredictEditsRecentFile>>,
 }
 
 /// A zero-indexed point in a text buffer consisting of a row and column.

crates/language/src/buffer.rs 🔗

@@ -146,7 +146,7 @@ pub struct BufferSnapshot {
     pub text: text::BufferSnapshot,
     pub(crate) syntax: SyntaxSnapshot,
     file: Option<Arc<dyn File>>,
-    diagnostics: SmallVec<[(LanguageServerId, DiagnosticSet); 2]>,
+    pub diagnostics: SmallVec<[(LanguageServerId, DiagnosticSet); 2]>,
     remote_selections: TreeMap<ReplicaId, SelectionSet>,
     language: Option<Arc<Language>>,
     non_text_state_update_count: usize,

crates/zeta/src/zeta.rs 🔗

@@ -20,8 +20,8 @@ use anyhow::{Context as _, Result, anyhow};
 use client::{Client, EditPredictionUsage, UserStore};
 use cloud_llm_client::{
     AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
-    PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile, PredictEditsResponse,
-    ZED_VERSION_HEADER_NAME,
+    PredictEditsAdditionalContext, PredictEditsBody, PredictEditsGitInfo, PredictEditsRecentFile,
+    PredictEditsResponse, ZED_VERSION_HEADER_NAME,
 };
 use collections::{HashMap, HashSet, VecDeque};
 use futures::AsyncReadExt;
@@ -68,7 +68,6 @@ const ZED_PREDICT_DATA_COLLECTION_CHOICE: &str = "zed_predict_data_collection_ch
 const MAX_CONTEXT_TOKENS: usize = 150;
 const MAX_REWRITE_TOKENS: usize = 350;
 const MAX_EVENT_TOKENS: usize = 500;
-const MAX_DIAGNOSTIC_GROUPS: usize = 10;
 
 /// Maximum number of events to track.
 const MAX_EVENT_COUNT: usize = 16;
@@ -76,12 +75,15 @@ const MAX_EVENT_COUNT: usize = 16;
 /// Maximum number of recent files to track.
 const MAX_RECENT_PROJECT_ENTRIES_COUNT: usize = 16;
 
-/// Minimum number of milliseconds between recent project entries to keep them
+/// Minimum number of milliseconds between recent file entries.
 const MIN_TIME_BETWEEN_RECENT_PROJECT_ENTRIES: Duration = Duration::from_millis(100);
 
 /// Maximum file path length to include in recent files list.
 const MAX_RECENT_FILE_PATH_LENGTH: usize = 512;
 
+/// Maximum number of JSON bytes for diagnostics in additional context.
+const MAX_DIAGNOSTICS_BYTES: usize = 4096;
+
 /// Maximum number of edit predictions to store for feedback.
 const MAX_SHOWN_COMPLETION_COUNT: usize = 50;
 
@@ -151,7 +153,6 @@ pub struct EditPrediction {
     edits: Arc<[(Range<Anchor>, String)]>,
     snapshot: BufferSnapshot,
     edit_preview: EditPreview,
-    input_outline: Arc<str>,
     input_events: Arc<str>,
     input_excerpt: Arc<str>,
     output_excerpt: Arc<str>,
@@ -407,7 +408,7 @@ impl Zeta {
     fn request_completion_impl<F, R>(
         &mut self,
         workspace: Option<Entity<Workspace>>,
-        project: Option<&Entity<Project>>,
+        project: Option<Entity<Project>>,
         buffer: &Entity<Buffer>,
         cursor: language::Anchor,
         can_collect_data: CanCollectData,
@@ -429,30 +430,21 @@ impl Zeta {
         let llm_token = self.llm_token.clone();
         let app_version = AppVersion::global(cx);
 
-        let cursor_point = cursor.to_point(&snapshot);
-        let cursor_offset = cursor_point.to_offset(&snapshot);
-        let git_info = if matches!(can_collect_data, CanCollectData(true)) {
-            self.gather_git_info(cursor_point, &buffer_snapshotted_at, &snapshot, project, cx)
-        } else {
-            None
-        };
-
         let full_path: Arc<Path> = snapshot
             .file()
             .map(|f| Arc::from(f.full_path(cx).as_path()))
             .unwrap_or_else(|| Arc::from(Path::new("untitled")));
         let full_path_str = full_path.to_string_lossy().to_string();
+        let cursor_point = cursor.to_point(&snapshot);
+        let cursor_offset = cursor_point.to_offset(&snapshot);
         let make_events_prompt = move || prompt_for_events(&events, MAX_EVENT_TOKENS);
-        let gather_task = gather_context(
-            project,
+        let gather_task = cx.background_spawn(gather_context(
             full_path_str,
-            &snapshot,
+            snapshot.clone(),
             cursor_point,
             make_events_prompt,
             can_collect_data,
-            git_info,
-            cx,
-        );
+        ));
 
         cx.spawn(async move |this, cx| {
             let GatherContextOutput {
@@ -461,13 +453,41 @@ impl Zeta {
             } = gather_task.await?;
             let done_gathering_context_at = Instant::now();
 
+            let additional_context_task = if matches!(can_collect_data, CanCollectData(true))
+                && let Some(file) = snapshot.file()
+                && let Ok(project_path) = cx.update(|cx| ProjectPath::from_file(file.as_ref(), cx))
+            {
+                // This is async to reduce latency of the edit predictions request. The downside is
+                // that it will see a slightly later state than was used when gathering context.
+                let snapshot = snapshot.clone();
+                let this = this.clone();
+                Some(cx.spawn(async move |cx| {
+                    if let Ok(Some(task)) = this.update(cx, |this, cx| {
+                        this.gather_additional_context(
+                            cursor_point,
+                            cursor_offset,
+                            snapshot,
+                            &buffer_snapshotted_at,
+                            project_path,
+                            project.as_ref(),
+                            cx,
+                        )
+                    }) {
+                        Some(task.await)
+                    } else {
+                        None
+                    }
+                }))
+            } else {
+                None
+            };
+
             log::debug!(
                 "Events:\n{}\nExcerpt:\n{:?}",
                 body.input_events,
                 body.input_excerpt
             );
 
-            let input_outline = body.outline.clone().unwrap_or_default();
             let input_events = body.input_events.clone();
             let input_excerpt = body.input_excerpt.clone();
 
@@ -524,6 +544,7 @@ impl Zeta {
                 .ok();
             }
 
+            let request_id = response.request_id.clone();
             let edit_prediction = Self::process_completion_response(
                 response,
                 buffer,
@@ -531,7 +552,6 @@ impl Zeta {
                 editable_range,
                 cursor_offset,
                 full_path,
-                input_outline,
                 input_events,
                 input_excerpt,
                 buffer_snapshotted_at,
@@ -555,6 +575,19 @@ impl Zeta {
                 );
             }
 
+            if let Some(additional_context_task) = additional_context_task {
+                cx.background_spawn(async move {
+                    if let Some(additional_context) = additional_context_task.await {
+                        telemetry::event!(
+                            "Edit Prediction Additional Context",
+                            request_id = request_id,
+                            additional_context = additional_context
+                        );
+                    }
+                })
+                .detach();
+            }
+
             edit_prediction
         })
     }
@@ -572,13 +605,12 @@ impl Zeta {
             and then another
             "#};
 
-        let project = None;
         let buffer = cx.new(|cx| Buffer::local(test_buffer_text, cx));
         let position = buffer.read(cx).anchor_before(Point::new(1, 0));
 
         let completion_tasks = vec![
             self.fake_completion(
-                project,
+                None,
                 &buffer,
                 position,
                 PredictEditsResponse {
@@ -590,12 +622,12 @@ And maybe a short line
 Then a few lines
 and then another
 {EDITABLE_REGION_END_MARKER}
-                        ", ),
+                        "),
                 },
                 cx,
             ),
             self.fake_completion(
-                project,
+                None,
                 &buffer,
                 position,
                 PredictEditsResponse {
@@ -612,7 +644,7 @@ and then another
                 cx,
             ),
             self.fake_completion(
-                project,
+                None,
                 &buffer,
                 position,
                 PredictEditsResponse {
@@ -630,7 +662,7 @@ and then another
                 cx,
             ),
             self.fake_completion(
-                project,
+                None,
                 &buffer,
                 position,
                 PredictEditsResponse {
@@ -648,7 +680,7 @@ and then another
                 cx,
             ),
             self.fake_completion(
-                project,
+                None,
                 &buffer,
                 position,
                 PredictEditsResponse {
@@ -665,7 +697,7 @@ and then another
                 cx,
             ),
             self.fake_completion(
-                project,
+                None,
                 &buffer,
                 position,
                 PredictEditsResponse {
@@ -681,7 +713,7 @@ and then another
                 cx,
             ),
             self.fake_completion(
-                project,
+                None,
                 &buffer,
                 position,
                 PredictEditsResponse {
@@ -715,7 +747,7 @@ and then another
     #[cfg(any(test, feature = "test-support"))]
     pub fn fake_completion(
         &mut self,
-        project: Option<&Entity<Project>>,
+        project: Option<Entity<Project>>,
         buffer: &Entity<Buffer>,
         position: language::Anchor,
         response: PredictEditsResponse,
@@ -736,7 +768,7 @@ and then another
 
     pub fn request_completion(
         &mut self,
-        project: Option<&Entity<Project>>,
+        project: Option<Entity<Project>>,
         buffer: &Entity<Buffer>,
         position: language::Anchor,
         can_collect_data: CanCollectData,
@@ -904,7 +936,6 @@ and then another
         editable_range: Range<usize>,
         cursor_offset: usize,
         path: Arc<Path>,
-        input_outline: String,
         input_events: String,
         input_excerpt: String,
         buffer_snapshotted_at: Instant,
@@ -949,7 +980,6 @@ and then another
                 edits,
                 edit_preview,
                 snapshot,
-                input_outline: input_outline.into(),
                 input_events: input_events.into(),
                 input_excerpt: input_excerpt.into(),
                 output_excerpt,
@@ -1078,7 +1108,6 @@ and then another
             rating,
             input_events = completion.input_events,
             input_excerpt = completion.input_excerpt,
-            input_outline = completion.input_outline,
             output_excerpt = completion.output_excerpt,
             feedback
         );
@@ -1136,17 +1165,17 @@ and then another
         }
     }
 
-    fn gather_git_info(
+    fn gather_additional_context(
         &mut self,
         cursor_point: language::Point,
+        cursor_offset: usize,
+        snapshot: BufferSnapshot,
         buffer_snapshotted_at: &Instant,
-        snapshot: &BufferSnapshot,
+        project_path: ProjectPath,
         project: Option<&Entity<Project>>,
         cx: &mut Context<Self>,
-    ) -> Option<PredictEditsGitInfo> {
+    ) -> Option<Task<PredictEditsAdditionalContext>> {
         let project = project?.read(cx);
-        let file = snapshot.file()?;
-        let project_path = ProjectPath::from_file(file.as_ref(), cx);
         let entry = project.entry_for_path(&project_path, cx)?;
         if !worktree_entry_eligible_for_collection(&entry) {
             return None;
@@ -1155,7 +1184,25 @@ and then another
         let git_store = project.git_store().read(cx);
         let (repository, repo_path) =
             git_store.repository_and_path_for_project_path(&project_path, cx)?;
-        let repo_path_str = repo_path.to_str()?;
+        let repo_path_string = repo_path.to_str()?.to_string();
+
+        let diagnostics = if let Some(local_lsp_store) = project.lsp_store().read(cx).as_local() {
+            snapshot
+                .diagnostics
+                .iter()
+                .filter_map(|(language_server_id, diagnostics)| {
+                    let language_server =
+                        local_lsp_store.running_language_server_for_id(*language_server_id)?;
+                    Some((
+                        *language_server_id,
+                        language_server.name(),
+                        diagnostics.clone(),
+                    ))
+                })
+                .collect()
+        } else {
+            Vec::new()
+        };
 
         repository.update(cx, |repository, cx| {
             let head_sha = repository.head_commit.as_ref()?.sha.to_string();
@@ -1163,14 +1210,61 @@ and then another
             let remote_upstream_url = repository.remote_upstream_url.clone();
             let recent_files = self.recent_files(&buffer_snapshotted_at, repository, cx);
 
-            Some(PredictEditsGitInfo {
-                input_path: Some(repo_path_str.to_string()),
-                cursor_point: Some(to_cloud_llm_client_point(cursor_point)),
-                head_sha: Some(head_sha),
-                remote_origin_url,
-                remote_upstream_url,
-                recent_files: Some(recent_files),
-            })
+            // group, resolve, and select diagnostics on a background thread
+            Some(cx.background_spawn(async move {
+                let mut diagnostic_groups_with_name = Vec::new();
+                for (language_server_id, language_server_name, diagnostics) in
+                    diagnostics.into_iter()
+                {
+                    let mut groups = Vec::new();
+                    diagnostics.groups(language_server_id, &mut groups, &snapshot);
+                    diagnostic_groups_with_name.extend(groups.into_iter().map(|(_, group)| {
+                        (
+                            language_server_name.clone(),
+                            group.resolve::<usize>(&snapshot),
+                        )
+                    }));
+                }
+
+                // sort by proximity to cursor
+                diagnostic_groups_with_name.sort_by_key(|(_, group)| {
+                    let range = &group.entries[group.primary_ix].range;
+                    if range.start >= cursor_offset {
+                        range.start - cursor_offset
+                    } else if cursor_offset >= range.end {
+                        cursor_offset - range.end
+                    } else {
+                        (cursor_offset - range.start).min(range.end - cursor_offset)
+                    }
+                });
+
+                let mut diagnostic_groups = Vec::new();
+                let mut diagnostic_groups_truncated = false;
+                let mut diagnostics_byte_count = 0;
+                for (name, group) in diagnostic_groups_with_name {
+                    let raw_value = serde_json::value::to_raw_value(&group).unwrap();
+                    diagnostics_byte_count += name.0.len() + raw_value.get().len();
+                    if diagnostics_byte_count > MAX_DIAGNOSTICS_BYTES {
+                        diagnostic_groups_truncated = true;
+                        break;
+                    }
+                    diagnostic_groups.push((name.to_string(), raw_value));
+                }
+
+                PredictEditsAdditionalContext {
+                    input_path: repo_path_string,
+                    cursor_point: to_cloud_llm_client_point(cursor_point),
+                    cursor_offset: cursor_offset,
+                    git_info: PredictEditsGitInfo {
+                        head_sha: Some(head_sha),
+                        remote_origin_url,
+                        remote_upstream_url,
+                    },
+                    diagnostic_groups,
+                    diagnostic_groups_truncated,
+                    recent_files,
+                }
+            }))
         })
     }
 
@@ -1319,74 +1413,34 @@ pub struct GatherContextOutput {
     pub editable_range: Range<usize>,
 }
 
-pub fn gather_context(
-    project: Option<&Entity<Project>>,
+pub async fn gather_context(
     full_path_str: String,
-    snapshot: &BufferSnapshot,
+    snapshot: BufferSnapshot,
     cursor_point: language::Point,
     make_events_prompt: impl FnOnce() -> String + Send + 'static,
     can_collect_data: CanCollectData,
-    git_info: Option<PredictEditsGitInfo>,
-    cx: &App,
-) -> Task<Result<GatherContextOutput>> {
-    let local_lsp_store =
-        project.and_then(|project| project.read(cx).lsp_store().read(cx).as_local());
-    let diagnostic_groups: Vec<(String, serde_json::Value)> =
-        if matches!(can_collect_data, CanCollectData(true))
-            && let Some(local_lsp_store) = local_lsp_store
-        {
-            snapshot
-                .diagnostic_groups(None)
-                .into_iter()
-                .filter_map(|(language_server_id, diagnostic_group)| {
-                    let language_server =
-                        local_lsp_store.running_language_server_for_id(language_server_id)?;
-                    let diagnostic_group = diagnostic_group.resolve::<usize>(snapshot);
-                    let language_server_name = language_server.name().to_string();
-                    let serialized = serde_json::to_value(diagnostic_group).unwrap();
-                    Some((language_server_name, serialized))
-                })
-                .collect::<Vec<_>>()
-        } else {
-            Vec::new()
-        };
-
-    cx.background_spawn({
-        let snapshot = snapshot.clone();
-        async move {
-            let diagnostic_groups = if diagnostic_groups.is_empty()
-                || diagnostic_groups.len() >= MAX_DIAGNOSTIC_GROUPS
-            {
-                None
-            } else {
-                Some(diagnostic_groups)
-            };
-
-            let input_excerpt = excerpt_for_cursor_position(
-                cursor_point,
-                &full_path_str,
-                &snapshot,
-                MAX_REWRITE_TOKENS,
-                MAX_CONTEXT_TOKENS,
-            );
-            let input_events = make_events_prompt();
-            let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
-
-            let body = PredictEditsBody {
-                input_events,
-                input_excerpt: input_excerpt.prompt,
-                can_collect_data: can_collect_data.0,
-                diagnostic_groups,
-                git_info,
-                outline: None,
-                speculated_output: None,
-            };
-
-            Ok(GatherContextOutput {
-                body,
-                editable_range,
-            })
-        }
+) -> Result<GatherContextOutput> {
+    let input_excerpt = excerpt_for_cursor_position(
+        cursor_point,
+        &full_path_str,
+        &snapshot,
+        MAX_REWRITE_TOKENS,
+        MAX_CONTEXT_TOKENS,
+    );
+    let input_events = make_events_prompt();
+    let editable_range = input_excerpt.editable_range.to_offset(&snapshot);
+
+    let body = PredictEditsBody {
+        input_events,
+        input_excerpt: input_excerpt.prompt,
+        can_collect_data: can_collect_data.0,
+        diagnostic_groups: None,
+        git_info: None,
+    };
+
+    Ok(GatherContextOutput {
+        body,
+        editable_range,
     })
 }
 
@@ -1763,13 +1817,7 @@ impl edit_prediction::EditPredictionProvider for ZetaEditPredictionProvider {
             let completion_request = this.update(cx, |this, cx| {
                 this.last_request_timestamp = Instant::now();
                 this.zeta.update(cx, |zeta, cx| {
-                    zeta.request_completion(
-                        project.as_ref(),
-                        &buffer,
-                        position,
-                        can_collect_data,
-                        cx,
-                    )
+                    zeta.request_completion(project, &buffer, position, can_collect_data, cx)
                 })
             });
 
@@ -1983,7 +2031,6 @@ mod tests {
             id: EditPredictionId(Uuid::new_v4()),
             excerpt_range: 0..0,
             cursor_offset: 0,
-            input_outline: "".into(),
             input_events: "".into(),
             input_excerpt: "".into(),
             output_excerpt: "".into(),

crates/zeta_cli/src/main.rs 🔗

@@ -129,15 +129,15 @@ async fn get_context(
         return Err(anyhow!("Absolute paths are not supported in --cursor"));
     }
 
-    let (project, _lsp_open_handle, buffer) = if use_language_server {
-        let (project, lsp_open_handle, buffer) =
+    let (_lsp_open_handle, buffer) = if use_language_server {
+        let (_project, lsp_open_handle, buffer) =
             open_buffer_with_language_server(&worktree_path, &cursor.path, app_state, cx).await?;
-        (Some(project), Some(lsp_open_handle), buffer)
+        (Some(lsp_open_handle), buffer)
     } else {
         let abs_path = worktree_path.join(&cursor.path);
         let content = smol::fs::read_to_string(&abs_path).await?;
         let buffer = cx.new(|cx| Buffer::local(content, cx))?;
-        (None, None, buffer)
+        (None, buffer)
     };
 
     let worktree_name = worktree_path
@@ -172,21 +172,14 @@ async fn get_context(
         None => String::new(),
     };
     // Enable gathering extra data not currently needed for edit predictions
-    let git_info = None;
-    let mut gather_context_output = cx
-        .update(|cx| {
-            gather_context(
-                project.as_ref(),
-                full_path_str,
-                &snapshot,
-                clipped_cursor,
-                move || events,
-                CanCollectData(true),
-                git_info,
-                cx,
-            )
-        })?
-        .await;
+    let mut gather_context_output = gather_context(
+        full_path_str,
+        snapshot,
+        clipped_cursor,
+        move || events,
+        CanCollectData(true),
+    )
+    .await;
 
     // Disable data collection for these requests, as this is currently just used for evals
     if let Ok(gather_context_output) = gather_context_output.as_mut() {