@@ -289,6 +289,22 @@ pub struct UsageMetadata {
pub total_token_count: Option<usize>,
}
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct ThinkingConfig {
+ pub thinking_budget: u32,
+}
+
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize, PartialEq, Eq)]
+pub enum GoogleModelMode {
+ #[default]
+ Default,
+ Thinking {
+ budget_tokens: Option<u32>,
+ },
+}
+
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerationConfig {
@@ -304,6 +320,8 @@ pub struct GenerationConfig {
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<usize>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub thinking_config: Option<ThinkingConfig>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -496,6 +514,8 @@ pub enum Model {
/// The name displayed in the UI, such as in the assistant panel model dropdown menu.
display_name: Option<String>,
max_tokens: usize,
+ #[serde(default)]
+ mode: GoogleModelMode,
},
}
@@ -552,6 +572,21 @@ impl Model {
Model::Custom { max_tokens, .. } => *max_tokens,
}
}
+
+ pub fn mode(&self) -> GoogleModelMode {
+ match self {
+ Self::Gemini15Pro
+ | Self::Gemini15Flash
+ | Self::Gemini20Pro
+ | Self::Gemini20Flash
+ | Self::Gemini20FlashThinking
+ | Self::Gemini20FlashLite
+ | Self::Gemini25ProExp0325
+ | Self::Gemini25ProPreview0325
+ | Self::Gemini25FlashPreview0417 => GoogleModelMode::Default,
+ Self::Custom { mode, .. } => *mode,
+ }
+ }
}
impl std::fmt::Display for Model {
@@ -4,6 +4,7 @@ use client::{Client, UserStore, zed_urls};
use futures::{
AsyncBufReadExt, FutureExt, Stream, StreamExt, future::BoxFuture, stream::BoxStream,
};
+use google_ai::GoogleModelMode;
use gpui::{
AnyElement, AnyView, App, AsyncApp, Context, Entity, SemanticVersion, Subscription, Task,
};
@@ -750,7 +751,8 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let model_id = self.model.id.to_string();
- let generate_content_request = into_google(request, model_id.clone());
+ let generate_content_request =
+ into_google(request, model_id.clone(), GoogleModelMode::Default);
async move {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?;
@@ -922,7 +924,8 @@ impl LanguageModel for CloudLanguageModel {
}
zed_llm_client::LanguageModelProvider::Google => {
let client = self.client.clone();
- let request = into_google(request, self.model.id.to_string());
+ let request =
+ into_google(request, self.model.id.to_string(), GoogleModelMode::Default);
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
let PerformLlmCompletionResponse {
@@ -4,7 +4,8 @@ use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use google_ai::{
- FunctionDeclaration, GenerateContentResponse, Part, SystemInstruction, UsageMetadata,
+ FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
+ ThinkingConfig, UsageMetadata,
};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
@@ -45,11 +46,41 @@ pub struct GoogleSettings {
pub available_models: Vec<AvailableModel>,
}
+#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum ModelMode {
+ #[default]
+ Default,
+ Thinking {
+ /// The maximum number of tokens to use for reasoning. Must be lower than the model's `max_output_tokens`.
+ budget_tokens: Option<u32>,
+ },
+}
+
+impl From<ModelMode> for GoogleModelMode {
+ fn from(value: ModelMode) -> Self {
+ match value {
+ ModelMode::Default => GoogleModelMode::Default,
+ ModelMode::Thinking { budget_tokens } => GoogleModelMode::Thinking { budget_tokens },
+ }
+ }
+}
+
+impl From<GoogleModelMode> for ModelMode {
+ fn from(value: GoogleModelMode) -> Self {
+ match value {
+ GoogleModelMode::Default => ModelMode::Default,
+ GoogleModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
+ }
+ }
+}
+
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
name: String,
display_name: Option<String>,
max_tokens: usize,
+ mode: Option<ModelMode>,
}
pub struct GoogleLanguageModelProvider {
@@ -216,6 +247,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
+ mode: model.mode.unwrap_or_default().into(),
},
);
}
@@ -343,7 +375,7 @@ impl LanguageModel for GoogleLanguageModel {
cx: &App,
) -> BoxFuture<'static, Result<usize>> {
let model_id = self.model.id().to_string();
- let request = into_google(request, model_id.clone());
+ let request = into_google(request, model_id.clone(), self.model.mode());
let http_client = self.http_client.clone();
let api_key = self.state.read(cx).api_key.clone();
@@ -379,7 +411,7 @@ impl LanguageModel for GoogleLanguageModel {
>,
>,
> {
- let request = into_google(request, self.model.id().to_string());
+ let request = into_google(request, self.model.id().to_string(), self.model.mode());
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request
@@ -394,6 +426,7 @@ impl LanguageModel for GoogleLanguageModel {
pub fn into_google(
mut request: LanguageModelRequest,
model_id: String,
+ mode: GoogleModelMode,
) -> google_ai::GenerateContentRequest {
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
content
@@ -504,6 +537,12 @@ pub fn into_google(
stop_sequences: Some(request.stop),
max_output_tokens: None,
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
+ thinking_config: match mode {
+ GoogleModelMode::Thinking { budget_tokens } => {
+ budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
+ }
+ GoogleModelMode::Default => None,
+ },
top_p: None,
top_k: None,
}),