assistant_settings.rs

  1use std::fmt;
  2
  3pub use anthropic::Model as AnthropicModel;
  4use gpui::Pixels;
  5pub use open_ai::Model as OpenAiModel;
  6use schemars::{
  7    schema::{InstanceType, Metadata, Schema, SchemaObject},
  8    JsonSchema,
  9};
 10use serde::{
 11    de::{self, Visitor},
 12    Deserialize, Deserializer, Serialize, Serializer,
 13};
 14use settings::{Settings, SettingsSources};
 15use strum::{EnumIter, IntoEnumIterator};
 16
 17use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
 18
 19#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
 20pub enum CloudModel {
 21    Gpt3Point5Turbo,
 22    Gpt4,
 23    Gpt4Turbo,
 24    #[default]
 25    Gpt4Omni,
 26    Claude3Opus,
 27    Claude3Sonnet,
 28    Claude3Haiku,
 29    Custom(String),
 30}
 31
 32impl Serialize for CloudModel {
 33    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
 34    where
 35        S: Serializer,
 36    {
 37        serializer.serialize_str(self.id())
 38    }
 39}
 40
 41impl<'de> Deserialize<'de> for CloudModel {
 42    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
 43    where
 44        D: Deserializer<'de>,
 45    {
 46        struct ZedDotDevModelVisitor;
 47
 48        impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
 49            type Value = CloudModel;
 50
 51            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
 52                formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
 53            }
 54
 55            fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
 56            where
 57                E: de::Error,
 58            {
 59                let model = CloudModel::iter()
 60                    .find(|model| model.id() == value)
 61                    .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
 62                Ok(model)
 63            }
 64        }
 65
 66        deserializer.deserialize_str(ZedDotDevModelVisitor)
 67    }
 68}
 69
 70impl JsonSchema for CloudModel {
 71    fn schema_name() -> String {
 72        "ZedDotDevModel".to_owned()
 73    }
 74
 75    fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
 76        let variants = CloudModel::iter()
 77            .filter_map(|model| {
 78                let id = model.id();
 79                if id.is_empty() {
 80                    None
 81                } else {
 82                    Some(id.to_string())
 83                }
 84            })
 85            .collect::<Vec<_>>();
 86        Schema::Object(SchemaObject {
 87            instance_type: Some(InstanceType::String.into()),
 88            enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
 89            metadata: Some(Box::new(Metadata {
 90                title: Some("ZedDotDevModel".to_owned()),
 91                default: Some(CloudModel::default().id().into()),
 92                examples: variants.into_iter().map(Into::into).collect(),
 93                ..Default::default()
 94            })),
 95            ..Default::default()
 96        })
 97    }
 98}
 99
