Fix bugs preventing non-staff users from using LLM service (#16307)

Max Brunsfeld , Marshall , and Joseph created

- db deadlock in GetLlmToken for non-staff users
- typo in allowed model name for non-staff users

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>
Co-authored-by: Joseph <joseph@zed.dev>

Change summary

crates/collab/src/llm/authorization.rs |  8 ++++----
crates/collab/src/rpc.rs               | 14 ++++++--------
2 files changed, 10 insertions(+), 12 deletions(-)

Detailed changes

crates/collab/src/llm/authorization.rs 🔗

@@ -26,7 +26,7 @@ fn authorize_access_to_model(
     }
 
     match (provider, model) {
-        (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3.5-sonnet") => {
+        (LanguageModelProvider::Anthropic, model) if model.starts_with("claude-3-5-sonnet") => {
             Ok(())
         }
         _ => Err(Error::http(
@@ -240,14 +240,14 @@ mod tests {
             (
                 Plan::ZedPro,
                 LanguageModelProvider::Anthropic,
-                "claude-3.5-sonnet",
+                "claude-3-5-sonnet",
                 true,
             ),
             // Free plan should have access to claude-3.5-sonnet
             (
                 Plan::Free,
                 LanguageModelProvider::Anthropic,
-                "claude-3.5-sonnet",
+                "claude-3-5-sonnet",
                 true,
             ),
             // Pro plan should NOT have access to other Anthropic models
@@ -303,7 +303,7 @@ mod tests {
 
         // Staff should have access to all models
         let test_cases = vec![
-            (LanguageModelProvider::Anthropic, "claude-3.5-sonnet"),
+            (LanguageModelProvider::Anthropic, "claude-3-5-sonnet"),
             (LanguageModelProvider::Anthropic, "claude-2"),
             (LanguageModelProvider::Anthropic, "claude-123-agi"),
             (LanguageModelProvider::OpenAi, "gpt-4"),

crates/collab/src/rpc.rs 🔗

@@ -71,7 +71,7 @@ use std::{
     time::{Duration, Instant},
 };
 use time::OffsetDateTime;
-use tokio::sync::{watch, Semaphore};
+use tokio::sync::{watch, MutexGuard, Semaphore};
 use tower::ServiceBuilder;
 use tracing::{
     field::{self},
@@ -192,7 +192,7 @@ impl Session {
         }
     }
 
-    pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
+    pub async fn current_plan(&self, db: MutexGuard<'_, DbHandle>) -> anyhow::Result<proto::Plan> {
         if self.is_staff() {
             return Ok(proto::Plan::ZedPro);
         }
@@ -201,7 +201,6 @@ impl Session {
             return Ok(proto::Plan::Free);
         };
 
-        let db = self.db().await;
         if db.has_active_billing_subscription(user_id).await? {
             Ok(proto::Plan::ZedPro)
         } else {
@@ -3500,7 +3499,7 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
 }
 
 async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
-    let plan = session.current_plan().await?;
+    let plan = session.current_plan(session.db().await).await?;
 
     session
         .peer
@@ -4503,7 +4502,7 @@ async fn count_language_model_tokens(
     };
     authorize_access_to_legacy_llm_endpoints(&session).await?;
 
-    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
         proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
         proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
     };
@@ -4623,7 +4622,7 @@ async fn compute_embeddings(
     let api_key = api_key.context("no OpenAI API key configured on the server")?;
     authorize_access_to_legacy_llm_endpoints(&session).await?;
 
-    let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+    let rate_limit: Box<dyn RateLimit> = match session.current_plan(session.db().await).await? {
         proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
         proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
     };
@@ -4940,11 +4939,10 @@ async fn get_llm_api_token(
     if Utc::now().naive_utc() - account_created_at < MIN_ACCOUNT_AGE_FOR_LLM_USE {
         Err(anyhow!("account too young"))?
     }
-
     let token = LlmTokenClaims::create(
         user.id,
         session.is_staff(),
-        session.current_plan().await?,
+        session.current_plan(db).await?,
         &session.app_state.config,
     )?;
     response.send(proto::GetLlmTokenResponse { token })?;