anthropic.rs

  1mod supported_countries;
  2
  3use std::time::Duration;
  4use std::{pin::Pin, str::FromStr};
  5
  6use anyhow::{anyhow, Context, Result};
  7use chrono::{DateTime, Utc};
  8use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use isahc::config::Configurable;
 11use isahc::http::{HeaderMap, HeaderValue};
 12use serde::{Deserialize, Serialize};
 13use strum::{EnumIter, EnumString};
 14use thiserror::Error;
 15use util::ResultExt as _;
 16
 17pub use supported_countries::*;
 18
 19pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
 20
 21#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 22#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 23pub struct AnthropicModelCacheConfiguration {
 24    pub min_total_token: usize,
 25    pub should_speculate: bool,
 26    pub max_cache_anchors: usize,
 27}
 28
 29#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 30#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 31pub enum Model {
 32    #[default]
 33    #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
 34    Claude3_5Sonnet,
 35    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
 36    Claude3Opus,
 37    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
 38    Claude3Sonnet,
 39    #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
 40    Claude3Haiku,
 41    #[serde(rename = "custom")]
 42    Custom {
 43        name: String,
 44        max_tokens: usize,
 45        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 46        display_name: Option<String>,
 47        /// Override this model with a different Anthropic model for tool calls.
 48        tool_override: Option<String>,
 49        /// Indicates whether this custom model supports caching.
 50        cache_configuration: Option<AnthropicModelCacheConfiguration>,
 51        max_output_tokens: Option<u32>,
 52        default_temperature: Option<f32>,
 53    },
 54}
 55
 56impl Model {
 57    pub fn from_id(id: &str) -> Result<Self> {
 58        if id.starts_with("claude-3-5-sonnet") {
 59            Ok(Self::Claude3_5Sonnet)
 60        } else if id.starts_with("claude-3-opus") {
 61            Ok(Self::Claude3Opus)
 62        } else if id.starts_with("claude-3-sonnet") {
 63            Ok(Self::Claude3Sonnet)
 64        } else if id.starts_with("claude-3-haiku") {
 65            Ok(Self::Claude3Haiku)
 66        } else {
 67            Err(anyhow!("invalid model id"))
 68        }
 69    }
 70
 71    pub fn id(&self) -> &str {
 72        match self {
 73            Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
 74            Model::Claude3Opus => "claude-3-opus-20240229",
 75            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 76            Model::Claude3Haiku => "claude-3-haiku-20240307",
 77            Self::Custom { name, .. } => name,
 78        }
 79    }
 80
 81    pub fn display_name(&self) -> &str {
 82        match self {
 83            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 84            Self::Claude3Opus => "Claude 3 Opus",
 85            Self::Claude3Sonnet => "Claude 3 Sonnet",
 86            Self::Claude3Haiku => "Claude 3 Haiku",
 87            Self::Custom {
 88                name, display_name, ..
 89            } => display_name.as_ref().unwrap_or(name),
 90        }
 91    }
 92
 93    pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
 94        match self {
 95            Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
 96                min_total_token: 2_048,
 97                should_speculate: true,
 98                max_cache_anchors: 4,
 99            }),
