1use std::fmt;
2
3use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
4pub use anthropic::Model as AnthropicModel;
5use gpui::Pixels;
6pub use ollama::Model as OllamaModel;
7pub use open_ai::Model as OpenAiModel;
8use schemars::{
9 schema::{InstanceType, Metadata, Schema, SchemaObject},
10 JsonSchema,
11};
12use serde::{
13 de::{self, Visitor},
14 Deserialize, Deserializer, Serialize, Serializer,
15};
16use settings::{Settings, SettingsSources};
17use strum::{EnumIter, IntoEnumIterator};
18
19#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
20pub enum CloudModel {
21 Gpt3Point5Turbo,
22 Gpt4,
23 Gpt4Turbo,
24 #[default]
25 Gpt4Omni,
26 Claude3_5Sonnet,
27 Claude3Opus,
28 Claude3Sonnet,
29 Claude3Haiku,
30 Gemini15Pro,
31 Gemini15Flash,
32 Custom(String),
33}
34
35impl Serialize for CloudModel {
36 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
37 where
38 S: Serializer,
39 {
40 serializer.serialize_str(self.id())
41 }
42}
43
44impl<'de> Deserialize<'de> for CloudModel {
45 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
46 where
47 D: Deserializer<'de>,
48 {
49 struct ZedDotDevModelVisitor;
50
51 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
52 type Value = CloudModel;
53
54 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
55 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
56 }
57
58 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
59 where
60 E: de::Error,
61 {
62 let model = CloudModel::iter()
63 .find(|model| model.id() == value)
64 .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
65 Ok(model)
66 }
67 }
68
69 deserializer.deserialize_str(ZedDotDevModelVisitor)
70 }
71}
72
73impl JsonSchema for CloudModel {
74 fn schema_name() -> String {
75 "ZedDotDevModel".to_owned()
76 }
77
78 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
79 let variants = CloudModel::iter()
80 .filter_map(|model| {
81 let id = model.id();
82 if id.is_empty() {
83 None
84 } else {
85 Some(id.to_string())
86 }
87 })
88 .collect::<Vec<_>>();
89 Schema::Object(SchemaObject {
90 instance_type: Some(InstanceType::String.into()),
91 enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
92 metadata: Some(Box::new(Metadata {
93 title: Some("ZedDotDevModel".to_owned()),
94 default: Some(CloudModel::default().id().into()),
95 examples: variants.into_iter().map(Into::into).collect(),
96 ..Default::default()
97 })),
98 ..Default::default()
99 })
100 }
101}
102
103impl CloudModel {
104 pub fn id(&self) -> &str {
105 match self {
106 Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
107 Self::Gpt4 => "gpt-4",
108 Self::Gpt4Turbo => "gpt-4-turbo-preview",
109 Self::Gpt4Omni => "gpt-4o",
110 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
111 Self::Claude3Opus => "claude-3-opus",
112 Self::Claude3Sonnet => "claude-3-sonnet",
113 Self::Claude3Haiku => "claude-3-haiku",
114 Self::Gemini15Pro => "gemini-1.5-pro",
115 Self::Gemini15Flash => "gemini-1.5-flash",
116 Self::Custom(id) => id,
117 }
118 }
119
120 pub fn display_name(&self) -> &str {
121 match self {
122 Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
123 Self::Gpt4 => "GPT 4",
124 Self::Gpt4Turbo => "GPT 4 Turbo",
125 Self::Gpt4Omni => "GPT 4 Omni",
126 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
127 Self::Claude3Opus => "Claude 3 Opus",
128 Self::Claude3Sonnet => "Claude 3 Sonnet",
129 Self::Claude3Haiku => "Claude 3 Haiku",
130 Self::Gemini15Pro => "Gemini 1.5 Pro",
131 Self::Gemini15Flash => "Gemini 1.5 Flash",
132 Self::Custom(id) => id.as_str(),
133 }
134 }
135
136 pub fn max_token_count(&self) -> usize {
137 match self {
138 Self::Gpt3Point5Turbo => 2048,
139 Self::Gpt4 => 4096,
140 Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
141 Self::Claude3_5Sonnet
142 | Self::Claude3Opus
143 | Self::Claude3Sonnet
144 | Self::Claude3Haiku => 200000,
145 Self::Gemini15Pro => 128000,
146 Self::Gemini15Flash => 32000,
147 Self::Custom(_) => 4096, // TODO: Make this configurable
148 }
149 }
150
151 pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
152 match self {
153 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
154 preprocess_anthropic_request(request)
155 }
156 _ => {}
157 }
158 }
159}
160
161#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
162#[serde(rename_all = "snake_case")]
163pub enum AssistantDockPosition {
164 Left,
165 #[default]
166 Right,
167 Bottom,
168}
169
170#[derive(Debug, PartialEq)]
171pub enum AssistantProvider {
172 ZedDotDev {
173 model: CloudModel,
174 },
175 OpenAi {
176 model: OpenAiModel,
177 api_url: String,
178 low_speed_timeout_in_seconds: Option<u64>,
179 available_models: Vec<OpenAiModel>,
180 },
181 Anthropic {
182 model: AnthropicModel,
183 api_url: String,
184 low_speed_timeout_in_seconds: Option<u64>,
185 },
186 Ollama {
187 model: OllamaModel,
188 api_url: String,
189 low_speed_timeout_in_seconds: Option<u64>,
190 },
191}
192
193impl Default for AssistantProvider {
194 fn default() -> Self {
195 Self::OpenAi {
196 model: OpenAiModel::default(),
197 api_url: open_ai::OPEN_AI_API_URL.into(),
198 low_speed_timeout_in_seconds: None,
199 available_models: Default::default(),
200 }
201 }
202}
203
204#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
205#[serde(tag = "name", rename_all = "snake_case")]
206pub enum AssistantProviderContent {
207 #[serde(rename = "zed.dev")]
208 ZedDotDev { default_model: Option<CloudModel> },
209 #[serde(rename = "openai")]
210 OpenAi {
211 default_model: Option<OpenAiModel>,
212 api_url: Option<String>,
213 low_speed_timeout_in_seconds: Option<u64>,
214 available_models: Option<Vec<OpenAiModel>>,
215 },
216 #[serde(rename = "anthropic")]
217 Anthropic {
218 default_model: Option<AnthropicModel>,
219 api_url: Option<String>,
220 low_speed_timeout_in_seconds: Option<u64>,
221 },
222 #[serde(rename = "ollama")]
223 Ollama {
224 default_model: Option<OllamaModel>,
225 api_url: Option<String>,
226 low_speed_timeout_in_seconds: Option<u64>,
227 },
228}
229
230#[derive(Debug, Default)]
231pub struct AssistantSettings {
232 pub enabled: bool,
233 pub button: bool,
234 pub dock: AssistantDockPosition,
235 pub default_width: Pixels,
236 pub default_height: Pixels,
237 pub provider: AssistantProvider,
238}
239
240/// Assistant panel settings
241#[derive(Clone, Serialize, Deserialize, Debug)]
242#[serde(untagged)]
243pub enum AssistantSettingsContent {
244 Versioned(VersionedAssistantSettingsContent),
245 Legacy(LegacyAssistantSettingsContent),
246}
247
248impl JsonSchema for AssistantSettingsContent {
249 fn schema_name() -> String {
250 VersionedAssistantSettingsContent::schema_name()
251 }
252
253 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
254 VersionedAssistantSettingsContent::json_schema(gen)
255 }
256
257 fn is_referenceable() -> bool {
258 VersionedAssistantSettingsContent::is_referenceable()
259 }
260}
261
262impl Default for AssistantSettingsContent {
263 fn default() -> Self {
264 Self::Versioned(VersionedAssistantSettingsContent::default())
265 }
266}
267
268impl AssistantSettingsContent {
269 fn upgrade(&self) -> AssistantSettingsContentV1 {
270 match self {
271 AssistantSettingsContent::Versioned(settings) => match settings {
272 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
273 },
274 AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
275 enabled: None,
276 button: settings.button,
277 dock: settings.dock,
278 default_width: settings.default_width,
279 default_height: settings.default_height,
280 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
281 Some(AssistantProviderContent::OpenAi {
282 default_model: settings.default_open_ai_model.clone(),
283 api_url: Some(open_ai_api_url.clone()),
284 low_speed_timeout_in_seconds: None,
285 available_models: Some(Default::default()),
286 })
287 } else {
288 settings.default_open_ai_model.clone().map(|open_ai_model| {
289 AssistantProviderContent::OpenAi {
290 default_model: Some(open_ai_model),
291 api_url: None,
292 low_speed_timeout_in_seconds: None,
293 available_models: Some(Default::default()),
294 }
295 })
296 },
297 },
298 }
299 }
300
301 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
302 match self {
303 AssistantSettingsContent::Versioned(settings) => match settings {
304 VersionedAssistantSettingsContent::V1(settings) => {
305 settings.dock = Some(dock);
306 }
307 },
308 AssistantSettingsContent::Legacy(settings) => {
309 settings.dock = Some(dock);
310 }
311 }
312 }
313
314 pub fn set_model(&mut self, new_model: LanguageModel) {
315 match self {
316 AssistantSettingsContent::Versioned(settings) => match settings {
317 VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
318 Some(AssistantProviderContent::ZedDotDev {
319 default_model: model,
320 }) => {
321 if let LanguageModel::Cloud(new_model) = new_model {
322 *model = Some(new_model);
323 }
324 }
325 Some(AssistantProviderContent::OpenAi {
326 default_model: model,
327 ..
328 }) => {
329 if let LanguageModel::OpenAi(new_model) = new_model {
330 *model = Some(new_model);
331 }
332 }
333 Some(AssistantProviderContent::Anthropic {
334 default_model: model,
335 ..
336 }) => {
337 if let LanguageModel::Anthropic(new_model) = new_model {
338 *model = Some(new_model);
339 }
340 }
341 Some(AssistantProviderContent::Ollama {
342 default_model: model,
343 ..
344 }) => {
345 if let LanguageModel::Ollama(new_model) = new_model {
346 *model = Some(new_model);
347 }
348 }
349 provider => match new_model {
350 LanguageModel::Cloud(model) => {
351 *provider = Some(AssistantProviderContent::ZedDotDev {
352 default_model: Some(model),
353 })
354 }
355 LanguageModel::OpenAi(model) => {
356 *provider = Some(AssistantProviderContent::OpenAi {
357 default_model: Some(model),
358 api_url: None,
359 low_speed_timeout_in_seconds: None,
360 available_models: Some(Default::default()),
361 })
362 }
363 LanguageModel::Anthropic(model) => {
364 *provider = Some(AssistantProviderContent::Anthropic {
365 default_model: Some(model),
366 api_url: None,
367 low_speed_timeout_in_seconds: None,
368 })
369 }
370 LanguageModel::Ollama(model) => {
371 *provider = Some(AssistantProviderContent::Ollama {
372 default_model: Some(model),
373 api_url: None,
374 low_speed_timeout_in_seconds: None,
375 })
376 }
377 },
378 },
379 },
380 AssistantSettingsContent::Legacy(settings) => {
381 if let LanguageModel::OpenAi(model) = new_model {
382 settings.default_open_ai_model = Some(model);
383 }
384 }
385 }
386 }
387}
388
389#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
390#[serde(tag = "version")]
391pub enum VersionedAssistantSettingsContent {
392 #[serde(rename = "1")]
393 V1(AssistantSettingsContentV1),
394}
395
396impl Default for VersionedAssistantSettingsContent {
397 fn default() -> Self {
398 Self::V1(AssistantSettingsContentV1 {
399 enabled: None,
400 button: None,
401 dock: None,
402 default_width: None,
403 default_height: None,
404 provider: None,
405 })
406 }
407}
408
409#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
410pub struct AssistantSettingsContentV1 {
411 /// Whether the Assistant is enabled.
412 ///
413 /// Default: true
414 enabled: Option<bool>,
415 /// Whether to show the assistant panel button in the status bar.
416 ///
417 /// Default: true
418 button: Option<bool>,
419 /// Where to dock the assistant.
420 ///
421 /// Default: right
422 dock: Option<AssistantDockPosition>,
423 /// Default width in pixels when the assistant is docked to the left or right.
424 ///
425 /// Default: 640
426 default_width: Option<f32>,
427 /// Default height in pixels when the assistant is docked to the bottom.
428 ///
429 /// Default: 320
430 default_height: Option<f32>,
431 /// The provider of the assistant service.
432 ///
433 /// This can either be the internal `zed.dev` service or an external `openai` service,
434 /// each with their respective default models and configurations.
435 provider: Option<AssistantProviderContent>,
436}
437
438#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
439pub struct LegacyAssistantSettingsContent {
440 /// Whether to show the assistant panel button in the status bar.
441 ///
442 /// Default: true
443 pub button: Option<bool>,
444 /// Where to dock the assistant.
445 ///
446 /// Default: right
447 pub dock: Option<AssistantDockPosition>,
448 /// Default width in pixels when the assistant is docked to the left or right.
449 ///
450 /// Default: 640
451 pub default_width: Option<f32>,
452 /// Default height in pixels when the assistant is docked to the bottom.
453 ///
454 /// Default: 320
455 pub default_height: Option<f32>,
456 /// The default OpenAI model to use when creating new contexts.
457 ///
458 /// Default: gpt-4-1106-preview
459 pub default_open_ai_model: Option<OpenAiModel>,
460 /// OpenAI API base URL to use when creating new contexts.
461 ///
462 /// Default: https://api.openai.com/v1
463 pub openai_api_url: Option<String>,
464}
465
466impl Settings for AssistantSettings {
467 const KEY: Option<&'static str> = Some("assistant");
468
469 type FileContent = AssistantSettingsContent;
470
471 fn load(
472 sources: SettingsSources<Self::FileContent>,
473 _: &mut gpui::AppContext,
474 ) -> anyhow::Result<Self> {
475 let mut settings = AssistantSettings::default();
476
477 for value in sources.defaults_and_customizations() {
478 let value = value.upgrade();
479 merge(&mut settings.enabled, value.enabled);
480 merge(&mut settings.button, value.button);
481 merge(&mut settings.dock, value.dock);
482 merge(
483 &mut settings.default_width,
484 value.default_width.map(Into::into),
485 );
486 merge(
487 &mut settings.default_height,
488 value.default_height.map(Into::into),
489 );
490 if let Some(provider) = value.provider.clone() {
491 match (&mut settings.provider, provider) {
492 (
493 AssistantProvider::ZedDotDev { model },
494 AssistantProviderContent::ZedDotDev {
495 default_model: model_override,
496 },
497 ) => {
498 merge(model, model_override);
499 }
500 (
501 AssistantProvider::OpenAi {
502 model,
503 api_url,
504 low_speed_timeout_in_seconds,
505 available_models,
506 },
507 AssistantProviderContent::OpenAi {
508 default_model: model_override,
509 api_url: api_url_override,
510 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
511 available_models: available_models_override,
512 },
513 ) => {
514 merge(model, model_override);
515 merge(api_url, api_url_override);
516 merge(available_models, available_models_override);
517 if let Some(low_speed_timeout_in_seconds_override) =
518 low_speed_timeout_in_seconds_override
519 {
520 *low_speed_timeout_in_seconds =
521 Some(low_speed_timeout_in_seconds_override);
522 }
523 }
524 (
525 AssistantProvider::Ollama {
526 model,
527 api_url,
528 low_speed_timeout_in_seconds,
529 },
530 AssistantProviderContent::Ollama {
531 default_model: model_override,
532 api_url: api_url_override,
533 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
534 },
535 ) => {
536 merge(model, model_override);
537 merge(api_url, api_url_override);
538 if let Some(low_speed_timeout_in_seconds_override) =
539 low_speed_timeout_in_seconds_override
540 {
541 *low_speed_timeout_in_seconds =
542 Some(low_speed_timeout_in_seconds_override);
543 }
544 }
545 (
546 AssistantProvider::Anthropic {
547 model,
548 api_url,
549 low_speed_timeout_in_seconds,
550 },
551 AssistantProviderContent::Anthropic {
552 default_model: model_override,
553 api_url: api_url_override,
554 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
555 },
556 ) => {
557 merge(model, model_override);
558 merge(api_url, api_url_override);
559 if let Some(low_speed_timeout_in_seconds_override) =
560 low_speed_timeout_in_seconds_override
561 {
562 *low_speed_timeout_in_seconds =
563 Some(low_speed_timeout_in_seconds_override);
564 }
565 }
566 (provider, provider_override) => {
567 *provider = match provider_override {
568 AssistantProviderContent::ZedDotDev {
569 default_model: model,
570 } => AssistantProvider::ZedDotDev {
571 model: model.unwrap_or_default(),
572 },
573 AssistantProviderContent::OpenAi {
574 default_model: model,
575 api_url,
576 low_speed_timeout_in_seconds,
577 available_models,
578 } => AssistantProvider::OpenAi {
579 model: model.unwrap_or_default(),
580 api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
581 low_speed_timeout_in_seconds,
582 available_models: available_models.unwrap_or_default(),
583 },
584 AssistantProviderContent::Anthropic {
585 default_model: model,
586 api_url,
587 low_speed_timeout_in_seconds,
588 } => AssistantProvider::Anthropic {
589 model: model.unwrap_or_default(),
590 api_url: api_url
591 .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
592 low_speed_timeout_in_seconds,
593 },
594 AssistantProviderContent::Ollama {
595 default_model: model,
596 api_url,
597 low_speed_timeout_in_seconds,
598 } => AssistantProvider::Ollama {
599 model: model.unwrap_or_default(),
600 api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
601 low_speed_timeout_in_seconds,
602 },
603 };
604 }
605 }
606 }
607 }
608
609 Ok(settings)
610 }
611}
612
613fn merge<T>(target: &mut T, value: Option<T>) {
614 if let Some(value) = value {
615 *target = value;
616 }
617}
618
619#[cfg(test)]
620mod tests {
621 use gpui::{AppContext, UpdateGlobal};
622 use settings::SettingsStore;
623
624 use super::*;
625
626 #[gpui::test]
627 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
628 let store = settings::SettingsStore::test(cx);
629 cx.set_global(store);
630
631 // Settings default to gpt-4-turbo.
632 AssistantSettings::register(cx);
633 assert_eq!(
634 AssistantSettings::get_global(cx).provider,
635 AssistantProvider::OpenAi {
636 model: OpenAiModel::FourOmni,
637 api_url: open_ai::OPEN_AI_API_URL.into(),
638 low_speed_timeout_in_seconds: None,
639 available_models: Default::default(),
640 }
641 );
642
643 // Ensure backward-compatibility.
644 SettingsStore::update_global(cx, |store, cx| {
645 store
646 .set_user_settings(
647 r#"{
648 "assistant": {
649 "openai_api_url": "test-url",
650 }
651 }"#,
652 cx,
653 )
654 .unwrap();
655 });
656 assert_eq!(
657 AssistantSettings::get_global(cx).provider,
658 AssistantProvider::OpenAi {
659 model: OpenAiModel::FourOmni,
660 api_url: "test-url".into(),
661 low_speed_timeout_in_seconds: None,
662 available_models: Default::default(),
663 }
664 );
665 SettingsStore::update_global(cx, |store, cx| {
666 store
667 .set_user_settings(
668 r#"{
669 "assistant": {
670 "default_open_ai_model": "gpt-4-0613"
671 }
672 }"#,
673 cx,
674 )
675 .unwrap();
676 });
677 assert_eq!(
678 AssistantSettings::get_global(cx).provider,
679 AssistantProvider::OpenAi {
680 model: OpenAiModel::Four,
681 api_url: open_ai::OPEN_AI_API_URL.into(),
682 low_speed_timeout_in_seconds: None,
683 available_models: Default::default(),
684 }
685 );
686
687 // The new version supports setting a custom model when using zed.dev.
688 SettingsStore::update_global(cx, |store, cx| {
689 store
690 .set_user_settings(
691 r#"{
692 "assistant": {
693 "version": "1",
694 "provider": {
695 "name": "zed.dev",
696 "default_model": "custom"
697 }
698 }
699 }"#,
700 cx,
701 )
702 .unwrap();
703 });
704 assert_eq!(
705 AssistantSettings::get_global(cx).provider,
706 AssistantProvider::ZedDotDev {
707 model: CloudModel::Custom("custom".into())
708 }
709 );
710 }
711}