From 806e944e25bb60133b4ce5a353ba09e2e64e4800 Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Mon, 16 Feb 2026 23:33:34 -0800 Subject: [PATCH] Fix EP CLI issues found when generating new teacher predictions (#49327) Release Notes: - N/A --- crates/edit_prediction/src/example_spec.rs | 2 + .../src/anthropic_client.rs | 165 +++++++++++++----- .../src/reversal_tracking.rs | 73 +++----- crates/language/src/language.rs | 2 +- crates/language/src/text_diff.rs | 44 +++++ 5 files changed, 196 insertions(+), 90 deletions(-) diff --git a/crates/edit_prediction/src/example_spec.rs b/crates/edit_prediction/src/example_spec.rs index c6609e5f1f42f21eb165488f85575f2c50fcd1e0..81e786670056814482fc0642a8ea79546366f2ed 100644 --- a/crates/edit_prediction/src/example_spec.rs +++ b/crates/edit_prediction/src/example_spec.rs @@ -66,6 +66,7 @@ pub struct CapturedPromptInput { pub excerpt_start_row: Option, pub events: Vec, pub related_files: Vec, + #[serde(default)] pub in_open_source_repo: bool, } @@ -75,6 +76,7 @@ pub struct CapturedEvent { pub old_path: Arc, pub diff: String, pub predicted: bool, + #[serde(default)] pub in_open_source_repo: bool, } diff --git a/crates/edit_prediction_cli/src/anthropic_client.rs b/crates/edit_prediction_cli/src/anthropic_client.rs index 20f40a240e819885179d7175d6c7e4d9e266130d..9758419860146c87de667ffd6f71278504912962 100644 --- a/crates/edit_prediction_cli/src/anthropic_client.rs +++ b/crates/edit_prediction_cli/src/anthropic_client.rs @@ -511,9 +511,12 @@ impl BatchingLlmClient { async fn upload_pending_requests(&self) -> Result> { const BATCH_CHUNK_SIZE: i32 = 16_000; + const MAX_BATCH_SIZE_BYTES: usize = 200 * 1024 * 1024; // 200MB (buffer below 256MB limit) let mut all_batch_ids = Vec::new(); let mut total_uploaded = 0; + let mut current_batch_rows = Vec::new(); + let mut current_batch_size = 0usize; loop { let rows: Vec<(String, String)> = { let connection = self.connection.lock().unwrap(); @@ -529,52 +532,130 @@ impl BatchingLlmClient { break; } - let request_hashes: Vec = rows.iter().map(|(hash, _)| hash.clone()).collect(); + // Split rows into sub-batches based on size + let mut batches_to_upload = Vec::new(); + + for row in rows { + let (hash, request_str) = row; + let serializable_request: SerializableRequest = serde_json::from_str(&request_str)?; + + let messages: Vec = serializable_request + .messages + .into_iter() + .map(|msg| Message { + role: match msg.role.as_str() { + "user" => Role::User, + "assistant" => Role::Assistant, + _ => Role::User, + }, + content: vec![RequestContent::Text { + text: msg.content, + cache_control: None, + }], + }) + .collect(); + + let params = AnthropicRequest { + model: serializable_request.model, + max_tokens: serializable_request.max_tokens, + messages, + tools: Vec::new(), + thinking: None, + tool_choice: None, + system: None, + metadata: None, + output_config: None, + stop_sequences: Vec::new(), + temperature: None, + top_k: None, + top_p: None, + }; + + let custom_id = format!("req_hash_{}", hash); + let batch_request = anthropic::batches::BatchRequest { custom_id, params }; + + // Estimate the serialized size of this request + let estimated_size = serde_json::to_string(&batch_request)?.len(); + + // If adding this request would exceed the limit, start a new batch + if !current_batch_rows.is_empty() + && current_batch_size + estimated_size > MAX_BATCH_SIZE_BYTES + { + batches_to_upload.push((current_batch_rows, current_batch_size)); + current_batch_rows = Vec::new(); + current_batch_size = 0; + } + + current_batch_rows.push((hash, batch_request)); + current_batch_size += estimated_size; + } - let batch_requests = rows + // Only upload full batches, keep the partial batch for the next iteration + // Upload each sub-batch + for (batch_rows, batch_size) in batches_to_upload { + let request_hashes: Vec = + batch_rows.iter().map(|(hash, _)| hash.clone()).collect(); + let batch_requests: Vec = + batch_rows.into_iter().map(|(_, req)| req).collect(); + + let batch_len = batch_requests.len(); + log::info!( + "Uploading batch with {} requests (~{:.2} MB)", + batch_len, + batch_size as f64 / (1024.0 * 1024.0) + ); + + let batch = anthropic::batches::create_batch( + self.http_client.as_ref(), + ANTHROPIC_API_URL, + &self.api_key, + anthropic::batches::CreateBatchRequest { + requests: batch_requests, + }, + ) + .await + .map_err(|e| anyhow::anyhow!("{:?}", e))?; + + { + let connection = self.connection.lock().unwrap(); + connection.with_savepoint("batch_upload", || { + let q = sql!(UPDATE cache SET batch_id = ? WHERE request_hash = ?); + let mut exec = connection.exec_bound::<(&str, &str)>(q)?; + for hash in &request_hashes { + exec((batch.id.as_str(), hash.as_str()))?; + } + Ok(()) + })?; + } + + total_uploaded += batch_len; + log::info!( + "Uploaded batch {} with {} requests ({} total)", + batch.id, + batch_len, + total_uploaded + ); + + all_batch_ids.push(batch.id); + } + } + + // Upload any remaining partial batch at the end + if !current_batch_rows.is_empty() { + let request_hashes: Vec = current_batch_rows .iter() - .map(|(hash, request_str)| { - let serializable_request: SerializableRequest = - serde_json::from_str(&request_str).unwrap(); - - let messages: Vec = serializable_request - .messages - .into_iter() - .map(|msg| Message { - role: match msg.role.as_str() { - "user" => Role::User, - "assistant" => Role::Assistant, - _ => Role::User, - }, - content: vec![RequestContent::Text { - text: msg.content, - cache_control: None, - }], - }) - .collect(); - - let params = AnthropicRequest { - model: serializable_request.model, - max_tokens: serializable_request.max_tokens, - messages, - tools: Vec::new(), - thinking: None, - tool_choice: None, - system: None, - metadata: None, - output_config: None, - stop_sequences: Vec::new(), - temperature: None, - top_k: None, - top_p: None, - }; - - let custom_id = format!("req_hash_{}", hash); - anthropic::batches::BatchRequest { custom_id, params } - }) - .collect::>(); + .map(|(hash, _)| hash.clone()) + .collect(); + let batch_requests: Vec = + current_batch_rows.into_iter().map(|(_, req)| req).collect(); let batch_len = batch_requests.len(); + log::info!( + "Uploading final batch with {} requests (~{:.2} MB)", + batch_len, + current_batch_size as f64 / (1024.0 * 1024.0) + ); + let batch = anthropic::batches::create_batch( self.http_client.as_ref(), ANTHROPIC_API_URL, diff --git a/crates/edit_prediction_cli/src/reversal_tracking.rs b/crates/edit_prediction_cli/src/reversal_tracking.rs index 139730bbe3a8f5788e7ef0aa8aaf32344802c685..2c82b749e929ee85bde3541565f1d80d81285f37 100644 --- a/crates/edit_prediction_cli/src/reversal_tracking.rs +++ b/crates/edit_prediction_cli/src/reversal_tracking.rs @@ -3,7 +3,7 @@ use std::path::Path; use std::sync::Arc; use edit_prediction::udiff::apply_diff_to_string; -use language::text_diff; +use language::{char_diff, text_diff}; use crate::example::ExamplePromptInputs; @@ -417,17 +417,6 @@ impl ReversalOverlap { } } -/// Check if `needle` is a subsequence of `haystack` (characters appear in order, not necessarily contiguous). -fn is_subsequence(needle: &str, haystack: &str) -> bool { - let mut needle_chars = needle.chars().peekable(); - for c in haystack.chars() { - if needle_chars.peek() == Some(&c) { - needle_chars.next(); - } - } - needle_chars.peek().is_none() -} - /// Normalize edits where `old_text` appears as a subsequence within `new_text` (extension), /// or where `new_text` appears as a subsequence within `old_text` (reduction). /// @@ -442,31 +431,35 @@ fn is_subsequence(needle: &str, haystack: &str) -> bool { fn normalize_extension_edits(edits: Vec) -> Vec { edits .into_iter() - .map(|edit| { + .flat_map(|edit| { if edit.old_text.is_empty() || edit.new_text.is_empty() { - return edit; + return vec![edit]; } - if is_subsequence(&edit.old_text, &edit.new_text) { - let inserted_char_count = - edit.new_text.chars().count() - edit.old_text.chars().count(); - GranularEdit { - range: edit.range.start..edit.range.start, - old_text: String::new(), - new_text: edit.new_text.chars().take(inserted_char_count).collect(), - } - } else if is_subsequence(&edit.new_text, &edit.old_text) { - let deleted_char_count = - edit.old_text.chars().count() - edit.new_text.chars().count(); - let deleted_text: String = edit.old_text.chars().take(deleted_char_count).collect(); - GranularEdit { - range: edit.range.start..edit.range.start + deleted_text.len(), - old_text: deleted_text, - new_text: String::new(), - } - } else { - edit + // Use character-wise diff to find exact byte ranges of changes + let char_edits = char_diff(&edit.old_text, &edit.new_text); + + let all_deletions = !char_edits.is_empty() + && char_edits + .iter() + .all(|(range, replacement)| !range.is_empty() && replacement.is_empty()); + let all_insertions = !char_edits.is_empty() + && char_edits + .iter() + .all(|(range, replacement)| range.is_empty() && !replacement.is_empty()); + if all_deletions || all_insertions { + return char_edits + .into_iter() + .map(|(range, replacement)| GranularEdit { + range: edit.range.start + range.start..edit.range.start + range.end, + old_text: edit.old_text[range].to_string(), + new_text: replacement.to_string(), + }) + .collect(); } + + // Otherwise, keep the original edit (mixed changes) + vec![edit] }) .collect() } @@ -1718,20 +1711,6 @@ mod tests { } } - #[test] - fn test_is_subsequence() { - assert!(is_subsequence("", "anything")); - assert!(is_subsequence("", "")); - assert!(is_subsequence("abc", "abc")); - assert!(is_subsequence("abc", "aXbXc")); - assert!(is_subsequence("ac", "abc")); - assert!(!is_subsequence("abc", "ab")); - assert!(!is_subsequence("abc", "cba")); - assert!(!is_subsequence("abc", "")); - assert!(is_subsequence("日本", "日X本Y語")); - assert!(!is_subsequence("日本語", "日本")); - } - #[test] fn test_compute_lcs_length() { assert_eq!(compute_lcs_length("", ""), 0); diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 20dd639506afec2cbbab0d7cd5b7c2a94032b752..fd14f42a93179ae0423f5acfa6ede3cceec94935 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -66,7 +66,7 @@ use syntax_map::{QueryCursorHandle, SyntaxSnapshot}; use task::RunnableTag; pub use task_context::{ContextLocation, ContextProvider, RunnableRange}; pub use text_diff::{ - DiffOptions, apply_diff_patch, apply_reversed_diff_patch, line_diff, text_diff, + DiffOptions, apply_diff_patch, apply_reversed_diff_patch, char_diff, line_diff, text_diff, text_diff_with_options, unified_diff, unified_diff_with_context, unified_diff_with_offsets, word_diff_ranges, }; diff --git a/crates/language/src/text_diff.rs b/crates/language/src/text_diff.rs index c46796827242f0a483c0b31416b89f3e73fa14e8..c084446ea9e41cc980b0ee8b59aec0034b06cb30 100644 --- a/crates/language/src/text_diff.rs +++ b/crates/language/src/text_diff.rs @@ -219,6 +219,25 @@ pub fn word_diff_ranges( (old_ranges, new_ranges) } +/// Computes character-level diff between two strings. +/// +/// Usually, you should use `text_diff`, which performs a word-wise diff. +pub fn char_diff<'a>(old_text: &'a str, new_text: &'a str) -> Vec<(Range, &'a str)> { + let mut input: InternedInput<&str> = InternedInput::default(); + input.update_before(tokenize_chars(old_text)); + input.update_after(tokenize_chars(new_text)); + let mut edits: Vec<(Range, &str)> = Vec::new(); + diff_internal(&input, |old_byte_range, new_byte_range, _, _| { + let replacement = if new_byte_range.is_empty() { + "" + } else { + &new_text[new_byte_range] + }; + edits.push((old_byte_range, replacement)); + }); + edits +} + pub struct DiffOptions { pub language_scope: Option, pub max_word_diff_len: usize, @@ -361,6 +380,14 @@ fn diff_internal( ); } +fn tokenize_chars(text: &str) -> impl Iterator { + let mut chars = text.char_indices().peekable(); + iter::from_fn(move || { + let (start, c) = chars.next()?; + Some(&text[start..start + c.len_utf8()]) + }) +} + fn tokenize(text: &str, language_scope: Option) -> impl Iterator { let classifier = CharClassifier::new(language_scope).scope_context(Some(CharScopeContext::Completion)); @@ -479,6 +506,23 @@ mod tests { ); } + #[test] + fn test_char_diff() { + assert_eq!(char_diff("", ""), vec![]); + assert_eq!(char_diff("", "abc"), vec![(0..0, "abc")]); + assert_eq!(char_diff("abc", ""), vec![(0..3, "")]); + assert_eq!(char_diff("ac", "abc"), vec![(1..1, "b")]); // "b" inserted + assert_eq!(char_diff("abc", "ac"), vec![(1..2, "")]); // "b" deleted + assert_eq!(char_diff("abc", "adc"), vec![(1..2, "d")]); // "b" replaced with "d" + assert_eq!(char_diff("日", "日本語"), vec![(3..3, "本語")]); // "本語" inserted + assert_eq!(char_diff("日本語", "日"), vec![(3..9, "")]); // "本語" deleted + assert_eq!(char_diff("🎉", "🎉🎊🎈"), vec![(4..4, "🎊🎈")]); // "🎊🎈" inserted + assert_eq!( + char_diff("test日本", "test日本語です"), + vec![(10..10, "語です")] + ); + } + #[test] fn test_unified_diff_with_offsets() { let old_text = "foo\nbar\nbaz\n";