seed.rs

  1use crate::db::{self, ChannelRole, NewUserParams};
  2
  3use anyhow::Context as _;
  4use chrono::{DateTime, Utc};
  5use db::Database;
  6use serde::{Deserialize, de::DeserializeOwned};
  7use std::{fs, path::Path};
  8
  9use crate::Config;
 10
 11/// A GitHub user.
 12///
 13/// This representation corresponds to the entries in the `seed/github_users.json` file.
 14#[derive(Debug, Deserialize)]
 15struct GithubUser {
 16    id: i32,
 17    login: String,
 18    email: Option<String>,
 19    name: Option<String>,
 20    created_at: DateTime<Utc>,
 21}
 22
 23#[derive(Deserialize)]
 24struct SeedConfig {
 25    /// Which users to create as admins.
 26    admins: Vec<String>,
 27    /// Which channels to create (all admins are invited to all channels).
 28    channels: Vec<String>,
 29}
 30
 31pub async fn seed(config: &Config, db: &Database, force: bool) -> anyhow::Result<()> {
 32    let client = reqwest::Client::new();
 33
 34    if !db.get_all_users(0, 1).await?.is_empty() && !force {
 35        return Ok(());
 36    }
 37
 38    let seed_path = config
 39        .seed_path
 40        .as_ref()
 41        .context("called seed with no SEED_PATH")?;
 42
 43    let seed_config = load_admins(seed_path)
 44        .context(format!("failed to load {}", seed_path.to_string_lossy()))?;
 45
 46    let mut first_user = None;
 47    let mut others = vec![];
 48
 49    let flag_names = ["language-models"];
 50    let mut flags = Vec::new();
 51
 52    let existing_feature_flags = db.list_feature_flags().await?;
 53
 54    for flag_name in flag_names {
 55        if existing_feature_flags
 56            .iter()
 57            .any(|flag| flag.flag == flag_name)
 58        {
 59            log::info!("Flag {flag_name:?} already exists");
 60            continue;
 61        }
 62
 63        let flag = db
 64            .create_user_flag(flag_name, false)
 65            .await
 66            .unwrap_or_else(|err| panic!("failed to create flag: '{flag_name}': {err}"));
 67        flags.push(flag);
 68    }
 69
 70    for admin_login in seed_config.admins {
 71        let user = fetch_github::<GithubUser>(
 72            &client,
 73            &format!("https://api.github.com/users/{admin_login}"),
 74        )
 75        .await;
 76        let user = db
 77            .create_user(
 78                &user.email.unwrap_or(format!("{admin_login}@example.com")),
 79                user.name.as_deref(),
 80                true,
 81                NewUserParams {
 82                    github_login: user.login,
 83                    github_user_id: user.id,
 84                },
 85            )
 86            .await
 87            .context("failed to create admin user")?;
 88        if first_user.is_none() {
 89            first_user = Some(user.user_id);
 90        } else {
 91            others.push(user.user_id)
 92        }
 93
 94        for flag in &flags {
 95            db.add_user_flag(user.user_id, *flag)
 96                .await
 97                .context(format!(
 98                    "Unable to enable flag '{}' for user '{}'",
 99                    flag, user.user_id
100                ))?;
101        }
102    }
103
104    for channel in seed_config.channels {
105        let (channel, _) = db
106            .create_channel(&channel, None, first_user.unwrap())
107            .await
108            .context("failed to create channel")?;
109
110        for user_id in &others {
111            db.invite_channel_member(
112                channel.id,
113                *user_id,
114                first_user.unwrap(),
115                ChannelRole::Admin,
116            )
117            .await
118            .context("failed to add user to channel")?;
119        }
120    }
121
122    let github_users_filepath = seed_path.parent().unwrap().join("seed/github_users.json");
123    let github_users: Vec<GithubUser> =
124        serde_json::from_str(&fs::read_to_string(github_users_filepath)?)?;
125
126    for github_user in github_users {
127        log::info!("Seeding {:?} from GitHub", github_user.login);
128
129        let user = db
130            .update_or_create_user_by_github_account(
131                &github_user.login,
132                github_user.id,
133                github_user.email.as_deref(),
134                github_user.name.as_deref(),
135                github_user.created_at,
136                None,
137            )
138            .await
139            .expect("failed to insert user");
140
141        for flag in &flags {
142            db.add_user_flag(user.id, *flag).await.context(format!(
143                "Unable to enable flag '{}' for user '{}'",
144                flag, user.id
145            ))?;
146        }
147    }
148
149    Ok(())
150}
151
152fn load_admins(path: impl AsRef<Path>) -> anyhow::Result<SeedConfig> {
153    let file_content = fs::read_to_string(path)?;
154    Ok(serde_json::from_str(&file_content)?)
155}
156
157async fn fetch_github<T: DeserializeOwned>(client: &reqwest::Client, url: &str) -> T {
158    let mut request_builder = client.get(url);
159    if let Ok(github_token) = std::env::var("GITHUB_TOKEN") {
160        request_builder =
161            request_builder.header("Authorization", format!("Bearer {}", github_token));
162    }
163    let response = request_builder
164        .header("user-agent", "zed")
165        .send()
166        .await
167        .unwrap_or_else(|error| panic!("failed to fetch '{url}': {error}"));
168    let response_text = response.text().await.unwrap_or_else(|error| {
169        panic!("failed to fetch '{url}': {error}");
170    });
171    serde_json::from_str(&response_text).unwrap_or_else(|error| {
172        panic!("failed to deserialize github user from '{url}'. Error: '{error}', text: '{response_text}'");
173    })
174}