anthropic.rs

  1use std::collections::HashMap;
  2use std::sync::Mutex;
  3
  4use serde::{Deserialize, Serialize};
  5use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
  6use zed_extension_api::{self as zed, *};
  7
  8struct AnthropicProvider {
  9    streams: Mutex<HashMap<String, StreamState>>,
 10    next_stream_id: Mutex<u64>,
 11}
 12
 13struct StreamState {
 14    response_stream: Option<HttpResponseStream>,
 15    buffer: String,
 16    started: bool,
 17    current_tool_use: Option<ToolUseState>,
 18    stop_reason: Option<LlmStopReason>,
 19    pending_signature: Option<String>,
 20}
 21
 22struct ToolUseState {
 23    id: String,
 24    name: String,
 25    input_json: String,
 26}
 27
 28struct ModelDefinition {
 29    real_id: &'static str,
 30    display_name: &'static str,
 31    max_tokens: u64,
 32    max_output_tokens: u64,
 33    supports_images: bool,
 34    supports_thinking: bool,
 35    is_default: bool,
 36    is_default_fast: bool,
 37}
 38
 39const MODELS: &[ModelDefinition] = &[
 40    ModelDefinition {
 41        real_id: "claude-opus-4-5-20251101",
 42        display_name: "Claude Opus 4.5",
 43        max_tokens: 200_000,
 44        max_output_tokens: 8_192,
 45        supports_images: true,
 46        supports_thinking: false,
 47        is_default: false,
 48        is_default_fast: false,
 49    },
 50    ModelDefinition {
 51        real_id: "claude-opus-4-5-20251101",
 52        display_name: "Claude Opus 4.5 Thinking",
 53        max_tokens: 200_000,
 54        max_output_tokens: 8_192,
 55        supports_images: true,
 56        supports_thinking: true,
 57        is_default: false,
 58        is_default_fast: false,
 59    },
 60    ModelDefinition {
 61        real_id: "claude-sonnet-4-5-20250929",
 62        display_name: "Claude Sonnet 4.5",
 63        max_tokens: 200_000,
 64        max_output_tokens: 8_192,
 65        supports_images: true,
 66        supports_thinking: false,
 67        is_default: true,
 68        is_default_fast: false,
 69    },
 70    ModelDefinition {
 71        real_id: "claude-sonnet-4-5-20250929",
 72        display_name: "Claude Sonnet 4.5 Thinking",
 73        max_tokens: 200_000,
 74        max_output_tokens: 8_192,
 75        supports_images: true,
 76        supports_thinking: true,
 77        is_default: false,
 78        is_default_fast: false,
 79    },
 80    ModelDefinition {
 81        real_id: "claude-sonnet-4-20250514",
 82        display_name: "Claude Sonnet 4",
 83        max_tokens: 200_000,
 84        max_output_tokens: 8_192,
 85        supports_images: true,
 86        supports_thinking: false,
 87        is_default: false,
 88        is_default_fast: false,
 89    },
 90    ModelDefinition {
 91        real_id: "claude-sonnet-4-20250514",
 92        display_name: "Claude Sonnet 4 Thinking",
 93        max_tokens: 200_000,
 94        max_output_tokens: 8_192,
 95        supports_images: true,
 96        supports_thinking: true,
 97        is_default: false,
 98        is_default_fast: false,
 99    },
100    ModelDefinition {
101        real_id: "claude-haiku-4-5-20251001",
102        display_name: "Claude Haiku 4.5",
103        max_tokens: 200_000,
104        max_output_tokens: 64_000,
105        supports_images: true,
106        supports_thinking: false,
107        is_default: false,
108        is_default_fast: true,
109    },
110    ModelDefinition {
111        real_id: "claude-haiku-4-5-20251001",
112        display_name: "Claude Haiku 4.5 Thinking",
113        max_tokens: 200_000,
114        max_output_tokens: 64_000,
115        supports_images: true,
116        supports_thinking: true,
117        is_default: false,
118        is_default_fast: false,
119    },
120    ModelDefinition {
121        real_id: "claude-3-5-sonnet-latest",
122        display_name: "Claude 3.5 Sonnet",
123        max_tokens: 200_000,
124        max_output_tokens: 8_192,
125        supports_images: true,
126        supports_thinking: false,
127        is_default: false,
128        is_default_fast: false,
129    },
130    ModelDefinition {
131        real_id: "claude-3-5-haiku-latest",
132        display_name: "Claude 3.5 Haiku",
133        max_tokens: 200_000,
134        max_output_tokens: 8_192,
135        supports_images: true,
136        supports_thinking: false,
137        is_default: false,
138        is_default_fast: false,
139    },
140];
141
142fn get_model_definition(display_name: &str) -> Option<&'static ModelDefinition> {
143    MODELS.iter().find(|m| m.display_name == display_name)
144}
145
146// Anthropic API Request Types
147
148#[derive(Serialize)]
149struct AnthropicRequest {
150    model: String,
151    max_tokens: u64,
152    messages: Vec<AnthropicMessage>,
153    #[serde(skip_serializing_if = "Option::is_none")]
154    system: Option<String>,
155    #[serde(skip_serializing_if = "Option::is_none")]
156    thinking: Option<AnthropicThinking>,
157    #[serde(skip_serializing_if = "Vec::is_empty")]
158    tools: Vec<AnthropicTool>,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    tool_choice: Option<AnthropicToolChoice>,
161    #[serde(skip_serializing_if = "Vec::is_empty")]
162    stop_sequences: Vec<String>,
163    #[serde(skip_serializing_if = "Option::is_none")]
164    temperature: Option<f32>,
165    stream: bool,
166}
167
168#[derive(Serialize)]
169struct AnthropicThinking {
170    #[serde(rename = "type")]
171    thinking_type: String,
172    #[serde(skip_serializing_if = "Option::is_none")]
173    budget_tokens: Option<u32>,
174}
175
176#[derive(Serialize)]
177struct AnthropicMessage {
178    role: String,
179    content: Vec<AnthropicContent>,
180}
181
182#[derive(Serialize, Clone)]
183#[serde(tag = "type")]
184enum AnthropicContent {
185    #[serde(rename = "text")]
186    Text { text: String },
187    #[serde(rename = "thinking")]
188    Thinking { thinking: String, signature: String },
189    #[serde(rename = "redacted_thinking")]
190    RedactedThinking { data: String },
191    #[serde(rename = "image")]
192    Image { source: AnthropicImageSource },
193    #[serde(rename = "tool_use")]
194    ToolUse {
195        id: String,
196        name: String,
197        input: serde_json::Value,
198    },
199    #[serde(rename = "tool_result")]
200    ToolResult {
201        tool_use_id: String,
202        is_error: bool,
203        content: String,
204    },
205}
206
207#[derive(Serialize, Clone)]
208struct AnthropicImageSource {
209    #[serde(rename = "type")]
210    source_type: String,
211    media_type: String,
212    data: String,
213}
214
215#[derive(Serialize)]
216struct AnthropicTool {
217    name: String,
218    description: String,
219    input_schema: serde_json::Value,
220}
221
222#[derive(Serialize)]
223#[serde(tag = "type", rename_all = "lowercase")]
224enum AnthropicToolChoice {
225    Auto,
226    Any,
227    None,
228}
229
230// Anthropic API Response Types
231
232#[derive(Deserialize, Debug)]
233#[serde(tag = "type")]
234#[allow(dead_code)]
235enum AnthropicEvent {
236    #[serde(rename = "message_start")]
237    MessageStart { message: AnthropicMessageResponse },
238    #[serde(rename = "content_block_start")]
239    ContentBlockStart {
240        index: usize,
241        content_block: AnthropicContentBlock,
242    },
243    #[serde(rename = "content_block_delta")]
244    ContentBlockDelta { index: usize, delta: AnthropicDelta },
245    #[serde(rename = "content_block_stop")]
246    ContentBlockStop { index: usize },
247    #[serde(rename = "message_delta")]
248    MessageDelta {
249        delta: AnthropicMessageDelta,
250        usage: AnthropicUsage,
251    },
252    #[serde(rename = "message_stop")]
253    MessageStop,
254    #[serde(rename = "ping")]
255    Ping,
256    #[serde(rename = "error")]
257    Error { error: AnthropicApiError },
258}
259
260#[derive(Deserialize, Debug)]
261struct AnthropicMessageResponse {
262    #[allow(dead_code)]
263    id: String,
264    #[allow(dead_code)]
265    role: String,
266    #[serde(default)]
267    usage: AnthropicUsage,
268}
269
270#[derive(Deserialize, Debug)]
271#[serde(tag = "type")]
272enum AnthropicContentBlock {
273    #[serde(rename = "text")]
274    Text { text: String },
275    #[serde(rename = "thinking")]
276    Thinking { thinking: String },
277    #[serde(rename = "redacted_thinking")]
278    RedactedThinking { data: String },
279    #[serde(rename = "tool_use")]
280    ToolUse { id: String, name: String },
281}
282
283#[derive(Deserialize, Debug)]
284#[serde(tag = "type")]
285enum AnthropicDelta {
286    #[serde(rename = "text_delta")]
287    TextDelta { text: String },
288    #[serde(rename = "thinking_delta")]
289    ThinkingDelta { thinking: String },
290    #[serde(rename = "signature_delta")]
291    SignatureDelta { signature: String },
292    #[serde(rename = "input_json_delta")]
293    InputJsonDelta { partial_json: String },
294}
295
296#[derive(Deserialize, Debug)]
297struct AnthropicMessageDelta {
298    stop_reason: Option<String>,
299}
300
301#[derive(Deserialize, Debug, Default)]
302struct AnthropicUsage {
303    #[serde(default)]
304    input_tokens: Option<u64>,
305    #[serde(default)]
306    output_tokens: Option<u64>,
307    #[serde(default)]
308    cache_creation_input_tokens: Option<u64>,
309    #[serde(default)]
310    cache_read_input_tokens: Option<u64>,
311}
312
313#[derive(Deserialize, Debug)]
314struct AnthropicApiError {
315    #[serde(rename = "type")]
316    #[allow(dead_code)]
317    error_type: String,
318    message: String,
319}
320
321fn convert_request(
322    model_id: &str,
323    request: &LlmCompletionRequest,
324) -> Result<AnthropicRequest, String> {
325    let model_def =
326        get_model_definition(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
327
328    let mut messages: Vec<AnthropicMessage> = Vec::new();
329    let mut system_message = String::new();
330
331    for msg in &request.messages {
332        match msg.role {
333            LlmMessageRole::System => {
334                for content in &msg.content {
335                    if let LlmMessageContent::Text(text) = content {
336                        if !system_message.is_empty() {
337                            system_message.push('\n');
338                        }
339                        system_message.push_str(text);
340                    }
341                }
342            }
343            LlmMessageRole::User => {
344                let mut contents: Vec<AnthropicContent> = Vec::new();
345
346                for content in &msg.content {
347                    match content {
348                        LlmMessageContent::Text(text) => {
349                            if !text.is_empty() {
350                                contents.push(AnthropicContent::Text { text: text.clone() });
351                            }
352                        }
353                        LlmMessageContent::Image(img) => {
354                            contents.push(AnthropicContent::Image {
355                                source: AnthropicImageSource {
356                                    source_type: "base64".to_string(),
357                                    media_type: "image/png".to_string(),
358                                    data: img.source.clone(),
359                                },
360                            });
361                        }
362                        LlmMessageContent::ToolResult(result) => {
363                            let content_text = match &result.content {
364                                LlmToolResultContent::Text(t) => t.clone(),
365                                LlmToolResultContent::Image(_) => "[Image]".to_string(),
366                            };
367                            contents.push(AnthropicContent::ToolResult {
368                                tool_use_id: result.tool_use_id.clone(),
369                                is_error: result.is_error,
370                                content: content_text,
371                            });
372                        }
373                        _ => {}
374                    }
375                }
376
377                if !contents.is_empty() {
378                    messages.push(AnthropicMessage {
379                        role: "user".to_string(),
380                        content: contents,
381                    });
382                }
383            }
384            LlmMessageRole::Assistant => {
385                let mut contents: Vec<AnthropicContent> = Vec::new();
386
387                for content in &msg.content {
388                    match content {
389                        LlmMessageContent::Text(text) => {
390                            if !text.is_empty() {
391                                contents.push(AnthropicContent::Text { text: text.clone() });
392                            }
393                        }
394                        LlmMessageContent::ToolUse(tool_use) => {
395                            let input: serde_json::Value =
396                                serde_json::from_str(&tool_use.input).unwrap_or_default();
397                            contents.push(AnthropicContent::ToolUse {
398                                id: tool_use.id.clone(),
399                                name: tool_use.name.clone(),
400                                input,
401                            });
402                        }
403                        LlmMessageContent::Thinking(thinking) => {
404                            if !thinking.text.is_empty() {
405                                contents.push(AnthropicContent::Thinking {
406                                    thinking: thinking.text.clone(),
407                                    signature: thinking.signature.clone().unwrap_or_default(),
408                                });
409                            }
410                        }
411                        LlmMessageContent::RedactedThinking(data) => {
412                            if !data.is_empty() {
413                                contents.push(AnthropicContent::RedactedThinking {
414                                    data: data.clone(),
415                                });
416                            }
417                        }
418                        _ => {}
419                    }
420                }
421
422                if !contents.is_empty() {
423                    messages.push(AnthropicMessage {
424                        role: "assistant".to_string(),
425                        content: contents,
426                    });
427                }
428            }
429        }
430    }
431
432    let tools: Vec<AnthropicTool> = request
433        .tools
434        .iter()
435        .map(|t| AnthropicTool {
436            name: t.name.clone(),
437            description: t.description.clone(),
438            input_schema: serde_json::from_str(&t.input_schema)
439                .unwrap_or(serde_json::Value::Object(Default::default())),
440        })
441        .collect();
442
443    let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
444        LlmToolChoice::Auto => AnthropicToolChoice::Auto,
445        LlmToolChoice::Any => AnthropicToolChoice::Any,
446        LlmToolChoice::None => AnthropicToolChoice::None,
447    });
448
449    let thinking = if model_def.supports_thinking && request.thinking_allowed {
450        Some(AnthropicThinking {
451            thinking_type: "enabled".to_string(),
452            budget_tokens: Some(4096),
453        })
454    } else {
455        None
456    };
457
458    Ok(AnthropicRequest {
459        model: model_def.real_id.to_string(),
460        max_tokens: model_def.max_output_tokens,
461        messages,
462        system: if system_message.is_empty() {
463            None
464        } else {
465            Some(system_message)
466        },
467        thinking,
468        tools,
469        tool_choice,
470        stop_sequences: request.stop_sequences.clone(),
471        temperature: request.temperature,
472        stream: true,
473    })
474}
475
476fn parse_sse_line(line: &str) -> Option<AnthropicEvent> {
477    let data = line.strip_prefix("data: ")?;
478    serde_json::from_str(data).ok()
479}
480
481impl zed::Extension for AnthropicProvider {
482    fn new() -> Self {
483        Self {
484            streams: Mutex::new(HashMap::new()),
485            next_stream_id: Mutex::new(0),
486        }
487    }
488
489    fn llm_providers(&self) -> Vec<LlmProviderInfo> {
490        vec![LlmProviderInfo {
491            id: "anthropic".into(),
492            name: "Anthropic".into(),
493            icon: Some("icons/anthropic.svg".into()),
494        }]
495    }
496
497    fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
498        Ok(MODELS
499            .iter()
500            .map(|m| LlmModelInfo {
501                id: m.display_name.to_string(),
502                name: m.display_name.to_string(),
503                max_token_count: m.max_tokens,
504                max_output_tokens: Some(m.max_output_tokens),
505                capabilities: LlmModelCapabilities {
506                    supports_images: m.supports_images,
507                    supports_tools: true,
508                    supports_tool_choice_auto: true,
509                    supports_tool_choice_any: true,
510                    supports_tool_choice_none: true,
511                    supports_thinking: m.supports_thinking,
512                    tool_input_format: LlmToolInputFormat::JsonSchema,
513                },
514                is_default: m.is_default,
515                is_default_fast: m.is_default_fast,
516            })
517            .collect())
518    }
519
520    fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
521        llm_get_credential("anthropic").is_some()
522    }
523
524    fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
525        Some(
526            "[Create an API key](https://console.anthropic.com/settings/keys) to use Anthropic as your LLM provider.".to_string(),
527        )
528    }
529
530    fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
531        llm_delete_credential("anthropic")
532    }
533
534    fn llm_stream_completion_start(
535        &mut self,
536        _provider_id: &str,
537        model_id: &str,
538        request: &LlmCompletionRequest,
539    ) -> Result<String, String> {
540        let api_key = llm_get_credential("anthropic").ok_or_else(|| {
541            "No API key configured. Please add your Anthropic API key in settings.".to_string()
542        })?;
543
544        let anthropic_request = convert_request(model_id, request)?;
545
546        let body = serde_json::to_vec(&anthropic_request)
547            .map_err(|e| format!("Failed to serialize request: {}", e))?;
548
549        let http_request = HttpRequest {
550            method: HttpMethod::Post,
551            url: "https://api.anthropic.com/v1/messages".to_string(),
552            headers: vec![
553                ("Content-Type".to_string(), "application/json".to_string()),
554                ("x-api-key".to_string(), api_key),
555                ("anthropic-version".to_string(), "2023-06-01".to_string()),
556            ],
557            body: Some(body),
558            redirect_policy: RedirectPolicy::FollowAll,
559        };
560
561        let response_stream = http_request
562            .fetch_stream()
563            .map_err(|e| format!("HTTP request failed: {}", e))?;
564
565        let stream_id = {
566            let mut id_counter = self.next_stream_id.lock().unwrap();
567            let id = format!("anthropic-stream-{}", *id_counter);
568            *id_counter += 1;
569            id
570        };
571
572        self.streams.lock().unwrap().insert(
573            stream_id.clone(),
574            StreamState {
575                response_stream: Some(response_stream),
576                buffer: String::new(),
577                started: false,
578                current_tool_use: None,
579                stop_reason: None,
580                pending_signature: None,
581            },
582        );
583
584        Ok(stream_id)
585    }
586
587    fn llm_stream_completion_next(
588        &mut self,
589        stream_id: &str,
590    ) -> Result<Option<LlmCompletionEvent>, String> {
591        let mut streams = self.streams.lock().unwrap();
592        let state = streams
593            .get_mut(stream_id)
594            .ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
595
596        if !state.started {
597            state.started = true;
598            return Ok(Some(LlmCompletionEvent::Started));
599        }
600
601        let response_stream = state
602            .response_stream
603            .as_mut()
604            .ok_or_else(|| "Stream already closed".to_string())?;
605
606        loop {
607            if let Some(newline_pos) = state.buffer.find('\n') {
608                let line = state.buffer[..newline_pos].to_string();
609                state.buffer = state.buffer[newline_pos + 1..].to_string();
610
611                if line.trim().is_empty() || line.starts_with("event:") {
612                    continue;
613                }
614
615                if let Some(event) = parse_sse_line(&line) {
616                    match event {
617                        AnthropicEvent::MessageStart { message } => {
618                            if let (Some(input), Some(output)) =
619                                (message.usage.input_tokens, message.usage.output_tokens)
620                            {
621                                return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
622                                    input_tokens: input,
623                                    output_tokens: output,
624                                    cache_creation_input_tokens: message
625                                        .usage
626                                        .cache_creation_input_tokens,
627                                    cache_read_input_tokens: message.usage.cache_read_input_tokens,
628                                })));
629                            }
630                        }
631                        AnthropicEvent::ContentBlockStart { content_block, .. } => {
632                            match content_block {
633                                AnthropicContentBlock::Text { text } => {
634                                    if !text.is_empty() {
635                                        return Ok(Some(LlmCompletionEvent::Text(text)));
636                                    }
637                                }
638                                AnthropicContentBlock::Thinking { thinking } => {
639                                    return Ok(Some(LlmCompletionEvent::Thinking(
640                                        LlmThinkingContent {
641                                            text: thinking,
642                                            signature: None,
643                                        },
644                                    )));
645                                }
646                                AnthropicContentBlock::RedactedThinking { data } => {
647                                    return Ok(Some(LlmCompletionEvent::RedactedThinking(data)));
648                                }
649                                AnthropicContentBlock::ToolUse { id, name } => {
650                                    state.current_tool_use = Some(ToolUseState {
651                                        id,
652                                        name,
653                                        input_json: String::new(),
654                                    });
655                                }
656                            }
657                        }
658                        AnthropicEvent::ContentBlockDelta { delta, .. } => match delta {
659                            AnthropicDelta::TextDelta { text } => {
660                                if !text.is_empty() {
661                                    return Ok(Some(LlmCompletionEvent::Text(text)));
662                                }
663                            }
664                            AnthropicDelta::ThinkingDelta { thinking } => {
665                                return Ok(Some(LlmCompletionEvent::Thinking(
666                                    LlmThinkingContent {
667                                        text: thinking,
668                                        signature: None,
669                                    },
670                                )));
671                            }
672                            AnthropicDelta::SignatureDelta { signature } => {
673                                state.pending_signature = Some(signature.clone());
674                                return Ok(Some(LlmCompletionEvent::Thinking(
675                                    LlmThinkingContent {
676                                        text: String::new(),
677                                        signature: Some(signature),
678                                    },
679                                )));
680                            }
681                            AnthropicDelta::InputJsonDelta { partial_json } => {
682                                if let Some(ref mut tool_use) = state.current_tool_use {
683                                    tool_use.input_json.push_str(&partial_json);
684                                }
685                            }
686                        },
687                        AnthropicEvent::ContentBlockStop { .. } => {
688                            if let Some(tool_use) = state.current_tool_use.take() {
689                                return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
690                                    id: tool_use.id,
691                                    name: tool_use.name,
692                                    input: tool_use.input_json,
693                                    is_input_complete: true,
694                                    thought_signature: state.pending_signature.take(),
695                                })));
696                            }
697                        }
698                        AnthropicEvent::MessageDelta { delta, usage } => {
699                            if let Some(reason) = delta.stop_reason {
700                                state.stop_reason = Some(match reason.as_str() {
701                                    "end_turn" => LlmStopReason::EndTurn,
702                                    "max_tokens" => LlmStopReason::MaxTokens,
703                                    "tool_use" => LlmStopReason::ToolUse,
704                                    _ => LlmStopReason::EndTurn,
705                                });
706                            }
707                            if let Some(output) = usage.output_tokens {
708                                return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
709                                    input_tokens: usage.input_tokens.unwrap_or(0),
710                                    output_tokens: output,
711                                    cache_creation_input_tokens: usage.cache_creation_input_tokens,
712                                    cache_read_input_tokens: usage.cache_read_input_tokens,
713                                })));
714                            }
715                        }
716                        AnthropicEvent::MessageStop => {
717                            if let Some(stop_reason) = state.stop_reason.take() {
718                                return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
719                            }
720                            return Ok(Some(LlmCompletionEvent::Stop(LlmStopReason::EndTurn)));
721                        }
722                        AnthropicEvent::Ping => {}
723                        AnthropicEvent::Error { error } => {
724                            return Err(format!("API error: {}", error.message));
725                        }
726                    }
727                }
728
729                continue;
730            }
731
732            match response_stream.next_chunk() {
733                Ok(Some(chunk)) => {
734                    let text = String::from_utf8_lossy(&chunk);
735                    state.buffer.push_str(&text);
736                }
737                Ok(None) => {
738                    if let Some(stop_reason) = state.stop_reason.take() {
739                        return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
740                    }
741                    return Ok(None);
742                }
743                Err(e) => {
744                    return Err(format!("Stream error: {}", e));
745                }
746            }
747        }
748    }
749
750    fn llm_stream_completion_close(&mut self, stream_id: &str) {
751        self.streams.lock().unwrap().remove(stream_id);
752    }
753}
754
755zed::register_extension!(AnthropicProvider);