completion.rs

  1use anyhow::Result;
  2use collections::HashMap;
  3use futures::{Stream, StreamExt};
  4use language_model_core::{
  5    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
  6    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
  7    Role, StopReason, TokenUsage,
  8    util::{fix_streamed_json, parse_tool_arguments},
  9};
 10use std::pin::Pin;
 11use std::str::FromStr;
 12
 13use crate::{
 14    AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta, Event,
 15    ImageSource, Message, RequestContent, ResponseContent, StringOrContents, Thinking, Tool,
 16    ToolChoice, ToolResultContent, ToolResultPart, Usage,
 17};
 18
 19fn to_anthropic_content(content: MessageContent) -> Option<RequestContent> {
 20    match content {
 21        MessageContent::Text(text) => {
 22            let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
 23                text.trim_end().to_string()
 24            } else {
 25                text
 26            };
 27            if !text.is_empty() {
 28                Some(RequestContent::Text {
 29                    text,
 30                    cache_control: None,
 31                })
 32            } else {
 33                None
 34            }
 35        }
 36        MessageContent::Thinking {
 37            text: thinking,
 38            signature,
 39        } => {
 40            if let Some(signature) = signature
 41                && !thinking.is_empty()
 42            {
 43                Some(RequestContent::Thinking {
 44                    thinking,
 45                    signature,
 46                    cache_control: None,
 47                })
 48            } else {
 49                None
 50            }
 51        }
 52        MessageContent::RedactedThinking(data) => {
 53            if !data.is_empty() {
 54                Some(RequestContent::RedactedThinking { data })
 55            } else {
 56                None
 57            }
 58        }
 59        MessageContent::Image(image) => Some(RequestContent::Image {
 60            source: ImageSource {
 61                source_type: "base64".to_string(),
 62                media_type: "image/png".to_string(),
 63                data: image.source.to_string(),
 64            },
 65            cache_control: None,
 66        }),
 67        MessageContent::ToolUse(tool_use) => Some(RequestContent::ToolUse {
 68            id: tool_use.id.to_string(),
 69            name: tool_use.name.to_string(),
 70            input: tool_use.input,
 71            cache_control: None,
 72        }),
 73        MessageContent::ToolResult(tool_result) => Some(RequestContent::ToolResult {
 74            tool_use_id: tool_result.tool_use_id.to_string(),
 75            is_error: tool_result.is_error,
 76            content: match tool_result.content {
 77                LanguageModelToolResultContent::Text(text) => {
 78                    ToolResultContent::Plain(text.to_string())
 79                }
 80                LanguageModelToolResultContent::Image(image) => {
 81                    ToolResultContent::Multipart(vec![ToolResultPart::Image {
 82                        source: ImageSource {
 83                            source_type: "base64".to_string(),
 84                            media_type: "image/png".to_string(),
 85                            data: image.source.to_string(),
 86                        },
 87                    }])
 88                }
 89            },
 90            cache_control: None,
 91        }),
 92    }
 93}
 94
 95pub fn into_anthropic(
 96    request: LanguageModelRequest,
 97    model: String,
 98    default_temperature: f32,
 99    max_output_tokens: u64,
100    mode: AnthropicModelMode,
101) -> crate::Request {
102    let mut new_messages: Vec<Message> = Vec::new();
103    let mut system_message = String::new();
104
105    for message in request.messages {
106        if message.contents_empty() {
107            continue;
108        }
109
110        match message.role {
111            Role::User | Role::Assistant => {
112                let mut anthropic_message_content: Vec<RequestContent> = message
113                    .content
114                    .into_iter()
115                    .filter_map(to_anthropic_content)
116                    .collect();
117                let anthropic_role = match message.role {
118                    Role::User => crate::Role::User,
119                    Role::Assistant => crate::Role::Assistant,
120                    Role::System => unreachable!("System role should never occur here"),
121                };
122                if anthropic_message_content.is_empty() {
123                    continue;
124                }
125
126                if let Some(last_message) = new_messages.last_mut()
127                    && last_message.role == anthropic_role
128                {
129                    last_message.content.extend(anthropic_message_content);
130                    continue;
131                }
132
133                // Mark the last segment of the message as cached
134                if message.cache {
135                    let cache_control_value = Some(CacheControl {
136                        cache_type: CacheControlType::Ephemeral,
137                    });
138                    for message_content in anthropic_message_content.iter_mut().rev() {
139                        match message_content {
140                            RequestContent::RedactedThinking { .. } => {
141                                // Caching is not possible, fallback to next message
142                            }
143                            RequestContent::Text { cache_control, .. }
144                            | RequestContent::Thinking { cache_control, .. }
145                            | RequestContent::Image { cache_control, .. }
146                            | RequestContent::ToolUse { cache_control, .. }
147                            | RequestContent::ToolResult { cache_control, .. } => {
148                                *cache_control = cache_control_value;
149                                break;
150                            }
151                        }
152                    }
153                }
154
155                new_messages.push(Message {
156                    role: anthropic_role,
157                    content: anthropic_message_content,
158                });
159            }
160            Role::System => {
161                if !system_message.is_empty() {
162                    system_message.push_str("\n\n");
163                }
164                system_message.push_str(&message.string_contents());
165            }
166        }
167    }
168
169    crate::Request {
170        model,
171        messages: new_messages,
172        max_tokens: max_output_tokens,
173        system: if system_message.is_empty() {
174            None
175        } else {
176            Some(StringOrContents::String(system_message))
177        },
178        thinking: if request.thinking_allowed {
179            match mode {
180                AnthropicModelMode::Thinking { budget_tokens } => {
181                    Some(Thinking::Enabled { budget_tokens })
182                }
183                AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive),
184                AnthropicModelMode::Default => None,
185            }
186        } else {
187            None
188        },
189        tools: request
190            .tools
191            .into_iter()
192            .map(|tool| Tool {
193                name: tool.name,
194                description: tool.description,
195                input_schema: tool.input_schema,
196                eager_input_streaming: tool.use_input_streaming,
197            })
198            .collect(),
199        tool_choice: request.tool_choice.map(|choice| match choice {
200            LanguageModelToolChoice::Auto => ToolChoice::Auto,
201            LanguageModelToolChoice::Any => ToolChoice::Any,
202            LanguageModelToolChoice::None => ToolChoice::None,
203        }),
204        metadata: None,
205        output_config: if request.thinking_allowed
206            && matches!(mode, AnthropicModelMode::AdaptiveThinking)
207        {
208            request.thinking_effort.as_deref().and_then(|effort| {
209                let effort = match effort {
210                    "low" => Some(crate::Effort::Low),
211                    "medium" => Some(crate::Effort::Medium),
212                    "high" => Some(crate::Effort::High),
213                    "max" => Some(crate::Effort::Max),
214                    _ => None,
215                };
216                effort.map(|effort| crate::OutputConfig {
217                    effort: Some(effort),
218                })
219            })
220        } else {
221            None
222        },
223        stop_sequences: Vec::new(),
224        speed: request.speed.map(Into::into),
225        temperature: request.temperature.or(Some(default_temperature)),
226        top_k: None,
227        top_p: None,
228    }
229}
230
231pub struct AnthropicEventMapper {
232    tool_uses_by_index: HashMap<usize, RawToolUse>,
233    usage: Usage,
234    stop_reason: StopReason,
235}
236
237impl AnthropicEventMapper {
238    pub fn new() -> Self {
239        Self {
240            tool_uses_by_index: HashMap::default(),
241            usage: Usage::default(),
242            stop_reason: StopReason::EndTurn,
243        }
244    }
245
246    pub fn map_stream(
247        mut self,
248        events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
249    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
250    {
251        events.flat_map(move |event| {
252            futures::stream::iter(match event {
253                Ok(event) => self.map_event(event),
254                Err(error) => vec![Err(error.into())],
255            })
256        })
257    }
258
259    pub fn map_event(
260        &mut self,
261        event: Event,
262    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
263        match event {
264            Event::ContentBlockStart {
265                index,
266                content_block,
267            } => match content_block {
268                ResponseContent::Text { text } => {
269                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
270                }
271                ResponseContent::Thinking { thinking } => {
272                    vec![Ok(LanguageModelCompletionEvent::Thinking {
273                        text: thinking,
274                        signature: None,
275                    })]
276                }
277                ResponseContent::RedactedThinking { data } => {
278                    vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
279                }
280                ResponseContent::ToolUse { id, name, .. } => {
281                    self.tool_uses_by_index.insert(
282                        index,
283                        RawToolUse {
284                            id,
285                            name,
286                            input_json: String::new(),
287                        },
288                    );
289                    Vec::new()
290                }
291            },
292            Event::ContentBlockDelta { index, delta } => match delta {
293                ContentDelta::TextDelta { text } => {
294                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
295                }
296                ContentDelta::ThinkingDelta { thinking } => {
297                    vec![Ok(LanguageModelCompletionEvent::Thinking {
298                        text: thinking,
299                        signature: None,
300                    })]
301                }
302                ContentDelta::SignatureDelta { signature } => {
303                    vec![Ok(LanguageModelCompletionEvent::Thinking {
304                        text: "".to_string(),
305                        signature: Some(signature),
306                    })]
307                }
308                ContentDelta::InputJsonDelta { partial_json } => {
309                    if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
310                        tool_use.input_json.push_str(&partial_json);
311
312                        // Try to convert invalid (incomplete) JSON into
313                        // valid JSON that serde can accept, e.g. by closing
314                        // unclosed delimiters. This way, we can update the
315                        // UI with whatever has been streamed back so far.
316                        if let Ok(input) =
317                            serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json))
318                        {
319                            return vec![Ok(LanguageModelCompletionEvent::ToolUse(
320                                LanguageModelToolUse {
321                                    id: tool_use.id.clone().into(),
322                                    name: tool_use.name.clone().into(),
323                                    is_input_complete: false,
324                                    raw_input: tool_use.input_json.clone(),
325                                    input,
326                                    thought_signature: None,
327                                },
328                            ))];
329                        }
330                    }
331                    vec![]
332                }
333            },
334            Event::ContentBlockStop { index } => {
335                if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
336                    let input_json = tool_use.input_json.trim();
337                    let event_result = match parse_tool_arguments(input_json) {
338                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
339                            LanguageModelToolUse {
340                                id: tool_use.id.into(),
341                                name: tool_use.name.into(),
342                                is_input_complete: true,
343                                input,
344                                raw_input: tool_use.input_json.clone(),
345                                thought_signature: None,
346                            },
347                        )),
348                        Err(json_parse_err) => {
349                            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
350                                id: tool_use.id.into(),
351                                tool_name: tool_use.name.into(),
352                                raw_input: input_json.into(),
353                                json_parse_error: json_parse_err.to_string(),
354                            })
355                        }
356                    };
357
358                    vec![event_result]
359                } else {
360                    Vec::new()
361                }
362            }
363            Event::MessageStart { message } => {
364                update_usage(&mut self.usage, &message.usage);
365                vec![
366                    Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
367                        &self.usage,
368                    ))),
369                    Ok(LanguageModelCompletionEvent::StartMessage {
370                        message_id: message.id,
371                    }),
372                ]
373            }
374            Event::MessageDelta { delta, usage } => {
375                update_usage(&mut self.usage, &usage);
376                if let Some(stop_reason) = delta.stop_reason.as_deref() {
377                    self.stop_reason = match stop_reason {
378                        "end_turn" => StopReason::EndTurn,
379                        "max_tokens" => StopReason::MaxTokens,
380                        "tool_use" => StopReason::ToolUse,
381                        "refusal" => StopReason::Refusal,
382                        _ => {
383                            log::error!("Unexpected anthropic stop_reason: {stop_reason}");
384                            StopReason::EndTurn
385                        }
386                    };
387                }
388                vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
389                    convert_usage(&self.usage),
390                ))]
391            }
392            Event::MessageStop => {
393                vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
394            }
395            Event::Error { error } => {
396                vec![Err(error.into())]
397            }
398            _ => Vec::new(),
399        }
400    }
401}
402
403struct RawToolUse {
404    id: String,
405    name: String,
406    input_json: String,
407}
408
409/// Updates usage data by preferring counts from `new`.
410fn update_usage(usage: &mut Usage, new: &Usage) {
411    if let Some(input_tokens) = new.input_tokens {
412        usage.input_tokens = Some(input_tokens);
413    }
414    if let Some(output_tokens) = new.output_tokens {
415        usage.output_tokens = Some(output_tokens);
416    }
417    if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
418        usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
419    }
420    if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
421        usage.cache_read_input_tokens = Some(cache_read_input_tokens);
422    }
423}
424
425fn convert_usage(usage: &Usage) -> TokenUsage {
426    TokenUsage {
427        input_tokens: usage.input_tokens.unwrap_or(0),
428        output_tokens: usage.output_tokens.unwrap_or(0),
429        cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
430        cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437    use crate::AnthropicModelMode;
438    use language_model_core::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
439
440    #[test]
441    fn test_cache_control_only_on_last_segment() {
442        let request = LanguageModelRequest {
443            messages: vec![LanguageModelRequestMessage {
444                role: Role::User,
445                content: vec![
446                    MessageContent::Text("Some prompt".to_string()),
447                    MessageContent::Image(LanguageModelImage::empty()),
448                    MessageContent::Image(LanguageModelImage::empty()),
449                    MessageContent::Image(LanguageModelImage::empty()),
450                    MessageContent::Image(LanguageModelImage::empty()),
451                ],
452                cache: true,
453                reasoning_details: None,
454            }],
455            thread_id: None,
456            prompt_id: None,
457            intent: None,
458            stop: vec![],
459            temperature: None,
460            tools: vec![],
461            tool_choice: None,
462            thinking_allowed: true,
463            thinking_effort: None,
464            speed: None,
465        };
466
467        let anthropic_request = into_anthropic(
468            request,
469            "claude-3-5-sonnet".to_string(),
470            0.7,
471            4096,
472            AnthropicModelMode::Default,
473        );
474
475        assert_eq!(anthropic_request.messages.len(), 1);
476
477        let message = &anthropic_request.messages[0];
478        assert_eq!(message.content.len(), 5);
479
480        assert!(matches!(
481            message.content[0],
482            RequestContent::Text {
483                cache_control: None,
484                ..
485            }
486        ));
487        for i in 1..3 {
488            assert!(matches!(
489                message.content[i],
490                RequestContent::Image {
491                    cache_control: None,
492                    ..
493                }
494            ));
495        }
496
497        assert!(matches!(
498            message.content[4],
499            RequestContent::Image {
500                cache_control: Some(CacheControl {
501                    cache_type: CacheControlType::Ephemeral,
502                }),
503                ..
504            }
505        ));
506    }
507
508    fn request_with_assistant_content(assistant_content: Vec<MessageContent>) -> crate::Request {
509        let mut request = LanguageModelRequest {
510            messages: vec![LanguageModelRequestMessage {
511                role: Role::User,
512                content: vec![MessageContent::Text("Hello".to_string())],
513                cache: false,
514                reasoning_details: None,
515            }],
516            thinking_effort: None,
517            thread_id: None,
518            prompt_id: None,
519            intent: None,
520            stop: vec![],
521            temperature: None,
522            tools: vec![],
523            tool_choice: None,
524            thinking_allowed: true,
525            speed: None,
526        };
527        request.messages.push(LanguageModelRequestMessage {
528            role: Role::Assistant,
529            content: assistant_content,
530            cache: false,
531            reasoning_details: None,
532        });
533        into_anthropic(
534            request,
535            "claude-sonnet-4-5".to_string(),
536            1.0,
537            16000,
538            AnthropicModelMode::Thinking {
539                budget_tokens: Some(10000),
540            },
541        )
542    }
543
544    #[test]
545    fn test_unsigned_thinking_blocks_stripped() {
546        let result = request_with_assistant_content(vec![
547            MessageContent::Thinking {
548                text: "Cancelled mid-think, no signature".to_string(),
549                signature: None,
550            },
551            MessageContent::Text("Some response text".to_string()),
552        ]);
553
554        let assistant_message = result
555            .messages
556            .iter()
557            .find(|m| m.role == crate::Role::Assistant)
558            .expect("assistant message should still exist");
559
560        assert_eq!(
561            assistant_message.content.len(),
562            1,
563            "Only the text content should remain; unsigned thinking block should be stripped"
564        );
565        assert!(matches!(
566            &assistant_message.content[0],
567            RequestContent::Text { text, .. } if text == "Some response text"
568        ));
569    }
570
571    #[test]
572    fn test_signed_thinking_blocks_preserved() {
573        let result = request_with_assistant_content(vec![
574            MessageContent::Thinking {
575                text: "Completed thinking".to_string(),
576                signature: Some("valid-signature".to_string()),
577            },
578            MessageContent::Text("Response".to_string()),
579        ]);
580
581        let assistant_message = result
582            .messages
583            .iter()
584            .find(|m| m.role == crate::Role::Assistant)
585            .expect("assistant message should exist");
586
587        assert_eq!(
588            assistant_message.content.len(),
589            2,
590            "Both the signed thinking block and text should be preserved"
591        );
592        assert!(matches!(
593            &assistant_message.content[0],
594            RequestContent::Thinking { thinking, signature, .. }
595                if thinking == "Completed thinking" && signature == "valid-signature"
596        ));
597    }
598
599    #[test]
600    fn test_only_unsigned_thinking_block_omits_entire_message() {
601        let result = request_with_assistant_content(vec![MessageContent::Thinking {
602            text: "Cancelled before any text or signature".to_string(),
603            signature: None,
604        }]);
605
606        let assistant_messages: Vec<_> = result
607            .messages
608            .iter()
609            .filter(|m| m.role == crate::Role::Assistant)
610            .collect();
611
612        assert_eq!(
613            assistant_messages.len(),
614            0,
615            "An assistant message whose only content was an unsigned thinking block \
616             should be omitted entirely"
617        );
618    }
619}