main.rs

  1mod api;
  2mod auth;
  3mod db;
  4mod env;
  5mod rpc;
  6
  7#[cfg(test)]
  8mod integration_tests;
  9
 10use anyhow::anyhow;
 11use axum::{routing::get, Router};
 12use collab::{Error, Result};
 13use db::Database;
 14use serde::Deserialize;
 15use std::{
 16    env::args,
 17    net::{SocketAddr, TcpListener},
 18    path::{Path, PathBuf},
 19    sync::Arc,
 20};
 21use tracing_log::LogTracer;
 22use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
 23use util::ResultExt;
 24
 25const VERSION: &'static str = env!("CARGO_PKG_VERSION");
 26
 27#[derive(Default, Deserialize)]
 28pub struct Config {
 29    pub http_port: u16,
 30    pub database_url: String,
 31    pub api_token: String,
 32    pub invite_link_prefix: String,
 33    pub live_kit_server: Option<String>,
 34    pub live_kit_key: Option<String>,
 35    pub live_kit_secret: Option<String>,
 36    pub rust_log: Option<String>,
 37    pub log_json: Option<bool>,
 38}
 39
 40#[derive(Default, Deserialize)]
 41pub struct MigrateConfig {
 42    pub database_url: String,
 43    pub migrations_path: Option<PathBuf>,
 44}
 45
 46pub struct AppState {
 47    db: Arc<Database>,
 48    live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
 49    config: Config,
 50}
 51
 52impl AppState {
 53    async fn new(config: Config) -> Result<Arc<Self>> {
 54        let mut db_options = db::ConnectOptions::new(config.database_url.clone());
 55        db_options.max_connections(5);
 56        let db = Database::new(db_options).await?;
 57        let live_kit_client = if let Some(((server, key), secret)) = config
 58            .live_kit_server
 59            .as_ref()
 60            .zip(config.live_kit_key.as_ref())
 61            .zip(config.live_kit_secret.as_ref())
 62        {
 63            Some(Arc::new(live_kit_server::api::LiveKitClient::new(
 64                server.clone(),
 65                key.clone(),
 66                secret.clone(),
 67            )) as Arc<dyn live_kit_server::api::Client>)
 68        } else {
 69            None
 70        };
 71
 72        let this = Self {
 73            db: Arc::new(db),
 74            live_kit_client,
 75            config,
 76        };
 77        Ok(Arc::new(this))
 78    }
 79}
 80
 81#[tokio::main]
 82async fn main() -> Result<()> {
 83    if let Err(error) = env::load_dotenv() {
 84        eprintln!(
 85            "error loading .env.toml (this is expected in production): {}",
 86            error
 87        );
 88    }
 89
 90    match args().skip(1).next().as_deref() {
 91        Some("version") => {
 92            println!("collab v{VERSION}");
 93        }
 94        Some("migrate") => {
 95            let config = envy::from_env::<MigrateConfig>().expect("error loading config");
 96            let mut db_options = db::ConnectOptions::new(config.database_url.clone());
 97            db_options.max_connections(5);
 98            let db = Database::new(db_options).await?;
 99
100            let migrations_path = config
101                .migrations_path
102                .as_deref()
103                .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
104
105            let migrations = db.migrate(&migrations_path, false).await?;
106            for (migration, duration) in migrations {
107                println!(
108                    "Ran {} {} {:?}",
109                    migration.version, migration.description, duration
110                );
111            }
112
113            return Ok(());
114        }
115        Some("serve") => {
116            let config = envy::from_env::<Config>().expect("error loading config");
117            init_tracing(&config);
118
119            let state = AppState::new(config).await?;
120            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
121                .expect("failed to bind TCP listener");
122
123            let rpc_server = rpc::Server::new(state.clone());
124
125            let app = api::routes(rpc_server.clone(), state.clone())
126                .merge(rpc::routes(rpc_server.clone()))
127                .merge(Router::new().route("/", get(handle_root)));
128
129            axum::Server::from_tcp(listener)?
130                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
131                .await?;
132        }
133        _ => {
134            Err(anyhow!("usage: collab <version | migrate | serve>"))?;
135        }
136    }
137    Ok(())
138}
139
140async fn handle_root() -> String {
141    format!("collab v{VERSION}")
142}
143
144pub fn init_tracing(config: &Config) -> Option<()> {
145    use std::str::FromStr;
146    use tracing_subscriber::layer::SubscriberExt;
147    let rust_log = config.rust_log.clone()?;
148
149    LogTracer::init().log_err()?;
150
151    let subscriber = tracing_subscriber::Registry::default()
152        .with(if config.log_json.unwrap_or(false) {
153            Box::new(
154                tracing_subscriber::fmt::layer()
155                    .fmt_fields(JsonFields::default())
156                    .event_format(
157                        tracing_subscriber::fmt::format()
158                            .json()
159                            .flatten_event(true)
160                            .with_span_list(true),
161                    ),
162            ) as Box<dyn Layer<_> + Send + Sync>
163        } else {
164            Box::new(
165                tracing_subscriber::fmt::layer()
166                    .event_format(tracing_subscriber::fmt::format().pretty()),
167            )
168        })
169        .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
170
171    tracing::subscriber::set_global_default(subscriber).unwrap();
172
173    None
174}