1use anyhow::{anyhow, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use serde::{Deserialize, Serialize};
6use std::{convert::TryFrom, time::Duration};
7use strum::EnumIter;
8
9pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
10
11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
13pub enum Model {
14 #[default]
15 #[serde(alias = "claude-3-5-sonnet", rename = "claude-3-5-sonnet-20240620")]
16 Claude3_5Sonnet,
17 #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
18 Claude3Opus,
19 #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
20 Claude3Sonnet,
21 #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
22 Claude3Haiku,
23 #[serde(rename = "custom")]
24 Custom { name: String, max_tokens: usize },
25}
26
27impl Model {
28 pub fn from_id(id: &str) -> Result<Self> {
29 if id.starts_with("claude-3-5-sonnet") {
30 Ok(Self::Claude3_5Sonnet)
31 } else if id.starts_with("claude-3-opus") {
32 Ok(Self::Claude3Opus)
33 } else if id.starts_with("claude-3-sonnet") {
34 Ok(Self::Claude3Sonnet)
35 } else if id.starts_with("claude-3-haiku") {
36 Ok(Self::Claude3Haiku)
37 } else {
38 Err(anyhow!("invalid model id"))
39 }
40 }
41
42 pub fn id(&self) -> &str {
43 match self {
44 Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
45 Model::Claude3Opus => "claude-3-opus-20240229",
46 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
47 Model::Claude3Haiku => "claude-3-opus-20240307",
48 Self::Custom { name, .. } => name,
49 }
50 }
51
52 pub fn display_name(&self) -> &str {
53 match self {
54 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
55 Self::Claude3Opus => "Claude 3 Opus",
56 Self::Claude3Sonnet => "Claude 3 Sonnet",
57 Self::Claude3Haiku => "Claude 3 Haiku",
58 Self::Custom { name, .. } => name,
59 }
60 }
61
62 pub fn max_token_count(&self) -> usize {
63 match self {
64 Self::Claude3_5Sonnet
65 | Self::Claude3Opus
66 | Self::Claude3Sonnet
67 | Self::Claude3Haiku => 200_000,
68 Self::Custom { max_tokens, .. } => *max_tokens,
69 }
70 }
71}
72
73#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
74#[serde(rename_all = "lowercase")]
75pub enum Role {
76 User,
77 Assistant,
78}
79
80impl TryFrom<String> for Role {
81 type Error = anyhow::Error;
82
83 fn try_from(value: String) -> Result<Self> {
84 match value.as_str() {
85 "user" => Ok(Self::User),
86 "assistant" => Ok(Self::Assistant),
87 _ => Err(anyhow!("invalid role '{value}'")),
88 }
89 }
90}
91
92impl From<Role> for String {
93 fn from(val: Role) -> Self {
94 match val {
95 Role::User => "user".to_owned(),
96 Role::Assistant => "assistant".to_owned(),
97 }
98 }
99}
100
101#[derive(Debug, Serialize)]
102pub struct Request {
103 #[serde(serialize_with = "serialize_request_model")]
104 pub model: Model,
105 pub messages: Vec<RequestMessage>,
106 pub stream: bool,
107 pub system: String,
108 pub max_tokens: u32,
109}
110
111fn serialize_request_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
112where
113 S: serde::Serializer,
114{
115 serializer.serialize_str(&model.id())
116}
117
118#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
119pub struct RequestMessage {
120 pub role: Role,
121 pub content: String,
122}
123
124#[derive(Deserialize, Debug)]
125#[serde(tag = "type", rename_all = "snake_case")]
126pub enum ResponseEvent {
127 MessageStart {
128 message: ResponseMessage,
129 },
130 ContentBlockStart {
131 index: u32,
132 content_block: ContentBlock,
133 },
134 Ping {},
135 ContentBlockDelta {
136 index: u32,
137 delta: TextDelta,
138 },
139 ContentBlockStop {
140 index: u32,
141 },
142 MessageDelta {
143 delta: ResponseMessage,
144 usage: Usage,
145 },
146 MessageStop {},
147}
148
149#[derive(Deserialize, Debug)]
150pub struct ResponseMessage {
151 #[serde(rename = "type")]
152 pub message_type: Option<String>,
153 pub id: Option<String>,
154 pub role: Option<String>,
155 pub content: Option<Vec<String>>,
156 pub model: Option<String>,
157 pub stop_reason: Option<String>,
158 pub stop_sequence: Option<String>,
159 pub usage: Option<Usage>,
160}
161
162#[derive(Deserialize, Debug)]
163pub struct Usage {
164 pub input_tokens: Option<u32>,
165 pub output_tokens: Option<u32>,
166}
167
168#[derive(Deserialize, Debug)]
169#[serde(tag = "type", rename_all = "snake_case")]
170pub enum ContentBlock {
171 Text { text: String },
172}
173
174#[derive(Deserialize, Debug)]
175#[serde(tag = "type", rename_all = "snake_case")]
176pub enum TextDelta {
177 TextDelta { text: String },
178}
179
180pub async fn stream_completion(
181 client: &dyn HttpClient,
182 api_url: &str,
183 api_key: &str,
184 request: Request,
185 low_speed_timeout: Option<Duration>,
186) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
187 let uri = format!("{api_url}/v1/messages");
188 let mut request_builder = HttpRequest::builder()
189 .method(Method::POST)
190 .uri(uri)
191 .header("Anthropic-Version", "2023-06-01")
192 .header("Anthropic-Beta", "tools-2024-04-04")
193 .header("X-Api-Key", api_key)
194 .header("Content-Type", "application/json");
195 if let Some(low_speed_timeout) = low_speed_timeout {
196 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
197 }
198 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
199 let mut response = client.send(request).await?;
200 if response.status().is_success() {
201 let reader = BufReader::new(response.into_body());
202 Ok(reader
203 .lines()
204 .filter_map(|line| async move {
205 match line {
206 Ok(line) => {
207 let line = line.strip_prefix("data: ")?;
208 match serde_json::from_str(line) {
209 Ok(response) => Some(Ok(response)),
210 Err(error) => Some(Err(anyhow!(error))),
211 }
212 }
213 Err(error) => Some(Err(anyhow!(error))),
214 }
215 })
216 .boxed())
217 } else {
218 let mut body = Vec::new();
219 response.body_mut().read_to_end(&mut body).await?;
220
221 let body_str = std::str::from_utf8(&body)?;
222
223 match serde_json::from_str::<ResponseEvent>(body_str) {
224 Ok(_) => Err(anyhow!(
225 "Unexpected success response while expecting an error: {}",
226 body_str,
227 )),
228 Err(_) => Err(anyhow!(
229 "Failed to connect to API: {} {}",
230 response.status(),
231 body_str,
232 )),
233 }
234 }
235}
236
237// #[cfg(test)]
238// mod tests {
239// use super::*;
240// use http::IsahcHttpClient;
241
242// #[tokio::test]
243// async fn stream_completion_success() {
244// let http_client = IsahcHttpClient::new().unwrap();
245
246// let request = Request {
247// model: Model::Claude3Opus,
248// messages: vec![RequestMessage {
249// role: Role::User,
250// content: "Ping".to_string(),
251// }],
252// stream: true,
253// system: "Respond to ping with pong".to_string(),
254// max_tokens: 4096,
255// };
256
257// let stream = stream_completion(
258// &http_client,
259// "https://api.anthropic.com",
260// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
261// request,
262// )
263// .await
264// .unwrap();
265
266// stream
267// .for_each(|event| async {
268// match event {
269// Ok(event) => println!("{:?}", event),
270// Err(e) => eprintln!("Error: {:?}", e),
271// }
272// })
273// .await;
274// }
275// }