anthropic.rs

  1use anyhow::{anyhow, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use serde::{Deserialize, Serialize};
  6use std::{convert::TryFrom, time::Duration};
  7use strum::EnumIter;
  8
  9pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 13pub enum Model {
 14    #[default]
 15    #[serde(alias = "claude-3-5-sonnet", rename = "claude-3-5-sonnet-20240620")]
 16    Claude3_5Sonnet,
 17    #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
 18    Claude3Opus,
 19    #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
 20    Claude3Sonnet,
 21    #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
 22    Claude3Haiku,
 23    #[serde(rename = "custom")]
 24    Custom { name: String, max_tokens: usize },
 25}
 26
 27impl Model {
 28    pub fn from_id(id: &str) -> Result<Self> {
 29        if id.starts_with("claude-3-5-sonnet") {
 30            Ok(Self::Claude3_5Sonnet)
 31        } else if id.starts_with("claude-3-opus") {
 32            Ok(Self::Claude3Opus)
 33        } else if id.starts_with("claude-3-sonnet") {
 34            Ok(Self::Claude3Sonnet)
 35        } else if id.starts_with("claude-3-haiku") {
 36            Ok(Self::Claude3Haiku)
 37        } else {
 38            Err(anyhow!("invalid model id"))
 39        }
 40    }
 41
 42    pub fn id(&self) -> &str {
 43        match self {
 44            Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
 45            Model::Claude3Opus => "claude-3-opus-20240229",
 46            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 47            Model::Claude3Haiku => "claude-3-opus-20240307",
 48            Self::Custom { name, .. } => name,
 49        }
 50    }
 51
 52    pub fn display_name(&self) -> &str {
 53        match self {
 54            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
 55            Self::Claude3Opus => "Claude 3 Opus",
 56            Self::Claude3Sonnet => "Claude 3 Sonnet",
 57            Self::Claude3Haiku => "Claude 3 Haiku",
 58            Self::Custom { name, .. } => name,
 59        }
 60    }
 61
 62    pub fn max_token_count(&self) -> usize {
 63        match self {
 64            Self::Claude3_5Sonnet
 65            | Self::Claude3Opus
 66            | Self::Claude3Sonnet
 67            | Self::Claude3Haiku => 200_000,
 68            Self::Custom { max_tokens, .. } => *max_tokens,
 69        }
 70    }
 71}
 72
 73#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 74#[serde(rename_all = "lowercase")]
 75pub enum Role {
 76    User,
 77    Assistant,
 78}
 79
 80impl TryFrom<String> for Role {
 81    type Error = anyhow::Error;
 82
 83    fn try_from(value: String) -> Result<Self> {
 84        match value.as_str() {
 85            "user" => Ok(Self::User),
 86            "assistant" => Ok(Self::Assistant),
 87            _ => Err(anyhow!("invalid role '{value}'")),
 88        }
 89    }
 90}
 91
 92impl From<Role> for String {
 93    fn from(val: Role) -> Self {
 94        match val {
 95            Role::User => "user".to_owned(),
 96            Role::Assistant => "assistant".to_owned(),
 97        }
 98    }
 99}
100
101#[derive(Debug, Serialize)]
102pub struct Request {
103    #[serde(serialize_with = "serialize_request_model")]
104    pub model: Model,
105    pub messages: Vec<RequestMessage>,
106    pub stream: bool,
107    pub system: String,
108    pub max_tokens: u32,
109}
110
111fn serialize_request_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
112where
113    S: serde::Serializer,
114{
115    serializer.serialize_str(&model.id())
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119pub struct RequestMessage {
120    pub role: Role,
121    pub content: String,
122}
123
124#[derive(Deserialize, Debug)]
125#[serde(tag = "type", rename_all = "snake_case")]
126pub enum ResponseEvent {
127    MessageStart {
128        message: ResponseMessage,
129    },
130    ContentBlockStart {
131        index: u32,
132        content_block: ContentBlock,
133    },
134    Ping {},
135    ContentBlockDelta {
136        index: u32,
137        delta: TextDelta,
138    },
139    ContentBlockStop {
140        index: u32,
141    },
142    MessageDelta {
143        delta: ResponseMessage,
144        usage: Usage,
145    },
146    MessageStop {},
147}
148
149#[derive(Deserialize, Debug)]
150pub struct ResponseMessage {
151    #[serde(rename = "type")]
152    pub message_type: Option<String>,
153    pub id: Option<String>,
154    pub role: Option<String>,
155    pub content: Option<Vec<String>>,
156    pub model: Option<String>,
157    pub stop_reason: Option<String>,
158    pub stop_sequence: Option<String>,
159    pub usage: Option<Usage>,
160}
161
162#[derive(Deserialize, Debug)]
163pub struct Usage {
164    pub input_tokens: Option<u32>,
165    pub output_tokens: Option<u32>,
166}
167
168#[derive(Deserialize, Debug)]
169#[serde(tag = "type", rename_all = "snake_case")]
170pub enum ContentBlock {
171    Text { text: String },
172}
173
174#[derive(Deserialize, Debug)]
175#[serde(tag = "type", rename_all = "snake_case")]
176pub enum TextDelta {
177    TextDelta { text: String },
178}
179
180pub async fn stream_completion(
181    client: &dyn HttpClient,
182    api_url: &str,
183    api_key: &str,
184    request: Request,
185    low_speed_timeout: Option<Duration>,
186) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
187    let uri = format!("{api_url}/v1/messages");
188    let mut request_builder = HttpRequest::builder()
189        .method(Method::POST)
190        .uri(uri)
191        .header("Anthropic-Version", "2023-06-01")
192        .header("Anthropic-Beta", "tools-2024-04-04")
193        .header("X-Api-Key", api_key)
194        .header("Content-Type", "application/json");
195    if let Some(low_speed_timeout) = low_speed_timeout {
196        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
197    }
198    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
199    let mut response = client.send(request).await?;
200    if response.status().is_success() {
201        let reader = BufReader::new(response.into_body());
202        Ok(reader
203            .lines()
204            .filter_map(|line| async move {
205                match line {
206                    Ok(line) => {
207                        let line = line.strip_prefix("data: ")?;
208                        match serde_json::from_str(line) {
209                            Ok(response) => Some(Ok(response)),
210                            Err(error) => Some(Err(anyhow!(error))),
211                        }
212                    }
213                    Err(error) => Some(Err(anyhow!(error))),
214                }
215            })
216            .boxed())
217    } else {
218        let mut body = Vec::new();
219        response.body_mut().read_to_end(&mut body).await?;
220
221        let body_str = std::str::from_utf8(&body)?;
222
223        match serde_json::from_str::<ResponseEvent>(body_str) {
224            Ok(_) => Err(anyhow!(
225                "Unexpected success response while expecting an error: {}",
226                body_str,
227            )),
228            Err(_) => Err(anyhow!(
229                "Failed to connect to API: {} {}",
230                response.status(),
231                body_str,
232            )),
233        }
234    }
235}
236
237// #[cfg(test)]
238// mod tests {
239//     use super::*;
240//     use http::IsahcHttpClient;
241
242//     #[tokio::test]
243//     async fn stream_completion_success() {
244//         let http_client = IsahcHttpClient::new().unwrap();
245
246//         let request = Request {
247//             model: Model::Claude3Opus,
248//             messages: vec![RequestMessage {
249//                 role: Role::User,
250//                 content: "Ping".to_string(),
251//             }],
252//             stream: true,
253//             system: "Respond to ping with pong".to_string(),
254//             max_tokens: 4096,
255//         };
256
257//         let stream = stream_completion(
258//             &http_client,
259//             "https://api.anthropic.com",
260//             &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
261//             request,
262//         )
263//         .await
264//         .unwrap();
265
266//         stream
267//             .for_each(|event| async {
268//                 match event {
269//                     Ok(event) => println!("{:?}", event),
270//                     Err(e) => eprintln!("Error: {:?}", e),
271//                 }
272//             })
273//             .await;
274//     }
275// }