anthropic.rs

  1mod supported_countries;
  2
  3use std::time::Duration;
  4use std::{pin::Pin, str::FromStr};
  5
  6use anyhow::{anyhow, Context, Result};
  7use chrono::{DateTime, Utc};
  8use collections::HashMap;
  9use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
 10use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 11use isahc::config::Configurable;
 12use isahc::http::{HeaderMap, HeaderValue};
 13use serde::{Deserialize, Serialize};
 14use strum::{EnumIter, EnumString};
 15use thiserror::Error;
 16use util::{maybe, ResultExt as _};
 17
 18pub use supported_countries::*;
 19
 20pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 21
 22#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 23#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 24pub struct AnthropicModelCacheConfiguration {
 25    pub min_total_token: usize,
 26    pub should_speculate: bool,
 27    pub max_cache_anchors: usize,
 28}
 29
 30#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 31#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 32pub enum Model {
 33    #[default]
 34    #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
 35    Claude3_5Sonnet,
 36    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
 37    Claude3Opus,
 38    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
 39    Claude3Sonnet,
 40    #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
 41    Claude3Haiku,
 42    #[serde(rename = "custom")]
 43    Custom {
 44        name: String,
 45        max_tokens: usize,
 46        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 47        display_name: Option<String>,
 48        /// Override this model with a different Anthropic model for tool calls.
 49        tool_override: Option<String>,
 50        /// Indicates whether this custom model supports caching.
 51        cache_configuration: Option<AnthropicModelCacheConfiguration>,
 52        max_output_tokens: Option<u32>,
 53    },
 54}
 55
 56impl Model {
 57    pub fn from_id(id: &str) -> Result<Self> {
 58        if id.starts_with("claude-3-5-sonnet") {
 59            Ok(Self::Claude3_5Sonnet)
 60        } else if id.starts_with("claude-3-opus") {
 61            Ok(Self::Claude3Opus)
 62        } else if id.starts_with("claude-3-sonnet") {
 63            Ok(Self::Claude3Sonnet)
 64        } else if id.starts_with("claude-3-haiku") {
 65            Ok(Self::Claude3Haiku)
 66        } else {
 67            Err(anyhow!("invalid model id"))
 68        }
 69    }
 70
 71    pub fn id(&self) -> &str {
 72        match self {
 73            Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
 74            Model::Claude3Opus => "claude-3-opus-20240229",
 75            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 76            Model::Claude3Haiku => "claude-3-haiku-20240307",
 77            Self::Custom { name, .. } => name,
 78        }
 79    }
 80
 81    pub fn display_name(&self) -> &str {
 82        match self {
 83            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 84            Self::Claude3Opus => "Claude 3 Opus",
 85            Self::Claude3Sonnet => "Claude 3 Sonnet",
 86            Self::Claude3Haiku => "Claude 3 Haiku",
 87            Self::Custom {
 88                name, display_name, ..
 89            } => display_name.as_ref().unwrap_or(name),
 90        }
 91    }
 92
 93    pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
 94        match self {
 95            Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
 96                min_total_token: 2_048,
 97                should_speculate: true,
 98                max_cache_anchors: 4,
 99            }),
