1use std::fmt;
2
3pub use anthropic::Model as AnthropicModel;
4use gpui::Pixels;
5pub use open_ai::Model as OpenAiModel;
6use schemars::{
7 schema::{InstanceType, Metadata, Schema, SchemaObject},
8 JsonSchema,
9};
10use serde::{
11 de::{self, Visitor},
12 Deserialize, Deserializer, Serialize, Serializer,
13};
14use settings::{Settings, SettingsSources};
15use strum::{EnumIter, IntoEnumIterator};
16
17use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
18
19#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
20pub enum CloudModel {
21 Gpt3Point5Turbo,
22 Gpt4,
23 Gpt4Turbo,
24 #[default]
25 Gpt4Omni,
26 Claude3Opus,
27 Claude3Sonnet,
28 Claude3Haiku,
29 Custom(String),
30}
31
32impl Serialize for CloudModel {
33 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
34 where
35 S: Serializer,
36 {
37 serializer.serialize_str(self.id())
38 }
39}
40
41impl<'de> Deserialize<'de> for CloudModel {
42 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
43 where
44 D: Deserializer<'de>,
45 {
46 struct ZedDotDevModelVisitor;
47
48 impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
49 type Value = CloudModel;
50
51 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
52 formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
53 }
54
55 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
56 where
57 E: de::Error,
58 {
59 let model = CloudModel::iter()
60 .find(|model| model.id() == value)
61 .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
62 Ok(model)
63 }
64 }
65
66 deserializer.deserialize_str(ZedDotDevModelVisitor)
67 }
68}
69
70impl JsonSchema for CloudModel {
71 fn schema_name() -> String {
72 "ZedDotDevModel".to_owned()
73 }
74
75 fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
76 let variants = CloudModel::iter()
77 .filter_map(|model| {
78 let id = model.id();
79 if id.is_empty() {
80 None
81 } else {
82 Some(id.to_string())
83 }
84 })
85 .collect::<Vec<_>>();
86 Schema::Object(SchemaObject {
87 instance_type: Some(InstanceType::String.into()),
88 enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
89 metadata: Some(Box::new(Metadata {
90 title: Some("ZedDotDevModel".to_owned()),
91 default: Some(CloudModel::default().id().into()),
92 examples: variants.into_iter().map(Into::into).collect(),
93 ..Default::default()
94 })),
95 ..Default::default()
96 })
97 }
98}
99
100impl CloudModel {
101 pub fn id(&self) -> &str {
102 match self {
103 Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
104 Self::Gpt4 => "gpt-4",
105 Self::Gpt4Turbo => "gpt-4-turbo-preview",
106 Self::Gpt4Omni => "gpt-4o",
107 Self::Claude3Opus => "claude-3-opus",
108 Self::Claude3Sonnet => "claude-3-sonnet",
109 Self::Claude3Haiku => "claude-3-haiku",
110 Self::Custom(id) => id,
111 }
112 }
113
114 pub fn display_name(&self) -> &str {
115 match self {
116 Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
117 Self::Gpt4 => "GPT 4",
118 Self::Gpt4Turbo => "GPT 4 Turbo",
119 Self::Gpt4Omni => "GPT 4 Omni",
120 Self::Claude3Opus => "Claude 3 Opus",
121 Self::Claude3Sonnet => "Claude 3 Sonnet",
122 Self::Claude3Haiku => "Claude 3 Haiku",
123 Self::Custom(id) => id.as_str(),
124 }
125 }
126
127 pub fn max_token_count(&self) -> usize {
128 match self {
129 Self::Gpt3Point5Turbo => 2048,
130 Self::Gpt4 => 4096,
131 Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
132 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => 200000,
133 Self::Custom(_) => 4096, // TODO: Make this configurable
134 }
135 }
136
137 pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
138 match self {
139 Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
140 preprocess_anthropic_request(request)
141 }
142 _ => {}
143 }
144 }
145}
146
147#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
148#[serde(rename_all = "snake_case")]
149pub enum AssistantDockPosition {
150 Left,
151 #[default]
152 Right,
153 Bottom,
154}
155
156#[derive(Debug, PartialEq)]
157pub enum AssistantProvider {
158 ZedDotDev {
159 model: CloudModel,
160 },
161 OpenAi {
162 model: OpenAiModel,
163 api_url: String,
164 low_speed_timeout_in_seconds: Option<u64>,
165 },
166 Anthropic {
167 model: AnthropicModel,
168 api_url: String,
169 low_speed_timeout_in_seconds: Option<u64>,
170 },
171}
172
173impl Default for AssistantProvider {
174 fn default() -> Self {
175 Self::OpenAi {
176 model: OpenAiModel::default(),
177 api_url: open_ai::OPEN_AI_API_URL.into(),
178 low_speed_timeout_in_seconds: None,
179 }
180 }
181}
182
183#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
184#[serde(tag = "name", rename_all = "snake_case")]
185pub enum AssistantProviderContent {
186 #[serde(rename = "zed.dev")]
187 ZedDotDev { default_model: Option<CloudModel> },
188 #[serde(rename = "openai")]
189 OpenAi {
190 default_model: Option<OpenAiModel>,
191 api_url: Option<String>,
192 low_speed_timeout_in_seconds: Option<u64>,
193 },
194 #[serde(rename = "anthropic")]
195 Anthropic {
196 default_model: Option<AnthropicModel>,
197 api_url: Option<String>,
198 low_speed_timeout_in_seconds: Option<u64>,
199 },
200}
201
202#[derive(Debug, Default)]
203pub struct AssistantSettings {
204 pub enabled: bool,
205 pub button: bool,
206 pub dock: AssistantDockPosition,
207 pub default_width: Pixels,
208 pub default_height: Pixels,
209 pub provider: AssistantProvider,
210}
211
212/// Assistant panel settings
213#[derive(Clone, Serialize, Deserialize, Debug)]
214#[serde(untagged)]
215pub enum AssistantSettingsContent {
216 Versioned(VersionedAssistantSettingsContent),
217 Legacy(LegacyAssistantSettingsContent),
218}
219
220impl JsonSchema for AssistantSettingsContent {
221 fn schema_name() -> String {
222 VersionedAssistantSettingsContent::schema_name()
223 }
224
225 fn json_schema(gen: &mut schemars::gen::SchemaGenerator) -> Schema {
226 VersionedAssistantSettingsContent::json_schema(gen)
227 }
228
229 fn is_referenceable() -> bool {
230 VersionedAssistantSettingsContent::is_referenceable()
231 }
232}
233
234impl Default for AssistantSettingsContent {
235 fn default() -> Self {
236 Self::Versioned(VersionedAssistantSettingsContent::default())
237 }
238}
239
240impl AssistantSettingsContent {
241 fn upgrade(&self) -> AssistantSettingsContentV1 {
242 match self {
243 AssistantSettingsContent::Versioned(settings) => match settings {
244 VersionedAssistantSettingsContent::V1(settings) => settings.clone(),
245 },
246 AssistantSettingsContent::Legacy(settings) => AssistantSettingsContentV1 {
247 enabled: None,
248 button: settings.button,
249 dock: settings.dock,
250 default_width: settings.default_width,
251 default_height: settings.default_height,
252 provider: if let Some(open_ai_api_url) = settings.openai_api_url.as_ref() {
253 Some(AssistantProviderContent::OpenAi {
254 default_model: settings.default_open_ai_model.clone(),
255 api_url: Some(open_ai_api_url.clone()),
256 low_speed_timeout_in_seconds: None,
257 })
258 } else {
259 settings.default_open_ai_model.clone().map(|open_ai_model| {
260 AssistantProviderContent::OpenAi {
261 default_model: Some(open_ai_model),
262 api_url: None,
263 low_speed_timeout_in_seconds: None,
264 }
265 })
266 },
267 },
268 }
269 }
270
271 pub fn set_dock(&mut self, dock: AssistantDockPosition) {
272 match self {
273 AssistantSettingsContent::Versioned(settings) => match settings {
274 VersionedAssistantSettingsContent::V1(settings) => {
275 settings.dock = Some(dock);
276 }
277 },
278 AssistantSettingsContent::Legacy(settings) => {
279 settings.dock = Some(dock);
280 }
281 }
282 }
283
284 pub fn set_model(&mut self, new_model: LanguageModel) {
285 match self {
286 AssistantSettingsContent::Versioned(settings) => match settings {
287 VersionedAssistantSettingsContent::V1(settings) => match &mut settings.provider {
288 Some(AssistantProviderContent::ZedDotDev {
289 default_model: model,
290 }) => {
291 if let LanguageModel::Cloud(new_model) = new_model {
292 *model = Some(new_model);
293 }
294 }
295 Some(AssistantProviderContent::OpenAi {
296 default_model: model,
297 ..
298 }) => {
299 if let LanguageModel::OpenAi(new_model) = new_model {
300 *model = Some(new_model);
301 }
302 }
303 Some(AssistantProviderContent::Anthropic {
304 default_model: model,
305 ..
306 }) => {
307 if let LanguageModel::Anthropic(new_model) = new_model {
308 *model = Some(new_model);
309 }
310 }
311 provider => match new_model {
312 LanguageModel::Cloud(model) => {
313 *provider = Some(AssistantProviderContent::ZedDotDev {
314 default_model: Some(model),
315 })
316 }
317 LanguageModel::OpenAi(model) => {
318 *provider = Some(AssistantProviderContent::OpenAi {
319 default_model: Some(model),
320 api_url: None,
321 low_speed_timeout_in_seconds: None,
322 })
323 }
324 LanguageModel::Anthropic(model) => {
325 *provider = Some(AssistantProviderContent::Anthropic {
326 default_model: Some(model),
327 api_url: None,
328 low_speed_timeout_in_seconds: None,
329 })
330 }
331 },
332 },
333 },
334 AssistantSettingsContent::Legacy(settings) => {
335 if let LanguageModel::OpenAi(model) = new_model {
336 settings.default_open_ai_model = Some(model);
337 }
338 }
339 }
340 }
341}
342
343#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
344#[serde(tag = "version")]
345pub enum VersionedAssistantSettingsContent {
346 #[serde(rename = "1")]
347 V1(AssistantSettingsContentV1),
348}
349
350impl Default for VersionedAssistantSettingsContent {
351 fn default() -> Self {
352 Self::V1(AssistantSettingsContentV1 {
353 enabled: None,
354 button: None,
355 dock: None,
356 default_width: None,
357 default_height: None,
358 provider: None,
359 })
360 }
361}
362
363#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
364pub struct AssistantSettingsContentV1 {
365 /// Whether the Assistant is enabled.
366 ///
367 /// Default: true
368 enabled: Option<bool>,
369 /// Whether to show the assistant panel button in the status bar.
370 ///
371 /// Default: true
372 button: Option<bool>,
373 /// Where to dock the assistant.
374 ///
375 /// Default: right
376 dock: Option<AssistantDockPosition>,
377 /// Default width in pixels when the assistant is docked to the left or right.
378 ///
379 /// Default: 640
380 default_width: Option<f32>,
381 /// Default height in pixels when the assistant is docked to the bottom.
382 ///
383 /// Default: 320
384 default_height: Option<f32>,
385 /// The provider of the assistant service.
386 ///
387 /// This can either be the internal `zed.dev` service or an external `openai` service,
388 /// each with their respective default models and configurations.
389 provider: Option<AssistantProviderContent>,
390}
391
392#[derive(Clone, Serialize, Deserialize, JsonSchema, Debug)]
393pub struct LegacyAssistantSettingsContent {
394 /// Whether to show the assistant panel button in the status bar.
395 ///
396 /// Default: true
397 pub button: Option<bool>,
398 /// Where to dock the assistant.
399 ///
400 /// Default: right
401 pub dock: Option<AssistantDockPosition>,
402 /// Default width in pixels when the assistant is docked to the left or right.
403 ///
404 /// Default: 640
405 pub default_width: Option<f32>,
406 /// Default height in pixels when the assistant is docked to the bottom.
407 ///
408 /// Default: 320
409 pub default_height: Option<f32>,
410 /// The default OpenAI model to use when creating new contexts.
411 ///
412 /// Default: gpt-4-1106-preview
413 pub default_open_ai_model: Option<OpenAiModel>,
414 /// OpenAI API base URL to use when creating new contexts.
415 ///
416 /// Default: https://api.openai.com/v1
417 pub openai_api_url: Option<String>,
418}
419
420impl Settings for AssistantSettings {
421 const KEY: Option<&'static str> = Some("assistant");
422
423 type FileContent = AssistantSettingsContent;
424
425 fn load(
426 sources: SettingsSources<Self::FileContent>,
427 _: &mut gpui::AppContext,
428 ) -> anyhow::Result<Self> {
429 let mut settings = AssistantSettings::default();
430
431 for value in sources.defaults_and_customizations() {
432 let value = value.upgrade();
433 merge(&mut settings.enabled, value.enabled);
434 merge(&mut settings.button, value.button);
435 merge(&mut settings.dock, value.dock);
436 merge(
437 &mut settings.default_width,
438 value.default_width.map(Into::into),
439 );
440 merge(
441 &mut settings.default_height,
442 value.default_height.map(Into::into),
443 );
444 if let Some(provider) = value.provider.clone() {
445 match (&mut settings.provider, provider) {
446 (
447 AssistantProvider::ZedDotDev { model },
448 AssistantProviderContent::ZedDotDev {
449 default_model: model_override,
450 },
451 ) => {
452 merge(model, model_override);
453 }
454 (
455 AssistantProvider::OpenAi {
456 model,
457 api_url,
458 low_speed_timeout_in_seconds,
459 },
460 AssistantProviderContent::OpenAi {
461 default_model: model_override,
462 api_url: api_url_override,
463 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
464 },
465 ) => {
466 merge(model, model_override);
467 merge(api_url, api_url_override);
468 if let Some(low_speed_timeout_in_seconds_override) =
469 low_speed_timeout_in_seconds_override
470 {
471 *low_speed_timeout_in_seconds =
472 Some(low_speed_timeout_in_seconds_override);
473 }
474 }
475 (
476 AssistantProvider::Anthropic {
477 model,
478 api_url,
479 low_speed_timeout_in_seconds,
480 },
481 AssistantProviderContent::Anthropic {
482 default_model: model_override,
483 api_url: api_url_override,
484 low_speed_timeout_in_seconds: low_speed_timeout_in_seconds_override,
485 },
486 ) => {
487 merge(model, model_override);
488 merge(api_url, api_url_override);
489 if let Some(low_speed_timeout_in_seconds_override) =
490 low_speed_timeout_in_seconds_override
491 {
492 *low_speed_timeout_in_seconds =
493 Some(low_speed_timeout_in_seconds_override);
494 }
495 }
496 (provider, provider_override) => {
497 *provider = match provider_override {
498 AssistantProviderContent::ZedDotDev {
499 default_model: model,
500 } => AssistantProvider::ZedDotDev {
501 model: model.unwrap_or_default(),
502 },
503 AssistantProviderContent::OpenAi {
504 default_model: model,
505 api_url,
506 low_speed_timeout_in_seconds,
507 } => AssistantProvider::OpenAi {
508 model: model.unwrap_or_default(),
509 api_url: api_url.unwrap_or_else(|| open_ai::OPEN_AI_API_URL.into()),
510 low_speed_timeout_in_seconds,
511 },
512 AssistantProviderContent::Anthropic {
513 default_model: model,
514 api_url,
515 low_speed_timeout_in_seconds,
516 } => AssistantProvider::Anthropic {
517 model: model.unwrap_or_default(),
518 api_url: api_url
519 .unwrap_or_else(|| anthropic::ANTHROPIC_API_URL.into()),
520 low_speed_timeout_in_seconds,
521 },
522 };
523 }
524 }
525 }
526 }
527
528 Ok(settings)
529 }
530}
531
532fn merge<T>(target: &mut T, value: Option<T>) {
533 if let Some(value) = value {
534 *target = value;
535 }
536}
537
538#[cfg(test)]
539mod tests {
540 use gpui::{AppContext, UpdateGlobal};
541 use settings::SettingsStore;
542
543 use super::*;
544
545 #[gpui::test]
546 fn test_deserialize_assistant_settings(cx: &mut AppContext) {
547 let store = settings::SettingsStore::test(cx);
548 cx.set_global(store);
549
550 // Settings default to gpt-4-turbo.
551 AssistantSettings::register(cx);
552 assert_eq!(
553 AssistantSettings::get_global(cx).provider,
554 AssistantProvider::OpenAi {
555 model: OpenAiModel::FourOmni,
556 api_url: open_ai::OPEN_AI_API_URL.into(),
557 low_speed_timeout_in_seconds: None,
558 }
559 );
560
561 // Ensure backward-compatibility.
562 SettingsStore::update_global(cx, |store, cx| {
563 store
564 .set_user_settings(
565 r#"{
566 "assistant": {
567 "openai_api_url": "test-url",
568 }
569 }"#,
570 cx,
571 )
572 .unwrap();
573 });
574 assert_eq!(
575 AssistantSettings::get_global(cx).provider,
576 AssistantProvider::OpenAi {
577 model: OpenAiModel::FourOmni,
578 api_url: "test-url".into(),
579 low_speed_timeout_in_seconds: None,
580 }
581 );
582 SettingsStore::update_global(cx, |store, cx| {
583 store
584 .set_user_settings(
585 r#"{
586 "assistant": {
587 "default_open_ai_model": "gpt-4-0613"
588 }
589 }"#,
590 cx,
591 )
592 .unwrap();
593 });
594 assert_eq!(
595 AssistantSettings::get_global(cx).provider,
596 AssistantProvider::OpenAi {
597 model: OpenAiModel::Four,
598 api_url: open_ai::OPEN_AI_API_URL.into(),
599 low_speed_timeout_in_seconds: None,
600 }
601 );
602
603 // The new version supports setting a custom model when using zed.dev.
604 SettingsStore::update_global(cx, |store, cx| {
605 store
606 .set_user_settings(
607 r#"{
608 "assistant": {
609 "version": "1",
610 "provider": {
611 "name": "zed.dev",
612 "default_model": "custom"
613 }
614 }
615 }"#,
616 cx,
617 )
618 .unwrap();
619 });
620 assert_eq!(
621 AssistantSettings::get_global(cx).provider,
622 AssistantProvider::ZedDotDev {
623 model: CloudModel::Custom("custom".into())
624 }
625 );
626 }
627}