Enable Claude 3 models to be used via the Zed server if "language-models" feature flag is enabled for user (#10015)

Nathan Sobo created

Release Notes:

- N/A

Change summary

Cargo.lock                                      |  13 +
Cargo.toml                                      |   2 
crates/anthropic/Cargo.toml                     |  22 +
crates/anthropic/src/anthropic.rs               | 234 +++++++++++++++++++
crates/assistant/src/assistant_panel.rs         |  13 
crates/assistant/src/assistant_settings.rs      |  40 ++-
crates/assistant/src/completion_provider/zed.rs |  20 +
crates/collab/Cargo.toml                        |   1 
crates/collab/k8s/collab.template.yml           |   5 
crates/collab/src/lib.rs                        |   1 
crates/collab/src/rpc.rs                        | 121 +++++++++
crates/collab/src/tests/test_server.rs          |   1 
12 files changed, 447 insertions(+), 26 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -213,6 +213,18 @@ dependencies = [
  "windows-sys 0.48.0",
 ]
 
+[[package]]
+name = "anthropic"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "futures 0.3.28",
+ "serde",
+ "serde_json",
+ "tokio",
+ "util",
+]
+
 [[package]]
 name = "anyhow"
 version = "1.0.75"
@@ -2214,6 +2226,7 @@ dependencies = [
 name = "collab"
 version = "0.44.0"
 dependencies = [
+ "anthropic",
  "anyhow",
  "async-trait",
  "async-tungstenite",

Cargo.toml 🔗

@@ -1,6 +1,7 @@
 [workspace]
 members = [
     "crates/activity_indicator",
+    "crates/anthropic",
     "crates/assets",
     "crates/assistant",
     "crates/audio",
@@ -119,6 +120,7 @@ resolver = "2"
 [workspace.dependencies]
 activity_indicator = { path = "crates/activity_indicator" }
 ai = { path = "crates/ai" }
+anthropic = { path = "crates/anthropic" }
 assets = { path = "crates/assets" }
 assistant = { path = "crates/assistant" }
 audio = { path = "crates/audio" }

crates/anthropic/Cargo.toml 🔗

@@ -0,0 +1,22 @@
+[package]
+name = "anthropic"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[lib]
+path = "src/anthropic.rs"
+
+[dependencies]
+anyhow.workspace = true
+futures.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+util.workspace = true
+
+[dev-dependencies]
+tokio.workspace = true
+
+[lints]
+workspace = true

crates/anthropic/src/anthropic.rs 🔗

@@ -0,0 +1,234 @@
+use anyhow::{anyhow, Result};
+use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
+use serde::{Deserialize, Serialize};
+use std::convert::TryFrom;
+use util::http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
+
+#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
+pub enum Model {
+    #[default]
+    #[serde(rename = "claude-3-opus-20240229")]
+    Claude3Opus,
+    #[serde(rename = "claude-3-sonnet-20240229")]
+    Claude3Sonnet,
+    #[serde(rename = "claude-3-haiku-20240307")]
+    Claude3Haiku,
+}
+
+impl Model {
+    pub fn from_id(id: &str) -> Result<Self> {
+        if id.starts_with("claude-3-opus") {
+            Ok(Self::Claude3Opus)
+        } else if id.starts_with("claude-3-sonnet") {
+            Ok(Self::Claude3Sonnet)
+        } else if id.starts_with("claude-3-haiku") {
+            Ok(Self::Claude3Haiku)
+        } else {
+            Err(anyhow!("Invalid model id: {}", id))
+        }
+    }
+
+    pub fn display_name(&self) -> &'static str {
+        match self {
+            Self::Claude3Opus => "Claude 3 Opus",
+            Self::Claude3Sonnet => "Claude 3 Sonnet",
+            Self::Claude3Haiku => "Claude 3 Haiku",
+        }
+    }
+
+    pub fn max_token_count(&self) -> usize {
+        200_000
+    }
+}
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+    User,
+    Assistant,
+}
+
+impl TryFrom<String> for Role {
+    type Error = anyhow::Error;
+
+    fn try_from(value: String) -> Result<Self> {
+        match value.as_str() {
+            "user" => Ok(Self::User),
+            "assistant" => Ok(Self::Assistant),
+            _ => Err(anyhow!("invalid role '{value}'")),
+        }
+    }
+}
+
+impl From<Role> for String {
+    fn from(val: Role) -> Self {
+        match val {
+            Role::User => "user".to_owned(),
+            Role::Assistant => "assistant".to_owned(),
+        }
+    }
+}
+
+#[derive(Debug, Serialize)]
+pub struct Request {
+    pub model: Model,
+    pub messages: Vec<RequestMessage>,
+    pub stream: bool,
+    pub system: String,
+    pub max_tokens: u32,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct RequestMessage {
+    pub role: Role,
+    pub content: String,
+}
+
+#[derive(Deserialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum ResponseEvent {
+    MessageStart {
+        message: ResponseMessage,
+    },
+    ContentBlockStart {
+        index: u32,
+        content_block: ContentBlock,
+    },
+    Ping {},
+    ContentBlockDelta {
+        index: u32,
+        delta: TextDelta,
+    },
+    ContentBlockStop {
+        index: u32,
+    },
+    MessageDelta {
+        delta: ResponseMessage,
+        usage: Usage,
+    },
+    MessageStop {},
+}
+
+#[derive(Deserialize, Debug)]
+pub struct ResponseMessage {
+    #[serde(rename = "type")]
+    pub message_type: Option<String>,
+    pub id: Option<String>,
+    pub role: Option<String>,
+    pub content: Option<Vec<String>>,
+    pub model: Option<String>,
+    pub stop_reason: Option<String>,
+    pub stop_sequence: Option<String>,
+    pub usage: Option<Usage>,
+}
+
+#[derive(Deserialize, Debug)]
+pub struct Usage {
+    pub input_tokens: Option<u32>,
+    pub output_tokens: Option<u32>,
+}
+
+#[derive(Deserialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum ContentBlock {
+    Text { text: String },
+}
+
+#[derive(Deserialize, Debug)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum TextDelta {
+    TextDelta { text: String },
+}
+
+pub async fn stream_completion(
+    client: &dyn HttpClient,
+    api_url: &str,
+    api_key: &str,
+    request: Request,
+) -> Result<BoxStream<'static, Result<ResponseEvent>>> {
+    let uri = format!("{api_url}/v1/messages");
+    let request = HttpRequest::builder()
+        .method(Method::POST)
+        .uri(uri)
+        .header("Anthropic-Version", "2023-06-01")
+        .header("Anthropic-Beta", "messages-2023-12-15")
+        .header("X-Api-Key", api_key)
+        .header("Content-Type", "application/json")
+        .body(AsyncBody::from(serde_json::to_string(&request)?))?;
+    let mut response = client.send(request).await?;
+    if response.status().is_success() {
+        let reader = BufReader::new(response.into_body());
+        Ok(reader
+            .lines()
+            .filter_map(|line| async move {
+                match line {
+                    Ok(line) => {
+                        let line = line.strip_prefix("data: ")?;
+                        match serde_json::from_str(line) {
+                            Ok(response) => Some(Ok(response)),
+                            Err(error) => Some(Err(anyhow!(error))),
+                        }
+                    }
+                    Err(error) => Some(Err(anyhow!(error))),
+                }
+            })
+            .boxed())
+    } else {
+        let mut body = Vec::new();
+        response.body_mut().read_to_end(&mut body).await?;
+
+        let body_str = std::str::from_utf8(&body)?;
+
+        match serde_json::from_str::<ResponseEvent>(body_str) {
+            Ok(_) => Err(anyhow!(
+                "Unexpected success response while expecting an error: {}",
+                body_str,
+            )),
+            Err(_) => Err(anyhow!(
+                "Failed to connect to API: {} {}",
+                response.status(),
+                body_str,
+            )),
+        }
+    }
+}
+
+// #[cfg(test)]
+// mod tests {
+//     use super::*;
+//     use util::http::IsahcHttpClient;
+
+//     #[tokio::test]
+//     async fn stream_completion_success() {
+//         let http_client = IsahcHttpClient::new().unwrap();
+
+//         let request = Request {
+//             model: Model::Claude3Opus,
+//             messages: vec![RequestMessage {
+//                 role: Role::User,
+//                 content: "Ping".to_string(),
+//             }],
+//             stream: true,
+//             system: "Respond to ping with pong".to_string(),
+//             max_tokens: 4096,
+//         };
+
+//         let stream = stream_completion(
+//             &http_client,
+//             "https://api.anthropic.com",
+//             &std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY not set"),
+//             request,
+//         )
+//         .await
+//         .unwrap();
+
+//         stream
+//             .for_each(|event| async {
+//                 match event {
+//                     Ok(event) => println!("{:?}", event),
+//                     Err(e) => eprintln!("Error: {:?}", e),
+//                 }
+//             })
+//             .await;
+//     }
+// }

