1mod supported_countries;
2
3use std::{pin::Pin, str::FromStr};
4
5use anyhow::{anyhow, Context as _, Result};
6use chrono::{DateTime, Utc};
7use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
8use http_client::http::{HeaderMap, HeaderValue};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use strum::{EnumIter, EnumString};
12use thiserror::Error;
13use util::ResultExt as _;
14
15pub use supported_countries::*;
16
17pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
18
19#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
20#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
21pub struct AnthropicModelCacheConfiguration {
22 pub min_total_token: usize,
23 pub should_speculate: bool,
24 pub max_cache_anchors: usize,
25}
26
27#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
28#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
29pub enum Model {
30 #[default]
31 #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
32 Claude3_5Sonnet,
33 #[serde(rename = "claude-3-7-sonnet", alias = "claude-3-7-sonnet-latest")]
34 Claude3_7Sonnet,
35 #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
36 Claude3_5Haiku,
37 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
38 Claude3Opus,
39 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
40 Claude3Sonnet,
41 #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-latest")]
42 Claude3Haiku,
43 #[serde(rename = "custom")]
44 Custom {
45 name: String,
46 max_tokens: usize,
47 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
48 display_name: Option<String>,
49 /// Override this model with a different Anthropic model for tool calls.
50 tool_override: Option<String>,
51 /// Indicates whether this custom model supports caching.
52 cache_configuration: Option<AnthropicModelCacheConfiguration>,
53 max_output_tokens: Option<u32>,
54 default_temperature: Option<f32>,
55 #[serde(default)]
56 extra_beta_headers: Vec<String>,
57 },
58}
59
60impl Model {
61 pub fn from_id(id: &str) -> Result<Self> {
62 if id.starts_with("claude-3-5-sonnet") {
63 Ok(Self::Claude3_5Sonnet)
64 } else if id.starts_with("claude-3-7-sonnet") {
65 Ok(Self::Claude3_7Sonnet)
66 } else if id.starts_with("claude-3-5-haiku") {
67 Ok(Self::Claude3_5Haiku)
68 } else if id.starts_with("claude-3-opus") {
69 Ok(Self::Claude3Opus)
70 } else if id.starts_with("claude-3-sonnet") {
71 Ok(Self::Claude3Sonnet)
72 } else if id.starts_with("claude-3-haiku") {
73 Ok(Self::Claude3Haiku)
74 } else {
75 Err(anyhow!("invalid model id"))
76 }
77 }
78
79 pub fn id(&self) -> &str {
80 match self {
81 Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
82 Model::Claude3_7Sonnet => "claude-3-7-sonnet-latest",
83 Model::Claude3_5Haiku => "claude-3-5-haiku-latest",
84 Model::Claude3Opus => "claude-3-opus-latest",
85 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
86 Model::Claude3Haiku => "claude-3-haiku-20240307",
87 Self::Custom { name, .. } => name,
88 }
89 }
90
91 pub fn display_name(&self) -> &str {
92 match self {
93 Self::Claude3_7Sonnet => "Claude 3.7 Sonnet",
94 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
95 Self::Claude3_5Haiku => "Claude 3.5 Haiku",
96 Self::Claude3Opus => "Claude 3 Opus",
97 Self::Claude3Sonnet => "Claude 3 Sonnet",
98 Self::Claude3Haiku => "Claude 3 Haiku",
99 Self::Custom {
100 name, display_name, ..
101 } => display_name.as_ref().unwrap_or(name),
102 }
103 }
104
105 pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
106 match self {
107 Self::Claude3_5Sonnet
108 | Self::Claude3_5Haiku
109 | Self::Claude3_7Sonnet
110 | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
111 min_total_token: 2_048,
112 should_speculate: true,
113 max_cache_anchors: 4,
114 }),
115 Self::Custom {
116 cache_configuration,
117 ..
118 } => cache_configuration.clone(),
119 _ => None,
120 }
121 }
122
123 pub fn max_token_count(&self) -> usize {
124 match self {
125 Self::Claude3_5Sonnet
126 | Self::Claude3_5Haiku
127 | Self::Claude3_7Sonnet
128 | Self::Claude3Opus
129 | Self::Claude3Sonnet
130 | Self::Claude3Haiku => 200_000,
131 Self::Custom { max_tokens, .. } => *max_tokens,
132 }
133 }
134
135 pub fn max_output_tokens(&self) -> u32 {
136 match self {
137 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
138 Self::Claude3_5Sonnet | Self::Claude3_7Sonnet | Self::Claude3_5Haiku => 8_192,
139 Self::Custom {
140 max_output_tokens, ..
141 } => max_output_tokens.unwrap_or(4_096),
142 }
143 }
144
145 pub fn default_temperature(&self) -> f32 {
146 match self {
147 Self::Claude3_5Sonnet
148 | Self::Claude3_7Sonnet
149 | Self::Claude3_5Haiku
150 | Self::Claude3Opus
151 | Self::Claude3Sonnet
152 | Self::Claude3Haiku => 1.0,
153 Self::Custom {
154 default_temperature,
155 ..
156 } => default_temperature.unwrap_or(1.0),
157 }
158 }
159
160 pub const DEFAULT_BETA_HEADERS: &[&str] = &["prompt-caching-2024-07-31"];
161
162 pub fn beta_headers(&self) -> String {
163 let mut headers = Self::DEFAULT_BETA_HEADERS
164 .into_iter()
165 .map(|header| header.to_string())
166 .collect::<Vec<_>>();
167
168 if let Self::Custom {
169 extra_beta_headers, ..
170 } = self
171 {
172 headers.extend(
173 extra_beta_headers
174 .iter()
175 .filter(|header| !header.trim().is_empty())
176 .cloned(),
177 );
178 }
179
180 headers.join(",")
181 }
182
183 pub fn tool_model_id(&self) -> &str {
184 if let Self::Custom {
185 tool_override: Some(tool_override),
186 ..
187 } = self
188 {
189 tool_override
190 } else {
191 self.id()
192 }
193 }
194}
195
196pub async fn complete(
197 client: &dyn HttpClient,
198 api_url: &str,
199 api_key: &str,
200 request: Request,
201) -> Result<Response, AnthropicError> {
202 let uri = format!("{api_url}/v1/messages");
203 let beta_headers = Model::from_id(&request.model)
204 .map(|model| model.beta_headers())
205 .unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
206 let request_builder = HttpRequest::builder()
207 .method(Method::POST)
208 .uri(uri)
209 .header("Anthropic-Version", "2023-06-01")
210 .header("Anthropic-Beta", beta_headers)
211 .header("X-Api-Key", api_key)
212 .header("Content-Type", "application/json");
213
214 let serialized_request =
215 serde_json::to_string(&request).context("failed to serialize request")?;
216 let request = request_builder
217 .body(AsyncBody::from(serialized_request))
218 .context("failed to construct request body")?;
219
220 let mut response = client
221 .send(request)
222 .await
223 .context("failed to send request to Anthropic")?;
224 if response.status().is_success() {
225 let mut body = Vec::new();
226 response
227 .body_mut()
228 .read_to_end(&mut body)
229 .await
230 .context("failed to read response body")?;
231 let response_message: Response =
232 serde_json::from_slice(&body).context("failed to deserialize response body")?;
233 Ok(response_message)
234 } else {
235 let mut body = Vec::new();
236 response
237 .body_mut()
238 .read_to_end(&mut body)
239 .await
240 .context("failed to read response body")?;
241 let body_str =
242 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
243 Err(AnthropicError::Other(anyhow!(
244 "Failed to connect to API: {} {}",
245 response.status(),
246 body_str
247 )))
248 }
249}
250
251pub async fn stream_completion(
252 client: &dyn HttpClient,
253 api_url: &str,
254 api_key: &str,
255 request: Request,
256) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
257 stream_completion_with_rate_limit_info(client, api_url, api_key, request)
258 .await
259 .map(|output| output.0)
260}
261
262/// <https://docs.anthropic.com/en/api/rate-limits#response-headers>
263#[derive(Debug)]
264pub struct RateLimitInfo {
265 pub requests_limit: usize,
266 pub requests_remaining: usize,
267 pub requests_reset: DateTime<Utc>,
268 pub tokens_limit: usize,
269 pub tokens_remaining: usize,
270 pub tokens_reset: DateTime<Utc>,
271}
272
273impl RateLimitInfo {
274 fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
275 let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
276 let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
277 let tokens_remaining =
278 get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
279 let requests_remaining =
280 get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
281 let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
282 let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
283 let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
284 let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
285
286 Ok(Self {
287 requests_limit,
288 tokens_limit,
289 requests_remaining,
290 tokens_remaining,
291 requests_reset,
292 tokens_reset,
293 })
294 }
295}
296
297fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
298 Ok(headers
299 .get(key)
300 .ok_or_else(|| anyhow!("missing header `{key}`"))?
301 .to_str()?)
302}
303
304pub async fn stream_completion_with_rate_limit_info(
305 client: &dyn HttpClient,
306 api_url: &str,
307 api_key: &str,
308 request: Request,
309) -> Result<
310 (
311 BoxStream<'static, Result<Event, AnthropicError>>,
312 Option<RateLimitInfo>,
313 ),
314 AnthropicError,
315> {
316 let request = StreamingRequest {
317 base: request,
318 stream: true,
319 };
320 let uri = format!("{api_url}/v1/messages");
321 let beta_headers = Model::from_id(&request.base.model)
322 .map(|model| model.beta_headers())
323 .unwrap_or_else(|_err| Model::DEFAULT_BETA_HEADERS.join(","));
324 let request_builder = HttpRequest::builder()
325 .method(Method::POST)
326 .uri(uri)
327 .header("Anthropic-Version", "2023-06-01")
328 .header("Anthropic-Beta", beta_headers)
329 .header("X-Api-Key", api_key)
330 .header("Content-Type", "application/json");
331 let serialized_request =
332 serde_json::to_string(&request).context("failed to serialize request")?;
333 let request = request_builder
334 .body(AsyncBody::from(serialized_request))
335 .context("failed to construct request body")?;
336
337 let mut response = client
338 .send(request)
339 .await
340 .context("failed to send request to Anthropic")?;
341 if response.status().is_success() {
342 let rate_limits = RateLimitInfo::from_headers(response.headers());
343 let reader = BufReader::new(response.into_body());
344 let stream = reader
345 .lines()
346 .filter_map(|line| async move {
347 match line {
348 Ok(line) => {
349 let line = line.strip_prefix("data: ")?;
350 match serde_json::from_str(line) {
351 Ok(response) => Some(Ok(response)),
352 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
353 }
354 }
355 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
356 }
357 })
358 .boxed();
359 Ok((stream, rate_limits.log_err()))
360 } else {
361 let mut body = Vec::new();
362 response
363 .body_mut()
364 .read_to_end(&mut body)
365 .await
366 .context("failed to read response body")?;
367
368 let body_str =
369 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
370
371 match serde_json::from_str::<Event>(body_str) {
372 Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
373 Ok(_) => Err(AnthropicError::Other(anyhow!(
374 "Unexpected success response while expecting an error: '{body_str}'",
375 ))),
376 Err(_) => Err(AnthropicError::Other(anyhow!(
377 "Failed to connect to API: {} {}",
378 response.status(),
379 body_str,
380 ))),
381 }
382 }
383}
384
385pub async fn extract_tool_args_from_events(
386 tool_name: String,
387 mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
388) -> Result<impl Send + Stream<Item = Result<String>>> {
389 let mut tool_use_index = None;
390 while let Some(event) = events.next().await {
391 if let Event::ContentBlockStart {
392 index,
393 content_block: ResponseContent::ToolUse { name, .. },
394 } = event?
395 {
396 if name == tool_name {
397 tool_use_index = Some(index);
398 break;
399 }
400 }
401 }
402
403 let Some(tool_use_index) = tool_use_index else {
404 return Err(anyhow!("tool not used"));
405 };
406
407 Ok(events.filter_map(move |event| {
408 let result = match event {
409 Err(error) => Some(Err(error)),
410 Ok(Event::ContentBlockDelta { index, delta }) => match delta {
411 ContentDelta::TextDelta { .. } => None,
412 ContentDelta::InputJsonDelta { partial_json } => {
413 if index == tool_use_index {
414 Some(Ok(partial_json))
415 } else {
416 None
417 }
418 }
419 },
420 _ => None,
421 };
422
423 async move { result }
424 }))
425}
426
427#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
428#[serde(rename_all = "lowercase")]
429pub enum CacheControlType {
430 Ephemeral,
431}
432
433#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
434pub struct CacheControl {
435 #[serde(rename = "type")]
436 pub cache_type: CacheControlType,
437}
438
439#[derive(Debug, Serialize, Deserialize)]
440pub struct Message {
441 pub role: Role,
442 pub content: Vec<RequestContent>,
443}
444
445#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
446#[serde(rename_all = "lowercase")]
447pub enum Role {
448 User,
449 Assistant,
450}
451
452#[derive(Debug, Serialize, Deserialize)]
453#[serde(tag = "type")]
454pub enum RequestContent {
455 #[serde(rename = "text")]
456 Text {
457 text: String,
458 #[serde(skip_serializing_if = "Option::is_none")]
459 cache_control: Option<CacheControl>,
460 },
461 #[serde(rename = "image")]
462 Image {
463 source: ImageSource,
464 #[serde(skip_serializing_if = "Option::is_none")]
465 cache_control: Option<CacheControl>,
466 },
467 #[serde(rename = "tool_use")]
468 ToolUse {
469 id: String,
470 name: String,
471 input: serde_json::Value,
472 #[serde(skip_serializing_if = "Option::is_none")]
473 cache_control: Option<CacheControl>,
474 },
475 #[serde(rename = "tool_result")]
476 ToolResult {
477 tool_use_id: String,
478 is_error: bool,
479 content: String,
480 #[serde(skip_serializing_if = "Option::is_none")]
481 cache_control: Option<CacheControl>,
482 },
483}
484
485#[derive(Debug, Serialize, Deserialize)]
486#[serde(tag = "type")]
487pub enum ResponseContent {
488 #[serde(rename = "text")]
489 Text { text: String },
490 #[serde(rename = "tool_use")]
491 ToolUse {
492 id: String,
493 name: String,
494 input: serde_json::Value,
495 },
496}
497
498#[derive(Debug, Serialize, Deserialize)]
499pub struct ImageSource {
500 #[serde(rename = "type")]
501 pub source_type: String,
502 pub media_type: String,
503 pub data: String,
504}
505
506#[derive(Debug, Serialize, Deserialize)]
507pub struct Tool {
508 pub name: String,
509 pub description: String,
510 pub input_schema: serde_json::Value,
511}
512
513#[derive(Debug, Serialize, Deserialize)]
514#[serde(tag = "type", rename_all = "lowercase")]
515pub enum ToolChoice {
516 Auto,
517 Any,
518 Tool { name: String },
519}
520
521#[derive(Debug, Serialize, Deserialize)]
522pub struct Request {
523 pub model: String,
524 pub max_tokens: u32,
525 pub messages: Vec<Message>,
526 #[serde(default, skip_serializing_if = "Vec::is_empty")]
527 pub tools: Vec<Tool>,
528 #[serde(default, skip_serializing_if = "Option::is_none")]
529 pub tool_choice: Option<ToolChoice>,
530 #[serde(default, skip_serializing_if = "Option::is_none")]
531 pub system: Option<String>,
532 #[serde(default, skip_serializing_if = "Option::is_none")]
533 pub metadata: Option<Metadata>,
534 #[serde(default, skip_serializing_if = "Vec::is_empty")]
535 pub stop_sequences: Vec<String>,
536 #[serde(default, skip_serializing_if = "Option::is_none")]
537 pub temperature: Option<f32>,
538 #[serde(default, skip_serializing_if = "Option::is_none")]
539 pub top_k: Option<u32>,
540 #[serde(default, skip_serializing_if = "Option::is_none")]
541 pub top_p: Option<f32>,
542}
543
544#[derive(Debug, Serialize, Deserialize)]
545struct StreamingRequest {
546 #[serde(flatten)]
547 pub base: Request,
548 pub stream: bool,
549}
550
551#[derive(Debug, Serialize, Deserialize)]
552pub struct Metadata {
553 pub user_id: Option<String>,
554}
555
556#[derive(Debug, Serialize, Deserialize, Default)]
557pub struct Usage {
558 #[serde(default, skip_serializing_if = "Option::is_none")]
559 pub input_tokens: Option<u32>,
560 #[serde(default, skip_serializing_if = "Option::is_none")]
561 pub output_tokens: Option<u32>,
562 #[serde(default, skip_serializing_if = "Option::is_none")]
563 pub cache_creation_input_tokens: Option<u32>,
564 #[serde(default, skip_serializing_if = "Option::is_none")]
565 pub cache_read_input_tokens: Option<u32>,
566}
567
568#[derive(Debug, Serialize, Deserialize)]
569pub struct Response {
570 pub id: String,
571 #[serde(rename = "type")]
572 pub response_type: String,
573 pub role: Role,
574 pub content: Vec<ResponseContent>,
575 pub model: String,
576 #[serde(default, skip_serializing_if = "Option::is_none")]
577 pub stop_reason: Option<String>,
578 #[serde(default, skip_serializing_if = "Option::is_none")]
579 pub stop_sequence: Option<String>,
580 pub usage: Usage,
581}
582
583#[derive(Debug, Serialize, Deserialize)]
584#[serde(tag = "type")]
585pub enum Event {
586 #[serde(rename = "message_start")]
587 MessageStart { message: Response },
588 #[serde(rename = "content_block_start")]
589 ContentBlockStart {
590 index: usize,
591 content_block: ResponseContent,
592 },
593 #[serde(rename = "content_block_delta")]
594 ContentBlockDelta { index: usize, delta: ContentDelta },
595 #[serde(rename = "content_block_stop")]
596 ContentBlockStop { index: usize },
597 #[serde(rename = "message_delta")]
598 MessageDelta { delta: MessageDelta, usage: Usage },
599 #[serde(rename = "message_stop")]
600 MessageStop,
601 #[serde(rename = "ping")]
602 Ping,
603 #[serde(rename = "error")]
604 Error { error: ApiError },
605}
606
607#[derive(Debug, Serialize, Deserialize)]
608#[serde(tag = "type")]
609pub enum ContentDelta {
610 #[serde(rename = "text_delta")]
611 TextDelta { text: String },
612 #[serde(rename = "input_json_delta")]
613 InputJsonDelta { partial_json: String },
614}
615
616#[derive(Debug, Serialize, Deserialize)]
617pub struct MessageDelta {
618 pub stop_reason: Option<String>,
619 pub stop_sequence: Option<String>,
620}
621
622#[derive(Error, Debug)]
623pub enum AnthropicError {
624 #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
625 ApiError(ApiError),
626 #[error("{0}")]
627 Other(#[from] anyhow::Error),
628}
629
630#[derive(Debug, Serialize, Deserialize)]
631pub struct ApiError {
632 #[serde(rename = "type")]
633 pub error_type: String,
634 pub message: String,
635}
636
637/// An Anthropic API error code.
638/// <https://docs.anthropic.com/en/api/errors#http-errors>
639#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
640#[strum(serialize_all = "snake_case")]
641pub enum ApiErrorCode {
642 /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
643 InvalidRequestError,
644 /// 401 - `authentication_error`: There's an issue with your API key.
645 AuthenticationError,
646 /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
647 PermissionError,
648 /// 404 - `not_found_error`: The requested resource was not found.
649 NotFoundError,
650 /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
651 RequestTooLarge,
652 /// 429 - `rate_limit_error`: Your account has hit a rate limit.
653 RateLimitError,
654 /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
655 ApiError,
656 /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
657 OverloadedError,
658}
659
660impl ApiError {
661 pub fn code(&self) -> Option<ApiErrorCode> {
662 ApiErrorCode::from_str(&self.error_type).ok()
663 }
664
665 pub fn is_rate_limit_error(&self) -> bool {
666 matches!(self.error_type.as_str(), "rate_limit_error")
667 }
668}