language_models: Count Google AI tokens through LLM service (#29319)

Marshall Bowers created

This PR wires the counting of Google AI tokens back up.

It now goes through the LLM service instead of collab's RPC.

Still only available for Zed staff.

Release Notes:

- N/A

Change summary

Cargo.lock                                   |  4 
Cargo.toml                                   |  2 
crates/language_models/src/provider/cloud.rs | 59 ++++++++++++++++++++-
3 files changed, 58 insertions(+), 7 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -18536,9 +18536,9 @@ dependencies = [
 
 [[package]]
 name = "zed_llm_client"
-version = "0.7.0"
+version = "0.7.1"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3c1666cd923c5eb4635f3743e69c6920d0ed71f29b26920616a5d220607df7c4"
+checksum = "cc9ec491b7112cb8c2fba3c17d9a349d8ab695fb1a4ef6c5c4b9fd8d7aa975c1"
 dependencies = [
  "anyhow",
  "serde",

Cargo.toml 🔗

@@ -606,7 +606,7 @@ wasmtime-wasi = "29"
 which = "6.0.0"
 wit-component = "0.221"
 workspace-hack = "0.1.0"
-zed_llm_client = "0.7.0"
+zed_llm_client = "0.7.1"
 zstd = "0.11"
 metal = "0.29"
 

crates/language_models/src/provider/cloud.rs 🔗

@@ -35,9 +35,9 @@ use strum::IntoEnumIterator;
 use thiserror::Error;
 use ui::{TintColor, prelude::*};
 use zed_llm_client::{
-    CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, EXPIRED_LLM_TOKEN_HEADER_NAME,
-    MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME, MODEL_REQUESTS_RESOURCE_HEADER_VALUE,
-    SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
+    CURRENT_PLAN_HEADER_NAME, CompletionBody, CompletionMode, CountTokensBody, CountTokensResponse,
+    EXPIRED_LLM_TOKEN_HEADER_NAME, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME,
+    MODEL_REQUESTS_RESOURCE_HEADER_VALUE, SUBSCRIPTION_LIMIT_RESOURCE_HEADER_NAME,
 };
 
 use crate::AllLanguageModelSettings;
@@ -686,7 +686,58 @@ impl LanguageModel for CloudLanguageModel {
         match self.model.clone() {
             CloudModel::Anthropic(_) => count_anthropic_tokens(request, cx),
             CloudModel::OpenAi(model) => count_open_ai_tokens(request, model, cx),
-            CloudModel::Google(_model) => async move { Ok(0) }.boxed(),
+            CloudModel::Google(model) => {
+                let client = self.client.clone();
+                let llm_api_token = self.llm_api_token.clone();
+                let request = into_google(request, model.id().into());
+                async move {
+                    let http_client = &client.http_client();
+                    let token = llm_api_token.acquire(&client).await?;
+
+                    let request_builder = http_client::Request::builder().method(Method::POST);
+                    let request_builder =
+                        if let Ok(completions_url) = std::env::var("ZED_COUNT_TOKENS_URL") {
+                            request_builder.uri(completions_url)
+                        } else {
+                            request_builder.uri(
+                                http_client
+                                    .build_zed_llm_url("/count_tokens", &[])?
+                                    .as_ref(),
+                            )
+                        };
+                    let request_body = CountTokensBody {
+                        provider: zed_llm_client::LanguageModelProvider::Google,
+                        model: model.id().into(),
+                        provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
+                            contents: request.contents,
+                        })?,
+                    };
+                    let request = request_builder
+                        .header("Content-Type", "application/json")
+                        .header("Authorization", format!("Bearer {token}"))
+                        .body(serde_json::to_string(&request_body)?.into())?;
+                    let mut response = http_client.send(request).await?;
+                    let status = response.status();
+                    let mut response_body = String::new();
+                    response
+                        .body_mut()
+                        .read_to_string(&mut response_body)
+                        .await?;
+
+                    if status.is_success() {
+                        let response_body: CountTokensResponse =
+                            serde_json::from_str(&response_body)?;
+
+                        Ok(response_body.tokens)
+                    } else {
+                        Err(anyhow!(ApiError {
+                            status,
+                            body: response_body
+                        }))
+                    }
+                }
+                .boxed()
+            }
         }
     }