ep: Add `qa` subcommand to check predictions quality (#47520)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction_cli/src/main.rs      |  36 ++
crates/edit_prediction_cli/src/qa.rs        | 395 +++++++++++++++++++++++
crates/edit_prediction_cli/src/word_diff.rs | 343 +++++++++++++++++++
3 files changed, 773 insertions(+), 1 deletion(-)

Detailed changes

crates/edit_prediction_cli/src/main.rs 🔗

@@ -12,12 +12,14 @@ mod paths;
 mod predict;
 mod progress;
 mod pull_examples;
+mod qa;
 mod reorder_patch;
 mod retrieve_context;
 mod score;
 mod split_commit;
 mod split_dataset;
 mod synthesize;
+mod word_diff;
 use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
 use collections::HashSet;
 use edit_prediction::EditPredictionStore;
@@ -159,6 +161,8 @@ enum Command {
     FilterLanguages(FilterLanguagesArgs),
     /// Import Anthropic batch results by batch IDs (useful for recovering after database loss)
     ImportBatch(ImportBatchArgs),
+    /// Assess the quality of predictions using LLM-as-a-judge
+    Qa(qa::QaArgs),
 }
 
 impl Display for Command {
@@ -194,6 +198,9 @@ impl Display for Command {
             Command::ImportBatch(args) => {
                 write!(f, "import-batch --batch-ids {}", args.batch_ids.join(" "))
             }
+            Command::Qa(_) => {
+                write!(f, "qa")
+            }
         }
     }
 }
@@ -558,6 +565,32 @@ fn main() {
             }
             return;
         }
+        Command::Qa(qa_args) => {
+            // Read examples from input files
+            let mut examples = example::read_example_files(&args.inputs);
+
+            // Apply filters
+            if let Some(name_filter) = &args.name {
+                examples.retain(|e| e.spec.name.contains(name_filter));
+            }
+            if let Some(repo_filter) = &args.repo {
+                examples.retain(|e| e.spec.repository_url.contains(repo_filter));
+            }
+            if let Some(offset) = args.offset {
+                examples.splice(0..offset, []);
+            }
+            if let Some(limit) = args.limit {
+                examples.truncate(limit);
+            }
+
+            smol::block_on(async {
+                if let Err(e) = qa::run_qa(&mut examples, qa_args, output.as_ref()).await {
+                    eprintln!("Error: {:?}", e);
+                    std::process::exit(1);
+                }
+            });
+            return;
+        }
         _ => {}
     }
 
