use std::{fmt::Debug, ops::Add};
use sum_tree::{Dimension, Edit, Item, KeyedItem, SumTree, Summary};

pub trait Operation: Clone + Debug {
    fn lamport_timestamp(&self) -> clock::Lamport;
}

#[derive(Clone, Debug)]
struct OperationItem<T>(T);

#[derive(Clone, Debug)]
pub struct OperationQueue<T: Operation>(SumTree<OperationItem<T>>);

#[derive(Clone, Copy, Debug, Default, Eq, Ord, PartialEq, PartialOrd)]
pub struct OperationKey(clock::Lamport);

#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct OperationSummary {
    pub key: OperationKey,
    pub len: usize,
}

impl OperationKey {
    pub fn new(timestamp: clock::Lamport) -> Self {
        Self(timestamp)
    }
}

impl<T: Operation> Default for OperationQueue<T> {
    fn default() -> Self {
        OperationQueue::new()
    }
}

impl<T: Operation> OperationQueue<T> {
    pub fn new() -> Self {
        OperationQueue(SumTree::new())
    }

    pub fn len(&self) -> usize {
        self.0.summary().len
    }

    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    pub fn insert(&mut self, mut ops: Vec<T>) {
        ops.sort_by_key(|op| op.lamport_timestamp());
        ops.dedup_by_key(|op| op.lamport_timestamp());
        self.0.edit(
            ops.into_iter()
                .map(|op| Edit::Insert(OperationItem(op)))
                .collect(),
            &(),
        );
    }

    pub fn drain(&mut self) -> Self {
        let clone = self.clone();
        self.0 = SumTree::new();
        clone
    }

    pub fn iter(&self) -> impl Iterator<Item = &T> {
        self.0.iter().map(|i| &i.0)
    }
}

impl Summary for OperationSummary {
    type Context = ();

    fn add_summary(&mut self, other: &Self, _: &()) {
        assert!(self.key < other.key);
        self.key = other.key;
        self.len += other.len;
    }
}

impl<'a> Add<&'a Self> for OperationSummary {
    type Output = Self;

    fn add(self, other: &Self) -> Self {
        assert!(self.key < other.key);
        OperationSummary {
            key: other.key,
            len: self.len + other.len,
        }
    }
}

impl<'a> Dimension<'a, OperationSummary> for OperationKey {
    fn add_summary(&mut self, summary: &OperationSummary, _: &()) {
        assert!(*self <= summary.key);
        *self = summary.key;
    }
}

impl<T: Operation> Item for OperationItem<T> {
    type Summary = OperationSummary;

    fn summary(&self) -> Self::Summary {
        OperationSummary {
            key: OperationKey::new(self.0.lamport_timestamp()),
            len: 1,
        }
    }
}

impl<T: Operation> KeyedItem for OperationItem<T> {
    type Key = OperationKey;

    fn key(&self) -> Self::Key {
        OperationKey::new(self.0.lamport_timestamp())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_len() {
        let mut clock = clock::Lamport::new(0);

        let mut queue = OperationQueue::new();
        assert_eq!(queue.len(), 0);

        queue.insert(vec![
            TestOperation(clock.tick()),
            TestOperation(clock.tick()),
        ]);
        assert_eq!(queue.len(), 2);

        queue.insert(vec![TestOperation(clock.tick())]);
        assert_eq!(queue.len(), 3);

        drop(queue.drain());
        assert_eq!(queue.len(), 0);

        queue.insert(vec![TestOperation(clock.tick())]);
        assert_eq!(queue.len(), 1);
    }

    #[derive(Clone, Debug, Eq, PartialEq)]
    struct TestOperation(clock::Lamport);

    impl Operation for TestOperation {
        fn lamport_timestamp(&self) -> clock::Lamport {
            self.0
        }
    }
}
