completion.rs

  1use anyhow::Result;
  2use collections::HashMap;
  3use futures::{Stream, StreamExt};
  4use language_model_core::{
  5    LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRequest,
  6    LanguageModelToolChoice, LanguageModelToolResultContent, LanguageModelToolUse, MessageContent,
  7    Role, StopReason, TokenUsage,
  8    util::{fix_streamed_json, parse_tool_arguments},
  9};
 10use std::pin::Pin;
 11use std::str::FromStr;
 12
 13use crate::{
 14    AnthropicError, AnthropicModelMode, CacheControl, CacheControlType, ContentDelta,
 15    CountTokensRequest, Event, ImageSource, Message, RequestContent, ResponseContent,
 16    StringOrContents, Thinking, Tool, ToolChoice, ToolResultContent, ToolResultPart, Usage,
 17};
 18
 19fn to_anthropic_content(content: MessageContent) -> Option<RequestContent> {
 20    match content {
 21        MessageContent::Text(text) => {
 22            let text = if text.chars().last().is_some_and(|c| c.is_whitespace()) {
 23                text.trim_end().to_string()
 24            } else {
 25                text
 26            };
 27            if !text.is_empty() {
 28                Some(RequestContent::Text {
 29                    text,
 30                    cache_control: None,
 31                })
 32            } else {
 33                None
 34            }
 35        }
 36        MessageContent::Thinking {
 37            text: thinking,
 38            signature,
 39        } => {
 40            if let Some(signature) = signature
 41                && !thinking.is_empty()
 42            {
 43                Some(RequestContent::Thinking {
 44                    thinking,
 45                    signature,
 46                    cache_control: None,
 47                })
 48            } else {
 49                None
 50            }
 51        }
 52        MessageContent::RedactedThinking(data) => {
 53            if !data.is_empty() {
 54                Some(RequestContent::RedactedThinking { data })
 55            } else {
 56                None
 57            }
 58        }
 59        MessageContent::Image(image) => Some(RequestContent::Image {
 60            source: ImageSource {
 61                source_type: "base64".to_string(),
 62                media_type: "image/png".to_string(),
 63                data: image.source.to_string(),
 64            },
 65            cache_control: None,
 66        }),
 67        MessageContent::ToolUse(tool_use) => Some(RequestContent::ToolUse {
 68            id: tool_use.id.to_string(),
 69            name: tool_use.name.to_string(),
 70            input: tool_use.input,
 71            cache_control: None,
 72        }),
 73        MessageContent::ToolResult(tool_result) => Some(RequestContent::ToolResult {
 74            tool_use_id: tool_result.tool_use_id.to_string(),
 75            is_error: tool_result.is_error,
 76            content: match tool_result.content {
 77                LanguageModelToolResultContent::Text(text) => {
 78                    ToolResultContent::Plain(text.to_string())
 79                }
 80                LanguageModelToolResultContent::Image(image) => {
 81                    ToolResultContent::Multipart(vec![ToolResultPart::Image {
 82                        source: ImageSource {
 83                            source_type: "base64".to_string(),
 84                            media_type: "image/png".to_string(),
 85                            data: image.source.to_string(),
 86                        },
 87                    }])
 88                }
 89            },
 90            cache_control: None,
 91        }),
 92    }
 93}
 94
 95/// Convert a LanguageModelRequest to an Anthropic CountTokensRequest.
 96pub fn into_anthropic_count_tokens_request(
 97    request: LanguageModelRequest,
 98    model: String,
 99    mode: AnthropicModelMode,
100) -> CountTokensRequest {
101    let mut new_messages: Vec<Message> = Vec::new();
102    let mut system_message = String::new();
103
104    for message in request.messages {
105        if message.contents_empty() {
106            continue;
107        }
108
109        match message.role {
110            Role::User | Role::Assistant => {
111                let anthropic_message_content: Vec<RequestContent> = message
112                    .content
113                    .into_iter()
114                    .filter_map(to_anthropic_content)
115                    .collect();
116                let anthropic_role = match message.role {
117                    Role::User => crate::Role::User,
118                    Role::Assistant => crate::Role::Assistant,
119                    Role::System => unreachable!("System role should never occur here"),
120                };
121                if anthropic_message_content.is_empty() {
122                    continue;
123                }
124
125                if let Some(last_message) = new_messages.last_mut()
126                    && last_message.role == anthropic_role
127                {
128                    last_message.content.extend(anthropic_message_content);
129                    continue;
130                }
131
132                new_messages.push(Message {
133                    role: anthropic_role,
134                    content: anthropic_message_content,
135                });
136            }
137            Role::System => {
138                if !system_message.is_empty() {
139                    system_message.push_str("\n\n");
140                }
141                system_message.push_str(&message.string_contents());
142            }
143        }
144    }
145
146    CountTokensRequest {
147        model,
148        messages: new_messages,
149        system: if system_message.is_empty() {
150            None
151        } else {
152            Some(StringOrContents::String(system_message))
153        },
154        thinking: if request.thinking_allowed {
155            match mode {
156                AnthropicModelMode::Thinking { budget_tokens } => {
157                    Some(Thinking::Enabled { budget_tokens })
158                }
159                AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive),
160                AnthropicModelMode::Default => None,
161            }
162        } else {
163            None
164        },
165        tools: request
166            .tools
167            .into_iter()
168            .map(|tool| Tool {
169                name: tool.name,
170                description: tool.description,
171                input_schema: tool.input_schema,
172                eager_input_streaming: tool.use_input_streaming,
173            })
174            .collect(),
175        tool_choice: request.tool_choice.map(|choice| match choice {
176            LanguageModelToolChoice::Auto => ToolChoice::Auto,
177            LanguageModelToolChoice::Any => ToolChoice::Any,
178            LanguageModelToolChoice::None => ToolChoice::None,
179        }),
180    }
181}
182
183/// Estimate tokens using tiktoken. Used as a fallback when the API is unavailable,
184/// or by providers (like Zed Cloud) that don't have direct Anthropic API access.
185pub fn count_anthropic_tokens_with_tiktoken(request: LanguageModelRequest) -> Result<u64> {
186    let messages = request.messages;
187    let mut tokens_from_images = 0;
188    let mut string_messages = Vec::with_capacity(messages.len());
189
190    for message in messages {
191        let mut string_contents = String::new();
192
193        for content in message.content {
194            match content {
195                MessageContent::Text(text) => {
196                    string_contents.push_str(&text);
197                }
198                MessageContent::Thinking { .. } => {
199                    // Thinking blocks are not included in the input token count.
200                }
201                MessageContent::RedactedThinking(_) => {
202                    // Thinking blocks are not included in the input token count.
203                }
204                MessageContent::Image(image) => {
205                    tokens_from_images += image.estimate_tokens();
206                }
207                MessageContent::ToolUse(_tool_use) => {
208                    // TODO: Estimate token usage from tool uses.
209                }
210                MessageContent::ToolResult(tool_result) => match &tool_result.content {
211                    LanguageModelToolResultContent::Text(text) => {
212                        string_contents.push_str(text);
213                    }
214                    LanguageModelToolResultContent::Image(image) => {
215                        tokens_from_images += image.estimate_tokens();
216                    }
217                },
218            }
219        }
220
221        if !string_contents.is_empty() {
222            string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
223                role: match message.role {
224                    Role::User => "user".into(),
225                    Role::Assistant => "assistant".into(),
226                    Role::System => "system".into(),
227                },
228                content: Some(string_contents),
229                name: None,
230                function_call: None,
231            });
232        }
233    }
234
235    // Tiktoken doesn't yet support these models, so we manually use the
236    // same tokenizer as GPT-4.
237    tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
238        .map(|tokens| (tokens + tokens_from_images) as u64)
239}
240
241pub fn into_anthropic(
242    request: LanguageModelRequest,
243    model: String,
244    default_temperature: f32,
245    max_output_tokens: u64,
246    mode: AnthropicModelMode,
247) -> crate::Request {
248    let mut new_messages: Vec<Message> = Vec::new();
249    let mut system_message = String::new();
250
251    for message in request.messages {
252        if message.contents_empty() {
253            continue;
254        }
255
256        match message.role {
257            Role::User | Role::Assistant => {
258                let mut anthropic_message_content: Vec<RequestContent> = message
259                    .content
260                    .into_iter()
261                    .filter_map(to_anthropic_content)
262                    .collect();
263                let anthropic_role = match message.role {
264                    Role::User => crate::Role::User,
265                    Role::Assistant => crate::Role::Assistant,
266                    Role::System => unreachable!("System role should never occur here"),
267                };
268                if anthropic_message_content.is_empty() {
269                    continue;
270                }
271
272                if let Some(last_message) = new_messages.last_mut()
273                    && last_message.role == anthropic_role
274                {
275                    last_message.content.extend(anthropic_message_content);
276                    continue;
277                }
278
279                // Mark the last segment of the message as cached
280                if message.cache {
281                    let cache_control_value = Some(CacheControl {
282                        cache_type: CacheControlType::Ephemeral,
283                    });
284                    for message_content in anthropic_message_content.iter_mut().rev() {
285                        match message_content {
286                            RequestContent::RedactedThinking { .. } => {
287                                // Caching is not possible, fallback to next message
288                            }
289                            RequestContent::Text { cache_control, .. }
290                            | RequestContent::Thinking { cache_control, .. }
291                            | RequestContent::Image { cache_control, .. }
292                            | RequestContent::ToolUse { cache_control, .. }
293                            | RequestContent::ToolResult { cache_control, .. } => {
294                                *cache_control = cache_control_value;
295                                break;
296                            }
297                        }
298                    }
299                }
300
301                new_messages.push(Message {
302                    role: anthropic_role,
303                    content: anthropic_message_content,
304                });
305            }
306            Role::System => {
307                if !system_message.is_empty() {
308                    system_message.push_str("\n\n");
309                }
310                system_message.push_str(&message.string_contents());
311            }
312        }
313    }
314
315    crate::Request {
316        model,
317        messages: new_messages,
318        max_tokens: max_output_tokens,
319        system: if system_message.is_empty() {
320            None
321        } else {
322            Some(StringOrContents::String(system_message))
323        },
324        thinking: if request.thinking_allowed {
325            match mode {
326                AnthropicModelMode::Thinking { budget_tokens } => {
327                    Some(Thinking::Enabled { budget_tokens })
328                }
329                AnthropicModelMode::AdaptiveThinking => Some(Thinking::Adaptive),
330                AnthropicModelMode::Default => None,
331            }
332        } else {
333            None
334        },
335        tools: request
336            .tools
337            .into_iter()
338            .map(|tool| Tool {
339                name: tool.name,
340                description: tool.description,
341                input_schema: tool.input_schema,
342                eager_input_streaming: tool.use_input_streaming,
343            })
344            .collect(),
345        tool_choice: request.tool_choice.map(|choice| match choice {
346            LanguageModelToolChoice::Auto => ToolChoice::Auto,
347            LanguageModelToolChoice::Any => ToolChoice::Any,
348            LanguageModelToolChoice::None => ToolChoice::None,
349        }),
350        metadata: None,
351        output_config: if request.thinking_allowed
352            && matches!(mode, AnthropicModelMode::AdaptiveThinking)
353        {
354            request.thinking_effort.as_deref().and_then(|effort| {
355                let effort = match effort {
356                    "low" => Some(crate::Effort::Low),
357                    "medium" => Some(crate::Effort::Medium),
358                    "high" => Some(crate::Effort::High),
359                    "max" => Some(crate::Effort::Max),
360                    _ => None,
361                };
362                effort.map(|effort| crate::OutputConfig {
363                    effort: Some(effort),
364                })
365            })
366        } else {
367            None
368        },
369        stop_sequences: Vec::new(),
370        speed: request.speed.map(Into::into),
371        temperature: request.temperature.or(Some(default_temperature)),
372        top_k: None,
373        top_p: None,
374    }
375}
376
377pub struct AnthropicEventMapper {
378    tool_uses_by_index: HashMap<usize, RawToolUse>,
379    usage: Usage,
380    stop_reason: StopReason,
381}
382
383impl AnthropicEventMapper {
384    pub fn new() -> Self {
385        Self {
386            tool_uses_by_index: HashMap::default(),
387            usage: Usage::default(),
388            stop_reason: StopReason::EndTurn,
389        }
390    }
391
392    pub fn map_stream(
393        mut self,
394        events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
395    ) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
396    {
397        events.flat_map(move |event| {
398            futures::stream::iter(match event {
399                Ok(event) => self.map_event(event),
400                Err(error) => vec![Err(error.into())],
401            })
402        })
403    }
404
405    pub fn map_event(
406        &mut self,
407        event: Event,
408    ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
409        match event {
410            Event::ContentBlockStart {
411                index,
412                content_block,
413            } => match content_block {
414                ResponseContent::Text { text } => {
415                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
416                }
417                ResponseContent::Thinking { thinking } => {
418                    vec![Ok(LanguageModelCompletionEvent::Thinking {
419                        text: thinking,
420                        signature: None,
421                    })]
422                }
423                ResponseContent::RedactedThinking { data } => {
424                    vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
425                }
426                ResponseContent::ToolUse { id, name, .. } => {
427                    self.tool_uses_by_index.insert(
428                        index,
429                        RawToolUse {
430                            id,
431                            name,
432                            input_json: String::new(),
433                        },
434                    );
435                    Vec::new()
436                }
437            },
438            Event::ContentBlockDelta { index, delta } => match delta {
439                ContentDelta::TextDelta { text } => {
440                    vec![Ok(LanguageModelCompletionEvent::Text(text))]
441                }
442                ContentDelta::ThinkingDelta { thinking } => {
443                    vec![Ok(LanguageModelCompletionEvent::Thinking {
444                        text: thinking,
445                        signature: None,
446                    })]
447                }
448                ContentDelta::SignatureDelta { signature } => {
449                    vec![Ok(LanguageModelCompletionEvent::Thinking {
450                        text: "".to_string(),
451                        signature: Some(signature),
452                    })]
453                }
454                ContentDelta::InputJsonDelta { partial_json } => {
455                    if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
456                        tool_use.input_json.push_str(&partial_json);
457
458                        // Try to convert invalid (incomplete) JSON into
459                        // valid JSON that serde can accept, e.g. by closing
460                        // unclosed delimiters. This way, we can update the
461                        // UI with whatever has been streamed back so far.
462                        if let Ok(input) =
463                            serde_json::Value::from_str(&fix_streamed_json(&tool_use.input_json))
464                        {
465                            return vec![Ok(LanguageModelCompletionEvent::ToolUse(
466                                LanguageModelToolUse {
467                                    id: tool_use.id.clone().into(),
468                                    name: tool_use.name.clone().into(),
469                                    is_input_complete: false,
470                                    raw_input: tool_use.input_json.clone(),
471                                    input,
472                                    thought_signature: None,
473                                },
474                            ))];
475                        }
476                    }
477                    vec![]
478                }
479            },
480            Event::ContentBlockStop { index } => {
481                if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
482                    let input_json = tool_use.input_json.trim();
483                    let event_result = match parse_tool_arguments(input_json) {
484                        Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
485                            LanguageModelToolUse {
486                                id: tool_use.id.into(),
487                                name: tool_use.name.into(),
488                                is_input_complete: true,
489                                input,
490                                raw_input: tool_use.input_json.clone(),
491                                thought_signature: None,
492                            },
493                        )),
494                        Err(json_parse_err) => {
495                            Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
496                                id: tool_use.id.into(),
497                                tool_name: tool_use.name.into(),
498                                raw_input: input_json.into(),
499                                json_parse_error: json_parse_err.to_string(),
500                            })
501                        }
502                    };
503
504                    vec![event_result]
505                } else {
506                    Vec::new()
507                }
508            }
509            Event::MessageStart { message } => {
510                update_usage(&mut self.usage, &message.usage);
511                vec![
512                    Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
513                        &self.usage,
514                    ))),
515                    Ok(LanguageModelCompletionEvent::StartMessage {
516                        message_id: message.id,
517                    }),
518                ]
519            }
520            Event::MessageDelta { delta, usage } => {
521                update_usage(&mut self.usage, &usage);
522                if let Some(stop_reason) = delta.stop_reason.as_deref() {
523                    self.stop_reason = match stop_reason {
524                        "end_turn" => StopReason::EndTurn,
525                        "max_tokens" => StopReason::MaxTokens,
526                        "tool_use" => StopReason::ToolUse,
527                        "refusal" => StopReason::Refusal,
528                        _ => {
529                            log::error!("Unexpected anthropic stop_reason: {stop_reason}");
530                            StopReason::EndTurn
531                        }
532                    };
533                }
534                vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
535                    convert_usage(&self.usage),
536                ))]
537            }
538            Event::MessageStop => {
539                vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
540            }
541            Event::Error { error } => {
542                vec![Err(error.into())]
543            }
544            _ => Vec::new(),
545        }
546    }
547}
548
549struct RawToolUse {
550    id: String,
551    name: String,
552    input_json: String,
553}
554
555/// Updates usage data by preferring counts from `new`.
556fn update_usage(usage: &mut Usage, new: &Usage) {
557    if let Some(input_tokens) = new.input_tokens {
558        usage.input_tokens = Some(input_tokens);
559    }
560    if let Some(output_tokens) = new.output_tokens {
561        usage.output_tokens = Some(output_tokens);
562    }
563    if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
564        usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
565    }
566    if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
567        usage.cache_read_input_tokens = Some(cache_read_input_tokens);
568    }
569}
570
571fn convert_usage(usage: &Usage) -> TokenUsage {
572    TokenUsage {
573        input_tokens: usage.input_tokens.unwrap_or(0),
574        output_tokens: usage.output_tokens.unwrap_or(0),
575        cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
576        cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583    use crate::AnthropicModelMode;
584    use language_model_core::{LanguageModelImage, LanguageModelRequestMessage, MessageContent};
585
586    #[test]
587    fn test_cache_control_only_on_last_segment() {
588        let request = LanguageModelRequest {
589            messages: vec![LanguageModelRequestMessage {
590                role: Role::User,
591                content: vec![
592                    MessageContent::Text("Some prompt".to_string()),
593                    MessageContent::Image(LanguageModelImage::empty()),
594                    MessageContent::Image(LanguageModelImage::empty()),
595                    MessageContent::Image(LanguageModelImage::empty()),
596                    MessageContent::Image(LanguageModelImage::empty()),
597                ],
598                cache: true,
599                reasoning_details: None,
600            }],
601            thread_id: None,
602            prompt_id: None,
603            intent: None,
604            stop: vec![],
605            temperature: None,
606            tools: vec![],
607            tool_choice: None,
608            thinking_allowed: true,
609            thinking_effort: None,
610            speed: None,
611        };
612
613        let anthropic_request = into_anthropic(
614            request,
615            "claude-3-5-sonnet".to_string(),
616            0.7,
617            4096,
618            AnthropicModelMode::Default,
619        );
620
621        assert_eq!(anthropic_request.messages.len(), 1);
622
623        let message = &anthropic_request.messages[0];
624        assert_eq!(message.content.len(), 5);
625
626        assert!(matches!(
627            message.content[0],
628            RequestContent::Text {
629                cache_control: None,
630                ..
631            }
632        ));
633        for i in 1..3 {
634            assert!(matches!(
635                message.content[i],
636                RequestContent::Image {
637                    cache_control: None,
638                    ..
639                }
640            ));
641        }
642
643        assert!(matches!(
644            message.content[4],
645            RequestContent::Image {
646                cache_control: Some(CacheControl {
647                    cache_type: CacheControlType::Ephemeral,
648                }),
649                ..
650            }
651        ));
652    }
653
654    fn request_with_assistant_content(assistant_content: Vec<MessageContent>) -> crate::Request {
655        let mut request = LanguageModelRequest {
656            messages: vec![LanguageModelRequestMessage {
657                role: Role::User,
658                content: vec![MessageContent::Text("Hello".to_string())],
659                cache: false,
660                reasoning_details: None,
661            }],
662            thinking_effort: None,
663            thread_id: None,
664            prompt_id: None,
665            intent: None,
666            stop: vec![],
667            temperature: None,
668            tools: vec![],
669            tool_choice: None,
670            thinking_allowed: true,
671            speed: None,
672        };
673        request.messages.push(LanguageModelRequestMessage {
674            role: Role::Assistant,
675            content: assistant_content,
676            cache: false,
677            reasoning_details: None,
678        });
679        into_anthropic(
680            request,
681            "claude-sonnet-4-5".to_string(),
682            1.0,
683            16000,
684            AnthropicModelMode::Thinking {
685                budget_tokens: Some(10000),
686            },
687        )
688    }
689
690    #[test]
691    fn test_unsigned_thinking_blocks_stripped() {
692        let result = request_with_assistant_content(vec![
693            MessageContent::Thinking {
694                text: "Cancelled mid-think, no signature".to_string(),
695                signature: None,
696            },
697            MessageContent::Text("Some response text".to_string()),
698        ]);
699
700        let assistant_message = result
701            .messages
702            .iter()
703            .find(|m| m.role == crate::Role::Assistant)
704            .expect("assistant message should still exist");
705
706        assert_eq!(
707            assistant_message.content.len(),
708            1,
709            "Only the text content should remain; unsigned thinking block should be stripped"
710        );
711        assert!(matches!(
712            &assistant_message.content[0],
713            RequestContent::Text { text, .. } if text == "Some response text"
714        ));
715    }
716
717    #[test]
718    fn test_signed_thinking_blocks_preserved() {
719        let result = request_with_assistant_content(vec![
720            MessageContent::Thinking {
721                text: "Completed thinking".to_string(),
722                signature: Some("valid-signature".to_string()),
723            },
724            MessageContent::Text("Response".to_string()),
725        ]);
726
727        let assistant_message = result
728            .messages
729            .iter()
730            .find(|m| m.role == crate::Role::Assistant)
731            .expect("assistant message should exist");
732
733        assert_eq!(
734            assistant_message.content.len(),
735            2,
736            "Both the signed thinking block and text should be preserved"
737        );
738        assert!(matches!(
739            &assistant_message.content[0],
740            RequestContent::Thinking { thinking, signature, .. }
741                if thinking == "Completed thinking" && signature == "valid-signature"
742        ));
743    }
744
745    #[test]
746    fn test_only_unsigned_thinking_block_omits_entire_message() {
747        let result = request_with_assistant_content(vec![MessageContent::Thinking {
748            text: "Cancelled before any text or signature".to_string(),
749            signature: None,
750        }]);
751
752        let assistant_messages: Vec<_> = result
753            .messages
754            .iter()
755            .filter(|m| m.role == crate::Role::Assistant)
756            .collect();
757
758        assert_eq!(
759            assistant_messages.len(),
760            0,
761            "An assistant message whose only content was an unsigned thinking block \
762             should be omitted entirely"
763        );
764    }
765}