1mod supported_countries;
2
3use std::{pin::Pin, str::FromStr};
4
5use anyhow::{anyhow, Context as _, Result};
6use chrono::{DateTime, Utc};
7use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
8use http_client::http::{HeaderMap, HeaderValue};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use strum::{EnumIter, EnumString};
12use thiserror::Error;
13use util::ResultExt as _;
14
15pub use supported_countries::*;
16
17pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
18
19#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
20#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
21pub struct AnthropicModelCacheConfiguration {
22 pub min_total_token: usize,
23 pub should_speculate: bool,
24 pub max_cache_anchors: usize,
25}
26
27#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
28#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
29pub enum Model {
30 #[default]
31 #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
32 Claude3_5Sonnet,
33 #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
34 Claude3_5Haiku,
35 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
36 Claude3Opus,
37 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
38 Claude3Sonnet,
39 #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-latest")]
40 Claude3Haiku,
41 #[serde(rename = "custom")]
42 Custom {
43 name: String,
44 max_tokens: usize,
45 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
46 display_name: Option<String>,
47 /// Override this model with a different Anthropic model for tool calls.
48 tool_override: Option<String>,
49 /// Indicates whether this custom model supports caching.
50 cache_configuration: Option<AnthropicModelCacheConfiguration>,
51 max_output_tokens: Option<u32>,
52 default_temperature: Option<f32>,
53 #[serde(default)]
54 extra_beta_headers: Vec<String>,
55 },
56}
57
58impl Model {
59 pub fn from_id(id: &str) -> Result<Self> {
60 if id.starts_with("claude-3-5-sonnet") {
61 Ok(Self::Claude3_5Sonnet)
62 } else if id.starts_with("claude-3-5-haiku") {
63 Ok(Self::Claude3_5Haiku)
64 } else if id.starts_with("claude-3-opus") {
65 Ok(Self::Claude3Opus)
66 } else if id.starts_with("claude-3-sonnet") {
67 Ok(Self::Claude3Sonnet)
68 } else if id.starts_with("claude-3-haiku") {
69 Ok(Self::Claude3Haiku)
70 } else {
71 Err(anyhow!("invalid model id"))
72 }
73 }
74
75 pub fn id(&self) -> &str {
76 match self {
77 Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
78 Model::Claude3_5Haiku => "claude-3-5-haiku-latest",
79 Model::Claude3Opus => "claude-3-opus-latest",
80 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
81 Model::Claude3Haiku => "claude-3-haiku-20240307",
82 Self::Custom { name, .. } => name,
83 }
84 }
85
86 pub fn display_name(&self) -> &str {
87 match self {
88 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
89 Self::Claude3_5Haiku => "Claude 3.5 Haiku",
90 Self::Claude3Opus => "Claude 3 Opus",
91 Self::Claude3Sonnet => "Claude 3 Sonnet",
92 Self::Claude3Haiku => "Claude 3 Haiku",
93 Self::Custom {
94 name, display_name, ..
95 } => display_name.as_ref().unwrap_or(name),
96 }
97 }
98
99 pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
100 match self {
101 Self::Claude3_5Sonnet | Self::Claude3_5Haiku | Self::Claude3Haiku => {
102 Some(AnthropicModelCacheConfiguration {
103 min_total_token: 2_048,
104 should_speculate: true,
105 max_cache_anchors: 4,
106 })
107 }
108 Self::Custom {
109 cache_configuration,
110 ..
111 } => cache_configuration.clone(),
112 _ => None,
113 }
114 }
115
116 pub fn max_token_count(&self) -> usize {
117 match self {
118 Self::Claude3_5Sonnet
119 | Self::Claude3_5Haiku
120 | Self::Claude3Opus
121 | Self::Claude3Sonnet
122 | Self::Claude3Haiku => 200_000,
123 Self::Custom { max_tokens, .. } => *max_tokens,
124 }
125 }
126
127 pub fn max_output_tokens(&self) -> u32 {
128 match self {
129 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
130 Self::Claude3_5Sonnet | Self::Claude3_5Haiku => 8_192,
131 Self::Custom {
132 max_output_tokens, ..
133 } => max_output_tokens.unwrap_or(4_096),
134 }
135 }
136
137 pub fn default_temperature(&self) -> f32 {
138 match self {
139 Self::Claude3_5Sonnet
140 | Self::Claude3_5Haiku
141 | Self::Claude3Opus
142 | Self::Claude3Sonnet
143 | Self::Claude3Haiku => 1.0,
144 Self::Custom {
145 default_temperature,
146 ..
147 } => default_temperature.unwrap_or(1.0),
148 }
149 }
150
151 pub const DEFAULT_BETA_HEADERS: &[&str] = &["prompt-caching-2024-07-31"];
152
153 pub fn beta_headers(&self) -> String {
154 let mut headers = Self::DEFAULT_BETA_HEADERS
155 .into_iter()
156 .map(|header| header.to_string())
157 .collect::<Vec<_>>();
158
159 if let Self::Custom {
160 extra_beta_headers, ..
161 } = self
162 {
163 headers.extend(
164 extra_beta_headers
165 .iter()
166 .filter(|header| !header.trim().is_empty())
167 .cloned(),
168 );
169 }
170
171 headers.join(",")
172 }
173
174 pub fn tool_model_id(&self) -> &str {
175 if let Self::Custom {
176 tool_override: Some(tool_override),
177 ..
178 } = self
179 {
180 tool_override
181 } else {
182 self.id()
183 }
184 }
185}
186
187pub async fn complete(
188 client: &dyn HttpClient,
189 api_url: &str,
190 api_key: &str,
191 request: Request,
192) -> Result<Response, AnthropicError> {
193 let uri = format!("{api_url}/v1/messages");
194 let beta_headers = Model::from_id(&request.model)
195 .map(|model| model.beta_headers())
196 .unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
197 let request_builder = HttpRequest::builder()
198 .method(Method::POST)
199 .uri(uri)
200 .header("Anthropic-Version", "2023-06-01")
201 .header("Anthropic-Beta", beta_headers)
202 .header("X-Api-Key", api_key)
203 .header("Content-Type", "application/json");
204
205 let serialized_request =
206 serde_json::to_string(&request).context("failed to serialize request")?;
207 let request = request_builder
208 .body(AsyncBody::from(serialized_request))
209 .context("failed to construct request body")?;
210
211 let mut response = client
212 .send(request)
213 .await
214 .context("failed to send request to Anthropic")?;
215 if response.status().is_success() {
216 let mut body = Vec::new();
217 response
218 .body_mut()
219 .read_to_end(&mut body)
220 .await
221 .context("failed to read response body")?;
222 let response_message: Response =
223 serde_json::from_slice(&body).context("failed to deserialize response body")?;
224 Ok(response_message)
225 } else {
226 let mut body = Vec::new();
227 response
228 .body_mut()
229 .read_to_end(&mut body)
230 .await
231 .context("failed to read response body")?;
232 let body_str =
233 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
234 Err(AnthropicError::Other(anyhow!(
235 "Failed to connect to API: {} {}",
236 response.status(),
237 body_str
238 )))
239 }
240}
241
242pub async fn stream_completion(
243 client: &dyn HttpClient,
244 api_url: &str,
245 api_key: &str,
246 request: Request,
247) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
248 stream_completion_with_rate_limit_info(client, api_url, api_key, request)
249 .await
250 .map(|output| output.0)
251}
252
253/// <https://docs.anthropic.com/en/api/rate-limits#response-headers>
254#[derive(Debug)]
255pub struct RateLimitInfo {
256 pub requests_limit: usize,
257 pub requests_remaining: usize,
258 pub requests_reset: DateTime<Utc>,
259 pub tokens_limit: usize,
260 pub tokens_remaining: usize,
261 pub tokens_reset: DateTime<Utc>,
262}
263
264impl RateLimitInfo {
265 fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
266 let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
267 let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
268 let tokens_remaining =
269 get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
270 let requests_remaining =
271 get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
272 let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
273 let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
274 let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
275 let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
276
277 Ok(Self {
278 requests_limit,
279 tokens_limit,
280 requests_remaining,
281 tokens_remaining,
282 requests_reset,
283 tokens_reset,
284 })
285 }
286}
287
288fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
289 Ok(headers
290 .get(key)
291 .ok_or_else(|| anyhow!("missing header `{key}`"))?
292 .to_str()?)
293}
294
295pub async fn stream_completion_with_rate_limit_info(
296 client: &dyn HttpClient,
297 api_url: &str,
298 api_key: &str,
299 request: Request,
300) -> Result<
301 (
302 BoxStream<'static, Result<Event, AnthropicError>>,
303 Option<RateLimitInfo>,
304 ),
305 AnthropicError,
306> {
307 let request = StreamingRequest {
308 base: request,
309 stream: true,
310 };
311 let uri = format!("{api_url}/v1/messages");
312 let beta_headers = Model::from_id(&request.base.model)
313 .map(|model| model.beta_headers())
314 .unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
315 let request_builder = HttpRequest::builder()
316 .method(Method::POST)
317 .uri(uri)
318 .header("Anthropic-Version", "2023-06-01")
319 .header("Anthropic-Beta", beta_headers)
320 .header("X-Api-Key", api_key)
321 .header("Content-Type", "application/json");
322 let serialized_request =
323 serde_json::to_string(&request).context("failed to serialize request")?;
324 let request = request_builder
325 .body(AsyncBody::from(serialized_request))
326 .context("failed to construct request body")?;
327
328 let mut response = client
329 .send(request)
330 .await
331 .context("failed to send request to Anthropic")?;
332 if response.status().is_success() {
333 let rate_limits = RateLimitInfo::from_headers(response.headers());
334 let reader = BufReader::new(response.into_body());
335 let stream = reader
336 .lines()
337 .filter_map(|line| async move {
338 match line {
339 Ok(line) => {
340 let line = line.strip_prefix("data: ")?;
341 match serde_json::from_str(line) {
342 Ok(response) => Some(Ok(response)),
343 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
344 }
345 }
346 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
347 }
348 })
349 .boxed();
350 Ok((stream, rate_limits.log_err()))
351 } else {
352 let mut body = Vec::new();
353 response
354 .body_mut()
355 .read_to_end(&mut body)
356 .await
357 .context("failed to read response body")?;
358
359 let body_str =
360 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
361
362 match serde_json::from_str::<Event>(body_str) {
363 Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
364 Ok(_) => Err(AnthropicError::Other(anyhow!(
365 "Unexpected success response while expecting an error: '{body_str}'",
366 ))),
367 Err(_) => Err(AnthropicError::Other(anyhow!(
368 "Failed to connect to API: {} {}",
369 response.status(),
370 body_str,
371 ))),
372 }
373 }
374}
375
376pub async fn extract_tool_args_from_events(
377 tool_name: String,
378 mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
379) -> Result<impl Send + Stream<Item = Result<String>>> {
380 let mut tool_use_index = None;
381 while let Some(event) = events.next().await {
382 if let Event::ContentBlockStart {
383 index,
384 content_block: ResponseContent::ToolUse { name, .. },
385 } = event?
386 {
387 if name == tool_name {
388 tool_use_index = Some(index);
389 break;
390 }
391 }
392 }
393
394 let Some(tool_use_index) = tool_use_index else {
395 return Err(anyhow!("tool not used"));
396 };
397
398 Ok(events.filter_map(move |event| {
399 let result = match event {
400 Err(error) => Some(Err(error)),
401 Ok(Event::ContentBlockDelta { index, delta }) => match delta {
402 ContentDelta::TextDelta { .. } => None,
403 ContentDelta::InputJsonDelta { partial_json } => {
404 if index == tool_use_index {
405 Some(Ok(partial_json))
406 } else {
407 None
408 }
409 }
410 },
411 _ => None,
412 };
413
414 async move { result }
415 }))
416}
417
418#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
419#[serde(rename_all = "lowercase")]
420pub enum CacheControlType {
421 Ephemeral,
422}
423
424#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
425pub struct CacheControl {
426 #[serde(rename = "type")]
427 pub cache_type: CacheControlType,
428}
429
430#[derive(Debug, Serialize, Deserialize)]
431pub struct Message {
432 pub role: Role,
433 pub content: Vec<RequestContent>,
434}
435
436#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
437#[serde(rename_all = "lowercase")]
438pub enum Role {
439 User,
440 Assistant,
441}
442
443#[derive(Debug, Serialize, Deserialize)]
444#[serde(tag = "type")]
445pub enum RequestContent {
446 #[serde(rename = "text")]
447 Text {
448 text: String,
449 #[serde(skip_serializing_if = "Option::is_none")]
450 cache_control: Option<CacheControl>,
451 },
452 #[serde(rename = "image")]
453 Image {
454 source: ImageSource,
455 #[serde(skip_serializing_if = "Option::is_none")]
456 cache_control: Option<CacheControl>,
457 },
458 #[serde(rename = "tool_use")]
459 ToolUse {
460 id: String,
461 name: String,
462 input: serde_json::Value,
463 #[serde(skip_serializing_if = "Option::is_none")]
464 cache_control: Option<CacheControl>,
465 },
466 #[serde(rename = "tool_result")]
467 ToolResult {
468 tool_use_id: String,
469 is_error: bool,
470 content: String,
471 #[serde(skip_serializing_if = "Option::is_none")]
472 cache_control: Option<CacheControl>,
473 },
474}
475
476#[derive(Debug, Serialize, Deserialize)]
477#[serde(tag = "type")]
478pub enum ResponseContent {
479 #[serde(rename = "text")]
480 Text { text: String },
481 #[serde(rename = "tool_use")]
482 ToolUse {
483 id: String,
484 name: String,
485 input: serde_json::Value,
486 },
487}
488
489#[derive(Debug, Serialize, Deserialize)]
490pub struct ImageSource {
491 #[serde(rename = "type")]
492 pub source_type: String,
493 pub media_type: String,
494 pub data: String,
495}
496
497#[derive(Debug, Serialize, Deserialize)]
498pub struct Tool {
499 pub name: String,
500 pub description: String,
501 pub input_schema: serde_json::Value,
502}
503
504#[derive(Debug, Serialize, Deserialize)]
505#[serde(tag = "type", rename_all = "lowercase")]
506pub enum ToolChoice {
507 Auto,
508 Any,
509 Tool { name: String },
510}
511
512#[derive(Debug, Serialize, Deserialize)]
513pub struct Request {
514 pub model: String,
515 pub max_tokens: u32,
516 pub messages: Vec<Message>,
517 #[serde(default, skip_serializing_if = "Vec::is_empty")]
518 pub tools: Vec<Tool>,
519 #[serde(default, skip_serializing_if = "Option::is_none")]
520 pub tool_choice: Option<ToolChoice>,
521 #[serde(default, skip_serializing_if = "Option::is_none")]
522 pub system: Option<String>,
523 #[serde(default, skip_serializing_if = "Option::is_none")]
524 pub metadata: Option<Metadata>,
525 #[serde(default, skip_serializing_if = "Vec::is_empty")]
526 pub stop_sequences: Vec<String>,
527 #[serde(default, skip_serializing_if = "Option::is_none")]
528 pub temperature: Option<f32>,
529 #[serde(default, skip_serializing_if = "Option::is_none")]
530 pub top_k: Option<u32>,
531 #[serde(default, skip_serializing_if = "Option::is_none")]
532 pub top_p: Option<f32>,
533}
534
535#[derive(Debug, Serialize, Deserialize)]
536struct StreamingRequest {
537 #[serde(flatten)]
538 pub base: Request,
539 pub stream: bool,
540}
541
542#[derive(Debug, Serialize, Deserialize)]
543pub struct Metadata {
544 pub user_id: Option<String>,
545}
546
547#[derive(Debug, Serialize, Deserialize)]
548pub struct Usage {
549 #[serde(default, skip_serializing_if = "Option::is_none")]
550 pub input_tokens: Option<u32>,
551 #[serde(default, skip_serializing_if = "Option::is_none")]
552 pub output_tokens: Option<u32>,
553 #[serde(default, skip_serializing_if = "Option::is_none")]
554 pub cache_creation_input_tokens: Option<u32>,
555 #[serde(default, skip_serializing_if = "Option::is_none")]
556 pub cache_read_input_tokens: Option<u32>,
557}
558
559#[derive(Debug, Serialize, Deserialize)]
560pub struct Response {
561 pub id: String,
562 #[serde(rename = "type")]
563 pub response_type: String,
564 pub role: Role,
565 pub content: Vec<ResponseContent>,
566 pub model: String,
567 #[serde(default, skip_serializing_if = "Option::is_none")]
568 pub stop_reason: Option<String>,
569 #[serde(default, skip_serializing_if = "Option::is_none")]
570 pub stop_sequence: Option<String>,
571 pub usage: Usage,
572}
573
574#[derive(Debug, Serialize, Deserialize)]
575#[serde(tag = "type")]
576pub enum Event {
577 #[serde(rename = "message_start")]
578 MessageStart { message: Response },
579 #[serde(rename = "content_block_start")]
580 ContentBlockStart {
581 index: usize,
582 content_block: ResponseContent,
583 },
584 #[serde(rename = "content_block_delta")]
585 ContentBlockDelta { index: usize, delta: ContentDelta },
586 #[serde(rename = "content_block_stop")]
587 ContentBlockStop { index: usize },
588 #[serde(rename = "message_delta")]
589 MessageDelta { delta: MessageDelta, usage: Usage },
590 #[serde(rename = "message_stop")]
591 MessageStop,
592 #[serde(rename = "ping")]
593 Ping,
594 #[serde(rename = "error")]
595 Error { error: ApiError },
596}
597
598#[derive(Debug, Serialize, Deserialize)]
599#[serde(tag = "type")]
600pub enum ContentDelta {
601 #[serde(rename = "text_delta")]
602 TextDelta { text: String },
603 #[serde(rename = "input_json_delta")]
604 InputJsonDelta { partial_json: String },
605}
606
607#[derive(Debug, Serialize, Deserialize)]
608pub struct MessageDelta {
609 pub stop_reason: Option<String>,
610 pub stop_sequence: Option<String>,
611}
612
613#[derive(Error, Debug)]
614pub enum AnthropicError {
615 #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
616 ApiError(ApiError),
617 #[error("{0}")]
618 Other(#[from] anyhow::Error),
619}
620
621#[derive(Debug, Serialize, Deserialize)]
622pub struct ApiError {
623 #[serde(rename = "type")]
624 pub error_type: String,
625 pub message: String,
626}
627
628/// An Anthropic API error code.
629/// <https://docs.anthropic.com/en/api/errors#http-errors>
630#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
631#[strum(serialize_all = "snake_case")]
632pub enum ApiErrorCode {
633 /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
634 InvalidRequestError,
635 /// 401 - `authentication_error`: There's an issue with your API key.
636 AuthenticationError,
637 /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
638 PermissionError,
639 /// 404 - `not_found_error`: The requested resource was not found.
640 NotFoundError,
641 /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
642 RequestTooLarge,
643 /// 429 - `rate_limit_error`: Your account has hit a rate limit.
644 RateLimitError,
645 /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
646 ApiError,
647 /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
648 OverloadedError,
649}
650
651impl ApiError {
652 pub fn code(&self) -> Option<ApiErrorCode> {
653 ApiErrorCode::from_str(&self.error_type).ok()
654 }
655
656 pub fn is_rate_limit_error(&self) -> bool {
657 matches!(self.error_type.as_str(), "rate_limit_error")
658 }
659}