assistant_settings.rs

  1use std::fmt;
  2
  3use gpui::Pixels;
  4pub use open_ai::Model as OpenAiModel;
  5use schemars::{
  6    schema::{InstanceType, Metadata, Schema, SchemaObject},
  7    JsonSchema,
  8};
  9use serde::{
 10    de::{self, Visitor},
 11    Deserialize, Deserializer, Serialize, Serializer,
 12};
 13use settings::Settings;
 14
 15#[derive(Clone, Debug, Default, PartialEq)]
 16pub enum ZedDotDevModel {
 17    Gpt3Point5Turbo,
 18    Gpt4,
 19    #[default]
 20    Gpt4Turbo,
 21    Claude3Opus,
 22    Claude3Sonnet,
 23    Claude3Haiku,
 24    Custom(String),
 25}
 26
 27impl Serialize for ZedDotDevModel {
 28    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 29    where
 30        S: Serializer,
 31    {
 32        serializer.serialize_str(self.id())
 33    }
 34}
 35
 36impl<'de> Deserialize<'de> for ZedDotDevModel {
 37    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 38    where
 39        D: Deserializer<'de>,
 40    {
 41        struct ZedDotDevModelVisitor;
 42
 43        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 44            type Value = ZedDotDevModel;
 45
 46            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 47                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 48            }
 49
 50            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 51            where
 52                E: de::Error,
 53            {
 54                match value {
 55                    "gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
 56                    "gpt-4" => Ok(ZedDotDevModel::Gpt4),
 57                    "gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
 58                    _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
 59                }
 60            }
 61        }
 62
 63        deserializer.deserialize_str(ZedDotDevModelVisitor)
 64    }
 65}
 66
 67impl JsonSchema for ZedDotDevModel {
 68    fn schema_name() -> String {
 69        "ZedDotDevModel".to_owned()
 70    }
 71
 72    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 73        let variants = vec![
 74            "gpt-3.5-turbo".to_owned(),
 75            "gpt-4".to_owned(),
 76            "gpt-4-turbo-preview".to_owned(),
 77        ];
 78        Schema::Object(SchemaObject {
 79            instance_type: Some(InstanceType::String.into()),
 80            enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
 81            metadata: Some(Box::new(Metadata {
 82                title: Some("ZedDotDevModel".to_owned()),
 83                default: Some(serde_json::json!("gpt-4-turbo-preview")),
 84                examples: vec![
 85                    serde_json::json!("gpt-3.5-turbo"),
 86                    serde_json::json!("gpt-4"),
 87                    serde_json::json!("gpt-4-turbo-preview"),
 88                    serde_json::json!("custom-model-name"),
 89                ],
 90                ..Default::default()
 91            })),
 92            ..Default::default()
 93        })
 94    }
 95}
 96
 97impl ZedDotDevModel {
 98    pub fn id(&self) -> &str {
 99        match self {
100            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
101            Self::Gpt4 => "gpt-4",
102            Self::Gpt4Turbo => "gpt-4-turbo-preview",
103            Self::Claude3Opus => "claude-3-opus",
104            Self::Claude3Sonnet => "claude-3-sonnet",
105            Self::Claude3Haiku => "claude-3-haiku",
106            Self::Custom(id) => id,
107        }
108    }
109
110    pub fn display_name(&self) -> &str {
111        match self {
112            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
113            Self::Gpt4 => "GPT 4",
114            Self::Gpt4Turbo => "GPT 4 Turbo",
115            Self::Claude3Opus => "Claude 3 Opus",
116            Self::Claude3Sonnet => "Claude 3 Sonnet",
117            Self::Claude3Haiku => "Claude 3 Haiku",
118            Self::Custom(id) => id.as_str(),
119        }
120    }
121
122    pub fn max_token_count(&self) -> usize {
123        match self {
124            Self::Gpt3Point5Turbo => 2048,
125            Self::Gpt4 => 4096,
126            Self::Gpt4Turbo => 128000,
127            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
128            Self::Custom(_) => 4096, // TODO: Make this configurable
129        }
130    }
131}
132
133#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
134#[serde(rename_all = "snake_case")]
135pub enum AssistantDockPosition {
136    Left,
137    #[default]
138    Right,
139    Bottom,
140}
141
142#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
143#[serde(tag = "name", rename_all = "snake_case")]
144pub enum AssistantProvider {
145    #[serde(rename = "zed.dev")]
146    ZedDotDev {
147        #[serde(default)]
148        default_model: ZedDotDevModel,
149    },
150    #[serde(rename = "openai")]
151    OpenAi {
152        #[serde(default)]
153        default_model: OpenAiModel,
154        #[serde(default = "open_ai_url")]
155        api_url: String,
156    },
157}
158
159impl Default for AssistantProvider {
160    fn default() -> Self {
161        Self::ZedDotDev {
162            default_model: ZedDotDevModel::default(),
163        }
164    }
165}
166
167fn open_ai_url() -> String {
168    "https://api.openai.com/v1".into()
169}
170
171#[derive(Default, Debug, Deserialize, Serialize)]
172pub struct AssistantSettings {
173    pub enabled: bool,
174    pub button: bool,
175    pub dock: AssistantDockPosition,
176    pub default_width: Pixels,
177    pub default_height: Pixels,
178    pub provider: AssistantProvider,
179}
180
181/// Assistant panel settings
182#[derive(Clone, Serialize, Deserialize, Debug)]
183#[serde(untagged)]
184pub enum AssistantSettingsContent {
185    Versioned(VersionedAssistantSettingsContent),
186    Legacy(LegacyAssistantSettingsContent),
187}
188
189impl JsonSchema for AssistantSettingsContent {
190    fn schema_name() -> String {
191        VersionedAssistantSettingsContent::schema_name()
192    }
193
194    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
195        VersionedAssistantSettingsContent::json_schema(gen)
196    }
197
198    fn is_referenceable() -> bool {
199        VersionedAssistantSettingsContent::is_referenceable()
200    }
201}
202
203impl Default for AssistantSettingsContent {
204    fn default() -> Self {
205        Self::Versioned(VersionedAssistantSettingsContent::default())
206    }
207}
208
209impl AssistantSettingsContent {
210    fn upgrade(&self) -> AssistantSettingsContentV1 {
211        match self {
212            AssistantSettingsContent::Versioned(settings) => match settings {
213                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
214            },
215            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
216                enabled: None,
217                button: settings.button,
218                dock: settings.dock,
219                default_width: settings.default_width,
220                default_height: settings.default_height,
221                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
222                    Some(AssistantProvider::OpenAi {
223                        default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
224                        api_url: open_ai_api_url.clone(),
225                    })
226                } else {
227                    settings.default_open_ai_model.clone().map(|open_ai_model| {
228                        AssistantProvider::OpenAi {
229                            default_model: open_ai_model,
230                            api_url: open_ai_url(),
231                        }
232                    })
233                },
234            },
235        }
236    }
237
238    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
239        match self {
240            AssistantSettingsContent::Versioned(settings) => match settings {
241                VersionedAssistantSettingsContent::V1(settings) => {
242                    settings.dock = Some(dock);
243                }
244            },
245            AssistantSettingsContent::Legacy(settings) => {
246                settings.dock = Some(dock);
247            }
248        }
249    }
250}
251
252#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
253#[serde(tag = "version")]
254pub enum VersionedAssistantSettingsContent {
255    #[serde(rename = "1")]
256    V1(AssistantSettingsContentV1),
257}
258
259impl Default for VersionedAssistantSettingsContent {
260    fn default() -> Self {
261        Self::V1(AssistantSettingsContentV1 {
262            enabled: None,
263            button: None,
264            dock: None,
265            default_width: None,
266            default_height: None,
267            provider: None,
268        })
269    }
270}
271
272#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
273pub struct AssistantSettingsContentV1 {
274    /// Whether the Assistant is enabled.
275    ///
276    /// Default: true
277    enabled: Option<bool>,
278    /// Whether to show the assistant panel button in the status bar.
279    ///
280    /// Default: true
281    button: Option<bool>,
282    /// Where to dock the assistant.
283    ///
284    /// Default: right
285    dock: Option<AssistantDockPosition>,
286    /// Default width in pixels when the assistant is docked to the left or right.
287    ///
288    /// Default: 640
289    default_width: Option<f32>,
290    /// Default height in pixels when the assistant is docked to the bottom.
291    ///
292    /// Default: 320
293    default_height: Option<f32>,
294    /// The provider of the assistant service.
295    ///
296    /// This can either be the internal `zed.dev` service or an external `openai` service,
297    /// each with their respective default models and configurations.
298    provider: Option<AssistantProvider>,
299}
300
301#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
302pub struct LegacyAssistantSettingsContent {
303    /// Whether to show the assistant panel button in the status bar.
304    ///
305    /// Default: true
306    pub button: Option<bool>,
307    /// Where to dock the assistant.
308    ///
309    /// Default: right
310    pub dock: Option<AssistantDockPosition>,
311    /// Default width in pixels when the assistant is docked to the left or right.
312    ///
313    /// Default: 640
314    pub default_width: Option<f32>,
315    /// Default height in pixels when the assistant is docked to the bottom.
316    ///
317    /// Default: 320
318    pub default_height: Option<f32>,
319    /// The default OpenAI model to use when starting new conversations.
320    ///
321    /// Default: gpt-4-1106-preview
322    pub default_open_ai_model: Option<OpenAiModel>,
323    /// OpenAI API base URL to use when starting new conversations.
324    ///
325    /// Default: https://api.openai.com/v1
326    pub openai_api_url: Option<String>,
327}
328
329impl Settings for AssistantSettings {
330    const KEY: Option<&'static str> = Some("assistant");
331
332    type FileContent = AssistantSettingsContent;
333
334    fn load(
335        default_value: &Self::FileContent,
336        user_values: &[&Self::FileContent],
337        _: &mut gpui::AppContext,
338    ) -> anyhow::Result<Self> {
339        let mut settings = AssistantSettings::default();
340
341        for value in [default_value].iter().chain(user_values) {
342            let value = value.upgrade();
343            merge(&mut settings.enabled, value.enabled);
344            merge(&mut settings.button, value.button);
345            merge(&mut settings.dock, value.dock);
346            merge(
347                &mut settings.default_width,
348                value.default_width.map(Into::into),
349            );
350            merge(
351                &mut settings.default_height,
352                value.default_height.map(Into::into),
353            );
354            if let Some(provider) = value.provider.clone() {
355                match (&mut settings.provider, provider) {
356                    (
357                        AssistantProvider::ZedDotDev { default_model },
358                        AssistantProvider::ZedDotDev {
359                            default_model: default_model_override,
360                        },
361                    ) => {
362                        *default_model = default_model_override;
363                    }
364                    (
365                        AssistantProvider::OpenAi {
366                            default_model,
367                            api_url,
368                        },
369                        AssistantProvider::OpenAi {
370                            default_model: default_model_override,
371                            api_url: api_url_override,
372                        },
373                    ) => {
374                        *default_model = default_model_override;
375                        *api_url = api_url_override;
376                    }
377                    (merged, provider_override) => {
378                        *merged = provider_override;
379                    }
380                }
381            }
382        }
383
384        Ok(settings)
385    }
386}
387
388fn merge<T: Copy>(target: &mut T, value: Option<T>) {
389    if let Some(value) = value {
390        *target = value;
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use gpui::{AppContext, BorrowAppContext};
397    use settings::SettingsStore;
398
399    use super::*;
400
401    #[gpui::test]
402    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
403        let store = settings::SettingsStore::test(cx);
404        cx.set_global(store);
405
406        // Settings default to gpt-4-turbo.
407        AssistantSettings::register(cx);
408        assert_eq!(
409            AssistantSettings::get_global(cx).provider,
410            AssistantProvider::OpenAi {
411                default_model: OpenAiModel::FourTurbo,
412                api_url: open_ai_url()
413            }
414        );
415
416        // Ensure backward-compatibility.
417        cx.update_global::<SettingsStore, _>(|store, cx| {
418            store
419                .set_user_settings(
420                    r#"{
421                        "assistant": {
422                            "openai_api_url": "test-url",
423                        }
424                    }"#,
425                    cx,
426                )
427                .unwrap();
428        });
429        assert_eq!(
430            AssistantSettings::get_global(cx).provider,
431            AssistantProvider::OpenAi {
432                default_model: OpenAiModel::FourTurbo,
433                api_url: "test-url".into()
434            }
435        );
436        cx.update_global::<SettingsStore, _>(|store, cx| {
437            store
438                .set_user_settings(
439                    r#"{
440                        "assistant": {
441                            "default_open_ai_model": "gpt-4-0613"
442                        }
443                    }"#,
444                    cx,
445                )
446                .unwrap();
447        });
448        assert_eq!(
449            AssistantSettings::get_global(cx).provider,
450            AssistantProvider::OpenAi {
451                default_model: OpenAiModel::Four,
452                api_url: open_ai_url()
453            }
454        );
455
456        // The new version supports setting a custom model when using zed.dev.
457        cx.update_global::<SettingsStore, _>(|store, cx| {
458            store
459                .set_user_settings(
460                    r#"{
461                        "assistant": {
462                            "version": "1",
463                            "provider": {
464                                "name": "zed.dev",
465                                "default_model": "custom"
466                            }
467                        }
468                    }"#,
469                    cx,
470                )
471                .unwrap();
472        });
473        assert_eq!(
474            AssistantSettings::get_global(cx).provider,
475            AssistantProvider::ZedDotDev {
476                default_model: ZedDotDevModel::Custom("custom".into())
477            }
478        );
479    }
480}