Add thinking budget for Gemini custom models (#31251)

90aca and Ben Brandt created

Closes #31243

As described in my issue, the [thinking
budget](https://ai.google.dev/gemini-api/docs/thinking) gets
automatically chosen by Gemini unless it is specifically set to
something. In order to have fast responses (inline assistant) I prefer
to set it to 0.

Release Notes:

- ai: Added `thinking` mode for custom Google models with configurable
token budget

---------

Co-authored-by: Ben Brandt <benjamin.j.brandt@gmail.com>

Change summary

crates/google_ai/src/google_ai.rs             | 35 ++++++++++++++++
crates/language_models/src/provider/cloud.rs  |  7 ++
crates/language_models/src/provider/google.rs | 45 +++++++++++++++++++-
3 files changed, 82 insertions(+), 5 deletions(-)

Detailed changes

crates/google_ai/src/google_ai.rs 🔗

@@ -289,6 +289,22 @@ pub struct UsageMetadata {
     pub total_token_count: Option<usize>,
 }
 
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ThinkingConfig {
+    pub thinking_budget: u32,
+}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
+pub enum GoogleModelMode {
+    #[default]
+    Default,
+    Thinking {
+        budget_tokens: Option<u32>,
+    },
+}
+
 #[derive(Debug, Deserialize, Serialize)]
 #[serde(rename_all = "camelCase")]
 pub struct GenerationConfig {
@@ -304,6 +320,8 @@ pub struct GenerationConfig {
     pub top_p: Option<f64>,
     #[serde(skip_serializing_if = "Option::is_none")]
     pub top_k: Option<usize>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    pub thinking_config: Option<ThinkingConfig>,
 }
 
 #[derive(Debug, Serialize, Deserialize)]
@@ -496,6 +514,8 @@ pub enum Model {
         /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
         display_name: Option<String>,
         max_tokens: usize,
+        #[serde(default)]
+        mode: GoogleModelMode,
     },
 }
 
@@ -552,6 +572,21 @@ impl Model {
             Model::Custom { max_tokens, .. } => *max_tokens,
         }
     }
+
+    pub fn mode(&self) -> GoogleModelMode {
+        match self {
+            Self::Gemini15Pro
+            | Self::Gemini15Flash
+            | Self::Gemini20Pro
+            | Self::Gemini20Flash
+            | Self::Gemini20FlashThinking
+            | Self::Gemini20FlashLite
+            | Self::Gemini25ProExp0325
+            | Self::Gemini25ProPreview0325
+            | Self::Gemini25FlashPreview0417 => GoogleModelMode::Default,
+            Self::Custom { mode, .. } => *mode,
+        }
+    }
 }
 
 impl std::fmt::Display for Model {

crates/language_models/src/provider/cloud.rs 🔗

@@ -4,6 +4,7 @@ use client::{Client, UserStore, zed_urls};
 use futures::{
     AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
 };
+use google_ai::GoogleModelMode;
 use gpui::{
     AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
 };
@@ -750,7 +751,8 @@ impl LanguageModel for CloudLanguageModel {
                 let client = self.client.clone();
                 let llm_api_token = self.llm_api_token.clone();
                 let model_id = self.model.id.to_string();
-                let generate_content_request = into_google(request, model_id.clone());
+                let generate_content_request =
+                    into_google(request, model_id.clone(), GoogleModelMode::Default);
                 async move {
                     let http_client = &client.http_client();
                     let token = llm_api_token.acquire(&client).await?;
@@ -922,7 +924,8 @@ impl LanguageModel for CloudLanguageModel {
             }
             zed_llm_client::LanguageModelProvider::Google => {
                 let client = self.client.clone();
-                let request = into_google(request, self.model.id.to_string());
+                let request =
+                    into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
                 let llm_api_token = self.llm_api_token.clone();
                 let future = self.request_limiter.stream(async move {
                     let PerformLlmCompletionResponse {

crates/language_models/src/provider/google.rs 🔗

@@ -4,7 +4,8 @@ use credentials_provider::CredentialsProvider;
 use editor::{Editor, EditorElement, EditorStyle};
 use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
 use google_ai::{
-    FunctionDeclaration, GenerateContentResponse, Part, SystemInstruction, UsageMetadata,
+    FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
+    ThinkingConfig, UsageMetadata,
 };
 use gpui::{
     AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
@@ -45,11 +46,41 @@ pub struct GoogleSettings {
     pub available_models: Vec<AvailableModel>,
 }
 
+#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ModelMode {
+    #[default]
+    Default,
+    Thinking {
+        /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
+        budget_tokens: Option<u32>,
+    },
+}
+
+impl From<ModelMode> for GoogleModelMode {
+    fn from(value: ModelMode) -> Self {
+        match value {
+            ModelMode::Default => GoogleModelMode::Default,
+            ModelMode::Thinking { budget_tokens } => GoogleModelMode::Thinking { budget_tokens },
+        }
+    }
+}
+
+impl From<GoogleModelMode> for ModelMode {
+    fn from(value: GoogleModelMode) -> Self {
+        match value {
+            GoogleModelMode::Default => ModelMode::Default,
+            GoogleModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
+        }
+    }
+}
+
 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 pub struct AvailableModel {
     name: String,
     display_name: Option<String>,
     max_tokens: usize,
+    mode: Option<ModelMode>,
 }
 
 pub struct GoogleLanguageModelProvider {
@@ -216,6 +247,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
                     name: model.name.clone(),
                     display_name: model.display_name.clone(),
                     max_tokens: model.max_tokens,
+                    mode: model.mode.unwrap_or_default().into(),
                 },
             );
         }
@@ -343,7 +375,7 @@ impl LanguageModel for GoogleLanguageModel {
         cx: &App,
     ) -> BoxFuture<'static, Result<usize>> {
         let model_id = self.model.id().to_string();
-        let request = into_google(request, model_id.clone());
+        let request = into_google(request, model_id.clone(), self.model.mode());
         let http_client = self.http_client.clone();
         let api_key = self.state.read(cx).api_key.clone();
 
@@ -379,7 +411,7 @@ impl LanguageModel for GoogleLanguageModel {
             >,
         >,
     > {
-        let request = into_google(request, self.model.id().to_string());
+        let request = into_google(request, self.model.id().to_string(), self.model.mode());
         let request = self.stream_completion(request, cx);
         let future = self.request_limiter.stream(async move {
             let response = request
@@ -394,6 +426,7 @@ impl LanguageModel for GoogleLanguageModel {
 pub fn into_google(
     mut request: LanguageModelRequest,
     model_id: String,
+    mode: GoogleModelMode,
 ) -> google_ai::GenerateContentRequest {
     fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
         content
@@ -504,6 +537,12 @@ pub fn into_google(
             stop_sequences: Some(request.stop),
             max_output_tokens: None,
             temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
+            thinking_config: match mode {
+                GoogleModelMode::Thinking { budget_tokens } => {
+                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
+                }
+                GoogleModelMode::Default => None,
+            },
             top_p: None,
             top_k: None,
         }),