Move cloud request building code to zeta2 + other misc changes

Michael Sloan created

Change summary

Cargo.lock                                                    |   4 
crates/cloud_llm_client/src/predict_edits_v3.rs               |  12 
crates/edit_prediction_context/src/declaration_scoring.rs     |   6 
crates/edit_prediction_context/src/edit_prediction_context.rs | 133 ----
crates/edit_prediction_context/src/syntax_index.rs            |   7 
crates/edit_prediction_context/src/text_similarity.rs         |   4 
crates/edit_prediction_tools/src/edit_prediction_tools.rs     |   2 
crates/zeta2/Cargo.toml                                       |   4 
crates/zeta2/src/zeta2.rs                                     | 126 ++++
9 files changed, 160 insertions(+), 138 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -21639,11 +21639,13 @@ name = "zeta2"
 version = "0.1.0"
 dependencies = [
  "client",
+ "cloud_llm_client",
  "edit_prediction",
+ "edit_prediction_context",
  "gpui",
  "language",
+ "log",
  "project",
- "util",
  "workspace-hack",
 ]
 

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -6,7 +6,7 @@ use crate::PredictEditsGitInfo;
 // TODO: snippet ordering within file / relative to excerpt
 
 #[derive(Debug, Clone, Serialize, Deserialize)]
