anthropic.rs

  1use anyhow::{anyhow, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use serde::{Deserialize, Serialize};
  6use std::{convert::TryFrom, time::Duration};
  7use strum::EnumIter;
  8
  9pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 13pub enum Model {
 14    #[default]
 15    #[serde(alias = "claude-3-5-sonnet", rename = "claude-3-5-sonnet-20240620")]
 16    Claude3_5Sonnet,
 17    #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
 18    Claude3Opus,
 19    #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
 20    Claude3Sonnet,
 21    #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
 22    Claude3Haiku,
 23}
 24
 25impl Model {
 26    pub fn from_id(id: &str) -> Result<Self> {
 27        if id.starts_with("claude-3-5-sonnet") {
 28            Ok(Self::Claude3_5Sonnet)
 29        } else if id.starts_with("claude-3-opus") {
 30            Ok(Self::Claude3Opus)
 31        } else if id.starts_with("claude-3-sonnet") {
 32            Ok(Self::Claude3Sonnet)
 33        } else if id.starts_with("claude-3-haiku") {
 34            Ok(Self::Claude3Haiku)
 35        } else {
 36            Err(anyhow!("Invalid model id: {}", id))
 37        }
 38    }
 39
 40    pub fn id(&self) -> &'static str {
 41        match self {
 42            Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
 43            Model::Claude3Opus => "claude-3-opus-20240229",
 44            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 45            Model::Claude3Haiku => "claude-3-opus-20240307",
 46        }
 47    }
 48
 49    pub fn display_name(&self) -> &'static str {
 50        match self {
 51            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 52            Self::Claude3Opus => "Claude 3 Opus",
 53            Self::Claude3Sonnet => "Claude 3 Sonnet",
 54            Self::Claude3Haiku => "Claude 3 Haiku",
 55        }
 56    }
 57
 58    pub fn max_token_count(&self) -> usize {
 59        200_000
 60    }
 61}
 62
 63#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 64#[serde(rename_all = "lowercase")]
 65pub enum Role {
 66    User,
 67    Assistant,
 68}
 69
 70impl TryFrom<String> for Role {
 71    type Error = anyhow::Error;
 72
 73    fn try_from(value: String) -> Result<Self> {
 74        match value.as_str() {
 75            "user" => Ok(Self::User),
 76            "assistant" => Ok(Self::Assistant),
 77            _ => Err(anyhow!("invalid role '{value}'")),
 78        }
 79    }
 80}
 81
 82impl From<Role> for String {
 83    fn from(val: Role) -> Self {
 84        match val {
 85            Role::User => "user".to_owned(),
 86            Role::Assistant => "assistant".to_owned(),
 87        }
 88    }
 89}
 90
 91#[derive(Debug, Serialize)]
 92pub struct Request {
 93    pub model: Model,
 94    pub messages: Vec<RequestMessage>,
 95    pub stream: bool,
 96    pub system: String,
 97    pub max_tokens: u32,
 98}
 99
100#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
101pub struct RequestMessage {
102    pub role: Role,
103    pub content: String,
104}
105
106#[derive(Deserialize, Debug)]
107#[serde(tag = "type", rename_all = "snake_case")]
108pub enum ResponseEvent {
109    MessageStart {
110        message: ResponseMessage,
111    },
112    ContentBlockStart {
113        index: u32,
114        content_block: ContentBlock,
115    },
116    Ping {},
117    ContentBlockDelta {
118        index: u32,
119        delta: TextDelta,
120    },
121    ContentBlockStop {
122        index: u32,
123    },
124    MessageDelta {
125        delta: ResponseMessage,
126        usage: Usage,
127    },
128    MessageStop {},
129}
130
131#[derive(Deserialize, Debug)]
132pub struct ResponseMessage {
133    #[serde(rename = "type")]
134    pub message_type: Option<String>,
135    pub id: Option<String>,
136    pub role: Option<String>,
137    pub content: Option<Vec<String>>,
138    pub model: Option<String>,
139    pub stop_reason: Option<String>,
140    pub stop_sequence: Option<String>,
141    pub usage: Option<Usage>,
142}
143
144#[derive(Deserialize, Debug)]
145pub struct Usage {
146    pub input_tokens: Option<u32>,
147    pub output_tokens: Option<u32>,
148}
149
150#[derive(Deserialize, Debug)]
151#[serde(tag = "type", rename_all = "snake_case")]
152pub enum ContentBlock {
153    Text { text: String },
154}
155
156#[derive(Deserialize, Debug)]
157#[serde(tag = "type", rename_all = "snake_case")]
158pub enum TextDelta {
159    TextDelta { text: String },
160}
161
162pub async fn stream_completion(
163    client: &dyn HttpClient,
164    api_url: &str,
165    api_key: &str,
166    request: Request,
167    low_speed_timeout: Option<Duration>,
168) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
169    let uri = format!("{api_url}/v1/messages");
170    let mut request_builder = HttpRequest::builder()
171        .method(Method::POST)
172        .uri(uri)
173        .header("Anthropic-Version", "2023-06-01")
174        .header("Anthropic-Beta", "tools-2024-04-04")
175        .header("X-Api-Key", api_key)
176        .header("Content-Type", "application/json");
177    if let Some(low_speed_timeout) = low_speed_timeout {
178        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
179    }
180    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
181    let mut response = client.send(request).await?;
182    if response.status().is_success() {
183        let reader = BufReader::new(response.into_body());
184        Ok(reader
185            .lines()
186            .filter_map(|line| async move {
187                match line {
188                    Ok(line) => {
189                        let line = line.strip_prefix("data: ")?;
190                        match serde_json::from_str(line) {
191                            Ok(response) => Some(Ok(response)),
192                            Err(error) => Some(Err(anyhow!(error))),
193                        }
194                    }
195                    Err(error) => Some(Err(anyhow!(error))),
196                }
197            })
198            .boxed())
199    } else {
200        let mut body = Vec::new();
201        response.body_mut().read_to_end(&mut body).await?;
202
203        let body_str = std::str::from_utf8(&body)?;
204
205        match serde_json::from_str::<ResponseEvent>(body_str) {
206            Ok(_) => Err(anyhow!(
207                "Unexpected success response while expecting an error: {}",
208                body_str,
209            )),
210            Err(_) => Err(anyhow!(
211                "Failed to connect to API: {} {}",
212                response.status(),
213                body_str,
214            )),
215        }
216    }
217}
218
219// #[cfg(test)]
220// mod tests {
221//     use super::*;
222//     use http::IsahcHttpClient;
223
224//     #[tokio::test]
225//     async fn stream_completion_success() {
226//         let http_client = IsahcHttpClient::new().unwrap();
227
228//         let request = Request {
229//             model: Model::Claude3Opus,
230//             messages: vec![RequestMessage {
231//                 role: Role::User,
232//                 content: "Ping".to_string(),
233//             }],
234//             stream: true,
235//             system: "Respond to ping with pong".to_string(),
236//             max_tokens: 4096,
237//         };
238
239//         let stream = stream_completion(
240//             &http_client,
241//             "https://api.anthropic.com",
242//             &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
243//             request,
244//         )
245//         .await
246//         .unwrap();
247
248//         stream
249//             .for_each(|event| async {
250//                 match event {
251//                     Ok(event) => println!("{:?}", event),
252//                     Err(e) => eprintln!("Error: {:?}", e),
253//                 }
254//             })
255//             .await;
256//     }
257// }