completion.rs

  1use anyhow::Result;
  2use futures::{Stream, StreamExt};
  3use language_model_core::{
  4    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
  5    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
  6    StopReason, TokenUsage,
  7};
  8use std::pin::Pin;
  9use std::sync::Arc;
 10use std::sync::atomic::{self, AtomicU64};
 11
 12use crate::{
 13    Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration,
 14    GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode,
 15    InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig,
 16    UsageMetadata,
 17};
 18
 19pub fn into_google(
 20    mut request: LanguageModelRequest,
 21    model_id: String,
 22    mode: GoogleModelMode,
 23) -> crate::GenerateContentRequest {
 24    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
 25        content
 26            .into_iter()
 27            .flat_map(|content| match content {
 28                MessageContent::Text(text) => {
 29                    if !text.is_empty() {
 30                        vec![Part::TextPart(TextPart { text })]
 31                    } else {
 32                        vec![]
 33                    }
 34                }
 35                MessageContent::Thinking {
 36                    text: _,
 37                    signature: Some(signature),
 38                } => {
 39                    if !signature.is_empty() {
 40                        vec![Part::ThoughtPart(crate::ThoughtPart {
 41                            thought: true,
 42                            thought_signature: signature,
 43                        })]
 44                    } else {
 45                        vec![]
 46                    }
 47                }
 48                MessageContent::Thinking { .. } => {
 49                    vec![]
 50                }
 51                MessageContent::RedactedThinking(_) => vec![],
 52                MessageContent::Image(image) => {
 53                    vec![Part::InlineDataPart(InlineDataPart {
 54                        inline_data: GenerativeContentBlob {
 55                            mime_type: "image/png".to_string(),
 56                            data: image.source.to_string(),
 57                        },
 58                    })]
 59                }
 60                MessageContent::ToolUse(tool_use) => {
 61                    // Normalize empty string signatures to None
 62                    let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
 63
 64                    vec![Part::FunctionCallPart(crate::FunctionCallPart {
 65                        function_call: crate::FunctionCall {
 66                            name: tool_use.name.to_string(),
 67                            args: tool_use.input,
 68                        },
 69                        thought_signature,
 70                    })]
 71                }
 72                MessageContent::ToolResult(tool_result) => {
 73                    let mut text_output = String::new();
 74                    let mut images: Vec<InlineDataPart> = Vec::new();
 75                    for part in tool_result.content {
 76                        match part {
 77                            language_model_core::LanguageModelToolResultContent::Text(text) => {
 78                                text_output.push_str(&text);
 79                            }
 80                            language_model_core::LanguageModelToolResultContent::Image(image) => {
 81                                images.push(InlineDataPart {
 82                                    inline_data: GenerativeContentBlob {
 83                                        mime_type: "image/png".to_string(),
 84                                        data: image.source.to_string(),
 85                                    },
 86                                });
 87                            }
 88                        }
 89                    }
 90                    let output = if text_output.is_empty() && !images.is_empty() {
 91                        "Tool responded with an image".to_string()
 92                    } else {
 93                        text_output
 94                    };
 95                    let mut parts = vec![Part::FunctionResponsePart(crate::FunctionResponsePart {
 96                        function_response: crate::FunctionResponse {
 97                            name: tool_result.tool_name.to_string(),
 98                            // The API expects a valid JSON object
 99                            response: serde_json::json!({
100                                "output": output
101                            }),
102                        },
103                    })];
104                    parts.extend(images.into_iter().map(Part::InlineDataPart));
105                    parts
106                }
107            })
108            .collect()
109    }
110
111    let system_instructions = if request
112        .messages
113        .first()
114        .is_some_and(|msg| matches!(msg.role, Role::System))
115    {
116        let message = request.messages.remove(0);
117        Some(SystemInstruction {
118            parts: map_content(message.content),
119        })
120    } else {
121        None
122    };
123
124    crate::GenerateContentRequest {
125        model: ModelName { model_id },
126        system_instruction: system_instructions,
127        contents: request
128            .messages
129            .into_iter()
130            .filter_map(|message| {
131                let parts = map_content(message.content);
132                if parts.is_empty() {
133                    None
134                } else {
135                    Some(Content {
136                        parts,
137                        role: match message.role {
138                            Role::User => crate::Role::User,
139                            Role::Assistant => crate::Role::Model,
140                            Role::System => crate::Role::User, // Google AI doesn't have a system role
141                        },
142                    })
143                }
144            })
145            .collect(),
146        generation_config: Some(GenerationConfig {
147            candidate_count: Some(1),
148            stop_sequences: Some(request.stop),
149            max_output_tokens: None,
150            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
151            thinking_config: match (request.thinking_allowed, mode) {
152                (true, GoogleModelMode::Thinking { budget_tokens }) => {
153                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
154                }
155                _ => None,
156            },
157            top_p: None,
158            top_k: None,
159        }),
160        safety_settings: None,
161        tools: (!request.tools.is_empty()).then(|| {
162            vec![crate::Tool {
163                function_declarations: request
164                    .tools
165                    .into_iter()
166                    .map(|tool| FunctionDeclaration {
167                        name: tool.name,
168                        description: tool.description,
169                        parameters: tool.input_schema,
170                    })
171                    .collect(),
172            }]
173        }),
174        tool_config: request.tool_choice.map(|choice| ToolConfig {
175            function_calling_config: FunctionCallingConfig {
176                mode: match choice {
177                    LanguageModelToolChoice::Auto => FunctionCallingMode::Auto,
178                    LanguageModelToolChoice::Any => FunctionCallingMode::Any,
179                    LanguageModelToolChoice::None => FunctionCallingMode::None,
180                },
181                allowed_function_names: None,
182            },
183        }),
184    }
185}
186
187pub struct GoogleEventMapper {
188    usage: UsageMetadata,
189    stop_reason: StopReason,
190}
191
192impl GoogleEventMapper {
193    pub fn new() -> Self {
194        Self {
195            usage: UsageMetadata::default(),
196            stop_reason: StopReason::EndTurn,
197        }
198    }
199
200    pub fn map_stream(
201        mut self,
202        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
203    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
204    {
205        events
206            .map(Some)
207            .chain(futures::stream::once(async { None }))
208            .flat_map(move |event| {
209                futures::stream::iter(match event {
210                    Some(Ok(event)) => self.map_event(event),
211                    Some(Err(error)) => {
212                        vec![Err(LanguageModelCompletionError::from(error))]
213                    }
214                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
215                })
216            })
217    }
218
219    pub fn map_event(
220        &mut self,
221        event: GenerateContentResponse,
222    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
223        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
224
225        let mut events: Vec<_> = Vec::new();
226        let mut wants_to_use_tool = false;
227        if let Some(usage_metadata) = event.usage_metadata {
228            update_usage(&mut self.usage, &usage_metadata);
229            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
230                convert_usage(&self.usage),
231            )))
232        }
233
234        if let Some(prompt_feedback) = event.prompt_feedback
235            && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
236        {
237            self.stop_reason = match block_reason {
238                "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
239                    StopReason::Refusal
240                }
241                _ => {
242                    log::error!("Unexpected Google block_reason: {block_reason}");
243                    StopReason::Refusal
244                }
245            };
246            events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
247
248            return events;
249        }
250
251        if let Some(candidates) = event.candidates {
252            for candidate in candidates {
253                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
254                    self.stop_reason = match finish_reason {
255                        "STOP" => StopReason::EndTurn,
256                        "MAX_TOKENS" => StopReason::MaxTokens,
257                        _ => {
258                            log::error!("Unexpected google finish_reason: {finish_reason}");
259                            StopReason::EndTurn
260                        }
261                    };
262                }
263                candidate
264                    .content
265                    .parts
266                    .into_iter()
267                    .for_each(|part| match part {
268                        Part::TextPart(text_part) => {
269                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
270                        }
271                        Part::InlineDataPart(_) => {}
272                        Part::FunctionCallPart(function_call_part) => {
273                            wants_to_use_tool = true;
274                            let name: Arc<str> = function_call_part.function_call.name.into();
275                            let next_tool_id =
276                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
277                            let id: LanguageModelToolUseId =
278                                format!("{}-{}", name, next_tool_id).into();
279
280                            // Normalize empty string signatures to None
281                            let thought_signature = function_call_part
282                                .thought_signature
283                                .filter(|s| !s.is_empty());
284
285                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
286                                LanguageModelToolUse {
287                                    id,
288                                    name,
289                                    is_input_complete: true,
290                                    raw_input: function_call_part.function_call.args.to_string(),
291                                    input: function_call_part.function_call.args,
292                                    thought_signature,
293                                },
294                            )));
295                        }
296                        Part::FunctionResponsePart(_) => {}
297                        Part::ThoughtPart(part) => {
298                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
299                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
300                                signature: Some(part.thought_signature),
301                            }));
302                        }
303                    });
304            }
305        }
306
307        // Even when Gemini wants to use a Tool, the API
308        // responds with `finish_reason: STOP`
309        if wants_to_use_tool {
310            self.stop_reason = StopReason::ToolUse;
311            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
312        }
313        events
314    }
315}
316
317fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
318    if let Some(prompt_token_count) = new.prompt_token_count {
319        usage.prompt_token_count = Some(prompt_token_count);
320    }
321    if let Some(cached_content_token_count) = new.cached_content_token_count {
322        usage.cached_content_token_count = Some(cached_content_token_count);
323    }
324    if let Some(candidates_token_count) = new.candidates_token_count {
325        usage.candidates_token_count = Some(candidates_token_count);
326    }
327    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
328        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
329    }
330    if let Some(thoughts_token_count) = new.thoughts_token_count {
331        usage.thoughts_token_count = Some(thoughts_token_count);
332    }
333    if let Some(total_token_count) = new.total_token_count {
334        usage.total_token_count = Some(total_token_count);
335    }
336}
337
338fn convert_usage(usage: &UsageMetadata) -> TokenUsage {
339    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
340    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
341    let input_tokens = prompt_tokens - cached_tokens;
342    let output_tokens = usage.candidates_token_count.unwrap_or(0);
343
344    TokenUsage {
345        input_tokens,
346        output_tokens,
347        cache_read_input_tokens: cached_tokens,
348        cache_creation_input_tokens: 0,
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use crate::{
356        Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
357        Part, Role as GoogleRole,
358    };
359    use serde_json::json;
360
361    #[test]
362    fn test_function_call_with_signature_creates_tool_use_with_signature() {
363        let mut mapper = GoogleEventMapper::new();
364
365        let response = GenerateContentResponse {
366            candidates: Some(vec![GenerateContentCandidate {
367                index: Some(0),
368                content: Content {
369                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
370                        function_call: FunctionCall {
371                            name: "test_function".to_string(),
372                            args: json!({"arg": "value"}),
373                        },
374                        thought_signature: Some("test_signature_123".to_string()),
375                    })],
376                    role: GoogleRole::Model,
377                },
378                finish_reason: None,
379                finish_message: None,
380                safety_ratings: None,
381                citation_metadata: None,
382            }]),
383            prompt_feedback: None,
384            usage_metadata: None,
385        };
386
387        let events = mapper.map_event(response);
388        assert_eq!(events.len(), 2);
389
390        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
391            assert_eq!(tool_use.name.as_ref(), "test_function");
392            assert_eq!(
393                tool_use.thought_signature.as_deref(),
394                Some("test_signature_123")
395            );
396        } else {
397            panic!("Expected ToolUse event");
398        }
399    }
400
401    #[test]
402    fn test_function_call_without_signature_has_none() {
403        let mut mapper = GoogleEventMapper::new();
404
405        let response = GenerateContentResponse {
406            candidates: Some(vec![GenerateContentCandidate {
407                index: Some(0),
408                content: Content {
409                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
410                        function_call: FunctionCall {
411                            name: "test_function".to_string(),
412                            args: json!({"arg": "value"}),
413                        },
414                        thought_signature: None,
415                    })],
416                    role: GoogleRole::Model,
417                },
418                finish_reason: None,
419                finish_message: None,
420                safety_ratings: None,
421                citation_metadata: None,
422            }]),
423            prompt_feedback: None,
424            usage_metadata: None,
425        };
426
427        let events = mapper.map_event(response);
428        assert_eq!(events.len(), 2);
429
430        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
431            assert!(tool_use.thought_signature.is_none());
432        } else {
433            panic!("Expected ToolUse event");
434        }
435    }
436
437    #[test]
438    fn test_empty_string_signature_normalized_to_none() {
439        let mut mapper = GoogleEventMapper::new();
440
441        let response = GenerateContentResponse {
442            candidates: Some(vec![GenerateContentCandidate {
443                index: Some(0),
444                content: Content {
445                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
446                        function_call: FunctionCall {
447                            name: "test_function".to_string(),
448                            args: json!({"arg": "value"}),
449                        },
450                        thought_signature: Some("".to_string()),
451                    })],
452                    role: GoogleRole::Model,
453                },
454                finish_reason: None,
455                finish_message: None,
456                safety_ratings: None,
457                citation_metadata: None,
458            }]),
459            prompt_feedback: None,
460            usage_metadata: None,
461        };
462
463        let events = mapper.map_event(response);
464        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
465            assert!(tool_use.thought_signature.is_none());
466        } else {
467            panic!("Expected ToolUse event");
468        }
469    }
470}