1use std::fmt;
2
3use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
4pub use anthropic::Model as AnthropicModel;
5use gpui::Pixels;
6pub use ollama::Model as OllamaModel;
7pub use open_ai::Model as OpenAiModel;
8use schemars::{
9 schema::{InstanceType, Metadata, Schema, SchemaObject},
10 JsonSchema,
11};
12use serde::{
13 de::{self, Visitor},
14 Deserialize, Deserializer, Serialize, Serializer,
15};
16use settings::{Settings, SettingsSources};
17use strum::{EnumIter, IntoEnumIterator};
18
19#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
20pub enum CloudModel {
21 Gpt3Point5Turbo,
22 Gpt4,
23 Gpt4Turbo,
24 #[default]
25 Gpt4Omni,
26 Claude3_5Sonnet,
27 Claude3Opus,
28 Claude3Sonnet,
29 Claude3Haiku,
30 Custom(String),
31}
32
33impl Serialize for CloudModel {
34 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35 where
36 S: Serializer,
37 {
38 serializer.serialize_str(self.id())
39 }
40}
41
42impl<'de> Deserialize<'de> for CloudModel {
43 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
44 where
45 D: Deserializer<'de>,
46 {
47 struct ZedDotDevModelVisitor;
48
49 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
50 type Value = CloudModel;
51
52 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
53 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
54 }
55
56 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
57 where
58 E: de::Error,
59 {
60 let model = CloudModel::iter()
61 .find(|model| model.id() == value)
62 .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
63 Ok(model)
64 }
65 }
66
67 deserializer.deserialize_str(ZedDotDevModelVisitor)
68 }
69}
70
71impl JsonSchema for CloudModel {
72 fn schema_name() -> String {
73 "ZedDotDevModel".to_owned()
74 }
75
76 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
77 let variants = CloudModel::iter()
78 .filter_map(|model| {
79 let id = model.id();
80 if id.is_empty() {
81 None
82 } else {
83 Some(id.to_string())
84 }
85 })
86 .collect::<Vec<_>>();
87 Schema::Object(SchemaObject {
88 instance_type: Some(InstanceType::String.into()),
89 enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
90 metadata: Some(Box::new(Metadata {
91 title: Some("ZedDotDevModel".to_owned()),
92 default: Some(CloudModel::default().id().into()),
93 examples: variants.into_iter().map(Into::into).collect(),
94 ..Default::default()
95 })),
96 ..Default::default()
97 })
98 }
99}
100
101impl CloudModel {
102 pub fn id(&self) -> &str {
103 match self {
104 Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
105 Self::Gpt4 => "gpt-4",
106 Self::Gpt4Turbo => "gpt-4-turbo-preview",
107 Self::Gpt4Omni => "gpt-4o",
108 Self::Claude3_5Sonnet => "claude-3-5-sonnet",
109 Self::Claude3Opus => "claude-3-opus",
110 Self::Claude3Sonnet => "claude-3-sonnet",
111 Self::Claude3Haiku => "claude-3-haiku",
112 Self::Custom(id) => id,
113 }
114 }
115
116 pub fn display_name(&self) -> &str {
117 match self {
118 Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
119 Self::Gpt4 => "GPT 4",
120 Self::Gpt4Turbo => "GPT 4 Turbo",
121 Self::Gpt4Omni => "GPT 4 Omni",
122 Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
123 Self::Claude3Opus => "Claude 3 Opus",
124 Self::Claude3Sonnet => "Claude 3 Sonnet",
125 Self::Claude3Haiku => "Claude 3 Haiku",
126 Self::Custom(id) => id.as_str(),
127 }
128 }
129
130 pub fn max_token_count(&self) -> usize {
131 match self {
132 Self::Gpt3Point5Turbo => 2048,
133 Self::Gpt4 => 4096,
134 Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
135 Self::Claude3_5Sonnet
136 | Self::Claude3Opus
137 | Self::Claude3Sonnet
138 | Self::Claude3Haiku => 200000,
139 Self::Custom(_) => 4096, // TODO: Make this configurable
140 }
141 }
142
143 pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
144 match self {
145 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
146 preprocess_anthropic_request(request)
147 }
148 _ => {}
149 }
150 }
151}
152
153#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
154#[serde(rename_all = "snake_case")]
155pub enum AssistantDockPosition {
156 Left,
157 #[default]
158 Right,
159 Bottom,
160}
161
162#[derive(Debug, PartialEq)]
163pub enum AssistantProvider {
164 ZedDotDev {
165 model: CloudModel,
166 },
167 OpenAi {
168 model: OpenAiModel,
169 api_url: String,
170 low_speed_timeout_in_seconds: Option<u64>,
171 available_models: Vec<OpenAiModel>,
172 },
173 Anthropic {
174 model: AnthropicModel,
175 api_url: String,
176 low_speed_timeout_in_seconds: Option<u64>,
177 },
178 Ollama {
179 model: OllamaModel,
180 api_url: String,
181 low_speed_timeout_in_seconds: Option<u64>,
182 },
183}
184
185impl Default for AssistantProvider {
186 fn default() -> Self {
187 Self::OpenAi {
188 model: OpenAiModel::default(),
189 api_url: open_ai::OPEN_AI_API_URL.into(),
190 low_speed_timeout_in_seconds: None,
191 available_models: Default::default(),
192 }
193 }
194}
195
196#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
197#[serde(tag = "name", rename_all = "snake_case")]
198pub enum AssistantProviderContent {
199 #[serde(rename = "zed.dev")]
200 ZedDotDev { default_model: Option<CloudModel> },
201 #[serde(rename = "openai")]
202 OpenAi {
203 default_model: Option<OpenAiModel>,
204 api_url: Option<String>,
205 low_speed_timeout_in_seconds: Option<u64>,
206 available_models: Option<Vec<OpenAiModel>>,
207 },
208 #[serde(rename = "anthropic")]
209 Anthropic {
210 default_model: Option<AnthropicModel>,
211 api_url: Option<String>,
212 low_speed_timeout_in_seconds: Option<u64>,
213 },
214 #[serde(rename = "ollama")]
215 Ollama {
216 default_model: Option<OllamaModel>,
217 api_url: Option<String>,
218 low_speed_timeout_in_seconds: Option<u64>,
219 },
220}
221
222#[derive(Debug, Default)]
223pub struct AssistantSettings {
224 pub enabled: bool,
225 pub button: bool,
226 pub dock: AssistantDockPosition,
227 pub default_width: Pixels,
228 pub default_height: Pixels,
229 pub provider: AssistantProvider,
230}
231
232/// Assistant panel settings
233#[derive(Clone, Serialize, Deserialize, Debug)]
234#[serde(untagged)]
235pub enum AssistantSettingsContent {
236 Versioned(VersionedAssistantSettingsContent),
237 Legacy(LegacyAssistantSettingsContent),
238}
239
240impl JsonSchema for AssistantSettingsContent {
241 fn schema_name() -> String {
242 VersionedAssistantSettingsContent::schema_name()
243 }
244
245 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
246 VersionedAssistantSettingsContent::json_schema(gen)
247 }
248
249 fn is_referenceable() -> bool {
250 VersionedAssistantSettingsContent::is_referenceable()
251 }
252}
253
254impl Default for AssistantSettingsContent {
255 fn default() -> Self {
256 Self::Versioned(VersionedAssistantSettingsContent::default())
257 }
258}
259
260impl AssistantSettingsContent {
261 fn upgrade(&self) -> AssistantSettingsContentV1 {
262 match self {
263 AssistantSettingsContent::Versioned(settings) => match settings {
264 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
265 },
266 AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
267 enabled: None,
268 button: settings.button,
269 dock: settings.dock,
270 default_width: settings.default_width,
271 default_height: settings.default_height,
272 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
273 Some(AssistantProviderContent::OpenAi {
274 default_model: settings.default_open_ai_model.clone(),
275 api_url: Some(open_ai_api_url.clone()),
276 low_speed_timeout_in_seconds: None,
277 available_models: Some(Default::default()),
278 })
279 } else {
280 settings.default_open_ai_model.clone().map(|open_ai_model| {
281 AssistantProviderContent::OpenAi {
282 default_model: Some(open_ai_model),
283 api_url: None,
284 low_speed_timeout_in_seconds: None,
285 available_models: Some(Default::default()),
286 }
287 })
288 },
289 },
290 }
291 }
292
293 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
294 match self {
295 AssistantSettingsContent::Versioned(settings) => match settings {
296 VersionedAssistantSettingsContent::V1(settings) => {
297 settings.dock = Some(dock);
298 }
299 },
300 AssistantSettingsContent::Legacy(settings) => {
301 settings.dock = Some(dock);
302 }
303 }
304 }
305
306 pub fn set_model(&mut self, new_model: LanguageModel) {
307 match self {
308 AssistantSettingsContent::Versioned(settings) => match settings {
309 VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
310 Some(AssistantProviderContent::ZedDotDev {
311 default_model: model,
312 }) => {
313 if let LanguageModel::Cloud(new_model) = new_model {
314 *model = Some(new_model);
315 }
316 }
317 Some(AssistantProviderContent::OpenAi {
318 default_model: model,
319 ..
320 }) => {
321 if let LanguageModel::OpenAi(new_model) = new_model {
322 *model = Some(new_model);
323 }
324 }
325 Some(AssistantProviderContent::Anthropic {
326 default_model: model,
327 ..
328 }) => {
329 if let LanguageModel::Anthropic(new_model) = new_model {
330 *model = Some(new_model);
331 }
332 }
333 Some(AssistantProviderContent::Ollama {
334 default_model: model,
335 ..
336 }) => {
337 if let LanguageModel::Ollama(new_model) = new_model {
338 *model = Some(new_model);
339 }
340 }
341 provider => match new_model {
342 LanguageModel::Cloud(model) => {
343 *provider = Some(AssistantProviderContent::ZedDotDev {
344 default_model: Some(model),
345 })
346 }
347 LanguageModel::OpenAi(model) => {
348 *provider = Some(AssistantProviderContent::OpenAi {
349 default_model: Some(model),
350 api_url: None,
351 low_speed_timeout_in_seconds: None,
352 available_models: Some(Default::default()),
353 })
354 }
355 LanguageModel::Anthropic(model) => {
356 *provider = Some(AssistantProviderContent::Anthropic {
357 default_model: Some(model),
358 api_url: None,
359 low_speed_timeout_in_seconds: None,
360 })
361 }
362 LanguageModel::Ollama(model) => {
363 *provider = Some(AssistantProviderContent::Ollama {
364 default_model: Some(model),
365 api_url: None,
366 low_speed_timeout_in_seconds: None,
367 })
368 }
369 },
370 },
371 },
372 AssistantSettingsContent::Legacy(settings) => {
373 if let LanguageModel::OpenAi(model) = new_model {
374 settings.default_open_ai_model = Some(model);
375 }
376 }
377 }
378 }
379}
380
381#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
382#[serde(tag = "version")]
383pub enum VersionedAssistantSettingsContent {
384 #[serde(rename = "1")]
385 V1(AssistantSettingsContentV1),
386}
387
388impl Default for VersionedAssistantSettingsContent {
389 fn default() -> Self {
390 Self::V1(AssistantSettingsContentV1 {
391 enabled: None,
392 button: None,
393 dock: None,
394 default_width: None,
395 default_height: None,
396 provider: None,
397 })
398 }
399}
400
401#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
402pub struct AssistantSettingsContentV1 {
403 /// Whether the Assistant is enabled.
404 ///
405 /// Default: true
406 enabled: Option<bool>,
407 /// Whether to show the assistant panel button in the status bar.
408 ///
409 /// Default: true
410 button: Option<bool>,
411 /// Where to dock the assistant.
412 ///
413 /// Default: right
414 dock: Option<AssistantDockPosition>,
415 /// Default width in pixels when the assistant is docked to the left or right.
416 ///
417 /// Default: 640
418 default_width: Option<f32>,
419 /// Default height in pixels when the assistant is docked to the bottom.
420 ///
421 /// Default: 320
422 default_height: Option<f32>,
423 /// The provider of the assistant service.
424 ///
425 /// This can either be the internal `zed.dev` service or an external `openai` service,
426 /// each with their respective default models and configurations.
427 provider: Option<AssistantProviderContent>,
428}
429
430#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
431pub struct LegacyAssistantSettingsContent {
432 /// Whether to show the assistant panel button in the status bar.
433 ///
434 /// Default: true
435 pub button: Option<bool>,
436 /// Where to dock the assistant.
437 ///
438 /// Default: right
439 pub dock: Option<AssistantDockPosition>,
440 /// Default width in pixels when the assistant is docked to the left or right.
441 ///
442 /// Default: 640
443 pub default_width: Option<f32>,
444 /// Default height in pixels when the assistant is docked to the bottom.
445 ///
446 /// Default: 320
447 pub default_height: Option<f32>,
448 /// The default OpenAI model to use when creating new contexts.
449 ///
450 /// Default: gpt-4-1106-preview
451 pub default_open_ai_model: Option<OpenAiModel>,
452 /// OpenAI API base URL to use when creating new contexts.
453 ///
454 /// Default: https://api.openai.com/v1
455 pub openai_api_url: Option<String>,
456}
457
458impl Settings for AssistantSettings {
459 const KEY: Option<&'static str> = Some("assistant");
460
461 type FileContent = AssistantSettingsContent;
462
463 fn load(
464 sources: SettingsSources<Self::FileContent>,
465 _: &mut gpui::AppContext,
466 ) -> anyhow::Result<Self> {
467 let mut settings = AssistantSettings::default();
468
469 for value in sources.defaults_and_customizations() {
470 let value = value.upgrade();
471 merge(&mut settings.enabled, value.enabled);
472 merge(&mut settings.button, value.button);
473 merge(&mut settings.dock, value.dock);
474 merge(
475 &mut settings.default_width,
476 value.default_width.map(Into::into),
477 );
478 merge(
479 &mut settings.default_height,
480 value.default_height.map(Into::into),
481 );
482 if let Some(provider) = value.provider.clone() {
483 match (&mut settings.provider, provider) {
484 (
485 AssistantProvider::ZedDotDev { model },
486 AssistantProviderContent::ZedDotDev {
487 default_model: model_override,
488 },
489 ) => {
490 merge(model, model_override);
491 }
492 (
493 AssistantProvider::OpenAi {
494 model,
495 api_url,
496 low_speed_timeout_in_seconds,
497 available_models,
498 },
499 AssistantProviderContent::OpenAi {
500 default_model: model_override,
501 api_url: api_url_override,
502 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
503 available_models: available_models_override,
504 },
505 ) => {
506 merge(model, model_override);
507 merge(api_url, api_url_override);
508 merge(available_models, available_models_override);
509 if let Some(low_speed_timeout_in_seconds_override) =
510 low_speed_timeout_in_seconds_override
511 {
512 *low_speed_timeout_in_seconds =
513 Some(low_speed_timeout_in_seconds_override);
514 }
515 }
516 (
517 AssistantProvider::Ollama {
518 model,
519 api_url,
520 low_speed_timeout_in_seconds,
521 },
522 AssistantProviderContent::Ollama {
523 default_model: model_override,
524 api_url: api_url_override,
525 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
526 },
527 ) => {
528 merge(model, model_override);
529 merge(api_url, api_url_override);
530 if let Some(low_speed_timeout_in_seconds_override) =
531 low_speed_timeout_in_seconds_override
532 {
533 *low_speed_timeout_in_seconds =
534 Some(low_speed_timeout_in_seconds_override);
535 }
536 }
537 (
538 AssistantProvider::Anthropic {
539 model,
540 api_url,
541 low_speed_timeout_in_seconds,
542 },
543 AssistantProviderContent::Anthropic {
544 default_model: model_override,
545 api_url: api_url_override,
546 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
547 },
548 ) => {
549 merge(model, model_override);
550 merge(api_url, api_url_override);
551 if let Some(low_speed_timeout_in_seconds_override) =
552 low_speed_timeout_in_seconds_override
553 {
554 *low_speed_timeout_in_seconds =
555 Some(low_speed_timeout_in_seconds_override);
556 }
557 }
558 (provider, provider_override) => {
559 *provider = match provider_override {
560 AssistantProviderContent::ZedDotDev {
561 default_model: model,
562 } => AssistantProvider::ZedDotDev {
563 model: model.unwrap_or_default(),
564 },
565 AssistantProviderContent::OpenAi {
566 default_model: model,
567 api_url,
568 low_speed_timeout_in_seconds,
569 available_models,
570 } => AssistantProvider::OpenAi {
571 model: model.unwrap_or_default(),
572 api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
573 low_speed_timeout_in_seconds,
574 available_models: available_models.unwrap_or_default(),
575 },
576 AssistantProviderContent::Anthropic {
577 default_model: model,
578 api_url,
579 low_speed_timeout_in_seconds,
580 } => AssistantProvider::Anthropic {
581 model: model.unwrap_or_default(),
582 api_url: api_url
583 .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
584 low_speed_timeout_in_seconds,
585 },
586 AssistantProviderContent::Ollama {
587 default_model: model,
588 api_url,
589 low_speed_timeout_in_seconds,
590 } => AssistantProvider::Ollama {
591 model: model.unwrap_or_default(),
592 api_url: api_url.unwrap_or_else(|| ollama::OLLAMA_API_URL.into()),
593 low_speed_timeout_in_seconds,
594 },
595 };
596 }
597 }
598 }
599 }
600
601 Ok(settings)
602 }
603}
604
605fn merge<T>(target: &mut T, value: Option<T>) {
606 if let Some(value) = value {
607 *target = value;
608 }
609}
610
611#[cfg(test)]
612mod tests {
613 use gpui::{AppContext, UpdateGlobal};
614 use settings::SettingsStore;
615
616 use super::*;
617
618 #[gpui::test]
619 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
620 let store = settings::SettingsStore::test(cx);
621 cx.set_global(store);
622
623 // Settings default to gpt-4-turbo.
624 AssistantSettings::register(cx);
625 assert_eq!(
626 AssistantSettings::get_global(cx).provider,
627 AssistantProvider::OpenAi {
628 model: OpenAiModel::FourOmni,
629 api_url: open_ai::OPEN_AI_API_URL.into(),
630 low_speed_timeout_in_seconds: None,
631 available_models: Default::default(),
632 }
633 );
634
635 // Ensure backward-compatibility.
636 SettingsStore::update_global(cx, |store, cx| {
637 store
638 .set_user_settings(
639 r#"{
640 "assistant": {
641 "openai_api_url": "test-url",
642 }
643 }"#,
644 cx,
645 )
646 .unwrap();
647 });
648 assert_eq!(
649 AssistantSettings::get_global(cx).provider,
650 AssistantProvider::OpenAi {
651 model: OpenAiModel::FourOmni,
652 api_url: "test-url".into(),
653 low_speed_timeout_in_seconds: None,
654 available_models: Default::default(),
655 }
656 );
657 SettingsStore::update_global(cx, |store, cx| {
658 store
659 .set_user_settings(
660 r#"{
661 "assistant": {
662 "default_open_ai_model": "gpt-4-0613"
663 }
664 }"#,
665 cx,
666 )
667 .unwrap();
668 });
669 assert_eq!(
670 AssistantSettings::get_global(cx).provider,
671 AssistantProvider::OpenAi {
672 model: OpenAiModel::Four,
673 api_url: open_ai::OPEN_AI_API_URL.into(),
674 low_speed_timeout_in_seconds: None,
675 available_models: Default::default(),
676 }
677 );
678
679 // The new version supports setting a custom model when using zed.dev.
680 SettingsStore::update_global(cx, |store, cx| {
681 store
682 .set_user_settings(
683 r#"{
684 "assistant": {
685 "version": "1",
686 "provider": {
687 "name": "zed.dev",
688 "default_model": "custom"
689 }
690 }
691 }"#,
692 cx,
693 )
694 .unwrap();
695 });
696 assert_eq!(
697 AssistantSettings::get_global(cx).provider,
698 AssistantProvider::ZedDotDev {
699 model: CloudModel::Custom("custom".into())
700 }
701 );
702 }
703}