1mod supported_countries;
2
3use std::time::Duration;
4use std::{pin::Pin, str::FromStr};
5
6use anyhow::{anyhow, Context, Result};
7use chrono::{DateTime, Utc};
8use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
9use http_client::http::{HeaderMap, HeaderValue};
10use http_client::{AsyncBody, HttpClient, HttpRequestExt, Method, Request as HttpRequest};
11use serde::{Deserialize, Serialize};
12use strum::{EnumIter, EnumString};
13use thiserror::Error;
14use util::ResultExt as _;
15
16pub use supported_countries::*;
17
18pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
19
20#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
21#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
22pub struct AnthropicModelCacheConfiguration {
23 pub min_total_token: usize,
24 pub should_speculate: bool,
25 pub max_cache_anchors: usize,
26}
27
28#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
29#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
30pub enum Model {
31 #[default]
32 #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
33 Claude3_5Sonnet,
34 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
35 Claude3Opus,
36 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
37 Claude3Sonnet,
38 #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-latest")]
39 Claude3Haiku,
40 #[serde(rename = "custom")]
41 Custom {
42 name: String,
43 max_tokens: usize,
44 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
45 display_name: Option<String>,
46 /// Override this model with a different Anthropic model for tool calls.
47 tool_override: Option<String>,
48 /// Indicates whether this custom model supports caching.
49 cache_configuration: Option<AnthropicModelCacheConfiguration>,
50 max_output_tokens: Option<u32>,
51 default_temperature: Option<f32>,
52 },
53}
54
55impl Model {
56 pub fn from_id(id: &str) -> Result<Self> {
57 if id.starts_with("claude-3-5-sonnet") {
58 Ok(Self::Claude3_5Sonnet)
59 } else if id.starts_with("claude-3-opus") {
60 Ok(Self::Claude3Opus)
61 } else if id.starts_with("claude-3-sonnet") {
62 Ok(Self::Claude3Sonnet)
63 } else if id.starts_with("claude-3-haiku") {
64 Ok(Self::Claude3Haiku)
65 } else {
66 Err(anyhow!("invalid model id"))
67 }
68 }
69
70 pub fn id(&self) -> &str {
71 match self {
72 Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
73 Model::Claude3Opus => "claude-3-opus-latest",
74 Model::Claude3Sonnet => "claude-3-sonnet-latest",
75 Model::Claude3Haiku => "claude-3-haiku-latest",
76 Self::Custom { name, .. } => name,
77 }
78 }
79
80 pub fn display_name(&self) -> &str {
81 match self {
82 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
83 Self::Claude3Opus => "Claude 3 Opus",
84 Self::Claude3Sonnet => "Claude 3 Sonnet",
85 Self::Claude3Haiku => "Claude 3 Haiku",
86 Self::Custom {
87 name, display_name, ..
88 } => display_name.as_ref().unwrap_or(name),
89 }
90 }
91
92 pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
93 match self {
94 Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
95 min_total_token: 2_048,
96 should_speculate: true,
97 max_cache_anchors: 4,
98 }),
99 Self::Custom {
100 cache_configuration,
101 ..
102 } => cache_configuration.clone(),
103 _ => None,
104 }
105 }
106
107 pub fn max_token_count(&self) -> usize {
108 match self {
109 Self::Claude3_5Sonnet
110 | Self::Claude3Opus
111 | Self::Claude3Sonnet
112 | Self::Claude3Haiku => 200_000,
113 Self::Custom { max_tokens, .. } => *max_tokens,
114 }
115 }
116
117 pub fn max_output_tokens(&self) -> u32 {
118 match self {
119 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
120 Self::Claude3_5Sonnet => 8_192,
121 Self::Custom {
122 max_output_tokens, ..
123 } => max_output_tokens.unwrap_or(4_096),
124 }
125 }
126
127 pub fn default_temperature(&self) -> f32 {
128 match self {
129 Self::Claude3_5Sonnet
130 | Self::Claude3Opus
131 | Self::Claude3Sonnet
132 | Self::Claude3Haiku => 1.0,
133 Self::Custom {
134 default_temperature,
135 ..
136 } => default_temperature.unwrap_or(1.0),
137 }
138 }
139
140 pub fn tool_model_id(&self) -> &str {
141 if let Self::Custom {
142 tool_override: Some(tool_override),
143 ..
144 } = self
145 {
146 tool_override
147 } else {
148 self.id()
149 }
150 }
151}
152
153pub async fn complete(
154 client: &dyn HttpClient,
155 api_url: &str,
156 api_key: &str,
157 request: Request,
158) -> Result<Response, AnthropicError> {
159 let uri = format!("{api_url}/v1/messages");
160 let request_builder = HttpRequest::builder()
161 .method(Method::POST)
162 .uri(uri)
163 .header("Anthropic-Version", "2023-06-01")
164 .header(
165 "Anthropic-Beta",
166 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
167 )
168 .header("X-Api-Key", api_key)
169 .header("Content-Type", "application/json");
170
171 let serialized_request =
172 serde_json::to_string(&request).context("failed to serialize request")?;
173 let request = request_builder
174 .body(AsyncBody::from(serialized_request))
175 .context("failed to construct request body")?;
176
177 let mut response = client
178 .send(request)
179 .await
180 .context("failed to send request to Anthropic")?;
181 if response.status().is_success() {
182 let mut body = Vec::new();
183 response
184 .body_mut()
185 .read_to_end(&mut body)
186 .await
187 .context("failed to read response body")?;
188 let response_message: Response =
189 serde_json::from_slice(&body).context("failed to deserialize response body")?;
190 Ok(response_message)
191 } else {
192 let mut body = Vec::new();
193 response
194 .body_mut()
195 .read_to_end(&mut body)
196 .await
197 .context("failed to read response body")?;
198 let body_str =
199 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
200 Err(AnthropicError::Other(anyhow!(
201 "Failed to connect to API: {} {}",
202 response.status(),
203 body_str
204 )))
205 }
206}
207
208pub async fn stream_completion(
209 client: &dyn HttpClient,
210 api_url: &str,
211 api_key: &str,
212 request: Request,
213 low_speed_timeout: Option<Duration>,
214) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
215 stream_completion_with_rate_limit_info(client, api_url, api_key, request, low_speed_timeout)
216 .await
217 .map(|output| output.0)
218}
219
220/// https://docs.anthropic.com/en/api/rate-limits#response-headers
221#[derive(Debug)]
222pub struct RateLimitInfo {
223 pub requests_limit: usize,
224 pub requests_remaining: usize,
225 pub requests_reset: DateTime<Utc>,
226 pub tokens_limit: usize,
227 pub tokens_remaining: usize,
228 pub tokens_reset: DateTime<Utc>,
229}
230
231impl RateLimitInfo {
232 fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
233 let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
234 let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
235 let tokens_remaining =
236 get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
237 let requests_remaining =
238 get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
239 let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
240 let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
241 let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
242 let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
243
244 Ok(Self {
245 requests_limit,
246 tokens_limit,
247 requests_remaining,
248 tokens_remaining,
249 requests_reset,
250 tokens_reset,
251 })
252 }
253}
254
255fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
256 Ok(headers
257 .get(key)
258 .ok_or_else(|| anyhow!("missing header `{key}`"))?
259 .to_str()?)
260}
261
262pub async fn stream_completion_with_rate_limit_info(
263 client: &dyn HttpClient,
264 api_url: &str,
265 api_key: &str,
266 request: Request,
267 low_speed_timeout: Option<Duration>,
268) -> Result<
269 (
270 BoxStream<'static, Result<Event, AnthropicError>>,
271 Option<RateLimitInfo>,
272 ),
273 AnthropicError,
274> {
275 let request = StreamingRequest {
276 base: request,
277 stream: true,
278 };
279 let uri = format!("{api_url}/v1/messages");
280 let mut request_builder = HttpRequest::builder()
281 .method(Method::POST)
282 .uri(uri)
283 .header("Anthropic-Version", "2023-06-01")
284 .header(
285 "Anthropic-Beta",
286 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
287 )
288 .header("X-Api-Key", api_key)
289 .header("Content-Type", "application/json");
290 if let Some(low_speed_timeout) = low_speed_timeout {
291 request_builder = request_builder.read_timeout(low_speed_timeout);
292 }
293 let serialized_request =
294 serde_json::to_string(&request).context("failed to serialize request")?;
295 let request = request_builder
296 .body(AsyncBody::from(serialized_request))
297 .context("failed to construct request body")?;
298
299 let mut response = client
300 .send(request)
301 .await
302 .context("failed to send request to Anthropic")?;
303 if response.status().is_success() {
304 let rate_limits = RateLimitInfo::from_headers(response.headers());
305 let reader = BufReader::new(response.into_body());
306 let stream = reader
307 .lines()
308 .filter_map(|line| async move {
309 match line {
310 Ok(line) => {
311 let line = line.strip_prefix("data: ")?;
312 match serde_json::from_str(line) {
313 Ok(response) => Some(Ok(response)),
314 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
315 }
316 }
317 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
318 }
319 })
320 .boxed();
321 Ok((stream, rate_limits.log_err()))
322 } else {
323 let mut body = Vec::new();
324 response
325 .body_mut()
326 .read_to_end(&mut body)
327 .await
328 .context("failed to read response body")?;
329
330 let body_str =
331 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
332
333 match serde_json::from_str::<Event>(body_str) {
334 Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
335 Ok(_) => Err(AnthropicError::Other(anyhow!(
336 "Unexpected success response while expecting an error: '{body_str}'",
337 ))),
338 Err(_) => Err(AnthropicError::Other(anyhow!(
339 "Failed to connect to API: {} {}",
340 response.status(),
341 body_str,
342 ))),
343 }
344 }
345}
346
347pub async fn extract_tool_args_from_events(
348 tool_name: String,
349 mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
350) -> Result<impl Send + Stream<Item = Result<String>>> {
351 let mut tool_use_index = None;
352 while let Some(event) = events.next().await {
353 if let Event::ContentBlockStart {
354 index,
355 content_block: ResponseContent::ToolUse { name, .. },
356 } = event?
357 {
358 if name == tool_name {
359 tool_use_index = Some(index);
360 break;
361 }
362 }
363 }
364
365 let Some(tool_use_index) = tool_use_index else {
366 return Err(anyhow!("tool not used"));
367 };
368
369 Ok(events.filter_map(move |event| {
370 let result = match event {
371 Err(error) => Some(Err(error)),
372 Ok(Event::ContentBlockDelta { index, delta }) => match delta {
373 ContentDelta::TextDelta { .. } => None,
374 ContentDelta::InputJsonDelta { partial_json } => {
375 if index == tool_use_index {
376 Some(Ok(partial_json))
377 } else {
378 None
379 }
380 }
381 },
382 _ => None,
383 };
384
385 async move { result }
386 }))
387}
388
389#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
390#[serde(rename_all = "lowercase")]
391pub enum CacheControlType {
392 Ephemeral,
393}
394
395#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
396pub struct CacheControl {
397 #[serde(rename = "type")]
398 pub cache_type: CacheControlType,
399}
400
401#[derive(Debug, Serialize, Deserialize)]
402pub struct Message {
403 pub role: Role,
404 pub content: Vec<RequestContent>,
405}
406
407#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
408#[serde(rename_all = "lowercase")]
409pub enum Role {
410 User,
411 Assistant,
412}
413
414#[derive(Debug, Serialize, Deserialize)]
415#[serde(tag = "type")]
416pub enum RequestContent {
417 #[serde(rename = "text")]
418 Text {
419 text: String,
420 #[serde(skip_serializing_if = "Option::is_none")]
421 cache_control: Option<CacheControl>,
422 },
423 #[serde(rename = "image")]
424 Image {
425 source: ImageSource,
426 #[serde(skip_serializing_if = "Option::is_none")]
427 cache_control: Option<CacheControl>,
428 },
429 #[serde(rename = "tool_use")]
430 ToolUse {
431 id: String,
432 name: String,
433 input: serde_json::Value,
434 #[serde(skip_serializing_if = "Option::is_none")]
435 cache_control: Option<CacheControl>,
436 },
437 #[serde(rename = "tool_result")]
438 ToolResult {
439 tool_use_id: String,
440 is_error: bool,
441 content: String,
442 #[serde(skip_serializing_if = "Option::is_none")]
443 cache_control: Option<CacheControl>,
444 },
445}
446
447#[derive(Debug, Serialize, Deserialize)]
448#[serde(tag = "type")]
449pub enum ResponseContent {
450 #[serde(rename = "text")]
451 Text { text: String },
452 #[serde(rename = "tool_use")]
453 ToolUse {
454 id: String,
455 name: String,
456 input: serde_json::Value,
457 },
458}
459
460#[derive(Debug, Serialize, Deserialize)]
461pub struct ImageSource {
462 #[serde(rename = "type")]
463 pub source_type: String,
464 pub media_type: String,
465 pub data: String,
466}
467
468#[derive(Debug, Serialize, Deserialize)]
469pub struct Tool {
470 pub name: String,
471 pub description: String,
472 pub input_schema: serde_json::Value,
473}
474
475#[derive(Debug, Serialize, Deserialize)]
476#[serde(tag = "type", rename_all = "lowercase")]
477pub enum ToolChoice {
478 Auto,
479 Any,
480 Tool { name: String },
481}
482
483#[derive(Debug, Serialize, Deserialize)]
484pub struct Request {
485 pub model: String,
486 pub max_tokens: u32,
487 pub messages: Vec<Message>,
488 #[serde(default, skip_serializing_if = "Vec::is_empty")]
489 pub tools: Vec<Tool>,
490 #[serde(default, skip_serializing_if = "Option::is_none")]
491 pub tool_choice: Option<ToolChoice>,
492 #[serde(default, skip_serializing_if = "Option::is_none")]
493 pub system: Option<String>,
494 #[serde(default, skip_serializing_if = "Option::is_none")]
495 pub metadata: Option<Metadata>,
496 #[serde(default, skip_serializing_if = "Vec::is_empty")]
497 pub stop_sequences: Vec<String>,
498 #[serde(default, skip_serializing_if = "Option::is_none")]
499 pub temperature: Option<f32>,
500 #[serde(default, skip_serializing_if = "Option::is_none")]
501 pub top_k: Option<u32>,
502 #[serde(default, skip_serializing_if = "Option::is_none")]
503 pub top_p: Option<f32>,
504}
505
506#[derive(Debug, Serialize, Deserialize)]
507struct StreamingRequest {
508 #[serde(flatten)]
509 pub base: Request,
510 pub stream: bool,
511}
512
513#[derive(Debug, Serialize, Deserialize)]
514pub struct Metadata {
515 pub user_id: Option<String>,
516}
517
518#[derive(Debug, Serialize, Deserialize)]
519pub struct Usage {
520 #[serde(default, skip_serializing_if = "Option::is_none")]
521 pub input_tokens: Option<u32>,
522 #[serde(default, skip_serializing_if = "Option::is_none")]
523 pub output_tokens: Option<u32>,
524 #[serde(default, skip_serializing_if = "Option::is_none")]
525 pub cache_creation_input_tokens: Option<u32>,
526 #[serde(default, skip_serializing_if = "Option::is_none")]
527 pub cache_read_input_tokens: Option<u32>,
528}
529
530#[derive(Debug, Serialize, Deserialize)]
531pub struct Response {
532 pub id: String,
533 #[serde(rename = "type")]
534 pub response_type: String,
535 pub role: Role,
536 pub content: Vec<ResponseContent>,
537 pub model: String,
538 #[serde(default, skip_serializing_if = "Option::is_none")]
539 pub stop_reason: Option<String>,
540 #[serde(default, skip_serializing_if = "Option::is_none")]
541 pub stop_sequence: Option<String>,
542 pub usage: Usage,
543}
544
545#[derive(Debug, Serialize, Deserialize)]
546#[serde(tag = "type")]
547pub enum Event {
548 #[serde(rename = "message_start")]
549 MessageStart { message: Response },
550 #[serde(rename = "content_block_start")]
551 ContentBlockStart {
552 index: usize,
553 content_block: ResponseContent,
554 },
555 #[serde(rename = "content_block_delta")]
556 ContentBlockDelta { index: usize, delta: ContentDelta },
557 #[serde(rename = "content_block_stop")]
558 ContentBlockStop { index: usize },
559 #[serde(rename = "message_delta")]
560 MessageDelta { delta: MessageDelta, usage: Usage },
561 #[serde(rename = "message_stop")]
562 MessageStop,
563 #[serde(rename = "ping")]
564 Ping,
565 #[serde(rename = "error")]
566 Error { error: ApiError },
567}
568
569#[derive(Debug, Serialize, Deserialize)]
570#[serde(tag = "type")]
571pub enum ContentDelta {
572 #[serde(rename = "text_delta")]
573 TextDelta { text: String },
574 #[serde(rename = "input_json_delta")]
575 InputJsonDelta { partial_json: String },
576}
577
578#[derive(Debug, Serialize, Deserialize)]
579pub struct MessageDelta {
580 pub stop_reason: Option<String>,
581 pub stop_sequence: Option<String>,
582}
583
584#[derive(Error, Debug)]
585pub enum AnthropicError {
586 #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
587 ApiError(ApiError),
588 #[error("{0}")]
589 Other(#[from] anyhow::Error),
590}
591
592#[derive(Debug, Serialize, Deserialize)]
593pub struct ApiError {
594 #[serde(rename = "type")]
595 pub error_type: String,
596 pub message: String,
597}
598
599/// An Anthropic API error code.
600/// https://docs.anthropic.com/en/api/errors#http-errors
601#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
602#[strum(serialize_all = "snake_case")]
603pub enum ApiErrorCode {
604 /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
605 InvalidRequestError,
606 /// 401 - `authentication_error`: There's an issue with your API key.
607 AuthenticationError,
608 /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
609 PermissionError,
610 /// 404 - `not_found_error`: The requested resource was not found.
611 NotFoundError,
612 /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
613 RequestTooLarge,
614 /// 429 - `rate_limit_error`: Your account has hit a rate limit.
615 RateLimitError,
616 /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
617 ApiError,
618 /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
619 OverloadedError,
620}
621
622impl ApiError {
623 pub fn code(&self) -> Option<ApiErrorCode> {
624 ApiErrorCode::from_str(&self.error_type).ok()
625 }
626
627 pub fn is_rate_limit_error(&self) -> bool {
628 matches!(self.error_type.as_str(), "rate_limit_error")
629 }
630}