anthropic.rs

  1mod supported_countries;
  2
  3use std::{pin::Pin, str::FromStr};
  4
  5use anyhow::{anyhow, Context as _, Result};
  6use chrono::{DateTime, Utc};
  7use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  8use http_client::http::{HeaderMap, HeaderValue};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use serde::{Deserialize, Serialize};
 11use strum::{EnumIter, EnumString};
 12use thiserror::Error;
 13use util::ResultExt as _;
 14
 15pub use supported_countries::*;
 16
 17pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
 18
 19#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 20#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 21pub struct AnthropicModelCacheConfiguration {
 22    pub min_total_token: usize,
 23    pub should_speculate: bool,
 24    pub max_cache_anchors: usize,
 25}
 26
 27#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 28#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 29pub enum Model {
 30    #[default]
 31    #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
 32    Claude3_5Sonnet,
 33    #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
 34    Claude3_5Haiku,
 35    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
 36    Claude3Opus,
 37    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
 38    Claude3Sonnet,
 39    #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-latest")]
 40    Claude3Haiku,
 41    #[serde(rename = "custom")]
 42    Custom {
 43        name: String,
 44        max_tokens: usize,
 45        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 46        display_name: Option<String>,
 47        /// Override this model with a different Anthropic model for tool calls.
 48        tool_override: Option<String>,
 49        /// Indicates whether this custom model supports caching.
 50        cache_configuration: Option<AnthropicModelCacheConfiguration>,
 51        max_output_tokens: Option<u32>,
 52        default_temperature: Option<f32>,
 53        #[serde(default)]
 54        extra_beta_headers: Vec<String>,
 55    },
 56}
 57
 58impl Model {
 59    pub fn from_id(id: &str) -> Result<Self> {
 60        if id.starts_with("claude-3-5-sonnet") {
 61            Ok(Self::Claude3_5Sonnet)
 62        } else if id.starts_with("claude-3-5-haiku") {
 63            Ok(Self::Claude3_5Haiku)
 64        } else if id.starts_with("claude-3-opus") {
 65            Ok(Self::Claude3Opus)
 66        } else if id.starts_with("claude-3-sonnet") {
 67            Ok(Self::Claude3Sonnet)
 68        } else if id.starts_with("claude-3-haiku") {
 69            Ok(Self::Claude3Haiku)
 70        } else {
 71            Err(anyhow!("invalid model id"))
 72        }
 73    }
 74
 75    pub fn id(&self) -> &str {
 76        match self {
 77            Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
 78            Model::Claude3_5Haiku => "claude-3-5-haiku-latest",
 79            Model::Claude3Opus => "claude-3-opus-latest",
 80            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 81            Model::Claude3Haiku => "claude-3-haiku-20240307",
 82            Self::Custom { name, .. } => name,
 83        }
 84    }
 85
 86    pub fn display_name(&self) -> &str {
 87        match self {
 88            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 89            Self::Claude3_5Haiku => "Claude 3.5 Haiku",
 90            Self::Claude3Opus => "Claude 3 Opus",
 91            Self::Claude3Sonnet => "Claude 3 Sonnet",
 92            Self::Claude3Haiku => "Claude 3 Haiku",
 93            Self::Custom {
 94                name, display_name, ..
 95            } => display_name.as_ref().unwrap_or(name),
 96        }
 97    }
 98
 99    pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
100        match self {
101            Self::Claude3_5Sonnet | Self::Claude3_5Haiku | Self::Claude3Haiku => {
102                Some(AnthropicModelCacheConfiguration {
103                    min_total_token: 2_048,
104                    should_speculate: true,
105                    max_cache_anchors: 4,
106                })
107            }
108            Self::Custom {
109                cache_configuration,
110                ..
111            } => cache_configuration.clone(),
112            _ => None,
113        }
114    }
115
116    pub fn max_token_count(&self) -> usize {
117        match self {
118            Self::Claude3_5Sonnet
119            | Self::Claude3_5Haiku
120            | Self::Claude3Opus
121            | Self::Claude3Sonnet
122            | Self::Claude3Haiku => 200_000,
123            Self::Custom { max_tokens, .. } => *max_tokens,
124        }
125    }
126
127    pub fn max_output_tokens(&self) -> u32 {
128        match self {
129            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
130            Self::Claude3_5Sonnet | Self::Claude3_5Haiku => 8_192,
131            Self::Custom {
132                max_output_tokens, ..
133            } => max_output_tokens.unwrap_or(4_096),
134        }
135    }
136
137    pub fn default_temperature(&self) -> f32 {
138        match self {
139            Self::Claude3_5Sonnet
140            | Self::Claude3_5Haiku
141            | Self::Claude3Opus
142            | Self::Claude3Sonnet
143            | Self::Claude3Haiku => 1.0,
144            Self::Custom {
145                default_temperature,
146                ..
147            } => default_temperature.unwrap_or(1.0),
148        }
149    }
150
151    pub const DEFAULT_BETA_HEADERS: &[&str] = &["prompt-caching-2024-07-31"];
152
153    pub fn beta_headers(&self) -> String {
154        let mut headers = Self::DEFAULT_BETA_HEADERS
155            .into_iter()
156            .map(|header| header.to_string())
157            .collect::<Vec<_>>();
158
159        if let Self::Custom {
160            extra_beta_headers, ..
161        } = self
162        {
163            headers.extend(
164                extra_beta_headers
165                    .iter()
166                    .filter(|header| !header.trim().is_empty())
167                    .cloned(),
168            );
169        }
170
171        headers.join(",")
172    }
173
174    pub fn tool_model_id(&self) -> &str {
175        if let Self::Custom {
176            tool_override: Some(tool_override),
177            ..
178        } = self
179        {
180            tool_override
181        } else {
182            self.id()
183        }
184    }
185}
186
187pub async fn complete(
188    client: &dyn HttpClient,
189    api_url: &str,
190    api_key: &str,
191    request: Request,
192) -> Result<Response, AnthropicError> {
193    let uri = format!("{api_url}/v1/messages");
194    let beta_headers = Model::from_id(&request.model)
195        .map(|model| model.beta_headers())
196        .unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
197    let request_builder = HttpRequest::builder()
198        .method(Method::POST)
199        .uri(uri)
200        .header("Anthropic-Version", "2023-06-01")
201        .header("Anthropic-Beta", beta_headers)
202        .header("X-Api-Key", api_key)
203        .header("Content-Type", "application/json");
204
205    let serialized_request =
206        serde_json::to_string(&request).context("failed to serialize request")?;
207    let request = request_builder
208        .body(AsyncBody::from(serialized_request))
209        .context("failed to construct request body")?;
210
211    let mut response = client
212        .send(request)
213        .await
214        .context("failed to send request to Anthropic")?;
215    if response.status().is_success() {
216        let mut body = Vec::new();
217        response
218            .body_mut()
219            .read_to_end(&mut body)
220            .await
221            .context("failed to read response body")?;
222        let response_message: Response =
223            serde_json::from_slice(&body).context("failed to deserialize response body")?;
224        Ok(response_message)
225    } else {
226        let mut body = Vec::new();
227        response
228            .body_mut()
229            .read_to_end(&mut body)
230            .await
231            .context("failed to read response body")?;
232        let body_str =
233            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
234        Err(AnthropicError::Other(anyhow!(
235            "Failed to connect to API: {} {}",
236            response.status(),
237            body_str
238        )))
239    }
240}
241
242pub async fn stream_completion(
243    client: &dyn HttpClient,
244    api_url: &str,
245    api_key: &str,
246    request: Request,
247) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
248    stream_completion_with_rate_limit_info(client, api_url, api_key, request)
249        .await
250        .map(|output| output.0)
251}
252
253/// <https://docs.anthropic.com/en/api/rate-limits#response-headers>
254#[derive(Debug)]
255pub struct RateLimitInfo {
256    pub requests_limit: usize,
257    pub requests_remaining: usize,
258    pub requests_reset: DateTime<Utc>,
259    pub tokens_limit: usize,
260    pub tokens_remaining: usize,
261    pub tokens_reset: DateTime<Utc>,
262}
263
264impl RateLimitInfo {
265    fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
266        let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
267        let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
268        let tokens_remaining =
269            get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
270        let requests_remaining =
271            get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
272        let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
273        let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
274        let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
275        let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
276
277        Ok(Self {
278            requests_limit,
279            tokens_limit,
280            requests_remaining,
281            tokens_remaining,
282            requests_reset,
283            tokens_reset,
284        })
285    }
286}
287
288fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
289    Ok(headers
290        .get(key)
291        .ok_or_else(|| anyhow!("missing header `{key}`"))?
292        .to_str()?)
293}
294
295pub async fn stream_completion_with_rate_limit_info(
296    client: &dyn HttpClient,
297    api_url: &str,
298    api_key: &str,
299    request: Request,
300) -> Result<
301    (
302        BoxStream<'static, Result<Event, AnthropicError>>,
303        Option<RateLimitInfo>,
304    ),
305    AnthropicError,
306> {
307    let request = StreamingRequest {
308        base: request,
309        stream: true,
310    };
311    let uri = format!("{api_url}/v1/messages");
312    let beta_headers = Model::from_id(&request.base.model)
313        .map(|model| model.beta_headers())
314        .unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
315    let request_builder = HttpRequest::builder()
316        .method(Method::POST)
317        .uri(uri)
318        .header("Anthropic-Version", "2023-06-01")
319        .header("Anthropic-Beta", beta_headers)
320        .header("X-Api-Key", api_key)
321        .header("Content-Type", "application/json");
322    let serialized_request =
323        serde_json::to_string(&request).context("failed to serialize request")?;
324    let request = request_builder
325        .body(AsyncBody::from(serialized_request))
326        .context("failed to construct request body")?;
327
328    let mut response = client
329        .send(request)
330        .await
331        .context("failed to send request to Anthropic")?;
332    if response.status().is_success() {
333        let rate_limits = RateLimitInfo::from_headers(response.headers());
334        let reader = BufReader::new(response.into_body());
335        let stream = reader
336            .lines()
337            .filter_map(|line| async move {
338                match line {
339                    Ok(line) => {
340                        let line = line.strip_prefix("data: ")?;
341                        match serde_json::from_str(line) {
342                            Ok(response) => Some(Ok(response)),
343                            Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
344                        }
345                    }
346                    Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
347                }
348            })
349            .boxed();
350        Ok((stream, rate_limits.log_err()))
351    } else {
352        let mut body = Vec::new();
353        response
354            .body_mut()
355            .read_to_end(&mut body)
356            .await
357            .context("failed to read response body")?;
358
359        let body_str =
360            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
361
362        match serde_json::from_str::<Event>(body_str) {
363            Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
364            Ok(_) => Err(AnthropicError::Other(anyhow!(
365                "Unexpected success response while expecting an error: '{body_str}'",
366            ))),
367            Err(_) => Err(AnthropicError::Other(anyhow!(
368                "Failed to connect to API: {} {}",
369                response.status(),
370                body_str,
371            ))),
372        }
373    }
374}
375
376pub async fn extract_tool_args_from_events(
377    tool_name: String,
378    mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
379) -> Result<impl Send + Stream<Item = Result<String>>> {
380    let mut tool_use_index = None;
381    while let Some(event) = events.next().await {
382        if let Event::ContentBlockStart {
383            index,
384            content_block: ResponseContent::ToolUse { name, .. },
385        } = event?
386        {
387            if name == tool_name {
388                tool_use_index = Some(index);
389                break;
390            }
391        }
392    }
393
394    let Some(tool_use_index) = tool_use_index else {
395        return Err(anyhow!("tool not used"));
396    };
397
398    Ok(events.filter_map(move |event| {
399        let result = match event {
400            Err(error) => Some(Err(error)),
401            Ok(Event::ContentBlockDelta { index, delta }) => match delta {
402                ContentDelta::TextDelta { .. } => None,
403                ContentDelta::InputJsonDelta { partial_json } => {
404                    if index == tool_use_index {
405                        Some(Ok(partial_json))
406                    } else {
407                        None
408                    }
409                }
410            },
411            _ => None,
412        };
413
414        async move { result }
415    }))
416}
417
418#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
419#[serde(rename_all = "lowercase")]
420pub enum CacheControlType {
421    Ephemeral,
422}
423
424#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
425pub struct CacheControl {
426    #[serde(rename = "type")]
427    pub cache_type: CacheControlType,
428}
429
430#[derive(Debug, Serialize, Deserialize)]
431pub struct Message {
432    pub role: Role,
433    pub content: Vec<RequestContent>,
434}
435
436#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
437#[serde(rename_all = "lowercase")]
438pub enum Role {
439    User,
440    Assistant,
441}
442
443#[derive(Debug, Serialize, Deserialize)]
444#[serde(tag = "type")]
445pub enum RequestContent {
446    #[serde(rename = "text")]
447    Text {
448        text: String,
449        #[serde(skip_serializing_if = "Option::is_none")]
450        cache_control: Option<CacheControl>,
451    },
452    #[serde(rename = "image")]
453    Image {
454        source: ImageSource,
455        #[serde(skip_serializing_if = "Option::is_none")]
456        cache_control: Option<CacheControl>,
457    },
458    #[serde(rename = "tool_use")]
459    ToolUse {
460        id: String,
461        name: String,
462        input: serde_json::Value,
463        #[serde(skip_serializing_if = "Option::is_none")]
464        cache_control: Option<CacheControl>,
465    },
466    #[serde(rename = "tool_result")]
467    ToolResult {
468        tool_use_id: String,
469        is_error: bool,
470        content: String,
471        #[serde(skip_serializing_if = "Option::is_none")]
472        cache_control: Option<CacheControl>,
473    },
474}
475
476#[derive(Debug, Serialize, Deserialize)]
477#[serde(tag = "type")]
478pub enum ResponseContent {
479    #[serde(rename = "text")]
480    Text { text: String },
481    #[serde(rename = "tool_use")]
482    ToolUse {
483        id: String,
484        name: String,
485        input: serde_json::Value,
486    },
487}
488
489#[derive(Debug, Serialize, Deserialize)]
490pub struct ImageSource {
491    #[serde(rename = "type")]
492    pub source_type: String,
493    pub media_type: String,
494    pub data: String,
495}
496
497#[derive(Debug, Serialize, Deserialize)]
498pub struct Tool {
499    pub name: String,
500    pub description: String,
501    pub input_schema: serde_json::Value,
502}
503
504#[derive(Debug, Serialize, Deserialize)]
505#[serde(tag = "type", rename_all = "lowercase")]
506pub enum ToolChoice {
507    Auto,
508    Any,
509    Tool { name: String },
510}
511
512#[derive(Debug, Serialize, Deserialize)]
513pub struct Request {
514    pub model: String,
515    pub max_tokens: u32,
516    pub messages: Vec<Message>,
517    #[serde(default, skip_serializing_if = "Vec::is_empty")]
518    pub tools: Vec<Tool>,
519    #[serde(default, skip_serializing_if = "Option::is_none")]
520    pub tool_choice: Option<ToolChoice>,
521    #[serde(default, skip_serializing_if = "Option::is_none")]
522    pub system: Option<String>,
523    #[serde(default, skip_serializing_if = "Option::is_none")]
524    pub metadata: Option<Metadata>,
525    #[serde(default, skip_serializing_if = "Vec::is_empty")]
526    pub stop_sequences: Vec<String>,
527    #[serde(default, skip_serializing_if = "Option::is_none")]
528    pub temperature: Option<f32>,
529    #[serde(default, skip_serializing_if = "Option::is_none")]
530    pub top_k: Option<u32>,
531    #[serde(default, skip_serializing_if = "Option::is_none")]
532    pub top_p: Option<f32>,
533}
534
535#[derive(Debug, Serialize, Deserialize)]
536struct StreamingRequest {
537    #[serde(flatten)]
538    pub base: Request,
539    pub stream: bool,
540}
541
542#[derive(Debug, Serialize, Deserialize)]
543pub struct Metadata {
544    pub user_id: Option<String>,
545}
546
547#[derive(Debug, Serialize, Deserialize)]
548pub struct Usage {
549    #[serde(default, skip_serializing_if = "Option::is_none")]
550    pub input_tokens: Option<u32>,
551    #[serde(default, skip_serializing_if = "Option::is_none")]
552    pub output_tokens: Option<u32>,
553    #[serde(default, skip_serializing_if = "Option::is_none")]
554    pub cache_creation_input_tokens: Option<u32>,
555    #[serde(default, skip_serializing_if = "Option::is_none")]
556    pub cache_read_input_tokens: Option<u32>,
557}
558
559#[derive(Debug, Serialize, Deserialize)]
560pub struct Response {
561    pub id: String,
562    #[serde(rename = "type")]
563    pub response_type: String,
564    pub role: Role,
565    pub content: Vec<ResponseContent>,
566    pub model: String,
567    #[serde(default, skip_serializing_if = "Option::is_none")]
568    pub stop_reason: Option<String>,
569    #[serde(default, skip_serializing_if = "Option::is_none")]
570    pub stop_sequence: Option<String>,
571    pub usage: Usage,
572}
573
574#[derive(Debug, Serialize, Deserialize)]
575#[serde(tag = "type")]
576pub enum Event {
577    #[serde(rename = "message_start")]
578    MessageStart { message: Response },
579    #[serde(rename = "content_block_start")]
580    ContentBlockStart {
581        index: usize,
582        content_block: ResponseContent,
583    },
584    #[serde(rename = "content_block_delta")]
585    ContentBlockDelta { index: usize, delta: ContentDelta },
586    #[serde(rename = "content_block_stop")]
587    ContentBlockStop { index: usize },
588    #[serde(rename = "message_delta")]
589    MessageDelta { delta: MessageDelta, usage: Usage },
590    #[serde(rename = "message_stop")]
591    MessageStop,
592    #[serde(rename = "ping")]
593    Ping,
594    #[serde(rename = "error")]
595    Error { error: ApiError },
596}
597
598#[derive(Debug, Serialize, Deserialize)]
599#[serde(tag = "type")]
600pub enum ContentDelta {
601    #[serde(rename = "text_delta")]
602    TextDelta { text: String },
603    #[serde(rename = "input_json_delta")]
604    InputJsonDelta { partial_json: String },
605}
606
607#[derive(Debug, Serialize, Deserialize)]
608pub struct MessageDelta {
609    pub stop_reason: Option<String>,
610    pub stop_sequence: Option<String>,
611}
612
613#[derive(Error, Debug)]
614pub enum AnthropicError {
615    #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
616    ApiError(ApiError),
617    #[error("{0}")]
618    Other(#[from] anyhow::Error),
619}
620
621#[derive(Debug, Serialize, Deserialize)]
622pub struct ApiError {
623    #[serde(rename = "type")]
624    pub error_type: String,
625    pub message: String,
626}
627
628/// An Anthropic API error code.
629/// <https://docs.anthropic.com/en/api/errors#http-errors>
630#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
631#[strum(serialize_all = "snake_case")]
632pub enum ApiErrorCode {
633    /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
634    InvalidRequestError,
635    /// 401 - `authentication_error`: There's an issue with your API key.
636    AuthenticationError,
637    /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
638    PermissionError,
639    /// 404 - `not_found_error`: The requested resource was not found.
640    NotFoundError,
641    /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
642    RequestTooLarge,
643    /// 429 - `rate_limit_error`: Your account has hit a rate limit.
644    RateLimitError,
645    /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
646    ApiError,
647    /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
648    OverloadedError,
649}
650
651impl ApiError {
652    pub fn code(&self) -> Option<ApiErrorCode> {
653        ApiErrorCode::from_str(&self.error_type).ok()
654    }
655
656    pub fn is_rate_limit_error(&self) -> bool {
657        matches!(self.error_type.as_str(), "rate_limit_error")
658    }
659}