assistant_settings.rs

  1use std::fmt;
  2
  3pub use anthropic::Model as AnthropicModel;
  4use gpui::Pixels;
  5pub use ollama::Model as OllamaModel;
  6pub use open_ai::Model as OpenAiModel;
  7use schemars::{
  8    schema::{InstanceType, Metadata, Schema, SchemaObject},
  9    JsonSchema,
 10};
 11use serde::{
 12    de::{self, Visitor},
 13    Deserialize, Deserializer, Serialize, Serializer,
 14};
 15use settings::{Settings, SettingsSources};
 16use strum::{EnumIter, IntoEnumIterator};
 17
 18use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
 19
 20#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 21pub enum CloudModel {
 22    Gpt3Point5Turbo,
 23    Gpt4,
 24    Gpt4Turbo,
 25    #[default]
 26    Gpt4Omni,
 27    Claude3_5Sonnet,
 28    Claude3Opus,
 29    Claude3Sonnet,
 30    Claude3Haiku,
 31    Custom(String),
 32}
 33
 34impl Serialize for CloudModel {
 35    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 36    where
 37        S: Serializer,
 38    {
 39        serializer.serialize_str(self.id())
 40    }
 41}
 42
 43impl<'de> Deserialize<'de> for CloudModel {
 44    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 45    where
 46        D: Deserializer<'de>,
 47    {
 48        struct ZedDotDevModelVisitor;
 49
 50        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 51            type Value = CloudModel;
 52
 53            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 54                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 55            }
 56
 57            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 58            where
 59                E: de::Error,
 60            {
 61                let model = CloudModel::iter()
 62                    .find(|model| model.id() == value)
 63                    .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
 64                Ok(model)
 65            }
 66        }
 67
 68        deserializer.deserialize_str(ZedDotDevModelVisitor)
 69    }
 70}
 71
 72impl JsonSchema for CloudModel {
 73    fn schema_name() -> String {
 74        "ZedDotDevModel".to_owned()
 75    }
 76
 77    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 78        let variants = CloudModel::iter()
 79            .filter_map(|model| {
 80                let id = model.id();
 81                if id.is_empty() {
 82                    None
 83                } else {
 84                    Some(id.to_string())
 85                }
 86            })
 87            .collect::<Vec<_>>();
 88        Schema::Object(SchemaObject {
 89            instance_type: Some(InstanceType::String.into()),
 90            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
 91            metadata: Some(Box::new(Metadata {
 92                title: Some("ZedDotDevModel".to_owned()),
 93                default: Some(CloudModel::default().id().into()),
 94                examples: variants.into_iter().map(Into::into).collect(),
 95                ..Default::default()
 96            })),
 97            ..Default::default()
 98        })
 99    }