100            Self::Custom {
101                cache_configuration,
102                ..
103            } => cache_configuration.clone(),
104            _ => None,
105        }
106    }
107
108    pub fn max_token_count(&self) -> usize {
109        match self {
110            Self::Claude3_5Sonnet
111            | Self::Claude3Opus
112            | Self::Claude3Sonnet
113            | Self::Claude3Haiku => 200_000,
114            Self::Custom { max_tokens, .. } => *max_tokens,
115        }
116    }
117
118    pub fn max_output_tokens(&self) -> u32 {
119        match self {
120            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
121            Self::Claude3_5Sonnet => 8_192,
122            Self::Custom {
123                max_output_tokens, ..
124            } => max_output_tokens.unwrap_or(4_096),
125        }
126    }
127
128    pub fn tool_model_id(&self) -> &str {
129        if let Self::Custom {
130            tool_override: Some(tool_override),
131            ..
132        } = self
133        {
134            tool_override
135        } else {
136            self.id()
137        }
138    }
139}
140
141pub async fn complete(
142    client: &dyn HttpClient,
143    api_url: &str,
144    api_key: &str,
145    request: Request,
146) -> Result<Response, AnthropicError> {
147    let uri = format!("{api_url}/v1/messages");
148    let request_builder = HttpRequest::builder()
149        .method(Method::POST)
150        .uri(uri)
151        .header("Anthropic-Version", "2023-06-01")
152        .header(
153            "Anthropic-Beta",
154            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
155        )
156        .header("X-Api-Key", api_key)
157        .header("Content-Type", "application/json");
158
159    let serialized_request =
160        serde_json::to_string(&request).context("failed to serialize request")?;
161    let request = request_builder
162        .body(AsyncBody::from(serialized_request))
163        .context("failed to construct request body")?;
164
165    let mut response = client
166        .send(request)
167        .await
168        .context("failed to send request to Anthropic")?;
169    if response.status().is_success() {
170        let mut body = Vec::new();
171        response
172            .body_mut()
173            .read_to_end(&mut body)
174            .await
175            .context("failed to read response body")?;
176        let response_message: Response =
177            serde_json::from_slice(&body).context("failed to deserialize response body")?;
178        Ok(response_message)
179    } else {
180        let mut body = Vec::new();
181        response
182            .body_mut()
183            .read_to_end(&mut body)
184            .await
185            .context("failed to read response body")?;
186        let body_str =
187            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
188        Err(AnthropicError::Other(anyhow!(
189            "Failed to connect to API: {} {}",
190            response.status(),
191            body_str
192        )))
193    }
194}
195
196pub async fn stream_completion(
197    client: &dyn HttpClient,
198    api_url: &str,
199    api_key: &str,
200    request: Request,
201    low_speed_timeout: Option<Duration>,
202) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
203    stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
204        .await
205        .map(|output| output.0)
206}
207
208/// https://docs.anthropic.com/en/api/rate-limits#response-headers
209#[derive(Debug)]
210pub struct RateLimitInfo {
211    pub requests_limit: usize,
212    pub requests_remaining: usize,
213    pub requests_reset: DateTime<Utc>,
214    pub tokens_limit: usize,
215    pub tokens_remaining: usize,
216    pub tokens_reset: DateTime<Utc>,
217}
218
219impl RateLimitInfo {
220    fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
221        let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
222        let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
223        let tokens_remaining =
224            get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
225        let requests_remaining =
226            get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
227        let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
228        let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
229        let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
230        let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
231
232        Ok(Self {
233            requests_limit,
234            tokens_limit,
235            requests_remaining,
236            tokens_remaining,
237            requests_reset,
238            tokens_reset,
239        })
240    }
241}
242
243fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
244    Ok(headers
245        .get(key)
246        .ok_or_else(|| anyhow!("missing header `{key}`"))?
247        .to_str()?)
248}
249
250pub async fn stream_completion_with_rate_limit_info(
251    client: &dyn HttpClient,
252    api_url: &str,
253    api_key: &str,
254    request: Request,
255    low_speed_timeout: Option<Duration>,
256) -> Result<
257    (
258        BoxStream<'static, Result<Event, AnthropicError>>,
259        Option<RateLimitInfo>,
260    ),
261    AnthropicError,
262> {
263    let request = StreamingRequest {
264        base: request,
265        stream: true,
266    };
267    let uri = format!("{api_url}/v1/messages");
268    let mut request_builder = HttpRequest::builder()
269        .method(Method::POST)
270        .uri(uri)
271        .header("Anthropic-Version", "2023-06-01")
272        .header(
273            "Anthropic-Beta",
274            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
275        )
276        .header("X-Api-Key", api_key)
277        .header("Content-Type", "application/json");
278    if let Some(low_speed_timeout) = low_speed_timeout {
279        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
280    }
281    let serialized_request =
282        serde_json::to_string(&request).context("failed to serialize request")?;
283    let request = request_builder
284        .body(AsyncBody::from(serialized_request))
285        .context("failed to construct request body")?;
286
287    let mut response = client
288        .send(request)
289        .await
290        .context("failed to send request to Anthropic")?;
291    if response.status().is_success() {
292        let rate_limits = RateLimitInfo::from_headers(response.headers());
293        let reader = BufReader::new(response.into_body());
294        let stream = reader
295            .lines()
296            .filter_map(|line| async move {
297                match line {
298                    Ok(line) => {
299                        let line = line.strip_prefix("data: ")?;
300                        match serde_json::from_str(line) {
301                            Ok(response) => Some(Ok(response)),
302                            Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
303                        }
304                    }
305                    Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
306                }
307            })
308            .boxed();
309        Ok((stream, rate_limits.log_err()))
310    } else {
311        let mut body = Vec::new();
312        response
313            .body_mut()
314            .read_to_end(&mut body)
315            .await
316            .context("failed to read response body")?;
317
318        let body_str =
319            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
320
321        match serde_json::from_str::<Event>(body_str) {
322            Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
323            Ok(_) => Err(AnthropicError::Other(anyhow!(
324                "Unexpected success response while expecting an error: '{body_str}'",
325            ))),
326            Err(_) => Err(AnthropicError::Other(anyhow!(
327                "Failed to connect to API: {} {}",
328                response.status(),
329                body_str,
330            ))),
331        }
332    }
333}
334
335pub fn extract_content_from_events(
336    events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
337) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
338    struct RawToolUse {
339        id: String,
340        name: String,
341        input_json: String,
342    }
343
344    struct State {
345        events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
346        tool_uses_by_index: HashMap<usize, RawToolUse>,
347    }
348
349    futures::stream::unfold(
350        State {
351            events,
352            tool_uses_by_index: HashMap::default(),
353        },
354        |mut state| async move {
355            while let Some(event) = state.events.next().await {
356                match event {
357                    Ok(event) => match event {
358                        Event::ContentBlockStart {
359                            index,
360                            content_block,
361                        } => match content_block {
362                            ResponseContent::Text { text } => {
363                                return Some((Some(Ok(ResponseContent::Text { text })), state));
364                            }
365                            ResponseContent::ToolUse { id, name, .. } => {
366                                state.tool_uses_by_index.insert(
367                                    index,
368                                    RawToolUse {
369                                        id,
370                                        name,
371                                        input_json: String::new(),
372                                    },
373                                );
374
375                                return Some((None, state));
376                            }
377                        },
378                        Event::ContentBlockDelta { index, delta } => match delta {
379                            ContentDelta::TextDelta { text } => {
380                                return Some((Some(Ok(ResponseContent::Text { text })), state));
381                            }
382                            ContentDelta::InputJsonDelta { partial_json } => {
383                                if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
384                                    tool_use.input_json.push_str(&partial_json);
385                                    return Some((None, state));
386                                }
387                            }
388                        },
389                        Event::ContentBlockStop { index } => {
390                            if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
391                                return Some((
392                                    Some(maybe!({
393                                        Ok(ResponseContent::ToolUse {
394                                            id: tool_use.id,
395                                            name: tool_use.name,
396                                            input: serde_json::Value::from_str(
397                                                &tool_use.input_json,
398                                            )
399                                            .map_err(|err| anyhow!(err))?,
400                                        })
401                                    })),
402                                    state,
403                                ));
404                            }
405                        }
406                        Event::Error { error } => {
407                            return Some((Some(Err(AnthropicError::ApiError(error))), state));
408                        }
409                        _ => {}
410                    },
411                    Err(err) => {
412                        return Some((Some(Err(err)), state));
413                    }
414                }
415            }
416
417            None
418        },
419    )
420    .filter_map(|event| async move { event })
421}
422
423pub async fn extract_tool_args_from_events(
424    tool_name: String,
425    mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
426) -> Result<impl Send + Stream<Item = Result<String>>> {
427    let mut tool_use_index = None;
428    while let Some(event) = events.next().await {
429        if let Event::ContentBlockStart {
430            index,
431            content_block,
432        } = event?
433        {
434            if let ResponseContent::ToolUse { name, .. } = content_block {
435                if name == tool_name {
436                    tool_use_index = Some(index);
437                    break;
438                }
439            }
440        }
441    }
442
443    let Some(tool_use_index) = tool_use_index else {
444        return Err(anyhow!("tool not used"));
445    };
446
447    Ok(events.filter_map(move |event| {
448        let result = match event {
449            Err(error) => Some(Err(error)),
450            Ok(Event::ContentBlockDelta { index, delta }) => match delta {
451                ContentDelta::TextDelta { .. } => None,
452                ContentDelta::InputJsonDelta { partial_json } => {
453                    if index == tool_use_index {
454                        Some(Ok(partial_json))
455                    } else {
456                        None
457                    }
458                }
459            },
460            _ => None,
461        };
462
463        async move { result }
464    }))
465}
466
467#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
468#[serde(rename_all = "lowercase")]
469pub enum CacheControlType {
470    Ephemeral,
471}
472
473#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
474pub struct CacheControl {
475    #[serde(rename = "type")]
476    pub cache_type: CacheControlType,
477}
478
479#[derive(Debug, Serialize, Deserialize)]
480pub struct Message {
481    pub role: Role,
482    pub content: Vec<RequestContent>,
483}
484
485#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
486#[serde(rename_all = "lowercase")]
487pub enum Role {
488    User,
489    Assistant,
490}
491
492#[derive(Debug, Serialize, Deserialize)]
493#[serde(tag = "type")]
494pub enum RequestContent {
495    #[serde(rename = "text")]
496    Text {
497        text: String,
498        #[serde(skip_serializing_if = "Option::is_none")]
499        cache_control: Option<CacheControl>,
500    },
501    #[serde(rename = "image")]
502    Image {
503        source: ImageSource,
504        #[serde(skip_serializing_if = "Option::is_none")]
505        cache_control: Option<CacheControl>,
506    },
507    #[serde(rename = "tool_use")]
508    ToolUse {
509        id: String,
510        name: String,
511        input: serde_json::Value,
512        #[serde(skip_serializing_if = "Option::is_none")]
513        cache_control: Option<CacheControl>,
514    },
515}
516
517#[derive(Debug, Serialize, Deserialize)]
518#[serde(tag = "type")]
519pub enum ResponseContent {
520    #[serde(rename = "text")]
521    Text { text: String },
522    #[serde(rename = "tool_use")]
523    ToolUse {
524        id: String,
525        name: String,
526        input: serde_json::Value,
527    },
528}
529
530#[derive(Debug, Serialize, Deserialize)]
531pub struct ImageSource {
532    #[serde(rename = "type")]
533    pub source_type: String,
534    pub media_type: String,
535    pub data: String,
536}
537
538#[derive(Debug, Serialize, Deserialize)]
539pub struct Tool {
540    pub name: String,
541    pub description: String,
542    pub input_schema: serde_json::Value,
543}
544
545#[derive(Debug, Serialize, Deserialize)]
546#[serde(tag = "type", rename_all = "lowercase")]
547pub enum ToolChoice {
548    Auto,
549    Any,
550    Tool { name: String },
551}
552
553#[derive(Debug, Serialize, Deserialize)]
554pub struct Request {
555    pub model: String,
556    pub max_tokens: u32,
557    pub messages: Vec<Message>,
558    #[serde(default, skip_serializing_if = "Vec::is_empty")]
559    pub tools: Vec<Tool>,
560    #[serde(default, skip_serializing_if = "Option::is_none")]
561    pub tool_choice: Option<ToolChoice>,
562    #[serde(default, skip_serializing_if = "Option::is_none")]
563    pub system: Option<String>,
564    #[serde(default, skip_serializing_if = "Option::is_none")]
565    pub metadata: Option<Metadata>,
566    #[serde(default, skip_serializing_if = "Vec::is_empty")]
567    pub stop_sequences: Vec<String>,
568    #[serde(default, skip_serializing_if = "Option::is_none")]
569    pub temperature: Option<f32>,
570    #[serde(default, skip_serializing_if = "Option::is_none")]
571    pub top_k: Option<u32>,
572    #[serde(default, skip_serializing_if = "Option::is_none")]
573    pub top_p: Option<f32>,
574}
575
576#[derive(Debug, Serialize, Deserialize)]
577struct StreamingRequest {
578    #[serde(flatten)]
579    pub base: Request,
580    pub stream: bool,
581}
582
583#[derive(Debug, Serialize, Deserialize)]
584pub struct Metadata {
585    pub user_id: Option<String>,
586}
587
588#[derive(Debug, Serialize, Deserialize)]
589pub struct Usage {
590    #[serde(default, skip_serializing_if = "Option::is_none")]
591    pub input_tokens: Option<u32>,
592    #[serde(default, skip_serializing_if = "Option::is_none")]
593    pub output_tokens: Option<u32>,
594}
595
596#[derive(Debug, Serialize, Deserialize)]
597pub struct Response {
598    pub id: String,
599    #[serde(rename = "type")]
600    pub response_type: String,
601    pub role: Role,
602    pub content: Vec<ResponseContent>,
603    pub model: String,
604    #[serde(default, skip_serializing_if = "Option::is_none")]
605    pub stop_reason: Option<String>,
606    #[serde(default, skip_serializing_if = "Option::is_none")]
607    pub stop_sequence: Option<String>,
608    pub usage: Usage,
609}
610
611#[derive(Debug, Serialize, Deserialize)]
612#[serde(tag = "type")]
613pub enum Event {
614    #[serde(rename = "message_start")]
615    MessageStart { message: Response },
616    #[serde(rename = "content_block_start")]
617    ContentBlockStart {
618        index: usize,
619        content_block: ResponseContent,
620    },
621    #[serde(rename = "content_block_delta")]
622    ContentBlockDelta { index: usize, delta: ContentDelta },
623    #[serde(rename = "content_block_stop")]
624    ContentBlockStop { index: usize },
625    #[serde(rename = "message_delta")]
626    MessageDelta { delta: MessageDelta, usage: Usage },
627    #[serde(rename = "message_stop")]
628    MessageStop,
629    #[serde(rename = "ping")]
630    Ping,
631    #[serde(rename = "error")]
632    Error { error: ApiError },
633}
634
635#[derive(Debug, Serialize, Deserialize)]
636#[serde(tag = "type")]
637pub enum ContentDelta {
638    #[serde(rename = "text_delta")]
639    TextDelta { text: String },
640    #[serde(rename = "input_json_delta")]
641    InputJsonDelta { partial_json: String },
642}
643
644#[derive(Debug, Serialize, Deserialize)]
645pub struct MessageDelta {
646    pub stop_reason: Option<String>,
647    pub stop_sequence: Option<String>,
648}
649
650#[derive(Error, Debug)]
651pub enum AnthropicError {
652    #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
653    ApiError(ApiError),
654    #[error("{0}")]
655    Other(#[from] anyhow::Error),
656}
657
658#[derive(Debug, Serialize, Deserialize)]
659pub struct ApiError {
660    #[serde(rename = "type")]
661    pub error_type: String,
662    pub message: String,
663}
664
665/// An Anthropic API error code.
666/// https://docs.anthropic.com/en/api/errors#http-errors
667#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
668#[strum(serialize_all = "snake_case")]
669pub enum ApiErrorCode {
670    /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
671    InvalidRequestError,
672    /// 401 - `authentication_error`: There's an issue with your API key.
673    AuthenticationError,
674    /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
675    PermissionError,
676    /// 404 - `not_found_error`: The requested resource was not found.
677    NotFoundError,
678    /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
679    RequestTooLarge,
680    /// 429 - `rate_limit_error`: Your account has hit a rate limit.
681    RateLimitError,
682    /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
683    ApiError,
684    /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
685    OverloadedError,
686}
687
688impl ApiError {
689    pub fn code(&self) -> Option<ApiErrorCode> {
690        ApiErrorCode::from_str(&self.error_type).ok()
691    }
692
693    pub fn is_rate_limit_error(&self) -> bool {
694        match self.error_type.as_str() {
695            "rate_limit_error" => true,
696            _ => false,
697        }
698    }
699}