anthropic.rs

  1use anyhow::{anyhow, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  4use isahc::config::Configurable;
  5use serde::{Deserialize, Serialize};
  6use std::{convert::TryFrom, time::Duration};
  7
  8pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
  9
 10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
 11#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
 12pub enum Model {
 13    #[default]
 14    #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
 15    Claude3Opus,
 16    #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
 17    Claude3Sonnet,
 18    #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
 19    Claude3Haiku,
 20}
 21
 22impl Model {
 23    pub fn from_id(id: &str) -> Result<Self> {
 24        if id.starts_with("claude-3-opus") {
 25            Ok(Self::Claude3Opus)
 26        } else if id.starts_with("claude-3-sonnet") {
 27            Ok(Self::Claude3Sonnet)
 28        } else if id.starts_with("claude-3-haiku") {
 29            Ok(Self::Claude3Haiku)
 30        } else {
 31            Err(anyhow!("Invalid model id: {}", id))
 32        }
 33    }
 34
 35    pub fn id(&self) -> &'static str {
 36        match self {
 37            Model::Claude3Opus => "claude-3-opus-20240229",
 38            Model::Claude3Sonnet => "claude-3-sonnet-20240229",
 39            Model::Claude3Haiku => "claude-3-opus-20240307",
 40        }
 41    }
 42
 43    pub fn display_name(&self) -> &'static str {
 44        match self {
 45            Self::Claude3Opus => "Claude 3 Opus",
 46            Self::Claude3Sonnet => "Claude 3 Sonnet",
 47            Self::Claude3Haiku => "Claude 3 Haiku",
 48        }
 49    }
 50
 51    pub fn max_token_count(&self) -> usize {
 52        200_000
 53    }
 54}
 55
 56#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 57#[serde(rename_all = "lowercase")]
 58pub enum Role {
 59    User,
 60    Assistant,
 61}
 62
 63impl TryFrom<String> for Role {
 64    type Error = anyhow::Error;
 65
 66    fn try_from(value: String) -> Result<Self> {
 67        match value.as_str() {
 68            "user" => Ok(Self::User),
 69            "assistant" => Ok(Self::Assistant),
 70            _ => Err(anyhow!("invalid role '{value}'")),
 71        }
 72    }
 73}
 74
 75impl From<Role> for String {
 76    fn from(val: Role) -> Self {
 77        match val {
 78            Role::User => "user".to_owned(),
 79            Role::Assistant => "assistant".to_owned(),
 80        }
 81    }
 82}
 83
 84#[derive(Debug, Serialize)]
 85pub struct Request {
 86    pub model: Model,
 87    pub messages: Vec<RequestMessage>,
 88    pub stream: bool,
 89    pub system: String,
 90    pub max_tokens: u32,
 91}
 92
 93#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 94pub struct RequestMessage {
 95    pub role: Role,
 96    pub content: String,
 97}
 98
 99#[derive(Deserialize, Debug)]
100#[serde(tag = "type", rename_all = "snake_case")]
101pub enum ResponseEvent {
102    MessageStart {
103        message: ResponseMessage,
104    },
105    ContentBlockStart {
106        index: u32,
107        content_block: ContentBlock,
108    },
109    Ping {},
110    ContentBlockDelta {
111        index: u32,
112        delta: TextDelta,
113    },
114    ContentBlockStop {
115        index: u32,
116    },
117    MessageDelta {
118        delta: ResponseMessage,
119        usage: Usage,
120    },
121    MessageStop {},
122}
123
124#[derive(Deserialize, Debug)]
125pub struct ResponseMessage {
126    #[serde(rename = "type")]
127    pub message_type: Option<String>,
128    pub id: Option<String>,
129    pub role: Option<String>,
130    pub content: Option<Vec<String>>,
131    pub model: Option<String>,
132    pub stop_reason: Option<String>,
133    pub stop_sequence: Option<String>,
134    pub usage: Option<Usage>,
135}
136
137#[derive(Deserialize, Debug)]
138pub struct Usage {
139    pub input_tokens: Option<u32>,
140    pub output_tokens: Option<u32>,
141}
142
143#[derive(Deserialize, Debug)]
144#[serde(tag = "type", rename_all = "snake_case")]
145pub enum ContentBlock {
146    Text { text: String },
147}
148
149#[derive(Deserialize, Debug)]
150#[serde(tag = "type", rename_all = "snake_case")]
151pub enum TextDelta {
152    TextDelta { text: String },
153}
154
155pub async fn stream_completion(
156    client: &dyn HttpClient,
157    api_url: &str,
158    api_key: &str,
159    request: Request,
160    low_speed_timeout: Option<Duration>,
161) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
162    let uri = format!("{api_url}/v1/messages");
163    let mut request_builder = HttpRequest::builder()
164        .method(Method::POST)
165        .uri(uri)
166        .header("Anthropic-Version", "2023-06-01")
167        .header("Anthropic-Beta", "tools-2024-04-04")
168        .header("X-Api-Key", api_key)
169        .header("Content-Type", "application/json");
170    if let Some(low_speed_timeout) = low_speed_timeout {
171        request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
172    }
173    let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
174    let mut response = client.send(request).await?;
175    if response.status().is_success() {
176        let reader = BufReader::new(response.into_body());
177        Ok(reader
178            .lines()
179            .filter_map(|line| async move {
180                match line {
181                    Ok(line) => {
182                        let line = line.strip_prefix("data: ")?;
183                        match serde_json::from_str(line) {
184                            Ok(response) => Some(Ok(response)),
185                            Err(error) => Some(Err(anyhow!(error))),
186                        }
187                    }
188                    Err(error) => Some(Err(anyhow!(error))),
189                }
190            })
191            .boxed())
192    } else {
193        let mut body = Vec::new();
194        response.body_mut().read_to_end(&mut body).await?;
195
196        let body_str = std::str::from_utf8(&body)?;
197
198        match serde_json::from_str::<ResponseEvent>(body_str) {
199            Ok(_) => Err(anyhow!(
200                "Unexpected success response while expecting an error: {}",
201                body_str,
202            )),
203            Err(_) => Err(anyhow!(
204                "Failed to connect to API: {} {}",
205                response.status(),
206                body_str,
207            )),
208        }
209    }
210}
211
212// #[cfg(test)]
213// mod tests {
214//     use super::*;
215//     use http::IsahcHttpClient;
216
217//     #[tokio::test]
218//     async fn stream_completion_success() {
219//         let http_client = IsahcHttpClient::new().unwrap();
220
221//         let request = Request {
222//             model: Model::Claude3Opus,
223//             messages: vec![RequestMessage {
224//                 role: Role::User,
225//                 content: "Ping".to_string(),
226//             }],
227//             stream: true,
228//             system: "Respond to ping with pong".to_string(),
229//             max_tokens: 4096,
230//         };
231
232//         let stream = stream_completion(
233//             &http_client,
234//             "https://api.anthropic.com",
235//             &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
236//             request,
237//         )
238//         .await
239//         .unwrap();
240
241//         stream
242//             .for_each(|event| async {
243//                 match event {
244//                     Ok(event) => println!("{:?}", event),
245//                     Err(e) => eprintln!("Error: {:?}", e),
246//                 }
247//             })
248//             .await;
249//     }
250// }