1use std::fmt;
2
3use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
4pub use anthropic::Model as AnthropicModel;
5use gpui::Pixels;
6pub use ollama::Model as OllamaModel;
7pub use open_ai::Model as OpenAiModel;
8use schemars::{
9 schema::{InstanceType, Metadata, Schema, SchemaObject},
10 JsonSchema,
11};
12use serde::{
13 de::{self, Visitor},
14 Deserialize, Deserializer, Serialize, Serializer,
15};
16use settings::{Settings, SettingsSources};
17use strum::{EnumIter, IntoEnumIterator};
18
19#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
20pub enum CloudModel {
21 Gpt3Point5Turbo,
22 Gpt4,
23 Gpt4Turbo,
24 #[default]
25 Gpt4Omni,
26 Gpt4OmniMini,
27 Claude3_5Sonnet,
28 Claude3Opus,
29 Claude3Sonnet,
30 Claude3Haiku,
31 Gemini15Pro,
32 Gemini15Flash,
33 Custom(String),
34}
35
36impl Serialize for CloudModel {
37 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
38 where
39 S: Serializer,
40 {
41 serializer.serialize_str(self.id())
42 }
43}
44
45impl<'de> Deserialize<'de> for CloudModel {
46 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
47 where
48 D: Deserializer<'de>,
49 {
50 struct ZedDotDevModelVisitor;
51
52 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
53 type Value = CloudModel;
54
55 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
56 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
57 }
58
59 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
60 where
61 E: de::Error,
62 {
63 let model = CloudModel::iter()
64 .find(|model| model.id() == value)
65 .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
66 Ok(model)
67 }
68 }
69
70 deserializer.deserialize_str(ZedDotDevModelVisitor)
71 }
72}
73
74impl JsonSchema for CloudModel {
75 fn schema_name() -> String {
76 "ZedDotDevModel".to_owned()
77 }
78
79 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
80 let variants = CloudModel::iter()
81 .filter_map(|model| {
82 let id = model.id();
83 if id.is_empty() {
84 None
85 } else {
86 Some(id.to_string())
87 }
88 })
89 .collect::<Vec<_>>();
90 Schema::Object(SchemaObject {
91 instance_type: Some(InstanceType::String.into()),
92 enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
93 metadata: Some(Box::new(Metadata {
94 title: Some("ZedDotDevModel".to_owned()),
95 default: Some(CloudModel::default().id().into()),
96 examples: variants.into_iter().map(Into::into).collect(),
97 ..Default::default()
98 })),
99 ..Default::default()
100 })
101 }
102}
103
104impl CloudModel {
105 pub fn id(&self) -> &str {
106 match self {
107 Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
108 Self::Gpt4 => "gpt-4",
109 Self::Gpt4Turbo => "gpt-4-turbo-preview",
110 Self::Gpt4Omni => "gpt-4o",
111 Self::Gpt4OmniMini => "gpt-4o-mini",
112 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
113 Self::Claude3Opus => "claude-3-opus",
114 Self::Claude3Sonnet => "claude-3-sonnet",
115 Self::Claude3Haiku => "claude-3-haiku",
116 Self::Gemini15Pro => "gemini-1.5-pro",
117 Self::Gemini15Flash => "gemini-1.5-flash",
118 Self::Custom(id) => id,
119 }
120 }
121
122 pub fn display_name(&self) -> &str {
123 match self {
124 Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
125 Self::Gpt4 => "GPT 4",
126 Self::Gpt4Turbo => "GPT 4 Turbo",
127 Self::Gpt4Omni => "GPT 4 Omni",
128 Self::Gpt4OmniMini => "GPT 4 Omni Mini",
129 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
130 Self::Claude3Opus => "Claude 3 Opus",
131 Self::Claude3Sonnet => "Claude 3 Sonnet",
132 Self::Claude3Haiku => "Claude 3 Haiku",
133 Self::Gemini15Pro => "Gemini 1.5 Pro",
134 Self::Gemini15Flash => "Gemini 1.5 Flash",
135 Self::Custom(id) => id.as_str(),
136 }
137 }
138
139 pub fn max_token_count(&self) -> usize {
140 match self {
141 Self::Gpt3Point5Turbo => 2048,
142 Self::Gpt4 => 4096,
143 Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
144 Self::Gpt4OmniMini => 128000,
145 Self::Claude3_5Sonnet
146 | Self::Claude3Opus
147 | Self::Claude3Sonnet
148 | Self::Claude3Haiku => 200000,
149 Self::Gemini15Pro => 128000,
150 Self::Gemini15Flash => 32000,
151 Self::Custom(_) => 4096, // TODO: Make this configurable
152 }
153 }
154
155 pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
156 match self {
157 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
158 preprocess_anthropic_request(request)
159 }
160 _ => {}
161 }
162 }
163}
164
165#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
166#[serde(rename_all = "snake_case")]
167pub enum AssistantDockPosition {
168 Left,
169 #[default]
170 Right,
171 Bottom,
172}
173
174#[derive(Debug, PartialEq)]
175pub enum AssistantProvider {
176 ZedDotDev {
177 model: CloudModel,
178 },
179 OpenAi {
180 model: OpenAiModel,
181 api_url: String,
182 low_speed_timeout_in_seconds: Option<u64>,
183 available_models: Vec<OpenAiModel>,
184 },
185 Anthropic {
186 model: AnthropicModel,
187 api_url: String,
188 low_speed_timeout_in_seconds: Option<u64>,
189 },
190 Ollama {
191 model: OllamaModel,
192 api_url: String,
193 low_speed_timeout_in_seconds: Option<u64>,
194 },
195}
196
197impl Default for AssistantProvider {
198 fn default() -> Self {
199 Self::OpenAi {
200 model: OpenAiModel::default(),
201 api_url: open_ai::OPEN_AI_API_URL.into(),
202 low_speed_timeout_in_seconds: None,
203 available_models: Default::default(),
204 }
205 }
206}
207
208#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
209#[serde(tag = "name", rename_all = "snake_case")]
210pub enum AssistantProviderContent {
211 #[serde(rename = "zed.dev")]
212 ZedDotDev { default_model: Option<CloudModel> },
213 #[serde(rename = "openai")]
214 OpenAi {
215 default_model: Option<OpenAiModel>,
216 api_url: Option<String>,
217 low_speed_timeout_in_seconds: Option<u64>,
218 available_models: Option<Vec<OpenAiModel>>,
219 },
220 #[serde(rename = "anthropic")]
221 Anthropic {
222 default_model: Option<AnthropicModel>,
223 api_url: Option<String>,
224 low_speed_timeout_in_seconds: Option<u64>,
225 },
226 #[serde(rename = "ollama")]
227 Ollama {
228 default_model: Option<OllamaModel>,
229 api_url: Option<String>,
230 low_speed_timeout_in_seconds: Option<u64>,
231 },
232}
233
234#[derive(Debug, Default)]
235pub struct AssistantSettings {
236 pub enabled: bool,
237 pub button: bool,
238 pub dock: AssistantDockPosition,
239 pub default_width: Pixels,
240 pub default_height: Pixels,
241 pub provider: AssistantProvider,
242}
243
244/// Assistant panel settings
245#[derive(Clone, Serialize, Deserialize, Debug)]
246#[serde(untagged)]
247pub enum AssistantSettingsContent {
248 Versioned(VersionedAssistantSettingsContent),
249 Legacy(LegacyAssistantSettingsContent),
250}
251
252impl JsonSchema for AssistantSettingsContent {
253 fn schema_name() -> String {
254 VersionedAssistantSettingsContent::schema_name()
255 }
256
257 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
258 VersionedAssistantSettingsContent::json_schema(gen)
259 }
260
261 fn is_referenceable() -> bool {
262 VersionedAssistantSettingsContent::is_referenceable()
263 }
264}
265
266impl Default for AssistantSettingsContent {
267 fn default() -> Self {
268 Self::Versioned(VersionedAssistantSettingsContent::default())
269 }
270}
271
272impl AssistantSettingsContent {
273 fn upgrade(&self) -> AssistantSettingsContentV1 {
274 match self {
275 AssistantSettingsContent::Versioned(settings) => match settings {
276 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
277 },
278 AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
279 enabled: None,
280 button: settings.button,
281 dock: settings.dock,
282 default_width: settings.default_width,
283 default_height: settings.default_height,
284 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
285 Some(AssistantProviderContent::OpenAi {
286 default_model: settings.default_open_ai_model.clone(),
287 api_url: Some(open_ai_api_url.clone()),
288 low_speed_timeout_in_seconds: None,
289 available_models: Some(Default::default()),
290 })
291 } else {
292 settings.default_open_ai_model.clone().map(|open_ai_model| {
293 AssistantProviderContent::OpenAi {
294 default_model: Some(open_ai_model),
295 api_url: None,
296 low_speed_timeout_in_seconds: None,
297 available_models: Some(Default::default()),
298 }
299 })
300 },
301 },
302 }
303 }
304
305 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
306 match self {
307 AssistantSettingsContent::Versioned(settings) => match settings {
308 VersionedAssistantSettingsContent::V1(settings) => {
309 settings.dock = Some(dock);
310 }
311 },
312 AssistantSettingsContent::Legacy(settings) => {
313 settings.dock = Some(dock);
314 }
315 }
316 }
317
318 pub fn set_model(&mut self, new_model: LanguageModel) {
319 match self {
320 AssistantSettingsContent::Versioned(settings) => match settings {
321 VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
322 Some(AssistantProviderContent::ZedDotDev {
323 default_model: model,
324 }) => {
325 if let LanguageModel::Cloud(new_model) = new_model {
326 *model = Some(new_model);
327 }
328 }
329 Some(AssistantProviderContent::OpenAi {
330 default_model: model,
331 ..
332 }) => {
333 if let LanguageModel::OpenAi(new_model) = new_model {
334 *model = Some(new_model);
335 }
336 }
337 Some(AssistantProviderContent::Anthropic {
338 default_model: model,
339 ..
340 }) => {
341 if let LanguageModel::Anthropic(new_model) = new_model {
342 *model = Some(new_model);
343 }
344 }
345 Some(AssistantProviderContent::Ollama {
346 default_model: model,
347 ..
348 }) => {
349 if let LanguageModel::Ollama(new_model) = new_model {
350 *model = Some(new_model);
351 }
352 }
353 provider => match new_model {
354 LanguageModel::Cloud(model) => {
355 *provider = Some(AssistantProviderContent::ZedDotDev {
356 default_model: Some(model),
357 })
358 }
359 LanguageModel::OpenAi(model) => {
360 *provider = Some(AssistantProviderContent::OpenAi {
361 default_model: Some(model),
362 api_url: None,
363 low_speed_timeout_in_seconds: None,
364 available_models: Some(Default::default()),
365 })
366 }
367 LanguageModel::Anthropic(model) => {
368 *provider = Some(AssistantProviderContent::Anthropic {
369 default_model: Some(model),
370 api_url: None,
371 low_speed_timeout_in_seconds: None,
372 })
373 }
374 LanguageModel::Ollama(model) => {
375 *provider = Some(AssistantProviderContent::Ollama {
376 default_model: Some(model),
377 api_url: None,
378 low_speed_timeout_in_seconds: None,
379 })
380 }
381 },
382 },
383 },
384 AssistantSettingsContent::Legacy(settings) => {
385 if let LanguageModel::OpenAi(model) = new_model {
386 settings.default_open_ai_model = Some(model);
387 }
388 }
389 }
390 }
391}
392
393#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
394#[serde(tag = "version")]
395pub enum VersionedAssistantSettingsContent {
396 #[serde(rename = "1")]
397 V1(AssistantSettingsContentV1),
398}
399
400impl Default for VersionedAssistantSettingsContent {
401 fn default() -> Self {
402 Self::V1(AssistantSettingsContentV1 {
403 enabled: None,
404 button: None,
405 dock: None,
406 default_width: None,
407 default_height: None,
408 provider: None,
409 })
410 }
411}
412
413#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
414pub struct AssistantSettingsContentV1 {
415 /// Whether the Assistant is enabled.
416 ///
417 /// Default: true
418 enabled: Option<bool>,
419 /// Whether to show the assistant panel button in the status bar.
420 ///
421 /// Default: true
422 button: Option<bool>,
423 /// Where to dock the assistant.
424 ///
425 /// Default: right
426 dock: Option<AssistantDockPosition>,
427 /// Default width in pixels when the assistant is docked to the left or right.
428 ///
429 /// Default: 640
430 default_width: Option<f32>,
431 /// Default height in pixels when the assistant is docked to the bottom.
432 ///
433 /// Default: 320
434 default_height: Option<f32>,
435 /// The provider of the assistant service.
436 ///
437 /// This can either be the internal `zed.dev` service or an external `openai` service,
438 /// each with their respective default models and configurations.
439 provider: Option<AssistantProviderContent>,
440}
441
442#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
443pub struct LegacyAssistantSettingsContent {
444 /// Whether to show the assistant panel button in the status bar.
445 ///
446 /// Default: true
447 pub button: Option<bool>,
448 /// Where to dock the assistant.
449 ///
450 /// Default: right
451 pub dock: Option<AssistantDockPosition>,
452 /// Default width in pixels when the assistant is docked to the left or right.
453 ///
454 /// Default: 640
455 pub default_width: Option<f32>,
456 /// Default height in pixels when the assistant is docked to the bottom.
457 ///
458 /// Default: 320
459 pub default_height: Option<f32>,
460 /// The default OpenAI model to use when creating new contexts.
461 ///
462 /// Default: gpt-4-1106-preview
463 pub default_open_ai_model: Option<OpenAiModel>,
464 /// OpenAI API base URL to use when creating new contexts.
465 ///
466 /// Default: https://api.openai.com/v1
467 pub openai_api_url: Option<String>,
468}
469
470impl Settings for AssistantSettings {
471 const KEY: Option<&'static str> = Some("assistant");
472
473 type FileContent = AssistantSettingsContent;
474
475 fn load(
476 sources: SettingsSources<Self::FileContent>,
477 _: &mut gpui::AppContext,
478 ) -> anyhow::Result<Self> {
479 let mut settings = AssistantSettings::default();
480
481 for value in sources.defaults_and_customizations() {
482 let value = value.upgrade();
483 merge(&mut settings.enabled, value.enabled);
484 merge(&mut settings.button, value.button);
485 merge(&mut settings.dock, value.dock);
486 merge(
487 &mut settings.default_width,
488 value.default_width.map(Into::into),
489 );
490 merge(
491 &mut settings.default_height,
492 value.default_height.map(Into::into),
493 );
494 if let Some(provider) = value.provider.clone() {
495 match (&mut settings.provider, provider) {
496 (
497 AssistantProvider::ZedDotDev { model },
498 AssistantProviderContent::ZedDotDev {
499 default_model: model_override,
500 },
501 ) => {
502 merge(model, model_override);
503 }
504 (
505 AssistantProvider::OpenAi {
506 model,
507 api_url,
508 low_speed_timeout_in_seconds,
509 available_models,
510 },
511 AssistantProviderContent::OpenAi {
512 default_model: model_override,
513 api_url: api_url_override,
514 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
515 available_models: available_models_override,
516 },
517 ) => {
518 merge(model, model_override);
519 merge(api_url, api_url_override);
520 merge(available_models, available_models_override);
521 if let Some(low_speed_timeout_in_seconds_override) =
522 low_speed_timeout_in_seconds_override
523 {
524 *low_speed_timeout_in_seconds =
525 Some(low_speed_timeout_in_seconds_override);
526 }
527 }
528 (
529 AssistantProvider::Ollama {
530 model,
531 api_url,
532 low_speed_timeout_in_seconds,
533 },
534 AssistantProviderContent::Ollama {
535 default_model: model_override,
536 api_url: api_url_override,
537 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
538 },
539 ) => {
540 merge(model, model_override);
541 merge(api_url, api_url_override);
542 if let Some(low_speed_timeout_in_seconds_override) =
543 low_speed_timeout_in_seconds_override
544 {
545 *low_speed_timeout_in_seconds =
546 Some(low_speed_timeout_in_seconds_override);
547 }
548 }
549 (
550 AssistantProvider::Anthropic {
551 model,
552 api_url,
553 low_speed_timeout_in_seconds,
554 },
555 AssistantProviderContent::Anthropic {
556 default_model: model_override,
557 api_url: api_url_override,
558 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
559 },
560 ) => {
561 merge(model, model_override);
562 merge(api_url, api_url_override);
563 if let Some(low_speed_timeout_in_seconds_override) =
564 low_speed_timeout_in_seconds_override
565 {
566 *low_speed_timeout_in_seconds =
567 Some(low_speed_timeout_in_seconds_override);
568 }
569 }
570 (provider, provider_override) => {
571 *provider = match provider_override {
572 AssistantProviderContent::ZedDotDev {
573 default_model: model,
574 } => AssistantProvider::ZedDotDev {
575 model: model.unwrap_or_default(),
576 },
577 AssistantProviderContent::OpenAi {
578 default_model: model,
579 api_url,
580 low_speed_timeout_in_seconds,
581 available_models,
582 } => AssistantProvider::OpenAi {
583 model: model.unwrap_or_default(),
584 api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
585 low_speed_timeout_in_seconds,
586 available_models: available_models.unwrap_or_default(),
587 },
588 AssistantProviderContent::Anthropic {
589 default_model: model,
590 api_url,
591 low_speed_timeout_in_seconds,
592 } => AssistantProvider::Anthropic {
593 model: model.unwrap_or_default(),
594 api_url: api_url
595 .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
596 low_speed_timeout_in_seconds,
597 },
598 AssistantProviderContent::Ollama {
599 default_model: model,
600 api_url,
601 low_speed_timeout_in_seconds,
602 } => AssistantProvider::Ollama {
603 model: model.unwrap_or_default(),
604 api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
605 low_speed_timeout_in_seconds,
606 },
607 };
608 }
609 }
610 }
611 }
612
613 Ok(settings)
614 }
615}
616
617fn merge<T>(target: &mut T, value: Option<T>) {
618 if let Some(value) = value {
619 *target = value;
620 }
621}
622
623#[cfg(test)]
624mod tests {
625 use gpui::{AppContext, UpdateGlobal};
626 use settings::SettingsStore;
627
628 use super::*;
629
630 #[gpui::test]
631 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
632 let store = settings::SettingsStore::test(cx);
633 cx.set_global(store);
634
635 // Settings default to gpt-4-turbo.
636 AssistantSettings::register(cx);
637 assert_eq!(
638 AssistantSettings::get_global(cx).provider,
639 AssistantProvider::OpenAi {
640 model: OpenAiModel::FourOmni,
641 api_url: open_ai::OPEN_AI_API_URL.into(),
642 low_speed_timeout_in_seconds: None,
643 available_models: Default::default(),
644 }
645 );
646
647 // Ensure backward-compatibility.
648 SettingsStore::update_global(cx, |store, cx| {
649 store
650 .set_user_settings(
651 r#"{
652 "assistant": {
653 "openai_api_url": "test-url",
654 }
655 }"#,
656 cx,
657 )
658 .unwrap();
659 });
660 assert_eq!(
661 AssistantSettings::get_global(cx).provider,
662 AssistantProvider::OpenAi {
663 model: OpenAiModel::FourOmni,
664 api_url: "test-url".into(),
665 low_speed_timeout_in_seconds: None,
666 available_models: Default::default(),
667 }
668 );
669 SettingsStore::update_global(cx, |store, cx| {
670 store
671 .set_user_settings(
672 r#"{
673 "assistant": {
674 "default_open_ai_model": "gpt-4-0613"
675 }
676 }"#,
677 cx,
678 )
679 .unwrap();
680 });
681 assert_eq!(
682 AssistantSettings::get_global(cx).provider,
683 AssistantProvider::OpenAi {
684 model: OpenAiModel::Four,
685 api_url: open_ai::OPEN_AI_API_URL.into(),
686 low_speed_timeout_in_seconds: None,
687 available_models: Default::default(),
688 }
689 );
690
691 // The new version supports setting a custom model when using zed.dev.
692 SettingsStore::update_global(cx, |store, cx| {
693 store
694 .set_user_settings(
695 r#"{
696 "assistant": {
697 "version": "1",
698 "provider": {
699 "name": "zed.dev",
700 "default_model": "custom"
701 }
702 }
703 }"#,
704 cx,
705 )
706 .unwrap();
707 });
708 assert_eq!(
709 AssistantSettings::get_global(cx).provider,
710 AssistantProvider::ZedDotDev {
711 model: CloudModel::Custom("custom".into())
712 }
713 );
714 }
715}