anthropic.rs

  1use anyhow::{anyhow, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use serde::{Deserialize, Serialize};
  6use std::time::Duration;
  7use strum::EnumIter;
  8
  9pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 13pub enum Model {
 14    #[default]
 15    #[serde(alias = "claude-3-5-sonnet", rename = "claude-3-5-sonnet-20240620")]
 16    Claude3_5Sonnet,
 17    #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
 18    Claude3Opus,
 19    #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
 20    Claude3Sonnet,
 21    #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
 22    Claude3Haiku,
 23    #[serde(rename = "custom")]
 24    Custom {
 25        name: String,
 26        max_tokens: usize,
 27        /// Override this model with a different Anthropic model for tool calls.
 28        tool_override: Option<String>,
 29    },
 30}
 31
 32impl Model {
 33    pub fn from_id(id: &str) -> Result<Self> {
 34        if id.starts_with("claude-3-5-sonnet") {
 35            Ok(Self::Claude3_5Sonnet)
 36        } else if id.starts_with("claude-3-opus") {
 37            Ok(Self::Claude3Opus)
 38        } else if id.starts_with("claude-3-sonnet") {
 39            Ok(Self::Claude3Sonnet)
 40        } else if id.starts_with("claude-3-haiku") {
 41            Ok(Self::Claude3Haiku)
 42        } else {
 43            Err(anyhow!("invalid model id"))
 44        }
 45    }
 46
 47    pub fn id(&self) -> &str {
 48        match self {
 49            Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
 50            Model::Claude3Opus => "claude-3-opus-20240229",
 51            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 52            Model::Claude3Haiku => "claude-3-opus-20240307",
 53            Self::Custom { name, .. } => name,
 54        }
 55    }
 56
 57    pub fn display_name(&self) -> &str {
 58        match self {
 59            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 60            Self::Claude3Opus => "Claude 3 Opus",
 61            Self::Claude3Sonnet => "Claude 3 Sonnet",
 62            Self::Claude3Haiku => "Claude 3 Haiku",
 63            Self::Custom { name, .. } => name,
 64        }
 65    }
 66
 67    pub fn max_token_count(&self) -> usize {
 68        match self {
 69            Self::Claude3_5Sonnet
 70            | Self::Claude3Opus
 71            | Self::Claude3Sonnet
 72            | Self::Claude3Haiku => 200_000,
 73            Self::Custom { max_tokens, .. } => *max_tokens,
 74        }
 75    }
 76
 77    pub fn tool_model_id(&self) -> &str {
 78        if let Self::Custom {
 79            tool_override: Some(tool_override),
 80            ..
 81        } = self
 82        {
 83            tool_override
 84        } else {
 85            self.id()
 86        }
 87    }
 88}
 89
 90pub async fn complete(
 91    client: &dyn HttpClient,
 92    api_url: &str,
 93    api_key: &str,
 94    request: Request,
 95) -> Result<Response> {
 96    let uri = format!("{api_url}/v1/messages");
 97    let request_builder = HttpRequest::builder()
 98        .method(Method::POST)
 99        .uri(uri)
100        .header("Anthropic-Version", "2023-06-01")
101        .header("Anthropic-Beta", "tools-2024-04-04")
102        .header("X-Api-Key", api_key)
103        .header("Content-Type", "application/json");
104
105    let serialized_request = serde_json::to_string(&request)?;
106    let request = request_builder.body(AsyncBody::from(serialized_request))?;
107
108    let mut response = client.send(request).await?;
109    if response.status().is_success() {
110        let mut body = Vec::new();
111        response.body_mut().read_to_end(&mut body).await?;
112        let response_message: Response = serde_json::from_slice(&body)?;
113        Ok(response_message)
114    } else {
115        let mut body = Vec::new();
116        response.body_mut().read_to_end(&mut body).await?;
117        let body_str = std::str::from_utf8(&body)?;
118        Err(anyhow!(
119            "Failed to connect to API: {} {}",
120            response.status(),
121            body_str
122        ))
123    }
124}
125
126pub async fn stream_completion(
127    client: &dyn HttpClient,
128    api_url: &str,
129    api_key: &str,
130    request: Request,
131    low_speed_timeout: Option<Duration>,
132) -> Result<BoxStream<'static, Result<Event>>> {
133    let request = StreamingRequest {
134        base: request,
135        stream: true,
136    };
137    let uri = format!("{api_url}/v1/messages");
138    let mut request_builder = HttpRequest::builder()
139        .method(Method::POST)
140        .uri(uri)
141        .header("Anthropic-Version", "2023-06-01")
142        .header("Anthropic-Beta", "tools-2024-04-04")
143        .header("X-Api-Key", api_key)
144        .header("Content-Type", "application/json");
145    if let Some(low_speed_timeout) = low_speed_timeout {
146        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
147    }
148    let serialized_request = serde_json::to_string(&request)?;
149    let request = request_builder.body(AsyncBody::from(serialized_request))?;
150
151    let mut response = client.send(request).await?;
152    if response.status().is_success() {
153        let reader = BufReader::new(response.into_body());
154        Ok(reader
155            .lines()
156            .filter_map(|line| async move {
157                match line {
158                    Ok(line) => {
159                        let line = line.strip_prefix("data: ")?;
160                        match serde_json::from_str(line) {
161                            Ok(response) => Some(Ok(response)),
162                            Err(error) => Some(Err(anyhow!(error))),
163                        }
164                    }
165                    Err(error) => Some(Err(anyhow!(error))),
166                }
167            })
168            .boxed())
169    } else {
170        let mut body = Vec::new();
171        response.body_mut().read_to_end(&mut body).await?;
172
173        let body_str = std::str::from_utf8(&body)?;
174
175        match serde_json::from_str::<Event>(body_str) {
176            Ok(Event::Error { error }) => Err(api_error_to_err(error)),
177            Ok(_) => Err(anyhow!(
178                "Unexpected success response while expecting an error: '{body_str}'",
179            )),
180            Err(_) => Err(anyhow!(
181                "Failed to connect to API: {} {}",
182                response.status(),
183                body_str,
184            )),
185        }
186    }
187}
188
189pub fn extract_text_from_events(
190    response: impl Stream<Item = Result<Event>>,
191) -> impl Stream<Item = Result<String>> {
192    response.filter_map(|response| async move {
193        match response {
194            Ok(response) => match response {
195                Event::ContentBlockStart { content_block, .. } => match content_block {
196                    Content::Text { text } => Some(Ok(text)),
197                    _ => None,
198                },
199                Event::ContentBlockDelta { delta, .. } => match delta {
200                    ContentDelta::TextDelta { text } => Some(Ok(text)),
201                    _ => None,
202                },
203                Event::Error { error } => Some(Err(api_error_to_err(error))),
204                _ => None,
205            },
206            Err(error) => Some(Err(error)),
207        }
208    })
209}
210
211fn api_error_to_err(
212    ApiError {
213        error_type,
214        message,
215    }: ApiError,
216) -> anyhow::Error {
217    anyhow!("API error. Type: '{error_type}', message: '{message}'",)
218}
219
220#[derive(Debug, Serialize, Deserialize)]
221pub struct Message {
222    pub role: Role,
223    pub content: Vec<Content>,
224}
225
226#[derive(Debug, Serialize, Deserialize)]
227#[serde(rename_all = "lowercase")]
228pub enum Role {
229    User,
230    Assistant,
231}
232
233#[derive(Debug, Serialize, Deserialize)]
234#[serde(tag = "type")]
235pub enum Content {
236    #[serde(rename = "text")]
237    Text { text: String },
238    #[serde(rename = "image")]
239    Image { source: ImageSource },
240    #[serde(rename = "tool_use")]
241    ToolUse {
242        id: String,
243        name: String,
244        input: serde_json::Value,
245    },
246    #[serde(rename = "tool_result")]
247    ToolResult {
248        tool_use_id: String,
249        content: String,
250    },
251}
252
253#[derive(Debug, Serialize, Deserialize)]
254pub struct ImageSource {
255    #[serde(rename = "type")]
256    pub source_type: String,
257    pub media_type: String,
258    pub data: String,
259}
260
261#[derive(Debug, Serialize, Deserialize)]
262pub struct Tool {
263    pub name: String,
264    pub description: String,
265    pub input_schema: serde_json::Value,
266}
267
268#[derive(Debug, Serialize, Deserialize)]
269#[serde(tag = "type", rename_all = "lowercase")]
270pub enum ToolChoice {
271    Auto,
272    Any,
273    Tool { name: String },
274}
275
276#[derive(Debug, Serialize, Deserialize)]
277pub struct Request {
278    pub model: String,
279    pub max_tokens: u32,
280    pub messages: Vec<Message>,
281    #[serde(default, skip_serializing_if = "Vec::is_empty")]
282    pub tools: Vec<Tool>,
283    #[serde(default, skip_serializing_if = "Option::is_none")]
284    pub tool_choice: Option<ToolChoice>,
285    #[serde(default, skip_serializing_if = "Option::is_none")]
286    pub system: Option<String>,
287    #[serde(default, skip_serializing_if = "Option::is_none")]
288    pub metadata: Option<Metadata>,
289    #[serde(default, skip_serializing_if = "Vec::is_empty")]
290    pub stop_sequences: Vec<String>,
291    #[serde(default, skip_serializing_if = "Option::is_none")]
292    pub temperature: Option<f32>,
293    #[serde(default, skip_serializing_if = "Option::is_none")]
294    pub top_k: Option<u32>,
295    #[serde(default, skip_serializing_if = "Option::is_none")]
296    pub top_p: Option<f32>,
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300struct StreamingRequest {
301    #[serde(flatten)]
302    pub base: Request,
303    pub stream: bool,
304}
305
306#[derive(Debug, Serialize, Deserialize)]
307pub struct Metadata {
308    pub user_id: Option<String>,
309}
310
311#[derive(Debug, Serialize, Deserialize)]
312pub struct Usage {
313    #[serde(default, skip_serializing_if = "Option::is_none")]
314    pub input_tokens: Option<u32>,
315    #[serde(default, skip_serializing_if = "Option::is_none")]
316    pub output_tokens: Option<u32>,
317}
318
319#[derive(Debug, Serialize, Deserialize)]
320pub struct Response {
321    pub id: String,
322    #[serde(rename = "type")]
323    pub response_type: String,
324    pub role: Role,
325    pub content: Vec<Content>,
326    pub model: String,
327    #[serde(default, skip_serializing_if = "Option::is_none")]
328    pub stop_reason: Option<String>,
329    #[serde(default, skip_serializing_if = "Option::is_none")]
330    pub stop_sequence: Option<String>,
331    pub usage: Usage,
332}
333
334#[derive(Debug, Serialize, Deserialize)]
335#[serde(tag = "type")]
336pub enum Event {
337    #[serde(rename = "message_start")]
338    MessageStart { message: Response },
339    #[serde(rename = "content_block_start")]
340    ContentBlockStart {
341        index: usize,
342        content_block: Content,
343    },
344    #[serde(rename = "content_block_delta")]
345    ContentBlockDelta { index: usize, delta: ContentDelta },
346    #[serde(rename = "content_block_stop")]
347    ContentBlockStop { index: usize },
348    #[serde(rename = "message_delta")]
349    MessageDelta { delta: MessageDelta, usage: Usage },
350    #[serde(rename = "message_stop")]
351    MessageStop,
352    #[serde(rename = "ping")]
353    Ping,
354    #[serde(rename = "error")]
355    Error { error: ApiError },
356}
357
358#[derive(Debug, Serialize, Deserialize)]
359#[serde(tag = "type")]
360pub enum ContentDelta {
361    #[serde(rename = "text_delta")]
362    TextDelta { text: String },
363    #[serde(rename = "input_json_delta")]
364    InputJsonDelta { partial_json: String },
365}
366
367#[derive(Debug, Serialize, Deserialize)]
368pub struct MessageDelta {
369    pub stop_reason: Option<String>,
370    pub stop_sequence: Option<String>,
371}
372
373#[derive(Debug, Serialize, Deserialize)]
374pub struct ApiError {
375    #[serde(rename = "type")]
376    pub error_type: String,
377    pub message: String,
378}