assistant_settings.rs

  1use std::fmt;
  2
  3use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
  4pub use anthropic::Model as AnthropicModel;
  5use gpui::Pixels;
  6pub use ollama::Model as OllamaModel;
  7pub use open_ai::Model as OpenAiModel;
  8use schemars::{
  9    schema::{InstanceType, Metadata, Schema, SchemaObject},
 10    JsonSchema,
 11};
 12use serde::{
 13    de::{self, Visitor},
 14    Deserialize, Deserializer, Serialize, Serializer,
 15};
 16use settings::{Settings, SettingsSources};
 17use strum::{EnumIter, IntoEnumIterator};
 18
 19#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 20pub enum CloudModel {
 21    Gpt3Point5Turbo,
 22    Gpt4,
 23    Gpt4Turbo,
 24    #[default]
 25    Gpt4Omni,
 26    Claude3_5Sonnet,
 27    Claude3Opus,
 28    Claude3Sonnet,
 29    Claude3Haiku,
 30    Custom(String),
 31}
 32
 33impl Serialize for CloudModel {
 34    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 35    where
 36        S: Serializer,
 37    {
 38        serializer.serialize_str(self.id())
 39    }
 40}
 41
 42impl<'de> Deserialize<'de> for CloudModel {
 43    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 44    where
 45        D: Deserializer<'de>,
 46    {
 47        struct ZedDotDevModelVisitor;
 48
 49        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 50            type Value = CloudModel;
 51
 52            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 53                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 54            }
 55
 56            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 57            where
 58                E: de::Error,
 59            {
 60                let model = CloudModel::iter()
 61                    .find(|model| model.id() == value)
 62                    .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
 63                Ok(model)
 64            }
 65        }
 66
 67        deserializer.deserialize_str(ZedDotDevModelVisitor)
 68    }
 69}
 70
 71impl JsonSchema for CloudModel {
 72    fn schema_name() -> String {
 73        "ZedDotDevModel".to_owned()
 74    }
 75
 76    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 77        let variants = CloudModel::iter()
 78            .filter_map(|model| {
 79                let id = model.id();
 80                if id.is_empty() {
 81                    None
 82                } else {
 83                    Some(id.to_string())
 84                }
 85            })
 86            .collect::<Vec<_>>();
 87        Schema::Object(SchemaObject {
 88            instance_type: Some(InstanceType::String.into()),
 89            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
 90            metadata: Some(Box::new(Metadata {
 91                title: Some("ZedDotDevModel".to_owned()),
 92                default: Some(CloudModel::default().id().into()),
 93                examples: variants.into_iter().map(Into::into).collect(),
 94                ..Default::default()
 95            })),
 96            ..Default::default()
 97        })
 98    }
 99}
