Detailed changes
@@ -20,6 +20,12 @@ pub enum Model {
Claude3Sonnet,
#[serde(alias = "claude-3-haiku", rename = "claude-3-haiku-20240307")]
Claude3Haiku,
+ #[serde(rename = "custom")]
+ Custom {
+ name: String,
+ #[serde(default)]
+ max_tokens: Option<usize>,
+ },
}
impl Model {
@@ -33,30 +39,41 @@ impl Model {
} else if id.starts_with("claude-3-haiku") {
Ok(Self::Claude3Haiku)
} else {
- Err(anyhow!("Invalid model id: {}", id))
+ Ok(Self::Custom {
+ name: id.to_string(),
+ max_tokens: None,
+ })
}
}
- pub fn id(&self) -> &'static str {
+ pub fn id(&self) -> &str {
match self {
Model::Claude3_5Sonnet => "claude-3-5-sonnet-20240620",
Model::Claude3Opus => "claude-3-opus-20240229",
Model::Claude3Sonnet => "claude-3-sonnet-20240229",
Model::Claude3Haiku => "claude-3-opus-20240307",
+ Model::Custom { name, .. } => name,
}
}
- pub fn display_name(&self) -> &'static str {
+ pub fn display_name(&self) -> &str {
match self {
Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
+ Self::Custom { name, .. } => name,
}
}
pub fn max_token_count(&self) -> usize {
- 200_000
+ match self {
+ Self::Claude3_5Sonnet
+ | Self::Claude3Opus
+ | Self::Claude3Sonnet
+ | Self::Claude3Haiku => 200_000,
+ Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
+ }
}
}
@@ -90,6 +107,7 @@ impl From<Role> for String {
#[derive(Debug, Serialize)]
pub struct Request {
+ #[serde(serialize_with = "serialize_request_model")]
pub model: Model,
pub messages: Vec<RequestMessage>,
pub stream: bool,
@@ -97,6 +115,13 @@ pub struct Request {
pub max_tokens: u32,
}
+fn serialize_request_model<S>(model: &Model, serializer: S) -> Result<S::Ok, S::Error>
+where
+ S: serde::Serializer,
+{
+ serializer.serialize_str(&model.id())
+}
+
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
pub struct RequestMessage {
pub role: Role,
@@ -668,7 +668,11 @@ mod tests {
"version": "1",
"provider": {
"name": "zed.dev",
- "default_model": "custom"
+ "default_model": {
+ "custom": {
+ "name": "custom-provider"
+ }
+ }
}
}
}"#,
@@ -679,7 +683,10 @@ mod tests {
assert_eq!(
AssistantSettings::get_global(cx).provider,
AssistantProvider::ZedDotDev {
- model: CloudModel::Custom("custom".into())
+ model: CloudModel::Custom {
+ name: "custom-provider".into(),
+ max_tokens: None
+ }
}
);
}
@@ -4514,7 +4514,7 @@ impl RateLimit for CompleteWithLanguageModelRateLimit {
}
async fn complete_with_language_model(
- request: proto::CompleteWithLanguageModel,
+ mut request: proto::CompleteWithLanguageModel,
response: StreamingResponse<proto::CompleteWithLanguageModel>,
session: Session,
open_ai_api_key: Option<Arc<str>>,
@@ -4530,18 +4530,43 @@ async fn complete_with_language_model(
.check::<CompleteWithLanguageModelRateLimit>(session.user_id())
.await?;
- if request.model.starts_with("gpt") {
- let api_key =
- open_ai_api_key.ok_or_else(|| anyhow!("no OpenAI API key configured on the server"))?;
- complete_with_open_ai(request, response, session, api_key).await?;
- } else if request.model.starts_with("gemini") {
- let api_key = google_ai_api_key
- .ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
- complete_with_google_ai(request, response, session, api_key).await?;
- } else if request.model.starts_with("claude") {
- let api_key = anthropic_api_key
- .ok_or_else(|| anyhow!("no Anthropic AI API key configured on the server"))?;
- complete_with_anthropic(request, response, session, api_key).await?;
+ let mut provider_and_model = request.model.split('/');
+ let (provider, model) = match (
+ provider_and_model.next().unwrap(),
+ provider_and_model.next(),
+ ) {
+ (provider, Some(model)) => (provider, model),
+ (model, None) => {
+ if model.starts_with("gpt") {
+ ("openai", model)
+ } else if model.starts_with("gemini") {
+ ("google", model)
+ } else if model.starts_with("claude") {
+ ("anthropic", model)
+ } else {
+ ("unknown", model)
+ }
+ }
+ };
+ let provider = provider.to_string();
+ request.model = model.to_string();
+
+ match provider.as_str() {
+ "openai" => {
+ let api_key = open_ai_api_key.context("no OpenAI API key configured on the server")?;
+ complete_with_open_ai(request, response, session, api_key).await?;
+ }
+ "anthropic" => {
+ let api_key =
+ anthropic_api_key.context("no Anthropic AI API key configured on the server")?;
+ complete_with_anthropic(request, response, session, api_key).await?;
+ }
+ "google" => {
+ let api_key =
+ google_ai_api_key.context("no Google AI API key configured on the server")?;
+ complete_with_google_ai(request, response, session, api_key).await?;
+ }
+ provider => return Err(anyhow!("unknown provider {:?}", provider))?,
}
Ok(())
@@ -54,15 +54,15 @@ impl CloudCompletionProvider {
impl LanguageModelCompletionProvider for CloudCompletionProvider {
fn available_models(&self) -> Vec<LanguageModel> {
- let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
- Some(custom_model)
+ let mut custom_model = if matches!(self.model, CloudModel::Custom { .. }) {
+ Some(self.model.clone())
} else {
None
};
CloudModel::iter()
.filter_map(move |model| {
- if let CloudModel::Custom(_) = model {
- Some(CloudModel::Custom(custom_model.take()?))
+ if let CloudModel::Custom { .. } = model {
+ custom_model.take()
} else {
Some(model)
}
@@ -117,9 +117,9 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider {
// Can't find a tokenizer for Claude 3, so for now just use the same as OpenAI's as an approximation.
count_open_ai_tokens(request, cx.background_executor())
}
- LanguageModel::Cloud(CloudModel::Custom(model)) => {
+ LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
- model,
+ model: name,
messages: request
.messages
.iter()
@@ -241,6 +241,7 @@ pub fn count_open_ai_tokens(
| LanguageModel::Cloud(CloudModel::Claude3Opus)
| LanguageModel::Cloud(CloudModel::Claude3Sonnet)
| LanguageModel::Cloud(CloudModel::Claude3Haiku)
+ | LanguageModel::Cloud(CloudModel::Custom { .. })
| LanguageModel::OpenAi(OpenAiModel::Custom { .. }) => {
// Tiktoken doesn't yet support these models, so we manually use the
// same tokenizer as GPT-4.
@@ -2,100 +2,40 @@ use crate::LanguageModelRequest;
pub use anthropic::Model as AnthropicModel;
pub use ollama::Model as OllamaModel;
pub use open_ai::Model as OpenAiModel;
-use schemars::{
- schema::{InstanceType, Metadata, Schema, SchemaObject},
- JsonSchema,
-};
-use serde::{
- de::{self, Visitor},
- Deserialize, Deserializer, Serialize, Serializer,
-};
-use std::fmt;
-use strum::{EnumIter, IntoEnumIterator};
+use schemars::JsonSchema;
+use serde::{Deserialize, Serialize};
+use strum::EnumIter;
-#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
+#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
pub enum CloudModel {
+ #[serde(rename = "gpt-3.5-turbo")]
Gpt3Point5Turbo,
+ #[serde(rename = "gpt-4")]
Gpt4,
+ #[serde(rename = "gpt-4-turbo-preview")]
Gpt4Turbo,
+ #[serde(rename = "gpt-4o")]
#[default]
Gpt4Omni,
+ #[serde(rename = "gpt-4o-mini")]
Gpt4OmniMini,
+ #[serde(rename = "claude-3-5-sonnet")]
Claude3_5Sonnet,
+ #[serde(rename = "claude-3-opus")]
Claude3Opus,
+ #[serde(rename = "claude-3-sonnet")]
Claude3Sonnet,
+ #[serde(rename = "claude-3-haiku")]
Claude3Haiku,
+ #[serde(rename = "gemini-1.5-pro")]
Gemini15Pro,
+ #[serde(rename = "gemini-1.5-flash")]
Gemini15Flash,
- Custom(String),
-}
-
-impl Serialize for CloudModel {
- fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
- where
- S: Serializer,
- {
- serializer.serialize_str(self.id())
- }
-}
-
-impl<'de> Deserialize<'de> for CloudModel {
- fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
- where
- D: Deserializer<'de>,
- {
- struct ZedDotDevModelVisitor;
-
- impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
- type Value = CloudModel;
-
- fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
- formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
- }
-
- fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
- where
- E: de::Error,
- {
- let model = CloudModel::iter()
- .find(|model| model.id() == value)
- .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
- Ok(model)
- }
- }
-
- deserializer.deserialize_str(ZedDotDevModelVisitor)
- }
-}
-
-impl JsonSchema for CloudModel {
- fn schema_name() -> String {
- "ZedDotDevModel".to_owned()
- }
-
- fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
- let variants = CloudModel::iter()
- .filter_map(|model| {
- let id = model.id();
- if id.is_empty() {
- None
- } else {
- Some(id.to_string())
- }
- })
- .collect::<Vec<_>>();
- Schema::Object(SchemaObject {
- instance_type: Some(InstanceType::String.into()),
- enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
- metadata: Some(Box::new(Metadata {
- title: Some("ZedDotDevModel".to_owned()),
- default: Some(CloudModel::default().id().into()),
- examples: variants.into_iter().map(Into::into).collect(),
- ..Default::default()
- })),
- ..Default::default()
- })
- }
+ #[serde(rename = "custom")]
+ Custom {
+ name: String,
+ max_tokens: Option<usize>,
+ },
}
impl CloudModel {
@@ -112,7 +52,7 @@ impl CloudModel {
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
- Self::Custom(id) => id,
+ Self::Custom { name, .. } => name,
}
}
@@ -129,7 +69,7 @@ impl CloudModel {
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
- Self::Custom(id) => id.as_str(),
+ Self::Custom { name, .. } => name,
}
}
@@ -145,14 +85,20 @@ impl CloudModel {
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
- Self::Custom(_) => 4096, // TODO: Make this configurable
+ Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
}
}
pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
match self {
- Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
- request.preprocess_anthropic()
+ Self::Claude3Opus
+ | Self::Claude3Sonnet
+ | Self::Claude3Haiku
+ | Self::Claude3_5Sonnet => {
+ request.preprocess_anthropic();
+ }
+ Self::Custom { name, .. } if name.starts_with("anthropic/") => {
+ request.preprocess_anthropic();
}
_ => {}
}