wip fixing anthropic regressions

Richard Feldman created

Change summary

crates/extension_host/src/wasm_host/llm_provider.rs |   2 
extensions/anthropic/src/anthropic.rs               | 432 +++++++++++++-
2 files changed, 404 insertions(+), 30 deletions(-)

Detailed changes

crates/extension_host/src/wasm_host/llm_provider.rs 🔗

@@ -1880,7 +1880,7 @@ fn convert_request_to_wit(request: LanguageModelRequest) -> LlmCompletionRequest
         tool_choice,
         stop_sequences: request.stop,
         temperature: request.temperature,
-        thinking_allowed: false,
+        thinking_allowed: request.thinking_allowed,
         max_tokens: None,
     }
 }

extensions/anthropic/src/anthropic.rs 🔗

@@ -34,9 +34,12 @@ struct ModelDefinition {
     supports_thinking: bool,
     is_default: bool,
     is_default_fast: bool,
+    beta_headers: Option<&'static str>,
+    supports_caching: bool,
 }
 
 const MODELS: &[ModelDefinition] = &[
+    // Claude Opus 4.5
     ModelDefinition {
         real_id: "claude-opus-4-5-20251101",
         display_name: "Claude Opus 4.5",
@@ -46,6 +49,8 @@ const MODELS: &[ModelDefinition] = &[
         supports_thinking: false,
         is_default: false,
         is_default_fast: false,
+        beta_headers: Some("fine-grained-tool-streaming-2025-05-14"),
+        supports_caching: true,
     },
     ModelDefinition {
         real_id: "claude-opus-4-5-20251101",
@@ -56,7 +61,60 @@ const MODELS: &[ModelDefinition] = &[
         supports_thinking: true,
         is_default: false,
         is_default_fast: false,
+        beta_headers: Some("fine-grained-tool-streaming-2025-05-14"),
+        supports_caching: true,
     },
+    // Claude Opus 4.1
+    ModelDefinition {
+        real_id: "claude-opus-4-1-20250805",
+        display_name: "Claude Opus 4.1",
+        max_tokens: 200_000,
+        max_output_tokens: 8_192,
+        supports_images: true,
+        supports_thinking: false,
+        is_default: false,
+        is_default_fast: false,
+        beta_headers: Some("fine-grained-tool-streaming-2025-05-14"),
+        supports_caching: true,
+    },
+    ModelDefinition {
+        real_id: "claude-opus-4-1-20250805",
+        display_name: "Claude Opus 4.1 Thinking",
+        max_tokens: 200_000,
+        max_output_tokens: 8_192,
+        supports_images: true,
+        supports_thinking: true,
+        is_default: false,
+        is_default_fast: false,
+        beta_headers: Some("fine-grained-tool-streaming-2025-05-14"),
+        supports_caching: true,
+    },
+    // Claude Opus 4
+    ModelDefinition {
+        real_id: "claude-opus-4-20250514",
+        display_name: "Claude Opus 4",
+        max_tokens: 200_000,
+        max_output_tokens: 8_192,
+        supports_images: true,
+        supports_thinking: false,
+        is_default: false,
+        is_default_fast: false,
+        beta_headers: Some("fine-grained-tool-streaming-2025-05-14"),
+        supports_caching: true,
+    },
+    ModelDefinition {
+        real_id: "claude-opus-4-20250514",
+        display_name: "Claude Opus 4 Thinking",
+        max_tokens: 200_000,
+        max_output_tokens: 8_192,
+        supports_images: true,
+        supports_thinking: true,
+        is_default: false,
+        is_default_fast: false,
+        beta_headers: Some("fine-grained-tool-streaming-2025-05-14"),
+        supports_caching: true,
+    },
+    // Claude Sonnet 4.5
     ModelDefinition {
         real_id: "claude-sonnet-4-5-20250929",
         display_name: "Claude Sonnet 4.5",
@@ -66,6 +124,8 @@ const MODELS: &[ModelDefinition] = &[
         supports_thinking: false,
         is_default: true,
         is_default_fast: false,
+        beta_headers: Some("fine-grained-tool-streaming-2025-05-14"),
+        supports_caching: true,
     },
     ModelDefinition {
         real_id: "claude-sonnet-4-5-20250929",
@@ -76,7 +136,10 @@ const MODELS: &[ModelDefinition] = &[
         supports_thinking: true,
         is_default: false,
         is_default_fast: false,
+        beta_headers: Some("fine-grained-tool-streaming-2025-05-14"),
+        supports_caching: true,
     },
+    // Claude Sonnet 4
     ModelDefinition {
         real_id: "claude-sonnet-4-20250514",
         display_name: "Claude Sonnet 4",
@@ -86,6 +149,8 @@ const MODELS: &[ModelDefinition] = &[
         supports_thinking: false,
         is_default: false,
         is_default_fast: false,
+        beta_headers: Some("fine-grained-tool-streaming-2025-05-14"),
+        supports_caching: true,
     },
     ModelDefinition {
         real_id: "claude-sonnet-4-20250514",
@@ -96,7 +161,52 @@ const MODELS: &[ModelDefinition] = &[
         supports_thinking: true,
         is_default: false,
         is_default_fast: false,
+        beta_headers: Some("fine-grained-tool-streaming-2025-05-14"),
+        supports_caching: true,
+    },
+    // Claude 3.7 Sonnet
+    ModelDefinition {
+        real_id: "claude-3-7-sonnet-latest",
+        display_name: "Claude 3.7 Sonnet",
+        max_tokens: 200_000,
+        max_output_tokens: 8_192,
+        supports_images: true,
+        supports_thinking: false,
+        is_default: false,
+        is_default_fast: false,
+        beta_headers: Some(
+            "token-efficient-tools-2025-02-19,fine-grained-tool-streaming-2025-05-14",
+        ),
+        supports_caching: true,
+    },
+    ModelDefinition {
+        real_id: "claude-3-7-sonnet-latest",
+        display_name: "Claude 3.7 Sonnet Thinking",
+        max_tokens: 200_000,
+        max_output_tokens: 8_192,
+        supports_images: true,
+        supports_thinking: true,
+        is_default: false,
+        is_default_fast: false,
+        beta_headers: Some(
+            "token-efficient-tools-2025-02-19,fine-grained-tool-streaming-2025-05-14",
+        ),
+        supports_caching: true,
+    },
+    // Claude 3.5 Sonnet
+    ModelDefinition {
+        real_id: "claude-3-5-sonnet-latest",
+        display_name: "Claude 3.5 Sonnet",
+        max_tokens: 200_000,
+        max_output_tokens: 8_192,
+        supports_images: true,
+        supports_thinking: false,
+        is_default: false,
+        is_default_fast: false,
+        beta_headers: None,
+        supports_caching: true,
     },
+    // Claude Haiku 4.5
     ModelDefinition {
         real_id: "claude-haiku-4-5-20251001",
         display_name: "Claude Haiku 4.5",
@@ -106,6 +216,8 @@ const MODELS: &[ModelDefinition] = &[
         supports_thinking: false,
         is_default: false,
         is_default_fast: true,
+        beta_headers: None,
+        supports_caching: true,
     },
     ModelDefinition {
         real_id: "claude-haiku-4-5-20251001",
@@ -116,26 +228,60 @@ const MODELS: &[ModelDefinition] = &[
         supports_thinking: true,
         is_default: false,
         is_default_fast: false,
+        beta_headers: None,
+        supports_caching: true,
     },
+    // Claude 3.5 Haiku
     ModelDefinition {
-        real_id: "claude-3-5-sonnet-latest",
-        display_name: "Claude 3.5 Sonnet",
+        real_id: "claude-3-5-haiku-latest",
+        display_name: "Claude 3.5 Haiku",
         max_tokens: 200_000,
         max_output_tokens: 8_192,
         supports_images: true,
         supports_thinking: false,
         is_default: false,
         is_default_fast: false,
+        beta_headers: None,
+        supports_caching: true,
     },
+    // Claude 3 Opus
     ModelDefinition {
-        real_id: "claude-3-5-haiku-latest",
-        display_name: "Claude 3.5 Haiku",
+        real_id: "claude-3-opus-latest",
+        display_name: "Claude 3 Opus",
         max_tokens: 200_000,
-        max_output_tokens: 8_192,
+        max_output_tokens: 4_096,
         supports_images: true,
         supports_thinking: false,
         is_default: false,
         is_default_fast: false,
+        beta_headers: None,
+        supports_caching: false,
+    },
+    // Claude 3 Sonnet
+    ModelDefinition {
+        real_id: "claude-3-sonnet-20240229",
+        display_name: "Claude 3 Sonnet",
+        max_tokens: 200_000,
+        max_output_tokens: 4_096,
+        supports_images: true,
+        supports_thinking: false,
+        is_default: false,
+        is_default_fast: false,
+        beta_headers: None,
+        supports_caching: false,
+    },
+    // Claude 3 Haiku
+    ModelDefinition {
+        real_id: "claude-3-haiku-20240307",
+        display_name: "Claude 3 Haiku",
+        max_tokens: 200_000,
+        max_output_tokens: 4_096,
+        supports_images: true,
+        supports_thinking: false,
+        is_default: false,
+        is_default_fast: false,
+        beta_headers: None,
+        supports_caching: true,
     },
 ];
 
@@ -145,13 +291,23 @@ fn get_model_definition(display_name: &str) -> Option<&'static ModelDefinition>
 
 // Anthropic API Request Types
 
+#[derive(Serialize, Clone)]
+struct CacheControl {
+    #[serde(rename = "type")]
+    cache_type: &'static str,
+}
+
+const EPHEMERAL_CACHE: CacheControl = CacheControl {
+    cache_type: "ephemeral",
+};
+
 #[derive(Serialize)]
 struct AnthropicRequest {
     model: String,
     max_tokens: u64,
     messages: Vec<AnthropicMessage>,
     #[serde(skip_serializing_if = "Option::is_none")]
-    system: Option<String>,
+    system: Option<AnthropicSystemContent>,
     #[serde(skip_serializing_if = "Option::is_none")]
     thinking: Option<AnthropicThinking>,
     #[serde(skip_serializing_if = "Vec::is_empty")]
@@ -165,6 +321,22 @@ struct AnthropicRequest {
     stream: bool,
 }
 
+#[derive(Serialize)]
+#[serde(untagged)]
+enum AnthropicSystemContent {
+    String(String),
+    Blocks(Vec<AnthropicSystemBlock>),
+}
+
+#[derive(Serialize)]
+struct AnthropicSystemBlock {
+    #[serde(rename = "type")]
+    block_type: &'static str,
+    text: String,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    cache_control: Option<CacheControl>,
+}
+
 #[derive(Serialize)]
 struct AnthropicThinking {
     #[serde(rename = "type")]
@@ -183,27 +355,58 @@ struct AnthropicMessage {
 #[serde(tag = "type")]
 enum AnthropicContent {
     #[serde(rename = "text")]
-    Text { text: String },
+    Text {
+        text: String,
+        #[serde(skip_serializing_if = "Option::is_none")]
+        cache_control: Option<CacheControl>,
+    },
     #[serde(rename = "thinking")]
-    Thinking { thinking: String, signature: String },
+    Thinking {
+        thinking: String,
+        signature: String,
+        #[serde(skip_serializing_if = "Option::is_none")]
+        cache_control: Option<CacheControl>,
+    },
     #[serde(rename = "redacted_thinking")]
     RedactedThinking { data: String },
     #[serde(rename = "image")]
-    Image { source: AnthropicImageSource },
+    Image {
+        source: AnthropicImageSource,
+        #[serde(skip_serializing_if = "Option::is_none")]
+        cache_control: Option<CacheControl>,
+    },
     #[serde(rename = "tool_use")]
     ToolUse {
         id: String,
         name: String,
         input: serde_json::Value,
+        #[serde(skip_serializing_if = "Option::is_none")]
+        cache_control: Option<CacheControl>,
     },
     #[serde(rename = "tool_result")]
     ToolResult {
         tool_use_id: String,
         is_error: bool,
-        content: String,
+        content: AnthropicToolResultContent,
+        #[serde(skip_serializing_if = "Option::is_none")]
+        cache_control: Option<CacheControl>,
     },
 }
 
+#[derive(Serialize, Clone)]
+#[serde(untagged)]
+enum AnthropicToolResultContent {
+    Plain(String),
+    Multipart(Vec<AnthropicToolResultPart>),
+}
+
+#[derive(Serialize, Clone)]
+#[serde(tag = "type", rename_all = "lowercase")]
+enum AnthropicToolResultPart {
+    Text { text: String },
+    Image { source: AnthropicImageSource },
+}
+
 #[derive(Serialize, Clone)]
 struct AnthropicImageSource {
     #[serde(rename = "type")]
@@ -217,6 +420,8 @@ struct AnthropicTool {
     name: String,
     description: String,
     input_schema: serde_json::Value,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    cache_control: Option<CacheControl>,
 }
 
 #[derive(Serialize)]
@@ -318,15 +523,83 @@ struct AnthropicApiError {
     message: String,
 }
 
+fn detect_media_type(data: &str) -> String {
+    if let Some(decoded) = base64_decode_first_bytes(data, 12) {
+        if decoded.starts_with(&[0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) {
+            return "image/png".to_string();
+        }
+        if decoded.starts_with(&[0xFF, 0xD8, 0xFF]) {
+            return "image/jpeg".to_string();
+        }
+        if decoded.starts_with(&[0x47, 0x49, 0x46, 0x38]) {
+            return "image/gif".to_string();
+        }
+        if decoded.len() >= 12 && &decoded[0..4] == b"RIFF" && &decoded[8..12] == b"WEBP" {
+            return "image/webp".to_string();
+        }
+    }
+    "image/png".to_string()
+}
+
+fn base64_decode_first_bytes(data: &str, num_bytes: usize) -> Option<Vec<u8>> {
+    let chars_needed = ((num_bytes + 2) / 3) * 4;
+    let prefix: String = data.chars().take(chars_needed).collect();
+
+    const BASE64_TABLE: &[u8; 64] =
+        b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
+
+    fn decode_char(c: char) -> Option<u8> {
+        if c == '=' {
+            return Some(0);
+        }
+        BASE64_TABLE
+            .iter()
+            .position(|&b| b == c as u8)
+            .map(|p| p as u8)
+    }
+
+    let chars: Vec<char> = prefix.chars().filter(|c| !c.is_whitespace()).collect();
+    if chars.len() < 4 {
+        return None;
+    }
+
+    let mut result = Vec::new();
+    for chunk in chars.chunks(4) {
+        if chunk.len() < 4 {
+            break;
+        }
+        let a = decode_char(chunk[0])?;
+        let b = decode_char(chunk[1])?;
+        let c = decode_char(chunk[2])?;
+        let d = decode_char(chunk[3])?;
+
+        result.push((a << 2) | (b >> 4));
+        if chunk[2] != '=' {
+            result.push((b << 4) | (c >> 2));
+        }
+        if chunk[3] != '=' {
+            result.push((c << 6) | d);
+        }
+
+        if result.len() >= num_bytes {
+            break;
+        }
+    }
+
+    Some(result)
+}
+
 fn convert_request(
     model_id: &str,
     request: &LlmCompletionRequest,
+    enable_caching: bool,
 ) -> Result<AnthropicRequest, String> {
     let model_def =
         get_model_definition(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
 
     let mut messages: Vec<AnthropicMessage> = Vec::new();
     let mut system_message = String::new();
+    let mut system_should_cache = false;
 
     for msg in &request.messages {
         match msg.role {
@@ -339,6 +612,9 @@ fn convert_request(
                         system_message.push_str(text);
                     }
                 }
+                if msg.cache {
+                    system_should_cache = true;
+                }
             }
             LlmMessageRole::User => {
                 let mut contents: Vec<AnthropicContent> = Vec::new();
@@ -347,27 +623,46 @@ fn convert_request(
                     match content {
                         LlmMessageContent::Text(text) => {
                             if !text.is_empty() {
-                                contents.push(AnthropicContent::Text { text: text.clone() });
+                                contents.push(AnthropicContent::Text {
+                                    text: text.clone(),
+                                    cache_control: None,
+                                });
                             }
                         }
                         LlmMessageContent::Image(img) => {
+                            let media_type = detect_media_type(&img.source);
                             contents.push(AnthropicContent::Image {
                                 source: AnthropicImageSource {
                                     source_type: "base64".to_string(),
-                                    media_type: "image/png".to_string(),
+                                    media_type,
                                     data: img.source.clone(),
                                 },
+                                cache_control: None,
                             });
                         }
                         LlmMessageContent::ToolResult(result) => {
-                            let content_text = match &result.content {
-                                LlmToolResultContent::Text(t) => t.clone(),
-                                LlmToolResultContent::Image(_) => "[Image]".to_string(),
+                            let content = match &result.content {
+                                LlmToolResultContent::Text(t) => {
+                                    AnthropicToolResultContent::Plain(t.clone())
+                                }
+                                LlmToolResultContent::Image(img) => {
+                                    let media_type = detect_media_type(&img.source);
+                                    AnthropicToolResultContent::Multipart(vec![
+                                        AnthropicToolResultPart::Image {
+                                            source: AnthropicImageSource {
+                                                source_type: "base64".to_string(),
+                                                media_type,
+                                                data: img.source.clone(),
+                                            },
+                                        },
+                                    ])
+                                }
                             };
                             contents.push(AnthropicContent::ToolResult {
                                 tool_use_id: result.tool_use_id.clone(),
                                 is_error: result.is_error,
-                                content: content_text,
+                                content,
+                                cache_control: None,
                             });
                         }
                         _ => {}
@@ -375,6 +670,11 @@ fn convert_request(
                 }
 
                 if !contents.is_empty() {
+                    if enable_caching && msg.cache {
+                        if let Some(last) = contents.last_mut() {
+                            set_cache_control(last);
+                        }
+                    }
                     messages.push(AnthropicMessage {
                         role: "user".to_string(),
                         content: contents,
@@ -388,7 +688,10 @@ fn convert_request(
                     match content {
                         LlmMessageContent::Text(text) => {
                             if !text.is_empty() {
-                                contents.push(AnthropicContent::Text { text: text.clone() });
+                                contents.push(AnthropicContent::Text {
+                                    text: text.clone(),
+                                    cache_control: None,
+                                });
                             }
                         }
                         LlmMessageContent::ToolUse(tool_use) => {
@@ -398,6 +701,7 @@ fn convert_request(
                                 id: tool_use.id.clone(),
                                 name: tool_use.name.clone(),
                                 input,
+                                cache_control: None,
                             });
                         }
                         LlmMessageContent::Thinking(thinking) => {
@@ -405,6 +709,7 @@ fn convert_request(
                                 contents.push(AnthropicContent::Thinking {
                                     thinking: thinking.text.clone(),
                                     signature: thinking.signature.clone().unwrap_or_default(),
+                                    cache_control: None,
                                 });
                             }
                         }
@@ -420,6 +725,11 @@ fn convert_request(
                 }
 
                 if !contents.is_empty() {
+                    if enable_caching && msg.cache {
+                        if let Some(last) = contents.last_mut() {
+                            set_cache_control(last);
+                        }
+                    }
                     messages.push(AnthropicMessage {
                         role: "assistant".to_string(),
                         content: contents,
@@ -429,7 +739,7 @@ fn convert_request(
         }
     }
 
-    let tools: Vec<AnthropicTool> = request
+    let mut tools: Vec<AnthropicTool> = request
         .tools
         .iter()
         .map(|t| AnthropicTool {
@@ -437,9 +747,16 @@ fn convert_request(
             description: t.description.clone(),
             input_schema: serde_json::from_str(&t.input_schema)
                 .unwrap_or(serde_json::Value::Object(Default::default())),
+            cache_control: None,
         })
         .collect();
 
+    if enable_caching && !tools.is_empty() {
+        if let Some(last_tool) = tools.last_mut() {
+            last_tool.cache_control = Some(EPHEMERAL_CACHE);
+        }
+    }
+
     let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
         LlmToolChoice::Auto => AnthropicToolChoice::Auto,
         LlmToolChoice::Any => AnthropicToolChoice::Any,
@@ -455,15 +772,23 @@ fn convert_request(
         None
     };
 
+    let system = if system_message.is_empty() {
+        None
+    } else if enable_caching && system_should_cache {
+        Some(AnthropicSystemContent::Blocks(vec![AnthropicSystemBlock {
+            block_type: "text",
+            text: system_message,
+            cache_control: Some(EPHEMERAL_CACHE),
+        }]))
+    } else {
+        Some(AnthropicSystemContent::String(system_message))
+    };
+
     Ok(AnthropicRequest {
         model: model_def.real_id.to_string(),
         max_tokens: model_def.max_output_tokens,
         messages,
-        system: if system_message.is_empty() {
-            None
-        } else {
-            Some(system_message)
-        },
+        system,
         thinking,
         tools,
         tool_choice,
@@ -473,6 +798,27 @@ fn convert_request(
     })
 }
 
+fn set_cache_control(content: &mut AnthropicContent) {
+    match content {
+        AnthropicContent::Text { cache_control, .. } => {
+            *cache_control = Some(EPHEMERAL_CACHE);
+        }
+        AnthropicContent::Thinking { cache_control, .. } => {
+            *cache_control = Some(EPHEMERAL_CACHE);
+        }
+        AnthropicContent::Image { cache_control, .. } => {
+            *cache_control = Some(EPHEMERAL_CACHE);
+        }
+        AnthropicContent::ToolUse { cache_control, .. } => {
+            *cache_control = Some(EPHEMERAL_CACHE);
+        }
+        AnthropicContent::ToolResult { cache_control, .. } => {
+            *cache_control = Some(EPHEMERAL_CACHE);
+        }
+        AnthropicContent::RedactedThinking { .. } => {}
+    }
+}
+
 fn parse_sse_line(line: &str) -> Option<AnthropicEvent> {
     let data = line.strip_prefix("data: ")?;
     serde_json::from_str(data).ok()
@@ -531,6 +877,23 @@ impl zed::Extension for AnthropicProvider {
         llm_delete_credential("anthropic")
     }
 
+    fn llm_cache_configuration(
+        &self,
+        _provider_id: &str,
+        model_id: &str,
+    ) -> Option<LlmCacheConfiguration> {
+        let model_def = get_model_definition(model_id)?;
+        if model_def.supports_caching {
+            Some(LlmCacheConfiguration {
+                max_cache_anchors: 4,
+                should_cache_tool_definitions: true,
+                min_total_token_count: 2048,
+            })
+        } else {
+            None
+        }
+    }
+
     fn llm_stream_completion_start(
         &mut self,
         _provider_id: &str,
@@ -541,19 +904,29 @@ impl zed::Extension for AnthropicProvider {
             "No API key configured. Please add your Anthropic API key in settings.".to_string()
         })?;
 
-        let anthropic_request = convert_request(model_id, request)?;
+        let model_def =
+            get_model_definition(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
+
+        let enable_caching = model_def.supports_caching;
+        let anthropic_request = convert_request(model_id, request, enable_caching)?;
 
         let body = serde_json::to_vec(&anthropic_request)
             .map_err(|e| format!("Failed to serialize request: {}", e))?;
 
+        let mut headers = vec![
+            ("Content-Type".to_string(), "application/json".to_string()),
+            ("x-api-key".to_string(), api_key),
+            ("anthropic-version".to_string(), "2023-06-01".to_string()),
+        ];
+
+        if let Some(beta) = model_def.beta_headers {
+            headers.push(("anthropic-beta".to_string(), beta.to_string()));
+        }
+
         let http_request = HttpRequest {
             method: HttpMethod::Post,
             url: "https://api.anthropic.com/v1/messages".to_string(),
-            headers: vec![
-                ("Content-Type".to_string(), "application/json".to_string()),
-                ("x-api-key".to_string(), api_key),
-                ("anthropic-version".to_string(), "2023-06-01".to_string()),
-            ],
+            headers,
             body: Some(body),
             redirect_policy: RedirectPolicy::FollowAll,
         };
@@ -701,6 +1074,7 @@ impl zed::Extension for AnthropicProvider {
                                     "end_turn" => LlmStopReason::EndTurn,
                                     "max_tokens" => LlmStopReason::MaxTokens,
                                     "tool_use" => LlmStopReason::ToolUse,
+                                    "refusal" => LlmStopReason::Refusal,
                                     _ => LlmStopReason::EndTurn,
                                 });
                             }