assistant_settings.rs

  1use std::{sync::Arc, time::Duration};
  2
  3use anthropic::Model as AnthropicModel;
  4use client::Client;
  5use completion::{
  6    AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider,
  7    LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider,
  8};
  9use gpui::{AppContext, Pixels};
 10use language_model::{CloudModel, LanguageModel};
 11use ollama::Model as OllamaModel;
 12use open_ai::Model as OpenAiModel;
 13use parking_lot::RwLock;
 14use schemars::{schema::Schema, JsonSchema};
 15use serde::{Deserialize, Serialize};
 16use settings::{Settings, SettingsSources};
 17
 18#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
 19#[serde(rename_all = "snake_case")]
 20pub enum AssistantDockPosition {
 21    Left,
 22    #[default]
 23    Right,
 24    Bottom,
 25}
 26
 27#[derive(Debug, PartialEq)]
 28pub enum AssistantProvider {
 29    ZedDotDev {
 30        model: CloudModel,
 31    },
 32    OpenAi {
 33        model: OpenAiModel,
 34        api_url: String,
 35        low_speed_timeout_in_seconds: Option<u64>,
 36        available_models: Vec<OpenAiModel>,
 37    },
 38    Anthropic {
 39        model: AnthropicModel,
 40        api_url: String,
 41        low_speed_timeout_in_seconds: Option<u64>,
 42    },
 43    Ollama {
 44        model: OllamaModel,
 45        api_url: String,
 46        low_speed_timeout_in_seconds: Option<u64>,
 47    },
 48}
 49
 50impl Default for AssistantProvider {
 51    fn default() -> Self {
 52        Self::OpenAi {
 53            model: OpenAiModel::default(),
 54            api_url: open_ai::OPEN_AI_API_URL.into(),
 55            low_speed_timeout_in_seconds: None,
 56            available_models: Default::default(),
 57        }
 58    }
 59}
 60
 61#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
 62#[serde(tag = "name", rename_all = "snake_case")]
 63pub enum AssistantProviderContent {
 64    #[serde(rename = "zed.dev")]
 65    ZedDotDev { default_model: Option<CloudModel> },
 66    #[serde(rename = "openai")]
 67    OpenAi {
 68        default_model: Option<OpenAiModel>,
 69        api_url: Option<String>,
 70        low_speed_timeout_in_seconds: Option<u64>,
 71        available_models: Option<Vec<OpenAiModel>>,
 72    },
 73    #[serde(rename = "anthropic")]
 74    Anthropic {
 75        default_model: Option<AnthropicModel>,
 76        api_url: Option<String>,
 77        low_speed_timeout_in_seconds: Option<u64>,
 78    },
 79    #[serde(rename = "ollama")]
 80    Ollama {
 81        default_model: Option<OllamaModel>,
 82        api_url: Option<String>,
 83        low_speed_timeout_in_seconds: Option<u64>,
 84    },
 85}
 86
 87#[derive(Debug, Default)]
 88pub struct AssistantSettings {
 89    pub enabled: bool,
 90    pub button: bool,
 91    pub dock: AssistantDockPosition,
 92    pub default_width: Pixels,
 93    pub default_height: Pixels,
 94    pub provider: AssistantProvider,
 95}
 96
 97/// Assistant panel settings
 98#[derive(Clone, Serialize, Deserialize, Debug)]
 99#[serde(untagged)]
