assistant_settings.rs

  1use std::fmt;
  2
  3pub use anthropic::Model as AnthropicModel;
  4use gpui::Pixels;
  5pub use open_ai::Model as OpenAiModel;
  6use schemars::{
  7    schema::{InstanceType, Metadata, Schema, SchemaObject},
  8    JsonSchema,
  9};
 10use serde::{
 11    de::{self, Visitor},
 12    Deserialize, Deserializer, Serialize, Serializer,
 13};
 14use settings::{Settings, SettingsSources};
 15use strum::{EnumIter, IntoEnumIterator};
 16
 17use crate::LanguageModel;
 18
 19#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 20pub enum ZedDotDevModel {
 21    Gpt3Point5Turbo,
 22    Gpt4,
 23    Gpt4Turbo,
 24    #[default]
 25    Gpt4Omni,
 26    Claude3Opus,
 27    Claude3Sonnet,
 28    Claude3Haiku,
 29    Custom(String),
 30}
 31
 32impl Serialize for ZedDotDevModel {
 33    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 34    where
 35        S: Serializer,
 36    {
 37        serializer.serialize_str(self.id())
 38    }
 39}
 40
 41impl<'de> Deserialize<'de> for ZedDotDevModel {
 42    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 43    where
 44        D: Deserializer<'de>,
 45    {
 46        struct ZedDotDevModelVisitor;
 47
 48        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 49            type Value = ZedDotDevModel;
 50
 51            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 52                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 53            }
 54
 55            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 56            where
 57                E: de::Error,
 58            {
 59                let model = ZedDotDevModel::iter()
 60                    .find(|model| model.id() == value)
 61                    .unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string()));
 62                Ok(model)
 63            }
 64        }
 65
 66        deserializer.deserialize_str(ZedDotDevModelVisitor)
 67    }
 68}
 69
 70impl JsonSchema for ZedDotDevModel {
 71    fn schema_name() -> String {
 72        "ZedDotDevModel".to_owned()
 73    }
 74
 75    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 76        let variants = ZedDotDevModel::iter()
 77            .filter_map(|model| {
 78                let id = model.id();
 79                if id.is_empty() {
 80                    None
 81                } else {
 82                    Some(id.to_string())
 83                }
 84            })
 85            .collect::<Vec<_>>();
 86        Schema::Object(SchemaObject {
 87            instance_type: Some(InstanceType::String.into()),
 88            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
 89            metadata: Some(Box::new(Metadata {
 90                title: Some("ZedDotDevModel".to_owned()),
 91                default: Some(ZedDotDevModel::default().id().into()),
 92                examples: variants.into_iter().map(Into::into).collect(),
 93                ..Default::default()
 94            })),
 95            ..Default::default()
 96        })
 97    }
 98}
 99
