bedrock: Add prompt caching support (#33194)

Vladimir Kuznichenkov and Oleksiy Syvokon created

Closes https://github.com/zed-industries/zed/issues/33221

Bedrock has similar to anthropic caching api, if we want to cache
messages up to a certain point, we should add a special block into that
message.

Additionally, we can cache tools definition by adding cache point block
after tools spec.

See: [Bedrock User Guide: Prompt
Caching](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models)

Release Notes:

- bedrock: Added prompt caching support

---------

Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>

Change summary

crates/bedrock/src/models.rs                   | 59 ++++++++++++++++++++
crates/language_models/src/provider/bedrock.rs | 52 ++++++++++++++---
2 files changed, 101 insertions(+), 10 deletions(-)

Detailed changes

crates/bedrock/src/models.rs 🔗

@@ -11,6 +11,13 @@ pub enum BedrockModelMode {
     },
 }
 
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub struct BedrockModelCacheConfiguration {
+    pub max_cache_anchors: usize,
+    pub min_total_token: u64,
+}
+
 #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 pub enum Model {
@@ -104,6 +111,7 @@ pub enum Model {
         display_name: Option<String>,
         max_output_tokens: Option<u64>,
         default_temperature: Option<f32>,
+        cache_configuration: Option<BedrockModelCacheConfiguration>,
     },
 }
 
@@ -401,6 +409,56 @@ impl Model {
         }
     }
 
+    pub fn supports_caching(&self) -> bool {
+        match self {
+            // Only Claude models on Bedrock support caching
+            // Nova models support only text caching
+            // https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models
+            Self::Claude3_5Haiku
+            | Self::Claude3_7Sonnet
+            | Self::Claude3_7SonnetThinking
+            | Self::ClaudeSonnet4
+            | Self::ClaudeSonnet4Thinking
+            | Self::ClaudeOpus4
+            | Self::ClaudeOpus4Thinking => true,
+
+            // Custom models - check if they have cache configuration
+            Self::Custom {
+                cache_configuration,
+                ..
+            } => cache_configuration.is_some(),
+
+            // All other models don't support caching
+            _ => false,
+        }
+    }
+
+    pub fn cache_configuration(&self) -> Option<BedrockModelCacheConfiguration> {
+        match self {
+            Self::Claude3_7Sonnet
+            | Self::Claude3_7SonnetThinking
+            | Self::ClaudeSonnet4
+            | Self::ClaudeSonnet4Thinking
+            | Self::ClaudeOpus4
+            | Self::ClaudeOpus4Thinking => Some(BedrockModelCacheConfiguration {
+                max_cache_anchors: 4,
+                min_total_token: 1024,
+            }),
+
+            Self::Claude3_5Haiku => Some(BedrockModelCacheConfiguration {
+                max_cache_anchors: 4,
+                min_total_token: 2048,
+            }),
+
+            Self::Custom {
+                cache_configuration,
+                ..
+            } => cache_configuration.clone(),
+
+            _ => None,
+        }
+    }
+
     pub fn mode(&self) -> BedrockModelMode {
         match self {
             Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
@@ -660,6 +718,7 @@ mod tests {
             display_name: Some("My Custom Model".to_string()),
             max_output_tokens: Some(8192),
             default_temperature: Some(0.7),
+            cache_configuration: None,
         };
 
         // Custom model should return its name unchanged

crates/language_models/src/provider/bedrock.rs 🔗

@@ -11,8 +11,8 @@ use aws_http_client::AwsHttpClient;
 use bedrock::bedrock_client::Client as BedrockClient;
 use bedrock::bedrock_client::config::timeout::TimeoutConfig;
 use bedrock::bedrock_client::types::{
-    ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta,
-    StopReason,
+    CachePointBlock, CachePointType, ContentBlockDelta, ContentBlockStart, ConverseStreamOutput,
+    ReasoningContentBlockDelta, StopReason,
 };
 use bedrock::{
     BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
@@ -48,7 +48,7 @@ use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
 use theme::ThemeSettings;
 use tokio::runtime::Handle;
 use ui::{Icon, IconName, List, Tooltip, prelude::*};
-use util::{ResultExt, default};
+use util::ResultExt;
 
 use crate::AllLanguageModelSettings;
 
@@ -329,6 +329,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
                     max_tokens: model.max_tokens,
                     max_output_tokens: model.max_output_tokens,
                     default_temperature: model.default_temperature,
+                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
+                        bedrock::BedrockModelCacheConfiguration {
+                            max_cache_anchors: config.max_cache_anchors,
+                            min_total_token: config.min_total_token,
+                        }
+                    }),
                 },
             );
         }