crates/assistant/src/assistant_panel.rs 🔗

@@ -768,15 +768,18 @@ impl AssistantPanel {
                 open_ai::Model::FourTurbo => open_ai::Model::ThreePointFiveTurbo,
             }),
             LanguageModel::ZedDotDev(model) => LanguageModel::ZedDotDev(match &model {
-                ZedDotDevModel::GptThreePointFiveTurbo => ZedDotDevModel::GptFour,
-                ZedDotDevModel::GptFour => ZedDotDevModel::GptFourTurbo,
-                ZedDotDevModel::GptFourTurbo => {
+                ZedDotDevModel::Gpt3Point5Turbo => ZedDotDevModel::Gpt4,
+                ZedDotDevModel::Gpt4 => ZedDotDevModel::Gpt4Turbo,
+                ZedDotDevModel::Gpt4Turbo => ZedDotDevModel::Claude3Opus,
+                ZedDotDevModel::Claude3Opus => ZedDotDevModel::Claude3Sonnet,
+                ZedDotDevModel::Claude3Sonnet => ZedDotDevModel::Claude3Haiku,
+                ZedDotDevModel::Claude3Haiku => {
                     match CompletionProvider::global(cx).default_model() {
                         LanguageModel::ZedDotDev(custom) => custom,
-                        _ => ZedDotDevModel::GptThreePointFiveTurbo,
+                        _ => ZedDotDevModel::Gpt3Point5Turbo,
                     }
                 }
-                ZedDotDevModel::Custom(_) => ZedDotDevModel::GptThreePointFiveTurbo,
+                ZedDotDevModel::Custom(_) => ZedDotDevModel::Gpt3Point5Turbo,
             }),
         };
 

