1mod supported_countries;
2
3use anyhow::{anyhow, Context, Result};
4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6use isahc::config::Configurable;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9use std::{pin::Pin, str::FromStr};
10use strum::{EnumIter, EnumString};
11use thiserror::Error;
12
13pub use supported_countries::*;
14
15pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
16
17#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
18#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
19pub struct AnthropicModelCacheConfiguration {
20 pub min_total_token: usize,
21 pub should_speculate: bool,
22 pub max_cache_anchors: usize,
23}
24
25#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
26#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
27pub enum Model {
28 #[default]
29 #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
30 Claude3_5Sonnet,
31 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
32 Claude3Opus,
33 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
34 Claude3Sonnet,
35 #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
36 Claude3Haiku,
37 #[serde(rename = "custom")]
38 Custom {
39 name: String,
40 max_tokens: usize,
41 /// Override this model with a different Anthropic model for tool calls.
42 tool_override: Option<String>,
43 /// Indicates whether this custom model supports caching.
44 cache_configuration: Option<AnthropicModelCacheConfiguration>,
45 max_output_tokens: Option<u32>,
46 },
47}
48
49impl Model {
50 pub fn from_id(id: &str) -> Result<Self> {
51 if id.starts_with("claude-3-5-sonnet") {
52 Ok(Self::Claude3_5Sonnet)
53 } else if id.starts_with("claude-3-opus") {
54 Ok(Self::Claude3Opus)
55 } else if id.starts_with("claude-3-sonnet") {
56 Ok(Self::Claude3Sonnet)
57 } else if id.starts_with("claude-3-haiku") {
58 Ok(Self::Claude3Haiku)
59 } else {
60 Err(anyhow!("invalid model id"))
61 }
62 }
63
64 pub fn id(&self) -> &str {
65 match self {
66 Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
67 Model::Claude3Opus => "claude-3-opus-20240229",
68 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
69 Model::Claude3Haiku => "claude-3-haiku-20240307",
70 Self::Custom { name, .. } => name,
71 }
72 }
73
74 pub fn display_name(&self) -> &str {
75 match self {
76 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
77 Self::Claude3Opus => "Claude 3 Opus",
78 Self::Claude3Sonnet => "Claude 3 Sonnet",
79 Self::Claude3Haiku => "Claude 3 Haiku",
80 Self::Custom { name, .. } => name,
81 }
82 }
83
84 pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
85 match self {
86 Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
87 min_total_token: 2_048,
88 should_speculate: true,
89 max_cache_anchors: 4,
90 }),
91 Self::Custom {
92 cache_configuration,
93 ..
94 } => cache_configuration.clone(),
95 _ => None,
96 }
97 }
98
99 pub fn max_token_count(&self) -> usize {
100 match self {
101 Self::Claude3_5Sonnet
102 | Self::Claude3Opus
103 | Self::Claude3Sonnet
104 | Self::Claude3Haiku => 200_000,
105 Self::Custom { max_tokens, .. } => *max_tokens,
106 }
107 }
108
109 pub fn max_output_tokens(&self) -> u32 {
110 match self {
111 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
112 Self::Claude3_5Sonnet => 8_192,
113 Self::Custom {
114 max_output_tokens, ..
115 } => max_output_tokens.unwrap_or(4_096),
116 }
117 }
118
119 pub fn tool_model_id(&self) -> &str {
120 if let Self::Custom {
121 tool_override: Some(tool_override),
122 ..
123 } = self
124 {
125 tool_override
126 } else {
127 self.id()
128 }
129 }
130}
131
132pub async fn complete(
133 client: &dyn HttpClient,
134 api_url: &str,
135 api_key: &str,
136 request: Request,
137) -> Result<Response, AnthropicError> {
138 let uri = format!("{api_url}/v1/messages");
139 let request_builder = HttpRequest::builder()
140 .method(Method::POST)
141 .uri(uri)
142 .header("Anthropic-Version", "2023-06-01")
143 .header(
144 "Anthropic-Beta",
145 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
146 )
147 .header("X-Api-Key", api_key)
148 .header("Content-Type", "application/json");
149
150 let serialized_request =
151 serde_json::to_string(&request).context("failed to serialize request")?;
152 let request = request_builder
153 .body(AsyncBody::from(serialized_request))
154 .context("failed to construct request body")?;
155
156 let mut response = client
157 .send(request)
158 .await
159 .context("failed to send request to Anthropic")?;
160 if response.status().is_success() {
161 let mut body = Vec::new();
162 response
163 .body_mut()
164 .read_to_end(&mut body)
165 .await
166 .context("failed to read response body")?;
167 let response_message: Response =
168 serde_json::from_slice(&body).context("failed to deserialize response body")?;
169 Ok(response_message)
170 } else {
171 let mut body = Vec::new();
172 response
173 .body_mut()
174 .read_to_end(&mut body)
175 .await
176 .context("failed to read response body")?;
177 let body_str =
178 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
179 Err(AnthropicError::Other(anyhow!(
180 "Failed to connect to API: {} {}",
181 response.status(),
182 body_str
183 )))
184 }
185}
186
187pub async fn stream_completion(
188 client: &dyn HttpClient,
189 api_url: &str,
190 api_key: &str,
191 request: Request,
192 low_speed_timeout: Option<Duration>,
193) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
194 let request = StreamingRequest {
195 base: request,
196 stream: true,
197 };
198 let uri = format!("{api_url}/v1/messages");
199 let mut request_builder = HttpRequest::builder()
200 .method(Method::POST)
201 .uri(uri)
202 .header("Anthropic-Version", "2023-06-01")
203 .header(
204 "Anthropic-Beta",
205 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
206 )
207 .header("X-Api-Key", api_key)
208 .header("Content-Type", "application/json");
209 if let Some(low_speed_timeout) = low_speed_timeout {
210 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
211 }
212 let serialized_request =
213 serde_json::to_string(&request).context("failed to serialize request")?;
214 let request = request_builder
215 .body(AsyncBody::from(serialized_request))
216 .context("failed to construct request body")?;
217
218 let mut response = client
219 .send(request)
220 .await
221 .context("failed to send request to Anthropic")?;
222 if response.status().is_success() {
223 let reader = BufReader::new(response.into_body());
224 Ok(reader
225 .lines()
226 .filter_map(|line| async move {
227 match line {
228 Ok(line) => {
229 let line = line.strip_prefix("data: ")?;
230 match serde_json::from_str(line) {
231 Ok(response) => Some(Ok(response)),
232 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
233 }
234 }
235 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
236 }
237 })
238 .boxed())
239 } else {
240 let mut body = Vec::new();
241 response
242 .body_mut()
243 .read_to_end(&mut body)
244 .await
245 .context("failed to read response body")?;
246
247 let body_str =
248 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
249
250 match serde_json::from_str::<Event>(body_str) {
251 Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
252 Ok(_) => Err(AnthropicError::Other(anyhow!(
253 "Unexpected success response while expecting an error: '{body_str}'",
254 ))),
255 Err(_) => Err(AnthropicError::Other(anyhow!(
256 "Failed to connect to API: {} {}",
257 response.status(),
258 body_str,
259 ))),
260 }
261 }
262}
263
264pub fn extract_text_from_events(
265 response: impl Stream<Item = Result<Event, AnthropicError>>,
266) -> impl Stream<Item = Result<String, AnthropicError>> {
267 response.filter_map(|response| async move {
268 match response {
269 Ok(response) => match response {
270 Event::ContentBlockStart { content_block, .. } => match content_block {
271 Content::Text { text, .. } => Some(Ok(text)),
272 _ => None,
273 },
274 Event::ContentBlockDelta { delta, .. } => match delta {
275 ContentDelta::TextDelta { text } => Some(Ok(text)),
276 _ => None,
277 },
278 Event::Error { error } => Some(Err(AnthropicError::ApiError(error))),
279 _ => None,
280 },
281 Err(error) => Some(Err(error)),
282 }
283 })
284}
285
286pub async fn extract_tool_args_from_events(
287 tool_name: String,
288 mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
289) -> Result<impl Send + Stream<Item = Result<String>>> {
290 let mut tool_use_index = None;
291 while let Some(event) = events.next().await {
292 if let Event::ContentBlockStart {
293 index,
294 content_block,
295 } = event?
296 {
297 if let Content::ToolUse { name, .. } = content_block {
298 if name == tool_name {
299 tool_use_index = Some(index);
300 break;
301 }
302 }
303 }
304 }
305
306 let Some(tool_use_index) = tool_use_index else {
307 return Err(anyhow!("tool not used"));
308 };
309
310 Ok(events.filter_map(move |event| {
311 let result = match event {
312 Err(error) => Some(Err(error)),
313 Ok(Event::ContentBlockDelta { index, delta }) => match delta {
314 ContentDelta::TextDelta { .. } => None,
315 ContentDelta::InputJsonDelta { partial_json } => {
316 if index == tool_use_index {
317 Some(Ok(partial_json))
318 } else {
319 None
320 }
321 }
322 },
323 _ => None,
324 };
325
326 async move { result }
327 }))
328}
329
330#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
331#[serde(rename_all = "lowercase")]
332pub enum CacheControlType {
333 Ephemeral,
334}
335
336#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
337pub struct CacheControl {
338 #[serde(rename = "type")]
339 pub cache_type: CacheControlType,
340}
341
342#[derive(Debug, Serialize, Deserialize)]
343pub struct Message {
344 pub role: Role,
345 pub content: Vec<Content>,
346}
347
348#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
349#[serde(rename_all = "lowercase")]
350pub enum Role {
351 User,
352 Assistant,
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356#[serde(tag = "type")]
357pub enum Content {
358 #[serde(rename = "text")]
359 Text {
360 text: String,
361 #[serde(skip_serializing_if = "Option::is_none")]
362 cache_control: Option<CacheControl>,
363 },
364 #[serde(rename = "image")]
365 Image {
366 source: ImageSource,
367 #[serde(skip_serializing_if = "Option::is_none")]
368 cache_control: Option<CacheControl>,
369 },
370 #[serde(rename = "tool_use")]
371 ToolUse {
372 id: String,
373 name: String,
374 input: serde_json::Value,
375 #[serde(skip_serializing_if = "Option::is_none")]
376 cache_control: Option<CacheControl>,
377 },
378 #[serde(rename = "tool_result")]
379 ToolResult {
380 tool_use_id: String,
381 content: String,
382 #[serde(skip_serializing_if = "Option::is_none")]
383 cache_control: Option<CacheControl>,
384 },
385}
386
387#[derive(Debug, Serialize, Deserialize)]
388pub struct ImageSource {
389 #[serde(rename = "type")]
390 pub source_type: String,
391 pub media_type: String,
392 pub data: String,
393}
394
395#[derive(Debug, Serialize, Deserialize)]
396pub struct Tool {
397 pub name: String,
398 pub description: String,
399 pub input_schema: serde_json::Value,
400}
401
402#[derive(Debug, Serialize, Deserialize)]
403#[serde(tag = "type", rename_all = "lowercase")]
404pub enum ToolChoice {
405 Auto,
406 Any,
407 Tool { name: String },
408}
409
410#[derive(Debug, Serialize, Deserialize)]
411pub struct Request {
412 pub model: String,
413 pub max_tokens: u32,
414 pub messages: Vec<Message>,
415 #[serde(default, skip_serializing_if = "Vec::is_empty")]
416 pub tools: Vec<Tool>,
417 #[serde(default, skip_serializing_if = "Option::is_none")]
418 pub tool_choice: Option<ToolChoice>,
419 #[serde(default, skip_serializing_if = "Option::is_none")]
420 pub system: Option<String>,
421 #[serde(default, skip_serializing_if = "Option::is_none")]
422 pub metadata: Option<Metadata>,
423 #[serde(default, skip_serializing_if = "Vec::is_empty")]
424 pub stop_sequences: Vec<String>,
425 #[serde(default, skip_serializing_if = "Option::is_none")]
426 pub temperature: Option<f32>,
427 #[serde(default, skip_serializing_if = "Option::is_none")]
428 pub top_k: Option<u32>,
429 #[serde(default, skip_serializing_if = "Option::is_none")]
430 pub top_p: Option<f32>,
431}
432
433#[derive(Debug, Serialize, Deserialize)]
434struct StreamingRequest {
435 #[serde(flatten)]
436 pub base: Request,
437 pub stream: bool,
438}
439
440#[derive(Debug, Serialize, Deserialize)]
441pub struct Metadata {
442 pub user_id: Option<String>,
443}
444
445#[derive(Debug, Serialize, Deserialize)]
446pub struct Usage {
447 #[serde(default, skip_serializing_if = "Option::is_none")]
448 pub input_tokens: Option<u32>,
449 #[serde(default, skip_serializing_if = "Option::is_none")]
450 pub output_tokens: Option<u32>,
451}
452
453#[derive(Debug, Serialize, Deserialize)]
454pub struct Response {
455 pub id: String,
456 #[serde(rename = "type")]
457 pub response_type: String,
458 pub role: Role,
459 pub content: Vec<Content>,
460 pub model: String,
461 #[serde(default, skip_serializing_if = "Option::is_none")]
462 pub stop_reason: Option<String>,
463 #[serde(default, skip_serializing_if = "Option::is_none")]
464 pub stop_sequence: Option<String>,
465 pub usage: Usage,
466}
467
468#[derive(Debug, Serialize, Deserialize)]
469#[serde(tag = "type")]
470pub enum Event {
471 #[serde(rename = "message_start")]
472 MessageStart { message: Response },
473 #[serde(rename = "content_block_start")]
474 ContentBlockStart {
475 index: usize,
476 content_block: Content,
477 },
478 #[serde(rename = "content_block_delta")]
479 ContentBlockDelta { index: usize, delta: ContentDelta },
480 #[serde(rename = "content_block_stop")]
481 ContentBlockStop { index: usize },
482 #[serde(rename = "message_delta")]
483 MessageDelta { delta: MessageDelta, usage: Usage },
484 #[serde(rename = "message_stop")]
485 MessageStop,
486 #[serde(rename = "ping")]
487 Ping,
488 #[serde(rename = "error")]
489 Error { error: ApiError },
490}
491
492#[derive(Debug, Serialize, Deserialize)]
493#[serde(tag = "type")]
494pub enum ContentDelta {
495 #[serde(rename = "text_delta")]
496 TextDelta { text: String },
497 #[serde(rename = "input_json_delta")]
498 InputJsonDelta { partial_json: String },
499}
500
501#[derive(Debug, Serialize, Deserialize)]
502pub struct MessageDelta {
503 pub stop_reason: Option<String>,
504 pub stop_sequence: Option<String>,
505}
506
507#[derive(Error, Debug)]
508pub enum AnthropicError {
509 #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
510 ApiError(ApiError),
511 #[error("{0}")]
512 Other(#[from] anyhow::Error),
513}
514
515#[derive(Debug, Serialize, Deserialize)]
516pub struct ApiError {
517 #[serde(rename = "type")]
518 pub error_type: String,
519 pub message: String,
520}
521
522/// An Anthropic API error code.
523/// https://docs.anthropic.com/en/api/errors#http-errors
524#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
525#[strum(serialize_all = "snake_case")]
526pub enum ApiErrorCode {
527 /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
528 InvalidRequestError,
529 /// 401 - `authentication_error`: There's an issue with your API key.
530 AuthenticationError,
531 /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
532 PermissionError,
533 /// 404 - `not_found_error`: The requested resource was not found.
534 NotFoundError,
535 /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
536 RequestTooLarge,
537 /// 429 - `rate_limit_error`: Your account has hit a rate limit.
538 RateLimitError,
539 /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
540 ApiError,
541 /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
542 OverloadedError,
543}
544
545impl ApiError {
546 pub fn code(&self) -> Option<ApiErrorCode> {
547 ApiErrorCode::from_str(&self.error_type).ok()
548 }
549
550 pub fn is_rate_limit_error(&self) -> bool {
551 match self.error_type.as_str() {
552 "rate_limit_error" => true,
553 _ => false,
554 }
555 }
556}