anthropic.rs

  1mod supported_countries;
  2
  3use std::{pin::Pin, str::FromStr};
  4
  5use anyhow::{anyhow, Context, Result};
  6use chrono::{DateTime, Utc};
  7use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  8use http_client::http::{HeaderMap, HeaderValue};
  9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
 10use serde::{Deserialize, Serialize};
 11use strum::{EnumIter, EnumString};
 12use thiserror::Error;
 13use util::ResultExt as _;
 14
 15pub use supported_countries::*;
 16
 17pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
 18
 19#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 20#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 21pub struct AnthropicModelCacheConfiguration {
 22    pub min_total_token: usize,
 23    pub should_speculate: bool,
 24    pub max_cache_anchors: usize,
 25}
 26
 27#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 28#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 29pub enum Model {
 30    #[default]
 31    #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
 32    Claude3_5Sonnet,
 33    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
 34    Claude3Opus,
 35    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
 36    Claude3Sonnet,
 37    #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-latest")]
 38    Claude3Haiku,
 39    #[serde(rename = "custom")]
 40    Custom {
 41        name: String,
 42        max_tokens: usize,
 43        /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
 44        display_name: Option<String>,
 45        /// Override this model with a different Anthropic model for tool calls.
 46        tool_override: Option<String>,
 47        /// Indicates whether this custom model supports caching.
 48        cache_configuration: Option<AnthropicModelCacheConfiguration>,
 49        max_output_tokens: Option<u32>,
 50        default_temperature: Option<f32>,
 51    },
 52}
 53
 54impl Model {
 55    pub fn from_id(id: &str) -> Result<Self> {
 56        if id.starts_with("claude-3-5-sonnet") {
 57            Ok(Self::Claude3_5Sonnet)
 58        } else if id.starts_with("claude-3-opus") {
 59            Ok(Self::Claude3Opus)
 60        } else if id.starts_with("claude-3-sonnet") {
 61            Ok(Self::Claude3Sonnet)
 62        } else if id.starts_with("claude-3-haiku") {
 63            Ok(Self::Claude3Haiku)
 64        } else {
 65            Err(anyhow!("invalid model id"))
 66        }
 67    }
 68
 69    pub fn id(&self) -> &str {
 70        match self {
 71            Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
 72            Model::Claude3Opus => "claude-3-opus-latest",
 73            Model::Claude3Sonnet => "claude-3-sonnet-latest",
 74            Model::Claude3Haiku => "claude-3-haiku-latest",
 75            Self::Custom { name, .. } => name,
 76        }
 77    }
 78
 79    pub fn display_name(&self) -> &str {
 80        match self {
 81            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 82            Self::Claude3Opus => "Claude 3 Opus",
 83            Self::Claude3Sonnet => "Claude 3 Sonnet",
 84            Self::Claude3Haiku => "Claude 3 Haiku",
 85            Self::Custom {
 86                name, display_name, ..
 87            } => display_name.as_ref().unwrap_or(name),
 88        }
 89    }
 90
 91    pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
 92        match self {
 93            Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
 94                min_total_token: 2_048,
 95                should_speculate: true,
 96                max_cache_anchors: 4,
 97            }),
 98            Self::Custom {
 99                cache_configuration,
100                ..
101            } => cache_configuration.clone(),
102            _ => None,
103        }
104    }
105
106    pub fn max_token_count(&self) -> usize {
107        match self {
108            Self::Claude3_5Sonnet
109            | Self::Claude3Opus
110            | Self::Claude3Sonnet
111            | Self::Claude3Haiku => 200_000,
112            Self::Custom { max_tokens, .. } => *max_tokens,
113        }
114    }
115
116    pub fn max_output_tokens(&self) -> u32 {
117        match self {
118            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
119            Self::Claude3_5Sonnet => 8_192,
120            Self::Custom {
121                max_output_tokens, ..
122            } => max_output_tokens.unwrap_or(4_096),
123        }
124    }
125
126    pub fn default_temperature(&self) -> f32 {
127        match self {
128            Self::Claude3_5Sonnet
129            | Self::Claude3Opus
130            | Self::Claude3Sonnet
131            | Self::Claude3Haiku => 1.0,
132            Self::Custom {
133                default_temperature,
134                ..
135            } => default_temperature.unwrap_or(1.0),
136        }
137    }
138
139    pub fn tool_model_id(&self) -> &str {
140        if let Self::Custom {
141            tool_override: Some(tool_override),
142            ..
143        } = self
144        {
145            tool_override
146        } else {
147            self.id()
148        }
149    }
150}
151
152pub async fn complete(
153    client: &dyn HttpClient,
154    api_url: &str,
155    api_key: &str,
156    request: Request,
157) -> Result<Response, AnthropicError> {
158    let uri = format!("{api_url}/v1/messages");
159    let request_builder = HttpRequest::builder()
160        .method(Method::POST)
161        .uri(uri)
162        .header("Anthropic-Version", "2023-06-01")
163        .header("Anthropic-Beta", "prompt-caching-2024-07-31")
164        .header("X-Api-Key", api_key)
165        .header("Content-Type", "application/json");
166
167    let serialized_request =
168        serde_json::to_string(&request).context("failed to serialize request")?;
169    let request = request_builder
170        .body(AsyncBody::from(serialized_request))
171        .context("failed to construct request body")?;
172
173    let mut response = client
174        .send(request)
175        .await
176        .context("failed to send request to Anthropic")?;
177    if response.status().is_success() {
178        let mut body = Vec::new();
179        response
180            .body_mut()
181            .read_to_end(&mut body)
182            .await
183            .context("failed to read response body")?;
184        let response_message: Response =
185            serde_json::from_slice(&body).context("failed to deserialize response body")?;
186        Ok(response_message)
187    } else {
188        let mut body = Vec::new();
189        response
190            .body_mut()
191            .read_to_end(&mut body)
192            .await
193            .context("failed to read response body")?;
194        let body_str =
195            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
196        Err(AnthropicError::Other(anyhow!(
197            "Failed to connect to API: {} {}",
198            response.status(),
199            body_str
200        )))
201    }
202}
203
204pub async fn stream_completion(
205    client: &dyn HttpClient,
206    api_url: &str,
207    api_key: &str,
208    request: Request,
209) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
210    stream_completion_with_rate_limit_info(client, api_url, api_key, request)
211        .await
212        .map(|output| output.0)
213}
214
215/// https://docs.anthropic.com/en/api/rate-limits#response-headers
216#[derive(Debug)]
217pub struct RateLimitInfo {
218    pub requests_limit: usize,
219    pub requests_remaining: usize,
220    pub requests_reset: DateTime<Utc>,
221    pub tokens_limit: usize,
222    pub tokens_remaining: usize,
223    pub tokens_reset: DateTime<Utc>,
224}
225
226impl RateLimitInfo {
227    fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
228        let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
229        let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
230        let tokens_remaining =
231            get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
232        let requests_remaining =
233            get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
234        let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
235        let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
236        let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
237        let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
238
239        Ok(Self {
240            requests_limit,
241            tokens_limit,
242            requests_remaining,
243            tokens_remaining,
244            requests_reset,
245            tokens_reset,
246        })
247    }
248}
249
250fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
251    Ok(headers
252        .get(key)
253        .ok_or_else(|| anyhow!("missing header `{key}`"))?
254        .to_str()?)
255}
256
257pub async fn stream_completion_with_rate_limit_info(
258    client: &dyn HttpClient,
259    api_url: &str,
260    api_key: &str,
261    request: Request,
262) -> Result<
263    (
264        BoxStream<'static, Result<Event, AnthropicError>>,
265        Option<RateLimitInfo>,
266    ),
267    AnthropicError,
268> {
269    let request = StreamingRequest {
270        base: request,
271        stream: true,
272    };
273    let uri = format!("{api_url}/v1/messages");
274    let request_builder = HttpRequest::builder()
275        .method(Method::POST)
276        .uri(uri)
277        .header("Anthropic-Version", "2023-06-01")
278        .header(
279            "Anthropic-Beta",
280            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
281        )
282        .header("X-Api-Key", api_key)
283        .header("Content-Type", "application/json");
284    let serialized_request =
285        serde_json::to_string(&request).context("failed to serialize request")?;
286    let request = request_builder
287        .body(AsyncBody::from(serialized_request))
288        .context("failed to construct request body")?;
289
290    let mut response = client
291        .send(request)
292        .await
293        .context("failed to send request to Anthropic")?;
294    if response.status().is_success() {
295        let rate_limits = RateLimitInfo::from_headers(response.headers());
296        let reader = BufReader::new(response.into_body());
297        let stream = reader
298            .lines()
299            .filter_map(|line| async move {
300                match line {
301                    Ok(line) => {
302                        let line = line.strip_prefix("data: ")?;
303                        match serde_json::from_str(line) {
304                            Ok(response) => Some(Ok(response)),
305                            Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
306                        }
307                    }
308                    Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
309                }
310            })
311            .boxed();
312        Ok((stream, rate_limits.log_err()))
313    } else {
314        let mut body = Vec::new();
315        response
316            .body_mut()
317            .read_to_end(&mut body)
318            .await
319            .context("failed to read response body")?;
320
321        let body_str =
322            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
323
324        match serde_json::from_str::<Event>(body_str) {
325            Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
326            Ok(_) => Err(AnthropicError::Other(anyhow!(
327                "Unexpected success response while expecting an error: '{body_str}'",
328            ))),
329            Err(_) => Err(AnthropicError::Other(anyhow!(
330                "Failed to connect to API: {} {}",
331                response.status(),
332                body_str,
333            ))),
334        }
335    }
336}
337
338pub async fn extract_tool_args_from_events(
339    tool_name: String,
340    mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
341) -> Result<impl Send + Stream<Item = Result<String>>> {
342    let mut tool_use_index = None;
343    while let Some(event) = events.next().await {
344        if let Event::ContentBlockStart {
345            index,
346            content_block: ResponseContent::ToolUse { name, .. },
347        } = event?
348        {
349            if name == tool_name {
350                tool_use_index = Some(index);
351                break;
352            }
353        }
354    }
355
356    let Some(tool_use_index) = tool_use_index else {
357        return Err(anyhow!("tool not used"));
358    };
359
360    Ok(events.filter_map(move |event| {
361        let result = match event {
362            Err(error) => Some(Err(error)),
363            Ok(Event::ContentBlockDelta { index, delta }) => match delta {
364                ContentDelta::TextDelta { .. } => None,
365                ContentDelta::InputJsonDelta { partial_json } => {
366                    if index == tool_use_index {
367                        Some(Ok(partial_json))
368                    } else {
369                        None
370                    }
371                }
372            },
373            _ => None,
374        };
375
376        async move { result }
377    }))
378}
379
380#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
381#[serde(rename_all = "lowercase")]
382pub enum CacheControlType {
383    Ephemeral,
384}
385
386#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
387pub struct CacheControl {
388    #[serde(rename = "type")]
389    pub cache_type: CacheControlType,
390}
391
392#[derive(Debug, Serialize, Deserialize)]
393pub struct Message {
394    pub role: Role,
395    pub content: Vec<RequestContent>,
396}
397
398#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
399#[serde(rename_all = "lowercase")]
400pub enum Role {
401    User,
402    Assistant,
403}
404
405#[derive(Debug, Serialize, Deserialize)]
406#[serde(tag = "type")]
407pub enum RequestContent {
408    #[serde(rename = "text")]
409    Text {
410        text: String,
411        #[serde(skip_serializing_if = "Option::is_none")]
412        cache_control: Option<CacheControl>,
413    },
414    #[serde(rename = "image")]
415    Image {
416        source: ImageSource,
417        #[serde(skip_serializing_if = "Option::is_none")]
418        cache_control: Option<CacheControl>,
419    },
420    #[serde(rename = "tool_use")]
421    ToolUse {
422        id: String,
423        name: String,
424        input: serde_json::Value,
425        #[serde(skip_serializing_if = "Option::is_none")]
426        cache_control: Option<CacheControl>,
427    },
428    #[serde(rename = "tool_result")]
429    ToolResult {
430        tool_use_id: String,
431        is_error: bool,
432        content: String,
433        #[serde(skip_serializing_if = "Option::is_none")]
434        cache_control: Option<CacheControl>,
435    },
436}
437
438#[derive(Debug, Serialize, Deserialize)]
439#[serde(tag = "type")]
440pub enum ResponseContent {
441    #[serde(rename = "text")]
442    Text { text: String },
443    #[serde(rename = "tool_use")]
444    ToolUse {
445        id: String,
446        name: String,
447        input: serde_json::Value,
448    },
449}
450
451#[derive(Debug, Serialize, Deserialize)]
452pub struct ImageSource {
453    #[serde(rename = "type")]
454    pub source_type: String,
455    pub media_type: String,
456    pub data: String,
457}
458
459#[derive(Debug, Serialize, Deserialize)]
460pub struct Tool {
461    pub name: String,
462    pub description: String,
463    pub input_schema: serde_json::Value,
464}
465
466#[derive(Debug, Serialize, Deserialize)]
467#[serde(tag = "type", rename_all = "lowercase")]
468pub enum ToolChoice {
469    Auto,
470    Any,
471    Tool { name: String },
472}
473
474#[derive(Debug, Serialize, Deserialize)]
475pub struct Request {
476    pub model: String,
477    pub max_tokens: u32,
478    pub messages: Vec<Message>,
479    #[serde(default, skip_serializing_if = "Vec::is_empty")]
480    pub tools: Vec<Tool>,
481    #[serde(default, skip_serializing_if = "Option::is_none")]
482    pub tool_choice: Option<ToolChoice>,
483    #[serde(default, skip_serializing_if = "Option::is_none")]
484    pub system: Option<String>,
485    #[serde(default, skip_serializing_if = "Option::is_none")]
486    pub metadata: Option<Metadata>,
487    #[serde(default, skip_serializing_if = "Vec::is_empty")]
488    pub stop_sequences: Vec<String>,
489    #[serde(default, skip_serializing_if = "Option::is_none")]
490    pub temperature: Option<f32>,
491    #[serde(default, skip_serializing_if = "Option::is_none")]
492    pub top_k: Option<u32>,
493    #[serde(default, skip_serializing_if = "Option::is_none")]
494    pub top_p: Option<f32>,
495}
496
497#[derive(Debug, Serialize, Deserialize)]
498struct StreamingRequest {
499    #[serde(flatten)]
500    pub base: Request,
501    pub stream: bool,
502}
503
504#[derive(Debug, Serialize, Deserialize)]
505pub struct Metadata {
506    pub user_id: Option<String>,
507}
508
509#[derive(Debug, Serialize, Deserialize)]
510pub struct Usage {
511    #[serde(default, skip_serializing_if = "Option::is_none")]
512    pub input_tokens: Option<u32>,
513    #[serde(default, skip_serializing_if = "Option::is_none")]
514    pub output_tokens: Option<u32>,
515    #[serde(default, skip_serializing_if = "Option::is_none")]
516    pub cache_creation_input_tokens: Option<u32>,
517    #[serde(default, skip_serializing_if = "Option::is_none")]
518    pub cache_read_input_tokens: Option<u32>,
519}
520
521#[derive(Debug, Serialize, Deserialize)]
522pub struct Response {
523    pub id: String,
524    #[serde(rename = "type")]
525    pub response_type: String,
526    pub role: Role,
527    pub content: Vec<ResponseContent>,
528    pub model: String,
529    #[serde(default, skip_serializing_if = "Option::is_none")]
530    pub stop_reason: Option<String>,
531    #[serde(default, skip_serializing_if = "Option::is_none")]
532    pub stop_sequence: Option<String>,
533    pub usage: Usage,
534}
535
536#[derive(Debug, Serialize, Deserialize)]
537#[serde(tag = "type")]
538pub enum Event {
539    #[serde(rename = "message_start")]
540    MessageStart { message: Response },
541    #[serde(rename = "content_block_start")]
542    ContentBlockStart {
543        index: usize,
544        content_block: ResponseContent,
545    },
546    #[serde(rename = "content_block_delta")]
547    ContentBlockDelta { index: usize, delta: ContentDelta },
548    #[serde(rename = "content_block_stop")]
549    ContentBlockStop { index: usize },
550    #[serde(rename = "message_delta")]
551    MessageDelta { delta: MessageDelta, usage: Usage },
552    #[serde(rename = "message_stop")]
553    MessageStop,
554    #[serde(rename = "ping")]
555    Ping,
556    #[serde(rename = "error")]
557    Error { error: ApiError },
558}
559
560#[derive(Debug, Serialize, Deserialize)]
561#[serde(tag = "type")]
562pub enum ContentDelta {
563    #[serde(rename = "text_delta")]
564    TextDelta { text: String },
565    #[serde(rename = "input_json_delta")]
566    InputJsonDelta { partial_json: String },
567}
568
569#[derive(Debug, Serialize, Deserialize)]
570pub struct MessageDelta {
571    pub stop_reason: Option<String>,
572    pub stop_sequence: Option<String>,
573}
574
575#[derive(Error, Debug)]
576pub enum AnthropicError {
577    #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
578    ApiError(ApiError),
579    #[error("{0}")]
580    Other(#[from] anyhow::Error),
581}
582
583#[derive(Debug, Serialize, Deserialize)]
584pub struct ApiError {
585    #[serde(rename = "type")]
586    pub error_type: String,
587    pub message: String,
588}
589
590/// An Anthropic API error code.
591/// https://docs.anthropic.com/en/api/errors#http-errors
592#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
593#[strum(serialize_all = "snake_case")]
594pub enum ApiErrorCode {
595    /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
596    InvalidRequestError,
597    /// 401 - `authentication_error`: There's an issue with your API key.
598    AuthenticationError,
599    /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
600    PermissionError,
601    /// 404 - `not_found_error`: The requested resource was not found.
602    NotFoundError,
603    /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
604    RequestTooLarge,
605    /// 429 - `rate_limit_error`: Your account has hit a rate limit.
606    RateLimitError,
607    /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
608    ApiError,
609    /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
610    OverloadedError,
611}
612
613impl ApiError {
614    pub fn code(&self) -> Option<ApiErrorCode> {
615        ApiErrorCode::from_str(&self.error_type).ok()
616    }
617
618    pub fn is_rate_limit_error(&self) -> bool {
619        matches!(self.error_type.as_str(), "rate_limit_error")
620    }
621}