main.rs

  1mod api;
  2mod auth;
  3mod db;
  4mod env;
  5mod rpc;
  6
  7#[cfg(test)]
  8mod db_tests;
  9#[cfg(test)]
 10mod integration_tests;
 11
 12use anyhow::anyhow;
 13use axum::{routing::get, Router};
 14use collab::{Error, Result};
 15use db::DefaultDb as Db;
 16use serde::Deserialize;
 17use std::{
 18    env::args,
 19    net::{SocketAddr, TcpListener},
 20    path::{Path, PathBuf},
 21    sync::Arc,
 22};
 23use tracing_log::LogTracer;
 24use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
 25use util::ResultExt;
 26
 27const VERSION: &'static str = env!("CARGO_PKG_VERSION");
 28
 29#[derive(Default, Deserialize)]
 30pub struct Config {
 31    pub http_port: u16,
 32    pub database_url: String,
 33    pub api_token: String,
 34    pub invite_link_prefix: String,
 35    pub live_kit_server: Option<String>,
 36    pub live_kit_key: Option<String>,
 37    pub live_kit_secret: Option<String>,
 38    pub rust_log: Option<String>,
 39    pub log_json: Option<bool>,
 40}
 41
 42#[derive(Default, Deserialize)]
 43pub struct MigrateConfig {
 44    pub database_url: String,
 45    pub migrations_path: Option<PathBuf>,
 46}
 47
 48pub struct AppState {
 49    db: Arc<Db>,
 50    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 51    config: Config,
 52}
 53
 54impl AppState {
 55    async fn new(config: Config) -> Result<Arc<Self>> {
 56        let db = Db::new(&config.database_url, 5).await?;
 57        let live_kit_client = if let Some(((server, key), secret)) = config
 58            .live_kit_server
 59            .as_ref()
 60            .zip(config.live_kit_key.as_ref())
 61            .zip(config.live_kit_secret.as_ref())
 62        {
 63            Some(Arc::new(live_kit_server::api::LiveKitClient::new(
 64                server.clone(),
 65                key.clone(),
 66                secret.clone(),
 67            )) as Arc<dyn live_kit_server::api::Client>)
 68        } else {
 69            None
 70        };
 71
 72        let this = Self {
 73            db: Arc::new(db),
 74            live_kit_client,
 75            config,
 76        };
 77        Ok(Arc::new(this))
 78    }
 79}
 80
 81#[tokio::main]
 82async fn main() -> Result<()> {
 83    if let Err(error) = env::load_dotenv() {
 84        eprintln!(
 85            "error loading .env.toml (this is expected in production): {}",
 86            error
 87        );
 88    }
 89
 90    match args().skip(1).next().as_deref() {
 91        Some("version") => {
 92            println!("collab v{VERSION}");
 93        }
 94        Some("migrate") => {
 95            let config = envy::from_env::<MigrateConfig>().expect("error loading config");
 96            let db = Db::new(&config.database_url, 5).await?;
 97
 98            let migrations_path = config
 99                .migrations_path
100                .as_deref()
101                .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
102
103            let migrations = db.migrate(&migrations_path, false).await?;
104            for (migration, duration) in migrations {
105                println!(
106                    "Ran {} {} {:?}",
107                    migration.version, migration.description, duration
108                );
109            }
110
111            return Ok(());
112        }
113        Some("serve") => {
114            let config = envy::from_env::<Config>().expect("error loading config");
115            init_tracing(&config);
116
117            let state = AppState::new(config).await?;
118            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
119                .expect("failed to bind TCP listener");
120
121            let rpc_server = rpc::Server::new(state.clone());
122
123            let app = api::routes(rpc_server.clone(), state.clone())
124                .merge(rpc::routes(rpc_server.clone()))
125                .merge(Router::new().route("/", get(handle_root)));
126
127            axum::Server::from_tcp(listener)?
128                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
129                .await?;
130        }
131        _ => {
132            Err(anyhow!("usage: collab <version | migrate | serve>"))?;
133        }
134    }
135    Ok(())
136}
137
138async fn handle_root() -> String {
139    format!("collab v{VERSION}")
140}
141
142pub fn init_tracing(config: &Config) -> Option<()> {
143    use std::str::FromStr;
144    use tracing_subscriber::layer::SubscriberExt;
145    let rust_log = config.rust_log.clone()?;
146
147    LogTracer::init().log_err()?;
148
149    let subscriber = tracing_subscriber::Registry::default()
150        .with(if config.log_json.unwrap_or(false) {
151            Box::new(
152                tracing_subscriber::fmt::layer()
153                    .fmt_fields(JsonFields::default())
154                    .event_format(
155                        tracing_subscriber::fmt::format()
156                            .json()
157                            .flatten_event(true)
158                            .with_span_list(true),
159                    ),
160            ) as Box<dyn Layer<_> + Send + Sync>
161        } else {
162            Box::new(
163                tracing_subscriber::fmt::layer()
164                    .event_format(tracing_subscriber::fmt::format().pretty()),
165            )
166        })
167        .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
168
169    tracing::subscriber::set_global_default(subscriber).unwrap();
170
171    None
172}