Improve EP teacher prompt, add some CLI features (#47814)

Max Brunsfeld , Oleksiy Syvokon , Ben Kunkle , and Oleksiy created

Release Notes:

- N/A

---------

Co-authored-by: Oleksiy Syvokon <oleksiy.syvokon@gmail.com>
Co-authored-by: Ben Kunkle <ben@zed.dev>
Co-authored-by: Oleksiy <oleksiy@zed.dev>

Change summary

Cargo.lock                                          |   1 
crates/edit_prediction/Cargo.toml                   |   1 
crates/edit_prediction/src/cursor_excerpt.rs        | 599 +++++++++++++-
crates/edit_prediction/src/edit_prediction_tests.rs |   3 
crates/edit_prediction/src/zeta1.rs                 |  18 
crates/edit_prediction_cli/src/anthropic_client.rs  |  31 
crates/edit_prediction_cli/src/main.rs              | 119 ++
crates/edit_prediction_cli/src/openai_client.rs     |  26 
crates/edit_prediction_cli/src/predict.rs           | 179 ++--
crates/edit_prediction_cli/src/prompts/teacher.md   |  90 ++
crates/edit_prediction_cli/src/pull_examples.rs     | 252 ++++++
crates/edit_prediction_cli/src/qa.rs                |   4 
crates/edit_prediction_cli/src/repair.rs            |   4 
13 files changed, 1,119 insertions(+), 208 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -5300,6 +5300,7 @@ dependencies = [
  "thiserror 2.0.17",
  "time",
  "toml 0.8.23",
+ "tree-sitter-rust",
  "ui",
  "util",
  "uuid",

crates/edit_prediction/Cargo.toml 🔗

@@ -80,4 +80,5 @@ parking_lot.workspace = true
 project = { workspace = true, features = ["test-support"] }
 settings = { workspace = true, features = ["test-support"] }
 workspace = { workspace = true, features = ["test-support"] }
+tree-sitter-rust.workspace = true
 zlog.workspace = true

crates/edit_prediction/src/cursor_excerpt.rs 🔗

@@ -7,66 +7,199 @@ pub fn editable_and_context_ranges_for_cursor_position(
     editable_region_token_limit: usize,
     context_token_limit: usize,
 ) -> (Range<Point>, Range<Point>) {
-    let mut scope_range = position..position;
-    let mut remaining_edit_tokens = editable_region_token_limit;
-
-    while let Some(parent) = snapshot.syntax_ancestor(scope_range.clone()) {
-        let parent_tokens = guess_token_count(parent.byte_range().len());
-        let parent_point_range = Point::new(
-            parent.start_position().row as u32,
-            parent.start_position().column as u32,
-        )
-            ..Point::new(
-                parent.end_position().row as u32,
-                parent.end_position().column as u32,
-            );
-        if parent_point_range == scope_range {
-            break;
-        } else if parent_tokens <= editable_region_token_limit {
-            scope_range = parent_point_range;
-            remaining_edit_tokens = editable_region_token_limit - parent_tokens;
+    let editable_range = compute_editable_range(snapshot, position, editable_region_token_limit);
+
+    let context_range = expand_context_syntactically_then_linewise(
+        snapshot,
+        editable_range.clone(),
+        context_token_limit,
+    );
+
+    (editable_range, context_range)
+}
+
+/// Computes the editable range using a three-phase approach:
+/// 1. Expand symmetrically from cursor (75% of budget)
+/// 2. Expand to syntax boundaries
+/// 3. Continue line-wise in the least-expanded direction
+fn compute_editable_range(
+    snapshot: &BufferSnapshot,
+    cursor: Point,
+    token_limit: usize,
+) -> Range<Point> {
+    // Phase 1: Expand symmetrically from cursor using 75% of budget.
+    let initial_budget = (token_limit * 3) / 4;
+    let (mut start_row, mut end_row, mut remaining_tokens) =
+        expand_symmetric_from_cursor(snapshot, cursor.row, initial_budget);
+
+    // Add remaining budget from phase 1.
+    remaining_tokens += token_limit.saturating_sub(initial_budget);
+
+    let original_start = start_row;
+    let original_end = end_row;
+
+    // Phase 2: Expand to syntax boundaries that fit within budget.
+    for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row)
+    {
+        let tokens_for_start = if boundary_start < start_row {
+            estimate_tokens_for_rows(snapshot, boundary_start, start_row)
+        } else {
+            0
+        };
+        let tokens_for_end = if boundary_end > end_row {
+            estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1)
+        } else {
+            0
+        };
+
+        let total_needed = tokens_for_start + tokens_for_end;
+
+        if total_needed <= remaining_tokens {
+            if boundary_start < start_row {
+                start_row = boundary_start;
+            }
+            if boundary_end > end_row {
+                end_row = boundary_end;
+            }
+            remaining_tokens = remaining_tokens.saturating_sub(total_needed);
         } else {
             break;
         }
     }
 
-    let editable_range = expand_range(snapshot, scope_range, remaining_edit_tokens);
-    let context_range = expand_range(snapshot, editable_range.clone(), context_token_limit);
-    (editable_range, context_range)
+    // Phase 3: Continue line-wise in the direction we expanded least during syntax phase.
+    let expanded_up = original_start.saturating_sub(start_row);
+    let expanded_down = end_row.saturating_sub(original_end);
+
+    (start_row, end_row, _) = expand_linewise_biased(
+        snapshot,
+        start_row,
+        end_row,
+        remaining_tokens,
+        expanded_up <= expanded_down, // prefer_up if we expanded less upward
+    );
+
+    let start = Point::new(start_row, 0);
+    let end = Point::new(end_row, snapshot.line_len(end_row));
+    start..end
 }
 
-fn expand_range(
+/// Expands symmetrically from cursor, one line at a time, alternating down then up.
+/// Returns (start_row, end_row, remaining_tokens).
+fn expand_symmetric_from_cursor(
     snapshot: &BufferSnapshot,
-    range: Range<Point>,
+    cursor_row: u32,
+    mut token_budget: usize,
+) -> (u32, u32, usize) {
+    let mut start_row = cursor_row;
+    let mut end_row = cursor_row;
+
+    // Account for the cursor's line.
+    let cursor_line_tokens = line_token_count(snapshot, cursor_row);
+    token_budget = token_budget.saturating_sub(cursor_line_tokens);
+
+    loop {
+        let can_expand_up = start_row > 0;
+        let can_expand_down = end_row < snapshot.max_point().row;
+
+        if token_budget == 0 || (!can_expand_up && !can_expand_down) {
+            break;
+        }
+
+        // Expand down first (slight forward bias for edit prediction).
+        if can_expand_down {
+            let next_row = end_row + 1;
+            let line_tokens = line_token_count(snapshot, next_row);
+            if line_tokens <= token_budget {
+                end_row = next_row;
+                token_budget = token_budget.saturating_sub(line_tokens);
+            } else {
+                break;
+            }
+        }
+
+        // Then expand up.
+        if can_expand_up && token_budget > 0 {
+            let next_row = start_row - 1;
+            let line_tokens = line_token_count(snapshot, next_row);
+            if line_tokens <= token_budget {
+                start_row = next_row;
+                token_budget = token_budget.saturating_sub(line_tokens);
+            } else {
+                break;
+            }
+        }
+    }
+
+    (start_row, end_row, token_budget)
+}
+
+/// Expands line-wise with a bias toward one direction.
+/// Returns (start_row, end_row, remaining_tokens).
+fn expand_linewise_biased(
+    snapshot: &BufferSnapshot,
+    mut start_row: u32,
+    mut end_row: u32,
     mut remaining_tokens: usize,
-) -> Range<Point> {
-    let mut expanded_range = range;
-    expanded_range.start.column = 0;
-    expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
+    prefer_up: bool,
+) -> (u32, u32, usize) {
     loop {
-        let mut expanded = false;
+        let can_expand_up = start_row > 0;
+        let can_expand_down = end_row < snapshot.max_point().row;
 
-        if remaining_tokens > 0 && expanded_range.start.row > 0 {
-            expanded_range.start.row -= 1;
-            let line_tokens =
-                guess_token_count(snapshot.line_len(expanded_range.start.row) as usize);
-            remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
-            expanded = true;
+        if remaining_tokens == 0 || (!can_expand_up && !can_expand_down) {
+            break;
         }
 
-        if remaining_tokens > 0 && expanded_range.end.row < snapshot.max_point().row {
-            expanded_range.end.row += 1;
-            expanded_range.end.column = snapshot.line_len(expanded_range.end.row);
-            let line_tokens = guess_token_count(expanded_range.end.column as usize);
-            remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
-            expanded = true;
+        let mut expanded = false;
+
+        // Try preferred direction first.
+        if prefer_up {
+            if can_expand_up {
+                let next_row = start_row - 1;
+                let line_tokens = line_token_count(snapshot, next_row);
+                if line_tokens <= remaining_tokens {
+                    start_row = next_row;
+                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+                    expanded = true;
+                }
+            }
+            if can_expand_down && remaining_tokens > 0 {
+                let next_row = end_row + 1;
+                let line_tokens = line_token_count(snapshot, next_row);
+                if line_tokens <= remaining_tokens {
+                    end_row = next_row;
+                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+                    expanded = true;
+                }
+            }
+        } else {
+            if can_expand_down {
+                let next_row = end_row + 1;
+                let line_tokens = line_token_count(snapshot, next_row);
+                if line_tokens <= remaining_tokens {
+                    end_row = next_row;
+                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+                    expanded = true;
+                }
+            }
+            if can_expand_up && remaining_tokens > 0 {
+                let next_row = start_row - 1;
+                let line_tokens = line_token_count(snapshot, next_row);
+                if line_tokens <= remaining_tokens {
+                    start_row = next_row;
+                    remaining_tokens = remaining_tokens.saturating_sub(line_tokens);
+                    expanded = true;
+                }
+            }
         }
 
         if !expanded {
             break;
         }
     }
-    expanded_range
+
+    (start_row, end_row, remaining_tokens)
 }
 
 /// Typical number of string bytes per token for the purposes of limiting model input. This is
@@ -76,3 +209,387 @@ pub(crate) const BYTES_PER_TOKEN_GUESS: usize = 3;
 pub fn guess_token_count(bytes: usize) -> usize {
     bytes / BYTES_PER_TOKEN_GUESS
 }
+
+fn line_token_count(snapshot: &BufferSnapshot, row: u32) -> usize {
+    guess_token_count(snapshot.line_len(row) as usize).max(1)
+}
+
+/// Estimates token count for rows in range [start_row, end_row).
+fn estimate_tokens_for_rows(snapshot: &BufferSnapshot, start_row: u32, end_row: u32) -> usize {
+    let mut tokens = 0;
+    for row in start_row..end_row {
+        tokens += line_token_count(snapshot, row);
+    }
+    tokens
+}
+
+/// Returns an iterator of (start_row, end_row) for successively larger syntax nodes
+/// containing the given row range. Smallest containing node first.
+fn containing_syntax_boundaries(
+    snapshot: &BufferSnapshot,
+    start_row: u32,
+    end_row: u32,
+) -> impl Iterator<Item = (u32, u32)> {
+    let range = Point::new(start_row, 0)..Point::new(end_row, snapshot.line_len(end_row));
+    let mut current = snapshot.syntax_ancestor(range);
+    let mut last_rows: Option<(u32, u32)> = None;
+
+    std::iter::from_fn(move || {
+        while let Some(node) = current.take() {
+            let node_start_row = node.start_position().row as u32;
+            let node_end_row = node.end_position().row as u32;
+            let rows = (node_start_row, node_end_row);
+
+            current = node.parent();
+
+            // Skip nodes that don't extend beyond our range.
+            if node_start_row >= start_row && node_end_row <= end_row {
+                continue;
+            }
+
+            // Skip if same as last returned (some nodes have same span).
+            if last_rows == Some(rows) {
+                continue;
+            }
+
+            last_rows = Some(rows);
+            return Some(rows);
+        }
+        None
+    })
+}
+
+/// Expands context by first trying to reach syntax boundaries,
+/// then expanding line-wise only if no syntax expansion occurred.
+fn expand_context_syntactically_then_linewise(
+    snapshot: &BufferSnapshot,
+    editable_range: Range<Point>,
+    context_token_limit: usize,
+) -> Range<Point> {
+    let mut start_row = editable_range.start.row;
+    let mut end_row = editable_range.end.row;
+    let mut remaining_tokens = context_token_limit;
+    let mut did_syntax_expand = false;
+
+    // Phase 1: Try to expand to containing syntax boundaries, picking the largest that fits.
+    for (boundary_start, boundary_end) in containing_syntax_boundaries(snapshot, start_row, end_row)
+    {
+        let tokens_for_start = if boundary_start < start_row {
+            estimate_tokens_for_rows(snapshot, boundary_start, start_row)
+        } else {
+            0
+        };
+        let tokens_for_end = if boundary_end > end_row {
+            estimate_tokens_for_rows(snapshot, end_row + 1, boundary_end + 1)
+        } else {
+            0
+        };
+
+        let total_needed = tokens_for_start + tokens_for_end;
+
+        if total_needed <= remaining_tokens {
+            if boundary_start < start_row {
+                start_row = boundary_start;
+            }
+            if boundary_end > end_row {
+                end_row = boundary_end;
+            }
+            remaining_tokens = remaining_tokens.saturating_sub(total_needed);
+            did_syntax_expand = true;
+        } else {
+            break;
+        }
+    }
+
+    // Phase 2: Only expand line-wise if no syntax expansion occurred.
+    if !did_syntax_expand {
+        (start_row, end_row, _) =
+            expand_linewise_biased(snapshot, start_row, end_row, remaining_tokens, true);
+    }
+
+    let start = Point::new(start_row, 0);
+    let end = Point::new(end_row, snapshot.line_len(end_row));
+    start..end
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use gpui::{App, AppContext};
+    use indoc::indoc;
+    use language::{Buffer, rust_lang};
+    use util::test::{TextRangeMarker, marked_text_ranges_by};
+
+    struct TestCase {
+        name: &'static str,
+        marked_text: &'static str,
+        editable_token_limit: usize,
+        context_token_limit: usize,
+    }
+
+    #[gpui::test]
+    fn test_editable_and_context_ranges(cx: &mut App) {
+        // Markers:
+        // ˇ = cursor position
+        // « » = expected editable range
+        // [ ] = expected context range
+        let test_cases = vec![
+            TestCase {
+                name: "cursor near end of function - expands to syntax boundaries",
+                marked_text: indoc! {r#"
+                    [fn first() {
+                        let a = 1;
+                        let b = 2;
+                    }
+
+                    fn foo() {
+                    «    let x = 1;
+                        let y = 2;
+                        println!("{}", x + y);ˇ
+                    }»]
+                "#},
+                // 18 tokens - expands symmetrically then to syntax boundaries
+                editable_token_limit: 18,
+                context_token_limit: 35,
+            },
+            TestCase {
+                name: "cursor at function start - expands to syntax boundaries",
+                marked_text: indoc! {r#"
+                    [fn before() {
+                    «    let a = 1;
+                    }
+
+                    fn foo() {ˇ
+                        let x = 1;
+                        let y = 2;
+                        let z = 3;
+                    }
+                    »
+                    fn after() {
+                        let b = 2;
+                    }]
+                "#},
+                // 25 tokens - expands symmetrically then to syntax boundaries
+                editable_token_limit: 25,
+                context_token_limit: 50,
+            },
+            TestCase {
+                name: "tiny budget - just lines around cursor",
+                marked_text: indoc! {r#"
+                    fn outer() {
+                    [    let line1 = 1;
+                        let line2 = 2;
+                    «    let line3 = 3;
+                        let line4 = 4;ˇ»
+                        let line5 = 5;
+                        let line6 = 6;]
+                        let line7 = 7;
+                    }
+                "#},
+                // 12 tokens (~36 bytes) = just the cursor line with tiny budget
+                editable_token_limit: 12,
+                context_token_limit: 24,
+            },
+            TestCase {
+                name: "small function fits entirely",
+                marked_text: indoc! {r#"
+                    [«fn foo() {
+                        let x = 1;ˇ
+                        let y = 2;
+                    }»]
+                "#},
+                // Plenty of budget for this small function
+                editable_token_limit: 30,
+                context_token_limit: 60,
+            },
+            TestCase {
+                name: "context extends beyond editable",
+                marked_text: indoc! {r#"
+                    [fn first() { let a = 1; }
+                    «fn second() { let b = 2; }
+                    fn third() { let c = 3; }ˇ
+                    fn fourth() { let d = 4; }»
+                    fn fifth() { let e = 5; }]
+                "#},
+                // Small editable, larger context
+                editable_token_limit: 25,
+                context_token_limit: 45,
+            },
+            // Tests for syntax-aware editable and context expansion
+            TestCase {
+                name: "cursor in first if-statement - expands to syntax boundaries",
+                marked_text: indoc! {r#"
+                    [«fn before() { }
+
+                    fn process() {
+                        if condition1 {
+                            let a = 1;ˇ
+                            let b = 2;
+                        }
+                        if condition2 {»
+                            let c = 3;
+                            let d = 4;
+                        }
+                        if condition3 {
+                            let e = 5;
+                            let f = 6;
+                        }
+                    }
+
+                    fn after() { }]
+                "#},
+                // 35 tokens allows expansion to include function header and first two if blocks
+                editable_token_limit: 35,
+                // 60 tokens allows context to include the whole file
+                context_token_limit: 60,
+            },
+            TestCase {
+                name: "cursor in middle if-statement - expands to syntax boundaries",
+                marked_text: indoc! {r#"
+                    [fn before() { }
+
+                    fn process() {
+                        if condition1 {
+                            let a = 1;
+                    «        let b = 2;
+                        }
+                        if condition2 {
+                            let c = 3;ˇ
+                            let d = 4;
+                        }
+                        if condition3 {
+                            let e = 5;»
+                            let f = 6;
+                        }
+                    }
+
+                    fn after() { }]
+                "#},
+                // 40 tokens allows expansion to surrounding if blocks
+                editable_token_limit: 40,
+                // 60 tokens allows context to include the whole file
+                context_token_limit: 60,
+            },
+            TestCase {
+                name: "cursor near bottom of long function - editable expands toward syntax, context reaches function",
+                marked_text: indoc! {r#"
+                    [fn other() { }
+
+                    fn long_function() {
+                        let line1 = 1;
+                        let line2 = 2;
+                        let line3 = 3;
+                        let line4 = 4;
+                        let line5 = 5;
+                        let line6 = 6;
+                    «    let line7 = 7;
+                        let line8 = 8;
+                        let line9 = 9;
+                        let line10 = 10;ˇ
+                        let line11 = 11;
+                    }
+
+                    fn another() { }»]
+                "#},
+                // 40 tokens for editable - allows several lines plus syntax expansion
+                editable_token_limit: 40,
+                // 55 tokens - enough for function but not whole file
+                context_token_limit: 55,
+            },
+        ];
+
+        for test_case in test_cases {
+            let cursor_marker: TextRangeMarker = 'ˇ'.into();
+            let editable_marker: TextRangeMarker = ('«', '»').into();
+            let context_marker: TextRangeMarker = ('[', ']').into();
+
+            let (text, mut ranges) = marked_text_ranges_by(
+                test_case.marked_text,
+                vec![
+                    cursor_marker.clone(),
+                    editable_marker.clone(),
+                    context_marker.clone(),
+                ],
+            );
+
+            let cursor_ranges = ranges.remove(&cursor_marker).unwrap_or_default();
+            let expected_editable = ranges.remove(&editable_marker).unwrap_or_default();
+            let expected_context = ranges.remove(&context_marker).unwrap_or_default();
+            assert_eq!(expected_editable.len(), 1);
+            assert_eq!(expected_context.len(), 1);
+
+            cx.new(|cx| {
+                let text = text.trim_end_matches('\n');
+                let buffer = Buffer::local(text, cx).with_language(rust_lang(), cx);
+                let snapshot = buffer.snapshot();
+
+                let cursor_offset = cursor_ranges[0].start;
+                let cursor_point = snapshot.offset_to_point(cursor_offset);
+                let expected_editable_start = snapshot.offset_to_point(expected_editable[0].start);
+                let expected_editable_end = snapshot.offset_to_point(expected_editable[0].end);
+                let expected_context_start = snapshot.offset_to_point(expected_context[0].start);
+                let expected_context_end = snapshot.offset_to_point(expected_context[0].end);
+
+                let (actual_editable, actual_context) =
+                    editable_and_context_ranges_for_cursor_position(
+                        cursor_point,
+                        &snapshot,
+                        test_case.editable_token_limit,
+                        test_case.context_token_limit,
+                    );
+
+                let range_text = |start: Point, end: Point| -> String {
+                    snapshot.text_for_range(start..end).collect()
+                };
+
+                let editable_match = actual_editable.start == expected_editable_start
+                    && actual_editable.end == expected_editable_end;
+                let context_match = actual_context.start == expected_context_start
+                    && actual_context.end == expected_context_end;
+
+                if !editable_match || !context_match {
+                    println!("\n=== FAILED: {} ===", test_case.name);
+                    if !editable_match {
+                        println!(
+                            "\nExpected editable ({:?}..{:?}):",
+                            expected_editable_start, expected_editable_end
+                        );
+                        println!(
+                            "---\n{}---",
+                            range_text(expected_editable_start, expected_editable_end)
+                        );
+                        println!(
+                            "\nActual editable ({:?}..{:?}):",
+                            actual_editable.start, actual_editable.end
+                        );
+                        println!(
+                            "---\n{}---",
+                            range_text(actual_editable.start, actual_editable.end)
+                        );
+                    }
+                    if !context_match {
+                        println!(
+                            "\nExpected context ({:?}..{:?}):",
+                            expected_context_start, expected_context_end
+                        );
+                        println!(
+                            "---\n{}---",
+                            range_text(expected_context_start, expected_context_end)
+                        );
+                        println!(
+                            "\nActual context ({:?}..{:?}):",
+                            actual_context.start, actual_context.end
+                        );
+                        println!(
+                            "---\n{}---",
+                            range_text(actual_context.start, actual_context.end)
+                        );
+                    }
+                    panic!("Test '{}' failed - see output above", test_case.name);
+                }
+
+                buffer
+            });
+        }
+    }
+}

crates/edit_prediction/src/edit_prediction_tests.rs 🔗

@@ -12,12 +12,13 @@ use futures::{
     AsyncReadExt, StreamExt,
     channel::{mpsc, oneshot},
 };
+use gpui::App;
 use gpui::{
     Entity, TestAppContext,
     http_client::{FakeHttpClient, Response},
 };
 use indoc::indoc;
-use language::Point;
+use language::{Buffer, Point};
 use lsp::LanguageServerId;
 use parking_lot::Mutex;
 use pretty_assertions::{assert_eq, assert_matches};

crates/edit_prediction/src/zeta1.rs 🔗

@@ -610,20 +610,17 @@ mod tests {
         let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(language::rust_lang(), cx));
         let snapshot = buffer.read(cx).snapshot();
 
-        // Ensure we try to fit the largest possible syntax scope, resorting to line-based expansion
-        // when a larger scope doesn't fit the editable region.
+        // The excerpt expands to syntax boundaries.
+        // With 50 token editable limit, we get a region that expands to syntax nodes.
         let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 50, 32);
         assert_eq!(
             excerpt.prompt,
             indoc! {r#"
             ```main.rs
-                let x = 42;
-                println!("Hello, world!");
-            <|editable_region_start|>
-            }
 
             fn bar() {
                 let x = 42;
+            <|editable_region_start|>
                 let mut sum = 0;
                 for i in 0..x {
                     sum += i;
@@ -639,7 +636,7 @@ mod tests {
             ```"#}
         );
 
-        // The `bar` function won't fit within the editable region, so we resort to line-based expansion.
+        // With smaller budget, the region expands to syntax boundaries but is tighter.
         let excerpt = excerpt_for_cursor_position(Point::new(12, 5), "main.rs", &snapshot, 40, 32);
         assert_eq!(
             excerpt.prompt,
@@ -648,8 +645,8 @@ mod tests {
             fn bar() {
                 let x = 42;
                 let mut sum = 0;
-            <|editable_region_start|>
                 for i in 0..x {
+            <|editable_region_start|>
                     sum += i;
                 }
                 println!("Sum: {}", sum);
@@ -657,11 +654,8 @@ mod tests {
             }
 
             fn generate_random_numbers() -> Vec<i32> {
-                let mut rng = rand::thread_rng();
             <|editable_region_end|>
-                let mut numbers = Vec::new();
-                for _ in 0..5 {
-                    numbers.push(rng.random_range(1..101));
+                let mut rng = rand::thread_rng();
             ```"#}
         );
     }

crates/edit_prediction_cli/src/anthropic_client.rs 🔗

@@ -208,8 +208,9 @@ impl BatchingLlmClient {
         model: &str,
         max_tokens: u64,
         messages: &[Message],
+        seed: Option<usize>,
     ) -> Result<Option<AnthropicResponse>> {
-        let request_hash_str = Self::request_hash(model, max_tokens, messages);
+        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
         let connection = self.connection.lock().unwrap();
         let response: Vec<String> = connection.select_bound(
             &sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
@@ -220,8 +221,14 @@ impl BatchingLlmClient {
             .and_then(|text| serde_json::from_str(&text).ok()))
     }
 
-    pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
-        let request_hash = Self::request_hash(model, max_tokens, messages);
+    pub fn mark_for_batch(
+        &self,
+        model: &str,
+        max_tokens: u64,
+        messages: &[Message],
+        seed: Option<usize>,
+    ) -> Result<()> {
+        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
 
         let serializable_messages: Vec<SerializableMessage> = messages
             .iter()
@@ -259,13 +266,14 @@ impl BatchingLlmClient {
         model: &str,
         max_tokens: u64,
         messages: Vec<Message>,
+        seed: Option<usize>,
     ) -> Result<Option<AnthropicResponse>> {
-        let response = self.lookup(model, max_tokens, &messages)?;
+        let response = self.lookup(model, max_tokens, &messages, seed)?;
         if let Some(response) = response {
             return Ok(Some(response));
         }
 
-        self.mark_for_batch(model, max_tokens, &messages)?;
+        self.mark_for_batch(model, max_tokens, &messages, seed)?;
 
         Ok(None)
     }
@@ -606,13 +614,21 @@ impl BatchingLlmClient {
         Ok(all_batch_ids)
     }
 
-    fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
+    fn request_hash(
+        model: &str,
+        max_tokens: u64,
+        messages: &[Message],
+        seed: Option<usize>,
+    ) -> String {
         let mut hasher = std::hash::DefaultHasher::new();
         model.hash(&mut hasher);
         max_tokens.hash(&mut hasher);
         for msg in messages {
             message_content_to_string(&msg.content).hash(&mut hasher);
         }
+        if let Some(seed) = seed {
+            seed.hash(&mut hasher);
+        }
         let request_hash = hasher.finish();
         format!("{request_hash:016x}")
     }
@@ -655,6 +671,7 @@ impl AnthropicClient {
         model: &str,
         max_tokens: u64,
         messages: Vec<Message>,
+        seed: Option<usize>,
     ) -> Result<Option<AnthropicResponse>> {
         match self {
             AnthropicClient::Plain(plain_llm_client) => plain_llm_client
@@ -663,7 +680,7 @@ impl AnthropicClient {
                 .map(Some),
             AnthropicClient::Batch(batching_llm_client) => {
                 batching_llm_client
-                    .generate(model, max_tokens, messages)
+                    .generate(model, max_tokens, messages, seed)
                     .await
             }
             AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),

crates/edit_prediction_cli/src/main.rs 🔗

@@ -85,6 +85,10 @@ struct EpArgs {
     /// Failed examples are always logged to the run's failed directory.
     #[arg(long, global = true, default_value = "keep")]
     failed: FailedHandling,
+    /// Output as markdown files instead of JSONL. When set, -o specifies a directory
+    /// where one .md file per example will be written (named after each example).
+    #[arg(long, short, global = true)]
+    markdown: bool,
 }
 
 /// Controls whether failed examples are included in the main output.
@@ -431,6 +435,7 @@ async fn load_examples(
 ) -> anyhow::Result<Vec<Example>> {
     let mut captured_after_timestamps = Vec::new();
     let mut rejected_after_timestamps = Vec::new();
+    let mut requested_after_timestamps = Vec::new();
     let mut file_inputs = Vec::new();
 
     for input in &args.inputs {
@@ -441,6 +446,10 @@ async fn load_examples(
             pull_examples::parse_rejected_after_input(input_string.as_ref())
         {
             rejected_after_timestamps.push(timestamp.to_string());
+        } else if let Some(timestamp) =
+            pull_examples::parse_requested_after_input(input_string.as_ref())
+        {
+            requested_after_timestamps.push(timestamp.to_string());
         } else {
             file_inputs.push(input.clone());
         }
@@ -477,14 +486,27 @@ async fn load_examples(
             rejected_after_timestamps.sort();
 
             let mut rejected_examples = pull_examples::fetch_rejected_examples_after(
-                http_client,
+                http_client.clone(),
                 &rejected_after_timestamps,
                 max_rows_per_timestamp,
-                background_executor,
+                background_executor.clone(),
             )
             .await?;
             examples.append(&mut rejected_examples);
         }
+
+        if !requested_after_timestamps.is_empty() {
+            requested_after_timestamps.sort();
+
+            let mut requested_examples = pull_examples::fetch_requested_examples_after(
+                http_client,
+                &requested_after_timestamps,
+                max_rows_per_timestamp,
+                background_executor,
+            )
+            .await?;
+            examples.append(&mut requested_examples);
+        }
     }
 
     crate::example::sort_examples_by_repo_and_rev(&mut examples);
@@ -583,6 +605,12 @@ fn main() {
     }
 
     let output = args.output_path();
+
+    if args.markdown && output.is_none() {
+        eprintln!("--markdown requires -o to specify the output directory");
+        std::process::exit(1);
+    }
+
     let command = match &args.command {
         Some(cmd) => cmd.clone(),
         None => {
@@ -756,6 +784,18 @@ fn main() {
 
                 let failfast_on_single_example = examples.len() == 1;
 
+                // For --markdown mode, create the output directory if it doesn't exist
+                let markdown_output_dir = if args.markdown {
+                    let dir = output.as_ref().expect("--markdown requires -o");
+                    if !dir.exists() {
+                        std::fs::create_dir_all(dir)
+                            .expect("Failed to create markdown output directory");
+                    }
+                    Some(dir.clone())
+                } else {
+                    None
+                };
+
                 // For --in-place, write to a temp file and rename at the end to avoid data loss on interruption
                 let in_place_temp_path = if args.in_place {
                     output.as_ref().map(|path| {
@@ -767,40 +807,41 @@ fn main() {
                     None
                 };
 
-                let output_sender: Option<mpsc::UnboundedSender<String>> =
-                    if args.output.is_some() || !matches!(command, Command::Eval(_)) {
-                        let write_path = in_place_temp_path.as_ref().or(output.as_ref());
-                        write_path.map(|path| {
-                            let file = if args.in_place {
-                                // For --in-place, write to temp file (truncate if exists)
-                                OpenOptions::new()
-                                    .create(true)
-                                    .write(true)
-                                    .truncate(true)
-                                    .open(path)
-                                    .expect("Failed to open temp output file")
-                            } else {
-                                // For regular output, append to support resuming
-                                OpenOptions::new()
-                                    .create(true)
-                                    .append(true)
-                                    .open(path)
-                                    .expect("Failed to open output file")
-                            };
-                            let mut writer = BufWriter::new(file);
-                            let (sender, mut receiver) = mpsc::unbounded::<String>();
-                            cx.background_spawn(async move {
-                                while let Some(line) = receiver.next().await {
-                                    writeln!(writer, "{}", line).expect("Failed to write example");
-                                    writer.flush().expect("Failed to flush output");
-                                }
-                            })
-                            .detach();
-                            sender
+                let output_sender: Option<mpsc::UnboundedSender<String>> = if !args.markdown
+                    && (args.output.is_some() || !matches!(command, Command::Eval(_)))
+                {
+                    let write_path = in_place_temp_path.as_ref().or(output.as_ref());
+                    write_path.map(|path| {
+                        let file = if args.in_place {
+                            // For --in-place, write to temp file (truncate if exists)
+                            OpenOptions::new()
+                                .create(true)
+                                .write(true)
+                                .truncate(true)
+                                .open(path)
+                                .expect("Failed to open temp output file")
+                        } else {
+                            // For regular output, append to support resuming
+                            OpenOptions::new()
+                                .create(true)
+                                .append(true)
+                                .open(path)
+                                .expect("Failed to open output file")
+                        };
+                        let mut writer = BufWriter::new(file);
+                        let (sender, mut receiver) = mpsc::unbounded::<String>();
+                        cx.background_spawn(async move {
+                            while let Some(line) = receiver.next().await {
+                                writeln!(writer, "{}", line).expect("Failed to write example");
+                                writer.flush().expect("Failed to flush output");
+                            }
                         })
-                    } else {
-                        None
-                    };
+                        .detach();
+                        sender
+                    })
+                } else {
+                    None
+                };
 
                 let grouped_examples = Mutex::new(group_examples_by_repo(examples));
                 let finished_examples = Mutex::new(Vec::new());
@@ -917,7 +958,13 @@ fn main() {
 
                                 let should_write = !failed || args.failed == FailedHandling::Keep;
                                 if should_write {
-                                    if let Some(ref mut sender) = output_sender.clone() {
+                                    if let Some(ref markdown_dir) = markdown_output_dir {
+                                        let filename = format!("{}.md", example.spec.filename());
+                                        let path = markdown_dir.join(&filename);
+                                        let markdown = example.spec.to_markdown();
+                                        std::fs::write(&path, &markdown)
+                                            .expect("Failed to write markdown file");
+                                    } else if let Some(ref mut sender) = output_sender.clone() {
                                         let line = serde_json::to_string(&example).unwrap();
                                         sender
                                             .send(line)

crates/edit_prediction_cli/src/openai_client.rs 🔗

@@ -138,8 +138,9 @@ impl BatchingOpenAiClient {
         model: &str,
         max_tokens: u64,
         messages: &[RequestMessage],
+        seed: Option<usize>,
     ) -> Result<Option<OpenAiResponse>> {
-        let request_hash_str = Self::request_hash(model, max_tokens, messages);
+        let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
         let connection = self.connection.lock().unwrap();
         let response: Vec<String> = connection.select_bound(
             &sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
@@ -155,8 +156,9 @@ impl BatchingOpenAiClient {
         model: &str,
         max_tokens: u64,
         messages: &[RequestMessage],
+        seed: Option<usize>,
     ) -> Result<()> {
-        let request_hash = Self::request_hash(model, max_tokens, messages);
+        let request_hash = Self::request_hash(model, max_tokens, messages, seed);
 
         let serializable_messages: Vec<SerializableMessage> = messages
             .iter()
@@ -191,13 +193,14 @@ impl BatchingOpenAiClient {
         model: &str,
         max_tokens: u64,
         messages: Vec<RequestMessage>,
+        seed: Option<usize>,
     ) -> Result<Option<OpenAiResponse>> {
-        let response = self.lookup(model, max_tokens, &messages)?;
+        let response = self.lookup(model, max_tokens, &messages, seed)?;
         if let Some(response) = response {
             return Ok(Some(response));
         }
 
-        self.mark_for_batch(model, max_tokens, &messages)?;
+        self.mark_for_batch(model, max_tokens, &messages, seed)?;
 
         Ok(None)
     }
@@ -558,7 +561,12 @@ impl BatchingOpenAiClient {
         Ok(all_batch_ids)
     }
 
-    fn request_hash(model: &str, max_tokens: u64, messages: &[RequestMessage]) -> String {
+    fn request_hash(
+        model: &str,
+        max_tokens: u64,
+        messages: &[RequestMessage],
+        seed: Option<usize>,
+    ) -> String {
         let mut hasher = std::hash::DefaultHasher::new();
         "openai".hash(&mut hasher);
         model.hash(&mut hasher);
@@ -566,6 +574,9 @@ impl BatchingOpenAiClient {
         for msg in messages {
             message_content_to_string(msg).hash(&mut hasher);
         }
+        if let Some(seed) = seed {
+            seed.hash(&mut hasher);
+        }
         let request_hash = hasher.finish();
         format!("{request_hash:016x}")
     }
@@ -631,6 +642,7 @@ impl OpenAiClient {
         model: &str,
         max_tokens: u64,
         messages: Vec<RequestMessage>,
+        seed: Option<usize>,
     ) -> Result<Option<OpenAiResponse>> {
         match self {
             OpenAiClient::Plain(plain_client) => plain_client
@@ -638,7 +650,9 @@ impl OpenAiClient {
                 .await
                 .map(Some),
             OpenAiClient::Batch(batching_client) => {
-                batching_client.generate(model, max_tokens, messages).await
+                batching_client
+                    .generate(model, max_tokens, messages, seed)
+                    .await
             }
             OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
         }

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -69,7 +69,7 @@ pub async fn run_prediction(
         .await?;
 
         let batched = matches!(provider, PredictionProvider::Teacher(..));
-        return predict_teacher(example, backend, batched).await;
+        return predict_teacher(example, backend, batched, repetition_count).await;
     }
 
     run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
@@ -261,10 +261,13 @@ async fn predict_teacher(
     example: &mut Example,
     backend: TeacherBackend,
     batched: bool,
+    repetition_count: usize,
 ) -> anyhow::Result<()> {
     match backend {
-        TeacherBackend::Sonnet45 => predict_anthropic(example, backend, batched).await,
-        TeacherBackend::Gpt52 => predict_openai(example, backend, batched).await,
+        TeacherBackend::Sonnet45 => {
+            predict_anthropic(example, backend, batched, repetition_count).await
+        }
+        TeacherBackend::Gpt52 => predict_openai(example, backend, batched, repetition_count).await,
     }
 }
 
@@ -272,6 +275,7 @@ async fn predict_anthropic(
     example: &mut Example,
     backend: TeacherBackend,
     batched: bool,
+    repetition_count: usize,
 ) -> anyhow::Result<()> {
     let llm_model_name = backend.model_name();
     let max_tokens = 16384;
@@ -286,46 +290,49 @@ async fn predict_anthropic(
 
     let prompt = example.prompt.as_ref().context("Prompt is required")?;
 
-    let messages = vec![anthropic::Message {
-        role: anthropic::Role::User,
-        content: vec![anthropic::RequestContent::Text {
-            text: prompt.input.clone(),
-            cache_control: None,
-        }],
-    }];
-
-    let Some(response) = llm_client
-        .generate(llm_model_name, max_tokens, messages)
-        .await?
-    else {
-        // Request stashed for batched processing
-        return Ok(());
-    };
+    for ix in 0..repetition_count {
+        let messages = vec![anthropic::Message {
+            role: anthropic::Role::User,
+            content: vec![anthropic::RequestContent::Text {
+                text: prompt.input.clone(),
+                cache_control: None,
+            }],
+        }];
+
+        let seed = if repetition_count > 1 { Some(ix) } else { None };
+        let Some(response) = llm_client
+            .generate(llm_model_name, max_tokens, messages, seed)
+            .await?
+        else {
+            // Request stashed for batched processing
+            return Ok(());
+        };
 
-    let actual_output = response
-        .content
-        .into_iter()
-        .filter_map(|content| match content {
-            anthropic::ResponseContent::Text { text } => Some(text),
-            _ => None,
-        })
-        .collect::<Vec<String>>()
-        .join("\n");
-
-    let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
-
-    let prediction = ExamplePrediction {
-        actual_patch: Some(actual_patch),
-        actual_output,
-        error: None,
-        provider: if batched {
-            PredictionProvider::Teacher(backend)
-        } else {
-            PredictionProvider::TeacherNonBatching(backend)
-        },
-    };
+        let actual_output = response
+            .content
+            .into_iter()
+            .filter_map(|content| match content {
+                anthropic::ResponseContent::Text { text } => Some(text),
+                _ => None,
+            })
+            .collect::<Vec<String>>()
+            .join("\n");
+
+        let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
+
+        let prediction = ExamplePrediction {
+            actual_patch: Some(actual_patch),
+            actual_output,
+            error: None,
+            provider: if batched {
+                PredictionProvider::Teacher(backend)
+            } else {
+                PredictionProvider::TeacherNonBatching(backend)
+            },
+        };
 
-    example.predictions.push(prediction);
+        example.predictions.push(prediction);
+    }
     Ok(())
 }
 
@@ -333,6 +340,7 @@ async fn predict_openai(
     example: &mut Example,
     backend: TeacherBackend,
     batched: bool,
+    repetition_count: usize,
 ) -> anyhow::Result<()> {
     let llm_model_name = backend.model_name();
     let max_tokens = 16384;
@@ -347,52 +355,55 @@ async fn predict_openai(
 
     let prompt = example.prompt.as_ref().context("Prompt is required")?;
 
-    let messages = vec![open_ai::RequestMessage::User {
-        content: open_ai::MessageContent::Plain(prompt.input.clone()),
-    }];
+    for ix in 0..repetition_count {
+        let messages = vec![open_ai::RequestMessage::User {
+            content: open_ai::MessageContent::Plain(prompt.input.clone()),
+        }];
+
+        let seed = if repetition_count > 1 { Some(ix) } else { None };
+        let Some(response) = llm_client
+            .generate(llm_model_name, max_tokens, messages, seed)
+            .await?
+        else {
+            // Request stashed for batched processing
+            return Ok(());
+        };
 
-    let Some(response) = llm_client
-        .generate(llm_model_name, max_tokens, messages)
-        .await?
-    else {
-        // Request stashed for batched processing
-        return Ok(());
-    };
+        let actual_output = response
+            .choices
+            .into_iter()
+            .filter_map(|choice| match choice.message {
+                open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
+                    open_ai::MessageContent::Plain(text) => text,
+                    open_ai::MessageContent::Multipart(parts) => parts
+                        .into_iter()
+                        .filter_map(|p| match p {
+                            open_ai::MessagePart::Text { text } => Some(text),
+                            _ => None,
+                        })
+                        .collect::<Vec<_>>()
+                        .join(""),
+                }),
+                _ => None,
+            })
+            .collect::<Vec<String>>()
+            .join("\n");
 
-    let actual_output = response
-        .choices
-        .into_iter()
-        .filter_map(|choice| match choice.message {
-            open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
-                open_ai::MessageContent::Plain(text) => text,
-                open_ai::MessageContent::Multipart(parts) => parts
-                    .into_iter()
-                    .filter_map(|p| match p {
-                        open_ai::MessagePart::Text { text } => Some(text),
-                        _ => None,
-                    })
-                    .collect::<Vec<_>>()
-                    .join(""),
-            }),
-            _ => None,
-        })
-        .collect::<Vec<String>>()
-        .join("\n");
-
-    let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
-
-    let prediction = ExamplePrediction {
-        actual_patch: Some(actual_patch),
-        actual_output,
-        error: None,
-        provider: if batched {
-            PredictionProvider::Teacher(backend)
-        } else {
-            PredictionProvider::TeacherNonBatching(backend)
-        },
-    };
+        let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
+
+        let prediction = ExamplePrediction {
+            actual_patch: Some(actual_patch),
+            actual_output,
+            error: None,
+            provider: if batched {
+                PredictionProvider::Teacher(backend)
+            } else {
+                PredictionProvider::TeacherNonBatching(backend)
+            },
+        };
 
-    example.predictions.push(prediction);
+        example.predictions.push(prediction);
+    }
     Ok(())
 }
 

crates/edit_prediction_cli/src/prompts/teacher.md 🔗

@@ -17,8 +17,10 @@ You are an edit prediction assistant in a code editor. Your task is to predict t
 - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
 - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
 - Keep existing formatting unless it's absolutely necessary
-- Don't write a lot of code if you're not sure what to do
+- When edit history and surrounding code suggest different edits, prioritize the most recent edits in the history as they best reflect current intent.
+- When uncertain, predict only the minimal, high-confidence portion of the edit. Prefer a small, correct prediction over a large, speculative one
 - Do not delete or remove text that was just added in the edit history. If a recent edit introduces incomplete or incorrect code, finish or fix it in place, or simply do nothing rather than removing it. Only remove a recent edit if the history explicitly shows the user undoing it themselves.
+- Treat partial text at or near the cursor as the beginning of something the user is actively typing. Complete the code the user appears to be creating based on context.
 
 # Input Format
 
@@ -33,20 +35,41 @@ You will be provided with:
 # Output Format
 
 - Briefly explain the user's current intent based on the edit history and their current cursor location.
-- Output the entire editable region, applying the edits that you predict the user will make next.
-- If you're unsure some portion of the next edit, you may still predict the surrounding code (such as a function definition, `for` loop, etc) and place the `<|user_cursor|>` within it for the user to fill in.
-- Wrap the edited code in a codeblock with exactly five backticks.
+- Output a markdown codeblock containing **only** the editable region with your predicted edits applied. The codeblock must start with `<|editable_region_start|>` and end with `<|editable_region_end|>`. Do not include any content before or after these tags.
+- If the next edit has some uncertainty, you may still predict the surrounding code (such as a function definition, `for` loop, etc) and place the `<|user_cursor|>` within it for the user to fill in.
+  -e.g. if a user is typing `func<|user_cursor|>`, but you don't know what the function name should be, you can predict `function <|user_cursor|>() {}`
 
-## Example
+## Example 1
 
-### Input
+There is code missing at the cursor location. The related excerpts includes the definition of a relevant type. You should fill in the missing code.
+
+### Related Excerpts
 
 `````
 struct Product {
     name: String,
     price: u32,
 }
+`````
+
+### User Edit History
 
+`````
+--- a/src/calculate.rs
++++ b/src/calculate.rs
+@@ -100,6 +100,7 @@
+ fn calculate_total(products: &[Product]) -> u32 {
+     let mut total = 0;
+     for product in products {
++        total += ;
+     }
+     total
+ }
+`````
+
+### Current File
+
+`````src/calculate.rs
 fn calculate_total(products: &[Product]) -> u32 {
 <|editable_region_start|>
     let mut total = 0;
@@ -63,14 +86,60 @@ fn calculate_total(products: &[Product]) -> u32 {
 The user is computing a sum based on a list of products. The only numeric field on `Product` is `price`, so they must intend to sum the prices.
 
 `````
+<|editable_region_start|>
     let mut total = 0;
     for product in products {
         total += product.price;
     }
     total
+<|editable_region_end|>
+`````
+
+## Example 2
+
+The user appears to be in the process of typing an eprintln call. Rather than fixing the spelling issue by deleting the newly-inserted content, you must continue the user's trajectory. It's not clear what data they intend to print. You should fill in as much code as is obviously intended, and position the cursor so that the user can fill in the rest.
+
+### User Edit History
+
+`````
+--- a/src/modal.rs
++++ b/src/modal.rs
+@@ -100,4 +100,4 @@
+ fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) {
+     modal_state.close();
+-     modal_state.dismiss();
++     eprmodal_state.dismiss();
+ }
+`````
+
+### Current File
+
+`````src/modal.rs
+// handle the close button click
+<|editable_region_start|>
+fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) {
+    modal_state.close();
+    epr<|user_cursor|>modal_state.dismiss();
+<|editable_region_end|>
+}
+`````
+
+### Output
+
+The user is clearly starting to type `eprintln!()`, however, what they intend to print is not obvious. I should fill in the print call and string literal, with the cursor positioned inside the string literal so the user can print whatever they want.
+
+`````
+<|editable_region_start|>
+fn handle_close_button_click(modal_state: &mut ModalState, evt: &Event) {
+    modal_state.close();
+    eprintln!("<|user_cursor|>");
+<|editable_region_end|>
 `````
 
-# 1. User Edits History
+
+# Your task:
+
+# 1. User Edit History
 
 `````
 {{edit_history}}
@@ -83,3 +152,10 @@ The user is computing a sum based on a list of products. The only numeric field
 # 3. Current File
 
 {{cursor_excerpt}}
+
+
+
+
+-----
+
+Based on the edit history and context above, predict the user's next edit within the editable region.

crates/edit_prediction_cli/src/pull_examples.rs 🔗

@@ -39,6 +39,11 @@ pub fn parse_rejected_after_input(input: &str) -> Option<&str> {
     input.strip_prefix("rejected-after:")
 }
 
+/// Parse an input token of the form `requested-after:{timestamp}`.
+pub fn parse_requested_after_input(input: &str) -> Option<&str> {
+    input.strip_prefix("requested-after:")
+}
+
 pub async fn fetch_captured_examples_after(
     http_client: Arc<dyn HttpClient>,
     after_timestamps: &[String],
@@ -556,6 +561,204 @@ pub async fn fetch_rejected_examples_after(
     Ok(all_examples)
 }
 
+pub async fn fetch_requested_examples_after(
+    http_client: Arc<dyn HttpClient>,
+    after_timestamps: &[String],
+    max_rows_per_timestamp: usize,
+    background_executor: BackgroundExecutor,
+) -> Result<Vec<Example>> {
+    if after_timestamps.is_empty() {
+        return Ok(Vec::new());
+    }
+
+    let progress = Progress::global();
+
+    let token = std::env::var("EP_SNOWFLAKE_API_KEY")
+        .context("missing required environment variable EP_SNOWFLAKE_API_KEY")?;
+    let base_url = std::env::var("EP_SNOWFLAKE_BASE_URL").context(
+        "missing required environment variable EP_SNOWFLAKE_BASE_URL (e.g. https://<account>.snowflakecomputing.com)",
+    )?;
+    let role = std::env::var("EP_SNOWFLAKE_ROLE").ok();
+
+    let mut all_examples = Vec::new();
+
+    for after_date in after_timestamps.iter() {
+        let step_progress_name = format!("requested>{after_date}");
+        let step_progress = progress.start(Step::PullExamples, &step_progress_name);
+        step_progress.set_substatus("querying");
+
+        let statement = indoc! {r#"
+            SELECT
+                req.event_properties:request_id::string AS request_id,
+                req.device_id::string AS device_id,
+                req.time::string AS time,
+                req.event_properties:input AS input
+            FROM events req
+            WHERE req.event_type = ?
+                AND req.event_properties:version = 'V3'
+                AND req.time > TRY_TO_TIMESTAMP_NTZ(?)
+            ORDER BY req.time ASC
+            LIMIT ?
+        "#};
+
+        let request = json!({
+            "statement": statement,
+            "timeout": DEFAULT_STATEMENT_TIMEOUT_SECONDS,
+            "database": "EVENTS",
+            "schema": "PUBLIC",
+            "warehouse": "DBT",
+            "role": role,
+            "bindings": {
+                "1": { "type": "TEXT", "value": PREDICTIVE_EDIT_REQUESTED_EVENT },
+                "2": { "type": "TEXT", "value": after_date },
+                "3": { "type": "FIXED", "value": max_rows_per_timestamp.to_string() }
+            }
+        });
+
+        let response = run_sql_with_polling(
+            http_client.clone(),
+            &base_url,
+            &token,
+            &request,
+            &step_progress,
+            background_executor.clone(),
+        )
+        .await?;
+
+        let total_rows = response
+            .result_set_meta_data
+            .as_ref()
+            .and_then(|m| m.num_rows)
+            .unwrap_or(response.data.len() as i64);
+
+        let num_partitions = response
+            .result_set_meta_data
+            .as_ref()
+            .map(|m| m.partition_info.len())
+            .unwrap_or(1)
+            .max(1);
+
+        step_progress.set_info(format!("{} rows", total_rows), InfoStyle::Normal);
+        step_progress.set_substatus("parsing");
+
+        let column_indices = get_column_indices(
+            &response.result_set_meta_data,
+            &["request_id", "device_id", "time", "input"],
+        );
+
+        all_examples.extend(requested_examples_from_response(
+            &response,
+            &column_indices,
+        )?);
+
+        if num_partitions > 1 {
+            let statement_handle = response
+                .statement_handle
+                .as_ref()
+                .context("response has multiple partitions but no statementHandle")?;
+
+            for partition in 1..num_partitions {
+                step_progress.set_substatus(format!(
+                    "fetching partition {}/{}",
+                    partition + 1,
+                    num_partitions
+                ));
+
+                let partition_response = fetch_partition(
+                    http_client.clone(),
+                    &base_url,
+                    &token,
+                    statement_handle,
+                    partition,
+                )
+                .await?;
+
+                all_examples.extend(requested_examples_from_response(
+                    &partition_response,
+                    &column_indices,
+                )?);
+            }
+        }
+
+        step_progress.set_substatus("done");
+    }
+
+    Ok(all_examples)
+}
+
+fn requested_examples_from_response<'a>(
+    response: &'a SnowflakeStatementResponse,
+    column_indices: &'a std::collections::HashMap<String, usize>,
+) -> Result<impl Iterator<Item = Example> + 'a> {
+    if let Some(code) = &response.code {
+        if code != SNOWFLAKE_SUCCESS_CODE {
+            anyhow::bail!(
+                "snowflake sql api returned error code={code} message={}",
+                response.message.as_deref().unwrap_or("<no message>")
+            );
+        }
+    }
+
+    let iter = response
+        .data
+        .iter()
+        .enumerate()
+        .filter_map(move |(row_index, data_row)| {
+            let get_string = |name: &str| -> Option<String> {
+                let index = column_indices.get(name).copied()?;
+                match data_row.get(index)? {
+                    JsonValue::String(s) => Some(s.clone()),
+                    JsonValue::Null => None,
+                    other => Some(other.to_string()),
+                }
+            };
+
+            let get_json = |name: &str| -> Option<JsonValue> {
+                let index = column_indices.get(name).copied()?;
+                let value = data_row.get(index)?;
+                if value.is_null() {
+                    return None;
+                }
+                match value {
+                    JsonValue::String(s) => serde_json::from_str(s).ok(),
+                    other => Some(other.clone()),
+                }
+            };
+
+            let request_id_str = get_string("request_id");
+            let device_id = get_string("device_id");
+            let time = get_string("time");
+            let input_json = get_json("input");
+            let input: Option<ZetaPromptInput> =
+                input_json.clone().and_then(|v| serde_json::from_value(v).ok());
+
+            match (request_id_str.clone(), device_id.clone(), time.clone(), input) {
+                (Some(request_id), Some(device_id), Some(time), Some(input)) => {
+                    Some(build_example_from_snowflake(
+                        request_id,
+                        device_id,
+                        time,
+                        input,
+                        vec!["requested".to_string()],
+                        None,
+                    ))
+                }
+                _ => {
+                    log::warn!(
+                        "skipping row {row_index}: missing fields - request_id={:?} device_id={:?} time={:?} input={:?}",
+                        request_id_str.is_some(),
+                        device_id.is_some(),
+                        time.is_some(),
+                        input_json.is_some(),
+                    );
+                    None
+                }
+            }
+        });
+
+    Ok(iter)
+}
+
 fn rejected_examples_from_response(
     response: &SnowflakeStatementResponse,
 ) -> Result<impl Iterator<Item = Example> + '_> {
@@ -666,6 +869,37 @@ fn build_rejected_example(
     output: String,
     was_shown: bool,
     reason: String,
+) -> Example {
+    let rejected_patch = build_output_patch(
+        &input.cursor_path,
+        input.cursor_excerpt.as_ref(),
+        &input.editable_range_in_excerpt,
+        &output,
+    );
+    let mut example = build_example_from_snowflake(
+        request_id,
+        device_id,
+        time,
+        input,
+        vec![format!("rejection:{}", reason.to_lowercase())],
+        Some(RejectionInfo { reason, was_shown }),
+    );
+    example.spec.rejected_patch = Some(rejected_patch);
+    example
+}
+
+struct RejectionInfo {
+    reason: String,
+    was_shown: bool,
+}
+
+fn build_example_from_snowflake(
+    request_id: String,
+    device_id: String,
+    time: String,
+    input: ZetaPromptInput,
+    tags: Vec<String>,
+    rejection: Option<RejectionInfo>,
 ) -> Example {
     let events: Vec<CapturedEvent> = input
         .events
@@ -715,25 +949,23 @@ fn build_rejected_example(
         edit_history.push('\n');
     }
 
-    let rejected_patch = build_rejected_patch(
-        &input.cursor_path,
-        cursor_excerpt,
-        &input.editable_range_in_excerpt,
-        &output,
-    );
+    let (rejection_reason, was_shown) = match &rejection {
+        Some(r) => (r.reason.clone(), r.was_shown),
+        None => (String::new(), false),
+    };
 
     let spec = ExampleSpec {
         name: request_id.clone(),
         repository_url: String::new(),
         revision: String::new(),
-        tags: vec![format!("rejection:{}", reason.to_lowercase())],
+        tags,
         reasoning: None,
         uncommitted_diff: String::new(),
         cursor_path: input.cursor_path.clone(),
         cursor_position: build_cursor_position(cursor_excerpt, cursor_offset),
         edit_history,
         expected_patches: Vec::new(),
-        rejected_patch: Some(rejected_patch),
+        rejected_patch: None,
         captured_prompt_input: Some(CapturedPromptInput {
             cursor_file_content: cursor_excerpt.to_string(),
             cursor_offset,
@@ -746,7 +978,7 @@ fn build_rejected_example(
             request_id,
             device_id,
             time,
-            rejection_reason: reason,
+            rejection_reason,
             was_shown,
         }),
     };
@@ -784,7 +1016,7 @@ fn build_cursor_position(excerpt: &str, cursor_offset: usize) -> String {
     format!("{}[CURSOR_POSITION]{}", before, after)
 }
 
-fn build_rejected_patch(
+fn build_output_patch(
     cursor_path: &std::path::Path,
     cursor_excerpt: &str,
     editable_range: &std::ops::Range<usize>,

crates/edit_prediction_cli/src/qa.rs 🔗

@@ -172,7 +172,7 @@ impl QaClient {
                         cache_control: None,
                     }],
                 }];
-                let response = client.generate(model, max_tokens, messages).await?;
+                let response = client.generate(model, max_tokens, messages, None).await?;
                 Ok(response.map(|r| {
                     r.content
                         .iter()
@@ -188,7 +188,7 @@ impl QaClient {
                 let messages = vec![open_ai::RequestMessage::User {
                     content: open_ai::MessageContent::Plain(prompt.to_string()),
                 }];
-                let response = client.generate(model, max_tokens, messages).await?;
+                let response = client.generate(model, max_tokens, messages, None).await?;
                 Ok(response.map(|r| {
                     r.choices
                         .into_iter()

crates/edit_prediction_cli/src/repair.rs 🔗

@@ -152,7 +152,7 @@ impl RepairClient {
                         cache_control: None,
                     }],
                 }];
-                let response = client.generate(model, max_tokens, messages).await?;
+                let response = client.generate(model, max_tokens, messages, None).await?;
                 Ok(response.map(|r| {
                     r.content
                         .iter()
@@ -168,7 +168,7 @@ impl RepairClient {
                 let messages = vec![open_ai::RequestMessage::User {
                     content: open_ai::MessageContent::Plain(prompt.to_string()),
                 }];
-                let response = client.generate(model, max_tokens, messages).await?;
+                let response = client.generate(model, max_tokens, messages, None).await?;
                 Ok(response.map(|r| {
                     r.choices
                         .into_iter()