Replace `lazy_static!` with `OnceLock` in `collab` crate (#8677)

Marshall Bowers created

This PR replaces a `lazy_static!` usage in the `collab` crate with
`OnceLock` from the standard library.

This allows us to drop the `lazy_static` dependency from this crate.

Release Notes:

- N/A

Change summary

crates/collab/Cargo.toml                           |  1 
crates/collab/src/api/events.rs                    | 14 +---
crates/collab/src/auth.rs                          | 23 ++++---
crates/collab/src/rpc.rs                           | 40 ++++++-------
crates/collab/src/tests/randomized_test_helpers.rs | 45 +++++++++++----
5 files changed, 67 insertions(+), 56 deletions(-)

Detailed changes

crates/collab/Cargo.toml 🔗

@@ -33,7 +33,6 @@ envy = "0.4.2"
 futures.workspace = true
 hex.workspace = true
 hyper = "0.14"
-lazy_static.workspace = true
 live_kit_server.workspace = true
 log.workspace = true
 nanoid = "0.4"

crates/collab/src/api/events.rs 🔗

@@ -1,11 +1,10 @@
-use std::sync::Arc;
+use std::sync::{Arc, OnceLock};
 
 use anyhow::{anyhow, Context};
 use axum::{
     body::Bytes, headers::Header, http::HeaderName, routing::post, Extension, Router, TypedHeader,
 };
 use hyper::StatusCode;
-use lazy_static::lazy_static;
 use serde::{Serialize, Serializer};
 use sha2::{Digest, Sha256};
 use telemetry_events::{
@@ -19,16 +18,12 @@ pub fn router() -> Router {
     Router::new().route("/telemetry/events", post(post_events))
 }
 
-lazy_static! {
-    static ref ZED_CHECKSUM_HEADER: HeaderName = HeaderName::from_static("x-zed-checksum");
-    static ref CLOUDFLARE_IP_COUNTRY_HEADER: HeaderName = HeaderName::from_static("cf-ipcountry");
-}
-
 pub struct ZedChecksumHeader(Vec<u8>);
 
 impl Header for ZedChecksumHeader {
     fn name() -> &'static HeaderName {
-        &ZED_CHECKSUM_HEADER
+        static ZED_CHECKSUM_HEADER: OnceLock<HeaderName> = OnceLock::new();
+        ZED_CHECKSUM_HEADER.get_or_init(|| HeaderName::from_static("x-zed-checksum"))
     }
 
     fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
@@ -55,7 +50,8 @@ pub struct CloudflareIpCountryHeader(String);
 
 impl Header for CloudflareIpCountryHeader {
     fn name() -> &'static HeaderName {
-        &CLOUDFLARE_IP_COUNTRY_HEADER
+        static CLOUDFLARE_IP_COUNTRY_HEADER: OnceLock<HeaderName> = OnceLock::new();
+        CLOUDFLARE_IP_COUNTRY_HEADER.get_or_init(|| HeaderName::from_static("cf-ipcountry"))
     }
 
     fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>

crates/collab/src/auth.rs 🔗

@@ -8,7 +8,6 @@ use axum::{
     middleware::Next,
     response::IntoResponse,
 };
-use lazy_static::lazy_static;
 use prometheus::{exponential_buckets, register_histogram, Histogram};
 use rand::thread_rng;
 use scrypt::{
@@ -16,17 +15,9 @@ use scrypt::{
     Scrypt,
 };
 use serde::{Deserialize, Serialize};
+use std::sync::OnceLock;
 use std::{sync::Arc, time::Instant};
 
-lazy_static! {
-    static ref METRIC_ACCESS_TOKEN_HASHING_TIME: Histogram = register_histogram!(
-        "access_token_hashing_time",
-        "time spent hashing access tokens",
-        exponential_buckets(10.0, 2.0, 10).unwrap(),
-    )
-    .unwrap();
-}
-
 #[derive(Clone, Debug, Default, PartialEq, Eq)]
 pub struct Impersonator(pub Option<db::User>);
 
@@ -182,6 +173,16 @@ pub async fn verify_access_token(
     user_id: UserId,
     db: &Arc<Database>,
 ) -> Result<VerifyAccessTokenResult> {
+    static METRIC_ACCESS_TOKEN_HASHING_TIME: OnceLock<Histogram> = OnceLock::new();
+    let metric_access_token_hashing_time = METRIC_ACCESS_TOKEN_HASHING_TIME.get_or_init(|| {
+        register_histogram!(
+            "access_token_hashing_time",
+            "time spent hashing access tokens",
+            exponential_buckets(10.0, 2.0, 10).unwrap(),
+        )
+        .unwrap()
+    });
+
     let token: AccessTokenJson = serde_json::from_str(&token)?;
 
     let db_token = db.get_access_token(token.id).await?;
@@ -197,7 +198,7 @@ pub async fn verify_access_token(
         .is_ok();
     let duration = t0.elapsed();
     log::info!("hashed access token in {:?}", duration);
-    METRIC_ACCESS_TOKEN_HASHING_TIME.observe(duration.as_millis() as f64);
+    metric_access_token_hashing_time.observe(duration.as_millis() as f64);
     Ok(VerifyAccessTokenResult {
         is_valid,
         impersonator_id: if db_token.impersonated_user_id.is_some() {

crates/collab/src/rpc.rs 🔗

@@ -35,7 +35,6 @@ use futures::{
     stream::FuturesUnordered,
     FutureExt, SinkExt, StreamExt, TryStreamExt,
 };
-use lazy_static::lazy_static;
 use prometheus::{register_int_gauge, IntGauge};
 use rpc::{
     proto::{
@@ -56,7 +55,7 @@ use std::{
     rc::Rc,
     sync::{
         atomic::{AtomicBool, Ordering::SeqCst},
-        Arc,
+        Arc, OnceLock,
     },
     time::{Duration, Instant},
 };
@@ -73,16 +72,6 @@ const MESSAGE_COUNT_PER_PAGE: usize = 100;
 const MAX_MESSAGE_LEN: usize = 1024;
 const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
 
-lazy_static! {
-    static ref METRIC_CONNECTIONS: IntGauge =
-        register_int_gauge!("connections", "number of connections").unwrap();
-    static ref METRIC_SHARED_PROJECTS: IntGauge = register_int_gauge!(
-        "shared_projects",
-        "number of open projects with one or more guests"
-    )
-    .unwrap();
-}
-
 type MessageHandler =
     Box<dyn Send + Sync + Fn(Box<dyn AnyTypedEnvelope>, Session) -> BoxFuture<'static, ()>>;
 
@@ -793,16 +782,12 @@ fn broadcast<F>(
     }
 }
 
-lazy_static! {
-    static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version");
-    static ref ZED_APP_VERSION: HeaderName = HeaderName::from_static("x-zed-app-version");
-}
-
 pub struct ProtocolVersion(u32);
 
 impl Header for ProtocolVersion {
     fn name() -> &'static HeaderName {
-        &ZED_PROTOCOL_VERSION
+        static ZED_PROTOCOL_VERSION: OnceLock<HeaderName> = OnceLock::new();
+        ZED_PROTOCOL_VERSION.get_or_init(|| HeaderName::from_static("x-zed-protocol-version"))
     }
 
     fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
@@ -828,7 +813,8 @@ impl Header for ProtocolVersion {
 pub struct AppVersionHeader(SemanticVersion);
 impl Header for AppVersionHeader {
     fn name() -> &'static HeaderName {
-        &ZED_APP_VERSION
+        static ZED_APP_VERSION: OnceLock<HeaderName> = OnceLock::new();
+        ZED_APP_VERSION.get_or_init(|| HeaderName::from_static("x-zed-app-version"))
     }
 
     fn decode<'i, I>(values: &mut I) -> Result<Self, axum::headers::Error>
@@ -922,17 +908,29 @@ pub async fn handle_websocket_request(
 }
 
 pub async fn handle_metrics(Extension(server): Extension<Arc<Server>>) -> Result<String> {
+    static CONNECTIONS_METRIC: OnceLock<IntGauge> = OnceLock::new();
+    let connections_metric = CONNECTIONS_METRIC
+        .get_or_init(|| register_int_gauge!("connections", "number of connections").unwrap());
+
     let connections = server
         .connection_pool
         .lock()
         .connections()
         .filter(|connection| !connection.admin)
         .count();
+    connections_metric.set(connections as _);
 
-    METRIC_CONNECTIONS.set(connections as _);
+    static SHARED_PROJECTS_METRIC: OnceLock<IntGauge> = OnceLock::new();
+    let shared_projects_metric = SHARED_PROJECTS_METRIC.get_or_init(|| {
+        register_int_gauge!(
+            "shared_projects",
+            "number of open projects with one or more guests"
+        )
+        .unwrap()
+    });
 
     let shared_projects = server.app_state.db.project_count_excluding_admins().await?;
-    METRIC_SHARED_PROJECTS.set(shared_projects as _);
+    shared_projects_metric.set(shared_projects as _);
 
     let encoder = prometheus::TextEncoder::new();
     let metric_families = prometheus::gather();

crates/collab/src/tests/randomized_test_helpers.rs 🔗

@@ -11,6 +11,7 @@ use rand::prelude::*;
 use rpc::RECEIVE_TIMEOUT;
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use settings::SettingsStore;
+use std::sync::OnceLock;
 use std::{
     env,
     path::PathBuf,
@@ -21,16 +22,32 @@ use std::{
     },
 };
 
-lazy_static::lazy_static! {
-    static ref PLAN_LOAD_PATH: Option<PathBuf> = path_env_var("LOAD_PLAN");
-    static ref PLAN_SAVE_PATH: Option<PathBuf> = path_env_var("SAVE_PLAN");
-    static ref MAX_PEERS: usize = env::var("MAX_PEERS")
-        .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
-        .unwrap_or(3);
-    static ref MAX_OPERATIONS: usize = env::var("OPERATIONS")
-        .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
-        .unwrap_or(10);
+fn plan_load_path() -> &'static Option<PathBuf> {
+    static PLAN_LOAD_PATH: OnceLock<Option<PathBuf>> = OnceLock::new();
+    PLAN_LOAD_PATH.get_or_init(|| path_env_var("LOAD_PLAN"))
+}
+
+fn plan_save_path() -> &'static Option<PathBuf> {
+    static PLAN_SAVE_PATH: OnceLock<Option<PathBuf>> = OnceLock::new();
+    PLAN_SAVE_PATH.get_or_init(|| path_env_var("SAVE_PLAN"))
+}
+
+fn max_peers() -> usize {
+    static MAX_PEERS: OnceLock<usize> = OnceLock::new();
+    *MAX_PEERS.get_or_init(|| {
+        env::var("MAX_PEERS")
+            .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
+            .unwrap_or(3)
+    })
+}
 
+fn max_operations() -> usize {
+    static MAX_OPERATIONS: OnceLock<usize> = OnceLock::new();
+    *MAX_OPERATIONS.get_or_init(|| {
+        env::var("OPERATIONS")
+            .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
+            .unwrap_or(10)
+    })
 }
 
 static LOADED_PLAN_JSON: Mutex<Option<Vec<u8>>> = Mutex::new(None);
@@ -175,7 +192,7 @@ pub async fn run_randomized_test<T: RandomizedTest>(
     }
     executor.run_until_parked();
 
-    if let Some(path) = &*PLAN_SAVE_PATH {
+    if let Some(path) = plan_save_path() {
         eprintln!("saved test plan to path {:?}", path);
         std::fs::write(path, plan.lock().serialize()).unwrap();
     }
@@ -183,7 +200,7 @@ pub async fn run_randomized_test<T: RandomizedTest>(
 
 pub fn save_randomized_test_plan() {
     if let Some(serialize_plan) = LAST_PLAN.lock().take() {
-        if let Some(path) = &*PLAN_SAVE_PATH {
+        if let Some(path) = plan_save_path() {
             eprintln!("saved test plan to path {:?}", path);
             std::fs::write(path, serialize_plan()).unwrap();
         }
@@ -197,7 +214,7 @@ impl<T: RandomizedTest> TestPlan<T> {
         let allow_client_disconnection = rng.gen_bool(0.1);
 
         let mut users = Vec::new();
-        for ix in 0..*MAX_PEERS {
+        for ix in 0..max_peers() {
             let username = format!("user-{}", ix + 1);
             let user_id = server
                 .app_state
@@ -234,12 +251,12 @@ impl<T: RandomizedTest> TestPlan<T> {
             stored_operations: Vec::new(),
             operation_ix: 0,
             next_batch_id: 0,
-            max_operations: *MAX_OPERATIONS,
+            max_operations: max_operations(),
             users,
             rng,
         }));
 
-        if let Some(path) = &*PLAN_LOAD_PATH {
+        if let Some(path) = plan_load_path() {
             let json = LOADED_PLAN_JSON
                 .lock()
                 .get_or_insert_with(|| {