lib.rs

  1pub mod api;
  2pub mod auth;
  3pub mod db;
  4pub mod env;
  5pub mod executor;
  6pub mod llm;
  7pub mod migrations;
  8pub mod rpc;
  9pub mod seed;
 10pub mod user_backfiller;
 11
 12#[cfg(test)]
 13mod tests;
 14
 15use anyhow::Context as _;
 16use aws_config::{BehaviorVersion, Region};
 17use axum::{
 18    http::{HeaderMap, StatusCode},
 19    response::IntoResponse,
 20};
 21use db::{ChannelId, Database};
 22use executor::Executor;
 23use llm::db::LlmDatabase;
 24use serde::Deserialize;
 25use std::{path::PathBuf, sync::Arc};
 26use util::ResultExt;
 27
 28pub type Result<T, E = Error> = std::result::Result<T, E>;
 29
 30pub enum Error {
 31    Http(StatusCode, String, HeaderMap),
 32    Database(sea_orm::error::DbErr),
 33    Internal(anyhow::Error),
 34}
 35
 36impl From<anyhow::Error> for Error {
 37    fn from(error: anyhow::Error) -> Self {
 38        Self::Internal(error)
 39    }
 40}
 41
 42impl From<sea_orm::error::DbErr> for Error {
 43    fn from(error: sea_orm::error::DbErr) -> Self {
 44        Self::Database(error)
 45    }
 46}
 47
 48impl From<axum::Error> for Error {
 49    fn from(error: axum::Error) -> Self {
 50        Self::Internal(error.into())
 51    }
 52}
 53
 54impl From<axum::http::Error> for Error {
 55    fn from(error: axum::http::Error) -> Self {
 56        Self::Internal(error.into())
 57    }
 58}
 59
 60impl From<serde_json::Error> for Error {
 61    fn from(error: serde_json::Error) -> Self {
 62        Self::Internal(error.into())
 63    }
 64}
 65
 66impl Error {
 67    fn http(code: StatusCode, message: String) -> Self {
 68        Self::Http(code, message, HeaderMap::default())
 69    }
 70}
 71
 72impl IntoResponse for Error {
 73    fn into_response(self) -> axum::response::Response {
 74        match self {
 75            Error::Http(code, message, headers) => {
 76                log::error!("HTTP error {}: {}", code, &message);
 77                (code, headers, message).into_response()
 78            }
 79            Error::Database(error) => {
 80                log::error!(
 81                    "HTTP error {}: {:?}",
 82                    StatusCode::INTERNAL_SERVER_ERROR,
 83                    &error
 84                );
 85                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
 86            }
 87            Error::Internal(error) => {
 88                log::error!(
 89                    "HTTP error {}: {:?}",
 90                    StatusCode::INTERNAL_SERVER_ERROR,
 91                    &error
 92                );
 93                (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response()
 94            }
 95        }
 96    }
 97}
 98
 99impl std::fmt::Debug for Error {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        match self {
102            Error::Http(code, message, _headers) => (code, message).fmt(f),
103            Error::Database(error) => error.fmt(f),
104            Error::Internal(error) => error.fmt(f),
105        }
106    }
107}
108
109impl std::fmt::Display for Error {
110    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111        match self {
112            Error::Http(code, message, _) => write!(f, "{code}: {message}"),
113            Error::Database(error) => error.fmt(f),
114            Error::Internal(error) => error.fmt(f),
115        }
116    }
117}
118
119impl std::error::Error for Error {}
120
121#[derive(Clone, Deserialize)]
122pub struct Config {
123    pub http_port: u16,
124    pub database_url: String,
125    pub migrations_path: Option<PathBuf>,
126    pub seed_path: Option<PathBuf>,
127    pub database_max_connections: u32,
128    pub api_token: String,
129    pub invite_link_prefix: String,
130    pub livekit_server: Option<String>,
131    pub livekit_key: Option<String>,
132    pub livekit_secret: Option<String>,
133    pub llm_database_url: Option<String>,
134    pub llm_database_max_connections: Option<u32>,
135    pub llm_database_migrations_path: Option<PathBuf>,
136    pub llm_api_secret: Option<String>,
137    pub rust_log: Option<String>,
138    pub log_json: Option<bool>,
139    pub blob_store_url: Option<String>,
140    pub blob_store_region: Option<String>,
141    pub blob_store_access_key: Option<String>,
142    pub blob_store_secret_key: Option<String>,
143    pub blob_store_bucket: Option<String>,
144    pub kinesis_region: Option<String>,
145    pub kinesis_stream: Option<String>,
146    pub kinesis_access_key: Option<String>,
147    pub kinesis_secret_key: Option<String>,
148    pub zed_environment: Arc<str>,
149    pub openai_api_key: Option<Arc<str>>,
150    pub google_ai_api_key: Option<Arc<str>>,
151    pub anthropic_api_key: Option<Arc<str>>,
152    pub anthropic_staff_api_key: Option<Arc<str>>,
153    pub llm_closed_beta_model_name: Option<Arc<str>>,
154    pub prediction_api_url: Option<Arc<str>>,
155    pub prediction_api_key: Option<Arc<str>>,
156    pub prediction_model: Option<Arc<str>>,
157    pub zed_client_checksum_seed: Option<String>,
158    pub slack_panics_webhook: Option<String>,
159    pub auto_join_channel_id: Option<ChannelId>,
160    pub supermaven_admin_api_key: Option<Arc<str>>,
161    pub user_backfiller_github_access_token: Option<Arc<str>>,
162}
163
164impl Config {
165    pub fn is_development(&self) -> bool {
166        self.zed_environment == "development".into()
167    }
168
169    /// Returns the base `zed.dev` URL.
170    pub fn zed_dot_dev_url(&self) -> &str {
171        match self.zed_environment.as_ref() {
172            "development" => "http://localhost:3000",
173            "staging" => "https://staging.zed.dev",
174            _ => "https://zed.dev",
175        }
176    }
177
178    #[cfg(test)]
179    pub fn test() -> Self {
180        Self {
181            http_port: 0,
182            database_url: "".into(),
183            database_max_connections: 0,
184            api_token: "".into(),
185            invite_link_prefix: "".into(),
186            livekit_server: None,
187            livekit_key: None,
188            livekit_secret: None,
189            llm_database_url: None,
190            llm_database_max_connections: None,
191            llm_database_migrations_path: None,
192            llm_api_secret: None,
193            rust_log: None,
194            log_json: None,
195            zed_environment: "test".into(),
196            blob_store_url: None,
197            blob_store_region: None,
198            blob_store_access_key: None,
199            blob_store_secret_key: None,
200            blob_store_bucket: None,
201            openai_api_key: None,
202            google_ai_api_key: None,
203            anthropic_api_key: None,
204            anthropic_staff_api_key: None,
205            llm_closed_beta_model_name: None,
206            prediction_api_url: None,
207            prediction_api_key: None,
208            prediction_model: None,
209            zed_client_checksum_seed: None,
210            slack_panics_webhook: None,
211            auto_join_channel_id: None,
212            migrations_path: None,
213            seed_path: None,
214            supermaven_admin_api_key: None,
215            user_backfiller_github_access_token: None,
216            kinesis_region: None,
217            kinesis_access_key: None,
218            kinesis_secret_key: None,
219            kinesis_stream: None,
220        }
221    }
222}
223
224/// The service mode that collab should run in.
225#[derive(Debug, PartialEq, Eq, Clone, Copy, strum::Display)]
226#[strum(serialize_all = "snake_case")]
227pub enum ServiceMode {
228    Api,
229    Collab,
230    All,
231}
232
233impl ServiceMode {
234    pub fn is_collab(&self) -> bool {
235        matches!(self, Self::Collab | Self::All)
236    }
237
238    pub fn is_api(&self) -> bool {
239        matches!(self, Self::Api | Self::All)
240    }
241}
242
243pub struct AppState {
244    pub db: Arc<Database>,
245    pub llm_db: Option<Arc<LlmDatabase>>,
246    pub livekit_client: Option<Arc<dyn livekit_api::Client>>,
247    pub blob_store_client: Option<aws_sdk_s3::Client>,
248    pub executor: Executor,
249    pub kinesis_client: Option<::aws_sdk_kinesis::Client>,
250    pub config: Config,
251}
252
253impl AppState {
254    pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
255        let mut db_options = db::ConnectOptions::new(config.database_url.clone());
256        db_options.max_connections(config.database_max_connections);
257        let mut db = Database::new(db_options).await?;
258        db.initialize_notification_kinds().await?;
259
260        let llm_db = if let Some((llm_database_url, llm_database_max_connections)) = config
261            .llm_database_url
262            .clone()
263            .zip(config.llm_database_max_connections)
264        {
265            let mut llm_db_options = db::ConnectOptions::new(llm_database_url);
266            llm_db_options.max_connections(llm_database_max_connections);
267            let mut llm_db = LlmDatabase::new(llm_db_options, executor.clone()).await?;
268            llm_db.initialize().await?;
269            Some(Arc::new(llm_db))
270        } else {
271            None
272        };
273
274        let livekit_client = if let Some(((server, key), secret)) = config
275            .livekit_server
276            .as_ref()
277            .zip(config.livekit_key.as_ref())
278            .zip(config.livekit_secret.as_ref())
279        {
280            Some(Arc::new(livekit_api::LiveKitClient::new(
281                server.clone(),
282                key.clone(),
283                secret.clone(),
284            )) as Arc<dyn livekit_api::Client>)
285        } else {
286            None
287        };
288
289        let db = Arc::new(db);
290        let this = Self {
291            db: db.clone(),
292            llm_db,
293            livekit_client,
294            blob_store_client: build_blob_store_client(&config).await.log_err(),
295            executor,
296            kinesis_client: if config.kinesis_access_key.is_some() {
297                build_kinesis_client(&config).await.log_err()
298            } else {
299                None
300            },
301            config,
302        };
303        Ok(Arc::new(this))
304    }
305}
306
307async fn build_blob_store_client(config: &Config) -> anyhow::Result<aws_sdk_s3::Client> {
308    let keys = aws_sdk_s3::config::Credentials::new(
309        config
310            .blob_store_access_key
311            .clone()
312            .context("missing blob_store_access_key")?,
313        config
314            .blob_store_secret_key
315            .clone()
316            .context("missing blob_store_secret_key")?,
317        None,
318        None,
319        "env",
320    );
321
322    let s3_config = aws_config::defaults(BehaviorVersion::latest())
323        .endpoint_url(
324            config
325                .blob_store_url
326                .as_ref()
327                .context("missing blob_store_url")?,
328        )
329        .region(Region::new(
330            config
331                .blob_store_region
332                .clone()
333                .context("missing blob_store_region")?,
334        ))
335        .credentials_provider(keys)
336        .load()
337        .await;
338
339    Ok(aws_sdk_s3::Client::new(&s3_config))
340}
341
342async fn build_kinesis_client(config: &Config) -> anyhow::Result<aws_sdk_kinesis::Client> {
343    let keys = aws_sdk_s3::config::Credentials::new(
344        config
345            .kinesis_access_key
346            .clone()
347            .context("missing kinesis_access_key")?,
348        config
349            .kinesis_secret_key
350            .clone()
351            .context("missing kinesis_secret_key")?,
352        None,
353        None,
354        "env",
355    );
356
357    let kinesis_config = aws_config::defaults(BehaviorVersion::latest())
358        .region(Region::new(
359            config
360                .kinesis_region
361                .clone()
362                .context("missing kinesis_region")?,
363        ))
364        .credentials_provider(keys)
365        .load()
366        .await;
367
368    Ok(aws_sdk_kinesis::Client::new(&kinesis_config))
369}