main.rs

  1use anyhow::anyhow;
  2use axum::{
  3    extract::MatchedPath,
  4    http::{Request, Response},
  5    routing::get,
  6    Extension, Router,
  7};
  8use collab::{api::billing::poll_stripe_events_periodically, llm::LlmState, ServiceMode};
  9use collab::{
 10    api::fetch_extensions_from_blob_store_periodically, db, env, executor::Executor,
 11    rpc::ResultExt, AppState, Config, RateLimiter, Result,
 12};
 13use db::Database;
 14use std::{
 15    env::args,
 16    net::{SocketAddr, TcpListener},
 17    path::Path,
 18    sync::Arc,
 19    time::Duration,
 20};
 21#[cfg(unix)]
 22use tokio::signal::unix::SignalKind;
 23use tower_http::trace::TraceLayer;
 24use tracing_subscriber::{
 25    filter::EnvFilter, fmt::format::JsonFields, util::SubscriberInitExt, Layer,
 26};
 27use util::ResultExt as _;
 28
 29const VERSION: &str = env!("CARGO_PKG_VERSION");
 30const REVISION: Option<&'static str> = option_env!("GITHUB_SHA");
 31
 32#[tokio::main]
 33async fn main() -> Result<()> {
 34    if let Err(error) = env::load_dotenv() {
 35        eprintln!(
 36            "error loading .env.toml (this is expected in production): {}",
 37            error
 38        );
 39    }
 40
 41    let mut args = args().skip(1);
 42    match args.next().as_deref() {
 43        Some("version") => {
 44            println!("collab v{} ({})", VERSION, REVISION.unwrap_or("unknown"));
 45        }
 46        Some("migrate") => {
 47            let config = envy::from_env::<Config>().expect("error loading config");
 48            run_migrations(&config).await?;
 49        }
 50        Some("seed") => {
 51            let config = envy::from_env::<Config>().expect("error loading config");
 52            let db_options = db::ConnectOptions::new(config.database_url.clone());
 53            let mut db = Database::new(db_options, Executor::Production).await?;
 54            db.initialize_notification_kinds().await?;
 55
 56            collab::seed::seed(&config, &db, true).await?;
 57        }
 58        Some("serve") => {
 59            let mode = match args.next().as_deref() {
 60                Some("collab") => ServiceMode::Collab,
 61                Some("api") => ServiceMode::Api,
 62                Some("llm") => ServiceMode::Llm,
 63                Some("all") => ServiceMode::All,
 64                _ => {
 65                    return Err(anyhow!(
 66                        "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
 67                    ))?;
 68                }
 69            };
 70
 71            let config = envy::from_env::<Config>().expect("error loading config");
 72            init_tracing(&config);
 73            let mut app = Router::new()
 74                .route("/", get(handle_root))
 75                .route("/healthz", get(handle_liveness_probe))
 76                .layer(Extension(mode));
 77
 78            let listener = TcpListener::bind(&format!("0.0.0.0:{}", config.http_port))
 79                .expect("failed to bind TCP listener");
 80
 81            let mut on_shutdown = None;
 82
 83            if mode.is_llm() {
 84                let state = LlmState::new(config.clone(), Executor::Production).await?;
 85
 86                app = app.layer(Extension(state.clone()));
 87            }
 88
 89            if mode.is_collab() || mode.is_api() {
 90                run_migrations(&config).await?;
 91
 92                let state = AppState::new(config, Executor::Production).await?;
 93
 94                if mode.is_collab() {
 95                    state.db.purge_old_embeddings().await.trace_err();
 96                    RateLimiter::save_periodically(
 97                        state.rate_limiter.clone(),
 98                        state.executor.clone(),
 99                    );
100
101                    let epoch = state
102                        .db
103                        .create_server(&state.config.zed_environment)
104                        .await?;
105                    let rpc_server = collab::rpc::Server::new(epoch, state.clone());
106                    rpc_server.start().await?;
107
108                    app = app
109                        .merge(collab::api::routes(rpc_server.clone()))
110                        .merge(collab::rpc::routes(rpc_server.clone()));
111
112                    on_shutdown = Some(Box::new(move || rpc_server.teardown()));
113                }
114
115                if mode.is_api() {
116                    poll_stripe_events_periodically(state.clone());
117                    fetch_extensions_from_blob_store_periodically(state.clone());
118
119                    app = app
120                        .merge(collab::api::events::router())
121                        .merge(collab::api::extensions::router())
122                }
123
124                app = app.layer(Extension(state.clone()));
125            }
126
127            app = app.layer(
128                TraceLayer::new_for_http()
129                    .make_span_with(|request: &Request<_>| {
130                        let matched_path = request
131                            .extensions()
132                            .get::<MatchedPath>()
133                            .map(MatchedPath::as_str);
134
135                        tracing::info_span!(
136                            "http_request",
137                            method = ?request.method(),
138                            matched_path,
139                        )
140                    })
141                    .on_response(
142                        |response: &Response<_>, latency: Duration, _: &tracing::Span| {
143                            let duration_ms = latency.as_micros() as f64 / 1000.;
144                            tracing::info!(
145                                duration_ms,
146                                status = response.status().as_u16(),
147                                "finished processing request"
148                            );
149                        },
150                    ),
151            );
152
153            #[cfg(unix)]
154            let signal = async move {
155                let mut sigterm = tokio::signal::unix::signal(SignalKind::terminate())
156                    .expect("failed to listen for interrupt signal");
157                let mut sigint = tokio::signal::unix::signal(SignalKind::interrupt())
158                    .expect("failed to listen for interrupt signal");
159                let sigterm = sigterm.recv();
160                let sigint = sigint.recv();
161                futures::pin_mut!(sigterm, sigint);
162                futures::future::select(sigterm, sigint).await;
163            };
164
165            #[cfg(windows)]
166            let signal = async move {
167                // todo(windows):
168                // `ctrl_close` does not work well, because tokio's signal handler always returns soon,
169                // but system terminates the application soon after returning CTRL+CLOSE handler.
170                // So we should implement blocking handler to treat CTRL+CLOSE signal.
171                let mut ctrl_break = tokio::signal::windows::ctrl_break()
172                    .expect("failed to listen for interrupt signal");
173                let mut ctrl_c = tokio::signal::windows::ctrl_c()
174                    .expect("failed to listen for interrupt signal");
175                let ctrl_break = ctrl_break.recv();
176                let ctrl_c = ctrl_c.recv();
177                futures::pin_mut!(ctrl_break, ctrl_c);
178                futures::future::select(ctrl_break, ctrl_c).await;
179            };
180
181            axum::Server::from_tcp(listener)
182                .map_err(|e| anyhow!(e))?
183                .serve(app.into_make_service_with_connect_info::<SocketAddr>())
184                .with_graceful_shutdown(async move {
185                    signal.await;
186                    tracing::info!("Received interrupt signal");
187
188                    if let Some(on_shutdown) = on_shutdown {
189                        on_shutdown();
190                    }
191                })
192                .await
193                .map_err(|e| anyhow!(e))?;
194        }
195        _ => {
196            Err(anyhow!(
197                "usage: collab <version | migrate | seed | serve <api|collab|llm|all>>"
198            ))?;
199        }
200    }
201    Ok(())
202}
203
204async fn run_migrations(config: &Config) -> Result<()> {
205    let db_options = db::ConnectOptions::new(config.database_url.clone());
206    let mut db = Database::new(db_options, Executor::Production).await?;
207
208    let migrations_path = config.migrations_path.as_deref().unwrap_or_else(|| {
209        #[cfg(feature = "sqlite")]
210        let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite");
211        #[cfg(not(feature = "sqlite"))]
212        let default_migrations = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
213
214        Path::new(default_migrations)
215    });
216
217    let migrations = db.migrate(&migrations_path, false).await?;
218    for (migration, duration) in migrations {
219        log::info!(
220            "Migrated {} {} {:?}",
221            migration.version,
222            migration.description,
223            duration
224        );
225    }
226
227    db.initialize_notification_kinds().await?;
228
229    if config.seed_path.is_some() {
230        collab::seed::seed(&config, &db, false).await?;
231    }
232
233    return Ok(());
234}
235
236async fn handle_root(Extension(mode): Extension<ServiceMode>) -> String {
237    format!(
238        "collab {mode:?} v{VERSION} ({})",
239        REVISION.unwrap_or("unknown")
240    )
241}
242
243async fn handle_liveness_probe(
244    app_state: Option<Extension<Arc<AppState>>>,
245    llm_state: Option<Extension<Arc<LlmState>>>,
246) -> Result<String> {
247    if let Some(state) = app_state {
248        state.db.get_all_users(0, 1).await?;
249    }
250
251    if let Some(_llm_state) = llm_state {}
252
253    Ok("ok".to_string())
254}
255
256pub fn init_tracing(config: &Config) -> Option<()> {
257    use std::str::FromStr;
258    use tracing_subscriber::layer::SubscriberExt;
259
260    let filter = EnvFilter::from_str(config.rust_log.as_deref()?).log_err()?;
261
262    tracing_subscriber::registry()
263        .with(if config.log_json.unwrap_or(false) {
264            Box::new(
265                tracing_subscriber::fmt::layer()
266                    .fmt_fields(JsonFields::default())
267                    .event_format(
268                        tracing_subscriber::fmt::format()
269                            .json()
270                            .flatten_event(true)
271                            .with_span_list(false),
272                    )
273                    .with_filter(filter),
274            ) as Box<dyn Layer<_> + Send + Sync>
275        } else {
276            Box::new(
277                tracing_subscriber::fmt::layer()
278                    .event_format(tracing_subscriber::fmt::format().pretty())
279                    .with_filter(filter),
280            )
281        })
282        .init();
283
284    None
285}