assistant_settings.rs

  1use std::fmt;
  2
  3use gpui::Pixels;
  4pub use open_ai::Model as OpenAiModel;
  5use schemars::{
  6    schema::{InstanceType, Metadata, Schema, SchemaObject},
  7    JsonSchema,
  8};
  9use serde::{
 10    de::{self, Visitor},
 11    Deserialize, Deserializer, Serialize, Serializer,
 12};
 13use settings::{Settings, SettingsSources};
 14
 15#[derive(Clone, Debug, Default, PartialEq)]
 16pub enum ZedDotDevModel {
 17    Gpt3Point5Turbo,
 18    Gpt4,
 19    #[default]
 20    Gpt4Turbo,
 21    Claude3Opus,
 22    Claude3Sonnet,
 23    Claude3Haiku,
 24    Custom(String),
 25}
 26
 27impl Serialize for ZedDotDevModel {
 28    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 29    where
 30        S: Serializer,
 31    {
 32        serializer.serialize_str(self.id())
 33    }
 34}
 35
 36impl<'de> Deserialize<'de> for ZedDotDevModel {
 37    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 38    where
 39        D: Deserializer<'de>,
 40    {
 41        struct ZedDotDevModelVisitor;
 42
 43        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 44            type Value = ZedDotDevModel;
 45
 46            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 47                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 48            }
 49
 50            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 51            where
 52                E: de::Error,
 53            {
 54                match value {
 55                    "gpt-3.5-turbo" => Ok(ZedDotDevModel::Gpt3Point5Turbo),
 56                    "gpt-4" => Ok(ZedDotDevModel::Gpt4),
 57                    "gpt-4-turbo-preview" => Ok(ZedDotDevModel::Gpt4Turbo),
 58                    _ => Ok(ZedDotDevModel::Custom(value.to_owned())),
 59                }
 60            }
 61        }
 62
 63        deserializer.deserialize_str(ZedDotDevModelVisitor)
 64    }
 65}
 66
 67impl JsonSchema for ZedDotDevModel {
 68    fn schema_name() -> String {
 69        "ZedDotDevModel".to_owned()
 70    }
 71
 72    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 73        let variants = vec![
 74            "gpt-3.5-turbo".to_owned(),
 75            "gpt-4".to_owned(),
 76            "gpt-4-turbo-preview".to_owned(),
 77        ];
 78        Schema::Object(SchemaObject {
 79            instance_type: Some(InstanceType::String.into()),
 80            enum_values: Some(variants.into_iter().map(|s| s.into()).collect()),
 81            metadata: Some(Box::new(Metadata {
 82                title: Some("ZedDotDevModel".to_owned()),
 83                default: Some(serde_json::json!("gpt-4-turbo-preview")),
 84                examples: vec![
 85                    serde_json::json!("gpt-3.5-turbo"),
 86                    serde_json::json!("gpt-4"),
 87                    serde_json::json!("gpt-4-turbo-preview"),
 88                    serde_json::json!("custom-model-name"),
 89                ],
 90                ..Default::default()
 91            })),
 92            ..Default::default()
 93        })
 94    }
 95}
 96
 97impl ZedDotDevModel {
 98    pub fn id(&self) -> &str {
 99        match self {
100            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
101            Self::Gpt4 => "gpt-4",
102            Self::Gpt4Turbo => "gpt-4-turbo-preview",
103            Self::Claude3Opus => "claude-3-opus",
104            Self::Claude3Sonnet => "claude-3-sonnet",
105            Self::Claude3Haiku => "claude-3-haiku",
106            Self::Custom(id) => id,
107        }
108    }
109
110    pub fn display_name(&self) -> &str {
111        match self {
112            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
113            Self::Gpt4 => "GPT 4",
114            Self::Gpt4Turbo => "GPT 4 Turbo",
115            Self::Claude3Opus => "Claude 3 Opus",
116            Self::Claude3Sonnet => "Claude 3 Sonnet",
117            Self::Claude3Haiku => "Claude 3 Haiku",
118            Self::Custom(id) => id.as_str(),
119        }
120    }
121
122    pub fn max_token_count(&self) -> usize {
123        match self {
124            Self::Gpt3Point5Turbo => 2048,
125            Self::Gpt4 => 4096,
126            Self::Gpt4Turbo => 128000,
127            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
128            Self::Custom(_) => 4096, // TODO: Make this configurable
129        }
130    }
131}
132
133#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
134#[serde(rename_all = "snake_case")]
135pub enum AssistantDockPosition {
136    Left,
137    #[default]
138    Right,
139    Bottom,
140}
141
142#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
143#[serde(tag = "name", rename_all = "snake_case")]
144pub enum AssistantProvider {
145    #[serde(rename = "zed.dev")]
146    ZedDotDev {
147        #[serde(default)]
148        default_model: ZedDotDevModel,
149    },
150    #[serde(rename = "openai")]
151    OpenAi {
152        #[serde(default)]
153        default_model: OpenAiModel,
154        #[serde(default = "open_ai_url")]
155        api_url: String,
156        #[serde(default)]
157        low_speed_timeout_in_seconds: Option<u64>,
158    },
159}
160
161impl Default for AssistantProvider {
162    fn default() -> Self {
163        Self::ZedDotDev {
164            default_model: ZedDotDevModel::default(),
165        }
166    }
167}
168
169fn open_ai_url() -> String {
170    "https://api.openai.com/v1".into()
171}
172
173#[derive(Default, Debug, Deserialize, Serialize)]
174pub struct AssistantSettings {
175    pub enabled: bool,
176    pub button: bool,
177    pub dock: AssistantDockPosition,
178    pub default_width: Pixels,
179    pub default_height: Pixels,
180    pub provider: AssistantProvider,
181}
182
183/// Assistant panel settings
184#[derive(Clone, Serialize, Deserialize, Debug)]
185#[serde(untagged)]
186pub enum AssistantSettingsContent {
187    Versioned(VersionedAssistantSettingsContent),
188    Legacy(LegacyAssistantSettingsContent),
189}
190
191impl JsonSchema for AssistantSettingsContent {
192    fn schema_name() -> String {
193        VersionedAssistantSettingsContent::schema_name()
194    }
195
196    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
197        VersionedAssistantSettingsContent::json_schema(gen)
198    }
199
200    fn is_referenceable() -> bool {
201        VersionedAssistantSettingsContent::is_referenceable()
202    }
203}
204
205impl Default for AssistantSettingsContent {
206    fn default() -> Self {
207        Self::Versioned(VersionedAssistantSettingsContent::default())
208    }
209}
210
211impl AssistantSettingsContent {
212    fn upgrade(&self) -> AssistantSettingsContentV1 {
213        match self {
214            AssistantSettingsContent::Versioned(settings) => match settings {
215                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
216            },
217            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
218                enabled: None,
219                button: settings.button,
220                dock: settings.dock,
221                default_width: settings.default_width,
222                default_height: settings.default_height,
223                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
224                    Some(AssistantProvider::OpenAi {
225                        default_model: settings.default_open_ai_model.clone().unwrap_or_default(),
226                        api_url: open_ai_api_url.clone(),
227                        low_speed_timeout_in_seconds: None,
228                    })
229                } else {
230                    settings.default_open_ai_model.clone().map(|open_ai_model| {
231                        AssistantProvider::OpenAi {
232                            default_model: open_ai_model,
233                            api_url: open_ai_url(),
234                            low_speed_timeout_in_seconds: None,
235                        }
236                    })
237                },
238            },
239        }
240    }
241
242    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
243        match self {
244            AssistantSettingsContent::Versioned(settings) => match settings {
245                VersionedAssistantSettingsContent::V1(settings) => {
246                    settings.dock = Some(dock);
247                }
248            },
249            AssistantSettingsContent::Legacy(settings) => {
250                settings.dock = Some(dock);
251            }
252        }
253    }
254}
255
256#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
257#[serde(tag = "version")]
258pub enum VersionedAssistantSettingsContent {
259    #[serde(rename = "1")]
260    V1(AssistantSettingsContentV1),
261}
262
263impl Default for VersionedAssistantSettingsContent {
264    fn default() -> Self {
265        Self::V1(AssistantSettingsContentV1 {
266            enabled: None,
267            button: None,
268            dock: None,
269            default_width: None,
270            default_height: None,
271            provider: None,
272        })
273    }
274}
275
276#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
277pub struct AssistantSettingsContentV1 {
278    /// Whether the Assistant is enabled.
279    ///
280    /// Default: true
281    enabled: Option<bool>,
282    /// Whether to show the assistant panel button in the status bar.
283    ///
284    /// Default: true
285    button: Option<bool>,
286    /// Where to dock the assistant.
287    ///
288    /// Default: right
289    dock: Option<AssistantDockPosition>,
290    /// Default width in pixels when the assistant is docked to the left or right.
291    ///
292    /// Default: 640
293    default_width: Option<f32>,
294    /// Default height in pixels when the assistant is docked to the bottom.
295    ///
296    /// Default: 320
297    default_height: Option<f32>,
298    /// The provider of the assistant service.
299    ///
300    /// This can either be the internal `zed.dev` service or an external `openai` service,
301    /// each with their respective default models and configurations.
302    provider: Option<AssistantProvider>,
303}
304
305#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
306pub struct LegacyAssistantSettingsContent {
307    /// Whether to show the assistant panel button in the status bar.
308    ///
309    /// Default: true
310    pub button: Option<bool>,
311    /// Where to dock the assistant.
312    ///
313    /// Default: right
314    pub dock: Option<AssistantDockPosition>,
315    /// Default width in pixels when the assistant is docked to the left or right.
316    ///
317    /// Default: 640
318    pub default_width: Option<f32>,
319    /// Default height in pixels when the assistant is docked to the bottom.
320    ///
321    /// Default: 320
322    pub default_height: Option<f32>,
323    /// The default OpenAI model to use when starting new conversations.
324    ///
325    /// Default: gpt-4-1106-preview
326    pub default_open_ai_model: Option<OpenAiModel>,
327    /// OpenAI API base URL to use when starting new conversations.
328    ///
329    /// Default: https://api.openai.com/v1
330    pub openai_api_url: Option<String>,
331}
332
333impl Settings for AssistantSettings {
334    const KEY: Option<&'static str> = Some("assistant");
335
336    type FileContent = AssistantSettingsContent;
337
338    fn load(
339        sources: SettingsSources<Self::FileContent>,
340        _: &mut gpui::AppContext,
341    ) -> anyhow::Result<Self> {
342        let mut settings = AssistantSettings::default();
343
344        for value in sources.defaults_and_customizations() {
345            let value = value.upgrade();
346            merge(&mut settings.enabled, value.enabled);
347            merge(&mut settings.button, value.button);
348            merge(&mut settings.dock, value.dock);
349            merge(
350                &mut settings.default_width,
351                value.default_width.map(Into::into),
352            );
353            merge(
354                &mut settings.default_height,
355                value.default_height.map(Into::into),
356            );
357            if let Some(provider) = value.provider.clone() {
358                match (&mut settings.provider, provider) {
359                    (
360                        AssistantProvider::ZedDotDev { default_model },
361                        AssistantProvider::ZedDotDev {
362                            default_model: default_model_override,
363                        },
364                    ) => {
365                        *default_model = default_model_override;
366                    }
367                    (
368                        AssistantProvider::OpenAi {
369                            default_model,
370                            api_url,
371                            low_speed_timeout_in_seconds,
372                        },
373                        AssistantProvider::OpenAi {
374                            default_model: default_model_override,
375                            api_url: api_url_override,
376                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
377                        },
378                    ) => {
379                        *default_model = default_model_override;
380                        *api_url = api_url_override;
381                        *low_speed_timeout_in_seconds = low_speed_timeout_in_seconds_override;
382                    }
383                    (merged, provider_override) => {
384                        *merged = provider_override;
385                    }
386                }
387            }
388        }
389
390        Ok(settings)
391    }
392}
393
394fn merge<T: Copy>(target: &mut T, value: Option<T>) {
395    if let Some(value) = value {
396        *target = value;
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use gpui::{AppContext, BorrowAppContext};
403    use settings::SettingsStore;
404
405    use super::*;
406
407    #[gpui::test]
408    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
409        let store = settings::SettingsStore::test(cx);
410        cx.set_global(store);
411
412        // Settings default to gpt-4-turbo.
413        AssistantSettings::register(cx);
414        assert_eq!(
415            AssistantSettings::get_global(cx).provider,
416            AssistantProvider::OpenAi {
417                default_model: OpenAiModel::FourTurbo,
418                api_url: open_ai_url(),
419                low_speed_timeout_in_seconds: None,
420            }
421        );
422
423        // Ensure backward-compatibility.
424        cx.update_global::<SettingsStore, _>(|store, cx| {
425            store
426                .set_user_settings(
427                    r#"{
428                        "assistant": {
429                            "openai_api_url": "test-url",
430                        }
431                    }"#,
432                    cx,
433                )
434                .unwrap();
435        });
436        assert_eq!(
437            AssistantSettings::get_global(cx).provider,
438            AssistantProvider::OpenAi {
439                default_model: OpenAiModel::FourTurbo,
440                api_url: "test-url".into(),
441                low_speed_timeout_in_seconds: None,
442            }
443        );
444        cx.update_global::<SettingsStore, _>(|store, cx| {
445            store
446                .set_user_settings(
447                    r#"{
448                        "assistant": {
449                            "default_open_ai_model": "gpt-4-0613"
450                        }
451                    }"#,
452                    cx,
453                )
454                .unwrap();
455        });
456        assert_eq!(
457            AssistantSettings::get_global(cx).provider,
458            AssistantProvider::OpenAi {
459                default_model: OpenAiModel::Four,
460                api_url: open_ai_url(),
461                low_speed_timeout_in_seconds: None,
462            }
463        );
464
465        // The new version supports setting a custom model when using zed.dev.
466        cx.update_global::<SettingsStore, _>(|store, cx| {
467            store
468                .set_user_settings(
469                    r#"{
470                        "assistant": {
471                            "version": "1",
472                            "provider": {
473                                "name": "zed.dev",
474                                "default_model": "custom"
475                            }
476                        }
477                    }"#,
478                    cx,
479                )
480                .unwrap();
481        });
482        assert_eq!(
483            AssistantSettings::get_global(cx).provider,
484            AssistantProvider::ZedDotDev {
485                default_model: ZedDotDevModel::Custom("custom".into())
486            }
487        );
488    }
489}