assistant_settings.rs

  1use std::fmt;
  2
  3pub use anthropic::Model as AnthropicModel;
  4use gpui::Pixels;
  5pub use ollama::Model as OllamaModel;
  6pub use open_ai::Model as OpenAiModel;
  7use schemars::{
  8    schema::{InstanceType, Metadata, Schema, SchemaObject},
  9    JsonSchema,
 10};
 11use serde::{
 12    de::{self, Visitor},
 13    Deserialize, Deserializer, Serialize, Serializer,
 14};
 15use settings::{Settings, SettingsSources};
 16use strum::{EnumIter, IntoEnumIterator};
 17
 18use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
 19
 20#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 21pub enum CloudModel {
 22    Gpt3Point5Turbo,
 23    Gpt4,
 24    Gpt4Turbo,
 25    #[default]
 26    Gpt4Omni,
 27    Claude3_5Sonnet,
 28    Claude3Opus,
 29    Claude3Sonnet,
 30    Claude3Haiku,
 31    Custom(String),
 32}
 33
 34impl Serialize for CloudModel {
 35    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 36    where
 37        S: Serializer,
 38    {
 39        serializer.serialize_str(self.id())
 40    }
 41}
 42
 43impl<'de> Deserialize<'de> for CloudModel {
 44    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 45    where
 46        D: Deserializer<'de>,
 47    {
 48        struct ZedDotDevModelVisitor;
 49
 50        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 51            type Value = CloudModel;
 52
 53            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 54                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 55            }
 56
 57            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 58            where
 59                E: de::Error,
 60            {
 61                let model = CloudModel::iter()
 62                    .find(|model| model.id() == value)
 63                    .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
 64                Ok(model)
 65            }
 66        }
 67
 68        deserializer.deserialize_str(ZedDotDevModelVisitor)
 69    }
 70}
 71
 72impl JsonSchema for CloudModel {
 73    fn schema_name() -> String {
 74        "ZedDotDevModel".to_owned()
 75    }
 76
 77    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 78        let variants = CloudModel::iter()
 79            .filter_map(|model| {
 80                let id = model.id();
 81                if id.is_empty() {
 82                    None
 83                } else {
 84                    Some(id.to_string())
 85                }
 86            })
 87            .collect::<Vec<_>>();
 88        Schema::Object(SchemaObject {
 89            instance_type: Some(InstanceType::String.into()),
 90            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
 91            metadata: Some(Box::new(Metadata {
 92                title: Some("ZedDotDevModel".to_owned()),
 93                default: Some(CloudModel::default().id().into()),
 94                examples: variants.into_iter().map(Into::into).collect(),
 95                ..Default::default()
 96            })),
 97            ..Default::default()
 98        })
 99    }
