Allow using a custom model when using zed.dev (#14933)

Antonio Scandurra created

Release Notes:

- N/A

Change summary

crates/anthropic/src/anthropic.rs              |  33 +++++
crates/assistant/src/assistant_settings.rs     |  11 +
crates/collab/src/rpc.rs                       |  51 ++++++--
crates/completion/src/cloud.rs                 |  12 +-
crates/completion/src/open_ai.rs               |   1 
crates/language_model/src/model/cloud_model.rs | 116 +++++--------------
6 files changed, 114 insertions(+), 110 deletions(-)

Detailed changes

crates/anthropic/src/anthropic.rs 🔗

@@ -20,6 +20,12 @@ pub enum Model {
     Claude3Sonnet,
     #[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
     Claude3Haiku,
+    #[serde(rename = "custom")]
+    Custom {
+        name: String,
+        #[serde(default)]
+        max_tokens: Option<usize>,
+    },
 }
 
 impl Model {
@@ -33,30 +39,41 @@ impl Model {
         } else if id.starts_with("claude-3-haiku") {
             Ok(Self::Claude3Haiku)
         } else {
-            Err(anyhow!("Invalid model id: {}", id))
+            Ok(Self::Custom {
+                name: id.to_string(),
+                max_tokens: None,
+            })
         }
     }
 
-    pub fn id(&self) -> &'static str {
+    pub fn id(&self) -> &str {
         match self {
             Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
             Model::Claude3Opus => "claude-3-opus-20240229",
             Model::Claude3Sonnet => "claude-3-sonnet-20240229",
             Model::Claude3Haiku => "claude-3-opus-20240307",
+            Model::Custom { name, .. } => name,
         }
     }
 
-    pub fn display_name(&self) -> &'static str {
+    pub fn display_name(&self) -> &str {
         match self {
             Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
             Self::Claude3Opus => "Claude 3 Opus",
             Self::Claude3Sonnet => "Claude 3 Sonnet",
             Self::Claude3Haiku => "Claude 3 Haiku",
+            Self::Custom { name, .. } => name,
         }
     }
 
     pub fn max_token_count(&self) -> usize {
-        200_000
+        match self {
+            Self::Claude3_5Sonnet
+            | Self::Claude3Opus
+            | Self::Claude3Sonnet
+            | Self::Claude3Haiku => 200_000,
+            Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
+        }
     }
 }
 
