1mod supported_countries;
2
3use std::time::Duration;
4use std::{pin::Pin, str::FromStr};
5
6use anyhow::{anyhow, Context, Result};
7use chrono::{DateTime, Utc};
8use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use isahc::config::Configurable;
11use isahc::http::{HeaderMap, HeaderValue};
12use serde::{Deserialize, Serialize};
13use strum::{EnumIter, EnumString};
14use thiserror::Error;
15use util::ResultExt as _;
16
17pub use supported_countries::*;
18
19pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
20
21#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
22#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
23pub struct AnthropicModelCacheConfiguration {
24 pub min_total_token: usize,
25 pub should_speculate: bool,
26 pub max_cache_anchors: usize,
27}
28
29#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
30#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
31pub enum Model {
32 #[default]
33 #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
34 Claude3_5Sonnet,
35 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
36 Claude3Opus,
37 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
38 Claude3Sonnet,
39 #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
40 Claude3Haiku,
41 #[serde(rename = "custom")]
42 Custom {
43 name: String,
44 max_tokens: usize,
45 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
46 display_name: Option<String>,
47 /// Override this model with a different Anthropic model for tool calls.
48 tool_override: Option<String>,
49 /// Indicates whether this custom model supports caching.
50 cache_configuration: Option<AnthropicModelCacheConfiguration>,
51 max_output_tokens: Option<u32>,
52 },
53}
54
55impl Model {
56 pub fn from_id(id: &str) -> Result<Self> {
57 if id.starts_with("claude-3-5-sonnet") {
58 Ok(Self::Claude3_5Sonnet)
59 } else if id.starts_with("claude-3-opus") {
60 Ok(Self::Claude3Opus)
61 } else if id.starts_with("claude-3-sonnet") {
62 Ok(Self::Claude3Sonnet)
63 } else if id.starts_with("claude-3-haiku") {
64 Ok(Self::Claude3Haiku)
65 } else {
66 Err(anyhow!("invalid model id"))
67 }
68 }
69
70 pub fn id(&self) -> &str {
71 match self {
72 Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
73 Model::Claude3Opus => "claude-3-opus-20240229",
74 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
75 Model::Claude3Haiku => "claude-3-haiku-20240307",
76 Self::Custom { name, .. } => name,
77 }
78 }
79
80 pub fn display_name(&self) -> &str {
81 match self {
82 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
83 Self::Claude3Opus => "Claude 3 Opus",
84 Self::Claude3Sonnet => "Claude 3 Sonnet",
85 Self::Claude3Haiku => "Claude 3 Haiku",
86 Self::Custom {
87 name, display_name, ..
88 } => display_name.as_ref().unwrap_or(name),
89 }
90 }
91
92 pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
93 match self {
94 Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
95 min_total_token: 2_048,
96 should_speculate: true,
97 max_cache_anchors: 4,
98 }),
99 Self::Custom {
100 cache_configuration,
101 ..
102 } => cache_configuration.clone(),
103 _ => None,
104 }
105 }
106
107 pub fn max_token_count(&self) -> usize {
108 match self {
109 Self::Claude3_5Sonnet
110 | Self::Claude3Opus
111 | Self::Claude3Sonnet
112 | Self::Claude3Haiku => 200_000,
113 Self::Custom { max_tokens, .. } => *max_tokens,
114 }
115 }
116
117 pub fn max_output_tokens(&self) -> u32 {
118 match self {
119 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
120 Self::Claude3_5Sonnet => 8_192,
121 Self::Custom {
122 max_output_tokens, ..
123 } => max_output_tokens.unwrap_or(4_096),
124 }
125 }
126
127 pub fn tool_model_id(&self) -> &str {
128 if let Self::Custom {
129 tool_override: Some(tool_override),
130 ..
131 } = self
132 {
133 tool_override
134 } else {
135 self.id()
136 }
137 }
138}
139
140pub async fn complete(
141 client: &dyn HttpClient,
142 api_url: &str,
143 api_key: &str,
144 request: Request,
145) -> Result<Response, AnthropicError> {
146 let uri = format!("{api_url}/v1/messages");
147 let request_builder = HttpRequest::builder()
148 .method(Method::POST)
149 .uri(uri)
150 .header("Anthropic-Version", "2023-06-01")
151 .header(
152 "Anthropic-Beta",
153 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
154 )
155 .header("X-Api-Key", api_key)
156 .header("Content-Type", "application/json");
157
158 let serialized_request =
159 serde_json::to_string(&request).context("failed to serialize request")?;
160 let request = request_builder
161 .body(AsyncBody::from(serialized_request))
162 .context("failed to construct request body")?;
163
164 let mut response = client
165 .send(request)
166 .await
167 .context("failed to send request to Anthropic")?;
168 if response.status().is_success() {
169 let mut body = Vec::new();
170 response
171 .body_mut()
172 .read_to_end(&mut body)
173 .await
174 .context("failed to read response body")?;
175 let response_message: Response =
176 serde_json::from_slice(&body).context("failed to deserialize response body")?;
177 Ok(response_message)
178 } else {
179 let mut body = Vec::new();
180 response
181 .body_mut()
182 .read_to_end(&mut body)
183 .await
184 .context("failed to read response body")?;
185 let body_str =
186 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
187 Err(AnthropicError::Other(anyhow!(
188 "Failed to connect to API: {} {}",
189 response.status(),
190 body_str
191 )))
192 }
193}
194
195pub async fn stream_completion(
196 client: &dyn HttpClient,
197 api_url: &str,
198 api_key: &str,
199 request: Request,
200 low_speed_timeout: Option<Duration>,
201) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
202 stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
203 .await
204 .map(|output| output.0)
205}
206
207/// https://docs.anthropic.com/en/api/rate-limits#response-headers
208#[derive(Debug)]
209pub struct RateLimitInfo {
210 pub requests_limit: usize,
211 pub requests_remaining: usize,
212 pub requests_reset: DateTime<Utc>,
213 pub tokens_limit: usize,
214 pub tokens_remaining: usize,
215 pub tokens_reset: DateTime<Utc>,
216}
217
218impl RateLimitInfo {
219 fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
220 let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
221 let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
222 let tokens_remaining =
223 get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
224 let requests_remaining =
225 get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
226 let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
227 let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
228 let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
229 let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
230
231 Ok(Self {
232 requests_limit,
233 tokens_limit,
234 requests_remaining,
235 tokens_remaining,
236 requests_reset,
237 tokens_reset,
238 })
239 }
240}
241
242fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
243 Ok(headers
244 .get(key)
245 .ok_or_else(|| anyhow!("missing header `{key}`"))?
246 .to_str()?)
247}
248
249pub async fn stream_completion_with_rate_limit_info(
250 client: &dyn HttpClient,
251 api_url: &str,
252 api_key: &str,
253 request: Request,
254 low_speed_timeout: Option<Duration>,
255) -> Result<
256 (
257 BoxStream<'static, Result<Event, AnthropicError>>,
258 Option<RateLimitInfo>,
259 ),
260 AnthropicError,
261> {
262 let request = StreamingRequest {
263 base: request,
264 stream: true,
265 };
266 let uri = format!("{api_url}/v1/messages");
267 let mut request_builder = HttpRequest::builder()
268 .method(Method::POST)
269 .uri(uri)
270 .header("Anthropic-Version", "2023-06-01")
271 .header(
272 "Anthropic-Beta",
273 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
274 )
275 .header("X-Api-Key", api_key)
276 .header("Content-Type", "application/json");
277 if let Some(low_speed_timeout) = low_speed_timeout {
278 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
279 }
280 let serialized_request =
281 serde_json::to_string(&request).context("failed to serialize request")?;
282 let request = request_builder
283 .body(AsyncBody::from(serialized_request))
284 .context("failed to construct request body")?;
285
286 let mut response = client
287 .send(request)
288 .await
289 .context("failed to send request to Anthropic")?;
290 if response.status().is_success() {
291 let rate_limits = RateLimitInfo::from_headers(response.headers());
292 let reader = BufReader::new(response.into_body());
293 let stream = reader
294 .lines()
295 .filter_map(|line| async move {
296 match line {
297 Ok(line) => {
298 let line = line.strip_prefix("data: ")?;
299 match serde_json::from_str(line) {
300 Ok(response) => Some(Ok(response)),
301 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
302 }
303 }
304 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
305 }
306 })
307 .boxed();
308 Ok((stream, rate_limits.log_err()))
309 } else {
310 let mut body = Vec::new();
311 response
312 .body_mut()
313 .read_to_end(&mut body)
314 .await
315 .context("failed to read response body")?;
316
317 let body_str =
318 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
319
320 match serde_json::from_str::<Event>(body_str) {
321 Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
322 Ok(_) => Err(AnthropicError::Other(anyhow!(
323 "Unexpected success response while expecting an error: '{body_str}'",
324 ))),
325 Err(_) => Err(AnthropicError::Other(anyhow!(
326 "Failed to connect to API: {} {}",
327 response.status(),
328 body_str,
329 ))),
330 }
331 }
332}
333
334pub async fn extract_tool_args_from_events(
335 tool_name: String,
336 mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
337) -> Result<impl Send + Stream<Item = Result<String>>> {
338 let mut tool_use_index = None;
339 while let Some(event) = events.next().await {
340 if let Event::ContentBlockStart {
341 index,
342 content_block: ResponseContent::ToolUse { name, .. },
343 } = event?
344 {
345 if name == tool_name {
346 tool_use_index = Some(index);
347 break;
348 }
349 }
350 }
351
352 let Some(tool_use_index) = tool_use_index else {
353 return Err(anyhow!("tool not used"));
354 };
355
356 Ok(events.filter_map(move |event| {
357 let result = match event {
358 Err(error) => Some(Err(error)),
359 Ok(Event::ContentBlockDelta { index, delta }) => match delta {
360 ContentDelta::TextDelta { .. } => None,
361 ContentDelta::InputJsonDelta { partial_json } => {
362 if index == tool_use_index {
363 Some(Ok(partial_json))
364 } else {
365 None
366 }
367 }
368 },
369 _ => None,
370 };
371
372 async move { result }
373 }))
374}
375
376#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
377#[serde(rename_all = "lowercase")]
378pub enum CacheControlType {
379 Ephemeral,
380}
381
382#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
383pub struct CacheControl {
384 #[serde(rename = "type")]
385 pub cache_type: CacheControlType,
386}
387
388#[derive(Debug, Serialize, Deserialize)]
389pub struct Message {
390 pub role: Role,
391 pub content: Vec<RequestContent>,
392}
393
394#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
395#[serde(rename_all = "lowercase")]
396pub enum Role {
397 User,
398 Assistant,
399}
400
401#[derive(Debug, Serialize, Deserialize)]
402#[serde(tag = "type")]
403pub enum RequestContent {
404 #[serde(rename = "text")]
405 Text {
406 text: String,
407 #[serde(skip_serializing_if = "Option::is_none")]
408 cache_control: Option<CacheControl>,
409 },
410 #[serde(rename = "image")]
411 Image {
412 source: ImageSource,
413 #[serde(skip_serializing_if = "Option::is_none")]
414 cache_control: Option<CacheControl>,
415 },
416 #[serde(rename = "tool_use")]
417 ToolUse {
418 id: String,
419 name: String,
420 input: serde_json::Value,
421 #[serde(skip_serializing_if = "Option::is_none")]
422 cache_control: Option<CacheControl>,
423 },
424 #[serde(rename = "tool_result")]
425 ToolResult {
426 tool_use_id: String,
427 is_error: bool,
428 content: String,
429 #[serde(skip_serializing_if = "Option::is_none")]
430 cache_control: Option<CacheControl>,
431 },
432}
433
434#[derive(Debug, Serialize, Deserialize)]
435#[serde(tag = "type")]
436pub enum ResponseContent {
437 #[serde(rename = "text")]
438 Text { text: String },
439 #[serde(rename = "tool_use")]
440 ToolUse {
441 id: String,
442 name: String,
443 input: serde_json::Value,
444 },
445}
446
447#[derive(Debug, Serialize, Deserialize)]
448pub struct ImageSource {
449 #[serde(rename = "type")]
450 pub source_type: String,
451 pub media_type: String,
452 pub data: String,
453}
454
455#[derive(Debug, Serialize, Deserialize)]
456pub struct Tool {
457 pub name: String,
458 pub description: String,
459 pub input_schema: serde_json::Value,
460}
461
462#[derive(Debug, Serialize, Deserialize)]
463#[serde(tag = "type", rename_all = "lowercase")]
464pub enum ToolChoice {
465 Auto,
466 Any,
467 Tool { name: String },
468}
469
470#[derive(Debug, Serialize, Deserialize)]
471pub struct Request {
472 pub model: String,
473 pub max_tokens: u32,
474 pub messages: Vec<Message>,
475 #[serde(default, skip_serializing_if = "Vec::is_empty")]
476 pub tools: Vec<Tool>,
477 #[serde(default, skip_serializing_if = "Option::is_none")]
478 pub tool_choice: Option<ToolChoice>,
479 #[serde(default, skip_serializing_if = "Option::is_none")]
480 pub system: Option<String>,
481 #[serde(default, skip_serializing_if = "Option::is_none")]
482 pub metadata: Option<Metadata>,
483 #[serde(default, skip_serializing_if = "Vec::is_empty")]
484 pub stop_sequences: Vec<String>,
485 #[serde(default, skip_serializing_if = "Option::is_none")]
486 pub temperature: Option<f32>,
487 #[serde(default, skip_serializing_if = "Option::is_none")]
488 pub top_k: Option<u32>,
489 #[serde(default, skip_serializing_if = "Option::is_none")]
490 pub top_p: Option<f32>,
491}
492
493#[derive(Debug, Serialize, Deserialize)]
494struct StreamingRequest {
495 #[serde(flatten)]
496 pub base: Request,
497 pub stream: bool,
498}
499
500#[derive(Debug, Serialize, Deserialize)]
501pub struct Metadata {
502 pub user_id: Option<String>,
503}
504
505#[derive(Debug, Serialize, Deserialize)]
506pub struct Usage {
507 #[serde(default, skip_serializing_if = "Option::is_none")]
508 pub input_tokens: Option<u32>,
509 #[serde(default, skip_serializing_if = "Option::is_none")]
510 pub output_tokens: Option<u32>,
511}
512
513#[derive(Debug, Serialize, Deserialize)]
514pub struct Response {
515 pub id: String,
516 #[serde(rename = "type")]
517 pub response_type: String,
518 pub role: Role,
519 pub content: Vec<ResponseContent>,
520 pub model: String,
521 #[serde(default, skip_serializing_if = "Option::is_none")]
522 pub stop_reason: Option<String>,
523 #[serde(default, skip_serializing_if = "Option::is_none")]
524 pub stop_sequence: Option<String>,
525 pub usage: Usage,
526}
527
528#[derive(Debug, Serialize, Deserialize)]
529#[serde(tag = "type")]
530pub enum Event {
531 #[serde(rename = "message_start")]
532 MessageStart { message: Response },
533 #[serde(rename = "content_block_start")]
534 ContentBlockStart {
535 index: usize,
536 content_block: ResponseContent,
537 },
538 #[serde(rename = "content_block_delta")]
539 ContentBlockDelta { index: usize, delta: ContentDelta },
540 #[serde(rename = "content_block_stop")]
541 ContentBlockStop { index: usize },
542 #[serde(rename = "message_delta")]
543 MessageDelta { delta: MessageDelta, usage: Usage },
544 #[serde(rename = "message_stop")]
545 MessageStop,
546 #[serde(rename = "ping")]
547 Ping,
548 #[serde(rename = "error")]
549 Error { error: ApiError },
550}
551
552#[derive(Debug, Serialize, Deserialize)]
553#[serde(tag = "type")]
554pub enum ContentDelta {
555 #[serde(rename = "text_delta")]
556 TextDelta { text: String },
557 #[serde(rename = "input_json_delta")]
558 InputJsonDelta { partial_json: String },
559}
560
561#[derive(Debug, Serialize, Deserialize)]
562pub struct MessageDelta {
563 pub stop_reason: Option<String>,
564 pub stop_sequence: Option<String>,
565}
566
567#[derive(Error, Debug)]
568pub enum AnthropicError {
569 #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
570 ApiError(ApiError),
571 #[error("{0}")]
572 Other(#[from] anyhow::Error),
573}
574
575#[derive(Debug, Serialize, Deserialize)]
576pub struct ApiError {
577 #[serde(rename = "type")]
578 pub error_type: String,
579 pub message: String,
580}
581
582/// An Anthropic API error code.
583/// https://docs.anthropic.com/en/api/errors#http-errors
584#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
585#[strum(serialize_all = "snake_case")]
586pub enum ApiErrorCode {
587 /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
588 InvalidRequestError,
589 /// 401 - `authentication_error`: There's an issue with your API key.
590 AuthenticationError,
591 /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
592 PermissionError,
593 /// 404 - `not_found_error`: The requested resource was not found.
594 NotFoundError,
595 /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
596 RequestTooLarge,
597 /// 429 - `rate_limit_error`: Your account has hit a rate limit.
598 RateLimitError,
599 /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
600 ApiError,
601 /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
602 OverloadedError,
603}
604
605impl ApiError {
606 pub fn code(&self) -> Option<ApiErrorCode> {
607 ApiErrorCode::from_str(&self.error_type).ok()
608 }
609
610 pub fn is_rate_limit_error(&self) -> bool {
611 matches!(self.error_type.as_str(), "rate_limit_error")
612 }
613}