diff --git a/Cargo.lock b/Cargo.lock index 3283d32a94e374e9625698b2e1fac4bec4d6b4f2..2a4c6c4f4394a7c3e13dd00b2178640528b8d034 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -113,6 +113,7 @@ dependencies = [ "language", "log", "menu", + "ordered-float", "project", "rand 0.8.5", "regex", diff --git a/crates/ai/Cargo.toml b/crates/ai/Cargo.toml index db8772bcb1ef59155b1fbf8a66ba4f79c9adc2f2..b03405bb93fa72f6079a6ca32e661e7d37bf704d 100644 --- a/crates/ai/Cargo.toml +++ b/crates/ai/Cargo.toml @@ -26,6 +26,7 @@ chrono = { version = "0.4", features = ["serde"] } futures.workspace = true indoc.workspace = true isahc.workspace = true +ordered-float.workspace = true regex.workspace = true schemars.workspace = true serde.workspace = true diff --git a/crates/ai/src/diff.rs b/crates/ai/src/diff.rs index 378206497bf5687c702696fceb6ae8bb579e65ae..7c5af34ff540a0721a956ce5ce30b48165e76031 100644 --- a/crates/ai/src/diff.rs +++ b/crates/ai/src/diff.rs @@ -1,11 +1,13 @@ use collections::HashMap; +use ordered_float::OrderedFloat; use std::{ + cmp, fmt::{self, Debug}, ops::Range, }; struct Matrix { - cells: Vec, + cells: Vec, rows: usize, cols: usize, } @@ -20,12 +22,12 @@ impl Matrix { } fn resize(&mut self, rows: usize, cols: usize) { - self.cells.resize(rows * cols, 0); + self.cells.resize(rows * cols, 0.); self.rows = rows; self.cols = cols; } - fn get(&self, row: usize, col: usize) -> isize { + fn get(&self, row: usize, col: usize) -> f64 { if row >= self.rows { panic!("row out of bounds") } @@ -36,7 +38,7 @@ impl Matrix { self.cells[col * self.rows + row] } - fn set(&mut self, row: usize, col: usize, value: isize) { + fn set(&mut self, row: usize, col: usize, value: f64) { if row >= self.rows { panic!("row out of bounds") } @@ -79,16 +81,17 @@ pub struct Diff { } impl Diff { - const INSERTION_SCORE: isize = -1; - const DELETION_SCORE: isize = -5; - const EQUALITY_BASE: isize = 2; + const INSERTION_SCORE: f64 = -1.; + const DELETION_SCORE: f64 = -5.; + const EQUALITY_BASE: f64 = 1.618; + const MAX_EQUALITY_EXPONENT: i32 = 32; pub fn new(old: String) -> Self { let old = old.chars().collect::>(); let mut scores = Matrix::new(); scores.resize(old.len() + 1, 1); for i in 0..=old.len() { - scores.set(i, 0, i as isize * Self::DELETION_SCORE); + scores.set(i, 0, i as f64 * Self::DELETION_SCORE); } Self { old, @@ -105,7 +108,7 @@ impl Diff { self.scores.resize(self.old.len() + 1, self.new.len() + 1); for j in self.new_text_ix + 1..=self.new.len() { - self.scores.set(0, j, j as isize * Self::INSERTION_SCORE); + self.scores.set(0, j, j as f64 * Self::INSERTION_SCORE); for i in 1..=self.old.len() { let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE; let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE; @@ -117,10 +120,11 @@ impl Diff { if self.old[i - 1] == ' ' { self.scores.get(i - 1, j - 1) } else { - self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.pow(equal_run / 3) + let exponent = cmp::min(equal_run as i32 / 3, Self::MAX_EQUALITY_EXPONENT); + self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent) } } else { - isize::MIN + f64::NEG_INFINITY }; let score = insertion_score.max(deletion_score).max(equality_score); @@ -128,7 +132,7 @@ impl Diff { } } - let mut max_score = isize::MIN; + let mut max_score = f64::NEG_INFINITY; let mut next_old_text_ix = self.old_text_ix; let next_new_text_ix = self.new.len(); for i in self.old_text_ix..=self.old.len() { @@ -173,7 +177,7 @@ impl Diff { let (prev_i, prev_j) = [insertion_score, deletion_score, equality_score] .iter() - .max_by_key(|cell| cell.map(|(i, j)| self.scores.get(i, j))) + .max_by_key(|cell| cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j)))) .unwrap() .unwrap();