Implement Anthropic prompt caching (#16274)

Roy Williams created

Release Notes:

- Adds support for Prompt Caching in Anthropic. For models that support
it this can dramatically lower cost while improving performance.

Change summary

crates/anthropic/src/anthropic.rs                 |  67 +++++
crates/assistant/src/context.rs                   | 176 ++++++++++++++++
crates/assistant/src/inline_assistant.rs          |   1 
crates/assistant/src/prompt_library.rs            |   1 
crates/assistant/src/terminal_inline_assistant.rs |   1 
crates/assistant/src/workflow.rs                  |   1 
crates/language_model/src/language_model.rs       |  14 +
crates/language_model/src/provider/anthropic.rs   |  24 ++
crates/language_model/src/provider/cloud.rs       |  12 
crates/language_model/src/request.rs              |  84 ++++---
crates/language_model/src/settings.rs             |  27 +
11 files changed, 338 insertions(+), 70 deletions(-)

Detailed changes

crates/anthropic/src/anthropic.rs 🔗

@@ -14,6 +14,14 @@ pub use supported_countries::*;
 
 pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 
+#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub struct AnthropicModelCacheConfiguration {
+    pub min_total_token: usize,
+    pub should_speculate: bool,
+    pub max_cache_anchors: usize,
+}
+
 #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 pub enum Model {
@@ -32,6 +40,8 @@ pub enum Model {
         max_tokens: usize,
         /// Override this model with a different Anthropic model for tool calls.
         tool_override: Option<String>,
+        /// Indicates whether this custom model supports caching.
+        cache_configuration: Option<AnthropicModelCacheConfiguration>,
     },
 }
 
@@ -70,6 +80,21 @@ impl Model {
         }
     }
 
