assistant_settings.rs

  1use std::fmt;
  2
  3use gpui::Pixels;
  4pub use open_ai::Model as OpenAiModel;
  5use schemars::{
  6    schema::{InstanceType, Metadata, Schema, SchemaObject},
  7    JsonSchema,
  8};
  9use serde::{
 10    de::{self, Visitor},
 11    Deserialize, Deserializer, Serialize, Serializer,
 12};
 13use settings::Settings;
 14
 15#[derive(Clone, Debug, Default, PartialEq)]
 16pub enum ZedDotDevModel {
 17    GptThreePointFiveTurbo,
 18    GptFour,
 19    #[default]
 20    GptFourTurbo,
 21    Custom(String),
 22}
 23
 24impl Serialize for ZedDotDevModel {
 25    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 26    where
 27        S: Serializer,
 28    {
 29        serializer.serialize_str(self.id())
 30    }
 31}
 32
 33impl<'de> Deserialize<'de> for ZedDotDevModel {
 34    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 35    where
 36        D: Deserializer<'de>,
 37    {
 38        struct ZedDotDevModelVisitor;
 39
 40        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 41            type Value = ZedDotDevModel;
 42
 43            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 44                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 45            }
 46
 47            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 48            where
 49                E: de::Error,
 50            {
 51                match value {
 52                    "gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
 53                    "gpt-4" => Ok(ZedDotDevModel::GptFour),
 54                    "gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
 55                    _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
 56                }
 57            }
 58        }
 59
 60        deserializer.deserialize_str(ZedDotDevModelVisitor)
 61    }
 62}
 63
 64impl JsonSchema for ZedDotDevModel {
 65    fn schema_name() -> String {
 66        "ZedDotDevModel".to_owned()
 67    }
 68
 69    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 70        let variants = vec![
 71            "gpt-3.5-turbo".to_owned(),
 72            "gpt-4".to_owned(),
 73            "gpt-4-turbo-preview".to_owned(),
 74        ];
 75        Schema::Object(SchemaObject {
 76            instance_type: Some(InstanceType::String.into()),
 77            enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
 78            metadata: Some(Box::new(Metadata {
 79                title: Some("ZedDotDevModel".to_owned()),
 80                default: Some(serde_json::json!("gpt-4-turbo-preview")),
 81                examples: vec![
 82                    serde_json::json!("gpt-3.5-turbo"),
 83                    serde_json::json!("gpt-4"),
 84                    serde_json::json!("gpt-4-turbo-preview"),
 85                    serde_json::json!("custom-model-name"),
 86                ],
 87                ..Default::default()
 88            })),
 89            ..Default::default()
 90        })
 91    }
 92}
 93
 94impl ZedDotDevModel {
 95    pub fn id(&self) -> &str {
 96        match self {
 97            Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
 98            Self::GptFour => "gpt-4",
 99            Self::GptFourTurbo => "gpt-4-turbo-preview",
100            Self::Custom(id) => id,
101        }
102    }
103
104    pub fn display_name(&self) -> &str {
105        match self {
106            Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
107            Self::GptFour => "gpt-4",
108            Self::GptFourTurbo => "gpt-4-turbo",
109            Self::Custom(id) => id.as_str(),
110        }
111    }
112
113    pub fn max_token_count(&self) -> usize {
114        match self {
115            Self::GptThreePointFiveTurbo => 2048,
116            Self::GptFour => 4096,
117            Self::GptFourTurbo => 128000,
118            Self::Custom(_) => 4096, // TODO: Make this configurable
119        }
120    }
121}
122
123#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
124#[serde(rename_all = "snake_case")]
125pub enum AssistantDockPosition {
126    Left,
127    #[default]
128    Right,
129    Bottom,
130}
131
132#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
133#[serde(tag = "name", rename_all = "snake_case")]
134pub enum AssistantProvider {
135    #[serde(rename = "zed.dev")]
136    ZedDotDev {
137        #[serde(default)]
138        default_model: ZedDotDevModel,
139    },
140    #[serde(rename = "openai")]
141    OpenAi {
142        #[serde(default)]
143        default_model: OpenAiModel,
144        #[serde(default = "open_ai_url")]
145        api_url: String,
146    },
147}
148
149impl Default for AssistantProvider {
150    fn default() -> Self {
151        Self::ZedDotDev {
152            default_model: ZedDotDevModel::default(),
153        }
154    }
155}
156
157fn open_ai_url() -> String {
158    "https://api.openai.com/v1".into()
159}
160
161#[derive(Default, Debug, Deserialize, Serialize)]
162pub struct AssistantSettings {
163    pub button: bool,
164    pub dock: AssistantDockPosition,
165    pub default_width: Pixels,
166    pub default_height: Pixels,
167    pub provider: AssistantProvider,
168}
169
170/// Assistant panel settings
171#[derive(Clone, Serialize, Deserialize, Debug)]
172#[serde(untagged)]
173pub enum AssistantSettingsContent {
174    Versioned(VersionedAssistantSettingsContent),
175    Legacy(LegacyAssistantSettingsContent),
176}
177
178impl JsonSchema for AssistantSettingsContent {
179    fn schema_name() -> String {
180        VersionedAssistantSettingsContent::schema_name()
181    }
182
183    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
184        VersionedAssistantSettingsContent::json_schema(gen)
185    }
186
187    fn is_referenceable() -> bool {
188        VersionedAssistantSettingsContent::is_referenceable()
189    }
190}
191
192impl Default for AssistantSettingsContent {
193    fn default() -> Self {
194        Self::Versioned(VersionedAssistantSettingsContent::default())
195    }
196}
197
198impl AssistantSettingsContent {
199    fn upgrade(&self) -> AssistantSettingsContentV1 {
200        match self {
201            AssistantSettingsContent::Versioned(settings) => match settings {
202                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
203            },
204            AssistantSettingsContent::Legacy(settings) => {
205                if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
206                    AssistantSettingsContentV1 {
207                        button: settings.button,
208                        dock: settings.dock,
209                        default_width: settings.default_width,
210                        default_height: settings.default_height,
211                        provider: Some(AssistantProvider::OpenAi {
212                            default_model: settings
213                                .default_open_ai_model
214                                .clone()
215                                .unwrap_or_default(),
216                            api_url: open_ai_api_url.clone(),
217                        }),
218                    }
219                } else if let Some(open_ai_model) = settings.default_open_ai_model.clone() {
220                    AssistantSettingsContentV1 {
221                        button: settings.button,
222                        dock: settings.dock,
223                        default_width: settings.default_width,
224                        default_height: settings.default_height,
225                        provider: Some(AssistantProvider::OpenAi {
226                            default_model: open_ai_model,
227                            api_url: open_ai_url(),
228                        }),
229                    }
230                } else {
231                    AssistantSettingsContentV1 {
232                        button: settings.button,
233                        dock: settings.dock,
234                        default_width: settings.default_width,
235                        default_height: settings.default_height,
236                        provider: None,
237                    }
238                }
239            }
240        }
241    }
242
243    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
244        match self {
245            AssistantSettingsContent::Versioned(settings) => match settings {
246                VersionedAssistantSettingsContent::V1(settings) => {
247                    settings.dock = Some(dock);
248                }
249            },
250            AssistantSettingsContent::Legacy(settings) => {
251                settings.dock = Some(dock);
252            }
253        }
254    }
255}
256
257#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
258#[serde(tag = "version")]
259pub enum VersionedAssistantSettingsContent {
260    #[serde(rename = "1")]
261    V1(AssistantSettingsContentV1),
262}
263
264impl Default for VersionedAssistantSettingsContent {
265    fn default() -> Self {
266        Self::V1(AssistantSettingsContentV1 {
267            button: None,
268            dock: None,
269            default_width: None,
270            default_height: None,
271            provider: None,
272        })
273    }
274}
275
276#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
277pub struct AssistantSettingsContentV1 {
278    /// Whether to show the assistant panel button in the status bar.
279    ///
280    /// Default: true
281    button: Option<bool>,
282    /// Where to dock the assistant.
283    ///
284    /// Default: right
285    dock: Option<AssistantDockPosition>,
286    /// Default width in pixels when the assistant is docked to the left or right.
287    ///
288    /// Default: 640
289    default_width: Option<f32>,
290    /// Default height in pixels when the assistant is docked to the bottom.
291    ///
292    /// Default: 320
293    default_height: Option<f32>,
294    /// The provider of the assistant service.
295    ///
296    /// This can either be the internal `zed.dev` service or an external `openai` service,
297    /// each with their respective default models and configurations.
298    provider: Option<AssistantProvider>,
299}
300
301#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
302pub struct LegacyAssistantSettingsContent {
303    /// Whether to show the assistant panel button in the status bar.
304    ///
305    /// Default: true
306    pub button: Option<bool>,
307    /// Where to dock the assistant.
308    ///
309    /// Default: right
310    pub dock: Option<AssistantDockPosition>,
311    /// Default width in pixels when the assistant is docked to the left or right.
312    ///
313    /// Default: 640
314    pub default_width: Option<f32>,
315    /// Default height in pixels when the assistant is docked to the bottom.
316    ///
317    /// Default: 320
318    pub default_height: Option<f32>,
319    /// The default OpenAI model to use when starting new conversations.
320    ///
321    /// Default: gpt-4-1106-preview
322    pub default_open_ai_model: Option<OpenAiModel>,
323    /// OpenAI API base URL to use when starting new conversations.
324    ///
325    /// Default: https://api.openai.com/v1
326    pub openai_api_url: Option<String>,
327}
328
329impl Settings for AssistantSettings {
330    const KEY: Option<&'static str> = Some("assistant");
331
332    type FileContent = AssistantSettingsContent;
333
334    fn load(
335        default_value: &Self::FileContent,
336        user_values: &[&Self::FileContent],
337        _: &mut gpui::AppContext,
338    ) -> anyhow::Result<Self> {
339        let mut settings = AssistantSettings::default();
340
341        for value in [default_value].iter().chain(user_values) {
342            let value = value.upgrade();
343            merge(&mut settings.button, value.button);
344            merge(&mut settings.dock, value.dock);
345            merge(
346                &mut settings.default_width,
347                value.default_width.map(Into::into),
348            );
349            merge(
350                &mut settings.default_height,
351                value.default_height.map(Into::into),
352            );
353            if let Some(provider) = value.provider.clone() {
354                match (&mut settings.provider, provider) {
355                    (
356                        AssistantProvider::ZedDotDev { default_model },
357                        AssistantProvider::ZedDotDev {
358                            default_model: default_model_override,
359                        },
360                    ) => {
361                        *default_model = default_model_override;
362                    }
363                    (
364                        AssistantProvider::OpenAi {
365                            default_model,
366                            api_url,
367                        },
368                        AssistantProvider::OpenAi {
369                            default_model: default_model_override,
370                            api_url: api_url_override,
371                        },
372                    ) => {
373                        *default_model = default_model_override;
374                        *api_url = api_url_override;
375                    }
376                    (merged, provider_override) => {
377                        *merged = provider_override;
378                    }
379                }
380            }
381        }
382
383        Ok(settings)
384    }
385}
386
387fn merge<T: Copy>(target: &mut T, value: Option<T>) {
388    if let Some(value) = value {
389        *target = value;
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use gpui::AppContext;
396    use settings::SettingsStore;
397
398    use super::*;
399
400    #[gpui::test]
401    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
402        let store = settings::SettingsStore::test(cx);
403        cx.set_global(store);
404
405        // Settings default to gpt-4-turbo.
406        AssistantSettings::register(cx);
407        assert_eq!(
408            AssistantSettings::get_global(cx).provider,
409            AssistantProvider::OpenAi {
410                default_model: OpenAiModel::FourTurbo,
411                api_url: open_ai_url()
412            }
413        );
414
415        // Ensure backward-compatibility.
416        cx.update_global::<SettingsStore, _>(|store, cx| {
417            store
418                .set_user_settings(
419                    r#"{
420                        "assistant": {
421                            "openai_api_url": "test-url",
422                        }
423                    }"#,
424                    cx,
425                )
426                .unwrap();
427        });
428        assert_eq!(
429            AssistantSettings::get_global(cx).provider,
430            AssistantProvider::OpenAi {
431                default_model: OpenAiModel::FourTurbo,
432                api_url: "test-url".into()
433            }
434        );
435        cx.update_global::<SettingsStore, _>(|store, cx| {
436            store
437                .set_user_settings(
438                    r#"{
439                        "assistant": {
440                            "default_open_ai_model": "gpt-4-0613"
441                        }
442                    }"#,
443                    cx,
444                )
445                .unwrap();
446        });
447        assert_eq!(
448            AssistantSettings::get_global(cx).provider,
449            AssistantProvider::OpenAi {
450                default_model: OpenAiModel::Four,
451                api_url: open_ai_url()
452            }
453        );
454
455        // The new version supports setting a custom model when using zed.dev.
456        cx.update_global::<SettingsStore, _>(|store, cx| {
457            store
458                .set_user_settings(
459                    r#"{
460                        "assistant": {
461                            "version": "1",
462                            "provider": {
463                                "name": "zed.dev",
464                                "default_model": "custom"
465                            }
466                        }
467                    }"#,
468                    cx,
469                )
470                .unwrap();
471        });
472        assert_eq!(
473            AssistantSettings::get_global(cx).provider,
474            AssistantProvider::ZedDotDev {
475                default_model: ZedDotDevModel::Custom("custom".into())
476            }
477        );
478    }
479}