1use anyhow::{anyhow, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
3use serde::{Deserialize, Serialize};
4use std::convert::TryFrom;
5use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6
7#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
8pub enum Model {
9 #[default]
10 #[serde(rename = "claude-3-opus-20240229")]
11 Claude3Opus,
12 #[serde(rename = "claude-3-sonnet-20240229")]
13 Claude3Sonnet,
14 #[serde(rename = "claude-3-haiku-20240307")]
15 Claude3Haiku,
16}
17
18impl Model {
19 pub fn from_id(id: &str) -> Result<Self> {
20 if id.starts_with("claude-3-opus") {
21 Ok(Self::Claude3Opus)
22 } else if id.starts_with("claude-3-sonnet") {
23 Ok(Self::Claude3Sonnet)
24 } else if id.starts_with("claude-3-haiku") {
25 Ok(Self::Claude3Haiku)
26 } else {
27 Err(anyhow!("Invalid model id: {}", id))
28 }
29 }
30
31 pub fn display_name(&self) -> &'static str {
32 match self {
33 Self::Claude3Opus => "Claude 3 Opus",
34 Self::Claude3Sonnet => "Claude 3 Sonnet",
35 Self::Claude3Haiku => "Claude 3 Haiku",
36 }
37 }
38
39 pub fn max_token_count(&self) -> usize {
40 200_000
41 }
42}
43
44#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
45#[serde(rename_all = "lowercase")]
46pub enum Role {
47 User,
48 Assistant,
49}
50
51impl TryFrom<String> for Role {
52 type Error = anyhow::Error;
53
54 fn try_from(value: String) -> Result<Self> {
55 match value.as_str() {
56 "user" => Ok(Self::User),
57 "assistant" => Ok(Self::Assistant),
58 _ => Err(anyhow!("invalid role '{value}'")),
59 }
60 }
61}
62
63impl From<Role> for String {
64 fn from(val: Role) -> Self {
65 match val {
66 Role::User => "user".to_owned(),
67 Role::Assistant => "assistant".to_owned(),
68 }
69 }
70}
71
72#[derive(Debug, Serialize)]
73pub struct Request {
74 pub model: Model,
75 pub messages: Vec<RequestMessage>,
76 pub stream: bool,
77 pub system: String,
78 pub max_tokens: u32,
79}
80
81#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
82pub struct RequestMessage {
83 pub role: Role,
84 pub content: String,
85}
86
87#[derive(Deserialize, Debug)]
88#[serde(tag = "type", rename_all = "snake_case")]
89pub enum ResponseEvent {
90 MessageStart {
91 message: ResponseMessage,
92 },
93 ContentBlockStart {
94 index: u32,
95 content_block: ContentBlock,
96 },
97 Ping {},
98 ContentBlockDelta {
99 index: u32,
100 delta: TextDelta,
101 },
102 ContentBlockStop {
103 index: u32,
104 },
105 MessageDelta {
106 delta: ResponseMessage,
107 usage: Usage,
108 },
109 MessageStop {},
110}
111
112#[derive(Deserialize, Debug)]
113pub struct ResponseMessage {
114 #[serde(rename = "type")]
115 pub message_type: Option<String>,
116 pub id: Option<String>,
117 pub role: Option<String>,
118 pub content: Option<Vec<String>>,
119 pub model: Option<String>,
120 pub stop_reason: Option<String>,
121 pub stop_sequence: Option<String>,
122 pub usage: Option<Usage>,
123}
124
125#[derive(Deserialize, Debug)]
126pub struct Usage {
127 pub input_tokens: Option<u32>,
128 pub output_tokens: Option<u32>,
129}
130
131#[derive(Deserialize, Debug)]
132#[serde(tag = "type", rename_all = "snake_case")]
133pub enum ContentBlock {
134 Text { text: String },
135}
136
137#[derive(Deserialize, Debug)]
138#[serde(tag = "type", rename_all = "snake_case")]
139pub enum TextDelta {
140 TextDelta { text: String },
141}
142
143pub async fn stream_completion(
144 client: &dyn HttpClient,
145 api_url: &str,
146 api_key: &str,
147 request: Request,
148) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
149 let uri = format!("{api_url}/v1/messages");
150 let request = HttpRequest::builder()
151 .method(Method::POST)
152 .uri(uri)
153 .header("Anthropic-Version", "2023-06-01")
154 .header("Anthropic-Beta", "messages-2023-12-15")
155 .header("X-Api-Key", api_key)
156 .header("Content-Type", "application/json")
157 .body(AsyncBody::from(serde_json::to_string(&request)?))?;
158 let mut response = client.send(request).await?;
159 if response.status().is_success() {
160 let reader = BufReader::new(response.into_body());
161 Ok(reader
162 .lines()
163 .filter_map(|line| async move {
164 match line {
165 Ok(line) => {
166 let line = line.strip_prefix("data: ")?;
167 match serde_json::from_str(line) {
168 Ok(response) => Some(Ok(response)),
169 Err(error) => Some(Err(anyhow!(error))),
170 }
171 }
172 Err(error) => Some(Err(anyhow!(error))),
173 }
174 })
175 .boxed())
176 } else {
177 let mut body = Vec::new();
178 response.body_mut().read_to_end(&mut body).await?;
179
180 let body_str = std::str::from_utf8(&body)?;
181
182 match serde_json::from_str::<ResponseEvent>(body_str) {
183 Ok(_) => Err(anyhow!(
184 "Unexpected success response while expecting an error: {}",
185 body_str,
186 )),
187 Err(_) => Err(anyhow!(
188 "Failed to connect to API: {} {}",
189 response.status(),
190 body_str,
191 )),
192 }
193 }
194}
195
196// #[cfg(test)]
197// mod tests {
198// use super::*;
199// use util::http::IsahcHttpClient;
200
201// #[tokio::test]
202// async fn stream_completion_success() {
203// let http_client = IsahcHttpClient::new().unwrap();
204
205// let request = Request {
206// model: Model::Claude3Opus,
207// messages: vec![RequestMessage {
208// role: Role::User,
209// content: "Ping".to_string(),
210// }],
211// stream: true,
212// system: "Respond to ping with pong".to_string(),
213// max_tokens: 4096,
214// };
215
216// let stream = stream_completion(
217// &http_client,
218// "https://api.anthropic.com",
219// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
220// request,
221// )
222// .await
223// .unwrap();
224
225// stream
226// .for_each(|event| async {
227// match event {
228// Ok(event) => println!("{:?}", event),
229// Err(e) => eprintln!("Error: {:?}", e),
230// }
231// })
232// .await;
233// }
234// }