main.rs

  1mod admin;
  2mod assets;
  3mod auth;
  4mod community;
  5mod db;
  6mod env;
  7mod errors;
  8mod expiring;
  9mod github;
 10mod home;
 11mod releases;
 12mod rpc;
 13mod team;
 14
 15use self::errors::TideResultExt as _;
 16use anyhow::Result;
 17use async_std::net::TcpListener;
 18use async_trait::async_trait;
 19use auth::RequestExt as _;
 20use db::Db;
 21use handlebars::{Handlebars, TemplateRenderError};
 22use parking_lot::RwLock;
 23use rust_embed::RustEmbed;
 24use serde::{Deserialize, Serialize};
 25use std::sync::Arc;
 26use surf::http::cookies::SameSite;
 27use tide::{log, sessions::SessionMiddleware};
 28use tide_compress::CompressMiddleware;
 29use zrpc::Peer;
 30
 31type Request = tide::Request<Arc<AppState>>;
 32
 33#[derive(RustEmbed)]
 34#[folder = "templates"]
 35struct Templates;
 36
 37#[derive(Default, Deserialize)]
 38pub struct Config {
 39    pub http_port: u16,
 40    pub database_url: String,
 41    pub session_secret: String,
 42    pub github_app_id: usize,
 43    pub github_client_id: String,
 44    pub github_client_secret: String,
 45    pub github_private_key: String,
 46}
 47
 48pub struct AppState {
 49    db: Db,
 50    handlebars: RwLock<Handlebars<'static>>,
 51    auth_client: auth::Client,
 52    github_client: Arc<github::AppClient>,
 53    repo_client: github::RepoClient,
 54    config: Config,
 55}
 56
 57impl AppState {
 58    async fn new(config: Config) -> tide::Result<Arc<Self>> {
 59        let db = Db::new(&config.database_url, 5).await?;
 60        let github_client =
 61            github::AppClient::new(config.github_app_id, config.github_private_key.clone());
 62        let repo_client = github_client
 63            .repo("zed-industries/zed".into())
 64            .await
 65            .context("failed to initialize github client")?;
 66
 67        let this = Self {
 68            db,
 69            handlebars: Default::default(),
 70            auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret),
 71            github_client,
 72            repo_client,
 73            config,
 74        };
 75        this.register_partials();
 76        Ok(Arc::new(this))
 77    }
 78
 79    fn register_partials(&self) {
 80        for path in Templates::iter() {
 81            if let Some(partial_name) = path
 82                .strip_prefix("partials/")
 83                .and_then(|path| path.strip_suffix(".hbs"))
 84            {
 85                let partial = Templates::get(path.as_ref()).unwrap();
 86                self.handlebars
 87                    .write()
 88                    .register_partial(partial_name, std::str::from_utf8(&partial.data).unwrap())
 89                    .unwrap()
 90            }
 91        }
 92    }
 93
 94    fn render_template(
 95        &self,
 96        path: &'static str,
 97        data: &impl Serialize,
 98    ) -> Result<String, TemplateRenderError> {
 99        #[cfg(debug_assertions)]
100        self.register_partials();
101
102        self.handlebars.read().render_template(
103            std::str::from_utf8(&Templates::get(path).unwrap().data).unwrap(),
104            data,
105        )
106    }
107}
108
109#[async_trait]
110trait RequestExt {
111    async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>>;
112    fn db(&self) -> &Db;
113}
114
115#[async_trait]
116impl RequestExt for Request {
117    async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>> {
118        if self.ext::<Arc<LayoutData>>().is_none() {
119            self.set_ext(Arc::new(LayoutData {
120                current_user: self.current_user().await?,
121            }));
122        }
123        Ok(self.ext::<Arc<LayoutData>>().unwrap().clone())
124    }
125
126    fn db(&self) -> &Db {
127        &self.state().db
128    }
129}
130
131#[derive(Serialize)]
132struct LayoutData {
133    current_user: Option<auth::User>,
134}
135
136#[async_std::main]
137async fn main() -> tide::Result<()> {
138    log::start();
139
140    if let Err(error) = env::load_dotenv() {
141        log::error!(
142            "error loading .env.toml (this is expected in production): {}",
143            error
144        );
145    }
146
147    let config = envy::from_env::<Config>().expect("error loading config");
148    let state = AppState::new(config).await?;
149    let rpc = Peer::new();
150    run_server(
151        state.clone(),
152        rpc,
153        TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)).await?,
154    )
155    .await?;
156    Ok(())
157}
158
159pub async fn run_server(
160    state: Arc<AppState>,
161    rpc: Arc<Peer>,
162    listener: TcpListener,
163) -> tide::Result<()> {
164    let mut web = tide::with_state(state.clone());
165    web.with(CompressMiddleware::new());
166    web.with(
167        SessionMiddleware::new(
168            db::SessionStore::new_with_table_name(&state.config.database_url, "sessions")
169                .await
170                .unwrap(),
171            state.config.session_secret.as_bytes(),
172        )
173        .with_same_site_policy(SameSite::Lax), // Required obtain our session in /auth_callback
174    );
175    web.with(errors::Middleware);
176    home::add_routes(&mut web);
177    team::add_routes(&mut web);
178    releases::add_routes(&mut web);
179    community::add_routes(&mut web);
180    admin::add_routes(&mut web);
181    auth::add_routes(&mut web);
182
183    let mut assets = tide::new();
184    assets.with(CompressMiddleware::new());
185    assets::add_routes(&mut assets);
186
187    let mut app = tide::with_state(state.clone());
188    rpc::add_routes(&mut app, &rpc);
189
190    app.at("/").nest(web);
191    app.at("/static").nest(assets);
192
193    app.listen(listener).await?;
194
195    Ok(())
196}