Add support for getting the token count for all parts of Gemini generation requests (#29630)

Michael Sloan created

* `CountTokensRequest` now takes a full `GenerateContentRequest` instead
of just content.

* Fixes use of `models/` prefix in `model` field of
`GenerateContentRequest`, since that's required for use in
`CountTokensRequest`. This didn't cause issues before because it was
always cleared and used in the path.

Release Notes:

- N/A

Change summary

crates/google_ai/src/google_ai.rs             | 107 ++++++++++++++++----
crates/language_models/src/provider/cloud.rs  |   9 
crates/language_models/src/provider/google.rs |   7 
3 files changed, 90 insertions(+), 33 deletions(-)

Detailed changes

crates/google_ai/src/google_ai.rs 🔗

@@ -1,7 +1,9 @@
+use std::mem;
+
 use anyhow::{Result, anyhow, bail};
 use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
 use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
-use serde::{Deserialize, Serialize};
+use serde::{Deserialize, Deserializer, Serialize, Serializer};
 
 pub const API_URL: &str = "https://generativelanguage.googleapis.com";
 
@@ -11,25 +13,13 @@ pub async fn stream_generate_content(
     api_key: &str,
     mut request: GenerateContentRequest,
 ) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
-    if request.contents.is_empty() {
-        bail!("Request must contain at least one content item");
-    }
+    validate_generate_content_request(&request)?;
 
-    if let Some(user_content) = request
-        .contents
-        .iter()
-        .find(|content| content.role == Role::User)
-    {
-        if user_content.parts.is_empty() {
-            bail!("User content must contain at least one part");
-        }
-    }
+    // The `model` field is emptied as it is provided as a path parameter.
+    let model_id = mem::take(&mut request.model.model_id);
 
-    let uri = format!(
-        "{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
-        model = request.model
-    );
-    request.model.clear();
+    let uri =
+        format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
 
     let request_builder = HttpRequest::builder()
         .method(Method::POST)
@@ -76,18 +66,22 @@ pub async fn count_tokens(
     client: &dyn HttpClient,
     api_url: &str,
     api_key: &str,
-    model_id: &str,
     request: CountTokensRequest,
 ) -> Result<CountTokensResponse> {
-    let uri = format!("{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",);
-    let request = serde_json::to_string(&request)?;
+    validate_generate_content_request(&request.generate_content_request)?;
+
+    let uri = format!(
+        "{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
+        model_id = &request.generate_content_request.model.model_id,
+    );
 
+    let request = serde_json::to_string(&request)?;
     let request_builder = HttpRequest::builder()
         .method(Method::POST)
         .uri(&uri)
         .header("Content-Type", "application/json");
-
     let http_request = request_builder.body(AsyncBody::from(request))?;
+
     let mut response = client.send(http_request).await?;
     let mut text = String::new();
     response.body_mut().read_to_string(&mut text).await?;
@@ -102,6 +96,28 @@ pub async fn count_tokens(
     }
 }
 
+pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
+    if request.model.is_empty() {
+        bail!("Model must be specified");
+    }
+
+    if request.contents.is_empty() {
+        bail!("Request must contain at least one content item");
+    }
+
+    if let Some(user_content) = request
+        .contents
+        .iter()
+        .find(|content| content.role == Role::User)
+    {
+        if user_content.parts.is_empty() {
+            bail!("User content must contain at least one part");
+        }
+    }
+
+    Ok(())
+}
+
 #[derive(Debug, Serialize, Deserialize)]
 pub enum Task {
     #[serde(rename = "generateContent")]
@@ -119,8 +135,8 @@ pub enum Task {
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct GenerateContentRequest {
-    #[serde(default, skip_serializing_if = "String::is_empty")]
-    pub model: String,
+    #[serde(default, skip_serializing_if = "ModelName::is_empty")]
+    pub model: ModelName,
     pub contents: Vec<Content>,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub system_instruction: Option<SystemInstruction>,
@@ -350,7 +366,7 @@ pub struct SafetyRating {
 #[derive(Debug, Serialize, Deserialize)]
 #[serde(rename_all = "camelCase")]
 pub struct CountTokensRequest {
-    pub contents: Vec<Content>,
+    pub generate_content_request: GenerateContentRequest,
 }
 
 #[derive(Debug, Serialize, Deserialize)]
@@ -406,6 +422,47 @@ pub struct FunctionDeclaration {
     pub parameters: serde_json::Value,
 }
 
+#[derive(Debug, Default)]
+pub struct ModelName {
+    pub model_id: String,
+}
+
+impl ModelName {
+    pub fn is_empty(&self) -> bool {
+        self.model_id.is_empty()
+    }
+}
+
+const MODEL_NAME_PREFIX: &str = "models/";
+
+impl Serialize for ModelName {
+    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+    where
+        S: Serializer,
+    {
+        serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
+    }
+}
+
+impl<'de> Deserialize<'de> for ModelName {
+    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+    where
+        D: Deserializer<'de>,
+    {
+        let string = String::deserialize(deserializer)?;
+        if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
+            Ok(Self {
+                model_id: id.to_string(),
+            })
+        } else {
+            return Err(serde::de::Error::custom(format!(
+                "Expected model name to begin with {}, got: {}",
+                MODEL_NAME_PREFIX, string
+            )));
+        }
+    }
+}
+
 #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 #[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
 pub enum Model {

crates/language_models/src/provider/cloud.rs 🔗

@@ -718,7 +718,8 @@ impl LanguageModel for CloudLanguageModel {
             CloudModel::Google(model) => {
                 let client = self.client.clone();
                 let llm_api_token = self.llm_api_token.clone();
-                let request = into_google(request, model.id().into());
+                let model_id = model.id().to_string();
+                let generate_content_request = into_google(request, model_id.clone());
                 async move {
                     let http_client = &client.http_client();
                     let token = llm_api_token.acquire(&client).await?;
@@ -736,9 +737,9 @@ impl LanguageModel for CloudLanguageModel {
                         };
                     let request_body = CountTokensBody {
                         provider: zed_llm_client::LanguageModelProvider::Google,
-                        model: model.id().into(),
+                        model: model_id,
                         provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
-                            contents: request.contents,
+                            generate_content_request,
                         })?,
                     };
                     let request = request_builder
@@ -895,7 +896,7 @@ impl LanguageModel for CloudLanguageModel {
                             prompt_id,
                             mode,
                             provider: zed_llm_client::LanguageModelProvider::Google,
-                            model: request.model.clone(),
+                            model: request.model.model_id.clone(),
                             provider_request: serde_json::to_value(&request)?,
                         },
                     )

crates/language_models/src/provider/google.rs 🔗

@@ -344,9 +344,8 @@ impl LanguageModel for GoogleLanguageModel {
                 http_client.as_ref(),
                 &api_url,
                 &api_key,
-                &model_id,
                 google_ai::CountTokensRequest {
-                    contents: request.contents,
+                    generate_content_request: request,
                 },
             )
             .await?;
@@ -382,7 +381,7 @@ impl LanguageModel for GoogleLanguageModel {
 
 pub fn into_google(
     mut request: LanguageModelRequest,
-    model: String,
+    model_id: String,
 ) -> google_ai::GenerateContentRequest {
     fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
         content
@@ -442,7 +441,7 @@ pub fn into_google(
     };
 
     google_ai::GenerateContentRequest {
-        model,
+        model: google_ai::ModelName { model_id },
         system_instruction: system_instructions,
         contents: request
             .messages