1use std::{sync::Arc, time::Duration};
2
3use anthropic::Model as AnthropicModel;
4use client::Client;
5use completion::{
6 AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider,
7 LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider,
8};
9use gpui::{AppContext, Pixels};
10use language_model::{CloudModel, LanguageModel};
11use ollama::Model as OllamaModel;
12use open_ai::Model as OpenAiModel;
13use parking_lot::RwLock;
14use schemars::{schema::Schema, JsonSchema};
15use serde::{Deserialize, Serialize};
16use settings::{Settings, SettingsSources};
17
18#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
19#[serde(rename_all = "snake_case")]
20pub enum AssistantDockPosition {
21 Left,
22 #[default]
23 Right,
24 Bottom,
25}
26
27#[derive(Debug, PartialEq)]
28pub enum AssistantProvider {
29 ZedDotDev {
30 model: CloudModel,
31 },
32 OpenAi {
33 model: OpenAiModel,
34 api_url: String,
35 low_speed_timeout_in_seconds: Option<u64>,
36 available_models: Vec<OpenAiModel>,
37 },
38 Anthropic {
39 model: AnthropicModel,
40 api_url: String,
41 low_speed_timeout_in_seconds: Option<u64>,
42 },
43 Ollama {
44 model: OllamaModel,
45 api_url: String,
46 low_speed_timeout_in_seconds: Option<u64>,
47 },
48}
49
50impl Default for AssistantProvider {
51 fn default() -> Self {
52 Self::OpenAi {
53 model: OpenAiModel::default(),
54 api_url: open_ai::OPEN_AI_API_URL.into(),
55 low_speed_timeout_in_seconds: None,
56 available_models: Default::default(),
57 }
58 }
59}
60
61#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
62#[serde(tag = "name", rename_all = "snake_case")]
63pub enum AssistantProviderContent {
64 #[serde(rename = "zed.dev")]
65 ZedDotDev { default_model: Option<CloudModel> },
66 #[serde(rename = "openai")]
67 OpenAi {
68 default_model: Option<OpenAiModel>,
69 api_url: Option<String>,
70 low_speed_timeout_in_seconds: Option<u64>,
71 available_models: Option<Vec<OpenAiModel>>,
72 },
73 #[serde(rename = "anthropic")]
74 Anthropic {
75 default_model: Option<AnthropicModel>,
76 api_url: Option<String>,
77 low_speed_timeout_in_seconds: Option<u64>,
78 },
79 #[serde(rename = "ollama")]
80 Ollama {
81 default_model: Option<OllamaModel>,
82 api_url: Option<String>,
83 low_speed_timeout_in_seconds: Option<u64>,
84 },
85}
86
87#[derive(Debug, Default)]
88pub struct AssistantSettings {
89 pub enabled: bool,
90 pub button: bool,
91 pub dock: AssistantDockPosition,
92 pub default_width: Pixels,
93 pub default_height: Pixels,
94 pub provider: AssistantProvider,
95}
96
97/// Assistant panel settings
98#[derive(Clone, Serialize, Deserialize, Debug)]
99#[serde(untagged)]
100pub enum AssistantSettingsContent {
101 Versioned(VersionedAssistantSettingsContent),
102 Legacy(LegacyAssistantSettingsContent),
103}
104
105impl JsonSchema for AssistantSettingsContent {
106 fn schema_name() -> String {
107 VersionedAssistantSettingsContent::schema_name()
108 }
109
110 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
111 VersionedAssistantSettingsContent::json_schema(gen)
112 }
113
114 fn is_referenceable() -> bool {
115 VersionedAssistantSettingsContent::is_referenceable()
116 }
117}
118
119impl Default for AssistantSettingsContent {
120 fn default() -> Self {
121 Self::Versioned(VersionedAssistantSettingsContent::default())
122 }
123}
124
125impl AssistantSettingsContent {
126 fn upgrade(&self) -> AssistantSettingsContentV1 {
127 match self {
128 AssistantSettingsContent::Versioned(settings) => match settings {
129 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
130 },
131 AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
132 enabled: None,
133 button: settings.button,
134 dock: settings.dock,
135 default_width: settings.default_width,
136 default_height: settings.default_height,
137 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
138 Some(AssistantProviderContent::OpenAi {
139 default_model: settings.default_open_ai_model.clone(),
140 api_url: Some(open_ai_api_url.clone()),
141 low_speed_timeout_in_seconds: None,
142 available_models: Some(Default::default()),
143 })
144 } else {
145 settings.default_open_ai_model.clone().map(|open_ai_model| {
146 AssistantProviderContent::OpenAi {
147 default_model: Some(open_ai_model),
148 api_url: None,
149 low_speed_timeout_in_seconds: None,
150 available_models: Some(Default::default()),
151 }
152 })
153 },
154 },
155 }
156 }
157
158 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
159 match self {
160 AssistantSettingsContent::Versioned(settings) => match settings {
161 VersionedAssistantSettingsContent::V1(settings) => {
162 settings.dock = Some(dock);
163 }
164 },
165 AssistantSettingsContent::Legacy(settings) => {
166 settings.dock = Some(dock);
167 }
168 }
169 }
170
171 pub fn set_model(&mut self, new_model: LanguageModel) {
172 match self {
173 AssistantSettingsContent::Versioned(settings) => match settings {
174 VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
175 Some(AssistantProviderContent::ZedDotDev {
176 default_model: model,
177 }) => {
178 if let LanguageModel::Cloud(new_model) = new_model {
179 *model = Some(new_model);
180 }
181 }
182 Some(AssistantProviderContent::OpenAi {
183 default_model: model,
184 ..
185 }) => {
186 if let LanguageModel::OpenAi(new_model) = new_model {
187 *model = Some(new_model);
188 }
189 }
190 Some(AssistantProviderContent::Anthropic {
191 default_model: model,
192 ..
193 }) => {
194 if let LanguageModel::Anthropic(new_model) = new_model {
195 *model = Some(new_model);
196 }
197 }
198 Some(AssistantProviderContent::Ollama {
199 default_model: model,
200 ..
201 }) => {
202 if let LanguageModel::Ollama(new_model) = new_model {
203 *model = Some(new_model);
204 }
205 }
206 provider => match new_model {
207 LanguageModel::Cloud(model) => {
208 *provider = Some(AssistantProviderContent::ZedDotDev {
209 default_model: Some(model),
210 })
211 }
212 LanguageModel::OpenAi(model) => {
213 *provider = Some(AssistantProviderContent::OpenAi {
214 default_model: Some(model),
215 api_url: None,
216 low_speed_timeout_in_seconds: None,
217 available_models: Some(Default::default()),
218 })
219 }
220 LanguageModel::Anthropic(model) => {
221 *provider = Some(AssistantProviderContent::Anthropic {
222 default_model: Some(model),
223 api_url: None,
224 low_speed_timeout_in_seconds: None,
225 })
226 }
227 LanguageModel::Ollama(model) => {
228 *provider = Some(AssistantProviderContent::Ollama {
229 default_model: Some(model),
230 api_url: None,
231 low_speed_timeout_in_seconds: None,
232 })
233 }
234 },
235 },
236 },
237 AssistantSettingsContent::Legacy(settings) => {
238 if let LanguageModel::OpenAi(model) = new_model {
239 settings.default_open_ai_model = Some(model);
240 }
241 }
242 }
243 }
244}
245
246#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
247#[serde(tag = "version")]
248pub enum VersionedAssistantSettingsContent {
249 #[serde(rename = "1")]
250 V1(AssistantSettingsContentV1),
251}
252
253impl Default for VersionedAssistantSettingsContent {
254 fn default() -> Self {
255 Self::V1(AssistantSettingsContentV1 {
256 enabled: None,
257 button: None,
258 dock: None,
259 default_width: None,
260 default_height: None,
261 provider: None,
262 })
263 }
264}
265
266#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
267pub struct AssistantSettingsContentV1 {
268 /// Whether the Assistant is enabled.
269 ///
270 /// Default: true
271 enabled: Option<bool>,
272 /// Whether to show the assistant panel button in the status bar.
273 ///
274 /// Default: true
275 button: Option<bool>,
276 /// Where to dock the assistant.
277 ///
278 /// Default: right
279 dock: Option<AssistantDockPosition>,
280 /// Default width in pixels when the assistant is docked to the left or right.
281 ///
282 /// Default: 640
283 default_width: Option<f32>,
284 /// Default height in pixels when the assistant is docked to the bottom.
285 ///
286 /// Default: 320
287 default_height: Option<f32>,
288 /// The provider of the assistant service.
289 ///
290 /// This can either be the internal `zed.dev` service or an external `openai` service,
291 /// each with their respective default models and configurations.
292 provider: Option<AssistantProviderContent>,
293}
294
295#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
296pub struct LegacyAssistantSettingsContent {
297 /// Whether to show the assistant panel button in the status bar.
298 ///
299 /// Default: true
300 pub button: Option<bool>,
301 /// Where to dock the assistant.
302 ///
303 /// Default: right
304 pub dock: Option<AssistantDockPosition>,
305 /// Default width in pixels when the assistant is docked to the left or right.
306 ///
307 /// Default: 640
308 pub default_width: Option<f32>,
309 /// Default height in pixels when the assistant is docked to the bottom.
310 ///
311 /// Default: 320
312 pub default_height: Option<f32>,
313 /// The default OpenAI model to use when creating new contexts.
314 ///
315 /// Default: gpt-4-1106-preview
316 pub default_open_ai_model: Option<OpenAiModel>,
317 /// OpenAI API base URL to use when creating new contexts.
318 ///
319 /// Default: https://api.openai.com/v1
320 pub openai_api_url: Option<String>,
321}
322
323impl Settings for AssistantSettings {
324 const KEY: Option<&'static str> = Some("assistant");
325
326 type FileContent = AssistantSettingsContent;
327
328 fn load(
329 sources: SettingsSources<Self::FileContent>,
330 _: &mut gpui::AppContext,
331 ) -> anyhow::Result<Self> {
332 let mut settings = AssistantSettings::default();
333
334 for value in sources.defaults_and_customizations() {
335 let value = value.upgrade();
336 merge(&mut settings.enabled, value.enabled);
337 merge(&mut settings.button, value.button);
338 merge(&mut settings.dock, value.dock);
339 merge(
340 &mut settings.default_width,
341 value.default_width.map(Into::into),
342 );
343 merge(
344 &mut settings.default_height,
345 value.default_height.map(Into::into),
346 );
347 if let Some(provider) = value.provider.clone() {
348 match (&mut settings.provider, provider) {
349 (
350 AssistantProvider::ZedDotDev { model },
351 AssistantProviderContent::ZedDotDev {
352 default_model: model_override,
353 },
354 ) => {
355 merge(model, model_override);
356 }
357 (
358 AssistantProvider::OpenAi {
359 model,
360 api_url,
361 low_speed_timeout_in_seconds,
362 available_models,
363 },
364 AssistantProviderContent::OpenAi {
365 default_model: model_override,
366 api_url: api_url_override,
367 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
368 available_models: available_models_override,
369 },
370 ) => {
371 merge(model, model_override);
372 merge(api_url, api_url_override);
373 merge(available_models, available_models_override);
374 if let Some(low_speed_timeout_in_seconds_override) =
375 low_speed_timeout_in_seconds_override
376 {
377 *low_speed_timeout_in_seconds =
378 Some(low_speed_timeout_in_seconds_override);
379 }
380 }
381 (
382 AssistantProvider::Ollama {
383 model,
384 api_url,
385 low_speed_timeout_in_seconds,
386 },
387 AssistantProviderContent::Ollama {
388 default_model: model_override,
389 api_url: api_url_override,
390 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
391 },
392 ) => {
393 merge(model, model_override);
394 merge(api_url, api_url_override);
395 if let Some(low_speed_timeout_in_seconds_override) =
396 low_speed_timeout_in_seconds_override
397 {
398 *low_speed_timeout_in_seconds =
399 Some(low_speed_timeout_in_seconds_override);
400 }
401 }
402 (
403 AssistantProvider::Anthropic {
404 model,
405 api_url,
406 low_speed_timeout_in_seconds,
407 },
408 AssistantProviderContent::Anthropic {
409 default_model: model_override,
410 api_url: api_url_override,
411 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
412 },
413 ) => {
414 merge(model, model_override);
415 merge(api_url, api_url_override);
416 if let Some(low_speed_timeout_in_seconds_override) =
417 low_speed_timeout_in_seconds_override
418 {
419 *low_speed_timeout_in_seconds =
420 Some(low_speed_timeout_in_seconds_override);
421 }
422 }
423 (provider, provider_override) => {
424 *provider = match provider_override {
425 AssistantProviderContent::ZedDotDev {
426 default_model: model,
427 } => AssistantProvider::ZedDotDev {
428 model: model.unwrap_or_default(),
429 },
430 AssistantProviderContent::OpenAi {
431 default_model: model,
432 api_url,
433 low_speed_timeout_in_seconds,
434 available_models,
435 } => AssistantProvider::OpenAi {
436 model: model.unwrap_or_default(),
437 api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
438 low_speed_timeout_in_seconds,
439 available_models: available_models.unwrap_or_default(),
440 },
441 AssistantProviderContent::Anthropic {
442 default_model: model,
443 api_url,
444 low_speed_timeout_in_seconds,
445 } => AssistantProvider::Anthropic {
446 model: model.unwrap_or_default(),
447 api_url: api_url
448 .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
449 low_speed_timeout_in_seconds,
450 },
451 AssistantProviderContent::Ollama {
452 default_model: model,
453 api_url,
454 low_speed_timeout_in_seconds,
455 } => AssistantProvider::Ollama {
456 model: model.unwrap_or_default(),
457 api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
458 low_speed_timeout_in_seconds,
459 },
460 };
461 }
462 }
463 }
464 }
465
466 Ok(settings)
467 }
468}
469
470fn merge<T>(target: &mut T, value: Option<T>) {
471 if let Some(value) = value {
472 *target = value;
473 }
474}
475
476pub fn update_completion_provider_settings(
477 provider: &mut CompletionProvider,
478 version: usize,
479 cx: &mut AppContext,
480) {
481 let updated = match &AssistantSettings::get_global(cx).provider {
482 AssistantProvider::ZedDotDev { model } => provider
483 .update_current_as::<_, CloudCompletionProvider>(|provider| {
484 provider.update(model.clone(), version);
485 }),
486 AssistantProvider::OpenAi {
487 model,
488 api_url,
489 low_speed_timeout_in_seconds,
490 available_models,
491 } => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
492 provider.update(
493 choose_openai_model(&model, &available_models),
494 api_url.clone(),
495 low_speed_timeout_in_seconds.map(Duration::from_secs),
496 version,
497 );
498 }),
499 AssistantProvider::Anthropic {
500 model,
501 api_url,
502 low_speed_timeout_in_seconds,
503 } => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
504 provider.update(
505 model.clone(),
506 api_url.clone(),
507 low_speed_timeout_in_seconds.map(Duration::from_secs),
508 version,
509 );
510 }),
511 AssistantProvider::Ollama {
512 model,
513 api_url,
514 low_speed_timeout_in_seconds,
515 } => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
516 provider.update(
517 model.clone(),
518 api_url.clone(),
519 low_speed_timeout_in_seconds.map(Duration::from_secs),
520 version,
521 cx,
522 );
523 }),
524 };
525
526 // Previously configured provider was changed to another one
527 if updated.is_none() {
528 provider.update_provider(|client| create_provider_from_settings(client, version, cx));
529 }
530}
531
532pub(crate) fn create_provider_from_settings(
533 client: Arc<Client>,
534 settings_version: usize,
535 cx: &mut AppContext,
536) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
537 match &AssistantSettings::get_global(cx).provider {
538 AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
539 CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
540 )),
541 AssistantProvider::OpenAi {
542 model,
543 api_url,
544 low_speed_timeout_in_seconds,
545 available_models,
546 } => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
547 choose_openai_model(&model, &available_models),
548 api_url.clone(),
549 client.http_client(),
550 low_speed_timeout_in_seconds.map(Duration::from_secs),
551 settings_version,
552 available_models.clone(),
553 ))),
554 AssistantProvider::Anthropic {
555 model,
556 api_url,
557 low_speed_timeout_in_seconds,
558 } => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
559 model.clone(),
560 api_url.clone(),
561 client.http_client(),
562 low_speed_timeout_in_seconds.map(Duration::from_secs),
563 settings_version,
564 ))),
565 AssistantProvider::Ollama {
566 model,
567 api_url,
568 low_speed_timeout_in_seconds,
569 } => Arc::new(RwLock::new(OllamaCompletionProvider::new(
570 model.clone(),
571 api_url.clone(),
572 client.http_client(),
573 low_speed_timeout_in_seconds.map(Duration::from_secs),
574 settings_version,
575 cx,
576 ))),
577 }
578}
579
580/// Choose which model to use for openai provider.
581/// If the model is not available, try to use the first available model, or fallback to the original model.
582fn choose_openai_model(
583 model: &::open_ai::Model,
584 available_models: &[::open_ai::Model],
585) -> ::open_ai::Model {
586 available_models
587 .iter()
588 .find(|&m| m == model)
589 .or_else(|| available_models.first())
590 .unwrap_or_else(|| model)
591 .clone()
592}
593
594#[cfg(test)]
595mod tests {
596 use gpui::{AppContext, UpdateGlobal};
597 use settings::SettingsStore;
598
599 use super::*;
600
601 #[gpui::test]
602 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
603 let store = settings::SettingsStore::test(cx);
604 cx.set_global(store);
605
606 // Settings default to gpt-4-turbo.
607 AssistantSettings::register(cx);
608 assert_eq!(
609 AssistantSettings::get_global(cx).provider,
610 AssistantProvider::OpenAi {
611 model: OpenAiModel::FourOmni,
612 api_url: open_ai::OPEN_AI_API_URL.into(),
613 low_speed_timeout_in_seconds: None,
614 available_models: Default::default(),
615 }
616 );
617
618 // Ensure backward-compatibility.
619 SettingsStore::update_global(cx, |store, cx| {
620 store
621 .set_user_settings(
622 r#"{
623 "assistant": {
624 "openai_api_url": "test-url",
625 }
626 }"#,
627 cx,
628 )
629 .unwrap();
630 });
631 assert_eq!(
632 AssistantSettings::get_global(cx).provider,
633 AssistantProvider::OpenAi {
634 model: OpenAiModel::FourOmni,
635 api_url: "test-url".into(),
636 low_speed_timeout_in_seconds: None,
637 available_models: Default::default(),
638 }
639 );
640 SettingsStore::update_global(cx, |store, cx| {
641 store
642 .set_user_settings(
643 r#"{
644 "assistant": {
645 "default_open_ai_model": "gpt-4-0613"
646 }
647 }"#,
648 cx,
649 )
650 .unwrap();
651 });
652 assert_eq!(
653 AssistantSettings::get_global(cx).provider,
654 AssistantProvider::OpenAi {
655 model: OpenAiModel::Four,
656 api_url: open_ai::OPEN_AI_API_URL.into(),
657 low_speed_timeout_in_seconds: None,
658 available_models: Default::default(),
659 }
660 );
661
662 // The new version supports setting a custom model when using zed.dev.
663 SettingsStore::update_global(cx, |store, cx| {
664 store
665 .set_user_settings(
666 r#"{
667 "assistant": {
668 "version": "1",
669 "provider": {
670 "name": "zed.dev",
671 "default_model": {
672 "custom": {
673 "name": "custom-provider"
674 }
675 }
676 }
677 }
678 }"#,
679 cx,
680 )
681 .unwrap();
682 });
683 assert_eq!(
684 AssistantSettings::get_global(cx).provider,
685 AssistantProvider::ZedDotDev {
686 model: CloudModel::Custom {
687 name: "custom-provider".into(),
688 max_tokens: None
689 }
690 }
691 );
692 }
693}