+    pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
+        match self {
+            Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
+                min_total_token: 2_048,
+                should_speculate: true,
+                max_cache_anchors: 4,
+            }),
+            Self::Custom {
+                cache_configuration,
+                ..
+            } => cache_configuration.clone(),
+            _ => None,
+        }
+    }
+
     pub fn max_token_count(&self) -> usize {
         match self {
             Self::Claude3_5Sonnet
@@ -104,7 +129,10 @@ pub async fn complete(
         .method(Method::POST)
         .uri(uri)
         .header("Anthropic-Version", "2023-06-01")
-        .header("Anthropic-Beta", "tools-2024-04-04")
+        .header(
+            "Anthropic-Beta",
+            "tools-2024-04-04,prompt-caching-2024-07-31",
+        )
         .header("X-Api-Key", api_key)
         .header("Content-Type", "application/json");
 
@@ -161,7 +189,10 @@ pub async fn stream_completion(
         .method(Method::POST)
         .uri(uri)
         .header("Anthropic-Version", "2023-06-01")
-        .header("Anthropic-Beta", "tools-2024-04-04")
+        .header(
+            "Anthropic-Beta",
+            "tools-2024-04-04,prompt-caching-2024-07-31",
+        )
         .header("X-Api-Key", api_key)
         .header("Content-Type", "application/json");
     if let Some(low_speed_timeout) = low_speed_timeout {
@@ -226,7 +257,7 @@ pub fn extract_text_from_events(
         match response {
             Ok(response) => match response {
                 Event::ContentBlockStart { content_block, .. } => match content_block {
-                    Content::Text { text } => Some(Ok(text)),
+                    Content::Text { text, .. } => Some(Ok(text)),
                     _ => None,
                 },
                 Event::ContentBlockDelta { delta, .. } => match delta {
@@ -285,13 +316,25 @@ pub async fn extract_tool_args_from_events(
     }))
 }
 
+#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
+#[serde(rename_all = "lowercase")]
+pub enum CacheControlType {
+    Ephemeral,
+}
+
+#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
+pub struct CacheControl {
+    #[serde(rename = "type")]
+    pub cache_type: CacheControlType,
+}
+
 #[derive(Debug, Serialize, Deserialize)]
 pub struct Message {
     pub role: Role,
     pub content: Vec<Content>,
 }
 
-#[derive(Debug, Serialize, Deserialize)]
+#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
 #[serde(rename_all = "lowercase")]
 pub enum Role {
     User,
@@ -302,19 +345,31 @@ pub enum Role {
 #[serde(tag = "type")]
 pub enum Content {
     #[serde(rename = "text")]
-    Text { text: String },
+    Text {
+        text: String,
+        #[serde(skip_serializing_if = "Option::is_none")]
+        cache_control: Option<CacheControl>,
+    },
     #[serde(rename = "image")]
-    Image { source: ImageSource },
+    Image {
+        source: ImageSource,
+        #[serde(skip_serializing_if = "Option::is_none")]
+        cache_control: Option<CacheControl>,
+    },
     #[serde(rename = "tool_use")]
     ToolUse {
         id: String,
         name: String,
         input: serde_json::Value,
+        #[serde(skip_serializing_if = "Option::is_none")]
+        cache_control: Option<CacheControl>,
     },
     #[serde(rename = "tool_result")]
     ToolResult {
         tool_use_id: String,
         content: String,
+        #[serde(skip_serializing_if = "Option::is_none")]
+        cache_control: Option<CacheControl>,
     },
 }
 

crates/assistant/src/context.rs 🔗

@@ -21,8 +21,8 @@ use gpui::{
 
 use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
 use language_model::{
-    LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
-    Role,
+    LanguageModel, LanguageModelCacheConfiguration, LanguageModelImage, LanguageModelRegistry,
+    LanguageModelRequest, LanguageModelRequestMessage, Role,
 };
 use open_ai::Model as OpenAiModel;
 use paths::{context_images_dir, contexts_dir};
@@ -30,7 +30,7 @@ use project::Project;
 use serde::{Deserialize, Serialize};
 use smallvec::SmallVec;
 use std::{
-    cmp::Ordering,
+    cmp::{max, Ordering},
     collections::hash_map,
     fmt::Debug,
     iter, mem,
@@ -107,6 +107,8 @@ impl ContextOperation {
                             message.status.context("invalid status")?,
                         ),
                         timestamp: id.0,
+                        should_cache: false,
+                        is_cache_anchor: false,
                     },
                     version: language::proto::deserialize_version(&insert.version),
                 })
@@ -121,6 +123,8 @@ impl ContextOperation {
                     timestamp: language::proto::deserialize_timestamp(
                         update.timestamp.context("invalid timestamp")?,
                     ),
+                    should_cache: false,
+                    is_cache_anchor: false,
                 },
                 version: language::proto::deserialize_version(&update.version),
             }),
@@ -313,6 +317,8 @@ pub struct MessageMetadata {
     pub role: Role,
     pub status: MessageStatus,
     timestamp: clock::Lamport,
+    should_cache: bool,
+    is_cache_anchor: bool,
 }
 
 #[derive(Clone, Debug)]
@@ -338,6 +344,7 @@ pub struct Message {
     pub anchor: language::Anchor,
     pub role: Role,
     pub status: MessageStatus,
+    pub cache: bool,
 }
 
 impl Message {
@@ -373,6 +380,7 @@ impl Message {
         LanguageModelRequestMessage {
             role: self.role,
             content,
+            cache: self.cache,
         }
     }
 }
@@ -421,6 +429,7 @@ pub struct Context {
     token_count: Option<usize>,
     pending_token_count: Task<Option<()>>,
     pending_save: Task<Result<()>>,
+    pending_cache_warming_task: Task<Option<()>>,
     path: Option<PathBuf>,
     _subscriptions: Vec<Subscription>,
     telemetry: Option<Arc<Telemetry>>,
@@ -498,6 +507,7 @@ impl Context {
             pending_completions: Default::default(),
             token_count: None,
             pending_token_count: Task::ready(None),
+            pending_cache_warming_task: Task::ready(None),
             _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
             pending_save: Task::ready(Ok(())),
             path: None,
@@ -524,6 +534,8 @@ impl Context {
                 role: Role::User,
                 status: MessageStatus::Done,
                 timestamp: first_message_id.0,
+                should_cache: false,
+                is_cache_anchor: false,
             },
         );
         this.message_anchors.push(message);
@@ -948,6 +960,7 @@ impl Context {
                 let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
                 this.update(&mut cx, |this, cx| {
                     this.token_count = Some(token_count);
+                    this.start_cache_warming(&model, cx);
                     cx.notify()
                 })
             }
@@ -955,6 +968,121 @@ impl Context {
         });
     }
 
+    pub fn mark_longest_messages_for_cache(
+        &mut self,
+        cache_configuration: &Option<LanguageModelCacheConfiguration>,
+        speculative: bool,
+        cx: &mut ModelContext<Self>,
+    ) -> bool {
+        let cache_configuration =
+            cache_configuration
+                .as_ref()
+                .unwrap_or(&LanguageModelCacheConfiguration {
+                    max_cache_anchors: 0,
+                    should_speculate: false,
+                    min_total_token: 0,
+                });
+
+        let messages: Vec<Message> = self
+            .messages_from_anchors(
+                self.message_anchors.iter().take(if speculative {
+                    self.message_anchors.len().saturating_sub(1)
+                } else {
+                    self.message_anchors.len()
+                }),
+                cx,
+            )
+            .filter(|message| message.offset_range.len() >= 5_000)
+            .collect();
+
+        let mut sorted_messages = messages.clone();
+        sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len()));
+        if cache_configuration.max_cache_anchors == 0 && cache_configuration.should_speculate {
+            // Some models support caching, but don't support anchors.  In that case we want to
+            // mark the largest message as needing to be cached, but we will not mark it as an
+            // anchor.
+            sorted_messages.truncate(1);
+        } else {
+            // Save 1 anchor for the inline assistant.
+            sorted_messages.truncate(max(cache_configuration.max_cache_anchors, 1) - 1);
+        }
+
+        let longest_message_ids: HashSet<MessageId> = sorted_messages
+            .into_iter()
+            .map(|message| message.id)
+            .collect();
+
+        let cache_deltas: HashSet<MessageId> = self
+            .messages_metadata
+            .iter()
+            .filter_map(|(id, metadata)| {
+                let should_cache = longest_message_ids.contains(id);
+                let should_be_anchor = should_cache && cache_configuration.max_cache_anchors > 0;
+                if metadata.should_cache != should_cache
+                    || metadata.is_cache_anchor != should_be_anchor
+                {
+                    Some(*id)
+                } else {
+                    None
+                }
+            })
+            .collect();
+
+        let mut newly_cached_item = false;
+        for id in cache_deltas {
+            newly_cached_item = newly_cached_item || longest_message_ids.contains(&id);
+            self.update_metadata(id, cx, |metadata| {
+                metadata.should_cache = longest_message_ids.contains(&id);
+                metadata.is_cache_anchor =
+                    metadata.should_cache && (cache_configuration.max_cache_anchors > 0);
+            });
+        }
+        newly_cached_item
+    }
+
+    fn start_cache_warming(&mut self, model: &Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
+        let cache_configuration = model.cache_configuration();
+        if !self.mark_longest_messages_for_cache(&cache_configuration, true, cx) {
+            return;
+        }
+        if let Some(cache_configuration) = cache_configuration {
+            if !cache_configuration.should_speculate {
+                return;
+            }
+        }
+
+        let request = {
+            let mut req = self.to_completion_request(cx);
+            // Skip the last message because it's likely to change and
+            // therefore would be a waste to cache.
+            req.messages.pop();
+            req.messages.push(LanguageModelRequestMessage {
+                role: Role::User,
+                content: vec!["Respond only with OK, nothing else.".into()],
+                cache: false,
+            });
+            req
+        };
+
+        let model = Arc::clone(model);
+        self.pending_cache_warming_task = cx.spawn(|_, cx| {
+            async move {
+                match model.stream_completion(request, &cx).await {
+                    Ok(mut stream) => {
+                        stream.next().await;
+                        log::info!("Cache warming completed successfully");
+                    }
+                    Err(e) => {
+                        log::warn!("Cache warming failed: {}", e);
+                    }
+                };
+
+                anyhow::Ok(())
+            }
+            .log_err()
+        });
+    }
+
     pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
         let buffer = self.buffer.read(cx);
         let mut row_ranges = self
@@ -1352,20 +1480,26 @@ impl Context {
         self.count_remaining_tokens(cx);
     }
 
-    pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
-        let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
-        let model = LanguageModelRegistry::read_global(cx).active_model()?;
-        let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
+    fn get_last_valid_message_id(&self, cx: &ModelContext<Self>) -> Option<MessageId> {
+        self.message_anchors.iter().rev().find_map(|message| {
             message
                 .start
                 .is_valid(self.buffer.read(cx))
                 .then_some(message.id)
-        })?;
+        })
+    }
+
+    pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
+        let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
+        let model = LanguageModelRegistry::read_global(cx).active_model()?;
+        let last_message_id = self.get_last_valid_message_id(cx)?;
 
         if !provider.is_authenticated(cx) {
             log::info!("completion provider has no credentials");
             return None;
         }
