assistant_settings.rs

  1use std::fmt;
  2
  3pub use anthropic::Model as AnthropicModel;
  4use gpui::Pixels;
  5pub use ollama::Model as OllamaModel;
  6pub use open_ai::Model as OpenAiModel;
  7use schemars::{
  8    schema::{InstanceType, Metadata, Schema, SchemaObject},
  9    JsonSchema,
 10};
 11use serde::{
 12    de::{self, Visitor},
 13    Deserialize, Deserializer, Serialize, Serializer,
 14};
 15use settings::{Settings, SettingsSources};
 16use strum::{EnumIter, IntoEnumIterator};
 17
 18use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
 19
 20#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 21pub enum CloudModel {
 22    Gpt3Point5Turbo,
 23    Gpt4,
 24    Gpt4Turbo,
 25    #[default]
 26    Gpt4Omni,
 27    Claude3Opus,
 28    Claude3Sonnet,
 29    Claude3Haiku,
 30    Custom(String),
 31}
 32
 33impl Serialize for CloudModel {
 34    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 35    where
 36        S: Serializer,
 37    {
 38        serializer.serialize_str(self.id())
 39    }
 40}
 41
 42impl<'de> Deserialize<'de> for CloudModel {
 43    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 44    where
 45        D: Deserializer<'de>,
 46    {
 47        struct ZedDotDevModelVisitor;
 48
 49        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 50            type Value = CloudModel;
 51
 52            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 53                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 54            }
 55
 56            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 57            where
 58                E: de::Error,
 59            {
 60                let model = CloudModel::iter()
 61                    .find(|model| model.id() == value)
 62                    .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
 63                Ok(model)
 64            }
 65        }
 66
 67        deserializer.deserialize_str(ZedDotDevModelVisitor)
 68    }
 69}
 70
 71impl JsonSchema for CloudModel {
 72    fn schema_name() -> String {
 73        "ZedDotDevModel".to_owned()
 74    }
 75
 76    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 77        let variants = CloudModel::iter()
 78            .filter_map(|model| {
 79                let id = model.id();
 80                if id.is_empty() {
 81                    None
 82                } else {
 83                    Some(id.to_string())
 84                }
 85            })
 86            .collect::<Vec<_>>();
 87        Schema::Object(SchemaObject {
 88            instance_type: Some(InstanceType::String.into()),
 89            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
 90            metadata: Some(Box::new(Metadata {
 91                title: Some("ZedDotDevModel".to_owned()),
 92                default: Some(CloudModel::default().id().into()),
 93                examples: variants.into_iter().map(Into::into).collect(),
 94                ..Default::default()
 95            })),
 96            ..Default::default()
 97        })
 98    }
 99}
