1mod supported_countries;
2
3use anyhow::{anyhow, Context, Result};
4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6use isahc::config::Configurable;
7use serde::{Deserialize, Serialize};
8use std::str::FromStr;
9use std::time::Duration;
10use strum::{EnumIter, EnumString};
11use thiserror::Error;
12
13pub use supported_countries::*;
14
15pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
16
17#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
18#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
19pub enum Model {
20 #[default]
21 #[serde(rename = "claude-3-5-sonnet", alias = "claude-3-5-sonnet-20240620")]
22 Claude3_5Sonnet,
23 #[serde(rename = "claude-3-opus", alias = "claude-3-opus-20240229")]
24 Claude3Opus,
25 #[serde(rename = "claude-3-sonnet", alias = "claude-3-sonnet-20240229")]
26 Claude3Sonnet,
27 #[serde(rename = "claude-3-haiku", alias = "claude-3-haiku-20240307")]
28 Claude3Haiku,
29 #[serde(rename = "custom")]
30 Custom {
31 name: String,
32 max_tokens: usize,
33 /// Override this model with a different Anthropic model for tool calls.
34 tool_override: Option<String>,
35 },
36}
37
38impl Model {
39 pub fn from_id(id: &str) -> Result<Self> {
40 if id.starts_with("claude-3-5-sonnet") {
41 Ok(Self::Claude3_5Sonnet)
42 } else if id.starts_with("claude-3-opus") {
43 Ok(Self::Claude3Opus)
44 } else if id.starts_with("claude-3-sonnet") {
45 Ok(Self::Claude3Sonnet)
46 } else if id.starts_with("claude-3-haiku") {
47 Ok(Self::Claude3Haiku)
48 } else {
49 Err(anyhow!("invalid model id"))
50 }
51 }
52
53 pub fn id(&self) -> &str {
54 match self {
55 Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
56 Model::Claude3Opus => "claude-3-opus-20240229",
57 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
58 Model::Claude3Haiku => "claude-3-haiku-20240307",
59 Self::Custom { name, .. } => name,
60 }
61 }
62
63 pub fn display_name(&self) -> &str {
64 match self {
65 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
66 Self::Claude3Opus => "Claude 3 Opus",
67 Self::Claude3Sonnet => "Claude 3 Sonnet",
68 Self::Claude3Haiku => "Claude 3 Haiku",
69 Self::Custom { name, .. } => name,
70 }
71 }
72
73 pub fn max_token_count(&self) -> usize {
74 match self {
75 Self::Claude3_5Sonnet
76 | Self::Claude3Opus
77 | Self::Claude3Sonnet
78 | Self::Claude3Haiku => 200_000,
79 Self::Custom { max_tokens, .. } => *max_tokens,
80 }
81 }
82
83 pub fn tool_model_id(&self) -> &str {
84 if let Self::Custom {
85 tool_override: Some(tool_override),
86 ..
87 } = self
88 {
89 tool_override
90 } else {
91 self.id()
92 }
93 }
94}
95
96pub async fn complete(
97 client: &dyn HttpClient,
98 api_url: &str,
99 api_key: &str,
100 request: Request,
101) -> Result<Response, AnthropicError> {
102 let uri = format!("{api_url}/v1/messages");
103 let request_builder = HttpRequest::builder()
104 .method(Method::POST)
105 .uri(uri)
106 .header("Anthropic-Version", "2023-06-01")
107 .header("Anthropic-Beta", "tools-2024-04-04")
108 .header("X-Api-Key", api_key)
109 .header("Content-Type", "application/json");
110
111 let serialized_request =
112 serde_json::to_string(&request).context("failed to serialize request")?;
113 let request = request_builder
114 .body(AsyncBody::from(serialized_request))
115 .context("failed to construct request body")?;
116
117 let mut response = client
118 .send(request)
119 .await
120 .context("failed to send request to Anthropic")?;
121 if response.status().is_success() {
122 let mut body = Vec::new();
123 response
124 .body_mut()
125 .read_to_end(&mut body)
126 .await
127 .context("failed to read response body")?;
128 let response_message: Response =
129 serde_json::from_slice(&body).context("failed to deserialize response body")?;
130 Ok(response_message)
131 } else {
132 let mut body = Vec::new();
133 response
134 .body_mut()
135 .read_to_end(&mut body)
136 .await
137 .context("failed to read response body")?;
138 let body_str =
139 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
140 Err(AnthropicError::Other(anyhow!(
141 "Failed to connect to API: {} {}",
142 response.status(),
143 body_str
144 )))
145 }
146}
147
148pub async fn stream_completion(
149 client: &dyn HttpClient,
150 api_url: &str,
151 api_key: &str,
152 request: Request,
153 low_speed_timeout: Option<Duration>,
154) -> Result<BoxStream<'static, Result<Event, AnthropicError>>, AnthropicError> {
155 let request = StreamingRequest {
156 base: request,
157 stream: true,
158 };
159 let uri = format!("{api_url}/v1/messages");
160 let mut request_builder = HttpRequest::builder()
161 .method(Method::POST)
162 .uri(uri)
163 .header("Anthropic-Version", "2023-06-01")
164 .header("Anthropic-Beta", "tools-2024-04-04")
165 .header("X-Api-Key", api_key)
166 .header("Content-Type", "application/json");
167 if let Some(low_speed_timeout) = low_speed_timeout {
168 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
169 }
170 let serialized_request =
171 serde_json::to_string(&request).context("failed to serialize request")?;
172 let request = request_builder
173 .body(AsyncBody::from(serialized_request))
174 .context("failed to construct request body")?;
175
176 let mut response = client
177 .send(request)
178 .await
179 .context("failed to send request to Anthropic")?;
180 if response.status().is_success() {
181 let reader = BufReader::new(response.into_body());
182 Ok(reader
183 .lines()
184 .filter_map(|line| async move {
185 match line {
186 Ok(line) => {
187 let line = line.strip_prefix("data: ")?;
188 match serde_json::from_str(line) {
189 Ok(response) => Some(Ok(response)),
190 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
191 }
192 }
193 Err(error) => Some(Err(AnthropicError::Other(anyhow!(error)))),
194 }
195 })
196 .boxed())
197 } else {
198 let mut body = Vec::new();
199 response
200 .body_mut()
201 .read_to_end(&mut body)
202 .await
203 .context("failed to read response body")?;
204
205 let body_str =
206 std::str::from_utf8(&body).context("failed to parse response body as UTF-8")?;
207
208 match serde_json::from_str::<Event>(body_str) {
209 Ok(Event::Error { error }) => Err(AnthropicError::ApiError(error)),
210 Ok(_) => Err(AnthropicError::Other(anyhow!(
211 "Unexpected success response while expecting an error: '{body_str}'",
212 ))),
213 Err(_) => Err(AnthropicError::Other(anyhow!(
214 "Failed to connect to API: {} {}",
215 response.status(),
216 body_str,
217 ))),
218 }
219 }
220}
221
222pub fn extract_text_from_events(
223 response: impl Stream<Item = Result<Event, AnthropicError>>,
224) -> impl Stream<Item = Result<String, AnthropicError>> {
225 response.filter_map(|response| async move {
226 match response {
227 Ok(response) => match response {
228 Event::ContentBlockStart { content_block, .. } => match content_block {
229 Content::Text { text } => Some(Ok(text)),
230 _ => None,
231 },
232 Event::ContentBlockDelta { delta, .. } => match delta {
233 ContentDelta::TextDelta { text } => Some(Ok(text)),
234 _ => None,
235 },
236 Event::Error { error } => Some(Err(AnthropicError::ApiError(error))),
237 _ => None,
238 },
239 Err(error) => Some(Err(error)),
240 }
241 })
242}
243
244#[derive(Debug, Serialize, Deserialize)]
245pub struct Message {
246 pub role: Role,
247 pub content: Vec<Content>,
248}
249
250#[derive(Debug, Serialize, Deserialize)]
251#[serde(rename_all = "lowercase")]
252pub enum Role {
253 User,
254 Assistant,
255}
256
257#[derive(Debug, Serialize, Deserialize)]
258#[serde(tag = "type")]
259pub enum Content {
260 #[serde(rename = "text")]
261 Text { text: String },
262 #[serde(rename = "image")]
263 Image { source: ImageSource },
264 #[serde(rename = "tool_use")]
265 ToolUse {
266 id: String,
267 name: String,
268 input: serde_json::Value,
269 },
270 #[serde(rename = "tool_result")]
271 ToolResult {
272 tool_use_id: String,
273 content: String,
274 },
275}
276
277#[derive(Debug, Serialize, Deserialize)]
278pub struct ImageSource {
279 #[serde(rename = "type")]
280 pub source_type: String,
281 pub media_type: String,
282 pub data: String,
283}
284
285#[derive(Debug, Serialize, Deserialize)]
286pub struct Tool {
287 pub name: String,
288 pub description: String,
289 pub input_schema: serde_json::Value,
290}
291
292#[derive(Debug, Serialize, Deserialize)]
293#[serde(tag = "type", rename_all = "lowercase")]
294pub enum ToolChoice {
295 Auto,
296 Any,
297 Tool { name: String },
298}
299
300#[derive(Debug, Serialize, Deserialize)]
301pub struct Request {
302 pub model: String,
303 pub max_tokens: u32,
304 pub messages: Vec<Message>,
305 #[serde(default, skip_serializing_if = "Vec::is_empty")]
306 pub tools: Vec<Tool>,
307 #[serde(default, skip_serializing_if = "Option::is_none")]
308 pub tool_choice: Option<ToolChoice>,
309 #[serde(default, skip_serializing_if = "Option::is_none")]
310 pub system: Option<String>,
311 #[serde(default, skip_serializing_if = "Option::is_none")]
312 pub metadata: Option<Metadata>,
313 #[serde(default, skip_serializing_if = "Vec::is_empty")]
314 pub stop_sequences: Vec<String>,
315 #[serde(default, skip_serializing_if = "Option::is_none")]
316 pub temperature: Option<f32>,
317 #[serde(default, skip_serializing_if = "Option::is_none")]
318 pub top_k: Option<u32>,
319 #[serde(default, skip_serializing_if = "Option::is_none")]
320 pub top_p: Option<f32>,
321}
322
323#[derive(Debug, Serialize, Deserialize)]
324struct StreamingRequest {
325 #[serde(flatten)]
326 pub base: Request,
327 pub stream: bool,
328}
329
330#[derive(Debug, Serialize, Deserialize)]
331pub struct Metadata {
332 pub user_id: Option<String>,
333}
334
335#[derive(Debug, Serialize, Deserialize)]
336pub struct Usage {
337 #[serde(default, skip_serializing_if = "Option::is_none")]
338 pub input_tokens: Option<u32>,
339 #[serde(default, skip_serializing_if = "Option::is_none")]
340 pub output_tokens: Option<u32>,
341}
342
343#[derive(Debug, Serialize, Deserialize)]
344pub struct Response {
345 pub id: String,
346 #[serde(rename = "type")]
347 pub response_type: String,
348 pub role: Role,
349 pub content: Vec<Content>,
350 pub model: String,
351 #[serde(default, skip_serializing_if = "Option::is_none")]
352 pub stop_reason: Option<String>,
353 #[serde(default, skip_serializing_if = "Option::is_none")]
354 pub stop_sequence: Option<String>,
355 pub usage: Usage,
356}
357
358#[derive(Debug, Serialize, Deserialize)]
359#[serde(tag = "type")]
360pub enum Event {
361 #[serde(rename = "message_start")]
362 MessageStart { message: Response },
363 #[serde(rename = "content_block_start")]
364 ContentBlockStart {
365 index: usize,
366 content_block: Content,
367 },
368 #[serde(rename = "content_block_delta")]
369 ContentBlockDelta { index: usize, delta: ContentDelta },
370 #[serde(rename = "content_block_stop")]
371 ContentBlockStop { index: usize },
372 #[serde(rename = "message_delta")]
373 MessageDelta { delta: MessageDelta, usage: Usage },
374 #[serde(rename = "message_stop")]
375 MessageStop,
376 #[serde(rename = "ping")]
377 Ping,
378 #[serde(rename = "error")]
379 Error { error: ApiError },
380}
381
382#[derive(Debug, Serialize, Deserialize)]
383#[serde(tag = "type")]
384pub enum ContentDelta {
385 #[serde(rename = "text_delta")]
386 TextDelta { text: String },
387 #[serde(rename = "input_json_delta")]
388 InputJsonDelta { partial_json: String },
389}
390
391#[derive(Debug, Serialize, Deserialize)]
392pub struct MessageDelta {
393 pub stop_reason: Option<String>,
394 pub stop_sequence: Option<String>,
395}
396
397#[derive(Error, Debug)]
398pub enum AnthropicError {
399 #[error("an error occurred while interacting with the Anthropic API: {error_type}: {message}", error_type = .0.error_type, message = .0.message)]
400 ApiError(ApiError),
401 #[error("{0}")]
402 Other(#[from] anyhow::Error),
403}
404
405#[derive(Debug, Serialize, Deserialize)]
406pub struct ApiError {
407 #[serde(rename = "type")]
408 pub error_type: String,
409 pub message: String,
410}
411
412/// An Anthropic API error code.
413/// https://docs.anthropic.com/en/api/errors#http-errors
414#[derive(Debug, PartialEq, Eq, Clone, Copy, EnumString)]
415#[strum(serialize_all = "snake_case")]
416pub enum ApiErrorCode {
417 /// 400 - `invalid_request_error`: There was an issue with the format or content of your request.
418 InvalidRequestError,
419 /// 401 - `authentication_error`: There's an issue with your API key.
420 AuthenticationError,
421 /// 403 - `permission_error`: Your API key does not have permission to use the specified resource.
422 PermissionError,
423 /// 404 - `not_found_error`: The requested resource was not found.
424 NotFoundError,
425 /// 413 - `request_too_large`: Request exceeds the maximum allowed number of bytes.
426 RequestTooLarge,
427 /// 429 - `rate_limit_error`: Your account has hit a rate limit.
428 RateLimitError,
429 /// 500 - `api_error`: An unexpected error has occurred internal to Anthropic's systems.
430 ApiError,
431 /// 529 - `overloaded_error`: Anthropic's API is temporarily overloaded.
432 OverloadedError,
433}
434
435impl ApiError {
436 pub fn code(&self) -> Option<ApiErrorCode> {
437 ApiErrorCode::from_str(&self.error_type).ok()
438 }
439
440 pub fn is_rate_limit_error(&self) -> bool {
441 match self.error_type.as_str() {
442 "rate_limit_error" => true,
443 _ => false,
444 }
445 }
446}