language_model: Denote the availability of language models (#15660)

Marshall Bowers created

This PR updates the `LanguageModel` trait with a new method for denoting
the availability of a model.

Right now we have two variants:

- `Public` for models that have no additional restrictions (other than
their respective setup/authentication requirements)
- `RequiresPlan` for models that require a specific Zed plan

Release Notes:

- N/A

Change summary

crates/assistant/src/model_selector.rs         |  1 
crates/language_model/src/language_model.rs    | 15 ++++
crates/language_model/src/model/cloud_model.rs | 64 ++++++++++++++++---
crates/language_model/src/provider/cloud.rs    |  6 +
4 files changed, 73 insertions(+), 13 deletions(-)

Detailed changes

crates/assistant/src/model_selector.rs 🔗

@@ -109,6 +109,7 @@ impl<T: PopoverTrigger> RenderOnce for ModelSelector<T> {
                                 let id = available_model.id();
                                 let provider_id = available_model.provider_id();
                                 let model_name = available_model.name().0.clone();
+                                let _availability = available_model.availability();
                                 let selected_model = selected_model.clone();
                                 let selected_provider = selected_provider.clone();
                                 move |_| {

crates/language_model/src/language_model.rs 🔗

@@ -14,6 +14,7 @@ use gpui::{
 };
 pub use model::*;
 use project::Fs;
+use proto::Plan;
 pub(crate) use rate_limiter::*;
 pub use registry::*;
 pub use request::*;
@@ -32,6 +33,15 @@ pub fn init(
     registry::init(user_store, client, cx);
 }
 
+/// The availability of a [`LanguageModel`].
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum LanguageModelAvailability {
+    /// The language model is available to the general public.
+    Public,
+    /// The language model is available to users on the indicated plan.
+    RequiresPlan(Plan),
+}
+
 pub trait LanguageModel: Send + Sync {
     fn id(&self) -> LanguageModelId;
     fn name(&self) -> LanguageModelName;
@@ -39,6 +49,11 @@ pub trait LanguageModel: Send + Sync {
     fn provider_name(&self) -> LanguageModelProviderName;
     fn telemetry_id(&self) -> String;
 
+    /// Returns the availability of this language model.
+    fn availability(&self) -> LanguageModelAvailability {
+        LanguageModelAvailability::Public
+    }
+
     fn max_token_count(&self) -> usize;
 
     fn count_tokens(

crates/language_model/src/model/cloud_model.rs 🔗

@@ -1,7 +1,10 @@
+use proto::Plan;
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use strum::EnumIter;
 
+use crate::LanguageModelAvailability;
+
 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 #[serde(tag = "provider", rename_all = "lowercase")]
 pub enum CloudModel {
@@ -46,28 +49,65 @@ impl Default for CloudModel {
 impl CloudModel {
     pub fn id(&self) -> &str {
         match self {
-            CloudModel::Anthropic(model) => model.id(),
-            CloudModel::OpenAi(model) => model.id(),
-            CloudModel::Google(model) => model.id(),
-            CloudModel::Zed(model) => model.id(),
+            Self::Anthropic(model) => model.id(),
+            Self::OpenAi(model) => model.id(),
+            Self::Google(model) => model.id(),
+            Self::Zed(model) => model.id(),
         }
     }
 
     pub fn display_name(&self) -> &str {
         match self {
-            CloudModel::Anthropic(model) => model.display_name(),
-            CloudModel::OpenAi(model) => model.display_name(),
-            CloudModel::Google(model) => model.display_name(),
-            CloudModel::Zed(model) => model.display_name(),
+            Self::Anthropic(model) => model.display_name(),
+            Self::OpenAi(model) => model.display_name(),
+            Self::Google(model) => model.display_name(),
+            Self::Zed(model) => model.display_name(),
         }
     }
 
     pub fn max_token_count(&self) -> usize {
         match self {
-            CloudModel::Anthropic(model) => model.max_token_count(),
-            CloudModel::OpenAi(model) => model.max_token_count(),
-            CloudModel::Google(model) => model.max_token_count(),
-            CloudModel::Zed(model) => model.max_token_count(),
+            Self::Anthropic(model) => model.max_token_count(),
+            Self::OpenAi(model) => model.max_token_count(),
+            Self::Google(model) => model.max_token_count(),
+            Self::Zed(model) => model.max_token_count(),
+        }
+    }
+
+    /// Returns the availability of this model.
+    pub fn availability(&self) -> LanguageModelAvailability {
+        match self {
+            Self::Anthropic(model) => match model {
+                anthropic::Model::Claude3_5Sonnet => {
+                    LanguageModelAvailability::RequiresPlan(Plan::Free)
+                }
+                anthropic::Model::Claude3Opus
+                | anthropic::Model::Claude3Sonnet
+                | anthropic::Model::Claude3Haiku
+                | anthropic::Model::Custom { .. } => {
+                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
+                }
+            },
+            Self::OpenAi(model) => match model {
+                open_ai::Model::ThreePointFiveTurbo
+                | open_ai::Model::Four
+                | open_ai::Model::FourTurbo
+                | open_ai::Model::FourOmni
+                | open_ai::Model::FourOmniMini
+                | open_ai::Model::Custom { .. } => {
+                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
+                }
+            },
+            Self::Google(model) => match model {
+                google_ai::Model::Gemini15Pro
+                | google_ai::Model::Gemini15Flash
+                | google_ai::Model::Custom { .. } => {
+                    LanguageModelAvailability::RequiresPlan(Plan::ZedPro)
+                }
+            },
+            Self::Zed(model) => match model {
+                ZedModel::Qwen2_7bInstruct => LanguageModelAvailability::RequiresPlan(Plan::ZedPro),
+            },
         }
     }
 }

crates/language_model/src/provider/cloud.rs 🔗

@@ -18,7 +18,7 @@ use std::{future, sync::Arc};
 use strum::IntoEnumIterator;
 use ui::prelude::*;
 
-use crate::LanguageModelProvider;
+use crate::{LanguageModelAvailability, LanguageModelProvider};
 
 use super::anthropic::count_anthropic_tokens;
 
@@ -236,6 +236,10 @@ impl LanguageModel for CloudLanguageModel {
         format!("zed.dev/{}", self.model.id())
     }
 
+    fn availability(&self) -> LanguageModelAvailability {
+        self.model.availability()
+    }
+
     fn max_token_count(&self) -> usize {
         self.model.max_token_count()
     }