gemini: Pass system prompt as system instructions (#28793)

Bennet Bo Fenner created

https://ai.google.dev/gemini-api/docs/text-generation#system-instructions

Release Notes:

- agent: Improve performance of Gemini models

Change summary

crates/google_ai/src/google_ai.rs             |  7 +
crates/language_models/src/provider/google.rs | 91 ++++++++++++--------
2 files changed, 62 insertions(+), 36 deletions(-)

Detailed changes

crates/google_ai/src/google_ai.rs 🔗

@@ -125,6 +125,7 @@ pub struct GenerateContentRequest {
     #[serde(default, skip_serializing_if = "String::is_empty")]
     pub model: String,
     pub contents: Vec<Content>,
+    pub system_instructions: Option<SystemInstructions>,
     pub generation_config: Option<GenerationConfig>,
     pub safety_settings: Option<Vec<SafetySetting>>,
     #[serde(skip_serializing_if = "Option::is_none")]
@@ -159,6 +160,12 @@ pub struct Content {
     pub role: Role,
 }
 
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(rename_all = "camelCase")]
+pub struct SystemInstructions {
+    pub parts: Vec<Part>,
+}
+
 #[derive(Debug, PartialEq, Deserialize, Serialize)]
 #[serde(rename_all = "camelCase")]
 pub enum Role {

crates/language_models/src/provider/google.rs 🔗

@@ -3,14 +3,16 @@ use collections::BTreeMap;
 use credentials_provider::CredentialsProvider;
 use editor::{Editor, EditorElement, EditorStyle};
 use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
-use google_ai::{FunctionDeclaration, GenerateContentResponse, Part, UsageMetadata};
+use google_ai::{
+    FunctionDeclaration, GenerateContentResponse, Part, SystemInstructions, UsageMetadata,
+};
 use gpui::{
     AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
 };
 use http_client::HttpClient;
 use language_model::{
     AuthenticateError, LanguageModelCompletionEvent, LanguageModelToolSchemaFormat,
-    LanguageModelToolUse, LanguageModelToolUseId, StopReason,
+    LanguageModelToolUse, LanguageModelToolUseId, MessageContent, StopReason,
 };
 use language_model::{
     LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
@@ -359,48 +361,65 @@ impl LanguageModel for GoogleLanguageModel {
 }
 
 pub fn into_google(
-    request: LanguageModelRequest,
+    mut request: LanguageModelRequest,
     model: String,
 ) -> google_ai::GenerateContentRequest {
+    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
+        content
+            .into_iter()
+            .filter_map(|content| match content {
+                language_model::MessageContent::Text(text) => {
+                    if !text.is_empty() {
+                        Some(Part::TextPart(google_ai::TextPart { text }))
+                    } else {
+                        None
+                    }
+                }
+                language_model::MessageContent::Image(_) => None,
+                language_model::MessageContent::ToolUse(tool_use) => {
+                    Some(Part::FunctionCallPart(google_ai::FunctionCallPart {
+                        function_call: google_ai::FunctionCall {
+                            name: tool_use.name.to_string(),
+                            args: tool_use.input,
+                        },
+                    }))
+                }
+                language_model::MessageContent::ToolResult(tool_result) => Some(
+                    Part::FunctionResponsePart(google_ai::FunctionResponsePart {
+                        function_response: google_ai::FunctionResponse {
+                            name: tool_result.tool_name.to_string(),
+                            // The API expects a valid JSON object
+                            response: serde_json::json!({
+                                "output": tool_result.content
+                            }),
+                        },
+                    }),
+                ),
+            })
+            .collect()
+    }
+
+    let system_instructions = if request
+        .messages
+        .first()
+        .map_or(false, |msg| matches!(msg.role, Role::System))
+    {
+        let message = request.messages.remove(0);
+        Some(SystemInstructions {
+            parts: map_content(message.content),
+        })
+    } else {
+        None
+    };
+
     google_ai::GenerateContentRequest {
         model,
+        system_instructions,
         contents: request
             .messages
             .into_iter()
             .map(|message| google_ai::Content {
-                parts: message
-                    .content
-                    .into_iter()
-                    .filter_map(|content| match content {
-                        language_model::MessageContent::Text(text) => {
-                            if !text.is_empty() {
-                                Some(Part::TextPart(google_ai::TextPart { text }))
-                            } else {
-                                None
-                            }
-                        }
-                        language_model::MessageContent::Image(_) => None,
-                        language_model::MessageContent::ToolUse(tool_use) => {
-                            Some(Part::FunctionCallPart(google_ai::FunctionCallPart {
-                                function_call: google_ai::FunctionCall {
-                                    name: tool_use.name.to_string(),
-                                    args: tool_use.input,
-                                },
-                            }))
-                        }
-                        language_model::MessageContent::ToolResult(tool_result) => Some(
-                            Part::FunctionResponsePart(google_ai::FunctionResponsePart {
-                                function_response: google_ai::FunctionResponse {
-                                    name: tool_result.tool_name.to_string(),
-                                    // The API expects a valid JSON object
-                                    response: serde_json::json!({
-                                        "output": tool_result.content
-                                    }),
-                                },
-                            }),
-                        ),
-                    })
-                    .collect(),
+                parts: map_content(message.content),
                 role: match message.role {
                     Role::User => google_ai::Role::User,
                     Role::Assistant => google_ai::Role::Model,