1use std::fmt;
2
3pub use anthropic::Model as AnthropicModel;
4use gpui::Pixels;
5pub use ollama::Model as OllamaModel;
6pub use open_ai::Model as OpenAiModel;
7use schemars::{
8 schema::{InstanceType, Metadata, Schema, SchemaObject},
9 JsonSchema,
10};
11use serde::{
12 de::{self, Visitor},
13 Deserialize, Deserializer, Serialize, Serializer,
14};
15use settings::{Settings, SettingsSources};
16use strum::{EnumIter, IntoEnumIterator};
17
18use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
19
20#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
21pub enum CloudModel {
22 Gpt3Point5Turbo,
23 Gpt4,
24 Gpt4Turbo,
25 #[default]
26 Gpt4Omni,
27 Claude3_5Sonnet,
28 Claude3Opus,
29 Claude3Sonnet,
30 Claude3Haiku,
31 Custom(String),
32}
33
34impl Serialize for CloudModel {
35 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
36 where
37 S: Serializer,
38 {
39 serializer.serialize_str(self.id())
40 }
41}
42
43impl<'de> Deserialize<'de> for CloudModel {
44 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
45 where
46 D: Deserializer<'de>,
47 {
48 struct ZedDotDevModelVisitor;
49
50 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
51 type Value = CloudModel;
52
53 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
54 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
55 }
56
57 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
58 where
59 E: de::Error,
60 {
61 let model = CloudModel::iter()
62 .find(|model| model.id() == value)
63 .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
64 Ok(model)
65 }
66 }
67
68 deserializer.deserialize_str(ZedDotDevModelVisitor)
69 }
70}
71
72impl JsonSchema for CloudModel {
73 fn schema_name() -> String {
74 "ZedDotDevModel".to_owned()
75 }
76
77 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
78 let variants = CloudModel::iter()
79 .filter_map(|model| {
80 let id = model.id();
81 if id.is_empty() {
82 None
83 } else {
84 Some(id.to_string())
85 }
86 })
87 .collect::<Vec<_>>();
88 Schema::Object(SchemaObject {
89 instance_type: Some(InstanceType::String.into()),
90 enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
91 metadata: Some(Box::new(Metadata {
92 title: Some("ZedDotDevModel".to_owned()),
93 default: Some(CloudModel::default().id().into()),
94 examples: variants.into_iter().map(Into::into).collect(),
95 ..Default::default()
96 })),
97 ..Default::default()
98 })
99 }
100}
101
102impl CloudModel {
103 pub fn id(&self) -> &str {
104 match self {
105 Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
106 Self::Gpt4 => "gpt-4",
107 Self::Gpt4Turbo => "gpt-4-turbo-preview",
108 Self::Gpt4Omni => "gpt-4o",
109 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
110 Self::Claude3Opus => "claude-3-opus",
111 Self::Claude3Sonnet => "claude-3-sonnet",
112 Self::Claude3Haiku => "claude-3-haiku",
113 Self::Custom(id) => id,
114 }
115 }
116
117 pub fn display_name(&self) -> &str {
118 match self {
119 Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
120 Self::Gpt4 => "GPT 4",
121 Self::Gpt4Turbo => "GPT 4 Turbo",
122 Self::Gpt4Omni => "GPT 4 Omni",
123 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
124 Self::Claude3Opus => "Claude 3 Opus",
125 Self::Claude3Sonnet => "Claude 3 Sonnet",
126 Self::Claude3Haiku => "Claude 3 Haiku",
127 Self::Custom(id) => id.as_str(),
128 }
129 }
130
131 pub fn max_token_count(&self) -> usize {
132 match self {
133 Self::Gpt3Point5Turbo => 2048,
134 Self::Gpt4 => 4096,
135 Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
136 Self::Claude3_5Sonnet
137 | Self::Claude3Opus
138 | Self::Claude3Sonnet
139 | Self::Claude3Haiku => 200000,
140 Self::Custom(_) => 4096, // TODO: Make this configurable
141 }
142 }
143
144 pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
145 match self {
146 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
147 preprocess_anthropic_request(request)
148 }
149 _ => {}
150 }
151 }
152}
153
154#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
155#[serde(rename_all = "snake_case")]
156pub enum AssistantDockPosition {
157 Left,
158 #[default]
159 Right,
160 Bottom,
161}
162
163#[derive(Debug, PartialEq)]
164pub enum AssistantProvider {
165 ZedDotDev {
166 model: CloudModel,
167 },
168 OpenAi {
169 model: OpenAiModel,
170 api_url: String,
171 low_speed_timeout_in_seconds: Option<u64>,
172 },
173 Anthropic {
174 model: AnthropicModel,
175 api_url: String,
176 low_speed_timeout_in_seconds: Option<u64>,
177 },
178 Ollama {
179 model: OllamaModel,
180 api_url: String,
181 low_speed_timeout_in_seconds: Option<u64>,
182 },
183}
184
185impl Default for AssistantProvider {
186 fn default() -> Self {
187 Self::OpenAi {
188 model: OpenAiModel::default(),
189 api_url: open_ai::OPEN_AI_API_URL.into(),
190 low_speed_timeout_in_seconds: None,
191 }
192 }
193}
194
195#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
196#[serde(tag = "name", rename_all = "snake_case")]
197pub enum AssistantProviderContent {
198 #[serde(rename = "zed.dev")]
199 ZedDotDev { default_model: Option<CloudModel> },
200 #[serde(rename = "openai")]
201 OpenAi {
202 default_model: Option<OpenAiModel>,
203 api_url: Option<String>,
204 low_speed_timeout_in_seconds: Option<u64>,
205 },
206 #[serde(rename = "anthropic")]
207 Anthropic {
208 default_model: Option<AnthropicModel>,
209 api_url: Option<String>,
210 low_speed_timeout_in_seconds: Option<u64>,
211 },
212 #[serde(rename = "ollama")]
213 Ollama {
214 default_model: Option<OllamaModel>,
215 api_url: Option<String>,
216 low_speed_timeout_in_seconds: Option<u64>,
217 },
218}
219
220#[derive(Debug, Default)]
221pub struct AssistantSettings {
222 pub enabled: bool,
223 pub button: bool,
224 pub dock: AssistantDockPosition,
225 pub default_width: Pixels,
226 pub default_height: Pixels,
227 pub provider: AssistantProvider,
228}
229
230/// Assistant panel settings
231#[derive(Clone, Serialize, Deserialize, Debug)]
232#[serde(untagged)]
233pub enum AssistantSettingsContent {
234 Versioned(VersionedAssistantSettingsContent),
235 Legacy(LegacyAssistantSettingsContent),
236}
237
238impl JsonSchema for AssistantSettingsContent {
239 fn schema_name() -> String {
240 VersionedAssistantSettingsContent::schema_name()
241 }
242
243 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
244 VersionedAssistantSettingsContent::json_schema(gen)
245 }
246
247 fn is_referenceable() -> bool {
248 VersionedAssistantSettingsContent::is_referenceable()
249 }
250}
251
252impl Default for AssistantSettingsContent {
253 fn default() -> Self {
254 Self::Versioned(VersionedAssistantSettingsContent::default())
255 }
256}
257
258impl AssistantSettingsContent {
259 fn upgrade(&self) -> AssistantSettingsContentV1 {
260 match self {
261 AssistantSettingsContent::Versioned(settings) => match settings {
262 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
263 },
264 AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
265 enabled: None,
266 button: settings.button,
267 dock: settings.dock,
268 default_width: settings.default_width,
269 default_height: settings.default_height,
270 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
271 Some(AssistantProviderContent::OpenAi {
272 default_model: settings.default_open_ai_model.clone(),
273 api_url: Some(open_ai_api_url.clone()),
274 low_speed_timeout_in_seconds: None,
275 })
276 } else {
277 settings.default_open_ai_model.clone().map(|open_ai_model| {
278 AssistantProviderContent::OpenAi {
279 default_model: Some(open_ai_model),
280 api_url: None,
281 low_speed_timeout_in_seconds: None,
282 }
283 })
284 },
285 },
286 }
287 }
288
289 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
290 match self {
291 AssistantSettingsContent::Versioned(settings) => match settings {
292 VersionedAssistantSettingsContent::V1(settings) => {
293 settings.dock = Some(dock);
294 }
295 },
296 AssistantSettingsContent::Legacy(settings) => {
297 settings.dock = Some(dock);
298 }
299 }
300 }
301
302 pub fn set_model(&mut self, new_model: LanguageModel) {
303 match self {
304 AssistantSettingsContent::Versioned(settings) => match settings {
305 VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
306 Some(AssistantProviderContent::ZedDotDev {
307 default_model: model,
308 }) => {
309 if let LanguageModel::Cloud(new_model) = new_model {
310 *model = Some(new_model);
311 }
312 }
313 Some(AssistantProviderContent::OpenAi {
314 default_model: model,
315 ..
316 }) => {
317 if let LanguageModel::OpenAi(new_model) = new_model {
318 *model = Some(new_model);
319 }
320 }
321 Some(AssistantProviderContent::Anthropic {
322 default_model: model,
323 ..
324 }) => {
325 if let LanguageModel::Anthropic(new_model) = new_model {
326 *model = Some(new_model);
327 }
328 }
329 Some(AssistantProviderContent::Ollama {
330 default_model: model,
331 ..
332 }) => {
333 if let LanguageModel::Ollama(new_model) = new_model {
334 *model = Some(new_model);
335 }
336 }
337 provider => match new_model {
338 LanguageModel::Cloud(model) => {
339 *provider = Some(AssistantProviderContent::ZedDotDev {
340 default_model: Some(model),
341 })
342 }
343 LanguageModel::OpenAi(model) => {
344 *provider = Some(AssistantProviderContent::OpenAi {
345 default_model: Some(model),
346 api_url: None,
347 low_speed_timeout_in_seconds: None,
348 })
349 }
350 LanguageModel::Anthropic(model) => {
351 *provider = Some(AssistantProviderContent::Anthropic {
352 default_model: Some(model),
353 api_url: None,
354 low_speed_timeout_in_seconds: None,
355 })
356 }
357 LanguageModel::Ollama(model) => {
358 *provider = Some(AssistantProviderContent::Ollama {
359 default_model: Some(model),
360 api_url: None,
361 low_speed_timeout_in_seconds: None,
362 })
363 }
364 },
365 },
366 },
367 AssistantSettingsContent::Legacy(settings) => {
368 if let LanguageModel::OpenAi(model) = new_model {
369 settings.default_open_ai_model = Some(model);
370 }
371 }
372 }
373 }
374}
375
376#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
377#[serde(tag = "version")]
378pub enum VersionedAssistantSettingsContent {
379 #[serde(rename = "1")]
380 V1(AssistantSettingsContentV1),
381}
382
383impl Default for VersionedAssistantSettingsContent {
384 fn default() -> Self {
385 Self::V1(AssistantSettingsContentV1 {
386 enabled: None,
387 button: None,
388 dock: None,
389 default_width: None,
390 default_height: None,
391 provider: None,
392 })
393 }
394}
395
396#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
397pub struct AssistantSettingsContentV1 {
398 /// Whether the Assistant is enabled.
399 ///
400 /// Default: true
401 enabled: Option<bool>,
402 /// Whether to show the assistant panel button in the status bar.
403 ///
404 /// Default: true
405 button: Option<bool>,
406 /// Where to dock the assistant.
407 ///
408 /// Default: right
409 dock: Option<AssistantDockPosition>,
410 /// Default width in pixels when the assistant is docked to the left or right.
411 ///
412 /// Default: 640
413 default_width: Option<f32>,
414 /// Default height in pixels when the assistant is docked to the bottom.
415 ///
416 /// Default: 320
417 default_height: Option<f32>,
418 /// The provider of the assistant service.
419 ///
420 /// This can either be the internal `zed.dev` service or an external `openai` service,
421 /// each with their respective default models and configurations.
422 provider: Option<AssistantProviderContent>,
423}
424
425#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
426pub struct LegacyAssistantSettingsContent {
427 /// Whether to show the assistant panel button in the status bar.
428 ///
429 /// Default: true
430 pub button: Option<bool>,
431 /// Where to dock the assistant.
432 ///
433 /// Default: right
434 pub dock: Option<AssistantDockPosition>,
435 /// Default width in pixels when the assistant is docked to the left or right.
436 ///
437 /// Default: 640
438 pub default_width: Option<f32>,
439 /// Default height in pixels when the assistant is docked to the bottom.
440 ///
441 /// Default: 320
442 pub default_height: Option<f32>,
443 /// The default OpenAI model to use when creating new contexts.
444 ///
445 /// Default: gpt-4-1106-preview
446 pub default_open_ai_model: Option<OpenAiModel>,
447 /// OpenAI API base URL to use when creating new contexts.
448 ///
449 /// Default: https://api.openai.com/v1
450 pub openai_api_url: Option<String>,
451}
452
453impl Settings for AssistantSettings {
454 const KEY: Option<&'static str> = Some("assistant");
455
456 type FileContent = AssistantSettingsContent;
457
458 fn load(
459 sources: SettingsSources<Self::FileContent>,
460 _: &mut gpui::AppContext,
461 ) -> anyhow::Result<Self> {
462 let mut settings = AssistantSettings::default();
463
464 for value in sources.defaults_and_customizations() {
465 let value = value.upgrade();
466 merge(&mut settings.enabled, value.enabled);
467 merge(&mut settings.button, value.button);
468 merge(&mut settings.dock, value.dock);
469 merge(
470 &mut settings.default_width,
471 value.default_width.map(Into::into),
472 );
473 merge(
474 &mut settings.default_height,
475 value.default_height.map(Into::into),
476 );
477 if let Some(provider) = value.provider.clone() {
478 match (&mut settings.provider, provider) {
479 (
480 AssistantProvider::ZedDotDev { model },
481 AssistantProviderContent::ZedDotDev {
482 default_model: model_override,
483 },
484 ) => {
485 merge(model, model_override);
486 }
487 (
488 AssistantProvider::OpenAi {
489 model,
490 api_url,
491 low_speed_timeout_in_seconds,
492 },
493 AssistantProviderContent::OpenAi {
494 default_model: model_override,
495 api_url: api_url_override,
496 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
497 },
498 ) => {
499 merge(model, model_override);
500 merge(api_url, api_url_override);
501 if let Some(low_speed_timeout_in_seconds_override) =
502 low_speed_timeout_in_seconds_override
503 {
504 *low_speed_timeout_in_seconds =
505 Some(low_speed_timeout_in_seconds_override);
506 }
507 }
508 (
509 AssistantProvider::Ollama {
510 model,
511 api_url,
512 low_speed_timeout_in_seconds,
513 },
514 AssistantProviderContent::Ollama {
515 default_model: model_override,
516 api_url: api_url_override,
517 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
518 },
519 ) => {
520 merge(model, model_override);
521 merge(api_url, api_url_override);
522 if let Some(low_speed_timeout_in_seconds_override) =
523 low_speed_timeout_in_seconds_override
524 {
525 *low_speed_timeout_in_seconds =
526 Some(low_speed_timeout_in_seconds_override);
527 }
528 }
529 (
530 AssistantProvider::Anthropic {
531 model,
532 api_url,
533 low_speed_timeout_in_seconds,
534 },
535 AssistantProviderContent::Anthropic {
536 default_model: model_override,
537 api_url: api_url_override,
538 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
539 },
540 ) => {
541 merge(model, model_override);
542 merge(api_url, api_url_override);
543 if let Some(low_speed_timeout_in_seconds_override) =
544 low_speed_timeout_in_seconds_override
545 {
546 *low_speed_timeout_in_seconds =
547 Some(low_speed_timeout_in_seconds_override);
548 }
549 }
550 (provider, provider_override) => {
551 *provider = match provider_override {
552 AssistantProviderContent::ZedDotDev {
553 default_model: model,
554 } => AssistantProvider::ZedDotDev {
555 model: model.unwrap_or_default(),
556 },
557 AssistantProviderContent::OpenAi {
558 default_model: model,
559 api_url,
560 low_speed_timeout_in_seconds,
561 } => AssistantProvider::OpenAi {
562 model: model.unwrap_or_default(),
563 api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
564 low_speed_timeout_in_seconds,
565 },
566 AssistantProviderContent::Anthropic {
567 default_model: model,
568 api_url,
569 low_speed_timeout_in_seconds,
570 } => AssistantProvider::Anthropic {
571 model: model.unwrap_or_default(),
572 api_url: api_url
573 .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
574 low_speed_timeout_in_seconds,
575 },
576 AssistantProviderContent::Ollama {
577 default_model: model,
578 api_url,
579 low_speed_timeout_in_seconds,
580 } => AssistantProvider::Ollama {
581 model: model.unwrap_or_default(),
582 api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
583 low_speed_timeout_in_seconds,
584 },
585 };
586 }
587 }
588 }
589 }
590
591 Ok(settings)
592 }
593}
594
595fn merge<T>(target: &mut T, value: Option<T>) {
596 if let Some(value) = value {
597 *target = value;
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use gpui::{AppContext, UpdateGlobal};
604 use settings::SettingsStore;
605
606 use super::*;
607
608 #[gpui::test]
609 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
610 let store = settings::SettingsStore::test(cx);
611 cx.set_global(store);
612
613 // Settings default to gpt-4-turbo.
614 AssistantSettings::register(cx);
615 assert_eq!(
616 AssistantSettings::get_global(cx).provider,
617 AssistantProvider::OpenAi {
618 model: OpenAiModel::FourOmni,
619 api_url: open_ai::OPEN_AI_API_URL.into(),
620 low_speed_timeout_in_seconds: None,
621 }
622 );
623
624 // Ensure backward-compatibility.
625 SettingsStore::update_global(cx, |store, cx| {
626 store
627 .set_user_settings(
628 r#"{
629 "assistant": {
630 "openai_api_url": "test-url",
631 }
632 }"#,
633 cx,
634 )
635 .unwrap();
636 });
637 assert_eq!(
638 AssistantSettings::get_global(cx).provider,
639 AssistantProvider::OpenAi {
640 model: OpenAiModel::FourOmni,
641 api_url: "test-url".into(),
642 low_speed_timeout_in_seconds: None,
643 }
644 );
645 SettingsStore::update_global(cx, |store, cx| {
646 store
647 .set_user_settings(
648 r#"{
649 "assistant": {
650 "default_open_ai_model": "gpt-4-0613"
651 }
652 }"#,
653 cx,
654 )
655 .unwrap();
656 });
657 assert_eq!(
658 AssistantSettings::get_global(cx).provider,
659 AssistantProvider::OpenAi {
660 model: OpenAiModel::Four,
661 api_url: open_ai::OPEN_AI_API_URL.into(),
662 low_speed_timeout_in_seconds: None,
663 }
664 );
665
666 // The new version supports setting a custom model when using zed.dev.
667 SettingsStore::update_global(cx, |store, cx| {
668 store
669 .set_user_settings(
670 r#"{
671 "assistant": {
672 "version": "1",
673 "provider": {
674 "name": "zed.dev",
675 "default_model": "custom"
676 }
677 }
678 }"#,
679 cx,
680 )
681 .unwrap();
682 });
683 assert_eq!(
684 AssistantSettings::get_global(cx).provider,
685 AssistantProvider::ZedDotDev {
686 model: CloudModel::Custom("custom".into())
687 }
688 );
689 }
690}