From 5e011ab0291869a9ee2eeaf1396304fd81ced517 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 1 Aug 2024 18:26:27 -0400 Subject: [PATCH] language_model: Denote the availability of language models (#15660) This PR updates the `LanguageModel` trait with a new method for denoting the availability of a model. Right now we have two variants: - `Public` for models that have no additional restrictions (other than their respective setup/authentication requirements) - `RequiresPlan` for models that require a specific Zed plan Release Notes: - N/A --- crates/assistant/src/model_selector.rs | 1 + crates/language_model/src/language_model.rs | 15 +++++ .../language_model/src/model/cloud_model.rs | 64 +++++++++++++++---- crates/language_model/src/provider/cloud.rs | 6 +- 4 files changed, 73 insertions(+), 13 deletions(-) diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs index 727f51ea555ab7fd7c95d3e3fa723e4104511971..d3d23b148de8f373b10aef10b94b246d6e6fc3fd 100644 --- a/crates/assistant/src/model_selector.rs +++ b/crates/assistant/src/model_selector.rs @@ -109,6 +109,7 @@ impl RenderOnce for ModelSelector { let id = available_model.id(); let provider_id = available_model.provider_id(); let model_name = available_model.name().0.clone(); + let _availability = available_model.availability(); let selected_model = selected_model.clone(); let selected_provider = selected_provider.clone(); move |_| { diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs index e92cbd652b3c42b4e369d46aaa33370b4d4392c2..2fa82197ab546a6b55e5386e2ec86ebc33bfee84 100644 --- a/crates/language_model/src/language_model.rs +++ b/crates/language_model/src/language_model.rs @@ -14,6 +14,7 @@ use gpui::{ }; pub use model::*; use project::Fs; +use proto::Plan; pub(crate) use rate_limiter::*; pub use registry::*; pub use request::*; @@ -32,6 +33,15 @@ pub fn init( registry::init(user_store, client, cx); } +/// The availability of a [`LanguageModel`]. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum LanguageModelAvailability { + /// The language model is available to the general public. + Public, + /// The language model is available to users on the indicated plan. + RequiresPlan(Plan), +} + pub trait LanguageModel: Send + Sync { fn id(&self) -> LanguageModelId; fn name(&self) -> LanguageModelName; @@ -39,6 +49,11 @@ pub trait LanguageModel: Send + Sync { fn provider_name(&self) -> LanguageModelProviderName; fn telemetry_id(&self) -> String; + /// Returns the availability of this language model. + fn availability(&self) -> LanguageModelAvailability { + LanguageModelAvailability::Public + } + fn max_token_count(&self) -> usize; fn count_tokens( diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index c94870796e4fcdcabbaa6a55f218d565261aa4c5..8d6f53dbc636886e4ee561b0236940d82d58845d 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -1,7 +1,10 @@ +use proto::Plan; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use strum::EnumIter; +use crate::LanguageModelAvailability; + #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)] #[serde(tag = "provider", rename_all = "lowercase")] pub enum CloudModel { @@ -46,28 +49,65 @@ impl Default for CloudModel { impl CloudModel { pub fn id(&self) -> &str { match self { - CloudModel::Anthropic(model) => model.id(), - CloudModel::OpenAi(model) => model.id(), - CloudModel::Google(model) => model.id(), - CloudModel::Zed(model) => model.id(), + Self::Anthropic(model) => model.id(), + Self::OpenAi(model) => model.id(), + Self::Google(model) => model.id(), + Self::Zed(model) => model.id(), } } pub fn display_name(&self) -> &str { match self { - CloudModel::Anthropic(model) => model.display_name(), - CloudModel::OpenAi(model) => model.display_name(), - CloudModel::Google(model) => model.display_name(), - CloudModel::Zed(model) => model.display_name(), + Self::Anthropic(model) => model.display_name(), + Self::OpenAi(model) => model.display_name(), + Self::Google(model) => model.display_name(), + Self::Zed(model) => model.display_name(), } } pub fn max_token_count(&self) -> usize { match self { - CloudModel::Anthropic(model) => model.max_token_count(), - CloudModel::OpenAi(model) => model.max_token_count(), - CloudModel::Google(model) => model.max_token_count(), - CloudModel::Zed(model) => model.max_token_count(), + Self::Anthropic(model) => model.max_token_count(), + Self::OpenAi(model) => model.max_token_count(), + Self::Google(model) => model.max_token_count(), + Self::Zed(model) => model.max_token_count(), + } + } + + /// Returns the availability of this model. + pub fn availability(&self) -> LanguageModelAvailability { + match self { + Self::Anthropic(model) => match model { + anthropic::Model::Claude3_5Sonnet => { + LanguageModelAvailability::RequiresPlan(Plan::Free) + } + anthropic::Model::Claude3Opus + | anthropic::Model::Claude3Sonnet + | anthropic::Model::Claude3Haiku + | anthropic::Model::Custom { .. } => { + LanguageModelAvailability::RequiresPlan(Plan::ZedPro) + } + }, + Self::OpenAi(model) => match model { + open_ai::Model::ThreePointFiveTurbo + | open_ai::Model::Four + | open_ai::Model::FourTurbo + | open_ai::Model::FourOmni + | open_ai::Model::FourOmniMini + | open_ai::Model::Custom { .. } => { + LanguageModelAvailability::RequiresPlan(Plan::ZedPro) + } + }, + Self::Google(model) => match model { + google_ai::Model::Gemini15Pro + | google_ai::Model::Gemini15Flash + | google_ai::Model::Custom { .. } => { + LanguageModelAvailability::RequiresPlan(Plan::ZedPro) + } + }, + Self::Zed(model) => match model { + ZedModel::Qwen2_7bInstruct => LanguageModelAvailability::RequiresPlan(Plan::ZedPro), + }, } } } diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 0c6402c7ab5273c73fdf5dc521b69f13e395cddd..3dda8b24e1f40bda3567e198e42cb375af5e008c 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -18,7 +18,7 @@ use std::{future, sync::Arc}; use strum::IntoEnumIterator; use ui::prelude::*; -use crate::LanguageModelProvider; +use crate::{LanguageModelAvailability, LanguageModelProvider}; use super::anthropic::count_anthropic_tokens; @@ -236,6 +236,10 @@ impl LanguageModel for CloudLanguageModel { format!("zed.dev/{}", self.model.id()) } + fn availability(&self) -> LanguageModelAvailability { + self.model.availability() + } + fn max_token_count(&self) -> usize { self.model.max_token_count() }