1mod supported_countries;
2
3use anyhow::{anyhow, Result};
4use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, Stream, StreamExt};
5use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
6use isahc::config::Configurable;
7use serde::{Deserialize, Serialize};
8use std::time::Duration;
9use strum::EnumIter;
10
11pub use supported_countries::*;
12
13pub const ANTHROPIC_API_URL: &'static str = "https://api.anthropic.com";
14
15#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
16#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq, EnumIter)]
17pub enum Model {
18 #[default]
19 #[serde(alias = "claude-3-5-sonnet", rename = "claude-3-5-sonnet-20240620")]
20 Claude3_5Sonnet,
21 #[serde(alias = "claude-3-opus", rename = "claude-3-opus-20240229")]
22 Claude3Opus,
23 #[serde(alias = "claude-3-sonnet", rename = "claude-3-sonnet-20240229")]
24 Claude3Sonnet,
25 #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
26 Claude3Haiku,
27 #[serde(rename = "custom")]
28 Custom {
29 name: String,
30 max_tokens: usize,
31 /// Override this model with a different Anthropic model for tool calls.
32 tool_override: Option<String>,
33 },
34}
35
36impl Model {
37 pub fn from_id(id: &str) -> Result<Self> {
38 if id.starts_with("claude-3-5-sonnet") {
39 Ok(Self::Claude3_5Sonnet)
40 } else if id.starts_with("claude-3-opus") {
41 Ok(Self::Claude3Opus)
42 } else if id.starts_with("claude-3-sonnet") {
43 Ok(Self::Claude3Sonnet)
44 } else if id.starts_with("claude-3-haiku") {
45 Ok(Self::Claude3Haiku)
46 } else {
47 Err(anyhow!("invalid model id"))
48 }
49 }
50
51 pub fn id(&self) -> &str {
52 match self {
53 Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
54 Model::Claude3Opus => "claude-3-opus-20240229",
55 Model::Claude3Sonnet => "claude-3-sonnet-20240229",
56 Model::Claude3Haiku => "claude-3-opus-20240307",
57 Self::Custom { name, .. } => name,
58 }
59 }
60
61 pub fn display_name(&self) -> &str {
62 match self {
63 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
64 Self::Claude3Opus => "Claude 3 Opus",
65 Self::Claude3Sonnet => "Claude 3 Sonnet",
66 Self::Claude3Haiku => "Claude 3 Haiku",
67 Self::Custom { name, .. } => name,
68 }
69 }
70
71 pub fn max_token_count(&self) -> usize {
72 match self {
73 Self::Claude3_5Sonnet
74 | Self::Claude3Opus
75 | Self::Claude3Sonnet
76 | Self::Claude3Haiku => 200_000,
77 Self::Custom { max_tokens, .. } => *max_tokens,
78 }
79 }
80
81 pub fn tool_model_id(&self) -> &str {
82 if let Self::Custom {
83 tool_override: Some(tool_override),
84 ..
85 } = self
86 {
87 tool_override
88 } else {
89 self.id()
90 }
91 }
92}
93
94pub async fn complete(
95 client: &dyn HttpClient,
96 api_url: &str,
97 api_key: &str,
98 request: Request,
99) -> Result<Response> {
100 let uri = format!("{api_url}/v1/messages");
101 let request_builder = HttpRequest::builder()
102 .method(Method::POST)
103 .uri(uri)
104 .header("Anthropic-Version", "2023-06-01")
105 .header("Anthropic-Beta", "tools-2024-04-04")
106 .header("X-Api-Key", api_key)
107 .header("Content-Type", "application/json");
108
109 let serialized_request = serde_json::to_string(&request)?;
110 let request = request_builder.body(AsyncBody::from(serialized_request))?;
111
112 let mut response = client.send(request).await?;
113 if response.status().is_success() {
114 let mut body = Vec::new();
115 response.body_mut().read_to_end(&mut body).await?;
116 let response_message: Response = serde_json::from_slice(&body)?;
117 Ok(response_message)
118 } else {
119 let mut body = Vec::new();
120 response.body_mut().read_to_end(&mut body).await?;
121 let body_str = std::str::from_utf8(&body)?;
122 Err(anyhow!(
123 "Failed to connect to API: {} {}",
124 response.status(),
125 body_str
126 ))
127 }
128}
129
130pub async fn stream_completion(
131 client: &dyn HttpClient,
132 api_url: &str,
133 api_key: &str,
134 request: Request,
135 low_speed_timeout: Option<Duration>,
136) -> Result<BoxStream<'static, Result<Event>>> {
137 let request = StreamingRequest {
138 base: request,
139 stream: true,
140 };
141 let uri = format!("{api_url}/v1/messages");
142 let mut request_builder = HttpRequest::builder()
143 .method(Method::POST)
144 .uri(uri)
145 .header("Anthropic-Version", "2023-06-01")
146 .header("Anthropic-Beta", "tools-2024-04-04")
147 .header("X-Api-Key", api_key)
148 .header("Content-Type", "application/json");
149 if let Some(low_speed_timeout) = low_speed_timeout {
150 request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
151 }
152 let serialized_request = serde_json::to_string(&request)?;
153 let request = request_builder.body(AsyncBody::from(serialized_request))?;
154
155 let mut response = client.send(request).await?;
156 if response.status().is_success() {
157 let reader = BufReader::new(response.into_body());
158 Ok(reader
159 .lines()
160 .filter_map(|line| async move {
161 match line {
162 Ok(line) => {
163 let line = line.strip_prefix("data: ")?;
164 match serde_json::from_str(line) {
165 Ok(response) => Some(Ok(response)),
166 Err(error) => Some(Err(anyhow!(error))),
167 }
168 }
169 Err(error) => Some(Err(anyhow!(error))),
170 }
171 })
172 .boxed())
173 } else {
174 let mut body = Vec::new();
175 response.body_mut().read_to_end(&mut body).await?;
176
177 let body_str = std::str::from_utf8(&body)?;
178
179 match serde_json::from_str::<Event>(body_str) {
180 Ok(Event::Error { error }) => Err(api_error_to_err(error)),
181 Ok(_) => Err(anyhow!(
182 "Unexpected success response while expecting an error: '{body_str}'",
183 )),
184 Err(_) => Err(anyhow!(
185 "Failed to connect to API: {} {}",
186 response.status(),
187 body_str,
188 )),
189 }
190 }
191}
192
193pub fn extract_text_from_events(
194 response: impl Stream<Item = Result<Event>>,
195) -> impl Stream<Item = Result<String>> {
196 response.filter_map(|response| async move {
197 match response {
198 Ok(response) => match response {
199 Event::ContentBlockStart { content_block, .. } => match content_block {
200 Content::Text { text } => Some(Ok(text)),
201 _ => None,
202 },
203 Event::ContentBlockDelta { delta, .. } => match delta {
204 ContentDelta::TextDelta { text } => Some(Ok(text)),
205 _ => None,
206 },
207 Event::Error { error } => Some(Err(api_error_to_err(error))),
208 _ => None,
209 },
210 Err(error) => Some(Err(error)),
211 }
212 })
213}
214
215fn api_error_to_err(
216 ApiError {
217 error_type,
218 message,
219 }: ApiError,
220) -> anyhow::Error {
221 anyhow!("API error. Type: '{error_type}', message: '{message}'",)
222}
223
224#[derive(Debug, Serialize, Deserialize)]
225pub struct Message {
226 pub role: Role,
227 pub content: Vec<Content>,
228}
229
230#[derive(Debug, Serialize, Deserialize)]
231#[serde(rename_all = "lowercase")]
232pub enum Role {
233 User,
234 Assistant,
235}
236
237#[derive(Debug, Serialize, Deserialize)]
238#[serde(tag = "type")]
239pub enum Content {
240 #[serde(rename = "text")]
241 Text { text: String },
242 #[serde(rename = "image")]
243 Image { source: ImageSource },
244 #[serde(rename = "tool_use")]
245 ToolUse {
246 id: String,
247 name: String,
248 input: serde_json::Value,
249 },
250 #[serde(rename = "tool_result")]
251 ToolResult {
252 tool_use_id: String,
253 content: String,
254 },
255}
256
257#[derive(Debug, Serialize, Deserialize)]
258pub struct ImageSource {
259 #[serde(rename = "type")]
260 pub source_type: String,
261 pub media_type: String,
262 pub data: String,
263}
264
265#[derive(Debug, Serialize, Deserialize)]
266pub struct Tool {
267 pub name: String,
268 pub description: String,
269 pub input_schema: serde_json::Value,
270}
271
272#[derive(Debug, Serialize, Deserialize)]
273#[serde(tag = "type", rename_all = "lowercase")]
274pub enum ToolChoice {
275 Auto,
276 Any,
277 Tool { name: String },
278}
279
280#[derive(Debug, Serialize, Deserialize)]
281pub struct Request {
282 pub model: String,
283 pub max_tokens: u32,
284 pub messages: Vec<Message>,
285 #[serde(default, skip_serializing_if = "Vec::is_empty")]
286 pub tools: Vec<Tool>,
287 #[serde(default, skip_serializing_if = "Option::is_none")]
288 pub tool_choice: Option<ToolChoice>,
289 #[serde(default, skip_serializing_if = "Option::is_none")]
290 pub system: Option<String>,
291 #[serde(default, skip_serializing_if = "Option::is_none")]
292 pub metadata: Option<Metadata>,
293 #[serde(default, skip_serializing_if = "Vec::is_empty")]
294 pub stop_sequences: Vec<String>,
295 #[serde(default, skip_serializing_if = "Option::is_none")]
296 pub temperature: Option<f32>,
297 #[serde(default, skip_serializing_if = "Option::is_none")]
298 pub top_k: Option<u32>,
299 #[serde(default, skip_serializing_if = "Option::is_none")]
300 pub top_p: Option<f32>,
301}
302
303#[derive(Debug, Serialize, Deserialize)]
304struct StreamingRequest {
305 #[serde(flatten)]
306 pub base: Request,
307 pub stream: bool,
308}
309
310#[derive(Debug, Serialize, Deserialize)]
311pub struct Metadata {
312 pub user_id: Option<String>,
313}
314
315#[derive(Debug, Serialize, Deserialize)]
316pub struct Usage {
317 #[serde(default, skip_serializing_if = "Option::is_none")]
318 pub input_tokens: Option<u32>,
319 #[serde(default, skip_serializing_if = "Option::is_none")]
320 pub output_tokens: Option<u32>,
321}
322
323#[derive(Debug, Serialize, Deserialize)]
324pub struct Response {
325 pub id: String,
326 #[serde(rename = "type")]
327 pub response_type: String,
328 pub role: Role,
329 pub content: Vec<Content>,
330 pub model: String,
331 #[serde(default, skip_serializing_if = "Option::is_none")]
332 pub stop_reason: Option<String>,
333 #[serde(default, skip_serializing_if = "Option::is_none")]
334 pub stop_sequence: Option<String>,
335 pub usage: Usage,
336}
337
338#[derive(Debug, Serialize, Deserialize)]
339#[serde(tag = "type")]
340pub enum Event {
341 #[serde(rename = "message_start")]
342 MessageStart { message: Response },
343 #[serde(rename = "content_block_start")]
344 ContentBlockStart {
345 index: usize,
346 content_block: Content,
347 },
348 #[serde(rename = "content_block_delta")]
349 ContentBlockDelta { index: usize, delta: ContentDelta },
350 #[serde(rename = "content_block_stop")]
351 ContentBlockStop { index: usize },
352 #[serde(rename = "message_delta")]
353 MessageDelta { delta: MessageDelta, usage: Usage },
354 #[serde(rename = "message_stop")]
355 MessageStop,
356 #[serde(rename = "ping")]
357 Ping,
358 #[serde(rename = "error")]
359 Error { error: ApiError },
360}
361
362#[derive(Debug, Serialize, Deserialize)]
363#[serde(tag = "type")]
364pub enum ContentDelta {
365 #[serde(rename = "text_delta")]
366 TextDelta { text: String },
367 #[serde(rename = "input_json_delta")]
368 InputJsonDelta { partial_json: String },
369}
370
371#[derive(Debug, Serialize, Deserialize)]
372pub struct MessageDelta {
373 pub stop_reason: Option<String>,
374 pub stop_sequence: Option<String>,
375}
376
377#[derive(Debug, Serialize, Deserialize)]
378pub struct ApiError {
379 #[serde(rename = "type")]
380 pub error_type: String,
381 pub message: String,
382}