1use std::fmt;
2
3pub use anthropic::Model as AnthropicModel;
4use gpui::Pixels;
5pub use ollama::Model as OllamaModel;
6pub use open_ai::Model as OpenAiModel;
7use schemars::{
8 schema::{InstanceType, Metadata, Schema, SchemaObject},
9 JsonSchema,
10};
11use serde::{
12 de::{self, Visitor},
13 Deserialize, Deserializer, Serialize, Serializer,
14};
15use settings::{Settings, SettingsSources};
16use strum::{EnumIter, IntoEnumIterator};
17
18use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
19
20#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
21pub enum CloudModel {
22 Gpt3Point5Turbo,
23 Gpt4,
24 Gpt4Turbo,
25 #[default]
26 Gpt4Omni,
27 Claude3Opus,
28 Claude3Sonnet,
29 Claude3Haiku,
30 Custom(String),
31}
32
33impl Serialize for CloudModel {
34 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35 where
36 S: Serializer,
37 {
38 serializer.serialize_str(self.id())
39 }
40}
41
42impl<'de> Deserialize<'de> for CloudModel {
43 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
44 where
45 D: Deserializer<'de>,
46 {
47 struct ZedDotDevModelVisitor;
48
49 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
50 type Value = CloudModel;
51
52 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
54 }
55
56 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
57 where
58 E: de::Error,
59 {
60 let model = CloudModel::iter()
61 .find(|model| model.id() == value)
62 .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
63 Ok(model)
64 }
65 }
66
67 deserializer.deserialize_str(ZedDotDevModelVisitor)
68 }
69}
70
71impl JsonSchema for CloudModel {
72 fn schema_name() -> String {
73 "ZedDotDevModel".to_owned()
74 }
75
76 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
77 let variants = CloudModel::iter()
78 .filter_map(|model| {
79 let id = model.id();
80 if id.is_empty() {
81 None
82 } else {
83 Some(id.to_string())
84 }
85 })
86 .collect::<Vec<_>>();
87 Schema::Object(SchemaObject {
88 instance_type: Some(InstanceType::String.into()),
89 enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
90 metadata: Some(Box::new(Metadata {
91 title: Some("ZedDotDevModel".to_owned()),
92 default: Some(CloudModel::default().id().into()),
93 examples: variants.into_iter().map(Into::into).collect(),
94 ..Default::default()
95 })),
96 ..Default::default()
97 })
98 }
99}
100
101impl CloudModel {
102 pub fn id(&self) -> &str {
103 match self {
104 Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
105 Self::Gpt4 => "gpt-4",
106 Self::Gpt4Turbo => "gpt-4-turbo-preview",
107 Self::Gpt4Omni => "gpt-4o",
108 Self::Claude3Opus => "claude-3-opus",
109 Self::Claude3Sonnet => "claude-3-sonnet",
110 Self::Claude3Haiku => "claude-3-haiku",
111 Self::Custom(id) => id,
112 }
113 }
114
115 pub fn display_name(&self) -> &str {
116 match self {
117 Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
118 Self::Gpt4 => "GPT 4",
119 Self::Gpt4Turbo => "GPT 4 Turbo",
120 Self::Gpt4Omni => "GPT 4 Omni",
121 Self::Claude3Opus => "Claude 3 Opus",
122 Self::Claude3Sonnet => "Claude 3 Sonnet",
123 Self::Claude3Haiku => "Claude 3 Haiku",
124 Self::Custom(id) => id.as_str(),
125 }
126 }
127
128 pub fn max_token_count(&self) -> usize {
129 match self {
130 Self::Gpt3Point5Turbo => 2048,
131 Self::Gpt4 => 4096,
132 Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
133 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
134 Self::Custom(_) => 4096, // TODO: Make this configurable
135 }
136 }
137
138 pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
139 match self {
140 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
141 preprocess_anthropic_request(request)
142 }
143 _ => {}
144 }
145 }
146}
147
148#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
149#[serde(rename_all = "snake_case")]
150pub enum AssistantDockPosition {
151 Left,
152 #[default]
153 Right,
154 Bottom,
155}
156
157#[derive(Debug, PartialEq)]
158pub enum AssistantProvider {
159 ZedDotDev {
160 model: CloudModel,
161 },
162 OpenAi {
163 model: OpenAiModel,
164 api_url: String,
165 low_speed_timeout_in_seconds: Option<u64>,
166 },
167 Anthropic {
168 model: AnthropicModel,
169 api_url: String,
170 low_speed_timeout_in_seconds: Option<u64>,
171 },
172 Ollama {
173 model: OllamaModel,
174 api_url: String,
175 low_speed_timeout_in_seconds: Option<u64>,
176 },
177}
178
179impl Default for AssistantProvider {
180 fn default() -> Self {
181 Self::OpenAi {
182 model: OpenAiModel::default(),
183 api_url: open_ai::OPEN_AI_API_URL.into(),
184 low_speed_timeout_in_seconds: None,
185 }
186 }
187}
188
189#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
190#[serde(tag = "name", rename_all = "snake_case")]
191pub enum AssistantProviderContent {
192 #[serde(rename = "zed.dev")]
193 ZedDotDev { default_model: Option<CloudModel> },
194 #[serde(rename = "openai")]
195 OpenAi {
196 default_model: Option<OpenAiModel>,
197 api_url: Option<String>,
198 low_speed_timeout_in_seconds: Option<u64>,
199 },
200 #[serde(rename = "anthropic")]
201 Anthropic {
202 default_model: Option<AnthropicModel>,
203 api_url: Option<String>,
204 low_speed_timeout_in_seconds: Option<u64>,
205 },
206 #[serde(rename = "ollama")]
207 Ollama {
208 default_model: Option<OllamaModel>,
209 api_url: Option<String>,
210 low_speed_timeout_in_seconds: Option<u64>,
211 },
212}
213
214#[derive(Debug, Default)]
215pub struct AssistantSettings {
216 pub enabled: bool,
217 pub button: bool,
218 pub dock: AssistantDockPosition,
219 pub default_width: Pixels,
220 pub default_height: Pixels,
221 pub provider: AssistantProvider,
222}
223
224/// Assistant panel settings
225#[derive(Clone, Serialize, Deserialize, Debug)]
226#[serde(untagged)]
227pub enum AssistantSettingsContent {
228 Versioned(VersionedAssistantSettingsContent),
229 Legacy(LegacyAssistantSettingsContent),
230}
231
232impl JsonSchema for AssistantSettingsContent {
233 fn schema_name() -> String {
234 VersionedAssistantSettingsContent::schema_name()
235 }
236
237 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
238 VersionedAssistantSettingsContent::json_schema(gen)
239 }
240
241 fn is_referenceable() -> bool {
242 VersionedAssistantSettingsContent::is_referenceable()
243 }
244}
245
246impl Default for AssistantSettingsContent {
247 fn default() -> Self {
248 Self::Versioned(VersionedAssistantSettingsContent::default())
249 }
250}
251
252impl AssistantSettingsContent {
253 fn upgrade(&self) -> AssistantSettingsContentV1 {
254 match self {
255 AssistantSettingsContent::Versioned(settings) => match settings {
256 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
257 },
258 AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
259 enabled: None,
260 button: settings.button,
261 dock: settings.dock,
262 default_width: settings.default_width,
263 default_height: settings.default_height,
264 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
265 Some(AssistantProviderContent::OpenAi {
266 default_model: settings.default_open_ai_model.clone(),
267 api_url: Some(open_ai_api_url.clone()),
268 low_speed_timeout_in_seconds: None,
269 })
270 } else {
271 settings.default_open_ai_model.clone().map(|open_ai_model| {
272 AssistantProviderContent::OpenAi {
273 default_model: Some(open_ai_model),
274 api_url: None,
275 low_speed_timeout_in_seconds: None,
276 }
277 })
278 },
279 },
280 }
281 }
282
283 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
284 match self {
285 AssistantSettingsContent::Versioned(settings) => match settings {
286 VersionedAssistantSettingsContent::V1(settings) => {
287 settings.dock = Some(dock);
288 }
289 },
290 AssistantSettingsContent::Legacy(settings) => {
291 settings.dock = Some(dock);
292 }
293 }
294 }
295
296 pub fn set_model(&mut self, new_model: LanguageModel) {
297 match self {
298 AssistantSettingsContent::Versioned(settings) => match settings {
299 VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
300 Some(AssistantProviderContent::ZedDotDev {
301 default_model: model,
302 }) => {
303 if let LanguageModel::Cloud(new_model) = new_model {
304 *model = Some(new_model);
305 }
306 }
307 Some(AssistantProviderContent::OpenAi {
308 default_model: model,
309 ..
310 }) => {
311 if let LanguageModel::OpenAi(new_model) = new_model {
312 *model = Some(new_model);
313 }
314 }
315 Some(AssistantProviderContent::Anthropic {
316 default_model: model,
317 ..
318 }) => {
319 if let LanguageModel::Anthropic(new_model) = new_model {
320 *model = Some(new_model);
321 }
322 }
323 provider => match new_model {
324 LanguageModel::Cloud(model) => {
325 *provider = Some(AssistantProviderContent::ZedDotDev {
326 default_model: Some(model),
327 })
328 }
329 LanguageModel::OpenAi(model) => {
330 *provider = Some(AssistantProviderContent::OpenAi {
331 default_model: Some(model),
332 api_url: None,
333 low_speed_timeout_in_seconds: None,
334 })
335 }
336 LanguageModel::Anthropic(model) => {
337 *provider = Some(AssistantProviderContent::Anthropic {
338 default_model: Some(model),
339 api_url: None,
340 low_speed_timeout_in_seconds: None,
341 })
342 }
343 LanguageModel::Ollama(model) => {
344 *provider = Some(AssistantProviderContent::Ollama {
345 default_model: Some(model),
346 api_url: None,
347 low_speed_timeout_in_seconds: None,
348 })
349 }
350 },
351 },
352 },
353 AssistantSettingsContent::Legacy(settings) => {
354 if let LanguageModel::OpenAi(model) = new_model {
355 settings.default_open_ai_model = Some(model);
356 }
357 }
358 }
359 }
360}
361
362#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
363#[serde(tag = "version")]
364pub enum VersionedAssistantSettingsContent {
365 #[serde(rename = "1")]
366 V1(AssistantSettingsContentV1),
367}
368
369impl Default for VersionedAssistantSettingsContent {
370 fn default() -> Self {
371 Self::V1(AssistantSettingsContentV1 {
372 enabled: None,
373 button: None,
374 dock: None,
375 default_width: None,
376 default_height: None,
377 provider: None,
378 })
379 }
380}
381
382#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
383pub struct AssistantSettingsContentV1 {
384 /// Whether the Assistant is enabled.
385 ///
386 /// Default: true
387 enabled: Option<bool>,
388 /// Whether to show the assistant panel button in the status bar.
389 ///
390 /// Default: true
391 button: Option<bool>,
392 /// Where to dock the assistant.
393 ///
394 /// Default: right
395 dock: Option<AssistantDockPosition>,
396 /// Default width in pixels when the assistant is docked to the left or right.
397 ///
398 /// Default: 640
399 default_width: Option<f32>,
400 /// Default height in pixels when the assistant is docked to the bottom.
401 ///
402 /// Default: 320
403 default_height: Option<f32>,
404 /// The provider of the assistant service.
405 ///
406 /// This can either be the internal `zed.dev` service or an external `openai` service,
407 /// each with their respective default models and configurations.
408 provider: Option<AssistantProviderContent>,
409}
410
411#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
412pub struct LegacyAssistantSettingsContent {
413 /// Whether to show the assistant panel button in the status bar.
414 ///
415 /// Default: true
416 pub button: Option<bool>,
417 /// Where to dock the assistant.
418 ///
419 /// Default: right
420 pub dock: Option<AssistantDockPosition>,
421 /// Default width in pixels when the assistant is docked to the left or right.
422 ///
423 /// Default: 640
424 pub default_width: Option<f32>,
425 /// Default height in pixels when the assistant is docked to the bottom.
426 ///
427 /// Default: 320
428 pub default_height: Option<f32>,
429 /// The default OpenAI model to use when creating new contexts.
430 ///
431 /// Default: gpt-4-1106-preview
432 pub default_open_ai_model: Option<OpenAiModel>,
433 /// OpenAI API base URL to use when creating new contexts.
434 ///
435 /// Default: https://api.openai.com/v1
436 pub openai_api_url: Option<String>,
437}
438
439impl Settings for AssistantSettings {
440 const KEY: Option<&'static str> = Some("assistant");
441
442 type FileContent = AssistantSettingsContent;
443
444 fn load(
445 sources: SettingsSources<Self::FileContent>,
446 _: &mut gpui::AppContext,
447 ) -> anyhow::Result<Self> {
448 let mut settings = AssistantSettings::default();
449
450 for value in sources.defaults_and_customizations() {
451 let value = value.upgrade();
452 merge(&mut settings.enabled, value.enabled);
453 merge(&mut settings.button, value.button);
454 merge(&mut settings.dock, value.dock);
455 merge(
456 &mut settings.default_width,
457 value.default_width.map(Into::into),
458 );
459 merge(
460 &mut settings.default_height,
461 value.default_height.map(Into::into),
462 );
463 if let Some(provider) = value.provider.clone() {
464 match (&mut settings.provider, provider) {
465 (
466 AssistantProvider::ZedDotDev { model },
467 AssistantProviderContent::ZedDotDev {
468 default_model: model_override,
469 },
470 ) => {
471 merge(model, model_override);
472 }
473 (
474 AssistantProvider::OpenAi {
475 model,
476 api_url,
477 low_speed_timeout_in_seconds,
478 },
479 AssistantProviderContent::OpenAi {
480 default_model: model_override,
481 api_url: api_url_override,
482 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
483 },
484 ) => {
485 merge(model, model_override);
486 merge(api_url, api_url_override);
487 if let Some(low_speed_timeout_in_seconds_override) =
488 low_speed_timeout_in_seconds_override
489 {
490 *low_speed_timeout_in_seconds =
491 Some(low_speed_timeout_in_seconds_override);
492 }
493 }
494 (
495 AssistantProvider::Ollama {
496 model,
497 api_url,
498 low_speed_timeout_in_seconds,
499 },
500 AssistantProviderContent::Ollama {
501 default_model: model_override,
502 api_url: api_url_override,
503 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
504 },
505 ) => {
506 merge(model, model_override);
507 merge(api_url, api_url_override);
508 if let Some(low_speed_timeout_in_seconds_override) =
509 low_speed_timeout_in_seconds_override
510 {
511 *low_speed_timeout_in_seconds =
512 Some(low_speed_timeout_in_seconds_override);
513 }
514 }
515 (
516 AssistantProvider::Anthropic {
517 model,
518 api_url,
519 low_speed_timeout_in_seconds,
520 },
521 AssistantProviderContent::Anthropic {
522 default_model: model_override,
523 api_url: api_url_override,
524 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
525 },
526 ) => {
527 merge(model, model_override);
528 merge(api_url, api_url_override);
529 if let Some(low_speed_timeout_in_seconds_override) =
530 low_speed_timeout_in_seconds_override
531 {
532 *low_speed_timeout_in_seconds =
533 Some(low_speed_timeout_in_seconds_override);
534 }
535 }
536 (provider, provider_override) => {
537 *provider = match provider_override {
538 AssistantProviderContent::ZedDotDev {
539 default_model: model,
540 } => AssistantProvider::ZedDotDev {
541 model: model.unwrap_or_default(),
542 },
543 AssistantProviderContent::OpenAi {
544 default_model: model,
545 api_url,
546 low_speed_timeout_in_seconds,
547 } => AssistantProvider::OpenAi {
548 model: model.unwrap_or_default(),
549 api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
550 low_speed_timeout_in_seconds,
551 },
552 AssistantProviderContent::Anthropic {
553 default_model: model,
554 api_url,
555 low_speed_timeout_in_seconds,
556 } => AssistantProvider::Anthropic {
557 model: model.unwrap_or_default(),
558 api_url: api_url
559 .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
560 low_speed_timeout_in_seconds,
561 },
562 AssistantProviderContent::Ollama {
563 default_model: model,
564 api_url,
565 low_speed_timeout_in_seconds,
566 } => AssistantProvider::Ollama {
567 model: model.unwrap_or_default(),
568 api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
569 low_speed_timeout_in_seconds,
570 },
571 };
572 }
573 }
574 }
575 }
576
577 Ok(settings)
578 }
579}
580
581fn merge<T>(target: &mut T, value: Option<T>) {
582 if let Some(value) = value {
583 *target = value;
584 }
585}
586
587#[cfg(test)]
588mod tests {
589 use gpui::{AppContext, UpdateGlobal};
590 use settings::SettingsStore;
591
592 use super::*;
593
594 #[gpui::test]
595 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
596 let store = settings::SettingsStore::test(cx);
597 cx.set_global(store);
598
599 // Settings default to gpt-4-turbo.
600 AssistantSettings::register(cx);
601 assert_eq!(
602 AssistantSettings::get_global(cx).provider,
603 AssistantProvider::OpenAi {
604 model: OpenAiModel::FourOmni,
605 api_url: open_ai::OPEN_AI_API_URL.into(),
606 low_speed_timeout_in_seconds: None,
607 }
608 );
609
610 // Ensure backward-compatibility.
611 SettingsStore::update_global(cx, |store, cx| {
612 store
613 .set_user_settings(
614 r#"{
615 "assistant": {
616 "openai_api_url": "test-url",
617 }
618 }"#,
619 cx,
620 )
621 .unwrap();
622 });
623 assert_eq!(
624 AssistantSettings::get_global(cx).provider,
625 AssistantProvider::OpenAi {
626 model: OpenAiModel::FourOmni,
627 api_url: "test-url".into(),
628 low_speed_timeout_in_seconds: None,
629 }
630 );
631 SettingsStore::update_global(cx, |store, cx| {
632 store
633 .set_user_settings(
634 r#"{
635 "assistant": {
636 "default_open_ai_model": "gpt-4-0613"
637 }
638 }"#,
639 cx,
640 )
641 .unwrap();
642 });
643 assert_eq!(
644 AssistantSettings::get_global(cx).provider,
645 AssistantProvider::OpenAi {
646 model: OpenAiModel::Four,
647 api_url: open_ai::OPEN_AI_API_URL.into(),
648 low_speed_timeout_in_seconds: None,
649 }
650 );
651
652 // The new version supports setting a custom model when using zed.dev.
653 SettingsStore::update_global(cx, |store, cx| {
654 store
655 .set_user_settings(
656 r#"{
657 "assistant": {
658 "version": "1",
659 "provider": {
660 "name": "zed.dev",
661 "default_model": "custom"
662 }
663 }
664 }"#,
665 cx,
666 )
667 .unwrap();
668 });
669 assert_eq!(
670 AssistantSettings::get_global(cx).provider,
671 AssistantProvider::ZedDotDev {
672 model: CloudModel::Custom("custom".into())
673 }
674 );
675 }
676}