100            Self::Custom {
101                cache_configuration,
102                ..
103            } => cache_configuration.clone(),
104            _ => None,
105        }
106    }
107
108    pub fn max_token_count(&self) -> usize {
109        match self {
110            Self::Claude3_5Sonnet
111            | Self::Claude3Opus
112            | Self::Claude3Sonnet
113            | Self::Claude3Haiku => 200_000,
114            Self::Custom { max_tokens, .. } => *max_tokens,
115        }
116    }
117
118    pub fn max_output_tokens(&self) -> u32 {
119        match self {
120            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
121            Self::Claude3_5Sonnet => 8_192,
122            Self::Custom {
123                max_output_tokens, ..
124            } => max_output_tokens.unwrap_or(4_096),
125        }
126    }
127
128    pub fn default_temperature(&self) -> f32 {
129        match self {
130            Self::Claude3_5Sonnet
131            | Self::Claude3Opus
132            | Self::Claude3Sonnet
133            | Self::Claude3Haiku => 1.0,
134            Self::Custom {
135                default_temperature,
136                ..
137            } => default_temperature.unwrap_or(1.0),
138        }
139    }
140
141    pub fn tool_model_id(&self) -> &str {
142        if let Self::Custom {
143            tool_override: Some(tool_override),
144            ..
145        } = self
146        {
147            tool_override
148        } else {
149            self.id()
150        }
151    }
152}
153
154pub async fn complete(
155    client: &dyn HttpClient,
156    api_url: &str,
157    api_key: &str,
158    request: Request,
159) -> Result<Response, AnthropicError> {
160    let uri = format!("{api_url}/v1/messages");
161    let request_builder = HttpRequest::builder()
162        .method(Method::POST)
163        .uri(uri)
164        .header("Anthropic-Version", "2023-06-01")
165        .header(
166            "Anthropic-Beta",
167            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
168        )
169        .header("X-Api-Key", api_key)
170        .header("Content-Type", "application/json");
171
172    let serialized_request =
173        serde_json::to_string(&request).context("failed to serialize request")?;
174    let request = request_builder
175        .body(AsyncBody::from(serialized_request))
176        .context("failed to construct request body")?;
177
178    let mut response = client
179        .send(request)
180        .await
181        .context("failed to send request to Anthropic")?;
182    if response.status().is_success() {
183        let mut body = Vec::new();
184        response
185            .body_mut()
186            .read_to_end(&mut body)
187            .await
188            .context("failed to read response body")?;
189        let response_message: Response =
190            serde_json::from_slice(&body).context("failed to deserialize response body")?;
191        Ok(response_message)
192    } else {
193        let mut body = Vec::new();
194        response
195            .body_mut()
196            .read_to_end(&mut body)
197            .await
198            .context("failed to read response body")?;
199        let body_str =
200            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
201        Err(AnthropicError::Other(anyhow!(
202            "Failed to connect to API: {} {}",
203            response.status(),
204            body_str
205        )))
206    }
207}
208
209pub async fn stream_completion(
210    client: &dyn HttpClient,
211    api_url: &str,
212    api_key: &str,
213    request: Request,
214    low_speed_timeout: Option<Duration>,
215) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
216    stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
217        .await
218        .map(|output| output.0)
219}
220
221/// https://docs.anthropic.com/en/api/rate-limits#response-headers
222#[derive(Debug)]
223pub struct RateLimitInfo {
224    pub requests_limit: usize,
225    pub requests_remaining: usize,
226    pub requests_reset: DateTime<Utc>,
227    pub tokens_limit: usize,
228    pub tokens_remaining: usize,
229    pub tokens_reset: DateTime<Utc>,
230}
231
232impl RateLimitInfo {
233    fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
234        let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
235        let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
236        let tokens_remaining =
237            get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
238        let requests_remaining =
239            get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
240        let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
241        let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
242        let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
243        let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
244
245        Ok(Self {
246            requests_limit,
247            tokens_limit,
248            requests_remaining,
249            tokens_remaining,
250            requests_reset,
251            tokens_reset,
252        })
253    }
254}
255
256fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
257    Ok(headers
258        .get(key)
259        .ok_or_else(|| anyhow!("missing header `{key}`"))?
260        .to_str()?)
261}
262
263pub async fn stream_completion_with_rate_limit_info(
264    client: &dyn HttpClient,
265    api_url: &str,
266    api_key: &str,
267    request: Request,
268    low_speed_timeout: Option<Duration>,
269) -> Result<
270    (
271        BoxStream<'static, Result<Event, AnthropicError>>,
272        Option<RateLimitInfo>,
273    ),
274    AnthropicError,
275> {
276    let request = StreamingRequest {
277        base: request,
278        stream: true,
279    };
280    let uri = format!("{api_url}/v1/messages");
281    let mut request_builder = HttpRequest::builder()
282        .method(Method::POST)
283        .uri(uri)
284        .header("Anthropic-Version", "2023-06-01")
285        .header(
286            "Anthropic-Beta",
287            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
288        )
289        .header("X-Api-Key", api_key)
290        .header("Content-Type", "application/json");
291    if let Some(low_speed_timeout) = low_speed_timeout {
292        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
293    }
294    let serialized_request =
295        serde_json::to_string(&request).context("failed to serialize request")?;
296    let request = request_builder
297        .body(AsyncBody::from(serialized_request))
298        .context("failed to construct request body")?;
299
300    let mut response = client
301        .send(request)
302        .await
303        .context("failed to send request to Anthropic")?;
304    if response.status().is_success() {
305        let rate_limits = RateLimitInfo::from_headers(response.headers());
306        let reader = BufReader::new(response.into_body());
307        let stream = reader
308            .lines()
309            .filter_map(|line| async move {
310                match line {
311                    Ok(line) => {
312                        let line = line.strip_prefix("data: ")?;
313                        match serde_json::from_str(line) {
314                            Ok(response) => Some(Ok(response)),
315                            Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
316                        }
317                    }
318                    Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
319                }
320            })
321            .boxed();
322        Ok((stream, rate_limits.log_err()))
323    } else {
324        let mut body = Vec::new();
325        response
326            .body_mut()
327            .read_to_end(&mut body)
328            .await
329            .context("failed to read response body")?;
330
331        let body_str =
332            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
333
334        match serde_json::from_str::<Event>(body_str) {
335            Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
336            Ok(_) => Err(AnthropicError::Other(anyhow!(
337                "Unexpected success response while expecting an error: '{body_str}'",
338            ))),
339            Err(_) => Err(AnthropicError::Other(anyhow!(
340                "Failed to connect to API: {} {}",
341                response.status(),
342                body_str,
343            ))),
344        }
345    }
346}
347
348pub async fn extract_tool_args_from_events(
349    tool_name: String,
350    mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
351) -> Result<impl Send + Stream<Item = Result<String>>> {
352    let mut tool_use_index = None;
353    while let Some(event) = events.next().await {
354        if let Event::ContentBlockStart {
355            index,
356            content_block: ResponseContent::ToolUse { name, .. },
357        } = event?
358        {
359            if name == tool_name {
360                tool_use_index = Some(index);
361                break;
362            }
363        }
364    }
365
366    let Some(tool_use_index) = tool_use_index else {
367        return Err(anyhow!("tool not used"));
368    };
369
370    Ok(events.filter_map(move |event| {
371        let result = match event {
372            Err(error) => Some(Err(error)),
373            Ok(Event::ContentBlockDelta { index, delta }) => match delta {
374                ContentDelta::TextDelta { .. } => None,
375                ContentDelta::InputJsonDelta { partial_json } => {
376                    if index == tool_use_index {
377                        Some(Ok(partial_json))
378                    } else {
379                        None
380                    }
381                }
382            },
383            _ => None,
384        };
385
386        async move { result }
387    }))
388}
389
390#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
391#[serde(rename_all = "lowercase")]
392pub enum CacheControlType {
393    Ephemeral,
394}
395
396#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
397pub struct CacheControl {
398    #[serde(rename = "type")]
399    pub cache_type: CacheControlType,
400}
401
402#[derive(Debug, Serialize, Deserialize)]
403pub struct Message {
404    pub role: Role,
405    pub content: Vec<RequestContent>,
406}
407
408#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
409#[serde(rename_all = "lowercase")]
410pub enum Role {
411    User,
412    Assistant,
413}
414
415#[derive(Debug, Serialize, Deserialize)]
416#[serde(tag = "type")]
417pub enum RequestContent {
418    #[serde(rename = "text")]
419    Text {
420        text: String,
421        #[serde(skip_serializing_if = "Option::is_none")]
422        cache_control: Option<CacheControl>,
423    },
424    #[serde(rename = "image")]
425    Image {
426        source: ImageSource,
427        #[serde(skip_serializing_if = "Option::is_none")]
428        cache_control: Option<CacheControl>,
429    },
430    #[serde(rename = "tool_use")]
431    ToolUse {
432        id: String,
433        name: String,
434        input: serde_json::Value,
435        #[serde(skip_serializing_if = "Option::is_none")]
436        cache_control: Option<CacheControl>,
437    },
438    #[serde(rename = "tool_result")]
439    ToolResult {
440        tool_use_id: String,
441        is_error: bool,
442        content: String,
443        #[serde(skip_serializing_if = "Option::is_none")]
444        cache_control: Option<CacheControl>,
445    },
446}
447
448#[derive(Debug, Serialize, Deserialize)]
449#[serde(tag = "type")]
450pub enum ResponseContent {
451    #[serde(rename = "text")]
452    Text { text: String },
453    #[serde(rename = "tool_use")]
454    ToolUse {
455        id: String,
456        name: String,
457        input: serde_json::Value,
458    },
459}
460
461#[derive(Debug, Serialize, Deserialize)]
462pub struct ImageSource {
463    #[serde(rename = "type")]
464    pub source_type: String,
465    pub media_type: String,
466    pub data: String,
467}
468
469#[derive(Debug, Serialize, Deserialize)]
470pub struct Tool {
471    pub name: String,
472    pub description: String,
473    pub input_schema: serde_json::Value,
474}
475
476#[derive(Debug, Serialize, Deserialize)]
477#[serde(tag = "type", rename_all = "lowercase")]
478pub enum ToolChoice {
479    Auto,
480    Any,
481    Tool { name: String },
482}
483
484#[derive(Debug, Serialize, Deserialize)]
485pub struct Request {
486    pub model: String,
487    pub max_tokens: u32,
488    pub messages: Vec<Message>,
489    #[serde(default, skip_serializing_if = "Vec::is_empty")]
490    pub tools: Vec<Tool>,
491    #[serde(default, skip_serializing_if = "Option::is_none")]
492    pub tool_choice: Option<ToolChoice>,
493    #[serde(default, skip_serializing_if = "Option::is_none")]
494    pub system: Option<String>,
495    #[serde(default, skip_serializing_if = "Option::is_none")]
496    pub metadata: Option<Metadata>,
497    #[serde(default, skip_serializing_if = "Vec::is_empty")]
498    pub stop_sequences: Vec<String>,
499    #[serde(default, skip_serializing_if = "Option::is_none")]
500    pub temperature: Option<f32>,
501    #[serde(default, skip_serializing_if = "Option::is_none")]
502    pub top_k: Option<u32>,
503    #[serde(default, skip_serializing_if = "Option::is_none")]
504    pub top_p: Option<f32>,
505}
506
507#[derive(Debug, Serialize, Deserialize)]
508struct StreamingRequest {
509    #[serde(flatten)]
510    pub base: Request,
511    pub stream: bool,
512}
513
514#[derive(Debug, Serialize, Deserialize)]
515pub struct Metadata {
516    pub user_id: Option<String>,
517}
518
519#[derive(Debug, Serialize, Deserialize)]
520pub struct Usage {
521    #[serde(default, skip_serializing_if = "Option::is_none")]
522    pub input_tokens: Option<u32>,
523    #[serde(default, skip_serializing_if = "Option::is_none")]
524    pub output_tokens: Option<u32>,
525}
526
527#[derive(Debug, Serialize, Deserialize)]
528pub struct Response {
529    pub id: String,
530    #[serde(rename = "type")]
531    pub response_type: String,
532    pub role: Role,
533    pub content: Vec<ResponseContent>,
534    pub model: String,
535    #[serde(default, skip_serializing_if = "Option::is_none")]
536    pub stop_reason: Option<String>,
537    #[serde(default, skip_serializing_if = "Option::is_none")]
538    pub stop_sequence: Option<String>,
539    pub usage: Usage,
540}
541
542#[derive(Debug, Serialize, Deserialize)]
543#[serde(tag = "type")]
544pub enum Event {
545    #[serde(rename = "message_start")]
546    MessageStart { message: Response },
547    #[serde(rename = "content_block_start")]
548    ContentBlockStart {
549        index: usize,
550        content_block: ResponseContent,
551    },
552    #[serde(rename = "content_block_delta")]
553    ContentBlockDelta { index: usize, delta: ContentDelta },
554    #[serde(rename = "content_block_stop")]
555    ContentBlockStop { index: usize },
556    #[serde(rename = "message_delta")]
557    MessageDelta { delta: MessageDelta, usage: Usage },
558    #[serde(rename = "message_stop")]
559    MessageStop,
560    #[serde(rename = "ping")]
561    Ping,
562    #[serde(rename = "error")]
563    Error { error: ApiError },
564}
565
566#[derive(Debug, Serialize, Deserialize)]
567#[serde(tag = "type")]
568pub enum ContentDelta {
569    #[serde(rename = "text_delta")]
570    TextDelta { text: String },
571    #[serde(rename = "input_json_delta")]
572    InputJsonDelta { partial_json: String },
573}
574
575#[derive(Debug, Serialize, Deserialize)]
576pub struct MessageDelta {
577    pub stop_reason: Option<String>,
578    pub stop_sequence: Option<String>,
579}
580
581#[derive(Error, Debug)]
582pub enum AnthropicError {
583    #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
584    ApiError(ApiError),
585    #[error("{0}")]
586    Other(#[from] anyhow::Error),
587}
588
589#[derive(Debug, Serialize, Deserialize)]
590pub struct ApiError {
591    #[serde(rename = "type")]
592    pub error_type: String,
593    pub message: String,
594}
595
596/// An Anthropic API error code.
597/// https://docs.anthropic.com/en/api/errors#http-errors
598#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
599#[strum(serialize_all = "snake_case")]
600pub enum ApiErrorCode {
601    /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
602    InvalidRequestError,
603    /// 401 - `authentication_error`: There's an issue with your API key.
604    AuthenticationError,
605    /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
606    PermissionError,
607    /// 404 - `not_found_error`: The requested resource was not found.
608    NotFoundError,
609    /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
610    RequestTooLarge,
611    /// 429 - `rate_limit_error`: Your account has hit a rate limit.
612    RateLimitError,
613    /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
614    ApiError,
615    /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
616    OverloadedError,
617}
618
619impl ApiError {
620    pub fn code(&self) -> Option<ApiErrorCode> {
621        ApiErrorCode::from_str(&self.error_type).ok()
622    }
623
624    pub fn is_rate_limit_error(&self) -> bool {
625        matches!(self.error_type.as_str(), "rate_limit_error")
626    }
627}