@@ -11,6 +11,13 @@ pub enum BedrockModelMode {
},
}
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub struct BedrockModelCacheConfiguration {
+ pub max_cache_anchors: usize,
+ pub min_total_token: u64,
+}
+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
pub enum Model {
@@ -104,6 +111,7 @@ pub enum Model {
display_name: Option<String>,
max_output_tokens: Option<u64>,
default_temperature: Option<f32>,
+ cache_configuration: Option<BedrockModelCacheConfiguration>,
},
}
@@ -401,6 +409,56 @@ impl Model {
}
}
+ pub fn supports_caching(&self) -> bool {
+ match self {
+ // Only Claude models on Bedrock support caching
+ // Nova models support only text caching
+ // https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html#prompt-caching-models
+ Self::Claude3_5Haiku
+ | Self::Claude3_7Sonnet
+ | Self::Claude3_7SonnetThinking
+ | Self::ClaudeSonnet4
+ | Self::ClaudeSonnet4Thinking
+ | Self::ClaudeOpus4
+ | Self::ClaudeOpus4Thinking => true,
+
+ // Custom models - check if they have cache configuration
+ Self::Custom {
+ cache_configuration,
+ ..
+ } => cache_configuration.is_some(),
+
+ // All other models don't support caching
+ _ => false,
+ }
+ }
+
+ pub fn cache_configuration(&self) -> Option<BedrockModelCacheConfiguration> {
+ match self {
+ Self::Claude3_7Sonnet
+ | Self::Claude3_7SonnetThinking
+ | Self::ClaudeSonnet4
+ | Self::ClaudeSonnet4Thinking
+ | Self::ClaudeOpus4
+ | Self::ClaudeOpus4Thinking => Some(BedrockModelCacheConfiguration {
+ max_cache_anchors: 4,
+ min_total_token: 1024,
+ }),
+
+ Self::Claude3_5Haiku => Some(BedrockModelCacheConfiguration {
+ max_cache_anchors: 4,
+ min_total_token: 2048,
+ }),
+
+ Self::Custom {
+ cache_configuration,
+ ..
+ } => cache_configuration.clone(),
+
+ _ => None,
+ }
+ }
+
pub fn mode(&self) -> BedrockModelMode {
match self {
Model::Claude3_7SonnetThinking => BedrockModelMode::Thinking {
@@ -660,6 +718,7 @@ mod tests {
display_name: Some("My Custom Model".to_string()),
max_output_tokens: Some(8192),
default_temperature: Some(0.7),
+ cache_configuration: None,
};
// Custom model should return its name unchanged
@@ -11,8 +11,8 @@ use aws_http_client::AwsHttpClient;
use bedrock::bedrock_client::Client as BedrockClient;
use bedrock::bedrock_client::config::timeout::TimeoutConfig;
use bedrock::bedrock_client::types::{
- ContentBlockDelta, ContentBlockStart, ConverseStreamOutput, ReasoningContentBlockDelta,
- StopReason,
+ CachePointBlock, CachePointType, ContentBlockDelta, ContentBlockStart, ConverseStreamOutput,
+ ReasoningContentBlockDelta, StopReason,
};
use bedrock::{
BedrockAnyToolChoice, BedrockAutoToolChoice, BedrockBlob, BedrockError, BedrockInnerContent,
@@ -48,7 +48,7 @@ use strum::{EnumIter, IntoEnumIterator, IntoStaticStr};
use theme::ThemeSettings;
use tokio::runtime::Handle;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
-use util::{ResultExt, default};
+use util::ResultExt;
use crate::AllLanguageModelSettings;
@@ -329,6 +329,12 @@ impl LanguageModelProvider for BedrockLanguageModelProvider {
max_tokens: model.max_tokens,
max_output_tokens: model.max_output_tokens,
default_temperature: model.default_temperature,
+ cache_configuration: model.cache_configuration.as_ref().map(|config| {
+ bedrock::BedrockModelCacheConfiguration {
+ max_cache_anchors: config.max_cache_anchors,
+ min_total_token: config.min_total_token,
+ }
+ }),
},
);
}
@@ -558,6 +564,7 @@ impl LanguageModel for BedrockModel {
self.model.default_temperature(),
self.model.max_output_tokens(),
self.model.mode(),
+ self.model.supports_caching(),
) {
Ok(request) => request,
Err(err) => return futures::future::ready(Err(err.into())).boxed(),
@@ -581,7 +588,13 @@ impl LanguageModel for BedrockModel {
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
- None
+ self.model
+ .cache_configuration()
+ .map(|config| LanguageModelCacheConfiguration {
+ max_cache_anchors: config.max_cache_anchors,
+ should_speculate: false,
+ min_total_token: config.min_total_token,
+ })
}
}
@@ -608,6 +621,7 @@ pub fn into_bedrock(
default_temperature: f32,
max_output_tokens: u64,
mode: BedrockModelMode,
+ supports_caching: bool,
) -> Result<bedrock::Request> {
let mut new_messages: Vec<BedrockMessage> = Vec::new();
let mut system_message = String::new();
@@ -619,7 +633,7 @@ pub fn into_bedrock(
match message.role {
Role::User | Role::Assistant => {
- let bedrock_message_content: Vec<BedrockInnerContent> = message
+ let mut bedrock_message_content: Vec<BedrockInnerContent> = message
.content
.into_iter()
.filter_map(|content| match content {
@@ -703,6 +717,14 @@ pub fn into_bedrock(
_ => None,
})
.collect();
+ if message.cache && supports_caching {
+ bedrock_message_content.push(BedrockInnerContent::CachePoint(
+ CachePointBlock::builder()
+ .r#type(CachePointType::Default)
+ .build()
+ .context("failed to build cache point block")?,
+ ));
+ }
let bedrock_role = match message.role {
Role::User => bedrock::BedrockRole::User,
Role::Assistant => bedrock::BedrockRole::Assistant,
@@ -731,7 +753,7 @@ pub fn into_bedrock(
}
}
- let tool_spec: Vec<BedrockTool> = request
+ let mut tool_spec: Vec<BedrockTool> = request
.tools
.iter()
.filter_map(|tool| {
@@ -748,6 +770,15 @@ pub fn into_bedrock(
})
.collect();
+ if !tool_spec.is_empty() && supports_caching {
+ tool_spec.push(BedrockTool::CachePoint(
+ CachePointBlock::builder()
+ .r#type(CachePointType::Default)
+ .build()
+ .context("failed to build cache point block")?,
+ ));
+ }
+
let tool_choice = match request.tool_choice {
Some(LanguageModelToolChoice::Auto) | None => {
BedrockToolChoice::Auto(BedrockAutoToolChoice::builder().build())
@@ -990,10 +1021,11 @@ pub fn map_to_language_model_completion_events(
LanguageModelCompletionEvent::UsageUpdate(
TokenUsage {
input_tokens: metadata.input_tokens as u64,
- output_tokens: metadata.output_tokens
- as u64,
- cache_creation_input_tokens: default(),
- cache_read_input_tokens: default(),
+ output_tokens: metadata.output_tokens as u64,
+ cache_creation_input_tokens:
+ metadata.cache_write_input_tokens.unwrap_or_default() as u64,
+ cache_read_input_tokens:
+ metadata.cache_read_input_tokens.unwrap_or_default() as u64,
},
);
return Some((Some(Ok(completion_event)), state));