assistant_settings.rs

  1use std::fmt;
  2
  3use gpui::Pixels;
  4pub use open_ai::Model as OpenAiModel;
  5use schemars::{
  6    schema::{InstanceType, Metadata, Schema, SchemaObject},
  7    JsonSchema,
  8};
  9use serde::{
 10    de::{self, Visitor},
 11    Deserialize, Deserializer, Serialize, Serializer,
 12};
 13use settings::Settings;
 14
 15#[derive(Clone, Debug, Default, PartialEq)]
 16pub enum ZedDotDevModel {
 17    GptThreePointFiveTurbo,
 18    GptFour,
 19    #[default]
 20    GptFourTurbo,
 21    Custom(String),
 22}
 23
 24impl Serialize for ZedDotDevModel {
 25    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 26    where
 27        S: Serializer,
 28    {
 29        serializer.serialize_str(self.id())
 30    }
 31}
 32
 33impl<'de> Deserialize<'de> for ZedDotDevModel {
 34    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 35    where
 36        D: Deserializer<'de>,
 37    {
 38        struct ZedDotDevModelVisitor;
 39
 40        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 41            type Value = ZedDotDevModel;
 42
 43            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 44                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 45            }
 46
 47            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 48            where
 49                E: de::Error,
 50            {
 51                match value {
 52                    "gpt-3.5-turbo" => Ok(ZedDotDevModel::GptThreePointFiveTurbo),
 53                    "gpt-4" => Ok(ZedDotDevModel::GptFour),
 54                    "gpt-4-turbo-preview" => Ok(ZedDotDevModel::GptFourTurbo),
 55                    _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
 56                }
 57            }
 58        }
 59
 60        deserializer.deserialize_str(ZedDotDevModelVisitor)
 61    }
 62}
 63
 64impl JsonSchema for ZedDotDevModel {
 65    fn schema_name() -> String {
 66        "ZedDotDevModel".to_owned()
 67    }
 68
 69    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 70        let variants = vec![
 71            "gpt-3.5-turbo".to_owned(),
 72            "gpt-4".to_owned(),
 73            "gpt-4-turbo-preview".to_owned(),
 74        ];
 75        Schema::Object(SchemaObject {
 76            instance_type: Some(InstanceType::String.into()),
 77            enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
 78            metadata: Some(Box::new(Metadata {
 79                title: Some("ZedDotDevModel".to_owned()),
 80                default: Some(serde_json::json!("gpt-4-turbo-preview")),
 81                examples: vec![
 82                    serde_json::json!("gpt-3.5-turbo"),
 83                    serde_json::json!("gpt-4"),
 84                    serde_json::json!("gpt-4-turbo-preview"),
 85                    serde_json::json!("custom-model-name"),
 86                ],
 87                ..Default::default()
 88            })),
 89            ..Default::default()
 90        })
 91    }
 92}
 93
 94impl ZedDotDevModel {
 95    pub fn id(&self) -> &str {
 96        match self {
 97            Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
 98            Self::GptFour => "gpt-4",
 99            Self::GptFourTurbo => "gpt-4-turbo-preview",
100            Self::Custom(id) => id,
101        }
102    }
103
104    pub fn display_name(&self) -> &str {
105        match self {
106            Self::GptThreePointFiveTurbo => "gpt-3.5-turbo",
107            Self::GptFour => "gpt-4",
108            Self::GptFourTurbo => "gpt-4-turbo",
109            Self::Custom(id) => id.as_str(),
110        }
111    }
112
113    pub fn max_token_count(&self) -> usize {
114        match self {
115            Self::GptThreePointFiveTurbo => 2048,
116            Self::GptFour => 4096,
117            Self::GptFourTurbo => 128000,
118            Self::Custom(_) => 4096, // TODO: Make this configurable
119        }
120    }
121}
122
123#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
124#[serde(rename_all = "snake_case")]
125pub enum AssistantDockPosition {
126    Left,
127    #[default]
128    Right,
129    Bottom,
130}
131
132#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
133#[serde(tag = "name", rename_all = "snake_case")]
134pub enum AssistantProvider {
135    #[serde(rename = "zed.dev")]
136    ZedDotDev {
137        #[serde(default)]
138        default_model: ZedDotDevModel,
139    },
140    #[serde(rename = "openai")]
141    OpenAi {
142        #[serde(default)]
143        default_model: OpenAiModel,
144        #[serde(default = "open_ai_url")]
145        api_url: String,
146    },
147}
148
149impl Default for AssistantProvider {
150    fn default() -> Self {
151        Self::ZedDotDev {
152            default_model: ZedDotDevModel::default(),
153        }
154    }
155}
156
157fn open_ai_url() -> String {
158    "https://api.openai.com/v1".into()
159}
160
161#[derive(Default, Debug, Deserialize, Serialize)]
162pub struct AssistantSettings {
163    pub enabled: bool,
164    pub button: bool,
165    pub dock: AssistantDockPosition,
166    pub default_width: Pixels,
167    pub default_height: Pixels,
168    pub provider: AssistantProvider,
169}
170
171/// Assistant panel settings
172#[derive(Clone, Serialize, Deserialize, Debug)]
173#[serde(untagged)]
174pub enum AssistantSettingsContent {
175    Versioned(VersionedAssistantSettingsContent),
176    Legacy(LegacyAssistantSettingsContent),
177}
178
179impl JsonSchema for AssistantSettingsContent {
180    fn schema_name() -> String {
181        VersionedAssistantSettingsContent::schema_name()
182    }
183
184    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
185        VersionedAssistantSettingsContent::json_schema(gen)
186    }
187
188    fn is_referenceable() -> bool {
189        VersionedAssistantSettingsContent::is_referenceable()
190    }
191}
192
193impl Default for AssistantSettingsContent {
194    fn default() -> Self {
195        Self::Versioned(VersionedAssistantSettingsContent::default())
196    }
197}
198
199impl AssistantSettingsContent {
200    fn upgrade(&self) -> AssistantSettingsContentV1 {
201        match self {
202            AssistantSettingsContent::Versioned(settings) => match settings {
203                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
204            },
205            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
206                enabled: None,
207                button: settings.button,
208                dock: settings.dock,
209                default_width: settings.default_width,
210                default_height: settings.default_height,
211                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
212                    Some(AssistantProvider::OpenAi {
213                        default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
214                        api_url: open_ai_api_url.clone(),
215                    })
216                } else {
217                    settings.default_open_ai_model.clone().map(|open_ai_model| {
218                        AssistantProvider::OpenAi {
219                            default_model: open_ai_model,
220                            api_url: open_ai_url(),
221                        }
222                    })
223                },
224            },
225        }
226    }
227
228    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
229        match self {
230            AssistantSettingsContent::Versioned(settings) => match settings {
231                VersionedAssistantSettingsContent::V1(settings) => {
232                    settings.dock = Some(dock);
233                }
234            },
235            AssistantSettingsContent::Legacy(settings) => {
236                settings.dock = Some(dock);
237            }
238        }
239    }
240}
241
242#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
243#[serde(tag = "version")]
244pub enum VersionedAssistantSettingsContent {
245    #[serde(rename = "1")]
246    V1(AssistantSettingsContentV1),
247}
248
249impl Default for VersionedAssistantSettingsContent {
250    fn default() -> Self {
251        Self::V1(AssistantSettingsContentV1 {
252            enabled: None,
253            button: None,
254            dock: None,
255            default_width: None,
256            default_height: None,
257            provider: None,
258        })
259    }
260}
261
262#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
263pub struct AssistantSettingsContentV1 {
264    /// Whether the Assistant is enabled.
265    ///
266    /// Default: true
267    enabled: Option<bool>,
268    /// Whether to show the assistant panel button in the status bar.
269    ///
270    /// Default: true
271    button: Option<bool>,
272    /// Where to dock the assistant.
273    ///
274    /// Default: right
275    dock: Option<AssistantDockPosition>,
276    /// Default width in pixels when the assistant is docked to the left or right.
277    ///
278    /// Default: 640
279    default_width: Option<f32>,
280    /// Default height in pixels when the assistant is docked to the bottom.
281    ///
282    /// Default: 320
283    default_height: Option<f32>,
284    /// The provider of the assistant service.
285    ///
286    /// This can either be the internal `zed.dev` service or an external `openai` service,
287    /// each with their respective default models and configurations.
288    provider: Option<AssistantProvider>,
289}
290
291#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
292pub struct LegacyAssistantSettingsContent {
293    /// Whether to show the assistant panel button in the status bar.
294    ///
295    /// Default: true
296    pub button: Option<bool>,
297    /// Where to dock the assistant.
298    ///
299    /// Default: right
300    pub dock: Option<AssistantDockPosition>,
301    /// Default width in pixels when the assistant is docked to the left or right.
302    ///
303    /// Default: 640
304    pub default_width: Option<f32>,
305    /// Default height in pixels when the assistant is docked to the bottom.
306    ///
307    /// Default: 320
308    pub default_height: Option<f32>,
309    /// The default OpenAI model to use when starting new conversations.
310    ///
311    /// Default: gpt-4-1106-preview
312    pub default_open_ai_model: Option<OpenAiModel>,
313    /// OpenAI API base URL to use when starting new conversations.
314    ///
315    /// Default: https://api.openai.com/v1
316    pub openai_api_url: Option<String>,
317}
318
319impl Settings for AssistantSettings {
320    const KEY: Option<&'static str> = Some("assistant");
321
322    type FileContent = AssistantSettingsContent;
323
324    fn load(
325        default_value: &Self::FileContent,
326        user_values: &[&Self::FileContent],
327        _: &mut gpui::AppContext,
328    ) -> anyhow::Result<Self> {
329        let mut settings = AssistantSettings::default();
330
331        for value in [default_value].iter().chain(user_values) {
332            let value = value.upgrade();
333            merge(&mut settings.enabled, value.enabled);
334            merge(&mut settings.button, value.button);
335            merge(&mut settings.dock, value.dock);
336            merge(
337                &mut settings.default_width,
338                value.default_width.map(Into::into),
339            );
340            merge(
341                &mut settings.default_height,
342                value.default_height.map(Into::into),
343            );
344            if let Some(provider) = value.provider.clone() {
345                match (&mut settings.provider, provider) {
346                    (
347                        AssistantProvider::ZedDotDev { default_model },
348                        AssistantProvider::ZedDotDev {
349                            default_model: default_model_override,
350                        },
351                    ) => {
352                        *default_model = default_model_override;
353                    }
354                    (
355                        AssistantProvider::OpenAi {
356                            default_model,
357                            api_url,
358                        },
359                        AssistantProvider::OpenAi {
360                            default_model: default_model_override,
361                            api_url: api_url_override,
362                        },
363                    ) => {
364                        *default_model = default_model_override;
365                        *api_url = api_url_override;
366                    }
367                    (merged, provider_override) => {
368                        *merged = provider_override;
369                    }
370                }
371            }
372        }
373
374        Ok(settings)
375    }
376}
377
378fn merge<T: Copy>(target: &mut T, value: Option<T>) {
379    if let Some(value) = value {
380        *target = value;
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use gpui::{AppContext, BorrowAppContext};
387    use settings::SettingsStore;
388
389    use super::*;
390
391    #[gpui::test]
392    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
393        let store = settings::SettingsStore::test(cx);
394        cx.set_global(store);
395
396        // Settings default to gpt-4-turbo.
397        AssistantSettings::register(cx);
398        assert_eq!(
399            AssistantSettings::get_global(cx).provider,
400            AssistantProvider::OpenAi {
401                default_model: OpenAiModel::FourTurbo,
402                api_url: open_ai_url()
403            }
404        );
405
406        // Ensure backward-compatibility.
407        cx.update_global::<SettingsStore, _>(|store, cx| {
408            store
409                .set_user_settings(
410                    r#"{
411                        "assistant": {
412                            "openai_api_url": "test-url",
413                        }
414                    }"#,
415                    cx,
416                )
417                .unwrap();
418        });
419        assert_eq!(
420            AssistantSettings::get_global(cx).provider,
421            AssistantProvider::OpenAi {
422                default_model: OpenAiModel::FourTurbo,
423                api_url: "test-url".into()
424            }
425        );
426        cx.update_global::<SettingsStore, _>(|store, cx| {
427            store
428                .set_user_settings(
429                    r#"{
430                        "assistant": {
431                            "default_open_ai_model": "gpt-4-0613"
432                        }
433                    }"#,
434                    cx,
435                )
436                .unwrap();
437        });
438        assert_eq!(
439            AssistantSettings::get_global(cx).provider,
440            AssistantProvider::OpenAi {
441                default_model: OpenAiModel::Four,
442                api_url: open_ai_url()
443            }
444        );
445
446        // The new version supports setting a custom model when using zed.dev.
447        cx.update_global::<SettingsStore, _>(|store, cx| {
448            store
449                .set_user_settings(
450                    r#"{
451                        "assistant": {
452                            "version": "1",
453                            "provider": {
454                                "name": "zed.dev",
455                                "default_model": "custom"
456                            }
457                        }
458                    }"#,
459                    cx,
460                )
461                .unwrap();
462        });
463        assert_eq!(
464            AssistantSettings::get_global(cx).provider,
465            AssistantProvider::ZedDotDev {
466                default_model: ZedDotDevModel::Custom("custom".into())
467            }
468        );
469    }
470}