1use anyhow::anyhow;
2use axum::{extract::MatchedPath, routing::get, Extension, Router};
3use collab::{
4 api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor, AppState,
5 Config, MigrateConfig, Result,
6};
7use db::Database;
8use hyper::Request;
9use std::{
10 env::args,
11 net::{SocketAddr, TcpListener},
12 path::Path,
13 sync::Arc,
14};
15#[cfg(unix)]
16use tokio::signal::unix::SignalKind;
17use tower_http::trace::{self, TraceLayer};
18use tracing::Level;
19use tracing_log::LogTracer;
20use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer};
21use util::ResultExt;
22
23const VERSION: &str = env!("CARGO_PKG_VERSION");
24const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
25
26#[tokio::main]
27async fn main() -> Result<()> {
28 console_subscriber::init();
29 if let Err(error) = env::load_dotenv() {
30 eprintln!(
31 "error loading .env.toml (this is expected in production): {}",
32 error
33 );
34 }
35
36 let mut args = args().skip(1);
37 match args.next().as_deref() {
38 Some("version") => {
39 println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
40 }
41 Some("migrate") => {
42 run_migrations().await?;
43 }
44 Some("serve") => {
45 let (is_api, is_collab) = if let Some(next) = args.next() {
46 (next == "api", next == "collab")
47 } else {
48 (true, true)
49 };
50 if !is_api && !is_collab {
51 Err(anyhow!(
52 "usage: collab <version | migrate | serve [api|collab]>"
53 ))?;
54 }
55
56 let config = envy::from_env::<Config>().expect("error loading config");
57 init_tracing(&config);
58
59 run_migrations().await?;
60
61 let state = AppState::new(config).await?;
62
63 let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port))
64 .expect("failed to bind TCP listener");
65
66 let rpc_server = if is_collab {
67 let epoch = state
68 .db
69 .create_server(&state.config.zed_environment)
70 .await?;
71 let rpc_server =
72 collab::rpc::Server::new(epoch, state.clone(), Executor::Production);
73 rpc_server.start().await?;
74
75 Some(rpc_server)
76 } else {
77 None
78 };
79
80 if is_api {
81 fetch_extensions_from_blob_store_periodically(state.clone(), Executor::Production);
82 }
83
84 let mut app = collab::api::routes(rpc_server.clone(), state.clone());
85 if let Some(rpc_server) = rpc_server.clone() {
86 app = app.merge(collab::rpc::routes(rpc_server))
87 }
88 app = app
89 .merge(
90 Router::new()
91 .route("/", get(handle_root))
92 .route("/healthz", get(handle_liveness_probe))
93 .merge(collab::api::extensions::router())
94 .merge(collab::api::events::router())
95 .layer(Extension(state.clone())),
96 )
97 .layer(
98 TraceLayer::new_for_http()
99 .make_span_with(|request: &Request<_>| {
100 let matched_path = request
101 .extensions()
102 .get::<MatchedPath>()
103 .map(MatchedPath::as_str);
104
105 tracing::info_span!(
106 "http_request",
107 method = ?request.method(),
108 matched_path,
109 )
110 })
111 .on_response(trace::DefaultOnResponse::new().level(Level::INFO)),
112 );
113
114 #[cfg(unix)]
115 axum::Server::from_tcp(listener)?
116 .serve(app.into_make_service_with_connect_info::<SocketAddr>())
117 .with_graceful_shutdown(async move {
118 let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
119 .expect("failed to listen for interrupt signal");
120 let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
121 .expect("failed to listen for interrupt signal");
122 let sigterm = sigterm.recv();
123 let sigint = sigint.recv();
124 futures::pin_mut!(sigterm, sigint);
125 futures::future::select(sigterm, sigint).await;
126 tracing::info!("Received interrupt signal");
127
128 if let Some(rpc_server) = rpc_server {
129 rpc_server.teardown();
130 }
131 })
132 .await?;
133
134 // todo("windows")
135 #[cfg(windows)]
136 unimplemented!();
137 }
138 _ => {
139 Err(anyhow!(
140 "usage: collab <version | migrate | serve [api|collab]>"
141 ))?;
142 }
143 }
144 Ok(())
145}
146
147async fn run_migrations() -> Result<()> {
148 let config = envy::from_env::<MigrateConfig>().expect("error loading config");
149 let db_options = db::ConnectOptions::new(config.database_url.clone());
150 let db = Database::new(db_options, Executor::Production).await?;
151
152 let migrations_path = config
153 .migrations_path
154 .as_deref()
155 .unwrap_or_else(|| Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")));
156
157 let migrations = db.migrate(&migrations_path, false).await?;
158 for (migration, duration) in migrations {
159 log::info!(
160 "Migrated {} {} {:?}",
161 migration.version,
162 migration.description,
163 duration
164 );
165 }
166
167 return Ok(());
168}
169
170async fn handle_root() -> String {
171 format!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"))
172}
173
174async fn handle_liveness_probe(Extension(state): Extension<Arc<AppState>>) -> Result<String> {
175 state.db.get_all_users(0, 1).await?;
176 Ok("ok".to_string())
177}
178
179pub fn init_tracing(config: &Config) -> Option<()> {
180 use std::str::FromStr;
181 use tracing_subscriber::layer::SubscriberExt;
182 let rust_log = config.rust_log.clone()?;
183
184 LogTracer::init().log_err()?;
185
186 let subscriber = tracing_subscriber::Registry::default()
187 .with(if config.log_json.unwrap_or(false) {
188 Box::new(
189 tracing_subscriber::fmt::layer()
190 .fmt_fields(JsonFields::default())
191 .event_format(
192 tracing_subscriber::fmt::format()
193 .json()
194 .flatten_event(true)
195 .with_span_list(true),
196 ),
197 ) as Box<dyn Layer<_> + Send + Sync>
198 } else {
199 Box::new(
200 tracing_subscriber::fmt::layer()
201 .event_format(tracing_subscriber::fmt::format().pretty()),
202 )
203 })
204 .with(EnvFilter::from_str(rust_log.as_str()).log_err()?);
205
206 tracing::subscriber::set_global_default(subscriber).unwrap();
207
208 None
209}