100impl ZedDotDevModel {
101    pub fn id(&self) -> &str {
102        match self {
103            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
104            Self::Gpt4 => "gpt-4",
105            Self::Gpt4Turbo => "gpt-4-turbo-preview",
106            Self::Gpt4Omni => "gpt-4o",
107            Self::Claude3Opus => "claude-3-opus",
108            Self::Claude3Sonnet => "claude-3-sonnet",
109            Self::Claude3Haiku => "claude-3-haiku",
110            Self::Custom(id) => id,
111        }
112    }
113
114    pub fn display_name(&self) -> &str {
115        match self {
116            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
117            Self::Gpt4 => "GPT 4",
118            Self::Gpt4Turbo => "GPT 4 Turbo",
119            Self::Gpt4Omni => "GPT 4 Omni",
120            Self::Claude3Opus => "Claude 3 Opus",
121            Self::Claude3Sonnet => "Claude 3 Sonnet",
122            Self::Claude3Haiku => "Claude 3 Haiku",
123            Self::Custom(id) => id.as_str(),
124        }
125    }
126
127    pub fn max_token_count(&self) -> usize {
128        match self {
129            Self::Gpt3Point5Turbo => 2048,
130            Self::Gpt4 => 4096,
131            Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
132            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
133            Self::Custom(_) => 4096, // TODO: Make this configurable
134        }
135    }
136}
137
138#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
139#[serde(rename_all = "snake_case")]
140pub enum AssistantDockPosition {
141    Left,
142    #[default]
143    Right,
144    Bottom,
145}
146
147#[derive(Debug, PartialEq)]
148pub enum AssistantProvider {
149    ZedDotDev {
150        model: ZedDotDevModel,
151    },
152    OpenAi {
153        model: OpenAiModel,
154        api_url: String,
155        low_speed_timeout_in_seconds: Option<u64>,
156    },
157    Anthropic {
158        model: AnthropicModel,
159        api_url: String,
160        low_speed_timeout_in_seconds: Option<u64>,
161    },
162}
163
164impl Default for AssistantProvider {
165    fn default() -> Self {
166        Self::OpenAi {
167            model: OpenAiModel::default(),
168            api_url: open_ai::OPEN_AI_API_URL.into(),
169            low_speed_timeout_in_seconds: None,
170        }
171    }
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
175#[serde(tag = "name", rename_all = "snake_case")]
176pub enum AssistantProviderContent {
177    #[serde(rename = "zed.dev")]
178    ZedDotDev {
179        default_model: Option<ZedDotDevModel>,
180    },
181    #[serde(rename = "openai")]
182    OpenAi {
183        default_model: Option<OpenAiModel>,
184        api_url: Option<String>,
185        low_speed_timeout_in_seconds: Option<u64>,
186    },
187    #[serde(rename = "anthropic")]
188    Anthropic {
189        default_model: Option<AnthropicModel>,
190        api_url: Option<String>,
191        low_speed_timeout_in_seconds: Option<u64>,
192    },
193}
194
195#[derive(Debug, Default)]
196pub struct AssistantSettings {
197    pub enabled: bool,
198    pub button: bool,
199    pub dock: AssistantDockPosition,
200    pub default_width: Pixels,
201    pub default_height: Pixels,
202    pub provider: AssistantProvider,
203}
204
205/// Assistant panel settings
206#[derive(Clone, Serialize, Deserialize, Debug)]
207#[serde(untagged)]
208pub enum AssistantSettingsContent {
209    Versioned(VersionedAssistantSettingsContent),
210    Legacy(LegacyAssistantSettingsContent),
211}
212
213impl JsonSchema for AssistantSettingsContent {
214    fn schema_name() -> String {
215        VersionedAssistantSettingsContent::schema_name()
216    }
217
218    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
219        VersionedAssistantSettingsContent::json_schema(gen)
220    }
221
222    fn is_referenceable() -> bool {
223        VersionedAssistantSettingsContent::is_referenceable()
224    }
225}
226
227impl Default for AssistantSettingsContent {
228    fn default() -> Self {
229        Self::Versioned(VersionedAssistantSettingsContent::default())
230    }
231}
232
233impl AssistantSettingsContent {
234    fn upgrade(&self) -> AssistantSettingsContentV1 {
235        match self {
236            AssistantSettingsContent::Versioned(settings) => match settings {
237                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
238            },
239            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
240                enabled: None,
241                button: settings.button,
242                dock: settings.dock,
243                default_width: settings.default_width,
244                default_height: settings.default_height,
245                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
246                    Some(AssistantProviderContent::OpenAi {
247                        default_model: settings.default_open_ai_model.clone(),
248                        api_url: Some(open_ai_api_url.clone()),
249                        low_speed_timeout_in_seconds: None,
250                    })
251                } else {
252                    settings.default_open_ai_model.clone().map(|open_ai_model| {
253                        AssistantProviderContent::OpenAi {
254                            default_model: Some(open_ai_model),
255                            api_url: None,
256                            low_speed_timeout_in_seconds: None,
257                        }
258                    })
259                },
260            },
261        }
262    }
263
264    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
265        match self {
266            AssistantSettingsContent::Versioned(settings) => match settings {
267                VersionedAssistantSettingsContent::V1(settings) => {
268                    settings.dock = Some(dock);
269                }
270            },
271            AssistantSettingsContent::Legacy(settings) => {
272                settings.dock = Some(dock);
273            }
274        }
275    }
276
277    pub fn set_model(&mut self, new_model: LanguageModel) {
278        match self {
279            AssistantSettingsContent::Versioned(settings) => match settings {
280                VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
281                    Some(AssistantProviderContent::ZedDotDev {
282                        default_model: model,
283                    }) => {
284                        if let LanguageModel::ZedDotDev(new_model) = new_model {
285                            *model = Some(new_model);
286                        }
287                    }
288                    Some(AssistantProviderContent::OpenAi {
289                        default_model: model,
290                        ..
291                    }) => {
292                        if let LanguageModel::OpenAi(new_model) = new_model {
293                            *model = Some(new_model);
294                        }
295                    }
296                    Some(AssistantProviderContent::Anthropic {
297                        default_model: model,
298                        ..
299                    }) => {
300                        if let LanguageModel::Anthropic(new_model) = new_model {
301                            *model = Some(new_model);
302                        }
303                    }
304                    provider => match new_model {
305                        LanguageModel::ZedDotDev(model) => {
306                            *provider = Some(AssistantProviderContent::ZedDotDev {
307                                default_model: Some(model),
308                            })
309                        }
310                        LanguageModel::OpenAi(model) => {
311                            *provider = Some(AssistantProviderContent::OpenAi {
312                                default_model: Some(model),
313                                api_url: None,
314                                low_speed_timeout_in_seconds: None,
315                            })
316                        }
317                        LanguageModel::Anthropic(model) => {
318                            *provider = Some(AssistantProviderContent::Anthropic {
319                                default_model: Some(model),
320                                api_url: None,
321                                low_speed_timeout_in_seconds: None,
322                            })
323                        }
324                    },
325                },
326            },
327            AssistantSettingsContent::Legacy(settings) => {
328                if let LanguageModel::OpenAi(model) = new_model {
329                    settings.default_open_ai_model = Some(model);
330                }
331            }
332        }
333    }
334}
335
336#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
337#[serde(tag = "version")]
338pub enum VersionedAssistantSettingsContent {
339    #[serde(rename = "1")]
340    V1(AssistantSettingsContentV1),
341}
342
343impl Default for VersionedAssistantSettingsContent {
344    fn default() -> Self {
345        Self::V1(AssistantSettingsContentV1 {
346            enabled: None,
347            button: None,
348            dock: None,
349            default_width: None,
350            default_height: None,
351            provider: None,
352        })
353    }
354}
355
356#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
357pub struct AssistantSettingsContentV1 {
358    /// Whether the Assistant is enabled.
359    ///
360    /// Default: true
361    enabled: Option<bool>,
362    /// Whether to show the assistant panel button in the status bar.
363    ///
364    /// Default: true
365    button: Option<bool>,
366    /// Where to dock the assistant.
367    ///
368    /// Default: right
369    dock: Option<AssistantDockPosition>,
370    /// Default width in pixels when the assistant is docked to the left or right.
371    ///
372    /// Default: 640
373    default_width: Option<f32>,
374    /// Default height in pixels when the assistant is docked to the bottom.
375    ///
376    /// Default: 320
377    default_height: Option<f32>,
378    /// The provider of the assistant service.
379    ///
380    /// This can either be the internal `zed.dev` service or an external `openai` service,
381    /// each with their respective default models and configurations.
382    provider: Option<AssistantProviderContent>,
383}
384
385#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
386pub struct LegacyAssistantSettingsContent {
387    /// Whether to show the assistant panel button in the status bar.
388    ///
389    /// Default: true
390    pub button: Option<bool>,
391    /// Where to dock the assistant.
392    ///
393    /// Default: right
394    pub dock: Option<AssistantDockPosition>,
395    /// Default width in pixels when the assistant is docked to the left or right.
396    ///
397    /// Default: 640
398    pub default_width: Option<f32>,
399    /// Default height in pixels when the assistant is docked to the bottom.
400    ///
401    /// Default: 320
402    pub default_height: Option<f32>,
403    /// The default OpenAI model to use when creating new contexts.
404    ///
405    /// Default: gpt-4-1106-preview
406    pub default_open_ai_model: Option<OpenAiModel>,
407    /// OpenAI API base URL to use when creating new contexts.
408    ///
409    /// Default: https://api.openai.com/v1
410    pub openai_api_url: Option<String>,
411}
412
413impl Settings for AssistantSettings {
414    const KEY: Option<&'static str> = Some("assistant");
415
416    type FileContent = AssistantSettingsContent;
417
418    fn load(
419        sources: SettingsSources<Self::FileContent>,
420        _: &mut gpui::AppContext,
421    ) -> anyhow::Result<Self> {
422        let mut settings = AssistantSettings::default();
423
424        for value in sources.defaults_and_customizations() {
425            let value = value.upgrade();
426            merge(&mut settings.enabled, value.enabled);
427            merge(&mut settings.button, value.button);
428            merge(&mut settings.dock, value.dock);
429            merge(
430                &mut settings.default_width,
431                value.default_width.map(Into::into),
432            );
433            merge(
434                &mut settings.default_height,
435                value.default_height.map(Into::into),
436            );
437            if let Some(provider) = value.provider.clone() {
438                match (&mut settings.provider, provider) {
439                    (
440                        AssistantProvider::ZedDotDev { model },
441                        AssistantProviderContent::ZedDotDev {
442                            default_model: model_override,
443                        },
444                    ) => {
445                        merge(model, model_override);
446                    }
447                    (
448                        AssistantProvider::OpenAi {
449                            model,
450                            api_url,
451                            low_speed_timeout_in_seconds,
452                        },
453                        AssistantProviderContent::OpenAi {
454                            default_model: model_override,
455                            api_url: api_url_override,
456                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
457                        },
458                    ) => {
459                        merge(model, model_override);
460                        merge(api_url, api_url_override);
461                        if let Some(low_speed_timeout_in_seconds_override) =
462                            low_speed_timeout_in_seconds_override
463                        {
464                            *low_speed_timeout_in_seconds =
465                                Some(low_speed_timeout_in_seconds_override);
466                        }
467                    }
468                    (
469                        AssistantProvider::Anthropic {
470                            model,
471                            api_url,
472                            low_speed_timeout_in_seconds,
473                        },
474                        AssistantProviderContent::Anthropic {
475                            default_model: model_override,
476                            api_url: api_url_override,
477                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
478                        },
479                    ) => {
480                        merge(model, model_override);
481                        merge(api_url, api_url_override);
482                        if let Some(low_speed_timeout_in_seconds_override) =
483                            low_speed_timeout_in_seconds_override
484                        {
485                            *low_speed_timeout_in_seconds =
486                                Some(low_speed_timeout_in_seconds_override);
487                        }
488                    }
489                    (provider, provider_override) => {
490                        *provider = match provider_override {
491                            AssistantProviderContent::ZedDotDev {
492                                default_model: model,
493                            } => AssistantProvider::ZedDotDev {
494                                model: model.unwrap_or_default(),
495                            },
496                            AssistantProviderContent::OpenAi {
497                                default_model: model,
498                                api_url,
499                                low_speed_timeout_in_seconds,
500                            } => AssistantProvider::OpenAi {
501                                model: model.unwrap_or_default(),
502                                api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
503                                low_speed_timeout_in_seconds,
504                            },
505                            AssistantProviderContent::Anthropic {
506                                default_model: model,
507                                api_url,
508                                low_speed_timeout_in_seconds,
509                            } => AssistantProvider::Anthropic {
510                                model: model.unwrap_or_default(),
511                                api_url: api_url
512                                    .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
513                                low_speed_timeout_in_seconds,
514                            },
515                        };
516                    }
517                }
518            }
519        }
520
521        Ok(settings)
522    }
523}
524
525fn merge<T>(target: &mut T, value: Option<T>) {
526    if let Some(value) = value {
527        *target = value;
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use gpui::{AppContext, UpdateGlobal};
534    use settings::SettingsStore;
535
536    use super::*;
537
538    #[gpui::test]
539    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
540        let store = settings::SettingsStore::test(cx);
541        cx.set_global(store);
542
543        // Settings default to gpt-4-turbo.
544        AssistantSettings::register(cx);
545        assert_eq!(
546            AssistantSettings::get_global(cx).provider,
547            AssistantProvider::OpenAi {
548                model: OpenAiModel::FourOmni,
549                api_url: open_ai::OPEN_AI_API_URL.into(),
550                low_speed_timeout_in_seconds: None,
551            }
552        );
553
554        // Ensure backward-compatibility.
555        SettingsStore::update_global(cx, |store, cx| {
556            store
557                .set_user_settings(
558                    r#"{
559                        "assistant": {
560                            "openai_api_url": "test-url",
561                        }
562                    }"#,
563                    cx,
564                )
565                .unwrap();
566        });
567        assert_eq!(
568            AssistantSettings::get_global(cx).provider,
569            AssistantProvider::OpenAi {
570                model: OpenAiModel::FourOmni,
571                api_url: "test-url".into(),
572                low_speed_timeout_in_seconds: None,
573            }
574        );
575        SettingsStore::update_global(cx, |store, cx| {
576            store
577                .set_user_settings(
578                    r#"{
579                        "assistant": {
580                            "default_open_ai_model": "gpt-4-0613"
581                        }
582                    }"#,
583                    cx,
584                )
585                .unwrap();
586        });
587        assert_eq!(
588            AssistantSettings::get_global(cx).provider,
589            AssistantProvider::OpenAi {
590                model: OpenAiModel::Four,
591                api_url: open_ai::OPEN_AI_API_URL.into(),
592                low_speed_timeout_in_seconds: None,
593            }
594        );
595
596        // The new version supports setting a custom model when using zed.dev.
597        SettingsStore::update_global(cx, |store, cx| {
598            store
599                .set_user_settings(
600                    r#"{
601                        "assistant": {
602                            "version": "1",
603                            "provider": {
604                                "name": "zed.dev",
605                                "default_model": "custom"
606                            }
607                        }
608                    }"#,
609                    cx,
610                )
611                .unwrap();
612        });
613        assert_eq!(
614            AssistantSettings::get_global(cx).provider,
615            AssistantProvider::ZedDotDev {
616                model: ZedDotDevModel::Custom("custom".into())
617            }
618        );
619    }
620}