completion.rs

  1use anyhow::Result;
  2use futures::{Stream, StreamExt};
  3use language_model_core::{
  4    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
  5    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
  6    StopReason, TokenUsage,
  7};
  8use std::pin::Pin;
  9use std::sync::Arc;
 10use std::sync::atomic::{self, AtomicU64};
 11
 12use crate::{
 13    Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration,
 14    GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode,
 15    InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig,
 16    UsageMetadata,
 17};
 18
 19pub fn into_google(
 20    mut request: LanguageModelRequest,
 21    model_id: String,
 22    mode: GoogleModelMode,
 23) -> crate::GenerateContentRequest {
 24    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
 25        content
 26            .into_iter()
 27            .flat_map(|content| match content {
 28                MessageContent::Text(text) => {
 29                    if !text.is_empty() {
 30                        vec![Part::TextPart(TextPart { text })]
 31                    } else {
 32                        vec![]
 33                    }
 34                }
 35                MessageContent::Thinking {
 36                    text: _,
 37                    signature: Some(signature),
 38                } => {
 39                    if !signature.is_empty() {
 40                        vec![Part::ThoughtPart(crate::ThoughtPart {
 41                            thought: true,
 42                            thought_signature: signature,
 43                        })]
 44                    } else {
 45                        vec![]
 46                    }
 47                }
 48                MessageContent::Thinking { .. } => {
 49                    vec![]
 50                }
 51                MessageContent::RedactedThinking(_) => vec![],
 52                MessageContent::Image(image) => {
 53                    vec![Part::InlineDataPart(InlineDataPart {
 54                        inline_data: GenerativeContentBlob {
 55                            mime_type: "image/png".to_string(),
 56                            data: image.source.to_string(),
 57                        },
 58                    })]
 59                }
 60                MessageContent::ToolUse(tool_use) => {
 61                    // Normalize empty string signatures to None
 62                    let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
 63
 64                    vec![Part::FunctionCallPart(crate::FunctionCallPart {
 65                        function_call: crate::FunctionCall {
 66                            name: tool_use.name.to_string(),
 67                            args: tool_use.input,
 68                        },
 69                        thought_signature,
 70                    })]
 71                }
 72                MessageContent::ToolResult(tool_result) => {
 73                    match tool_result.content {
 74                        language_model_core::LanguageModelToolResultContent::Text(text) => {
 75                            vec![Part::FunctionResponsePart(crate::FunctionResponsePart {
 76                                function_response: crate::FunctionResponse {
 77                                    name: tool_result.tool_name.to_string(),
 78                                    // The API expects a valid JSON object
 79                                    response: serde_json::json!({
 80                                        "output": text
 81                                    }),
 82                                },
 83                            })]
 84                        }
 85                        language_model_core::LanguageModelToolResultContent::Image(image) => {
 86                            vec![
 87                                Part::FunctionResponsePart(crate::FunctionResponsePart {
 88                                    function_response: crate::FunctionResponse {
 89                                        name: tool_result.tool_name.to_string(),
 90                                        // The API expects a valid JSON object
 91                                        response: serde_json::json!({
 92                                            "output": "Tool responded with an image"
 93                                        }),
 94                                    },
 95                                }),
 96                                Part::InlineDataPart(InlineDataPart {
 97                                    inline_data: GenerativeContentBlob {
 98                                        mime_type: "image/png".to_string(),
 99                                        data: image.source.to_string(),
100                                    },
101                                }),
102                            ]
103                        }
104                    }
105                }
106            })
107            .collect()
108    }
109
110    let system_instructions = if request
111        .messages
112        .first()
113        .is_some_and(|msg| matches!(msg.role, Role::System))
114    {
115        let message = request.messages.remove(0);
116        Some(SystemInstruction {
117            parts: map_content(message.content),
118        })
119    } else {
120        None
121    };
122
123    crate::GenerateContentRequest {
124        model: ModelName { model_id },
125        system_instruction: system_instructions,
126        contents: request
127            .messages
128            .into_iter()
129            .filter_map(|message| {
130                let parts = map_content(message.content);
131                if parts.is_empty() {
132                    None
133                } else {
134                    Some(Content {
135                        parts,
136                        role: match message.role {
137                            Role::User => crate::Role::User,
138                            Role::Assistant => crate::Role::Model,
139                            Role::System => crate::Role::User, // Google AI doesn't have a system role
140                        },
141                    })
142                }
143            })
144            .collect(),
145        generation_config: Some(GenerationConfig {
146            candidate_count: Some(1),
147            stop_sequences: Some(request.stop),
148            max_output_tokens: None,
149            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
150            thinking_config: match (request.thinking_allowed, mode) {
151                (true, GoogleModelMode::Thinking { budget_tokens }) => {
152                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
153                }
154                _ => None,
155            },
156            top_p: None,
157            top_k: None,
158        }),
159        safety_settings: None,
160        tools: (!request.tools.is_empty()).then(|| {
161            vec![crate::Tool {
162                function_declarations: request
163                    .tools
164                    .into_iter()
165                    .map(|tool| FunctionDeclaration {
166                        name: tool.name,
167                        description: tool.description,
168                        parameters: tool.input_schema,
169                    })
170                    .collect(),
171            }]
172        }),
173        tool_config: request.tool_choice.map(|choice| ToolConfig {
174            function_calling_config: FunctionCallingConfig {
175                mode: match choice {
176                    LanguageModelToolChoice::Auto => FunctionCallingMode::Auto,
177                    LanguageModelToolChoice::Any => FunctionCallingMode::Any,
178                    LanguageModelToolChoice::None => FunctionCallingMode::None,
179                },
180                allowed_function_names: None,
181            },
182        }),
183    }
184}
185
186pub struct GoogleEventMapper {
187    usage: UsageMetadata,
188    stop_reason: StopReason,
189}
190
191impl GoogleEventMapper {
192    pub fn new() -> Self {
193        Self {
194            usage: UsageMetadata::default(),
195            stop_reason: StopReason::EndTurn,
196        }
197    }
198
199    pub fn map_stream(
200        mut self,
201        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
202    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
203    {
204        events
205            .map(Some)
206            .chain(futures::stream::once(async { None }))
207            .flat_map(move |event| {
208                futures::stream::iter(match event {
209                    Some(Ok(event)) => self.map_event(event),
210                    Some(Err(error)) => {
211                        vec![Err(LanguageModelCompletionError::from(error))]
212                    }
213                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
214                })
215            })
216    }
217
218    pub fn map_event(
219        &mut self,
220        event: GenerateContentResponse,
221    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
222        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
223
224        let mut events: Vec<_> = Vec::new();
225        let mut wants_to_use_tool = false;
226        if let Some(usage_metadata) = event.usage_metadata {
227            update_usage(&mut self.usage, &usage_metadata);
228            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
229                convert_usage(&self.usage),
230            )))
231        }
232
233        if let Some(prompt_feedback) = event.prompt_feedback
234            && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
235        {
236            self.stop_reason = match block_reason {
237                "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
238                    StopReason::Refusal
239                }
240                _ => {
241                    log::error!("Unexpected Google block_reason: {block_reason}");
242                    StopReason::Refusal
243                }
244            };
245            events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
246
247            return events;
248        }
249
250        if let Some(candidates) = event.candidates {
251            for candidate in candidates {
252                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
253                    self.stop_reason = match finish_reason {
254                        "STOP" => StopReason::EndTurn,
255                        "MAX_TOKENS" => StopReason::MaxTokens,
256                        _ => {
257                            log::error!("Unexpected google finish_reason: {finish_reason}");
258                            StopReason::EndTurn
259                        }
260                    };
261                }
262                candidate
263                    .content
264                    .parts
265                    .into_iter()
266                    .for_each(|part| match part {
267                        Part::TextPart(text_part) => {
268                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
269                        }
270                        Part::InlineDataPart(_) => {}
271                        Part::FunctionCallPart(function_call_part) => {
272                            wants_to_use_tool = true;
273                            let name: Arc<str> = function_call_part.function_call.name.into();
274                            let next_tool_id =
275                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
276                            let id: LanguageModelToolUseId =
277                                format!("{}-{}", name, next_tool_id).into();
278
279                            // Normalize empty string signatures to None
280                            let thought_signature = function_call_part
281                                .thought_signature
282                                .filter(|s| !s.is_empty());
283
284                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
285                                LanguageModelToolUse {
286                                    id,
287                                    name,
288                                    is_input_complete: true,
289                                    raw_input: function_call_part.function_call.args.to_string(),
290                                    input: function_call_part.function_call.args,
291                                    thought_signature,
292                                },
293                            )));
294                        }
295                        Part::FunctionResponsePart(_) => {}
296                        Part::ThoughtPart(part) => {
297                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
298                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
299                                signature: Some(part.thought_signature),
300                            }));
301                        }
302                    });
303            }
304        }
305
306        // Even when Gemini wants to use a Tool, the API
307        // responds with `finish_reason: STOP`
308        if wants_to_use_tool {
309            self.stop_reason = StopReason::ToolUse;
310            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
311        }
312        events
313    }
314}
315
316fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
317    if let Some(prompt_token_count) = new.prompt_token_count {
318        usage.prompt_token_count = Some(prompt_token_count);
319    }
320    if let Some(cached_content_token_count) = new.cached_content_token_count {
321        usage.cached_content_token_count = Some(cached_content_token_count);
322    }
323    if let Some(candidates_token_count) = new.candidates_token_count {
324        usage.candidates_token_count = Some(candidates_token_count);
325    }
326    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
327        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
328    }
329    if let Some(thoughts_token_count) = new.thoughts_token_count {
330        usage.thoughts_token_count = Some(thoughts_token_count);
331    }
332    if let Some(total_token_count) = new.total_token_count {
333        usage.total_token_count = Some(total_token_count);
334    }
335}
336
337fn convert_usage(usage: &UsageMetadata) -> TokenUsage {
338    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
339    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
340    let input_tokens = prompt_tokens - cached_tokens;
341    let output_tokens = usage.candidates_token_count.unwrap_or(0);
342
343    TokenUsage {
344        input_tokens,
345        output_tokens,
346        cache_read_input_tokens: cached_tokens,
347        cache_creation_input_tokens: 0,
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use crate::{
355        Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
356        Part, Role as GoogleRole,
357    };
358    use serde_json::json;
359
360    #[test]
361    fn test_function_call_with_signature_creates_tool_use_with_signature() {
362        let mut mapper = GoogleEventMapper::new();
363
364        let response = GenerateContentResponse {
365            candidates: Some(vec![GenerateContentCandidate {
366                index: Some(0),
367                content: Content {
368                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
369                        function_call: FunctionCall {
370                            name: "test_function".to_string(),
371                            args: json!({"arg": "value"}),
372                        },
373                        thought_signature: Some("test_signature_123".to_string()),
374                    })],
375                    role: GoogleRole::Model,
376                },
377                finish_reason: None,
378                finish_message: None,
379                safety_ratings: None,
380                citation_metadata: None,
381            }]),
382            prompt_feedback: None,
383            usage_metadata: None,
384        };
385
386        let events = mapper.map_event(response);
387        assert_eq!(events.len(), 2);
388
389        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
390            assert_eq!(tool_use.name.as_ref(), "test_function");
391            assert_eq!(
392                tool_use.thought_signature.as_deref(),
393                Some("test_signature_123")
394            );
395        } else {
396            panic!("Expected ToolUse event");
397        }
398    }
399
400    #[test]
401    fn test_function_call_without_signature_has_none() {
402        let mut mapper = GoogleEventMapper::new();
403
404        let response = GenerateContentResponse {
405            candidates: Some(vec![GenerateContentCandidate {
406                index: Some(0),
407                content: Content {
408                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
409                        function_call: FunctionCall {
410                            name: "test_function".to_string(),
411                            args: json!({"arg": "value"}),
412                        },
413                        thought_signature: None,
414                    })],
415                    role: GoogleRole::Model,
416                },
417                finish_reason: None,
418                finish_message: None,
419                safety_ratings: None,
420                citation_metadata: None,
421            }]),
422            prompt_feedback: None,
423            usage_metadata: None,
424        };
425
426        let events = mapper.map_event(response);
427        assert_eq!(events.len(), 2);
428
429        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
430            assert!(tool_use.thought_signature.is_none());
431        } else {
432            panic!("Expected ToolUse event");
433        }
434    }
435
436    #[test]
437    fn test_empty_string_signature_normalized_to_none() {
438        let mut mapper = GoogleEventMapper::new();
439
440        let response = GenerateContentResponse {
441            candidates: Some(vec![GenerateContentCandidate {
442                index: Some(0),
443                content: Content {
444                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
445                        function_call: FunctionCall {
446                            name: "test_function".to_string(),
447                            args: json!({"arg": "value"}),
448                        },
449                        thought_signature: Some("".to_string()),
450                    })],
451                    role: GoogleRole::Model,
452                },
453                finish_reason: None,
454                finish_message: None,
455                safety_ratings: None,
456                citation_metadata: None,
457            }]),
458            prompt_feedback: None,
459            usage_metadata: None,
460        };
461
462        let events = mapper.map_event(response);
463        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
464            assert!(tool_use.thought_signature.is_none());
465        } else {
466            panic!("Expected ToolUse event");
467        }
468    }
469}