1mod supported_countries;
2
3use std::time::Duration;
4use std::{pin::Pin, str::FromStr};
5
6use anyhow::{anyhow, Context, Result};
7use chrono::{DateTime, Utc};
8use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use isahc::config::Configurable;
11use isahc::http::{HeaderMap, HeaderValue};
12use serde::{Deserialize, Serialize};
13use strum::{EnumIter, EnumString};
14use thiserror::Error;
15use util::ResultExt as _;
16
17pub use supported_countries::*;
18
19pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
20
21#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
22#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
23pub struct AnthropicModelCacheConfiguration {
24 pub min_total_token: usize,
25 pub should_speculate: bool,
26 pub max_cache_anchors: usize,
27}
28
29#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
30#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
31pub enum Model {
32 #[default]
33 #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
34 Claude3_5Sonnet,
35 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
36 Claude3Opus,
37 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
38 Claude3Sonnet,
39 #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
40 Claude3Haiku,
41 #[serde(rename = "custom")]
42 Custom {
43 name: String,
44 max_tokens: usize,
45 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
46 display_name: Option<String>,
47 /// Override this model with a different Anthropic model for tool calls.
48 tool_override: Option<String>,
49 /// Indicates whether this custom model supports caching.
50 cache_configuration: Option<AnthropicModelCacheConfiguration>,
51 max_output_tokens: Option<u32>,
52 default_temperature: Option<f32>,
53 },
54}
55
56impl Model {
57 pub fn from_id(id: &str) -> Result<Self> {
58 if id.starts_with("claude-3-5-sonnet") {
59 Ok(Self::Claude3_5Sonnet)
60 } else if id.starts_with("claude-3-opus") {
61 Ok(Self::Claude3Opus)
62 } else if id.starts_with("claude-3-sonnet") {
63 Ok(Self::Claude3Sonnet)
64 } else if id.starts_with("claude-3-haiku") {
65 Ok(Self::Claude3Haiku)
66 } else {
67 Err(anyhow!("invalid model id"))
68 }
69 }
70
71 pub fn id(&self) -> &str {
72 match self {
73 Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
74 Model::Claude3Opus => "claude-3-opus-20240229",
75 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
76 Model::Claude3Haiku => "claude-3-haiku-20240307",
77 Self::Custom { name, .. } => name,
78 }
79 }
80
81 pub fn display_name(&self) -> &str {
82 match self {
83 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
84 Self::Claude3Opus => "Claude 3 Opus",
85 Self::Claude3Sonnet => "Claude 3 Sonnet",
86 Self::Claude3Haiku => "Claude 3 Haiku",
87 Self::Custom {
88 name, display_name, ..
89 } => display_name.as_ref().unwrap_or(name),
90 }
91 }
92
93 pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
94 match self {
95 Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
96 min_total_token: 2_048,
97 should_speculate: true,
98 max_cache_anchors: 4,
99 }),
100 Self::Custom {
101 cache_configuration,
102 ..
103 } => cache_configuration.clone(),
104 _ => None,
105 }
106 }
107
108 pub fn max_token_count(&self) -> usize {
109 match self {
110 Self::Claude3_5Sonnet
111 | Self::Claude3Opus
112 | Self::Claude3Sonnet
113 | Self::Claude3Haiku => 200_000,
114 Self::Custom { max_tokens, .. } => *max_tokens,
115 }
116 }
117
118 pub fn max_output_tokens(&self) -> u32 {
119 match self {
120 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
121 Self::Claude3_5Sonnet => 8_192,
122 Self::Custom {
123 max_output_tokens, ..
124 } => max_output_tokens.unwrap_or(4_096),
125 }
126 }
127
128 pub fn default_temperature(&self) -> f32 {
129 match self {
130 Self::Claude3_5Sonnet
131 | Self::Claude3Opus
132 | Self::Claude3Sonnet
133 | Self::Claude3Haiku => 1.0,
134 Self::Custom {
135 default_temperature,
136 ..
137 } => default_temperature.unwrap_or(1.0),
138 }
139 }
140
141 pub fn tool_model_id(&self) -> &str {
142 if let Self::Custom {
143 tool_override: Some(tool_override),
144 ..
145 } = self
146 {
147 tool_override
148 } else {
149 self.id()
150 }
151 }
152}
153
154pub async fn complete(
155 client: &dyn HttpClient,
156 api_url: &str,
157 api_key: &str,
158 request: Request,
159) -> Result<Response, AnthropicError> {
160 let uri = format!("{api_url}/v1/messages");
161 let request_builder = HttpRequest::builder()
162 .method(Method::POST)
163 .uri(uri)
164 .header("Anthropic-Version", "2023-06-01")
165 .header(
166 "Anthropic-Beta",
167 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
168 )
169 .header("X-Api-Key", api_key)
170 .header("Content-Type", "application/json");
171
172 let serialized_request =
173 serde_json::to_string(&request).context("failed to serialize request")?;
174 let request = request_builder
175 .body(AsyncBody::from(serialized_request))
176 .context("failed to construct request body")?;
177
178 let mut response = client
179 .send(request)
180 .await
181 .context("failed to send request to Anthropic")?;
182 if response.status().is_success() {
183 let mut body = Vec::new();
184 response
185 .body_mut()
186 .read_to_end(&mut body)
187 .await
188 .context("failed to read response body")?;
189 let response_message: Response =
190 serde_json::from_slice(&body).context("failed to deserialize response body")?;
191 Ok(response_message)
192 } else {
193 let mut body = Vec::new();
194 response
195 .body_mut()
196 .read_to_end(&mut body)
197 .await
198 .context("failed to read response body")?;
199 let body_str =
200 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
201 Err(AnthropicError::Other(anyhow!(
202 "Failed to connect to API: {} {}",
203 response.status(),
204 body_str
205 )))
206 }
207}
208
209pub async fn stream_completion(
210 client: &dyn HttpClient,
211 api_url: &str,
212 api_key: &str,
213 request: Request,
214 low_speed_timeout: Option<Duration>,
215) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
216 stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
217 .await
218 .map(|output| output.0)
219}
220
221/// https://docs.anthropic.com/en/api/rate-limits#response-headers
222#[derive(Debug)]
223pub struct RateLimitInfo {
224 pub requests_limit: usize,
225 pub requests_remaining: usize,
226 pub requests_reset: DateTime<Utc>,
227 pub tokens_limit: usize,
228 pub tokens_remaining: usize,
229 pub tokens_reset: DateTime<Utc>,
230}
231
232impl RateLimitInfo {
233 fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
234 let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
235 let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
236 let tokens_remaining =
237 get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
238 let requests_remaining =
239 get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
240 let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
241 let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
242 let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
243 let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
244
245 Ok(Self {
246 requests_limit,
247 tokens_limit,
248 requests_remaining,
249 tokens_remaining,
250 requests_reset,
251 tokens_reset,
252 })
253 }
254}
255
256fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
257 Ok(headers
258 .get(key)
259 .ok_or_else(|| anyhow!("missing header `{key}`"))?
260 .to_str()?)
261}
262
263pub async fn stream_completion_with_rate_limit_info(
264 client: &dyn HttpClient,
265 api_url: &str,
266 api_key: &str,
267 request: Request,
268 low_speed_timeout: Option<Duration>,
269) -> Result<
270 (
271 BoxStream<'static, Result<Event, AnthropicError>>,
272 Option<RateLimitInfo>,
273 ),
274 AnthropicError,
275> {
276 let request = StreamingRequest {
277 base: request,
278 stream: true,
279 };
280 let uri = format!("{api_url}/v1/messages");
281 let mut request_builder = HttpRequest::builder()
282 .method(Method::POST)
283 .uri(uri)
284 .header("Anthropic-Version", "2023-06-01")
285 .header(
286 "Anthropic-Beta",
287 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
288 )
289 .header("X-Api-Key", api_key)
290 .header("Content-Type", "application/json");
291 if let Some(low_speed_timeout) = low_speed_timeout {
292 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
293 }
294 let serialized_request =
295 serde_json::to_string(&request).context("failed to serialize request")?;
296 let request = request_builder
297 .body(AsyncBody::from(serialized_request))
298 .context("failed to construct request body")?;
299
300 let mut response = client
301 .send(request)
302 .await
303 .context("failed to send request to Anthropic")?;
304 if response.status().is_success() {
305 let rate_limits = RateLimitInfo::from_headers(response.headers());
306 let reader = BufReader::new(response.into_body());
307 let stream = reader
308 .lines()
309 .filter_map(|line| async move {
310 match line {
311 Ok(line) => {
312 let line = line.strip_prefix("data: ")?;
313 match serde_json::from_str(line) {
314 Ok(response) => Some(Ok(response)),
315 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
316 }
317 }
318 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
319 }
320 })
321 .boxed();
322 Ok((stream, rate_limits.log_err()))
323 } else {
324 let mut body = Vec::new();
325 response
326 .body_mut()
327 .read_to_end(&mut body)
328 .await
329 .context("failed to read response body")?;
330
331 let body_str =
332 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
333
334 match serde_json::from_str::<Event>(body_str) {
335 Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
336 Ok(_) => Err(AnthropicError::Other(anyhow!(
337 "Unexpected success response while expecting an error: '{body_str}'",
338 ))),
339 Err(_) => Err(AnthropicError::Other(anyhow!(
340 "Failed to connect to API: {} {}",
341 response.status(),
342 body_str,
343 ))),
344 }
345 }
346}
347
348pub async fn extract_tool_args_from_events(
349 tool_name: String,
350 mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
351) -> Result<impl Send + Stream<Item = Result<String>>> {
352 let mut tool_use_index = None;
353 while let Some(event) = events.next().await {
354 if let Event::ContentBlockStart {
355 index,
356 content_block: ResponseContent::ToolUse { name, .. },
357 } = event?
358 {
359 if name == tool_name {
360 tool_use_index = Some(index);
361 break;
362 }
363 }
364 }
365
366 let Some(tool_use_index) = tool_use_index else {
367 return Err(anyhow!("tool not used"));
368 };
369
370 Ok(events.filter_map(move |event| {
371 let result = match event {
372 Err(error) => Some(Err(error)),
373 Ok(Event::ContentBlockDelta { index, delta }) => match delta {
374 ContentDelta::TextDelta { .. } => None,
375 ContentDelta::InputJsonDelta { partial_json } => {
376 if index == tool_use_index {
377 Some(Ok(partial_json))
378 } else {
379 None
380 }
381 }
382 },
383 _ => None,
384 };
385
386 async move { result }
387 }))
388}
389
390#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
391#[serde(rename_all = "lowercase")]
392pub enum CacheControlType {
393 Ephemeral,
394}
395
396#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
397pub struct CacheControl {
398 #[serde(rename = "type")]
399 pub cache_type: CacheControlType,
400}
401
402#[derive(Debug, Serialize, Deserialize)]
403pub struct Message {
404 pub role: Role,
405 pub content: Vec<RequestContent>,
406}
407
408#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
409#[serde(rename_all = "lowercase")]
410pub enum Role {
411 User,
412 Assistant,
413}
414
415#[derive(Debug, Serialize, Deserialize)]
416#[serde(tag = "type")]
417pub enum RequestContent {
418 #[serde(rename = "text")]
419 Text {
420 text: String,
421 #[serde(skip_serializing_if = "Option::is_none")]
422 cache_control: Option<CacheControl>,
423 },
424 #[serde(rename = "image")]
425 Image {
426 source: ImageSource,
427 #[serde(skip_serializing_if = "Option::is_none")]
428 cache_control: Option<CacheControl>,
429 },
430 #[serde(rename = "tool_use")]
431 ToolUse {
432 id: String,
433 name: String,
434 input: serde_json::Value,
435 #[serde(skip_serializing_if = "Option::is_none")]
436 cache_control: Option<CacheControl>,
437 },
438 #[serde(rename = "tool_result")]
439 ToolResult {
440 tool_use_id: String,
441 is_error: bool,
442 content: String,
443 #[serde(skip_serializing_if = "Option::is_none")]
444 cache_control: Option<CacheControl>,
445 },
446}
447
448#[derive(Debug, Serialize, Deserialize)]
449#[serde(tag = "type")]
450pub enum ResponseContent {
451 #[serde(rename = "text")]
452 Text { text: String },
453 #[serde(rename = "tool_use")]
454 ToolUse {
455 id: String,
456 name: String,
457 input: serde_json::Value,
458 },
459}
460
461#[derive(Debug, Serialize, Deserialize)]
462pub struct ImageSource {
463 #[serde(rename = "type")]
464 pub source_type: String,
465 pub media_type: String,
466 pub data: String,
467}
468
469#[derive(Debug, Serialize, Deserialize)]
470pub struct Tool {
471 pub name: String,
472 pub description: String,
473 pub input_schema: serde_json::Value,
474}
475
476#[derive(Debug, Serialize, Deserialize)]
477#[serde(tag = "type", rename_all = "lowercase")]
478pub enum ToolChoice {
479 Auto,
480 Any,
481 Tool { name: String },
482}
483
484#[derive(Debug, Serialize, Deserialize)]
485pub struct Request {
486 pub model: String,
487 pub max_tokens: u32,
488 pub messages: Vec<Message>,
489 #[serde(default, skip_serializing_if = "Vec::is_empty")]
490 pub tools: Vec<Tool>,
491 #[serde(default, skip_serializing_if = "Option::is_none")]
492 pub tool_choice: Option<ToolChoice>,
493 #[serde(default, skip_serializing_if = "Option::is_none")]
494 pub system: Option<String>,
495 #[serde(default, skip_serializing_if = "Option::is_none")]
496 pub metadata: Option<Metadata>,
497 #[serde(default, skip_serializing_if = "Vec::is_empty")]
498 pub stop_sequences: Vec<String>,
499 #[serde(default, skip_serializing_if = "Option::is_none")]
500 pub temperature: Option<f32>,
501 #[serde(default, skip_serializing_if = "Option::is_none")]
502 pub top_k: Option<u32>,
503 #[serde(default, skip_serializing_if = "Option::is_none")]
504 pub top_p: Option<f32>,
505}
506
507#[derive(Debug, Serialize, Deserialize)]
508struct StreamingRequest {
509 #[serde(flatten)]
510 pub base: Request,
511 pub stream: bool,
512}
513
514#[derive(Debug, Serialize, Deserialize)]
515pub struct Metadata {
516 pub user_id: Option<String>,
517}
518
519#[derive(Debug, Serialize, Deserialize)]
520pub struct Usage {
521 #[serde(default, skip_serializing_if = "Option::is_none")]
522 pub input_tokens: Option<u32>,
523 #[serde(default, skip_serializing_if = "Option::is_none")]
524 pub output_tokens: Option<u32>,
525}
526
527#[derive(Debug, Serialize, Deserialize)]
528pub struct Response {
529 pub id: String,
530 #[serde(rename = "type")]
531 pub response_type: String,
532 pub role: Role,
533 pub content: Vec<ResponseContent>,
534 pub model: String,
535 #[serde(default, skip_serializing_if = "Option::is_none")]
536 pub stop_reason: Option<String>,
537 #[serde(default, skip_serializing_if = "Option::is_none")]
538 pub stop_sequence: Option<String>,
539 pub usage: Usage,
540}
541
542#[derive(Debug, Serialize, Deserialize)]
543#[serde(tag = "type")]
544pub enum Event {
545 #[serde(rename = "message_start")]
546 MessageStart { message: Response },
547 #[serde(rename = "content_block_start")]
548 ContentBlockStart {
549 index: usize,
550 content_block: ResponseContent,
551 },
552 #[serde(rename = "content_block_delta")]
553 ContentBlockDelta { index: usize, delta: ContentDelta },
554 #[serde(rename = "content_block_stop")]
555 ContentBlockStop { index: usize },
556 #[serde(rename = "message_delta")]
557 MessageDelta { delta: MessageDelta, usage: Usage },
558 #[serde(rename = "message_stop")]
559 MessageStop,
560 #[serde(rename = "ping")]
561 Ping,
562 #[serde(rename = "error")]
563 Error { error: ApiError },
564}
565
566#[derive(Debug, Serialize, Deserialize)]
567#[serde(tag = "type")]
568pub enum ContentDelta {
569 #[serde(rename = "text_delta")]
570 TextDelta { text: String },
571 #[serde(rename = "input_json_delta")]
572 InputJsonDelta { partial_json: String },
573}
574
575#[derive(Debug, Serialize, Deserialize)]
576pub struct MessageDelta {
577 pub stop_reason: Option<String>,
578 pub stop_sequence: Option<String>,
579}
580
581#[derive(Error, Debug)]
582pub enum AnthropicError {
583 #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
584 ApiError(ApiError),
585 #[error("{0}")]
586 Other(#[from] anyhow::Error),
587}
588
589#[derive(Debug, Serialize, Deserialize)]
590pub struct ApiError {
591 #[serde(rename = "type")]
592 pub error_type: String,
593 pub message: String,
594}
595
596/// An Anthropic API error code.
597/// https://docs.anthropic.com/en/api/errors#http-errors
598#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
599#[strum(serialize_all = "snake_case")]
600pub enum ApiErrorCode {
601 /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
602 InvalidRequestError,
603 /// 401 - `authentication_error`: There's an issue with your API key.
604 AuthenticationError,
605 /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
606 PermissionError,
607 /// 404 - `not_found_error`: The requested resource was not found.
608 NotFoundError,
609 /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
610 RequestTooLarge,
611 /// 429 - `rate_limit_error`: Your account has hit a rate limit.
612 RateLimitError,
613 /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
614 ApiError,
615 /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
616 OverloadedError,
617}
618
619impl ApiError {
620 pub fn code(&self) -> Option<ApiErrorCode> {
621 ApiErrorCode::from_str(&self.error_type).ok()
622 }
623
624 pub fn is_rate_limit_error(&self) -> bool {
625 matches!(self.error_type.as_str(), "rate_limit_error")
626 }
627}