collab: Remove `CountLanguageModelTokens` RPC message (#29314)

Marshall Bowers created

This PR removes the `CountLanguageModelTokens` RPC message from collab.

We were only using this for Google AI models through the Zed provider
(which is only available to Zed staff).

For now we're returning `0`, but will bring back soon.

Release Notes:

- N/A

Change summary

Cargo.lock                                   |   1 
crates/collab/Cargo.toml                     |   1 
crates/collab/src/rpc.rs                     | 111 ---------------------
crates/language_models/src/provider/cloud.rs |  19 ---
crates/proto/proto/ai.proto                  |  16 ---
crates/proto/proto/zed.proto                 |   4 
crates/proto/src/proto.rs                    |   3 
7 files changed, 4 insertions(+), 151 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2998,7 +2998,6 @@ dependencies = [
  "git",
  "git_hosting_providers",
  "git_ui",
- "google_ai",
  "gpui",
  "gpui_tokio",
  "hex",

crates/collab/Cargo.toml 🔗

@@ -34,7 +34,6 @@ dashmap.workspace = true
 derive_more.workspace = true
 envy = "0.4.2"
 futures.workspace = true
-google_ai.workspace = true
 hex.workspace = true
 http_client.workspace = true
 jsonwebtoken.workspace = true

crates/collab/src/rpc.rs 🔗

@@ -3,7 +3,7 @@ mod connection_pool;
 use crate::api::{CloudflareIpCountryHeader, SystemIdHeader};
 use crate::llm::LlmTokenClaims;
 use crate::{
-    AppState, Config, Error, RateLimit, Result, auth,
+    AppState, Error, Result, auth,
     db::{
         self, BufferId, Capability, Channel, ChannelId, ChannelRole, ChannelsForUser,
         CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
@@ -33,7 +33,6 @@ use chrono::Utc;
 use collections::{HashMap, HashSet};
 pub use connection_pool::{ConnectionPool, ZedVersion};
 use core::fmt::{self, Debug, Formatter};
-use http_client::HttpClient;
 use reqwest_client::ReqwestClient;
 use rpc::proto::split_repository_update;
 use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi};
@@ -132,7 +131,6 @@ struct Session {
     connection_pool: Arc<parking_lot::Mutex<ConnectionPool>>,
     app_state: Arc<AppState>,
     supermaven_client: Option<Arc<SupermavenAdminApi>>,
-    http_client: Arc<dyn HttpClient>,
     /// The GeoIP country code for the user.
     #[allow(unused)]
     geoip_country_code: Option<String>,
@@ -425,17 +423,7 @@ impl Server {
             .add_request_handler(forward_mutating_project_request::<proto::GitChangeBranch>)
             .add_request_handler(forward_mutating_project_request::<proto::CheckForPushedCommits>)
             .add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
-            .add_message_handler(update_context)
-            .add_request_handler({
-                let app_state = app_state.clone();
-                move |request, response, session| {
-                    let app_state = app_state.clone();
-                    async move {
-                        count_language_model_tokens(request, response, session, &app_state.config)
-                            .await
-                    }
-                }
-            });
+            .add_message_handler(update_context);
 
         Arc::new(server)
     }
@@ -764,7 +752,6 @@ impl Server {
                 peer: this.peer.clone(),
                 connection_pool: this.connection_pool.clone(),
                 app_state: this.app_state.clone(),
-                http_client,
                 geoip_country_code,
                 system_id,
                 _executor: executor.clone(),
@@ -3683,100 +3670,6 @@ async fn acknowledge_buffer_version(
     Ok(())
 }
 
-async fn count_language_model_tokens(
-    request: proto::CountLanguageModelTokens,
-    response: Response<proto::CountLanguageModelTokens>,
-    session: Session,
-    config: &Config,
-) -> Result<()> {
-    authorize_access_to_legacy_llm_endpoints(&session).await?;
-
-    let rate_limit: Box<dyn RateLimit> = match session.current_plan(&session.db().await).await? {
-        proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
-        proto::Plan::Free | proto::Plan::ZedProTrial => {
-            Box::new(FreeCountLanguageModelTokensRateLimit)
-        }
-    };
-
-    session
-        .app_state
-        .rate_limiter
-        .check(&*rate_limit, session.user_id())
-        .await?;
-
-    let result = match proto::LanguageModelProvider::from_i32(request.provider) {
-        Some(proto::LanguageModelProvider::Google) => {
-            let api_key = config
-                .google_ai_api_key
-                .as_ref()
-                .context("no Google AI API key configured on the server")?;
-            google_ai::count_tokens(
-                session.http_client.as_ref(),
-                google_ai::API_URL,
-                api_key,
-                serde_json::from_str(&request.request)?,
-            )
-            .await?
-        }
-        _ => return Err(anyhow!("unsupported provider"))?,
-    };
-
-    response.send(proto::CountLanguageModelTokensResponse {
-        token_count: result.total_tokens as u32,
-    })?;
-
-    Ok(())
-}
-
-struct ZedProCountLanguageModelTokensRateLimit;
-
-impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
-    fn capacity(&self) -> usize {
-        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
-            .ok()
-            .and_then(|v| v.parse().ok())
-            .unwrap_or(600) // Picked arbitrarily
-    }
-
-    fn refill_duration(&self) -> chrono::Duration {
-        chrono::Duration::hours(1)
-    }
-
-    fn db_name(&self) -> &'static str {
-        "zed-pro:count-language-model-tokens"
-    }
-}
-
-struct FreeCountLanguageModelTokensRateLimit;
-
-impl RateLimit for FreeCountLanguageModelTokensRateLimit {
-    fn capacity(&self) -> usize {
-        std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
-            .ok()
-            .and_then(|v| v.parse().ok())
-            .unwrap_or(600 / 10) // Picked arbitrarily
-    }
-
-    fn refill_duration(&self) -> chrono::Duration {
-        chrono::Duration::hours(1)
-    }
-
-    fn db_name(&self) -> &'static str {
-        "free:count-language-model-tokens"
-    }
-}
-
-/// This is leftover from before the LLM service.
-///
-/// The endpoints protected by this check will be moved there eventually.
-async fn authorize_access_to_legacy_llm_endpoints(session: &Session) -> Result<(), Error> {
-    if session.is_staff() {
-        Ok(())
-    } else {
-        Err(anyhow!("permission denied"))?
-    }
-}
-
 /// Get a Supermaven API key for the user
 async fn get_supermaven_api_key(
     _request: proto::GetSupermavenApiKey,

crates/language_models/src/provider/cloud.rs 🔗

@@ -686,24 +686,7 @@ impl LanguageModel for CloudLanguageModel {
         match self.model.clone() {
             CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
             CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
-            CloudModel::Google(model) => {
-                let client = self.client.clone();
-                let request = into_google(request, model.id().into());
-                let request = google_ai::CountTokensRequest {
-                    contents: request.contents,
-                };
-                async move {
-                    let request = serde_json::to_string(&request)?;
-                    let response = client
-                        .request(proto::CountLanguageModelTokens {
-                            provider: proto::LanguageModelProvider::Google as i32,
-                            request,
-                        })
-                        .await?;
-                    Ok(response.token_count as usize)
-                }
-                .boxed()
-            }
+            CloudModel::Google(_model) => async move { Ok(0) }.boxed(),
         }
     }
 

crates/proto/proto/ai.proto 🔗

@@ -172,19 +172,3 @@ enum LanguageModelRole {
     LanguageModelSystem = 2;
     reserved 3;
 }
-
-message CountLanguageModelTokens {
-    LanguageModelProvider provider = 1;
-    string request = 2;
-}
-
-message CountLanguageModelTokensResponse {
-    uint32 token_count = 1;
-}
-
-enum LanguageModelProvider {
-    Anthropic = 0;
-    OpenAI = 1;
-    Google = 2;
-    Zed = 3;
-}

crates/proto/proto/zed.proto 🔗

@@ -206,9 +206,6 @@ message Envelope {
         GetImplementation get_implementation = 162;
         GetImplementationResponse get_implementation_response = 163;
 
-        CountLanguageModelTokens count_language_model_tokens = 230;
-        CountLanguageModelTokensResponse count_language_model_tokens_response = 231;
-
         UpdateChannelMessage update_channel_message = 170;
         ChannelMessageUpdate channel_message_update = 171;
 
@@ -397,6 +394,7 @@ message Envelope {
     reserved 205 to 206;
     reserved 221;
     reserved 224 to 229;
+    reserved 230 to 231;
     reserved 246;
     reserved 270;
     reserved 247 to 254;

crates/proto/src/proto.rs 🔗

@@ -50,8 +50,6 @@ messages!(
     (CloseBuffer, Foreground),
     (Commit, Background),
     (CopyProjectEntry, Foreground),
-    (CountLanguageModelTokens, Background),
-    (CountLanguageModelTokensResponse, Background),
     (CreateBufferForPeer, Foreground),
     (CreateChannel, Foreground),
     (CreateChannelResponse, Foreground),
@@ -374,7 +372,6 @@ request_messages!(
     (PerformRename, PerformRenameResponse),
     (Ping, Ack),
     (PrepareRename, PrepareRenameResponse),
-    (CountLanguageModelTokens, CountLanguageModelTokensResponse),
     (RefreshInlayHints, Ack),
     (RefreshCodeLens, Ack),
     (RejoinChannelBuffers, RejoinChannelBuffersResponse),