+        // Compute which messages to cache, including the last one.
+        self.mark_longest_messages_for_cache(&model.cache_configuration(), false, cx);
 
         let request = self.to_completion_request(cx);
         let assistant_message = self
@@ -1580,6 +1714,8 @@ impl Context {
                 role,
                 status,
                 timestamp: anchor.id.0,
+                should_cache: false,
+                is_cache_anchor: false,
             };
             self.insert_message(anchor.clone(), metadata.clone(), cx);
             self.push_op(
@@ -1696,6 +1832,8 @@ impl Context {
                 role,
                 status: MessageStatus::Done,
                 timestamp: suffix.id.0,
+                should_cache: false,
+                is_cache_anchor: false,
             };
             self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
             self.push_op(
@@ -1745,6 +1883,8 @@ impl Context {
                         role,
                         status: MessageStatus::Done,
                         timestamp: selection.id.0,
+                        should_cache: false,
+                        is_cache_anchor: false,
                     };
                     self.insert_message(selection.clone(), selection_metadata.clone(), cx);
                     self.push_op(
@@ -1811,6 +1951,7 @@ impl Context {
                     content: vec![
                         "Summarize the context into a short title without punctuation.".into(),
                     ],
+                    cache: false,
                 }));
             let request = LanguageModelRequest {
                 messages: messages.collect(),
@@ -1910,14 +2051,22 @@ impl Context {
         result
     }
 
-    pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
+    fn messages_from_anchors<'a>(
+        &'a self,
+        message_anchors: impl Iterator<Item = &'a MessageAnchor> + 'a,
+        cx: &'a AppContext,
+    ) -> impl 'a + Iterator<Item = Message> {
         let buffer = self.buffer.read(cx);
-        let messages = self.message_anchors.iter().enumerate();
+        let messages = message_anchors.enumerate();
         let images = self.image_anchors.iter();
 
         Self::messages_from_iters(buffer, &self.messages_metadata, messages, images)
     }
 
