Detailed changes
@@ -213,6 +213,18 @@ dependencies = [
"windows-sys 0.48.0",
]
+[[package]]
+name = "anthropic"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.28",
+ "serde",
+ "serde_json",
+ "tokio",
+ "util",
+]
+
[[package]]
name = "anyhow"
version = "1.0.75"
@@ -2214,6 +2226,7 @@ dependencies = [
name = "collab"
version = "0.44.0"
dependencies = [
+ "anthropic",
"anyhow",
"async-trait",
"async-tungstenite",
@@ -1,6 +1,7 @@
[workspace]
members = [
"crates/activity_indicator",
+ "crates/anthropic",
"crates/assets",
"crates/assistant",
"crates/audio",
@@ -119,6 +120,7 @@ resolver = "2"
[workspace.dependencies]
activity_indicator = { path = "crates/activity_indicator" }
ai = { path = "crates/ai" }
+anthropic = { path = "crates/anthropic" }
assets = { path = "crates/assets" }
assistant = { path = "crates/assistant" }
audio = { path = "crates/audio" }
@@ -0,0 +1,22 @@
+[package]
+name = "anthropic"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[lib]
+path = "src/anthropic.rs"
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+util.workspace = true
+
+[dev-dependencies]
+tokio.workspace = true
+
+[lints]
+workspace = true
@@ -0,0 +1,234 @@
+use anyhow::{anyhow, Result};
+use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use serde::{Deserialize, Serialize};
+use std::convert::TryFrom;
+use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub enum Model {
+ #[default]
+ #[serde(rename = "claude-3-opus-20240229")]
+ Claude3Opus,
+ #[serde(rename = "claude-3-sonnet-20240229")]
+ Claude3Sonnet,
+ #[serde(rename = "claude-3-haiku-20240307")]
+ Claude3Haiku,
+}
+
+impl Model {
+ pub fn from_id(id: &str) -> Result<Self> {
+ if id.starts_with("claude-3-opus") {
+ Ok(Self::Claude3Opus)
+ } else if id.starts_with("claude-3-sonnet") {
+ Ok(Self::Claude3Sonnet)
+ } else if id.starts_with("claude-3-haiku") {
+ Ok(Self::Claude3Haiku)
+ } else {
+ Err(anyhow!("Invalid model id: {}", id))
+ }
+ }
+
+ pub fn display_name(&self) -> &'static str {
+ match self {
+ Self::Claude3Opus => "Claude 3 Opus",
+ Self::Claude3Sonnet => "Claude 3 Sonnet",
+ Self::Claude3Haiku => "Claude 3 Haiku",
+ }
+ }
+
+ pub fn max_token_count(&self) -> usize {
+ 200_000
+ }
+}
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+}
+
+impl TryFrom<String> for Role {
+ type Error = anyhow::Error;
+
+ fn try_from(value: String) -> Result<Self> {
+ match value.as_str() {
+ "user" => Ok(Self::User),
+ "assistant" => Ok(Self::Assistant),
+ _ => Err(anyhow!("invalid role '{value}'")),
+ }
+ }
+}
+
+impl From<Role> for String {
+ fn from(val: Role) -> Self {
+ match val {
+ Role::User => "user".to_owned(),
+ Role::Assistant => "assistant".to_owned(),
+ }
+ }
+}
+
+#[derive(Debug, Serialize)]
+pub struct Request {
+ pub model: Model,
+ pub messages: Vec<RequestMessage>,
+ pub stream: bool,
+ pub system: String,
+ pub max_tokens: u32,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+ pub role: Role,
+ pub content: String,
+}
+
+#[derive(Deserialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum ResponseEvent {
+ MessageStart {
+ message: ResponseMessage,
+ },
+ ContentBlockStart {
+ index: u32,
+ content_block: ContentBlock,
+ },
+ Ping {},
+ ContentBlockDelta {
+ index: u32,
+ delta: TextDelta,
+ },
+ ContentBlockStop {
+ index: u32,
+ },
+ MessageDelta {
+ delta: ResponseMessage,
+ usage: Usage,
+ },
+ MessageStop {},
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ResponseMessage {
+ #[serde(rename = "type")]
+ pub message_type: Option<String>,
+ pub id: Option<String>,
+ pub role: Option<String>,
+ pub content: Option<Vec<String>>,
+ pub model: Option<String>,
+ pub stop_reason: Option<String>,
+ pub stop_sequence: Option<String>,
+ pub usage: Option<Usage>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct Usage {
+ pub input_tokens: Option<u32>,
+ pub output_tokens: Option<u32>,
+}
+
+#[derive(Deserialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum ContentBlock {
+ Text { text: String },
+}
+
+#[derive(Deserialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum TextDelta {
+ TextDelta { text: String },
+}
+
+pub async fn stream_completion(
+ client: &dyn HttpClient,
+ api_url: &str,
+ api_key: &str,
+ request: Request,
+) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
+ let uri = format!("{api_url}/v1/messages");
+ let request = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Anthropic-Version", "2023-06-01")
+ .header("Anthropic-Beta", "messages-2023-12-15")
+ .header("X-Api-Key", api_key)
+ .header("Content-Type", "application/json")
+ .body(AsyncBody::from(serde_json::to_string(&request)?))?;
+ let mut response = client.send(request).await?;
+ if response.status().is_success() {
+ let reader = BufReader::new(response.into_body());
+ Ok(reader
+ .lines()
+ .filter_map(|line| async move {
+ match line {
+ Ok(line) => {
+ let line = line.strip_prefix("data: ")?;
+ match serde_json::from_str(line) {
+ Ok(response) => Some(Ok(response)),
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ }
+ Err(error) => Some(Err(anyhow!(error))),
+ }
+ })
+ .boxed())
+ } else {
+ let mut body = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+
+ let body_str = std::str::from_utf8(&body)?;
+
+ match serde_json::from_str::<ResponseEvent>(body_str) {
+ Ok(_) => Err(anyhow!(
+ "Unexpected success response while expecting an error: {}",
+ body_str,
+ )),
+ Err(_) => Err(anyhow!(
+ "Failed to connect to API: {} {}",
+ response.status(),
+ body_str,
+ )),
+ }
+ }
+}
+
+// #[cfg(test)]
+// mod tests {
+// use super::*;
+// use util::http::IsahcHttpClient;
+
+// #[tokio::test]
+// async fn stream_completion_success() {
+// let http_client = IsahcHttpClient::new().unwrap();
+
+// let request = Request {
+// model: Model::Claude3Opus,
+// messages: vec![RequestMessage {
+// role: Role::User,
+// content: "Ping".to_string(),
+// }],
+// stream: true,
+// system: "Respond to ping with pong".to_string(),
+// max_tokens: 4096,
+// };
+
+// let stream = stream_completion(
+// &http_client,
+// "https://api.anthropic.com",
+// &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
+// request,
+// )
+// .await
+// .unwrap();
+
+// stream
+// .for_each(|event| async {
+// match event {
+// Ok(event) => println!("{:?}", event),
+// Err(e) => eprintln!("Error: {:?}", e),
+// }
+// })
+// .await;
+// }
+// }
@@ -768,15 +768,18 @@ impl AssistantPanel {
open_ai::Model::FourTurbo => open_ai::Model::ThreePointFiveTurbo,
}),
LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
- ZedDotDevModel::GptThreePointFiveTurbo => ZedDotDevModel::GptFour,
- ZedDotDevModel::GptFour => ZedDotDevModel::GptFourTurbo,
- ZedDotDevModel::GptFourTurbo => {
+ ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
+ ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
+ ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Claude3Opus,
+ ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
+ ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
+ ZedDotDevModel::Claude3Haiku => {
match CompletionProvider::global(cx).default_model() {
LanguageModel::ZedDotDev(custom) => custom,
- _ => ZedDotDevModel::GptThreePointFiveTurbo,
+ _ => ZedDotDevModel::Gpt3Point5Turbo,
}
}
- ZedDotDevModel::Custom(_) => ZedDotDevModel::GptThreePointFiveTurbo,
+ ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo,
}),
};
@@ -14,10 +14,13 @@ use settings::Settings;
#[derive(Clone, Debug, Default, PartialEq)]
pub enum ZedDotDevModel {
- GptThreePointFiveTurbo,
- GptFour,
+ Gpt3Point5Turbo,
+ Gpt4,
#[default]
- GptFourTurbo,
+ Gpt4Turbo,
+ Claude3Opus,
+ Claude3Sonnet,
+ Claude3Haiku,
Custom(String),
}
@@ -49,9 +52,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
E: de::Error,
{
match value {
- "gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
- "gpt-4" => Ok(ZedDotDevModel::GptFour),
- "gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
+ "gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
+ "gpt-4" => Ok(ZedDotDevModel::Gpt4),
+ "gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
_ => Ok(ZedDotDevModel::Custom(value.to_owned())),
}
}
@@ -94,27 +97,34 @@ impl JsonSchema for ZedDotDevModel {
impl ZedDotDevModel {
pub fn id(&self) -> &str {
match self {
- Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
- Self::GptFour => "gpt-4",
- Self::GptFourTurbo => "gpt-4-turbo-preview",
+ Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
+ Self::Gpt4 => "gpt-4",
+ Self::Gpt4Turbo => "gpt-4-turbo-preview",
+ Self::Claude3Opus => "claude-3-opus",
+ Self::Claude3Sonnet => "claude-3-sonnet",
+ Self::Claude3Haiku => "claude-3-haiku",
Self::Custom(id) => id,
}
}
pub fn display_name(&self) -> &str {
match self {
- Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
- Self::GptFour => "gpt-4",
- Self::GptFourTurbo => "gpt-4-turbo",
+ Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
+ Self::Gpt4 => "GPT 4",
+ Self::Gpt4Turbo => "GPT 4 Turbo",
+ Self::Claude3Opus => "Claude 3 Opus",
+ Self::Claude3Sonnet => "Claude 3 Sonnet",
+ Self::Claude3Haiku => "Claude 3 Haiku",
Self::Custom(id) => id.as_str(),
}
}
pub fn max_token_count(&self) -> usize {
match self {
- Self::GptThreePointFiveTurbo => 2048,
- Self::GptFour => 4096,
- Self::GptFourTurbo => 128000,
+ Self::Gpt3Point5Turbo => 2048,
+ Self::Gpt4 => 4096,
+ Self::Gpt4Turbo => 128000,
+ Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
Self::Custom(_) => 4096, // TODO: Make this configurable
}
}
@@ -1,5 +1,5 @@
use crate::{
- assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider,
+ assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
LanguageModelRequest,
};
use anyhow::{anyhow, Result};
@@ -78,13 +78,21 @@ impl ZedDotDevCompletionProvider {
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
match request.model {
- crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
- crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour)
- | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo)
- | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => {
+ LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
+ LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
+ | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
+ | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
count_open_ai_tokens(request, cx.background_executor())
}
- crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
+ LanguageModel::ZedDotDev(
+ ZedDotDevModel::Claude3Opus
+ | ZedDotDevModel::Claude3Sonnet
+ | ZedDotDevModel::Claude3Haiku,
+ ) => {
+ // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
+ count_open_ai_tokens(request, cx.background_executor())
+ }
+ LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model,
messages: request
@@ -18,6 +18,7 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
test-support = ["sqlite"]
[dependencies]
+anthropic.workspace = true
anyhow.workspace = true
async-tungstenite = "0.16"
aws-config = { version = "1.1.5" }
@@ -130,6 +130,11 @@ spec:
secretKeyRef:
name: openai
key: api_key
+ - name: ANTHROPIC_API_KEY
+ valueFrom:
+ secretKeyRef:
+ name: anthropic
+ key: api_key
- name: BLOB_STORE_ACCESS_KEY
valueFrom:
secretKeyRef:
@@ -134,6 +134,7 @@ pub struct Config {
pub zed_environment: Arc<str>,
pub openai_api_key: Option<Arc<str>>,
pub google_ai_api_key: Option<Arc<str>>,
+ pub anthropic_api_key: Option<Arc<str>>,
pub zed_client_checksum_seed: Option<String>,
pub slack_panics_webhook: Option<String>,
pub auto_join_channel_id: Option<ChannelId>,
@@ -419,6 +419,7 @@ impl Server {
session,
app_state.config.openai_api_key.clone(),
app_state.config.google_ai_api_key.clone(),
+ app_state.config.anthropic_api_key.clone(),
)
}
})
@@ -3506,6 +3507,7 @@ async fn complete_with_language_model(
session: Session,
open_ai_api_key: Option<Arc<str>>,
google_ai_api_key: Option<Arc<str>>,
+ anthropic_api_key: Option<Arc<str>>,
) -> Result<()> {
let Some(session) = session.for_user() else {
return Err(anyhow!("user not found"))?;
@@ -3524,6 +3526,10 @@ async fn complete_with_language_model(
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
complete_with_google_ai(request, response, session, api_key).await?;
+ } else if request.model.starts_with("claude") {
+ let api_key = anthropic_api_key
+ .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
+ complete_with_anthropic(request, response, session, api_key).await?;
}
Ok(())
@@ -3621,6 +3627,121 @@ async fn complete_with_google_ai(
Ok(())
}
+async fn complete_with_anthropic(
+ request: proto::CompleteWithLanguageModel,
+ response: StreamingResponse<proto::CompleteWithLanguageModel>,
+ session: UserSession,
+ api_key: Arc<str>,
+) -> Result<()> {
+ let model = anthropic::Model::from_id(&request.model)?;
+
+ let mut system_message = String::new();
+ let messages = request
+ .messages
+ .into_iter()
+ .filter_map(|message| match message.role() {
+ LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
+ role: anthropic::Role::User,
+ content: message.content,
+ }),
+ LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
+ role: anthropic::Role::Assistant,
+ content: message.content,
+ }),
+ // Anthropic's API breaks system instructions out as a separate field rather
+ // than having a system message role.
+ LanguageModelRole::LanguageModelSystem => {
+ if !system_message.is_empty() {
+ system_message.push_str("\n\n");
+ }
+ system_message.push_str(&message.content);
+
+ None
+ }
+ })
+ .collect();
+
+ let mut stream = anthropic::stream_completion(
+ &session.http_client,
+ "https://api.anthropic.com",
+ &api_key,
+ anthropic::Request {
+ model,
+ messages,
+ stream: true,
+ system: system_message,
+ max_tokens: 4092,
+ },
+ )
+ .await?;
+
+ let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
+
+ while let Some(event) = stream.next().await {
+ let event = event?;
+
+ match event {
+ anthropic::ResponseEvent::MessageStart { message } => {
+ if let Some(role) = message.role {
+ if role == "assistant" {
+ current_role = proto::LanguageModelRole::LanguageModelAssistant;
+ } else if role == "user" {
+ current_role = proto::LanguageModelRole::LanguageModelUser;
+ }
+ }
+ }
+ anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
+ match content_block {
+ anthropic::ContentBlock::Text { text } => {
+ if !text.is_empty() {
+ response.send(proto::LanguageModelResponse {
+ choices: vec![proto::LanguageModelChoiceDelta {
+ index: 0,
+ delta: Some(proto::LanguageModelResponseMessage {
+ role: Some(current_role as i32),
+ content: Some(text),
+ }),
+ finish_reason: None,
+ }],
+ })?;
+ }
+ }
+ }
+ }
+ anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
+ anthropic::TextDelta::TextDelta { text } => {
+ response.send(proto::LanguageModelResponse {
+ choices: vec![proto::LanguageModelChoiceDelta {
+ index: 0,
+ delta: Some(proto::LanguageModelResponseMessage {
+ role: Some(current_role as i32),
+ content: Some(text),
+ }),
+ finish_reason: None,
+ }],
+ })?;
+ }
+ },
+ anthropic::ResponseEvent::MessageDelta { delta, .. } => {
+ if let Some(stop_reason) = delta.stop_reason {
+ response.send(proto::LanguageModelResponse {
+ choices: vec![proto::LanguageModelChoiceDelta {
+ index: 0,
+ delta: None,
+ finish_reason: Some(stop_reason),
+ }],
+ })?;
+ }
+ }
+ anthropic::ResponseEvent::ContentBlockStop { .. } => {}
+ anthropic::ResponseEvent::MessageStop {} => {}
+ anthropic::ResponseEvent::Ping {} => {}
+ }
+ }
+
+ Ok(())
+}
+
struct CountTokensWithLanguageModelRateLimit;
impl RateLimit for CountTokensWithLanguageModelRateLimit {
@@ -512,6 +512,7 @@ impl TestServer {
blob_store_bucket: None,
openai_api_key: None,
google_ai_api_key: None,
+ anthropic_api_key: None,
clickhouse_url: None,
clickhouse_user: None,
clickhouse_password: None,