assistant_settings.rs

  1use std::fmt;
  2
  3use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
  4pub use anthropic::Model as AnthropicModel;
  5use gpui::Pixels;
  6pub use ollama::Model as OllamaModel;
  7pub use open_ai::Model as OpenAiModel;
  8use schemars::{
  9    schema::{InstanceType, Metadata, Schema, SchemaObject},
 10    JsonSchema,
 11};
 12use serde::{
 13    de::{self, Visitor},
 14    Deserialize, Deserializer, Serialize, Serializer,
 15};
 16use settings::{Settings, SettingsSources};
 17use strum::{EnumIter, IntoEnumIterator};
 18
 19#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 20pub enum CloudModel {
 21    Gpt3Point5Turbo,
 22    Gpt4,
 23    Gpt4Turbo,
 24    #[default]
 25    Gpt4Omni,
 26    Claude3_5Sonnet,
 27    Claude3Opus,
 28    Claude3Sonnet,
 29    Claude3Haiku,
 30    Gemini15Pro,
 31    Gemini15Flash,
 32    Custom(String),
 33}
 34
 35impl Serialize for CloudModel {
 36    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 37    where
 38        S: Serializer,
 39    {
 40        serializer.serialize_str(self.id())
 41    }
 42}
 43
 44impl<'de> Deserialize<'de> for CloudModel {
 45    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 46    where
 47        D: Deserializer<'de>,
 48    {
 49        struct ZedDotDevModelVisitor;
 50
 51        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 52            type Value = CloudModel;
 53
 54            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 55                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 56            }
 57
 58            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 59            where
 60                E: de::Error,
 61            {
 62                let model = CloudModel::iter()
 63                    .find(|model| model.id() == value)
 64                    .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
 65                Ok(model)
 66            }
 67        }
 68
 69        deserializer.deserialize_str(ZedDotDevModelVisitor)
 70    }
 71}
 72
 73impl JsonSchema for CloudModel {
 74    fn schema_name() -> String {
 75        "ZedDotDevModel".to_owned()
 76    }
 77
 78    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 79        let variants = CloudModel::iter()
 80            .filter_map(|model| {
 81                let id = model.id();
 82                if id.is_empty() {
 83                    None
 84                } else {
 85                    Some(id.to_string())
 86                }
 87            })
 88            .collect::<Vec<_>>();
 89        Schema::Object(SchemaObject {
 90            instance_type: Some(InstanceType::String.into()),
 91            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
 92            metadata: Some(Box::new(Metadata {
 93                title: Some("ZedDotDevModel".to_owned()),
 94                default: Some(CloudModel::default().id().into()),
 95                examples: variants.into_iter().map(Into::into).collect(),
 96                ..Default::default()
 97            })),
 98            ..Default::default()
 99        })
100    }
101}
102
103impl CloudModel {
104    pub fn id(&self) -> &str {
105        match self {
106            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
107            Self::Gpt4 => "gpt-4",
108            Self::Gpt4Turbo => "gpt-4-turbo-preview",
109            Self::Gpt4Omni => "gpt-4o",
110            Self::Claude3_5Sonnet => "claude-3-5-sonnet",
111            Self::Claude3Opus => "claude-3-opus",
112            Self::Claude3Sonnet => "claude-3-sonnet",
113            Self::Claude3Haiku => "claude-3-haiku",
114            Self::Gemini15Pro => "gemini-1.5-pro",
115            Self::Gemini15Flash => "gemini-1.5-flash",
116            Self::Custom(id) => id,
117        }
118    }
119
120    pub fn display_name(&self) -> &str {
121        match self {
122            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
123            Self::Gpt4 => "GPT 4",
124            Self::Gpt4Turbo => "GPT 4 Turbo",
125            Self::Gpt4Omni => "GPT 4 Omni",
126            Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
127            Self::Claude3Opus => "Claude 3 Opus",
128            Self::Claude3Sonnet => "Claude 3 Sonnet",
129            Self::Claude3Haiku => "Claude 3 Haiku",
130            Self::Gemini15Pro => "Gemini 1.5 Pro",
131            Self::Gemini15Flash => "Gemini 1.5 Flash",
132            Self::Custom(id) => id.as_str(),
133        }
134    }
135
136    pub fn max_token_count(&self) -> usize {
137        match self {
138            Self::Gpt3Point5Turbo => 2048,
139            Self::Gpt4 => 4096,
140            Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
141            Self::Claude3_5Sonnet
142            | Self::Claude3Opus
143            | Self::Claude3Sonnet
144            | Self::Claude3Haiku => 200000,
145            Self::Gemini15Pro => 128000,
146            Self::Gemini15Flash => 32000,
147            Self::Custom(_) => 4096, // TODO: Make this configurable
148        }
149    }
150
151    pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
152        match self {
153            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
154                preprocess_anthropic_request(request)
155            }
156            _ => {}
157        }
158    }
159}
160
161#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
162#[serde(rename_all = "snake_case")]
163pub enum AssistantDockPosition {
164    Left,
165    #[default]
166    Right,
167    Bottom,
168}
169
170#[derive(Debug, PartialEq)]
171pub enum AssistantProvider {
172    ZedDotDev {
173        model: CloudModel,
174    },
175    OpenAi {
176        model: OpenAiModel,
177        api_url: String,
178        low_speed_timeout_in_seconds: Option<u64>,
179        available_models: Vec<OpenAiModel>,
180    },
181    Anthropic {
182        model: AnthropicModel,
183        api_url: String,
184        low_speed_timeout_in_seconds: Option<u64>,
185    },
186    Ollama {
187        model: OllamaModel,
188        api_url: String,
189        low_speed_timeout_in_seconds: Option<u64>,
190    },
191}
192
193impl Default for AssistantProvider {
194    fn default() -> Self {
195        Self::OpenAi {
196            model: OpenAiModel::default(),
197            api_url: open_ai::OPEN_AI_API_URL.into(),
198            low_speed_timeout_in_seconds: None,
199            available_models: Default::default(),
200        }
201    }
202}
203
204#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
205#[serde(tag = "name", rename_all = "snake_case")]
206pub enum AssistantProviderContent {
207    #[serde(rename = "zed.dev")]
208    ZedDotDev { default_model: Option<CloudModel> },
209    #[serde(rename = "openai")]
210    OpenAi {
211        default_model: Option<OpenAiModel>,
212        api_url: Option<String>,
213        low_speed_timeout_in_seconds: Option<u64>,
214        available_models: Option<Vec<OpenAiModel>>,
215    },
216    #[serde(rename = "anthropic")]
217    Anthropic {
218        default_model: Option<AnthropicModel>,
219        api_url: Option<String>,
220        low_speed_timeout_in_seconds: Option<u64>,
221    },
222    #[serde(rename = "ollama")]
223    Ollama {
224        default_model: Option<OllamaModel>,
225        api_url: Option<String>,
226        low_speed_timeout_in_seconds: Option<u64>,
227    },
228}
229
230#[derive(Debug, Default)]
231pub struct AssistantSettings {
232    pub enabled: bool,
233    pub button: bool,
234    pub dock: AssistantDockPosition,
235    pub default_width: Pixels,
236    pub default_height: Pixels,
237    pub provider: AssistantProvider,
238}
239
240/// Assistant panel settings
241#[derive(Clone, Serialize, Deserialize, Debug)]
242#[serde(untagged)]
243pub enum AssistantSettingsContent {
244    Versioned(VersionedAssistantSettingsContent),
245    Legacy(LegacyAssistantSettingsContent),
246}
247
248impl JsonSchema for AssistantSettingsContent {
249    fn schema_name() -> String {
250        VersionedAssistantSettingsContent::schema_name()
251    }
252
253    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
254        VersionedAssistantSettingsContent::json_schema(gen)
255    }
256
257    fn is_referenceable() -> bool {
258        VersionedAssistantSettingsContent::is_referenceable()
259    }
260}
261
262impl Default for AssistantSettingsContent {
263    fn default() -> Self {
264        Self::Versioned(VersionedAssistantSettingsContent::default())
265    }
266}
267
268impl AssistantSettingsContent {
269    fn upgrade(&self) -> AssistantSettingsContentV1 {
270        match self {
271            AssistantSettingsContent::Versioned(settings) => match settings {
272                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
273            },
274            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
275                enabled: None,
276                button: settings.button,
277                dock: settings.dock,
278                default_width: settings.default_width,
279                default_height: settings.default_height,
280                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
281                    Some(AssistantProviderContent::OpenAi {
282                        default_model: settings.default_open_ai_model.clone(),
283                        api_url: Some(open_ai_api_url.clone()),
284                        low_speed_timeout_in_seconds: None,
285                        available_models: Some(Default::default()),
286                    })
287                } else {
288                    settings.default_open_ai_model.clone().map(|open_ai_model| {
289                        AssistantProviderContent::OpenAi {
290                            default_model: Some(open_ai_model),
291                            api_url: None,
292                            low_speed_timeout_in_seconds: None,
293                            available_models: Some(Default::default()),
294                        }
295                    })
296                },
297            },
298        }
299    }
300
301    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
302        match self {
303            AssistantSettingsContent::Versioned(settings) => match settings {
304                VersionedAssistantSettingsContent::V1(settings) => {
305                    settings.dock = Some(dock);
306                }
307            },
308            AssistantSettingsContent::Legacy(settings) => {
309                settings.dock = Some(dock);
310            }
311        }
312    }
313
314    pub fn set_model(&mut self, new_model: LanguageModel) {
315        match self {
316            AssistantSettingsContent::Versioned(settings) => match settings {
317                VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
318                    Some(AssistantProviderContent::ZedDotDev {
319                        default_model: model,
320                    }) => {
321                        if let LanguageModel::Cloud(new_model) = new_model {
322                            *model = Some(new_model);
323                        }
324                    }
325                    Some(AssistantProviderContent::OpenAi {
326                        default_model: model,
327                        ..
328                    }) => {
329                        if let LanguageModel::OpenAi(new_model) = new_model {
330                            *model = Some(new_model);
331                        }
332                    }
333                    Some(AssistantProviderContent::Anthropic {
334                        default_model: model,
335                        ..
336                    }) => {
337                        if let LanguageModel::Anthropic(new_model) = new_model {
338                            *model = Some(new_model);
339                        }
340                    }
341                    Some(AssistantProviderContent::Ollama {
342                        default_model: model,
343                        ..
344                    }) => {
345                        if let LanguageModel::Ollama(new_model) = new_model {
346                            *model = Some(new_model);
347                        }
348                    }
349                    provider => match new_model {
350                        LanguageModel::Cloud(model) => {
351                            *provider = Some(AssistantProviderContent::ZedDotDev {
352                                default_model: Some(model),
353                            })
354                        }
355                        LanguageModel::OpenAi(model) => {
356                            *provider = Some(AssistantProviderContent::OpenAi {
357                                default_model: Some(model),
358                                api_url: None,
359                                low_speed_timeout_in_seconds: None,
360                                available_models: Some(Default::default()),
361                            })
362                        }
363                        LanguageModel::Anthropic(model) => {
364                            *provider = Some(AssistantProviderContent::Anthropic {
365                                default_model: Some(model),
366                                api_url: None,
367                                low_speed_timeout_in_seconds: None,
368                            })
369                        }
370                        LanguageModel::Ollama(model) => {
371                            *provider = Some(AssistantProviderContent::Ollama {
372                                default_model: Some(model),
373                                api_url: None,
374                                low_speed_timeout_in_seconds: None,
375                            })
376                        }
377                    },
378                },
379            },
380            AssistantSettingsContent::Legacy(settings) => {
381                if let LanguageModel::OpenAi(model) = new_model {
382                    settings.default_open_ai_model = Some(model);
383                }
384            }
385        }
386    }
387}
388
389#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
390#[serde(tag = "version")]
391pub enum VersionedAssistantSettingsContent {
392    #[serde(rename = "1")]
393    V1(AssistantSettingsContentV1),
394}
395
396impl Default for VersionedAssistantSettingsContent {
397    fn default() -> Self {
398        Self::V1(AssistantSettingsContentV1 {
399            enabled: None,
400            button: None,
401            dock: None,
402            default_width: None,
403            default_height: None,
404            provider: None,
405        })
406    }
407}
408
409#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
410pub struct AssistantSettingsContentV1 {
411    /// Whether the Assistant is enabled.
412    ///
413    /// Default: true
414    enabled: Option<bool>,
415    /// Whether to show the assistant panel button in the status bar.
416    ///
417    /// Default: true
418    button: Option<bool>,
419    /// Where to dock the assistant.
420    ///
421    /// Default: right
422    dock: Option<AssistantDockPosition>,
423    /// Default width in pixels when the assistant is docked to the left or right.
424    ///
425    /// Default: 640
426    default_width: Option<f32>,
427    /// Default height in pixels when the assistant is docked to the bottom.
428    ///
429    /// Default: 320
430    default_height: Option<f32>,
431    /// The provider of the assistant service.
432    ///
433    /// This can either be the internal `zed.dev` service or an external `openai` service,
434    /// each with their respective default models and configurations.
435    provider: Option<AssistantProviderContent>,
436}
437
438#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
439pub struct LegacyAssistantSettingsContent {
440    /// Whether to show the assistant panel button in the status bar.
441    ///
442    /// Default: true
443    pub button: Option<bool>,
444    /// Where to dock the assistant.
445    ///
446    /// Default: right
447    pub dock: Option<AssistantDockPosition>,
448    /// Default width in pixels when the assistant is docked to the left or right.
449    ///
450    /// Default: 640
451    pub default_width: Option<f32>,
452    /// Default height in pixels when the assistant is docked to the bottom.
453    ///
454    /// Default: 320
455    pub default_height: Option<f32>,
456    /// The default OpenAI model to use when creating new contexts.
457    ///
458    /// Default: gpt-4-1106-preview
459    pub default_open_ai_model: Option<OpenAiModel>,
460    /// OpenAI API base URL to use when creating new contexts.
461    ///
462    /// Default: https://api.openai.com/v1
463    pub openai_api_url: Option<String>,
464}
465
466impl Settings for AssistantSettings {
467    const KEY: Option<&'static str> = Some("assistant");
468
469    type FileContent = AssistantSettingsContent;
470
471    fn load(
472        sources: SettingsSources<Self::FileContent>,
473        _: &mut gpui::AppContext,
474    ) -> anyhow::Result<Self> {
475        let mut settings = AssistantSettings::default();
476
477        for value in sources.defaults_and_customizations() {
478            let value = value.upgrade();
479            merge(&mut settings.enabled, value.enabled);
480            merge(&mut settings.button, value.button);
481            merge(&mut settings.dock, value.dock);
482            merge(
483                &mut settings.default_width,
484                value.default_width.map(Into::into),
485            );
486            merge(
487                &mut settings.default_height,
488                value.default_height.map(Into::into),
489            );
490            if let Some(provider) = value.provider.clone() {
491                match (&mut settings.provider, provider) {
492                    (
493                        AssistantProvider::ZedDotDev { model },
494                        AssistantProviderContent::ZedDotDev {
495                            default_model: model_override,
496                        },
497                    ) => {
498                        merge(model, model_override);
499                    }
500                    (
501                        AssistantProvider::OpenAi {
502                            model,
503                            api_url,
504                            low_speed_timeout_in_seconds,
505                            available_models,
506                        },
507                        AssistantProviderContent::OpenAi {
508                            default_model: model_override,
509                            api_url: api_url_override,
510                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
511                            available_models: available_models_override,
512                        },
513                    ) => {
514                        merge(model, model_override);
515                        merge(api_url, api_url_override);
516                        merge(available_models, available_models_override);
517                        if let Some(low_speed_timeout_in_seconds_override) =
518                            low_speed_timeout_in_seconds_override
519                        {
520                            *low_speed_timeout_in_seconds =
521                                Some(low_speed_timeout_in_seconds_override);
522                        }
523                    }
524                    (
525                        AssistantProvider::Ollama {
526                            model,
527                            api_url,
528                            low_speed_timeout_in_seconds,
529                        },
530                        AssistantProviderContent::Ollama {
531                            default_model: model_override,
532                            api_url: api_url_override,
533                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
534                        },
535                    ) => {
536                        merge(model, model_override);
537                        merge(api_url, api_url_override);
538                        if let Some(low_speed_timeout_in_seconds_override) =
539                            low_speed_timeout_in_seconds_override
540                        {
541                            *low_speed_timeout_in_seconds =
542                                Some(low_speed_timeout_in_seconds_override);
543                        }
544                    }
545                    (
546                        AssistantProvider::Anthropic {
547                            model,
548                            api_url,
549                            low_speed_timeout_in_seconds,
550                        },
551                        AssistantProviderContent::Anthropic {
552                            default_model: model_override,
553                            api_url: api_url_override,
554                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
555                        },
556                    ) => {
557                        merge(model, model_override);
558                        merge(api_url, api_url_override);
559                        if let Some(low_speed_timeout_in_seconds_override) =
560                            low_speed_timeout_in_seconds_override
561                        {
562                            *low_speed_timeout_in_seconds =
563                                Some(low_speed_timeout_in_seconds_override);
564                        }
565                    }
566                    (provider, provider_override) => {
567                        *provider = match provider_override {
568                            AssistantProviderContent::ZedDotDev {
569                                default_model: model,
570                            } => AssistantProvider::ZedDotDev {
571                                model: model.unwrap_or_default(),
572                            },
573                            AssistantProviderContent::OpenAi {
574                                default_model: model,
575                                api_url,
576                                low_speed_timeout_in_seconds,
577                                available_models,
578                            } => AssistantProvider::OpenAi {
579                                model: model.unwrap_or_default(),
580                                api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
581                                low_speed_timeout_in_seconds,
582                                available_models: available_models.unwrap_or_default(),
583                            },
584                            AssistantProviderContent::Anthropic {
585                                default_model: model,
586                                api_url,
587                                low_speed_timeout_in_seconds,
588                            } => AssistantProvider::Anthropic {
589                                model: model.unwrap_or_default(),
590                                api_url: api_url
591                                    .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
592                                low_speed_timeout_in_seconds,
593                            },
594                            AssistantProviderContent::Ollama {
595                                default_model: model,
596                                api_url,
597                                low_speed_timeout_in_seconds,
598                            } => AssistantProvider::Ollama {
599                                model: model.unwrap_or_default(),
600                                api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
601                                low_speed_timeout_in_seconds,
602                            },
603                        };
604                    }
605                }
606            }
607        }
608
609        Ok(settings)
610    }
611}
612
613fn merge<T>(target: &mut T, value: Option<T>) {
614    if let Some(value) = value {
615        *target = value;
616    }
617}
618
619#[cfg(test)]
620mod tests {
621    use gpui::{AppContext, UpdateGlobal};
622    use settings::SettingsStore;
623
624    use super::*;
625
626    #[gpui::test]
627    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
628        let store = settings::SettingsStore::test(cx);
629        cx.set_global(store);
630
631        // Settings default to gpt-4-turbo.
632        AssistantSettings::register(cx);
633        assert_eq!(
634            AssistantSettings::get_global(cx).provider,
635            AssistantProvider::OpenAi {
636                model: OpenAiModel::FourOmni,
637                api_url: open_ai::OPEN_AI_API_URL.into(),
638                low_speed_timeout_in_seconds: None,
639                available_models: Default::default(),
640            }
641        );
642
643        // Ensure backward-compatibility.
644        SettingsStore::update_global(cx, |store, cx| {
645            store
646                .set_user_settings(
647                    r#"{
648                        "assistant": {
649                            "openai_api_url": "test-url",
650                        }
651                    }"#,
652                    cx,
653                )
654                .unwrap();
655        });
656        assert_eq!(
657            AssistantSettings::get_global(cx).provider,
658            AssistantProvider::OpenAi {
659                model: OpenAiModel::FourOmni,
660                api_url: "test-url".into(),
661                low_speed_timeout_in_seconds: None,
662                available_models: Default::default(),
663            }
664        );
665        SettingsStore::update_global(cx, |store, cx| {
666            store
667                .set_user_settings(
668                    r#"{
669                        "assistant": {
670                            "default_open_ai_model": "gpt-4-0613"
671                        }
672                    }"#,
673                    cx,
674                )
675                .unwrap();
676        });
677        assert_eq!(
678            AssistantSettings::get_global(cx).provider,
679            AssistantProvider::OpenAi {
680                model: OpenAiModel::Four,
681                api_url: open_ai::OPEN_AI_API_URL.into(),
682                low_speed_timeout_in_seconds: None,
683                available_models: Default::default(),
684            }
685        );
686
687        // The new version supports setting a custom model when using zed.dev.
688        SettingsStore::update_global(cx, |store, cx| {
689            store
690                .set_user_settings(
691                    r#"{
692                        "assistant": {
693                            "version": "1",
694                            "provider": {
695                                "name": "zed.dev",
696                                "default_model": "custom"
697                            }
698                        }
699                    }"#,
700                    cx,
701                )
702                .unwrap();
703        });
704        assert_eq!(
705            AssistantSettings::get_global(cx).provider,
706            AssistantProvider::ZedDotDev {
707                model: CloudModel::Custom("custom".into())
708            }
709        );
710    }
711}