completion.rs

  1use anyhow::Result;
  2use futures::{Stream, StreamExt};
  3use language_model_core::{
  4    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
  5    LanguageModelToolChoice, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
  6    StopReason, TokenUsage,
  7};
  8use std::pin::Pin;
  9use std::sync::Arc;
 10use std::sync::atomic::{self, AtomicU64};
 11
 12use crate::{
 13    Content, FunctionCallingConfig, FunctionCallingMode, FunctionDeclaration,
 14    GenerateContentResponse, GenerationConfig, GenerativeContentBlob, GoogleModelMode,
 15    InlineDataPart, ModelName, Part, SystemInstruction, TextPart, ThinkingConfig, ToolConfig,
 16    UsageMetadata,
 17};
 18
 19pub fn into_google(
 20    mut request: LanguageModelRequest,
 21    model_id: String,
 22    mode: GoogleModelMode,
 23) -> crate::GenerateContentRequest {
 24    fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
 25        content
 26            .into_iter()
 27            .flat_map(|content| match content {
 28                MessageContent::Text(text) => {
 29                    if !text.is_empty() {
 30                        vec![Part::TextPart(TextPart { text })]
 31                    } else {
 32                        vec![]
 33                    }
 34                }
 35                MessageContent::Thinking {
 36                    text: _,
 37                    signature: Some(signature),
 38                } => {
 39                    if !signature.is_empty() {
 40                        vec![Part::ThoughtPart(crate::ThoughtPart {
 41                            thought: true,
 42                            thought_signature: signature,
 43                        })]
 44                    } else {
 45                        vec![]
 46                    }
 47                }
 48                MessageContent::Thinking { .. } => {
 49                    vec![]
 50                }
 51                MessageContent::RedactedThinking(_) => vec![],
 52                MessageContent::Image(image) => {
 53                    vec![Part::InlineDataPart(InlineDataPart {
 54                        inline_data: GenerativeContentBlob {
 55                            mime_type: "image/png".to_string(),
 56                            data: image.source.to_string(),
 57                        },
 58                    })]
 59                }
 60                MessageContent::ToolUse(tool_use) => {
 61                    // Normalize empty string signatures to None
 62                    let thought_signature = tool_use.thought_signature.filter(|s| !s.is_empty());
 63
 64                    vec![Part::FunctionCallPart(crate::FunctionCallPart {
 65                        function_call: crate::FunctionCall {
 66                            name: tool_use.name.to_string(),
 67                            args: tool_use.input,
 68                        },
 69                        thought_signature,
 70                    })]
 71                }
 72                MessageContent::ToolResult(tool_result) => {
 73                    match tool_result.content {
 74                        language_model_core::LanguageModelToolResultContent::Text(text) => {
 75                            vec![Part::FunctionResponsePart(crate::FunctionResponsePart {
 76                                function_response: crate::FunctionResponse {
 77                                    name: tool_result.tool_name.to_string(),
 78                                    // The API expects a valid JSON object
 79                                    response: serde_json::json!({
 80                                        "output": text
 81                                    }),
 82                                },
 83                            })]
 84                        }
 85                        language_model_core::LanguageModelToolResultContent::Image(image) => {
 86                            vec![
 87                                Part::FunctionResponsePart(crate::FunctionResponsePart {
 88                                    function_response: crate::FunctionResponse {
 89                                        name: tool_result.tool_name.to_string(),
 90                                        // The API expects a valid JSON object
 91                                        response: serde_json::json!({
 92                                            "output": "Tool responded with an image"
 93                                        }),
 94                                    },
 95                                }),
 96                                Part::InlineDataPart(InlineDataPart {
 97                                    inline_data: GenerativeContentBlob {
 98                                        mime_type: "image/png".to_string(),
 99                                        data: image.source.to_string(),
100                                    },
101                                }),
102                            ]
103                        }
104                    }
105                }
106            })
107            .collect()
108    }
109
110    let system_instructions = if request
111        .messages
112        .first()
113        .is_some_and(|msg| matches!(msg.role, Role::System))
114    {
115        let message = request.messages.remove(0);
116        Some(SystemInstruction {
117            parts: map_content(message.content),
118        })
119    } else {
120        None
121    };
122
123    crate::GenerateContentRequest {
124        model: ModelName { model_id },
125        system_instruction: system_instructions,
126        contents: request
127            .messages
128            .into_iter()
129            .filter_map(|message| {
130                let parts = map_content(message.content);
131                if parts.is_empty() {
132                    None
133                } else {
134                    Some(Content {
135                        parts,
136                        role: match message.role {
137                            Role::User => crate::Role::User,
138                            Role::Assistant => crate::Role::Model,
139                            Role::System => crate::Role::User, // Google AI doesn't have a system role
140                        },
141                    })
142                }
143            })
144            .collect(),
145        generation_config: Some(GenerationConfig {
146            candidate_count: Some(1),
147            stop_sequences: Some(request.stop),
148            max_output_tokens: None,
149            temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
150            thinking_config: match (request.thinking_allowed, mode) {
151                (true, GoogleModelMode::Thinking { budget_tokens }) => {
152                    budget_tokens.map(|thinking_budget| ThinkingConfig { thinking_budget })
153                }
154                _ => None,
155            },
156            top_p: None,
157            top_k: None,
158        }),
159        safety_settings: None,
160        tools: (!request.tools.is_empty()).then(|| {
161            vec![crate::Tool {
162                function_declarations: request
163                    .tools
164                    .into_iter()
165                    .map(|tool| FunctionDeclaration {
166                        name: tool.name,
167                        description: tool.description,
168                        parameters: tool.input_schema,
169                    })
170                    .collect(),
171            }]
172        }),
173        tool_config: request.tool_choice.map(|choice| ToolConfig {
174            function_calling_config: FunctionCallingConfig {
175                mode: match choice {
176                    LanguageModelToolChoice::Auto => FunctionCallingMode::Auto,
177                    LanguageModelToolChoice::Any => FunctionCallingMode::Any,
178                    LanguageModelToolChoice::None => FunctionCallingMode::None,
179                },
180                allowed_function_names: None,
181            },
182        }),
183    }
184}
185
186pub struct GoogleEventMapper {
187    usage: UsageMetadata,
188    stop_reason: StopReason,
189}
190
191impl GoogleEventMapper {
192    pub fn new() -> Self {
193        Self {
194            usage: UsageMetadata::default(),
195            stop_reason: StopReason::EndTurn,
196        }
197    }
198
199    pub fn map_stream(
200        mut self,
201        events: Pin<Box<dyn Send + Stream<Item = Result<GenerateContentResponse>>>>,
202    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
203    {
204        events
205            .map(Some)
206            .chain(futures::stream::once(async { None }))
207            .flat_map(move |event| {
208                futures::stream::iter(match event {
209                    Some(Ok(event)) => self.map_event(event),
210                    Some(Err(error)) => {
211                        vec![Err(LanguageModelCompletionError::from(error))]
212                    }
213                    None => vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))],
214                })
215            })
216    }
217
218    pub fn map_event(
219        &mut self,
220        event: GenerateContentResponse,
221    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
222        static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
223
224        let mut events: Vec<_> = Vec::new();
225        let mut wants_to_use_tool = false;
226        if let Some(usage_metadata) = event.usage_metadata {
227            update_usage(&mut self.usage, &usage_metadata);
228            events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
229                convert_usage(&self.usage),
230            )))
231        }
232
233        if let Some(prompt_feedback) = event.prompt_feedback
234            && let Some(block_reason) = prompt_feedback.block_reason.as_deref()
235        {
236            self.stop_reason = match block_reason {
237                "SAFETY" | "OTHER" | "BLOCKLIST" | "PROHIBITED_CONTENT" | "IMAGE_SAFETY" => {
238                    StopReason::Refusal
239                }
240                _ => {
241                    log::error!("Unexpected Google block_reason: {block_reason}");
242                    StopReason::Refusal
243                }
244            };
245            events.push(Ok(LanguageModelCompletionEvent::Stop(self.stop_reason)));
246
247            return events;
248        }
249
250        if let Some(candidates) = event.candidates {
251            for candidate in candidates {
252                if let Some(finish_reason) = candidate.finish_reason.as_deref() {
253                    self.stop_reason = match finish_reason {
254                        "STOP" => StopReason::EndTurn,
255                        "MAX_TOKENS" => StopReason::MaxTokens,
256                        _ => {
257                            log::error!("Unexpected google finish_reason: {finish_reason}");
258                            StopReason::EndTurn
259                        }
260                    };
261                }
262                candidate
263                    .content
264                    .parts
265                    .into_iter()
266                    .for_each(|part| match part {
267                        Part::TextPart(text_part) => {
268                            events.push(Ok(LanguageModelCompletionEvent::Text(text_part.text)))
269                        }
270                        Part::InlineDataPart(_) => {}
271                        Part::FunctionCallPart(function_call_part) => {
272                            wants_to_use_tool = true;
273                            let name: Arc<str> = function_call_part.function_call.name.into();
274                            let next_tool_id =
275                                TOOL_CALL_COUNTER.fetch_add(1, atomic::Ordering::SeqCst);
276                            let id: LanguageModelToolUseId =
277                                format!("{}-{}", name, next_tool_id).into();
278
279                            // Normalize empty string signatures to None
280                            let thought_signature = function_call_part
281                                .thought_signature
282                                .filter(|s| !s.is_empty());
283
284                            events.push(Ok(LanguageModelCompletionEvent::ToolUse(
285                                LanguageModelToolUse {
286                                    id,
287                                    name,
288                                    is_input_complete: true,
289                                    raw_input: function_call_part.function_call.args.to_string(),
290                                    input: function_call_part.function_call.args,
291                                    thought_signature,
292                                },
293                            )));
294                        }
295                        Part::FunctionResponsePart(_) => {}
296                        Part::ThoughtPart(part) => {
297                            events.push(Ok(LanguageModelCompletionEvent::Thinking {
298                                text: "(Encrypted thought)".to_string(), // TODO: Can we populate this from thought summaries?
299                                signature: Some(part.thought_signature),
300                            }));
301                        }
302                    });
303            }
304        }
305
306        // Even when Gemini wants to use a Tool, the API
307        // responds with `finish_reason: STOP`
308        if wants_to_use_tool {
309            self.stop_reason = StopReason::ToolUse;
310            events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
311        }
312        events
313    }
314}
315
316/// Count tokens for a Google AI model using tiktoken. This is synchronous;
317/// callers should spawn it on a background thread if needed.
318pub fn count_google_tokens(request: LanguageModelRequest) -> Result<u64> {
319    let messages = request
320        .messages
321        .into_iter()
322        .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
323            role: match message.role {
324                Role::User => "user".into(),
325                Role::Assistant => "assistant".into(),
326                Role::System => "system".into(),
327            },
328            content: Some(message.string_contents()),
329            name: None,
330            function_call: None,
331        })
332        .collect::<Vec<_>>();
333
334    // Tiktoken doesn't yet support these models, so we manually use the
335    // same tokenizer as GPT-4.
336    tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
337}
338
339fn update_usage(usage: &mut UsageMetadata, new: &UsageMetadata) {
340    if let Some(prompt_token_count) = new.prompt_token_count {
341        usage.prompt_token_count = Some(prompt_token_count);
342    }
343    if let Some(cached_content_token_count) = new.cached_content_token_count {
344        usage.cached_content_token_count = Some(cached_content_token_count);
345    }
346    if let Some(candidates_token_count) = new.candidates_token_count {
347        usage.candidates_token_count = Some(candidates_token_count);
348    }
349    if let Some(tool_use_prompt_token_count) = new.tool_use_prompt_token_count {
350        usage.tool_use_prompt_token_count = Some(tool_use_prompt_token_count);
351    }
352    if let Some(thoughts_token_count) = new.thoughts_token_count {
353        usage.thoughts_token_count = Some(thoughts_token_count);
354    }
355    if let Some(total_token_count) = new.total_token_count {
356        usage.total_token_count = Some(total_token_count);
357    }
358}
359
360fn convert_usage(usage: &UsageMetadata) -> TokenUsage {
361    let prompt_tokens = usage.prompt_token_count.unwrap_or(0);
362    let cached_tokens = usage.cached_content_token_count.unwrap_or(0);
363    let input_tokens = prompt_tokens - cached_tokens;
364    let output_tokens = usage.candidates_token_count.unwrap_or(0);
365
366    TokenUsage {
367        input_tokens,
368        output_tokens,
369        cache_read_input_tokens: cached_tokens,
370        cache_creation_input_tokens: 0,
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use crate::{
378        Content, FunctionCall, FunctionCallPart, GenerateContentCandidate, GenerateContentResponse,
379        Part, Role as GoogleRole,
380    };
381    use serde_json::json;
382
383    #[test]
384    fn test_function_call_with_signature_creates_tool_use_with_signature() {
385        let mut mapper = GoogleEventMapper::new();
386
387        let response = GenerateContentResponse {
388            candidates: Some(vec![GenerateContentCandidate {
389                index: Some(0),
390                content: Content {
391                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
392                        function_call: FunctionCall {
393                            name: "test_function".to_string(),
394                            args: json!({"arg": "value"}),
395                        },
396                        thought_signature: Some("test_signature_123".to_string()),
397                    })],
398                    role: GoogleRole::Model,
399                },
400                finish_reason: None,
401                finish_message: None,
402                safety_ratings: None,
403                citation_metadata: None,
404            }]),
405            prompt_feedback: None,
406            usage_metadata: None,
407        };
408
409        let events = mapper.map_event(response);
410        assert_eq!(events.len(), 2);
411
412        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
413            assert_eq!(tool_use.name.as_ref(), "test_function");
414            assert_eq!(
415                tool_use.thought_signature.as_deref(),
416                Some("test_signature_123")
417            );
418        } else {
419            panic!("Expected ToolUse event");
420        }
421    }
422
423    #[test]
424    fn test_function_call_without_signature_has_none() {
425        let mut mapper = GoogleEventMapper::new();
426
427        let response = GenerateContentResponse {
428            candidates: Some(vec![GenerateContentCandidate {
429                index: Some(0),
430                content: Content {
431                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
432                        function_call: FunctionCall {
433                            name: "test_function".to_string(),
434                            args: json!({"arg": "value"}),
435                        },
436                        thought_signature: None,
437                    })],
438                    role: GoogleRole::Model,
439                },
440                finish_reason: None,
441                finish_message: None,
442                safety_ratings: None,
443                citation_metadata: None,
444            }]),
445            prompt_feedback: None,
446            usage_metadata: None,
447        };
448
449        let events = mapper.map_event(response);
450        assert_eq!(events.len(), 2);
451
452        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
453            assert!(tool_use.thought_signature.is_none());
454        } else {
455            panic!("Expected ToolUse event");
456        }
457    }
458
459    #[test]
460    fn test_empty_string_signature_normalized_to_none() {
461        let mut mapper = GoogleEventMapper::new();
462
463        let response = GenerateContentResponse {
464            candidates: Some(vec![GenerateContentCandidate {
465                index: Some(0),
466                content: Content {
467                    parts: vec![Part::FunctionCallPart(FunctionCallPart {
468                        function_call: FunctionCall {
469                            name: "test_function".to_string(),
470                            args: json!({"arg": "value"}),
471                        },
472                        thought_signature: Some("".to_string()),
473                    })],
474                    role: GoogleRole::Model,
475                },
476                finish_reason: None,
477                finish_message: None,
478                safety_ratings: None,
479                citation_metadata: None,
480            }]),
481            prompt_feedback: None,
482            usage_metadata: None,
483        };
484
485        let events = mapper.map_event(response);
486        if let Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) = &events[0] {
487            assert!(tool_use.thought_signature.is_none());
488        } else {
489            panic!("Expected ToolUse event");
490        }
491    }
492}