1use std::fmt;
2
3use gpui::Pixels;
4pub use open_ai::Model as OpenAiModel;
5use schemars::{
6 schema::{InstanceType, Metadata, Schema, SchemaObject},
7 JsonSchema,
8};
9use serde::{
10 de::{self, Visitor},
11 Deserialize, Deserializer, Serialize, Serializer,
12};
13use settings::Settings;
14
15#[derive(Clone, Debug, Default, PartialEq)]
16pub enum ZedDotDevModel {
17 GptThreePointFiveTurbo,
18 GptFour,
19 #[default]
20 GptFourTurbo,
21 Custom(String),
22}
23
24impl Serialize for ZedDotDevModel {
25 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
26 where
27 S: Serializer,
28 {
29 serializer.serialize_str(self.id())
30 }
31}
32
33impl<'de> Deserialize<'de> for ZedDotDevModel {
34 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
35 where
36 D: Deserializer<'de>,
37 {
38 struct ZedDotDevModelVisitor;
39
40 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
41 type Value = ZedDotDevModel;
42
43 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
44 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
45 }
46
47 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
48 where
49 E: de::Error,
50 {
51 match value {
52 "gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
53 "gpt-4" => Ok(ZedDotDevModel::GptFour),
54 "gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
55 _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
56 }
57 }
58 }
59
60 deserializer.deserialize_str(ZedDotDevModelVisitor)
61 }
62}
63
64impl JsonSchema for ZedDotDevModel {
65 fn schema_name() -> String {
66 "ZedDotDevModel".to_owned()
67 }
68
69 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
70 let variants = vec![
71 "gpt-3.5-turbo".to_owned(),
72 "gpt-4".to_owned(),
73 "gpt-4-turbo-preview".to_owned(),
74 ];
75 Schema::Object(SchemaObject {
76 instance_type: Some(InstanceType::String.into()),
77 enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
78 metadata: Some(Box::new(Metadata {
79 title: Some("ZedDotDevModel".to_owned()),
80 default: Some(serde_json::json!("gpt-4-turbo-preview")),
81 examples: vec![
82 serde_json::json!("gpt-3.5-turbo"),
83 serde_json::json!("gpt-4"),
84 serde_json::json!("gpt-4-turbo-preview"),
85 serde_json::json!("custom-model-name"),
86 ],
87 ..Default::default()
88 })),
89 ..Default::default()
90 })
91 }
92}
93
94impl ZedDotDevModel {
95 pub fn id(&self) -> &str {
96 match self {
97 Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
98 Self::GptFour => "gpt-4",
99 Self::GptFourTurbo => "gpt-4-turbo-preview",
100 Self::Custom(id) => id,
101 }
102 }
103
104 pub fn display_name(&self) -> &str {
105 match self {
106 Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
107 Self::GptFour => "gpt-4",
108 Self::GptFourTurbo => "gpt-4-turbo",
109 Self::Custom(id) => id.as_str(),
110 }
111 }
112}
113
114#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
115#[serde(rename_all = "snake_case")]
116pub enum AssistantDockPosition {
117 Left,
118 #[default]
119 Right,
120 Bottom,
121}
122
123#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
124#[serde(tag = "name", rename_all = "snake_case")]
125pub enum AssistantProvider {
126 #[serde(rename = "zed.dev")]
127 ZedDotDev {
128 #[serde(default)]
129 default_model: ZedDotDevModel,
130 },
131 #[serde(rename = "openai")]
132 OpenAi {
133 #[serde(default)]
134 default_model: OpenAiModel,
135 #[serde(default = "open_ai_url")]
136 api_url: String,
137 },
138}
139
140impl Default for AssistantProvider {
141 fn default() -> Self {
142 Self::ZedDotDev {
143 default_model: ZedDotDevModel::default(),
144 }
145 }
146}
147
148fn open_ai_url() -> String {
149 "https://api.openai.com/v1".into()
150}
151
152#[derive(Default, Debug, Deserialize, Serialize)]
153pub struct AssistantSettings {
154 pub button: bool,
155 pub dock: AssistantDockPosition,
156 pub default_width: Pixels,
157 pub default_height: Pixels,
158 pub provider: AssistantProvider,
159}
160
161/// Assistant panel settings
162#[derive(Clone, Serialize, Deserialize, Debug)]
163#[serde(untagged)]
164pub enum AssistantSettingsContent {
165 Versioned(VersionedAssistantSettingsContent),
166 Legacy(LegacyAssistantSettingsContent),
167}
168
169impl JsonSchema for AssistantSettingsContent {
170 fn schema_name() -> String {
171 VersionedAssistantSettingsContent::schema_name()
172 }
173
174 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
175 VersionedAssistantSettingsContent::json_schema(gen)
176 }
177
178 fn is_referenceable() -> bool {
179 VersionedAssistantSettingsContent::is_referenceable()
180 }
181}
182
183impl Default for AssistantSettingsContent {
184 fn default() -> Self {
185 Self::Versioned(VersionedAssistantSettingsContent::default())
186 }
187}
188
189impl AssistantSettingsContent {
190 fn upgrade(&self) -> AssistantSettingsContentV1 {
191 match self {
192 AssistantSettingsContent::Versioned(settings) => match settings {
193 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
194 },
195 AssistantSettingsContent::Legacy(settings) => {
196 if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
197 AssistantSettingsContentV1 {
198 button: settings.button,
199 dock: settings.dock,
200 default_width: settings.default_width,
201 default_height: settings.default_height,
202 provider: Some(AssistantProvider::OpenAi {
203 default_model: settings
204 .default_open_ai_model
205 .clone()
206 .unwrap_or_default(),
207 api_url: open_ai_api_url.clone(),
208 }),
209 }
210 } else if let Some(open_ai_model) = settings.default_open_ai_model.clone() {
211 AssistantSettingsContentV1 {
212 button: settings.button,
213 dock: settings.dock,
214 default_width: settings.default_width,
215 default_height: settings.default_height,
216 provider: Some(AssistantProvider::OpenAi {
217 default_model: open_ai_model,
218 api_url: open_ai_url(),
219 }),
220 }
221 } else {
222 AssistantSettingsContentV1 {
223 button: settings.button,
224 dock: settings.dock,
225 default_width: settings.default_width,
226 default_height: settings.default_height,
227 provider: None,
228 }
229 }
230 }
231 }
232 }
233
234 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
235 match self {
236 AssistantSettingsContent::Versioned(settings) => match settings {
237 VersionedAssistantSettingsContent::V1(settings) => {
238 settings.dock = Some(dock);
239 }
240 },
241 AssistantSettingsContent::Legacy(settings) => {
242 settings.dock = Some(dock);
243 }
244 }
245 }
246}
247
248#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
249#[serde(tag = "version")]
250pub enum VersionedAssistantSettingsContent {
251 #[serde(rename = "1")]
252 V1(AssistantSettingsContentV1),
253}
254
255impl Default for VersionedAssistantSettingsContent {
256 fn default() -> Self {
257 Self::V1(AssistantSettingsContentV1 {
258 button: None,
259 dock: None,
260 default_width: None,
261 default_height: None,
262 provider: None,
263 })
264 }
265}
266
267#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
268pub struct AssistantSettingsContentV1 {
269 /// Whether to show the assistant panel button in the status bar.
270 ///
271 /// Default: true
272 button: Option<bool>,
273 /// Where to dock the assistant.
274 ///
275 /// Default: right
276 dock: Option<AssistantDockPosition>,
277 /// Default width in pixels when the assistant is docked to the left or right.
278 ///
279 /// Default: 640
280 default_width: Option<f32>,
281 /// Default height in pixels when the assistant is docked to the bottom.
282 ///
283 /// Default: 320
284 default_height: Option<f32>,
285 /// The provider of the assistant service.
286 ///
287 /// This can either be the internal `zed.dev` service or an external `openai` service,
288 /// each with their respective default models and configurations.
289 provider: Option<AssistantProvider>,
290}
291
292#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
293pub struct LegacyAssistantSettingsContent {
294 /// Whether to show the assistant panel button in the status bar.
295 ///
296 /// Default: true
297 pub button: Option<bool>,
298 /// Where to dock the assistant.
299 ///
300 /// Default: right
301 pub dock: Option<AssistantDockPosition>,
302 /// Default width in pixels when the assistant is docked to the left or right.
303 ///
304 /// Default: 640
305 pub default_width: Option<f32>,
306 /// Default height in pixels when the assistant is docked to the bottom.
307 ///
308 /// Default: 320
309 pub default_height: Option<f32>,
310 /// The default OpenAI model to use when starting new conversations.
311 ///
312 /// Default: gpt-4-1106-preview
313 pub default_open_ai_model: Option<OpenAiModel>,
314 /// OpenAI API base URL to use when starting new conversations.
315 ///
316 /// Default: https://api.openai.com/v1
317 pub openai_api_url: Option<String>,
318}
319
320impl Settings for AssistantSettings {
321 const KEY: Option<&'static str> = Some("assistant");
322
323 type FileContent = AssistantSettingsContent;
324
325 fn load(
326 default_value: &Self::FileContent,
327 user_values: &[&Self::FileContent],
328 _: &mut gpui::AppContext,
329 ) -> anyhow::Result<Self> {
330 let mut settings = AssistantSettings::default();
331
332 for value in [default_value].iter().chain(user_values) {
333 let value = value.upgrade();
334 merge(&mut settings.button, value.button);
335 merge(&mut settings.dock, value.dock);
336 merge(
337 &mut settings.default_width,
338 value.default_width.map(Into::into),
339 );
340 merge(
341 &mut settings.default_height,
342 value.default_height.map(Into::into),
343 );
344 if let Some(provider) = value.provider.clone() {
345 match (&mut settings.provider, provider) {
346 (
347 AssistantProvider::ZedDotDev { default_model },
348 AssistantProvider::ZedDotDev {
349 default_model: default_model_override,
350 },
351 ) => {
352 *default_model = default_model_override;
353 }
354 (
355 AssistantProvider::OpenAi {
356 default_model,
357 api_url,
358 },
359 AssistantProvider::OpenAi {
360 default_model: default_model_override,
361 api_url: api_url_override,
362 },
363 ) => {
364 *default_model = default_model_override;
365 *api_url = api_url_override;
366 }
367 (merged, provider_override) => {
368 *merged = provider_override;
369 }
370 }
371 }
372 }
373
374 Ok(settings)
375 }
376}
377
378fn merge<T: Copy>(target: &mut T, value: Option<T>) {
379 if let Some(value) = value {
380 *target = value;
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use gpui::AppContext;
387 use settings::SettingsStore;
388
389 use super::*;
390
391 #[gpui::test]
392 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
393 let store = settings::SettingsStore::test(cx);
394 cx.set_global(store);
395
396 // Settings default to gpt-4-turbo.
397 AssistantSettings::register(cx);
398 assert_eq!(
399 AssistantSettings::get_global(cx).provider,
400 AssistantProvider::OpenAi {
401 default_model: OpenAiModel::FourTurbo,
402 api_url: open_ai_url()
403 }
404 );
405
406 // Ensure backward-compatibility.
407 cx.update_global::<SettingsStore, _>(|store, cx| {
408 store
409 .set_user_settings(
410 r#"{
411 "assistant": {
412 "openai_api_url": "test-url",
413 }
414 }"#,
415 cx,
416 )
417 .unwrap();
418 });
419 assert_eq!(
420 AssistantSettings::get_global(cx).provider,
421 AssistantProvider::OpenAi {
422 default_model: OpenAiModel::FourTurbo,
423 api_url: "test-url".into()
424 }
425 );
426 cx.update_global::<SettingsStore, _>(|store, cx| {
427 store
428 .set_user_settings(
429 r#"{
430 "assistant": {
431 "default_open_ai_model": "gpt-4-0613"
432 }
433 }"#,
434 cx,
435 )
436 .unwrap();
437 });
438 assert_eq!(
439 AssistantSettings::get_global(cx).provider,
440 AssistantProvider::OpenAi {
441 default_model: OpenAiModel::Four,
442 api_url: open_ai_url()
443 }
444 );
445
446 // The new version supports setting a custom model when using zed.dev.
447 cx.update_global::<SettingsStore, _>(|store, cx| {
448 store
449 .set_user_settings(
450 r#"{
451 "assistant": {
452 "version": "1",
453 "provider": {
454 "name": "zed.dev",
455 "default_model": "custom"
456 }
457 }
458 }"#,
459 cx,
460 )
461 .unwrap();
462 });
463 assert_eq!(
464 AssistantSettings::get_global(cx).provider,
465 AssistantProvider::ZedDotDev {
466 default_model: ZedDotDevModel::Custom("custom".into())
467 }
468 );
469 }
470}