diff --git a/crates/edit_prediction_cli/src/anthropic_client.rs b/crates/edit_prediction_cli/src/anthropic_client.rs index 9758419860146c87de667ffd6f71278504912962..941e82c9dbf39186cd4e061f81bcfe71b2ad5ee0 100644 --- a/crates/edit_prediction_cli/src/anthropic_client.rs +++ b/crates/edit_prediction_cli/src/anthropic_client.rs @@ -11,6 +11,7 @@ use reqwest_client::ReqwestClient; use sqlez::bindable::Bind; use sqlez::bindable::StaticColumnCount; use sqlez_macros::sql; +use std::collections::HashSet; use std::hash::Hash; use std::hash::Hasher; use std::path::Path; @@ -517,6 +518,7 @@ impl BatchingLlmClient { let mut current_batch_rows = Vec::new(); let mut current_batch_size = 0usize; + let mut pending_hashes: HashSet = HashSet::new(); loop { let rows: Vec<(String, String)> = { let connection = self.connection.lock().unwrap(); @@ -534,9 +536,15 @@ impl BatchingLlmClient { // Split rows into sub-batches based on size let mut batches_to_upload = Vec::new(); + let mut new_rows_added = 0; for row in rows { let (hash, request_str) = row; + + // Skip rows already added to current_batch_rows but not yet uploaded + if pending_hashes.contains(&hash) { + continue; + } let serializable_request: SerializableRequest = serde_json::from_str(&request_str)?; let messages: Vec = serializable_request @@ -586,8 +594,16 @@ impl BatchingLlmClient { current_batch_size = 0; } + pending_hashes.insert(hash.clone()); current_batch_rows.push((hash, batch_request)); current_batch_size += estimated_size; + new_rows_added += 1; + } + + // If no new rows were added this iteration, all pending requests are already + // in current_batch_rows, so we should break to avoid an infinite loop + if new_rows_added == 0 { + break; } // Only upload full batches, keep the partial batch for the next iteration @@ -595,6 +611,11 @@ impl BatchingLlmClient { for (batch_rows, batch_size) in batches_to_upload { let request_hashes: Vec = batch_rows.iter().map(|(hash, _)| hash.clone()).collect(); + + // Remove uploaded hashes from pending set + for hash in &request_hashes { + pending_hashes.remove(hash); + } let batch_requests: Vec = batch_rows.into_iter().map(|(_, req)| req).collect(); diff --git a/crates/edit_prediction_cli/src/prompts/repair.md b/crates/edit_prediction_cli/src/prompts/repair.md index 8ff3390af29cef83f4b23c141b6a772fb326cfc2..d3d77c03924344ec4d7d47c3e92e7b7cfda2084d 100644 --- a/crates/edit_prediction_cli/src/prompts/repair.md +++ b/crates/edit_prediction_cli/src/prompts/repair.md @@ -6,6 +6,8 @@ Your previous prediction has quality issues that need to be addressed. Please ge {quality_feedback} +{token_change_info} + ## Your Previous Prediction (word-diff format) ````` diff --git a/crates/edit_prediction_cli/src/repair.rs b/crates/edit_prediction_cli/src/repair.rs index e1e588b0174ed9db5fdf52470cead38eea28019d..bf59604eb4b14d6860504151037fb09bffc37104 100644 --- a/crates/edit_prediction_cli/src/repair.rs +++ b/crates/edit_prediction_cli/src/repair.rs @@ -11,6 +11,7 @@ use crate::{ anthropic_client::AnthropicClient, example::{ActualCursor, Example, ExamplePrediction}, format_prompt::{TeacherPrompt, extract_last_codeblock}, + metrics::count_patch_token_changes, openai_client::OpenAiClient, parse_output::run_parse_output, paths::LLM_CACHE_DB, @@ -168,10 +169,26 @@ pub fn build_repair_message(example: &Example) -> Result { let actual_patch_word_diff = unified_to_word_diff(actual_patch); + let token_counts = count_patch_token_changes(actual_patch); + let mut token_change_info = format!( + "\n## Token Change Statistics\n\n\ + - **Deleted tokens**: {}\n\ + - **Inserted tokens**: {}", + token_counts.deleted_tokens, token_counts.inserted_tokens, + ); + if token_counts.deleted_tokens > 100 || token_counts.inserted_tokens > 100 { + token_change_info.push_str( + "\n\n> **Note:** The token change count is high. \ + Consider producing a more scoped edit that targets only the lines \ + that truly need to change, rather than rewriting large sections.", + ); + } + let prompt_template = crate::prompt_assets::get_prompt("repair.md"); Ok(prompt_template .replace("{actual_patch_word_diff}", &actual_patch_word_diff) - .replace("{quality_feedback}", &quality_feedback)) + .replace("{quality_feedback}", &quality_feedback) + .replace("{token_change_info}", &token_change_info)) } /// Check if an example needs repair based on QA feedback or computed scores.