assistant_settings.rs

  1use std::fmt;
  2
  3use gpui::Pixels;
  4pub use open_ai::Model as OpenAiModel;
  5use schemars::{
  6    schema::{InstanceType, Metadata, Schema, SchemaObject},
  7    JsonSchema,
  8};
  9use serde::{
 10    de::{self, Visitor},
 11    Deserialize, Deserializer, Serialize, Serializer,
 12};
 13use settings::Settings;
 14
 15#[derive(Clone, Debug, Default, PartialEq)]
 16pub enum ZedDotDevModel {
 17    GptThreePointFiveTurbo,
 18    GptFour,
 19    #[default]
 20    GptFourTurbo,
 21    Custom(String),
 22}
 23
 24impl Serialize for ZedDotDevModel {
 25    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 26    where
 27        S: Serializer,
 28    {
 29        serializer.serialize_str(self.id())
 30    }
 31}
 32
 33impl<'de> Deserialize<'de> for ZedDotDevModel {
 34    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 35    where
 36        D: Deserializer<'de>,
 37    {
 38        struct ZedDotDevModelVisitor;
 39
 40        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 41            type Value = ZedDotDevModel;
 42
 43            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 44                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 45            }
 46
 47            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 48            where
 49                E: de::Error,
 50            {
 51                match value {
 52                    "gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
 53                    "gpt-4" => Ok(ZedDotDevModel::GptFour),
 54                    "gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
 55                    _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
 56                }
 57            }
 58        }
 59
 60        deserializer.deserialize_str(ZedDotDevModelVisitor)
 61    }
 62}
 63
 64impl JsonSchema for ZedDotDevModel {
 65    fn schema_name() -> String {
 66        "ZedDotDevModel".to_owned()
 67    }
 68
 69    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 70        let variants = vec![
 71            "gpt-3.5-turbo".to_owned(),
 72            "gpt-4".to_owned(),
 73            "gpt-4-turbo-preview".to_owned(),
 74        ];
 75        Schema::Object(SchemaObject {
 76            instance_type: Some(InstanceType::String.into()),
 77            enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
 78            metadata: Some(Box::new(Metadata {
 79                title: Some("ZedDotDevModel".to_owned()),
 80                default: Some(serde_json::json!("gpt-4-turbo-preview")),
 81                examples: vec![
 82                    serde_json::json!("gpt-3.5-turbo"),
 83                    serde_json::json!("gpt-4"),
 84                    serde_json::json!("gpt-4-turbo-preview"),
 85                    serde_json::json!("custom-model-name"),
 86                ],
 87                ..Default::default()
 88            })),
 89            ..Default::default()
 90        })
 91    }
 92}
 93
 94impl ZedDotDevModel {
 95    pub fn id(&self) -> &str {
 96        match self {
 97            Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
 98            Self::GptFour => "gpt-4",
 99            Self::GptFourTurbo => "gpt-4-turbo-preview",
100            Self::Custom(id) => id,
101        }
102    }
103
104    pub fn display_name(&self) -> &str {
105        match self {
106            Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
107            Self::GptFour => "gpt-4",
108            Self::GptFourTurbo => "gpt-4-turbo",
109            Self::Custom(id) => id.as_str(),
110        }
111    }
112}
113
114#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
115#[serde(rename_all = "snake_case")]
116pub enum AssistantDockPosition {
117    Left,
118    #[default]
119    Right,
120    Bottom,
121}
122
123#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
124#[serde(tag = "name", rename_all = "snake_case")]
125pub enum AssistantProvider {
126    #[serde(rename = "zed.dev")]
127    ZedDotDev {
128        #[serde(default)]
129        default_model: ZedDotDevModel,
130    },
131    #[serde(rename = "openai")]
132    OpenAi {
133        #[serde(default)]
134        default_model: OpenAiModel,
135        #[serde(default = "open_ai_url")]
136        api_url: String,
137    },
138}
139
140impl Default for AssistantProvider {
141    fn default() -> Self {
142        Self::ZedDotDev {
143            default_model: ZedDotDevModel::default(),
144        }
145    }
146}
147
148fn open_ai_url() -> String {
149    "https://api.openai.com/v1".into()
150}
151
152#[derive(Default, Debug, Deserialize, Serialize)]
153pub struct AssistantSettings {
154    pub button: bool,
155    pub dock: AssistantDockPosition,
156    pub default_width: Pixels,
157    pub default_height: Pixels,
158    pub provider: AssistantProvider,
159}
160
161/// Assistant panel settings
162#[derive(Clone, Serialize, Deserialize, Debug)]
163#[serde(untagged)]
164pub enum AssistantSettingsContent {
165    Versioned(VersionedAssistantSettingsContent),
166    Legacy(LegacyAssistantSettingsContent),
167}
168
169impl JsonSchema for AssistantSettingsContent {
170    fn schema_name() -> String {
171        VersionedAssistantSettingsContent::schema_name()
172    }
173
174    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
175        VersionedAssistantSettingsContent::json_schema(gen)
176    }
177
178    fn is_referenceable() -> bool {
179        VersionedAssistantSettingsContent::is_referenceable()
180    }
181}
182
183impl Default for AssistantSettingsContent {
184    fn default() -> Self {
185        Self::Versioned(VersionedAssistantSettingsContent::default())
186    }
187}
188
189impl AssistantSettingsContent {
190    fn upgrade(&self) -> AssistantSettingsContentV1 {
191        match self {
192            AssistantSettingsContent::Versioned(settings) => match settings {
193                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
194            },
195            AssistantSettingsContent::Legacy(settings) => {
196                if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
197                    AssistantSettingsContentV1 {
198                        button: settings.button,
199                        dock: settings.dock,
200                        default_width: settings.default_width,
201                        default_height: settings.default_height,
202                        provider: Some(AssistantProvider::OpenAi {
203                            default_model: settings
204                                .default_open_ai_model
205                                .clone()
206                                .unwrap_or_default(),
207                            api_url: open_ai_api_url.clone(),
208                        }),
209                    }
210                } else if let Some(open_ai_model) = settings.default_open_ai_model.clone() {
211                    AssistantSettingsContentV1 {
212                        button: settings.button,
213                        dock: settings.dock,
214                        default_width: settings.default_width,
215                        default_height: settings.default_height,
216                        provider: Some(AssistantProvider::OpenAi {
217                            default_model: open_ai_model,
218                            api_url: open_ai_url(),
219                        }),
220                    }
221                } else {
222                    AssistantSettingsContentV1 {
223                        button: settings.button,
224                        dock: settings.dock,
225                        default_width: settings.default_width,
226                        default_height: settings.default_height,
227                        provider: None,
228                    }
229                }
230            }
231        }
232    }
233
234    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
235        match self {
236            AssistantSettingsContent::Versioned(settings) => match settings {
237                VersionedAssistantSettingsContent::V1(settings) => {
238                    settings.dock = Some(dock);
239                }
240            },
241            AssistantSettingsContent::Legacy(settings) => {
242                settings.dock = Some(dock);
243            }
244        }
245    }
246}
247
248#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
249#[serde(tag = "version")]
250pub enum VersionedAssistantSettingsContent {
251    #[serde(rename = "1")]
252    V1(AssistantSettingsContentV1),
253}
254
255impl Default for VersionedAssistantSettingsContent {
256    fn default() -> Self {
257        Self::V1(AssistantSettingsContentV1 {
258            button: None,
259            dock: None,
260            default_width: None,
261            default_height: None,
262            provider: None,
263        })
264    }
265}
266
267#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
268pub struct AssistantSettingsContentV1 {
269    /// Whether to show the assistant panel button in the status bar.
270    ///
271    /// Default: true
272    button: Option<bool>,
273    /// Where to dock the assistant.
274    ///
275    /// Default: right
276    dock: Option<AssistantDockPosition>,
277    /// Default width in pixels when the assistant is docked to the left or right.
278    ///
279    /// Default: 640
280    default_width: Option<f32>,
281    /// Default height in pixels when the assistant is docked to the bottom.
282    ///
283    /// Default: 320
284    default_height: Option<f32>,
285    /// The provider of the assistant service.
286    ///
287    /// This can either be the internal `zed.dev` service or an external `openai` service,
288    /// each with their respective default models and configurations.
289    provider: Option<AssistantProvider>,
290}
291
292#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
293pub struct LegacyAssistantSettingsContent {
294    /// Whether to show the assistant panel button in the status bar.
295    ///
296    /// Default: true
297    pub button: Option<bool>,
298    /// Where to dock the assistant.
299    ///
300    /// Default: right
301    pub dock: Option<AssistantDockPosition>,
302    /// Default width in pixels when the assistant is docked to the left or right.
303    ///
304    /// Default: 640
305    pub default_width: Option<f32>,
306    /// Default height in pixels when the assistant is docked to the bottom.
307    ///
308    /// Default: 320
309    pub default_height: Option<f32>,
310    /// The default OpenAI model to use when starting new conversations.
311    ///
312    /// Default: gpt-4-1106-preview
313    pub default_open_ai_model: Option<OpenAiModel>,
314    /// OpenAI API base URL to use when starting new conversations.
315    ///
316    /// Default: https://api.openai.com/v1
317    pub openai_api_url: Option<String>,
318}
319
320impl Settings for AssistantSettings {
321    const KEY: Option<&'static str> = Some("assistant");
322
323    type FileContent = AssistantSettingsContent;
324
325    fn load(
326        default_value: &Self::FileContent,
327        user_values: &[&Self::FileContent],
328        _: &mut gpui::AppContext,
329    ) -> anyhow::Result<Self> {
330        let mut settings = AssistantSettings::default();
331
332        for value in [default_value].iter().chain(user_values) {
333            let value = value.upgrade();
334            merge(&mut settings.button, value.button);
335            merge(&mut settings.dock, value.dock);
336            merge(
337                &mut settings.default_width,
338                value.default_width.map(Into::into),
339            );
340            merge(
341                &mut settings.default_height,
342                value.default_height.map(Into::into),
343            );
344            if let Some(provider) = value.provider.clone() {
345                match (&mut settings.provider, provider) {
346                    (
347                        AssistantProvider::ZedDotDev { default_model },
348                        AssistantProvider::ZedDotDev {
349                            default_model: default_model_override,
350                        },
351                    ) => {
352                        *default_model = default_model_override;
353                    }
354                    (
355                        AssistantProvider::OpenAi {
356                            default_model,
357                            api_url,
358                        },
359                        AssistantProvider::OpenAi {
360                            default_model: default_model_override,
361                            api_url: api_url_override,
362                        },
363                    ) => {
364                        *default_model = default_model_override;
365                        *api_url = api_url_override;
366                    }
367                    (merged, provider_override) => {
368                        *merged = provider_override;
369                    }
370                }
371            }
372        }
373
374        Ok(settings)
375    }
376}
377
378fn merge<T: Copy>(target: &mut T, value: Option<T>) {
379    if let Some(value) = value {
380        *target = value;
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use gpui::AppContext;
387    use settings::SettingsStore;
388
389    use super::*;
390
391    #[gpui::test]
392    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
393        let store = settings::SettingsStore::test(cx);
394        cx.set_global(store);
395
396        // Settings default to gpt-4-turbo.
397        AssistantSettings::register(cx);
398        assert_eq!(
399            AssistantSettings::get_global(cx).provider,
400            AssistantProvider::OpenAi {
401                default_model: OpenAiModel::FourTurbo,
402                api_url: open_ai_url()
403            }
404        );
405
406        // Ensure backward-compatibility.
407        cx.update_global::<SettingsStore, _>(|store, cx| {
408            store
409                .set_user_settings(
410                    r#"{
411                        "assistant": {
412                            "openai_api_url": "test-url",
413                        }
414                    }"#,
415                    cx,
416                )
417                .unwrap();
418        });
419        assert_eq!(
420            AssistantSettings::get_global(cx).provider,
421            AssistantProvider::OpenAi {
422                default_model: OpenAiModel::FourTurbo,
423                api_url: "test-url".into()
424            }
425        );
426        cx.update_global::<SettingsStore, _>(|store, cx| {
427            store
428                .set_user_settings(
429                    r#"{
430                        "assistant": {
431                            "default_open_ai_model": "gpt-4-0613"
432                        }
433                    }"#,
434                    cx,
435                )
436                .unwrap();
437        });
438        assert_eq!(
439            AssistantSettings::get_global(cx).provider,
440            AssistantProvider::OpenAi {
441                default_model: OpenAiModel::Four,
442                api_url: open_ai_url()
443            }
444        );
445
446        // The new version supports setting a custom model when using zed.dev.
447        cx.update_global::<SettingsStore, _>(|store, cx| {
448            store
449                .set_user_settings(
450                    r#"{
451                        "assistant": {
452                            "version": "1",
453                            "provider": {
454                                "name": "zed.dev",
455                                "default_model": "custom"
456                            }
457                        }
458                    }"#,
459                    cx,
460                )
461                .unwrap();
462        });
463        assert_eq!(
464            AssistantSettings::get_global(cx).provider,
465            AssistantProvider::ZedDotDev {
466                default_model: ZedDotDevModel::Custom("custom".into())
467            }
468        );
469    }
470}