Add feature-flagged access to LLM service (#16136)

Marshall Bowers and Max created

This PR adds feature-flagged access to the LLM service.

We've repurposed the `language-models` feature flag to be used for
providing access to Claude 3.5 Sonnet through the Zed provider.

The remaining RPC endpoints that were previously behind the
`language-models` feature flag are now behind a staff check.

We also put some Zed Pro related messaging behind a feature flag.

Release Notes:

- N/A

---------

Co-authored-by: Max <max@zed.dev>

Change summary

crates/collab/src/rpc.rs                    | 28 +++++----
crates/language_model/src/provider/cloud.rs | 64 ++++++++++++----------
2 files changed, 49 insertions(+), 43 deletions(-)

Detailed changes

crates/collab/src/rpc.rs 🔗

@@ -4501,7 +4501,7 @@ async fn count_language_model_tokens(
     let Some(session) = session.for_user() else {
         return Err(anyhow!("user not found"))?;
     };
-    authorize_access_to_language_models(&session).await?;
+    authorize_access_to_legacy_llm_endpoints(&session).await?;
 
     let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
         proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
@@ -4621,7 +4621,7 @@ async fn compute_embeddings(
     api_key: Option<Arc<str>>,
 ) -> Result<()> {
     let api_key = api_key.context("no OpenAI API key configured on the server")?;
-    authorize_access_to_language_models(&session).await?;
+    authorize_access_to_legacy_llm_endpoints(&session).await?;
 
     let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
         proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
@@ -4685,7 +4685,7 @@ async fn get_cached_embeddings(
     response: Response<proto::GetCachedEmbeddings>,
     session: UserSession,
 ) -> Result<()> {
-    authorize_access_to_language_models(&session).await?;
+    authorize_access_to_legacy_llm_endpoints(&session).await?;
 
     let db = session.db().await;
     let embeddings = db.get_embeddings(&request.model, &request.digests).await?;
@@ -4699,14 +4699,15 @@ async fn get_cached_embeddings(
     Ok(())
 }
 
-async fn authorize_access_to_language_models(session: &UserSession) -> Result<(), Error> {
-    let db = session.db().await;
-    let flags = db.get_user_flags(session.user_id()).await?;
-    if flags.iter().any(|flag| flag == "language-models") {
-        return Ok(());
+/// This is leftover from before the LLM service.
+///
+/// The endpoints protected by this check will be moved there eventually.
+async fn authorize_access_to_legacy_llm_endpoints(session: &UserSession) -> Result<(), Error> {
+    if session.is_staff() {
+        Ok(())
+    } else {
+        Err(anyhow!("permission denied"))?
     }
-
-    Err(anyhow!("permission denied"))?
 }
 
 /// Get a Supermaven API key for the user
@@ -4915,12 +4916,13 @@ async fn get_llm_api_token(
     response: Response<proto::GetLlmToken>,
     session: UserSession,
 ) -> Result<()> {
-    if !session.is_staff() {
+    let db = session.db().await;
+
+    let flags = db.get_user_flags(session.user_id()).await?;
+    if !session.is_staff() && !flags.iter().any(|flag| flag == "language-models") {
         Err(anyhow!("permission denied"))?
     }
 
-    let db = session.db().await;
-
     let user_id = session.user_id();
     let user = db
         .get_user_by_id(user_id)

crates/language_model/src/provider/cloud.rs 🔗

@@ -8,7 +8,7 @@ use anthropic::AnthropicError;
 use anyhow::{anyhow, bail, Context as _, Result};
 use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
 use collections::BTreeMap;
-use feature_flags::{FeatureFlagAppExt, LanguageModels};
+use feature_flags::{FeatureFlagAppExt, ZedPro};
 use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
 use gpui::{
     AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
@@ -168,13 +168,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
     fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
         let mut models = BTreeMap::default();
 
-        let is_user = !cx.has_flag::<LanguageModels>();
-        if is_user {
-            models.insert(
-                anthropic::Model::Claude3_5Sonnet.id().to_string(),
-                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
-            );
-        } else {
+        if cx.is_staff() {
             for model in anthropic::Model::iter() {
                 if !matches!(model, anthropic::Model::Custom { .. }) {
                     models.insert(model.id().to_string(), CloudModel::Anthropic(model));
@@ -218,6 +212,11 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
                 };
                 models.insert(model.id().to_string(), model.clone());
             }
+        } else {
+            models.insert(
+                anthropic::Model::Claude3_5Sonnet.id().to_string(),
+                CloudModel::Anthropic(anthropic::Model::Claude3_5Sonnet),
+            );
         }
 
         models
@@ -869,34 +868,39 @@ impl Render for ConfigurationView {
                     if is_pro {
                         "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
                     } else {
-                        "You have basic access to models from Anthropic, OpenAI, Google and more through the Zed AI Free plan."
+                        "You have basic access to models from Anthropic through the Zed AI Free plan."
                     }))
-                .child(
-                    if is_pro {
+                .children(if is_pro {
+                    Some(
                         h_flex().child(
-                        Button::new("manage_settings", "Manage Subscription")
-                            .style(ButtonStyle::Filled)
-                            .on_click(cx.listener(|_, _, cx| {
-                                cx.open_url(ACCOUNT_SETTINGS_URL)
-                            })))
-                    } else {
+                            Button::new("manage_settings", "Manage Subscription")
+                                .style(ButtonStyle::Filled)
+                                .on_click(
+                                    cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)),
+                                ),
+                        ),
+                    )
+                } else if cx.has_flag::<ZedPro>() {
+                    Some(
                         h_flex()
                             .gap_2()
                             .child(
-                        Button::new("learn_more", "Learn more")
-                            .style(ButtonStyle::Subtle)
-                            .on_click(cx.listener(|_, _, cx| {
-                                cx.open_url(ZED_AI_URL)
-                            })))
+                                Button::new("learn_more", "Learn more")
+                                    .style(ButtonStyle::Subtle)
+                                    .on_click(cx.listener(|_, _, cx| cx.open_url(ZED_AI_URL))),
+                            )
                             .child(
-                        Button::new("upgrade", "Upgrade")
-                            .style(ButtonStyle::Subtle)
-                            .color(Color::Accent)
-                            .on_click(cx.listener(|_, _, cx| {
-                                cx.open_url(ACCOUNT_SETTINGS_URL)
-                            })))
-                    },
-                )
+                                Button::new("upgrade", "Upgrade")
+                                    .style(ButtonStyle::Subtle)
+                                    .color(Color::Accent)
+                                    .on_click(
+                                        cx.listener(|_, _, cx| cx.open_url(ACCOUNT_SETTINGS_URL)),
+                                    ),
+                            ),
+                    )
+                } else {
+                    None
+                })
         } else {
             v_flex()
                 .gap_6()