zeta2: Max retrieved definitions option (#40515)

Agus Zubiaga created

Release Notes:

- N/A

Change summary

crates/edit_prediction_context/src/edit_prediction_context.rs | 13 ++-
crates/zeta2/src/zeta2.rs                                     |  2 
crates/zeta2_tools/src/zeta2_tools.rs                         | 14 +++++
crates/zeta_cli/src/main.rs                                   |  6 +-
4 files changed, 26 insertions(+), 9 deletions(-)

Detailed changes

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -27,9 +27,9 @@ pub use predict_edits_v3::Line;
 #[derive(Clone, Debug, PartialEq)]
 pub struct EditPredictionContextOptions {
     pub use_imports: bool,
-    pub use_references: bool,
     pub excerpt: EditPredictionExcerptOptions,
     pub score: EditPredictionScoreOptions,
+    pub max_retrieved_declarations: u8,
 }
 
 #[derive(Clone, Debug)]
@@ -118,7 +118,7 @@ impl EditPredictionContext {
         )?;
         let excerpt_text = excerpt.text(buffer);
 
-        let declarations = if options.use_references
+        let declarations = if options.max_retrieved_declarations > 0
             && let Some(index_state) = index_state
         {
             let excerpt_occurrences =
@@ -136,7 +136,7 @@ impl EditPredictionContext {
 
             let references = get_references(&excerpt, &excerpt_text, buffer);
 
-            scored_declarations(
+            let mut declarations = scored_declarations(
                 &options.score,
                 &index_state,
                 &excerpt,
@@ -146,7 +146,10 @@ impl EditPredictionContext {
                 references,
                 cursor_offset_in_file,
                 buffer,
-            )
+            );
+            // TODO [zeta2] if we need this when we ship, we should probably do it in a smarter way
+            declarations.truncate(options.max_retrieved_declarations as usize);
+            declarations
         } else {
             vec![]
         };
@@ -200,7 +203,6 @@ mod tests {
                     buffer_snapshot,
                     EditPredictionContextOptions {
                         use_imports: true,
-                        use_references: true,
                         excerpt: EditPredictionExcerptOptions {
                             max_bytes: 60,
                             min_bytes: 10,
@@ -209,6 +211,7 @@ mod tests {
                         score: EditPredictionScoreOptions {
                             omit_excerpt_overlaps: true,
                         },
+                        max_retrieved_declarations: u8::MAX,
                     },
                     Some(index.clone()),
                     cx,

crates/zeta2/src/zeta2.rs 🔗

@@ -48,7 +48,7 @@ const MAX_EVENT_COUNT: usize = 16;
 
 pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
     use_imports: true,
-    use_references: false,
+    max_retrieved_declarations: 0,
     excerpt: EditPredictionExcerptOptions {
         max_bytes: 512,
         min_bytes: 128,

crates/zeta2_tools/src/zeta2_tools.rs 🔗

@@ -68,6 +68,7 @@ pub struct Zeta2Inspector {
     min_excerpt_bytes_input: Entity<SingleLineInput>,
     cursor_context_ratio_input: Entity<SingleLineInput>,
     max_prompt_bytes_input: Entity<SingleLineInput>,
+    max_retrieved_declarations: Entity<SingleLineInput>,
     active_view: ActiveView,
     zeta: Entity<Zeta>,
     _active_editor_subscription: Option<Subscription>,
@@ -133,6 +134,7 @@ impl Zeta2Inspector {
             min_excerpt_bytes_input: Self::number_input("Min Excerpt Bytes", window, cx),
             cursor_context_ratio_input: Self::number_input("Cursor Context Ratio", window, cx),
             max_prompt_bytes_input: Self::number_input("Max Prompt Bytes", window, cx),
+            max_retrieved_declarations: Self::number_input("Max Retrieved Definitions", window, cx),
             zeta: zeta.clone(),
             _active_editor_subscription: None,
             _update_state_task: Task::ready(()),
@@ -170,6 +172,13 @@ impl Zeta2Inspector {
         self.max_prompt_bytes_input.update(cx, |input, cx| {
             input.set_text(options.max_prompt_bytes.to_string(), window, cx);
         });
+        self.max_retrieved_declarations.update(cx, |input, cx| {
+            input.set_text(
+                options.context.max_retrieved_declarations.to_string(),
+                window,
+                cx,
+            );
+        });
         cx.notify();
     }
 
@@ -246,6 +255,10 @@ impl Zeta2Inspector {
                             cx,
                         ),
                     },
+                    max_retrieved_declarations: number_input_value(
+                        &this.max_retrieved_declarations,
+                        cx,
+                    ),
                     ..zeta_options.context
                 };
 
@@ -536,6 +549,7 @@ impl Zeta2Inspector {
                         h_flex()
                             .gap_2()
                             .items_end()
+                            .child(self.max_retrieved_declarations.clone())
                             .child(self.max_prompt_bytes_input.clone())
                             .child(self.render_prompt_format_dropdown(window, cx)),
                     ),

crates/zeta_cli/src/main.rs 🔗

@@ -94,8 +94,8 @@ struct Zeta2Args {
     file_indexing_parallelism: usize,
     #[arg(long, default_value_t = false)]
     disable_imports_gathering: bool,
-    #[arg(long, default_value_t = false)]
-    disable_reference_retrieval: bool,
+    #[arg(long, default_value_t = u8::MAX)]
+    max_retrieved_definitions: u8,
 }
 
 #[derive(clap::ValueEnum, Default, Debug, Clone)]
@@ -302,7 +302,7 @@ impl Zeta2Args {
     fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions {
         zeta2::ZetaOptions {
             context: EditPredictionContextOptions {
-                use_references: !self.disable_reference_retrieval,
+                max_retrieved_declarations: self.max_retrieved_definitions,
                 use_imports: !self.disable_imports_gathering,
                 excerpt: EditPredictionExcerptOptions {
                     max_bytes: self.max_excerpt_bytes,