1use anyhow::{anyhow, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use serde::{Deserialize, Serialize};
6use std::{convert::TryFrom, time::Duration};
7use strum::EnumIter;
8
9pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
10
11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
13pub enum Model {
14 #[default]
15 #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
16 Claude3Opus,
17 #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
18 Claude3Sonnet,
19 #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
20 Claude3Haiku,
21}
22
23impl Model {
24 pub fn from_id(id: &str) -> Result<Self> {
25 if id.starts_with("claude-3-opus") {
26 Ok(Self::Claude3Opus)
27 } else if id.starts_with("claude-3-sonnet") {
28 Ok(Self::Claude3Sonnet)
29 } else if id.starts_with("claude-3-haiku") {
30 Ok(Self::Claude3Haiku)
31 } else {
32 Err(anyhow!("Invalid model id: {}", id))
33 }
34 }
35
36 pub fn id(&self) -> &'static str {
37 match self {
38 Model::Claude3Opus => "claude-3-opus-20240229",
39 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
40 Model::Claude3Haiku => "claude-3-opus-20240307",
41 }
42 }
43
44 pub fn display_name(&self) -> &'static str {
45 match self {
46 Self::Claude3Opus => "Claude 3 Opus",
47 Self::Claude3Sonnet => "Claude 3 Sonnet",
48 Self::Claude3Haiku => "Claude 3 Haiku",
49 }
50 }
51
52 pub fn max_token_count(&self) -> usize {
53 200_000
54 }
55}
56
57#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
58#[serde(rename_all = "lowercase")]
59pub enum Role {
60 User,
61 Assistant,
62}
63
64impl TryFrom<String> for Role {
65 type Error = anyhow::Error;
66
67 fn try_from(value: String) -> Result<Self> {
68 match value.as_str() {
69 "user" => Ok(Self::User),
70 "assistant" => Ok(Self::Assistant),
71 _ => Err(anyhow!("invalid role '{value}'")),
72 }
73 }
74}
75
76impl From<Role> for String {
77 fn from(val: Role) -> Self {
78 match val {
79 Role::User => "user".to_owned(),
80 Role::Assistant => "assistant".to_owned(),
81 }
82 }
83}
84
85#[derive(Debug, Serialize)]
86pub struct Request {
87 pub model: Model,
88 pub messages: Vec<RequestMessage>,
89 pub stream: bool,
90 pub system: String,
91 pub max_tokens: u32,
92}
93
94#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
95pub struct RequestMessage {
96 pub role: Role,
97 pub content: String,
98}
99
100#[derive(Deserialize, Debug)]
101#[serde(tag = "type", rename_all = "snake_case")]
102pub enum ResponseEvent {
103 MessageStart {
104 message: ResponseMessage,
105 },
106 ContentBlockStart {
107 index: u32,
108 content_block: ContentBlock,
109 },
110 Ping {},
111 ContentBlockDelta {
112 index: u32,
113 delta: TextDelta,
114 },
115 ContentBlockStop {
116 index: u32,
117 },
118 MessageDelta {
119 delta: ResponseMessage,
120 usage: Usage,
121 },
122 MessageStop {},
123}
124
125#[derive(Deserialize, Debug)]
126pub struct ResponseMessage {
127 #[serde(rename = "type")]
128 pub message_type: Option<String>,
129 pub id: Option<String>,
130 pub role: Option<String>,
131 pub content: Option<Vec<String>>,
132 pub model: Option<String>,
133 pub stop_reason: Option<String>,
134 pub stop_sequence: Option<String>,
135 pub usage: Option<Usage>,
136}
137
138#[derive(Deserialize, Debug)]
139pub struct Usage {
140 pub input_tokens: Option<u32>,
141 pub output_tokens: Option<u32>,
142}
143
144#[derive(Deserialize, Debug)]
145#[serde(tag = "type", rename_all = "snake_case")]
146pub enum ContentBlock {
147 Text { text: String },
148}
149
150#[derive(Deserialize, Debug)]
151#[serde(tag = "type", rename_all = "snake_case")]
152pub enum TextDelta {
153 TextDelta { text: String },
154}
155
156pub async fn stream_completion(
157 client: &dyn HttpClient,
158 api_url: &str,
159 api_key: &str,
160 request: Request,
161 low_speed_timeout: Option<Duration>,
162) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
163 let uri = format!("{api_url}/v1/messages");
164 let mut request_builder = HttpRequest::builder()
165 .method(Method::POST)
166 .uri(uri)
167 .header("Anthropic-Version", "2023-06-01")
168 .header("Anthropic-Beta", "tools-2024-04-04")
169 .header("X-Api-Key", api_key)
170 .header("Content-Type", "application/json");
171 if let Some(low_speed_timeout) = low_speed_timeout {
172 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
173 }
174 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
175 let mut response = client.send(request).await?;
176 if response.status().is_success() {
177 let reader = BufReader::new(response.into_body());
178 Ok(reader
179 .lines()
180 .filter_map(|line| async move {
181 match line {
182 Ok(line) => {
183 let line = line.strip_prefix("data: ")?;
184 match serde_json::from_str(line) {
185 Ok(response) => Some(Ok(response)),
186 Err(error) => Some(Err(anyhow!(error))),
187 }
188 }
189 Err(error) => Some(Err(anyhow!(error))),
190 }
191 })
192 .boxed())
193 } else {
194 let mut body = Vec::new();
195 response.body_mut().read_to_end(&mut body).await?;
196
197 let body_str = std::str::from_utf8(&body)?;
198
199 match serde_json::from_str::<ResponseEvent>(body_str) {
200 Ok(_) => Err(anyhow!(
201 "Unexpected success response while expecting an error: {}",
202 body_str,
203 )),
204 Err(_) => Err(anyhow!(
205 "Failed to connect to API: {} {}",
206 response.status(),
207 body_str,
208 )),
209 }
210 }
211}
212
213// #[cfg(test)]
214// mod tests {
215// use super::*;
216// use http::IsahcHttpClient;
217
218// #[tokio::test]
219// async fn stream_completion_success() {
220// let http_client = IsahcHttpClient::new().unwrap();
221
222// let request = Request {
223// model: Model::Claude3Opus,
224// messages: vec![RequestMessage {
225// role: Role::User,
226// content: "Ping".to_string(),
227// }],
228// stream: true,
229// system: "Respond to ping with pong".to_string(),
230// max_tokens: 4096,
231// };
232
233// let stream = stream_completion(
234// &http_client,
235// "https://api.anthropic.com",
236// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
237// request,
238// )
239// .await
240// .unwrap();
241
242// stream
243// .for_each(|event| async {
244// match event {
245// Ok(event) => println!("{:?}", event),
246// Err(e) => eprintln!("Error: {:?}", e),
247// }
248// })
249// .await;
250// }
251// }