1use std::fmt;
2
3pub use anthropic::Model as AnthropicModel;
4use gpui::Pixels;
5pub use ollama::Model as OllamaModel;
6pub use open_ai::Model as OpenAiModel;
7use schemars::{
8 schema::{InstanceType, Metadata, Schema, SchemaObject},
9 JsonSchema,
10};
11use serde::{
12 de::{self, Visitor},
13 Deserialize, Deserializer, Serialize, Serializer,
14};
15use settings::{Settings, SettingsSources};
16use strum::{EnumIter, IntoEnumIterator};
17
18use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
19
20#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
21pub enum CloudModel {
22 Gpt3Point5Turbo,
23 Gpt4,
24 Gpt4Turbo,
25 #[default]
26 Gpt4Omni,
27 Claude3_5Sonnet,
28 Claude3Opus,
29 Claude3Sonnet,
30 Claude3Haiku,
31 Custom(String),
32}
33
34impl Serialize for CloudModel {
35 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
36 where
37 S: Serializer,
38 {
39 serializer.serialize_str(self.id())
40 }
41}
42
43impl<'de> Deserialize<'de> for CloudModel {
44 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
45 where
46 D: Deserializer<'de>,
47 {
48 struct ZedDotDevModelVisitor;
49
50 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
51 type Value = CloudModel;
52
53 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
54 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
55 }
56
57 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
58 where
59 E: de::Error,
60 {
61 let model = CloudModel::iter()
62 .find(|model| model.id() == value)
63 .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
64 Ok(model)
65 }
66 }
67
68 deserializer.deserialize_str(ZedDotDevModelVisitor)
69 }
70}
71
72impl JsonSchema for CloudModel {
73 fn schema_name() -> String {
74 "ZedDotDevModel".to_owned()
75 }
76
77 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
78 let variants = CloudModel::iter()
79 .filter_map(|model| {
80 let id = model.id();
81 if id.is_empty() {
82 None
83 } else {
84 Some(id.to_string())
85 }
86 })
87 .collect::<Vec<_>>();
88 Schema::Object(SchemaObject {
89 instance_type: Some(InstanceType::String.into()),
90 enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
91 metadata: Some(Box::new(Metadata {
92 title: Some("ZedDotDevModel".to_owned()),
93 default: Some(CloudModel::default().id().into()),
94 examples: variants.into_iter().map(Into::into).collect(),
95 ..Default::default()
96 })),
97 ..Default::default()
98 })
99 }
100}
101
102impl CloudModel {
103 pub fn id(&self) -> &str {
104 match self {
105 Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
106 Self::Gpt4 => "gpt-4",
107 Self::Gpt4Turbo => "gpt-4-turbo-preview",
108 Self::Gpt4Omni => "gpt-4o",
109 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
110 Self::Claude3Opus => "claude-3-opus",
111 Self::Claude3Sonnet => "claude-3-sonnet",
112 Self::Claude3Haiku => "claude-3-haiku",
113 Self::Custom(id) => id,
114 }
115 }
116
117 pub fn display_name(&self) -> &str {
118 match self {
119 Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
120 Self::Gpt4 => "GPT 4",
121 Self::Gpt4Turbo => "GPT 4 Turbo",
122 Self::Gpt4Omni => "GPT 4 Omni",
123 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
124 Self::Claude3Opus => "Claude 3 Opus",
125 Self::Claude3Sonnet => "Claude 3 Sonnet",
126 Self::Claude3Haiku => "Claude 3 Haiku",
127 Self::Custom(id) => id.as_str(),
128 }
129 }
130
131 pub fn max_token_count(&self) -> usize {
132 match self {
133 Self::Gpt3Point5Turbo => 2048,
134 Self::Gpt4 => 4096,
135 Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
136 Self::Claude3_5Sonnet
137 | Self::Claude3Opus
138 | Self::Claude3Sonnet
139 | Self::Claude3Haiku => 200000,
140 Self::Custom(_) => 4096, // TODO: Make this configurable
141 }
142 }
143
144 pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
145 match self {
146 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
147 preprocess_anthropic_request(request)
148 }
149 _ => {}
150 }
151 }
152}
153
154#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
155#[serde(rename_all = "snake_case")]
156pub enum AssistantDockPosition {
157 Left,
158 #[default]
159 Right,
160 Bottom,
161}
162
163#[derive(Debug, PartialEq)]
164pub enum AssistantProvider {
165 ZedDotDev {
166 model: CloudModel,
167 },
168 OpenAi {
169 model: OpenAiModel,
170 api_url: String,
171 low_speed_timeout_in_seconds: Option<u64>,
172 available_models: Vec<OpenAiModel>,
173 },
174 Anthropic {
175 model: AnthropicModel,
176 api_url: String,
177 low_speed_timeout_in_seconds: Option<u64>,
178 },
179 Ollama {
180 model: OllamaModel,
181 api_url: String,
182 low_speed_timeout_in_seconds: Option<u64>,
183 },
184}
185
186impl Default for AssistantProvider {
187 fn default() -> Self {
188 Self::OpenAi {
189 model: OpenAiModel::default(),
190 api_url: open_ai::OPEN_AI_API_URL.into(),
191 low_speed_timeout_in_seconds: None,
192 available_models: Default::default(),
193 }
194 }
195}
196
197#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
198#[serde(tag = "name", rename_all = "snake_case")]
199pub enum AssistantProviderContent {
200 #[serde(rename = "zed.dev")]
201 ZedDotDev { default_model: Option<CloudModel> },
202 #[serde(rename = "openai")]
203 OpenAi {
204 default_model: Option<OpenAiModel>,
205 api_url: Option<String>,
206 low_speed_timeout_in_seconds: Option<u64>,
207 available_models: Option<Vec<OpenAiModel>>,
208 },
209 #[serde(rename = "anthropic")]
210 Anthropic {
211 default_model: Option<AnthropicModel>,
212 api_url: Option<String>,
213 low_speed_timeout_in_seconds: Option<u64>,
214 },
215 #[serde(rename = "ollama")]
216 Ollama {
217 default_model: Option<OllamaModel>,
218 api_url: Option<String>,
219 low_speed_timeout_in_seconds: Option<u64>,
220 },
221}
222
223#[derive(Debug, Default)]
224pub struct AssistantSettings {
225 pub enabled: bool,
226 pub button: bool,
227 pub dock: AssistantDockPosition,
228 pub default_width: Pixels,
229 pub default_height: Pixels,
230 pub provider: AssistantProvider,
231}
232
233/// Assistant panel settings
234#[derive(Clone, Serialize, Deserialize, Debug)]
235#[serde(untagged)]
236pub enum AssistantSettingsContent {
237 Versioned(VersionedAssistantSettingsContent),
238 Legacy(LegacyAssistantSettingsContent),
239}
240
241impl JsonSchema for AssistantSettingsContent {
242 fn schema_name() -> String {
243 VersionedAssistantSettingsContent::schema_name()
244 }
245
246 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
247 VersionedAssistantSettingsContent::json_schema(gen)
248 }
249
250 fn is_referenceable() -> bool {
251 VersionedAssistantSettingsContent::is_referenceable()
252 }
253}
254
255impl Default for AssistantSettingsContent {
256 fn default() -> Self {
257 Self::Versioned(VersionedAssistantSettingsContent::default())
258 }
259}
260
261impl AssistantSettingsContent {
262 fn upgrade(&self) -> AssistantSettingsContentV1 {
263 match self {
264 AssistantSettingsContent::Versioned(settings) => match settings {
265 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
266 },
267 AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
268 enabled: None,
269 button: settings.button,
270 dock: settings.dock,
271 default_width: settings.default_width,
272 default_height: settings.default_height,
273 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
274 Some(AssistantProviderContent::OpenAi {
275 default_model: settings.default_open_ai_model.clone(),
276 api_url: Some(open_ai_api_url.clone()),
277 low_speed_timeout_in_seconds: None,
278 available_models: Some(Default::default()),
279 })
280 } else {
281 settings.default_open_ai_model.clone().map(|open_ai_model| {
282 AssistantProviderContent::OpenAi {
283 default_model: Some(open_ai_model),
284 api_url: None,
285 low_speed_timeout_in_seconds: None,
286 available_models: Some(Default::default()),
287 }
288 })
289 },
290 },
291 }
292 }
293
294 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
295 match self {
296 AssistantSettingsContent::Versioned(settings) => match settings {
297 VersionedAssistantSettingsContent::V1(settings) => {
298 settings.dock = Some(dock);
299 }
300 },
301 AssistantSettingsContent::Legacy(settings) => {
302 settings.dock = Some(dock);
303 }
304 }
305 }
306
307 pub fn set_model(&mut self, new_model: LanguageModel) {
308 match self {
309 AssistantSettingsContent::Versioned(settings) => match settings {
310 VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
311 Some(AssistantProviderContent::ZedDotDev {
312 default_model: model,
313 }) => {
314 if let LanguageModel::Cloud(new_model) = new_model {
315 *model = Some(new_model);
316 }
317 }
318 Some(AssistantProviderContent::OpenAi {
319 default_model: model,
320 ..
321 }) => {
322 if let LanguageModel::OpenAi(new_model) = new_model {
323 *model = Some(new_model);
324 }
325 }
326 Some(AssistantProviderContent::Anthropic {
327 default_model: model,
328 ..
329 }) => {
330 if let LanguageModel::Anthropic(new_model) = new_model {
331 *model = Some(new_model);
332 }
333 }
334 Some(AssistantProviderContent::Ollama {
335 default_model: model,
336 ..
337 }) => {
338 if let LanguageModel::Ollama(new_model) = new_model {
339 *model = Some(new_model);
340 }
341 }
342 provider => match new_model {
343 LanguageModel::Cloud(model) => {
344 *provider = Some(AssistantProviderContent::ZedDotDev {
345 default_model: Some(model),
346 })
347 }
348 LanguageModel::OpenAi(model) => {
349 *provider = Some(AssistantProviderContent::OpenAi {
350 default_model: Some(model),
351 api_url: None,
352 low_speed_timeout_in_seconds: None,
353 available_models: Some(Default::default()),
354 })
355 }
356 LanguageModel::Anthropic(model) => {
357 *provider = Some(AssistantProviderContent::Anthropic {
358 default_model: Some(model),
359 api_url: None,
360 low_speed_timeout_in_seconds: None,
361 })
362 }
363 LanguageModel::Ollama(model) => {
364 *provider = Some(AssistantProviderContent::Ollama {
365 default_model: Some(model),
366 api_url: None,
367 low_speed_timeout_in_seconds: None,
368 })
369 }
370 },
371 },
372 },
373 AssistantSettingsContent::Legacy(settings) => {
374 if let LanguageModel::OpenAi(model) = new_model {
375 settings.default_open_ai_model = Some(model);
376 }
377 }
378 }
379 }
380}
381
382#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
383#[serde(tag = "version")]
384pub enum VersionedAssistantSettingsContent {
385 #[serde(rename = "1")]
386 V1(AssistantSettingsContentV1),
387}
388
389impl Default for VersionedAssistantSettingsContent {
390 fn default() -> Self {
391 Self::V1(AssistantSettingsContentV1 {
392 enabled: None,
393 button: None,
394 dock: None,
395 default_width: None,
396 default_height: None,
397 provider: None,
398 })
399 }
400}
401
402#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
403pub struct AssistantSettingsContentV1 {
404 /// Whether the Assistant is enabled.
405 ///
406 /// Default: true
407 enabled: Option<bool>,
408 /// Whether to show the assistant panel button in the status bar.
409 ///
410 /// Default: true
411 button: Option<bool>,
412 /// Where to dock the assistant.
413 ///
414 /// Default: right
415 dock: Option<AssistantDockPosition>,
416 /// Default width in pixels when the assistant is docked to the left or right.
417 ///
418 /// Default: 640
419 default_width: Option<f32>,
420 /// Default height in pixels when the assistant is docked to the bottom.
421 ///
422 /// Default: 320
423 default_height: Option<f32>,
424 /// The provider of the assistant service.
425 ///
426 /// This can either be the internal `zed.dev` service or an external `openai` service,
427 /// each with their respective default models and configurations.
428 provider: Option<AssistantProviderContent>,
429}
430
431#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
432pub struct LegacyAssistantSettingsContent {
433 /// Whether to show the assistant panel button in the status bar.
434 ///
435 /// Default: true
436 pub button: Option<bool>,
437 /// Where to dock the assistant.
438 ///
439 /// Default: right
440 pub dock: Option<AssistantDockPosition>,
441 /// Default width in pixels when the assistant is docked to the left or right.
442 ///
443 /// Default: 640
444 pub default_width: Option<f32>,
445 /// Default height in pixels when the assistant is docked to the bottom.
446 ///
447 /// Default: 320
448 pub default_height: Option<f32>,
449 /// The default OpenAI model to use when creating new contexts.
450 ///
451 /// Default: gpt-4-1106-preview
452 pub default_open_ai_model: Option<OpenAiModel>,
453 /// OpenAI API base URL to use when creating new contexts.
454 ///
455 /// Default: https://api.openai.com/v1
456 pub openai_api_url: Option<String>,
457}
458
459impl Settings for AssistantSettings {
460 const KEY: Option<&'static str> = Some("assistant");
461
462 type FileContent = AssistantSettingsContent;
463
464 fn load(
465 sources: SettingsSources<Self::FileContent>,
466 _: &mut gpui::AppContext,
467 ) -> anyhow::Result<Self> {
468 let mut settings = AssistantSettings::default();
469
470 for value in sources.defaults_and_customizations() {
471 let value = value.upgrade();
472 merge(&mut settings.enabled, value.enabled);
473 merge(&mut settings.button, value.button);
474 merge(&mut settings.dock, value.dock);
475 merge(
476 &mut settings.default_width,
477 value.default_width.map(Into::into),
478 );
479 merge(
480 &mut settings.default_height,
481 value.default_height.map(Into::into),
482 );
483 if let Some(provider) = value.provider.clone() {
484 match (&mut settings.provider, provider) {
485 (
486 AssistantProvider::ZedDotDev { model },
487 AssistantProviderContent::ZedDotDev {
488 default_model: model_override,
489 },
490 ) => {
491 merge(model, model_override);
492 }
493 (
494 AssistantProvider::OpenAi {
495 model,
496 api_url,
497 low_speed_timeout_in_seconds,
498 available_models,
499 },
500 AssistantProviderContent::OpenAi {
501 default_model: model_override,
502 api_url: api_url_override,
503 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
504 available_models: available_models_override,
505 },
506 ) => {
507 merge(model, model_override);
508 merge(api_url, api_url_override);
509 merge(available_models, available_models_override);
510 if let Some(low_speed_timeout_in_seconds_override) =
511 low_speed_timeout_in_seconds_override
512 {
513 *low_speed_timeout_in_seconds =
514 Some(low_speed_timeout_in_seconds_override);
515 }
516 }
517 (
518 AssistantProvider::Ollama {
519 model,
520 api_url,
521 low_speed_timeout_in_seconds,
522 },
523 AssistantProviderContent::Ollama {
524 default_model: model_override,
525 api_url: api_url_override,
526 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
527 },
528 ) => {
529 merge(model, model_override);
530 merge(api_url, api_url_override);
531 if let Some(low_speed_timeout_in_seconds_override) =
532 low_speed_timeout_in_seconds_override
533 {
534 *low_speed_timeout_in_seconds =
535 Some(low_speed_timeout_in_seconds_override);
536 }
537 }
538 (
539 AssistantProvider::Anthropic {
540 model,
541 api_url,
542 low_speed_timeout_in_seconds,
543 },
544 AssistantProviderContent::Anthropic {
545 default_model: model_override,
546 api_url: api_url_override,
547 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
548 },
549 ) => {
550 merge(model, model_override);
551 merge(api_url, api_url_override);
552 if let Some(low_speed_timeout_in_seconds_override) =
553 low_speed_timeout_in_seconds_override
554 {
555 *low_speed_timeout_in_seconds =
556 Some(low_speed_timeout_in_seconds_override);
557 }
558 }
559 (provider, provider_override) => {
560 *provider = match provider_override {
561 AssistantProviderContent::ZedDotDev {
562 default_model: model,
563 } => AssistantProvider::ZedDotDev {
564 model: model.unwrap_or_default(),
565 },
566 AssistantProviderContent::OpenAi {
567 default_model: model,
568 api_url,
569 low_speed_timeout_in_seconds,
570 available_models,
571 } => AssistantProvider::OpenAi {
572 model: model.unwrap_or_default(),
573 api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
574 low_speed_timeout_in_seconds,
575 available_models: available_models.unwrap_or_default(),
576 },
577 AssistantProviderContent::Anthropic {
578 default_model: model,
579 api_url,
580 low_speed_timeout_in_seconds,
581 } => AssistantProvider::Anthropic {
582 model: model.unwrap_or_default(),
583 api_url: api_url
584 .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
585 low_speed_timeout_in_seconds,
586 },
587 AssistantProviderContent::Ollama {
588 default_model: model,
589 api_url,
590 low_speed_timeout_in_seconds,
591 } => AssistantProvider::Ollama {
592 model: model.unwrap_or_default(),
593 api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
594 low_speed_timeout_in_seconds,
595 },
596 };
597 }
598 }
599 }
600 }
601
602 Ok(settings)
603 }
604}
605
606fn merge<T>(target: &mut T, value: Option<T>) {
607 if let Some(value) = value {
608 *target = value;
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use gpui::{AppContext, UpdateGlobal};
615 use settings::SettingsStore;
616
617 use super::*;
618
619 #[gpui::test]
620 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
621 let store = settings::SettingsStore::test(cx);
622 cx.set_global(store);
623
624 // Settings default to gpt-4-turbo.
625 AssistantSettings::register(cx);
626 assert_eq!(
627 AssistantSettings::get_global(cx).provider,
628 AssistantProvider::OpenAi {
629 model: OpenAiModel::FourOmni,
630 api_url: open_ai::OPEN_AI_API_URL.into(),
631 low_speed_timeout_in_seconds: None,
632 available_models: Default::default(),
633 }
634 );
635
636 // Ensure backward-compatibility.
637 SettingsStore::update_global(cx, |store, cx| {
638 store
639 .set_user_settings(
640 r#"{
641 "assistant": {
642 "openai_api_url": "test-url",
643 }
644 }"#,
645 cx,
646 )
647 .unwrap();
648 });
649 assert_eq!(
650 AssistantSettings::get_global(cx).provider,
651 AssistantProvider::OpenAi {
652 model: OpenAiModel::FourOmni,
653 api_url: "test-url".into(),
654 low_speed_timeout_in_seconds: None,
655 available_models: Default::default(),
656 }
657 );
658 SettingsStore::update_global(cx, |store, cx| {
659 store
660 .set_user_settings(
661 r#"{
662 "assistant": {
663 "default_open_ai_model": "gpt-4-0613"
664 }
665 }"#,
666 cx,
667 )
668 .unwrap();
669 });
670 assert_eq!(
671 AssistantSettings::get_global(cx).provider,
672 AssistantProvider::OpenAi {
673 model: OpenAiModel::Four,
674 api_url: open_ai::OPEN_AI_API_URL.into(),
675 low_speed_timeout_in_seconds: None,
676 available_models: Default::default(),
677 }
678 );
679
680 // The new version supports setting a custom model when using zed.dev.
681 SettingsStore::update_global(cx, |store, cx| {
682 store
683 .set_user_settings(
684 r#"{
685 "assistant": {
686 "version": "1",
687 "provider": {
688 "name": "zed.dev",
689 "default_model": "custom"
690 }
691 }
692 }"#,
693 cx,
694 )
695 .unwrap();
696 });
697 assert_eq!(
698 AssistantSettings::get_global(cx).provider,
699 AssistantProvider::ZedDotDev {
700 model: CloudModel::Custom("custom".into())
701 }
702 );
703 }
704}