1mod supported_countries;
2
3use std::time::Duration;
4use std::{pin::Pin, str::FromStr};
5
6use anyhow::{anyhow, Context, Result};
7use chrono::{DateTime, Utc};
8use collections::HashMap;
9use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
10use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
11use isahc::config::Configurable;
12use isahc::http::{HeaderMap, HeaderValue};
13use serde::{Deserialize, Serialize};
14use strum::{EnumIter, EnumString};
15use thiserror::Error;
16use util::{maybe, ResultExt as _};
17
18pub use supported_countries::*;
19
20pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
21
22#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
23#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
24pub struct AnthropicModelCacheConfiguration {
25 pub min_total_token: usize,
26 pub should_speculate: bool,
27 pub max_cache_anchors: usize,
28}
29
30#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
31#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
32pub enum Model {
33 #[default]
34 #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
35 Claude3_5Sonnet,
36 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
37 Claude3Opus,
38 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
39 Claude3Sonnet,
40 #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
41 Claude3Haiku,
42 #[serde(rename = "custom")]
43 Custom {
44 name: String,
45 max_tokens: usize,
46 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
47 display_name: Option<String>,
48 /// Override this model with a different Anthropic model for tool calls.
49 tool_override: Option<String>,
50 /// Indicates whether this custom model supports caching.
51 cache_configuration: Option<AnthropicModelCacheConfiguration>,
52 max_output_tokens: Option<u32>,
53 },
54}
55
56impl Model {
57 pub fn from_id(id: &str) -> Result<Self> {
58 if id.starts_with("claude-3-5-sonnet") {
59 Ok(Self::Claude3_5Sonnet)
60 } else if id.starts_with("claude-3-opus") {
61 Ok(Self::Claude3Opus)
62 } else if id.starts_with("claude-3-sonnet") {
63 Ok(Self::Claude3Sonnet)
64 } else if id.starts_with("claude-3-haiku") {
65 Ok(Self::Claude3Haiku)
66 } else {
67 Err(anyhow!("invalid model id"))
68 }
69 }
70
71 pub fn id(&self) -> &str {
72 match self {
73 Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
74 Model::Claude3Opus => "claude-3-opus-20240229",
75 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
76 Model::Claude3Haiku => "claude-3-haiku-20240307",
77 Self::Custom { name, .. } => name,
78 }
79 }
80
81 pub fn display_name(&self) -> &str {
82 match self {
83 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
84 Self::Claude3Opus => "Claude 3 Opus",
85 Self::Claude3Sonnet => "Claude 3 Sonnet",
86 Self::Claude3Haiku => "Claude 3 Haiku",
87 Self::Custom {
88 name, display_name, ..
89 } => display_name.as_ref().unwrap_or(name),
90 }
91 }
92
93 pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
94 match self {
95 Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
96 min_total_token: 2_048,
97 should_speculate: true,
98 max_cache_anchors: 4,
99 }),
100 Self::Custom {
101 cache_configuration,
102 ..
103 } => cache_configuration.clone(),
104 _ => None,
105 }
106 }
107
108 pub fn max_token_count(&self) -> usize {
109 match self {
110 Self::Claude3_5Sonnet
111 | Self::Claude3Opus
112 | Self::Claude3Sonnet
113 | Self::Claude3Haiku => 200_000,
114 Self::Custom { max_tokens, .. } => *max_tokens,
115 }
116 }
117
118 pub fn max_output_tokens(&self) -> u32 {
119 match self {
120 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
121 Self::Claude3_5Sonnet => 8_192,
122 Self::Custom {
123 max_output_tokens, ..
124 } => max_output_tokens.unwrap_or(4_096),
125 }
126 }
127
128 pub fn tool_model_id(&self) -> &str {
129 if let Self::Custom {
130 tool_override: Some(tool_override),
131 ..
132 } = self
133 {
134 tool_override
135 } else {
136 self.id()
137 }
138 }
139}
140
141pub async fn complete(
142 client: &dyn HttpClient,
143 api_url: &str,
144 api_key: &str,
145 request: Request,
146) -> Result<Response, AnthropicError> {
147 let uri = format!("{api_url}/v1/messages");
148 let request_builder = HttpRequest::builder()
149 .method(Method::POST)
150 .uri(uri)
151 .header("Anthropic-Version", "2023-06-01")
152 .header(
153 "Anthropic-Beta",
154 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
155 )
156 .header("X-Api-Key", api_key)
157 .header("Content-Type", "application/json");
158
159 let serialized_request =
160 serde_json::to_string(&request).context("failed to serialize request")?;
161 let request = request_builder
162 .body(AsyncBody::from(serialized_request))
163 .context("failed to construct request body")?;
164
165 let mut response = client
166 .send(request)
167 .await
168 .context("failed to send request to Anthropic")?;
169 if response.status().is_success() {
170 let mut body = Vec::new();
171 response
172 .body_mut()
173 .read_to_end(&mut body)
174 .await
175 .context("failed to read response body")?;
176 let response_message: Response =
177 serde_json::from_slice(&body).context("failed to deserialize response body")?;
178 Ok(response_message)
179 } else {
180 let mut body = Vec::new();
181 response
182 .body_mut()
183 .read_to_end(&mut body)
184 .await
185 .context("failed to read response body")?;
186 let body_str =
187 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
188 Err(AnthropicError::Other(anyhow!(
189 "Failed to connect to API: {} {}",
190 response.status(),
191 body_str
192 )))
193 }
194}
195
196pub async fn stream_completion(
197 client: &dyn HttpClient,
198 api_url: &str,
199 api_key: &str,
200 request: Request,
201 low_speed_timeout: Option<Duration>,
202) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
203 stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
204 .await
205 .map(|output| output.0)
206}
207
208/// https://docs.anthropic.com/en/api/rate-limits#response-headers
209#[derive(Debug)]
210pub struct RateLimitInfo {
211 pub requests_limit: usize,
212 pub requests_remaining: usize,
213 pub requests_reset: DateTime<Utc>,
214 pub tokens_limit: usize,
215 pub tokens_remaining: usize,
216 pub tokens_reset: DateTime<Utc>,
217}
218
219impl RateLimitInfo {
220 fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
221 let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
222 let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
223 let tokens_remaining =
224 get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
225 let requests_remaining =
226 get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
227 let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
228 let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
229 let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
230 let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
231
232 Ok(Self {
233 requests_limit,
234 tokens_limit,
235 requests_remaining,
236 tokens_remaining,
237 requests_reset,
238 tokens_reset,
239 })
240 }
241}
242
243fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
244 Ok(headers
245 .get(key)
246 .ok_or_else(|| anyhow!("missing header `{key}`"))?
247 .to_str()?)
248}
249
250pub async fn stream_completion_with_rate_limit_info(
251 client: &dyn HttpClient,
252 api_url: &str,
253 api_key: &str,
254 request: Request,
255 low_speed_timeout: Option<Duration>,
256) -> Result<
257 (
258 BoxStream<'static, Result<Event, AnthropicError>>,
259 Option<RateLimitInfo>,
260 ),
261 AnthropicError,
262> {
263 let request = StreamingRequest {
264 base: request,
265 stream: true,
266 };
267 let uri = format!("{api_url}/v1/messages");
268 let mut request_builder = HttpRequest::builder()
269 .method(Method::POST)
270 .uri(uri)
271 .header("Anthropic-Version", "2023-06-01")
272 .header(
273 "Anthropic-Beta",
274 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
275 )
276 .header("X-Api-Key", api_key)
277 .header("Content-Type", "application/json");
278 if let Some(low_speed_timeout) = low_speed_timeout {
279 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
280 }
281 let serialized_request =
282 serde_json::to_string(&request).context("failed to serialize request")?;
283 let request = request_builder
284 .body(AsyncBody::from(serialized_request))
285 .context("failed to construct request body")?;
286
287 let mut response = client
288 .send(request)
289 .await
290 .context("failed to send request to Anthropic")?;
291 if response.status().is_success() {
292 let rate_limits = RateLimitInfo::from_headers(response.headers());
293 let reader = BufReader::new(response.into_body());
294 let stream = reader
295 .lines()
296 .filter_map(|line| async move {
297 match line {
298 Ok(line) => {
299 let line = line.strip_prefix("data: ")?;
300 match serde_json::from_str(line) {
301 Ok(response) => Some(Ok(response)),
302 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
303 }
304 }
305 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
306 }
307 })
308 .boxed();
309 Ok((stream, rate_limits.log_err()))
310 } else {
311 let mut body = Vec::new();
312 response
313 .body_mut()
314 .read_to_end(&mut body)
315 .await
316 .context("failed to read response body")?;
317
318 let body_str =
319 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
320
321 match serde_json::from_str::<Event>(body_str) {
322 Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
323 Ok(_) => Err(AnthropicError::Other(anyhow!(
324 "Unexpected success response while expecting an error: '{body_str}'",
325 ))),
326 Err(_) => Err(AnthropicError::Other(anyhow!(
327 "Failed to connect to API: {} {}",
328 response.status(),
329 body_str,
330 ))),
331 }
332 }
333}
334
335pub fn extract_content_from_events(
336 events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
337) -> impl Stream<Item = Result<ResponseContent, AnthropicError>> {
338 struct RawToolUse {
339 id: String,
340 name: String,
341 input_json: String,
342 }
343
344 struct State {
345 events: Pin<Box<dyn Send + Stream<Item = Result<Event, AnthropicError>>>>,
346 tool_uses_by_index: HashMap<usize, RawToolUse>,
347 }
348
349 futures::stream::unfold(
350 State {
351 events,
352 tool_uses_by_index: HashMap::default(),
353 },
354 |mut state| async move {
355 while let Some(event) = state.events.next().await {
356 match event {
357 Ok(event) => match event {
358 Event::ContentBlockStart {
359 index,
360 content_block,
361 } => match content_block {
362 ResponseContent::Text { text } => {
363 return Some((Some(Ok(ResponseContent::Text { text })), state));
364 }
365 ResponseContent::ToolUse { id, name, .. } => {
366 state.tool_uses_by_index.insert(
367 index,
368 RawToolUse {
369 id,
370 name,
371 input_json: String::new(),
372 },
373 );
374
375 return Some((None, state));
376 }
377 },
378 Event::ContentBlockDelta { index, delta } => match delta {
379 ContentDelta::TextDelta { text } => {
380 return Some((Some(Ok(ResponseContent::Text { text })), state));
381 }
382 ContentDelta::InputJsonDelta { partial_json } => {
383 if let Some(tool_use) = state.tool_uses_by_index.get_mut(&index) {
384 tool_use.input_json.push_str(&partial_json);
385 return Some((None, state));
386 }
387 }
388 },
389 Event::ContentBlockStop { index } => {
390 if let Some(tool_use) = state.tool_uses_by_index.remove(&index) {
391 return Some((
392 Some(maybe!({
393 Ok(ResponseContent::ToolUse {
394 id: tool_use.id,
395 name: tool_use.name,
396 input: serde_json::Value::from_str(
397 &tool_use.input_json,
398 )
399 .map_err(|err| anyhow!(err))?,
400 })
401 })),
402 state,
403 ));
404 }
405 }
406 Event::Error { error } => {
407 return Some((Some(Err(AnthropicError::ApiError(error))), state));
408 }
409 _ => {}
410 },
411 Err(err) => {
412 return Some((Some(Err(err)), state));
413 }
414 }
415 }
416
417 None
418 },
419 )
420 .filter_map(|event| async move { event })
421}
422
423pub async fn extract_tool_args_from_events(
424 tool_name: String,
425 mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
426) -> Result<impl Send + Stream<Item = Result<String>>> {
427 let mut tool_use_index = None;
428 while let Some(event) = events.next().await {
429 if let Event::ContentBlockStart {
430 index,
431 content_block,
432 } = event?
433 {
434 if let ResponseContent::ToolUse { name, .. } = content_block {
435 if name == tool_name {
436 tool_use_index = Some(index);
437 break;
438 }
439 }
440 }
441 }
442
443 let Some(tool_use_index) = tool_use_index else {
444 return Err(anyhow!("tool not used"));
445 };
446
447 Ok(events.filter_map(move |event| {
448 let result = match event {
449 Err(error) => Some(Err(error)),
450 Ok(Event::ContentBlockDelta { index, delta }) => match delta {
451 ContentDelta::TextDelta { .. } => None,
452 ContentDelta::InputJsonDelta { partial_json } => {
453 if index == tool_use_index {
454 Some(Ok(partial_json))
455 } else {
456 None
457 }
458 }
459 },
460 _ => None,
461 };
462
463 async move { result }
464 }))
465}
466
467#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
468#[serde(rename_all = "lowercase")]
469pub enum CacheControlType {
470 Ephemeral,
471}
472
473#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
474pub struct CacheControl {
475 #[serde(rename = "type")]
476 pub cache_type: CacheControlType,
477}
478
479#[derive(Debug, Serialize, Deserialize)]
480pub struct Message {
481 pub role: Role,
482 pub content: Vec<RequestContent>,
483}
484
485#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
486#[serde(rename_all = "lowercase")]
487pub enum Role {
488 User,
489 Assistant,
490}
491
492#[derive(Debug, Serialize, Deserialize)]
493#[serde(tag = "type")]
494pub enum RequestContent {
495 #[serde(rename = "text")]
496 Text {
497 text: String,
498 #[serde(skip_serializing_if = "Option::is_none")]
499 cache_control: Option<CacheControl>,
500 },
501 #[serde(rename = "image")]
502 Image {
503 source: ImageSource,
504 #[serde(skip_serializing_if = "Option::is_none")]
505 cache_control: Option<CacheControl>,
506 },
507 #[serde(rename = "tool_use")]
508 ToolUse {
509 id: String,
510 name: String,
511 input: serde_json::Value,
512 #[serde(skip_serializing_if = "Option::is_none")]
513 cache_control: Option<CacheControl>,
514 },
515}
516
517#[derive(Debug, Serialize, Deserialize)]
518#[serde(tag = "type")]
519pub enum ResponseContent {
520 #[serde(rename = "text")]
521 Text { text: String },
522 #[serde(rename = "tool_use")]
523 ToolUse {
524 id: String,
525 name: String,
526 input: serde_json::Value,
527 },
528}
529
530#[derive(Debug, Serialize, Deserialize)]
531pub struct ImageSource {
532 #[serde(rename = "type")]
533 pub source_type: String,
534 pub media_type: String,
535 pub data: String,
536}
537
538#[derive(Debug, Serialize, Deserialize)]
539pub struct Tool {
540 pub name: String,
541 pub description: String,
542 pub input_schema: serde_json::Value,
543}
544
545#[derive(Debug, Serialize, Deserialize)]
546#[serde(tag = "type", rename_all = "lowercase")]
547pub enum ToolChoice {
548 Auto,
549 Any,
550 Tool { name: String },
551}
552
553#[derive(Debug, Serialize, Deserialize)]
554pub struct Request {
555 pub model: String,
556 pub max_tokens: u32,
557 pub messages: Vec<Message>,
558 #[serde(default, skip_serializing_if = "Vec::is_empty")]
559 pub tools: Vec<Tool>,
560 #[serde(default, skip_serializing_if = "Option::is_none")]
561 pub tool_choice: Option<ToolChoice>,
562 #[serde(default, skip_serializing_if = "Option::is_none")]
563 pub system: Option<String>,
564 #[serde(default, skip_serializing_if = "Option::is_none")]
565 pub metadata: Option<Metadata>,
566 #[serde(default, skip_serializing_if = "Vec::is_empty")]
567 pub stop_sequences: Vec<String>,
568 #[serde(default, skip_serializing_if = "Option::is_none")]
569 pub temperature: Option<f32>,
570 #[serde(default, skip_serializing_if = "Option::is_none")]
571 pub top_k: Option<u32>,
572 #[serde(default, skip_serializing_if = "Option::is_none")]
573 pub top_p: Option<f32>,
574}
575
576#[derive(Debug, Serialize, Deserialize)]
577struct StreamingRequest {
578 #[serde(flatten)]
579 pub base: Request,
580 pub stream: bool,
581}
582
583#[derive(Debug, Serialize, Deserialize)]
584pub struct Metadata {
585 pub user_id: Option<String>,
586}
587
588#[derive(Debug, Serialize, Deserialize)]
589pub struct Usage {
590 #[serde(default, skip_serializing_if = "Option::is_none")]
591 pub input_tokens: Option<u32>,
592 #[serde(default, skip_serializing_if = "Option::is_none")]
593 pub output_tokens: Option<u32>,
594}
595
596#[derive(Debug, Serialize, Deserialize)]
597pub struct Response {
598 pub id: String,
599 #[serde(rename = "type")]
600 pub response_type: String,
601 pub role: Role,
602 pub content: Vec<ResponseContent>,
603 pub model: String,
604 #[serde(default, skip_serializing_if = "Option::is_none")]
605 pub stop_reason: Option<String>,
606 #[serde(default, skip_serializing_if = "Option::is_none")]
607 pub stop_sequence: Option<String>,
608 pub usage: Usage,
609}
610
611#[derive(Debug, Serialize, Deserialize)]
612#[serde(tag = "type")]
613pub enum Event {
614 #[serde(rename = "message_start")]
615 MessageStart { message: Response },
616 #[serde(rename = "content_block_start")]
617 ContentBlockStart {
618 index: usize,
619 content_block: ResponseContent,
620 },
621 #[serde(rename = "content_block_delta")]
622 ContentBlockDelta { index: usize, delta: ContentDelta },
623 #[serde(rename = "content_block_stop")]
624 ContentBlockStop { index: usize },
625 #[serde(rename = "message_delta")]
626 MessageDelta { delta: MessageDelta, usage: Usage },
627 #[serde(rename = "message_stop")]
628 MessageStop,
629 #[serde(rename = "ping")]
630 Ping,
631 #[serde(rename = "error")]
632 Error { error: ApiError },
633}
634
635#[derive(Debug, Serialize, Deserialize)]
636#[serde(tag = "type")]
637pub enum ContentDelta {
638 #[serde(rename = "text_delta")]
639 TextDelta { text: String },
640 #[serde(rename = "input_json_delta")]
641 InputJsonDelta { partial_json: String },
642}
643
644#[derive(Debug, Serialize, Deserialize)]
645pub struct MessageDelta {
646 pub stop_reason: Option<String>,
647 pub stop_sequence: Option<String>,
648}
649
650#[derive(Error, Debug)]
651pub enum AnthropicError {
652 #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
653 ApiError(ApiError),
654 #[error("{0}")]
655 Other(#[from] anyhow::Error),
656}
657
658#[derive(Debug, Serialize, Deserialize)]
659pub struct ApiError {
660 #[serde(rename = "type")]
661 pub error_type: String,
662 pub message: String,
663}
664
665/// An Anthropic API error code.
666/// https://docs.anthropic.com/en/api/errors#http-errors
667#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
668#[strum(serialize_all = "snake_case")]
669pub enum ApiErrorCode {
670 /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
671 InvalidRequestError,
672 /// 401 - `authentication_error`: There's an issue with your API key.
673 AuthenticationError,
674 /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
675 PermissionError,
676 /// 404 - `not_found_error`: The requested resource was not found.
677 NotFoundError,
678 /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
679 RequestTooLarge,
680 /// 429 - `rate_limit_error`: Your account has hit a rate limit.
681 RateLimitError,
682 /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
683 ApiError,
684 /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
685 OverloadedError,
686}
687
688impl ApiError {
689 pub fn code(&self) -> Option<ApiErrorCode> {
690 ApiErrorCode::from_str(&self.error_type).ok()
691 }
692
693 pub fn is_rate_limit_error(&self) -> bool {
694 match self.error_type.as_str() {
695 "rate_limit_error" => true,
696 _ => false,
697 }
698 }
699}