+    pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
+        self.messages_from_anchors(self.message_anchors.iter(), cx)
+    }
+
     pub fn messages_from_iters<'a>(
         buffer: &'a Buffer,
         metadata: &'a HashMap<MessageId, MessageMetadata>,
@@ -1969,6 +2118,7 @@ impl Context {
                     anchor: message_anchor.start,
                     role: metadata.role,
                     status: metadata.status.clone(),
+                    cache: metadata.is_cache_anchor,
                     image_offsets,
                 });
             }
@@ -2215,6 +2365,8 @@ impl SavedContext {
                         role: message.metadata.role,
                         status: message.metadata.status,
                         timestamp: message.metadata.timestamp,
+                        should_cache: false,
+                        is_cache_anchor: false,
                     },
                     version: version.clone(),
                 });
@@ -2231,6 +2383,8 @@ impl SavedContext {
                     role: metadata.role,
                     status: metadata.status,
                     timestamp,
+                    should_cache: false,
+                    is_cache_anchor: false,
                 },
                 version: version.clone(),
             });
@@ -2325,6 +2479,8 @@ impl SavedContextV0_3_0 {
                             role: metadata.role,
                             status: metadata.status.clone(),
                             timestamp,
+                            should_cache: false,
+                            is_cache_anchor: false,
                         },
                         image_offsets: Vec::new(),
                     })

