main.rs

  1mod api;
  2mod auth;
  3mod db;
  4mod env;
  5mod rpc;
  6
  7#[cfg(test)]
  8mod db_tests;
  9#[cfg(test)]
 10mod integration_tests;
 11
 12use anyhow::anyhow;
 13use axum::{routing::get, Router};
 14use collab::{Error, Result};
 15use db::{Db, PostgresDb};
 16use serde::Deserialize;
 17use std::{
 18    env::args,
 19    net::{SocketAddr, TcpListener},
 20    path::PathBuf,
 21    sync::Arc,
 22    time::Duration,
 23};
 24use tracing_log::LogTracer;
 25use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
 26use util::ResultExt;
 27
 28const VERSION: &'static str = env!("CARGO_PKG_VERSION");
 29
 30#[derive(Default, Deserialize)]
 31pub struct Config {
 32    pub http_port: u16,
 33    pub database_url: String,
 34    pub api_token: String,
 35    pub invite_link_prefix: String,
 36    pub rust_log: Option<String>,
 37    pub log_json: Option<bool>,
 38}
 39
 40#[derive(Default, Deserialize)]
 41pub struct MigrateConfig {
 42    pub database_url: String,
 43    pub migrations_path: Option<PathBuf>,
 44}
 45
 46pub struct AppState {
 47    db: Arc<dyn Db>,
 48    config: Config,
 49}
 50
 51#[tokio::main]
 52async fn main() -> Result<()> {
 53    if let Err(error) = env::load_dotenv() {
 54        eprintln!(
 55            "error loading .env.toml (this is expected in production): {}",
 56            error
 57        );
 58    }
 59
 60    match args().skip(1).next().as_deref() {
 61        Some("version") => {
 62            println!("collab v{VERSION}");
 63        }
 64        Some("migrate") => {
 65            let config = envy::from_env::<MigrateConfig>().expect("error loading config");
 66            let db = PostgresDb::new(&config.database_url, 5).await?;
 67
 68            let migrations_path = config
 69                .migrations_path
 70                .as_deref()
 71                .or(db::DEFAULT_MIGRATIONS_PATH.map(|s| s.as_ref()))
 72                .ok_or_else(|| anyhow!("missing MIGRATIONS_PATH environment variable"))?;
 73
 74            let migrations = db.migrate(&migrations_path, false).await?;
 75            for (migration, duration) in migrations {
 76                println!(
 77                    "Ran {} {} {:?}",
 78                    migration.version, migration.description, duration
 79                );
 80            }
 81
 82            return Ok(());
 83        }
 84        Some("serve") => {
 85            let config = envy::from_env::<Config>().expect("error loading config");
 86            let db = PostgresDb::new(&config.database_url, 5).await?;
 87
 88            init_tracing(&config);
 89            let state = Arc::new(AppState {
 90                db: Arc::new(db),
 91                config,
 92            });
 93
 94            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
 95                .expect("failed to bind TCP listener");
 96
 97            let rpc_server = rpc::Server::new(state.clone(), None);
 98            rpc_server
 99                .start_recording_project_activity(Duration::from_secs(5 * 60), rpc::RealExecutor);
100
101            let app = api::routes(&rpc_server, state.clone())
102                .merge(rpc::routes(rpc_server))
103                .merge(Router::new().route("/", get(handle_root)));
104
105            axum::Server::from_tcp(listener)?
106                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
107                .await?;
108        }
109        _ => {
110            Err(anyhow!("usage: collab <version | migrate | serve>"))?;
111        }
112    }
113    Ok(())
114}
115
116async fn handle_root() -> String {
117    format!("collab v{VERSION}")
118}
119
120pub fn init_tracing(config: &Config) -> Option<()> {
121    use std::str::FromStr;
122    use tracing_subscriber::layer::SubscriberExt;
123    let rust_log = config.rust_log.clone()?;
124
125    LogTracer::init().log_err()?;
126
127    let subscriber = tracing_subscriber::Registry::default()
128        .with(if config.log_json.unwrap_or(false) {
129            Box::new(
130                tracing_subscriber::fmt::layer()
131                    .fmt_fields(JsonFields::default())
132                    .event_format(
133                        tracing_subscriber::fmt::format()
134                            .json()
135                            .flatten_event(true)
136                            .with_span_list(true),
137                    ),
138            ) as Box<dyn Layer<_> + Send + Sync>
139        } else {
140            Box::new(
141                tracing_subscriber::fmt::layer()
142                    .event_format(tracing_subscriber::fmt::format().pretty()),
143            )
144        })
145        .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
146
147    tracing::subscriber::set_global_default(subscriber).unwrap();
148
149    None
150}