1use anyhow::{anyhow, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use serde::{Deserialize, Serialize};
6use std::{convert::TryFrom, time::Duration};
7
8pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
9
10#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
11#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
12pub enum Model {
13 #[default]
14 #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
15 Claude3Opus,
16 #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
17 Claude3Sonnet,
18 #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
19 Claude3Haiku,
20}
21
22impl Model {
23 pub fn from_id(id: &str) -> Result<Self> {
24 if id.starts_with("claude-3-opus") {
25 Ok(Self::Claude3Opus)
26 } else if id.starts_with("claude-3-sonnet") {
27 Ok(Self::Claude3Sonnet)
28 } else if id.starts_with("claude-3-haiku") {
29 Ok(Self::Claude3Haiku)
30 } else {
31 Err(anyhow!("Invalid model id: {}", id))
32 }
33 }
34
35 pub fn id(&self) -> &'static str {
36 match self {
37 Model::Claude3Opus => "claude-3-opus-20240229",
38 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
39 Model::Claude3Haiku => "claude-3-opus-20240307",
40 }
41 }
42
43 pub fn display_name(&self) -> &'static str {
44 match self {
45 Self::Claude3Opus => "Claude 3 Opus",
46 Self::Claude3Sonnet => "Claude 3 Sonnet",
47 Self::Claude3Haiku => "Claude 3 Haiku",
48 }
49 }
50
51 pub fn max_token_count(&self) -> usize {
52 200_000
53 }
54}
55
56#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
57#[serde(rename_all = "lowercase")]
58pub enum Role {
59 User,
60 Assistant,
61}
62
63impl TryFrom<String> for Role {
64 type Error = anyhow::Error;
65
66 fn try_from(value: String) -> Result<Self> {
67 match value.as_str() {
68 "user" => Ok(Self::User),
69 "assistant" => Ok(Self::Assistant),
70 _ => Err(anyhow!("invalid role '{value}'")),
71 }
72 }
73}
74
75impl From<Role> for String {
76 fn from(val: Role) -> Self {
77 match val {
78 Role::User => "user".to_owned(),
79 Role::Assistant => "assistant".to_owned(),
80 }
81 }
82}
83
84#[derive(Debug, Serialize)]
85pub struct Request {
86 pub model: Model,
87 pub messages: Vec<RequestMessage>,
88 pub stream: bool,
89 pub system: String,
90 pub max_tokens: u32,
91}
92
93#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
94pub struct RequestMessage {
95 pub role: Role,
96 pub content: String,
97}
98
99#[derive(Deserialize, Debug)]
100#[serde(tag = "type", rename_all = "snake_case")]
101pub enum ResponseEvent {
102 MessageStart {
103 message: ResponseMessage,
104 },
105 ContentBlockStart {
106 index: u32,
107 content_block: ContentBlock,
108 },
109 Ping {},
110 ContentBlockDelta {
111 index: u32,
112 delta: TextDelta,
113 },
114 ContentBlockStop {
115 index: u32,
116 },
117 MessageDelta {
118 delta: ResponseMessage,
119 usage: Usage,
120 },
121 MessageStop {},
122}
123
124#[derive(Deserialize, Debug)]
125pub struct ResponseMessage {
126 #[serde(rename = "type")]
127 pub message_type: Option<String>,
128 pub id: Option<String>,
129 pub role: Option<String>,
130 pub content: Option<Vec<String>>,
131 pub model: Option<String>,
132 pub stop_reason: Option<String>,
133 pub stop_sequence: Option<String>,
134 pub usage: Option<Usage>,
135}
136
137#[derive(Deserialize, Debug)]
138pub struct Usage {
139 pub input_tokens: Option<u32>,
140 pub output_tokens: Option<u32>,
141}
142
143#[derive(Deserialize, Debug)]
144#[serde(tag = "type", rename_all = "snake_case")]
145pub enum ContentBlock {
146 Text { text: String },
147}
148
149#[derive(Deserialize, Debug)]
150#[serde(tag = "type", rename_all = "snake_case")]
151pub enum TextDelta {
152 TextDelta { text: String },
153}
154
155pub async fn stream_completion(
156 client: &dyn HttpClient,
157 api_url: &str,
158 api_key: &str,
159 request: Request,
160 low_speed_timeout: Option<Duration>,
161) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
162 let uri = format!("{api_url}/v1/messages");
163 let mut request_builder = HttpRequest::builder()
164 .method(Method::POST)
165 .uri(uri)
166 .header("Anthropic-Version", "2023-06-01")
167 .header("Anthropic-Beta", "tools-2024-04-04")
168 .header("X-Api-Key", api_key)
169 .header("Content-Type", "application/json");
170 if let Some(low_speed_timeout) = low_speed_timeout {
171 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
172 }
173 let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
174 let mut response = client.send(request).await?;
175 if response.status().is_success() {
176 let reader = BufReader::new(response.into_body());
177 Ok(reader
178 .lines()
179 .filter_map(|line| async move {
180 match line {
181 Ok(line) => {
182 let line = line.strip_prefix("data: ")?;
183 match serde_json::from_str(line) {
184 Ok(response) => Some(Ok(response)),
185 Err(error) => Some(Err(anyhow!(error))),
186 }
187 }
188 Err(error) => Some(Err(anyhow!(error))),
189 }
190 })
191 .boxed())
192 } else {
193 let mut body = Vec::new();
194 response.body_mut().read_to_end(&mut body).await?;
195
196 let body_str = std::str::from_utf8(&body)?;
197
198 match serde_json::from_str::<ResponseEvent>(body_str) {
199 Ok(_) => Err(anyhow!(
200 "Unexpected success response while expecting an error: {}",
201 body_str,
202 )),
203 Err(_) => Err(anyhow!(
204 "Failed to connect to API: {} {}",
205 response.status(),
206 body_str,
207 )),
208 }
209 }
210}
211
212// #[cfg(test)]
213// mod tests {
214// use super::*;
215// use http::IsahcHttpClient;
216
217// #[tokio::test]
218// async fn stream_completion_success() {
219// let http_client = IsahcHttpClient::new().unwrap();
220
221// let request = Request {
222// model: Model::Claude3Opus,
223// messages: vec![RequestMessage {
224// role: Role::User,
225// content: "Ping".to_string(),
226// }],
227// stream: true,
228// system: "Respond to ping with pong".to_string(),
229// max_tokens: 4096,
230// };
231
232// let stream = stream_completion(
233// &http_client,
234// "https://api.anthropic.com",
235// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
236// request,
237// )
238// .await
239// .unwrap();
240
241// stream
242// .for_each(|event| async {
243// match event {
244// Ok(event) => println!("{:?}", event),
245// Err(e) => eprintln!("Error: {:?}", e),
246// }
247// })
248// .await;
249// }
250// }