crates/assistant/src/inline_assistant.rs 🔗

@@ -2387,6 +2387,7 @@ impl Codegen {
         messages.push(LanguageModelRequestMessage {
             role: Role::User,
             content: vec![prompt.into()],
+            cache: false,
         });
 
         Ok(LanguageModelRequest {

crates/assistant/src/prompt_library.rs 🔗

@@ -784,6 +784,7 @@ impl PromptLibrary {
                                     messages: vec![LanguageModelRequestMessage {
                                         role: Role::System,
                                         content: vec![body.to_string().into()],
+                                        cache: false,
                                     }],
                                     stop: Vec::new(),
                                     temperature: 1.,

crates/assistant/src/workflow.rs 🔗

@@ -136,6 +136,7 @@ impl WorkflowStep {
                 request.messages.push(LanguageModelRequestMessage {
                     role: Role::User,
                     content: vec![prompt.into()],
+                    cache: false,
                 });
 
                 // Invoke the model to get its edit suggestions for this workflow step.

crates/language_model/src/language_model.rs 🔗

@@ -20,7 +20,7 @@ pub use registry::*;
 pub use request::*;
 pub use role::*;
 use schemars::JsonSchema;
-use serde::de::DeserializeOwned;
+use serde::{de::DeserializeOwned, Deserialize, Serialize};
 use std::{future::Future, sync::Arc};
 use ui::IconName;
 
@@ -43,6 +43,14 @@ pub enum LanguageModelAvailability {
     RequiresPlan(Plan),
 }
 
+/// Configuration for caching language model messages.
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+pub struct LanguageModelCacheConfiguration {
+    pub max_cache_anchors: usize,
+    pub should_speculate: bool,
+    pub min_total_token: usize,
+}
+
 pub trait LanguageModel: Send + Sync {
     fn id(&self) -> LanguageModelId;
     fn name(&self) -> LanguageModelName;
@@ -78,6 +86,10 @@ pub trait LanguageModel: Send + Sync {
         cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 
+    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
+        None
+    }
+
     #[cfg(any(test, feature = "test-support"))]
     fn as_fake(&self) -> &provider::fake::FakeLanguageModel {
         unimplemented!()

crates/language_model/src/provider/anthropic.rs 🔗

@@ -1,7 +1,7 @@
 use crate::{
-    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
-    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
-    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+    settings::AllLanguageModelSettings, LanguageModel, LanguageModelCacheConfiguration,
+    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
+    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 };
 use anthropic::AnthropicError;
 use anyhow::{anyhow, Context as _, Result};
@@ -38,6 +38,7 @@ pub struct AvailableModel {
     pub name: String,
     pub max_tokens: usize,
     pub tool_override: Option<String>,
+    pub cache_configuration: Option<LanguageModelCacheConfiguration>,
 }
 
 pub struct AnthropicLanguageModelProvider {
@@ -171,6 +172,13 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
                     name: model.name.clone(),
                     max_tokens: model.max_tokens,
                     tool_override: model.tool_override.clone(),
+                    cache_configuration: model.cache_configuration.as_ref().map(|config| {
+                        anthropic::AnthropicModelCacheConfiguration {
+                            max_cache_anchors: config.max_cache_anchors,
+                            should_speculate: config.should_speculate,
+                            min_total_token: config.min_total_token,
+                        }
+                    }),
                 },
             );
         }
@@ -351,6 +359,16 @@ impl LanguageModel for AnthropicModel {
         .boxed()
     }
 
+    fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
+        self.model
+            .cache_configuration()
+            .map(|config| LanguageModelCacheConfiguration {
+                max_cache_anchors: config.max_cache_anchors,
+                should_speculate: config.should_speculate,
+                min_total_token: config.min_total_token,
+            })
+    }
+
     fn use_any_tool(
         &self,
         request: LanguageModelRequest,

crates/language_model/src/provider/cloud.rs 🔗

@@ -1,7 +1,7 @@
 use super::open_ai::count_open_ai_tokens;
 use crate::{
-    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
-    LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
+    settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelCacheConfiguration,
+    LanguageModelId, LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
     LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
 };
 use anthropic::AnthropicError;
@@ -56,6 +56,7 @@ pub struct AvailableModel {
     name: String,
     max_tokens: usize,
     tool_override: Option<String>,
+    cache_configuration: Option<LanguageModelCacheConfiguration>,
 }
 
 pub struct CloudLanguageModelProvider {
@@ -202,6 +203,13 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
                             name: model.name.clone(),
                             max_tokens: model.max_tokens,
                             tool_override: model.tool_override.clone(),
+                            cache_configuration: model.cache_configuration.as_ref().map(|config| {
+                                anthropic::AnthropicModelCacheConfiguration {
+                                    max_cache_anchors: config.max_cache_anchors,
+                                    should_speculate: config.should_speculate,
+                                    min_total_token: config.min_total_token,
+                                }
+                            }),
                         })
                     }
                     AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {

crates/language_model/src/request.rs 🔗

@@ -193,6 +193,7 @@ impl From<&str> for MessageContent {
 pub struct LanguageModelRequestMessage {
     pub role: Role,
     pub content: Vec<MessageContent>,
+    pub cache: bool,
 }
 
 impl LanguageModelRequestMessage {
@@ -213,7 +214,7 @@ impl LanguageModelRequestMessage {
                 .content
                 .get(0)
                 .map(|content| match content {
-                    MessageContent::Text(s) => s.is_empty(),
+                    MessageContent::Text(s) => s.trim().is_empty(),
                     MessageContent::Image(_) => true,
                 })
                 .unwrap_or(false)
@@ -286,7 +287,7 @@ impl LanguageModelRequest {
     }
 
     pub fn into_anthropic(self, model: String) -> anthropic::Request {
-        let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
+        let mut new_messages: Vec<anthropic::Message> = Vec::new();
         let mut system_message = String::new();
 
         for message in self.messages {
@@ -296,18 +297,50 @@ impl LanguageModelRequest {
 
             match message.role {
                 Role::User | Role::Assistant => {
+                    let cache_control = if message.cache {
+                        Some(anthropic::CacheControl {
+                            cache_type: anthropic::CacheControlType::Ephemeral,
+                        })
+                    } else {
+                        None
+                    };
+                    let anthropic_message_content: Vec<anthropic::Content> = message
+                        .content
+                        .into_iter()
+                        // TODO: filter out the empty messages in the message construction step
+                        .filter_map(|content| match content {
+                            MessageContent::Text(t) if !t.is_empty() => {
+                                Some(anthropic::Content::Text {
+                                    text: t,
+                                    cache_control,
+                                })
+                            }
+                            MessageContent::Image(i) => Some(anthropic::Content::Image {
+                                source: anthropic::ImageSource {
+                                    source_type: "base64".to_string(),
+                                    media_type: "image/png".to_string(),
+                                    data: i.source.to_string(),
+                                },
+                                cache_control,
+                            }),
+                            _ => None,
+                        })
+                        .collect();
+                    let anthropic_role = match message.role {
+                        Role::User => anthropic::Role::User,
+                        Role::Assistant => anthropic::Role::Assistant,
+                        Role::System => unreachable!("System role should never occur here"),
+                    };
                     if let Some(last_message) = new_messages.last_mut() {
-                        if last_message.role == message.role {
-                            // TODO: is this append done properly?
-                            last_message.content.push(MessageContent::Text(format!(
-                                "\n\n{}",
-                                message.string_contents()
-                            )));
+                        if last_message.role == anthropic_role {
+                            last_message.content.extend(anthropic_message_content);
                             continue;
                         }
                     }
-
-                    new_messages.push(message);
+                    new_messages.push(anthropic::Message {
+                        role: anthropic_role,
+                        content: anthropic_message_content,
+                    });
                 }
                 Role::System => {
                     if !system_message.is_empty() {
@@ -320,36 +353,7 @@ impl LanguageModelRequest {
 
         anthropic::Request {
             model,
-            messages: new_messages
-                .into_iter()
-                .filter_map(|message| {
-                    Some(anthropic::Message {
-                        role: match message.role {
-                            Role::User => anthropic::Role::User,
-                            Role::Assistant => anthropic::Role::Assistant,
-                            Role::System => return None,
-                        },
-                        content: message
-                            .content
-                            .into_iter()
-                            // TODO: filter out the empty messages in the message construction step
-                            .filter_map(|content| match content {
-                                MessageContent::Text(t) if !t.is_empty() => {
-                                    Some(anthropic::Content::Text { text: t })
-                                }
-                                MessageContent::Image(i) => Some(anthropic::Content::Image {
-                                    source: anthropic::ImageSource {
-                                        source_type: "base64".to_string(),
-                                        media_type: "image/png".to_string(),
-                                        data: i.source.to_string(),
-                                    },
-                                }),
-                                _ => None,
-                            })
-                            .collect(),
-                    })
-                })
-                .collect(),
+            messages: new_messages,
             max_tokens: 4092,
             system: Some(system_message),
             tools: Vec::new(),

crates/language_model/src/settings.rs 🔗

@@ -7,14 +7,17 @@ use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{update_settings_file, Settings, SettingsSources};
 
-use crate::provider::{
-    self,
-    anthropic::AnthropicSettings,
-    cloud::{self, ZedDotDevSettings},
-    copilot_chat::CopilotChatSettings,
-    google::GoogleSettings,
-    ollama::OllamaSettings,
-    open_ai::OpenAiSettings,
+use crate::{
+    provider::{
+        self,
+        anthropic::AnthropicSettings,
+        cloud::{self, ZedDotDevSettings},
+        copilot_chat::CopilotChatSettings,
+        google::GoogleSettings,
+        ollama::OllamaSettings,
+        open_ai::OpenAiSettings,
+    },
+    LanguageModelCacheConfiguration,
 };
 
 /// Initializes the language model settings.
@@ -93,10 +96,18 @@ impl AnthropicSettingsContent {
                                     name,
                                     max_tokens,
                                     tool_override,
+                                    cache_configuration,
                                 } => Some(provider::anthropic::AvailableModel {
                                     name,
                                     max_tokens,
                                     tool_override,
+                                    cache_configuration: cache_configuration.as_ref().map(
+                                        |config| LanguageModelCacheConfiguration {
+                                            max_cache_anchors: config.max_cache_anchors,
+                                            should_speculate: config.should_speculate,
+                                            min_total_token: config.min_total_token,
+                                        },
+                                    ),
                                 }),
                                 _ => None,
                             })