assistant_settings.rs

  1use std::fmt;
  2
  3pub use anthropic::Model as AnthropicModel;
  4use gpui::Pixels;
  5pub use ollama::Model as OllamaModel;
  6pub use open_ai::Model as OpenAiModel;
  7use schemars::{
  8    schema::{InstanceType, Metadata, Schema, SchemaObject},
  9    JsonSchema,
 10};
 11use serde::{
 12    de::{self, Visitor},
 13    Deserialize, Deserializer, Serialize, Serializer,
 14};
 15use settings::{Settings, SettingsSources};
 16use strum::{EnumIter, IntoEnumIterator};
 17
 18use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
 19
 20#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 21pub enum CloudModel {
 22    Gpt3Point5Turbo,
 23    Gpt4,
 24    Gpt4Turbo,
 25    #[default]
 26    Gpt4Omni,
 27    Claude3_5Sonnet,
 28    Claude3Opus,
 29    Claude3Sonnet,
 30    Claude3Haiku,
 31    Custom(String),
 32}
 33
 34impl Serialize for CloudModel {
 35    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 36    where
 37        S: Serializer,
 38    {
 39        serializer.serialize_str(self.id())
 40    }
 41}
 42
 43impl<'de> Deserialize<'de> for CloudModel {
 44    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 45    where
 46        D: Deserializer<'de>,
 47    {
 48        struct ZedDotDevModelVisitor;
 49
 50        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 51            type Value = CloudModel;
 52
 53            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 54                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 55            }
 56
 57            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 58            where
 59                E: de::Error,
 60            {
 61                let model = CloudModel::iter()
 62                    .find(|model| model.id() == value)
 63                    .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
 64                Ok(model)
 65            }
 66        }
 67
 68        deserializer.deserialize_str(ZedDotDevModelVisitor)
 69    }
 70}
 71
 72impl JsonSchema for CloudModel {
 73    fn schema_name() -> String {
 74        "ZedDotDevModel".to_owned()
 75    }
 76
 77    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 78        let variants = CloudModel::iter()
 79            .filter_map(|model| {
 80                let id = model.id();
 81                if id.is_empty() {
 82                    None
 83                } else {
 84                    Some(id.to_string())
 85                }
 86            })
 87            .collect::<Vec<_>>();
 88        Schema::Object(SchemaObject {
 89            instance_type: Some(InstanceType::String.into()),
 90            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
 91            metadata: Some(Box::new(Metadata {
 92                title: Some("ZedDotDevModel".to_owned()),
 93                default: Some(CloudModel::default().id().into()),
 94                examples: variants.into_iter().map(Into::into).collect(),
 95                ..Default::default()
 96            })),
 97            ..Default::default()
 98        })
 99    }
