1mod supported_countries;
2
3use std::{pin::Pin, str::FromStr};
4
5use anyhow::{anyhow, Context, Result};
6use chrono::{DateTime, Utc};
7use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
8use http_client::http::{HeaderMap, HeaderValue};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use strum::{EnumIter, EnumString};
12use thiserror::Error;
13use util::ResultExt as _;
14
15pub use supported_countries::*;
16
17pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
18
19#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
20#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
21pub struct AnthropicModelCacheConfiguration {
22 pub min_total_token: usize,
23 pub should_speculate: bool,
24 pub max_cache_anchors: usize,
25}
26
27#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
28#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
29pub enum Model {
30 #[default]
31 #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
32 Claude3_5Sonnet,
33 #[serde(rename = "claude-3-5-haiku", alias = "claude-3-5-haiku-latest")]
34 Claude3_5Haiku,
35 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
36 Claude3Opus,
37 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
38 Claude3Sonnet,
39 #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-latest")]
40 Claude3Haiku,
41 #[serde(rename = "custom")]
42 Custom {
43 name: String,
44 max_tokens: usize,
45 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
46 display_name: Option<String>,
47 /// Override this model with a different Anthropic model for tool calls.
48 tool_override: Option<String>,
49 /// Indicates whether this custom model supports caching.
50 cache_configuration: Option<AnthropicModelCacheConfiguration>,
51 max_output_tokens: Option<u32>,
52 default_temperature: Option<f32>,
53 #[serde(default)]
54 extra_beta_headers: Vec<String>,
55 },
56}
57
58impl Model {
59 pub fn from_id(id: &str) -> Result<Self> {
60 if id.starts_with("claude-3-5-sonnet") {
61 Ok(Self::Claude3_5Sonnet)
62 } else if id.starts_with("claude-3-5-haiku") {
63 Ok(Self::Claude3_5Haiku)
64 } else if id.starts_with("claude-3-opus") {
65 Ok(Self::Claude3Opus)
66 } else if id.starts_with("claude-3-sonnet") {
67 Ok(Self::Claude3Sonnet)
68 } else if id.starts_with("claude-3-haiku") {
69 Ok(Self::Claude3Haiku)
70 } else {
71 Err(anyhow!("invalid model id"))
72 }
73 }
74
75 pub fn id(&self) -> &str {
76 match self {
77 Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
78 Model::Claude3_5Haiku => "claude-3-5-haiku-latest",
79 Model::Claude3Opus => "claude-3-opus-latest",
80 Model::Claude3Sonnet => "claude-3-sonnet-latest",
81 Model::Claude3Haiku => "claude-3-haiku-latest",
82 Self::Custom { name, .. } => name,
83 }
84 }
85
86 pub fn display_name(&self) -> &str {
87 match self {
88 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
89 Self::Claude3_5Haiku => "Claude 3.5 Haiku",
90 Self::Claude3Opus => "Claude 3 Opus",
91 Self::Claude3Sonnet => "Claude 3 Sonnet",
92 Self::Claude3Haiku => "Claude 3 Haiku",
93 Self::Custom {
94 name, display_name, ..
95 } => display_name.as_ref().unwrap_or(name),
96 }
97 }
98
99 pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
100 match self {
101 Self::Claude3_5Sonnet | Self::Claude3_5Haiku | Self::Claude3Haiku => {
102 Some(AnthropicModelCacheConfiguration {
103 min_total_token: 2_048,
104 should_speculate: true,
105 max_cache_anchors: 4,
106 })
107 }
108 Self::Custom {
109 cache_configuration,
110 ..
111 } => cache_configuration.clone(),
112 _ => None,
113 }
114 }
115
116 pub fn max_token_count(&self) -> usize {
117 match self {
118 Self::Claude3_5Sonnet
119 | Self::Claude3_5Haiku
120 | Self::Claude3Opus
121 | Self::Claude3Sonnet
122 | Self::Claude3Haiku => 200_000,
123 Self::Custom { max_tokens, .. } => *max_tokens,
124 }
125 }
126
127 pub fn max_output_tokens(&self) -> u32 {
128 match self {
129 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
130 Self::Claude3_5Sonnet | Self::Claude3_5Haiku => 8_192,
131 Self::Custom {
132 max_output_tokens, ..
133 } => max_output_tokens.unwrap_or(4_096),
134 }
135 }
136
137 pub fn default_temperature(&self) -> f32 {
138 match self {
139 Self::Claude3_5Sonnet
140 | Self::Claude3_5Haiku
141 | Self::Claude3Opus
142 | Self::Claude3Sonnet
143 | Self::Claude3Haiku => 1.0,
144 Self::Custom {
145 default_temperature,
146 ..
147 } => default_temperature.unwrap_or(1.0),
148 }
149 }
150
151 pub fn beta_headers(&self) -> String {
152 let mut headers = vec!["prompt-caching-2024-07-31".to_string()];
153
154 if let Self::Custom {
155 extra_beta_headers, ..
156 } = self
157 {
158 headers.extend(
159 extra_beta_headers
160 .iter()
161 .filter(|header| !header.trim().is_empty())
162 .cloned(),
163 );
164 }
165
166 headers.join(",")
167 }
168
169 pub fn tool_model_id(&self) -> &str {
170 if let Self::Custom {
171 tool_override: Some(tool_override),
172 ..
173 } = self
174 {
175 tool_override
176 } else {
177 self.id()
178 }
179 }
180}
181
182pub async fn complete(
183 client: &dyn HttpClient,
184 api_url: &str,
185 api_key: &str,
186 request: Request,
187) -> Result<Response, AnthropicError> {
188 let uri = format!("{api_url}/v1/messages");
189 let model = Model::from_id(&request.model)?;
190 let request_builder = HttpRequest::builder()
191 .method(Method::POST)
192 .uri(uri)
193 .header("Anthropic-Version", "2023-06-01")
194 .header("Anthropic-Beta", model.beta_headers())
195 .header("X-Api-Key", api_key)
196 .header("Content-Type", "application/json");
197
198 let serialized_request =
199 serde_json::to_string(&request).context("failed to serialize request")?;
200 let request = request_builder
201 .body(AsyncBody::from(serialized_request))
202 .context("failed to construct request body")?;
203
204 let mut response = client
205 .send(request)
206 .await
207 .context("failed to send request to Anthropic")?;
208 if response.status().is_success() {
209 let mut body = Vec::new();
210 response
211 .body_mut()
212 .read_to_end(&mut body)
213 .await
214 .context("failed to read response body")?;
215 let response_message: Response =
216 serde_json::from_slice(&body).context("failed to deserialize response body")?;
217 Ok(response_message)
218 } else {
219 let mut body = Vec::new();
220 response
221 .body_mut()
222 .read_to_end(&mut body)
223 .await
224 .context("failed to read response body")?;
225 let body_str =
226 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
227 Err(AnthropicError::Other(anyhow!(
228 "Failed to connect to API: {} {}",
229 response.status(),
230 body_str
231 )))
232 }
233}
234
235pub async fn stream_completion(
236 client: &dyn HttpClient,
237 api_url: &str,
238 api_key: &str,
239 request: Request,
240) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
241 stream_completion_with_rate_limit_info(client, api_url, api_key, request)
242 .await
243 .map(|output| output.0)
244}
245
246/// https://docs.anthropic.com/en/api/rate-limits#response-headers
247#[derive(Debug)]
248pub struct RateLimitInfo {
249 pub requests_limit: usize,
250 pub requests_remaining: usize,
251 pub requests_reset: DateTime<Utc>,
252 pub tokens_limit: usize,
253 pub tokens_remaining: usize,
254 pub tokens_reset: DateTime<Utc>,
255}
256
257impl RateLimitInfo {
258 fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
259 let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
260 let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
261 let tokens_remaining =
262 get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
263 let requests_remaining =
264 get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
265 let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
266 let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
267 let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
268 let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
269
270 Ok(Self {
271 requests_limit,
272 tokens_limit,
273 requests_remaining,
274 tokens_remaining,
275 requests_reset,
276 tokens_reset,
277 })
278 }
279}
280
281fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
282 Ok(headers
283 .get(key)
284 .ok_or_else(|| anyhow!("missing header `{key}`"))?
285 .to_str()?)
286}
287
288pub async fn stream_completion_with_rate_limit_info(
289 client: &dyn HttpClient,
290 api_url: &str,
291 api_key: &str,
292 request: Request,
293) -> Result<
294 (
295 BoxStream<'static, Result<Event, AnthropicError>>,
296 Option<RateLimitInfo>,
297 ),
298 AnthropicError,
299> {
300 let request = StreamingRequest {
301 base: request,
302 stream: true,
303 };
304 let uri = format!("{api_url}/v1/messages");
305 let model = Model::from_id(&request.base.model)?;
306 let request_builder = HttpRequest::builder()
307 .method(Method::POST)
308 .uri(uri)
309 .header("Anthropic-Version", "2023-06-01")
310 .header("Anthropic-Beta", model.beta_headers())
311 .header("X-Api-Key", api_key)
312 .header("Content-Type", "application/json");
313 let serialized_request =
314 serde_json::to_string(&request).context("failed to serialize request")?;
315 let request = request_builder
316 .body(AsyncBody::from(serialized_request))
317 .context("failed to construct request body")?;
318
319 let mut response = client
320 .send(request)
321 .await
322 .context("failed to send request to Anthropic")?;
323 if response.status().is_success() {
324 let rate_limits = RateLimitInfo::from_headers(response.headers());
325 let reader = BufReader::new(response.into_body());
326 let stream = reader
327 .lines()
328 .filter_map(|line| async move {
329 match line {
330 Ok(line) => {
331 let line = line.strip_prefix("data: ")?;
332 match serde_json::from_str(line) {
333 Ok(response) => Some(Ok(response)),
334 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
335 }
336 }
337 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
338 }
339 })
340 .boxed();
341 Ok((stream, rate_limits.log_err()))
342 } else {
343 let mut body = Vec::new();
344 response
345 .body_mut()
346 .read_to_end(&mut body)
347 .await
348 .context("failed to read response body")?;
349
350 let body_str =
351 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
352
353 match serde_json::from_str::<Event>(body_str) {
354 Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
355 Ok(_) => Err(AnthropicError::Other(anyhow!(
356 "Unexpected success response while expecting an error: '{body_str}'",
357 ))),
358 Err(_) => Err(AnthropicError::Other(anyhow!(
359 "Failed to connect to API: {} {}",
360 response.status(),
361 body_str,
362 ))),
363 }
364 }
365}
366
367pub async fn extract_tool_args_from_events(
368 tool_name: String,
369 mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
370) -> Result<impl Send + Stream<Item = Result<String>>> {
371 let mut tool_use_index = None;
372 while let Some(event) = events.next().await {
373 if let Event::ContentBlockStart {
374 index,
375 content_block: ResponseContent::ToolUse { name, .. },
376 } = event?
377 {
378 if name == tool_name {
379 tool_use_index = Some(index);
380 break;
381 }
382 }
383 }
384
385 let Some(tool_use_index) = tool_use_index else {
386 return Err(anyhow!("tool not used"));
387 };
388
389 Ok(events.filter_map(move |event| {
390 let result = match event {
391 Err(error) => Some(Err(error)),
392 Ok(Event::ContentBlockDelta { index, delta }) => match delta {
393 ContentDelta::TextDelta { .. } => None,
394 ContentDelta::InputJsonDelta { partial_json } => {
395 if index == tool_use_index {
396 Some(Ok(partial_json))
397 } else {
398 None
399 }
400 }
401 },
402 _ => None,
403 };
404
405 async move { result }
406 }))
407}
408
409#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
410#[serde(rename_all = "lowercase")]
411pub enum CacheControlType {
412 Ephemeral,
413}
414
415#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
416pub struct CacheControl {
417 #[serde(rename = "type")]
418 pub cache_type: CacheControlType,
419}
420
421#[derive(Debug, Serialize, Deserialize)]
422pub struct Message {
423 pub role: Role,
424 pub content: Vec<RequestContent>,
425}
426
427#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
428#[serde(rename_all = "lowercase")]
429pub enum Role {
430 User,
431 Assistant,
432}
433
434#[derive(Debug, Serialize, Deserialize)]
435#[serde(tag = "type")]
436pub enum RequestContent {
437 #[serde(rename = "text")]
438 Text {
439 text: String,
440 #[serde(skip_serializing_if = "Option::is_none")]
441 cache_control: Option<CacheControl>,
442 },
443 #[serde(rename = "image")]
444 Image {
445 source: ImageSource,
446 #[serde(skip_serializing_if = "Option::is_none")]
447 cache_control: Option<CacheControl>,
448 },
449 #[serde(rename = "tool_use")]
450 ToolUse {
451 id: String,
452 name: String,
453 input: serde_json::Value,
454 #[serde(skip_serializing_if = "Option::is_none")]
455 cache_control: Option<CacheControl>,
456 },
457 #[serde(rename = "tool_result")]
458 ToolResult {
459 tool_use_id: String,
460 is_error: bool,
461 content: String,
462 #[serde(skip_serializing_if = "Option::is_none")]
463 cache_control: Option<CacheControl>,
464 },
465}
466
467#[derive(Debug, Serialize, Deserialize)]
468#[serde(tag = "type")]
469pub enum ResponseContent {
470 #[serde(rename = "text")]
471 Text { text: String },
472 #[serde(rename = "tool_use")]
473 ToolUse {
474 id: String,
475 name: String,
476 input: serde_json::Value,
477 },
478}
479
480#[derive(Debug, Serialize, Deserialize)]
481pub struct ImageSource {
482 #[serde(rename = "type")]
483 pub source_type: String,
484 pub media_type: String,
485 pub data: String,
486}
487
488#[derive(Debug, Serialize, Deserialize)]
489pub struct Tool {
490 pub name: String,
491 pub description: String,
492 pub input_schema: serde_json::Value,
493}
494
495#[derive(Debug, Serialize, Deserialize)]
496#[serde(tag = "type", rename_all = "lowercase")]
497pub enum ToolChoice {
498 Auto,
499 Any,
500 Tool { name: String },
501}
502
503#[derive(Debug, Serialize, Deserialize)]
504pub struct Request {
505 pub model: String,
506 pub max_tokens: u32,
507 pub messages: Vec<Message>,
508 #[serde(default, skip_serializing_if = "Vec::is_empty")]
509 pub tools: Vec<Tool>,
510 #[serde(default, skip_serializing_if = "Option::is_none")]
511 pub tool_choice: Option<ToolChoice>,
512 #[serde(default, skip_serializing_if = "Option::is_none")]
513 pub system: Option<String>,
514 #[serde(default, skip_serializing_if = "Option::is_none")]
515 pub metadata: Option<Metadata>,
516 #[serde(default, skip_serializing_if = "Vec::is_empty")]
517 pub stop_sequences: Vec<String>,
518 #[serde(default, skip_serializing_if = "Option::is_none")]
519 pub temperature: Option<f32>,
520 #[serde(default, skip_serializing_if = "Option::is_none")]
521 pub top_k: Option<u32>,
522 #[serde(default, skip_serializing_if = "Option::is_none")]
523 pub top_p: Option<f32>,
524}
525
526#[derive(Debug, Serialize, Deserialize)]
527struct StreamingRequest {
528 #[serde(flatten)]
529 pub base: Request,
530 pub stream: bool,
531}
532
533#[derive(Debug, Serialize, Deserialize)]
534pub struct Metadata {
535 pub user_id: Option<String>,
536}
537
538#[derive(Debug, Serialize, Deserialize)]
539pub struct Usage {
540 #[serde(default, skip_serializing_if = "Option::is_none")]
541 pub input_tokens: Option<u32>,
542 #[serde(default, skip_serializing_if = "Option::is_none")]
543 pub output_tokens: Option<u32>,
544 #[serde(default, skip_serializing_if = "Option::is_none")]
545 pub cache_creation_input_tokens: Option<u32>,
546 #[serde(default, skip_serializing_if = "Option::is_none")]
547 pub cache_read_input_tokens: Option<u32>,
548}
549
550#[derive(Debug, Serialize, Deserialize)]
551pub struct Response {
552 pub id: String,
553 #[serde(rename = "type")]
554 pub response_type: String,
555 pub role: Role,
556 pub content: Vec<ResponseContent>,
557 pub model: String,
558 #[serde(default, skip_serializing_if = "Option::is_none")]
559 pub stop_reason: Option<String>,
560 #[serde(default, skip_serializing_if = "Option::is_none")]
561 pub stop_sequence: Option<String>,
562 pub usage: Usage,
563}
564
565#[derive(Debug, Serialize, Deserialize)]
566#[serde(tag = "type")]
567pub enum Event {
568 #[serde(rename = "message_start")]
569 MessageStart { message: Response },
570 #[serde(rename = "content_block_start")]
571 ContentBlockStart {
572 index: usize,
573 content_block: ResponseContent,
574 },
575 #[serde(rename = "content_block_delta")]
576 ContentBlockDelta { index: usize, delta: ContentDelta },
577 #[serde(rename = "content_block_stop")]
578 ContentBlockStop { index: usize },
579 #[serde(rename = "message_delta")]
580 MessageDelta { delta: MessageDelta, usage: Usage },
581 #[serde(rename = "message_stop")]
582 MessageStop,
583 #[serde(rename = "ping")]
584 Ping,
585 #[serde(rename = "error")]
586 Error { error: ApiError },
587}
588
589#[derive(Debug, Serialize, Deserialize)]
590#[serde(tag = "type")]
591pub enum ContentDelta {
592 #[serde(rename = "text_delta")]
593 TextDelta { text: String },
594 #[serde(rename = "input_json_delta")]
595 InputJsonDelta { partial_json: String },
596}
597
598#[derive(Debug, Serialize, Deserialize)]
599pub struct MessageDelta {
600 pub stop_reason: Option<String>,
601 pub stop_sequence: Option<String>,
602}
603
604#[derive(Error, Debug)]
605pub enum AnthropicError {
606 #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
607 ApiError(ApiError),
608 #[error("{0}")]
609 Other(#[from] anyhow::Error),
610}
611
612#[derive(Debug, Serialize, Deserialize)]
613pub struct ApiError {
614 #[serde(rename = "type")]
615 pub error_type: String,
616 pub message: String,
617}
618
619/// An Anthropic API error code.
620/// https://docs.anthropic.com/en/api/errors#http-errors
621#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
622#[strum(serialize_all = "snake_case")]
623pub enum ApiErrorCode {
624 /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
625 InvalidRequestError,
626 /// 401 - `authentication_error`: There's an issue with your API key.
627 AuthenticationError,
628 /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
629 PermissionError,
630 /// 404 - `not_found_error`: The requested resource was not found.
631 NotFoundError,
632 /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
633 RequestTooLarge,
634 /// 429 - `rate_limit_error`: Your account has hit a rate limit.
635 RateLimitError,
636 /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
637 ApiError,
638 /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
639 OverloadedError,
640}
641
642impl ApiError {
643 pub fn code(&self) -> Option<ApiErrorCode> {
644 ApiErrorCode::from_str(&self.error_type).ok()
645 }
646
647 pub fn is_rate_limit_error(&self) -> bool {
648 matches!(self.error_type.as_str(), "rate_limit_error")
649 }
650}