1use anyhow::anyhow;
2use axum::{extract::MatchedPath, routing::get, Extension, Router};
3use collab::{
4 api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
5 Config, MigrateConfig, Result,
6};
7use db::Database;
8use hyper::Request;
9use std::{
10 env::args,
11 net::{SocketAddr, TcpListener},
12 path::Path,
13 sync::Arc,
14};
15use tokio::signal::unix::SignalKind;
16use tower_http::trace::{self, TraceLayer};
17use tracing::Level;
18use tracing_log::LogTracer;
19use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
20use util::ResultExt;
21
22const VERSION: &'static str = env!("CARGO_PKG_VERSION");
23const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
24
25#[tokio::main]
26async fn main() -> Result<()> {
27 if let Err(error) = env::load_dotenv() {
28 eprintln!(
29 "error loading .env.toml (this is expected in production): {}",
30 error
31 );
32 }
33
34 let mut args = args().skip(1);
35 match args.next().as_deref() {
36 Some("version") => {
37 println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
38 }
39 Some("migrate") => {
40 run_migrations().await?;
41 }
42 Some("serve") => {
43 let (is_api, is_collab) = if let Some(next) = args.next() {
44 (next == "api", next == "collab")
45 } else {
46 (true, true)
47 };
48 if !is_api && !is_collab {
49 Err(anyhow!(
50 "usage: collab <version | migrate | serve [api|collab]>"
51 ))?;
52 }
53
54 let config = envy::from_env::<Config>().expect("error loading config");
55 init_tracing(&config);
56
57 run_migrations().await?;
58
59 let state = AppState::new(config).await?;
60
61 let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
62 .expect("failed to bind TCP listener");
63
64 let rpc_server = if is_collab {
65 let epoch = state
66 .db
67 .create_server(&state.config.zed_environment)
68 .await?;
69 let rpc_server =
70 collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
71 rpc_server.start().await?;
72
73 Some(rpc_server)
74 } else {
75 None
76 };
77
78 if is_api {
79 fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
80 }
81
82 let mut app = collab::api::routes(rpc_server.clone(), state.clone());
83 if let Some(rpc_server) = rpc_server.clone() {
84 app = app.merge(collab::rpc::routes(rpc_server))
85 }
86 app = app
87 .merge(
88 Router::new()
89 .route("/", get(handle_root))
90 .route("/healthz", get(handle_liveness_probe))
91 .merge(collab::api::extensions::router())
92 .merge(collab::api::events::router())
93 .layer(Extension(state.clone())),
94 )
95 .layer(
96 TraceLayer::new_for_http()
97 .make_span_with(|request: &Request<_>| {
98 let matched_path = request
99 .extensions()
100 .get::<MatchedPath>()
101 .map(MatchedPath::as_str);
102
103 tracing::info_span!(
104 "http_request",
105 method = ?request.method(),
106 matched_path,
107 )
108 })
109 .on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
110 );
111
112 axum::Server::from_tcp(listener)?
113 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
114 .with_graceful_shutdown(async move {
115 let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
116 .expect("failed to listen for interrupt signal");
117 let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
118 .expect("failed to listen for interrupt signal");
119 let sigterm = sigterm.recv();
120 let sigint = sigint.recv();
121 futures::pin_mut!(sigterm, sigint);
122 futures::future::select(sigterm, sigint).await;
123 tracing::info!("Received interrupt signal");
124
125 if let Some(rpc_server) = rpc_server {
126 rpc_server.teardown();
127 }
128 })
129 .await?;
130 }
131 _ => {
132 Err(anyhow!(
133 "usage: collab <version | migrate | serve [api|collab]>"
134 ))?;
135 }
136 }
137 Ok(())
138}
139
140async fn run_migrations() -> Result<()> {
141 let config = envy::from_env::<MigrateConfig>().expect("error loading config");
142 let db_options = db::ConnectOptions::new(config.database_url.clone());
143 let db = Database::new(db_options, Executor::Production).await?;
144
145 let migrations_path = config
146 .migrations_path
147 .as_deref()
148 .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
149
150 let migrations = db.migrate(&migrations_path, false).await?;
151 for (migration, duration) in migrations {
152 log::info!(
153 "Migrated {} {} {:?}",
154 migration.version,
155 migration.description,
156 duration
157 );
158 }
159
160 return Ok(());
161}
162
163async fn handle_root() -> String {
164 format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
165}
166
167async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
168 state.db.get_all_users(0, 1).await?;
169 Ok("ok".to_string())
170}
171
172pub fn init_tracing(config: &Config) -> Option<()> {
173 use std::str::FromStr;
174 use tracing_subscriber::layer::SubscriberExt;
175 let rust_log = config.rust_log.clone()?;
176
177 LogTracer::init().log_err()?;
178
179 let subscriber = tracing_subscriber::Registry::default()
180 .with(if config.log_json.unwrap_or(false) {
181 Box::new(
182 tracing_subscriber::fmt::layer()
183 .fmt_fields(JsonFields::default())
184 .event_format(
185 tracing_subscriber::fmt::format()
186 .json()
187 .flatten_event(true)
188 .with_span_list(true),
189 ),
190 ) as Box<dyn Layer<_> + Send + Sync>
191 } else {
192 Box::new(
193 tracing_subscriber::fmt::layer()
194 .event_format(tracing_subscriber::fmt::format().pretty()),
195 )
196 })
197 .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
198
199 tracing::subscriber::set_global_default(subscriber).unwrap();
200
201 None
202}