1mod supported_countries;
2
3use anyhow::{anyhow, Context, Result};
4use chrono::{DateTime, Utc};
5use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
6use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
7use isahc::config::Configurable;
8use isahc::http::{HeaderMap, HeaderValue};
9use serde::{Deserialize, Serialize};
10use std::time::Duration;
11use std::{pin::Pin, str::FromStr};
12use strum::{EnumIter, EnumString};
13use thiserror::Error;
14use util::ResultExt as _;
15
16pub use supported_countries::*;
17
18pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
19
20#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
21#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
22pub struct AnthropicModelCacheConfiguration {
23 pub min_total_token: usize,
24 pub should_speculate: bool,
25 pub max_cache_anchors: usize,
26}
27
28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
30pub enum Model {
31 #[default]
32 #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
33 Claude3_5Sonnet,
34 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
35 Claude3Opus,
36 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
37 Claude3Sonnet,
38 #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
39 Claude3Haiku,
40 #[serde(rename = "custom")]
41 Custom {
42 name: String,
43 max_tokens: usize,
44 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
45 display_name: Option<String>,
46 /// Override this model with a different Anthropic model for tool calls.
47 tool_override: Option<String>,
48 /// Indicates whether this custom model supports caching.
49 cache_configuration: Option<AnthropicModelCacheConfiguration>,
50 max_output_tokens: Option<u32>,
51 },
52}
53
54impl Model {
55 pub fn from_id(id: &str) -> Result<Self> {
56 if id.starts_with("claude-3-5-sonnet") {
57 Ok(Self::Claude3_5Sonnet)
58 } else if id.starts_with("claude-3-opus") {
59 Ok(Self::Claude3Opus)
60 } else if id.starts_with("claude-3-sonnet") {
61 Ok(Self::Claude3Sonnet)
62 } else if id.starts_with("claude-3-haiku") {
63 Ok(Self::Claude3Haiku)
64 } else {
65 Err(anyhow!("invalid model id"))
66 }
67 }
68
69 pub fn id(&self) -> &str {
70 match self {
71 Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
72 Model::Claude3Opus => "claude-3-opus-20240229",
73 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
74 Model::Claude3Haiku => "claude-3-haiku-20240307",
75 Self::Custom { name, .. } => name,
76 }
77 }
78
79 pub fn display_name(&self) -> &str {
80 match self {
81 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
82 Self::Claude3Opus => "Claude 3 Opus",
83 Self::Claude3Sonnet => "Claude 3 Sonnet",
84 Self::Claude3Haiku => "Claude 3 Haiku",
85 Self::Custom {
86 name, display_name, ..
87 } => display_name.as_ref().unwrap_or(name),
88 }
89 }
90
91 pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
92 match self {
93 Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
94 min_total_token: 2_048,
95 should_speculate: true,
96 max_cache_anchors: 4,
97 }),
98 Self::Custom {
99 cache_configuration,
100 ..
101 } => cache_configuration.clone(),
102 _ => None,
103 }
104 }
105
106 pub fn max_token_count(&self) -> usize {
107 match self {
108 Self::Claude3_5Sonnet
109 | Self::Claude3Opus
110 | Self::Claude3Sonnet
111 | Self::Claude3Haiku => 200_000,
112 Self::Custom { max_tokens, .. } => *max_tokens,
113 }
114 }
115
116 pub fn max_output_tokens(&self) -> u32 {
117 match self {
118 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
119 Self::Claude3_5Sonnet => 8_192,
120 Self::Custom {
121 max_output_tokens, ..
122 } => max_output_tokens.unwrap_or(4_096),
123 }
124 }
125
126 pub fn tool_model_id(&self) -> &str {
127 if let Self::Custom {
128 tool_override: Some(tool_override),
129 ..
130 } = self
131 {
132 tool_override
133 } else {
134 self.id()
135 }
136 }
137}
138
139pub async fn complete(
140 client: &dyn HttpClient,
141 api_url: &str,
142 api_key: &str,
143 request: Request,
144) -> Result<Response, AnthropicError> {
145 let uri = format!("{api_url}/v1/messages");
146 let request_builder = HttpRequest::builder()
147 .method(Method::POST)
148 .uri(uri)
149 .header("Anthropic-Version", "2023-06-01")
150 .header(
151 "Anthropic-Beta",
152 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
153 )
154 .header("X-Api-Key", api_key)
155 .header("Content-Type", "application/json");
156
157 let serialized_request =
158 serde_json::to_string(&request).context("failed to serialize request")?;
159 let request = request_builder
160 .body(AsyncBody::from(serialized_request))
161 .context("failed to construct request body")?;
162
163 let mut response = client
164 .send(request)
165 .await
166 .context("failed to send request to Anthropic")?;
167 if response.status().is_success() {
168 let mut body = Vec::new();
169 response
170 .body_mut()
171 .read_to_end(&mut body)
172 .await
173 .context("failed to read response body")?;
174 let response_message: Response =
175 serde_json::from_slice(&body).context("failed to deserialize response body")?;
176 Ok(response_message)
177 } else {
178 let mut body = Vec::new();
179 response
180 .body_mut()
181 .read_to_end(&mut body)
182 .await
183 .context("failed to read response body")?;
184 let body_str =
185 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
186 Err(AnthropicError::Other(anyhow!(
187 "Failed to connect to API: {} {}",
188 response.status(),
189 body_str
190 )))
191 }
192}
193
194pub async fn stream_completion(
195 client: &dyn HttpClient,
196 api_url: &str,
197 api_key: &str,
198 request: Request,
199 low_speed_timeout: Option<Duration>,
200) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
201 stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
202 .await
203 .map(|output| output.0)
204}
205
206/// https://docs.anthropic.com/en/api/rate-limits#response-headers
207#[derive(Debug)]
208pub struct RateLimitInfo {
209 pub requests_limit: usize,
210 pub requests_remaining: usize,
211 pub requests_reset: DateTime<Utc>,
212 pub tokens_limit: usize,
213 pub tokens_remaining: usize,
214 pub tokens_reset: DateTime<Utc>,
215}
216
217impl RateLimitInfo {
218 fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
219 let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
220 let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
221 let tokens_remaining =
222 get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
223 let requests_remaining =
224 get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
225 let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
226 let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
227 let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
228 let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
229
230 Ok(Self {
231 requests_limit,
232 tokens_limit,
233 requests_remaining,
234 tokens_remaining,
235 requests_reset,
236 tokens_reset,
237 })
238 }
239}
240
241fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
242 Ok(headers
243 .get(key)
244 .ok_or_else(|| anyhow!("missing header `{key}`"))?
245 .to_str()?)
246}
247
248pub async fn stream_completion_with_rate_limit_info(
249 client: &dyn HttpClient,
250 api_url: &str,
251 api_key: &str,
252 request: Request,
253 low_speed_timeout: Option<Duration>,
254) -> Result<
255 (
256 BoxStream<'static, Result<Event, AnthropicError>>,
257 Option<RateLimitInfo>,
258 ),
259 AnthropicError,
260> {
261 let request = StreamingRequest {
262 base: request,
263 stream: true,
264 };
265 let uri = format!("{api_url}/v1/messages");
266 let mut request_builder = HttpRequest::builder()
267 .method(Method::POST)
268 .uri(uri)
269 .header("Anthropic-Version", "2023-06-01")
270 .header(
271 "Anthropic-Beta",
272 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
273 )
274 .header("X-Api-Key", api_key)
275 .header("Content-Type", "application/json");
276 if let Some(low_speed_timeout) = low_speed_timeout {
277 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
278 }
279 let serialized_request =
280 serde_json::to_string(&request).context("failed to serialize request")?;
281 let request = request_builder
282 .body(AsyncBody::from(serialized_request))
283 .context("failed to construct request body")?;
284
285 let mut response = client
286 .send(request)
287 .await
288 .context("failed to send request to Anthropic")?;
289 if response.status().is_success() {
290 let rate_limits = RateLimitInfo::from_headers(response.headers());
291 let reader = BufReader::new(response.into_body());
292 let stream = reader
293 .lines()
294 .filter_map(|line| async move {
295 match line {
296 Ok(line) => {
297 let line = line.strip_prefix("data: ")?;
298 match serde_json::from_str(line) {
299 Ok(response) => Some(Ok(response)),
300 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
301 }
302 }
303 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
304 }
305 })
306 .boxed();
307 Ok((stream, rate_limits.log_err()))
308 } else {
309 let mut body = Vec::new();
310 response
311 .body_mut()
312 .read_to_end(&mut body)
313 .await
314 .context("failed to read response body")?;
315
316 let body_str =
317 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
318
319 match serde_json::from_str::<Event>(body_str) {
320 Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
321 Ok(_) => Err(AnthropicError::Other(anyhow!(
322 "Unexpected success response while expecting an error: '{body_str}'",
323 ))),
324 Err(_) => Err(AnthropicError::Other(anyhow!(
325 "Failed to connect to API: {} {}",
326 response.status(),
327 body_str,
328 ))),
329 }
330 }
331}
332
333pub fn extract_text_from_events(
334 response: impl Stream<Item = Result<Event, AnthropicError>>,
335) -> impl Stream<Item = Result<String, AnthropicError>> {
336 response.filter_map(|response| async move {
337 match response {
338 Ok(response) => match response {
339 Event::ContentBlockStart { content_block, .. } => match content_block {
340 Content::Text { text, .. } => Some(Ok(text)),
341 _ => None,
342 },
343 Event::ContentBlockDelta { delta, .. } => match delta {
344 ContentDelta::TextDelta { text } => Some(Ok(text)),
345 _ => None,
346 },
347 Event::Error { error } => Some(Err(AnthropicError::ApiError(error))),
348 _ => None,
349 },
350 Err(error) => Some(Err(error)),
351 }
352 })
353}
354
355pub async fn extract_tool_args_from_events(
356 tool_name: String,
357 mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
358) -> Result<impl Send + Stream<Item = Result<String>>> {
359 let mut tool_use_index = None;
360 while let Some(event) = events.next().await {
361 if let Event::ContentBlockStart {
362 index,
363 content_block,
364 } = event?
365 {
366 if let Content::ToolUse { name, .. } = content_block {
367 if name == tool_name {
368 tool_use_index = Some(index);
369 break;
370 }
371 }
372 }
373 }
374
375 let Some(tool_use_index) = tool_use_index else {
376 return Err(anyhow!("tool not used"));
377 };
378
379 Ok(events.filter_map(move |event| {
380 let result = match event {
381 Err(error) => Some(Err(error)),
382 Ok(Event::ContentBlockDelta { index, delta }) => match delta {
383 ContentDelta::TextDelta { .. } => None,
384 ContentDelta::InputJsonDelta { partial_json } => {
385 if index == tool_use_index {
386 Some(Ok(partial_json))
387 } else {
388 None
389 }
390 }
391 },
392 _ => None,
393 };
394
395 async move { result }
396 }))
397}
398
399#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
400#[serde(rename_all = "lowercase")]
401pub enum CacheControlType {
402 Ephemeral,
403}
404
405#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
406pub struct CacheControl {
407 #[serde(rename = "type")]
408 pub cache_type: CacheControlType,
409}
410
411#[derive(Debug, Serialize, Deserialize)]
412pub struct Message {
413 pub role: Role,
414 pub content: Vec<Content>,
415}
416
417#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
418#[serde(rename_all = "lowercase")]
419pub enum Role {
420 User,
421 Assistant,
422}
423
424#[derive(Debug, Serialize, Deserialize)]
425#[serde(tag = "type")]
426pub enum Content {
427 #[serde(rename = "text")]
428 Text {
429 text: String,
430 #[serde(skip_serializing_if = "Option::is_none")]
431 cache_control: Option<CacheControl>,
432 },
433 #[serde(rename = "image")]
434 Image {
435 source: ImageSource,
436 #[serde(skip_serializing_if = "Option::is_none")]
437 cache_control: Option<CacheControl>,
438 },
439 #[serde(rename = "tool_use")]
440 ToolUse {
441 id: String,
442 name: String,
443 input: serde_json::Value,
444 #[serde(skip_serializing_if = "Option::is_none")]
445 cache_control: Option<CacheControl>,
446 },
447 #[serde(rename = "tool_result")]
448 ToolResult {
449 tool_use_id: String,
450 content: String,
451 #[serde(skip_serializing_if = "Option::is_none")]
452 cache_control: Option<CacheControl>,
453 },
454}
455
456#[derive(Debug, Serialize, Deserialize)]
457pub struct ImageSource {
458 #[serde(rename = "type")]
459 pub source_type: String,
460 pub media_type: String,
461 pub data: String,
462}
463
464#[derive(Debug, Serialize, Deserialize)]
465pub struct Tool {
466 pub name: String,
467 pub description: String,
468 pub input_schema: serde_json::Value,
469}
470
471#[derive(Debug, Serialize, Deserialize)]
472#[serde(tag = "type", rename_all = "lowercase")]
473pub enum ToolChoice {
474 Auto,
475 Any,
476 Tool { name: String },
477}
478
479#[derive(Debug, Serialize, Deserialize)]
480pub struct Request {
481 pub model: String,
482 pub max_tokens: u32,
483 pub messages: Vec<Message>,
484 #[serde(default, skip_serializing_if = "Vec::is_empty")]
485 pub tools: Vec<Tool>,
486 #[serde(default, skip_serializing_if = "Option::is_none")]
487 pub tool_choice: Option<ToolChoice>,
488 #[serde(default, skip_serializing_if = "Option::is_none")]
489 pub system: Option<String>,
490 #[serde(default, skip_serializing_if = "Option::is_none")]
491 pub metadata: Option<Metadata>,
492 #[serde(default, skip_serializing_if = "Vec::is_empty")]
493 pub stop_sequences: Vec<String>,
494 #[serde(default, skip_serializing_if = "Option::is_none")]
495 pub temperature: Option<f32>,
496 #[serde(default, skip_serializing_if = "Option::is_none")]
497 pub top_k: Option<u32>,
498 #[serde(default, skip_serializing_if = "Option::is_none")]
499 pub top_p: Option<f32>,
500}
501
502#[derive(Debug, Serialize, Deserialize)]
503struct StreamingRequest {
504 #[serde(flatten)]
505 pub base: Request,
506 pub stream: bool,
507}
508
509#[derive(Debug, Serialize, Deserialize)]
510pub struct Metadata {
511 pub user_id: Option<String>,
512}
513
514#[derive(Debug, Serialize, Deserialize)]
515pub struct Usage {
516 #[serde(default, skip_serializing_if = "Option::is_none")]
517 pub input_tokens: Option<u32>,
518 #[serde(default, skip_serializing_if = "Option::is_none")]
519 pub output_tokens: Option<u32>,
520}
521
522#[derive(Debug, Serialize, Deserialize)]
523pub struct Response {
524 pub id: String,
525 #[serde(rename = "type")]
526 pub response_type: String,
527 pub role: Role,
528 pub content: Vec<Content>,
529 pub model: String,
530 #[serde(default, skip_serializing_if = "Option::is_none")]
531 pub stop_reason: Option<String>,
532 #[serde(default, skip_serializing_if = "Option::is_none")]
533 pub stop_sequence: Option<String>,
534 pub usage: Usage,
535}
536
537#[derive(Debug, Serialize, Deserialize)]
538#[serde(tag = "type")]
539pub enum Event {
540 #[serde(rename = "message_start")]
541 MessageStart { message: Response },
542 #[serde(rename = "content_block_start")]
543 ContentBlockStart {
544 index: usize,
545 content_block: Content,
546 },
547 #[serde(rename = "content_block_delta")]
548 ContentBlockDelta { index: usize, delta: ContentDelta },
549 #[serde(rename = "content_block_stop")]
550 ContentBlockStop { index: usize },
551 #[serde(rename = "message_delta")]
552 MessageDelta { delta: MessageDelta, usage: Usage },
553 #[serde(rename = "message_stop")]
554 MessageStop,
555 #[serde(rename = "ping")]
556 Ping,
557 #[serde(rename = "error")]
558 Error { error: ApiError },
559}
560
561#[derive(Debug, Serialize, Deserialize)]
562#[serde(tag = "type")]
563pub enum ContentDelta {
564 #[serde(rename = "text_delta")]
565 TextDelta { text: String },
566 #[serde(rename = "input_json_delta")]
567 InputJsonDelta { partial_json: String },
568}
569
570#[derive(Debug, Serialize, Deserialize)]
571pub struct MessageDelta {
572 pub stop_reason: Option<String>,
573 pub stop_sequence: Option<String>,
574}
575
576#[derive(Error, Debug)]
577pub enum AnthropicError {
578 #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
579 ApiError(ApiError),
580 #[error("{0}")]
581 Other(#[from] anyhow::Error),
582}
583
584#[derive(Debug, Serialize, Deserialize)]
585pub struct ApiError {
586 #[serde(rename = "type")]
587 pub error_type: String,
588 pub message: String,
589}
590
591/// An Anthropic API error code.
592/// https://docs.anthropic.com/en/api/errors#http-errors
593#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
594#[strum(serialize_all = "snake_case")]
595pub enum ApiErrorCode {
596 /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
597 InvalidRequestError,
598 /// 401 - `authentication_error`: There's an issue with your API key.
599 AuthenticationError,
600 /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
601 PermissionError,
602 /// 404 - `not_found_error`: The requested resource was not found.
603 NotFoundError,
604 /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
605 RequestTooLarge,
606 /// 429 - `rate_limit_error`: Your account has hit a rate limit.
607 RateLimitError,
608 /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
609 ApiError,
610 /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
611 OverloadedError,
612}
613
614impl ApiError {
615 pub fn code(&self) -> Option<ApiErrorCode> {
616 ApiErrorCode::from_str(&self.error_type).ok()
617 }
618
619 pub fn is_rate_limit_error(&self) -> bool {
620 match self.error_type.as_str() {
621 "rate_limit_error" => true,
622 _ => false,
623 }
624 }
625}