@@ -724,7 +757,8 @@ fn main() {
                                         | Command::SplitCommit(_)
                                         | Command::Split(_)
                                         | Command::FilterLanguages(_)
-                                        | Command::ImportBatch(_) => {
+                                        | Command::ImportBatch(_)
+                                        | Command::Qa(_) => {
                                             unreachable!()
                                         }
                                     }

crates/edit_prediction_cli/src/qa.rs 🔗

@@ -0,0 +1,395 @@
+//! Quality assessment of predictions using LLM-as-a-judge.
+//!
+//! This module uses the Anthropic Batch API to evaluate prediction quality.
+//! Caching is handled by the underlying AnthropicClient.
+
+use crate::anthropic_client::AnthropicClient;
+use crate::example::Example;
+use crate::paths::CACHE_DIR;
+use crate::word_diff::unified_to_word_diff;
+use anthropic::{Message, RequestContent, Role};
+use anyhow::Result;
+use serde::{Deserialize, Serialize};
+use std::io::{BufWriter, Write};
+use std::path::PathBuf;
+use std::sync::LazyLock;
+
+/// Model to use for QA evaluation.
+const MODEL: &str = "claude-sonnet-4-5";
+
+/// Path to the QA cache database.
+pub static QA_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("qa_cache.sqlite"));
+
+/// Arguments for the QA command.
+#[derive(Debug, Clone, clap::Args)]
+pub struct QaArgs {
+    /// Use synchronous API instead of batch
+    #[clap(long)]
+    pub no_batch: bool,
+
+    /// Wait for batch to complete (polls every 30s)
+    #[clap(long)]
+    pub wait: bool,
+}
+
+/// Result of QA evaluation for a single prediction.
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct QaResult {
+    /// Free-form reasoning from the judge.
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub reasoning: Option<String>,
+
+    /// Does the prediction undo/revert changes the user intentionally made?
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub reverts_edits: Option<bool>,
+
+    /// Confidence score (1-5) for user acceptance likelihood.
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub confidence: Option<u8>,
+
+    /// The raw response from the model.
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub response: Option<String>,
+
+    /// Error message if parsing or request failed.
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub error: Option<String>,
+}
+
+/// Build the assessment prompt for an example.
+pub fn build_prompt(example: &Example) -> Option<String> {
+    let prediction = example.predictions.first()?;
+    let actual_patch = prediction.actual_patch.as_ref()?;
+    let prompt_inputs = example.prompt_inputs.as_ref()?;
+
+    let actual_patch_word_diff = unified_to_word_diff(actual_patch);
+
+    let mut edit_history = String::new();
+    for event in &prompt_inputs.edit_history {
+        match event.as_ref() {
+            zeta_prompt::Event::BufferChange {
+                path,
+                old_path,
+                diff,
+                predicted: _,
+                in_open_source_repo: _,
+            } => {
+                edit_history.push_str(&format!("--- a{}\n", old_path.display()));
+                edit_history.push_str(&format!("+++ b{}\n", path.display()));
+                let diff_word_diff = unified_to_word_diff(diff);
+                edit_history.push_str(&diff_word_diff);
+                edit_history.push_str("\n\n");
+            }
+        }
+    }
+
+    Some(format!(
+        r#"
+You are evaluating an edit prediction model for a code editor. The model observes a programmer's recent edit history and predicts what edit they will make next.
+
+All diffs are in the word-diff format.
+
+The model is instructed to:
+- Complete partially-applied refactoring or changes
+- Maintain consistency with established patterns and style
+- NOT delete or revert text that was just added (unless the user explicitly undid it themselves)
+
+## Edit History (chronological)
+```````
+{edit_history}
+```````
+
+## Predicted Next Edit
+```````
+{actual_patch_word_diff}
+```````
+
+## Evaluate
+
+1. **reverts_edits**: Does the prediction undo, or revert changes the user intentionally made in the **edit history**?
+
+2. **confidence**: How likely is the user to accept this suggestion?
+   - 1 = Definitely reject (wrong, nonsensical, or harmful)
+   - 2 = Probably reject (doesn't fit intent or pattern)
+   - 3 = Uncertain (plausible but not clearly correct)
+   - 4 = Probably accept (reasonable next step)
+   - 5 = Definitely accept (obvious continuation)
+
+Output JSON in this format:
+
+```
+{{
+    "reasoning": "your reasoning here",
+    "reverts_edits": true/false,
+    "confidence": 1-5
+}}
+```
+"#
+    ))
+}
+
+/// Extract a code block from a response.
+fn extract_codeblock(response: &str) -> Option<String> {
+    let lines: Vec<&str> = response.lines().collect();
+    for (i, line) in lines.iter().enumerate() {
+        if line.starts_with("```") {
+            let start = i + 1;
+            for (j, end_line) in lines[start..].iter().enumerate() {
+                if end_line.starts_with("```") {
+                    return Some(lines[start..start + j].join("\n"));
+                }
+            }
+            return Some(lines[start..].join("\n"));
+        }
+    }
+    None
+}
+
+/// Parse the LLM response into a QaResult.
+fn parse_response(response_text: &str) -> QaResult {
+    let codeblock = extract_codeblock(response_text);
+
+    // Try parsing codeblock first, then fall back to raw response
+    for text_to_parse in [codeblock.as_deref(), Some(response_text.trim())] {
+        let Some(text) = text_to_parse else {
+            continue;
+        };
+
+        if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(text) {
+            return QaResult {
+                reasoning: parsed
+                    .get("reasoning")
+                    .and_then(|v| v.as_str())
+                    .map(|s| s.to_string()),
+                reverts_edits: parsed.get("reverts_edits").and_then(|v| v.as_bool()),
+                confidence: parsed
+                    .get("confidence")
+                    .and_then(|v| v.as_u64())
+                    .map(|v| v as u8),
+                response: Some(response_text.to_string()),
+                error: None,
+            };
+        }
+    }
+
+    // If all parsing attempts fail, return error
+    QaResult {
+        reasoning: Some(response_text.to_string()),
+        reverts_edits: None,
+        confidence: None,
+        response: Some(response_text.to_string()),
+        error: Some("Could not parse JSON from response".to_string()),
+    }
+}
+
+/// Run the QA evaluation on a set of examples.
+pub async fn run_qa(
+    examples: &mut [Example],
+    args: &QaArgs,
+    output_path: Option<&PathBuf>,
+) -> Result<()> {
+    let client = if args.no_batch {
+        AnthropicClient::plain()?
+    } else {
+        AnthropicClient::batch(&QA_CACHE_DB)?
+    };
+
+    eprintln!("Using model: {}, batching: {}", MODEL, !args.no_batch);
+
+    // First pass: send requests (client handles caching internally)
+    let mut prompts: Vec<(usize, String)> = Vec::new();
+    let mut skipped_count = 0;
+
+    for (idx, example) in examples.iter().enumerate() {
+        let Some(prompt) = build_prompt(example) else {
+            skipped_count += 1;
+            continue;
+        };
+        prompts.push((idx, prompt));
+    }
+
+    if skipped_count > 0 {
+        eprintln!("Skipping {} items with missing actual_patch", skipped_count);
+    }
+
+    eprintln!("{} items to process", prompts.len());
+
+    // Process all items
+    let mut results: Vec<(usize, Option<QaResult>)> = Vec::new();
+
+    if args.no_batch {
+        // Synchronous processing
+        for (i, (idx, prompt)) in prompts.iter().enumerate() {
+            eprint!("\rProcessing {}/{}", i + 1, prompts.len());
+
+            let messages = vec![Message {
+                role: Role::User,
+                content: vec![RequestContent::Text {
+                    text: prompt.clone(),
+                    cache_control: None,
+                }],
+            }];
+
+            let response = client.generate(MODEL, 1024, messages).await?;
+            let result = response.map(|r| {
+                let text = r
+                    .content
+                    .iter()
+                    .filter_map(|c| match c {
+                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
+                        _ => None,
+                    })
+                    .collect::<Vec<_>>()
+                    .join("");
+                parse_response(&text)
+            });
+            results.push((*idx, result));
+        }
+        eprintln!();
+    } else {
+        // Queue all for batching
+        for (idx, prompt) in &prompts {
+            let messages = vec![Message {
+                role: Role::User,
+                content: vec![RequestContent::Text {
+                    text: prompt.clone(),
+                    cache_control: None,
+                }],
+            }];
+
+            let response = client.generate(MODEL, 1024, messages).await?;
+            let result = response.map(|r| {
+                let text = r
+                    .content
+                    .iter()
+                    .filter_map(|c| match c {
+                        anthropic::ResponseContent::Text { text } => Some(text.as_str()),
+                        _ => None,
+                    })
+                    .collect::<Vec<_>>()
+                    .join("");
+                parse_response(&text)
+            });
+            results.push((*idx, result));
+        }
+
+        // Sync batches (upload pending, download finished)
+        client.sync_batches().await?;
+
+        if args.wait {
+            eprintln!("Waiting for batch to complete...");
+            loop {
+                std::thread::sleep(std::time::Duration::from_secs(30));
+                client.sync_batches().await?;
+
+                // Re-check all items that didn't have results
+                let mut all_done = true;
+                for (result_idx, (idx, prompt)) in prompts.iter().enumerate() {
+                    if results[result_idx].1.is_none() {
+                        let messages = vec![Message {
+                            role: Role::User,
+                            content: vec![RequestContent::Text {
+                                text: prompt.clone(),
+                                cache_control: None,
+                            }],
+                        }];
+
+                        let response = client.generate(MODEL, 1024, messages).await?;
+                        if let Some(r) = response {
+                            let text = r
+                                .content
+                                .iter()
+                                .filter_map(|c| match c {
+                                    anthropic::ResponseContent::Text { text } => {
+                                        Some(text.as_str())
+                                    }
+                                    _ => None,
+                                })
+                                .collect::<Vec<_>>()
+                                .join("");
+                            results[result_idx] = (*idx, Some(parse_response(&text)));
+                        } else {
+                            all_done = false;
+                        }
+                    }
+                }
+
+                let done_count = results.iter().filter(|(_, r)| r.is_some()).count();
+                if all_done {
+                    break;
+                }
+                eprintln!("Still waiting... {}/{} results", done_count, prompts.len());
+            }
+        } else {
+            let pending_count = results.iter().filter(|(_, r)| r.is_none()).count();
+            if pending_count > 0 {
+                eprintln!(
+                    "Batch submitted. {} pending. Run again later to retrieve results.",
+                    pending_count
+                );
+            }
+        }
+    }
+
+    // Build results map by index
+    let mut results_by_idx: std::collections::HashMap<usize, QaResult> =
+        std::collections::HashMap::new();
+    for (idx, result) in results {
+        if let Some(r) = result {
+            results_by_idx.insert(idx, r);
+        }
+    }
+
+    // Output results
+    let mut writer: Box<dyn Write> = if let Some(path) = output_path {
+        Box::new(BufWriter::new(std::fs::File::create(path)?))
+    } else {
+        Box::new(std::io::stdout())
+    };
+
+    let mut num_total = 0;
+    let mut num_reverts_edits = 0;
+
+    for (idx, example) in examples.iter_mut().enumerate() {
+        // Skip examples that couldn't be processed
+        if build_prompt(example).is_none() {
+            continue;
+        }
+
+        let result = results_by_idx
+            .get(&idx)
+            .cloned()
+            .unwrap_or_else(|| QaResult {
+                reasoning: None,
+                reverts_edits: None,
+                confidence: None,
+                response: None,
+                error: Some("Result not found".to_string()),
+            });
+
+        if result.reverts_edits == Some(true) {
+            num_reverts_edits += 1;
+        }
+        num_total += 1;
+
+        // Add QA result to example and output
+        let mut example_json = serde_json::to_value(&example)?;
+        example_json["qa"] = serde_json::to_value(&result)?;
+        writeln!(writer, "{}", serde_json::to_string(&example_json)?)?;
+    }
+
+    if let Some(path) = output_path {
+        eprintln!("Results written to {}", path.display());
+    }
+
+    eprintln!("Processed:     {} items", num_total);
+    if num_total > 0 {
+        eprintln!(
+            "Reverts edits: {} ({:.2}%)",
+            num_reverts_edits,
+            num_reverts_edits as f64 / num_total as f64 * 100.0
+        );
+    }
+
+    Ok(())
+}

crates/edit_prediction_cli/src/word_diff.rs 🔗

@@ -0,0 +1,343 @@
+//! Word-diff utilities for converting unified diffs to word-diff format.
+
+/// Convert unified diff to word-diff format.
+///
+/// This transforms line-based diffs into word-level diffs where:
+/// - Deletions are marked with `[-...-]`
+/// - Insertions are marked with `{+...+}`
+pub fn unified_to_word_diff(unified_diff: &str) -> String {
+    let lines: Vec<&str> = unified_diff.lines().collect();
+    let mut result = String::new();
+    let mut old_lines: Vec<&str> = Vec::new();
+    let mut new_lines: Vec<&str> = Vec::new();
+
+    let flush_changes =
+        |old_lines: &mut Vec<&str>, new_lines: &mut Vec<&str>, result: &mut String| {
+            if old_lines.is_empty() && new_lines.is_empty() {
+                return;
+            }
+
+            // Strip the leading '-' or '+' from each line
+            let old_text: String = old_lines
+                .iter()
+                .map(|line| if line.len() > 1 { &line[1..] } else { "" })
+                .collect::<Vec<_>>()
+                .join("\n");
+
+            let new_text: String = new_lines
+                .iter()
+                .map(|line| if line.len() > 1 { &line[1..] } else { "" })
+                .collect::<Vec<_>>()
+                .join("\n");
+
+            if !old_text.is_empty() || !new_text.is_empty() {
+                let word_diff = compute_word_diff(&old_text, &new_text);
+                result.push_str(&word_diff);
+            }
+
+            old_lines.clear();
+            new_lines.clear();
+        };
+
+    for line in lines {
+        if line.starts_with("---") || line.starts_with("+++") {
+            flush_changes(&mut old_lines, &mut new_lines, &mut result);
+            result.push_str(line);
+            result.push('\n');
+        } else if line.starts_with("@@") {
+            flush_changes(&mut old_lines, &mut new_lines, &mut result);
+            result.push_str(line);
+            result.push('\n');
+        } else if line.starts_with('-') {
+            old_lines.push(line);
+        } else if line.starts_with('+') {
+            new_lines.push(line);
+        } else if line.starts_with(' ') || line.is_empty() {
+            flush_changes(&mut old_lines, &mut new_lines, &mut result);
+            result.push_str(line);
+            result.push('\n');
+        } else {
+            // Header lines (diff --git, index, etc.)
+            flush_changes(&mut old_lines, &mut new_lines, &mut result);
+            result.push_str(line);
+            result.push('\n');
+        }
+    }
+
+    flush_changes(&mut old_lines, &mut new_lines, &mut result);
+    result
+}
+
+/// Compute word-level diff between two text blocks.
+///
+/// Words and whitespace are treated as separate tokens. The output uses:
+/// - `[-...-]` for deleted content
+/// - `{+...+}` for inserted content
+fn compute_word_diff(old_text: &str, new_text: &str) -> String {
+    // Split into words while preserving whitespace
+    let old_words = tokenize(old_text);
+    let new_words = tokenize(new_text);
+
+    let ops = diff_tokens(&old_words, &new_words);
+    let mut result = String::new();
+
+    for op in ops {
+        match op {
+            DiffOp::Equal(start, end) => {
+                for token in &old_words[start..end] {
+                    result.push_str(token);
+                }
+            }
+            DiffOp::Delete(start, end) => {
+                result.push_str("[-");
+                for token in &old_words[start..end] {
+                    result.push_str(token);
+                }
+                result.push_str("-]");
+            }
+            DiffOp::Insert(start, end) => {
+                result.push_str("{+");
+                for token in &new_words[start..end] {
+                    result.push_str(token);
+                }
+                result.push_str("+}");
+            }
+            DiffOp::Replace {
+                old_start,
+                old_end,
+                new_start,
+                new_end,
+            } => {
+                result.push_str("[-");
+                for token in &old_words[old_start..old_end] {
+                    result.push_str(token);
+                }
+                result.push_str("-]");
+                result.push_str("{+");
+                for token in &new_words[new_start..new_end] {
+                    result.push_str(token);
+                }
+                result.push_str("+}");
+            }
+        }
+    }
+
+    if !result.is_empty() && !result.ends_with('\n') {
+        result.push('\n');
+    }
+
+    result
+}
+
+/// Tokenize text into words and whitespace sequences.
+fn tokenize(text: &str) -> Vec<&str> {
+    let mut tokens = Vec::new();
+    let mut chars = text.char_indices().peekable();
+
+    while let Some((start, ch)) = chars.next() {
+        if ch.is_whitespace() {
+            // Collect contiguous whitespace
+            let mut end = start + ch.len_utf8();
+            while let Some(&(_, next_ch)) = chars.peek() {
+                if next_ch.is_whitespace() {
+                    end += next_ch.len_utf8();
+                    chars.next();
+                } else {
+                    break;
+                }
+            }
+            tokens.push(&text[start..end]);
+        } else {
+            // Collect contiguous non-whitespace
+            let mut end = start + ch.len_utf8();
+            while let Some(&(_, next_ch)) = chars.peek() {
+                if !next_ch.is_whitespace() {
+                    end += next_ch.len_utf8();
+                    chars.next();
+                } else {
+                    break;
+                }
+            }
+            tokens.push(&text[start..end]);
+        }
+    }
+
+    tokens
+}
+
+#[derive(Debug)]
+enum DiffOp {
+    Equal(usize, usize),
+    Delete(usize, usize),
+    Insert(usize, usize),
+    Replace {
+        old_start: usize,
+        old_end: usize,
+        new_start: usize,
+        new_end: usize,
+    },
+}
+
+/// Compute diff operations between two token sequences using a simple LCS-based algorithm.
+fn diff_tokens<'a>(old: &[&'a str], new: &[&'a str]) -> Vec<DiffOp> {
+    // Build LCS table
+    let m = old.len();
+    let n = new.len();
+
+    if m == 0 && n == 0 {
+        return vec![];
+    }
+    if m == 0 {
+        return vec![DiffOp::Insert(0, n)];
+    }
+    if n == 0 {
+        return vec![DiffOp::Delete(0, m)];
+    }
+
+    // LCS dynamic programming
+    let mut dp = vec![vec![0usize; n + 1]; m + 1];
+    for i in 1..=m {
+        for j in 1..=n {
+            if old[i - 1] == new[j - 1] {
+                dp[i][j] = dp[i - 1][j - 1] + 1;
+            } else {
+                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
+            }
+        }
+    }
+
+    // Backtrack to find operations
+    let mut ops = Vec::new();
+    let mut i = m;
+    let mut j = n;
+
+    // We'll collect in reverse order, then reverse at the end
+    let mut stack: Vec<(usize, usize, bool)> = Vec::new(); // (index, end, is_old)
+
+    while i > 0 || j > 0 {
+        if i > 0 && j > 0 && old[i - 1] == new[j - 1] {
+            stack.push((i - 1, i, true)); // Equal marker (using old index)
+            stack.push((j - 1, j, false)); // Paired with new index
+            i -= 1;
+            j -= 1;
+        } else if j > 0 && (i == 0 || dp[i][j - 1] >= dp[i - 1][j]) {
+            // Insert from new
+            stack.push((j - 1, j, false));
+            j -= 1;
+        } else {
+            // Delete from old
+            stack.push((i - 1, i, true));
+            i -= 1;
+        }
+    }
+
+    // Process the stack to build proper DiffOps
+    // This is a simplified approach - just iterate through and build ops
+    let mut old_idx = 0;
+    let mut new_idx = 0;
+
+    while old_idx < m || new_idx < n {
+        // Find next matching pair
+        let mut old_match = None;
+        let mut new_match = None;
+
+        for oi in old_idx..m {
+            for ni in new_idx..n {
+                if old[oi] == new[ni] {
+                    old_match = Some(oi);
+                    new_match = Some(ni);
+                    break;
+                }
+            }
+            if old_match.is_some() {
+                break;
+            }
+        }
+
+        match (old_match, new_match) {
+            (Some(om), Some(nm)) => {
+                // Handle any deletions/insertions before the match
+                if old_idx < om && new_idx < nm {
+                    ops.push(DiffOp::Replace {
+                        old_start: old_idx,
+                        old_end: om,
+                        new_start: new_idx,
+                        new_end: nm,
+                    });
+                } else if old_idx < om {
+                    ops.push(DiffOp::Delete(old_idx, om));
+                } else if new_idx < nm {
+                    ops.push(DiffOp::Insert(new_idx, nm));
+                }
+
+                // Find the extent of the equal sequence
+                let mut eq_end_old = om;
+                let mut eq_end_new = nm;
+                while eq_end_old < m && eq_end_new < n && old[eq_end_old] == new[eq_end_new] {
+                    eq_end_old += 1;
+                    eq_end_new += 1;
+                }
+
+                ops.push(DiffOp::Equal(om, eq_end_old));
+                old_idx = eq_end_old;
+                new_idx = eq_end_new;
+            }
+            _ => {
+                // No more matches, handle remaining
+                if old_idx < m && new_idx < n {
+                    ops.push(DiffOp::Replace {
+                        old_start: old_idx,
+                        old_end: m,
+                        new_start: new_idx,
+                        new_end: n,
+                    });
+                } else if old_idx < m {
+                    ops.push(DiffOp::Delete(old_idx, m));
+                } else if new_idx < n {
+                    ops.push(DiffOp::Insert(new_idx, n));
+                }
+                break;
+            }
+        }
+    }
+
+    ops
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    #[test]
+    fn test_tokenize() {
+        let tokens = tokenize("hello world");
+        assert_eq!(tokens, vec!["hello", " ", "world"]);
+
+        let tokens = tokenize("  multiple   spaces  ");
+        assert_eq!(tokens, vec!["  ", "multiple", "   ", "spaces", "  "]);
+    }
+
+    #[test]
+    fn test_compute_word_diff_simple() {
+        let result = compute_word_diff("hello world", "hello there");
+        assert!(result.contains("[-world-]"));
+        assert!(result.contains("{+there+}"));
+    }
+
+    #[test]
+    fn test_unified_to_word_diff() {
+        let unified = "\
+--- a/file.txt
++++ b/file.txt
+@@ -1,3 +1,3 @@
+ context line
+-old text here
++new text here
+ more context";
+
+        let result = unified_to_word_diff(unified);
+        assert!(result.contains("--- a/file.txt"));
+        assert!(result.contains("+++ b/file.txt"));
+        assert!(result.contains("@@"));
+    }
+}