@@ -558,6 +564,7 @@ impl LanguageModel for BedrockModel {
             self.model.default_temperature(),
             self.model.max_output_tokens(),
             self.model.mode(),
+            self.model.supports_caching(),
         ) {
             Ok(request) => request,
             Err(err) => return futures::future::ready(Err(err.into())).boxed(),
@@ -581,7 +588,13 @@ impl LanguageModel for BedrockModel {
     }
 
     fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
-        None
+        self.model
+            .cache_configuration()
+            .map(|config| LanguageModelCacheConfiguration {
+                max_cache_anchors: config.max_cache_anchors,
+                should_speculate: false,
+                min_total_token: config.min_total_token,
+            })
     }
 }
 
@@ -608,6 +621,7 @@ pub fn into_bedrock(
     default_temperature: f32,
     max_output_tokens: u64,
     mode: BedrockModelMode,
+    supports_caching: bool,
 ) -> Result<bedrock::Request> {
     let mut new_messages: Vec<BedrockMessage> = Vec::new();
     let mut system_message = String::new();
@@ -619,7 +633,7 @@ pub fn into_bedrock(
 
         match message.role {
             Role::User | Role::Assistant => {
-                let bedrock_message_content: Vec<BedrockInnerContent> = message
+                let mut bedrock_message_content: Vec<BedrockInnerContent> = message
                     .content
                     .into_iter()
                     .filter_map(|content| match content {
@@ -703,6 +717,14 @@ pub fn into_bedrock(
                         _ => None,
                     })
                     .collect();
+                if message.cache && supports_caching {
+                    bedrock_message_content.push(BedrockInnerContent::CachePoint(
+                        CachePointBlock::builder()
+                            .r#type(CachePointType::Default)
+                            .build()
+                            .context("failed to build cache point block")?,
+                    ));
+                }
                 let bedrock_role = match message.role {
                     Role::User => bedrock::BedrockRole::User,
                     Role::Assistant => bedrock::BedrockRole::Assistant,
@@ -731,7 +753,7 @@ pub fn into_bedrock(
         }
     }
 
-    let tool_spec: Vec<BedrockTool> = request
+    let mut tool_spec: Vec<BedrockTool> = request
         .tools
         .iter()
         .filter_map(|tool| {
@@ -748,6 +770,15 @@ pub fn into_bedrock(
         })
         .collect();
 
+    if !tool_spec.is_empty() && supports_caching {
+        tool_spec.push(BedrockTool::CachePoint(
+            CachePointBlock::builder()
+                .r#type(CachePointType::Default)
+                .build()
+                .context("failed to build cache point block")?,
+        ));
+    }
+
     let tool_choice = match request.tool_choice {
         Some(LanguageModelToolChoice::Auto) | None => {
             BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
@@ -990,10 +1021,11 @@ pub fn map_to_language_model_completion_events(
                                                 LanguageModelCompletionEvent::UsageUpdate(
                                                     TokenUsage {
                                                         input_tokens: metadata.input_tokens as u64,
-                                                        output_tokens: metadata.output_tokens
-                                                            as u64,
-                                                        cache_creation_input_tokens: default(),
-                                                        cache_read_input_tokens: default(),
+                                                        output_tokens: metadata.output_tokens as u64,
+                                                        cache_creation_input_tokens:
+                                                            metadata.cache_write_input_tokens.unwrap_or_default() as u64,
+                                                        cache_read_input_tokens:
+                                                            metadata.cache_read_input_tokens.unwrap_or_default() as u64,
                                                     },
                                                 );
                                             return Some((Some(Ok(completion_event)), state));