100}
101
102impl CloudModel {
103    pub fn id(&self) -> &str {
104        match self {
105            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
106            Self::Gpt4 => "gpt-4",
107            Self::Gpt4Turbo => "gpt-4-turbo-preview",
108            Self::Gpt4Omni => "gpt-4o",
109            Self::Claude3_5Sonnet => "claude-3-5-sonnet",
110            Self::Claude3Opus => "claude-3-opus",
111            Self::Claude3Sonnet => "claude-3-sonnet",
112            Self::Claude3Haiku => "claude-3-haiku",
113            Self::Custom(id) => id,
114        }
115    }
116
117    pub fn display_name(&self) -> &str {
118        match self {
119            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
120            Self::Gpt4 => "GPT 4",
121            Self::Gpt4Turbo => "GPT 4 Turbo",
122            Self::Gpt4Omni => "GPT 4 Omni",
123            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
124            Self::Claude3Opus => "Claude 3 Opus",
125            Self::Claude3Sonnet => "Claude 3 Sonnet",
126            Self::Claude3Haiku => "Claude 3 Haiku",
127            Self::Custom(id) => id.as_str(),
128        }
129    }
130
131    pub fn max_token_count(&self) -> usize {
132        match self {
133            Self::Gpt3Point5Turbo => 2048,
134            Self::Gpt4 => 4096,
135            Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
136            Self::Claude3_5Sonnet
137            | Self::Claude3Opus
138            | Self::Claude3Sonnet
139            | Self::Claude3Haiku => 200000,
140            Self::Custom(_) => 4096, // TODO: Make this configurable
141        }
142    }
143
144    pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
145        match self {
146            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
147                preprocess_anthropic_request(request)
148            }
149            _ => {}
150        }
151    }
152}
153
154#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
155#[serde(rename_all = "snake_case")]
156pub enum AssistantDockPosition {
157    Left,
158    #[default]
159    Right,
160    Bottom,
161}
162
163#[derive(Debug, PartialEq)]
164pub enum AssistantProvider {
165    ZedDotDev {
166        model: CloudModel,
167    },
168    OpenAi {
169        model: OpenAiModel,
170        api_url: String,
171        low_speed_timeout_in_seconds: Option<u64>,
172    },
173    Anthropic {
174        model: AnthropicModel,
175        api_url: String,
176        low_speed_timeout_in_seconds: Option<u64>,
177    },
178    Ollama {
179        model: OllamaModel,
180        api_url: String,
181        low_speed_timeout_in_seconds: Option<u64>,
182    },
183}
184
185impl Default for AssistantProvider {
186    fn default() -> Self {
187        Self::OpenAi {
188            model: OpenAiModel::default(),
189            api_url: open_ai::OPEN_AI_API_URL.into(),
190            low_speed_timeout_in_seconds: None,
191        }
192    }
193}
194
195#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
196#[serde(tag = "name", rename_all = "snake_case")]
197pub enum AssistantProviderContent {
198    #[serde(rename = "zed.dev")]
199    ZedDotDev { default_model: Option<CloudModel> },
200    #[serde(rename = "openai")]
201    OpenAi {
202        default_model: Option<OpenAiModel>,
203        api_url: Option<String>,
204        low_speed_timeout_in_seconds: Option<u64>,
205    },
206    #[serde(rename = "anthropic")]
207    Anthropic {
208        default_model: Option<AnthropicModel>,
209        api_url: Option<String>,
210        low_speed_timeout_in_seconds: Option<u64>,
211    },
212    #[serde(rename = "ollama")]
213    Ollama {
214        default_model: Option<OllamaModel>,
215        api_url: Option<String>,
216        low_speed_timeout_in_seconds: Option<u64>,
217    },
218}
219
220#[derive(Debug, Default)]
221pub struct AssistantSettings {
222    pub enabled: bool,
223    pub button: bool,
224    pub dock: AssistantDockPosition,
225    pub default_width: Pixels,
226    pub default_height: Pixels,
227    pub provider: AssistantProvider,
228}
229
230/// Assistant panel settings
231#[derive(Clone, Serialize, Deserialize, Debug)]
232#[serde(untagged)]
233pub enum AssistantSettingsContent {
234    Versioned(VersionedAssistantSettingsContent),
235    Legacy(LegacyAssistantSettingsContent),
236}
237
238impl JsonSchema for AssistantSettingsContent {
239    fn schema_name() -> String {
240        VersionedAssistantSettingsContent::schema_name()
241    }
242
243    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
244        VersionedAssistantSettingsContent::json_schema(gen)
245    }
246
247    fn is_referenceable() -> bool {
248        VersionedAssistantSettingsContent::is_referenceable()
249    }
250}
251
252impl Default for AssistantSettingsContent {
253    fn default() -> Self {
254        Self::Versioned(VersionedAssistantSettingsContent::default())
255    }
256}
257
258impl AssistantSettingsContent {
259    fn upgrade(&self) -> AssistantSettingsContentV1 {
260        match self {
261            AssistantSettingsContent::Versioned(settings) => match settings {
262                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
263            },
264            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
265                enabled: None,
266                button: settings.button,
267                dock: settings.dock,
268                default_width: settings.default_width,
269                default_height: settings.default_height,
270                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
271                    Some(AssistantProviderContent::OpenAi {
272                        default_model: settings.default_open_ai_model.clone(),
273                        api_url: Some(open_ai_api_url.clone()),
274                        low_speed_timeout_in_seconds: None,
275                    })
276                } else {
277                    settings.default_open_ai_model.clone().map(|open_ai_model| {
278                        AssistantProviderContent::OpenAi {
279                            default_model: Some(open_ai_model),
280                            api_url: None,
281                            low_speed_timeout_in_seconds: None,
282                        }
283                    })
284                },
285            },
286        }
287    }
288
289    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
290        match self {
291            AssistantSettingsContent::Versioned(settings) => match settings {
292                VersionedAssistantSettingsContent::V1(settings) => {
293                    settings.dock = Some(dock);
294                }
295            },
296            AssistantSettingsContent::Legacy(settings) => {
297                settings.dock = Some(dock);
298            }
299        }
300    }
301
302    pub fn set_model(&mut self, new_model: LanguageModel) {
303        match self {
304            AssistantSettingsContent::Versioned(settings) => match settings {
305                VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
306                    Some(AssistantProviderContent::ZedDotDev {
307                        default_model: model,
308                    }) => {
309                        if let LanguageModel::Cloud(new_model) = new_model {
310                            *model = Some(new_model);
311                        }
312                    }
313                    Some(AssistantProviderContent::OpenAi {
314                        default_model: model,
315                        ..
316                    }) => {
317                        if let LanguageModel::OpenAi(new_model) = new_model {
318                            *model = Some(new_model);
319                        }
320                    }
321                    Some(AssistantProviderContent::Anthropic {
322                        default_model: model,
323                        ..
324                    }) => {
325                        if let LanguageModel::Anthropic(new_model) = new_model {
326                            *model = Some(new_model);
327                        }
328                    }
329                    Some(AssistantProviderContent::Ollama {
330                        default_model: model,
331                        ..
332                    }) => {
333                        if let LanguageModel::Ollama(new_model) = new_model {
334                            *model = Some(new_model);
335                        }
336                    }
337                    provider => match new_model {
338                        LanguageModel::Cloud(model) => {
339                            *provider = Some(AssistantProviderContent::ZedDotDev {
340                                default_model: Some(model),
341                            })
342                        }
343                        LanguageModel::OpenAi(model) => {
344                            *provider = Some(AssistantProviderContent::OpenAi {
345                                default_model: Some(model),
346                                api_url: None,
347                                low_speed_timeout_in_seconds: None,
348                            })
349                        }
350                        LanguageModel::Anthropic(model) => {
351                            *provider = Some(AssistantProviderContent::Anthropic {
352                                default_model: Some(model),
353                                api_url: None,
354                                low_speed_timeout_in_seconds: None,
355                            })
356                        }
357                        LanguageModel::Ollama(model) => {
358                            *provider = Some(AssistantProviderContent::Ollama {
359                                default_model: Some(model),
360                                api_url: None,
361                                low_speed_timeout_in_seconds: None,
362                            })
363                        }
364                    },
365                },
366            },
367            AssistantSettingsContent::Legacy(settings) => {
368                if let LanguageModel::OpenAi(model) = new_model {
369                    settings.default_open_ai_model = Some(model);
370                }
371            }
372        }
373    }
374}
375
376#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
377#[serde(tag = "version")]
378pub enum VersionedAssistantSettingsContent {
379    #[serde(rename = "1")]
380    V1(AssistantSettingsContentV1),
381}
382
383impl Default for VersionedAssistantSettingsContent {
384    fn default() -> Self {
385        Self::V1(AssistantSettingsContentV1 {
386            enabled: None,
387            button: None,
388            dock: None,
389            default_width: None,
390            default_height: None,
391            provider: None,
392        })
393    }
394}
395
396#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
397pub struct AssistantSettingsContentV1 {
398    /// Whether the Assistant is enabled.
399    ///
400    /// Default: true
401    enabled: Option<bool>,
402    /// Whether to show the assistant panel button in the status bar.
403    ///
404    /// Default: true
405    button: Option<bool>,
406    /// Where to dock the assistant.
407    ///
408    /// Default: right
409    dock: Option<AssistantDockPosition>,
410    /// Default width in pixels when the assistant is docked to the left or right.
411    ///
412    /// Default: 640
413    default_width: Option<f32>,
414    /// Default height in pixels when the assistant is docked to the bottom.
415    ///
416    /// Default: 320
417    default_height: Option<f32>,
418    /// The provider of the assistant service.
419    ///
420    /// This can either be the internal `zed.dev` service or an external `openai` service,
421    /// each with their respective default models and configurations.
422    provider: Option<AssistantProviderContent>,
423}
424
425#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
426pub struct LegacyAssistantSettingsContent {
427    /// Whether to show the assistant panel button in the status bar.
428    ///
429    /// Default: true
430    pub button: Option<bool>,
431    /// Where to dock the assistant.
432    ///
433    /// Default: right
434    pub dock: Option<AssistantDockPosition>,
435    /// Default width in pixels when the assistant is docked to the left or right.
436    ///
437    /// Default: 640
438    pub default_width: Option<f32>,
439    /// Default height in pixels when the assistant is docked to the bottom.
440    ///
441    /// Default: 320
442    pub default_height: Option<f32>,
443    /// The default OpenAI model to use when creating new contexts.
444    ///
445    /// Default: gpt-4-1106-preview
446    pub default_open_ai_model: Option<OpenAiModel>,
447    /// OpenAI API base URL to use when creating new contexts.
448    ///
449    /// Default: https://api.openai.com/v1
450    pub openai_api_url: Option<String>,
451}
452
453impl Settings for AssistantSettings {
454    const KEY: Option<&'static str> = Some("assistant");
455
456    type FileContent = AssistantSettingsContent;
457
458    fn load(
459        sources: SettingsSources<Self::FileContent>,
460        _: &mut gpui::AppContext,
461    ) -> anyhow::Result<Self> {
462        let mut settings = AssistantSettings::default();
463
464        for value in sources.defaults_and_customizations() {
465            let value = value.upgrade();
466            merge(&mut settings.enabled, value.enabled);
467            merge(&mut settings.button, value.button);
468            merge(&mut settings.dock, value.dock);
469            merge(
470                &mut settings.default_width,
471                value.default_width.map(Into::into),
472            );
473            merge(
474                &mut settings.default_height,
475                value.default_height.map(Into::into),
476            );
477            if let Some(provider) = value.provider.clone() {
478                match (&mut settings.provider, provider) {
479                    (
480                        AssistantProvider::ZedDotDev { model },
481                        AssistantProviderContent::ZedDotDev {
482                            default_model: model_override,
483                        },
484                    ) => {
485                        merge(model, model_override);
486                    }
487                    (
488                        AssistantProvider::OpenAi {
489                            model,
490                            api_url,
491                            low_speed_timeout_in_seconds,
492                        },
493                        AssistantProviderContent::OpenAi {
494                            default_model: model_override,
495                            api_url: api_url_override,
496                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
497                        },
498                    ) => {
499                        merge(model, model_override);
500                        merge(api_url, api_url_override);
501                        if let Some(low_speed_timeout_in_seconds_override) =
502                            low_speed_timeout_in_seconds_override
503                        {
504                            *low_speed_timeout_in_seconds =
505                                Some(low_speed_timeout_in_seconds_override);
506                        }
507                    }
508                    (
509                        AssistantProvider::Ollama {
510                            model,
511                            api_url,
512                            low_speed_timeout_in_seconds,
513                        },
514                        AssistantProviderContent::Ollama {
515                            default_model: model_override,
516                            api_url: api_url_override,
517                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
518                        },
519                    ) => {
520                        merge(model, model_override);
521                        merge(api_url, api_url_override);
522                        if let Some(low_speed_timeout_in_seconds_override) =
523                            low_speed_timeout_in_seconds_override
524                        {
525                            *low_speed_timeout_in_seconds =
526                                Some(low_speed_timeout_in_seconds_override);
527                        }
528                    }
529                    (
530                        AssistantProvider::Anthropic {
531                            model,
532                            api_url,
533                            low_speed_timeout_in_seconds,
534                        },
535                        AssistantProviderContent::Anthropic {
536                            default_model: model_override,
537                            api_url: api_url_override,
538                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
539                        },
540                    ) => {
541                        merge(model, model_override);
542                        merge(api_url, api_url_override);
543                        if let Some(low_speed_timeout_in_seconds_override) =
544                            low_speed_timeout_in_seconds_override
545                        {
546                            *low_speed_timeout_in_seconds =
547                                Some(low_speed_timeout_in_seconds_override);
548                        }
549                    }
550                    (provider, provider_override) => {
551                        *provider = match provider_override {
552                            AssistantProviderContent::ZedDotDev {
553                                default_model: model,
554                            } => AssistantProvider::ZedDotDev {
555                                model: model.unwrap_or_default(),
556                            },
557                            AssistantProviderContent::OpenAi {
558                                default_model: model,
559                                api_url,
560                                low_speed_timeout_in_seconds,
561                            } => AssistantProvider::OpenAi {
562                                model: model.unwrap_or_default(),
563                                api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
564                                low_speed_timeout_in_seconds,
565                            },
566                            AssistantProviderContent::Anthropic {
567                                default_model: model,
568                                api_url,
569                                low_speed_timeout_in_seconds,
570                            } => AssistantProvider::Anthropic {
571                                model: model.unwrap_or_default(),
572                                api_url: api_url
573                                    .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
574                                low_speed_timeout_in_seconds,
575                            },
576                            AssistantProviderContent::Ollama {
577                                default_model: model,
578                                api_url,
579                                low_speed_timeout_in_seconds,
580                            } => AssistantProvider::Ollama {
581                                model: model.unwrap_or_default(),
582                                api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
583                                low_speed_timeout_in_seconds,
584                            },
585                        };
586                    }
587                }
588            }
589        }
590
591        Ok(settings)
592    }
593}
594
595fn merge<T>(target: &mut T, value: Option<T>) {
596    if let Some(value) = value {
597        *target = value;
598    }
599}
600
601#[cfg(test)]
602mod tests {
603    use gpui::{AppContext, UpdateGlobal};
604    use settings::SettingsStore;
605
606    use super::*;
607
608    #[gpui::test]
609    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
610        let store = settings::SettingsStore::test(cx);
611        cx.set_global(store);
612
613        // Settings default to gpt-4-turbo.
614        AssistantSettings::register(cx);
615        assert_eq!(
616            AssistantSettings::get_global(cx).provider,
617            AssistantProvider::OpenAi {
618                model: OpenAiModel::FourOmni,
619                api_url: open_ai::OPEN_AI_API_URL.into(),
620                low_speed_timeout_in_seconds: None,
621            }
622        );
623
624        // Ensure backward-compatibility.
625        SettingsStore::update_global(cx, |store, cx| {
626            store
627                .set_user_settings(
628                    r#"{
629                        "assistant": {
630                            "openai_api_url": "test-url",
631                        }
632                    }"#,
633                    cx,
634                )
635                .unwrap();
636        });
637        assert_eq!(
638            AssistantSettings::get_global(cx).provider,
639            AssistantProvider::OpenAi {
640                model: OpenAiModel::FourOmni,
641                api_url: "test-url".into(),
642                low_speed_timeout_in_seconds: None,
643            }
644        );
645        SettingsStore::update_global(cx, |store, cx| {
646            store
647                .set_user_settings(
648                    r#"{
649                        "assistant": {
650                            "default_open_ai_model": "gpt-4-0613"
651                        }
652                    }"#,
653                    cx,
654                )
655                .unwrap();
656        });
657        assert_eq!(
658            AssistantSettings::get_global(cx).provider,
659            AssistantProvider::OpenAi {
660                model: OpenAiModel::Four,
661                api_url: open_ai::OPEN_AI_API_URL.into(),
662                low_speed_timeout_in_seconds: None,
663            }
664        );
665
666        // The new version supports setting a custom model when using zed.dev.
667        SettingsStore::update_global(cx, |store, cx| {
668            store
669                .set_user_settings(
670                    r#"{
671                        "assistant": {
672                            "version": "1",
673                            "provider": {
674                                "name": "zed.dev",
675                                "default_model": "custom"
676                            }
677                        }
678                    }"#,
679                    cx,
680                )
681                .unwrap();
682        });
683        assert_eq!(
684            AssistantSettings::get_global(cx).provider,
685            AssistantProvider::ZedDotDev {
686                model: CloudModel::Custom("custom".into())
687            }
688        );
689    }
690}