lib.rs

  1pub mod api;
  2pub mod auth;
  3pub mod db;
  4pub mod env;
  5pub mod executor;
  6pub mod llm;
  7mod rate_limiter;
  8pub mod rpc;
  9pub mod seed;
 10
 11#[cfg(test)]
 12mod tests;
 13
 14use anyhow::anyhow;
 15use aws_config::{BehaviorVersion, Region};
 16use axum::{
 17    http::{HeaderMap, StatusCode},
 18    response::IntoResponse,
 19};
 20use db::{ChannelId, Database};
 21use executor::Executor;
 22pub use rate_limiter::*;
 23use serde::Deserialize;
 24use std::{path::PathBuf, sync::Arc};
 25use util::ResultExt;
 26
 27pub type Result<T, E = Error> = std::result::Result<T, E>;
 28
 29pub enum Error {
 30    Http(StatusCode, String, HeaderMap),
 31    Database(sea_orm::error::DbErr),
 32    Internal(anyhow::Error),
 33    Stripe(stripe::StripeError),
 34}
 35
 36impl From<anyhow::Error> for Error {
 37    fn from(error: anyhow::Error) -> Self {
 38        Self::Internal(error)
 39    }
 40}
 41
 42impl From<sea_orm::error::DbErr> for Error {
 43    fn from(error: sea_orm::error::DbErr) -> Self {
 44        Self::Database(error)
 45    }
 46}
 47
 48impl From<stripe::StripeError> for Error {
 49    fn from(error: stripe::StripeError) -> Self {
 50        Self::Stripe(error)
 51    }
 52}
 53
 54impl From<axum::Error> for Error {
 55    fn from(error: axum::Error) -> Self {
 56        Self::Internal(error.into())
 57    }
 58}
 59
 60impl From<axum::http::Error> for Error {
 61    fn from(error: axum::http::Error) -> Self {
 62        Self::Internal(error.into())
 63    }
 64}
 65
 66impl From<serde_json::Error> for Error {
 67    fn from(error: serde_json::Error) -> Self {
 68        Self::Internal(error.into())
 69    }
 70}
 71
 72impl Error {
 73    fn http(code: StatusCode, message: String) -> Self {
 74        Self::Http(code, message, HeaderMap::default())
 75    }
 76}
 77
 78impl IntoResponse for Error {
 79    fn into_response(self) -> axum::response::Response {
 80        match self {
 81            Error::Http(code, message, headers) => {
 82                log::error!("HTTP error {}: {}", code, &message);
 83                (code, headers, message).into_response()
 84            }
 85            Error::Database(error) => {
 86                log::error!(
 87                    "HTTP error {}: {:?}",
 88                    StatusCode::INTERNAL_SERVER_ERROR,
 89                    &error
 90                );
 91                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
 92            }
 93            Error::Internal(error) => {
 94                log::error!(
 95                    "HTTP error {}: {:?}",
 96                    StatusCode::INTERNAL_SERVER_ERROR,
 97                    &error
 98                );
 99                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
100            }
101            Error::Stripe(error) => {
102                log::error!(
103                    "HTTP error {}: {:?}",
104                    StatusCode::INTERNAL_SERVER_ERROR,
105                    &error
106                );
107                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
108            }
109        }
110    }
111}
112
113impl std::fmt::Debug for Error {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        match self {
116            Error::Http(code, message, _headers) => (code, message).fmt(f),
117            Error::Database(error) => error.fmt(f),
118            Error::Internal(error) => error.fmt(f),
119            Error::Stripe(error) => error.fmt(f),
120        }
121    }
122}
123
124impl std::fmt::Display for Error {
125    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126        match self {
127            Error::Http(code, message, _) => write!(f, "{code}: {message}"),
128            Error::Database(error) => error.fmt(f),
129            Error::Internal(error) => error.fmt(f),
130            Error::Stripe(error) => error.fmt(f),
131        }
132    }
133}
134
135impl std::error::Error for Error {}
136
137#[derive(Clone, Deserialize)]
138pub struct Config {
139    pub http_port: u16,
140    pub database_url: String,
141    pub migrations_path: Option<PathBuf>,
142    pub seed_path: Option<PathBuf>,
143    pub database_max_connections: u32,
144    pub api_token: String,
145    pub clickhouse_url: Option<String>,
146    pub clickhouse_user: Option<String>,
147    pub clickhouse_password: Option<String>,
148    pub clickhouse_database: Option<String>,
149    pub invite_link_prefix: String,
150    pub live_kit_server: Option<String>,
151    pub live_kit_key: Option<String>,
152    pub live_kit_secret: Option<String>,
153    pub llm_api_secret: Option<String>,
154    pub rust_log: Option<String>,
155    pub log_json: Option<bool>,
156    pub blob_store_url: Option<String>,
157    pub blob_store_region: Option<String>,
158    pub blob_store_access_key: Option<String>,
159    pub blob_store_secret_key: Option<String>,
160    pub blob_store_bucket: Option<String>,
161    pub zed_environment: Arc<str>,
162    pub openai_api_key: Option<Arc<str>>,
163    pub google_ai_api_key: Option<Arc<str>>,
164    pub anthropic_api_key: Option<Arc<str>>,
165    pub qwen2_7b_api_key: Option<Arc<str>>,
166    pub qwen2_7b_api_url: Option<Arc<str>>,
167    pub zed_client_checksum_seed: Option<String>,
168    pub slack_panics_webhook: Option<String>,
169    pub auto_join_channel_id: Option<ChannelId>,
170    pub stripe_api_key: Option<String>,
171    pub stripe_price_id: Option<Arc<str>>,
172    pub supermaven_admin_api_key: Option<Arc<str>>,
173}
174
175impl Config {
176    pub fn is_development(&self) -> bool {
177        self.zed_environment == "development".into()
178    }
179
180    /// Returns the base `zed.dev` URL.
181    pub fn zed_dot_dev_url(&self) -> &str {
182        match self.zed_environment.as_ref() {
183            "development" => "http://localhost:3000",
184            "staging" => "https://staging.zed.dev",
185            _ => "https://zed.dev",
186        }
187    }
188}
189
190/// The service mode that collab should run in.
191#[derive(Debug, PartialEq, Eq, Clone, Copy)]
192pub enum ServiceMode {
193    Api,
194    Collab,
195    Llm,
196    All,
197}
198
199impl ServiceMode {
200    pub fn is_collab(&self) -> bool {
201        matches!(self, Self::Collab | Self::All)
202    }
203
204    pub fn is_api(&self) -> bool {
205        matches!(self, Self::Api | Self::All)
206    }
207
208    pub fn is_llm(&self) -> bool {
209        matches!(self, Self::Llm | Self::All)
210    }
211}
212
213pub struct AppState {
214    pub db: Arc<Database>,
215    pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
216    pub blob_store_client: Option<aws_sdk_s3::Client>,
217    pub stripe_client: Option<Arc<stripe::Client>>,
218    pub rate_limiter: Arc<RateLimiter>,
219    pub executor: Executor,
220    pub clickhouse_client: Option<clickhouse::Client>,
221    pub config: Config,
222}
223
224impl AppState {
225    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
226        let mut db_options = db::ConnectOptions::new(config.database_url.clone());
227        db_options.max_connections(config.database_max_connections);
228        let mut db = Database::new(db_options, Executor::Production).await?;
229        db.initialize_notification_kinds().await?;
230
231        let live_kit_client = if let Some(((server, key), secret)) = config
232            .live_kit_server
233            .as_ref()
234            .zip(config.live_kit_key.as_ref())
235            .zip(config.live_kit_secret.as_ref())
236        {
237            Some(Arc::new(live_kit_server::api::LiveKitClient::new(
238                server.clone(),
239                key.clone(),
240                secret.clone(),
241            )) as Arc<dyn live_kit_server::api::Client>)
242        } else {
243            None
244        };
245
246        let db = Arc::new(db);
247        let this = Self {
248            db: db.clone(),
249            live_kit_client,
250            blob_store_client: build_blob_store_client(&config).await.log_err(),
251            stripe_client: build_stripe_client(&config)
252                .await
253                .map(|client| Arc::new(client))
254                .log_err(),
255            rate_limiter: Arc::new(RateLimiter::new(db)),
256            executor,
257            clickhouse_client: config
258                .clickhouse_url
259                .as_ref()
260                .and_then(|_| build_clickhouse_client(&config).log_err()),
261            config,
262        };
263        Ok(Arc::new(this))
264    }
265}
266
267async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
268    let api_key = config
269        .stripe_api_key
270        .as_ref()
271        .ok_or_else(|| anyhow!("missing stripe_api_key"))?;
272
273    Ok(stripe::Client::new(api_key))
274}
275
276async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::Client> {
277    let keys = aws_sdk_s3::config::Credentials::new(
278        config
279            .blob_store_access_key
280            .clone()
281            .ok_or_else(|| anyhow!("missing blob_store_access_key"))?,
282        config
283            .blob_store_secret_key
284            .clone()
285            .ok_or_else(|| anyhow!("missing blob_store_secret_key"))?,
286        None,
287        None,
288        "env",
289    );
290
291    let s3_config = aws_config::defaults(BehaviorVersion::latest())
292        .endpoint_url(
293            config
294                .blob_store_url
295                .as_ref()
296                .ok_or_else(|| anyhow!("missing blob_store_url"))?,
297        )
298        .region(Region::new(
299            config
300                .blob_store_region
301                .clone()
302                .ok_or_else(|| anyhow!("missing blob_store_region"))?,
303        ))
304        .credentials_provider(keys)
305        .load()
306        .await;
307
308    Ok(aws_sdk_s3::Client::new(&s3_config))
309}
310
311fn build_clickhouse_client(config: &Config) -> anyhow::Result<clickhouse::Client> {
312    Ok(clickhouse::Client::default()
313        .with_url(
314            config
315                .clickhouse_url
316                .as_ref()
317                .ok_or_else(|| anyhow!("missing clickhouse_url"))?,
318        )
319        .with_user(
320            config
321                .clickhouse_user
322                .as_ref()
323                .ok_or_else(|| anyhow!("missing clickhouse_user"))?,
324        )
325        .with_password(
326            config
327                .clickhouse_password
328                .as_ref()
329                .ok_or_else(|| anyhow!("missing clickhouse_password"))?,
330        )
331        .with_database(
332            config
333                .clickhouse_database
334                .as_ref()
335                .ok_or_else(|| anyhow!("missing clickhouse_database"))?,
336        ))
337}