1use anyhow::Result;
2use futures::{Stream, StreamExt};
3use language_model_core::{
4 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
5 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
6 StopReason, TokenUsage,
7};
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::atomic::{self, AtomicU64};
11
12use crate::{
13 Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration,
14 GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode,
15 InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig,
16 UsageMetadata,
17};
18
19pub fn into_google(
20 mut request: LanguageModelRequest,
21 model_id: String,
22 mode: GoogleModelMode,
23) -> crate::GenerateContentRequest {
24 fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
25 content
26 .into_iter()
27 .flat_map(|content| match content {
28 MessageContent::Text(text) => {
29 if !text.is_empty() {
30 vec![Part::TextPart(TextPart { text })]
31 } else {
32 vec![]
33 }
34 }
35 MessageContent::Thinking {
36 text: _,
37 signature: Some(signature),
38 } => {
39 if !signature.is_empty() {
40 vec![Part::ThoughtPart(crate::ThoughtPart {
41 thought: true,
42 thought_signature: signature,
43 })]
44 } else {
45 vec![]
46 }
47 }
48 MessageContent::Thinking { .. } => {
49 vec![]
50 }
51 MessageContent::RedactedThinking(_) => vec![],
52 MessageContent::Image(image) => {
53 vec![Part::InlineDataPart(InlineDataPart {
54 inline_data: GenerativeContentBlob {
55 mime_type: "image/png".to_string(),
56 data: image.source.to_string(),
57 },
58 })]
59 }
60 MessageContent::ToolUse(tool_use) => {
61 // Normalize empty string signatures to None
62 let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
63
64 vec![Part::FunctionCallPart(crate::FunctionCallPart {
65 function_call: crate::FunctionCall {
66 name: tool_use.name.to_string(),
67 args: tool_use.input,
68 },
69 thought_signature,
70 })]
71 }
72 MessageContent::ToolResult(tool_result) => {
73 match tool_result.content {
74 language_model_core::LanguageModelToolResultContent::Text(text) => {
75 vec![Part::FunctionResponsePart(crate::FunctionResponsePart {
76 function_response: crate::FunctionResponse {
77 name: tool_result.tool_name.to_string(),
78 // The API expects a valid JSON object
79 response: serde_json::json!({
80 "output": text
81 }),
82 },
83 })]
84 }
85 language_model_core::LanguageModelToolResultContent::Image(image) => {
86 vec![
87 Part::FunctionResponsePart(crate::FunctionResponsePart {
88 function_response: crate::FunctionResponse {
89 name: tool_result.tool_name.to_string(),
90 // The API expects a valid JSON object
91 response: serde_json::json!({
92 "output": "Tool responded with an image"
93 }),
94 },
95 }),
96 Part::InlineDataPart(InlineDataPart {
97 inline_data: GenerativeContentBlob {
98 mime_type: "image/png".to_string(),
99 data: image.source.to_string(),
100 },
101 }),
102 ]
103 }
104 }
105 }
106 })
107 .collect()
108 }
109
110 let system_instructions = if request
111 .messages
112 .first()
113 .is_some_and(|msg| matches!(msg.role, Role::System))
114 {
115 let message = request.messages.remove(0);
116 Some(SystemInstruction {
117 parts: map_content(message.content),
118 })
119 } else {
120 None
121 };
122
123 crate::GenerateContentRequest {
124 model: ModelName { model_id },
125 system_instruction: system_instructions,
126 contents: request
127 .messages
128 .into_iter()
129 .filter_map(|message| {
130 let parts = map_content(message.content);
131 if parts.is_empty() {
132 None
133 } else {
134 Some(Content {
135 parts,
136 role: match message.role {
137 Role::User => crate::Role::User,
138 Role::Assistant => crate::Role::Model,
139 Role::System => crate::Role::User, // Google AI doesn't have a system role
140 },
141 })
142 }
143 })
144 .collect(),
145 generation_config: Some(GenerationConfig {
146 candidate_count: Some(1),
147 stop_sequences: Some(request.stop),
148 max_output_tokens: None,
149 temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
150 thinking_config: match (request.thinking_allowed, mode) {
151 (true, GoogleModelMode::Thinking { budget_tokens }) => {
152 budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
153 }
154 _ => None,
155 },
156 top_p: None,
157 top_k: None,
158 }),
159 safety_settings: None,
160 tools: (!request.tools.is_empty()).then(|| {
161 vec![crate::Tool {
162 function_declarations: request
163 .tools
164 .into_iter()
165 .map(|tool| FunctionDeclaration {
166 name: tool.name,
167 description: tool.description,
168 parameters: tool.input_schema,
169 })
170 .collect(),
171 }]
172 }),
173 tool_config: request.tool_choice.map(|choice| ToolConfig {
174 function_calling_config: FunctionCallingConfig {
175 mode: match choice {
176 LanguageModelToolChoice::Auto => FunctionCallingMode::Auto,
177 LanguageModelToolChoice::Any => FunctionCallingMode::Any,
178 LanguageModelToolChoice::None => FunctionCallingMode::None,
179 },
180 allowed_function_names: None,
181 },
182 }),
183 }
184}
185
186pub struct GoogleEventMapper {
187 usage: UsageMetadata,
188 stop_reason: StopReason,
189}
190
191impl GoogleEventMapper {
192 pub fn new() -> Self {
193 Self {
194 usage: UsageMetadata::default(),
195 stop_reason: StopReason::EndTurn,
196 }
197 }
198
199 pub fn map_stream(
200 mut self,
201 events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
202 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
203 {
204 events
205 .map(Some)
206 .chain(futures::stream::once(async { None }))
207 .flat_map(move |event| {
208 futures::stream::iter(match event {
209 Some(Ok(event)) => self.map_event(event),
210 Some(Err(error)) => {
211 vec![Err(LanguageModelCompletionError::from(error))]
212 }
213 None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
214 })
215 })
216 }
217
218 pub fn map_event(
219 &mut self,
220 event: GenerateContentResponse,
221 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
222 static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
223
224 let mut events: Vec<_> = Vec::new();
225 let mut wants_to_use_tool = false;
226 if let Some(usage_metadata) = event.usage_metadata {
227 update_usage(&mut self.usage, &usage_metadata);
228 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
229 convert_usage(&self.usage),
230 )))
231 }
232
233 if let Some(prompt_feedback) = event.prompt_feedback
234 && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
235 {
236 self.stop_reason = match block_reason {
237 "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
238 StopReason::Refusal
239 }
240 _ => {
241 log::error!("Unexpected Google block_reason: {block_reason}");
242 StopReason::Refusal
243 }
244 };
245 events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
246
247 return events;
248 }
249
250 if let Some(candidates) = event.candidates {
251 for candidate in candidates {
252 if let Some(finish_reason) = candidate.finish_reason.as_deref() {
253 self.stop_reason = match finish_reason {
254 "STOP" => StopReason::EndTurn,
255 "MAX_TOKENS" => StopReason::MaxTokens,
256 _ => {
257 log::error!("Unexpected google finish_reason: {finish_reason}");
258 StopReason::EndTurn
259 }
260 };
261 }
262 candidate
263 .content
264 .parts
265 .into_iter()
266 .for_each(|part| match part {
267 Part::TextPart(text_part) => {
268 events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
269 }
270 Part::InlineDataPart(_) => {}
271 Part::FunctionCallPart(function_call_part) => {
272 wants_to_use_tool = true;
273 let name: Arc<str> = function_call_part.function_call.name.into();
274 let next_tool_id =
275 TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
276 let id: LanguageModelToolUseId =
277 format!("{}-{}", name, next_tool_id).into();
278
279 // Normalize empty string signatures to None
280 let thought_signature = function_call_part
281 .thought_signature
282 .filter(|s| !s.is_empty());
283
284 events.push(Ok(LanguageModelCompletionEvent::ToolUse(
285 LanguageModelToolUse {
286 id,
287 name,
288 is_input_complete: true,
289 raw_input: function_call_part.function_call.args.to_string(),
290 input: function_call_part.function_call.args,
291 thought_signature,
292 },
293 )));
294 }
295 Part::FunctionResponsePart(_) => {}
296 Part::ThoughtPart(part) => {
297 events.push(Ok(LanguageModelCompletionEvent::Thinking {
298 text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
299 signature: Some(part.thought_signature),
300 }));
301 }
302 });
303 }
304 }
305
306 // Even when Gemini wants to use a Tool, the API
307 // responds with `finish_reason: STOP`
308 if wants_to_use_tool {
309 self.stop_reason = StopReason::ToolUse;
310 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
311 }
312 events
313 }
314}
315
316fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
317 if let Some(prompt_token_count) = new.prompt_token_count {
318 usage.prompt_token_count = Some(prompt_token_count);
319 }
320 if let Some(cached_content_token_count) = new.cached_content_token_count {
321 usage.cached_content_token_count = Some(cached_content_token_count);
322 }
323 if let Some(candidates_token_count) = new.candidates_token_count {
324 usage.candidates_token_count = Some(candidates_token_count);
325 }
326 if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
327 usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
328 }
329 if let Some(thoughts_token_count) = new.thoughts_token_count {
330 usage.thoughts_token_count = Some(thoughts_token_count);
331 }
332 if let Some(total_token_count) = new.total_token_count {
333 usage.total_token_count = Some(total_token_count);
334 }
335}
336
337fn convert_usage(usage: &UsageMetadata) -> TokenUsage {
338 let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
339 let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
340 let input_tokens = prompt_tokens - cached_tokens;
341 let output_tokens = usage.candidates_token_count.unwrap_or(0);
342
343 TokenUsage {
344 input_tokens,
345 output_tokens,
346 cache_read_input_tokens: cached_tokens,
347 cache_creation_input_tokens: 0,
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use crate::{
355 Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
356 Part, Role as GoogleRole,
357 };
358 use serde_json::json;
359
360 #[test]
361 fn test_function_call_with_signature_creates_tool_use_with_signature() {
362 let mut mapper = GoogleEventMapper::new();
363
364 let response = GenerateContentResponse {
365 candidates: Some(vec![GenerateContentCandidate {
366 index: Some(0),
367 content: Content {
368 parts: vec![Part::FunctionCallPart(FunctionCallPart {
369 function_call: FunctionCall {
370 name: "test_function".to_string(),
371 args: json!({"arg": "value"}),
372 },
373 thought_signature: Some("test_signature_123".to_string()),
374 })],
375 role: GoogleRole::Model,
376 },
377 finish_reason: None,
378 finish_message: None,
379 safety_ratings: None,
380 citation_metadata: None,
381 }]),
382 prompt_feedback: None,
383 usage_metadata: None,
384 };
385
386 let events = mapper.map_event(response);
387 assert_eq!(events.len(), 2);
388
389 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
390 assert_eq!(tool_use.name.as_ref(), "test_function");
391 assert_eq!(
392 tool_use.thought_signature.as_deref(),
393 Some("test_signature_123")
394 );
395 } else {
396 panic!("Expected ToolUse event");
397 }
398 }
399
400 #[test]
401 fn test_function_call_without_signature_has_none() {
402 let mut mapper = GoogleEventMapper::new();
403
404 let response = GenerateContentResponse {
405 candidates: Some(vec![GenerateContentCandidate {
406 index: Some(0),
407 content: Content {
408 parts: vec![Part::FunctionCallPart(FunctionCallPart {
409 function_call: FunctionCall {
410 name: "test_function".to_string(),
411 args: json!({"arg": "value"}),
412 },
413 thought_signature: None,
414 })],
415 role: GoogleRole::Model,
416 },
417 finish_reason: None,
418 finish_message: None,
419 safety_ratings: None,
420 citation_metadata: None,
421 }]),
422 prompt_feedback: None,
423 usage_metadata: None,
424 };
425
426 let events = mapper.map_event(response);
427 assert_eq!(events.len(), 2);
428
429 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
430 assert!(tool_use.thought_signature.is_none());
431 } else {
432 panic!("Expected ToolUse event");
433 }
434 }
435
436 #[test]
437 fn test_empty_string_signature_normalized_to_none() {
438 let mut mapper = GoogleEventMapper::new();
439
440 let response = GenerateContentResponse {
441 candidates: Some(vec![GenerateContentCandidate {
442 index: Some(0),
443 content: Content {
444 parts: vec![Part::FunctionCallPart(FunctionCallPart {
445 function_call: FunctionCall {
446 name: "test_function".to_string(),
447 args: json!({"arg": "value"}),
448 },
449 thought_signature: Some("".to_string()),
450 })],
451 role: GoogleRole::Model,
452 },
453 finish_reason: None,
454 finish_message: None,
455 safety_ratings: None,
456 citation_metadata: None,
457 }]),
458 prompt_feedback: None,
459 usage_metadata: None,
460 };
461
462 let events = mapper.map_event(response);
463 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
464 assert!(tool_use.thought_signature.is_none());
465 } else {
466 panic!("Expected ToolUse event");
467 }
468 }
469}