1mod supported_countries;
2
3use std::{pin::Pin, str::FromStr};
4
5use anyhow::{anyhow, Context, Result};
6use chrono::{DateTime, Utc};
7use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
8use http_client::http::{HeaderMap, HeaderValue};
9use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
10use serde::{Deserialize, Serialize};
11use strum::{EnumIter, EnumString};
12use thiserror::Error;
13use util::ResultExt as _;
14
15pub use supported_countries::*;
16
17pub const ANTHROPIC_API_URL: &str = "https://api.anthropic.com";
18
19#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
20#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
21pub struct AnthropicModelCacheConfiguration {
22 pub min_total_token: usize,
23 pub should_speculate: bool,
24 pub max_cache_anchors: usize,
25}
26
27#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
28#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
29pub enum Model {
30 #[default]
31 #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-latest")]
32 Claude3_5Sonnet,
33 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-latest")]
34 Claude3Opus,
35 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-latest")]
36 Claude3Sonnet,
37 #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-latest")]
38 Claude3Haiku,
39 #[serde(rename = "custom")]
40 Custom {
41 name: String,
42 max_tokens: usize,
43 /// The name displayed in the UI, such as in the assistant panel model dropdown menu.
44 display_name: Option<String>,
45 /// Override this model with a different Anthropic model for tool calls.
46 tool_override: Option<String>,
47 /// Indicates whether this custom model supports caching.
48 cache_configuration: Option<AnthropicModelCacheConfiguration>,
49 max_output_tokens: Option<u32>,
50 default_temperature: Option<f32>,
51 },
52}
53
54impl Model {
55 pub fn from_id(id: &str) -> Result<Self> {
56 if id.starts_with("claude-3-5-sonnet") {
57 Ok(Self::Claude3_5Sonnet)
58 } else if id.starts_with("claude-3-opus") {
59 Ok(Self::Claude3Opus)
60 } else if id.starts_with("claude-3-sonnet") {
61 Ok(Self::Claude3Sonnet)
62 } else if id.starts_with("claude-3-haiku") {
63 Ok(Self::Claude3Haiku)
64 } else {
65 Err(anyhow!("invalid model id"))
66 }
67 }
68
69 pub fn id(&self) -> &str {
70 match self {
71 Model::Claude3_5Sonnet => "claude-3-5-sonnet-latest",
72 Model::Claude3Opus => "claude-3-opus-latest",
73 Model::Claude3Sonnet => "claude-3-sonnet-latest",
74 Model::Claude3Haiku => "claude-3-haiku-latest",
75 Self::Custom { name, .. } => name,
76 }
77 }
78
79 pub fn display_name(&self) -> &str {
80 match self {
81 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
82 Self::Claude3Opus => "Claude 3 Opus",
83 Self::Claude3Sonnet => "Claude 3 Sonnet",
84 Self::Claude3Haiku => "Claude 3 Haiku",
85 Self::Custom {
86 name, display_name, ..
87 } => display_name.as_ref().unwrap_or(name),
88 }
89 }
90
91 pub fn cache_configuration(&self) -> Option<AnthropicModelCacheConfiguration> {
92 match self {
93 Self::Claude3_5Sonnet | Self::Claude3Haiku => Some(AnthropicModelCacheConfiguration {
94 min_total_token: 2_048,
95 should_speculate: true,
96 max_cache_anchors: 4,
97 }),
98 Self::Custom {
99 cache_configuration,
100 ..
101 } => cache_configuration.clone(),
102 _ => None,
103 }
104 }
105
106 pub fn max_token_count(&self) -> usize {
107 match self {
108 Self::Claude3_5Sonnet
109 | Self::Claude3Opus
110 | Self::Claude3Sonnet
111 | Self::Claude3Haiku => 200_000,
112 Self::Custom { max_tokens, .. } => *max_tokens,
113 }
114 }
115
116 pub fn max_output_tokens(&self) -> u32 {
117 match self {
118 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 4_096,
119 Self::Claude3_5Sonnet => 8_192,
120 Self::Custom {
121 max_output_tokens, ..
122 } => max_output_tokens.unwrap_or(4_096),
123 }
124 }
125
126 pub fn default_temperature(&self) -> f32 {
127 match self {
128 Self::Claude3_5Sonnet
129 | Self::Claude3Opus
130 | Self::Claude3Sonnet
131 | Self::Claude3Haiku => 1.0,
132 Self::Custom {
133 default_temperature,
134 ..
135 } => default_temperature.unwrap_or(1.0),
136 }
137 }
138
139 pub fn tool_model_id(&self) -> &str {
140 if let Self::Custom {
141 tool_override: Some(tool_override),
142 ..
143 } = self
144 {
145 tool_override
146 } else {
147 self.id()
148 }
149 }
150}
151
152pub async fn complete(
153 client: &dyn HttpClient,
154 api_url: &str,
155 api_key: &str,
156 request: Request,
157) -> Result<Response, AnthropicError> {
158 let uri = format!("{api_url}/v1/messages");
159 let request_builder = HttpRequest::builder()
160 .method(Method::POST)
161 .uri(uri)
162 .header("Anthropic-Version", "2023-06-01")
163 .header("Anthropic-Beta", "prompt-caching-2024-07-31")
164 .header("X-Api-Key", api_key)
165 .header("Content-Type", "application/json");
166
167 let serialized_request =
168 serde_json::to_string(&request).context("failed to serialize request")?;
169 let request = request_builder
170 .body(AsyncBody::from(serialized_request))
171 .context("failed to construct request body")?;
172
173 let mut response = client
174 .send(request)
175 .await
176 .context("failed to send request to Anthropic")?;
177 if response.status().is_success() {
178 let mut body = Vec::new();
179 response
180 .body_mut()
181 .read_to_end(&mut body)
182 .await
183 .context("failed to read response body")?;
184 let response_message: Response =
185 serde_json::from_slice(&body).context("failed to deserialize response body")?;
186 Ok(response_message)
187 } else {
188 let mut body = Vec::new();
189 response
190 .body_mut()
191 .read_to_end(&mut body)
192 .await
193 .context("failed to read response body")?;
194 let body_str =
195 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
196 Err(AnthropicError::Other(anyhow!(
197 "Failed to connect to API: {} {}",
198 response.status(),
199 body_str
200 )))
201 }
202}
203
204pub async fn stream_completion(
205 client: &dyn HttpClient,
206 api_url: &str,
207 api_key: &str,
208 request: Request,
209) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
210 stream_completion_with_rate_limit_info(client, api_url, api_key, request)
211 .await
212 .map(|output| output.0)
213}
214
215/// https://docs.anthropic.com/en/api/rate-limits#response-headers
216#[derive(Debug)]
217pub struct RateLimitInfo {
218 pub requests_limit: usize,
219 pub requests_remaining: usize,
220 pub requests_reset: DateTime<Utc>,
221 pub tokens_limit: usize,
222 pub tokens_remaining: usize,
223 pub tokens_reset: DateTime<Utc>,
224}
225
226impl RateLimitInfo {
227 fn from_headers(headers: &HeaderMap<HeaderValue>) -> Result<Self> {
228 let tokens_limit = get_header("anthropic-ratelimit-tokens-limit", headers)?.parse()?;
229 let requests_limit = get_header("anthropic-ratelimit-requests-limit", headers)?.parse()?;
230 let tokens_remaining =
231 get_header("anthropic-ratelimit-tokens-remaining", headers)?.parse()?;
232 let requests_remaining =
233 get_header("anthropic-ratelimit-requests-remaining", headers)?.parse()?;
234 let requests_reset = get_header("anthropic-ratelimit-requests-reset", headers)?;
235 let tokens_reset = get_header("anthropic-ratelimit-tokens-reset", headers)?;
236 let requests_reset = DateTime::parse_from_rfc3339(requests_reset)?.to_utc();
237 let tokens_reset = DateTime::parse_from_rfc3339(tokens_reset)?.to_utc();
238
239 Ok(Self {
240 requests_limit,
241 tokens_limit,
242 requests_remaining,
243 tokens_remaining,
244 requests_reset,
245 tokens_reset,
246 })
247 }
248}
249
250fn get_header<'a>(key: &str, headers: &'a HeaderMap) -> Result<&'a str, anyhow::Error> {
251 Ok(headers
252 .get(key)
253 .ok_or_else(|| anyhow!("missing header `{key}`"))?
254 .to_str()?)
255}
256
257pub async fn stream_completion_with_rate_limit_info(
258 client: &dyn HttpClient,
259 api_url: &str,
260 api_key: &str,
261 request: Request,
262) -> Result<
263 (
264 BoxStream<'static, Result<Event, AnthropicError>>,
265 Option<RateLimitInfo>,
266 ),
267 AnthropicError,
268> {
269 let request = StreamingRequest {
270 base: request,
271 stream: true,
272 };
273 let uri = format!("{api_url}/v1/messages");
274 let request_builder = HttpRequest::builder()
275 .method(Method::POST)
276 .uri(uri)
277 .header("Anthropic-Version", "2023-06-01")
278 .header(
279 "Anthropic-Beta",
280 "tools-2024-04-04,prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15",
281 )
282 .header("X-Api-Key", api_key)
283 .header("Content-Type", "application/json");
284 let serialized_request =
285 serde_json::to_string(&request).context("failed to serialize request")?;
286 let request = request_builder
287 .body(AsyncBody::from(serialized_request))
288 .context("failed to construct request body")?;
289
290 let mut response = client
291 .send(request)
292 .await
293 .context("failed to send request to Anthropic")?;
294 if response.status().is_success() {
295 let rate_limits = RateLimitInfo::from_headers(response.headers());
296 let reader = BufReader::new(response.into_body());
297 let stream = reader
298 .lines()
299 .filter_map(|line| async move {
300 match line {
301 Ok(line) => {
302 let line = line.strip_prefix("data: ")?;
303 match serde_json::from_str(line) {
304 Ok(response) => Some(Ok(response)),
305 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
306 }
307 }
308 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
309 }
310 })
311 .boxed();
312 Ok((stream, rate_limits.log_err()))
313 } else {
314 let mut body = Vec::new();
315 response
316 .body_mut()
317 .read_to_end(&mut body)
318 .await
319 .context("failed to read response body")?;
320
321 let body_str =
322 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
323
324 match serde_json::from_str::<Event>(body_str) {
325 Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
326 Ok(_) => Err(AnthropicError::Other(anyhow!(
327 "Unexpected success response while expecting an error: '{body_str}'",
328 ))),
329 Err(_) => Err(AnthropicError::Other(anyhow!(
330 "Failed to connect to API: {} {}",
331 response.status(),
332 body_str,
333 ))),
334 }
335 }
336}
337
338pub async fn extract_tool_args_from_events(
339 tool_name: String,
340 mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
341) -> Result<impl Send + Stream<Item = Result<String>>> {
342 let mut tool_use_index = None;
343 while let Some(event) = events.next().await {
344 if let Event::ContentBlockStart {
345 index,
346 content_block: ResponseContent::ToolUse { name, .. },
347 } = event?
348 {
349 if name == tool_name {
350 tool_use_index = Some(index);
351 break;
352 }
353 }
354 }
355
356 let Some(tool_use_index) = tool_use_index else {
357 return Err(anyhow!("tool not used"));
358 };
359
360 Ok(events.filter_map(move |event| {
361 let result = match event {
362 Err(error) => Some(Err(error)),
363 Ok(Event::ContentBlockDelta { index, delta }) => match delta {
364 ContentDelta::TextDelta { .. } => None,
365 ContentDelta::InputJsonDelta { partial_json } => {
366 if index == tool_use_index {
367 Some(Ok(partial_json))
368 } else {
369 None
370 }
371 }
372 },
373 _ => None,
374 };
375
376 async move { result }
377 }))
378}
379
380#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
381#[serde(rename_all = "lowercase")]
382pub enum CacheControlType {
383 Ephemeral,
384}
385
386#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
387pub struct CacheControl {
388 #[serde(rename = "type")]
389 pub cache_type: CacheControlType,
390}
391
392#[derive(Debug, Serialize, Deserialize)]
393pub struct Message {
394 pub role: Role,
395 pub content: Vec<RequestContent>,
396}
397
398#[derive(Debug, Serialize, Deserialize, Eq, PartialEq, Hash)]
399#[serde(rename_all = "lowercase")]
400pub enum Role {
401 User,
402 Assistant,
403}
404
405#[derive(Debug, Serialize, Deserialize)]
406#[serde(tag = "type")]
407pub enum RequestContent {
408 #[serde(rename = "text")]
409 Text {
410 text: String,
411 #[serde(skip_serializing_if = "Option::is_none")]
412 cache_control: Option<CacheControl>,
413 },
414 #[serde(rename = "image")]
415 Image {
416 source: ImageSource,
417 #[serde(skip_serializing_if = "Option::is_none")]
418 cache_control: Option<CacheControl>,
419 },
420 #[serde(rename = "tool_use")]
421 ToolUse {
422 id: String,
423 name: String,
424 input: serde_json::Value,
425 #[serde(skip_serializing_if = "Option::is_none")]
426 cache_control: Option<CacheControl>,
427 },
428 #[serde(rename = "tool_result")]
429 ToolResult {
430 tool_use_id: String,
431 is_error: bool,
432 content: String,
433 #[serde(skip_serializing_if = "Option::is_none")]
434 cache_control: Option<CacheControl>,
435 },
436}
437
438#[derive(Debug, Serialize, Deserialize)]
439#[serde(tag = "type")]
440pub enum ResponseContent {
441 #[serde(rename = "text")]
442 Text { text: String },
443 #[serde(rename = "tool_use")]
444 ToolUse {
445 id: String,
446 name: String,
447 input: serde_json::Value,
448 },
449}
450
451#[derive(Debug, Serialize, Deserialize)]
452pub struct ImageSource {
453 #[serde(rename = "type")]
454 pub source_type: String,
455 pub media_type: String,
456 pub data: String,
457}
458
459#[derive(Debug, Serialize, Deserialize)]
460pub struct Tool {
461 pub name: String,
462 pub description: String,
463 pub input_schema: serde_json::Value,
464}
465
466#[derive(Debug, Serialize, Deserialize)]
467#[serde(tag = "type", rename_all = "lowercase")]
468pub enum ToolChoice {
469 Auto,
470 Any,
471 Tool { name: String },
472}
473
474#[derive(Debug, Serialize, Deserialize)]
475pub struct Request {
476 pub model: String,
477 pub max_tokens: u32,
478 pub messages: Vec<Message>,
479 #[serde(default, skip_serializing_if = "Vec::is_empty")]
480 pub tools: Vec<Tool>,
481 #[serde(default, skip_serializing_if = "Option::is_none")]
482 pub tool_choice: Option<ToolChoice>,
483 #[serde(default, skip_serializing_if = "Option::is_none")]
484 pub system: Option<String>,
485 #[serde(default, skip_serializing_if = "Option::is_none")]
486 pub metadata: Option<Metadata>,
487 #[serde(default, skip_serializing_if = "Vec::is_empty")]
488 pub stop_sequences: Vec<String>,
489 #[serde(default, skip_serializing_if = "Option::is_none")]
490 pub temperature: Option<f32>,
491 #[serde(default, skip_serializing_if = "Option::is_none")]
492 pub top_k: Option<u32>,
493 #[serde(default, skip_serializing_if = "Option::is_none")]
494 pub top_p: Option<f32>,
495}
496
497#[derive(Debug, Serialize, Deserialize)]
498struct StreamingRequest {
499 #[serde(flatten)]
500 pub base: Request,
501 pub stream: bool,
502}
503
504#[derive(Debug, Serialize, Deserialize)]
505pub struct Metadata {
506 pub user_id: Option<String>,
507}
508
509#[derive(Debug, Serialize, Deserialize)]
510pub struct Usage {
511 #[serde(default, skip_serializing_if = "Option::is_none")]
512 pub input_tokens: Option<u32>,
513 #[serde(default, skip_serializing_if = "Option::is_none")]
514 pub output_tokens: Option<u32>,
515 #[serde(default, skip_serializing_if = "Option::is_none")]
516 pub cache_creation_input_tokens: Option<u32>,
517 #[serde(default, skip_serializing_if = "Option::is_none")]
518 pub cache_read_input_tokens: Option<u32>,
519}
520
521#[derive(Debug, Serialize, Deserialize)]
522pub struct Response {
523 pub id: String,
524 #[serde(rename = "type")]
525 pub response_type: String,
526 pub role: Role,
527 pub content: Vec<ResponseContent>,
528 pub model: String,
529 #[serde(default, skip_serializing_if = "Option::is_none")]
530 pub stop_reason: Option<String>,
531 #[serde(default, skip_serializing_if = "Option::is_none")]
532 pub stop_sequence: Option<String>,
533 pub usage: Usage,
534}
535
536#[derive(Debug, Serialize, Deserialize)]
537#[serde(tag = "type")]
538pub enum Event {
539 #[serde(rename = "message_start")]
540 MessageStart { message: Response },
541 #[serde(rename = "content_block_start")]
542 ContentBlockStart {
543 index: usize,
544 content_block: ResponseContent,
545 },
546 #[serde(rename = "content_block_delta")]
547 ContentBlockDelta { index: usize, delta: ContentDelta },
548 #[serde(rename = "content_block_stop")]
549 ContentBlockStop { index: usize },
550 #[serde(rename = "message_delta")]
551 MessageDelta { delta: MessageDelta, usage: Usage },
552 #[serde(rename = "message_stop")]
553 MessageStop,
554 #[serde(rename = "ping")]
555 Ping,
556 #[serde(rename = "error")]
557 Error { error: ApiError },
558}
559
560#[derive(Debug, Serialize, Deserialize)]
561#[serde(tag = "type")]
562pub enum ContentDelta {
563 #[serde(rename = "text_delta")]
564 TextDelta { text: String },
565 #[serde(rename = "input_json_delta")]
566 InputJsonDelta { partial_json: String },
567}
568
569#[derive(Debug, Serialize, Deserialize)]
570pub struct MessageDelta {
571 pub stop_reason: Option<String>,
572 pub stop_sequence: Option<String>,
573}
574
575#[derive(Error, Debug)]
576pub enum AnthropicError {
577 #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
578 ApiError(ApiError),
579 #[error("{0}")]
580 Other(#[from] anyhow::Error),
581}
582
583#[derive(Debug, Serialize, Deserialize)]
584pub struct ApiError {
585 #[serde(rename = "type")]
586 pub error_type: String,
587 pub message: String,
588}
589
590/// An Anthropic API error code.
591/// https://docs.anthropic.com/en/api/errors#http-errors
592#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
593#[strum(serialize_all = "snake_case")]
594pub enum ApiErrorCode {
595 /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
596 InvalidRequestError,
597 /// 401 - `authentication_error`: There's an issue with your API key.
598 AuthenticationError,
599 /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
600 PermissionError,
601 /// 404 - `not_found_error`: The requested resource was not found.
602 NotFoundError,
603 /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
604 RequestTooLarge,
605 /// 429 - `rate_limit_error`: Your account has hit a rate limit.
606 RateLimitError,
607 /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
608 ApiError,
609 /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
610 OverloadedError,
611}
612
613impl ApiError {
614 pub fn code(&self) -> Option<ApiErrorCode> {
615 ApiErrorCode::from_str(&self.error_type).ok()
616 }
617
618 pub fn is_rate_limit_error(&self) -> bool {
619 matches!(self.error_type.as_str(), "rate_limit_error")
620 }
621}