1use std::fmt;
2
3use gpui::Pixels;
4pub use open_ai::Model as OpenAiModel;
5use schemars::{
6 schema::{InstanceType, Metadata, Schema, SchemaObject},
7 JsonSchema,
8};
9use serde::{
10 de::{self, Visitor},
11 Deserialize, Deserializer, Serialize, Serializer,
12};
13use settings::Settings;
14
15#[derive(Clone, Debug, Default, PartialEq)]
16pub enum ZedDotDevModel {
17 GptThreePointFiveTurbo,
18 GptFour,
19 #[default]
20 GptFourTurbo,
21 Custom(String),
22}
23
24impl Serialize for ZedDotDevModel {
25 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
26 where
27 S: Serializer,
28 {
29 serializer.serialize_str(self.id())
30 }
31}
32
33impl<'de> Deserialize<'de> for ZedDotDevModel {
34 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
35 where
36 D: Deserializer<'de>,
37 {
38 struct ZedDotDevModelVisitor;
39
40 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
41 type Value = ZedDotDevModel;
42
43 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
44 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
45 }
46
47 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
48 where
49 E: de::Error,
50 {
51 match value {
52 "gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
53 "gpt-4" => Ok(ZedDotDevModel::GptFour),
54 "gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
55 _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
56 }
57 }
58 }
59
60 deserializer.deserialize_str(ZedDotDevModelVisitor)
61 }
62}
63
64impl JsonSchema for ZedDotDevModel {
65 fn schema_name() -> String {
66 "ZedDotDevModel".to_owned()
67 }
68
69 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
70 let variants = vec![
71 "gpt-3.5-turbo".to_owned(),
72 "gpt-4".to_owned(),
73 "gpt-4-turbo-preview".to_owned(),
74 ];
75 Schema::Object(SchemaObject {
76 instance_type: Some(InstanceType::String.into()),
77 enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
78 metadata: Some(Box::new(Metadata {
79 title: Some("ZedDotDevModel".to_owned()),
80 default: Some(serde_json::json!("gpt-4-turbo-preview")),
81 examples: vec![
82 serde_json::json!("gpt-3.5-turbo"),
83 serde_json::json!("gpt-4"),
84 serde_json::json!("gpt-4-turbo-preview"),
85 serde_json::json!("custom-model-name"),
86 ],
87 ..Default::default()
88 })),
89 ..Default::default()
90 })
91 }
92}
93
94impl ZedDotDevModel {
95 pub fn id(&self) -> &str {
96 match self {
97 Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
98 Self::GptFour => "gpt-4",
99 Self::GptFourTurbo => "gpt-4-turbo-preview",
100 Self::Custom(id) => id,
101 }
102 }
103
104 pub fn display_name(&self) -> &str {
105 match self {
106 Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
107 Self::GptFour => "gpt-4",
108 Self::GptFourTurbo => "gpt-4-turbo",
109 Self::Custom(id) => id.as_str(),
110 }
111 }
112
113 pub fn max_token_count(&self) -> usize {
114 match self {
115 Self::GptThreePointFiveTurbo => 2048,
116 Self::GptFour => 4096,
117 Self::GptFourTurbo => 128000,
118 Self::Custom(_) => 4096, // TODO: Make this configurable
119 }
120 }
121}
122
123#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
124#[serde(rename_all = "snake_case")]
125pub enum AssistantDockPosition {
126 Left,
127 #[default]
128 Right,
129 Bottom,
130}
131
132#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
133#[serde(tag = "name", rename_all = "snake_case")]
134pub enum AssistantProvider {
135 #[serde(rename = "zed.dev")]
136 ZedDotDev {
137 #[serde(default)]
138 default_model: ZedDotDevModel,
139 },
140 #[serde(rename = "openai")]
141 OpenAi {
142 #[serde(default)]
143 default_model: OpenAiModel,
144 #[serde(default = "open_ai_url")]
145 api_url: String,
146 },
147}
148
149impl Default for AssistantProvider {
150 fn default() -> Self {
151 Self::ZedDotDev {
152 default_model: ZedDotDevModel::default(),
153 }
154 }
155}
156
157fn open_ai_url() -> String {
158 "https://api.openai.com/v1".into()
159}
160
161#[derive(Default, Debug, Deserialize, Serialize)]
162pub struct AssistantSettings {
163 pub enabled: bool,
164 pub button: bool,
165 pub dock: AssistantDockPosition,
166 pub default_width: Pixels,
167 pub default_height: Pixels,
168 pub provider: AssistantProvider,
169}
170
171/// Assistant panel settings
172#[derive(Clone, Serialize, Deserialize, Debug)]
173#[serde(untagged)]
174pub enum AssistantSettingsContent {
175 Versioned(VersionedAssistantSettingsContent),
176 Legacy(LegacyAssistantSettingsContent),
177}
178
179impl JsonSchema for AssistantSettingsContent {
180 fn schema_name() -> String {
181 VersionedAssistantSettingsContent::schema_name()
182 }
183
184 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
185 VersionedAssistantSettingsContent::json_schema(gen)
186 }
187
188 fn is_referenceable() -> bool {
189 VersionedAssistantSettingsContent::is_referenceable()
190 }
191}
192
193impl Default for AssistantSettingsContent {
194 fn default() -> Self {
195 Self::Versioned(VersionedAssistantSettingsContent::default())
196 }
197}
198
199impl AssistantSettingsContent {
200 fn upgrade(&self) -> AssistantSettingsContentV1 {
201 match self {
202 AssistantSettingsContent::Versioned(settings) => match settings {
203 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
204 },
205 AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
206 enabled: None,
207 button: settings.button,
208 dock: settings.dock,
209 default_width: settings.default_width,
210 default_height: settings.default_height,
211 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
212 Some(AssistantProvider::OpenAi {
213 default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
214 api_url: open_ai_api_url.clone(),
215 })
216 } else {
217 settings.default_open_ai_model.clone().map(|open_ai_model| {
218 AssistantProvider::OpenAi {
219 default_model: open_ai_model,
220 api_url: open_ai_url(),
221 }
222 })
223 },
224 },
225 }
226 }
227
228 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
229 match self {
230 AssistantSettingsContent::Versioned(settings) => match settings {
231 VersionedAssistantSettingsContent::V1(settings) => {
232 settings.dock = Some(dock);
233 }
234 },
235 AssistantSettingsContent::Legacy(settings) => {
236 settings.dock = Some(dock);
237 }
238 }
239 }
240}
241
242#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
243#[serde(tag = "version")]
244pub enum VersionedAssistantSettingsContent {
245 #[serde(rename = "1")]
246 V1(AssistantSettingsContentV1),
247}
248
249impl Default for VersionedAssistantSettingsContent {
250 fn default() -> Self {
251 Self::V1(AssistantSettingsContentV1 {
252 enabled: None,
253 button: None,
254 dock: None,
255 default_width: None,
256 default_height: None,
257 provider: None,
258 })
259 }
260}
261
262#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
263pub struct AssistantSettingsContentV1 {
264 /// Whether the Assistant is enabled.
265 ///
266 /// Default: true
267 enabled: Option<bool>,
268 /// Whether to show the assistant panel button in the status bar.
269 ///
270 /// Default: true
271 button: Option<bool>,
272 /// Where to dock the assistant.
273 ///
274 /// Default: right
275 dock: Option<AssistantDockPosition>,
276 /// Default width in pixels when the assistant is docked to the left or right.
277 ///
278 /// Default: 640
279 default_width: Option<f32>,
280 /// Default height in pixels when the assistant is docked to the bottom.
281 ///
282 /// Default: 320
283 default_height: Option<f32>,
284 /// The provider of the assistant service.
285 ///
286 /// This can either be the internal `zed.dev` service or an external `openai` service,
287 /// each with their respective default models and configurations.
288 provider: Option<AssistantProvider>,
289}
290
291#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
292pub struct LegacyAssistantSettingsContent {
293 /// Whether to show the assistant panel button in the status bar.
294 ///
295 /// Default: true
296 pub button: Option<bool>,
297 /// Where to dock the assistant.
298 ///
299 /// Default: right
300 pub dock: Option<AssistantDockPosition>,
301 /// Default width in pixels when the assistant is docked to the left or right.
302 ///
303 /// Default: 640
304 pub default_width: Option<f32>,
305 /// Default height in pixels when the assistant is docked to the bottom.
306 ///
307 /// Default: 320
308 pub default_height: Option<f32>,
309 /// The default OpenAI model to use when starting new conversations.
310 ///
311 /// Default: gpt-4-1106-preview
312 pub default_open_ai_model: Option<OpenAiModel>,
313 /// OpenAI API base URL to use when starting new conversations.
314 ///
315 /// Default: https://api.openai.com/v1
316 pub openai_api_url: Option<String>,
317}
318
319impl Settings for AssistantSettings {
320 const KEY: Option<&'static str> = Some("assistant");
321
322 type FileContent = AssistantSettingsContent;
323
324 fn load(
325 default_value: &Self::FileContent,
326 user_values: &[&Self::FileContent],
327 _: &mut gpui::AppContext,
328 ) -> anyhow::Result<Self> {
329 let mut settings = AssistantSettings::default();
330
331 for value in [default_value].iter().chain(user_values) {
332 let value = value.upgrade();
333 merge(&mut settings.enabled, value.enabled);
334 merge(&mut settings.button, value.button);
335 merge(&mut settings.dock, value.dock);
336 merge(
337 &mut settings.default_width,
338 value.default_width.map(Into::into),
339 );
340 merge(
341 &mut settings.default_height,
342 value.default_height.map(Into::into),
343 );
344 if let Some(provider) = value.provider.clone() {
345 match (&mut settings.provider, provider) {
346 (
347 AssistantProvider::ZedDotDev { default_model },
348 AssistantProvider::ZedDotDev {
349 default_model: default_model_override,
350 },
351 ) => {
352 *default_model = default_model_override;
353 }
354 (
355 AssistantProvider::OpenAi {
356 default_model,
357 api_url,
358 },
359 AssistantProvider::OpenAi {
360 default_model: default_model_override,
361 api_url: api_url_override,
362 },
363 ) => {
364 *default_model = default_model_override;
365 *api_url = api_url_override;
366 }
367 (merged, provider_override) => {
368 *merged = provider_override;
369 }
370 }
371 }
372 }
373
374 Ok(settings)
375 }
376}
377
378fn merge<T: Copy>(target: &mut T, value: Option<T>) {
379 if let Some(value) = value {
380 *target = value;
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use gpui::AppContext;
387 use settings::SettingsStore;
388
389 use super::*;
390
391 #[gpui::test]
392 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
393 let store = settings::SettingsStore::test(cx);
394 cx.set_global(store);
395
396 // Settings default to gpt-4-turbo.
397 AssistantSettings::register(cx);
398 assert_eq!(
399 AssistantSettings::get_global(cx).provider,
400 AssistantProvider::OpenAi {
401 default_model: OpenAiModel::FourTurbo,
402 api_url: open_ai_url()
403 }
404 );
405
406 // Ensure backward-compatibility.
407 cx.update_global::<SettingsStore, _>(|store, cx| {
408 store
409 .set_user_settings(
410 r#"{
411 "assistant": {
412 "openai_api_url": "test-url",
413 }
414 }"#,
415 cx,
416 )
417 .unwrap();
418 });
419 assert_eq!(
420 AssistantSettings::get_global(cx).provider,
421 AssistantProvider::OpenAi {
422 default_model: OpenAiModel::FourTurbo,
423 api_url: "test-url".into()
424 }
425 );
426 cx.update_global::<SettingsStore, _>(|store, cx| {
427 store
428 .set_user_settings(
429 r#"{
430 "assistant": {
431 "default_open_ai_model": "gpt-4-0613"
432 }
433 }"#,
434 cx,
435 )
436 .unwrap();
437 });
438 assert_eq!(
439 AssistantSettings::get_global(cx).provider,
440 AssistantProvider::OpenAi {
441 default_model: OpenAiModel::Four,
442 api_url: open_ai_url()
443 }
444 );
445
446 // The new version supports setting a custom model when using zed.dev.
447 cx.update_global::<SettingsStore, _>(|store, cx| {
448 store
449 .set_user_settings(
450 r#"{
451 "assistant": {
452 "version": "1",
453 "provider": {
454 "name": "zed.dev",
455 "default_model": "custom"
456 }
457 }
458 }"#,
459 cx,
460 )
461 .unwrap();
462 });
463 assert_eq!(
464 AssistantSettings::get_global(cx).provider,
465 AssistantProvider::ZedDotDev {
466 default_model: ZedDotDevModel::Custom("custom".into())
467 }
468 );
469 }
470}