anthropic.rs

  1use anyhow::{anyhow, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use serde::{Deserialize, Serialize};
  6use std::{convert::TryFrom, time::Duration};
  7use strum::EnumIter;
  8
  9pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
 10
 11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
 13pub enum Model {
 14    #[default]
 15    #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
 16    Claude3Opus,
 17    #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
 18    Claude3Sonnet,
 19    #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
 20    Claude3Haiku,
 21}
 22
 23impl Model {
 24    pub fn from_id(id: &str) -> Result<Self> {
 25        if id.starts_with("claude-3-opus") {
 26            Ok(Self::Claude3Opus)
 27        } else if id.starts_with("claude-3-sonnet") {
 28            Ok(Self::Claude3Sonnet)
 29        } else if id.starts_with("claude-3-haiku") {
 30            Ok(Self::Claude3Haiku)
 31        } else {
 32            Err(anyhow!("Invalid model id: {}", id))
 33        }
 34    }
 35
 36    pub fn id(&self) -> &'static str {
 37        match self {
 38            Model::Claude3Opus => "claude-3-opus-20240229",
 39            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 40            Model::Claude3Haiku => "claude-3-opus-20240307",
 41        }
 42    }
 43
 44    pub fn display_name(&self) -> &'static str {
 45        match self {
 46            Self::Claude3Opus => "Claude 3 Opus",
 47            Self::Claude3Sonnet => "Claude 3 Sonnet",
 48            Self::Claude3Haiku => "Claude 3 Haiku",
 49        }
 50    }
 51
 52    pub fn max_token_count(&self) -> usize {
 53        200_000
 54    }
 55}
 56
 57#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 58#[serde(rename_all = "lowercase")]
 59pub enum Role {
 60    User,
 61    Assistant,
 62}
 63
 64impl TryFrom<String> for Role {
 65    type Error = anyhow::Error;
 66
 67    fn try_from(value: String) -> Result<Self> {
 68        match value.as_str() {
 69            "user" => Ok(Self::User),
 70            "assistant" => Ok(Self::Assistant),
 71            _ => Err(anyhow!("invalid role '{value}'")),
 72        }
 73    }
 74}
 75
 76impl From<Role> for String {
 77    fn from(val: Role) -> Self {
 78        match val {
 79            Role::User => "user".to_owned(),
 80            Role::Assistant => "assistant".to_owned(),
 81        }
 82    }
 83}
 84
 85#[derive(Debug, Serialize)]
 86pub struct Request {
 87    pub model: Model,
 88    pub messages: Vec<RequestMessage>,
 89    pub stream: bool,
 90    pub system: String,
 91    pub max_tokens: u32,
 92}
 93
 94#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 95pub struct RequestMessage {
 96    pub role: Role,
 97    pub content: String,
 98}
 99
100#[derive(Deserialize, Debug)]
101#[serde(tag = "type", rename_all = "snake_case")]
102pub enum ResponseEvent {
103    MessageStart {
104        message: ResponseMessage,
105    },
106    ContentBlockStart {
107        index: u32,
108        content_block: ContentBlock,
109    },
110    Ping {},
111    ContentBlockDelta {
112        index: u32,
113        delta: TextDelta,
114    },
115    ContentBlockStop {
116        index: u32,
117    },
118    MessageDelta {
119        delta: ResponseMessage,
120        usage: Usage,
121    },
122    MessageStop {},
123}
124
125#[derive(Deserialize, Debug)]
126pub struct ResponseMessage {
127    #[serde(rename = "type")]
128    pub message_type: Option<String>,
129    pub id: Option<String>,
130    pub role: Option<String>,
131    pub content: Option<Vec<String>>,
132    pub model: Option<String>,
133    pub stop_reason: Option<String>,
134    pub stop_sequence: Option<String>,
135    pub usage: Option<Usage>,
136}
137
138#[derive(Deserialize, Debug)]
139pub struct Usage {
140    pub input_tokens: Option<u32>,
141    pub output_tokens: Option<u32>,
142}
143
144#[derive(Deserialize, Debug)]
145#[serde(tag = "type", rename_all = "snake_case")]
146pub enum ContentBlock {
147    Text { text: String },
148}
149
150#[derive(Deserialize, Debug)]
151#[serde(tag = "type", rename_all = "snake_case")]
152pub enum TextDelta {
153    TextDelta { text: String },
154}
155
156pub async fn stream_completion(
157    client: &dyn HttpClient,
158    api_url: &str,
159    api_key: &str,
160    request: Request,
161    low_speed_timeout: Option<Duration>,
162) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
163    let uri = format!("{api_url}/v1/messages");
164    let mut request_builder = HttpRequest::builder()
165        .method(Method::POST)
166        .uri(uri)
167        .header("Anthropic-Version", "2023-06-01")
168        .header("Anthropic-Beta", "tools-2024-04-04")
169        .header("X-Api-Key", api_key)
170        .header("Content-Type", "application/json");
171    if let Some(low_speed_timeout) = low_speed_timeout {
172        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
173    }
174    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
175    let mut response = client.send(request).await?;
176    if response.status().is_success() {
177        let reader = BufReader::new(response.into_body());
178        Ok(reader
179            .lines()
180            .filter_map(|line| async move {
181                match line {
182                    Ok(line) => {
183                        let line = line.strip_prefix("data: ")?;
184                        match serde_json::from_str(line) {
185                            Ok(response) => Some(Ok(response)),
186                            Err(error) => Some(Err(anyhow!(error))),
187                        }
188                    }
189                    Err(error) => Some(Err(anyhow!(error))),
190                }
191            })
192            .boxed())
193    } else {
194        let mut body = Vec::new();
195        response.body_mut().read_to_end(&mut body).await?;
196
197        let body_str = std::str::from_utf8(&body)?;
198
199        match serde_json::from_str::<ResponseEvent>(body_str) {
200            Ok(_) => Err(anyhow!(
201                "Unexpected success response while expecting an error: {}",
202                body_str,
203            )),
204            Err(_) => Err(anyhow!(
205                "Failed to connect to API: {} {}",
206                response.status(),
207                body_str,
208            )),
209        }
210    }
211}
212
213// #[cfg(test)]
214// mod tests {
215//     use super::*;
216//     use http::IsahcHttpClient;
217
218//     #[tokio::test]
219//     async fn stream_completion_success() {
220//         let http_client = IsahcHttpClient::new().unwrap();
221
222//         let request = Request {
223//             model: Model::Claude3Opus,
224//             messages: vec![RequestMessage {
225//                 role: Role::User,
226//                 content: "Ping".to_string(),
227//             }],
228//             stream: true,
229//             system: "Respond to ping with pong".to_string(),
230//             max_tokens: 4096,
231//         };
232
233//         let stream = stream_completion(
234//             &http_client,
235//             "https://api.anthropic.com",
236//             &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
237//             request,
238//         )
239//         .await
240//         .unwrap();
241
242//         stream
243//             .for_each(|event| async {
244//                 match event {
245//                     Ok(event) => println!("{:?}", event),
246//                     Err(e) => eprintln!("Error: {:?}", e),
247//                 }
248//             })
249//             .await;
250//     }
251// }