1use std::fmt;
2
3pub use anthropic::Model as AnthropicModel;
4use gpui::Pixels;
5pub use open_ai::Model as OpenAiModel;
6use schemars::{
7 schema::{InstanceType, Metadata, Schema, SchemaObject},
8 JsonSchema,
9};
10use serde::{
11 de::{self, Visitor},
12 Deserialize, Deserializer, Serialize, Serializer,
13};
14use settings::{Settings, SettingsSources};
15use strum::{EnumIter, IntoEnumIterator};
16
17use crate::LanguageModel;
18
19#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
20pub enum ZedDotDevModel {
21 Gpt3Point5Turbo,
22 Gpt4,
23 Gpt4Turbo,
24 #[default]
25 Gpt4Omni,
26 Claude3Opus,
27 Claude3Sonnet,
28 Claude3Haiku,
29 Custom(String),
30}
31
32impl Serialize for ZedDotDevModel {
33 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34 where
35 S: Serializer,
36 {
37 serializer.serialize_str(self.id())
38 }
39}
40
41impl<'de> Deserialize<'de> for ZedDotDevModel {
42 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
43 where
44 D: Deserializer<'de>,
45 {
46 struct ZedDotDevModelVisitor;
47
48 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
49 type Value = ZedDotDevModel;
50
51 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
52 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
53 }
54
55 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
56 where
57 E: de::Error,
58 {
59 let model = ZedDotDevModel::iter()
60 .find(|model| model.id() == value)
61 .unwrap_or_else(|| ZedDotDevModel::Custom(value.to_string()));
62 Ok(model)
63 }
64 }
65
66 deserializer.deserialize_str(ZedDotDevModelVisitor)
67 }
68}
69
70impl JsonSchema for ZedDotDevModel {
71 fn schema_name() -> String {
72 "ZedDotDevModel".to_owned()
73 }
74
75 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
76 let variants = ZedDotDevModel::iter()
77 .filter_map(|model| {
78 let id = model.id();
79 if id.is_empty() {
80 None
81 } else {
82 Some(id.to_string())
83 }
84 })
85 .collect::<Vec<_>>();
86 Schema::Object(SchemaObject {
87 instance_type: Some(InstanceType::String.into()),
88 enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
89 metadata: Some(Box::new(Metadata {
90 title: Some("ZedDotDevModel".to_owned()),
91 default: Some(ZedDotDevModel::default().id().into()),
92 examples: variants.into_iter().map(Into::into).collect(),
93 ..Default::default()
94 })),
95 ..Default::default()
96 })
97 }
98}
99
100impl ZedDotDevModel {
101 pub fn id(&self) -> &str {
102 match self {
103 Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
104 Self::Gpt4 => "gpt-4",
105 Self::Gpt4Turbo => "gpt-4-turbo-preview",
106 Self::Gpt4Omni => "gpt-4o",
107 Self::Claude3Opus => "claude-3-opus",
108 Self::Claude3Sonnet => "claude-3-sonnet",
109 Self::Claude3Haiku => "claude-3-haiku",
110 Self::Custom(id) => id,
111 }
112 }
113
114 pub fn display_name(&self) -> &str {
115 match self {
116 Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
117 Self::Gpt4 => "GPT 4",
118 Self::Gpt4Turbo => "GPT 4 Turbo",
119 Self::Gpt4Omni => "GPT 4 Omni",
120 Self::Claude3Opus => "Claude 3 Opus",
121 Self::Claude3Sonnet => "Claude 3 Sonnet",
122 Self::Claude3Haiku => "Claude 3 Haiku",
123 Self::Custom(id) => id.as_str(),
124 }
125 }
126
127 pub fn max_token_count(&self) -> usize {
128 match self {
129 Self::Gpt3Point5Turbo => 2048,
130 Self::Gpt4 => 4096,
131 Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
132 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
133 Self::Custom(_) => 4096, // TODO: Make this configurable
134 }
135 }
136}
137
138#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
139#[serde(rename_all = "snake_case")]
140pub enum AssistantDockPosition {
141 Left,
142 #[default]
143 Right,
144 Bottom,
145}
146
147#[derive(Debug, PartialEq)]
148pub enum AssistantProvider {
149 ZedDotDev {
150 model: ZedDotDevModel,
151 },
152 OpenAi {
153 model: OpenAiModel,
154 api_url: String,
155 low_speed_timeout_in_seconds: Option<u64>,
156 },
157 Anthropic {
158 model: AnthropicModel,
159 api_url: String,
160 low_speed_timeout_in_seconds: Option<u64>,
161 },
162}
163
164impl Default for AssistantProvider {
165 fn default() -> Self {
166 Self::OpenAi {
167 model: OpenAiModel::default(),
168 api_url: open_ai::OPEN_AI_API_URL.into(),
169 low_speed_timeout_in_seconds: None,
170 }
171 }
172}
173
174#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
175#[serde(tag = "name", rename_all = "snake_case")]
176pub enum AssistantProviderContent {
177 #[serde(rename = "zed.dev")]
178 ZedDotDev {
179 default_model: Option<ZedDotDevModel>,
180 },
181 #[serde(rename = "openai")]
182 OpenAi {
183 default_model: Option<OpenAiModel>,
184 api_url: Option<String>,
185 low_speed_timeout_in_seconds: Option<u64>,
186 },
187 #[serde(rename = "anthropic")]
188 Anthropic {
189 default_model: Option<AnthropicModel>,
190 api_url: Option<String>,
191 low_speed_timeout_in_seconds: Option<u64>,
192 },
193}
194
195#[derive(Debug, Default)]
196pub struct AssistantSettings {
197 pub enabled: bool,
198 pub button: bool,
199 pub dock: AssistantDockPosition,
200 pub default_width: Pixels,
201 pub default_height: Pixels,
202 pub provider: AssistantProvider,
203}
204
205/// Assistant panel settings
206#[derive(Clone, Serialize, Deserialize, Debug)]
207#[serde(untagged)]
208pub enum AssistantSettingsContent {
209 Versioned(VersionedAssistantSettingsContent),
210 Legacy(LegacyAssistantSettingsContent),
211}
212
213impl JsonSchema for AssistantSettingsContent {
214 fn schema_name() -> String {
215 VersionedAssistantSettingsContent::schema_name()
216 }
217
218 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
219 VersionedAssistantSettingsContent::json_schema(gen)
220 }
221
222 fn is_referenceable() -> bool {
223 VersionedAssistantSettingsContent::is_referenceable()
224 }
225}
226
227impl Default for AssistantSettingsContent {
228 fn default() -> Self {
229 Self::Versioned(VersionedAssistantSettingsContent::default())
230 }
231}
232
233impl AssistantSettingsContent {
234 fn upgrade(&self) -> AssistantSettingsContentV1 {
235 match self {
236 AssistantSettingsContent::Versioned(settings) => match settings {
237 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
238 },
239 AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
240 enabled: None,
241 button: settings.button,
242 dock: settings.dock,
243 default_width: settings.default_width,
244 default_height: settings.default_height,
245 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
246 Some(AssistantProviderContent::OpenAi {
247 default_model: settings.default_open_ai_model.clone(),
248 api_url: Some(open_ai_api_url.clone()),
249 low_speed_timeout_in_seconds: None,
250 })
251 } else {
252 settings.default_open_ai_model.clone().map(|open_ai_model| {
253 AssistantProviderContent::OpenAi {
254 default_model: Some(open_ai_model),
255 api_url: None,
256 low_speed_timeout_in_seconds: None,
257 }
258 })
259 },
260 },
261 }
262 }
263
264 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
265 match self {
266 AssistantSettingsContent::Versioned(settings) => match settings {
267 VersionedAssistantSettingsContent::V1(settings) => {
268 settings.dock = Some(dock);
269 }
270 },
271 AssistantSettingsContent::Legacy(settings) => {
272 settings.dock = Some(dock);
273 }
274 }
275 }
276
277 pub fn set_model(&mut self, new_model: LanguageModel) {
278 match self {
279 AssistantSettingsContent::Versioned(settings) => match settings {
280 VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
281 Some(AssistantProviderContent::ZedDotDev {
282 default_model: model,
283 }) => {
284 if let LanguageModel::ZedDotDev(new_model) = new_model {
285 *model = Some(new_model);
286 }
287 }
288 Some(AssistantProviderContent::OpenAi {
289 default_model: model,
290 ..
291 }) => {
292 if let LanguageModel::OpenAi(new_model) = new_model {
293 *model = Some(new_model);
294 }
295 }
296 Some(AssistantProviderContent::Anthropic {
297 default_model: model,
298 ..
299 }) => {
300 if let LanguageModel::Anthropic(new_model) = new_model {
301 *model = Some(new_model);
302 }
303 }
304 provider => match new_model {
305 LanguageModel::ZedDotDev(model) => {
306 *provider = Some(AssistantProviderContent::ZedDotDev {
307 default_model: Some(model),
308 })
309 }
310 LanguageModel::OpenAi(model) => {
311 *provider = Some(AssistantProviderContent::OpenAi {
312 default_model: Some(model),
313 api_url: None,
314 low_speed_timeout_in_seconds: None,
315 })
316 }
317 LanguageModel::Anthropic(model) => {
318 *provider = Some(AssistantProviderContent::Anthropic {
319 default_model: Some(model),
320 api_url: None,
321 low_speed_timeout_in_seconds: None,
322 })
323 }
324 },
325 },
326 },
327 AssistantSettingsContent::Legacy(settings) => {
328 if let LanguageModel::OpenAi(model) = new_model {
329 settings.default_open_ai_model = Some(model);
330 }
331 }
332 }
333 }
334}
335
336#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
337#[serde(tag = "version")]
338pub enum VersionedAssistantSettingsContent {
339 #[serde(rename = "1")]
340 V1(AssistantSettingsContentV1),
341}
342
343impl Default for VersionedAssistantSettingsContent {
344 fn default() -> Self {
345 Self::V1(AssistantSettingsContentV1 {
346 enabled: None,
347 button: None,
348 dock: None,
349 default_width: None,
350 default_height: None,
351 provider: None,
352 })
353 }
354}
355
356#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
357pub struct AssistantSettingsContentV1 {
358 /// Whether the Assistant is enabled.
359 ///
360 /// Default: true
361 enabled: Option<bool>,
362 /// Whether to show the assistant panel button in the status bar.
363 ///
364 /// Default: true
365 button: Option<bool>,
366 /// Where to dock the assistant.
367 ///
368 /// Default: right
369 dock: Option<AssistantDockPosition>,
370 /// Default width in pixels when the assistant is docked to the left or right.
371 ///
372 /// Default: 640
373 default_width: Option<f32>,
374 /// Default height in pixels when the assistant is docked to the bottom.
375 ///
376 /// Default: 320
377 default_height: Option<f32>,
378 /// The provider of the assistant service.
379 ///
380 /// This can either be the internal `zed.dev` service or an external `openai` service,
381 /// each with their respective default models and configurations.
382 provider: Option<AssistantProviderContent>,
383}
384
385#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
386pub struct LegacyAssistantSettingsContent {
387 /// Whether to show the assistant panel button in the status bar.
388 ///
389 /// Default: true
390 pub button: Option<bool>,
391 /// Where to dock the assistant.
392 ///
393 /// Default: right
394 pub dock: Option<AssistantDockPosition>,
395 /// Default width in pixels when the assistant is docked to the left or right.
396 ///
397 /// Default: 640
398 pub default_width: Option<f32>,
399 /// Default height in pixels when the assistant is docked to the bottom.
400 ///
401 /// Default: 320
402 pub default_height: Option<f32>,
403 /// The default OpenAI model to use when creating new contexts.
404 ///
405 /// Default: gpt-4-1106-preview
406 pub default_open_ai_model: Option<OpenAiModel>,
407 /// OpenAI API base URL to use when creating new contexts.
408 ///
409 /// Default: https://api.openai.com/v1
410 pub openai_api_url: Option<String>,
411}
412
413impl Settings for AssistantSettings {
414 const KEY: Option<&'static str> = Some("assistant");
415
416 type FileContent = AssistantSettingsContent;
417
418 fn load(
419 sources: SettingsSources<Self::FileContent>,
420 _: &mut gpui::AppContext,
421 ) -> anyhow::Result<Self> {
422 let mut settings = AssistantSettings::default();
423
424 for value in sources.defaults_and_customizations() {
425 let value = value.upgrade();
426 merge(&mut settings.enabled, value.enabled);
427 merge(&mut settings.button, value.button);
428 merge(&mut settings.dock, value.dock);
429 merge(
430 &mut settings.default_width,
431 value.default_width.map(Into::into),
432 );
433 merge(
434 &mut settings.default_height,
435 value.default_height.map(Into::into),
436 );
437 if let Some(provider) = value.provider.clone() {
438 match (&mut settings.provider, provider) {
439 (
440 AssistantProvider::ZedDotDev { model },
441 AssistantProviderContent::ZedDotDev {
442 default_model: model_override,
443 },
444 ) => {
445 merge(model, model_override);
446 }
447 (
448 AssistantProvider::OpenAi {
449 model,
450 api_url,
451 low_speed_timeout_in_seconds,
452 },
453 AssistantProviderContent::OpenAi {
454 default_model: model_override,
455 api_url: api_url_override,
456 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
457 },
458 ) => {
459 merge(model, model_override);
460 merge(api_url, api_url_override);
461 if let Some(low_speed_timeout_in_seconds_override) =
462 low_speed_timeout_in_seconds_override
463 {
464 *low_speed_timeout_in_seconds =
465 Some(low_speed_timeout_in_seconds_override);
466 }
467 }
468 (
469 AssistantProvider::Anthropic {
470 model,
471 api_url,
472 low_speed_timeout_in_seconds,
473 },
474 AssistantProviderContent::Anthropic {
475 default_model: model_override,
476 api_url: api_url_override,
477 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
478 },
479 ) => {
480 merge(model, model_override);
481 merge(api_url, api_url_override);
482 if let Some(low_speed_timeout_in_seconds_override) =
483 low_speed_timeout_in_seconds_override
484 {
485 *low_speed_timeout_in_seconds =
486 Some(low_speed_timeout_in_seconds_override);
487 }
488 }
489 (provider, provider_override) => {
490 *provider = match provider_override {
491 AssistantProviderContent::ZedDotDev {
492 default_model: model,
493 } => AssistantProvider::ZedDotDev {
494 model: model.unwrap_or_default(),
495 },
496 AssistantProviderContent::OpenAi {
497 default_model: model,
498 api_url,
499 low_speed_timeout_in_seconds,
500 } => AssistantProvider::OpenAi {
501 model: model.unwrap_or_default(),
502 api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
503 low_speed_timeout_in_seconds,
504 },
505 AssistantProviderContent::Anthropic {
506 default_model: model,
507 api_url,
508 low_speed_timeout_in_seconds,
509 } => AssistantProvider::Anthropic {
510 model: model.unwrap_or_default(),
511 api_url: api_url
512 .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
513 low_speed_timeout_in_seconds,
514 },
515 };
516 }
517 }
518 }
519 }
520
521 Ok(settings)
522 }
523}
524
525fn merge<T>(target: &mut T, value: Option<T>) {
526 if let Some(value) = value {
527 *target = value;
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use gpui::{AppContext, UpdateGlobal};
534 use settings::SettingsStore;
535
536 use super::*;
537
538 #[gpui::test]
539 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
540 let store = settings::SettingsStore::test(cx);
541 cx.set_global(store);
542
543 // Settings default to gpt-4-turbo.
544 AssistantSettings::register(cx);
545 assert_eq!(
546 AssistantSettings::get_global(cx).provider,
547 AssistantProvider::OpenAi {
548 model: OpenAiModel::FourOmni,
549 api_url: open_ai::OPEN_AI_API_URL.into(),
550 low_speed_timeout_in_seconds: None,
551 }
552 );
553
554 // Ensure backward-compatibility.
555 SettingsStore::update_global(cx, |store, cx| {
556 store
557 .set_user_settings(
558 r#"{
559 "assistant": {
560 "openai_api_url": "test-url",
561 }
562 }"#,
563 cx,
564 )
565 .unwrap();
566 });
567 assert_eq!(
568 AssistantSettings::get_global(cx).provider,
569 AssistantProvider::OpenAi {
570 model: OpenAiModel::FourOmni,
571 api_url: "test-url".into(),
572 low_speed_timeout_in_seconds: None,
573 }
574 );
575 SettingsStore::update_global(cx, |store, cx| {
576 store
577 .set_user_settings(
578 r#"{
579 "assistant": {
580 "default_open_ai_model": "gpt-4-0613"
581 }
582 }"#,
583 cx,
584 )
585 .unwrap();
586 });
587 assert_eq!(
588 AssistantSettings::get_global(cx).provider,
589 AssistantProvider::OpenAi {
590 model: OpenAiModel::Four,
591 api_url: open_ai::OPEN_AI_API_URL.into(),
592 low_speed_timeout_in_seconds: None,
593 }
594 );
595
596 // The new version supports setting a custom model when using zed.dev.
597 SettingsStore::update_global(cx, |store, cx| {
598 store
599 .set_user_settings(
600 r#"{
601 "assistant": {
602 "version": "1",
603 "provider": {
604 "name": "zed.dev",
605 "default_model": "custom"
606 }
607 }
608 }"#,
609 cx,
610 )
611 .unwrap();
612 });
613 assert_eq!(
614 AssistantSettings::get_global(cx).provider,
615 AssistantProvider::ZedDotDev {
616 model: ZedDotDevModel::Custom("custom".into())
617 }
618 );
619 }
620}