Remove code paths that skip LLM db in prod (#16008)

Max Brunsfeld created

Release Notes:

- N/A

Change summary

crates/collab/src/llm.rs  | 77 +++++++++++++++-------------------------
crates/collab/src/main.rs | 12 +++---
2 files changed, 35 insertions(+), 54 deletions(-)

Detailed changes

crates/collab/src/llm.rs 🔗

@@ -27,7 +27,7 @@ pub use token::*;
 pub struct LlmState {
     pub config: Config,
     pub executor: Executor,
-    pub db: Option<Arc<LlmDatabase>>,
+    pub db: Arc<LlmDatabase>,
     pub http_client: IsahcHttpClient,
     active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
 }
@@ -36,25 +36,20 @@ const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
 
 impl LlmState {
     pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
-        // TODO: This is temporary until we have the LLM database stood up.
-        let db = if config.is_development() {
-            let database_url = config
-                .llm_database_url
-                .as_ref()
-                .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
-            let max_connections = config
-                .llm_database_max_connections
-                .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
-
-            let mut db_options = db::ConnectOptions::new(database_url);
-            db_options.max_connections(max_connections);
-            let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
-            db.initialize().await?;
-
-            Some(Arc::new(db))
-        } else {
-            None
-        };
+        let database_url = config
+            .llm_database_url
+            .as_ref()
+            .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
+        let max_connections = config
+            .llm_database_max_connections
+            .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
+
+        let mut db_options = db::ConnectOptions::new(database_url);
+        db_options.max_connections(max_connections);
+        let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
+        db.initialize().await?;
+
+        let db = Arc::new(db);
 
         let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
         let http_client = IsahcHttpClient::builder()
@@ -62,11 +57,8 @@ impl LlmState {
             .build()
             .context("failed to construct http client")?;
 
-        let initial_active_user_count = if let Some(db) = &db {
-            Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
-        } else {
-            None
-        };
+        let initial_active_user_count =
+            Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
 
         let this = Self {
             config,
@@ -88,14 +80,10 @@ impl LlmState {
             }
         }
 
-        if let Some(db) = &self.db {
-            let mut cache = self.active_user_count.write().await;
-            let new_count = db.get_active_user_count(now).await?;
-            *cache = Some((now, new_count));
-            Ok(new_count)
-        } else {
-            Ok(ActiveUserCount::default())
-        }
+        let mut cache = self.active_user_count.write().await;
+        let new_count = self.db.get_active_user_count(now).await?;
+        *cache = Some((now, new_count));
+        Ok(new_count)
     }
 }
 
@@ -165,9 +153,7 @@ async fn perform_completion(
 
     let user_id = claims.user_id as i32;
 
-    if state.db.is_some() {
-        check_usage_limit(&state, params.provider, &model, &claims).await?;
-    }
+    check_usage_limit(&state, params.provider, &model, &claims).await?;
 
     match params.provider {
         LanguageModelProvider::Anthropic => {
@@ -199,14 +185,14 @@ async fn perform_completion(
             )
             .await?;
 
-            let mut recorder = state.db.clone().map(|db| UsageRecorder {
-                db,
+            let mut recorder = UsageRecorder {
+                db: state.db.clone(),
                 executor: state.executor.clone(),
                 user_id,
                 provider: params.provider,
                 model,
                 token_count: 0,
-            });
+            };
 
             let stream = chunks.map(move |event| {
                 let mut buffer = Vec::new();
@@ -216,10 +202,8 @@ async fn perform_completion(
                             message: anthropic::Response { usage, .. },
                         }
                         | anthropic::Event::MessageDelta { usage, .. } => {
-                            if let Some(recorder) = &mut recorder {
-                                recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
-                                recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
-                            }
+                            recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
+                            recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
                         }
                         _ => {}
                     }
@@ -349,12 +333,9 @@ async fn check_usage_limit(
     model_name: &str,
     claims: &LlmTokenClaims,
 ) -> Result<()> {
-    let db = state
+    let model = state.db.model(provider, model_name)?;
+    let usage = state
         .db
-        .as_ref()
-        .ok_or_else(|| anyhow!("LLM database not configured"))?;
-    let model = db.model(provider, model_name)?;
-    let usage = db
         .get_usage(claims.user_id as i32, provider, model_name, Utc::now())
         .await?;
 

crates/collab/src/main.rs 🔗

@@ -248,11 +248,6 @@ async fn setup_app_database(config: &Config) -> Result<()> {
 }
 
 async fn setup_llm_database(config: &Config) -> Result<()> {
-    // TODO: This is temporary until we have the LLM database stood up.
-    if !config.is_development() {
-        return Ok(());
-    }
-
     let database_url = config
         .llm_database_url
         .as_ref()
@@ -298,7 +293,12 @@ async fn handle_liveness_probe(
         state.db.get_all_users(0, 1).await?;
     }
 
-    if let Some(_llm_state) = llm_state {}
+    if let Some(llm_state) = llm_state {
+        llm_state
+            .db
+            .get_active_user_count(chrono::Utc::now())
+            .await?;
+    }
 
     Ok("ok".to_string())
 }