zeta2: Collect nearby diagnostics (#38732)

Michael Sloan and Bennet created

Release Notes:

- N/A

Co-authored-by: Bennet <bennet@zed.dev>

Change summary

crates/cloud_llm_client/src/predict_edits_v3.rs |  12 +
crates/language/src/buffer.rs                   |   6 +
crates/zeta2/src/zeta2.rs                       | 102 ++++++++++++++++--
crates/zeta2_tools/src/zeta2_tools.rs           |  36 ++++--
crates/zeta_cli/src/main.rs                     |  17 ++-
5 files changed, 135 insertions(+), 38 deletions(-)

Detailed changes

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -24,6 +24,8 @@ pub struct PredictEditsRequest {
     pub can_collect_data: bool,
     #[serde(skip_serializing_if = "Vec::is_empty", default)]
     pub diagnostic_groups: Vec<DiagnosticGroup>,
+    #[serde(skip_serializing_if = "is_default", default)]
+    pub diagnostic_groups_truncated: bool,
     /// Info about the git repository state, only present when can_collect_data is true.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub git_info: Option<PredictEditsGitInfo>,
@@ -92,10 +94,8 @@ pub struct ScoreComponents {
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct DiagnosticGroup {
-    pub language_server: String,
-    pub diagnostic_group: serde_json::Value,
-}
+#[serde(transparent)]
+pub struct DiagnosticGroup(pub Box<serde_json::value::RawValue>);
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
 pub struct PredictEditsResponse {
@@ -119,3 +119,7 @@ pub struct Edit {
     pub range: Range<usize>,
     pub content: String,
 }
+
+fn is_default<T: Default + PartialEq>(value: &T) -> bool {
+    *value == T::default()
+}

crates/language/src/buffer.rs 🔗

@@ -4558,6 +4558,12 @@ impl BufferSnapshot {
         })
     }
 
+    /// Raw access to the diagnostic sets. Typically `diagnostic_groups` or `diagnostic_group`
+    /// should be used instead.
+    pub fn diagnostic_sets(&self) -> &SmallVec<[(LanguageServerId, DiagnosticSet); 2]> {
+        &self.diagnostics
+    }
+
     /// Returns all the diagnostic groups associated with the given
     /// language server ID. If no language server ID is provided,
     /// all diagnostics groups are returned.

crates/zeta2/src/zeta2.rs 🔗

@@ -18,7 +18,9 @@ use gpui::{
     App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
     http_client, prelude::*,
 };
-use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
+use language::{
+    Anchor, Buffer, DiagnosticSet, LanguageServerId, OffsetRangeExt as _, ToOffset as _, ToPoint,
+};
 use language::{BufferSnapshot, EditPreview};
 use language_model::{LlmApiToken, RefreshLlmTokenListener};
 use project::Project;
@@ -45,6 +47,11 @@ pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPrediction
     target_before_cursor_over_total_bytes: 0.5,
 };
 
+pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
+    excerpt: DEFAULT_EXCERPT_OPTIONS,
+    max_diagnostic_bytes: 2048,
+};
+
 #[derive(Clone)]
 struct ZetaGlobal(Entity<Zeta>);
 
@@ -56,11 +63,17 @@ pub struct Zeta {
     llm_token: LlmApiToken,
     _llm_token_subscription: Subscription,
     projects: HashMap<EntityId, ZetaProject>,
-    pub excerpt_options: EditPredictionExcerptOptions,
+    options: ZetaOptions,
     update_required: bool,
     debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
 }
 