crates/assistant/src/assistant_settings.rs 🔗

@@ -14,10 +14,13 @@ use settings::Settings;
 
 #[derive(Clone, Debug, Default, PartialEq)]
 pub enum ZedDotDevModel {
-    GptThreePointFiveTurbo,
-    GptFour,
+    Gpt3Point5Turbo,
+    Gpt4,
     #[default]
-    GptFourTurbo,
+    Gpt4Turbo,
+    Claude3Opus,
+    Claude3Sonnet,
+    Claude3Haiku,
     Custom(String),
 }
 
@@ -49,9 +52,9 @@ impl<'de> Deserialize<'de> for ZedDotDevModel {
                 E: de::Error,
             {
                 match value {
-                    "gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
-                    "gpt-4" => Ok(ZedDotDevModel::GptFour),
-                    "gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
+                    "gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
+                    "gpt-4" => Ok(ZedDotDevModel::Gpt4),
+                    "gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
                     _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
                 }
             }
@@ -94,27 +97,34 @@ impl JsonSchema for ZedDotDevModel {
 impl ZedDotDevModel {
     pub fn id(&self) -> &str {
         match self {
-            Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
-            Self::GptFour => "gpt-4",
-            Self::GptFourTurbo => "gpt-4-turbo-preview",
+            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
+            Self::Gpt4 => "gpt-4",
+            Self::Gpt4Turbo => "gpt-4-turbo-preview",
+            Self::Claude3Opus => "claude-3-opus",
+            Self::Claude3Sonnet => "claude-3-sonnet",
+            Self::Claude3Haiku => "claude-3-haiku",
             Self::Custom(id) => id,
         }
     }
 
     pub fn display_name(&self) -> &str {
         match self {
-            Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
-            Self::GptFour => "gpt-4",
-            Self::GptFourTurbo => "gpt-4-turbo",
+            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
+            Self::Gpt4 => "GPT 4",
+            Self::Gpt4Turbo => "GPT 4 Turbo",
+            Self::Claude3Opus => "Claude 3 Opus",
+            Self::Claude3Sonnet => "Claude 3 Sonnet",
+            Self::Claude3Haiku => "Claude 3 Haiku",
             Self::Custom(id) => id.as_str(),
         }
     }
 
     pub fn max_token_count(&self) -> usize {
         match self {
-            Self::GptThreePointFiveTurbo => 2048,
-            Self::GptFour => 4096,
-            Self::GptFourTurbo => 128000,
+            Self::Gpt3Point5Turbo => 2048,
+            Self::Gpt4 => 4096,
+            Self::Gpt4Turbo => 128000,
+            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
             Self::Custom(_) => 4096, // TODO: Make this configurable
         }
     }

crates/assistant/src/completion_provider/zed.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{
-    assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider,
+    assistant_settings::ZedDotDevModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
     LanguageModelRequest,
 };
 use anyhow::{anyhow, Result};
@@ -78,13 +78,21 @@ impl ZedDotDevCompletionProvider {
         cx: &AppContext,
     ) -> BoxFuture<'static, Result<usize>> {
         match request.model {
-            crate::LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
-            crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFour)
-            | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptFourTurbo)
-            | crate::LanguageModel::ZedDotDev(ZedDotDevModel::GptThreePointFiveTurbo) => {
+            LanguageModel::OpenAi(_) => future::ready(Err(anyhow!("invalid model"))).boxed(),
+            LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4)
+            | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt4Turbo)
+            | LanguageModel::ZedDotDev(ZedDotDevModel::Gpt3Point5Turbo) => {
                 count_open_ai_tokens(request, cx.background_executor())
             }
-            crate::LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
+            LanguageModel::ZedDotDev(
+                ZedDotDevModel::Claude3Opus
+                | ZedDotDevModel::Claude3Sonnet
+                | ZedDotDevModel::Claude3Haiku,
+            ) => {
+                // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
+                count_open_ai_tokens(request, cx.background_executor())
+            }
+            LanguageModel::ZedDotDev(ZedDotDevModel::Custom(model)) => {
                 let request = self.client.request(proto::CountTokensWithLanguageModel {
                     model,
                     messages: request

crates/collab/Cargo.toml 🔗

@@ -18,6 +18,7 @@ sqlite = ["sea-orm/sqlx-sqlite", "sqlx/sqlite"]
 test-support = ["sqlite"]
 
 [dependencies]
+anthropic.workspace = true
 anyhow.workspace = true
 async-tungstenite = "0.16"
 aws-config = { version = "1.1.5" }

crates/collab/k8s/collab.template.yml 🔗

@@ -130,6 +130,11 @@ spec:
                 secretKeyRef:
                   name: openai
                   key: api_key
+            - name: ANTHROPIC_API_KEY
+              valueFrom:
+                secretKeyRef:
+                  name: anthropic
+                  key: api_key
             - name: BLOB_STORE_ACCESS_KEY
               valueFrom:
                 secretKeyRef:

crates/collab/src/lib.rs 🔗

@@ -134,6 +134,7 @@ pub struct Config {
     pub zed_environment: Arc<str>,
     pub openai_api_key: Option<Arc<str>>,
     pub google_ai_api_key: Option<Arc<str>>,
+    pub anthropic_api_key: Option<Arc<str>>,
     pub zed_client_checksum_seed: Option<String>,
     pub slack_panics_webhook: Option<String>,
     pub auto_join_channel_id: Option<ChannelId>,

crates/collab/src/rpc.rs 🔗

@@ -419,6 +419,7 @@ impl Server {
                         session,
                         app_state.config.openai_api_key.clone(),
                         app_state.config.google_ai_api_key.clone(),
+                        app_state.config.anthropic_api_key.clone(),
                     )
                 }
             })
@@ -3506,6 +3507,7 @@ async fn complete_with_language_model(
     session: Session,
     open_ai_api_key: Option<Arc<str>>,
     google_ai_api_key: Option<Arc<str>>,
+    anthropic_api_key: Option<Arc<str>>,
 ) -> Result<()> {
     let Some(session) = session.for_user() else {
         return Err(anyhow!("user not found"))?;
@@ -3524,6 +3526,10 @@ async fn complete_with_language_model(
         let api_key = google_ai_api_key
             .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
         complete_with_google_ai(request, response, session, api_key).await?;
+    } else if request.model.starts_with("claude") {
+        let api_key = anthropic_api_key
+            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
+        complete_with_anthropic(request, response, session, api_key).await?;
     }
 
     Ok(())
@@ -3621,6 +3627,121 @@ async fn complete_with_google_ai(
     Ok(())
 }
 
+async fn complete_with_anthropic(
+    request: proto::CompleteWithLanguageModel,
+    response: StreamingResponse<proto::CompleteWithLanguageModel>,
+    session: UserSession,
+    api_key: Arc<str>,
+) -> Result<()> {
+    let model = anthropic::Model::from_id(&request.model)?;
+
+    let mut system_message = String::new();
+    let messages = request
+        .messages
+        .into_iter()
+        .filter_map(|message| match message.role() {
+            LanguageModelRole::LanguageModelUser => Some(anthropic::RequestMessage {
+                role: anthropic::Role::User,
+                content: message.content,
+            }),
+            LanguageModelRole::LanguageModelAssistant => Some(anthropic::RequestMessage {
+                role: anthropic::Role::Assistant,
+                content: message.content,
+            }),
+            // Anthropic's API breaks system instructions out as a separate field rather
+            // than having a system message role.
+            LanguageModelRole::LanguageModelSystem => {
+                if !system_message.is_empty() {
+                    system_message.push_str("\n\n");
+                }
+                system_message.push_str(&message.content);
+
+                None
+            }
+        })
+        .collect();
+
+    let mut stream = anthropic::stream_completion(
+        &session.http_client,
+        "https://api.anthropic.com",
+        &api_key,
+        anthropic::Request {
+            model,
+            messages,
+            stream: true,
+            system: system_message,
+            max_tokens: 4092,
+        },
+    )
+    .await?;
+
+    let mut current_role = proto::LanguageModelRole::LanguageModelAssistant;
+
+    while let Some(event) = stream.next().await {
+        let event = event?;
+
+        match event {
+            anthropic::ResponseEvent::MessageStart { message } => {
+                if let Some(role) = message.role {
+                    if role == "assistant" {
+                        current_role = proto::LanguageModelRole::LanguageModelAssistant;
+                    } else if role == "user" {
+                        current_role = proto::LanguageModelRole::LanguageModelUser;
+                    }
+                }
+            }
+            anthropic::ResponseEvent::ContentBlockStart { content_block, .. } => {
+                match content_block {
+                    anthropic::ContentBlock::Text { text } => {
+                        if !text.is_empty() {
+                            response.send(proto::LanguageModelResponse {
+                                choices: vec![proto::LanguageModelChoiceDelta {
+                                    index: 0,
+                                    delta: Some(proto::LanguageModelResponseMessage {
+                                        role: Some(current_role as i32),
+                                        content: Some(text),
+                                    }),
+                                    finish_reason: None,
+                                }],
+                            })?;
+                        }
+                    }
+                }
+            }
+            anthropic::ResponseEvent::ContentBlockDelta { delta, .. } => match delta {
+                anthropic::TextDelta::TextDelta { text } => {
+                    response.send(proto::LanguageModelResponse {
+                        choices: vec![proto::LanguageModelChoiceDelta {
+                            index: 0,
+                            delta: Some(proto::LanguageModelResponseMessage {
+                                role: Some(current_role as i32),
+                                content: Some(text),
+                            }),
+                            finish_reason: None,
+                        }],
+                    })?;
+                }
+            },
+            anthropic::ResponseEvent::MessageDelta { delta, .. } => {
+                if let Some(stop_reason) = delta.stop_reason {
+                    response.send(proto::LanguageModelResponse {
+                        choices: vec![proto::LanguageModelChoiceDelta {
+                            index: 0,
+                            delta: None,
+                            finish_reason: Some(stop_reason),
+                        }],
+                    })?;
+                }
+            }
+            anthropic::ResponseEvent::ContentBlockStop { .. } => {}
+            anthropic::ResponseEvent::MessageStop {} => {}
+            anthropic::ResponseEvent::Ping {} => {}
+        }
+    }
+
+    Ok(())
+}
+
 struct CountTokensWithLanguageModelRateLimit;
 
 impl RateLimit for CountTokensWithLanguageModelRateLimit {

crates/collab/src/tests/test_server.rs 🔗

@@ -512,6 +512,7 @@ impl TestServer {
                 blob_store_bucket: None,
                 openai_api_key: None,
                 google_ai_api_key: None,
+                anthropic_api_key: None,
                 clickhouse_url: None,
                 clickhouse_user: None,
                 clickhouse_password: None,