Make scoring more precise by using floats when diffing AI refactors

Antonio Scandurra created

Change summary

Cargo.lock            |  1 +
crates/ai/Cargo.toml  |  1 +
crates/ai/src/diff.rs | 30 +++++++++++++++++-------------
3 files changed, 19 insertions(+), 13 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -113,6 +113,7 @@ dependencies = [
  "language",
  "log",
  "menu",
+ "ordered-float",
  "project",
  "rand 0.8.5",
  "regex",

crates/ai/Cargo.toml 🔗

@@ -26,6 +26,7 @@ chrono = { version = "0.4", features = ["serde"] }
 futures.workspace = true
 indoc.workspace = true
 isahc.workspace = true
+ordered-float.workspace = true
 regex.workspace = true
 schemars.workspace = true
 serde.workspace = true

crates/ai/src/diff.rs 🔗

@@ -1,11 +1,13 @@
 use collections::HashMap;
+use ordered_float::OrderedFloat;
 use std::{
+    cmp,
     fmt::{self, Debug},
     ops::Range,
 };
 
 struct Matrix {
-    cells: Vec<isize>,
+    cells: Vec<f64>,
     rows: usize,
     cols: usize,
 }
@@ -20,12 +22,12 @@ impl Matrix {
     }
 
     fn resize(&mut self, rows: usize, cols: usize) {
-        self.cells.resize(rows * cols, 0);
+        self.cells.resize(rows * cols, 0.);
         self.rows = rows;
         self.cols = cols;
     }
 
-    fn get(&self, row: usize, col: usize) -> isize {
+    fn get(&self, row: usize, col: usize) -> f64 {
         if row >= self.rows {
             panic!("row out of bounds")
         }
@@ -36,7 +38,7 @@ impl Matrix {
         self.cells[col * self.rows + row]
     }
 
-    fn set(&mut self, row: usize, col: usize, value: isize) {
+    fn set(&mut self, row: usize, col: usize, value: f64) {
         if row >= self.rows {
             panic!("row out of bounds")
         }
@@ -79,16 +81,17 @@ pub struct Diff {
 }
 
 impl Diff {
-    const INSERTION_SCORE: isize = -1;
-    const DELETION_SCORE: isize = -5;
-    const EQUALITY_BASE: isize = 2;
+    const INSERTION_SCORE: f64 = -1.;
+    const DELETION_SCORE: f64 = -5.;
+    const EQUALITY_BASE: f64 = 1.618;
+    const MAX_EQUALITY_EXPONENT: i32 = 32;
 
     pub fn new(old: String) -> Self {
         let old = old.chars().collect::<Vec<_>>();
         let mut scores = Matrix::new();
         scores.resize(old.len() + 1, 1);
         for i in 0..=old.len() {
-            scores.set(i, 0, i as isize * Self::DELETION_SCORE);
+            scores.set(i, 0, i as f64 * Self::DELETION_SCORE);
         }
         Self {
             old,
@@ -105,7 +108,7 @@ impl Diff {
         self.scores.resize(self.old.len() + 1, self.new.len() + 1);
 
         for j in self.new_text_ix + 1..=self.new.len() {
-            self.scores.set(0, j, j as isize * Self::INSERTION_SCORE);
+            self.scores.set(0, j, j as f64 * Self::INSERTION_SCORE);
             for i in 1..=self.old.len() {
                 let insertion_score = self.scores.get(i, j - 1) + Self::INSERTION_SCORE;
                 let deletion_score = self.scores.get(i - 1, j) + Self::DELETION_SCORE;
@@ -117,10 +120,11 @@ impl Diff {
                     if self.old[i - 1] == ' ' {
                         self.scores.get(i - 1, j - 1)
                     } else {
-                        self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.pow(equal_run / 3)
+                        let exponent = cmp::min(equal_run as i32 / 3, Self::MAX_EQUALITY_EXPONENT);
+                        self.scores.get(i - 1, j - 1) + Self::EQUALITY_BASE.powi(exponent)
                     }
                 } else {
-                    isize::MIN
+                    f64::NEG_INFINITY
                 };
 
                 let score = insertion_score.max(deletion_score).max(equality_score);
@@ -128,7 +132,7 @@ impl Diff {
             }
         }
 
-        let mut max_score = isize::MIN;
+        let mut max_score = f64::NEG_INFINITY;
         let mut next_old_text_ix = self.old_text_ix;
         let next_new_text_ix = self.new.len();
         for i in self.old_text_ix..=self.old.len() {
@@ -173,7 +177,7 @@ impl Diff {
 
             let (prev_i, prev_j) = [insertion_score, deletion_score, equality_score]
                 .iter()
-                .max_by_key(|cell| cell.map(|(i, j)| self.scores.get(i, j)))
+                .max_by_key(|cell| cell.map(|(i, j)| OrderedFloat(self.scores.get(i, j))))
                 .unwrap()
                 .unwrap();