lib.rs

  1pub mod api;
  2pub mod auth;
  3pub mod db;
  4pub mod env;
  5pub mod executor;
  6pub mod llm;
  7mod rate_limiter;
  8pub mod rpc;
  9pub mod seed;
 10
 11#[cfg(test)]
 12mod tests;
 13
 14use anyhow::anyhow;
 15use aws_config::{BehaviorVersion, Region};
 16use axum::{
 17    http::{HeaderMap, StatusCode},
 18    response::IntoResponse,
 19};
 20use db::{ChannelId, Database};
 21use executor::Executor;
 22pub use rate_limiter::*;
 23use serde::Deserialize;
 24use std::{path::PathBuf, sync::Arc};
 25use util::ResultExt;
 26
 27pub type Result<T, E = Error> = std::result::Result<T, E>;
 28
 29pub enum Error {
 30    Http(StatusCode, String, HeaderMap),
 31    Database(sea_orm::error::DbErr),
 32    Internal(anyhow::Error),
 33    Stripe(stripe::StripeError),
 34}
 35
 36impl From<anyhow::Error> for Error {
 37    fn from(error: anyhow::Error) -> Self {
 38        Self::Internal(error)
 39    }
 40}
 41
 42impl From<sea_orm::error::DbErr> for Error {
 43    fn from(error: sea_orm::error::DbErr) -> Self {
 44        Self::Database(error)
 45    }
 46}
 47
 48impl From<stripe::StripeError> for Error {
 49    fn from(error: stripe::StripeError) -> Self {
 50        Self::Stripe(error)
 51    }
 52}
 53
 54impl From<axum::Error> for Error {
 55    fn from(error: axum::Error) -> Self {
 56        Self::Internal(error.into())
 57    }
 58}
 59
 60impl From<axum::http::Error> for Error {
 61    fn from(error: axum::http::Error) -> Self {
 62        Self::Internal(error.into())
 63    }
 64}
 65
 66impl From<serde_json::Error> for Error {
 67    fn from(error: serde_json::Error) -> Self {
 68        Self::Internal(error.into())
 69    }
 70}
 71
 72impl Error {
 73    fn http(code: StatusCode, message: String) -> Self {
 74        Self::Http(code, message, HeaderMap::default())
 75    }
 76}
 77
 78impl IntoResponse for Error {
 79    fn into_response(self) -> axum::response::Response {
 80        match self {
 81            Error::Http(code, message, headers) => {
 82                log::error!("HTTP error {}: {}", code, &message);
 83                (code, headers, message).into_response()
 84            }
 85            Error::Database(error) => {
 86                log::error!(
 87                    "HTTP error {}: {:?}",
 88                    StatusCode::INTERNAL_SERVER_ERROR,
 89                    &error
 90                );
 91                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
 92            }
 93            Error::Internal(error) => {
 94                log::error!(
 95                    "HTTP error {}: {:?}",
 96                    StatusCode::INTERNAL_SERVER_ERROR,
 97                    &error
 98                );
 99                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
100            }
101            Error::Stripe(error) => {
102                log::error!(
103                    "HTTP error {}: {:?}",
104                    StatusCode::INTERNAL_SERVER_ERROR,
105                    &error
106                );
107                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
108            }
109        }
110    }
111}
112
113impl std::fmt::Debug for Error {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        match self {
116            Error::Http(code, message, _headers) => (code, message).fmt(f),
117            Error::Database(error) => error.fmt(f),
118            Error::Internal(error) => error.fmt(f),
119            Error::Stripe(error) => error.fmt(f),
120        }
121    }
122}
123
124impl std::fmt::Display for Error {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        match self {
127            Error::Http(code, message, _) => write!(f, "{code}: {message}"),
128            Error::Database(error) => error.fmt(f),
129            Error::Internal(error) => error.fmt(f),
130            Error::Stripe(error) => error.fmt(f),
131        }
132    }
133}
134
135impl std::error::Error for Error {}
136
137#[derive(Clone, Deserialize)]
138pub struct Config {
139    pub http_port: u16,
140    pub database_url: String,
141    pub migrations_path: Option<PathBuf>,
142    pub seed_path: Option<PathBuf>,
143    pub database_max_connections: u32,
144    pub api_token: String,
145    pub clickhouse_url: Option<String>,
146    pub clickhouse_user: Option<String>,
147    pub clickhouse_password: Option<String>,
148    pub clickhouse_database: Option<String>,
149    pub invite_link_prefix: String,
150    pub live_kit_server: Option<String>,
151    pub live_kit_key: Option<String>,
152    pub live_kit_secret: Option<String>,
153    pub llm_api_secret: Option<String>,
154    pub rust_log: Option<String>,
155    pub log_json: Option<bool>,
156    pub blob_store_url: Option<String>,
157    pub blob_store_region: Option<String>,
158    pub blob_store_access_key: Option<String>,
159    pub blob_store_secret_key: Option<String>,
160    pub blob_store_bucket: Option<String>,
161    pub zed_environment: Arc<str>,
162    pub openai_api_key: Option<Arc<str>>,
163    pub google_ai_api_key: Option<Arc<str>>,
164    pub anthropic_api_key: Option<Arc<str>>,
165    pub qwen2_7b_api_key: Option<Arc<str>>,
166    pub qwen2_7b_api_url: Option<Arc<str>>,
167    pub zed_client_checksum_seed: Option<String>,
168    pub slack_panics_webhook: Option<String>,
169    pub auto_join_channel_id: Option<ChannelId>,
170    pub stripe_api_key: Option<String>,
171    pub stripe_price_id: Option<Arc<str>>,
172    pub supermaven_admin_api_key: Option<Arc<str>>,
173}
174
175impl Config {
176    pub fn is_development(&self) -> bool {
177        self.zed_environment == "development".into()
178    }
179
180    /// Returns the base `zed.dev` URL.
181    pub fn zed_dot_dev_url(&self) -> &str {
182        match self.zed_environment.as_ref() {
183            "development" => "http://localhost:3000",
184            "staging" => "https://staging.zed.dev",
185            _ => "https://zed.dev",
186        }
187    }
188
189    #[cfg(test)]
190    pub fn test() -> Self {
191        Self {
192            http_port: 0,
193            database_url: "".into(),
194            database_max_connections: 0,
195            api_token: "".into(),
196            invite_link_prefix: "".into(),
197            live_kit_server: None,
198            live_kit_key: None,
199            live_kit_secret: None,
200            llm_api_secret: None,
201            rust_log: None,
202            log_json: None,
203            zed_environment: "test".into(),
204            blob_store_url: None,
205            blob_store_region: None,
206            blob_store_access_key: None,
207            blob_store_secret_key: None,
208            blob_store_bucket: None,
209            openai_api_key: None,
210            google_ai_api_key: None,
211            anthropic_api_key: None,
212            clickhouse_url: None,
213            clickhouse_user: None,
214            clickhouse_password: None,
215            clickhouse_database: None,
216            zed_client_checksum_seed: None,
217            slack_panics_webhook: None,
218            auto_join_channel_id: None,
219            migrations_path: None,
220            seed_path: None,
221            stripe_api_key: None,
222            stripe_price_id: None,
223            supermaven_admin_api_key: None,
224            qwen2_7b_api_key: None,
225            qwen2_7b_api_url: None,
226        }
227    }
228}
229
230/// The service mode that collab should run in.
231#[derive(Debug, PartialEq, Eq, Clone, Copy)]
232pub enum ServiceMode {
233    Api,
234    Collab,
235    Llm,
236    All,
237}
238
239impl ServiceMode {
240    pub fn is_collab(&self) -> bool {
241        matches!(self, Self::Collab | Self::All)
242    }
243
244    pub fn is_api(&self) -> bool {
245        matches!(self, Self::Api | Self::All)
246    }
247
248    pub fn is_llm(&self) -> bool {
249        matches!(self, Self::Llm | Self::All)
250    }
251}
252
253pub struct AppState {
254    pub db: Arc<Database>,
255    pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
256    pub blob_store_client: Option<aws_sdk_s3::Client>,
257    pub stripe_client: Option<Arc<stripe::Client>>,
258    pub rate_limiter: Arc<RateLimiter>,
259    pub executor: Executor,
260    pub clickhouse_client: Option<clickhouse::Client>,
261    pub config: Config,
262}
263
264impl AppState {
265    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
266        let mut db_options = db::ConnectOptions::new(config.database_url.clone());
267        db_options.max_connections(config.database_max_connections);
268        let mut db = Database::new(db_options, Executor::Production).await?;
269        db.initialize_notification_kinds().await?;
270
271        let live_kit_client = if let Some(((server, key), secret)) = config
272            .live_kit_server
273            .as_ref()
274            .zip(config.live_kit_key.as_ref())
275            .zip(config.live_kit_secret.as_ref())
276        {
277            Some(Arc::new(live_kit_server::api::LiveKitClient::new(
278                server.clone(),
279                key.clone(),
280                secret.clone(),
281            )) as Arc<dyn live_kit_server::api::Client>)
282        } else {
283            None
284        };
285
286        let db = Arc::new(db);
287        let this = Self {
288            db: db.clone(),
289            live_kit_client,
290            blob_store_client: build_blob_store_client(&config).await.log_err(),
291            stripe_client: build_stripe_client(&config)
292                .await
293                .map(|client| Arc::new(client))
294                .log_err(),
295            rate_limiter: Arc::new(RateLimiter::new(db)),
296            executor,
297            clickhouse_client: config
298                .clickhouse_url
299                .as_ref()
300                .and_then(|_| build_clickhouse_client(&config).log_err()),
301            config,
302        };
303        Ok(Arc::new(this))
304    }
305}
306
307async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
308    let api_key = config
309        .stripe_api_key
310        .as_ref()
311        .ok_or_else(|| anyhow!("missing stripe_api_key"))?;
312
313    Ok(stripe::Client::new(api_key))
314}
315
316async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::Client> {
317    let keys = aws_sdk_s3::config::Credentials::new(
318        config
319            .blob_store_access_key
320            .clone()
321            .ok_or_else(|| anyhow!("missing blob_store_access_key"))?,
322        config
323            .blob_store_secret_key
324            .clone()
325            .ok_or_else(|| anyhow!("missing blob_store_secret_key"))?,
326        None,
327        None,
328        "env",
329    );
330
331    let s3_config = aws_config::defaults(BehaviorVersion::latest())
332        .endpoint_url(
333            config
334                .blob_store_url
335                .as_ref()
336                .ok_or_else(|| anyhow!("missing blob_store_url"))?,
337        )
338        .region(Region::new(
339            config
340                .blob_store_region
341                .clone()
342                .ok_or_else(|| anyhow!("missing blob_store_region"))?,
343        ))
344        .credentials_provider(keys)
345        .load()
346        .await;
347
348    Ok(aws_sdk_s3::Client::new(&s3_config))
349}
350
351fn build_clickhouse_client(config: &Config) -> anyhow::Result<clickhouse::Client> {
352    Ok(clickhouse::Client::default()
353        .with_url(
354            config
355                .clickhouse_url
356                .as_ref()
357                .ok_or_else(|| anyhow!("missing clickhouse_url"))?,
358        )
359        .with_user(
360            config
361                .clickhouse_user
362                .as_ref()
363                .ok_or_else(|| anyhow!("missing clickhouse_user"))?,
364        )
365        .with_password(
366            config
367                .clickhouse_password
368                .as_ref()
369                .ok_or_else(|| anyhow!("missing clickhouse_password"))?,
370        )
371        .with_database(
372            config
373                .clickhouse_database
374                .as_ref()
375                .ok_or_else(|| anyhow!("missing clickhouse_database"))?,
376        ))
377}