anthropic.rs

  1mod supported_countries;
  2
  3use anyhow::{anyhow, Result};
  4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6use isahc::config::Configurable;
  7use serde::{Deserialize, Serialize};
  8use std::time::Duration;
  9use strum::EnumIter;
 10
 11pub use supported_countries::*;
 12
 13pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 14
 15#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 16#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 17pub enum Model {
 18    #[default]
 19    #[serde(alias = "claude-3-5-sonnet", rename = "claude-3-5-sonnet-20240620")]
 20    Claude3_5Sonnet,
 21    #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
 22    Claude3Opus,
 23    #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
 24    Claude3Sonnet,
 25    #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
 26    Claude3Haiku,
 27    #[serde(rename = "custom")]
 28    Custom {
 29        name: String,
 30        max_tokens: usize,
 31        /// Override this model with a different Anthropic model for tool calls.
 32        tool_override: Option<String>,
 33    },
 34}
 35
 36impl Model {
 37    pub fn from_id(id: &str) -> Result<Self> {
 38        if id.starts_with("claude-3-5-sonnet") {
 39            Ok(Self::Claude3_5Sonnet)
 40        } else if id.starts_with("claude-3-opus") {
 41            Ok(Self::Claude3Opus)
 42        } else if id.starts_with("claude-3-sonnet") {
 43            Ok(Self::Claude3Sonnet)
 44        } else if id.starts_with("claude-3-haiku") {
 45            Ok(Self::Claude3Haiku)
 46        } else {
 47            Err(anyhow!("invalid model id"))
 48        }
 49    }
 50
 51    pub fn id(&self) -> &str {
 52        match self {
 53            Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
 54            Model::Claude3Opus => "claude-3-opus-20240229",
 55            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 56            Model::Claude3Haiku => "claude-3-opus-20240307",
 57            Self::Custom { name, .. } => name,
 58        }
 59    }
 60
 61    pub fn display_name(&self) -> &str {
 62        match self {
 63            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 64            Self::Claude3Opus => "Claude 3 Opus",
 65            Self::Claude3Sonnet => "Claude 3 Sonnet",
 66            Self::Claude3Haiku => "Claude 3 Haiku",
 67            Self::Custom { name, .. } => name,
 68        }
 69    }
 70
 71    pub fn max_token_count(&self) -> usize {
 72        match self {
 73            Self::Claude3_5Sonnet
 74            | Self::Claude3Opus
 75            | Self::Claude3Sonnet
 76            | Self::Claude3Haiku => 200_000,
 77            Self::Custom { max_tokens, .. } => *max_tokens,
 78        }
 79    }
 80
 81    pub fn tool_model_id(&self) -> &str {
 82        if let Self::Custom {
 83            tool_override: Some(tool_override),
 84            ..
 85        } = self
 86        {
 87            tool_override
 88        } else {
 89            self.id()
 90        }
 91    }
 92}
 93
 94pub async fn complete(
 95    client: &dyn HttpClient,
 96    api_url: &str,
 97    api_key: &str,
 98    request: Request,
 99) -> Result<Response> {
100    let uri = format!("{api_url}/v1/messages");
101    let request_builder = HttpRequest::builder()
102        .method(Method::POST)
103        .uri(uri)
104        .header("Anthropic-Version", "2023-06-01")
105        .header("Anthropic-Beta", "tools-2024-04-04")
106        .header("X-Api-Key", api_key)
107        .header("Content-Type", "application/json");
108
109    let serialized_request = serde_json::to_string(&request)?;
110    let request = request_builder.body(AsyncBody::from(serialized_request))?;
111
112    let mut response = client.send(request).await?;
113    if response.status().is_success() {
114        let mut body = Vec::new();
115        response.body_mut().read_to_end(&mut body).await?;
116        let response_message: Response = serde_json::from_slice(&body)?;
117        Ok(response_message)
118    } else {
119        let mut body = Vec::new();
120        response.body_mut().read_to_end(&mut body).await?;
121        let body_str = std::str::from_utf8(&body)?;
122        Err(anyhow!(
123            "Failed to connect to API: {} {}",
124            response.status(),
125            body_str
126        ))
127    }
128}
129
130pub async fn stream_completion(
131    client: &dyn HttpClient,
132    api_url: &str,
133    api_key: &str,
134    request: Request,
135    low_speed_timeout: Option<Duration>,
136) -> Result<BoxStream<'static, Result<Event>>> {
137    let request = StreamingRequest {
138        base: request,
139        stream: true,
140    };
141    let uri = format!("{api_url}/v1/messages");
142    let mut request_builder = HttpRequest::builder()
143        .method(Method::POST)
144        .uri(uri)
145        .header("Anthropic-Version", "2023-06-01")
146        .header("Anthropic-Beta", "tools-2024-04-04")
147        .header("X-Api-Key", api_key)
148        .header("Content-Type", "application/json");
149    if let Some(low_speed_timeout) = low_speed_timeout {
150        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
151    }
152    let serialized_request = serde_json::to_string(&request)?;
153    let request = request_builder.body(AsyncBody::from(serialized_request))?;
154
155    let mut response = client.send(request).await?;
156    if response.status().is_success() {
157        let reader = BufReader::new(response.into_body());
158        Ok(reader
159            .lines()
160            .filter_map(|line| async move {
161                match line {
162                    Ok(line) => {
163                        let line = line.strip_prefix("data: ")?;
164                        match serde_json::from_str(line) {
165                            Ok(response) => Some(Ok(response)),
166                            Err(error) => Some(Err(anyhow!(error))),
167                        }
168                    }
169                    Err(error) => Some(Err(anyhow!(error))),
170                }
171            })
172            .boxed())
173    } else {
174        let mut body = Vec::new();
175        response.body_mut().read_to_end(&mut body).await?;
176
177        let body_str = std::str::from_utf8(&body)?;
178
179        match serde_json::from_str::<Event>(body_str) {
180            Ok(Event::Error { error }) => Err(api_error_to_err(error)),
181            Ok(_) => Err(anyhow!(
182                "Unexpected success response while expecting an error: '{body_str}'",
183            )),
184            Err(_) => Err(anyhow!(
185                "Failed to connect to API: {} {}",
186                response.status(),
187                body_str,
188            )),
189        }
190    }
191}
192
193pub fn extract_text_from_events(
194    response: impl Stream<Item = Result<Event>>,
195) -> impl Stream<Item = Result<String>> {
196    response.filter_map(|response| async move {
197        match response {
198            Ok(response) => match response {
199                Event::ContentBlockStart { content_block, .. } => match content_block {
200                    Content::Text { text } => Some(Ok(text)),
201                    _ => None,
202                },
203                Event::ContentBlockDelta { delta, .. } => match delta {
204                    ContentDelta::TextDelta { text } => Some(Ok(text)),
205                    _ => None,
206                },
207                Event::Error { error } => Some(Err(api_error_to_err(error))),
208                _ => None,
209            },
210            Err(error) => Some(Err(error)),
211        }
212    })
213}
214
215fn api_error_to_err(
216    ApiError {
217        error_type,
218        message,
219    }: ApiError,
220) -> anyhow::Error {
221    anyhow!("API error. Type: '{error_type}', message: '{message}'",)
222}
223
224#[derive(Debug, Serialize, Deserialize)]
225pub struct Message {
226    pub role: Role,
227    pub content: Vec<Content>,
228}
229
230#[derive(Debug, Serialize, Deserialize)]
231#[serde(rename_all = "lowercase")]
232pub enum Role {
233    User,
234    Assistant,
235}
236
237#[derive(Debug, Serialize, Deserialize)]
238#[serde(tag = "type")]
239pub enum Content {
240    #[serde(rename = "text")]
241    Text { text: String },
242    #[serde(rename = "image")]
243    Image { source: ImageSource },
244    #[serde(rename = "tool_use")]
245    ToolUse {
246        id: String,
247        name: String,
248        input: serde_json::Value,
249    },
250    #[serde(rename = "tool_result")]
251    ToolResult {
252        tool_use_id: String,
253        content: String,
254    },
255}
256
257#[derive(Debug, Serialize, Deserialize)]
258pub struct ImageSource {
259    #[serde(rename = "type")]
260    pub source_type: String,
261    pub media_type: String,
262    pub data: String,
263}
264
265#[derive(Debug, Serialize, Deserialize)]
266pub struct Tool {
267    pub name: String,
268    pub description: String,
269    pub input_schema: serde_json::Value,
270}
271
272#[derive(Debug, Serialize, Deserialize)]
273#[serde(tag = "type", rename_all = "lowercase")]
274pub enum ToolChoice {
275    Auto,
276    Any,
277    Tool { name: String },
278}
279
280#[derive(Debug, Serialize, Deserialize)]
281pub struct Request {
282    pub model: String,
283    pub max_tokens: u32,
284    pub messages: Vec<Message>,
285    #[serde(default, skip_serializing_if = "Vec::is_empty")]
286    pub tools: Vec<Tool>,
287    #[serde(default, skip_serializing_if = "Option::is_none")]
288    pub tool_choice: Option<ToolChoice>,
289    #[serde(default, skip_serializing_if = "Option::is_none")]
290    pub system: Option<String>,
291    #[serde(default, skip_serializing_if = "Option::is_none")]
292    pub metadata: Option<Metadata>,
293    #[serde(default, skip_serializing_if = "Vec::is_empty")]
294    pub stop_sequences: Vec<String>,
295    #[serde(default, skip_serializing_if = "Option::is_none")]
296    pub temperature: Option<f32>,
297    #[serde(default, skip_serializing_if = "Option::is_none")]
298    pub top_k: Option<u32>,
299    #[serde(default, skip_serializing_if = "Option::is_none")]
300    pub top_p: Option<f32>,
301}
302
303#[derive(Debug, Serialize, Deserialize)]
304struct StreamingRequest {
305    #[serde(flatten)]
306    pub base: Request,
307    pub stream: bool,
308}
309
310#[derive(Debug, Serialize, Deserialize)]
311pub struct Metadata {
312    pub user_id: Option<String>,
313}
314
315#[derive(Debug, Serialize, Deserialize)]
316pub struct Usage {
317    #[serde(default, skip_serializing_if = "Option::is_none")]
318    pub input_tokens: Option<u32>,
319    #[serde(default, skip_serializing_if = "Option::is_none")]
320    pub output_tokens: Option<u32>,
321}
322
323#[derive(Debug, Serialize, Deserialize)]
324pub struct Response {
325    pub id: String,
326    #[serde(rename = "type")]
327    pub response_type: String,
328    pub role: Role,
329    pub content: Vec<Content>,
330    pub model: String,
331    #[serde(default, skip_serializing_if = "Option::is_none")]
332    pub stop_reason: Option<String>,
333    #[serde(default, skip_serializing_if = "Option::is_none")]
334    pub stop_sequence: Option<String>,
335    pub usage: Usage,
336}
337
338#[derive(Debug, Serialize, Deserialize)]
339#[serde(tag = "type")]
340pub enum Event {
341    #[serde(rename = "message_start")]
342    MessageStart { message: Response },
343    #[serde(rename = "content_block_start")]
344    ContentBlockStart {
345        index: usize,
346        content_block: Content,
347    },
348    #[serde(rename = "content_block_delta")]
349    ContentBlockDelta { index: usize, delta: ContentDelta },
350    #[serde(rename = "content_block_stop")]
351    ContentBlockStop { index: usize },
352    #[serde(rename = "message_delta")]
353    MessageDelta { delta: MessageDelta, usage: Usage },
354    #[serde(rename = "message_stop")]
355    MessageStop,
356    #[serde(rename = "ping")]
357    Ping,
358    #[serde(rename = "error")]
359    Error { error: ApiError },
360}
361
362#[derive(Debug, Serialize, Deserialize)]
363#[serde(tag = "type")]
364pub enum ContentDelta {
365    #[serde(rename = "text_delta")]
366    TextDelta { text: String },
367    #[serde(rename = "input_json_delta")]
368    InputJsonDelta { partial_json: String },
369}
370
371#[derive(Debug, Serialize, Deserialize)]
372pub struct MessageDelta {
373    pub stop_reason: Option<String>,
374    pub stop_sequence: Option<String>,
375}
376
377#[derive(Debug, Serialize, Deserialize)]
378pub struct ApiError {
379    #[serde(rename = "type")]
380    pub error_type: String,
381    pub message: String,
382}