sliding_window.rs

  1use std::collections::VecDeque;
  2use std::fmt::Debug;
  3use util::debug_panic;
  4
  5use crate::{HashFrom, Occurrences};
  6
  7#[derive(Debug)]
  8pub struct SlidingWindow<D, T, S> {
  9    target: T,
 10    intersection: Occurrences<S>,
 11    regions: VecDeque<WeightedOverlapRegion<D, S>>,
 12    numerator: u32,
 13    window_count: u32,
 14    jaccard_denominator_part: u32,
 15}
 16
 17#[derive(Debug)]
 18struct WeightedOverlapRegion<D, S> {
 19    data: D,
 20    added_hashes: Vec<AddedHash<S>>,
 21    window_count_delta: u32,
 22}
 23
 24#[derive(Debug)]
 25struct AddedHash<S> {
 26    hash: HashFrom<S>,
 27    target_count: u32,
 28}
 29
 30impl<D, T: AsRef<Occurrences<S>>, S> SlidingWindow<D, T, S> {
 31    pub fn new(target: T) -> Self {
 32        Self::with_capacity(target, 0)
 33    }
 34
 35    pub fn with_capacity(target: T, capacity: usize) -> Self {
 36        let jaccard_denominator_part = target.as_ref().len();
 37        Self {
 38            target,
 39            intersection: Occurrences::default().into(),
 40            regions: VecDeque::with_capacity(capacity),
 41            numerator: 0,
 42            window_count: 0,
 43            jaccard_denominator_part,
 44        }
 45    }
 46
 47    pub fn clear(&mut self) {
 48        self.intersection.clear();
 49        self.regions.clear();
 50        self.numerator = 0;
 51        self.window_count = 0;
 52        self.jaccard_denominator_part = self.target.as_ref().len();
 53    }
 54
 55    pub fn push_back(&mut self, data: D, hashes: impl IntoIterator<Item = HashFrom<S>>) {
 56        let mut added_hashes = Vec::new();
 57        let mut window_count_delta = 0;
 58        for hash in hashes {
 59            window_count_delta += 1;
 60            let target_count = self.target.as_ref().get_count(hash);
 61            if target_count > 0 {
 62                added_hashes.push(AddedHash { hash, target_count });
 63                let window_hash_count = self.intersection.add_hash(hash);
 64                if window_hash_count <= target_count {
 65                    self.numerator += 1;
 66                } else {
 67                    self.jaccard_denominator_part += 1;
 68                }
 69            }
 70        }
 71        self.window_count += window_count_delta;
 72        self.regions.push_back(WeightedOverlapRegion {
 73            data,
 74            added_hashes,
 75            window_count_delta,
 76        });
 77    }
 78
 79    pub fn pop_front(&mut self) -> D {
 80        let removed = self
 81            .regions
 82            .pop_front()
 83            .expect("No sliding window region to remove");
 84
 85        for AddedHash { hash, target_count } in removed.added_hashes {
 86            let window_hash_count = self.intersection.remove_hash(hash);
 87            if window_hash_count < target_count {
 88                if let Some(numerator) = self.numerator.checked_sub(1) {
 89                    self.numerator = numerator;
 90                } else {
 91                    debug_panic!("bug: underflow in sliding window text similarity");
 92                }
 93            } else {
 94                if let Some(jaccard_denominator_part) = self.jaccard_denominator_part.checked_sub(1)
 95                {
 96                    self.jaccard_denominator_part = jaccard_denominator_part;
 97                } else {
 98                    debug_panic!("bug: underflow in sliding window text similarity");
 99                }
100            }
101        }
102
103        if let Some(window_count) = self.window_count.checked_sub(removed.window_count_delta) {
104            self.window_count = window_count;
105        } else {
106            debug_panic!("bug: underflow in sliding window text similarity");
107        }
108
109        removed.data
110    }
111
112    pub fn weighted_overlap_coefficient(&self) -> f32 {
113        let denominator = self.target.as_ref().len().min(self.window_count);
114        if denominator == 0 {
115            0.0
116        } else {
117            self.numerator as f32 / denominator as f32
118        }
119    }
120
121    pub fn weighted_jaccard_similarity(&self) -> f32 {
122        let mut denominator = self.jaccard_denominator_part;
123        if let Some(other_denominator_part) = self.window_count.checked_sub(self.intersection.len())
124        {
125            denominator += other_denominator_part;
126        } else {
127            debug_panic!("bug: underflow in sliding window text similarity");
128        }
129        if denominator == 0 {
130            0.0
131        } else {
132            self.numerator as f32 / denominator as f32
133        }
134    }
135}
136
137#[cfg(test)]
138mod test {
139    use super::*;
140    use crate::{IdentifierParts, OccurrenceSource, Occurrences, WeightedSimilarity};
141
142    #[test]
143    fn test_sliding_window() {
144        let target = Occurrences::new(IdentifierParts::occurrences_in_str("a b c d"));
145        let mut checked_window = CheckedSlidingWindow::new(target);
146
147        checked_window.push_back("a");
148        checked_window.pop_front();
149
150        checked_window.push_back("a b");
151        checked_window.push_back("a");
152        checked_window.pop_front();
153        checked_window.pop_front();
154
155        checked_window.push_back("a b");
156        checked_window.push_back("a b c");
157        checked_window.pop_front();
158        checked_window.push_back("a b c d");
159        checked_window.pop_front();
160        checked_window.pop_front();
161
162        checked_window.clear();
163        checked_window.push_back("d d d");
164        checked_window.pop_front();
165    }
166
167    #[derive(Debug)]
168    struct CheckedSlidingWindow {
169        inner: SlidingWindow<u32, Occurrences<IdentifierParts>, IdentifierParts>,
170        text: String,
171        first_line: u32,
172        last_line: u32,
173    }
174
175    impl CheckedSlidingWindow {
176        fn new(target: Occurrences<IdentifierParts>) -> Self {
177            CheckedSlidingWindow {
178                inner: SlidingWindow::new(target),
179                text: String::new(),
180                first_line: 0,
181                last_line: 0,
182            }
183        }
184
185        #[track_caller]
186        fn clear(&mut self) {
187            self.inner.clear();
188            self.text.clear();
189            self.first_line = 0;
190            self.last_line = 0;
191            self.check_after_mutation();
192        }
193
194        #[track_caller]
195        fn push_back(&mut self, line: &str) {
196            self.inner
197                .push_back(self.last_line, IdentifierParts::occurrences_in_str(line));
198            self.text.push_str(line);
199            self.text.push('\n');
200            self.last_line += 1;
201            self.check_after_mutation();
202        }
203
204        #[track_caller]
205        fn pop_front(&mut self) {
206            assert_eq!(self.inner.pop_front(), self.first_line);
207            self.text.drain(0..self.text.find("\n").unwrap() + 1);
208            self.first_line += 1;
209            self.check_after_mutation();
210        }
211
212        #[track_caller]
213        fn check_after_mutation(&self) {
214            assert_eq!(
215                self.inner.weighted_overlap_coefficient(),
216                Occurrences::new(IdentifierParts::occurrences_in_str(&self.text))
217                    .weighted_overlap_coefficient(&self.inner.target),
218                "weighted_overlap_coefficient"
219            );
220            assert_eq!(
221                self.inner.weighted_jaccard_similarity(),
222                Occurrences::new(IdentifierParts::occurrences_in_str(&self.text))
223                    .weighted_jaccard_similarity(&self.inner.target),
224                "weighted_jaccard_similarity"
225            );
226        }
227    }
228}