100pub enum AssistantSettingsContent {
101    Versioned(VersionedAssistantSettingsContent),
102    Legacy(LegacyAssistantSettingsContent),
103}
104
105impl JsonSchema for AssistantSettingsContent {
106    fn schema_name() -> String {
107        VersionedAssistantSettingsContent::schema_name()
108    }
109
110    fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
111        VersionedAssistantSettingsContent::json_schema(gen)
112    }
113
114    fn is_referenceable() -> bool {
115        VersionedAssistantSettingsContent::is_referenceable()
116    }
117}
118
119impl Default for AssistantSettingsContent {
120    fn default() -> Self {
121        Self::Versioned(VersionedAssistantSettingsContent::default())
122    }
123}
124
125impl AssistantSettingsContent {
126    fn upgrade(&self) -> AssistantSettingsContentV1 {
127        match self {
128            AssistantSettingsContent::Versioned(settings) => match settings {
129                VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
130            },
131            AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
132                enabled: None,
133                button: settings.button,
134                dock: settings.dock,
135                default_width: settings.default_width,
136                default_height: settings.default_height,
137                provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
138                    Some(AssistantProviderContent::OpenAi {
139                        default_model: settings.default_open_ai_model.clone(),
140                        api_url: Some(open_ai_api_url.clone()),
141                        low_speed_timeout_in_seconds: None,
142                        available_models: Some(Default::default()),
143                    })
144                } else {
145                    settings.default_open_ai_model.clone().map(|open_ai_model| {
146                        AssistantProviderContent::OpenAi {
147                            default_model: Some(open_ai_model),
148                            api_url: None,
149                            low_speed_timeout_in_seconds: None,
150                            available_models: Some(Default::default()),
151                        }
152                    })
153                },
154            },
155        }
156    }
157
158    pub fn set_dock(&mut self, dock: AssistantDockPosition) {
159        match self {
160            AssistantSettingsContent::Versioned(settings) => match settings {
161                VersionedAssistantSettingsContent::V1(settings) => {
162                    settings.dock = Some(dock);
163                }
164            },
165            AssistantSettingsContent::Legacy(settings) => {
166                settings.dock = Some(dock);
167            }
168        }
169    }
170
171    pub fn set_model(&mut self, new_model: LanguageModel) {
172        match self {
173            AssistantSettingsContent::Versioned(settings) => match settings {
174                VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
175                    Some(AssistantProviderContent::ZedDotDev {
176                        default_model: model,
177                    }) => {
178                        if let LanguageModel::Cloud(new_model) = new_model {
179                            *model = Some(new_model);
180                        }
181                    }
182                    Some(AssistantProviderContent::OpenAi {
183                        default_model: model,
184                        ..
185                    }) => {
186                        if let LanguageModel::OpenAi(new_model) = new_model {
187                            *model = Some(new_model);
188                        }
189                    }
190                    Some(AssistantProviderContent::Anthropic {
191                        default_model: model,
192                        ..
193                    }) => {
194                        if let LanguageModel::Anthropic(new_model) = new_model {
195                            *model = Some(new_model);
196                        }
197                    }
198                    Some(AssistantProviderContent::Ollama {
199                        default_model: model,
200                        ..
201                    }) => {
202                        if let LanguageModel::Ollama(new_model) = new_model {
203                            *model = Some(new_model);
204                        }
205                    }
206                    provider => match new_model {
207                        LanguageModel::Cloud(model) => {
208                            *provider = Some(AssistantProviderContent::ZedDotDev {
209                                default_model: Some(model),
210                            })
211                        }
212                        LanguageModel::OpenAi(model) => {
213                            *provider = Some(AssistantProviderContent::OpenAi {
214                                default_model: Some(model),
215                                api_url: None,
216                                low_speed_timeout_in_seconds: None,
217                                available_models: Some(Default::default()),
218                            })
219                        }
220                        LanguageModel::Anthropic(model) => {
221                            *provider = Some(AssistantProviderContent::Anthropic {
222                                default_model: Some(model),
223                                api_url: None,
224                                low_speed_timeout_in_seconds: None,
225                            })
226                        }
227                        LanguageModel::Ollama(model) => {
228                            *provider = Some(AssistantProviderContent::Ollama {
229                                default_model: Some(model),
230                                api_url: None,
231                                low_speed_timeout_in_seconds: None,
232                            })
233                        }
234                    },
235                },
236            },
237            AssistantSettingsContent::Legacy(settings) => {
238                if let LanguageModel::OpenAi(model) = new_model {
239                    settings.default_open_ai_model = Some(model);
240                }
241            }
242        }
243    }
244}
245
246#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
247#[serde(tag = "version")]
248pub enum VersionedAssistantSettingsContent {
249    #[serde(rename = "1")]
250    V1(AssistantSettingsContentV1),
251}
252
253impl Default for VersionedAssistantSettingsContent {
254    fn default() -> Self {
255        Self::V1(AssistantSettingsContentV1 {
256            enabled: None,
257            button: None,
258            dock: None,
259            default_width: None,
260            default_height: None,
261            provider: None,
262        })
263    }
264}
265
266#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
267pub struct AssistantSettingsContentV1 {
268    /// Whether the Assistant is enabled.
269    ///
270    /// Default: true
271    enabled: Option<bool>,
272    /// Whether to show the assistant panel button in the status bar.
273    ///
274    /// Default: true
275    button: Option<bool>,
276    /// Where to dock the assistant.
277    ///
278    /// Default: right
279    dock: Option<AssistantDockPosition>,
280    /// Default width in pixels when the assistant is docked to the left or right.
281    ///
282    /// Default: 640
283    default_width: Option<f32>,
284    /// Default height in pixels when the assistant is docked to the bottom.
285    ///
286    /// Default: 320
287    default_height: Option<f32>,
288    /// The provider of the assistant service.
289    ///
290    /// This can either be the internal `zed.dev` service or an external `openai` service,
291    /// each with their respective default models and configurations.
292    provider: Option<AssistantProviderContent>,
293}
294
295#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
296pub struct LegacyAssistantSettingsContent {
297    /// Whether to show the assistant panel button in the status bar.
298    ///
299    /// Default: true
300    pub button: Option<bool>,
301    /// Where to dock the assistant.
302    ///
303    /// Default: right
304    pub dock: Option<AssistantDockPosition>,
305    /// Default width in pixels when the assistant is docked to the left or right.
306    ///
307    /// Default: 640
308    pub default_width: Option<f32>,
309    /// Default height in pixels when the assistant is docked to the bottom.
310    ///
311    /// Default: 320
312    pub default_height: Option<f32>,
313    /// The default OpenAI model to use when creating new contexts.
314    ///
315    /// Default: gpt-4-1106-preview
316    pub default_open_ai_model: Option<OpenAiModel>,
317    /// OpenAI API base URL to use when creating new contexts.
318    ///
319    /// Default: https://api.openai.com/v1
320    pub openai_api_url: Option<String>,
321}
322
323impl Settings for AssistantSettings {
324    const KEY: Option<&'static str> = Some("assistant");
325
326    type FileContent = AssistantSettingsContent;
327
328    fn load(
329        sources: SettingsSources<Self::FileContent>,
330        _: &mut gpui::AppContext,
331    ) -> anyhow::Result<Self> {
332        let mut settings = AssistantSettings::default();
333
334        for value in sources.defaults_and_customizations() {
335            let value = value.upgrade();
336            merge(&mut settings.enabled, value.enabled);
337            merge(&mut settings.button, value.button);
338            merge(&mut settings.dock, value.dock);
339            merge(
340                &mut settings.default_width,
341                value.default_width.map(Into::into),
342            );
343            merge(
344                &mut settings.default_height,
345                value.default_height.map(Into::into),
346            );
347            if let Some(provider) = value.provider.clone() {
348                match (&mut settings.provider, provider) {
349                    (
350                        AssistantProvider::ZedDotDev { model },
351                        AssistantProviderContent::ZedDotDev {
352                            default_model: model_override,
353                        },
354                    ) => {
355                        merge(model, model_override);
356                    }
357                    (
358                        AssistantProvider::OpenAi {
359                            model,
360                            api_url,
361                            low_speed_timeout_in_seconds,
362                            available_models,
363                        },
364                        AssistantProviderContent::OpenAi {
365                            default_model: model_override,
366                            api_url: api_url_override,
367                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
368                            available_models: available_models_override,
369                        },
370                    ) => {
371                        merge(model, model_override);
372                        merge(api_url, api_url_override);
373                        merge(available_models, available_models_override);
374                        if let Some(low_speed_timeout_in_seconds_override) =
375                            low_speed_timeout_in_seconds_override
376                        {
377                            *low_speed_timeout_in_seconds =
378                                Some(low_speed_timeout_in_seconds_override);
379                        }
380                    }
381                    (
382                        AssistantProvider::Ollama {
383                            model,
384                            api_url,
385                            low_speed_timeout_in_seconds,
386                        },
387                        AssistantProviderContent::Ollama {
388                            default_model: model_override,
389                            api_url: api_url_override,
390                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
391                        },
392                    ) => {
393                        merge(model, model_override);
394                        merge(api_url, api_url_override);
395                        if let Some(low_speed_timeout_in_seconds_override) =
396                            low_speed_timeout_in_seconds_override
397                        {
398                            *low_speed_timeout_in_seconds =
399                                Some(low_speed_timeout_in_seconds_override);
400                        }
401                    }
402                    (
403                        AssistantProvider::Anthropic {
404                            model,
405                            api_url,
406                            low_speed_timeout_in_seconds,
407                        },
408                        AssistantProviderContent::Anthropic {
409                            default_model: model_override,
410                            api_url: api_url_override,
411                            low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
412                        },
413                    ) => {
414                        merge(model, model_override);
415                        merge(api_url, api_url_override);
416                        if let Some(low_speed_timeout_in_seconds_override) =
417                            low_speed_timeout_in_seconds_override
418                        {
419                            *low_speed_timeout_in_seconds =
420                                Some(low_speed_timeout_in_seconds_override);
421                        }
422                    }
423                    (provider, provider_override) => {
424                        *provider = match provider_override {
425                            AssistantProviderContent::ZedDotDev {
426                                default_model: model,
427                            } => AssistantProvider::ZedDotDev {
428                                model: model.unwrap_or_default(),
429                            },
430                            AssistantProviderContent::OpenAi {
431                                default_model: model,
432                                api_url,
433                                low_speed_timeout_in_seconds,
434                                available_models,
435                            } => AssistantProvider::OpenAi {
436                                model: model.unwrap_or_default(),
437                                api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
438                                low_speed_timeout_in_seconds,
439                                available_models: available_models.unwrap_or_default(),
440                            },
441                            AssistantProviderContent::Anthropic {
442                                default_model: model,
443                                api_url,
444                                low_speed_timeout_in_seconds,
445                            } => AssistantProvider::Anthropic {
446                                model: model.unwrap_or_default(),
447                                api_url: api_url
448                                    .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
449                                low_speed_timeout_in_seconds,
450                            },
451                            AssistantProviderContent::Ollama {
452                                default_model: model,
453                                api_url,
454                                low_speed_timeout_in_seconds,
455                            } => AssistantProvider::Ollama {
456                                model: model.unwrap_or_default(),
457                                api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
458                                low_speed_timeout_in_seconds,
459                            },
460                        };
461                    }
462                }
463            }
464        }
465
466        Ok(settings)
467    }
468}
469
470fn merge<T>(target: &mut T, value: Option<T>) {
471    if let Some(value) = value {
472        *target = value;
473    }
474}
475
476pub fn update_completion_provider_settings(
477    provider: &mut CompletionProvider,
478    version: usize,
479    cx: &mut AppContext,
480) {
481    let updated = match &AssistantSettings::get_global(cx).provider {
482        AssistantProvider::ZedDotDev { model } => provider
483            .update_current_as::<_, CloudCompletionProvider>(|provider| {
484                provider.update(model.clone(), version);
485            }),
486        AssistantProvider::OpenAi {
487            model,
488            api_url,
489            low_speed_timeout_in_seconds,
490            available_models,
491        } => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
492            provider.update(
493                choose_openai_model(&model, &available_models),
494                api_url.clone(),
495                low_speed_timeout_in_seconds.map(Duration::from_secs),
496                version,
497            );
498        }),
499        AssistantProvider::Anthropic {
500            model,
501            api_url,
502            low_speed_timeout_in_seconds,
503        } => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
504            provider.update(
505                model.clone(),
506                api_url.clone(),
507                low_speed_timeout_in_seconds.map(Duration::from_secs),
508                version,
509            );
510        }),
511        AssistantProvider::Ollama {
512            model,
513            api_url,
514            low_speed_timeout_in_seconds,
515        } => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
516            provider.update(
517                model.clone(),
518                api_url.clone(),
519                low_speed_timeout_in_seconds.map(Duration::from_secs),
520                version,
521                cx,
522            );
523        }),
524    };
525
526    // Previously configured provider was changed to another one
527    if updated.is_none() {
528        provider.update_provider(|client| create_provider_from_settings(client, version, cx));
529    }
530}
531
532pub(crate) fn create_provider_from_settings(
533    client: Arc<Client>,
534    settings_version: usize,
535    cx: &mut AppContext,
536) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
537    match &AssistantSettings::get_global(cx).provider {
538        AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
539            CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
540        )),
541        AssistantProvider::OpenAi {
542            model,
543            api_url,
544            low_speed_timeout_in_seconds,
545            available_models,
546        } => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
547            choose_openai_model(&model, &available_models),
548            api_url.clone(),
549            client.http_client(),
550            low_speed_timeout_in_seconds.map(Duration::from_secs),
551            settings_version,
552            available_models.clone(),
553        ))),
554        AssistantProvider::Anthropic {
555            model,
556            api_url,
557            low_speed_timeout_in_seconds,
558        } => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
559            model.clone(),
560            api_url.clone(),
561            client.http_client(),
562            low_speed_timeout_in_seconds.map(Duration::from_secs),
563            settings_version,
564        ))),
565        AssistantProvider::Ollama {
566            model,
567            api_url,
568            low_speed_timeout_in_seconds,
569        } => Arc::new(RwLock::new(OllamaCompletionProvider::new(
570            model.clone(),
571            api_url.clone(),
572            client.http_client(),
573            low_speed_timeout_in_seconds.map(Duration::from_secs),
574            settings_version,
575            cx,
576        ))),
577    }
578}
579
580/// Choose which model to use for openai provider.
581/// If the model is not available, try to use the first available model, or fallback to the original model.
582fn choose_openai_model(
583    model: &::open_ai::Model,
584    available_models: &[::open_ai::Model],
585) -> ::open_ai::Model {
586    available_models
587        .iter()
588        .find(|&m| m == model)
589        .or_else(|| available_models.first())
590        .unwrap_or_else(|| model)
591        .clone()
592}
593
594#[cfg(test)]
595mod tests {
596    use gpui::{AppContext, UpdateGlobal};
597    use settings::SettingsStore;
598
599    use super::*;
600
601    #[gpui::test]
602    fn test_deserialize_assistant_settings(cx: &mut AppContext) {
603        let store = settings::SettingsStore::test(cx);
604        cx.set_global(store);
605
606        // Settings default to gpt-4-turbo.
607        AssistantSettings::register(cx);
608        assert_eq!(
609            AssistantSettings::get_global(cx).provider,
610            AssistantProvider::OpenAi {
611                model: OpenAiModel::FourOmni,
612                api_url: open_ai::OPEN_AI_API_URL.into(),
613                low_speed_timeout_in_seconds: None,
614                available_models: Default::default(),
615            }
616        );
617
618        // Ensure backward-compatibility.
619        SettingsStore::update_global(cx, |store, cx| {
620            store
621                .set_user_settings(
622                    r#"{
623                        "assistant": {
624                            "openai_api_url": "test-url",
625                        }
626                    }"#,
627                    cx,
628                )
629                .unwrap();
630        });
631        assert_eq!(
632            AssistantSettings::get_global(cx).provider,
633            AssistantProvider::OpenAi {
634                model: OpenAiModel::FourOmni,
635                api_url: "test-url".into(),
636                low_speed_timeout_in_seconds: None,
637                available_models: Default::default(),
638            }
639        );
640        SettingsStore::update_global(cx, |store, cx| {
641            store
642                .set_user_settings(
643                    r#"{
644                        "assistant": {
645                            "default_open_ai_model": "gpt-4-0613"
646                        }
647                    }"#,
648                    cx,
649                )
650                .unwrap();
651        });
652        assert_eq!(
653            AssistantSettings::get_global(cx).provider,
654            AssistantProvider::OpenAi {
655                model: OpenAiModel::Four,
656                api_url: open_ai::OPEN_AI_API_URL.into(),
657                low_speed_timeout_in_seconds: None,
658                available_models: Default::default(),
659            }
660        );
661
662        // The new version supports setting a custom model when using zed.dev.
663        SettingsStore::update_global(cx, |store, cx| {
664            store
665                .set_user_settings(
666                    r#"{
667                        "assistant": {
668                            "version": "1",
669                            "provider": {
670                                "name": "zed.dev",
671                                "default_model": {
672                                    "custom": {
673                                        "name": "custom-provider"
674                                    }
675                                }
676                            }
677                        }
678                    }"#,
679                    cx,
680                )
681                .unwrap();
682        });
683        assert_eq!(
684            AssistantSettings::get_global(cx).provider,
685            AssistantProvider::ZedDotDev {
686                model: CloudModel::Custom {
687                    name: "custom-provider".into(),
688                    max_tokens: None
689                }
690            }
691        );
692    }
693}