db.rs

 1pub mod kvp;
 2pub mod workspace;
 3
 4use std::fs;
 5use std::ops::Deref;
 6use std::path::Path;
 7
 8use anyhow::Result;
 9use indoc::indoc;
10use sqlez::connection::Connection;
11use sqlez::domain::Domain;
12use sqlez::thread_safe_connection::ThreadSafeConnection;
13
14pub use workspace::*;
15
16const INITIALIZE_QUERY: &'static str = indoc! {"
17    PRAGMA journal_mode=WAL;
18    PRAGMA synchronous=NORMAL;
19    PRAGMA foreign_keys=TRUE;
20    PRAGMA case_sensitive_like=TRUE;
21"};
22
23#[derive(Clone)]
24pub struct Db<D: Domain>(ThreadSafeConnection<D>);
25
26impl<D: Domain> Deref for Db<D> {
27    type Target = sqlez::connection::Connection;
28
29    fn deref(&self) -> &Self::Target {
30        &self.0.deref()
31    }
32}
33
34impl<D: Domain> Db<D> {
35    /// Open or create a database at the given directory path.
36    pub fn open(db_dir: &Path, channel: &'static str) -> Self {
37        // Use 0 for now. Will implement incrementing and clearing of old db files soon TM
38        let current_db_dir = db_dir.join(Path::new(&format!("0-{}", channel)));
39        fs::create_dir_all(&current_db_dir)
40            .expect("Should be able to create the database directory");
41        let db_path = current_db_dir.join(Path::new("db.sqlite"));
42
43        Db(
44            ThreadSafeConnection::new(db_path.to_string_lossy().as_ref(), true)
45                .with_initialize_query(INITIALIZE_QUERY),
46        )
47    }
48
49    /// Open a in memory database for testing and as a fallback.
50    pub fn open_in_memory(db_name: &str) -> Self {
51        Db(ThreadSafeConnection::new(db_name, false).with_initialize_query(INITIALIZE_QUERY))
52    }
53
54    pub fn persisting(&self) -> bool {
55        self.persistent()
56    }
57
58    pub fn write_to<P: AsRef<Path>>(&self, dest: P) -> Result<()> {
59        let destination = Connection::open_file(dest.as_ref().to_string_lossy().as_ref());
60        self.backup_main(&destination)
61    }
62
63    pub fn open_as<D2: Domain>(&self) -> Db<D2> {
64        Db(self.0.for_domain())
65    }
66}