1use std::fmt;
2
3pub use anthropic::Model as AnthropicModel;
4use gpui::Pixels;
5pub use ollama::Model as OllamaModel;
6pub use open_ai::Model as OpenAiModel;
7use schemars::{
8 schema::{InstanceType, Metadata, Schema, SchemaObject},
9 JsonSchema,
10};
11use serde::{
12 de::{self, Visitor},
13 Deserialize, Deserializer, Serialize, Serializer,
14};
15use settings::{Settings, SettingsSources};
16use strum::{EnumIter, IntoEnumIterator};
17
18use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
19
20#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
21pub enum CloudModel {
22 Gpt3Point5Turbo,
23 Gpt4,
24 Gpt4Turbo,
25 #[default]
26 Gpt4Omni,
27 Claude3_5Sonnet,
28 Claude3Opus,
29 Claude3Sonnet,
30 Claude3Haiku,
31 Custom(String),
32}
33
34impl Serialize for CloudModel {
35 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
36 where
37 S: Serializer,
38 {
39 serializer.serialize_str(self.id())
40 }
41}
42
43impl<'de> Deserialize<'de> for CloudModel {
44 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
45 where
46 D: Deserializer<'de>,
47 {
48 struct ZedDotDevModelVisitor;
49
50 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
51 type Value = CloudModel;
52
53 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
54 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
55 }
56
57 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
58 where
59 E: de::Error,
60 {
61 let model = CloudModel::iter()
62 .find(|model| model.id() == value)
63 .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
64 Ok(model)
65 }
66 }
67
68 deserializer.deserialize_str(ZedDotDevModelVisitor)
69 }
70}
71
72impl JsonSchema for CloudModel {
73 fn schema_name() -> String {
74 "ZedDotDevModel".to_owned()
75 }
76
77 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
78 let variants = CloudModel::iter()
79 .filter_map(|model| {
80 let id = model.id();
81 if id.is_empty() {
82 None
83 } else {
84 Some(id.to_string())
85 }
86 })
87 .collect::<Vec<_>>();
88 Schema::Object(SchemaObject {
89 instance_type: Some(InstanceType::String.into()),
90 enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
91 metadata: Some(Box::new(Metadata {
92 title: Some("ZedDotDevModel".to_owned()),
93 default: Some(CloudModel::default().id().into()),
94 examples: variants.into_iter().map(Into::into).collect(),
95 ..Default::default()
96 })),
97 ..Default::default()
98 })
99 }
100}
101
102impl CloudModel {
103 pub fn id(&self) -> &str {
104 match self {
105 Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
106 Self::Gpt4 => "gpt-4",
107 Self::Gpt4Turbo => "gpt-4-turbo-preview",
108 Self::Gpt4Omni => "gpt-4o",
109 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
110 Self::Claude3Opus => "claude-3-opus",
111 Self::Claude3Sonnet => "claude-3-sonnet",
112 Self::Claude3Haiku => "claude-3-haiku",
113 Self::Custom(id) => id,
114 }
115 }
116
117 pub fn display_name(&self) -> &str {
118 match self {
119 Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
120 Self::Gpt4 => "GPT 4",
121 Self::Gpt4Turbo => "GPT 4 Turbo",
122 Self::Gpt4Omni => "GPT 4 Omni",
123 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
124 Self::Claude3Opus => "Claude 3 Opus",
125 Self::Claude3Sonnet => "Claude 3 Sonnet",
126 Self::Claude3Haiku => "Claude 3 Haiku",
127 Self::Custom(id) => id.as_str(),
128 }
129 }
130
131 pub fn max_token_count(&self) -> usize {
132 match self {
133 Self::Gpt3Point5Turbo => 2048,
134 Self::Gpt4 => 4096,
135 Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
136 Self::Claude3_5Sonnet
137 | Self::Claude3Opus
138 | Self::Claude3Sonnet
139 | Self::Claude3Haiku => 200000,
140 Self::Custom(_) => 4096, // TODO: Make this configurable
141 }
142 }
143
144 pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
145 match self {
146 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
147 preprocess_anthropic_request(request)
148 }
149 _ => {}
150 }
151 }
152}
153
154#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
155#[serde(rename_all = "snake_case")]
156pub enum AssistantDockPosition {
157 Left,
158 #[default]
159 Right,
160 Bottom,
161}
162
163#[derive(Debug, PartialEq)]
164pub enum AssistantProvider {
165 ZedDotDev {
166 model: CloudModel,
167 },
168 OpenAi {
169 model: OpenAiModel,
170 api_url: String,
171 low_speed_timeout_in_seconds: Option<u64>,
172 },
173 Anthropic {
174 model: AnthropicModel,
175 api_url: String,
176 low_speed_timeout_in_seconds: Option<u64>,
177 },
178 Ollama {
179 model: OllamaModel,
180 api_url: String,
181 low_speed_timeout_in_seconds: Option<u64>,
182 },
183}
184
185impl Default for AssistantProvider {
186 fn default() -> Self {
187 Self::OpenAi {
188 model: OpenAiModel::default(),
189 api_url: open_ai::OPEN_AI_API_URL.into(),
190 low_speed_timeout_in_seconds: None,
191 }
192 }
193}
194
195#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
196#[serde(tag = "name", rename_all = "snake_case")]
197pub enum AssistantProviderContent {
198 #[serde(rename = "zed.dev")]
199 ZedDotDev { default_model: Option<CloudModel> },
200 #[serde(rename = "openai")]
201 OpenAi {
202 default_model: Option<OpenAiModel>,
203 api_url: Option<String>,
204 low_speed_timeout_in_seconds: Option<u64>,
205 },
206 #[serde(rename = "anthropic")]
207 Anthropic {
208 default_model: Option<AnthropicModel>,
209 api_url: Option<String>,
210 low_speed_timeout_in_seconds: Option<u64>,
211 },
212 #[serde(rename = "ollama")]
213 Ollama {
214 default_model: Option<OllamaModel>,
215 api_url: Option<String>,
216 low_speed_timeout_in_seconds: Option<u64>,
217 },
218}
219
220#[derive(Debug, Default)]
221pub struct AssistantSettings {
222 pub enabled: bool,
223 pub button: bool,
224 pub dock: AssistantDockPosition,
225 pub default_width: Pixels,
226 pub default_height: Pixels,
227 pub provider: AssistantProvider,
228}
229
230/// Assistant panel settings
231#[derive(Clone, Serialize, Deserialize, Debug)]
232#[serde(untagged)]
233pub enum AssistantSettingsContent {
234 Versioned(VersionedAssistantSettingsContent),
235 Legacy(LegacyAssistantSettingsContent),
236}
237
238impl JsonSchema for AssistantSettingsContent {
239 fn schema_name() -> String {
240 VersionedAssistantSettingsContent::schema_name()
241 }
242
243 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
244 VersionedAssistantSettingsContent::json_schema(gen)
245 }
246
247 fn is_referenceable() -> bool {
248 VersionedAssistantSettingsContent::is_referenceable()
249 }
250}
251
252impl Default for AssistantSettingsContent {
253 fn default() -> Self {
254 Self::Versioned(VersionedAssistantSettingsContent::default())
255 }
256}
257
258impl AssistantSettingsContent {
259 fn upgrade(&self) -> AssistantSettingsContentV1 {
260 match self {
261 AssistantSettingsContent::Versioned(settings) => match settings {
262 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
263 },
264 AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
265 enabled: None,
266 button: settings.button,
267 dock: settings.dock,
268 default_width: settings.default_width,
269 default_height: settings.default_height,
270 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
271 Some(AssistantProviderContent::OpenAi {
272 default_model: settings.default_open_ai_model.clone(),
273 api_url: Some(open_ai_api_url.clone()),
274 low_speed_timeout_in_seconds: None,
275 })
276 } else {
277 settings.default_open_ai_model.clone().map(|open_ai_model| {
278 AssistantProviderContent::OpenAi {
279 default_model: Some(open_ai_model),
280 api_url: None,
281 low_speed_timeout_in_seconds: None,
282 }
283 })
284 },
285 },
286 }
287 }
288
289 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
290 match self {
291 AssistantSettingsContent::Versioned(settings) => match settings {
292 VersionedAssistantSettingsContent::V1(settings) => {
293 settings.dock = Some(dock);
294 }
295 },
296 AssistantSettingsContent::Legacy(settings) => {
297 settings.dock = Some(dock);
298 }
299 }
300 }
301
302 pub fn set_model(&mut self, new_model: LanguageModel) {
303 match self {
304 AssistantSettingsContent::Versioned(settings) => match settings {
305 VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
306 Some(AssistantProviderContent::ZedDotDev {
307 default_model: model,
308 }) => {
309 if let LanguageModel::Cloud(new_model) = new_model {
310 *model = Some(new_model);
311 }
312 }
313 Some(AssistantProviderContent::OpenAi {
314 default_model: model,
315 ..
316 }) => {
317 if let LanguageModel::OpenAi(new_model) = new_model {
318 *model = Some(new_model);
319 }
320 }
321 Some(AssistantProviderContent::Anthropic {
322 default_model: model,
323 ..
324 }) => {
325 if let LanguageModel::Anthropic(new_model) = new_model {
326 *model = Some(new_model);
327 }
328 }
329 provider => match new_model {
330 LanguageModel::Cloud(model) => {
331 *provider = Some(AssistantProviderContent::ZedDotDev {
332 default_model: Some(model),
333 })
334 }
335 LanguageModel::OpenAi(model) => {
336 *provider = Some(AssistantProviderContent::OpenAi {
337 default_model: Some(model),
338 api_url: None,
339 low_speed_timeout_in_seconds: None,
340 })
341 }
342 LanguageModel::Anthropic(model) => {
343 *provider = Some(AssistantProviderContent::Anthropic {
344 default_model: Some(model),
345 api_url: None,
346 low_speed_timeout_in_seconds: None,
347 })
348 }
349 LanguageModel::Ollama(model) => {
350 *provider = Some(AssistantProviderContent::Ollama {
351 default_model: Some(model),
352 api_url: None,
353 low_speed_timeout_in_seconds: None,
354 })
355 }
356 },
357 },
358 },
359 AssistantSettingsContent::Legacy(settings) => {
360 if let LanguageModel::OpenAi(model) = new_model {
361 settings.default_open_ai_model = Some(model);
362 }
363 }
364 }
365 }
366}
367
368#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
369#[serde(tag = "version")]
370pub enum VersionedAssistantSettingsContent {
371 #[serde(rename = "1")]
372 V1(AssistantSettingsContentV1),
373}
374
375impl Default for VersionedAssistantSettingsContent {
376 fn default() -> Self {
377 Self::V1(AssistantSettingsContentV1 {
378 enabled: None,
379 button: None,
380 dock: None,
381 default_width: None,
382 default_height: None,
383 provider: None,
384 })
385 }
386}
387
388#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
389pub struct AssistantSettingsContentV1 {
390 /// Whether the Assistant is enabled.
391 ///
392 /// Default: true
393 enabled: Option<bool>,
394 /// Whether to show the assistant panel button in the status bar.
395 ///
396 /// Default: true
397 button: Option<bool>,
398 /// Where to dock the assistant.
399 ///
400 /// Default: right
401 dock: Option<AssistantDockPosition>,
402 /// Default width in pixels when the assistant is docked to the left or right.
403 ///
404 /// Default: 640
405 default_width: Option<f32>,
406 /// Default height in pixels when the assistant is docked to the bottom.
407 ///
408 /// Default: 320
409 default_height: Option<f32>,
410 /// The provider of the assistant service.
411 ///
412 /// This can either be the internal `zed.dev` service or an external `openai` service,
413 /// each with their respective default models and configurations.
414 provider: Option<AssistantProviderContent>,
415}
416
417#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
418pub struct LegacyAssistantSettingsContent {
419 /// Whether to show the assistant panel button in the status bar.
420 ///
421 /// Default: true
422 pub button: Option<bool>,
423 /// Where to dock the assistant.
424 ///
425 /// Default: right
426 pub dock: Option<AssistantDockPosition>,
427 /// Default width in pixels when the assistant is docked to the left or right.
428 ///
429 /// Default: 640
430 pub default_width: Option<f32>,
431 /// Default height in pixels when the assistant is docked to the bottom.
432 ///
433 /// Default: 320
434 pub default_height: Option<f32>,
435 /// The default OpenAI model to use when creating new contexts.
436 ///
437 /// Default: gpt-4-1106-preview
438 pub default_open_ai_model: Option<OpenAiModel>,
439 /// OpenAI API base URL to use when creating new contexts.
440 ///
441 /// Default: https://api.openai.com/v1
442 pub openai_api_url: Option<String>,
443}
444
445impl Settings for AssistantSettings {
446 const KEY: Option<&'static str> = Some("assistant");
447
448 type FileContent = AssistantSettingsContent;
449
450 fn load(
451 sources: SettingsSources<Self::FileContent>,
452 _: &mut gpui::AppContext,
453 ) -> anyhow::Result<Self> {
454 let mut settings = AssistantSettings::default();
455
456 for value in sources.defaults_and_customizations() {
457 let value = value.upgrade();
458 merge(&mut settings.enabled, value.enabled);
459 merge(&mut settings.button, value.button);
460 merge(&mut settings.dock, value.dock);
461 merge(
462 &mut settings.default_width,
463 value.default_width.map(Into::into),
464 );
465 merge(
466 &mut settings.default_height,
467 value.default_height.map(Into::into),
468 );
469 if let Some(provider) = value.provider.clone() {
470 match (&mut settings.provider, provider) {
471 (
472 AssistantProvider::ZedDotDev { model },
473 AssistantProviderContent::ZedDotDev {
474 default_model: model_override,
475 },
476 ) => {
477 merge(model, model_override);
478 }
479 (
480 AssistantProvider::OpenAi {
481 model,
482 api_url,
483 low_speed_timeout_in_seconds,
484 },
485 AssistantProviderContent::OpenAi {
486 default_model: model_override,
487 api_url: api_url_override,
488 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
489 },
490 ) => {
491 merge(model, model_override);
492 merge(api_url, api_url_override);
493 if let Some(low_speed_timeout_in_seconds_override) =
494 low_speed_timeout_in_seconds_override
495 {
496 *low_speed_timeout_in_seconds =
497 Some(low_speed_timeout_in_seconds_override);
498 }
499 }
500 (
501 AssistantProvider::Ollama {
502 model,
503 api_url,
504 low_speed_timeout_in_seconds,
505 },
506 AssistantProviderContent::Ollama {
507 default_model: model_override,
508 api_url: api_url_override,
509 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
510 },
511 ) => {
512 merge(model, model_override);
513 merge(api_url, api_url_override);
514 if let Some(low_speed_timeout_in_seconds_override) =
515 low_speed_timeout_in_seconds_override
516 {
517 *low_speed_timeout_in_seconds =
518 Some(low_speed_timeout_in_seconds_override);
519 }
520 }
521 (
522 AssistantProvider::Anthropic {
523 model,
524 api_url,
525 low_speed_timeout_in_seconds,
526 },
527 AssistantProviderContent::Anthropic {
528 default_model: model_override,
529 api_url: api_url_override,
530 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
531 },
532 ) => {
533 merge(model, model_override);
534 merge(api_url, api_url_override);
535 if let Some(low_speed_timeout_in_seconds_override) =
536 low_speed_timeout_in_seconds_override
537 {
538 *low_speed_timeout_in_seconds =
539 Some(low_speed_timeout_in_seconds_override);
540 }
541 }
542 (provider, provider_override) => {
543 *provider = match provider_override {
544 AssistantProviderContent::ZedDotDev {
545 default_model: model,
546 } => AssistantProvider::ZedDotDev {
547 model: model.unwrap_or_default(),
548 },
549 AssistantProviderContent::OpenAi {
550 default_model: model,
551 api_url,
552 low_speed_timeout_in_seconds,
553 } => AssistantProvider::OpenAi {
554 model: model.unwrap_or_default(),
555 api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
556 low_speed_timeout_in_seconds,
557 },
558 AssistantProviderContent::Anthropic {
559 default_model: model,
560 api_url,
561 low_speed_timeout_in_seconds,
562 } => AssistantProvider::Anthropic {
563 model: model.unwrap_or_default(),
564 api_url: api_url
565 .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
566 low_speed_timeout_in_seconds,
567 },
568 AssistantProviderContent::Ollama {
569 default_model: model,
570 api_url,
571 low_speed_timeout_in_seconds,
572 } => AssistantProvider::Ollama {
573 model: model.unwrap_or_default(),
574 api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
575 low_speed_timeout_in_seconds,
576 },
577 };
578 }
579 }
580 }
581 }
582
583 Ok(settings)
584 }
585}
586
587fn merge<T>(target: &mut T, value: Option<T>) {
588 if let Some(value) = value {
589 *target = value;
590 }
591}
592
593#[cfg(test)]
594mod tests {
595 use gpui::{AppContext, UpdateGlobal};
596 use settings::SettingsStore;
597
598 use super::*;
599
600 #[gpui::test]
601 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
602 let store = settings::SettingsStore::test(cx);
603 cx.set_global(store);
604
605 // Settings default to gpt-4-turbo.
606 AssistantSettings::register(cx);
607 assert_eq!(
608 AssistantSettings::get_global(cx).provider,
609 AssistantProvider::OpenAi {
610 model: OpenAiModel::FourOmni,
611 api_url: open_ai::OPEN_AI_API_URL.into(),
612 low_speed_timeout_in_seconds: None,
613 }
614 );
615
616 // Ensure backward-compatibility.
617 SettingsStore::update_global(cx, |store, cx| {
618 store
619 .set_user_settings(
620 r#"{
621 "assistant": {
622 "openai_api_url": "test-url",
623 }
624 }"#,
625 cx,
626 )
627 .unwrap();
628 });
629 assert_eq!(
630 AssistantSettings::get_global(cx).provider,
631 AssistantProvider::OpenAi {
632 model: OpenAiModel::FourOmni,
633 api_url: "test-url".into(),
634 low_speed_timeout_in_seconds: None,
635 }
636 );
637 SettingsStore::update_global(cx, |store, cx| {
638 store
639 .set_user_settings(
640 r#"{
641 "assistant": {
642 "default_open_ai_model": "gpt-4-0613"
643 }
644 }"#,
645 cx,
646 )
647 .unwrap();
648 });
649 assert_eq!(
650 AssistantSettings::get_global(cx).provider,
651 AssistantProvider::OpenAi {
652 model: OpenAiModel::Four,
653 api_url: open_ai::OPEN_AI_API_URL.into(),
654 low_speed_timeout_in_seconds: None,
655 }
656 );
657
658 // The new version supports setting a custom model when using zed.dev.
659 SettingsStore::update_global(cx, |store, cx| {
660 store
661 .set_user_settings(
662 r#"{
663 "assistant": {
664 "version": "1",
665 "provider": {
666 "name": "zed.dev",
667 "default_model": "custom"
668 }
669 }
670 }"#,
671 cx,
672 )
673 .unwrap();
674 });
675 assert_eq!(
676 AssistantSettings::get_global(cx).provider,
677 AssistantProvider::ZedDotDev {
678 model: CloudModel::Custom("custom".into())
679 }
680 );
681 }
682}