assistant_settings.rs

  1use std::fmt;
  2
  3use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
  4pub use anthropic::Model as AnthropicModel;
  5use gpui::Pixels;
  6pub use ollama::Model as OllamaModel;
  7pub use open_ai::Model as OpenAiModel;
  8use schemars::{
  9    schema::{InstanceType, Metadata, Schema, SchemaObject},
 10    JsonSchema,
 11};
 12use serde::{
 13    de::{self, Visitor},
 14    Deserialize, Deserializer, Serialize, Serializer,
 15};
 16use settings::{Settings, SettingsSources};
 17use strum::{EnumIter, IntoEnumIterator};
 18
 19#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 20pub enum CloudModel {
 21    Gpt3Point5Turbo,
 22    Gpt4,
 23    Gpt4Turbo,
 24    #[default]
 25    Gpt4Omni,
 26    Gpt4OmniMini,
 27    Claude3_5Sonnet,
 28    Claude3Opus,
 29    Claude3Sonnet,
 30    Claude3Haiku,
 31    Gemini15Pro,
 32    Gemini15Flash,
 33    Custom(String),
 34}
 35
 36impl Serialize for CloudModel {
 37    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 38    where
 39        S: Serializer,
 40    {
 41        serializer.serialize_str(self.id())
 42    }
 43}
 44
 45impl<'de> Deserialize<'de> for CloudModel {
 46    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 47    where
 48        D: Deserializer<'de>,
 49    {
 50        struct ZedDotDevModelVisitor;
 51
 52        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 53            type Value = CloudModel;
 54
 55            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 56                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 57            }
 58
 59            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 60            where
 61                E: de::Error,
 62            {
 63                let model = CloudModel::iter()
 64                    .find(|model| model.id() == value)
 65                    .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
 66                Ok(model)
 67            }
 68        }
 69
 70        deserializer.deserialize_str(ZedDotDevModelVisitor)
 71    }
 72}
 73
 74impl JsonSchema for CloudModel {
 75    fn schema_name() -> String {
 76        "ZedDotDevModel".to_owned()
 77    }
 78
 79    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 80        let variants = CloudModel::iter()
 81            .filter_map(|model| {
 82                let id = model.id();
 83                if id.is_empty() {
 84                    None
 85                } else {
 86                    Some(id.to_string())
 87                }
 88            })
 89            .collect::<Vec<_>>();
 90        Schema::Object(SchemaObject {
 91            instance_type: Some(InstanceType::String.into()),
 92            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
 93            metadata: Some(Box::new(Metadata {
 94                title: Some("ZedDotDevModel".to_owned()),
 95                default: Some(CloudModel::default().id().into()),
 96                examples: variants.into_iter().map(Into::into).collect(),
 97                ..Default::default()
 98            })),
 99            ..Default::default()
100        })
101    }
102}
103
104impl CloudModel {
105    pub fn id(&self) -> &str {
106        match self {
107            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
108            Self::Gpt4 => "gpt-4",
109            Self::Gpt4Turbo => "gpt-4-turbo-preview",
110            Self::Gpt4Omni => "gpt-4o",
111            Self::Gpt4OmniMini => "gpt-4o-mini",
112            Self::Claude3_5Sonnet => "claude-3-5-sonnet",
113            Self::Claude3Opus => "claude-3-opus",
114            Self::Claude3Sonnet => "claude-3-sonnet",
115            Self::Claude3Haiku => "claude-3-haiku",
116            Self::Gemini15Pro => "gemini-1.5-pro",
117            Self::Gemini15Flash => "gemini-1.5-flash",
118            Self::Custom(id) => id,
119        }
120    }
121
122    pub fn display_name(&self) -> &str {
123        match self {
124            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
125            Self::Gpt4 => "GPT 4",
126            Self::Gpt4Turbo => "GPT 4 Turbo",
127            Self::Gpt4Omni => "GPT 4 Omni",
128            Self::Gpt4OmniMini => "GPT 4 Omni Mini",
129            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
130            Self::Claude3Opus => "Claude 3 Opus",
131            Self::Claude3Sonnet => "Claude 3 Sonnet",
132            Self::Claude3Haiku => "Claude 3 Haiku",
133            Self::Gemini15Pro => "Gemini 1.5 Pro",
134            Self::Gemini15Flash => "Gemini 1.5 Flash",
135            Self::Custom(id) => id.as_str(),
136        }
137    }
138
139    pub fn max_token_count(&self) -> usize {
140        match self {
141            Self::Gpt3Point5Turbo => 2048,
142            Self::Gpt4 => 4096,
143            Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
144            Self::Gpt4OmniMini => 128000,
145            Self::Claude3_5Sonnet
146            | Self::Claude3Opus
147            | Self::Claude3Sonnet
148            | Self::Claude3Haiku => 200000,
149            Self::Gemini15Pro => 128000,
150            Self::Gemini15Flash => 32000,
151            Self::Custom(_) => 4096, // TODO: Make this configurable
152        }
153    }
154
155    pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
156        match self {
157            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
158                preprocess_anthropic_request(request)
159            }
160            _ => {}
161        }
162    }
163}
164
165#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
166#[serde(rename_all = "snake_case")]
167pub enum AssistantDockPosition {
168    Left,
169    #[default]
170    Right,
171    Bottom,
172}
173
174#[derive(Debug, PartialEq)]
175pub enum AssistantProvider {
176    ZedDotDev {
177        model: CloudModel,
178    },
179    OpenAi {
180        model: OpenAiModel,
181        api_url: String,
182        low_speed_timeout_in_seconds: Option<u64>,
183        available_models: Vec<OpenAiModel>,
184    },
185    Anthropic {
186        model: AnthropicModel,
187        api_url: String,
188        low_speed_timeout_in_seconds: Option<u64>,
189    },
190    Ollama {
191        model: OllamaModel,
192        api_url: String,
193        low_speed_timeout_in_seconds: Option<u64>,
194    },
195}
196
197impl Default for AssistantProvider {
198    fn default() -> Self {
199        Self::OpenAi {
200            model: OpenAiModel::default(),
201            api_url: open_ai::OPEN_AI_API_URL.into(),
202            low_speed_timeout_in_seconds: None,
203            available_models: Default::default(),
204        }
205    }
206}
207
208#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
209#[serde(tag = "name", rename_all = "snake_case")]
210pub enum AssistantProviderContent {
211    #[serde(rename = "zed.dev")]
212    ZedDotDev { default_model: Option<CloudModel> },
213    #[serde(rename = "openai")]
214    OpenAi {
215        default_model: Option<OpenAiModel>,
216        api_url: Option<String>,
217        low_speed_timeout_in_seconds: Option<u64>,
218        available_models: Option<Vec<OpenAiModel>>,
219    },
220    #[serde(rename = "anthropic")]
221    Anthropic {
222        default_model: Option<AnthropicModel>,
223        api_url: Option<String>,
224        low_speed_timeout_in_seconds: Option<u64>,
225    },
226    #[serde(rename = "ollama")]
227    Ollama {
228        default_model: Option<OllamaModel>,
229        api_url: Option<String>,
230        low_speed_timeout_in_seconds: Option<u64>,
231    },
232}
233
234#[derive(Debug, Default)]
235pub struct AssistantSettings {
236    pub enabled: bool,
237    pub button: bool,
238    pub dock: AssistantDockPosition,
239    pub default_width: Pixels,
240    pub default_height: Pixels,
241    pub provider: AssistantProvider,
242}
243
244/// Assistant panel settings
245#[derive(Clone, Serialize, Deserialize, Debug)]
246#[serde(untagged)]
247pub enum AssistantSettingsContent {
248    Versioned(VersionedAssistantSettingsContent),
249    Legacy(LegacyAssistantSettingsContent),
250}
251
252impl JsonSchema for AssistantSettingsContent {
253    fn schema_name() -> String {
254        VersionedAssistantSettingsContent::schema_name()
255    }
256
257    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
258        VersionedAssistantSettingsContent::json_schema(gen)
259    }
260
261    fn is_referenceable() -> bool {
262        VersionedAssistantSettingsContent::is_referenceable()
263    }
264}
265
266impl Default for AssistantSettingsContent {
267    fn default() -> Self {
268        Self::Versioned(VersionedAssistantSettingsContent::default())
269    }
270}
271
272impl AssistantSettingsContent {
273    fn upgrade(&self) -> AssistantSettingsContentV1 {
274        match self {
275            AssistantSettingsContent::Versioned(settings) => match settings {
276                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
277            },
278            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
279                enabled: None,
280                button: settings.button,
281                dock: settings.dock,
282                default_width: settings.default_width,
283                default_height: settings.default_height,
284                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
285                    Some(AssistantProviderContent::OpenAi {
286                        default_model: settings.default_open_ai_model.clone(),
287                        api_url: Some(open_ai_api_url.clone()),
288                        low_speed_timeout_in_seconds: None,
289                        available_models: Some(Default::default()),
290                    })
291                } else {
292                    settings.default_open_ai_model.clone().map(|open_ai_model| {
293                        AssistantProviderContent::OpenAi {
294                            default_model: Some(open_ai_model),
295                            api_url: None,
296                            low_speed_timeout_in_seconds: None,
297                            available_models: Some(Default::default()),
298                        }
299                    })
300                },
301            },
302        }
303    }
304
305    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
306        match self {
307            AssistantSettingsContent::Versioned(settings) => match settings {
308                VersionedAssistantSettingsContent::V1(settings) => {
309                    settings.dock = Some(dock);
310                }
311            },
312            AssistantSettingsContent::Legacy(settings) => {
313                settings.dock = Some(dock);
314            }
315        }
316    }
317
318    pub fn set_model(&mut self, new_model: LanguageModel) {
319        match self {
320            AssistantSettingsContent::Versioned(settings) => match settings {
321                VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
322                    Some(AssistantProviderContent::ZedDotDev {
323                        default_model: model,
324                    }) => {
325                        if let LanguageModel::Cloud(new_model) = new_model {
326                            *model = Some(new_model);
327                        }
328                    }
329                    Some(AssistantProviderContent::OpenAi {
330                        default_model: model,
331                        ..
332                    }) => {
333                        if let LanguageModel::OpenAi(new_model) = new_model {
334                            *model = Some(new_model);
335                        }
336                    }
337                    Some(AssistantProviderContent::Anthropic {
338                        default_model: model,
339                        ..
340                    }) => {
341                        if let LanguageModel::Anthropic(new_model) = new_model {
342                            *model = Some(new_model);
343                        }
344                    }
345                    Some(AssistantProviderContent::Ollama {
346                        default_model: model,
347                        ..
348                    }) => {
349                        if let LanguageModel::Ollama(new_model) = new_model {
350                            *model = Some(new_model);
351                        }
352                    }
353                    provider => match new_model {
354                        LanguageModel::Cloud(model) => {
355                            *provider = Some(AssistantProviderContent::ZedDotDev {
356                                default_model: Some(model),
357                            })
358                        }
359                        LanguageModel::OpenAi(model) => {
360                            *provider = Some(AssistantProviderContent::OpenAi {
361                                default_model: Some(model),
362                                api_url: None,
363                                low_speed_timeout_in_seconds: None,
364                                available_models: Some(Default::default()),
365                            })
366                        }
367                        LanguageModel::Anthropic(model) => {
368                            *provider = Some(AssistantProviderContent::Anthropic {
369                                default_model: Some(model),
370                                api_url: None,
371                                low_speed_timeout_in_seconds: None,
372                            })
373                        }
374                        LanguageModel::Ollama(model) => {
375                            *provider = Some(AssistantProviderContent::Ollama {
376                                default_model: Some(model),
377                                api_url: None,
378                                low_speed_timeout_in_seconds: None,
379                            })
380                        }
381                    },
382                },
383            },
384            AssistantSettingsContent::Legacy(settings) => {
385                if let LanguageModel::OpenAi(model) = new_model {
386                    settings.default_open_ai_model = Some(model);
387                }
388            }
389        }
390    }
391}
392
393#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
394#[serde(tag = "version")]
395pub enum VersionedAssistantSettingsContent {
396    #[serde(rename = "1")]
397    V1(AssistantSettingsContentV1),
398}
399
400impl Default for VersionedAssistantSettingsContent {
401    fn default() -> Self {
402        Self::V1(AssistantSettingsContentV1 {
403            enabled: None,
404            button: None,
405            dock: None,
406            default_width: None,
407            default_height: None,
408            provider: None,
409        })
410    }
411}
412
413#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
414pub struct AssistantSettingsContentV1 {
415    /// Whether the Assistant is enabled.
416    ///
417    /// Default: true
418    enabled: Option<bool>,
419    /// Whether to show the assistant panel button in the status bar.
420    ///
421    /// Default: true
422    button: Option<bool>,
423    /// Where to dock the assistant.
424    ///
425    /// Default: right
426    dock: Option<AssistantDockPosition>,
427    /// Default width in pixels when the assistant is docked to the left or right.
428    ///
429    /// Default: 640
430    default_width: Option<f32>,
431    /// Default height in pixels when the assistant is docked to the bottom.
432    ///
433    /// Default: 320
434    default_height: Option<f32>,
435    /// The provider of the assistant service.
436    ///
437    /// This can either be the internal `zed.dev` service or an external `openai` service,
438    /// each with their respective default models and configurations.
439    provider: Option<AssistantProviderContent>,
440}
441
442#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
443pub struct LegacyAssistantSettingsContent {
444    /// Whether to show the assistant panel button in the status bar.
445    ///
446    /// Default: true
447    pub button: Option<bool>,
448    /// Where to dock the assistant.
449    ///
450    /// Default: right
451    pub dock: Option<AssistantDockPosition>,
452    /// Default width in pixels when the assistant is docked to the left or right.
453    ///
454    /// Default: 640
455    pub default_width: Option<f32>,
456    /// Default height in pixels when the assistant is docked to the bottom.
457    ///
458    /// Default: 320
459    pub default_height: Option<f32>,
460    /// The default OpenAI model to use when creating new contexts.
461    ///
462    /// Default: gpt-4-1106-preview
463    pub default_open_ai_model: Option<OpenAiModel>,
464    /// OpenAI API base URL to use when creating new contexts.
465    ///
466    /// Default: https://api.openai.com/v1
467    pub openai_api_url: Option<String>,
468}
469
470impl Settings for AssistantSettings {
471    const KEY: Option<&'static str> = Some("assistant");
472
473    type FileContent = AssistantSettingsContent;
474
475    fn load(
476        sources: SettingsSources<Self::FileContent>,
477        _: &mut gpui::AppContext,
478    ) -> anyhow::Result<Self> {
479        let mut settings = AssistantSettings::default();
480
481        for value in sources.defaults_and_customizations() {
482            let value = value.upgrade();
483            merge(&mut settings.enabled, value.enabled);
484            merge(&mut settings.button, value.button);
485            merge(&mut settings.dock, value.dock);
486            merge(
487                &mut settings.default_width,
488                value.default_width.map(Into::into),
489            );
490            merge(
491                &mut settings.default_height,
492                value.default_height.map(Into::into),
493            );
494            if let Some(provider) = value.provider.clone() {
495                match (&mut settings.provider, provider) {
496                    (
497                        AssistantProvider::ZedDotDev { model },
498                        AssistantProviderContent::ZedDotDev {
499                            default_model: model_override,
500                        },
501                    ) => {
502                        merge(model, model_override);
503                    }
504                    (
505                        AssistantProvider::OpenAi {
506                            model,
507                            api_url,
508                            low_speed_timeout_in_seconds,
509                            available_models,
510                        },
511                        AssistantProviderContent::OpenAi {
512                            default_model: model_override,
513                            api_url: api_url_override,
514                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
515                            available_models: available_models_override,
516                        },
517                    ) => {
518                        merge(model, model_override);
519                        merge(api_url, api_url_override);
520                        merge(available_models, available_models_override);
521                        if let Some(low_speed_timeout_in_seconds_override) =
522                            low_speed_timeout_in_seconds_override
523                        {
524                            *low_speed_timeout_in_seconds =
525                                Some(low_speed_timeout_in_seconds_override);
526                        }
527                    }
528                    (
529                        AssistantProvider::Ollama {
530                            model,
531                            api_url,
532                            low_speed_timeout_in_seconds,
533                        },
534                        AssistantProviderContent::Ollama {
535                            default_model: model_override,
536                            api_url: api_url_override,
537                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
538                        },
539                    ) => {
540                        merge(model, model_override);
541                        merge(api_url, api_url_override);
542                        if let Some(low_speed_timeout_in_seconds_override) =
543                            low_speed_timeout_in_seconds_override
544                        {
545                            *low_speed_timeout_in_seconds =
546                                Some(low_speed_timeout_in_seconds_override);
547                        }
548                    }
549                    (
550                        AssistantProvider::Anthropic {
551                            model,
552                            api_url,
553                            low_speed_timeout_in_seconds,
554                        },
555                        AssistantProviderContent::Anthropic {
556                            default_model: model_override,
557                            api_url: api_url_override,
558                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
559                        },
560                    ) => {
561                        merge(model, model_override);
562                        merge(api_url, api_url_override);
563                        if let Some(low_speed_timeout_in_seconds_override) =
564                            low_speed_timeout_in_seconds_override
565                        {
566                            *low_speed_timeout_in_seconds =
567                                Some(low_speed_timeout_in_seconds_override);
568                        }
569                    }
570                    (provider, provider_override) => {
571                        *provider = match provider_override {
572                            AssistantProviderContent::ZedDotDev {
573                                default_model: model,
574                            } => AssistantProvider::ZedDotDev {
575                                model: model.unwrap_or_default(),
576                            },
577                            AssistantProviderContent::OpenAi {
578                                default_model: model,
579                                api_url,
580                                low_speed_timeout_in_seconds,
581                                available_models,
582                            } => AssistantProvider::OpenAi {
583                                model: model.unwrap_or_default(),
584                                api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
585                                low_speed_timeout_in_seconds,
586                                available_models: available_models.unwrap_or_default(),
587                            },
588                            AssistantProviderContent::Anthropic {
589                                default_model: model,
590                                api_url,
591                                low_speed_timeout_in_seconds,
592                            } => AssistantProvider::Anthropic {
593                                model: model.unwrap_or_default(),
594                                api_url: api_url
595                                    .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
596                                low_speed_timeout_in_seconds,
597                            },
598                            AssistantProviderContent::Ollama {
599                                default_model: model,
600                                api_url,
601                                low_speed_timeout_in_seconds,
602                            } => AssistantProvider::Ollama {
603                                model: model.unwrap_or_default(),
604                                api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
605                                low_speed_timeout_in_seconds,
606                            },
607                        };
608                    }
609                }
610            }
611        }
612
613        Ok(settings)
614    }
615}
616
617fn merge<T>(target: &mut T, value: Option<T>) {
618    if let Some(value) = value {
619        *target = value;
620    }
621}
622
623#[cfg(test)]
624mod tests {
625    use gpui::{AppContext, UpdateGlobal};
626    use settings::SettingsStore;
627
628    use super::*;
629
630    #[gpui::test]
631    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
632        let store = settings::SettingsStore::test(cx);
633        cx.set_global(store);
634
635        // Settings default to gpt-4-turbo.
636        AssistantSettings::register(cx);
637        assert_eq!(
638            AssistantSettings::get_global(cx).provider,
639            AssistantProvider::OpenAi {
640                model: OpenAiModel::FourOmni,
641                api_url: open_ai::OPEN_AI_API_URL.into(),
642                low_speed_timeout_in_seconds: None,
643                available_models: Default::default(),
644            }
645        );
646
647        // Ensure backward-compatibility.
648        SettingsStore::update_global(cx, |store, cx| {
649            store
650                .set_user_settings(
651                    r#"{
652                        "assistant": {
653                            "openai_api_url": "test-url",
654                        }
655                    }"#,
656                    cx,
657                )
658                .unwrap();
659        });
660        assert_eq!(
661            AssistantSettings::get_global(cx).provider,
662            AssistantProvider::OpenAi {
663                model: OpenAiModel::FourOmni,
664                api_url: "test-url".into(),
665                low_speed_timeout_in_seconds: None,
666                available_models: Default::default(),
667            }
668        );
669        SettingsStore::update_global(cx, |store, cx| {
670            store
671                .set_user_settings(
672                    r#"{
673                        "assistant": {
674                            "default_open_ai_model": "gpt-4-0613"
675                        }
676                    }"#,
677                    cx,
678                )
679                .unwrap();
680        });
681        assert_eq!(
682            AssistantSettings::get_global(cx).provider,
683            AssistantProvider::OpenAi {
684                model: OpenAiModel::Four,
685                api_url: open_ai::OPEN_AI_API_URL.into(),
686                low_speed_timeout_in_seconds: None,
687                available_models: Default::default(),
688            }
689        );
690
691        // The new version supports setting a custom model when using zed.dev.
692        SettingsStore::update_global(cx, |store, cx| {
693            store
694                .set_user_settings(
695                    r#"{
696                        "assistant": {
697                            "version": "1",
698                            "provider": {
699                                "name": "zed.dev",
700                                "default_model": "custom"
701                            }
702                        }
703                    }"#,
704                    cx,
705                )
706                .unwrap();
707        });
708        assert_eq!(
709            AssistantSettings::get_global(cx).provider,
710            AssistantProvider::ZedDotDev {
711                model: CloudModel::Custom("custom".into())
712            }
713        );
714    }
715}