anthropic.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Context, Result};
  4use chrono::{DateTime, Utc};
  5use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  6use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  7use isahc::config::Configurable;
  8use isahc::http::{HeaderMap, HeaderValue};
  9use serde::{Deserialize, Serialize};
 10use std::time::Duration;
 11use std::{pin::Pin, str::FromStr};
 12use strum::{EnumIter, EnumString};
 13use thiserror::Error;
 14use util::ResultExt as _;
 15
 16pub use supported_countries::*;
 17
 18pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 19
 20#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 21#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 22pub struct AnthropicModelCacheConfiguration {
 23    pub min_total_token: usize,
 24    pub should_speculate: bool,
 25    pub max_cache_anchors: usize,
 26}
 27
 28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 30pub enum Model {
 31    #[default]
 32    #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
 33    Claude3_5Sonnet,
 34    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
 35    Claude3Opus,
 36    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
 37    Claude3Sonnet,
 38    #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
 39    Claude3Haiku,
 40    #[serde(rename = "custom")]
 41    Custom {
 42        name: String,
 43        max_tokens: usize,
 44        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 45        display_name: Option<String>,
 46        /// Override this model with a different Anthropic model for tool calls.
 47        tool_override: Option<String>,
 48        /// Indicates whether this custom model supports caching.
 49        cache_configuration: Option<AnthropicModelCacheConfiguration>,
 50        max_output_tokens: Option<u32>,
 51    },
 52}
 53
 54impl Model {
 55    pub fn from_id(id: &str) -> Result<Self> {
 56        if id.starts_with("claude-3-5-sonnet") {
 57            Ok(Self::Claude3_5Sonnet)
 58        } else if id.starts_with("claude-3-opus") {
 59            Ok(Self::Claude3Opus)
 60        } else if id.starts_with("claude-3-sonnet") {
 61            Ok(Self::Claude3Sonnet)
 62        } else if id.starts_with("claude-3-haiku") {
 63            Ok(Self::Claude3Haiku)
 64        } else {
 65            Err(anyhow!("invalid model id"))
 66        }
 67    }
 68
 69    pub fn id(&self) -> &str {
 70        match self {
 71            Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
 72            Model::Claude3Opus => "claude-3-opus-20240229",
 73            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 74            Model::Claude3Haiku => "claude-3-haiku-20240307",
 75            Self::Custom { name, .. } => name,
 76        }
 77    }
 78
 79    pub fn display_name(&self) -> &str {
 80        match self {
 81            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 82            Self::Claude3Opus => "Claude 3 Opus",
 83            Self::Claude3Sonnet => "Claude 3 Sonnet",
 84            Self::Claude3Haiku => "Claude 3 Haiku",
 85            Self::Custom {
 86                name, display_name, ..
 87            } => display_name.as_ref().unwrap_or(name),
 88        }
 89    }
 90
 91    pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
 92        match self {
 93            Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
 94                min_total_token: 2_048,
 95                should_speculate: true,
 96                max_cache_anchors: 4,
 97            }),
 98            Self::Custom {
 99                cache_configuration,
100                ..
101            } => cache_configuration.clone(),
102            _ => None,
103        }
104    }
105
106    pub fn max_token_count(&self) -> usize {
107        match self {
108            Self::Claude3_5Sonnet
109            | Self::Claude3Opus
110            | Self::Claude3Sonnet
111            | Self::Claude3Haiku => 200_000,
112            Self::Custom { max_tokens, .. } => *max_tokens,
113        }
114    }
115
116    pub fn max_output_tokens(&self) -> u32 {
117        match self {
118            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
119            Self::Claude3_5Sonnet => 8_192,
120            Self::Custom {
121                max_output_tokens, ..
122            } => max_output_tokens.unwrap_or(4_096),
123        }
124    }
125
126    pub fn tool_model_id(&self) -> &str {
127        if let Self::Custom {
128            tool_override: Some(tool_override),
129            ..
130        } = self
131        {
132            tool_override
133        } else {
134            self.id()
135        }
136    }
137}
138
139pub async fn complete(
140    client: &dyn HttpClient,
141    api_url: &str,
142    api_key: &str,
143    request: Request,
144) -> Result<Response, AnthropicError> {
145    let uri = format!("{api_url}/v1/messages");
146    let request_builder = HttpRequest::builder()
147        .method(Method::POST)
148        .uri(uri)
149        .header("Anthropic-Version", "2023-06-01")
150        .header(
151            "Anthropic-Beta",
152            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
153        )
154        .header("X-Api-Key", api_key)
155        .header("Content-Type", "application/json");
156
157    let serialized_request =
158        serde_json::to_string(&request).context("failed to serialize request")?;
159    let request = request_builder
160        .body(AsyncBody::from(serialized_request))
161        .context("failed to construct request body")?;
162
163    let mut response = client
164        .send(request)
165        .await
166        .context("failed to send request to Anthropic")?;
167    if response.status().is_success() {
168        let mut body = Vec::new();
169        response
170            .body_mut()
171            .read_to_end(&mut body)
172            .await
173            .context("failed to read response body")?;
174        let response_message: Response =
175            serde_json::from_slice(&body).context("failed to deserialize response body")?;
176        Ok(response_message)
177    } else {
178        let mut body = Vec::new();
179        response
180            .body_mut()
181            .read_to_end(&mut body)
182            .await
183            .context("failed to read response body")?;
184        let body_str =
185            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
186        Err(AnthropicError::Other(anyhow!(
187            "Failed to connect to API: {} {}",
188            response.status(),
189            body_str
190        )))
191    }
192}
193
194pub async fn stream_completion(
195    client: &dyn HttpClient,
196    api_url: &str,
197    api_key: &str,
198    request: Request,
199    low_speed_timeout: Option<Duration>,
200) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
201    stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
202        .await
203        .map(|output| output.0)
204}
205
206/// https://docs.anthropic.com/en/api/rate-limits#response-headers
207#[derive(Debug)]
208pub struct RateLimitInfo {
209    pub requests_limit: usize,
210    pub requests_remaining: usize,
211    pub requests_reset: DateTime<Utc>,
212    pub tokens_limit: usize,
213    pub tokens_remaining: usize,
214    pub tokens_reset: DateTime<Utc>,
215}
216
217impl RateLimitInfo {
218    fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
219        let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
220        let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
221        let tokens_remaining =
222            get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
223        let requests_remaining =
224            get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
225        let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
226        let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
227        let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
228        let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
229
230        Ok(Self {
231            requests_limit,
232            tokens_limit,
233            requests_remaining,
234            tokens_remaining,
235            requests_reset,
236            tokens_reset,
237        })
238    }
239}
240
241fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
242    Ok(headers
243        .get(key)
244        .ok_or_else(|| anyhow!("missing header `{key}`"))?
245        .to_str()?)
246}
247
248pub async fn stream_completion_with_rate_limit_info(
249    client: &dyn HttpClient,
250    api_url: &str,
251    api_key: &str,
252    request: Request,
253    low_speed_timeout: Option<Duration>,
254) -> Result<
255    (
256        BoxStream<'static, Result<Event, AnthropicError>>,
257        Option<RateLimitInfo>,
258    ),
259    AnthropicError,
260> {
261    let request = StreamingRequest {
262        base: request,
263        stream: true,
264    };
265    let uri = format!("{api_url}/v1/messages");
266    let mut request_builder = HttpRequest::builder()
267        .method(Method::POST)
268        .uri(uri)
269        .header("Anthropic-Version", "2023-06-01")
270        .header(
271            "Anthropic-Beta",
272            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
273        )
274        .header("X-Api-Key", api_key)
275        .header("Content-Type", "application/json");
276    if let Some(low_speed_timeout) = low_speed_timeout {
277        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
278    }
279    let serialized_request =
280        serde_json::to_string(&request).context("failed to serialize request")?;
281    let request = request_builder
282        .body(AsyncBody::from(serialized_request))
283        .context("failed to construct request body")?;
284
285    let mut response = client
286        .send(request)
287        .await
288        .context("failed to send request to Anthropic")?;
289    if response.status().is_success() {
290        let rate_limits = RateLimitInfo::from_headers(response.headers());
291        let reader = BufReader::new(response.into_body());
292        let stream = reader
293            .lines()
294            .filter_map(|line| async move {
295                match line {
296                    Ok(line) => {
297                        let line = line.strip_prefix("data: ")?;
298                        match serde_json::from_str(line) {
299                            Ok(response) => Some(Ok(response)),
300                            Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
301                        }
302                    }
303                    Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
304                }
305            })
306            .boxed();
307        Ok((stream, rate_limits.log_err()))
308    } else {
309        let mut body = Vec::new();
310        response
311            .body_mut()
312            .read_to_end(&mut body)
313            .await
314            .context("failed to read response body")?;
315
316        let body_str =
317            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
318
319        match serde_json::from_str::<Event>(body_str) {
320            Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
321            Ok(_) => Err(AnthropicError::Other(anyhow!(
322                "Unexpected success response while expecting an error: '{body_str}'",
323            ))),
324            Err(_) => Err(AnthropicError::Other(anyhow!(
325                "Failed to connect to API: {} {}",
326                response.status(),
327                body_str,
328            ))),
329        }
330    }
331}
332
333pub fn extract_text_from_events(
334    response: impl Stream<Item = Result<Event, AnthropicError>>,
335) -> impl Stream<Item = Result<String, AnthropicError>> {
336    response.filter_map(|response| async move {
337        match response {
338            Ok(response) => match response {
339                Event::ContentBlockStart { content_block, .. } => match content_block {
340                    Content::Text { text, .. } => Some(Ok(text)),
341                    _ => None,
342                },
343                Event::ContentBlockDelta { delta, .. } => match delta {
344                    ContentDelta::TextDelta { text } => Some(Ok(text)),
345                    _ => None,
346                },
347                Event::Error { error } => Some(Err(AnthropicError::ApiError(error))),
348                _ => None,
349            },
350            Err(error) => Some(Err(error)),
351        }
352    })
353}
354
355pub async fn extract_tool_args_from_events(
356    tool_name: String,
357    mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
358) -> Result<impl Send + Stream<Item = Result<String>>> {
359    let mut tool_use_index = None;
360    while let Some(event) = events.next().await {
361        if let Event::ContentBlockStart {
362            index,
363            content_block,
364        } = event?
365        {
366            if let Content::ToolUse { name, .. } = content_block {
367                if name == tool_name {
368                    tool_use_index = Some(index);
369                    break;
370                }
371            }
372        }
373    }
374
375    let Some(tool_use_index) = tool_use_index else {
376        return Err(anyhow!("tool not used"));
377    };
378
379    Ok(events.filter_map(move |event| {
380        let result = match event {
381            Err(error) => Some(Err(error)),
382            Ok(Event::ContentBlockDelta { index, delta }) => match delta {
383                ContentDelta::TextDelta { .. } => None,
384                ContentDelta::InputJsonDelta { partial_json } => {
385                    if index == tool_use_index {
386                        Some(Ok(partial_json))
387                    } else {
388                        None
389                    }
390                }
391            },
392            _ => None,
393        };
394
395        async move { result }
396    }))
397}
398
399#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
400#[serde(rename_all = "lowercase")]
401pub enum CacheControlType {
402    Ephemeral,
403}
404
405#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
406pub struct CacheControl {
407    #[serde(rename = "type")]
408    pub cache_type: CacheControlType,
409}
410
411#[derive(Debug, Serialize, Deserialize)]
412pub struct Message {
413    pub role: Role,
414    pub content: Vec<Content>,
415}
416
417#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
418#[serde(rename_all = "lowercase")]
419pub enum Role {
420    User,
421    Assistant,
422}
423
424#[derive(Debug, Serialize, Deserialize)]
425#[serde(tag = "type")]
426pub enum Content {
427    #[serde(rename = "text")]
428    Text {
429        text: String,
430        #[serde(skip_serializing_if = "Option::is_none")]
431        cache_control: Option<CacheControl>,
432    },
433    #[serde(rename = "image")]
434    Image {
435        source: ImageSource,
436        #[serde(skip_serializing_if = "Option::is_none")]
437        cache_control: Option<CacheControl>,
438    },
439    #[serde(rename = "tool_use")]
440    ToolUse {
441        id: String,
442        name: String,
443        input: serde_json::Value,
444        #[serde(skip_serializing_if = "Option::is_none")]
445        cache_control: Option<CacheControl>,
446    },
447    #[serde(rename = "tool_result")]
448    ToolResult {
449        tool_use_id: String,
450        content: String,
451        #[serde(skip_serializing_if = "Option::is_none")]
452        cache_control: Option<CacheControl>,
453    },
454}
455
456#[derive(Debug, Serialize, Deserialize)]
457pub struct ImageSource {
458    #[serde(rename = "type")]
459    pub source_type: String,
460    pub media_type: String,
461    pub data: String,
462}
463
464#[derive(Debug, Serialize, Deserialize)]
465pub struct Tool {
466    pub name: String,
467    pub description: String,
468    pub input_schema: serde_json::Value,
469}
470
471#[derive(Debug, Serialize, Deserialize)]
472#[serde(tag = "type", rename_all = "lowercase")]
473pub enum ToolChoice {
474    Auto,
475    Any,
476    Tool { name: String },
477}
478
479#[derive(Debug, Serialize, Deserialize)]
480pub struct Request {
481    pub model: String,
482    pub max_tokens: u32,
483    pub messages: Vec<Message>,
484    #[serde(default, skip_serializing_if = "Vec::is_empty")]
485    pub tools: Vec<Tool>,
486    #[serde(default, skip_serializing_if = "Option::is_none")]
487    pub tool_choice: Option<ToolChoice>,
488    #[serde(default, skip_serializing_if = "Option::is_none")]
489    pub system: Option<String>,
490    #[serde(default, skip_serializing_if = "Option::is_none")]
491    pub metadata: Option<Metadata>,
492    #[serde(default, skip_serializing_if = "Vec::is_empty")]
493    pub stop_sequences: Vec<String>,
494    #[serde(default, skip_serializing_if = "Option::is_none")]
495    pub temperature: Option<f32>,
496    #[serde(default, skip_serializing_if = "Option::is_none")]
497    pub top_k: Option<u32>,
498    #[serde(default, skip_serializing_if = "Option::is_none")]
499    pub top_p: Option<f32>,
500}
501
502#[derive(Debug, Serialize, Deserialize)]
503struct StreamingRequest {
504    #[serde(flatten)]
505    pub base: Request,
506    pub stream: bool,
507}
508
509#[derive(Debug, Serialize, Deserialize)]
510pub struct Metadata {
511    pub user_id: Option<String>,
512}
513
514#[derive(Debug, Serialize, Deserialize)]
515pub struct Usage {
516    #[serde(default, skip_serializing_if = "Option::is_none")]
517    pub input_tokens: Option<u32>,
518    #[serde(default, skip_serializing_if = "Option::is_none")]
519    pub output_tokens: Option<u32>,
520}
521
522#[derive(Debug, Serialize, Deserialize)]
523pub struct Response {
524    pub id: String,
525    #[serde(rename = "type")]
526    pub response_type: String,
527    pub role: Role,
528    pub content: Vec<Content>,
529    pub model: String,
530    #[serde(default, skip_serializing_if = "Option::is_none")]
531    pub stop_reason: Option<String>,
532    #[serde(default, skip_serializing_if = "Option::is_none")]
533    pub stop_sequence: Option<String>,
534    pub usage: Usage,
535}
536
537#[derive(Debug, Serialize, Deserialize)]
538#[serde(tag = "type")]
539pub enum Event {
540    #[serde(rename = "message_start")]
541    MessageStart { message: Response },
542    #[serde(rename = "content_block_start")]
543    ContentBlockStart {
544        index: usize,
545        content_block: Content,
546    },
547    #[serde(rename = "content_block_delta")]
548    ContentBlockDelta { index: usize, delta: ContentDelta },
549    #[serde(rename = "content_block_stop")]
550    ContentBlockStop { index: usize },
551    #[serde(rename = "message_delta")]
552    MessageDelta { delta: MessageDelta, usage: Usage },
553    #[serde(rename = "message_stop")]
554    MessageStop,
555    #[serde(rename = "ping")]
556    Ping,
557    #[serde(rename = "error")]
558    Error { error: ApiError },
559}
560
561#[derive(Debug, Serialize, Deserialize)]
562#[serde(tag = "type")]
563pub enum ContentDelta {
564    #[serde(rename = "text_delta")]
565    TextDelta { text: String },
566    #[serde(rename = "input_json_delta")]
567    InputJsonDelta { partial_json: String },
568}
569
570#[derive(Debug, Serialize, Deserialize)]
571pub struct MessageDelta {
572    pub stop_reason: Option<String>,
573    pub stop_sequence: Option<String>,
574}
575
576#[derive(Error, Debug)]
577pub enum AnthropicError {
578    #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
579    ApiError(ApiError),
580    #[error("{0}")]
581    Other(#[from] anyhow::Error),
582}
583
584#[derive(Debug, Serialize, Deserialize)]
585pub struct ApiError {
586    #[serde(rename = "type")]
587    pub error_type: String,
588    pub message: String,
589}
590
591/// An Anthropic API error code.
592/// https://docs.anthropic.com/en/api/errors#http-errors
593#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
594#[strum(serialize_all = "snake_case")]
595pub enum ApiErrorCode {
596    /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
597    InvalidRequestError,
598    /// 401 - `authentication_error`: There's an issue with your API key.
599    AuthenticationError,
600    /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
601    PermissionError,
602    /// 404 - `not_found_error`: The requested resource was not found.
603    NotFoundError,
604    /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
605    RequestTooLarge,
606    /// 429 - `rate_limit_error`: Your account has hit a rate limit.
607    RateLimitError,
608    /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
609    ApiError,
610    /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
611    OverloadedError,
612}
613
614impl ApiError {
615    pub fn code(&self) -> Option<ApiErrorCode> {
616        ApiErrorCode::from_str(&self.error_type).ok()
617    }
618
619    pub fn is_rate_limit_error(&self) -> bool {
620        match self.error_type.as_str() {
621            "rate_limit_error" => true,
622            _ => false,
623        }
624    }
625}