1use crate::LanguageModelRequest;
2pub use anthropic::Model as AnthropicModel;
3pub use ollama::Model as OllamaModel;
4pub use open_ai::Model as OpenAiModel;
5use schemars::{
6 schema::{InstanceType, Metadata, Schema, SchemaObject},
7 JsonSchema,
8};
9use serde::{
10 de::{self, Visitor},
11 Deserialize, Deserializer, Serialize, Serializer,
12};
13use std::fmt;
14use strum::{EnumIter, IntoEnumIterator};
15
16#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
17pub enum CloudModel {
18 Gpt3Point5Turbo,
19 Gpt4,
20 Gpt4Turbo,
21 #[default]
22 Gpt4Omni,
23 Gpt4OmniMini,
24 Claude3_5Sonnet,
25 Claude3Opus,
26 Claude3Sonnet,
27 Claude3Haiku,
28 Gemini15Pro,
29 Gemini15Flash,
30 Custom(String),
31}
32
33impl Serialize for CloudModel {
34 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35 where
36 S: Serializer,
37 {
38 serializer.serialize_str(self.id())
39 }
40}
41
42impl<'de> Deserialize<'de> for CloudModel {
43 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
44 where
45 D: Deserializer<'de>,
46 {
47 struct ZedDotDevModelVisitor;
48
49 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
50 type Value = CloudModel;
51
52 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
54 }
55
56 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
57 where
58 E: de::Error,
59 {
60 let model = CloudModel::iter()
61 .find(|model| model.id() == value)
62 .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
63 Ok(model)
64 }
65 }
66
67 deserializer.deserialize_str(ZedDotDevModelVisitor)
68 }
69}
70
71impl JsonSchema for CloudModel {
72 fn schema_name() -> String {
73 "ZedDotDevModel".to_owned()
74 }
75
76 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
77 let variants = CloudModel::iter()
78 .filter_map(|model| {
79 let id = model.id();
80 if id.is_empty() {
81 None
82 } else {
83 Some(id.to_string())
84 }
85 })
86 .collect::<Vec<_>>();
87 Schema::Object(SchemaObject {
88 instance_type: Some(InstanceType::String.into()),
89 enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
90 metadata: Some(Box::new(Metadata {
91 title: Some("ZedDotDevModel".to_owned()),
92 default: Some(CloudModel::default().id().into()),
93 examples: variants.into_iter().map(Into::into).collect(),
94 ..Default::default()
95 })),
96 ..Default::default()
97 })
98 }
99}
100
101impl CloudModel {
102 pub fn id(&self) -> &str {
103 match self {
104 Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
105 Self::Gpt4 => "gpt-4",
106 Self::Gpt4Turbo => "gpt-4-turbo-preview",
107 Self::Gpt4Omni => "gpt-4o",
108 Self::Gpt4OmniMini => "gpt-4o-mini",
109 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
110 Self::Claude3Opus => "claude-3-opus",
111 Self::Claude3Sonnet => "claude-3-sonnet",
112 Self::Claude3Haiku => "claude-3-haiku",
113 Self::Gemini15Pro => "gemini-1.5-pro",
114 Self::Gemini15Flash => "gemini-1.5-flash",
115 Self::Custom(id) => id,
116 }
117 }
118
119 pub fn display_name(&self) -> &str {
120 match self {
121 Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
122 Self::Gpt4 => "GPT 4",
123 Self::Gpt4Turbo => "GPT 4 Turbo",
124 Self::Gpt4Omni => "GPT 4 Omni",
125 Self::Gpt4OmniMini => "GPT 4 Omni Mini",
126 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
127 Self::Claude3Opus => "Claude 3 Opus",
128 Self::Claude3Sonnet => "Claude 3 Sonnet",
129 Self::Claude3Haiku => "Claude 3 Haiku",
130 Self::Gemini15Pro => "Gemini 1.5 Pro",
131 Self::Gemini15Flash => "Gemini 1.5 Flash",
132 Self::Custom(id) => id.as_str(),
133 }
134 }
135
136 pub fn max_token_count(&self) -> usize {
137 match self {
138 Self::Gpt3Point5Turbo => 2048,
139 Self::Gpt4 => 4096,
140 Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
141 Self::Gpt4OmniMini => 128000,
142 Self::Claude3_5Sonnet
143 | Self::Claude3Opus
144 | Self::Claude3Sonnet
145 | Self::Claude3Haiku => 200000,
146 Self::Gemini15Pro => 128000,
147 Self::Gemini15Flash => 32000,
148 Self::Custom(_) => 4096, // TODO: Make this configurable
149 }
150 }
151
152 pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
153 match self {
154 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
155 request.preprocess_anthropic()
156 }
157 _ => {}
158 }
159 }
160}