1use anyhow::{anyhow, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use serde::{Deserialize, Serialize};
6use std::{convert::TryFrom, time::Duration};
7use strum::EnumIter;
8
9pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
10
11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
13pub enum Model {
14 #[default]
15 #[serde(alias = "claude-3-5-sonnet", rename = "claude-3-5-sonnet-20240620")]
16 Claude3_5Sonnet,
17 #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
18 Claude3Opus,
19 #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
20 Claude3Sonnet,
21 #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
22 Claude3Haiku,
23}
24
25impl Model {
26 pub fn from_id(id: &str) -> Result<Self> {
27 if id.starts_with("claude-3-5-sonnet") {
28 Ok(Self::Claude3_5Sonnet)
29 } else if id.starts_with("claude-3-opus") {
30 Ok(Self::Claude3Opus)
31 } else if id.starts_with("claude-3-sonnet") {
32 Ok(Self::Claude3Sonnet)
33 } else if id.starts_with("claude-3-haiku") {
34 Ok(Self::Claude3Haiku)
35 } else {
36 Err(anyhow!("Invalid model id: {}", id))
37 }
38 }
39
40 pub fn id(&self) -> &'static str {
41 match self {
42 Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
43 Model::Claude3Opus => "claude-3-opus-20240229",
44 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
45 Model::Claude3Haiku => "claude-3-opus-20240307",
46 }
47 }
48
49 pub fn display_name(&self) -> &'static str {
50 match self {
51 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
52 Self::Claude3Opus => "Claude 3 Opus",
53 Self::Claude3Sonnet => "Claude 3 Sonnet",
54 Self::Claude3Haiku => "Claude 3 Haiku",
55 }
56 }
57
58 pub fn max_token_count(&self) -> usize {
59 200_000
60 }
61}
62
63#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
64#[serde(rename_all = "lowercase")]
65pub enum Role {
66 User,
67 Assistant,
68}
69
70impl TryFrom<String> for Role {
71 type Error = anyhow::Error;
72
73 fn try_from(value: String) -> Result<Self> {
74 match value.as_str() {
75 "user" => Ok(Self::User),
76 "assistant" => Ok(Self::Assistant),
77 _ => Err(anyhow!("invalid role '{value}'")),
78 }
79 }
80}
81
82impl From<Role> for String {
83 fn from(val: Role) -> Self {
84 match val {
85 Role::User => "user".to_owned(),
86 Role::Assistant => "assistant".to_owned(),
87 }
88 }
89}
90
91#[derive(Debug, Serialize)]
92pub struct Request {
93 pub model: Model,
94 pub messages: Vec<RequestMessage>,
95 pub stream: bool,
96 pub system: String,
97 pub max_tokens: u32,
98}
99
100#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
101pub struct RequestMessage {
102 pub role: Role,
103 pub content: String,
104}
105
106#[derive(Deserialize, Debug)]
107#[serde(tag = "type", rename_all = "snake_case")]
108pub enum ResponseEvent {
109 MessageStart {
110 message: ResponseMessage,
111 },
112 ContentBlockStart {
113 index: u32,
114 content_block: ContentBlock,
115 },
116 Ping {},
117 ContentBlockDelta {
118 index: u32,
119 delta: TextDelta,
120 },
121 ContentBlockStop {
122 index: u32,
123 },
124 MessageDelta {
125 delta: ResponseMessage,
126 usage: Usage,
127 },
128 MessageStop {},
129}
130
131#[derive(Deserialize, Debug)]
132pub struct ResponseMessage {
133 #[serde(rename = "type")]
134 pub message_type: Option<String>,
135 pub id: Option<String>,
136 pub role: Option<String>,
137 pub content: Option<Vec<String>>,
138 pub model: Option<String>,
139 pub stop_reason: Option<String>,
140 pub stop_sequence: Option<String>,
141 pub usage: Option<Usage>,
142}
143
144#[derive(Deserialize, Debug)]
145pub struct Usage {
146 pub input_tokens: Option<u32>,
147 pub output_tokens: Option<u32>,
148}
149
150#[derive(Deserialize, Debug)]
151#[serde(tag = "type", rename_all = "snake_case")]
152pub enum ContentBlock {
153 Text { text: String },
154}
155
156#[derive(Deserialize, Debug)]
157#[serde(tag = "type", rename_all = "snake_case")]
158pub enum TextDelta {
159 TextDelta { text: String },
160}
161
162pub async fn stream_completion(
163 client: &dyn HttpClient,
164 api_url: &str,
165 api_key: &str,
166 request: Request,
167 low_speed_timeout: Option<Duration>,
168) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
169 let uri = format!("{api_url}/v1/messages");
170 let mut request_builder = HttpRequest::builder()
171 .method(Method::POST)
172 .uri(uri)
173 .header("Anthropic-Version", "2023-06-01")
174 .header("Anthropic-Beta", "tools-2024-04-04")
175 .header("X-Api-Key", api_key)
176 .header("Content-Type", "application/json");
177 if let Some(low_speed_timeout) = low_speed_timeout {
178 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
179 }
180 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
181 let mut response = client.send(request).await?;
182 if response.status().is_success() {
183 let reader = BufReader::new(response.into_body());
184 Ok(reader
185 .lines()
186 .filter_map(|line| async move {
187 match line {
188 Ok(line) => {
189 let line = line.strip_prefix("data: ")?;
190 match serde_json::from_str(line) {
191 Ok(response) => Some(Ok(response)),
192 Err(error) => Some(Err(anyhow!(error))),
193 }
194 }
195 Err(error) => Some(Err(anyhow!(error))),
196 }
197 })
198 .boxed())
199 } else {
200 let mut body = Vec::new();
201 response.body_mut().read_to_end(&mut body).await?;
202
203 let body_str = std::str::from_utf8(&body)?;
204
205 match serde_json::from_str::<ResponseEvent>(body_str) {
206 Ok(_) => Err(anyhow!(
207 "Unexpected success response while expecting an error: {}",
208 body_str,
209 )),
210 Err(_) => Err(anyhow!(
211 "Failed to connect to API: {} {}",
212 response.status(),
213 body_str,
214 )),
215 }
216 }
217}
218
219// #[cfg(test)]
220// mod tests {
221// use super::*;
222// use http::IsahcHttpClient;
223
224// #[tokio::test]
225// async fn stream_completion_success() {
226// let http_client = IsahcHttpClient::new().unwrap();
227
228// let request = Request {
229// model: Model::Claude3Opus,
230// messages: vec![RequestMessage {
231// role: Role::User,
232// content: "Ping".to_string(),
233// }],
234// stream: true,
235// system: "Respond to ping with pong".to_string(),
236// max_tokens: 4096,
237// };
238
239// let stream = stream_completion(
240// &http_client,
241// "https://api.anthropic.com",
242// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
243// request,
244// )
245// .await
246// .unwrap();
247
248// stream
249// .for_each(|event| async {
250// match event {
251// Ok(event) => println!("{:?}", event),
252// Err(e) => eprintln!("Error: {:?}", e),
253// }
254// })
255// .await;
256// }
257// }