1use anyhow::anyhow;
2use axum::{extract::MatchedPath, routing::get, Extension, Router};
3use collab::{
4 api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
5 Config, MigrateConfig, Result,
6};
7use db::Database;
8use hyper::Request;
9use std::{
10 env::args,
11 net::{SocketAddr, TcpListener},
12 path::Path,
13 sync::Arc,
14};
15#[cfg(unix)]
16use tokio::signal::unix::SignalKind;
17use tower_http::trace::{self, TraceLayer};
18use tracing::Level;
19use tracing_log::LogTracer;
20use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
21use util::ResultExt;
22
23const VERSION: &'static str = env!("CARGO_PKG_VERSION");
24const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
25
26#[tokio::main]
27async fn main() -> Result<()> {
28 if let Err(error) = env::load_dotenv() {
29 eprintln!(
30 "error loading .env.toml (this is expected in production): {}",
31 error
32 );
33 }
34
35 let mut args = args().skip(1);
36 match args.next().as_deref() {
37 Some("version") => {
38 println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
39 }
40 Some("migrate") => {
41 run_migrations().await?;
42 }
43 Some("serve") => {
44 let (is_api, is_collab) = if let Some(next) = args.next() {
45 (next == "api", next == "collab")
46 } else {
47 (true, true)
48 };
49 if !is_api && !is_collab {
50 Err(anyhow!(
51 "usage: collab <version | migrate | serve [api|collab]>"
52 ))?;
53 }
54
55 let config = envy::from_env::<Config>().expect("error loading config");
56 init_tracing(&config);
57
58 run_migrations().await?;
59
60 let state = AppState::new(config).await?;
61
62 let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
63 .expect("failed to bind TCP listener");
64
65 let rpc_server = if is_collab {
66 let epoch = state
67 .db
68 .create_server(&state.config.zed_environment)
69 .await?;
70 let rpc_server =
71 collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
72 rpc_server.start().await?;
73
74 Some(rpc_server)
75 } else {
76 None
77 };
78
79 if is_api {
80 fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
81 }
82
83 let mut app = collab::api::routes(rpc_server.clone(), state.clone());
84 if let Some(rpc_server) = rpc_server.clone() {
85 app = app.merge(collab::rpc::routes(rpc_server))
86 }
87 app = app
88 .merge(
89 Router::new()
90 .route("/", get(handle_root))
91 .route("/healthz", get(handle_liveness_probe))
92 .merge(collab::api::extensions::router())
93 .merge(collab::api::events::router())
94 .layer(Extension(state.clone())),
95 )
96 .layer(
97 TraceLayer::new_for_http()
98 .make_span_with(|request: &Request<_>| {
99 let matched_path = request
100 .extensions()
101 .get::<MatchedPath>()
102 .map(MatchedPath::as_str);
103
104 tracing::info_span!(
105 "http_request",
106 method = ?request.method(),
107 matched_path,
108 )
109 })
110 .on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
111 );
112
113 #[cfg(unix)]
114 axum::Server::from_tcp(listener)?
115 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
116 .with_graceful_shutdown(async move {
117 let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
118 .expect("failed to listen for interrupt signal");
119 let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
120 .expect("failed to listen for interrupt signal");
121 let sigterm = sigterm.recv();
122 let sigint = sigint.recv();
123 futures::pin_mut!(sigterm, sigint);
124 futures::future::select(sigterm, sigint).await;
125 tracing::info!("Received interrupt signal");
126
127 if let Some(rpc_server) = rpc_server {
128 rpc_server.teardown();
129 }
130 })
131 .await?;
132
133 // todo!("windows")
134 #[cfg(windows)]
135 unimplemented!();
136 }
137 _ => {
138 Err(anyhow!(
139 "usage: collab <version | migrate | serve [api|collab]>"
140 ))?;
141 }
142 }
143 Ok(())
144}
145
146async fn run_migrations() -> Result<()> {
147 let config = envy::from_env::<MigrateConfig>().expect("error loading config");
148 let db_options = db::ConnectOptions::new(config.database_url.clone());
149 let db = Database::new(db_options, Executor::Production).await?;
150
151 let migrations_path = config
152 .migrations_path
153 .as_deref()
154 .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
155
156 let migrations = db.migrate(&migrations_path, false).await?;
157 for (migration, duration) in migrations {
158 log::info!(
159 "Migrated {} {} {:?}",
160 migration.version,
161 migration.description,
162 duration
163 );
164 }
165
166 return Ok(());
167}
168
169async fn handle_root() -> String {
170 format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
171}
172
173async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
174 state.db.get_all_users(0, 1).await?;
175 Ok("ok".to_string())
176}
177
178pub fn init_tracing(config: &Config) -> Option<()> {
179 use std::str::FromStr;
180 use tracing_subscriber::layer::SubscriberExt;
181 let rust_log = config.rust_log.clone()?;
182
183 LogTracer::init().log_err()?;
184
185 let subscriber = tracing_subscriber::Registry::default()
186 .with(if config.log_json.unwrap_or(false) {
187 Box::new(
188 tracing_subscriber::fmt::layer()
189 .fmt_fields(JsonFields::default())
190 .event_format(
191 tracing_subscriber::fmt::format()
192 .json()
193 .flatten_event(true)
194 .with_span_list(true),
195 ),
196 ) as Box<dyn Layer<_> + Send + Sync>
197 } else {
198 Box::new(
199 tracing_subscriber::fmt::layer()
200 .event_format(tracing_subscriber::fmt::format().pretty()),
201 )
202 })
203 .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
204
205 tracing::subscriber::set_global_default(subscriber).unwrap();
206
207 None
208}