@@ -90,6 +107,7 @@ impl From<Role> for String {
 
 #[derive(Debug, Serialize)]
 pub struct Request {
+    #[serde(serialize_with = "serialize_request_model")]
     pub model: Model,
     pub messages: Vec<RequestMessage>,
     pub stream: bool,
@@ -97,6 +115,13 @@ pub struct Request {
     pub max_tokens: u32,
 }
 
+fn serialize_request_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
+where
+    S: serde::Serializer,
+{
+    serializer.serialize_str(&model.id())
+}
+
 #[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 pub struct RequestMessage {
     pub role: Role,

crates/assistant/src/assistant_settings.rs 🔗

@@ -668,7 +668,11 @@ mod tests {
                             "version": "1",
                             "provider": {
                                 "name": "zed.dev",
-                                "default_model": "custom"
+                                "default_model": {
+                                    "custom": {
+                                        "name": "custom-provider"
+                                    }
+                                }
                             }
                         }
                     }"#,
@@ -679,7 +683,10 @@ mod tests {
         assert_eq!(
             AssistantSettings::get_global(cx).provider,
             AssistantProvider::ZedDotDev {
-                model: CloudModel::Custom("custom".into())
+                model: CloudModel::Custom {
+                    name: "custom-provider".into(),
+                    max_tokens: None
+                }
             }
         );
     }

crates/collab/src/rpc.rs 🔗

@@ -4514,7 +4514,7 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
 }
 
 async fn complete_with_language_model(
-    request: proto::CompleteWithLanguageModel,
+    mut request: proto::CompleteWithLanguageModel,
     response: StreamingResponse<proto::CompleteWithLanguageModel>,
     session: Session,
     open_ai_api_key: Option<Arc<str>>,
@@ -4530,18 +4530,43 @@ async fn complete_with_language_model(
         .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
         .await?;
 
-    if request.model.starts_with("gpt") {
-        let api_key =
-            open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
-        complete_with_open_ai(request, response, session, api_key).await?;
-    } else if request.model.starts_with("gemini") {
-        let api_key = google_ai_api_key
-            .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
-        complete_with_google_ai(request, response, session, api_key).await?;
-    } else if request.model.starts_with("claude") {
-        let api_key = anthropic_api_key
-            .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
-        complete_with_anthropic(request, response, session, api_key).await?;
+    let mut provider_and_model = request.model.split('/');
+    let (provider, model) = match (
+        provider_and_model.next().unwrap(),
+        provider_and_model.next(),
+    ) {
+        (provider, Some(model)) => (provider, model),
+        (model, None) => {
+            if model.starts_with("gpt") {
+                ("openai", model)
+            } else if model.starts_with("gemini") {
+                ("google", model)
+            } else if model.starts_with("claude") {
+                ("anthropic", model)
+            } else {
+                ("unknown", model)
+            }
+        }
+    };
+    let provider = provider.to_string();
+    request.model = model.to_string();
+
+    match provider.as_str() {
+        "openai" => {
+            let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
+            complete_with_open_ai(request, response, session, api_key).await?;
+        }
+        "anthropic" => {
+            let api_key =
+                anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
+            complete_with_anthropic(request, response, session, api_key).await?;
+        }
+        "google" => {
+            let api_key =
+                google_ai_api_key.context("no Google AI API key configured on the server")?;
+            complete_with_google_ai(request, response, session, api_key).await?;
+        }
+        provider => return Err(anyhow!("unknown provider {:?}", provider))?,
     }
 
     Ok(())

crates/completion/src/cloud.rs 🔗

@@ -54,15 +54,15 @@ impl CloudCompletionProvider {
 
 impl LanguageModelCompletionProvider for CloudCompletionProvider {
     fn available_models(&self) -> Vec<LanguageModel> {
-        let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
-            Some(custom_model)
+        let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
+            Some(self.model.clone())
         } else {
             None
         };
         CloudModel::iter()
             .filter_map(move |model| {
-                if let CloudModel::Custom(_) = model {
-                    Some(CloudModel::Custom(custom_model.take()?))
+                if let CloudModel::Custom { .. } = model {
+                    custom_model.take()
                 } else {
                     Some(model)
                 }
@@ -117,9 +117,9 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider {
                 // Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
                 count_open_ai_tokens(request, cx.background_executor())
             }
-            LanguageModel::Cloud(CloudModel::Custom(model)) => {
+            LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
                 let request = self.client.request(proto::CountTokensWithLanguageModel {
-                    model,
+                    model: name,
                     messages: request
                         .messages
                         .iter()

crates/completion/src/open_ai.rs 🔗

@@ -241,6 +241,7 @@ pub fn count_open_ai_tokens(
                 | LanguageModel::Cloud(CloudModel::Claude3Opus)
                 | LanguageModel::Cloud(CloudModel::Claude3Sonnet)
                 | LanguageModel::Cloud(CloudModel::Claude3Haiku)
+                | LanguageModel::Cloud(CloudModel::Custom { .. })
                 | LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
                     // Tiktoken doesn't yet support these models, so we manually use the
                     // same tokenizer as GPT-4.

crates/language_model/src/model/cloud_model.rs 🔗

@@ -2,100 +2,40 @@ use crate::LanguageModelRequest;
 pub use anthropic::Model as AnthropicModel;
 pub use ollama::Model as OllamaModel;
 pub use open_ai::Model as OpenAiModel;
-use schemars::{
-    schema::{InstanceType, Metadata, Schema, SchemaObject},
-    JsonSchema,
-};
-use serde::{
-    de::{self, Visitor},
-    Deserialize, Deserializer, Serialize, Serializer,
-};
-use std::fmt;
-use strum::{EnumIter, IntoEnumIterator};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use strum::EnumIter;
 
-#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
+#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
 pub enum CloudModel {
+    #[serde(rename = "gpt-3.5-turbo")]
     Gpt3Point5Turbo,
+    #[serde(rename = "gpt-4")]
     Gpt4,
+    #[serde(rename = "gpt-4-turbo-preview")]
     Gpt4Turbo,
+    #[serde(rename = "gpt-4o")]
     #[default]
     Gpt4Omni,
+    #[serde(rename = "gpt-4o-mini")]
     Gpt4OmniMini,
+    #[serde(rename = "claude-3-5-sonnet")]
     Claude3_5Sonnet,
+    #[serde(rename = "claude-3-opus")]
     Claude3Opus,
+    #[serde(rename = "claude-3-sonnet")]
     Claude3Sonnet,
+    #[serde(rename = "claude-3-haiku")]
     Claude3Haiku,
+    #[serde(rename = "gemini-1.5-pro")]
     Gemini15Pro,
+    #[serde(rename = "gemini-1.5-flash")]
     Gemini15Flash,
-    Custom(String),
-}
-
-impl Serialize for CloudModel {
-    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
-    where
-        S: Serializer,
-    {
-        serializer.serialize_str(self.id())
-    }
-}
-
-impl<'de> Deserialize<'de> for CloudModel {
-    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
-    where
-        D: Deserializer<'de>,
-    {
-        struct ZedDotDevModelVisitor;
-
-        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
-            type Value = CloudModel;
-
-            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
-                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
-            }
-
-            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
-            where
-                E: de::Error,
-            {
-                let model = CloudModel::iter()
-                    .find(|model| model.id() == value)
-                    .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
-                Ok(model)
-            }
-        }
-
-        deserializer.deserialize_str(ZedDotDevModelVisitor)
-    }
-}
-
-impl JsonSchema for CloudModel {
-    fn schema_name() -> String {
-        "ZedDotDevModel".to_owned()
-    }
-
-    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
-        let variants = CloudModel::iter()
-            .filter_map(|model| {
-                let id = model.id();
-                if id.is_empty() {
-                    None
-                } else {
-                    Some(id.to_string())
-                }
-            })
-            .collect::<Vec<_>>();
-        Schema::Object(SchemaObject {
-            instance_type: Some(InstanceType::String.into()),
-            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
-            metadata: Some(Box::new(Metadata {
-                title: Some("ZedDotDevModel".to_owned()),
-                default: Some(CloudModel::default().id().into()),
-                examples: variants.into_iter().map(Into::into).collect(),
-                ..Default::default()
-            })),
-            ..Default::default()
-        })
-    }
+    #[serde(rename = "custom")]
+    Custom {
+        name: String,
+        max_tokens: Option<usize>,
+    },
 }
 
 impl CloudModel {
@@ -112,7 +52,7 @@ impl CloudModel {
             Self::Claude3Haiku => "claude-3-haiku",
             Self::Gemini15Pro => "gemini-1.5-pro",
             Self::Gemini15Flash => "gemini-1.5-flash",
-            Self::Custom(id) => id,
+            Self::Custom { name, .. } => name,
         }
     }
 
@@ -129,7 +69,7 @@ impl CloudModel {
             Self::Claude3Haiku => "Claude 3 Haiku",
             Self::Gemini15Pro => "Gemini 1.5 Pro",
             Self::Gemini15Flash => "Gemini 1.5 Flash",
-            Self::Custom(id) => id.as_str(),
+            Self::Custom { name, .. } => name,
         }
     }
 
@@ -145,14 +85,20 @@ impl CloudModel {
             | Self::Claude3Haiku => 200000,
             Self::Gemini15Pro => 128000,
             Self::Gemini15Flash => 32000,
-            Self::Custom(_) => 4096, // TODO: Make this configurable
+            Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
         }
     }
 
     pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
         match self {
-            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
-                request.preprocess_anthropic()
+            Self::Claude3Opus
+            | Self::Claude3Sonnet
+            | Self::Claude3Haiku
+            | Self::Claude3_5Sonnet => {
+                request.preprocess_anthropic();
+            }
+            Self::Custom { name, .. } if name.starts_with("anthropic/") => {
+                request.preprocess_anthropic();
             }
             _ => {}
         }