1use anyhow::anyhow;
2use axum::{
3 extract::MatchedPath,
4 http::{Request, Response},
5 routing::get,
6 Extension, Router,
7};
8use collab::{
9 api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
10 Config, MigrateConfig, RateLimiter, Result,
11};
12use db::Database;
13use std::{
14 env::args,
15 net::{SocketAddr, TcpListener},
16 path::Path,
17 sync::Arc,
18 time::Duration,
19};
20#[cfg(unix)]
21use tokio::signal::unix::SignalKind;
22use tower_http::trace::TraceLayer;
23use tracing_subscriber::{
24 filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
25};
26use util::ResultExt;
27
28const VERSION: &str = env!("CARGO_PKG_VERSION");
29const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
30
31#[tokio::main]
32async fn main() -> Result<()> {
33 if let Err(error) = env::load_dotenv() {
34 eprintln!(
35 "error loading .env.toml (this is expected in production): {}",
36 error
37 );
38 }
39
40 let mut args = args().skip(1);
41 match args.next().as_deref() {
42 Some("version") => {
43 println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
44 }
45 Some("migrate") => {
46 run_migrations().await?;
47 }
48 Some("serve") => {
49 let (is_api, is_collab) = if let Some(next) = args.next() {
50 (next == "api", next == "collab")
51 } else {
52 (true, true)
53 };
54 if !is_api && !is_collab {
55 Err(anyhow!(
56 "usage: collab <version | migrate | serve [api|collab]>"
57 ))?;
58 }
59
60 let config = envy::from_env::<Config>().expect("error loading config");
61 init_tracing(&config);
62
63 run_migrations().await?;
64
65 let state = AppState::new(config, Executor::Production).await?;
66
67 let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
68 .expect("failed to bind TCP listener");
69
70 let rpc_server = if is_collab {
71 let epoch = state
72 .db
73 .create_server(&state.config.zed_environment)
74 .await?;
75 let rpc_server = collab::rpc::Server::new(epoch, state.clone());
76 rpc_server.start().await?;
77
78 Some(rpc_server)
79 } else {
80 None
81 };
82
83 if is_collab {
84 RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
85 }
86
87 if is_api {
88 fetch_extensions_from_blob_store_periodically(state.clone());
89 }
90
91 let mut app = collab::api::routes(rpc_server.clone(), state.clone());
92 if let Some(rpc_server) = rpc_server.clone() {
93 app = app.merge(collab::rpc::routes(rpc_server))
94 }
95 app = app
96 .merge(
97 Router::new()
98 .route("/", get(handle_root))
99 .route("/healthz", get(handle_liveness_probe))
100 .merge(collab::api::extensions::router())
101 .merge(collab::api::events::router())
102 .layer(Extension(state.clone())),
103 )
104 .layer(
105 TraceLayer::new_for_http()
106 .make_span_with(|request: &Request<_>| {
107 let matched_path = request
108 .extensions()
109 .get::<MatchedPath>()
110 .map(MatchedPath::as_str);
111
112 tracing::info_span!(
113 "http_request",
114 method = ?request.method(),
115 matched_path,
116 )
117 })
118 .on_response(
119 |response: &Response<_>, latency: Duration, _: &tracing::Span| {
120 let duration_ms = latency.as_micros() as f64 / 1000.;
121 tracing::info!(
122 duration_ms,
123 status = response.status().as_u16(),
124 "finished processing request"
125 );
126 },
127 ),
128 );
129
130 #[cfg(unix)]
131 axum::Server::from_tcp(listener)
132 .map_err(|e| anyhow!(e))?
133 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
134 .with_graceful_shutdown(async move {
135 let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
136 .expect("failed to listen for interrupt signal");
137 let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
138 .expect("failed to listen for interrupt signal");
139 let sigterm = sigterm.recv();
140 let sigint = sigint.recv();
141 futures::pin_mut!(sigterm, sigint);
142 futures::future::select(sigterm, sigint).await;
143 tracing::info!("Received interrupt signal");
144
145 if let Some(rpc_server) = rpc_server {
146 rpc_server.teardown();
147 }
148 })
149 .await
150 .map_err(|e| anyhow!(e))?;
151
152 // todo("windows")
153 #[cfg(windows)]
154 unimplemented!();
155 }
156 _ => {
157 Err(anyhow!(
158 "usage: collab <version | migrate | serve [api|collab]>"
159 ))?;
160 }
161 }
162 Ok(())
163}
164
165async fn run_migrations() -> Result<()> {
166 let config = envy::from_env::<MigrateConfig>().expect("error loading config");
167 let db_options = db::ConnectOptions::new(config.database_url.clone());
168 let db = Database::new(db_options, Executor::Production).await?;
169
170 let migrations_path = config
171 .migrations_path
172 .as_deref()
173 .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
174
175 let migrations = db.migrate(&migrations_path, false).await?;
176 for (migration, duration) in migrations {
177 log::info!(
178 "Migrated {} {} {:?}",
179 migration.version,
180 migration.description,
181 duration
182 );
183 }
184
185 return Ok(());
186}
187
188async fn handle_root() -> String {
189 format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
190}
191
192async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
193 state.db.get_all_users(0, 1).await?;
194 Ok("ok".to_string())
195}
196
197pub fn init_tracing(config: &Config) -> Option<()> {
198 use std::str::FromStr;
199 use tracing_subscriber::layer::SubscriberExt;
200
201 let filter = EnvFilter::from_str(config.rust_log.as_deref()?).log_err()?;
202
203 tracing_subscriber::registry()
204 .with(if config.log_json.unwrap_or(false) {
205 Box::new(
206 tracing_subscriber::fmt::layer()
207 .fmt_fields(JsonFields::default())
208 .event_format(
209 tracing_subscriber::fmt::format()
210 .json()
211 .flatten_event(true)
212 .with_span_list(false),
213 )
214 .with_filter(filter),
215 ) as Box<dyn Layer<_> + Send + Sync>
216 } else {
217 Box::new(
218 tracing_subscriber::fmt::layer()
219 .event_format(tracing_subscriber::fmt::format().pretty())
220 .with_filter(filter),
221 )
222 })
223 .init();
224
225 None
226}