1use anyhow::Result;
2use futures::{Stream, StreamExt};
3use language_model_core::{
4 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
5 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
6 StopReason, TokenUsage,
7};
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::atomic::{self, AtomicU64};
11
12use crate::{
13 Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration,
14 GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode,
15 InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig,
16 UsageMetadata,
17};
18
19pub fn into_google(
20 mut request: LanguageModelRequest,
21 model_id: String,
22 mode: GoogleModelMode,
23) -> crate::GenerateContentRequest {
24 fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
25 content
26 .into_iter()
27 .flat_map(|content| match content {
28 MessageContent::Text(text) => {
29 if !text.is_empty() {
30 vec![Part::TextPart(TextPart { text })]
31 } else {
32 vec![]
33 }
34 }
35 MessageContent::Thinking {
36 text: _,
37 signature: Some(signature),
38 } => {
39 if !signature.is_empty() {
40 vec![Part::ThoughtPart(crate::ThoughtPart {
41 thought: true,
42 thought_signature: signature,
43 })]
44 } else {
45 vec![]
46 }
47 }
48 MessageContent::Thinking { .. } => {
49 vec![]
50 }
51 MessageContent::RedactedThinking(_) => vec![],
52 MessageContent::Image(image) => {
53 vec![Part::InlineDataPart(InlineDataPart {
54 inline_data: GenerativeContentBlob {
55 mime_type: "image/png".to_string(),
56 data: image.source.to_string(),
57 },
58 })]
59 }
60 MessageContent::ToolUse(tool_use) => {
61 // Normalize empty string signatures to None
62 let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
63
64 vec![Part::FunctionCallPart(crate::FunctionCallPart {
65 function_call: crate::FunctionCall {
66 name: tool_use.name.to_string(),
67 args: tool_use.input,
68 },
69 thought_signature,
70 })]
71 }
72 MessageContent::ToolResult(tool_result) => {
73 let mut text_output = String::new();
74 let mut images: Vec<InlineDataPart> = Vec::new();
75 for part in tool_result.content {
76 match part {
77 language_model_core::LanguageModelToolResultContent::Text(text) => {
78 text_output.push_str(&text);
79 }
80 language_model_core::LanguageModelToolResultContent::Image(image) => {
81 images.push(InlineDataPart {
82 inline_data: GenerativeContentBlob {
83 mime_type: "image/png".to_string(),
84 data: image.source.to_string(),
85 },
86 });
87 }
88 }
89 }
90 let output = if text_output.is_empty() && !images.is_empty() {
91 "Tool responded with an image".to_string()
92 } else {
93 text_output
94 };
95 let mut parts = vec![Part::FunctionResponsePart(crate::FunctionResponsePart {
96 function_response: crate::FunctionResponse {
97 name: tool_result.tool_name.to_string(),
98 // The API expects a valid JSON object
99 response: serde_json::json!({
100 "output": output
101 }),
102 },
103 })];
104 parts.extend(images.into_iter().map(Part::InlineDataPart));
105 parts
106 }
107 })
108 .collect()
109 }
110
111 let system_instructions = if request
112 .messages
113 .first()
114 .is_some_and(|msg| matches!(msg.role, Role::System))
115 {
116 let message = request.messages.remove(0);
117 Some(SystemInstruction {
118 parts: map_content(message.content),
119 })
120 } else {
121 None
122 };
123
124 crate::GenerateContentRequest {
125 model: ModelName { model_id },
126 system_instruction: system_instructions,
127 contents: request
128 .messages
129 .into_iter()
130 .filter_map(|message| {
131 let parts = map_content(message.content);
132 if parts.is_empty() {
133 None
134 } else {
135 Some(Content {
136 parts,
137 role: match message.role {
138 Role::User => crate::Role::User,
139 Role::Assistant => crate::Role::Model,
140 Role::System => crate::Role::User, // Google AI doesn't have a system role
141 },
142 })
143 }
144 })
145 .collect(),
146 generation_config: Some(GenerationConfig {
147 candidate_count: Some(1),
148 stop_sequences: Some(request.stop),
149 max_output_tokens: None,
150 temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
151 thinking_config: match (request.thinking_allowed, mode) {
152 (true, GoogleModelMode::Thinking { budget_tokens }) => {
153 budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
154 }
155 _ => None,
156 },
157 top_p: None,
158 top_k: None,
159 }),
160 safety_settings: None,
161 tools: (!request.tools.is_empty()).then(|| {
162 vec![crate::Tool {
163 function_declarations: request
164 .tools
165 .into_iter()
166 .map(|tool| FunctionDeclaration {
167 name: tool.name,
168 description: tool.description,
169 parameters: tool.input_schema,
170 })
171 .collect(),
172 }]
173 }),
174 tool_config: request.tool_choice.map(|choice| ToolConfig {
175 function_calling_config: FunctionCallingConfig {
176 mode: match choice {
177 LanguageModelToolChoice::Auto => FunctionCallingMode::Auto,
178 LanguageModelToolChoice::Any => FunctionCallingMode::Any,
179 LanguageModelToolChoice::None => FunctionCallingMode::None,
180 },
181 allowed_function_names: None,
182 },
183 }),
184 }
185}
186
187pub struct GoogleEventMapper {
188 usage: UsageMetadata,
189 stop_reason: StopReason,
190}
191
192impl GoogleEventMapper {
193 pub fn new() -> Self {
194 Self {
195 usage: UsageMetadata::default(),
196 stop_reason: StopReason::EndTurn,
197 }
198 }
199
200 pub fn map_stream(
201 mut self,
202 events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
203 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
204 {
205 events
206 .map(Some)
207 .chain(futures::stream::once(async { None }))
208 .flat_map(move |event| {
209 futures::stream::iter(match event {
210 Some(Ok(event)) => self.map_event(event),
211 Some(Err(error)) => {
212 vec![Err(LanguageModelCompletionError::from(error))]
213 }
214 None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
215 })
216 })
217 }
218
219 pub fn map_event(
220 &mut self,
221 event: GenerateContentResponse,
222 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
223 static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
224
225 let mut events: Vec<_> = Vec::new();
226 let mut wants_to_use_tool = false;
227 if let Some(usage_metadata) = event.usage_metadata {
228 update_usage(&mut self.usage, &usage_metadata);
229 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
230 convert_usage(&self.usage),
231 )))
232 }
233
234 if let Some(prompt_feedback) = event.prompt_feedback
235 && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
236 {
237 self.stop_reason = match block_reason {
238 "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
239 StopReason::Refusal
240 }
241 _ => {
242 log::error!("Unexpected Google block_reason: {block_reason}");
243 StopReason::Refusal
244 }
245 };
246 events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
247
248 return events;
249 }
250
251 if let Some(candidates) = event.candidates {
252 for candidate in candidates {
253 if let Some(finish_reason) = candidate.finish_reason.as_deref() {
254 self.stop_reason = match finish_reason {
255 "STOP" => StopReason::EndTurn,
256 "MAX_TOKENS" => StopReason::MaxTokens,
257 _ => {
258 log::error!("Unexpected google finish_reason: {finish_reason}");
259 StopReason::EndTurn
260 }
261 };
262 }
263 candidate
264 .content
265 .parts
266 .into_iter()
267 .for_each(|part| match part {
268 Part::TextPart(text_part) => {
269 events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
270 }
271 Part::InlineDataPart(_) => {}
272 Part::FunctionCallPart(function_call_part) => {
273 wants_to_use_tool = true;
274 let name: Arc<str> = function_call_part.function_call.name.into();
275 let next_tool_id =
276 TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
277 let id: LanguageModelToolUseId =
278 format!("{}-{}", name, next_tool_id).into();
279
280 // Normalize empty string signatures to None
281 let thought_signature = function_call_part
282 .thought_signature
283 .filter(|s| !s.is_empty());
284
285 events.push(Ok(LanguageModelCompletionEvent::ToolUse(
286 LanguageModelToolUse {
287 id,
288 name,
289 is_input_complete: true,
290 raw_input: function_call_part.function_call.args.to_string(),
291 input: function_call_part.function_call.args,
292 thought_signature,
293 },
294 )));
295 }
296 Part::FunctionResponsePart(_) => {}
297 Part::ThoughtPart(part) => {
298 events.push(Ok(LanguageModelCompletionEvent::Thinking {
299 text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
300 signature: Some(part.thought_signature),
301 }));
302 }
303 });
304 }
305 }
306
307 // Even when Gemini wants to use a Tool, the API
308 // responds with `finish_reason: STOP`
309 if wants_to_use_tool {
310 self.stop_reason = StopReason::ToolUse;
311 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
312 }
313 events
314 }
315}
316
317fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
318 if let Some(prompt_token_count) = new.prompt_token_count {
319 usage.prompt_token_count = Some(prompt_token_count);
320 }
321 if let Some(cached_content_token_count) = new.cached_content_token_count {
322 usage.cached_content_token_count = Some(cached_content_token_count);
323 }
324 if let Some(candidates_token_count) = new.candidates_token_count {
325 usage.candidates_token_count = Some(candidates_token_count);
326 }
327 if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
328 usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
329 }
330 if let Some(thoughts_token_count) = new.thoughts_token_count {
331 usage.thoughts_token_count = Some(thoughts_token_count);
332 }
333 if let Some(total_token_count) = new.total_token_count {
334 usage.total_token_count = Some(total_token_count);
335 }
336}
337
338fn convert_usage(usage: &UsageMetadata) -> TokenUsage {
339 let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
340 let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
341 let input_tokens = prompt_tokens - cached_tokens;
342 let output_tokens = usage.candidates_token_count.unwrap_or(0);
343
344 TokenUsage {
345 input_tokens,
346 output_tokens,
347 cache_read_input_tokens: cached_tokens,
348 cache_creation_input_tokens: 0,
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::{
356 Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
357 Part, Role as GoogleRole,
358 };
359 use serde_json::json;
360
361 #[test]
362 fn test_function_call_with_signature_creates_tool_use_with_signature() {
363 let mut mapper = GoogleEventMapper::new();
364
365 let response = GenerateContentResponse {
366 candidates: Some(vec![GenerateContentCandidate {
367 index: Some(0),
368 content: Content {
369 parts: vec![Part::FunctionCallPart(FunctionCallPart {
370 function_call: FunctionCall {
371 name: "test_function".to_string(),
372 args: json!({"arg": "value"}),
373 },
374 thought_signature: Some("test_signature_123".to_string()),
375 })],
376 role: GoogleRole::Model,
377 },
378 finish_reason: None,
379 finish_message: None,
380 safety_ratings: None,
381 citation_metadata: None,
382 }]),
383 prompt_feedback: None,
384 usage_metadata: None,
385 };
386
387 let events = mapper.map_event(response);
388 assert_eq!(events.len(), 2);
389
390 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
391 assert_eq!(tool_use.name.as_ref(), "test_function");
392 assert_eq!(
393 tool_use.thought_signature.as_deref(),
394 Some("test_signature_123")
395 );
396 } else {
397 panic!("Expected ToolUse event");
398 }
399 }
400
401 #[test]
402 fn test_function_call_without_signature_has_none() {
403 let mut mapper = GoogleEventMapper::new();
404
405 let response = GenerateContentResponse {
406 candidates: Some(vec![GenerateContentCandidate {
407 index: Some(0),
408 content: Content {
409 parts: vec![Part::FunctionCallPart(FunctionCallPart {
410 function_call: FunctionCall {
411 name: "test_function".to_string(),
412 args: json!({"arg": "value"}),
413 },
414 thought_signature: None,
415 })],
416 role: GoogleRole::Model,
417 },
418 finish_reason: None,
419 finish_message: None,
420 safety_ratings: None,
421 citation_metadata: None,
422 }]),
423 prompt_feedback: None,
424 usage_metadata: None,
425 };
426
427 let events = mapper.map_event(response);
428 assert_eq!(events.len(), 2);
429
430 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
431 assert!(tool_use.thought_signature.is_none());
432 } else {
433 panic!("Expected ToolUse event");
434 }
435 }
436
437 #[test]
438 fn test_empty_string_signature_normalized_to_none() {
439 let mut mapper = GoogleEventMapper::new();
440
441 let response = GenerateContentResponse {
442 candidates: Some(vec![GenerateContentCandidate {
443 index: Some(0),
444 content: Content {
445 parts: vec![Part::FunctionCallPart(FunctionCallPart {
446 function_call: FunctionCall {
447 name: "test_function".to_string(),
448 args: json!({"arg": "value"}),
449 },
450 thought_signature: Some("".to_string()),
451 })],
452 role: GoogleRole::Model,
453 },
454 finish_reason: None,
455 finish_message: None,
456 safety_ratings: None,
457 citation_metadata: None,
458 }]),
459 prompt_feedback: None,
460 usage_metadata: None,
461 };
462
463 let events = mapper.map_event(response);
464 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
465 assert!(tool_use.thought_signature.is_none());
466 } else {
467 panic!("Expected ToolUse event");
468 }
469 }
470}