1use anyhow::anyhow;
2use axum::{
3 extract::MatchedPath,
4 http::{Request, Response},
5 routing::get,
6 Extension, Router,
7};
8use collab::{
9 api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
10 Config, MigrateConfig, Result,
11};
12use db::Database;
13use std::{
14 env::args,
15 net::{SocketAddr, TcpListener},
16 path::Path,
17 sync::Arc,
18 time::Duration,
19};
20#[cfg(unix)]
21use tokio::signal::unix::SignalKind;
22use tower_http::trace::TraceLayer;
23use tracing_subscriber::{
24 filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
25};
26use util::ResultExt;
27
28const VERSION: &str = env!("CARGO_PKG_VERSION");
29const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
30
31#[tokio::main]
32async fn main() -> Result<()> {
33 if let Err(error) = env::load_dotenv() {
34 eprintln!(
35 "error loading .env.toml (this is expected in production): {}",
36 error
37 );
38 }
39
40 let mut args = args().skip(1);
41 match args.next().as_deref() {
42 Some("version") => {
43 println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
44 }
45 Some("migrate") => {
46 run_migrations().await?;
47 }
48 Some("serve") => {
49 let (is_api, is_collab) = if let Some(next) = args.next() {
50 (next == "api", next == "collab")
51 } else {
52 (true, true)
53 };
54 if !is_api && !is_collab {
55 Err(anyhow!(
56 "usage: collab <version | migrate | serve [api|collab]>"
57 ))?;
58 }
59
60 let config = envy::from_env::<Config>().expect("error loading config");
61 init_tracing(&config);
62
63 run_migrations().await?;
64
65 let state = AppState::new(config).await?;
66
67 let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
68 .expect("failed to bind TCP listener");
69
70 let rpc_server = if is_collab {
71 let epoch = state
72 .db
73 .create_server(&state.config.zed_environment)
74 .await?;
75 let rpc_server =
76 collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
77 rpc_server.start().await?;
78
79 Some(rpc_server)
80 } else {
81 None
82 };
83
84 if is_api {
85 fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
86 }
87
88 let mut app = collab::api::routes(rpc_server.clone(), state.clone());
89 if let Some(rpc_server) = rpc_server.clone() {
90 app = app.merge(collab::rpc::routes(rpc_server))
91 }
92 app = app
93 .merge(
94 Router::new()
95 .route("/", get(handle_root))
96 .route("/healthz", get(handle_liveness_probe))
97 .merge(collab::api::extensions::router())
98 .merge(collab::api::events::router())
99 .layer(Extension(state.clone())),
100 )
101 .layer(
102 TraceLayer::new_for_http()
103 .make_span_with(|request: &Request<_>| {
104 let matched_path = request
105 .extensions()
106 .get::<MatchedPath>()
107 .map(MatchedPath::as_str);
108
109 tracing::info_span!(
110 "http_request",
111 method = ?request.method(),
112 matched_path,
113 )
114 })
115 .on_response(
116 |response: &Response<_>, latency: Duration, _: &tracing::Span| {
117 let duration_ms = latency.as_micros() as f64 / 1000.;
118 tracing::info!(
119 duration_ms,
120 status = response.status().as_u16(),
121 "finished processing request"
122 );
123 },
124 ),
125 );
126
127 #[cfg(unix)]
128 axum::Server::from_tcp(listener)
129 .map_err(|e| anyhow!(e))?
130 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
131 .with_graceful_shutdown(async move {
132 let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
133 .expect("failed to listen for interrupt signal");
134 let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
135 .expect("failed to listen for interrupt signal");
136 let sigterm = sigterm.recv();
137 let sigint = sigint.recv();
138 futures::pin_mut!(sigterm, sigint);
139 futures::future::select(sigterm, sigint).await;
140 tracing::info!("Received interrupt signal");
141
142 if let Some(rpc_server) = rpc_server {
143 rpc_server.teardown();
144 }
145 })
146 .await
147 .map_err(|e| anyhow!(e))?;
148
149 // todo("windows")
150 #[cfg(windows)]
151 unimplemented!();
152 }
153 _ => {
154 Err(anyhow!(
155 "usage: collab <version | migrate | serve [api|collab]>"
156 ))?;
157 }
158 }
159 Ok(())
160}
161
162async fn run_migrations() -> Result<()> {
163 let config = envy::from_env::<MigrateConfig>().expect("error loading config");
164 let db_options = db::ConnectOptions::new(config.database_url.clone());
165 let db = Database::new(db_options, Executor::Production).await?;
166
167 let migrations_path = config
168 .migrations_path
169 .as_deref()
170 .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
171
172 let migrations = db.migrate(&migrations_path, false).await?;
173 for (migration, duration) in migrations {
174 log::info!(
175 "Migrated {} {} {:?}",
176 migration.version,
177 migration.description,
178 duration
179 );
180 }
181
182 return Ok(());
183}
184
185async fn handle_root() -> String {
186 format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
187}
188
189async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
190 state.db.get_all_users(0, 1).await?;
191 Ok("ok".to_string())
192}
193
194pub fn init_tracing(config: &Config) -> Option<()> {
195 use std::str::FromStr;
196 use tracing_subscriber::layer::SubscriberExt;
197
198 let filter = EnvFilter::from_str(config.rust_log.as_deref()?).log_err()?;
199
200 tracing_subscriber::registry()
201 .with(if config.log_json.unwrap_or(false) {
202 Box::new(
203 tracing_subscriber::fmt::layer()
204 .fmt_fields(JsonFields::default())
205 .event_format(
206 tracing_subscriber::fmt::format()
207 .json()
208 .flatten_event(true)
209 .with_span_list(false),
210 )
211 .with_filter(filter),
212 ) as Box<dyn Layer<_> + Send + Sync>
213 } else {
214 Box::new(
215 tracing_subscriber::fmt::layer()
216 .event_format(tracing_subscriber::fmt::format().pretty())
217 .with_filter(filter),
218 )
219 })
220 .init();
221
222 None
223}