100}
101
102impl CloudModel {
103    pub fn id(&self) -> &str {
104        match self {
105            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
106            Self::Gpt4 => "gpt-4",
107            Self::Gpt4Turbo => "gpt-4-turbo-preview",
108            Self::Gpt4Omni => "gpt-4o",
109            Self::Claude3_5Sonnet => "claude-3-5-sonnet",
110            Self::Claude3Opus => "claude-3-opus",
111            Self::Claude3Sonnet => "claude-3-sonnet",
112            Self::Claude3Haiku => "claude-3-haiku",
113            Self::Custom(id) => id,
114        }
115    }
116
117    pub fn display_name(&self) -> &str {
118        match self {
119            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
120            Self::Gpt4 => "GPT 4",
121            Self::Gpt4Turbo => "GPT 4 Turbo",
122            Self::Gpt4Omni => "GPT 4 Omni",
123            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
124            Self::Claude3Opus => "Claude 3 Opus",
125            Self::Claude3Sonnet => "Claude 3 Sonnet",
126            Self::Claude3Haiku => "Claude 3 Haiku",
127            Self::Custom(id) => id.as_str(),
128        }
129    }
130
131    pub fn max_token_count(&self) -> usize {
132        match self {
133            Self::Gpt3Point5Turbo => 2048,
134            Self::Gpt4 => 4096,
135            Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
136            Self::Claude3_5Sonnet
137            | Self::Claude3Opus
138            | Self::Claude3Sonnet
139            | Self::Claude3Haiku => 200000,
140            Self::Custom(_) => 4096, // TODO: Make this configurable
141        }
142    }
143
144    pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
145        match self {
146            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
147                preprocess_anthropic_request(request)
148            }
149            _ => {}
150        }
151    }
152}
153
154#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
155#[serde(rename_all = "snake_case")]
156pub enum AssistantDockPosition {
157    Left,
158    #[default]
159    Right,
160    Bottom,
161}
162
163#[derive(Debug, PartialEq)]
164pub enum AssistantProvider {
165    ZedDotDev {
166        model: CloudModel,
167    },
168    OpenAi {
169        model: OpenAiModel,
170        api_url: String,
171        low_speed_timeout_in_seconds: Option<u64>,
172    },
173    Anthropic {
174        model: AnthropicModel,
175        api_url: String,
176        low_speed_timeout_in_seconds: Option<u64>,
177    },
178    Ollama {
179        model: OllamaModel,
180        api_url: String,
181        low_speed_timeout_in_seconds: Option<u64>,
182    },
183}
184
185impl Default for AssistantProvider {
186    fn default() -> Self {
187        Self::OpenAi {
188            model: OpenAiModel::default(),
189            api_url: open_ai::OPEN_AI_API_URL.into(),
190            low_speed_timeout_in_seconds: None,
191        }
192    }
193}
194
195#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
196#[serde(tag = "name", rename_all = "snake_case")]
197pub enum AssistantProviderContent {
198    #[serde(rename = "zed.dev")]
199    ZedDotDev { default_model: Option<CloudModel> },
200    #[serde(rename = "openai")]
201    OpenAi {
202        default_model: Option<OpenAiModel>,
203        api_url: Option<String>,
204        low_speed_timeout_in_seconds: Option<u64>,
205    },
206    #[serde(rename = "anthropic")]
207    Anthropic {
208        default_model: Option<AnthropicModel>,
209        api_url: Option<String>,
210        low_speed_timeout_in_seconds: Option<u64>,
211    },
212    #[serde(rename = "ollama")]
213    Ollama {
214        default_model: Option<OllamaModel>,
215        api_url: Option<String>,
216        low_speed_timeout_in_seconds: Option<u64>,
217    },
218}
219
220#[derive(Debug, Default)]
221pub struct AssistantSettings {
222    pub enabled: bool,
223    pub button: bool,
224    pub dock: AssistantDockPosition,
225    pub default_width: Pixels,
226    pub default_height: Pixels,
227    pub provider: AssistantProvider,
228}
229
230/// Assistant panel settings
231#[derive(Clone, Serialize, Deserialize, Debug)]
232#[serde(untagged)]
233pub enum AssistantSettingsContent {
234    Versioned(VersionedAssistantSettingsContent),
235    Legacy(LegacyAssistantSettingsContent),
236}
237
238impl JsonSchema for AssistantSettingsContent {
239    fn schema_name() -> String {
240        VersionedAssistantSettingsContent::schema_name()
241    }
242
243    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
244        VersionedAssistantSettingsContent::json_schema(gen)
245    }
246
247    fn is_referenceable() -> bool {
248        VersionedAssistantSettingsContent::is_referenceable()
249    }
250}
251
252impl Default for AssistantSettingsContent {
253    fn default() -> Self {
254        Self::Versioned(VersionedAssistantSettingsContent::default())
255    }
256}
257
258impl AssistantSettingsContent {
259    fn upgrade(&self) -> AssistantSettingsContentV1 {
260        match self {
261            AssistantSettingsContent::Versioned(settings) => match settings {
262                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
263            },
264            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
265                enabled: None,
266                button: settings.button,
267                dock: settings.dock,
268                default_width: settings.default_width,
269                default_height: settings.default_height,
270                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
271                    Some(AssistantProviderContent::OpenAi {
272                        default_model: settings.default_open_ai_model.clone(),
273                        api_url: Some(open_ai_api_url.clone()),
274                        low_speed_timeout_in_seconds: None,
275                    })
276                } else {
277                    settings.default_open_ai_model.clone().map(|open_ai_model| {
278                        AssistantProviderContent::OpenAi {
279                            default_model: Some(open_ai_model),
280                            api_url: None,
281                            low_speed_timeout_in_seconds: None,
282                        }
283                    })
284                },
285            },
286        }
287    }
288
289    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
290        match self {
291            AssistantSettingsContent::Versioned(settings) => match settings {
292                VersionedAssistantSettingsContent::V1(settings) => {
293                    settings.dock = Some(dock);
294                }
295            },
296            AssistantSettingsContent::Legacy(settings) => {
297                settings.dock = Some(dock);
298            }
299        }
300    }
301
302    pub fn set_model(&mut self, new_model: LanguageModel) {
303        match self {
304            AssistantSettingsContent::Versioned(settings) => match settings {
305                VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
306                    Some(AssistantProviderContent::ZedDotDev {
307                        default_model: model,
308                    }) => {
309                        if let LanguageModel::Cloud(new_model) = new_model {
310                            *model = Some(new_model);
311                        }
312                    }
313                    Some(AssistantProviderContent::OpenAi {
314                        default_model: model,
315                        ..
316                    }) => {
317                        if let LanguageModel::OpenAi(new_model) = new_model {
318                            *model = Some(new_model);
319                        }
320                    }
321                    Some(AssistantProviderContent::Anthropic {
322                        default_model: model,
323                        ..
324                    }) => {
325                        if let LanguageModel::Anthropic(new_model) = new_model {
326                            *model = Some(new_model);
327                        }
328                    }
329                    provider => match new_model {
330                        LanguageModel::Cloud(model) => {
331                            *provider = Some(AssistantProviderContent::ZedDotDev {
332                                default_model: Some(model),
333                            })
334                        }
335                        LanguageModel::OpenAi(model) => {
336                            *provider = Some(AssistantProviderContent::OpenAi {
337                                default_model: Some(model),
338                                api_url: None,
339                                low_speed_timeout_in_seconds: None,
340                            })
341                        }
342                        LanguageModel::Anthropic(model) => {
343                            *provider = Some(AssistantProviderContent::Anthropic {
344                                default_model: Some(model),
345                                api_url: None,
346                                low_speed_timeout_in_seconds: None,
347                            })
348                        }
349                        LanguageModel::Ollama(model) => {
350                            *provider = Some(AssistantProviderContent::Ollama {
351                                default_model: Some(model),
352                                api_url: None,
353                                low_speed_timeout_in_seconds: None,
354                            })
355                        }
356                    },
357                },
358            },
359            AssistantSettingsContent::Legacy(settings) => {
360                if let LanguageModel::OpenAi(model) = new_model {
361                    settings.default_open_ai_model = Some(model);
362                }
363            }
364        }
365    }
366}
367
368#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
369#[serde(tag = "version")]
370pub enum VersionedAssistantSettingsContent {
371    #[serde(rename = "1")]
372    V1(AssistantSettingsContentV1),
373}
374
375impl Default for VersionedAssistantSettingsContent {
376    fn default() -> Self {
377        Self::V1(AssistantSettingsContentV1 {
378            enabled: None,
379            button: None,
380            dock: None,
381            default_width: None,
382            default_height: None,
383            provider: None,
384        })
385    }
386}
387
388#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
389pub struct AssistantSettingsContentV1 {
390    /// Whether the Assistant is enabled.
391    ///
392    /// Default: true
393    enabled: Option<bool>,
394    /// Whether to show the assistant panel button in the status bar.
395    ///
396    /// Default: true
397    button: Option<bool>,
398    /// Where to dock the assistant.
399    ///
400    /// Default: right
401    dock: Option<AssistantDockPosition>,
402    /// Default width in pixels when the assistant is docked to the left or right.
403    ///
404    /// Default: 640
405    default_width: Option<f32>,
406    /// Default height in pixels when the assistant is docked to the bottom.
407    ///
408    /// Default: 320
409    default_height: Option<f32>,
410    /// The provider of the assistant service.
411    ///
412    /// This can either be the internal `zed.dev` service or an external `openai` service,
413    /// each with their respective default models and configurations.
414    provider: Option<AssistantProviderContent>,
415}
416
417#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
418pub struct LegacyAssistantSettingsContent {
419    /// Whether to show the assistant panel button in the status bar.
420    ///
421    /// Default: true
422    pub button: Option<bool>,
423    /// Where to dock the assistant.
424    ///
425    /// Default: right
426    pub dock: Option<AssistantDockPosition>,
427    /// Default width in pixels when the assistant is docked to the left or right.
428    ///
429    /// Default: 640
430    pub default_width: Option<f32>,
431    /// Default height in pixels when the assistant is docked to the bottom.
432    ///
433    /// Default: 320
434    pub default_height: Option<f32>,
435    /// The default OpenAI model to use when creating new contexts.
436    ///
437    /// Default: gpt-4-1106-preview
438    pub default_open_ai_model: Option<OpenAiModel>,
439    /// OpenAI API base URL to use when creating new contexts.
440    ///
441    /// Default: https://api.openai.com/v1
442    pub openai_api_url: Option<String>,
443}
444
445impl Settings for AssistantSettings {
446    const KEY: Option<&'static str> = Some("assistant");
447
448    type FileContent = AssistantSettingsContent;
449
450    fn load(
451        sources: SettingsSources<Self::FileContent>,
452        _: &mut gpui::AppContext,
453    ) -> anyhow::Result<Self> {
454        let mut settings = AssistantSettings::default();
455
456        for value in sources.defaults_and_customizations() {
457            let value = value.upgrade();
458            merge(&mut settings.enabled, value.enabled);
459            merge(&mut settings.button, value.button);
460            merge(&mut settings.dock, value.dock);
461            merge(
462                &mut settings.default_width,
463                value.default_width.map(Into::into),
464            );
465            merge(
466                &mut settings.default_height,
467                value.default_height.map(Into::into),
468            );
469            if let Some(provider) = value.provider.clone() {
470                match (&mut settings.provider, provider) {
471                    (
472                        AssistantProvider::ZedDotDev { model },
473                        AssistantProviderContent::ZedDotDev {
474                            default_model: model_override,
475                        },
476                    ) => {
477                        merge(model, model_override);
478                    }
479                    (
480                        AssistantProvider::OpenAi {
481                            model,
482                            api_url,
483                            low_speed_timeout_in_seconds,
484                        },
485                        AssistantProviderContent::OpenAi {
486                            default_model: model_override,
487                            api_url: api_url_override,
488                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
489                        },
490                    ) => {
491                        merge(model, model_override);
492                        merge(api_url, api_url_override);
493                        if let Some(low_speed_timeout_in_seconds_override) =
494                            low_speed_timeout_in_seconds_override
495                        {
496                            *low_speed_timeout_in_seconds =
497                                Some(low_speed_timeout_in_seconds_override);
498                        }
499                    }
500                    (
501                        AssistantProvider::Ollama {
502                            model,
503                            api_url,
504                            low_speed_timeout_in_seconds,
505                        },
506                        AssistantProviderContent::Ollama {
507                            default_model: model_override,
508                            api_url: api_url_override,
509                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
510                        },
511                    ) => {
512                        merge(model, model_override);
513                        merge(api_url, api_url_override);
514                        if let Some(low_speed_timeout_in_seconds_override) =
515                            low_speed_timeout_in_seconds_override
516                        {
517                            *low_speed_timeout_in_seconds =
518                                Some(low_speed_timeout_in_seconds_override);
519                        }
520                    }
521                    (
522                        AssistantProvider::Anthropic {
523                            model,
524                            api_url,
525                            low_speed_timeout_in_seconds,
526                        },
527                        AssistantProviderContent::Anthropic {
528                            default_model: model_override,
529                            api_url: api_url_override,
530                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
531                        },
532                    ) => {
533                        merge(model, model_override);
534                        merge(api_url, api_url_override);
535                        if let Some(low_speed_timeout_in_seconds_override) =
536                            low_speed_timeout_in_seconds_override
537                        {
538                            *low_speed_timeout_in_seconds =
539                                Some(low_speed_timeout_in_seconds_override);
540                        }
541                    }
542                    (provider, provider_override) => {
543                        *provider = match provider_override {
544                            AssistantProviderContent::ZedDotDev {
545                                default_model: model,
546                            } => AssistantProvider::ZedDotDev {
547                                model: model.unwrap_or_default(),
548                            },
549                            AssistantProviderContent::OpenAi {
550                                default_model: model,
551                                api_url,
552                                low_speed_timeout_in_seconds,
553                            } => AssistantProvider::OpenAi {
554                                model: model.unwrap_or_default(),
555                                api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
556                                low_speed_timeout_in_seconds,
557                            },
558                            AssistantProviderContent::Anthropic {
559                                default_model: model,
560                                api_url,
561                                low_speed_timeout_in_seconds,
562                            } => AssistantProvider::Anthropic {
563                                model: model.unwrap_or_default(),
564                                api_url: api_url
565                                    .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
566                                low_speed_timeout_in_seconds,
567                            },
568                            AssistantProviderContent::Ollama {
569                                default_model: model,
570                                api_url,
571                                low_speed_timeout_in_seconds,
572                            } => AssistantProvider::Ollama {
573                                model: model.unwrap_or_default(),
574                                api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
575                                low_speed_timeout_in_seconds,
576                            },
577                        };
578                    }
579                }
580            }
581        }
582
583        Ok(settings)
584    }
585}
586
587fn merge<T>(target: &mut T, value: Option<T>) {
588    if let Some(value) = value {
589        *target = value;
590    }
591}
592
593#[cfg(test)]
594mod tests {
595    use gpui::{AppContext, UpdateGlobal};
596    use settings::SettingsStore;
597
598    use super::*;
599
600    #[gpui::test]
601    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
602        let store = settings::SettingsStore::test(cx);
603        cx.set_global(store);
604
605        // Settings default to gpt-4-turbo.
606        AssistantSettings::register(cx);
607        assert_eq!(
608            AssistantSettings::get_global(cx).provider,
609            AssistantProvider::OpenAi {
610                model: OpenAiModel::FourOmni,
611                api_url: open_ai::OPEN_AI_API_URL.into(),
612                low_speed_timeout_in_seconds: None,
613            }
614        );
615
616        // Ensure backward-compatibility.
617        SettingsStore::update_global(cx, |store, cx| {
618            store
619                .set_user_settings(
620                    r#"{
621                        "assistant": {
622                            "openai_api_url": "test-url",
623                        }
624                    }"#,
625                    cx,
626                )
627                .unwrap();
628        });
629        assert_eq!(
630            AssistantSettings::get_global(cx).provider,
631            AssistantProvider::OpenAi {
632                model: OpenAiModel::FourOmni,
633                api_url: "test-url".into(),
634                low_speed_timeout_in_seconds: None,
635            }
636        );
637        SettingsStore::update_global(cx, |store, cx| {
638            store
639                .set_user_settings(
640                    r#"{
641                        "assistant": {
642                            "default_open_ai_model": "gpt-4-0613"
643                        }
644                    }"#,
645                    cx,
646                )
647                .unwrap();
648        });
649        assert_eq!(
650            AssistantSettings::get_global(cx).provider,
651            AssistantProvider::OpenAi {
652                model: OpenAiModel::Four,
653                api_url: open_ai::OPEN_AI_API_URL.into(),
654                low_speed_timeout_in_seconds: None,
655            }
656        );
657
658        // The new version supports setting a custom model when using zed.dev.
659        SettingsStore::update_global(cx, |store, cx| {
660            store
661                .set_user_settings(
662                    r#"{
663                        "assistant": {
664                            "version": "1",
665                            "provider": {
666                                "name": "zed.dev",
667                                "default_model": "custom"
668                            }
669                        }
670                    }"#,
671                    cx,
672                )
673                .unwrap();
674        });
675        assert_eq!(
676            AssistantSettings::get_global(cx).provider,
677            AssistantProvider::ZedDotDev {
678                model: CloudModel::Custom("custom".into())
679            }
680        );
681    }
682}