anthropic.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Context, Result};
  4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6use isahc::config::Configurable;
  7use serde::{Deserialize, Serialize};
  8use std::time::Duration;
  9use std::{pin::Pin, str::FromStr};
 10use strum::{EnumIter, EnumString};
 11use thiserror::Error;
 12
 13pub use supported_countries::*;
 14
 15pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 16
 17#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 18#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 19pub struct AnthropicModelCacheConfiguration {
 20    pub min_total_token: usize,
 21    pub should_speculate: bool,
 22    pub max_cache_anchors: usize,
 23}
 24
 25#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 26#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 27pub enum Model {
 28    #[default]
 29    #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
 30    Claude3_5Sonnet,
 31    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
 32    Claude3Opus,
 33    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
 34    Claude3Sonnet,
 35    #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
 36    Claude3Haiku,
 37    #[serde(rename = "custom")]
 38    Custom {
 39        name: String,
 40        max_tokens: usize,
 41        /// Override this model with a different Anthropic model for tool calls.
 42        tool_override: Option<String>,
 43        /// Indicates whether this custom model supports caching.
 44        cache_configuration: Option<AnthropicModelCacheConfiguration>,
 45        max_output_tokens: Option<u32>,
 46    },
 47}
 48
 49impl Model {
 50    pub fn from_id(id: &str) -> Result<Self> {
 51        if id.starts_with("claude-3-5-sonnet") {
 52            Ok(Self::Claude3_5Sonnet)
 53        } else if id.starts_with("claude-3-opus") {
 54            Ok(Self::Claude3Opus)
 55        } else if id.starts_with("claude-3-sonnet") {
 56            Ok(Self::Claude3Sonnet)
 57        } else if id.starts_with("claude-3-haiku") {
 58            Ok(Self::Claude3Haiku)
 59        } else {
 60            Err(anyhow!("invalid model id"))
 61        }
 62    }
 63
 64    pub fn id(&self) -> &str {
 65        match self {
 66            Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
 67            Model::Claude3Opus => "claude-3-opus-20240229",
 68            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 69            Model::Claude3Haiku => "claude-3-haiku-20240307",
 70            Self::Custom { name, .. } => name,
 71        }
 72    }
 73
 74    pub fn display_name(&self) -> &str {
 75        match self {
 76            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 77            Self::Claude3Opus => "Claude 3 Opus",
 78            Self::Claude3Sonnet => "Claude 3 Sonnet",
 79            Self::Claude3Haiku => "Claude 3 Haiku",
 80            Self::Custom { name, .. } => name,
 81        }
 82    }
 83
 84    pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
 85        match self {
 86            Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
 87                min_total_token: 2_048,
 88                should_speculate: true,
 89                max_cache_anchors: 4,
 90            }),
 91            Self::Custom {
 92                cache_configuration,
 93                ..
 94            } => cache_configuration.clone(),
 95            _ => None,
 96        }
 97    }
 98
 99    pub fn max_token_count(&self) -> usize {
100        match self {
101            Self::Claude3_5Sonnet
102            | Self::Claude3Opus
103            | Self::Claude3Sonnet
104            | Self::Claude3Haiku => 200_000,
105            Self::Custom { max_tokens, .. } => *max_tokens,
106        }
107    }
108
109    pub fn max_output_tokens(&self) -> u32 {
110        match self {
111            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
112            Self::Claude3_5Sonnet => 8_192,
113            Self::Custom {
114                max_output_tokens, ..
115            } => max_output_tokens.unwrap_or(4_096),
116        }
117    }
118
119    pub fn tool_model_id(&self) -> &str {
120        if let Self::Custom {
121            tool_override: Some(tool_override),
122            ..
123        } = self
124        {
125            tool_override
126        } else {
127            self.id()
128        }
129    }
130}
131
132pub async fn complete(
133    client: &dyn HttpClient,
134    api_url: &str,
135    api_key: &str,
136    request: Request,
137) -> Result<Response, AnthropicError> {
138    let uri = format!("{api_url}/v1/messages");
139    let request_builder = HttpRequest::builder()
140        .method(Method::POST)
141        .uri(uri)
142        .header("Anthropic-Version", "2023-06-01")
143        .header(
144            "Anthropic-Beta",
145            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
146        )
147        .header("X-Api-Key", api_key)
148        .header("Content-Type", "application/json");
149
150    let serialized_request =
151        serde_json::to_string(&request).context("failed to serialize request")?;
152    let request = request_builder
153        .body(AsyncBody::from(serialized_request))
154        .context("failed to construct request body")?;
155
156    let mut response = client
157        .send(request)
158        .await
159        .context("failed to send request to Anthropic")?;
160    if response.status().is_success() {
161        let mut body = Vec::new();
162        response
163            .body_mut()
164            .read_to_end(&mut body)
165            .await
166            .context("failed to read response body")?;
167        let response_message: Response =
168            serde_json::from_slice(&body).context("failed to deserialize response body")?;
169        Ok(response_message)
170    } else {
171        let mut body = Vec::new();
172        response
173            .body_mut()
174            .read_to_end(&mut body)
175            .await
176            .context("failed to read response body")?;
177        let body_str =
178            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
179        Err(AnthropicError::Other(anyhow!(
180            "Failed to connect to API: {} {}",
181            response.status(),
182            body_str
183        )))
184    }
185}
186
187pub async fn stream_completion(
188    client: &dyn HttpClient,
189    api_url: &str,
190    api_key: &str,
191    request: Request,
192    low_speed_timeout: Option<Duration>,
193) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
194    let request = StreamingRequest {
195        base: request,
196        stream: true,
197    };
198    let uri = format!("{api_url}/v1/messages");
199    let mut request_builder = HttpRequest::builder()
200        .method(Method::POST)
201        .uri(uri)
202        .header("Anthropic-Version", "2023-06-01")
203        .header(
204            "Anthropic-Beta",
205            "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
206        )
207        .header("X-Api-Key", api_key)
208        .header("Content-Type", "application/json");
209    if let Some(low_speed_timeout) = low_speed_timeout {
210        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
211    }
212    let serialized_request =
213        serde_json::to_string(&request).context("failed to serialize request")?;
214    let request = request_builder
215        .body(AsyncBody::from(serialized_request))
216        .context("failed to construct request body")?;
217
218    let mut response = client
219        .send(request)
220        .await
221        .context("failed to send request to Anthropic")?;
222    if response.status().is_success() {
223        let reader = BufReader::new(response.into_body());
224        Ok(reader
225            .lines()
226            .filter_map(|line| async move {
227                match line {
228                    Ok(line) => {
229                        let line = line.strip_prefix("data: ")?;
230                        match serde_json::from_str(line) {
231                            Ok(response) => Some(Ok(response)),
232                            Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
233                        }
234                    }
235                    Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
236                }
237            })
238            .boxed())
239    } else {
240        let mut body = Vec::new();
241        response
242            .body_mut()
243            .read_to_end(&mut body)
244            .await
245            .context("failed to read response body")?;
246
247        let body_str =
248            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
249
250        match serde_json::from_str::<Event>(body_str) {
251            Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
252            Ok(_) => Err(AnthropicError::Other(anyhow!(
253                "Unexpected success response while expecting an error: '{body_str}'",
254            ))),
255            Err(_) => Err(AnthropicError::Other(anyhow!(
256                "Failed to connect to API: {} {}",
257                response.status(),
258                body_str,
259            ))),
260        }
261    }
262}
263
264pub fn extract_text_from_events(
265    response: impl Stream<Item = Result<Event, AnthropicError>>,
266) -> impl Stream<Item = Result<String, AnthropicError>> {
267    response.filter_map(|response| async move {
268        match response {
269            Ok(response) => match response {
270                Event::ContentBlockStart { content_block, .. } => match content_block {
271                    Content::Text { text, .. } => Some(Ok(text)),
272                    _ => None,
273                },
274                Event::ContentBlockDelta { delta, .. } => match delta {
275                    ContentDelta::TextDelta { text } => Some(Ok(text)),
276                    _ => None,
277                },
278                Event::Error { error } => Some(Err(AnthropicError::ApiError(error))),
279                _ => None,
280            },
281            Err(error) => Some(Err(error)),
282        }
283    })
284}
285
286pub async fn extract_tool_args_from_events(
287    tool_name: String,
288    mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
289) -> Result<impl Send + Stream<Item = Result<String>>> {
290    let mut tool_use_index = None;
291    while let Some(event) = events.next().await {
292        if let Event::ContentBlockStart {
293            index,
294            content_block,
295        } = event?
296        {
297            if let Content::ToolUse { name, .. } = content_block {
298                if name == tool_name {
299                    tool_use_index = Some(index);
300                    break;
301                }
302            }
303        }
304    }
305
306    let Some(tool_use_index) = tool_use_index else {
307        return Err(anyhow!("tool not used"));
308    };
309
310    Ok(events.filter_map(move |event| {
311        let result = match event {
312            Err(error) => Some(Err(error)),
313            Ok(Event::ContentBlockDelta { index, delta }) => match delta {
314                ContentDelta::TextDelta { .. } => None,
315                ContentDelta::InputJsonDelta { partial_json } => {
316                    if index == tool_use_index {
317                        Some(Ok(partial_json))
318                    } else {
319                        None
320                    }
321                }
322            },
323            _ => None,
324        };
325
326        async move { result }
327    }))
328}
329
330#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
331#[serde(rename_all = "lowercase")]
332pub enum CacheControlType {
333    Ephemeral,
334}
335
336#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
337pub struct CacheControl {
338    #[serde(rename = "type")]
339    pub cache_type: CacheControlType,
340}
341
342#[derive(Debug, Serialize, Deserialize)]
343pub struct Message {
344    pub role: Role,
345    pub content: Vec<Content>,
346}
347
348#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
349#[serde(rename_all = "lowercase")]
350pub enum Role {
351    User,
352    Assistant,
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356#[serde(tag = "type")]
357pub enum Content {
358    #[serde(rename = "text")]
359    Text {
360        text: String,
361        #[serde(skip_serializing_if = "Option::is_none")]
362        cache_control: Option<CacheControl>,
363    },
364    #[serde(rename = "image")]
365    Image {
366        source: ImageSource,
367        #[serde(skip_serializing_if = "Option::is_none")]
368        cache_control: Option<CacheControl>,
369    },
370    #[serde(rename = "tool_use")]
371    ToolUse {
372        id: String,
373        name: String,
374        input: serde_json::Value,
375        #[serde(skip_serializing_if = "Option::is_none")]
376        cache_control: Option<CacheControl>,
377    },
378    #[serde(rename = "tool_result")]
379    ToolResult {
380        tool_use_id: String,
381        content: String,
382        #[serde(skip_serializing_if = "Option::is_none")]
383        cache_control: Option<CacheControl>,
384    },
385}
386
387#[derive(Debug, Serialize, Deserialize)]
388pub struct ImageSource {
389    #[serde(rename = "type")]
390    pub source_type: String,
391    pub media_type: String,
392    pub data: String,
393}
394
395#[derive(Debug, Serialize, Deserialize)]
396pub struct Tool {
397    pub name: String,
398    pub description: String,
399    pub input_schema: serde_json::Value,
400}
401
402#[derive(Debug, Serialize, Deserialize)]
403#[serde(tag = "type", rename_all = "lowercase")]
404pub enum ToolChoice {
405    Auto,
406    Any,
407    Tool { name: String },
408}
409
410#[derive(Debug, Serialize, Deserialize)]
411pub struct Request {
412    pub model: String,
413    pub max_tokens: u32,
414    pub messages: Vec<Message>,
415    #[serde(default, skip_serializing_if = "Vec::is_empty")]
416    pub tools: Vec<Tool>,
417    #[serde(default, skip_serializing_if = "Option::is_none")]
418    pub tool_choice: Option<ToolChoice>,
419    #[serde(default, skip_serializing_if = "Option::is_none")]
420    pub system: Option<String>,
421    #[serde(default, skip_serializing_if = "Option::is_none")]
422    pub metadata: Option<Metadata>,
423    #[serde(default, skip_serializing_if = "Vec::is_empty")]
424    pub stop_sequences: Vec<String>,
425    #[serde(default, skip_serializing_if = "Option::is_none")]
426    pub temperature: Option<f32>,
427    #[serde(default, skip_serializing_if = "Option::is_none")]
428    pub top_k: Option<u32>,
429    #[serde(default, skip_serializing_if = "Option::is_none")]
430    pub top_p: Option<f32>,
431}
432
433#[derive(Debug, Serialize, Deserialize)]
434struct StreamingRequest {
435    #[serde(flatten)]
436    pub base: Request,
437    pub stream: bool,
438}
439
440#[derive(Debug, Serialize, Deserialize)]
441pub struct Metadata {
442    pub user_id: Option<String>,
443}
444
445#[derive(Debug, Serialize, Deserialize)]
446pub struct Usage {
447    #[serde(default, skip_serializing_if = "Option::is_none")]
448    pub input_tokens: Option<u32>,
449    #[serde(default, skip_serializing_if = "Option::is_none")]
450    pub output_tokens: Option<u32>,
451}
452
453#[derive(Debug, Serialize, Deserialize)]
454pub struct Response {
455    pub id: String,
456    #[serde(rename = "type")]
457    pub response_type: String,
458    pub role: Role,
459    pub content: Vec<Content>,
460    pub model: String,
461    #[serde(default, skip_serializing_if = "Option::is_none")]
462    pub stop_reason: Option<String>,
463    #[serde(default, skip_serializing_if = "Option::is_none")]
464    pub stop_sequence: Option<String>,
465    pub usage: Usage,
466}
467
468#[derive(Debug, Serialize, Deserialize)]
469#[serde(tag = "type")]
470pub enum Event {
471    #[serde(rename = "message_start")]
472    MessageStart { message: Response },
473    #[serde(rename = "content_block_start")]
474    ContentBlockStart {
475        index: usize,
476        content_block: Content,
477    },
478    #[serde(rename = "content_block_delta")]
479    ContentBlockDelta { index: usize, delta: ContentDelta },
480    #[serde(rename = "content_block_stop")]
481    ContentBlockStop { index: usize },
482    #[serde(rename = "message_delta")]
483    MessageDelta { delta: MessageDelta, usage: Usage },
484    #[serde(rename = "message_stop")]
485    MessageStop,
486    #[serde(rename = "ping")]
487    Ping,
488    #[serde(rename = "error")]
489    Error { error: ApiError },
490}
491
492#[derive(Debug, Serialize, Deserialize)]
493#[serde(tag = "type")]
494pub enum ContentDelta {
495    #[serde(rename = "text_delta")]
496    TextDelta { text: String },
497    #[serde(rename = "input_json_delta")]
498    InputJsonDelta { partial_json: String },
499}
500
501#[derive(Debug, Serialize, Deserialize)]
502pub struct MessageDelta {
503    pub stop_reason: Option<String>,
504    pub stop_sequence: Option<String>,
505}
506
507#[derive(Error, Debug)]
508pub enum AnthropicError {
509    #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
510    ApiError(ApiError),
511    #[error("{0}")]
512    Other(#[from] anyhow::Error),
513}
514
515#[derive(Debug, Serialize, Deserialize)]
516pub struct ApiError {
517    #[serde(rename = "type")]
518    pub error_type: String,
519    pub message: String,
520}
521
522/// An Anthropic API error code.
523/// https://docs.anthropic.com/en/api/errors#http-errors
524#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
525#[strum(serialize_all = "snake_case")]
526pub enum ApiErrorCode {
527    /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
528    InvalidRequestError,
529    /// 401 - `authentication_error`: There's an issue with your API key.
530    AuthenticationError,
531    /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
532    PermissionError,
533    /// 404 - `not_found_error`: The requested resource was not found.
534    NotFoundError,
535    /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
536    RequestTooLarge,
537    /// 429 - `rate_limit_error`: Your account has hit a rate limit.
538    RateLimitError,
539    /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
540    ApiError,
541    /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
542    OverloadedError,
543}
544
545impl ApiError {
546    pub fn code(&self) -> Option<ApiErrorCode> {
547        ApiErrorCode::from_str(&self.error_type).ok()
548    }
549
550    pub fn is_rate_limit_error(&self) -> bool {
551        match self.error_type.as_str() {
552            "rate_limit_error" => true,
553            _ => false,
554        }
555    }
556}