main.rs

  1use anyhow::anyhow;
  2use axum::{extract::MatchedPath, routing::get, Extension, Router};
  3use collab::{
  4    api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
  5    Config, MigrateConfig, Result,
  6};
  7use db::Database;
  8use hyper::Request;
  9use std::{
 10    env::args,
 11    net::{SocketAddr, TcpListener},
 12    path::Path,
 13    sync::Arc,
 14};
 15use tokio::signal::unix::SignalKind;
 16use tower_http::trace::{self, TraceLayer};
 17use tracing::Level;
 18use tracing_log::LogTracer;
 19use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
 20use util::ResultExt;
 21
 22const VERSION: &'static str = env!("CARGO_PKG_VERSION");
 23const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
 24
 25#[tokio::main]
 26async fn main() -> Result<()> {
 27    if let Err(error) = env::load_dotenv() {
 28        eprintln!(
 29            "error loading .env.toml (this is expected in production): {}",
 30            error
 31        );
 32    }
 33
 34    let mut args = args().skip(1);
 35    match args.next().as_deref() {
 36        Some("version") => {
 37            println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
 38        }
 39        Some("migrate") => {
 40            run_migrations().await?;
 41        }
 42        Some("serve") => {
 43            let (is_api, is_collab) = if let Some(next) = args.next() {
 44                (next == "api", next == "collab")
 45            } else {
 46                (true, true)
 47            };
 48            if !is_api && !is_collab {
 49                Err(anyhow!(
 50                    "usage: collab <version | migrate | serve [api|collab]>"
 51                ))?;
 52            }
 53
 54            let config = envy::from_env::<Config>().expect("error loading config");
 55            init_tracing(&config);
 56
 57            run_migrations().await?;
 58
 59            let state = AppState::new(config).await?;
 60
 61            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
 62                .expect("failed to bind TCP listener");
 63
 64            let rpc_server = if is_collab {
 65                let epoch = state
 66                    .db
 67                    .create_server(&state.config.zed_environment)
 68                    .await?;
 69                let rpc_server =
 70                    collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
 71                rpc_server.start().await?;
 72
 73                Some(rpc_server)
 74            } else {
 75                None
 76            };
 77
 78            if is_api {
 79                fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
 80            }
 81
 82            let mut app = collab::api::routes(rpc_server.clone(), state.clone());
 83            if let Some(rpc_server) = rpc_server.clone() {
 84                app = app.merge(collab::rpc::routes(rpc_server))
 85            }
 86            app = app
 87                .merge(
 88                    Router::new()
 89                        .route("/", get(handle_root))
 90                        .route("/healthz", get(handle_liveness_probe))
 91                        .merge(collab::api::extensions::router())
 92                        .merge(collab::api::events::router())
 93                        .layer(Extension(state.clone())),
 94                )
 95                .layer(
 96                    TraceLayer::new_for_http()
 97                        .make_span_with(|request: &Request<_>| {
 98                            let matched_path = request
 99                                .extensions()
100                                .get::<MatchedPath>()
101                                .map(MatchedPath::as_str);
102
103                            tracing::info_span!(
104                                "http_request",
105                                method = ?request.method(),
106                                matched_path,
107                            )
108                        })
109                        .on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
110                );
111
112            axum::Server::from_tcp(listener)?
113                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
114                .with_graceful_shutdown(async move {
115                    let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
116                        .expect("failed to listen for interrupt signal");
117                    let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
118                        .expect("failed to listen for interrupt signal");
119                    let sigterm = sigterm.recv();
120                    let sigint = sigint.recv();
121                    futures::pin_mut!(sigterm, sigint);
122                    futures::future::select(sigterm, sigint).await;
123                    tracing::info!("Received interrupt signal");
124
125                    if let Some(rpc_server) = rpc_server {
126                        rpc_server.teardown();
127                    }
128                })
129                .await?;
130        }
131        _ => {
132            Err(anyhow!(
133                "usage: collab <version | migrate | serve [api|collab]>"
134            ))?;
135        }
136    }
137    Ok(())
138}
139
140async fn run_migrations() -> Result<()> {
141    let config = envy::from_env::<MigrateConfig>().expect("error loading config");
142    let db_options = db::ConnectOptions::new(config.database_url.clone());
143    let db = Database::new(db_options, Executor::Production).await?;
144
145    let migrations_path = config
146        .migrations_path
147        .as_deref()
148        .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
149
150    let migrations = db.migrate(&migrations_path, false).await?;
151    for (migration, duration) in migrations {
152        log::info!(
153            "Migrated {} {} {:?}",
154            migration.version,
155            migration.description,
156            duration
157        );
158    }
159
160    return Ok(());
161}
162
163async fn handle_root() -> String {
164    format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
165}
166
167async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
168    state.db.get_all_users(0, 1).await?;
169    Ok("ok".to_string())
170}
171
172pub fn init_tracing(config: &Config) -> Option<()> {
173    use std::str::FromStr;
174    use tracing_subscriber::layer::SubscriberExt;
175    let rust_log = config.rust_log.clone()?;
176
177    LogTracer::init().log_err()?;
178
179    let subscriber = tracing_subscriber::Registry::default()
180        .with(if config.log_json.unwrap_or(false) {
181            Box::new(
182                tracing_subscriber::fmt::layer()
183                    .fmt_fields(JsonFields::default())
184                    .event_format(
185                        tracing_subscriber::fmt::format()
186                            .json()
187                            .flatten_event(true)
188                            .with_span_list(true),
189                    ),
190            ) as Box<dyn Layer<_> + Send + Sync>
191        } else {
192            Box::new(
193                tracing_subscriber::fmt::layer()
194                    .event_format(tracing_subscriber::fmt::format().pretty()),
195            )
196        })
197        .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
198
199    tracing::subscriber::set_global_default(subscriber).unwrap();
200
201    None
202}