100}
101
102impl CloudModel {
103    pub fn id(&self) -> &str {
104        match self {
105            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
106            Self::Gpt4 => "gpt-4",
107            Self::Gpt4Turbo => "gpt-4-turbo-preview",
108            Self::Gpt4Omni => "gpt-4o",
109            Self::Claude3_5Sonnet => "claude-3-5-sonnet",
110            Self::Claude3Opus => "claude-3-opus",
111            Self::Claude3Sonnet => "claude-3-sonnet",
112            Self::Claude3Haiku => "claude-3-haiku",
113            Self::Custom(id) => id,
114        }
115    }
116
117    pub fn display_name(&self) -> &str {
118        match self {
119            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
120            Self::Gpt4 => "GPT 4",
121            Self::Gpt4Turbo => "GPT 4 Turbo",
122            Self::Gpt4Omni => "GPT 4 Omni",
123            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
124            Self::Claude3Opus => "Claude 3 Opus",
125            Self::Claude3Sonnet => "Claude 3 Sonnet",
126            Self::Claude3Haiku => "Claude 3 Haiku",
127            Self::Custom(id) => id.as_str(),
128        }
129    }
130
131    pub fn max_token_count(&self) -> usize {
132        match self {
133            Self::Gpt3Point5Turbo => 2048,
134            Self::Gpt4 => 4096,
135            Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
136            Self::Claude3_5Sonnet
137            | Self::Claude3Opus
138            | Self::Claude3Sonnet
139            | Self::Claude3Haiku => 200000,
140            Self::Custom(_) => 4096, // TODO: Make this configurable
141        }
142    }
143
144    pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
145        match self {
146            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
147                preprocess_anthropic_request(request)
148            }
149            _ => {}
150        }
151    }
152}
153
154#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
155#[serde(rename_all = "snake_case")]
156pub enum AssistantDockPosition {
157    Left,
158    #[default]
159    Right,
160    Bottom,
161}
162
163#[derive(Debug, PartialEq)]
164pub enum AssistantProvider {
165    ZedDotDev {
166        model: CloudModel,
167    },
168    OpenAi {
169        model: OpenAiModel,
170        api_url: String,
171        low_speed_timeout_in_seconds: Option<u64>,
172        available_models: Vec<OpenAiModel>,
173    },
174    Anthropic {
175        model: AnthropicModel,
176        api_url: String,
177        low_speed_timeout_in_seconds: Option<u64>,
178    },
179    Ollama {
180        model: OllamaModel,
181        api_url: String,
182        low_speed_timeout_in_seconds: Option<u64>,
183    },
184}
185
186impl Default for AssistantProvider {
187    fn default() -> Self {
188        Self::OpenAi {
189            model: OpenAiModel::default(),
190            api_url: open_ai::OPEN_AI_API_URL.into(),
191            low_speed_timeout_in_seconds: None,
192            available_models: Default::default(),
193        }
194    }
195}
196
197#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
198#[serde(tag = "name", rename_all = "snake_case")]
199pub enum AssistantProviderContent {
200    #[serde(rename = "zed.dev")]
201    ZedDotDev { default_model: Option<CloudModel> },
202    #[serde(rename = "openai")]
203    OpenAi {
204        default_model: Option<OpenAiModel>,
205        api_url: Option<String>,
206        low_speed_timeout_in_seconds: Option<u64>,
207        available_models: Option<Vec<OpenAiModel>>,
208    },
209    #[serde(rename = "anthropic")]
210    Anthropic {
211        default_model: Option<AnthropicModel>,
212        api_url: Option<String>,
213        low_speed_timeout_in_seconds: Option<u64>,
214    },
215    #[serde(rename = "ollama")]
216    Ollama {
217        default_model: Option<OllamaModel>,
218        api_url: Option<String>,
219        low_speed_timeout_in_seconds: Option<u64>,
220    },
221}
222
223#[derive(Debug, Default)]
224pub struct AssistantSettings {
225    pub enabled: bool,
226    pub button: bool,
227    pub dock: AssistantDockPosition,
228    pub default_width: Pixels,
229    pub default_height: Pixels,
230    pub provider: AssistantProvider,
231}
232
233/// Assistant panel settings
234#[derive(Clone, Serialize, Deserialize, Debug)]
235#[serde(untagged)]
236pub enum AssistantSettingsContent {
237    Versioned(VersionedAssistantSettingsContent),
238    Legacy(LegacyAssistantSettingsContent),
239}
240
241impl JsonSchema for AssistantSettingsContent {
242    fn schema_name() -> String {
243        VersionedAssistantSettingsContent::schema_name()
244    }
245
246    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
247        VersionedAssistantSettingsContent::json_schema(gen)
248    }
249
250    fn is_referenceable() -> bool {
251        VersionedAssistantSettingsContent::is_referenceable()
252    }
253}
254
255impl Default for AssistantSettingsContent {
256    fn default() -> Self {
257        Self::Versioned(VersionedAssistantSettingsContent::default())
258    }
259}
260
261impl AssistantSettingsContent {
262    fn upgrade(&self) -> AssistantSettingsContentV1 {
263        match self {
264            AssistantSettingsContent::Versioned(settings) => match settings {
265                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
266            },
267            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
268                enabled: None,
269                button: settings.button,
270                dock: settings.dock,
271                default_width: settings.default_width,
272                default_height: settings.default_height,
273                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
274                    Some(AssistantProviderContent::OpenAi {
275                        default_model: settings.default_open_ai_model.clone(),
276                        api_url: Some(open_ai_api_url.clone()),
277                        low_speed_timeout_in_seconds: None,
278                        available_models: Some(Default::default()),
279                    })
280                } else {
281                    settings.default_open_ai_model.clone().map(|open_ai_model| {
282                        AssistantProviderContent::OpenAi {
283                            default_model: Some(open_ai_model),
284                            api_url: None,
285                            low_speed_timeout_in_seconds: None,
286                            available_models: Some(Default::default()),
287                        }
288                    })
289                },
290            },
291        }
292    }
293
294    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
295        match self {
296            AssistantSettingsContent::Versioned(settings) => match settings {
297                VersionedAssistantSettingsContent::V1(settings) => {
298                    settings.dock = Some(dock);
299                }
300            },
301            AssistantSettingsContent::Legacy(settings) => {
302                settings.dock = Some(dock);
303            }
304        }
305    }
306
307    pub fn set_model(&mut self, new_model: LanguageModel) {
308        match self {
309            AssistantSettingsContent::Versioned(settings) => match settings {
310                VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
311                    Some(AssistantProviderContent::ZedDotDev {
312                        default_model: model,
313                    }) => {
314                        if let LanguageModel::Cloud(new_model) = new_model {
315                            *model = Some(new_model);
316                        }
317                    }
318                    Some(AssistantProviderContent::OpenAi {
319                        default_model: model,
320                        ..
321                    }) => {
322                        if let LanguageModel::OpenAi(new_model) = new_model {
323                            *model = Some(new_model);
324                        }
325                    }
326                    Some(AssistantProviderContent::Anthropic {
327                        default_model: model,
328                        ..
329                    }) => {
330                        if let LanguageModel::Anthropic(new_model) = new_model {
331                            *model = Some(new_model);
332                        }
333                    }
334                    Some(AssistantProviderContent::Ollama {
335                        default_model: model,
336                        ..
337                    }) => {
338                        if let LanguageModel::Ollama(new_model) = new_model {
339                            *model = Some(new_model);
340                        }
341                    }
342                    provider => match new_model {
343                        LanguageModel::Cloud(model) => {
344                            *provider = Some(AssistantProviderContent::ZedDotDev {
345                                default_model: Some(model),
346                            })
347                        }
348                        LanguageModel::OpenAi(model) => {
349                            *provider = Some(AssistantProviderContent::OpenAi {
350                                default_model: Some(model),
351                                api_url: None,
352                                low_speed_timeout_in_seconds: None,
353                                available_models: Some(Default::default()),
354                            })
355                        }
356                        LanguageModel::Anthropic(model) => {
357                            *provider = Some(AssistantProviderContent::Anthropic {
358                                default_model: Some(model),
359                                api_url: None,
360                                low_speed_timeout_in_seconds: None,
361                            })
362                        }
363                        LanguageModel::Ollama(model) => {
364                            *provider = Some(AssistantProviderContent::Ollama {
365                                default_model: Some(model),
366                                api_url: None,
367                                low_speed_timeout_in_seconds: None,
368                            })
369                        }
370                    },
371                },
372            },
373            AssistantSettingsContent::Legacy(settings) => {
374                if let LanguageModel::OpenAi(model) = new_model {
375                    settings.default_open_ai_model = Some(model);
376                }
377            }
378        }
379    }
380}
381
382#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
383#[serde(tag = "version")]
384pub enum VersionedAssistantSettingsContent {
385    #[serde(rename = "1")]
386    V1(AssistantSettingsContentV1),
387}
388
389impl Default for VersionedAssistantSettingsContent {
390    fn default() -> Self {
391        Self::V1(AssistantSettingsContentV1 {
392            enabled: None,
393            button: None,
394            dock: None,
395            default_width: None,
396            default_height: None,
397            provider: None,
398        })
399    }
400}
401
402#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
403pub struct AssistantSettingsContentV1 {
404    /// Whether the Assistant is enabled.
405    ///
406    /// Default: true
407    enabled: Option<bool>,
408    /// Whether to show the assistant panel button in the status bar.
409    ///
410    /// Default: true
411    button: Option<bool>,
412    /// Where to dock the assistant.
413    ///
414    /// Default: right
415    dock: Option<AssistantDockPosition>,
416    /// Default width in pixels when the assistant is docked to the left or right.
417    ///
418    /// Default: 640
419    default_width: Option<f32>,
420    /// Default height in pixels when the assistant is docked to the bottom.
421    ///
422    /// Default: 320
423    default_height: Option<f32>,
424    /// The provider of the assistant service.
425    ///
426    /// This can either be the internal `zed.dev` service or an external `openai` service,
427    /// each with their respective default models and configurations.
428    provider: Option<AssistantProviderContent>,
429}
430
431#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
432pub struct LegacyAssistantSettingsContent {
433    /// Whether to show the assistant panel button in the status bar.
434    ///
435    /// Default: true
436    pub button: Option<bool>,
437    /// Where to dock the assistant.
438    ///
439    /// Default: right
440    pub dock: Option<AssistantDockPosition>,
441    /// Default width in pixels when the assistant is docked to the left or right.
442    ///
443    /// Default: 640
444    pub default_width: Option<f32>,
445    /// Default height in pixels when the assistant is docked to the bottom.
446    ///
447    /// Default: 320
448    pub default_height: Option<f32>,
449    /// The default OpenAI model to use when creating new contexts.
450    ///
451    /// Default: gpt-4-1106-preview
452    pub default_open_ai_model: Option<OpenAiModel>,
453    /// OpenAI API base URL to use when creating new contexts.
454    ///
455    /// Default: https://api.openai.com/v1
456    pub openai_api_url: Option<String>,
457}
458
459impl Settings for AssistantSettings {
460    const KEY: Option<&'static str> = Some("assistant");
461
462    type FileContent = AssistantSettingsContent;
463
464    fn load(
465        sources: SettingsSources<Self::FileContent>,
466        _: &mut gpui::AppContext,
467    ) -> anyhow::Result<Self> {
468        let mut settings = AssistantSettings::default();
469
470        for value in sources.defaults_and_customizations() {
471            let value = value.upgrade();
472            merge(&mut settings.enabled, value.enabled);
473            merge(&mut settings.button, value.button);
474            merge(&mut settings.dock, value.dock);
475            merge(
476                &mut settings.default_width,
477                value.default_width.map(Into::into),
478            );
479            merge(
480                &mut settings.default_height,
481                value.default_height.map(Into::into),
482            );
483            if let Some(provider) = value.provider.clone() {
484                match (&mut settings.provider, provider) {
485                    (
486                        AssistantProvider::ZedDotDev { model },
487                        AssistantProviderContent::ZedDotDev {
488                            default_model: model_override,
489                        },
490                    ) => {
491                        merge(model, model_override);
492                    }
493                    (
494                        AssistantProvider::OpenAi {
495                            model,
496                            api_url,
497                            low_speed_timeout_in_seconds,
498                            available_models,
499                        },
500                        AssistantProviderContent::OpenAi {
501                            default_model: model_override,
502                            api_url: api_url_override,
503                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
504                            available_models: available_models_override,
505                        },
506                    ) => {
507                        merge(model, model_override);
508                        merge(api_url, api_url_override);
509                        merge(available_models, available_models_override);
510                        if let Some(low_speed_timeout_in_seconds_override) =
511                            low_speed_timeout_in_seconds_override
512                        {
513                            *low_speed_timeout_in_seconds =
514                                Some(low_speed_timeout_in_seconds_override);
515                        }
516                    }
517                    (
518                        AssistantProvider::Ollama {
519                            model,
520                            api_url,
521                            low_speed_timeout_in_seconds,
522                        },
523                        AssistantProviderContent::Ollama {
524                            default_model: model_override,
525                            api_url: api_url_override,
526                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
527                        },
528                    ) => {
529                        merge(model, model_override);
530                        merge(api_url, api_url_override);
531                        if let Some(low_speed_timeout_in_seconds_override) =
532                            low_speed_timeout_in_seconds_override
533                        {
534                            *low_speed_timeout_in_seconds =
535                                Some(low_speed_timeout_in_seconds_override);
536                        }
537                    }
538                    (
539                        AssistantProvider::Anthropic {
540                            model,
541                            api_url,
542                            low_speed_timeout_in_seconds,
543                        },
544                        AssistantProviderContent::Anthropic {
545                            default_model: model_override,
546                            api_url: api_url_override,
547                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
548                        },
549                    ) => {
550                        merge(model, model_override);
551                        merge(api_url, api_url_override);
552                        if let Some(low_speed_timeout_in_seconds_override) =
553                            low_speed_timeout_in_seconds_override
554                        {
555                            *low_speed_timeout_in_seconds =
556                                Some(low_speed_timeout_in_seconds_override);
557                        }
558                    }
559                    (provider, provider_override) => {
560                        *provider = match provider_override {
561                            AssistantProviderContent::ZedDotDev {
562                                default_model: model,
563                            } => AssistantProvider::ZedDotDev {
564                                model: model.unwrap_or_default(),
565                            },
566                            AssistantProviderContent::OpenAi {
567                                default_model: model,
568                                api_url,
569                                low_speed_timeout_in_seconds,
570                                available_models,
571                            } => AssistantProvider::OpenAi {
572                                model: model.unwrap_or_default(),
573                                api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
574                                low_speed_timeout_in_seconds,
575                                available_models: available_models.unwrap_or_default(),
576                            },
577                            AssistantProviderContent::Anthropic {
578                                default_model: model,
579                                api_url,
580                                low_speed_timeout_in_seconds,
581                            } => AssistantProvider::Anthropic {
582                                model: model.unwrap_or_default(),
583                                api_url: api_url
584                                    .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
585                                low_speed_timeout_in_seconds,
586                            },
587                            AssistantProviderContent::Ollama {
588                                default_model: model,
589                                api_url,
590                                low_speed_timeout_in_seconds,
591                            } => AssistantProvider::Ollama {
592                                model: model.unwrap_or_default(),
593                                api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
594                                low_speed_timeout_in_seconds,
595                            },
596                        };
597                    }
598                }
599            }
600        }
601
602        Ok(settings)
603    }
604}
605
606fn merge<T>(target: &mut T, value: Option<T>) {
607    if let Some(value) = value {
608        *target = value;
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use gpui::{AppContext, UpdateGlobal};
615    use settings::SettingsStore;
616
617    use super::*;
618
619    #[gpui::test]
620    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
621        let store = settings::SettingsStore::test(cx);
622        cx.set_global(store);
623
624        // Settings default to gpt-4-turbo.
625        AssistantSettings::register(cx);
626        assert_eq!(
627            AssistantSettings::get_global(cx).provider,
628            AssistantProvider::OpenAi {
629                model: OpenAiModel::FourOmni,
630                api_url: open_ai::OPEN_AI_API_URL.into(),
631                low_speed_timeout_in_seconds: None,
632                available_models: Default::default(),
633            }
634        );
635
636        // Ensure backward-compatibility.
637        SettingsStore::update_global(cx, |store, cx| {
638            store
639                .set_user_settings(
640                    r#"{
641                        "assistant": {
642                            "openai_api_url": "test-url",
643                        }
644                    }"#,
645                    cx,
646                )
647                .unwrap();
648        });
649        assert_eq!(
650            AssistantSettings::get_global(cx).provider,
651            AssistantProvider::OpenAi {
652                model: OpenAiModel::FourOmni,
653                api_url: "test-url".into(),
654                low_speed_timeout_in_seconds: None,
655                available_models: Default::default(),
656            }
657        );
658        SettingsStore::update_global(cx, |store, cx| {
659            store
660                .set_user_settings(
661                    r#"{
662                        "assistant": {
663                            "default_open_ai_model": "gpt-4-0613"
664                        }
665                    }"#,
666                    cx,
667                )
668                .unwrap();
669        });
670        assert_eq!(
671            AssistantSettings::get_global(cx).provider,
672            AssistantProvider::OpenAi {
673                model: OpenAiModel::Four,
674                api_url: open_ai::OPEN_AI_API_URL.into(),
675                low_speed_timeout_in_seconds: None,
676                available_models: Default::default(),
677            }
678        );
679
680        // The new version supports setting a custom model when using zed.dev.
681        SettingsStore::update_global(cx, |store, cx| {
682            store
683                .set_user_settings(
684                    r#"{
685                        "assistant": {
686                            "version": "1",
687                            "provider": {
688                                "name": "zed.dev",
689                                "default_model": "custom"
690                            }
691                        }
692                    }"#,
693                    cx,
694                )
695                .unwrap();
696        });
697        assert_eq!(
698            AssistantSettings::get_global(cx).provider,
699            AssistantProvider::ZedDotDev {
700                model: CloudModel::Custom("custom".into())
701            }
702        );
703    }
704}