lib.rs

  1pub mod api;
  2pub mod auth;
  3pub mod db;
  4pub mod env;
  5pub mod executor;
  6pub mod llm;
  7mod rate_limiter;
  8pub mod rpc;
  9pub mod seed;
 10
 11#[cfg(test)]
 12mod tests;
 13
 14use anyhow::anyhow;
 15use aws_config::{BehaviorVersion, Region};
 16use axum::{http::StatusCode, response::IntoResponse};
 17use db::{ChannelId, Database};
 18use executor::Executor;
 19pub use rate_limiter::*;
 20use serde::Deserialize;
 21use std::{path::PathBuf, sync::Arc};
 22use util::ResultExt;
 23
 24pub type Result<T, E = Error> = std::result::Result<T, E>;
 25
 26pub enum Error {
 27    Http(StatusCode, String),
 28    Database(sea_orm::error::DbErr),
 29    Internal(anyhow::Error),
 30    Stripe(stripe::StripeError),
 31}
 32
 33impl From<anyhow::Error> for Error {
 34    fn from(error: anyhow::Error) -> Self {
 35        Self::Internal(error)
 36    }
 37}
 38
 39impl From<sea_orm::error::DbErr> for Error {
 40    fn from(error: sea_orm::error::DbErr) -> Self {
 41        Self::Database(error)
 42    }
 43}
 44
 45impl From<stripe::StripeError> for Error {
 46    fn from(error: stripe::StripeError) -> Self {
 47        Self::Stripe(error)
 48    }
 49}
 50
 51impl From<axum::Error> for Error {
 52    fn from(error: axum::Error) -> Self {
 53        Self::Internal(error.into())
 54    }
 55}
 56
 57impl From<axum::http::Error> for Error {
 58    fn from(error: axum::http::Error) -> Self {
 59        Self::Internal(error.into())
 60    }
 61}
 62
 63impl From<serde_json::Error> for Error {
 64    fn from(error: serde_json::Error) -> Self {
 65        Self::Internal(error.into())
 66    }
 67}
 68
 69impl IntoResponse for Error {
 70    fn into_response(self) -> axum::response::Response {
 71        match self {
 72            Error::Http(code, message) => {
 73                log::error!("HTTP error {}: {}", code, &message);
 74                (code, message).into_response()
 75            }
 76            Error::Database(error) => {
 77                log::error!(
 78                    "HTTP error {}: {:?}",
 79                    StatusCode::INTERNAL_SERVER_ERROR,
 80                    &error
 81                );
 82                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
 83            }
 84            Error::Internal(error) => {
 85                log::error!(
 86                    "HTTP error {}: {:?}",
 87                    StatusCode::INTERNAL_SERVER_ERROR,
 88                    &error
 89                );
 90                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
 91            }
 92            Error::Stripe(error) => {
 93                log::error!(
 94                    "HTTP error {}: {:?}",
 95                    StatusCode::INTERNAL_SERVER_ERROR,
 96                    &error
 97                );
 98                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
 99            }
100        }
101    }
102}
103
104impl std::fmt::Debug for Error {
105    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106        match self {
107            Error::Http(code, message) => (code, message).fmt(f),
108            Error::Database(error) => error.fmt(f),
109            Error::Internal(error) => error.fmt(f),
110            Error::Stripe(error) => error.fmt(f),
111        }
112    }
113}
114
115impl std::fmt::Display for Error {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        match self {
118            Error::Http(code, message) => write!(f, "{code}: {message}"),
119            Error::Database(error) => error.fmt(f),
120            Error::Internal(error) => error.fmt(f),
121            Error::Stripe(error) => error.fmt(f),
122        }
123    }
124}
125
126impl std::error::Error for Error {}
127
128#[derive(Clone, Deserialize)]
129pub struct Config {
130    pub http_port: u16,
131    pub database_url: String,
132    pub migrations_path: Option<PathBuf>,
133    pub seed_path: Option<PathBuf>,
134    pub database_max_connections: u32,
135    pub api_token: String,
136    pub clickhouse_url: Option<String>,
137    pub clickhouse_user: Option<String>,
138    pub clickhouse_password: Option<String>,
139    pub clickhouse_database: Option<String>,
140    pub invite_link_prefix: String,
141    pub live_kit_server: Option<String>,
142    pub live_kit_key: Option<String>,
143    pub live_kit_secret: Option<String>,
144    pub rust_log: Option<String>,
145    pub log_json: Option<bool>,
146    pub blob_store_url: Option<String>,
147    pub blob_store_region: Option<String>,
148    pub blob_store_access_key: Option<String>,
149    pub blob_store_secret_key: Option<String>,
150    pub blob_store_bucket: Option<String>,
151    pub zed_environment: Arc<str>,
152    pub openai_api_key: Option<Arc<str>>,
153    pub google_ai_api_key: Option<Arc<str>>,
154    pub anthropic_api_key: Option<Arc<str>>,
155    pub qwen2_7b_api_key: Option<Arc<str>>,
156    pub qwen2_7b_api_url: Option<Arc<str>>,
157    pub zed_client_checksum_seed: Option<String>,
158    pub slack_panics_webhook: Option<String>,
159    pub auto_join_channel_id: Option<ChannelId>,
160    pub stripe_api_key: Option<String>,
161    pub stripe_price_id: Option<Arc<str>>,
162    pub supermaven_admin_api_key: Option<Arc<str>>,
163}
164
165impl Config {
166    pub fn is_development(&self) -> bool {
167        self.zed_environment == "development".into()
168    }
169
170    /// Returns the base `zed.dev` URL.
171    pub fn zed_dot_dev_url(&self) -> &str {
172        match self.zed_environment.as_ref() {
173            "development" => "http://localhost:3000",
174            "staging" => "https://staging.zed.dev",
175            _ => "https://zed.dev",
176        }
177    }
178}
179
180/// The service mode that collab should run in.
181#[derive(Debug, PartialEq, Eq, Clone, Copy)]
182pub enum ServiceMode {
183    Api,
184    Collab,
185    Llm,
186    All,
187}
188
189impl ServiceMode {
190    pub fn is_collab(&self) -> bool {
191        matches!(self, Self::Collab | Self::All)
192    }
193
194    pub fn is_api(&self) -> bool {
195        matches!(self, Self::Api | Self::All)
196    }
197
198    pub fn is_llm(&self) -> bool {
199        matches!(self, Self::Llm | Self::All)
200    }
201}
202
203pub struct AppState {
204    pub db: Arc<Database>,
205    pub live_kit_client: Option<Arc<dyn live_kit_server::api::Client>>,
206    pub blob_store_client: Option<aws_sdk_s3::Client>,
207    pub stripe_client: Option<Arc<stripe::Client>>,
208    pub rate_limiter: Arc<RateLimiter>,
209    pub executor: Executor,
210    pub clickhouse_client: Option<clickhouse::Client>,
211    pub config: Config,
212}
213
214impl AppState {
215    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
216        let mut db_options = db::ConnectOptions::new(config.database_url.clone());
217        db_options.max_connections(config.database_max_connections);
218        let mut db = Database::new(db_options, Executor::Production).await?;
219        db.initialize_notification_kinds().await?;
220
221        let live_kit_client = if let Some(((server, key), secret)) = config
222            .live_kit_server
223            .as_ref()
224            .zip(config.live_kit_key.as_ref())
225            .zip(config.live_kit_secret.as_ref())
226        {
227            Some(Arc::new(live_kit_server::api::LiveKitClient::new(
228                server.clone(),
229                key.clone(),
230                secret.clone(),
231            )) as Arc<dyn live_kit_server::api::Client>)
232        } else {
233            None
234        };
235
236        let db = Arc::new(db);
237        let this = Self {
238            db: db.clone(),
239            live_kit_client,
240            blob_store_client: build_blob_store_client(&config).await.log_err(),
241            stripe_client: build_stripe_client(&config)
242                .await
243                .map(|client| Arc::new(client))
244                .log_err(),
245            rate_limiter: Arc::new(RateLimiter::new(db)),
246            executor,
247            clickhouse_client: config
248                .clickhouse_url
249                .as_ref()
250                .and_then(|_| build_clickhouse_client(&config).log_err()),
251            config,
252        };
253        Ok(Arc::new(this))
254    }
255}
256
257async fn build_stripe_client(config: &Config) -> anyhow::Result<stripe::Client> {
258    let api_key = config
259        .stripe_api_key
260        .as_ref()
261        .ok_or_else(|| anyhow!("missing stripe_api_key"))?;
262
263    Ok(stripe::Client::new(api_key))
264}
265
266async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::Client> {
267    let keys = aws_sdk_s3::config::Credentials::new(
268        config
269            .blob_store_access_key
270            .clone()
271            .ok_or_else(|| anyhow!("missing blob_store_access_key"))?,
272        config
273            .blob_store_secret_key
274            .clone()
275            .ok_or_else(|| anyhow!("missing blob_store_secret_key"))?,
276        None,
277        None,
278        "env",
279    );
280
281    let s3_config = aws_config::defaults(BehaviorVersion::latest())
282        .endpoint_url(
283            config
284                .blob_store_url
285                .as_ref()
286                .ok_or_else(|| anyhow!("missing blob_store_url"))?,
287        )
288        .region(Region::new(
289            config
290                .blob_store_region
291                .clone()
292                .ok_or_else(|| anyhow!("missing blob_store_region"))?,
293        ))
294        .credentials_provider(keys)
295        .load()
296        .await;
297
298    Ok(aws_sdk_s3::Client::new(&s3_config))
299}
300
301fn build_clickhouse_client(config: &Config) -> anyhow::Result<clickhouse::Client> {
302    Ok(clickhouse::Client::default()
303        .with_url(
304            config
305                .clickhouse_url
306                .as_ref()
307                .ok_or_else(|| anyhow!("missing clickhouse_url"))?,
308        )
309        .with_user(
310            config
311                .clickhouse_user
312                .as_ref()
313                .ok_or_else(|| anyhow!("missing clickhouse_user"))?,
314        )
315        .with_password(
316            config
317                .clickhouse_password
318                .as_ref()
319                .ok_or_else(|| anyhow!("missing clickhouse_password"))?,
320        )
321        .with_database(
322            config
323                .clickhouse_database
324                .as_ref()
325                .ok_or_else(|| anyhow!("missing clickhouse_database"))?,
326        ))
327}