100
101impl CloudModel {
102    pub fn id(&self) -> &str {
103        match self {
104            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
105            Self::Gpt4 => "gpt-4",
106            Self::Gpt4Turbo => "gpt-4-turbo-preview",
107            Self::Gpt4Omni => "gpt-4o",
108            Self::Claude3_5Sonnet => "claude-3-5-sonnet",
109            Self::Claude3Opus => "claude-3-opus",
110            Self::Claude3Sonnet => "claude-3-sonnet",
111            Self::Claude3Haiku => "claude-3-haiku",
112            Self::Custom(id) => id,
113        }
114    }
115
116    pub fn display_name(&self) -> &str {
117        match self {
118            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
119            Self::Gpt4 => "GPT 4",
120            Self::Gpt4Turbo => "GPT 4 Turbo",
121            Self::Gpt4Omni => "GPT 4 Omni",
122            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
123            Self::Claude3Opus => "Claude 3 Opus",
124            Self::Claude3Sonnet => "Claude 3 Sonnet",
125            Self::Claude3Haiku => "Claude 3 Haiku",
126            Self::Custom(id) => id.as_str(),
127        }
128    }
129
130    pub fn max_token_count(&self) -> usize {
131        match self {
132            Self::Gpt3Point5Turbo => 2048,
133            Self::Gpt4 => 4096,
134            Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
135            Self::Claude3_5Sonnet
136            | Self::Claude3Opus
137            | Self::Claude3Sonnet
138            | Self::Claude3Haiku => 200000,
139            Self::Custom(_) => 4096, // TODO: Make this configurable
140        }
141    }
142
143    pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
144        match self {
145            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
146                preprocess_anthropic_request(request)
147            }
148            _ => {}
149        }
150    }
151}
152
153#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
154#[serde(rename_all = "snake_case")]
155pub enum AssistantDockPosition {
156    Left,
157    #[default]
158    Right,
159    Bottom,
160}
161
162#[derive(Debug, PartialEq)]
163pub enum AssistantProvider {
164    ZedDotDev {
165        model: CloudModel,
166    },
167    OpenAi {
168        model: OpenAiModel,
169        api_url: String,
170        low_speed_timeout_in_seconds: Option<u64>,
171        available_models: Vec<OpenAiModel>,
172    },
173    Anthropic {
174        model: AnthropicModel,
175        api_url: String,
176        low_speed_timeout_in_seconds: Option<u64>,
177    },
178    Ollama {
179        model: OllamaModel,
180        api_url: String,
181        low_speed_timeout_in_seconds: Option<u64>,
182    },
183}
184
185impl Default for AssistantProvider {
186    fn default() -> Self {
187        Self::OpenAi {
188            model: OpenAiModel::default(),
189            api_url: open_ai::OPEN_AI_API_URL.into(),
190            low_speed_timeout_in_seconds: None,
191            available_models: Default::default(),
192        }
193    }
194}
195
196#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
197#[serde(tag = "name", rename_all = "snake_case")]
198pub enum AssistantProviderContent {
199    #[serde(rename = "zed.dev")]
200    ZedDotDev { default_model: Option<CloudModel> },
201    #[serde(rename = "openai")]
202    OpenAi {
203        default_model: Option<OpenAiModel>,
204        api_url: Option<String>,
205        low_speed_timeout_in_seconds: Option<u64>,
206        available_models: Option<Vec<OpenAiModel>>,
207    },
208    #[serde(rename = "anthropic")]
209    Anthropic {
210        default_model: Option<AnthropicModel>,
211        api_url: Option<String>,
212        low_speed_timeout_in_seconds: Option<u64>,
213    },
214    #[serde(rename = "ollama")]
215    Ollama {
216        default_model: Option<OllamaModel>,
217        api_url: Option<String>,
218        low_speed_timeout_in_seconds: Option<u64>,
219    },
220}
221
222#[derive(Debug, Default)]
223pub struct AssistantSettings {
224    pub enabled: bool,
225    pub button: bool,
226    pub dock: AssistantDockPosition,
227    pub default_width: Pixels,
228    pub default_height: Pixels,
229    pub provider: AssistantProvider,
230}
231
232/// Assistant panel settings
233#[derive(Clone, Serialize, Deserialize, Debug)]
234#[serde(untagged)]
235pub enum AssistantSettingsContent {
236    Versioned(VersionedAssistantSettingsContent),
237    Legacy(LegacyAssistantSettingsContent),
238}
239
240impl JsonSchema for AssistantSettingsContent {
241    fn schema_name() -> String {
242        VersionedAssistantSettingsContent::schema_name()
243    }
244
245    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
246        VersionedAssistantSettingsContent::json_schema(gen)
247    }
248
249    fn is_referenceable() -> bool {
250        VersionedAssistantSettingsContent::is_referenceable()
251    }
252}
253
254impl Default for AssistantSettingsContent {
255    fn default() -> Self {
256        Self::Versioned(VersionedAssistantSettingsContent::default())
257    }
258}
259
260impl AssistantSettingsContent {
261    fn upgrade(&self) -> AssistantSettingsContentV1 {
262        match self {
263            AssistantSettingsContent::Versioned(settings) => match settings {
264                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
265            },
266            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
267                enabled: None,
268                button: settings.button,
269                dock: settings.dock,
270                default_width: settings.default_width,
271                default_height: settings.default_height,
272                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
273                    Some(AssistantProviderContent::OpenAi {
274                        default_model: settings.default_open_ai_model.clone(),
275                        api_url: Some(open_ai_api_url.clone()),
276                        low_speed_timeout_in_seconds: None,
277                        available_models: Some(Default::default()),
278                    })
279                } else {
280                    settings.default_open_ai_model.clone().map(|open_ai_model| {
281                        AssistantProviderContent::OpenAi {
282                            default_model: Some(open_ai_model),
283                            api_url: None,
284                            low_speed_timeout_in_seconds: None,
285                            available_models: Some(Default::default()),
286                        }
287                    })
288                },
289            },
290        }
291    }
292
293    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
294        match self {
295            AssistantSettingsContent::Versioned(settings) => match settings {
296                VersionedAssistantSettingsContent::V1(settings) => {
297                    settings.dock = Some(dock);
298                }
299            },
300            AssistantSettingsContent::Legacy(settings) => {
301                settings.dock = Some(dock);
302            }
303        }
304    }
305
306    pub fn set_model(&mut self, new_model: LanguageModel) {
307        match self {
308            AssistantSettingsContent::Versioned(settings) => match settings {
309                VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
310                    Some(AssistantProviderContent::ZedDotDev {
311                        default_model: model,
312                    }) => {
313                        if let LanguageModel::Cloud(new_model) = new_model {
314                            *model = Some(new_model);
315                        }
316                    }
317                    Some(AssistantProviderContent::OpenAi {
318                        default_model: model,
319                        ..
320                    }) => {
321                        if let LanguageModel::OpenAi(new_model) = new_model {
322                            *model = Some(new_model);
323                        }
324                    }
325                    Some(AssistantProviderContent::Anthropic {
326                        default_model: model,
327                        ..
328                    }) => {
329                        if let LanguageModel::Anthropic(new_model) = new_model {
330                            *model = Some(new_model);
331                        }
332                    }
333                    Some(AssistantProviderContent::Ollama {
334                        default_model: model,
335                        ..
336                    }) => {
337                        if let LanguageModel::Ollama(new_model) = new_model {
338                            *model = Some(new_model);
339                        }
340                    }
341                    provider => match new_model {
342                        LanguageModel::Cloud(model) => {
343                            *provider = Some(AssistantProviderContent::ZedDotDev {
344                                default_model: Some(model),
345                            })
346                        }
347                        LanguageModel::OpenAi(model) => {
348                            *provider = Some(AssistantProviderContent::OpenAi {
349                                default_model: Some(model),
350                                api_url: None,
351                                low_speed_timeout_in_seconds: None,
352                                available_models: Some(Default::default()),
353                            })
354                        }
355                        LanguageModel::Anthropic(model) => {
356                            *provider = Some(AssistantProviderContent::Anthropic {
357                                default_model: Some(model),
358                                api_url: None,
359                                low_speed_timeout_in_seconds: None,
360                            })
361                        }
362                        LanguageModel::Ollama(model) => {
363                            *provider = Some(AssistantProviderContent::Ollama {
364                                default_model: Some(model),
365                                api_url: None,
366                                low_speed_timeout_in_seconds: None,
367                            })
368                        }
369                    },
370                },
371            },
372            AssistantSettingsContent::Legacy(settings) => {
373                if let LanguageModel::OpenAi(model) = new_model {
374                    settings.default_open_ai_model = Some(model);
375                }
376            }
377        }
378    }
379}
380
381#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
382#[serde(tag = "version")]
383pub enum VersionedAssistantSettingsContent {
384    #[serde(rename = "1")]
385    V1(AssistantSettingsContentV1),
386}
387
388impl Default for VersionedAssistantSettingsContent {
389    fn default() -> Self {
390        Self::V1(AssistantSettingsContentV1 {
391            enabled: None,
392            button: None,
393            dock: None,
394            default_width: None,
395            default_height: None,
396            provider: None,
397        })
398    }
399}
400
401#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
402pub struct AssistantSettingsContentV1 {
403    /// Whether the Assistant is enabled.
404    ///
405    /// Default: true
406    enabled: Option<bool>,
407    /// Whether to show the assistant panel button in the status bar.
408    ///
409    /// Default: true
410    button: Option<bool>,
411    /// Where to dock the assistant.
412    ///
413    /// Default: right
414    dock: Option<AssistantDockPosition>,
415    /// Default width in pixels when the assistant is docked to the left or right.
416    ///
417    /// Default: 640
418    default_width: Option<f32>,
419    /// Default height in pixels when the assistant is docked to the bottom.
420    ///
421    /// Default: 320
422    default_height: Option<f32>,
423    /// The provider of the assistant service.
424    ///
425    /// This can either be the internal `zed.dev` service or an external `openai` service,
426    /// each with their respective default models and configurations.
427    provider: Option<AssistantProviderContent>,
428}
429
430#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
431pub struct LegacyAssistantSettingsContent {
432    /// Whether to show the assistant panel button in the status bar.
433    ///
434    /// Default: true
435    pub button: Option<bool>,
436    /// Where to dock the assistant.
437    ///
438    /// Default: right
439    pub dock: Option<AssistantDockPosition>,
440    /// Default width in pixels when the assistant is docked to the left or right.
441    ///
442    /// Default: 640
443    pub default_width: Option<f32>,
444    /// Default height in pixels when the assistant is docked to the bottom.
445    ///
446    /// Default: 320
447    pub default_height: Option<f32>,
448    /// The default OpenAI model to use when creating new contexts.
449    ///
450    /// Default: gpt-4-1106-preview
451    pub default_open_ai_model: Option<OpenAiModel>,
452    /// OpenAI API base URL to use when creating new contexts.
453    ///
454    /// Default: https://api.openai.com/v1
455    pub openai_api_url: Option<String>,
456}
457
458impl Settings for AssistantSettings {
459    const KEY: Option<&'static str> = Some("assistant");
460
461    type FileContent = AssistantSettingsContent;
462
463    fn load(
464        sources: SettingsSources<Self::FileContent>,
465        _: &mut gpui::AppContext,
466    ) -> anyhow::Result<Self> {
467        let mut settings = AssistantSettings::default();
468
469        for value in sources.defaults_and_customizations() {
470            let value = value.upgrade();
471            merge(&mut settings.enabled, value.enabled);
472            merge(&mut settings.button, value.button);
473            merge(&mut settings.dock, value.dock);
474            merge(
475                &mut settings.default_width,
476                value.default_width.map(Into::into),
477            );
478            merge(
479                &mut settings.default_height,
480                value.default_height.map(Into::into),
481            );
482            if let Some(provider) = value.provider.clone() {
483                match (&mut settings.provider, provider) {
484                    (
485                        AssistantProvider::ZedDotDev { model },
486                        AssistantProviderContent::ZedDotDev {
487                            default_model: model_override,
488                        },
489                    ) => {
490                        merge(model, model_override);
491                    }
492                    (
493                        AssistantProvider::OpenAi {
494                            model,
495                            api_url,
496                            low_speed_timeout_in_seconds,
497                            available_models,
498                        },
499                        AssistantProviderContent::OpenAi {
500                            default_model: model_override,
501                            api_url: api_url_override,
502                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
503                            available_models: available_models_override,
504                        },
505                    ) => {
506                        merge(model, model_override);
507                        merge(api_url, api_url_override);
508                        merge(available_models, available_models_override);
509                        if let Some(low_speed_timeout_in_seconds_override) =
510                            low_speed_timeout_in_seconds_override
511                        {
512                            *low_speed_timeout_in_seconds =
513                                Some(low_speed_timeout_in_seconds_override);
514                        }
515                    }
516                    (
517                        AssistantProvider::Ollama {
518                            model,
519                            api_url,
520                            low_speed_timeout_in_seconds,
521                        },
522                        AssistantProviderContent::Ollama {
523                            default_model: model_override,
524                            api_url: api_url_override,
525                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
526                        },
527                    ) => {
528                        merge(model, model_override);
529                        merge(api_url, api_url_override);
530                        if let Some(low_speed_timeout_in_seconds_override) =
531                            low_speed_timeout_in_seconds_override
532                        {
533                            *low_speed_timeout_in_seconds =
534                                Some(low_speed_timeout_in_seconds_override);
535                        }
536                    }
537                    (
538                        AssistantProvider::Anthropic {
539                            model,
540                            api_url,
541                            low_speed_timeout_in_seconds,
542                        },
543                        AssistantProviderContent::Anthropic {
544                            default_model: model_override,
545                            api_url: api_url_override,
546                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
547                        },
548                    ) => {
549                        merge(model, model_override);
550                        merge(api_url, api_url_override);
551                        if let Some(low_speed_timeout_in_seconds_override) =
552                            low_speed_timeout_in_seconds_override
553                        {
554                            *low_speed_timeout_in_seconds =
555                                Some(low_speed_timeout_in_seconds_override);
556                        }
557                    }
558                    (provider, provider_override) => {
559                        *provider = match provider_override {
560                            AssistantProviderContent::ZedDotDev {
561                                default_model: model,
562                            } => AssistantProvider::ZedDotDev {
563                                model: model.unwrap_or_default(),
564                            },
565                            AssistantProviderContent::OpenAi {
566                                default_model: model,
567                                api_url,
568                                low_speed_timeout_in_seconds,
569                                available_models,
570                            } => AssistantProvider::OpenAi {
571                                model: model.unwrap_or_default(),
572                                api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
573                                low_speed_timeout_in_seconds,
574                                available_models: available_models.unwrap_or_default(),
575                            },
576                            AssistantProviderContent::Anthropic {
577                                default_model: model,
578                                api_url,
579                                low_speed_timeout_in_seconds,
580                            } => AssistantProvider::Anthropic {
581                                model: model.unwrap_or_default(),
582                                api_url: api_url
583                                    .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
584                                low_speed_timeout_in_seconds,
585                            },
586                            AssistantProviderContent::Ollama {
587                                default_model: model,
588                                api_url,
589                                low_speed_timeout_in_seconds,
590                            } => AssistantProvider::Ollama {
591                                model: model.unwrap_or_default(),
592                                api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
593                                low_speed_timeout_in_seconds,
594                            },
595                        };
596                    }
597                }
598            }
599        }
600
601        Ok(settings)
602    }
603}
604
605fn merge<T>(target: &mut T, value: Option<T>) {
606    if let Some(value) = value {
607        *target = value;
608    }
609}
610
611#[cfg(test)]
612mod tests {
613    use gpui::{AppContext, UpdateGlobal};
614    use settings::SettingsStore;
615
616    use super::*;
617
618    #[gpui::test]
619    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
620        let store = settings::SettingsStore::test(cx);
621        cx.set_global(store);
622
623        // Settings default to gpt-4-turbo.
624        AssistantSettings::register(cx);
625        assert_eq!(
626            AssistantSettings::get_global(cx).provider,
627            AssistantProvider::OpenAi {
628                model: OpenAiModel::FourOmni,
629                api_url: open_ai::OPEN_AI_API_URL.into(),
630                low_speed_timeout_in_seconds: None,
631                available_models: Default::default(),
632            }
633        );
634
635        // Ensure backward-compatibility.
636        SettingsStore::update_global(cx, |store, cx| {
637            store
638                .set_user_settings(
639                    r#"{
640                        "assistant": {
641                            "openai_api_url": "test-url",
642                        }
643                    }"#,
644                    cx,
645                )
646                .unwrap();
647        });
648        assert_eq!(
649            AssistantSettings::get_global(cx).provider,
650            AssistantProvider::OpenAi {
651                model: OpenAiModel::FourOmni,
652                api_url: "test-url".into(),
653                low_speed_timeout_in_seconds: None,
654                available_models: Default::default(),
655            }
656        );
657        SettingsStore::update_global(cx, |store, cx| {
658            store
659                .set_user_settings(
660                    r#"{
661                        "assistant": {
662                            "default_open_ai_model": "gpt-4-0613"
663                        }
664                    }"#,
665                    cx,
666                )
667                .unwrap();
668        });
669        assert_eq!(
670            AssistantSettings::get_global(cx).provider,
671            AssistantProvider::OpenAi {
672                model: OpenAiModel::Four,
673                api_url: open_ai::OPEN_AI_API_URL.into(),
674                low_speed_timeout_in_seconds: None,
675                available_models: Default::default(),
676            }
677        );
678
679        // The new version supports setting a custom model when using zed.dev.
680        SettingsStore::update_global(cx, |store, cx| {
681            store
682                .set_user_settings(
683                    r#"{
684                        "assistant": {
685                            "version": "1",
686                            "provider": {
687                                "name": "zed.dev",
688                                "default_model": "custom"
689                            }
690                        }
691                    }"#,
692                    cx,
693                )
694                .unwrap();
695        });
696        assert_eq!(
697            AssistantSettings::get_global(cx).provider,
698            AssistantProvider::ZedDotDev {
699                model: CloudModel::Custom("custom".into())
700            }
701        );
702    }
703}