1use anyhow::anyhow;
2use axum::{extract::MatchedPath, http::Request, routing::get, Extension, Router};
3use collab::{
4 api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
5 Config, MigrateConfig, Result,
6};
7use db::Database;
8use std::{
9 env::args,
10 net::{SocketAddr, TcpListener},
11 path::Path,
12 sync::Arc,
13};
14#[cfg(unix)]
15use tokio::signal::unix::SignalKind;
16use tower_http::trace::{self, TraceLayer};
17use tracing::Level;
18use tracing_subscriber::{
19 filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
20};
21use util::ResultExt;
22
23const VERSION: &str = env!("CARGO_PKG_VERSION");
24const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
25
26#[tokio::main]
27async fn main() -> Result<()> {
28 if let Err(error) = env::load_dotenv() {
29 eprintln!(
30 "error loading .env.toml (this is expected in production): {}",
31 error
32 );
33 }
34
35 let mut args = args().skip(1);
36 match args.next().as_deref() {
37 Some("version") => {
38 println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
39 }
40 Some("migrate") => {
41 run_migrations().await?;
42 }
43 Some("serve") => {
44 let (is_api, is_collab) = if let Some(next) = args.next() {
45 (next == "api", next == "collab")
46 } else {
47 (true, true)
48 };
49 if !is_api && !is_collab {
50 Err(anyhow!(
51 "usage: collab <version | migrate | serve [api|collab]>"
52 ))?;
53 }
54
55 let config = envy::from_env::<Config>().expect("error loading config");
56 init_tracing(&config);
57
58 run_migrations().await?;
59
60 let state = AppState::new(config).await?;
61
62 let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
63 .expect("failed to bind TCP listener");
64
65 let rpc_server = if is_collab {
66 let epoch = state
67 .db
68 .create_server(&state.config.zed_environment)
69 .await?;
70 let rpc_server =
71 collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
72 rpc_server.start().await?;
73
74 Some(rpc_server)
75 } else {
76 None
77 };
78
79 if is_api {
80 fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
81 }
82
83 let mut app = collab::api::routes(rpc_server.clone(), state.clone());
84 if let Some(rpc_server) = rpc_server.clone() {
85 app = app.merge(collab::rpc::routes(rpc_server))
86 }
87 app = app
88 .merge(
89 Router::new()
90 .route("/", get(handle_root))
91 .route("/healthz", get(handle_liveness_probe))
92 .merge(collab::api::extensions::router())
93 .merge(collab::api::events::router())
94 .layer(Extension(state.clone())),
95 )
96 .layer(
97 TraceLayer::new_for_http()
98 .make_span_with(|request: &Request<_>| {
99 let matched_path = request
100 .extensions()
101 .get::<MatchedPath>()
102 .map(MatchedPath::as_str);
103
104 tracing::info_span!(
105 "http_request",
106 method = ?request.method(),
107 matched_path,
108 )
109 })
110 .on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
111 );
112
113 #[cfg(unix)]
114 axum::Server::from_tcp(listener)
115 .map_err(|e| anyhow!(e))?
116 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
117 .with_graceful_shutdown(async move {
118 let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
119 .expect("failed to listen for interrupt signal");
120 let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
121 .expect("failed to listen for interrupt signal");
122 let sigterm = sigterm.recv();
123 let sigint = sigint.recv();
124 futures::pin_mut!(sigterm, sigint);
125 futures::future::select(sigterm, sigint).await;
126 tracing::info!("Received interrupt signal");
127
128 if let Some(rpc_server) = rpc_server {
129 rpc_server.teardown();
130 }
131 })
132 .await
133 .map_err(|e| anyhow!(e))?;
134
135 // todo("windows")
136 #[cfg(windows)]
137 unimplemented!();
138 }
139 _ => {
140 Err(anyhow!(
141 "usage: collab <version | migrate | serve [api|collab]>"
142 ))?;
143 }
144 }
145 Ok(())
146}
147
148async fn run_migrations() -> Result<()> {
149 let config = envy::from_env::<MigrateConfig>().expect("error loading config");
150 let db_options = db::ConnectOptions::new(config.database_url.clone());
151 let db = Database::new(db_options, Executor::Production).await?;
152
153 let migrations_path = config
154 .migrations_path
155 .as_deref()
156 .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
157
158 let migrations = db.migrate(&migrations_path, false).await?;
159 for (migration, duration) in migrations {
160 log::info!(
161 "Migrated {} {} {:?}",
162 migration.version,
163 migration.description,
164 duration
165 );
166 }
167
168 return Ok(());
169}
170
171async fn handle_root() -> String {
172 format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
173}
174
175async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
176 state.db.get_all_users(0, 1).await?;
177 Ok("ok".to_string())
178}
179
180pub fn init_tracing(config: &Config) -> Option<()> {
181 use std::str::FromStr;
182 use tracing_subscriber::layer::SubscriberExt;
183
184 let filter = EnvFilter::from_str(config.rust_log.as_deref()?).log_err()?;
185
186 tracing_subscriber::registry()
187 .with(if config.log_json.unwrap_or(false) {
188 Box::new(
189 tracing_subscriber::fmt::layer()
190 .fmt_fields(JsonFields::default())
191 .event_format(
192 tracing_subscriber::fmt::format()
193 .json()
194 .flatten_event(true)
195 .with_span_list(true),
196 )
197 .with_filter(filter),
198 ) as Box<dyn Layer<_> + Send + Sync>
199 } else {
200 Box::new(
201 tracing_subscriber::fmt::layer()
202 .event_format(tracing_subscriber::fmt::format().pretty())
203 .with_filter(filter),
204 )
205 })
206 .init();
207
208 None
209}