main.rs

  1use anyhow::anyhow;
  2use axum::{
  3    extract::MatchedPath,
  4    http::{Request, Response},
  5    routing::get,
  6    Extension, Router,
  7};
  8use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
  9use collab::{
 10    api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
 11    rpc::ResultExt, AppState, Config, RateLimiter, Result,
 12};
 13use db::Database;
 14use std::{
 15    env::args,
 16    net::{SocketAddr, TcpListener},
 17    path::Path,
 18    sync::Arc,
 19    time::Duration,
 20};
 21#[cfg(unix)]
 22use tokio::signal::unix::SignalKind;
 23use tower_http::trace::TraceLayer;
 24use tracing_subscriber::{
 25    filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
 26};
 27use util::ResultExt as _;
 28
 29const VERSION: &str = env!("CARGO_PKG_VERSION");
 30const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
 31
 32#[tokio::main]
 33async fn main() -> Result<()> {
 34    if let Err(error) = env::load_dotenv() {
 35        eprintln!(
 36            "error loading .env.toml (this is expected in production): {}",
 37            error
 38        );
 39    }
 40
 41    let mut args = args().skip(1);
 42    match args.next().as_deref() {
 43        Some("version") => {
 44            println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
 45        }
 46        Some("migrate") => {
 47            let config = envy::from_env::<Config>().expect("error loading config");
 48            run_migrations(&config).await?;
 49        }
 50        Some("seed") => {
 51            let config = envy::from_env::<Config>().expect("error loading config");
 52            let db_options = db::ConnectOptions::new(config.database_url.clone());
 53            let mut db = Database::new(db_options, Executor::Production).await?;
 54            db.initialize_notification_kinds().await?;
 55
 56            collab::seed::seed(&config, &db, true).await?;
 57        }
 58        Some("serve") => {
 59            let mode = match args.next().as_deref() {
 60                Some("collab") => ServiceMode::Collab,
 61                Some("api") => ServiceMode::Api,
 62                Some("llm") => ServiceMode::Llm,
 63                Some("all") => ServiceMode::All,
 64                _ => {
 65                    return Err(anyhow!(
 66                        "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
 67                    ))?;
 68                }
 69            };
 70
 71            let config = envy::from_env::<Config>().expect("error loading config");
 72            init_tracing(&config);
 73            let mut app = Router::new()
 74                .route("/", get(handle_root))
 75                .route("/healthz", get(handle_liveness_probe))
 76                .layer(Extension(mode));
 77
 78            let listener = TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
 79                .expect("failed to bind TCP listener");
 80
 81            let mut on_shutdown = None;
 82
 83            if mode.is_llm() {
 84                let state = LlmState::new(config.clone(), Executor::Production).await?;
 85
 86                app = app
 87                    .merge(collab::llm::routes())
 88                    .layer(Extension(state.clone()));
 89            }
 90
 91            if mode.is_collab() || mode.is_api() {
 92                run_migrations(&config).await?;
 93
 94                let state = AppState::new(config, Executor::Production).await?;
 95
 96                if mode.is_collab() {
 97                    state.db.purge_old_embeddings().await.trace_err();
 98                    RateLimiter::save_periodically(
 99                        state.rate_limiter.clone(),
100                        state.executor.clone(),
101                    );
102
103                    let epoch = state
104                        .db
105                        .create_server(&state.config.zed_environment)
106                        .await?;
107                    let rpc_server = collab::rpc::Server::new(epoch, state.clone());
108                    rpc_server.start().await?;
109
110                    app = app
111                        .merge(collab::api::routes(rpc_server.clone()))
112                        .merge(collab::rpc::routes(rpc_server.clone()));
113
114                    on_shutdown = Some(Box::new(move || rpc_server.teardown()));
115                }
116
117                if mode.is_api() {
118                    poll_stripe_events_periodically(state.clone());
119                    fetch_extensions_from_blob_store_periodically(state.clone());
120
121                    app = app
122                        .merge(collab::api::events::router())
123                        .merge(collab::api::extensions::router())
124                }
125
126                app = app.layer(Extension(state.clone()));
127            }
128
129            app = app.layer(
130                TraceLayer::new_for_http()
131                    .make_span_with(|request: &Request<_>| {
132                        let matched_path = request
133                            .extensions()
134                            .get::<MatchedPath>()
135                            .map(MatchedPath::as_str);
136
137                        tracing::info_span!(
138                            "http_request",
139                            method = ?request.method(),
140                            matched_path,
141                        )
142                    })
143                    .on_response(
144                        |response: &Response<_>, latency: Duration, _: &tracing::Span| {
145                            let duration_ms = latency.as_micros() as f64 / 1000.;
146                            tracing::info!(
147                                duration_ms,
148                                status = response.status().as_u16(),
149                                "finished processing request"
150                            );
151                        },
152                    ),
153            );
154
155            #[cfg(unix)]
156            let signal = async move {
157                let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
158                    .expect("failed to listen for interrupt signal");
159                let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
160                    .expect("failed to listen for interrupt signal");
161                let sigterm = sigterm.recv();
162                let sigint = sigint.recv();
163                futures::pin_mut!(sigterm, sigint);
164                futures::future::select(sigterm, sigint).await;
165            };
166
167            #[cfg(windows)]
168            let signal = async move {
169                // todo(windows):
170                // `ctrl_close` does not work well, because tokio's signal handler always returns soon,
171                // but system terminates the application soon after returning CTRL+CLOSE handler.
172                // So we should implement blocking handler to treat CTRL+CLOSE signal.
173                let mut ctrl_break = tokio::signal::windows::ctrl_break()
174                    .expect("failed to listen for interrupt signal");
175                let mut ctrl_c = tokio::signal::windows::ctrl_c()
176                    .expect("failed to listen for interrupt signal");
177                let ctrl_break = ctrl_break.recv();
178                let ctrl_c = ctrl_c.recv();
179                futures::pin_mut!(ctrl_break, ctrl_c);
180                futures::future::select(ctrl_break, ctrl_c).await;
181            };
182
183            axum::Server::from_tcp(listener)
184                .map_err(|e| anyhow!(e))?
185                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
186                .with_graceful_shutdown(async move {
187                    signal.await;
188                    tracing::info!("Received interrupt signal");
189
190                    if let Some(on_shutdown) = on_shutdown {
191                        on_shutdown();
192                    }
193                })
194                .await
195                .map_err(|e| anyhow!(e))?;
196        }
197        _ => {
198            Err(anyhow!(
199                "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
200            ))?;
201        }
202    }
203    Ok(())
204}
205
206async fn run_migrations(config: &Config) -> Result<()> {
207    let db_options = db::ConnectOptions::new(config.database_url.clone());
208    let mut db = Database::new(db_options, Executor::Production).await?;
209
210    let migrations_path = config.migrations_path.as_deref().unwrap_or_else(|| {
211        #[cfg(feature = "sqlite")]
212        let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
213        #[cfg(not(feature = "sqlite"))]
214        let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
215
216        Path::new(default_migrations)
217    });
218
219    let migrations = db.migrate(&migrations_path, false).await?;
220    for (migration, duration) in migrations {
221        log::info!(
222            "Migrated {} {} {:?}",
223            migration.version,
224            migration.description,
225            duration
226        );
227    }
228
229    db.initialize_notification_kinds().await?;
230
231    if config.seed_path.is_some() {
232        collab::seed::seed(&config, &db, false).await?;
233    }
234
235    return Ok(());
236}
237
238async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
239    format!(
240        "collab {mode:?} v{VERSION} ({})",
241        REVISION.unwrap_or("unknown")
242    )
243}
244
245async fn handle_liveness_probe(
246    app_state: Option<Extension<Arc<AppState>>>,
247    llm_state: Option<Extension<Arc<LlmState>>>,
248) -> Result<String> {
249    if let Some(state) = app_state {
250        state.db.get_all_users(0, 1).await?;
251    }
252
253    if let Some(_llm_state) = llm_state {}
254
255    Ok("ok".to_string())
256}
257
258pub fn init_tracing(config: &Config) -> Option<()> {
259    use std::str::FromStr;
260    use tracing_subscriber::layer::SubscriberExt;
261
262    let filter = EnvFilter::from_str(config.rust_log.as_deref()?).log_err()?;
263
264    tracing_subscriber::registry()
265        .with(if config.log_json.unwrap_or(false) {
266            Box::new(
267                tracing_subscriber::fmt::layer()
268                    .fmt_fields(JsonFields::default())
269                    .event_format(
270                        tracing_subscriber::fmt::format()
271                            .json()
272                            .flatten_event(true)
273                            .with_span_list(false),
274                    )
275                    .with_filter(filter),
276            ) as Box<dyn Layer<_> + Send + Sync>
277        } else {
278            Box::new(
279                tracing_subscriber::fmt::layer()
280                    .event_format(tracing_subscriber::fmt::format().pretty())
281                    .with_filter(filter),
282            )
283        })
284        .init();
285
286    None
287}