anthropic.rs

  1use anyhow::{anyhow, Result};
  2use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
  3use serde::{Deserialize, Serialize};
  4use std::convert::TryFrom;
  5use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
  6
  7#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
  8pub enum Model {
  9    #[default]
 10    #[serde(rename = "claude-3-opus-20240229")]
 11    Claude3Opus,
 12    #[serde(rename = "claude-3-sonnet-20240229")]
 13    Claude3Sonnet,
 14    #[serde(rename = "claude-3-haiku-20240307")]
 15    Claude3Haiku,
 16}
 17
 18impl Model {
 19    pub fn from_id(id: &str) -> Result<Self> {
 20        if id.starts_with("claude-3-opus") {
 21            Ok(Self::Claude3Opus)
 22        } else if id.starts_with("claude-3-sonnet") {
 23            Ok(Self::Claude3Sonnet)
 24        } else if id.starts_with("claude-3-haiku") {
 25            Ok(Self::Claude3Haiku)
 26        } else {
 27            Err(anyhow!("Invalid model id: {}", id))
 28        }
 29    }
 30
 31    pub fn display_name(&self) -> &'static str {
 32        match self {
 33            Self::Claude3Opus => "Claude 3 Opus",
 34            Self::Claude3Sonnet => "Claude 3 Sonnet",
 35            Self::Claude3Haiku => "Claude 3 Haiku",
 36        }
 37    }
 38
 39    pub fn max_token_count(&self) -> usize {
 40        200_000
 41    }
 42}
 43
 44#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
 45#[serde(rename_all = "lowercase")]
 46pub enum Role {
 47    User,
 48    Assistant,
 49}
 50
 51impl TryFrom<String> for Role {
 52    type Error = anyhow::Error;
 53
 54    fn try_from(value: String) -> Result<Self> {
 55        match value.as_str() {
 56            "user" => Ok(Self::User),
 57            "assistant" => Ok(Self::Assistant),
 58            _ => Err(anyhow!("invalid role '{value}'")),
 59        }
 60    }
 61}
 62
 63impl From<Role> for String {
 64    fn from(val: Role) -> Self {
 65        match val {
 66            Role::User => "user".to_owned(),
 67            Role::Assistant => "assistant".to_owned(),
 68        }
 69    }
 70}
 71
 72#[derive(Debug, Serialize)]
 73pub struct Request {
 74    pub model: Model,
 75    pub messages: Vec<RequestMessage>,
 76    pub stream: bool,
 77    pub system: String,
 78    pub max_tokens: u32,
 79}
 80
 81#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 82pub struct RequestMessage {
 83    pub role: Role,
 84    pub content: String,
 85}
 86
 87#[derive(Deserialize, Debug)]
 88#[serde(tag = "type", rename_all = "snake_case")]
 89pub enum ResponseEvent {
 90    MessageStart {
 91        message: ResponseMessage,
 92    },
 93    ContentBlockStart {
 94        index: u32,
 95        content_block: ContentBlock,
 96    },
 97    Ping {},
 98    ContentBlockDelta {
 99        index: u32,
100        delta: TextDelta,
101    },
102    ContentBlockStop {
103        index: u32,
104    },
105    MessageDelta {
106        delta: ResponseMessage,
107        usage: Usage,
108    },
109    MessageStop {},
110}
111
112#[derive(Deserialize, Debug)]
113pub struct ResponseMessage {
114    #[serde(rename = "type")]
115    pub message_type: Option<String>,
116    pub id: Option<String>,
117    pub role: Option<String>,
118    pub content: Option<Vec<String>>,
119    pub model: Option<String>,
120    pub stop_reason: Option<String>,
121    pub stop_sequence: Option<String>,
122    pub usage: Option<Usage>,
123}
124
125#[derive(Deserialize, Debug)]
126pub struct Usage {
127    pub input_tokens: Option<u32>,
128    pub output_tokens: Option<u32>,
129}
130
131#[derive(Deserialize, Debug)]
132#[serde(tag = "type", rename_all = "snake_case")]
133pub enum ContentBlock {
134    Text { text: String },
135}
136
137#[derive(Deserialize, Debug)]
138#[serde(tag = "type", rename_all = "snake_case")]
139pub enum TextDelta {
140    TextDelta { text: String },
141}
142
143pub async fn stream_completion(
144    client: &dyn HttpClient,
145    api_url: &str,
146    api_key: &str,
147    request: Request,
148) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
149    let uri = format!("{api_url}/v1/messages");
150    let request = HttpRequest::builder()
151        .method(Method::POST)
152        .uri(uri)
153        .header("Anthropic-Version", "2023-06-01")
154        .header("Anthropic-Beta", "messages-2023-12-15")
155        .header("X-Api-Key", api_key)
156        .header("Content-Type", "application/json")
157        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
158    let mut response = client.send(request).await?;
159    if response.status().is_success() {
160        let reader = BufReader::new(response.into_body());
161        Ok(reader
162            .lines()
163            .filter_map(|line| async move {
164                match line {
165                    Ok(line) => {
166                        let line = line.strip_prefix("data: ")?;
167                        match serde_json::from_str(line) {
168                            Ok(response) => Some(Ok(response)),
169                            Err(error) => Some(Err(anyhow!(error))),
170                        }
171                    }
172                    Err(error) => Some(Err(anyhow!(error))),
173                }
174            })
175            .boxed())
176    } else {
177        let mut body = Vec::new();
178        response.body_mut().read_to_end(&mut body).await?;
179
180        let body_str = std::str::from_utf8(&body)?;
181
182        match serde_json::from_str::<ResponseEvent>(body_str) {
183            Ok(_) => Err(anyhow!(
184                "Unexpected success response while expecting an error: {}",
185                body_str,
186            )),
187            Err(_) => Err(anyhow!(
188                "Failed to connect to API: {} {}",
189                response.status(),
190                body_str,
191            )),
192        }
193    }
194}
195
196// #[cfg(test)]
197// mod tests {
198//     use super::*;
199//     use util::http::IsahcHttpClient;
200
201//     #[tokio::test]
202//     async fn stream_completion_success() {
203//         let http_client = IsahcHttpClient::new().unwrap();
204
205//         let request = Request {
206//             model: Model::Claude3Opus,
207//             messages: vec![RequestMessage {
208//                 role: Role::User,
209//                 content: "Ping".to_string(),
210//             }],
211//             stream: true,
212//             system: "Respond to ping with pong".to_string(),
213//             max_tokens: 4096,
214//         };
215
216//         let stream = stream_completion(
217//             &http_client,
218//             "https://api.anthropic.com",
219//             &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
220//             request,
221//         )
222//         .await
223//         .unwrap();
224
225//         stream
226//             .for_each(|event| async {
227//                 match event {
228//                     Ok(event) => println!("{:?}", event),
229//                     Err(e) => eprintln!("Error: {:?}", e),
230//                 }
231//             })
232//             .await;
233//     }
234// }