1use anyhow::anyhow;
2use axum::{
3 extract::MatchedPath,
4 http::{Request, Response},
5 routing::get,
6 Extension, Router,
7};
8use collab::{
9 api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
10 Config, MigrateConfig, RateLimiter, Result,
11};
12use db::Database;
13use std::{
14 env::args,
15 net::{SocketAddr, TcpListener},
16 path::Path,
17 sync::Arc,
18 time::Duration,
19};
20#[cfg(unix)]
21use tokio::signal::unix::SignalKind;
22use tower_http::trace::TraceLayer;
23use tracing_subscriber::{
24 filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
25};
26use util::ResultExt;
27
28const VERSION: &str = env!("CARGO_PKG_VERSION");
29const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
30
31#[tokio::main]
32async fn main() -> Result<()> {
33 if let Err(error) = env::load_dotenv() {
34 eprintln!(
35 "error loading .env.toml (this is expected in production): {}",
36 error
37 );
38 }
39
40 let mut args = args().skip(1);
41 match args.next().as_deref() {
42 Some("version") => {
43 println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
44 }
45 Some("migrate") => {
46 run_migrations().await?;
47 }
48 Some("serve") => {
49 let (is_api, is_collab) = if let Some(next) = args.next() {
50 (next == "api", next == "collab")
51 } else {
52 (true, true)
53 };
54 if !is_api && !is_collab {
55 Err(anyhow!(
56 "usage: collab <version | migrate | serve [api|collab]>"
57 ))?;
58 }
59
60 let config = envy::from_env::<Config>().expect("error loading config");
61 init_tracing(&config);
62
63 run_migrations().await?;
64
65 let state = AppState::new(config, Executor::Production).await?;
66
67 let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
68 .expect("failed to bind TCP listener");
69
70 let epoch = state
71 .db
72 .create_server(&state.config.zed_environment)
73 .await?;
74 let rpc_server = collab::rpc::Server::new(epoch, state.clone());
75 rpc_server.start().await?;
76
77 fetch_extensions_from_blob_store_periodically(state.clone());
78 RateLimiter::save_periodically(state.rate_limiter.clone(), state.executor.clone());
79
80 let rpc_server = if is_collab {
81 let epoch = state
82 .db
83 .create_server(&state.config.zed_environment)
84 .await?;
85 let rpc_server = collab::rpc::Server::new(epoch, state.clone());
86 rpc_server.start().await?;
87
88 Some(rpc_server)
89 } else {
90 None
91 };
92
93 if is_api {
94 fetch_extensions_from_blob_store_periodically(state.clone());
95 }
96
97 let mut app = collab::api::routes(rpc_server.clone(), state.clone());
98 if let Some(rpc_server) = rpc_server.clone() {
99 app = app.merge(collab::rpc::routes(rpc_server))
100 }
101 app = app
102 .merge(
103 Router::new()
104 .route("/", get(handle_root))
105 .route("/healthz", get(handle_liveness_probe))
106 .merge(collab::api::extensions::router())
107 .merge(collab::api::events::router())
108 .layer(Extension(state.clone())),
109 )
110 .layer(
111 TraceLayer::new_for_http()
112 .make_span_with(|request: &Request<_>| {
113 let matched_path = request
114 .extensions()
115 .get::<MatchedPath>()
116 .map(MatchedPath::as_str);
117
118 tracing::info_span!(
119 "http_request",
120 method = ?request.method(),
121 matched_path,
122 )
123 })
124 .on_response(
125 |response: &Response<_>, latency: Duration, _: &tracing::Span| {
126 let duration_ms = latency.as_micros() as f64 / 1000.;
127 tracing::info!(
128 duration_ms,
129 status = response.status().as_u16(),
130 "finished processing request"
131 );
132 },
133 ),
134 );
135
136 #[cfg(unix)]
137 axum::Server::from_tcp(listener)
138 .map_err(|e| anyhow!(e))?
139 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
140 .with_graceful_shutdown(async move {
141 let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
142 .expect("failed to listen for interrupt signal");
143 let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
144 .expect("failed to listen for interrupt signal");
145 let sigterm = sigterm.recv();
146 let sigint = sigint.recv();
147 futures::pin_mut!(sigterm, sigint);
148 futures::future::select(sigterm, sigint).await;
149 tracing::info!("Received interrupt signal");
150
151 if let Some(rpc_server) = rpc_server {
152 rpc_server.teardown();
153 }
154 })
155 .await
156 .map_err(|e| anyhow!(e))?;
157
158 // todo("windows")
159 #[cfg(windows)]
160 unimplemented!();
161 }
162 _ => {
163 Err(anyhow!(
164 "usage: collab <version | migrate | serve [api|collab]>"
165 ))?;
166 }
167 }
168 Ok(())
169}
170
171async fn run_migrations() -> Result<()> {
172 let config = envy::from_env::<MigrateConfig>().expect("error loading config");
173 let db_options = db::ConnectOptions::new(config.database_url.clone());
174 let db = Database::new(db_options, Executor::Production).await?;
175
176 let migrations_path = config
177 .migrations_path
178 .as_deref()
179 .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
180
181 let migrations = db.migrate(&migrations_path, false).await?;
182 for (migration, duration) in migrations {
183 log::info!(
184 "Migrated {} {} {:?}",
185 migration.version,
186 migration.description,
187 duration
188 );
189 }
190
191 return Ok(());
192}
193
194async fn handle_root() -> String {
195 format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
196}
197
198async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
199 state.db.get_all_users(0, 1).await?;
200 Ok("ok".to_string())
201}
202
203pub fn init_tracing(config: &Config) -> Option<()> {
204 use std::str::FromStr;
205 use tracing_subscriber::layer::SubscriberExt;
206
207 let filter = EnvFilter::from_str(config.rust_log.as_deref()?).log_err()?;
208
209 tracing_subscriber::registry()
210 .with(if config.log_json.unwrap_or(false) {
211 Box::new(
212 tracing_subscriber::fmt::layer()
213 .fmt_fields(JsonFields::default())
214 .event_format(
215 tracing_subscriber::fmt::format()
216 .json()
217 .flatten_event(true)
218 .with_span_list(false),
219 )
220 .with_filter(filter),
221 ) as Box<dyn Layer<_> + Send + Sync>
222 } else {
223 Box::new(
224 tracing_subscriber::fmt::layer()
225 .event_format(tracing_subscriber::fmt::format().pretty())
226 .with_filter(filter),
227 )
228 })
229 .init();
230
231 None
232}