+#[derive(Debug, Clone, PartialEq)]
+pub struct ZetaOptions {
+    pub excerpt: EditPredictionExcerptOptions,
+    pub max_diagnostic_bytes: usize,
+}
+
 pub struct PredictionDebugInfo {
     pub context: EditPredictionContext,
     pub retrieval_time: TimeDelta,
@@ -113,7 +126,7 @@ impl Zeta {
             projects: HashMap::new(),
             client,
             user_store,
-            excerpt_options: DEFAULT_EXCERPT_OPTIONS,
+            options: DEFAULT_OPTIONS,
             llm_token: LlmApiToken::default(),
             _llm_token_subscription: cx.subscribe(
                 &refresh_llm_token_listener,
@@ -138,12 +151,12 @@ impl Zeta {
         debug_watch_rx
     }
 
-    pub fn excerpt_options(&self) -> &EditPredictionExcerptOptions {
-        &self.excerpt_options
+    pub fn options(&self) -> &ZetaOptions {
+        &self.options
     }
 
-    pub fn set_excerpt_options(&mut self, options: EditPredictionExcerptOptions) {
-        self.excerpt_options = options;
+    pub fn set_options(&mut self, options: ZetaOptions) {
+        self.options = options;
     }
 
     pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
@@ -290,7 +303,7 @@ impl Zeta {
                 .syntax_index
                 .read_with(cx, |index, _cx| index.state().clone())
         });
-        let excerpt_options = self.excerpt_options.clone();
+        let options = self.options.clone();
         let snapshot = buffer.read(cx).snapshot();
         let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
             return Task::ready(Err(anyhow!("No file path for excerpt")));
@@ -343,6 +356,8 @@ impl Zeta {
             })
             .unwrap_or_default();
 
+        let diagnostics = snapshot.diagnostic_sets().clone();
+
         let request_task = cx.background_spawn({
             let snapshot = snapshot.clone();
             let buffer = buffer.clone();
@@ -353,14 +368,15 @@ impl Zeta {
                     None
                 };
 
-                let cursor_point = position.to_point(&snapshot);
+                let cursor_offset = position.to_offset(&snapshot);
+                let cursor_point = cursor_offset.to_point(&snapshot);
 
                 let before_retrieval = chrono::Utc::now();
 
                 let Some(context) = EditPredictionContext::gather_context(
                     cursor_point,
                     &snapshot,
-                    &excerpt_options,
+                    &options.excerpt,
                     index_state.as_deref(),
                 ) else {
                     return Ok(None);
@@ -372,13 +388,22 @@ impl Zeta {
                     None
                 };
 
+                let (diagnostic_groups, diagnostic_groups_truncated) =
+                    Self::gather_nearby_diagnostics(
+                        cursor_offset,
+                        &diagnostics,
+                        &snapshot,
+                        options.max_diagnostic_bytes,
+                    );
+
                 let request = make_cloud_request(
                     excerpt_path.clone(),
                     context,
                     events,
                     // TODO data collection
                     false,
-                    Vec::new(),
+                    diagnostic_groups,
+                    diagnostic_groups_truncated,
                     None,
                     debug_context.is_some(),
                     &worktree_snapshots,
@@ -575,6 +600,52 @@ impl Zeta {
         }
     }
 
+    fn gather_nearby_diagnostics(
+        cursor_offset: usize,
+        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
+        snapshot: &BufferSnapshot,
+        max_diagnostics_bytes: usize,
+    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
+        // TODO: Could make this more efficient
+        let mut diagnostic_groups = Vec::new();
+        for (language_server_id, diagnostics) in diagnostic_sets {
+            let mut groups = Vec::new();
+            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
+            diagnostic_groups.extend(
+                groups
+                    .into_iter()
+                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
+            );
+        }
+
+        // sort by proximity to cursor
+        diagnostic_groups.sort_by_key(|group| {
+            let range = &group.entries[group.primary_ix].range;
+            if range.start >= cursor_offset {
+                range.start - cursor_offset
+            } else if cursor_offset >= range.end {
+                cursor_offset - range.end
+            } else {
+                (cursor_offset - range.start).min(range.end - cursor_offset)
+            }
+        });
+
+        let mut results = Vec::new();
+        let mut diagnostic_groups_truncated = false;
+        let mut diagnostics_byte_count = 0;
+        for group in diagnostic_groups {
+            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
+            diagnostics_byte_count += raw_value.get().len();
+            if diagnostics_byte_count > max_diagnostics_bytes {
+                diagnostic_groups_truncated = true;
+                break;
+            }
+            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
+        }
+
+        (results, diagnostic_groups_truncated)
+    }
+
     // TODO: Dedupe with similar code in request_prediction?
     pub fn cloud_request_for_zeta_cli(
         &mut self,
@@ -590,7 +661,7 @@ impl Zeta {
                 .syntax_index
                 .read_with(cx, |index, _cx| index.state().clone())
         });
-        let excerpt_options = self.excerpt_options.clone();
+        let options = self.options.clone();
         let snapshot = buffer.read(cx).snapshot();
         let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
             return Task::ready(Err(anyhow!("No file path for excerpt")));
@@ -614,7 +685,7 @@ impl Zeta {
             EditPredictionContext::gather_context(
                 cursor_point,
                 &snapshot,
-                &excerpt_options,
+                &options.excerpt,
                 index_state.as_deref(),
             )
             .context("Failed to select excerpt")
@@ -626,6 +697,7 @@ impl Zeta {
                     Vec::new(),
                     false,
                     Vec::new(),
+                    false,
                     None,
                     debug_info,
                     &worktree_snapshots,
@@ -985,6 +1057,7 @@ fn make_cloud_request(
     events: Vec<predict_edits_v3::Event>,
     can_collect_data: bool,
     diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
+    diagnostic_groups_truncated: bool,
     git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
     debug_info: bool,
     worktrees: &Vec<worktree::Snapshot>,
@@ -1058,6 +1131,8 @@ fn make_cloud_request(
         events,
         can_collect_data,
         diagnostic_groups,
+        diagnostic_groups_truncated,
+
         git_info,
         debug_info,
     }
@@ -1141,7 +1216,6 @@ fn interpolate(
 mod tests {
     use super::*;
     use gpui::TestAppContext;
-    use language::ToOffset as _;
 
     #[gpui::test]
     async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {

crates/zeta2_tools/src/zeta2_tools.rs 🔗

@@ -22,7 +22,7 @@ use ui::prelude::*;
 use ui_input::SingleLineInput;
 use util::ResultExt;
 use workspace::{Item, SplitDirection, Workspace};
-use zeta2::Zeta;
+use zeta2::{Zeta, ZetaOptions};
 
 use edit_prediction_context::{EditPredictionExcerptOptions, SnippetStyle};
 
@@ -137,25 +137,28 @@ impl Zeta2Inspector {
             _update_state_task: Task::ready(()),
             _receive_task: receive_task,
         };
-        this.set_input_options(&zeta.read(cx).excerpt_options().clone(), window, cx);
+        this.set_input_options(&zeta.read(cx).options().clone(), window, cx);
         this
     }
 
     fn set_input_options(
         &mut self,
-        options: &EditPredictionExcerptOptions,
+        options: &ZetaOptions,
         window: &mut Window,
         cx: &mut Context<Self>,
     ) {
         self.max_bytes_input.update(cx, |input, cx| {
-            input.set_text(options.max_bytes.to_string(), window, cx);
+            input.set_text(options.excerpt.max_bytes.to_string(), window, cx);
         });
         self.min_bytes_input.update(cx, |input, cx| {
-            input.set_text(options.min_bytes.to_string(), window, cx);
+            input.set_text(options.excerpt.min_bytes.to_string(), window, cx);
         });
         self.cursor_context_ratio_input.update(cx, |input, cx| {
             input.set_text(
-                format!("{:.2}", options.target_before_cursor_over_total_bytes),
+                format!(
+                    "{:.2}",
+                    options.excerpt.target_before_cursor_over_total_bytes
+                ),
                 window,
                 cx,
             );
@@ -163,9 +166,8 @@ impl Zeta2Inspector {
         cx.notify();
     }
 
-    fn set_options(&mut self, options: EditPredictionExcerptOptions, cx: &mut Context<Self>) {
-        self.zeta
-            .update(cx, |this, _cx| this.set_excerpt_options(options));
+    fn set_options(&mut self, options: ZetaOptions, cx: &mut Context<Self>) {
+        self.zeta.update(cx, |this, _cx| this.set_options(options));
 
         const THROTTLE_TIME: Duration = Duration::from_millis(100);
 
@@ -233,7 +235,7 @@ impl Zeta2Inspector {
                         .unwrap_or_default()
                 }
 
-                let options = EditPredictionExcerptOptions {
+                let excerpt_options = EditPredictionExcerptOptions {
                     max_bytes: number_input_value(&this.max_bytes_input, cx),
                     min_bytes: number_input_value(&this.min_bytes_input, cx),
                     target_before_cursor_over_total_bytes: number_input_value(
@@ -242,7 +244,13 @@ impl Zeta2Inspector {
                     ),
                 };
 
-                this.set_options(options, cx);
+                this.set_options(
+                    ZetaOptions {
+                        excerpt: excerpt_options,
+                        ..this.zeta.read(cx).options().clone()
+                    },
+                    cx,
+                );
             },
         )
         .detach();
@@ -525,15 +533,15 @@ impl Render for Zeta2Inspector {
                                             .child(
                                                 ui::Button::new("reset-options", "Reset")
                                                     .disabled(
-                                                        self.zeta.read(cx).excerpt_options()
-                                                            == &zeta2::DEFAULT_EXCERPT_OPTIONS,
+                                                        self.zeta.read(cx).options()
+                                                            == &zeta2::DEFAULT_OPTIONS,
                                                     )
                                                     .style(ButtonStyle::Outlined)
                                                     .size(ButtonSize::Large)
                                                     .on_click(cx.listener(
                                                         |this, _, window, cx| {
                                                             this.set_input_options(
-                                                                &zeta2::DEFAULT_EXCERPT_OPTIONS,
+                                                                &zeta2::DEFAULT_OPTIONS,
                                                                 window,
                                                                 cx,
                                                             );

crates/zeta_cli/src/main.rs 🔗

@@ -70,6 +70,8 @@ struct Zeta2Args {
     excerpt_min_bytes: usize,
     #[arg(long, default_value_t = 0.66)]
     target_before_cursor_over_total_bytes: f32,
+    #[arg(long, default_value_t = 1024)]
+    max_diagnostic_bytes: usize,
 }
 
 #[derive(Debug, Clone)]
@@ -221,12 +223,15 @@ async fn get_context(
                 });
                 zeta.update(cx, |zeta, cx| {
                     zeta.register_buffer(&buffer, &project, cx);
-                    zeta.excerpt_options = EditPredictionExcerptOptions {
-                        max_bytes: zeta2_args.excerpt_max_bytes,
-                        min_bytes: zeta2_args.excerpt_min_bytes,
-                        target_before_cursor_over_total_bytes: zeta2_args
-                            .target_before_cursor_over_total_bytes,
-                    }
+                    zeta.set_options(zeta2::ZetaOptions {
+                        excerpt: EditPredictionExcerptOptions {
+                            max_bytes: zeta2_args.excerpt_max_bytes,
+                            min_bytes: zeta2_args.excerpt_min_bytes,
+                            target_before_cursor_over_total_bytes: zeta2_args
+                                .target_before_cursor_over_total_bytes,
+                        },
+                        max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes,
+                    })
                 });
                 // TODO: Actually wait for indexing.
                 let timer = cx.background_executor().timer(Duration::from_secs(5));