main.rs

  1use anyhow::anyhow;
  2use axum::{extract::MatchedPath, routing::get, Extension, Router};
  3use collab::{
  4    api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
  5    Config, MigrateConfig, Result,
  6};
  7use db::Database;
  8use hyper::Request;
  9use std::{
 10    env::args,
 11    net::{SocketAddr, TcpListener},
 12    path::Path,
 13    sync::Arc,
 14};
 15#[cfg(unix)]
 16use tokio::signal::unix::SignalKind;
 17use tower_http::trace::{self, TraceLayer};
 18use tracing::Level;
 19use tracing_log::LogTracer;
 20use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
 21use util::ResultExt;
 22
 23const VERSION: &str = env!("CARGO_PKG_VERSION");
 24const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
 25
 26#[tokio::main]
 27async fn main() -> Result<()> {
 28    console_subscriber::init();
 29    if let Err(error) = env::load_dotenv() {
 30        eprintln!(
 31            "error loading .env.toml (this is expected in production): {}",
 32            error
 33        );
 34    }
 35
 36    let mut args = args().skip(1);
 37    match args.next().as_deref() {
 38        Some("version") => {
 39            println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
 40        }
 41        Some("migrate") => {
 42            run_migrations().await?;
 43        }
 44        Some("serve") => {
 45            let (is_api, is_collab) = if let Some(next) = args.next() {
 46                (next == "api", next == "collab")
 47            } else {
 48                (true, true)
 49            };
 50            if !is_api && !is_collab {
 51                Err(anyhow!(
 52                    "usage: collab <version | migrate | serve [api|collab]>"
 53                ))?;
 54            }
 55
 56            let config = envy::from_env::<Config>().expect("error loading config");
 57            init_tracing(&config);
 58
 59            run_migrations().await?;
 60
 61            let state = AppState::new(config).await?;
 62
 63            let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
 64                .expect("failed to bind TCP listener");
 65
 66            let rpc_server = if is_collab {
 67                let epoch = state
 68                    .db
 69                    .create_server(&state.config.zed_environment)
 70                    .await?;
 71                let rpc_server =
 72                    collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
 73                rpc_server.start().await?;
 74
 75                Some(rpc_server)
 76            } else {
 77                None
 78            };
 79
 80            if is_api {
 81                fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
 82            }
 83
 84            let mut app = collab::api::routes(rpc_server.clone(), state.clone());
 85            if let Some(rpc_server) = rpc_server.clone() {
 86                app = app.merge(collab::rpc::routes(rpc_server))
 87            }
 88            app = app
 89                .merge(
 90                    Router::new()
 91                        .route("/", get(handle_root))
 92                        .route("/healthz", get(handle_liveness_probe))
 93                        .merge(collab::api::extensions::router())
 94                        .merge(collab::api::events::router())
 95                        .layer(Extension(state.clone())),
 96                )
 97                .layer(
 98                    TraceLayer::new_for_http()
 99                        .make_span_with(|request: &Request<_>| {
100                            let matched_path = request
101                                .extensions()
102                                .get::<MatchedPath>()
103                                .map(MatchedPath::as_str);
104
105                            tracing::info_span!(
106                                "http_request",
107                                method = ?request.method(),
108                                matched_path,
109                            )
110                        })
111                        .on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
112                );
113
114            #[cfg(unix)]
115            axum::Server::from_tcp(listener)?
116                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
117                .with_graceful_shutdown(async move {
118                    let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
119                        .expect("failed to listen for interrupt signal");
120                    let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
121                        .expect("failed to listen for interrupt signal");
122                    let sigterm = sigterm.recv();
123                    let sigint = sigint.recv();
124                    futures::pin_mut!(sigterm, sigint);
125                    futures::future::select(sigterm, sigint).await;
126                    tracing::info!("Received interrupt signal");
127
128                    if let Some(rpc_server) = rpc_server {
129                        rpc_server.teardown();
130                    }
131                })
132                .await?;
133
134            // todo("windows")
135            #[cfg(windows)]
136            unimplemented!();
137        }
138        _ => {
139            Err(anyhow!(
140                "usage: collab <version | migrate | serve [api|collab]>"
141            ))?;
142        }
143    }
144    Ok(())
145}
146
147async fn run_migrations() -> Result<()> {
148    let config = envy::from_env::<MigrateConfig>().expect("error loading config");
149    let db_options = db::ConnectOptions::new(config.database_url.clone());
150    let db = Database::new(db_options, Executor::Production).await?;
151
152    let migrations_path = config
153        .migrations_path
154        .as_deref()
155        .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
156
157    let migrations = db.migrate(&migrations_path, false).await?;
158    for (migration, duration) in migrations {
159        log::info!(
160            "Migrated {} {} {:?}",
161            migration.version,
162            migration.description,
163            duration
164        );
165    }
166
167    return Ok(());
168}
169
170async fn handle_root() -> String {
171    format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
172}
173
174async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
175    state.db.get_all_users(0, 1).await?;
176    Ok("ok".to_string())
177}
178
179pub fn init_tracing(config: &Config) -> Option<()> {
180    use std::str::FromStr;
181    use tracing_subscriber::layer::SubscriberExt;
182    let rust_log = config.rust_log.clone()?;
183
184    LogTracer::init().log_err()?;
185
186    let subscriber = tracing_subscriber::Registry::default()
187        .with(if config.log_json.unwrap_or(false) {
188            Box::new(
189                tracing_subscriber::fmt::layer()
190                    .fmt_fields(JsonFields::default())
191                    .event_format(
192                        tracing_subscriber::fmt::format()
193                            .json()
194                            .flatten_event(true)
195                            .with_span_list(true),
196                    ),
197            ) as Box<dyn Layer<_> + Send + Sync>
198        } else {
199            Box::new(
200                tracing_subscriber::fmt::layer()
201                    .event_format(tracing_subscriber::fmt::format().pretty()),
202            )
203        })
204        .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
205
206    tracing::subscriber::set_global_default(subscriber).unwrap();
207
208    None
209}