main.rs

  1mod api;
  2mod auth;
  3mod db;
  4mod db2;
  5mod env;
  6mod rpc;
  7
  8#[cfg(test)]
  9mod integration_tests;
 10
 11use anyhow::anyhow;
 12use axum::{routing::get, Router};
 13use collab::{Error, Result};
 14use db::DefaultDb as Db;
 15use serde::Deserialize;
 16use std::{
 17    env::args,
 18    net::{SocketAddr, TcpListener},
 19    path::{Path, PathBuf},
 20    sync::Arc,
 21};
 22use tracing_log::LogTracer;
 23use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
 24use util::ResultExt;
 25
 26const VERSION: &'static str = env!("CARGO_PKG_VERSION");
 27
 28#[derive(Default, Deserialize)]
 29pub struct Config {
 30    pub http_port: u16,
 31    pub database_url: String,
 32    pub api_token: String,
 33    pub invite_link_prefix: String,
 34    pub live_kit_server: Option<String>,
 35    pub live_kit_key: Option<String>,
 36    pub live_kit_secret: Option<String>,
 37    pub rust_log: Option<String>,
 38    pub log_json: Option<bool>,
 39}
 40
 41#[derive(Default, Deserialize)]
 42pub struct MigrateConfig {
 43    pub database_url: String,
 44    pub migrations_path: Option<PathBuf>,
 45}
 46
 47pub struct AppState {
 48    db: Arc<Db>,
 49    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 50    config: Config,
 51}
 52
 53impl AppState {
 54    async fn new(config: Config) -> Result<Arc<Self>> {
 55        let db = Db::new(&config.database_url, 5).await?;
 56        let live_kit_client = if let Some(((server, key), secret)) = config
 57            .live_kit_server
 58            .as_ref()
 59            .zip(config.live_kit_key.as_ref())
 60            .zip(config.live_kit_secret.as_ref())
 61        {
 62            Some(Arc::new(live_kit_server::api::LiveKitClient::new(
 63                server.clone(),
 64                key.clone(),
 65                secret.clone(),
 66            )) as Arc<dyn live_kit_server::api::Client>)
 67        } else {
 68            None
 69        };
 70
 71        let this = Self {
 72            db: Arc::new(db),
 73            live_kit_client,
 74            config,
 75        };
 76        Ok(Arc::new(this))
 77    }
 78}
 79
 80#[tokio::main]
 81async fn main() -> Result<()> {
 82    if let Err(error) = env::load_dotenv() {
 83        eprintln!(
 84            "error loading .env.toml (this is expected in production): {}",
 85            error
 86        );
 87    }
 88
 89    match args().skip(1).next().as_deref() {
 90        Some("version") => {
 91            println!("collab v{VERSION}");
 92        }
 93        Some("migrate") => {
 94            let config = envy::from_env::<MigrateConfig>().expect("error loading config");
 95            let db = Db::new(&config.database_url, 5).await?;
 96
 97            let migrations_path = config
 98                .migrations_path
 99                .as_deref()
100                .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
101
102            let migrations = db.migrate(&migrations_path, false).await?;
103            for (migration, duration) in migrations {
104                println!(
105                    "Ran {} {} {:?}",
106                    migration.version, migration.description, duration
107                );
108            }
109
110            return Ok(());
111        }
112        Some("serve") => {
113            let config = envy::from_env::<Config>().expect("error loading config");
114            init_tracing(&config);
115
116            let state = AppState::new(config).await?;
117            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
118                .expect("failed to bind TCP listener");
119
120            let rpc_server = rpc::Server::new(state.clone());
121
122            let app = api::routes(rpc_server.clone(), state.clone())
123                .merge(rpc::routes(rpc_server.clone()))
124                .merge(Router::new().route("/", get(handle_root)));
125
126            axum::Server::from_tcp(listener)?
127                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
128                .await?;
129        }
130        _ => {
131            Err(anyhow!("usage: collab <version | migrate | serve>"))?;
132        }
133    }
134    Ok(())
135}
136
137async fn handle_root() -> String {
138    format!("collab v{VERSION}")
139}
140
141pub fn init_tracing(config: &Config) -> Option<()> {
142    use std::str::FromStr;
143    use tracing_subscriber::layer::SubscriberExt;
144    let rust_log = config.rust_log.clone()?;
145
146    LogTracer::init().log_err()?;
147
148    let subscriber = tracing_subscriber::Registry::default()
149        .with(if config.log_json.unwrap_or(false) {
150            Box::new(
151                tracing_subscriber::fmt::layer()
152                    .fmt_fields(JsonFields::default())
153                    .event_format(
154                        tracing_subscriber::fmt::format()
155                            .json()
156                            .flatten_event(true)
157                            .with_span_list(true),
158                    ),
159            ) as Box<dyn Layer<_> + Send + Sync>
160        } else {
161            Box::new(
162                tracing_subscriber::fmt::layer()
163                    .event_format(tracing_subscriber::fmt::format().pretty()),
164            )
165        })
166        .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
167
168    tracing::subscriber::set_global_default(subscriber).unwrap();
169
170    None
171}