anthropic.rs

  1mod supported_countries;
  2
  3use std::time::Duration;
  4use std::{pin::Pin, str::FromStr};
  5
  6use anyhow::{anyhow, Context, Result};
  7use chrono::{DateTime, Utc};
  8use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use isahc::config::Configurable;
 11use isahc::http::{HeaderMap, HeaderValue};
 12use serde::{Deserialize, Serialize};
 13use strum::{EnumIter, EnumString};
 14use thiserror::Error;
 15use util::ResultExt as _;
 16
 17pub use supported_countries::*;
 18
 19pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
 20
 21#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 22#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 23pub struct AnthropicModelCacheConfiguration {
 24    pub min_total_token: usize,
 25    pub should_speculate: bool,
 26    pub max_cache_anchors: usize,
 27}
 28
 29#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 30#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 31pub enum Model {
 32    #[default]
 33    #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
 34    Claude3_5Sonnet,
 35    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
 36    Claude3Opus,
 37    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
 38    Claude3Sonnet,
 39    #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
 40    Claude3Haiku,
 41    #[serde(rename = "custom")]
 42    Custom {
 43        name: String,
 44        max_tokens: usize,
 45        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 46        display_name: Option<String>,
 47        /// Override this model with a different Anthropic model for tool calls.
 48        tool_override: Option<String>,
 49        /// Indicates whether this custom model supports caching.
 50        cache_configuration: Option<AnthropicModelCacheConfiguration>,
 51        max_output_tokens: Option<u32>,
 52    },
 53}
 54
 55impl Model {
 56    pub fn from_id(id: &str) -> Result<Self> {
 57        if id.starts_with("claude-3-5-sonnet") {
 58            Ok(Self::Claude3_5Sonnet)
 59        } else if id.starts_with("claude-3-opus") {
 60            Ok(Self::Claude3Opus)
 61        } else if id.starts_with("claude-3-sonnet") {
 62            Ok(Self::Claude3Sonnet)
 63        } else if id.starts_with("claude-3-haiku") {
 64            Ok(Self::Claude3Haiku)
 65        } else {
 66            Err(anyhow!("invalid model id"))
 67        }
 68    }
 69
 70    pub fn id(&self) -> &str {
 71        match self {
 72            Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
 73            Model::Claude3Opus => "claude-3-opus-20240229",
 74            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 75            Model::Claude3Haiku => "claude-3-haiku-20240307",
 76            Self::Custom { name, .. } => name,
 77        }
 78    }
 79
 80    pub fn display_name(&self) -> &str {
 81        match self {
 82            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 83            Self::Claude3Opus => "Claude 3 Opus",
 84            Self::Claude3Sonnet => "Claude 3 Sonnet",
 85            Self::Claude3Haiku => "Claude 3 Haiku",
 86            Self::Custom {
 87                name, display_name, ..
 88            } => display_name.as_ref().unwrap_or(name),
 89        }
 90    }
 91
 92    pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
 93        match self {
 94            Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
 95                min_total_token: 2_048,
 96                should_speculate: true,
 97                max_cache_anchors: 4,
 98            }),
 99            Self::Custom {
100                cache_configuration,
101                ..
102            } => cache_configuration.clone(),
103            _ => None,
104        }
105    }
106
107    pub fn max_token_count(&self) -> usize {
108        match self {
109            Self::Claude3_5Sonnet
110            | Self::Claude3Opus
111            | Self::Claude3Sonnet
112            | Self::Claude3Haiku => 200_000,
113            Self::Custom { max_tokens, .. } => *max_tokens,
114        }
115    }
116
117    pub fn max_output_tokens(&self) -> u32 {
118        match self {
119            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
120            Self::Claude3_5Sonnet => 8_192,
121            Self::Custom {
122                max_output_tokens, ..
123            } => max_output_tokens.unwrap_or(4_096),
124        }
125    }
126
127    pub fn tool_model_id(&self) -> &str {
128        if let Self::Custom {
129            tool_override: Some(tool_override),
130            ..
131        } = self
132        {
133            tool_override
134        } else {
135            self.id()
136        }
137    }
138}
139
140pub async fn complete(
141    client: &dyn HttpClient,
142    api_url: &str,
143    api_key: &str,
144    request: Request,
145) -> Result<Response, AnthropicError> {
146    let uri = format!("{api_url}/v1/messages");
147    let request_builder = HttpRequest::builder()
148        .method(Method::POST)
149        .uri(uri)
150        .header("Anthropic-Version", "2023-06-01")
151        .header(
152            "Anthropic-Beta",
153            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
154        )
155        .header("X-Api-Key", api_key)
156        .header("Content-Type", "application/json");
157
158    let serialized_request =
159        serde_json::to_string(&request).context("failed to serialize request")?;
160    let request = request_builder
161        .body(AsyncBody::from(serialized_request))
162        .context("failed to construct request body")?;
163
164    let mut response = client
165        .send(request)
166        .await
167        .context("failed to send request to Anthropic")?;
168    if response.status().is_success() {
169        let mut body = Vec::new();
170        response
171            .body_mut()
172            .read_to_end(&mut body)
173            .await
174            .context("failed to read response body")?;
175        let response_message: Response =
176            serde_json::from_slice(&body).context("failed to deserialize response body")?;
177        Ok(response_message)
178    } else {
179        let mut body = Vec::new();
180        response
181            .body_mut()
182            .read_to_end(&mut body)
183            .await
184            .context("failed to read response body")?;
185        let body_str =
186            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
187        Err(AnthropicError::Other(anyhow!(
188            "Failed to connect to API: {} {}",
189            response.status(),
190            body_str
191        )))
192    }
193}
194
195pub async fn stream_completion(
196    client: &dyn HttpClient,
197    api_url: &str,
198    api_key: &str,
199    request: Request,
200    low_speed_timeout: Option<Duration>,
201) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
202    stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
203        .await
204        .map(|output| output.0)
205}
206
207/// https://docs.anthropic.com/en/api/rate-limits#response-headers
208#[derive(Debug)]
209pub struct RateLimitInfo {
210    pub requests_limit: usize,
211    pub requests_remaining: usize,
212    pub requests_reset: DateTime<Utc>,
213    pub tokens_limit: usize,
214    pub tokens_remaining: usize,
215    pub tokens_reset: DateTime<Utc>,
216}
217
218impl RateLimitInfo {
219    fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
220        let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
221        let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
222        let tokens_remaining =
223            get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
224        let requests_remaining =
225            get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
226        let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
227        let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
228        let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
229        let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
230
231        Ok(Self {
232            requests_limit,
233            tokens_limit,
234            requests_remaining,
235            tokens_remaining,
236            requests_reset,
237            tokens_reset,
238        })
239    }
240}
241
242fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
243    Ok(headers
244        .get(key)
245        .ok_or_else(|| anyhow!("missing header `{key}`"))?
246        .to_str()?)
247}
248
249pub async fn stream_completion_with_rate_limit_info(
250    client: &dyn HttpClient,
251    api_url: &str,
252    api_key: &str,
253    request: Request,
254    low_speed_timeout: Option<Duration>,
255) -> Result<
256    (
257        BoxStream<'static, Result<Event, AnthropicError>>,
258        Option<RateLimitInfo>,
259    ),
260    AnthropicError,
261> {
262    let request = StreamingRequest {
263        base: request,
264        stream: true,
265    };
266    let uri = format!("{api_url}/v1/messages");
267    let mut request_builder = HttpRequest::builder()
268        .method(Method::POST)
269        .uri(uri)
270        .header("Anthropic-Version", "2023-06-01")
271        .header(
272            "Anthropic-Beta",
273            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
274        )
275        .header("X-Api-Key", api_key)
276        .header("Content-Type", "application/json");
277    if let Some(low_speed_timeout) = low_speed_timeout {
278        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
279    }
280    let serialized_request =
281        serde_json::to_string(&request).context("failed to serialize request")?;
282    let request = request_builder
283        .body(AsyncBody::from(serialized_request))
284        .context("failed to construct request body")?;
285
286    let mut response = client
287        .send(request)
288        .await
289        .context("failed to send request to Anthropic")?;
290    if response.status().is_success() {
291        let rate_limits = RateLimitInfo::from_headers(response.headers());
292        let reader = BufReader::new(response.into_body());
293        let stream = reader
294            .lines()
295            .filter_map(|line| async move {
296                match line {
297                    Ok(line) => {
298                        let line = line.strip_prefix("data: ")?;
299                        match serde_json::from_str(line) {
300                            Ok(response) => Some(Ok(response)),
301                            Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
302                        }
303                    }
304                    Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
305                }
306            })
307            .boxed();
308        Ok((stream, rate_limits.log_err()))
309    } else {
310        let mut body = Vec::new();
311        response
312            .body_mut()
313            .read_to_end(&mut body)
314            .await
315            .context("failed to read response body")?;
316
317        let body_str =
318            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
319
320        match serde_json::from_str::<Event>(body_str) {
321            Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
322            Ok(_) => Err(AnthropicError::Other(anyhow!(
323                "Unexpected success response while expecting an error: '{body_str}'",
324            ))),
325            Err(_) => Err(AnthropicError::Other(anyhow!(
326                "Failed to connect to API: {} {}",
327                response.status(),
328                body_str,
329            ))),
330        }
331    }
332}
333
334pub async fn extract_tool_args_from_events(
335    tool_name: String,
336    mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
337) -> Result<impl Send + Stream<Item = Result<String>>> {
338    let mut tool_use_index = None;
339    while let Some(event) = events.next().await {
340        if let Event::ContentBlockStart {
341            index,
342            content_block: ResponseContent::ToolUse { name, .. },
343        } = event?
344        {
345            if name == tool_name {
346                tool_use_index = Some(index);
347                break;
348            }
349        }
350    }
351
352    let Some(tool_use_index) = tool_use_index else {
353        return Err(anyhow!("tool not used"));
354    };
355
356    Ok(events.filter_map(move |event| {
357        let result = match event {
358            Err(error) => Some(Err(error)),
359            Ok(Event::ContentBlockDelta { index, delta }) => match delta {
360                ContentDelta::TextDelta { .. } => None,
361                ContentDelta::InputJsonDelta { partial_json } => {
362                    if index == tool_use_index {
363                        Some(Ok(partial_json))
364                    } else {
365                        None
366                    }
367                }
368            },
369            _ => None,
370        };
371
372        async move { result }
373    }))
374}
375
376#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
377#[serde(rename_all = "lowercase")]
378pub enum CacheControlType {
379    Ephemeral,
380}
381
382#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
383pub struct CacheControl {
384    #[serde(rename = "type")]
385    pub cache_type: CacheControlType,
386}
387
388#[derive(Debug, Serialize, Deserialize)]
389pub struct Message {
390    pub role: Role,
391    pub content: Vec<RequestContent>,
392}
393
394#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
395#[serde(rename_all = "lowercase")]
396pub enum Role {
397    User,
398    Assistant,
399}
400
401#[derive(Debug, Serialize, Deserialize)]
402#[serde(tag = "type")]
403pub enum RequestContent {
404    #[serde(rename = "text")]
405    Text {
406        text: String,
407        #[serde(skip_serializing_if = "Option::is_none")]
408        cache_control: Option<CacheControl>,
409    },
410    #[serde(rename = "image")]
411    Image {
412        source: ImageSource,
413        #[serde(skip_serializing_if = "Option::is_none")]
414        cache_control: Option<CacheControl>,
415    },
416    #[serde(rename = "tool_use")]
417    ToolUse {
418        id: String,
419        name: String,
420        input: serde_json::Value,
421        #[serde(skip_serializing_if = "Option::is_none")]
422        cache_control: Option<CacheControl>,
423    },
424    #[serde(rename = "tool_result")]
425    ToolResult {
426        tool_use_id: String,
427        is_error: bool,
428        content: String,
429        #[serde(skip_serializing_if = "Option::is_none")]
430        cache_control: Option<CacheControl>,
431    },
432}
433
434#[derive(Debug, Serialize, Deserialize)]
435#[serde(tag = "type")]
436pub enum ResponseContent {
437    #[serde(rename = "text")]
438    Text { text: String },
439    #[serde(rename = "tool_use")]
440    ToolUse {
441        id: String,
442        name: String,
443        input: serde_json::Value,
444    },
445}
446
447#[derive(Debug, Serialize, Deserialize)]
448pub struct ImageSource {
449    #[serde(rename = "type")]
450    pub source_type: String,
451    pub media_type: String,
452    pub data: String,
453}
454
455#[derive(Debug, Serialize, Deserialize)]
456pub struct Tool {
457    pub name: String,
458    pub description: String,
459    pub input_schema: serde_json::Value,
460}
461
462#[derive(Debug, Serialize, Deserialize)]
463#[serde(tag = "type", rename_all = "lowercase")]
464pub enum ToolChoice {
465    Auto,
466    Any,
467    Tool { name: String },
468}
469
470#[derive(Debug, Serialize, Deserialize)]
471pub struct Request {
472    pub model: String,
473    pub max_tokens: u32,
474    pub messages: Vec<Message>,
475    #[serde(default, skip_serializing_if = "Vec::is_empty")]
476    pub tools: Vec<Tool>,
477    #[serde(default, skip_serializing_if = "Option::is_none")]
478    pub tool_choice: Option<ToolChoice>,
479    #[serde(default, skip_serializing_if = "Option::is_none")]
480    pub system: Option<String>,
481    #[serde(default, skip_serializing_if = "Option::is_none")]
482    pub metadata: Option<Metadata>,
483    #[serde(default, skip_serializing_if = "Vec::is_empty")]
484    pub stop_sequences: Vec<String>,
485    #[serde(default, skip_serializing_if = "Option::is_none")]
486    pub temperature: Option<f32>,
487    #[serde(default, skip_serializing_if = "Option::is_none")]
488    pub top_k: Option<u32>,
489    #[serde(default, skip_serializing_if = "Option::is_none")]
490    pub top_p: Option<f32>,
491}
492
493#[derive(Debug, Serialize, Deserialize)]
494struct StreamingRequest {
495    #[serde(flatten)]
496    pub base: Request,
497    pub stream: bool,
498}
499
500#[derive(Debug, Serialize, Deserialize)]
501pub struct Metadata {
502    pub user_id: Option<String>,
503}
504
505#[derive(Debug, Serialize, Deserialize)]
506pub struct Usage {
507    #[serde(default, skip_serializing_if = "Option::is_none")]
508    pub input_tokens: Option<u32>,
509    #[serde(default, skip_serializing_if = "Option::is_none")]
510    pub output_tokens: Option<u32>,
511}
512
513#[derive(Debug, Serialize, Deserialize)]
514pub struct Response {
515    pub id: String,
516    #[serde(rename = "type")]
517    pub response_type: String,
518    pub role: Role,
519    pub content: Vec<ResponseContent>,
520    pub model: String,
521    #[serde(default, skip_serializing_if = "Option::is_none")]
522    pub stop_reason: Option<String>,
523    #[serde(default, skip_serializing_if = "Option::is_none")]
524    pub stop_sequence: Option<String>,
525    pub usage: Usage,
526}
527
528#[derive(Debug, Serialize, Deserialize)]
529#[serde(tag = "type")]
530pub enum Event {
531    #[serde(rename = "message_start")]
532    MessageStart { message: Response },
533    #[serde(rename = "content_block_start")]
534    ContentBlockStart {
535        index: usize,
536        content_block: ResponseContent,
537    },
538    #[serde(rename = "content_block_delta")]
539    ContentBlockDelta { index: usize, delta: ContentDelta },
540    #[serde(rename = "content_block_stop")]
541    ContentBlockStop { index: usize },
542    #[serde(rename = "message_delta")]
543    MessageDelta { delta: MessageDelta, usage: Usage },
544    #[serde(rename = "message_stop")]
545    MessageStop,
546    #[serde(rename = "ping")]
547    Ping,
548    #[serde(rename = "error")]
549    Error { error: ApiError },
550}
551
552#[derive(Debug, Serialize, Deserialize)]
553#[serde(tag = "type")]
554pub enum ContentDelta {
555    #[serde(rename = "text_delta")]
556    TextDelta { text: String },
557    #[serde(rename = "input_json_delta")]
558    InputJsonDelta { partial_json: String },
559}
560
561#[derive(Debug, Serialize, Deserialize)]
562pub struct MessageDelta {
563    pub stop_reason: Option<String>,
564    pub stop_sequence: Option<String>,
565}
566
567#[derive(Error, Debug)]
568pub enum AnthropicError {
569    #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
570    ApiError(ApiError),
571    #[error("{0}")]
572    Other(#[from] anyhow::Error),
573}
574
575#[derive(Debug, Serialize, Deserialize)]
576pub struct ApiError {
577    #[serde(rename = "type")]
578    pub error_type: String,
579    pub message: String,
580}
581
582/// An Anthropic API error code.
583/// https://docs.anthropic.com/en/api/errors#http-errors
584#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
585#[strum(serialize_all = "snake_case")]
586pub enum ApiErrorCode {
587    /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
588    InvalidRequestError,
589    /// 401 - `authentication_error`: There's an issue with your API key.
590    AuthenticationError,
591    /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
592    PermissionError,
593    /// 404 - `not_found_error`: The requested resource was not found.
594    NotFoundError,
595    /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
596    RequestTooLarge,
597    /// 429 - `rate_limit_error`: Your account has hit a rate limit.
598    RateLimitError,
599    /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
600    ApiError,
601    /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
602    OverloadedError,
603}
604
605impl ApiError {
606    pub fn code(&self) -> Option<ApiErrorCode> {
607        ApiErrorCode::from_str(&self.error_type).ok()
608    }
609
610    pub fn is_rate_limit_error(&self) -> bool {
611        matches!(self.error_type.as_str(), "rate_limit_error")
612    }
613}