1use anyhow::{anyhow, Result};
2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
3use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
4use isahc::config::Configurable;
5use serde::{Deserialize, Serialize};
6use std::time::Duration;
7use strum::EnumIter;
8
9pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
10
11#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
13pub enum Model {
14 #[default]
15 #[serde(alias = "claude-3-5-sonnet", rename = "claude-3-5-sonnet-20240620")]
16 Claude3_5Sonnet,
17 #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
18 Claude3Opus,
19 #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
20 Claude3Sonnet,
21 #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
22 Claude3Haiku,
23 #[serde(rename = "custom")]
24 Custom {
25 name: String,
26 max_tokens: usize,
27 /// Override this model with a different Anthropic model for tool calls.
28 tool_override: Option<String>,
29 },
30}
31
32impl Model {
33 pub fn from_id(id: &str) -> Result<Self> {
34 if id.starts_with("claude-3-5-sonnet") {
35 Ok(Self::Claude3_5Sonnet)
36 } else if id.starts_with("claude-3-opus") {
37 Ok(Self::Claude3Opus)
38 } else if id.starts_with("claude-3-sonnet") {
39 Ok(Self::Claude3Sonnet)
40 } else if id.starts_with("claude-3-haiku") {
41 Ok(Self::Claude3Haiku)
42 } else {
43 Err(anyhow!("invalid model id"))
44 }
45 }
46
47 pub fn id(&self) -> &str {
48 match self {
49 Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
50 Model::Claude3Opus => "claude-3-opus-20240229",
51 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
52 Model::Claude3Haiku => "claude-3-opus-20240307",
53 Self::Custom { name, .. } => name,
54 }
55 }
56
57 pub fn display_name(&self) -> &str {
58 match self {
59 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
60 Self::Claude3Opus => "Claude 3 Opus",
61 Self::Claude3Sonnet => "Claude 3 Sonnet",
62 Self::Claude3Haiku => "Claude 3 Haiku",
63 Self::Custom { name, .. } => name,
64 }
65 }
66
67 pub fn max_token_count(&self) -> usize {
68 match self {
69 Self::Claude3_5Sonnet
70 | Self::Claude3Opus
71 | Self::Claude3Sonnet
72 | Self::Claude3Haiku => 200_000,
73 Self::Custom { max_tokens, .. } => *max_tokens,
74 }
75 }
76
77 pub fn tool_model_id(&self) -> &str {
78 if let Self::Custom {
79 tool_override: Some(tool_override),
80 ..
81 } = self
82 {
83 tool_override
84 } else {
85 self.id()
86 }
87 }
88}
89
90pub async fn complete(
91 client: &dyn HttpClient,
92 api_url: &str,
93 api_key: &str,
94 request: Request,
95) -> Result<Response> {
96 let uri = format!("{api_url}/v1/messages");
97 let request_builder = HttpRequest::builder()
98 .method(Method::POST)
99 .uri(uri)
100 .header("Anthropic-Version", "2023-06-01")
101 .header("Anthropic-Beta", "tools-2024-04-04")
102 .header("X-Api-Key", api_key)
103 .header("Content-Type", "application/json");
104
105 let serialized_request = serde_json::to_string(&request)?;
106 let request = request_builder.body(AsyncBody::from(serialized_request))?;
107
108 let mut response = client.send(request).await?;
109 if response.status().is_success() {
110 let mut body = Vec::new();
111 response.body_mut().read_to_end(&mut body).await?;
112 let response_message: Response = serde_json::from_slice(&body)?;
113 Ok(response_message)
114 } else {
115 let mut body = Vec::new();
116 response.body_mut().read_to_end(&mut body).await?;
117 let body_str = std::str::from_utf8(&body)?;
118 Err(anyhow!(
119 "Failed to connect to API: {} {}",
120 response.status(),
121 body_str
122 ))
123 }
124}
125
126pub async fn stream_completion(
127 client: &dyn HttpClient,
128 api_url: &str,
129 api_key: &str,
130 request: Request,
131 low_speed_timeout: Option<Duration>,
132) -> Result<BoxStream<'static, Result<Event>>> {
133 let request = StreamingRequest {
134 base: request,
135 stream: true,
136 };
137 let uri = format!("{api_url}/v1/messages");
138 let mut request_builder = HttpRequest::builder()
139 .method(Method::POST)
140 .uri(uri)
141 .header("Anthropic-Version", "2023-06-01")
142 .header("Anthropic-Beta", "tools-2024-04-04")
143 .header("X-Api-Key", api_key)
144 .header("Content-Type", "application/json");
145 if let Some(low_speed_timeout) = low_speed_timeout {
146 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
147 }
148 let serialized_request = serde_json::to_string(&request)?;
149 let request = request_builder.body(AsyncBody::from(serialized_request))?;
150
151 let mut response = client.send(request).await?;
152 if response.status().is_success() {
153 let reader = BufReader::new(response.into_body());
154 Ok(reader
155 .lines()
156 .filter_map(|line| async move {
157 match line {
158 Ok(line) => {
159 let line = line.strip_prefix("data: ")?;
160 match serde_json::from_str(line) {
161 Ok(response) => Some(Ok(response)),
162 Err(error) => Some(Err(anyhow!(error))),
163 }
164 }
165 Err(error) => Some(Err(anyhow!(error))),
166 }
167 })
168 .boxed())
169 } else {
170 let mut body = Vec::new();
171 response.body_mut().read_to_end(&mut body).await?;
172
173 let body_str = std::str::from_utf8(&body)?;
174
175 match serde_json::from_str::<Event>(body_str) {
176 Ok(Event::Error { error }) => Err(api_error_to_err(error)),
177 Ok(_) => Err(anyhow!(
178 "Unexpected success response while expecting an error: '{body_str}'",
179 )),
180 Err(_) => Err(anyhow!(
181 "Failed to connect to API: {} {}",
182 response.status(),
183 body_str,
184 )),
185 }
186 }
187}
188
189pub fn extract_text_from_events(
190 response: impl Stream<Item = Result<Event>>,
191) -> impl Stream<Item = Result<String>> {
192 response.filter_map(|response| async move {
193 match response {
194 Ok(response) => match response {
195 Event::ContentBlockStart { content_block, .. } => match content_block {
196 Content::Text { text } => Some(Ok(text)),
197 _ => None,
198 },
199 Event::ContentBlockDelta { delta, .. } => match delta {
200 ContentDelta::TextDelta { text } => Some(Ok(text)),
201 _ => None,
202 },
203 Event::Error { error } => Some(Err(api_error_to_err(error))),
204 _ => None,
205 },
206 Err(error) => Some(Err(error)),
207 }
208 })
209}
210
211fn api_error_to_err(
212 ApiError {
213 error_type,
214 message,
215 }: ApiError,
216) -> anyhow::Error {
217 anyhow!("API error. Type: '{error_type}', message: '{message}'",)
218}
219
220#[derive(Debug, Serialize, Deserialize)]
221pub struct Message {
222 pub role: Role,
223 pub content: Vec<Content>,
224}
225
226#[derive(Debug, Serialize, Deserialize)]
227#[serde(rename_all = "lowercase")]
228pub enum Role {
229 User,
230 Assistant,
231}
232
233#[derive(Debug, Serialize, Deserialize)]
234#[serde(tag = "type")]
235pub enum Content {
236 #[serde(rename = "text")]
237 Text { text: String },
238 #[serde(rename = "image")]
239 Image { source: ImageSource },
240 #[serde(rename = "tool_use")]
241 ToolUse {
242 id: String,
243 name: String,
244 input: serde_json::Value,
245 },
246 #[serde(rename = "tool_result")]
247 ToolResult {
248 tool_use_id: String,
249 content: String,
250 },
251}
252
253#[derive(Debug, Serialize, Deserialize)]
254pub struct ImageSource {
255 #[serde(rename = "type")]
256 pub source_type: String,
257 pub media_type: String,
258 pub data: String,
259}
260
261#[derive(Debug, Serialize, Deserialize)]
262pub struct Tool {
263 pub name: String,
264 pub description: String,
265 pub input_schema: serde_json::Value,
266}
267
268#[derive(Debug, Serialize, Deserialize)]
269#[serde(tag = "type", rename_all = "lowercase")]
270pub enum ToolChoice {
271 Auto,
272 Any,
273 Tool { name: String },
274}
275
276#[derive(Debug, Serialize, Deserialize)]
277pub struct Request {
278 pub model: String,
279 pub max_tokens: u32,
280 pub messages: Vec<Message>,
281 #[serde(default, skip_serializing_if = "Vec::is_empty")]
282 pub tools: Vec<Tool>,
283 #[serde(default, skip_serializing_if = "Option::is_none")]
284 pub tool_choice: Option<ToolChoice>,
285 #[serde(default, skip_serializing_if = "Option::is_none")]
286 pub system: Option<String>,
287 #[serde(default, skip_serializing_if = "Option::is_none")]
288 pub metadata: Option<Metadata>,
289 #[serde(default, skip_serializing_if = "Vec::is_empty")]
290 pub stop_sequences: Vec<String>,
291 #[serde(default, skip_serializing_if = "Option::is_none")]
292 pub temperature: Option<f32>,
293 #[serde(default, skip_serializing_if = "Option::is_none")]
294 pub top_k: Option<u32>,
295 #[serde(default, skip_serializing_if = "Option::is_none")]
296 pub top_p: Option<f32>,
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300struct StreamingRequest {
301 #[serde(flatten)]
302 pub base: Request,
303 pub stream: bool,
304}
305
306#[derive(Debug, Serialize, Deserialize)]
307pub struct Metadata {
308 pub user_id: Option<String>,
309}
310
311#[derive(Debug, Serialize, Deserialize)]
312pub struct Usage {
313 #[serde(default, skip_serializing_if = "Option::is_none")]
314 pub input_tokens: Option<u32>,
315 #[serde(default, skip_serializing_if = "Option::is_none")]
316 pub output_tokens: Option<u32>,
317}
318
319#[derive(Debug, Serialize, Deserialize)]
320pub struct Response {
321 pub id: String,
322 #[serde(rename = "type")]
323 pub response_type: String,
324 pub role: Role,
325 pub content: Vec<Content>,
326 pub model: String,
327 #[serde(default, skip_serializing_if = "Option::is_none")]
328 pub stop_reason: Option<String>,
329 #[serde(default, skip_serializing_if = "Option::is_none")]
330 pub stop_sequence: Option<String>,
331 pub usage: Usage,
332}
333
334#[derive(Debug, Serialize, Deserialize)]
335#[serde(tag = "type")]
336pub enum Event {
337 #[serde(rename = "message_start")]
338 MessageStart { message: Response },
339 #[serde(rename = "content_block_start")]
340 ContentBlockStart {
341 index: usize,
342 content_block: Content,
343 },
344 #[serde(rename = "content_block_delta")]
345 ContentBlockDelta { index: usize, delta: ContentDelta },
346 #[serde(rename = "content_block_stop")]
347 ContentBlockStop { index: usize },
348 #[serde(rename = "message_delta")]
349 MessageDelta { delta: MessageDelta, usage: Usage },
350 #[serde(rename = "message_stop")]
351 MessageStop,
352 #[serde(rename = "ping")]
353 Ping,
354 #[serde(rename = "error")]
355 Error { error: ApiError },
356}
357
358#[derive(Debug, Serialize, Deserialize)]
359#[serde(tag = "type")]
360pub enum ContentDelta {
361 #[serde(rename = "text_delta")]
362 TextDelta { text: String },
363 #[serde(rename = "input_json_delta")]
364 InputJsonDelta { partial_json: String },
365}
366
367#[derive(Debug, Serialize, Deserialize)]
368pub struct MessageDelta {
369 pub stop_reason: Option<String>,
370 pub stop_sequence: Option<String>,
371}
372
373#[derive(Debug, Serialize, Deserialize)]
374pub struct ApiError {
375 #[serde(rename = "type")]
376 pub error_type: String,
377 pub message: String,
378}