100
101impl CloudModel {
102    pub fn id(&self) -> &str {
103        match self {
104            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
105            Self::Gpt4 => "gpt-4",
106            Self::Gpt4Turbo => "gpt-4-turbo-preview",
107            Self::Gpt4Omni => "gpt-4o",
108            Self::Claude3Opus => "claude-3-opus",
109            Self::Claude3Sonnet => "claude-3-sonnet",
110            Self::Claude3Haiku => "claude-3-haiku",
111            Self::Custom(id) => id,
112        }
113    }
114
115    pub fn display_name(&self) -> &str {
116        match self {
117            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
118            Self::Gpt4 => "GPT 4",
119            Self::Gpt4Turbo => "GPT 4 Turbo",
120            Self::Gpt4Omni => "GPT 4 Omni",
121            Self::Claude3Opus => "Claude 3 Opus",
122            Self::Claude3Sonnet => "Claude 3 Sonnet",
123            Self::Claude3Haiku => "Claude 3 Haiku",
124            Self::Custom(id) => id.as_str(),
125        }
126    }
127
128    pub fn max_token_count(&self) -> usize {
129        match self {
130            Self::Gpt3Point5Turbo => 2048,
131            Self::Gpt4 => 4096,
132            Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
133            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
134            Self::Custom(_) => 4096, // TODO: Make this configurable
135        }
136    }
137
138    pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
139        match self {
140            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
141                preprocess_anthropic_request(request)
142            }
143            _ => {}
144        }
145    }
146}
147
148#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
149#[serde(rename_all = "snake_case")]
150pub enum AssistantDockPosition {
151    Left,
152    #[default]
153    Right,
154    Bottom,
155}
156
157#[derive(Debug, PartialEq)]
158pub enum AssistantProvider {
159    ZedDotDev {
160        model: CloudModel,
161    },
162    OpenAi {
163        model: OpenAiModel,
164        api_url: String,
165        low_speed_timeout_in_seconds: Option<u64>,
166    },
167    Anthropic {
168        model: AnthropicModel,
169        api_url: String,
170        low_speed_timeout_in_seconds: Option<u64>,
171    },
172    Ollama {
173        model: OllamaModel,
174        api_url: String,
175        low_speed_timeout_in_seconds: Option<u64>,
176    },
177}
178
179impl Default for AssistantProvider {
180    fn default() -> Self {
181        Self::OpenAi {
182            model: OpenAiModel::default(),
183            api_url: open_ai::OPEN_AI_API_URL.into(),
184            low_speed_timeout_in_seconds: None,
185        }
186    }
187}
188
189#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
190#[serde(tag = "name", rename_all = "snake_case")]
191pub enum AssistantProviderContent {
192    #[serde(rename = "zed.dev")]
193    ZedDotDev { default_model: Option<CloudModel> },
194    #[serde(rename = "openai")]
195    OpenAi {
196        default_model: Option<OpenAiModel>,
197        api_url: Option<String>,
198        low_speed_timeout_in_seconds: Option<u64>,
199    },
200    #[serde(rename = "anthropic")]
201    Anthropic {
202        default_model: Option<AnthropicModel>,
203        api_url: Option<String>,
204        low_speed_timeout_in_seconds: Option<u64>,
205    },
206    #[serde(rename = "ollama")]
207    Ollama {
208        default_model: Option<OllamaModel>,
209        api_url: Option<String>,
210        low_speed_timeout_in_seconds: Option<u64>,
211    },
212}
213
214#[derive(Debug, Default)]
215pub struct AssistantSettings {
216    pub enabled: bool,
217    pub button: bool,
218    pub dock: AssistantDockPosition,
219    pub default_width: Pixels,
220    pub default_height: Pixels,
221    pub provider: AssistantProvider,
222}
223
224/// Assistant panel settings
225#[derive(Clone, Serialize, Deserialize, Debug)]
226#[serde(untagged)]
227pub enum AssistantSettingsContent {
228    Versioned(VersionedAssistantSettingsContent),
229    Legacy(LegacyAssistantSettingsContent),
230}
231
232impl JsonSchema for AssistantSettingsContent {
233    fn schema_name() -> String {
234        VersionedAssistantSettingsContent::schema_name()
235    }
236
237    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
238        VersionedAssistantSettingsContent::json_schema(gen)
239    }
240
241    fn is_referenceable() -> bool {
242        VersionedAssistantSettingsContent::is_referenceable()
243    }
244}
245
246impl Default for AssistantSettingsContent {
247    fn default() -> Self {
248        Self::Versioned(VersionedAssistantSettingsContent::default())
249    }
250}
251
252impl AssistantSettingsContent {
253    fn upgrade(&self) -> AssistantSettingsContentV1 {
254        match self {
255            AssistantSettingsContent::Versioned(settings) => match settings {
256                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
257            },
258            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
259                enabled: None,
260                button: settings.button,
261                dock: settings.dock,
262                default_width: settings.default_width,
263                default_height: settings.default_height,
264                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
265                    Some(AssistantProviderContent::OpenAi {
266                        default_model: settings.default_open_ai_model.clone(),
267                        api_url: Some(open_ai_api_url.clone()),
268                        low_speed_timeout_in_seconds: None,
269                    })
270                } else {
271                    settings.default_open_ai_model.clone().map(|open_ai_model| {
272                        AssistantProviderContent::OpenAi {
273                            default_model: Some(open_ai_model),
274                            api_url: None,
275                            low_speed_timeout_in_seconds: None,
276                        }
277                    })
278                },
279            },
280        }
281    }
282
283    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
284        match self {
285            AssistantSettingsContent::Versioned(settings) => match settings {
286                VersionedAssistantSettingsContent::V1(settings) => {
287                    settings.dock = Some(dock);
288                }
289            },
290            AssistantSettingsContent::Legacy(settings) => {
291                settings.dock = Some(dock);
292            }
293        }
294    }
295
296    pub fn set_model(&mut self, new_model: LanguageModel) {
297        match self {
298            AssistantSettingsContent::Versioned(settings) => match settings {
299                VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
300                    Some(AssistantProviderContent::ZedDotDev {
301                        default_model: model,
302                    }) => {
303                        if let LanguageModel::Cloud(new_model) = new_model {
304                            *model = Some(new_model);
305                        }
306                    }
307                    Some(AssistantProviderContent::OpenAi {
308                        default_model: model,
309                        ..
310                    }) => {
311                        if let LanguageModel::OpenAi(new_model) = new_model {
312                            *model = Some(new_model);
313                        }
314                    }
315                    Some(AssistantProviderContent::Anthropic {
316                        default_model: model,
317                        ..
318                    }) => {
319                        if let LanguageModel::Anthropic(new_model) = new_model {
320                            *model = Some(new_model);
321                        }
322                    }
323                    provider => match new_model {
324                        LanguageModel::Cloud(model) => {
325                            *provider = Some(AssistantProviderContent::ZedDotDev {
326                                default_model: Some(model),
327                            })
328                        }
329                        LanguageModel::OpenAi(model) => {
330                            *provider = Some(AssistantProviderContent::OpenAi {
331                                default_model: Some(model),
332                                api_url: None,
333                                low_speed_timeout_in_seconds: None,
334                            })
335                        }
336                        LanguageModel::Anthropic(model) => {
337                            *provider = Some(AssistantProviderContent::Anthropic {
338                                default_model: Some(model),
339                                api_url: None,
340                                low_speed_timeout_in_seconds: None,
341                            })
342                        }
343                        LanguageModel::Ollama(model) => {
344                            *provider = Some(AssistantProviderContent::Ollama {
345                                default_model: Some(model),
346                                api_url: None,
347                                low_speed_timeout_in_seconds: None,
348                            })
349                        }
350                    },
351                },
352            },
353            AssistantSettingsContent::Legacy(settings) => {
354                if let LanguageModel::OpenAi(model) = new_model {
355                    settings.default_open_ai_model = Some(model);
356                }
357            }
358        }
359    }
360}
361
362#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
363#[serde(tag = "version")]
364pub enum VersionedAssistantSettingsContent {
365    #[serde(rename = "1")]
366    V1(AssistantSettingsContentV1),
367}
368
369impl Default for VersionedAssistantSettingsContent {
370    fn default() -> Self {
371        Self::V1(AssistantSettingsContentV1 {
372            enabled: None,
373            button: None,
374            dock: None,
375            default_width: None,
376            default_height: None,
377            provider: None,
378        })
379    }
380}
381
382#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
383pub struct AssistantSettingsContentV1 {
384    /// Whether the Assistant is enabled.
385    ///
386    /// Default: true
387    enabled: Option<bool>,
388    /// Whether to show the assistant panel button in the status bar.
389    ///
390    /// Default: true
391    button: Option<bool>,
392    /// Where to dock the assistant.
393    ///
394    /// Default: right
395    dock: Option<AssistantDockPosition>,
396    /// Default width in pixels when the assistant is docked to the left or right.
397    ///
398    /// Default: 640
399    default_width: Option<f32>,
400    /// Default height in pixels when the assistant is docked to the bottom.
401    ///
402    /// Default: 320
403    default_height: Option<f32>,
404    /// The provider of the assistant service.
405    ///
406    /// This can either be the internal `zed.dev` service or an external `openai` service,
407    /// each with their respective default models and configurations.
408    provider: Option<AssistantProviderContent>,
409}
410
411#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
412pub struct LegacyAssistantSettingsContent {
413    /// Whether to show the assistant panel button in the status bar.
414    ///
415    /// Default: true
416    pub button: Option<bool>,
417    /// Where to dock the assistant.
418    ///
419    /// Default: right
420    pub dock: Option<AssistantDockPosition>,
421    /// Default width in pixels when the assistant is docked to the left or right.
422    ///
423    /// Default: 640
424    pub default_width: Option<f32>,
425    /// Default height in pixels when the assistant is docked to the bottom.
426    ///
427    /// Default: 320
428    pub default_height: Option<f32>,
429    /// The default OpenAI model to use when creating new contexts.
430    ///
431    /// Default: gpt-4-1106-preview
432    pub default_open_ai_model: Option<OpenAiModel>,
433    /// OpenAI API base URL to use when creating new contexts.
434    ///
435    /// Default: https://api.openai.com/v1
436    pub openai_api_url: Option<String>,
437}
438
439impl Settings for AssistantSettings {
440    const KEY: Option<&'static str> = Some("assistant");
441
442    type FileContent = AssistantSettingsContent;
443
444    fn load(
445        sources: SettingsSources<Self::FileContent>,
446        _: &mut gpui::AppContext,
447    ) -> anyhow::Result<Self> {
448        let mut settings = AssistantSettings::default();
449
450        for value in sources.defaults_and_customizations() {
451            let value = value.upgrade();
452            merge(&mut settings.enabled, value.enabled);
453            merge(&mut settings.button, value.button);
454            merge(&mut settings.dock, value.dock);
455            merge(
456                &mut settings.default_width,
457                value.default_width.map(Into::into),
458            );
459            merge(
460                &mut settings.default_height,
461                value.default_height.map(Into::into),
462            );
463            if let Some(provider) = value.provider.clone() {
464                match (&mut settings.provider, provider) {
465                    (
466                        AssistantProvider::ZedDotDev { model },
467                        AssistantProviderContent::ZedDotDev {
468                            default_model: model_override,
469                        },
470                    ) => {
471                        merge(model, model_override);
472                    }
473                    (
474                        AssistantProvider::OpenAi {
475                            model,
476                            api_url,
477                            low_speed_timeout_in_seconds,
478                        },
479                        AssistantProviderContent::OpenAi {
480                            default_model: model_override,
481                            api_url: api_url_override,
482                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
483                        },
484                    ) => {
485                        merge(model, model_override);
486                        merge(api_url, api_url_override);
487                        if let Some(low_speed_timeout_in_seconds_override) =
488                            low_speed_timeout_in_seconds_override
489                        {
490                            *low_speed_timeout_in_seconds =
491                                Some(low_speed_timeout_in_seconds_override);
492                        }
493                    }
494                    (
495                        AssistantProvider::Ollama {
496                            model,
497                            api_url,
498                            low_speed_timeout_in_seconds,
499                        },
500                        AssistantProviderContent::Ollama {
501                            default_model: model_override,
502                            api_url: api_url_override,
503                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
504                        },
505                    ) => {
506                        merge(model, model_override);
507                        merge(api_url, api_url_override);
508                        if let Some(low_speed_timeout_in_seconds_override) =
509                            low_speed_timeout_in_seconds_override
510                        {
511                            *low_speed_timeout_in_seconds =
512                                Some(low_speed_timeout_in_seconds_override);
513                        }
514                    }
515                    (
516                        AssistantProvider::Anthropic {
517                            model,
518                            api_url,
519                            low_speed_timeout_in_seconds,
520                        },
521                        AssistantProviderContent::Anthropic {
522                            default_model: model_override,
523                            api_url: api_url_override,
524                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
525                        },
526                    ) => {
527                        merge(model, model_override);
528                        merge(api_url, api_url_override);
529                        if let Some(low_speed_timeout_in_seconds_override) =
530                            low_speed_timeout_in_seconds_override
531                        {
532                            *low_speed_timeout_in_seconds =
533                                Some(low_speed_timeout_in_seconds_override);
534                        }
535                    }
536                    (provider, provider_override) => {
537                        *provider = match provider_override {
538                            AssistantProviderContent::ZedDotDev {
539                                default_model: model,
540                            } => AssistantProvider::ZedDotDev {
541                                model: model.unwrap_or_default(),
542                            },
543                            AssistantProviderContent::OpenAi {
544                                default_model: model,
545                                api_url,
546                                low_speed_timeout_in_seconds,
547                            } => AssistantProvider::OpenAi {
548                                model: model.unwrap_or_default(),
549                                api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
550                                low_speed_timeout_in_seconds,
551                            },
552                            AssistantProviderContent::Anthropic {
553                                default_model: model,
554                                api_url,
555                                low_speed_timeout_in_seconds,
556                            } => AssistantProvider::Anthropic {
557                                model: model.unwrap_or_default(),
558                                api_url: api_url
559                                    .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
560                                low_speed_timeout_in_seconds,
561                            },
562                            AssistantProviderContent::Ollama {
563                                default_model: model,
564                                api_url,
565                                low_speed_timeout_in_seconds,
566                            } => AssistantProvider::Ollama {
567                                model: model.unwrap_or_default(),
568                                api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
569                                low_speed_timeout_in_seconds,
570                            },
571                        };
572                    }
573                }
574            }
575        }
576
577        Ok(settings)
578    }
579}
580
581fn merge<T>(target: &mut T, value: Option<T>) {
582    if let Some(value) = value {
583        *target = value;
584    }
585}
586
587#[cfg(test)]
588mod tests {
589    use gpui::{AppContext, UpdateGlobal};
590    use settings::SettingsStore;
591
592    use super::*;
593
594    #[gpui::test]
595    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
596        let store = settings::SettingsStore::test(cx);
597        cx.set_global(store);
598
599        // Settings default to gpt-4-turbo.
600        AssistantSettings::register(cx);
601        assert_eq!(
602            AssistantSettings::get_global(cx).provider,
603            AssistantProvider::OpenAi {
604                model: OpenAiModel::FourOmni,
605                api_url: open_ai::OPEN_AI_API_URL.into(),
606                low_speed_timeout_in_seconds: None,
607            }
608        );
609
610        // Ensure backward-compatibility.
611        SettingsStore::update_global(cx, |store, cx| {
612            store
613                .set_user_settings(
614                    r#"{
615                        "assistant": {
616                            "openai_api_url": "test-url",
617                        }
618                    }"#,
619                    cx,
620                )
621                .unwrap();
622        });
623        assert_eq!(
624            AssistantSettings::get_global(cx).provider,
625            AssistantProvider::OpenAi {
626                model: OpenAiModel::FourOmni,
627                api_url: "test-url".into(),
628                low_speed_timeout_in_seconds: None,
629            }
630        );
631        SettingsStore::update_global(cx, |store, cx| {
632            store
633                .set_user_settings(
634                    r#"{
635                        "assistant": {
636                            "default_open_ai_model": "gpt-4-0613"
637                        }
638                    }"#,
639                    cx,
640                )
641                .unwrap();
642        });
643        assert_eq!(
644            AssistantSettings::get_global(cx).provider,
645            AssistantProvider::OpenAi {
646                model: OpenAiModel::Four,
647                api_url: open_ai::OPEN_AI_API_URL.into(),
648                low_speed_timeout_in_seconds: None,
649            }
650        );
651
652        // The new version supports setting a custom model when using zed.dev.
653        SettingsStore::update_global(cx, |store, cx| {
654            store
655                .set_user_settings(
656                    r#"{
657                        "assistant": {
658                            "version": "1",
659                            "provider": {
660                                "name": "zed.dev",
661                                "default_model": "custom"
662                            }
663                        }
664                    }"#,
665                    cx,
666                )
667                .unwrap();
668        });
669        assert_eq!(
670            AssistantSettings::get_global(cx).provider,
671            AssistantProvider::ZedDotDev {
672                model: CloudModel::Custom("custom".into())
673            }
674        );
675    }
676}