anthropic.rs

  1mod supported_countries;
  2
  3use std::time::Duration;
  4use std::{pin::Pin, str::FromStr};
  5
  6use anyhow::{anyhow, Context, Result};
  7use chrono::{DateTime, Utc};
  8use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  9use http_client::http::{HeaderMap, HeaderValue};
 10use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
 11use serde::{Deserialize, Serialize};
 12use strum::{EnumIter, EnumString};
 13use thiserror::Error;
 14use util::ResultExt as _;
 15
 16pub use supported_countries::*;
 17
 18pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
 19
 20#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 21#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 22pub struct AnthropicModelCacheConfiguration {
 23    pub min_total_token: usize,
 24    pub should_speculate: bool,
 25    pub max_cache_anchors: usize,
 26}
 27
 28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 30pub enum Model {
 31    #[default]
 32    #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
 33    Claude3_5Sonnet,
 34    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
 35    Claude3Opus,
 36    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
 37    Claude3Sonnet,
 38    #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-latest")]
 39    Claude3Haiku,
 40    #[serde(rename = "custom")]
 41    Custom {
 42        name: String,
 43        max_tokens: usize,
 44        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 45        display_name: Option<String>,
 46        /// Override this model with a different Anthropic model for tool calls.
 47        tool_override: Option<String>,
 48        /// Indicates whether this custom model supports caching.
 49        cache_configuration: Option<AnthropicModelCacheConfiguration>,
 50        max_output_tokens: Option<u32>,
 51        default_temperature: Option<f32>,
 52    },
 53}
 54
 55impl Model {
 56    pub fn from_id(id: &str) -> Result<Self> {
 57        if id.starts_with("claude-3-5-sonnet") {
 58            Ok(Self::Claude3_5Sonnet)
 59        } else if id.starts_with("claude-3-opus") {
 60            Ok(Self::Claude3Opus)
 61        } else if id.starts_with("claude-3-sonnet") {
 62            Ok(Self::Claude3Sonnet)
 63        } else if id.starts_with("claude-3-haiku") {
 64            Ok(Self::Claude3Haiku)
 65        } else {
 66            Err(anyhow!("invalid model id"))
 67        }
 68    }
 69
 70    pub fn id(&self) -> &str {
 71        match self {
 72            Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
 73            Model::Claude3Opus => "claude-3-opus-latest",
 74            Model::Claude3Sonnet => "claude-3-sonnet-latest",
 75            Model::Claude3Haiku => "claude-3-haiku-latest",
 76            Self::Custom { name, .. } => name,
 77        }
 78    }
 79
 80    pub fn display_name(&self) -> &str {
 81        match self {
 82            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 83            Self::Claude3Opus => "Claude 3 Opus",
 84            Self::Claude3Sonnet => "Claude 3 Sonnet",
 85            Self::Claude3Haiku => "Claude 3 Haiku",
 86            Self::Custom {
 87                name, display_name, ..
 88            } => display_name.as_ref().unwrap_or(name),
 89        }
 90    }
 91
 92    pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
 93        match self {
 94            Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
 95                min_total_token: 2_048,
 96                should_speculate: true,
 97                max_cache_anchors: 4,
 98            }),
 99            Self::Custom {
100                cache_configuration,
101                ..
102            } => cache_configuration.clone(),
103            _ => None,
104        }
105    }
106
107    pub fn max_token_count(&self) -> usize {
108        match self {
109            Self::Claude3_5Sonnet
110            | Self::Claude3Opus
111            | Self::Claude3Sonnet
112            | Self::Claude3Haiku => 200_000,
113            Self::Custom { max_tokens, .. } => *max_tokens,
114        }
115    }
116
117    pub fn max_output_tokens(&self) -> u32 {
118        match self {
119            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
120            Self::Claude3_5Sonnet => 8_192,
121            Self::Custom {
122                max_output_tokens, ..
123            } => max_output_tokens.unwrap_or(4_096),
124        }
125    }
126
127    pub fn default_temperature(&self) -> f32 {
128        match self {
129            Self::Claude3_5Sonnet
130            | Self::Claude3Opus
131            | Self::Claude3Sonnet
132            | Self::Claude3Haiku => 1.0,
133            Self::Custom {
134                default_temperature,
135                ..
136            } => default_temperature.unwrap_or(1.0),
137        }
138    }
139
140    pub fn tool_model_id(&self) -> &str {
141        if let Self::Custom {
142            tool_override: Some(tool_override),
143            ..
144        } = self
145        {
146            tool_override
147        } else {
148            self.id()
149        }
150    }
151}
152
153pub async fn complete(
154    client: &dyn HttpClient,
155    api_url: &str,
156    api_key: &str,
157    request: Request,
158) -> Result<Response, AnthropicError> {
159    let uri = format!("{api_url}/v1/messages");
160    let request_builder = HttpRequest::builder()
161        .method(Method::POST)
162        .uri(uri)
163        .header("Anthropic-Version", "2023-06-01")
164        .header(
165            "Anthropic-Beta",
166            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
167        )
168        .header("X-Api-Key", api_key)
169        .header("Content-Type", "application/json");
170
171    let serialized_request =
172        serde_json::to_string(&request).context("failed to serialize request")?;
173    let request = request_builder
174        .body(AsyncBody::from(serialized_request))
175        .context("failed to construct request body")?;
176
177    let mut response = client
178        .send(request)
179        .await
180        .context("failed to send request to Anthropic")?;
181    if response.status().is_success() {
182        let mut body = Vec::new();
183        response
184            .body_mut()
185            .read_to_end(&mut body)
186            .await
187            .context("failed to read response body")?;
188        let response_message: Response =
189            serde_json::from_slice(&body).context("failed to deserialize response body")?;
190        Ok(response_message)
191    } else {
192        let mut body = Vec::new();
193        response
194            .body_mut()
195            .read_to_end(&mut body)
196            .await
197            .context("failed to read response body")?;
198        let body_str =
199            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
200        Err(AnthropicError::Other(anyhow!(
201            "Failed to connect to API: {} {}",
202            response.status(),
203            body_str
204        )))
205    }
206}
207
208pub async fn stream_completion(
209    client: &dyn HttpClient,
210    api_url: &str,
211    api_key: &str,
212    request: Request,
213    low_speed_timeout: Option<Duration>,
214) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
215    stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
216        .await
217        .map(|output| output.0)
218}
219
220/// https://docs.anthropic.com/en/api/rate-limits#response-headers
221#[derive(Debug)]
222pub struct RateLimitInfo {
223    pub requests_limit: usize,
224    pub requests_remaining: usize,
225    pub requests_reset: DateTime<Utc>,
226    pub tokens_limit: usize,
227    pub tokens_remaining: usize,
228    pub tokens_reset: DateTime<Utc>,
229}
230
231impl RateLimitInfo {
232    fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
233        let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
234        let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
235        let tokens_remaining =
236            get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
237        let requests_remaining =
238            get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
239        let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
240        let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
241        let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
242        let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
243
244        Ok(Self {
245            requests_limit,
246            tokens_limit,
247            requests_remaining,
248            tokens_remaining,
249            requests_reset,
250            tokens_reset,
251        })
252    }
253}
254
255fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
256    Ok(headers
257        .get(key)
258        .ok_or_else(|| anyhow!("missing header `{key}`"))?
259        .to_str()?)
260}
261
262pub async fn stream_completion_with_rate_limit_info(
263    client: &dyn HttpClient,
264    api_url: &str,
265    api_key: &str,
266    request: Request,
267    low_speed_timeout: Option<Duration>,
268) -> Result<
269    (
270        BoxStream<'static, Result<Event, AnthropicError>>,
271        Option<RateLimitInfo>,
272    ),
273    AnthropicError,
274> {
275    let request = StreamingRequest {
276        base: request,
277        stream: true,
278    };
279    let uri = format!("{api_url}/v1/messages");
280    let mut request_builder = HttpRequest::builder()
281        .method(Method::POST)
282        .uri(uri)
283        .header("Anthropic-Version", "2023-06-01")
284        .header(
285            "Anthropic-Beta",
286            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
287        )
288        .header("X-Api-Key", api_key)
289        .header("Content-Type", "application/json");
290    if let Some(low_speed_timeout) = low_speed_timeout {
291        request_builder = request_builder.read_timeout(low_speed_timeout);
292    }
293    let serialized_request =
294        serde_json::to_string(&request).context("failed to serialize request")?;
295    let request = request_builder
296        .body(AsyncBody::from(serialized_request))
297        .context("failed to construct request body")?;
298
299    let mut response = client
300        .send(request)
301        .await
302        .context("failed to send request to Anthropic")?;
303    if response.status().is_success() {
304        let rate_limits = RateLimitInfo::from_headers(response.headers());
305        let reader = BufReader::new(response.into_body());
306        let stream = reader
307            .lines()
308            .filter_map(|line| async move {
309                match line {
310                    Ok(line) => {
311                        let line = line.strip_prefix("data: ")?;
312                        match serde_json::from_str(line) {
313                            Ok(response) => Some(Ok(response)),
314                            Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
315                        }
316                    }
317                    Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
318                }
319            })
320            .boxed();
321        Ok((stream, rate_limits.log_err()))
322    } else {
323        let mut body = Vec::new();
324        response
325            .body_mut()
326            .read_to_end(&mut body)
327            .await
328            .context("failed to read response body")?;
329
330        let body_str =
331            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
332
333        match serde_json::from_str::<Event>(body_str) {
334            Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
335            Ok(_) => Err(AnthropicError::Other(anyhow!(
336                "Unexpected success response while expecting an error: '{body_str}'",
337            ))),
338            Err(_) => Err(AnthropicError::Other(anyhow!(
339                "Failed to connect to API: {} {}",
340                response.status(),
341                body_str,
342            ))),
343        }
344    }
345}
346
347pub async fn extract_tool_args_from_events(
348    tool_name: String,
349    mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
350) -> Result<impl Send + Stream<Item = Result<String>>> {
351    let mut tool_use_index = None;
352    while let Some(event) = events.next().await {
353        if let Event::ContentBlockStart {
354            index,
355            content_block: ResponseContent::ToolUse { name, .. },
356        } = event?
357        {
358            if name == tool_name {
359                tool_use_index = Some(index);
360                break;
361            }
362        }
363    }
364
365    let Some(tool_use_index) = tool_use_index else {
366        return Err(anyhow!("tool not used"));
367    };
368
369    Ok(events.filter_map(move |event| {
370        let result = match event {
371            Err(error) => Some(Err(error)),
372            Ok(Event::ContentBlockDelta { index, delta }) => match delta {
373                ContentDelta::TextDelta { .. } => None,
374                ContentDelta::InputJsonDelta { partial_json } => {
375                    if index == tool_use_index {
376                        Some(Ok(partial_json))
377                    } else {
378                        None
379                    }
380                }
381            },
382            _ => None,
383        };
384
385        async move { result }
386    }))
387}
388
389#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
390#[serde(rename_all = "lowercase")]
391pub enum CacheControlType {
392    Ephemeral,
393}
394
395#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
396pub struct CacheControl {
397    #[serde(rename = "type")]
398    pub cache_type: CacheControlType,
399}
400
401#[derive(Debug, Serialize, Deserialize)]
402pub struct Message {
403    pub role: Role,
404    pub content: Vec<RequestContent>,
405}
406
407#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
408#[serde(rename_all = "lowercase")]
409pub enum Role {
410    User,
411    Assistant,
412}
413
414#[derive(Debug, Serialize, Deserialize)]
415#[serde(tag = "type")]
416pub enum RequestContent {
417    #[serde(rename = "text")]
418    Text {
419        text: String,
420        #[serde(skip_serializing_if = "Option::is_none")]
421        cache_control: Option<CacheControl>,
422    },
423    #[serde(rename = "image")]
424    Image {
425        source: ImageSource,
426        #[serde(skip_serializing_if = "Option::is_none")]
427        cache_control: Option<CacheControl>,
428    },
429    #[serde(rename = "tool_use")]
430    ToolUse {
431        id: String,
432        name: String,
433        input: serde_json::Value,
434        #[serde(skip_serializing_if = "Option::is_none")]
435        cache_control: Option<CacheControl>,
436    },
437    #[serde(rename = "tool_result")]
438    ToolResult {
439        tool_use_id: String,
440        is_error: bool,
441        content: String,
442        #[serde(skip_serializing_if = "Option::is_none")]
443        cache_control: Option<CacheControl>,
444    },
445}
446
447#[derive(Debug, Serialize, Deserialize)]
448#[serde(tag = "type")]
449pub enum ResponseContent {
450    #[serde(rename = "text")]
451    Text { text: String },
452    #[serde(rename = "tool_use")]
453    ToolUse {
454        id: String,
455        name: String,
456        input: serde_json::Value,
457    },
458}
459
460#[derive(Debug, Serialize, Deserialize)]
461pub struct ImageSource {
462    #[serde(rename = "type")]
463    pub source_type: String,
464    pub media_type: String,
465    pub data: String,
466}
467
468#[derive(Debug, Serialize, Deserialize)]
469pub struct Tool {
470    pub name: String,
471    pub description: String,
472    pub input_schema: serde_json::Value,
473}
474
475#[derive(Debug, Serialize, Deserialize)]
476#[serde(tag = "type", rename_all = "lowercase")]
477pub enum ToolChoice {
478    Auto,
479    Any,
480    Tool { name: String },
481}
482
483#[derive(Debug, Serialize, Deserialize)]
484pub struct Request {
485    pub model: String,
486    pub max_tokens: u32,
487    pub messages: Vec<Message>,
488    #[serde(default, skip_serializing_if = "Vec::is_empty")]
489    pub tools: Vec<Tool>,
490    #[serde(default, skip_serializing_if = "Option::is_none")]
491    pub tool_choice: Option<ToolChoice>,
492    #[serde(default, skip_serializing_if = "Option::is_none")]
493    pub system: Option<String>,
494    #[serde(default, skip_serializing_if = "Option::is_none")]
495    pub metadata: Option<Metadata>,
496    #[serde(default, skip_serializing_if = "Vec::is_empty")]
497    pub stop_sequences: Vec<String>,
498    #[serde(default, skip_serializing_if = "Option::is_none")]
499    pub temperature: Option<f32>,
500    #[serde(default, skip_serializing_if = "Option::is_none")]
501    pub top_k: Option<u32>,
502    #[serde(default, skip_serializing_if = "Option::is_none")]
503    pub top_p: Option<f32>,
504}
505
506#[derive(Debug, Serialize, Deserialize)]
507struct StreamingRequest {
508    #[serde(flatten)]
509    pub base: Request,
510    pub stream: bool,
511}
512
513#[derive(Debug, Serialize, Deserialize)]
514pub struct Metadata {
515    pub user_id: Option<String>,
516}
517
518#[derive(Debug, Serialize, Deserialize)]
519pub struct Usage {
520    #[serde(default, skip_serializing_if = "Option::is_none")]
521    pub input_tokens: Option<u32>,
522    #[serde(default, skip_serializing_if = "Option::is_none")]
523    pub output_tokens: Option<u32>,
524    #[serde(default, skip_serializing_if = "Option::is_none")]
525    pub cache_creation_input_tokens: Option<u32>,
526    #[serde(default, skip_serializing_if = "Option::is_none")]
527    pub cache_read_input_tokens: Option<u32>,
528}
529
530#[derive(Debug, Serialize, Deserialize)]
531pub struct Response {
532    pub id: String,
533    #[serde(rename = "type")]
534    pub response_type: String,
535    pub role: Role,
536    pub content: Vec<ResponseContent>,
537    pub model: String,
538    #[serde(default, skip_serializing_if = "Option::is_none")]
539    pub stop_reason: Option<String>,
540    #[serde(default, skip_serializing_if = "Option::is_none")]
541    pub stop_sequence: Option<String>,
542    pub usage: Usage,
543}
544
545#[derive(Debug, Serialize, Deserialize)]
546#[serde(tag = "type")]
547pub enum Event {
548    #[serde(rename = "message_start")]
549    MessageStart { message: Response },
550    #[serde(rename = "content_block_start")]
551    ContentBlockStart {
552        index: usize,
553        content_block: ResponseContent,
554    },
555    #[serde(rename = "content_block_delta")]
556    ContentBlockDelta { index: usize, delta: ContentDelta },
557    #[serde(rename = "content_block_stop")]
558    ContentBlockStop { index: usize },
559    #[serde(rename = "message_delta")]
560    MessageDelta { delta: MessageDelta, usage: Usage },
561    #[serde(rename = "message_stop")]
562    MessageStop,
563    #[serde(rename = "ping")]
564    Ping,
565    #[serde(rename = "error")]
566    Error { error: ApiError },
567}
568
569#[derive(Debug, Serialize, Deserialize)]
570#[serde(tag = "type")]
571pub enum ContentDelta {
572    #[serde(rename = "text_delta")]
573    TextDelta { text: String },
574    #[serde(rename = "input_json_delta")]
575    InputJsonDelta { partial_json: String },
576}
577
578#[derive(Debug, Serialize, Deserialize)]
579pub struct MessageDelta {
580    pub stop_reason: Option<String>,
581    pub stop_sequence: Option<String>,
582}
583
584#[derive(Error, Debug)]
585pub enum AnthropicError {
586    #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
587    ApiError(ApiError),
588    #[error("{0}")]
589    Other(#[from] anyhow::Error),
590}
591
592#[derive(Debug, Serialize, Deserialize)]
593pub struct ApiError {
594    #[serde(rename = "type")]
595    pub error_type: String,
596    pub message: String,
597}
598
599/// An Anthropic API error code.
600/// https://docs.anthropic.com/en/api/errors#http-errors
601#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
602#[strum(serialize_all = "snake_case")]
603pub enum ApiErrorCode {
604    /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
605    InvalidRequestError,
606    /// 401 - `authentication_error`: There's an issue with your API key.
607    AuthenticationError,
608    /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
609    PermissionError,
610    /// 404 - `not_found_error`: The requested resource was not found.
611    NotFoundError,
612    /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
613    RequestTooLarge,
614    /// 429 - `rate_limit_error`: Your account has hit a rate limit.
615    RateLimitError,
616    /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
617    ApiError,
618    /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
619    OverloadedError,
620}
621
622impl ApiError {
623    pub fn code(&self) -> Option<ApiErrorCode> {
624        ApiErrorCode::from_str(&self.error_type).ok()
625    }
626
627    pub fn is_rate_limit_error(&self) -> bool {
628        matches!(self.error_type.as_str(), "rate_limit_error")
629    }
630}