completion.rs

  1use anyhow::Result;
  2use collections::HashMap;
  3use futures::{Stream, StreamExt};
  4use language_model_core::{
  5    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
  6    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
  7    Role, StopReason, TokenUsage,
  8    util::{fix_streamed_json, parse_tool_arguments},
  9};
 10use std::pin::Pin;
 11use std::str::FromStr;
 12
 13use crate::{
 14    AdaptiveThinkingDisplay, AnthropicError, AnthropicModelMode, CacheControl, CacheControlType,
 15    ContentDelta, Event, ImageSource, Message, RequestContent, ResponseContent, StringOrContents,
 16    Thinking, Tool, ToolChoice, ToolResultContent, ToolResultPart, Usage,
 17};
 18
 19fn to_anthropic_content(content: MessageContent) -> Option<RequestContent> {
 20    match content {
 21        MessageContent::Text(text) => {
 22            let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
 23                text.trim_end().to_string()
 24            } else {
 25                text
 26            };
 27            if !text.is_empty() {
 28                Some(RequestContent::Text {
 29                    text,
 30                    cache_control: None,
 31                })
 32            } else {
 33                None
 34            }
 35        }
 36        MessageContent::Thinking {
 37            text: thinking,
 38            signature,
 39        } => {
 40            if let Some(signature) = signature
 41                && !thinking.is_empty()
 42            {
 43                Some(RequestContent::Thinking {
 44                    thinking,
 45                    signature,
 46                    cache_control: None,
 47                })
 48            } else {
 49                None
 50            }
 51        }
 52        MessageContent::RedactedThinking(data) => {
 53            if !data.is_empty() {
 54                Some(RequestContent::RedactedThinking { data })
 55            } else {
 56                None
 57            }
 58        }
 59        MessageContent::Image(image) => Some(RequestContent::Image {
 60            source: ImageSource {
 61                source_type: "base64".to_string(),
 62                media_type: "image/png".to_string(),
 63                data: image.source.to_string(),
 64            },
 65            cache_control: None,
 66        }),
 67        MessageContent::ToolUse(tool_use) => Some(RequestContent::ToolUse {
 68            id: tool_use.id.to_string(),
 69            name: tool_use.name.to_string(),
 70            input: tool_use.input,
 71            cache_control: None,
 72        }),
 73        MessageContent::ToolResult(tool_result) => {
 74            let content = match tool_result.content.as_slice() {
 75                [LanguageModelToolResultContent::Text(text)] => {
 76                    ToolResultContent::Plain(text.to_string())
 77                }
 78                _ => {
 79                    let parts = tool_result
 80                        .content
 81                        .into_iter()
 82                        .map(|part| match part {
 83                            LanguageModelToolResultContent::Text(text) => ToolResultPart::Text {
 84                                text: text.to_string(),
 85                            },
 86                            LanguageModelToolResultContent::Image(image) => ToolResultPart::Image {
 87                                source: ImageSource {
 88                                    source_type: "base64".to_string(),
 89                                    media_type: "image/png".to_string(),
 90                                    data: image.source.to_string(),
 91                                },
 92                            },
 93                        })
 94                        .collect();
 95                    ToolResultContent::Multipart(parts)
 96                }
 97            };
 98            Some(RequestContent::ToolResult {
 99                tool_use_id: tool_result.tool_use_id.to_string(),
100                is_error: tool_result.is_error,
101                content,
102                cache_control: None,
103            })
104        }
105    }
106}
107
108pub fn into_anthropic(
109    request: LanguageModelRequest,
110    model: String,
111    default_temperature: f32,
112    max_output_tokens: u64,
113    mode: AnthropicModelMode,
114) -> crate::Request {
115    let mut new_messages: Vec<Message> = Vec::new();
116    let mut system_message = String::new();
117
118    for message in request.messages {
119        if message.contents_empty() {
120            continue;
121        }
122
123        match message.role {
124            Role::User | Role::Assistant => {
125                let mut anthropic_message_content: Vec<RequestContent> = message
126                    .content
127                    .into_iter()
128                    .filter_map(to_anthropic_content)
129                    .collect();
130                let anthropic_role = match message.role {
131                    Role::User => crate::Role::User,
132                    Role::Assistant => crate::Role::Assistant,
133                    Role::System => unreachable!("System role should never occur here"),
134                };
135                if anthropic_message_content.is_empty() {
136                    continue;
137                }
138
139                if let Some(last_message) = new_messages.last_mut()
140                    && last_message.role == anthropic_role
141                {
142                    last_message.content.extend(anthropic_message_content);
143                    continue;
144                }
145
146                // Mark the last segment of the message as cached
147                if message.cache {
148                    let cache_control_value = Some(CacheControl {
149                        cache_type: CacheControlType::Ephemeral,
150                    });
151                    for message_content in anthropic_message_content.iter_mut().rev() {
152                        match message_content {
153                            RequestContent::RedactedThinking { .. } => {
154                                // Caching is not possible, fallback to next message
155                            }
156                            RequestContent::Text { cache_control, .. }
157                            | RequestContent::Thinking { cache_control, .. }
158                            | RequestContent::Image { cache_control, .. }
159                            | RequestContent::ToolUse { cache_control, .. }
160                            | RequestContent::ToolResult { cache_control, .. } => {
161                                *cache_control = cache_control_value;
162                                break;
163                            }
164                        }
165                    }
166                }
167
168                new_messages.push(Message {
169                    role: anthropic_role,
170                    content: anthropic_message_content,
171                });
172            }
173            Role::System => {
174                if !system_message.is_empty() {
175                    system_message.push_str("\n\n");
176                }
177                system_message.push_str(&message.string_contents());
178            }
179        }
180    }
181
182    crate::Request {
183        model,
184        messages: new_messages,
185        max_tokens: max_output_tokens,
186        system: if system_message.is_empty() {
187            None
188        } else {
189            Some(StringOrContents::String(system_message))
190        },
191        thinking: if request.thinking_allowed {
192            match mode {
193                AnthropicModelMode::Thinking { budget_tokens } => {
194                    Some(Thinking::Enabled { budget_tokens })
195                }
196                AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive {
197                    display: Some(AdaptiveThinkingDisplay::Summarized),
198                }),
199                AnthropicModelMode::Default => None,
200            }
201        } else {
202            None
203        },
204        tools: request
205            .tools
206            .into_iter()
207            .map(|tool| Tool {
208                name: tool.name,
209                description: tool.description,
210                input_schema: tool.input_schema,
211                eager_input_streaming: tool.use_input_streaming,
212            })
213            .collect(),
214        tool_choice: request.tool_choice.map(|choice| match choice {
215            LanguageModelToolChoice::Auto => ToolChoice::Auto,
216            LanguageModelToolChoice::Any => ToolChoice::Any,
217            LanguageModelToolChoice::None => ToolChoice::None,
218        }),
219        metadata: None,
220        output_config: if request.thinking_allowed
221            && matches!(mode, AnthropicModelMode::AdaptiveThinking)
222        {
223            request.thinking_effort.as_deref().and_then(|effort| {
224                let effort = match effort {
225                    "low" => Some(crate::Effort::Low),
226                    "medium" => Some(crate::Effort::Medium),
227                    "high" => Some(crate::Effort::High),
228                    "max" => Some(crate::Effort::Max),
229                    _ => None,
230                };
231                effort.map(|effort| crate::OutputConfig {
232                    effort: Some(effort),
233                })
234            })
235        } else {
236            None
237        },
238        stop_sequences: Vec::new(),
239        speed: request.speed.map(Into::into),
240        temperature: request.temperature.or(Some(default_temperature)),
241        top_k: None,
242        top_p: None,
243    }
244}
245
246pub struct AnthropicEventMapper {
247    tool_uses_by_index: HashMap<usize, RawToolUse>,
248    usage: Usage,
249    stop_reason: StopReason,
250}
251
252impl AnthropicEventMapper {
253    pub fn new() -> Self {
254        Self {
255            tool_uses_by_index: HashMap::default(),
256            usage: Usage::default(),
257            stop_reason: StopReason::EndTurn,
258        }
259    }
260
261    pub fn map_stream(
262        mut self,
263        events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
264    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
265    {
266        events.flat_map(move |event| {
267            futures::stream::iter(match event {
268                Ok(event) => self.map_event(event),
269                Err(error) => vec![Err(error.into())],
270            })
271        })
272    }
273
274    pub fn map_event(
275        &mut self,
276        event: Event,
277    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
278        match event {
279            Event::ContentBlockStart {
280                index,
281                content_block,
282            } => match content_block {
283                ResponseContent::Text { text } => {
284                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
285                }
286                ResponseContent::Thinking { thinking } => {
287                    vec![Ok(LanguageModelCompletionEvent::Thinking {
288                        text: thinking,
289                        signature: None,
290                    })]
291                }
292                ResponseContent::RedactedThinking { data } => {
293                    vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
294                }
295                ResponseContent::ToolUse { id, name, .. } => {
296                    self.tool_uses_by_index.insert(
297                        index,
298                        RawToolUse {
299                            id,
300                            name,
301                            input_json: String::new(),
302                        },
303                    );
304                    Vec::new()
305                }
306            },
307            Event::ContentBlockDelta { index, delta } => match delta {
308                ContentDelta::TextDelta { text } => {
309                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
310                }
311                ContentDelta::ThinkingDelta { thinking } => {
312                    vec![Ok(LanguageModelCompletionEvent::Thinking {
313                        text: thinking,
314                        signature: None,
315                    })]
316                }
317                ContentDelta::SignatureDelta { signature } => {
318                    vec![Ok(LanguageModelCompletionEvent::Thinking {
319                        text: "".to_string(),
320                        signature: Some(signature),
321                    })]
322                }
323                ContentDelta::InputJsonDelta { partial_json } => {
324                    if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
325                        tool_use.input_json.push_str(&partial_json);
326
327                        // Try to convert invalid (incomplete) JSON into
328                        // valid JSON that serde can accept, e.g. by closing
329                        // unclosed delimiters. This way, we can update the
330                        // UI with whatever has been streamed back so far.
331                        if let Ok(input) =
332                            serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json))
333                        {
334                            return vec![Ok(LanguageModelCompletionEvent::ToolUse(
335                                LanguageModelToolUse {
336                                    id: tool_use.id.clone().into(),
337                                    name: tool_use.name.clone().into(),
338                                    is_input_complete: false,
339                                    raw_input: tool_use.input_json.clone(),
340                                    input,
341                                    thought_signature: None,
342                                },
343                            ))];
344                        }
345                    }
346                    vec![]
347                }
348            },
349            Event::ContentBlockStop { index } => {
350                if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
351                    let input_json = tool_use.input_json.trim();
352                    let event_result = match parse_tool_arguments(input_json) {
353                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
354                            LanguageModelToolUse {
355                                id: tool_use.id.into(),
356                                name: tool_use.name.into(),
357                                is_input_complete: true,
358                                input,
359                                raw_input: tool_use.input_json.clone(),
360                                thought_signature: None,
361                            },
362                        )),
363                        Err(json_parse_err) => {
364                            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
365                                id: tool_use.id.into(),
366                                tool_name: tool_use.name.into(),
367                                raw_input: input_json.into(),
368                                json_parse_error: json_parse_err.to_string(),
369                            })
370                        }
371                    };
372
373                    vec![event_result]
374                } else {
375                    Vec::new()
376                }
377            }
378            Event::MessageStart { message } => {
379                update_usage(&mut self.usage, &message.usage);
380                vec![
381                    Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
382                        &self.usage,
383                    ))),
384                    Ok(LanguageModelCompletionEvent::StartMessage {
385                        message_id: message.id,
386                    }),
387                ]
388            }
389            Event::MessageDelta { delta, usage } => {
390                update_usage(&mut self.usage, &usage);
391                if let Some(stop_reason) = delta.stop_reason.as_deref() {
392                    self.stop_reason = match stop_reason {
393                        "end_turn" => StopReason::EndTurn,
394                        "max_tokens" => StopReason::MaxTokens,
395                        "tool_use" => StopReason::ToolUse,
396                        "refusal" => StopReason::Refusal,
397                        _ => {
398                            log::error!("Unexpected anthropic stop_reason: {stop_reason}");
399                            StopReason::EndTurn
400                        }
401                    };
402                }
403                vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
404                    convert_usage(&self.usage),
405                ))]
406            }
407            Event::MessageStop => {
408                vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
409            }
410            Event::Error { error } => {
411                vec![Err(error.into())]
412            }
413            _ => Vec::new(),
414        }
415    }
416}
417
418struct RawToolUse {
419    id: String,
420    name: String,
421    input_json: String,
422}
423
424/// Updates usage data by preferring counts from `new`.
425fn update_usage(usage: &mut Usage, new: &Usage) {
426    if let Some(input_tokens) = new.input_tokens {
427        usage.input_tokens = Some(input_tokens);
428    }
429    if let Some(output_tokens) = new.output_tokens {
430        usage.output_tokens = Some(output_tokens);
431    }
432    if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
433        usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
434    }
435    if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
436        usage.cache_read_input_tokens = Some(cache_read_input_tokens);
437    }
438}
439
440fn convert_usage(usage: &Usage) -> TokenUsage {
441    TokenUsage {
442        input_tokens: usage.input_tokens.unwrap_or(0),
443        output_tokens: usage.output_tokens.unwrap_or(0),
444        cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
445        cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
446    }
447}
448
449#[cfg(test)]
450mod tests {
451    use super::*;
452    use crate::AnthropicModelMode;
453    use language_model_core::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
454
455    #[test]
456    fn test_cache_control_only_on_last_segment() {
457        let request = LanguageModelRequest {
458            messages: vec![LanguageModelRequestMessage {
459                role: Role::User,
460                content: vec![
461                    MessageContent::Text("Some prompt".to_string()),
462                    MessageContent::Image(LanguageModelImage::empty()),
463                    MessageContent::Image(LanguageModelImage::empty()),
464                    MessageContent::Image(LanguageModelImage::empty()),
465                    MessageContent::Image(LanguageModelImage::empty()),
466                ],
467                cache: true,
468                reasoning_details: None,
469            }],
470            thread_id: None,
471            prompt_id: None,
472            intent: None,
473            stop: vec![],
474            temperature: None,
475            tools: vec![],
476            tool_choice: None,
477            thinking_allowed: true,
478            thinking_effort: None,
479            speed: None,
480        };
481
482        let anthropic_request = into_anthropic(
483            request,
484            "claude-3-5-sonnet".to_string(),
485            0.7,
486            4096,
487            AnthropicModelMode::Default,
488        );
489
490        assert_eq!(anthropic_request.messages.len(), 1);
491
492        let message = &anthropic_request.messages[0];
493        assert_eq!(message.content.len(), 5);
494
495        assert!(matches!(
496            message.content[0],
497            RequestContent::Text {
498                cache_control: None,
499                ..
500            }
501        ));
502        for i in 1..3 {
503            assert!(matches!(
504                message.content[i],
505                RequestContent::Image {
506                    cache_control: None,
507                    ..
508                }
509            ));
510        }
511
512        assert!(matches!(
513            message.content[4],
514            RequestContent::Image {
515                cache_control: Some(CacheControl {
516                    cache_type: CacheControlType::Ephemeral,
517                }),
518                ..
519            }
520        ));
521    }
522
523    fn request_with_assistant_content(assistant_content: Vec<MessageContent>) -> crate::Request {
524        let mut request = LanguageModelRequest {
525            messages: vec![LanguageModelRequestMessage {
526                role: Role::User,
527                content: vec![MessageContent::Text("Hello".to_string())],
528                cache: false,
529                reasoning_details: None,
530            }],
531            thinking_effort: None,
532            thread_id: None,
533            prompt_id: None,
534            intent: None,
535            stop: vec![],
536            temperature: None,
537            tools: vec![],
538            tool_choice: None,
539            thinking_allowed: true,
540            speed: None,
541        };
542        request.messages.push(LanguageModelRequestMessage {
543            role: Role::Assistant,
544            content: assistant_content,
545            cache: false,
546            reasoning_details: None,
547        });
548        into_anthropic(
549            request,
550            "claude-sonnet-4-5".to_string(),
551            1.0,
552            16000,
553            AnthropicModelMode::Thinking {
554                budget_tokens: Some(10000),
555            },
556        )
557    }
558
559    #[test]
560    fn test_unsigned_thinking_blocks_stripped() {
561        let result = request_with_assistant_content(vec![
562            MessageContent::Thinking {
563                text: "Cancelled mid-think, no signature".to_string(),
564                signature: None,
565            },
566            MessageContent::Text("Some response text".to_string()),
567        ]);
568
569        let assistant_message = result
570            .messages
571            .iter()
572            .find(|m| m.role == crate::Role::Assistant)
573            .expect("assistant message should still exist");
574
575        assert_eq!(
576            assistant_message.content.len(),
577            1,
578            "Only the text content should remain; unsigned thinking block should be stripped"
579        );
580        assert!(matches!(
581            &assistant_message.content[0],
582            RequestContent::Text { text, .. } if text == "Some response text"
583        ));
584    }
585
586    #[test]
587    fn test_signed_thinking_blocks_preserved() {
588        let result = request_with_assistant_content(vec![
589            MessageContent::Thinking {
590                text: "Completed thinking".to_string(),
591                signature: Some("valid-signature".to_string()),
592            },
593            MessageContent::Text("Response".to_string()),
594        ]);
595
596        let assistant_message = result
597            .messages
598            .iter()
599            .find(|m| m.role == crate::Role::Assistant)
600            .expect("assistant message should exist");
601
602        assert_eq!(
603            assistant_message.content.len(),
604            2,
605            "Both the signed thinking block and text should be preserved"
606        );
607        assert!(matches!(
608            &assistant_message.content[0],
609            RequestContent::Thinking { thinking, signature, .. }
610                if thinking == "Completed thinking" && signature == "valid-signature"
611        ));
612    }
613
614    #[test]
615    fn test_only_unsigned_thinking_block_omits_entire_message() {
616        let result = request_with_assistant_content(vec![MessageContent::Thinking {
617            text: "Cancelled before any text or signature".to_string(),
618            signature: None,
619        }]);
620
621        let assistant_messages: Vec<_> = result
622            .messages
623            .iter()
624            .filter(|m| m.role == crate::Role::Assistant)
625            .collect();
626
627        assert_eq!(
628            assistant_messages.len(),
629            0,
630            "An assistant message whose only content was an unsigned thinking block \
631             should be omitted entirely"
632        );
633    }
634}