Fix EP CLI issues found when generating new teacher predictions (#49327)

Max Brunsfeld created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/example_spec.rs          |   2 
crates/edit_prediction_cli/src/anthropic_client.rs  | 165 +++++++++++---
crates/edit_prediction_cli/src/reversal_tracking.rs |  73 ++----
crates/language/src/language.rs                     |   2 
crates/language/src/text_diff.rs                    |  44 ++++
5 files changed, 196 insertions(+), 90 deletions(-)

Detailed changes

crates/edit_prediction/src/example_spec.rs πŸ”—

@@ -66,6 +66,7 @@ pub struct CapturedPromptInput {
     pub excerpt_start_row: Option<u32>,
     pub events: Vec<CapturedEvent>,
     pub related_files: Vec<CapturedRelatedFile>,
+    #[serde(default)]
     pub in_open_source_repo: bool,
 }
 
@@ -75,6 +76,7 @@ pub struct CapturedEvent {
     pub old_path: Arc<Path>,
     pub diff: String,
     pub predicted: bool,
+    #[serde(default)]
     pub in_open_source_repo: bool,
 }
 

crates/edit_prediction_cli/src/anthropic_client.rs πŸ”—

@@ -511,9 +511,12 @@ impl BatchingLlmClient {
 
     async fn upload_pending_requests(&self) -> Result<Vec<String>> {
         const BATCH_CHUNK_SIZE: i32 = 16_000;
+        const MAX_BATCH_SIZE_BYTES: usize = 200 * 1024 * 1024; // 200MB (buffer below 256MB limit)
         let mut all_batch_ids = Vec::new();
         let mut total_uploaded = 0;
 
+        let mut current_batch_rows = Vec::new();
+        let mut current_batch_size = 0usize;
         loop {
             let rows: Vec<(String, String)> = {
                 let connection = self.connection.lock().unwrap();
@@ -529,52 +532,130 @@ impl BatchingLlmClient {
                 break;
             }
 
-            let request_hashes: Vec<String> = rows.iter().map(|(hash, _)| hash.clone()).collect();
+            // Split rows into sub-batches based on size
+            let mut batches_to_upload = Vec::new();
+
+            for row in rows {
+                let (hash, request_str) = row;
+                let serializable_request: SerializableRequest = serde_json::from_str(&request_str)?;
+
+                let messages: Vec<Message> = serializable_request
+                    .messages
+                    .into_iter()
+                    .map(|msg| Message {
+                        role: match msg.role.as_str() {
+                            "user" => Role::User,
+                            "assistant" => Role::Assistant,
+                            _ => Role::User,
+                        },
+                        content: vec![RequestContent::Text {
+                            text: msg.content,
+                            cache_control: None,
+                        }],
+                    })
+                    .collect();
+
+                let params = AnthropicRequest {
+                    model: serializable_request.model,
+                    max_tokens: serializable_request.max_tokens,
+                    messages,
+                    tools: Vec::new(),
+                    thinking: None,
+                    tool_choice: None,
+                    system: None,
+                    metadata: None,
+                    output_config: None,
+                    stop_sequences: Vec::new(),
+                    temperature: None,
+                    top_k: None,
+                    top_p: None,
+                };
+
+                let custom_id = format!("req_hash_{}", hash);
+                let batch_request = anthropic::batches::BatchRequest { custom_id, params };
+
+                // Estimate the serialized size of this request
+                let estimated_size = serde_json::to_string(&batch_request)?.len();
+
+                // If adding this request would exceed the limit, start a new batch
+                if !current_batch_rows.is_empty()
+                    && current_batch_size + estimated_size > MAX_BATCH_SIZE_BYTES
+                {
+                    batches_to_upload.push((current_batch_rows, current_batch_size));
+                    current_batch_rows = Vec::new();
+                    current_batch_size = 0;
+                }
+
+                current_batch_rows.push((hash, batch_request));
+                current_batch_size += estimated_size;
+            }
 
-            let batch_requests = rows
+            // Only upload full batches, keep the partial batch for the next iteration
+            // Upload each sub-batch
+            for (batch_rows, batch_size) in batches_to_upload {
+                let request_hashes: Vec<String> =
+                    batch_rows.iter().map(|(hash, _)| hash.clone()).collect();
+                let batch_requests: Vec<anthropic::batches::BatchRequest> =
+                    batch_rows.into_iter().map(|(_, req)| req).collect();
+
+                let batch_len = batch_requests.len();
+                log::info!(
+                    "Uploading batch with {} requests (~{:.2} MB)",
+                    batch_len,
+                    batch_size as f64 / (1024.0 * 1024.0)
+                );
+
+                let batch = anthropic::batches::create_batch(
+                    self.http_client.as_ref(),
+                    ANTHROPIC_API_URL,
+                    &self.api_key,
+                    anthropic::batches::CreateBatchRequest {
+                        requests: batch_requests,
+                    },
+                )
+                .await
+                .map_err(|e| anyhow::anyhow!("{:?}", e))?;
+
+                {
+                    let connection = self.connection.lock().unwrap();
+                    connection.with_savepoint("batch_upload", || {
+                        let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?);
+                        let mut exec = connection.exec_bound::<(&str, &str)>(q)?;
+                        for hash in &request_hashes {
+                            exec((batch.id.as_str(), hash.as_str()))?;
+                        }
+                        Ok(())
+                    })?;
+                }
+
+                total_uploaded += batch_len;
+                log::info!(
+                    "Uploaded batch {} with {} requests ({} total)",
+                    batch.id,
+                    batch_len,
+                    total_uploaded
+                );
+
+                all_batch_ids.push(batch.id);
+            }
+        }
+
+        // Upload any remaining partial batch at the end
+        if !current_batch_rows.is_empty() {
+            let request_hashes: Vec<String> = current_batch_rows
                 .iter()
-                .map(|(hash, request_str)| {
-                    let serializable_request: SerializableRequest =
-                        serde_json::from_str(&request_str).unwrap();
-
-                    let messages: Vec<Message> = serializable_request
-                        .messages
-                        .into_iter()
-                        .map(|msg| Message {
-                            role: match msg.role.as_str() {
-                                "user" => Role::User,
-                                "assistant" => Role::Assistant,
-                                _ => Role::User,
-                            },
-                            content: vec![RequestContent::Text {
-                                text: msg.content,
-                                cache_control: None,
-                            }],
-                        })
-                        .collect();
-
-                    let params = AnthropicRequest {
-                        model: serializable_request.model,
-                        max_tokens: serializable_request.max_tokens,
-                        messages,
-                        tools: Vec::new(),
-                        thinking: None,
-                        tool_choice: None,
-                        system: None,
-                        metadata: None,
-                        output_config: None,
-                        stop_sequences: Vec::new(),
-                        temperature: None,
-                        top_k: None,
-                        top_p: None,
-                    };
-
-                    let custom_id = format!("req_hash_{}", hash);
-                    anthropic::batches::BatchRequest { custom_id, params }
-                })
-                .collect::<Vec<_>>();
+                .map(|(hash, _)| hash.clone())
+                .collect();
+            let batch_requests: Vec<anthropic::batches::BatchRequest> =
+                current_batch_rows.into_iter().map(|(_, req)| req).collect();
 
             let batch_len = batch_requests.len();
+            log::info!(
+                "Uploading final batch with {} requests (~{:.2} MB)",
+                batch_len,
+                current_batch_size as f64 / (1024.0 * 1024.0)
+            );
+
             let batch = anthropic::batches::create_batch(
                 self.http_client.as_ref(),
                 ANTHROPIC_API_URL,

crates/edit_prediction_cli/src/reversal_tracking.rs πŸ”—

@@ -3,7 +3,7 @@ use std::path::Path;
 use std::sync::Arc;
 
 use edit_prediction::udiff::apply_diff_to_string;
-use language::text_diff;
+use language::{char_diff, text_diff};
 
 use crate::example::ExamplePromptInputs;
 
@@ -417,17 +417,6 @@ impl ReversalOverlap {
     }
 }
 
-/// Check if `needle` is a subsequence of `haystack` (characters appear in order, not necessarily contiguous).
-fn is_subsequence(needle: &str, haystack: &str) -> bool {
-    let mut needle_chars = needle.chars().peekable();
-    for c in haystack.chars() {
-        if needle_chars.peek() == Some(&c) {
-            needle_chars.next();
-        }
-    }
-    needle_chars.peek().is_none()
-}
-
 /// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension),
 /// or where `new_text` appears as a subsequence within `old_text` (reduction).
 ///
@@ -442,31 +431,35 @@ fn is_subsequence(needle: &str, haystack: &str) -> bool {
 fn normalize_extension_edits(edits: Vec<GranularEdit>) -> Vec<GranularEdit> {
     edits
         .into_iter()
-        .map(|edit| {
+        .flat_map(|edit| {
             if edit.old_text.is_empty() || edit.new_text.is_empty() {
-                return edit;
+                return vec![edit];
             }
 
-            if is_subsequence(&edit.old_text, &edit.new_text) {
-                let inserted_char_count =
-                    edit.new_text.chars().count() - edit.old_text.chars().count();
-                GranularEdit {
-                    range: edit.range.start..edit.range.start,
-                    old_text: String::new(),
-                    new_text: edit.new_text.chars().take(inserted_char_count).collect(),
-                }
-            } else if is_subsequence(&edit.new_text, &edit.old_text) {
-                let deleted_char_count =
-                    edit.old_text.chars().count() - edit.new_text.chars().count();
-                let deleted_text: String = edit.old_text.chars().take(deleted_char_count).collect();
-                GranularEdit {
-                    range: edit.range.start..edit.range.start + deleted_text.len(),
-                    old_text: deleted_text,
-                    new_text: String::new(),
-                }
-            } else {
-                edit
+            // Use character-wise diff to find exact byte ranges of changes
+            let char_edits = char_diff(&edit.old_text, &edit.new_text);
+
+            let all_deletions = !char_edits.is_empty()
+                && char_edits
+                    .iter()
+                    .all(|(range, replacement)| !range.is_empty() && replacement.is_empty());
+            let all_insertions = !char_edits.is_empty()
+                && char_edits
+                    .iter()
+                    .all(|(range, replacement)| range.is_empty() && !replacement.is_empty());
+            if all_deletions || all_insertions {
+                return char_edits
+                    .into_iter()
+                    .map(|(range, replacement)| GranularEdit {
+                        range: edit.range.start + range.start..edit.range.start + range.end,
+                        old_text: edit.old_text[range].to_string(),
+                        new_text: replacement.to_string(),
+                    })
+                    .collect();
             }
+
+            // Otherwise, keep the original edit (mixed changes)
+            vec![edit]
         })
         .collect()
 }
@@ -1718,20 +1711,6 @@ mod tests {
         }
     }
 
-    #[test]
-    fn test_is_subsequence() {
-        assert!(is_subsequence("", "anything"));
-        assert!(is_subsequence("", ""));
-        assert!(is_subsequence("abc", "abc"));
-        assert!(is_subsequence("abc", "aXbXc"));
-        assert!(is_subsequence("ac", "abc"));
-        assert!(!is_subsequence("abc", "ab"));
-        assert!(!is_subsequence("abc", "cba"));
-        assert!(!is_subsequence("abc", ""));
-        assert!(is_subsequence("ζ—₯本", "ζ—₯X本Yθͺž"));
-        assert!(!is_subsequence("ζ—₯本θͺž", "ζ—₯本"));
-    }
-
     #[test]
     fn test_compute_lcs_length() {
         assert_eq!(compute_lcs_length("", ""), 0);

crates/language/src/language.rs πŸ”—

@@ -66,7 +66,7 @@ use syntax_map::{QueryCursorHandle, SyntaxSnapshot};
 use task::RunnableTag;
 pub use task_context::{ContextLocation, ContextProvider, RunnableRange};
 pub use text_diff::{
-    DiffOptions, apply_diff_patch, apply_reversed_diff_patch, line_diff, text_diff,
+    DiffOptions, apply_diff_patch, apply_reversed_diff_patch, char_diff, line_diff, text_diff,
     text_diff_with_options, unified_diff, unified_diff_with_context, unified_diff_with_offsets,
     word_diff_ranges,
 };

crates/language/src/text_diff.rs πŸ”—

@@ -219,6 +219,25 @@ pub fn word_diff_ranges(
     (old_ranges, new_ranges)
 }
 
+/// Computes character-level diff between two strings.
+///
+/// Usually, you should use `text_diff`, which performs a word-wise diff.
+pub fn char_diff<'a>(old_text: &'a str, new_text: &'a str) -> Vec<(Range<usize>, &'a str)> {
+    let mut input: InternedInput<&str> = InternedInput::default();
+    input.update_before(tokenize_chars(old_text));
+    input.update_after(tokenize_chars(new_text));
+    let mut edits: Vec<(Range<usize>, &str)> = Vec::new();
+    diff_internal(&input, |old_byte_range, new_byte_range, _, _| {
+        let replacement = if new_byte_range.is_empty() {
+            ""
+        } else {
+            &new_text[new_byte_range]
+        };
+        edits.push((old_byte_range, replacement));
+    });
+    edits
+}
+
 pub struct DiffOptions {
     pub language_scope: Option<LanguageScope>,
     pub max_word_diff_len: usize,
@@ -361,6 +380,14 @@ fn diff_internal(
     );
 }
 
+fn tokenize_chars(text: &str) -> impl Iterator<Item = &str> {
+    let mut chars = text.char_indices().peekable();
+    iter::from_fn(move || {
+        let (start, c) = chars.next()?;
+        Some(&text[start..start + c.len_utf8()])
+    })
+}
+
 fn tokenize(text: &str, language_scope: Option<LanguageScope>) -> impl Iterator<Item = &str> {
     let classifier =
         CharClassifier::new(language_scope).scope_context(Some(CharScopeContext::Completion));
@@ -479,6 +506,23 @@ mod tests {
         );
     }
 
+    #[test]
+    fn test_char_diff() {
+        assert_eq!(char_diff("", ""), vec![]);
+        assert_eq!(char_diff("", "abc"), vec![(0..0, "abc")]);
+        assert_eq!(char_diff("abc", ""), vec![(0..3, "")]);
+        assert_eq!(char_diff("ac", "abc"), vec![(1..1, "b")]); // "b" inserted
+        assert_eq!(char_diff("abc", "ac"), vec![(1..2, "")]); // "b" deleted
+        assert_eq!(char_diff("abc", "adc"), vec![(1..2, "d")]); // "b" replaced with "d"
+        assert_eq!(char_diff("ζ—₯", "ζ—₯本θͺž"), vec![(3..3, "本θͺž")]); // "本θͺž" inserted
+        assert_eq!(char_diff("ζ—₯本θͺž", "ζ—₯"), vec![(3..9, "")]); // "本θͺž" deleted
+        assert_eq!(char_diff("πŸŽ‰", "πŸŽ‰πŸŽŠπŸŽˆ"), vec![(4..4, "🎊🎈")]); // "🎊🎈" inserted
+        assert_eq!(
+            char_diff("testζ—₯本", "testζ—₯本θͺžγ§γ™"),
+            vec![(10..10, "θͺžγ§γ™")]
+        );
+    }
+
     #[test]
     fn test_unified_diff_with_offsets() {
         let old_text = "foo\nbar\nbaz\n";