100impl CloudModel {
101    pub fn id(&self) -> &str {
102        match self {
103            Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
104            Self::Gpt4 => "gpt-4",
105            Self::Gpt4Turbo => "gpt-4-turbo-preview",
106            Self::Gpt4Omni => "gpt-4o",
107            Self::Claude3Opus => "claude-3-opus",
108            Self::Claude3Sonnet => "claude-3-sonnet",
109            Self::Claude3Haiku => "claude-3-haiku",
110            Self::Custom(id) => id,
111        }
112    }
113
114    pub fn display_name(&self) -> &str {
115        match self {
116            Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
117            Self::Gpt4 => "GPT 4",
118            Self::Gpt4Turbo => "GPT 4 Turbo",
119            Self::Gpt4Omni => "GPT 4 Omni",
120            Self::Claude3Opus => "Claude 3 Opus",
121            Self::Claude3Sonnet => "Claude 3 Sonnet",
122            Self::Claude3Haiku => "Claude 3 Haiku",
123            Self::Custom(id) => id.as_str(),
124        }
125    }
126
127    pub fn max_token_count(&self) -> usize {
128        match self {
129            Self::Gpt3Point5Turbo => 2048,
130            Self::Gpt4 => 4096,
131            Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
132            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
133            Self::Custom(_) => 4096, // TODO: Make this configurable
134        }
135    }
136
137    pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
138        match self {
139            Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
140                preprocess_anthropic_request(request)
141            }
142            _ => {}
143        }
144    }
145}
146
147#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
148#[serde(rename_all = "snake_case")]
149pub enum AssistantDockPosition {
150    Left,
151    #[default]
152    Right,
153    Bottom,
154}
155
156#[derive(Debug, PartialEq)]
157pub enum AssistantProvider {
158    ZedDotDev {
159        model: CloudModel,
160    },
161    OpenAi {
162        model: OpenAiModel,
163        api_url: String,
164        low_speed_timeout_in_seconds: Option<u64>,
165    },
166    Anthropic {
167        model: AnthropicModel,
168        api_url: String,
169        low_speed_timeout_in_seconds: Option<u64>,
170    },
171}
172
173impl Default for AssistantProvider {
174    fn default() -> Self {
175        Self::OpenAi {
176            model: OpenAiModel::default(),
177            api_url: open_ai::OPEN_AI_API_URL.into(),
178            low_speed_timeout_in_seconds: None,
179        }
180    }
181}
182
183#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
184#[serde(tag = "name", rename_all = "snake_case")]
185pub enum AssistantProviderContent {
186    #[serde(rename = "zed.dev")]
187    ZedDotDev { default_model: Option<CloudModel> },
188    #[serde(rename = "openai")]
189    OpenAi {
190        default_model: Option<OpenAiModel>,
191        api_url: Option<String>,
192        low_speed_timeout_in_seconds: Option<u64>,
193    },
194    #[serde(rename = "anthropic")]
195    Anthropic {
196        default_model: Option<AnthropicModel>,
197        api_url: Option<String>,
198        low_speed_timeout_in_seconds: Option<u64>,
199    },
200}
201
202#[derive(Debug, Default)]
203pub struct AssistantSettings {
204    pub enabled: bool,
205    pub button: bool,
206    pub dock: AssistantDockPosition,
207    pub default_width: Pixels,
208    pub default_height: Pixels,
209    pub provider: AssistantProvider,
210}
211
212/// Assistant panel settings
213#[derive(Clone, Serialize, Deserialize, Debug)]
214#[serde(untagged)]
215pub enum AssistantSettingsContent {
216    Versioned(VersionedAssistantSettingsContent),
217    Legacy(LegacyAssistantSettingsContent),
218}
219
220impl JsonSchema for AssistantSettingsContent {
221    fn schema_name() -> String {
222        VersionedAssistantSettingsContent::schema_name()
223    }
224
225    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
226        VersionedAssistantSettingsContent::json_schema(gen)
227    }
228
229    fn is_referenceable() -> bool {
230        VersionedAssistantSettingsContent::is_referenceable()
231    }
232}
233
234impl Default for AssistantSettingsContent {
235    fn default() -> Self {
236        Self::Versioned(VersionedAssistantSettingsContent::default())
237    }
238}
239
240impl AssistantSettingsContent {
241    fn upgrade(&self) -> AssistantSettingsContentV1 {
242        match self {
243            AssistantSettingsContent::Versioned(settings) => match settings {
244                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
245            },
246            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
247                enabled: None,
248                button: settings.button,
249                dock: settings.dock,
250                default_width: settings.default_width,
251                default_height: settings.default_height,
252                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
253                    Some(AssistantProviderContent::OpenAi {
254                        default_model: settings.default_open_ai_model.clone(),
255                        api_url: Some(open_ai_api_url.clone()),
256                        low_speed_timeout_in_seconds: None,
257                    })
258                } else {
259                    settings.default_open_ai_model.clone().map(|open_ai_model| {
260                        AssistantProviderContent::OpenAi {
261                            default_model: Some(open_ai_model),
262                            api_url: None,
263                            low_speed_timeout_in_seconds: None,
264                        }
265                    })
266                },
267            },
268        }
269    }
270
271    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
272        match self {
273            AssistantSettingsContent::Versioned(settings) => match settings {
274                VersionedAssistantSettingsContent::V1(settings) => {
275                    settings.dock = Some(dock);
276                }
277            },
278            AssistantSettingsContent::Legacy(settings) => {
279                settings.dock = Some(dock);
280            }
281        }
282    }
283
284    pub fn set_model(&mut self, new_model: LanguageModel) {
285        match self {
286            AssistantSettingsContent::Versioned(settings) => match settings {
287                VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
288                    Some(AssistantProviderContent::ZedDotDev {
289                        default_model: model,
290                    }) => {
291                        if let LanguageModel::Cloud(new_model) = new_model {
292                            *model = Some(new_model);
293                        }
294                    }
295                    Some(AssistantProviderContent::OpenAi {
296                        default_model: model,
297                        ..
298                    }) => {
299                        if let LanguageModel::OpenAi(new_model) = new_model {
300                            *model = Some(new_model);
301                        }
302                    }
303                    Some(AssistantProviderContent::Anthropic {
304                        default_model: model,
305                        ..
306                    }) => {
307                        if let LanguageModel::Anthropic(new_model) = new_model {
308                            *model = Some(new_model);
309                        }
310                    }
311                    provider => match new_model {
312                        LanguageModel::Cloud(model) => {
313                            *provider = Some(AssistantProviderContent::ZedDotDev {
314                                default_model: Some(model),
315                            })
316                        }
317                        LanguageModel::OpenAi(model) => {
318                            *provider = Some(AssistantProviderContent::OpenAi {
319                                default_model: Some(model),
320                                api_url: None,
321                                low_speed_timeout_in_seconds: None,
322                            })
323                        }
324                        LanguageModel::Anthropic(model) => {
325                            *provider = Some(AssistantProviderContent::Anthropic {
326                                default_model: Some(model),
327                                api_url: None,
328                                low_speed_timeout_in_seconds: None,
329                            })
330                        }
331                    },
332                },
333            },
334            AssistantSettingsContent::Legacy(settings) => {
335                if let LanguageModel::OpenAi(model) = new_model {
336                    settings.default_open_ai_model = Some(model);
337                }
338            }
339        }
340    }
341}
342
343#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
344#[serde(tag = "version")]
345pub enum VersionedAssistantSettingsContent {
346    #[serde(rename = "1")]
347    V1(AssistantSettingsContentV1),
348}
349
350impl Default for VersionedAssistantSettingsContent {
351    fn default() -> Self {
352        Self::V1(AssistantSettingsContentV1 {
353            enabled: None,
354            button: None,
355            dock: None,
356            default_width: None,
357            default_height: None,
358            provider: None,
359        })
360    }
361}
362
363#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
364pub struct AssistantSettingsContentV1 {
365    /// Whether the Assistant is enabled.
366    ///
367    /// Default: true
368    enabled: Option<bool>,
369    /// Whether to show the assistant panel button in the status bar.
370    ///
371    /// Default: true
372    button: Option<bool>,
373    /// Where to dock the assistant.
374    ///
375    /// Default: right
376    dock: Option<AssistantDockPosition>,
377    /// Default width in pixels when the assistant is docked to the left or right.
378    ///
379    /// Default: 640
380    default_width: Option<f32>,
381    /// Default height in pixels when the assistant is docked to the bottom.
382    ///
383    /// Default: 320
384    default_height: Option<f32>,
385    /// The provider of the assistant service.
386    ///
387    /// This can either be the internal `zed.dev` service or an external `openai` service,
388    /// each with their respective default models and configurations.
389    provider: Option<AssistantProviderContent>,
390}
391
392#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
393pub struct LegacyAssistantSettingsContent {
394    /// Whether to show the assistant panel button in the status bar.
395    ///
396    /// Default: true
397    pub button: Option<bool>,
398    /// Where to dock the assistant.
399    ///
400    /// Default: right
401    pub dock: Option<AssistantDockPosition>,
402    /// Default width in pixels when the assistant is docked to the left or right.
403    ///
404    /// Default: 640
405    pub default_width: Option<f32>,
406    /// Default height in pixels when the assistant is docked to the bottom.
407    ///
408    /// Default: 320
409    pub default_height: Option<f32>,
410    /// The default OpenAI model to use when creating new contexts.
411    ///
412    /// Default: gpt-4-1106-preview
413    pub default_open_ai_model: Option<OpenAiModel>,
414    /// OpenAI API base URL to use when creating new contexts.
415    ///
416    /// Default: https://api.openai.com/v1
417    pub openai_api_url: Option<String>,
418}
419
420impl Settings for AssistantSettings {
421    const KEY: Option<&'static str> = Some("assistant");
422
423    type FileContent = AssistantSettingsContent;
424
425    fn load(
426        sources: SettingsSources<Self::FileContent>,
427        _: &mut gpui::AppContext,
428    ) -> anyhow::Result<Self> {
429        let mut settings = AssistantSettings::default();
430
431        for value in sources.defaults_and_customizations() {
432            let value = value.upgrade();
433            merge(&mut settings.enabled, value.enabled);
434            merge(&mut settings.button, value.button);
435            merge(&mut settings.dock, value.dock);
436            merge(
437                &mut settings.default_width,
438                value.default_width.map(Into::into),
439            );
440            merge(
441                &mut settings.default_height,
442                value.default_height.map(Into::into),
443            );
444            if let Some(provider) = value.provider.clone() {
445                match (&mut settings.provider, provider) {
446                    (
447                        AssistantProvider::ZedDotDev { model },
448                        AssistantProviderContent::ZedDotDev {
449                            default_model: model_override,
450                        },
451                    ) => {
452                        merge(model, model_override);
453                    }
454                    (
455                        AssistantProvider::OpenAi {
456                            model,
457                            api_url,
458                            low_speed_timeout_in_seconds,
459                        },
460                        AssistantProviderContent::OpenAi {
461                            default_model: model_override,
462                            api_url: api_url_override,
463                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
464                        },
465                    ) => {
466                        merge(model, model_override);
467                        merge(api_url, api_url_override);
468                        if let Some(low_speed_timeout_in_seconds_override) =
469                            low_speed_timeout_in_seconds_override
470                        {
471                            *low_speed_timeout_in_seconds =
472                                Some(low_speed_timeout_in_seconds_override);
473                        }
474                    }
475                    (
476                        AssistantProvider::Anthropic {
477                            model,
478                            api_url,
479                            low_speed_timeout_in_seconds,
480                        },
481                        AssistantProviderContent::Anthropic {
482                            default_model: model_override,
483                            api_url: api_url_override,
484                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
485                        },
486                    ) => {
487                        merge(model, model_override);
488                        merge(api_url, api_url_override);
489                        if let Some(low_speed_timeout_in_seconds_override) =
490                            low_speed_timeout_in_seconds_override
491                        {
492                            *low_speed_timeout_in_seconds =
493                                Some(low_speed_timeout_in_seconds_override);
494                        }
495                    }
496                    (provider, provider_override) => {
497                        *provider = match provider_override {
498                            AssistantProviderContent::ZedDotDev {
499                                default_model: model,
500                            } => AssistantProvider::ZedDotDev {
501                                model: model.unwrap_or_default(),
502                            },
503                            AssistantProviderContent::OpenAi {
504                                default_model: model,
505                                api_url,
506                                low_speed_timeout_in_seconds,
507                            } => AssistantProvider::OpenAi {
508                                model: model.unwrap_or_default(),
509                                api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
510                                low_speed_timeout_in_seconds,
511                            },
512                            AssistantProviderContent::Anthropic {
513                                default_model: model,
514                                api_url,
515                                low_speed_timeout_in_seconds,
516                            } => AssistantProvider::Anthropic {
517                                model: model.unwrap_or_default(),
518                                api_url: api_url
519                                    .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
520                                low_speed_timeout_in_seconds,
521                            },
522                        };
523                    }
524                }
525            }
526        }
527
528        Ok(settings)
529    }
530}
531
532fn merge<T>(target: &mut T, value: Option<T>) {
533    if let Some(value) = value {
534        *target = value;
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use gpui::{AppContext, UpdateGlobal};
541    use settings::SettingsStore;
542
543    use super::*;
544
545    #[gpui::test]
546    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
547        let store = settings::SettingsStore::test(cx);
548        cx.set_global(store);
549
550        // Settings default to gpt-4-turbo.
551        AssistantSettings::register(cx);
552        assert_eq!(
553            AssistantSettings::get_global(cx).provider,
554            AssistantProvider::OpenAi {
555                model: OpenAiModel::FourOmni,
556                api_url: open_ai::OPEN_AI_API_URL.into(),
557                low_speed_timeout_in_seconds: None,
558            }
559        );
560
561        // Ensure backward-compatibility.
562        SettingsStore::update_global(cx, |store, cx| {
563            store
564                .set_user_settings(
565                    r#"{
566                        "assistant": {
567                            "openai_api_url": "test-url",
568                        }
569                    }"#,
570                    cx,
571                )
572                .unwrap();
573        });
574        assert_eq!(
575            AssistantSettings::get_global(cx).provider,
576            AssistantProvider::OpenAi {
577                model: OpenAiModel::FourOmni,
578                api_url: "test-url".into(),
579                low_speed_timeout_in_seconds: None,
580            }
581        );
582        SettingsStore::update_global(cx, |store, cx| {
583            store
584                .set_user_settings(
585                    r#"{
586                        "assistant": {
587                            "default_open_ai_model": "gpt-4-0613"
588                        }
589                    }"#,
590                    cx,
591                )
592                .unwrap();
593        });
594        assert_eq!(
595            AssistantSettings::get_global(cx).provider,
596            AssistantProvider::OpenAi {
597                model: OpenAiModel::Four,
598                api_url: open_ai::OPEN_AI_API_URL.into(),
599                low_speed_timeout_in_seconds: None,
600            }
601        );
602
603        // The new version supports setting a custom model when using zed.dev.
604        SettingsStore::update_global(cx, |store, cx| {
605            store
606                .set_user_settings(
607                    r#"{
608                        "assistant": {
609                            "version": "1",
610                            "provider": {
611                                "name": "zed.dev",
612                                "default_model": "custom"
613                            }
614                        }
615                    }"#,
616                    cx,
617                )
618                .unwrap();
619        });
620        assert_eq!(
621            AssistantSettings::get_global(cx).provider,
622            AssistantProvider::ZedDotDev {
623                model: CloudModel::Custom("custom".into())
624            }
625        );
626    }
627}