cloud_model.rs

  1use crate::LanguageModelRequest;
  2pub use anthropic::Model as AnthropicModel;
  3pub use ollama::Model as OllamaModel;
  4pub use open_ai::Model as OpenAiModel;
  5use schemars::{
  6    schema::{InstanceType, Metadata, Schema, SchemaObject},
  7    JsonSchema,
  8};
  9use serde::{
 10    de::{self, Visitor},
 11    Deserialize, Deserializer, Serialize, Serializer,
 12};
 13use std::fmt;
 14use strum::{EnumIter, IntoEnumIterator};
 15
 16#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 17pub enum CloudModel {
 18    Gpt3Point5Turbo,
 19    Gpt4,
 20    Gpt4Turbo,
 21    #[default]
 22    Gpt4Omni,
 23    Gpt4OmniMini,
 24    Claude3_5Sonnet,
 25    Claude3Opus,
 26    Claude3Sonnet,
 27    Claude3Haiku,
 28    Gemini15Pro,
 29    Gemini15Flash,
 30    Custom(String),
 31}
 32
 33impl Serialize for CloudModel {
 34    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 35    where
 36        S: Serializer,
 37    {
 38        serializer.serialize_str(self.id())
 39    }
 40}
 41
 42impl<'de> Deserialize<'de> for CloudModel {
 43    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 44    where
 45        D: Deserializer<'de>,
 46    {
 47        struct ZedDotDevModelVisitor;
 48
 49        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 50            type Value = CloudModel;
 51
 52            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 53                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 54            }
 55
 56            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 57            where
 58                E: de::Error,
 59            {
 60                let model = CloudModel::iter()
 61                    .find(|model| model.id() == value)
 62                    .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
 63                Ok(model)
 64            }
 65        }
 66
 67        deserializer.deserialize_str(ZedDotDevModelVisitor)
 68    }
 69}
 70
 71impl JsonSchema for CloudModel {
 72    fn schema_name() -> String {
 73        "ZedDotDevModel".to_owned()
 74    }
 75
 76    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 77        let variants = CloudModel::iter()
 78            .filter_map(|model| {
 79                let id = model.id();
 80                if id.is_empty() {
 81                    None
 82                } else {
 83                    Some(id.to_string())
 84                }
 85            })
 86            .collect::<Vec<_>>();
 87        Schema::Object(SchemaObject {
 88            instance_type: Some(InstanceType::String.into()),
 89            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
 90            metadata: Some(Box::new(Metadata {
 91                title: Some("ZedDotDevModel".to_owned()),
 92                default: Some(CloudModel::default().id().into()),
 93                examples: variants.into_iter().map(Into::into).collect(),
 94                ..Default::default()
 95            })),
 96            ..Default::default()
 97        })
 98    }
 99}
100
101impl CloudModel {
102    pub fn id(&self) -> &str {
103        match self {
104            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
105            Self::Gpt4 => "gpt-4",
106            Self::Gpt4Turbo => "gpt-4-turbo-preview",
107            Self::Gpt4Omni => "gpt-4o",
108            Self::Gpt4OmniMini => "gpt-4o-mini",
109            Self::Claude3_5Sonnet => "claude-3-5-sonnet",
110            Self::Claude3Opus => "claude-3-opus",
111            Self::Claude3Sonnet => "claude-3-sonnet",
112            Self::Claude3Haiku => "claude-3-haiku",
113            Self::Gemini15Pro => "gemini-1.5-pro",
114            Self::Gemini15Flash => "gemini-1.5-flash",
115            Self::Custom(id) => id,
116        }
117    }
118
119    pub fn display_name(&self) -> &str {
120        match self {
121            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
122            Self::Gpt4 => "GPT 4",
123            Self::Gpt4Turbo => "GPT 4 Turbo",
124            Self::Gpt4Omni => "GPT 4 Omni",
125            Self::Gpt4OmniMini => "GPT 4 Omni Mini",
126            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
127            Self::Claude3Opus => "Claude 3 Opus",
128            Self::Claude3Sonnet => "Claude 3 Sonnet",
129            Self::Claude3Haiku => "Claude 3 Haiku",
130            Self::Gemini15Pro => "Gemini 1.5 Pro",
131            Self::Gemini15Flash => "Gemini 1.5 Flash",
132            Self::Custom(id) => id.as_str(),
133        }
134    }
135
136    pub fn max_token_count(&self) -> usize {
137        match self {
138            Self::Gpt3Point5Turbo => 2048,
139            Self::Gpt4 => 4096,
140            Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
141            Self::Gpt4OmniMini => 128000,
142            Self::Claude3_5Sonnet
143            | Self::Claude3Opus
144            | Self::Claude3Sonnet
145            | Self::Claude3Haiku => 200000,
146            Self::Gemini15Pro => 128000,
147            Self::Gemini15Flash => 32000,
148            Self::Custom(_) => 4096, // TODO: Make this configurable
149        }
150    }
151
152    pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
153        match self {
154            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
155                request.preprocess_anthropic()
156            }
157            _ => {}
158        }
159    }
160}