1use anyhow::Result;
2use futures::{Stream, StreamExt};
3use language_model_core::{
4 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
5 LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
6 StopReason, TokenUsage,
7};
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::atomic::{self, AtomicU64};
11
12use crate::{
13 Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration,
14 GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode,
15 InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig,
16 UsageMetadata,
17};
18
19pub fn into_google(
20 mut request: LanguageModelRequest,
21 model_id: String,
22 mode: GoogleModelMode,
23) -> crate::GenerateContentRequest {
24 fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
25 content
26 .into_iter()
27 .flat_map(|content| match content {
28 MessageContent::Text(text) => {
29 if !text.is_empty() {
30 vec![Part::TextPart(TextPart { text })]
31 } else {
32 vec![]
33 }
34 }
35 MessageContent::Thinking {
36 text: _,
37 signature: Some(signature),
38 } => {
39 if !signature.is_empty() {
40 vec![Part::ThoughtPart(crate::ThoughtPart {
41 thought: true,
42 thought_signature: signature,
43 })]
44 } else {
45 vec![]
46 }
47 }
48 MessageContent::Thinking { .. } => {
49 vec![]
50 }
51 MessageContent::RedactedThinking(_) => vec![],
52 MessageContent::Image(image) => {
53 vec![Part::InlineDataPart(InlineDataPart {
54 inline_data: GenerativeContentBlob {
55 mime_type: "image/png".to_string(),
56 data: image.source.to_string(),
57 },
58 })]
59 }
60 MessageContent::ToolUse(tool_use) => {
61 // Normalize empty string signatures to None
62 let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
63
64 vec![Part::FunctionCallPart(crate::FunctionCallPart {
65 function_call: crate::FunctionCall {
66 name: tool_use.name.to_string(),
67 args: tool_use.input,
68 },
69 thought_signature,
70 })]
71 }
72 MessageContent::ToolResult(tool_result) => {
73 match tool_result.content {
74 language_model_core::LanguageModelToolResultContent::Text(text) => {
75 vec![Part::FunctionResponsePart(crate::FunctionResponsePart {
76 function_response: crate::FunctionResponse {
77 name: tool_result.tool_name.to_string(),
78 // The API expects a valid JSON object
79 response: serde_json::json!({
80 "output": text
81 }),
82 },
83 })]
84 }
85 language_model_core::LanguageModelToolResultContent::Image(image) => {
86 vec![
87 Part::FunctionResponsePart(crate::FunctionResponsePart {
88 function_response: crate::FunctionResponse {
89 name: tool_result.tool_name.to_string(),
90 // The API expects a valid JSON object
91 response: serde_json::json!({
92 "output": "Tool responded with an image"
93 }),
94 },
95 }),
96 Part::InlineDataPart(InlineDataPart {
97 inline_data: GenerativeContentBlob {
98 mime_type: "image/png".to_string(),
99 data: image.source.to_string(),
100 },
101 }),
102 ]
103 }
104 }
105 }
106 })
107 .collect()
108 }
109
110 let system_instructions = if request
111 .messages
112 .first()
113 .is_some_and(|msg| matches!(msg.role, Role::System))
114 {
115 let message = request.messages.remove(0);
116 Some(SystemInstruction {
117 parts: map_content(message.content),
118 })
119 } else {
120 None
121 };
122
123 crate::GenerateContentRequest {
124 model: ModelName { model_id },
125 system_instruction: system_instructions,
126 contents: request
127 .messages
128 .into_iter()
129 .filter_map(|message| {
130 let parts = map_content(message.content);
131 if parts.is_empty() {
132 None
133 } else {
134 Some(Content {
135 parts,
136 role: match message.role {
137 Role::User => crate::Role::User,
138 Role::Assistant => crate::Role::Model,
139 Role::System => crate::Role::User, // Google AI doesn't have a system role
140 },
141 })
142 }
143 })
144 .collect(),
145 generation_config: Some(GenerationConfig {
146 candidate_count: Some(1),
147 stop_sequences: Some(request.stop),
148 max_output_tokens: None,
149 temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
150 thinking_config: match (request.thinking_allowed, mode) {
151 (true, GoogleModelMode::Thinking { budget_tokens }) => {
152 budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
153 }
154 _ => None,
155 },
156 top_p: None,
157 top_k: None,
158 }),
159 safety_settings: None,
160 tools: (!request.tools.is_empty()).then(|| {
161 vec![crate::Tool {
162 function_declarations: request
163 .tools
164 .into_iter()
165 .map(|tool| FunctionDeclaration {
166 name: tool.name,
167 description: tool.description,
168 parameters: tool.input_schema,
169 })
170 .collect(),
171 }]
172 }),
173 tool_config: request.tool_choice.map(|choice| ToolConfig {
174 function_calling_config: FunctionCallingConfig {
175 mode: match choice {
176 LanguageModelToolChoice::Auto => FunctionCallingMode::Auto,
177 LanguageModelToolChoice::Any => FunctionCallingMode::Any,
178 LanguageModelToolChoice::None => FunctionCallingMode::None,
179 },
180 allowed_function_names: None,
181 },
182 }),
183 }
184}
185
186pub struct GoogleEventMapper {
187 usage: UsageMetadata,
188 stop_reason: StopReason,
189}
190
191impl GoogleEventMapper {
192 pub fn new() -> Self {
193 Self {
194 usage: UsageMetadata::default(),
195 stop_reason: StopReason::EndTurn,
196 }
197 }
198
199 pub fn map_stream(
200 mut self,
201 events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
202 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
203 {
204 events
205 .map(Some)
206 .chain(futures::stream::once(async { None }))
207 .flat_map(move |event| {
208 futures::stream::iter(match event {
209 Some(Ok(event)) => self.map_event(event),
210 Some(Err(error)) => {
211 vec![Err(LanguageModelCompletionError::from(error))]
212 }
213 None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
214 })
215 })
216 }
217
218 pub fn map_event(
219 &mut self,
220 event: GenerateContentResponse,
221 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
222 static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
223
224 let mut events: Vec<_> = Vec::new();
225 let mut wants_to_use_tool = false;
226 if let Some(usage_metadata) = event.usage_metadata {
227 update_usage(&mut self.usage, &usage_metadata);
228 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
229 convert_usage(&self.usage),
230 )))
231 }
232
233 if let Some(prompt_feedback) = event.prompt_feedback
234 && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
235 {
236 self.stop_reason = match block_reason {
237 "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
238 StopReason::Refusal
239 }
240 _ => {
241 log::error!("Unexpected Google block_reason: {block_reason}");
242 StopReason::Refusal
243 }
244 };
245 events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
246
247 return events;
248 }
249
250 if let Some(candidates) = event.candidates {
251 for candidate in candidates {
252 if let Some(finish_reason) = candidate.finish_reason.as_deref() {
253 self.stop_reason = match finish_reason {
254 "STOP" => StopReason::EndTurn,
255 "MAX_TOKENS" => StopReason::MaxTokens,
256 _ => {
257 log::error!("Unexpected google finish_reason: {finish_reason}");
258 StopReason::EndTurn
259 }
260 };
261 }
262 candidate
263 .content
264 .parts
265 .into_iter()
266 .for_each(|part| match part {
267 Part::TextPart(text_part) => {
268 events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
269 }
270 Part::InlineDataPart(_) => {}
271 Part::FunctionCallPart(function_call_part) => {
272 wants_to_use_tool = true;
273 let name: Arc<str> = function_call_part.function_call.name.into();
274 let next_tool_id =
275 TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
276 let id: LanguageModelToolUseId =
277 format!("{}-{}", name, next_tool_id).into();
278
279 // Normalize empty string signatures to None
280 let thought_signature = function_call_part
281 .thought_signature
282 .filter(|s| !s.is_empty());
283
284 events.push(Ok(LanguageModelCompletionEvent::ToolUse(
285 LanguageModelToolUse {
286 id,
287 name,
288 is_input_complete: true,
289 raw_input: function_call_part.function_call.args.to_string(),
290 input: function_call_part.function_call.args,
291 thought_signature,
292 },
293 )));
294 }
295 Part::FunctionResponsePart(_) => {}
296 Part::ThoughtPart(part) => {
297 events.push(Ok(LanguageModelCompletionEvent::Thinking {
298 text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
299 signature: Some(part.thought_signature),
300 }));
301 }
302 });
303 }
304 }
305
306 // Even when Gemini wants to use a Tool, the API
307 // responds with `finish_reason: STOP`
308 if wants_to_use_tool {
309 self.stop_reason = StopReason::ToolUse;
310 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
311 }
312 events
313 }
314}
315
316/// Count tokens for a Google AI model using tiktoken. This is synchronous;
317/// callers should spawn it on a background thread if needed.
318pub fn count_google_tokens(request: LanguageModelRequest) -> Result<u64> {
319 let messages = request
320 .messages
321 .into_iter()
322 .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
323 role: match message.role {
324 Role::User => "user".into(),
325 Role::Assistant => "assistant".into(),
326 Role::System => "system".into(),
327 },
328 content: Some(message.string_contents()),
329 name: None,
330 function_call: None,
331 })
332 .collect::<Vec<_>>();
333
334 // Tiktoken doesn't yet support these models, so we manually use the
335 // same tokenizer as GPT-4.
336 tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
337}
338
339fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
340 if let Some(prompt_token_count) = new.prompt_token_count {
341 usage.prompt_token_count = Some(prompt_token_count);
342 }
343 if let Some(cached_content_token_count) = new.cached_content_token_count {
344 usage.cached_content_token_count = Some(cached_content_token_count);
345 }
346 if let Some(candidates_token_count) = new.candidates_token_count {
347 usage.candidates_token_count = Some(candidates_token_count);
348 }
349 if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
350 usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
351 }
352 if let Some(thoughts_token_count) = new.thoughts_token_count {
353 usage.thoughts_token_count = Some(thoughts_token_count);
354 }
355 if let Some(total_token_count) = new.total_token_count {
356 usage.total_token_count = Some(total_token_count);
357 }
358}
359
360fn convert_usage(usage: &UsageMetadata) -> TokenUsage {
361 let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
362 let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
363 let input_tokens = prompt_tokens - cached_tokens;
364 let output_tokens = usage.candidates_token_count.unwrap_or(0);
365
366 TokenUsage {
367 input_tokens,
368 output_tokens,
369 cache_read_input_tokens: cached_tokens,
370 cache_creation_input_tokens: 0,
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use crate::{
378 Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
379 Part, Role as GoogleRole,
380 };
381 use serde_json::json;
382
383 #[test]
384 fn test_function_call_with_signature_creates_tool_use_with_signature() {
385 let mut mapper = GoogleEventMapper::new();
386
387 let response = GenerateContentResponse {
388 candidates: Some(vec![GenerateContentCandidate {
389 index: Some(0),
390 content: Content {
391 parts: vec![Part::FunctionCallPart(FunctionCallPart {
392 function_call: FunctionCall {
393 name: "test_function".to_string(),
394 args: json!({"arg": "value"}),
395 },
396 thought_signature: Some("test_signature_123".to_string()),
397 })],
398 role: GoogleRole::Model,
399 },
400 finish_reason: None,
401 finish_message: None,
402 safety_ratings: None,
403 citation_metadata: None,
404 }]),
405 prompt_feedback: None,
406 usage_metadata: None,
407 };
408
409 let events = mapper.map_event(response);
410 assert_eq!(events.len(), 2);
411
412 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
413 assert_eq!(tool_use.name.as_ref(), "test_function");
414 assert_eq!(
415 tool_use.thought_signature.as_deref(),
416 Some("test_signature_123")
417 );
418 } else {
419 panic!("Expected ToolUse event");
420 }
421 }
422
423 #[test]
424 fn test_function_call_without_signature_has_none() {
425 let mut mapper = GoogleEventMapper::new();
426
427 let response = GenerateContentResponse {
428 candidates: Some(vec![GenerateContentCandidate {
429 index: Some(0),
430 content: Content {
431 parts: vec![Part::FunctionCallPart(FunctionCallPart {
432 function_call: FunctionCall {
433 name: "test_function".to_string(),
434 args: json!({"arg": "value"}),
435 },
436 thought_signature: None,
437 })],
438 role: GoogleRole::Model,
439 },
440 finish_reason: None,
441 finish_message: None,
442 safety_ratings: None,
443 citation_metadata: None,
444 }]),
445 prompt_feedback: None,
446 usage_metadata: None,
447 };
448
449 let events = mapper.map_event(response);
450 assert_eq!(events.len(), 2);
451
452 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
453 assert!(tool_use.thought_signature.is_none());
454 } else {
455 panic!("Expected ToolUse event");
456 }
457 }
458
459 #[test]
460 fn test_empty_string_signature_normalized_to_none() {
461 let mut mapper = GoogleEventMapper::new();
462
463 let response = GenerateContentResponse {
464 candidates: Some(vec![GenerateContentCandidate {
465 index: Some(0),
466 content: Content {
467 parts: vec![Part::FunctionCallPart(FunctionCallPart {
468 function_call: FunctionCall {
469 name: "test_function".to_string(),
470 args: json!({"arg": "value"}),
471 },
472 thought_signature: Some("".to_string()),
473 })],
474 role: GoogleRole::Model,
475 },
476 finish_reason: None,
477 finish_message: None,
478 safety_ratings: None,
479 citation_metadata: None,
480 }]),
481 prompt_feedback: None,
482 usage_metadata: None,
483 };
484
485 let events = mapper.map_event(response);
486 if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
487 assert!(tool_use.thought_signature.is_none());
488 } else {
489 panic!("Expected ToolUse event");
490 }
491 }
492}