-pub struct Body {
+pub struct PredictEditsRequest {
     pub excerpt: String,
     /// Within `signatures`
     pub excerpt_parent: Option<usize>,
@@ -15,8 +15,8 @@ pub struct Body {
     pub events: Vec<Event>,
     #[serde(default)]
     pub can_collect_data: bool,
-    #[serde(skip_serializing_if = "Option::is_none", default)]
-    pub diagnostic_groups: Option<Vec<(String, serde_json::Value)>>,
+    #[serde(skip_serializing_if = "Vec::is_empty", default)]
+    pub diagnostic_groups: Vec<DiagnosticGroup>,
     /// Info about the git repository state, only present when can_collect_data is true.
     #[serde(skip_serializing_if = "Option::is_none", default)]
     pub git_info: Option<PredictEditsGitInfo>,
@@ -68,6 +68,12 @@ pub struct ScoreComponents {
     pub adjacent_vs_signature_weighted_overlap: f32,
 }
 
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct DiagnosticGroup {
+    pub language_server: String,
+    pub diagnostic_group: serde_json::Value,
+}
+
 /*
 #[derive(Debug, Clone)]
 pub struct SerializedJson<T> {

crates/edit_prediction_context/src/declaration_scoring.rs 🔗

@@ -16,10 +16,6 @@ use crate::{
 
 const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 
-// TODO:
-//
-// * Consider adding declaration_file_count
-
 #[derive(Clone, Debug)]
 pub struct ScoredSnippet {
     pub identifier: Identifier,
@@ -28,7 +24,6 @@ pub struct ScoredSnippet {
     pub scores: Scores,
 }
 
-// TODO: Consider having "Concise" style corresponding to `concise_text`
 #[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
 pub enum SnippetStyle {
     Signature,
@@ -244,6 +239,7 @@ fn score_snippet(
     let adjacent_vs_signature_weighted_overlap =
         weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
 
+    // TODO: Consider adding declaration_file_count
     let score_components = ScoreComponents {
         is_same_file,
         is_referenced_nearby,

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -6,23 +6,15 @@ mod reference;
 mod syntax_index;
 mod text_similarity;
 
-use cloud_llm_client::predict_edits_v3::{self, Signature};
-use collections::HashMap;
-pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier};
-pub use declaration_scoring::SnippetStyle;
-pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
-
 use gpui::{App, AppContext as _, Entity, Task};
 use language::BufferSnapshot;
-pub use reference::references_in_excerpt;
-pub use syntax_index::SyntaxIndex;
 use text::{Point, ToOffset as _};
 
-use crate::{
-    declaration::DeclarationId,
-    declaration_scoring::{ScoredSnippet, scored_snippets},
-    syntax_index::SyntaxIndexState,
-};
+pub use declaration::*;
+pub use declaration_scoring::*;
+pub use excerpt::*;
+pub use reference::*;
+pub use syntax_index::*;
 
 #[derive(Debug)]
 pub struct EditPredictionContext {
@@ -32,7 +24,7 @@ pub struct EditPredictionContext {
 }
 
 impl EditPredictionContext {
-    pub fn gather(
+    pub fn gather_context_in_background(
         cursor_point: Point,
         buffer: BufferSnapshot,
         excerpt_options: EditPredictionExcerptOptions,
@@ -42,25 +34,25 @@ impl EditPredictionContext {
         let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
         cx.background_spawn(async move {
             let index_state = index_state.lock().await;
-            Self::gather_context(cursor_point, buffer, excerpt_options, &index_state)
+            Self::gather_context(cursor_point, &buffer, &excerpt_options, &index_state)
         })
     }
 
-    fn gather_context(
+    pub fn gather_context(
         cursor_point: Point,
-        buffer: BufferSnapshot,
-        excerpt_options: EditPredictionExcerptOptions,
+        buffer: &BufferSnapshot,
+        excerpt_options: &EditPredictionExcerptOptions,
         index_state: &SyntaxIndexState,
     ) -> Option<Self> {
         let excerpt = EditPredictionExcerpt::select_from_buffer(
             cursor_point,
-            &buffer,
-            &excerpt_options,
+            buffer,
+            excerpt_options,
             Some(index_state),
         )?;
-        let excerpt_text = excerpt.text(&buffer);
-        let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer);
-        let cursor_offset = cursor_point.to_offset(&buffer);
+        let excerpt_text = excerpt.text(buffer);
+        let references = references_in_excerpt(&excerpt, &excerpt_text, buffer);
+        let cursor_offset = cursor_point.to_offset(buffer);
 
         let snippets = scored_snippets(
             &index_state,
@@ -68,7 +60,7 @@ impl EditPredictionContext {
             &excerpt_text,
             references,
             cursor_offset,
-            &buffer,
+            buffer,
         );
 
         Some(Self {
@@ -77,97 +69,6 @@ impl EditPredictionContext {
             snippets,
         })
     }
-
-    pub fn cloud_request(
-        cursor_point: Point,
-        buffer: BufferSnapshot,
-        excerpt_options: EditPredictionExcerptOptions,
-        syntax_index: Entity<SyntaxIndex>,
-        cx: &mut App,
-    ) -> Task<Option<predict_edits_v3::Body>> {
-        let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
-        cx.background_spawn(async move {
-            let index_state = index_state.lock().await;
-            Self::gather_context(cursor_point, buffer, excerpt_options, &index_state)
-                .map(|context| context.into_cloud_request(&index_state))
-        })
-    }
-
-    pub fn into_cloud_request(self, index: &SyntaxIndexState) -> predict_edits_v3::Body {
-        let mut signatures = Vec::new();
-        let mut declaration_to_signature_index = HashMap::default();
-        let mut referenced_declarations = Vec::new();
-        let excerpt_parent = self
-            .excerpt
-            .parent_declarations
-            .last()
-            .and_then(|(parent, _)| {
-                add_signature(
-                    *parent,
-                    &mut declaration_to_signature_index,
-                    &mut signatures,
-                    index,
-                )
-            });
-        for snippet in self.snippets {
-            let parent_index = snippet.declaration.parent().and_then(|parent| {
-                add_signature(
-                    parent,
-                    &mut declaration_to_signature_index,
-                    &mut signatures,
-                    index,
-                )
-            });
-            let (text, text_is_truncated) = snippet.declaration.item_text();
-            referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
-                text: text.into(),
-                text_is_truncated,
-                signature_range: snippet.declaration.signature_range_in_item_text(),
-                parent_index,
-                score_components: snippet.score_components,
-                signature_score: snippet.scores.signature,
-                declaration_score: snippet.scores.declaration,
-            });
-        }
-        predict_edits_v3::Body {
-            excerpt: self.excerpt_text.body,
-            referenced_declarations,
-            signatures,
-            excerpt_parent,
-            // todo!
-            events: vec![],
-            can_collect_data: false,
-            diagnostic_groups: None,
-            git_info: None,
-        }
-    }
-}
-
-fn add_signature(
-    declaration_id: DeclarationId,
-    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
-    signatures: &mut Vec<Signature>,
-    index: &SyntaxIndexState,
-) -> Option<usize> {
-    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
-        return Some(*signature_index);
-    }
-    let Some(parent_declaration) = index.declaration(declaration_id) else {
-        log::error!("bug: missing parent declaration");
-        return None;
-    };
-    let parent_index = parent_declaration.parent().and_then(|parent| {
-        add_signature(parent, declaration_to_signature_index, signatures, index)
-    });
-    let (text, text_is_truncated) = parent_declaration.signature_text();
-    let signature_index = signatures.len();
-    signatures.push(Signature {
-        text: text.into(),
-        text_is_truncated,
-        parent_index,
-    });
-    declaration_to_signature_index.insert(declaration_id, signature_index);
-    Some(signature_index)
 }
 
 #[cfg(test)]
@@ -205,7 +106,7 @@ mod tests {
 
         let context = cx
             .update(|cx| {
-                EditPredictionContext::gather(
+                EditPredictionContext::gather_context_in_background(
                     cursor_point,
                     buffer_snapshot,
                     EditPredictionExcerptOptions {

crates/edit_prediction_context/src/syntax_index.rs 🔗

@@ -17,12 +17,6 @@ use crate::declaration::{
 };
 use crate::outline::declarations_in_buffer;
 
-// TODO:
-//
-// * Skip for remote projects
-//
-// * Consider making SyntaxIndex not an Entity.
-
 // Potential future improvements:
 //
 // * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which
@@ -41,7 +35,6 @@ use crate::outline::declarations_in_buffer;
 // * Concurrent slotmap
 //
 // * Use queue for parsing
-//
 
 pub struct SyntaxIndex {
     state: Arc<Mutex<SyntaxIndexState>>,

crates/edit_prediction_context/src/text_similarity.rs 🔗

@@ -9,8 +9,12 @@ use crate::reference::Reference;
 // That implementation could actually be more efficient - no need to track words in the window that
 // are not in the query.
 
+// TODO: Consider a flat sorted Vec<(String, usize)> representation. Intersection can just walk the
+// two in parallel.
+
 static IDENTIFIER_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\b\w+\b").unwrap());
 
+// TODO: use &str or Cow<str> keys?
 #[derive(Debug)]
 pub struct IdentifierOccurrences {
     identifier_to_count: HashMap<String, usize>,

crates/zeta2/Cargo.toml 🔗

@@ -13,9 +13,11 @@ path = "src/zeta2.rs"
 
 [dependencies]
 client.workspace = true
+cloud_llm_client.workspace = true
 edit_prediction.workspace = true
+edit_prediction_context.workspace = true
 gpui.workspace = true
 language.workspace = true
+log.workspace = true
 project.workspace = true
 workspace-hack.workspace = true
-util.workspace = true

crates/zeta2/src/zeta2.rs 🔗

@@ -1,9 +1,14 @@
-use std::{ops::Range, sync::Arc};
-
-use gpui::{App, Entity, EntityId, Task, prelude::*};
-
+use cloud_llm_client::predict_edits_v3::{self, Signature};
 use edit_prediction::{DataCollectionState, Direction, EditPrediction, EditPredictionProvider};
+use edit_prediction_context::{
+    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
+    SyntaxIndexState,
+};
+use gpui::{App, Entity, EntityId, Task, prelude::*};
 use language::{Anchor, ToPoint};
+use language::{BufferSnapshot, Point};
+use std::collections::HashMap;
+use std::{ops::Range, sync::Arc};
 
 pub struct Zeta2EditPredictionProvider {
     current: Option<CurrentEditPrediction>,
@@ -152,3 +157,116 @@ impl EditPredictionProvider for Zeta2EditPredictionProvider {
         Some(current_prediction.prediction)
     }
 }
+
+pub fn make_cloud_request_in_background(
+    cursor_point: Point,
+    buffer: BufferSnapshot,
+    events: Vec<predict_edits_v3::Event>,
+    can_collect_data: bool,
+    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
+    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
+    excerpt_options: EditPredictionExcerptOptions,
+    syntax_index: Entity<SyntaxIndex>,
+    cx: &mut App,
+) -> Task<Option<predict_edits_v3::PredictEditsRequest>> {
+    let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone());
+    cx.background_spawn(async move {
+        let index_state = index_state.lock().await;
+        EditPredictionContext::gather_context(cursor_point, &buffer, &excerpt_options, &index_state)
+            .map(|context| {
+                make_cloud_request(
+                    context,
+                    events,
+                    can_collect_data,
+                    diagnostic_groups,
+                    git_info,
+                    &index_state,
+                )
+            })
+    })
+}
+
+pub fn make_cloud_request(
+    context: EditPredictionContext,
+    events: Vec<predict_edits_v3::Event>,
+    can_collect_data: bool,
+    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
+    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
+    index_state: &SyntaxIndexState,
+) -> predict_edits_v3::PredictEditsRequest {
+    let mut signatures = Vec::new();
+    let mut declaration_to_signature_index = HashMap::default();
+    let mut referenced_declarations = Vec::new();
+    for snippet in context.snippets {
+        let parent_index = snippet.declaration.parent().and_then(|parent| {
+            add_signature(
+                parent,
+                &mut declaration_to_signature_index,
+                &mut signatures,
+                index_state,
+            )
+        });
+        let (text, text_is_truncated) = snippet.declaration.item_text();
+        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
+            text: text.into(),
+            text_is_truncated,
+            signature_range: snippet.declaration.signature_range_in_item_text(),
+            parent_index,
+            score_components: snippet.score_components,
+            signature_score: snippet.scores.signature,
+            declaration_score: snippet.scores.declaration,
+        });
+    }
+
+    let excerpt_parent = context
+        .excerpt
+        .parent_declarations
+        .last()
+        .and_then(|(parent, _)| {
+            add_signature(
+                *parent,
+                &mut declaration_to_signature_index,
+                &mut signatures,
+                index_state,
+            )
+        });
+
+    predict_edits_v3::PredictEditsRequest {
+        excerpt: context.excerpt_text.body,
+        referenced_declarations,
+        signatures,
+        excerpt_parent,
+        // todo!
+        events,
+        can_collect_data,
+        diagnostic_groups,
+        git_info,
+    }
+}
+
+fn add_signature(
+    declaration_id: DeclarationId,
+    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
+    signatures: &mut Vec<Signature>,
+    index: &SyntaxIndexState,
+) -> Option<usize> {
+    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
+        return Some(*signature_index);
+    }
+    let Some(parent_declaration) = index.declaration(declaration_id) else {
+        log::error!("bug: missing parent declaration");
+        return None;
+    };
+    let parent_index = parent_declaration.parent().and_then(|parent| {
+        add_signature(parent, declaration_to_signature_index, signatures, index)
+    });
+    let (text, text_is_truncated) = parent_declaration.signature_text();
+    let signature_index = signatures.len();
+    signatures.push(Signature {
+        text: text.into(),
+        text_is_truncated,
+        parent_index,
+    });
+    declaration_to_signature_index.insert(declaration_id, signature_index);
+    Some(signature_index)
+}