main.rs

  1mod api;
  2mod auth;
  3mod db;
  4mod env;
  5mod rpc;
  6
  7#[cfg(test)]
  8mod db_tests;
  9#[cfg(test)]
 10mod integration_tests;
 11
 12use crate::rpc::ResultExt as _;
 13use anyhow::anyhow;
 14use axum::{routing::get, Router};
 15use collab::{Error, Result};
 16use db::DefaultDb as Db;
 17use serde::Deserialize;
 18use std::{
 19    env::args,
 20    net::{SocketAddr, TcpListener},
 21    path::{Path, PathBuf},
 22    sync::Arc,
 23    time::Duration,
 24};
 25use tokio::signal;
 26use tracing_log::LogTracer;
 27use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
 28use util::ResultExt;
 29
 30const VERSION: &'static str = env!("CARGO_PKG_VERSION");
 31
 32#[derive(Default, Deserialize)]
 33pub struct Config {
 34    pub http_port: u16,
 35    pub database_url: String,
 36    pub api_token: String,
 37    pub invite_link_prefix: String,
 38    pub live_kit_server: Option<String>,
 39    pub live_kit_key: Option<String>,
 40    pub live_kit_secret: Option<String>,
 41    pub rust_log: Option<String>,
 42    pub log_json: Option<bool>,
 43}
 44
 45#[derive(Default, Deserialize)]
 46pub struct MigrateConfig {
 47    pub database_url: String,
 48    pub migrations_path: Option<PathBuf>,
 49}
 50
 51pub struct AppState {
 52    db: Arc<Db>,
 53    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 54    config: Config,
 55}
 56
 57impl AppState {
 58    async fn new(config: Config) -> Result<Arc<Self>> {
 59        let db = Db::new(&config.database_url, 5).await?;
 60        let live_kit_client = if let Some(((server, key), secret)) = config
 61            .live_kit_server
 62            .as_ref()
 63            .zip(config.live_kit_key.as_ref())
 64            .zip(config.live_kit_secret.as_ref())
 65        {
 66            Some(Arc::new(live_kit_server::api::LiveKitClient::new(
 67                server.clone(),
 68                key.clone(),
 69                secret.clone(),
 70            )) as Arc<dyn live_kit_server::api::Client>)
 71        } else {
 72            None
 73        };
 74
 75        let this = Self {
 76            db: Arc::new(db),
 77            live_kit_client,
 78            config,
 79        };
 80        Ok(Arc::new(this))
 81    }
 82}
 83
 84#[tokio::main]
 85async fn main() -> Result<()> {
 86    if let Err(error) = env::load_dotenv() {
 87        eprintln!(
 88            "error loading .env.toml (this is expected in production): {}",
 89            error
 90        );
 91    }
 92
 93    match args().skip(1).next().as_deref() {
 94        Some("version") => {
 95            println!("collab v{VERSION}");
 96        }
 97        Some("migrate") => {
 98            let config = envy::from_env::<MigrateConfig>().expect("error loading config");
 99            let db = Db::new(&config.database_url, 5).await?;
100
101            let migrations_path = config
102                .migrations_path
103                .as_deref()
104                .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
105
106            let migrations = db.migrate(&migrations_path, false).await?;
107            for (migration, duration) in migrations {
108                println!(
109                    "Ran {} {} {:?}",
110                    migration.version, migration.description, duration
111                );
112            }
113
114            return Ok(());
115        }
116        Some("serve") => {
117            let config = envy::from_env::<Config>().expect("error loading config");
118            init_tracing(&config);
119
120            let state = AppState::new(config).await?;
121            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
122                .expect("failed to bind TCP listener");
123
124            let rpc_server = rpc::Server::new(state.clone());
125
126            let app = api::routes(rpc_server.clone(), state.clone())
127                .merge(rpc::routes(rpc_server.clone()))
128                .merge(Router::new().route("/", get(handle_root)));
129
130            axum::Server::from_tcp(listener)?
131                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
132                .with_graceful_shutdown(graceful_shutdown(rpc_server, state))
133                .await?;
134        }
135        _ => {
136            Err(anyhow!("usage: collab <version | migrate | serve>"))?;
137        }
138    }
139    Ok(())
140}
141
142async fn handle_root() -> String {
143    format!("collab v{VERSION}")
144}
145
146pub fn init_tracing(config: &Config) -> Option<()> {
147    use std::str::FromStr;
148    use tracing_subscriber::layer::SubscriberExt;
149    let rust_log = config.rust_log.clone()?;
150
151    LogTracer::init().log_err()?;
152
153    let subscriber = tracing_subscriber::Registry::default()
154        .with(if config.log_json.unwrap_or(false) {
155            Box::new(
156                tracing_subscriber::fmt::layer()
157                    .fmt_fields(JsonFields::default())
158                    .event_format(
159                        tracing_subscriber::fmt::format()
160                            .json()
161                            .flatten_event(true)
162                            .with_span_list(true),
163                    ),
164            ) as Box<dyn Layer<_> + Send + Sync>
165        } else {
166            Box::new(
167                tracing_subscriber::fmt::layer()
168                    .event_format(tracing_subscriber::fmt::format().pretty()),
169            )
170        })
171        .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
172
173    tracing::subscriber::set_global_default(subscriber).unwrap();
174
175    None
176}
177
178async fn graceful_shutdown(rpc_server: Arc<rpc::Server>, state: Arc<AppState>) {
179    let ctrl_c = async {
180        signal::ctrl_c()
181            .await
182            .expect("failed to install Ctrl+C handler");
183    };
184
185    #[cfg(unix)]
186    let terminate = async {
187        signal::unix::signal(signal::unix::SignalKind::terminate())
188            .expect("failed to install signal handler")
189            .recv()
190            .await;
191    };
192
193    #[cfg(not(unix))]
194    let terminate = std::future::pending::<()>();
195
196    tokio::select! {
197        _ = ctrl_c => {},
198        _ = terminate => {},
199    }
200
201    if let Some(live_kit) = state.live_kit_client.as_ref() {
202        let deletions = rpc_server
203            .store()
204            .await
205            .rooms()
206            .values()
207            .map(|room| {
208                let name = room.live_kit_room.clone();
209                async {
210                    live_kit.delete_room(name).await.trace_err();
211                }
212            })
213            .collect::<Vec<_>>();
214
215        tracing::info!("deleting all live-kit rooms");
216        if let Err(_) = tokio::time::timeout(
217            Duration::from_secs(10),
218            futures::future::join_all(deletions),
219        )
220        .await
221        {
222            tracing::error!("timed out waiting for live-kit room deletion");
223        }
224    }
225}