anthropic.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Context, Result};
  4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6use isahc::config::Configurable;
  7use serde::{Deserialize, Serialize};
  8use std::str::FromStr;
  9use std::time::Duration;
 10use strum::{EnumIter, EnumString};
 11use thiserror::Error;
 12
 13pub use supported_countries::*;
 14
 15pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 16
 17#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 18#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 19pub enum Model {
 20    #[default]
 21    #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
 22    Claude3_5Sonnet,
 23    #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
 24    Claude3Opus,
 25    #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
 26    Claude3Sonnet,
 27    #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
 28    Claude3Haiku,
 29    #[serde(rename = "custom")]
 30    Custom {
 31        name: String,
 32        max_tokens: usize,
 33        /// Override this model with a different Anthropic model for tool calls.
 34        tool_override: Option<String>,
 35    },
 36}
 37
 38impl Model {
 39    pub fn from_id(id: &str) -> Result<Self> {
 40        if id.starts_with("claude-3-5-sonnet") {
 41            Ok(Self::Claude3_5Sonnet)
 42        } else if id.starts_with("claude-3-opus") {
 43            Ok(Self::Claude3Opus)
 44        } else if id.starts_with("claude-3-sonnet") {
 45            Ok(Self::Claude3Sonnet)
 46        } else if id.starts_with("claude-3-haiku") {
 47            Ok(Self::Claude3Haiku)
 48        } else {
 49            Err(anyhow!("invalid model id"))
 50        }
 51    }
 52
 53    pub fn id(&self) -> &str {
 54        match self {
 55            Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
 56            Model::Claude3Opus => "claude-3-opus-20240229",
 57            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 58            Model::Claude3Haiku => "claude-3-haiku-20240307",
 59            Self::Custom { name, .. } => name,
 60        }
 61    }
 62
 63    pub fn display_name(&self) -> &str {
 64        match self {
 65            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 66            Self::Claude3Opus => "Claude 3 Opus",
 67            Self::Claude3Sonnet => "Claude 3 Sonnet",
 68            Self::Claude3Haiku => "Claude 3 Haiku",
 69            Self::Custom { name, .. } => name,
 70        }
 71    }
 72
 73    pub fn max_token_count(&self) -> usize {
 74        match self {
 75            Self::Claude3_5Sonnet
 76            | Self::Claude3Opus
 77            | Self::Claude3Sonnet
 78            | Self::Claude3Haiku => 200_000,
 79            Self::Custom { max_tokens, .. } => *max_tokens,
 80        }
 81    }
 82
 83    pub fn tool_model_id(&self) -> &str {
 84        if let Self::Custom {
 85            tool_override: Some(tool_override),
 86            ..
 87        } = self
 88        {
 89            tool_override
 90        } else {
 91            self.id()
 92        }
 93    }
 94}
 95
 96pub async fn complete(
 97    client: &dyn HttpClient,
 98    api_url: &str,
 99    api_key: &str,
100    request: Request,
101) -> Result<Response, AnthropicError> {
102    let uri = format!("{api_url}/v1/messages");
103    let request_builder = HttpRequest::builder()
104        .method(Method::POST)
105        .uri(uri)
106        .header("Anthropic-Version", "2023-06-01")
107        .header("Anthropic-Beta", "tools-2024-04-04")
108        .header("X-Api-Key", api_key)
109        .header("Content-Type", "application/json");
110
111    let serialized_request =
112        serde_json::to_string(&request).context("failed to serialize request")?;
113    let request = request_builder
114        .body(AsyncBody::from(serialized_request))
115        .context("failed to construct request body")?;
116
117    let mut response = client
118        .send(request)
119        .await
120        .context("failed to send request to Anthropic")?;
121    if response.status().is_success() {
122        let mut body = Vec::new();
123        response
124            .body_mut()
125            .read_to_end(&mut body)
126            .await
127            .context("failed to read response body")?;
128        let response_message: Response =
129            serde_json::from_slice(&body).context("failed to deserialize response body")?;
130        Ok(response_message)
131    } else {
132        let mut body = Vec::new();
133        response
134            .body_mut()
135            .read_to_end(&mut body)
136            .await
137            .context("failed to read response body")?;
138        let body_str =
139            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
140        Err(AnthropicError::Other(anyhow!(
141            "Failed to connect to API: {} {}",
142            response.status(),
143            body_str
144        )))
145    }
146}
147
148pub async fn stream_completion(
149    client: &dyn HttpClient,
150    api_url: &str,
151    api_key: &str,
152    request: Request,
153    low_speed_timeout: Option<Duration>,
154) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
155    let request = StreamingRequest {
156        base: request,
157        stream: true,
158    };
159    let uri = format!("{api_url}/v1/messages");
160    let mut request_builder = HttpRequest::builder()
161        .method(Method::POST)
162        .uri(uri)
163        .header("Anthropic-Version", "2023-06-01")
164        .header("Anthropic-Beta", "tools-2024-04-04")
165        .header("X-Api-Key", api_key)
166        .header("Content-Type", "application/json");
167    if let Some(low_speed_timeout) = low_speed_timeout {
168        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
169    }
170    let serialized_request =
171        serde_json::to_string(&request).context("failed to serialize request")?;
172    let request = request_builder
173        .body(AsyncBody::from(serialized_request))
174        .context("failed to construct request body")?;
175
176    let mut response = client
177        .send(request)
178        .await
179        .context("failed to send request to Anthropic")?;
180    if response.status().is_success() {
181        let reader = BufReader::new(response.into_body());
182        Ok(reader
183            .lines()
184            .filter_map(|line| async move {
185                match line {
186                    Ok(line) => {
187                        let line = line.strip_prefix("data: ")?;
188                        match serde_json::from_str(line) {
189                            Ok(response) => Some(Ok(response)),
190                            Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
191                        }
192                    }
193                    Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
194                }
195            })
196            .boxed())
197    } else {
198        let mut body = Vec::new();
199        response
200            .body_mut()
201            .read_to_end(&mut body)
202            .await
203            .context("failed to read response body")?;
204
205        let body_str =
206            std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
207
208        match serde_json::from_str::<Event>(body_str) {
209            Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
210            Ok(_) => Err(AnthropicError::Other(anyhow!(
211                "Unexpected success response while expecting an error: '{body_str}'",
212            ))),
213            Err(_) => Err(AnthropicError::Other(anyhow!(
214                "Failed to connect to API: {} {}",
215                response.status(),
216                body_str,
217            ))),
218        }
219    }
220}
221
222pub fn extract_text_from_events(
223    response: impl Stream<Item = Result<Event, AnthropicError>>,
224) -> impl Stream<Item = Result<String, AnthropicError>> {
225    response.filter_map(|response| async move {
226        match response {
227            Ok(response) => match response {
228                Event::ContentBlockStart { content_block, .. } => match content_block {
229                    Content::Text { text } => Some(Ok(text)),
230                    _ => None,
231                },
232                Event::ContentBlockDelta { delta, .. } => match delta {
233                    ContentDelta::TextDelta { text } => Some(Ok(text)),
234                    _ => None,
235                },
236                Event::Error { error } => Some(Err(AnthropicError::ApiError(error))),
237                _ => None,
238            },
239            Err(error) => Some(Err(error)),
240        }
241    })
242}
243
244#[derive(Debug, Serialize, Deserialize)]
245pub struct Message {
246    pub role: Role,
247    pub content: Vec<Content>,
248}
249
250#[derive(Debug, Serialize, Deserialize)]
251#[serde(rename_all = "lowercase")]
252pub enum Role {
253    User,
254    Assistant,
255}
256
257#[derive(Debug, Serialize, Deserialize)]
258#[serde(tag = "type")]
259pub enum Content {
260    #[serde(rename = "text")]
261    Text { text: String },
262    #[serde(rename = "image")]
263    Image { source: ImageSource },
264    #[serde(rename = "tool_use")]
265    ToolUse {
266        id: String,
267        name: String,
268        input: serde_json::Value,
269    },
270    #[serde(rename = "tool_result")]
271    ToolResult {
272        tool_use_id: String,
273        content: String,
274    },
275}
276
277#[derive(Debug, Serialize, Deserialize)]
278pub struct ImageSource {
279    #[serde(rename = "type")]
280    pub source_type: String,
281    pub media_type: String,
282    pub data: String,
283}
284
285#[derive(Debug, Serialize, Deserialize)]
286pub struct Tool {
287    pub name: String,
288    pub description: String,
289    pub input_schema: serde_json::Value,
290}
291
292#[derive(Debug, Serialize, Deserialize)]
293#[serde(tag = "type", rename_all = "lowercase")]
294pub enum ToolChoice {
295    Auto,
296    Any,
297    Tool { name: String },
298}
299
300#[derive(Debug, Serialize, Deserialize)]
301pub struct Request {
302    pub model: String,
303    pub max_tokens: u32,
304    pub messages: Vec<Message>,
305    #[serde(default, skip_serializing_if = "Vec::is_empty")]
306    pub tools: Vec<Tool>,
307    #[serde(default, skip_serializing_if = "Option::is_none")]
308    pub tool_choice: Option<ToolChoice>,
309    #[serde(default, skip_serializing_if = "Option::is_none")]
310    pub system: Option<String>,
311    #[serde(default, skip_serializing_if = "Option::is_none")]
312    pub metadata: Option<Metadata>,
313    #[serde(default, skip_serializing_if = "Vec::is_empty")]
314    pub stop_sequences: Vec<String>,
315    #[serde(default, skip_serializing_if = "Option::is_none")]
316    pub temperature: Option<f32>,
317    #[serde(default, skip_serializing_if = "Option::is_none")]
318    pub top_k: Option<u32>,
319    #[serde(default, skip_serializing_if = "Option::is_none")]
320    pub top_p: Option<f32>,
321}
322
323#[derive(Debug, Serialize, Deserialize)]
324struct StreamingRequest {
325    #[serde(flatten)]
326    pub base: Request,
327    pub stream: bool,
328}
329
330#[derive(Debug, Serialize, Deserialize)]
331pub struct Metadata {
332    pub user_id: Option<String>,
333}
334
335#[derive(Debug, Serialize, Deserialize)]
336pub struct Usage {
337    #[serde(default, skip_serializing_if = "Option::is_none")]
338    pub input_tokens: Option<u32>,
339    #[serde(default, skip_serializing_if = "Option::is_none")]
340    pub output_tokens: Option<u32>,
341}
342
343#[derive(Debug, Serialize, Deserialize)]
344pub struct Response {
345    pub id: String,
346    #[serde(rename = "type")]
347    pub response_type: String,
348    pub role: Role,
349    pub content: Vec<Content>,
350    pub model: String,
351    #[serde(default, skip_serializing_if = "Option::is_none")]
352    pub stop_reason: Option<String>,
353    #[serde(default, skip_serializing_if = "Option::is_none")]
354    pub stop_sequence: Option<String>,
355    pub usage: Usage,
356}
357
358#[derive(Debug, Serialize, Deserialize)]
359#[serde(tag = "type")]
360pub enum Event {
361    #[serde(rename = "message_start")]
362    MessageStart { message: Response },
363    #[serde(rename = "content_block_start")]
364    ContentBlockStart {
365        index: usize,
366        content_block: Content,
367    },
368    #[serde(rename = "content_block_delta")]
369    ContentBlockDelta { index: usize, delta: ContentDelta },
370    #[serde(rename = "content_block_stop")]
371    ContentBlockStop { index: usize },
372    #[serde(rename = "message_delta")]
373    MessageDelta { delta: MessageDelta, usage: Usage },
374    #[serde(rename = "message_stop")]
375    MessageStop,
376    #[serde(rename = "ping")]
377    Ping,
378    #[serde(rename = "error")]
379    Error { error: ApiError },
380}
381
382#[derive(Debug, Serialize, Deserialize)]
383#[serde(tag = "type")]
384pub enum ContentDelta {
385    #[serde(rename = "text_delta")]
386    TextDelta { text: String },
387    #[serde(rename = "input_json_delta")]
388    InputJsonDelta { partial_json: String },
389}
390
391#[derive(Debug, Serialize, Deserialize)]
392pub struct MessageDelta {
393    pub stop_reason: Option<String>,
394    pub stop_sequence: Option<String>,
395}
396
397#[derive(Error, Debug)]
398pub enum AnthropicError {
399    #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
400    ApiError(ApiError),
401    #[error("{0}")]
402    Other(#[from] anyhow::Error),
403}
404
405#[derive(Debug, Serialize, Deserialize)]
406pub struct ApiError {
407    #[serde(rename = "type")]
408    pub error_type: String,
409    pub message: String,
410}
411
412/// An Anthropic API error code.
413/// https://docs.anthropic.com/en/api/errors#http-errors
414#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
415#[strum(serialize_all = "snake_case")]
416pub enum ApiErrorCode {
417    /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
418    InvalidRequestError,
419    /// 401 - `authentication_error`: There's an issue with your API key.
420    AuthenticationError,
421    /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
422    PermissionError,
423    /// 404 - `not_found_error`: The requested resource was not found.
424    NotFoundError,
425    /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
426    RequestTooLarge,
427    /// 429 - `rate_limit_error`: Your account has hit a rate limit.
428    RateLimitError,
429    /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
430    ApiError,
431    /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
432    OverloadedError,
433}
434
435impl ApiError {
436    pub fn code(&self) -> Option<ApiErrorCode> {
437        ApiErrorCode::from_str(&self.error_type).ok()
438    }
439
440    pub fn is_rate_limit_error(&self) -> bool {
441        match self.error_type.as_str() {
442            "rate_limit_error" => true,
443            _ => false,
444        }
445    }
446}