1use anyhow::Context as _;
2use std::sync::Arc;
3use util::ResultExt;
4
5pub struct WgpuContext {
6 pub instance: wgpu::Instance,
7 pub adapter: wgpu::Adapter,
8 pub device: Arc<wgpu::Device>,
9 pub queue: Arc<wgpu::Queue>,
10 dual_source_blending: bool,
11}
12
13impl WgpuContext {
14 pub fn new() -> anyhow::Result<Self> {
15 let device_id_filter = match std::env::var("ZED_DEVICE_ID") {
16 Ok(val) => parse_pci_id(&val)
17 .context("Failed to parse device ID from `ZED_DEVICE_ID` environment variable")
18 .log_err(),
19 Err(std::env::VarError::NotPresent) => None,
20 err => {
21 err.context("Failed to read value of `ZED_DEVICE_ID` environment variable")
22 .log_err();
23 None
24 }
25 };
26
27 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
28 backends: wgpu::Backends::VULKAN | wgpu::Backends::GL,
29 flags: wgpu::InstanceFlags::default(),
30 backend_options: wgpu::BackendOptions::default(),
31 memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
32 });
33
34 let adapter = smol::block_on(Self::select_adapter(&instance, device_id_filter))?;
35
36 log::info!(
37 "Selected GPU adapter: {:?} ({:?})",
38 adapter.get_info().name,
39 adapter.get_info().backend
40 );
41
42 let dual_source_blending_available = adapter
43 .features()
44 .contains(wgpu::Features::DUAL_SOURCE_BLENDING);
45
46 let mut required_features = wgpu::Features::empty();
47 if dual_source_blending_available {
48 required_features |= wgpu::Features::DUAL_SOURCE_BLENDING;
49 } else {
50 log::warn!(
51 "Dual-source blending not available on this GPU. \
52 Subpixel text antialiasing will be disabled."
53 );
54 }
55
56 let (device, queue) = smol::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
57 label: Some("gpui_device"),
58 required_features,
59 required_limits: wgpu::Limits::default(),
60 memory_hints: wgpu::MemoryHints::MemoryUsage,
61 trace: wgpu::Trace::Off,
62 experimental_features: wgpu::ExperimentalFeatures::disabled(),
63 }))
64 .map_err(|e| anyhow::anyhow!("Failed to create wgpu device: {e}"))?;
65
66 Ok(Self {
67 instance,
68 adapter,
69 device: Arc::new(device),
70 queue: Arc::new(queue),
71 dual_source_blending: dual_source_blending_available,
72 })
73 }
74
75 async fn select_adapter(
76 instance: &wgpu::Instance,
77 device_id_filter: Option<u32>,
78 ) -> anyhow::Result<wgpu::Adapter> {
79 if let Some(device_id) = device_id_filter {
80 let adapters: Vec<_> = instance.enumerate_adapters(wgpu::Backends::all()).await;
81
82 if adapters.is_empty() {
83 anyhow::bail!("No GPU adapters found");
84 }
85
86 let mut non_matching_adapter_infos: Vec<wgpu::AdapterInfo> = Vec::new();
87
88 for adapter in adapters.into_iter() {
89 let info = adapter.get_info();
90 if info.device == device_id {
91 log::info!(
92 "Found GPU matching ZED_DEVICE_ID={:#06x}: {}",
93 device_id,
94 info.name
95 );
96 return Ok(adapter);
97 } else {
98 non_matching_adapter_infos.push(info);
99 }
100 }
101
102 log::warn!(
103 "No GPU found matching ZED_DEVICE_ID={:#06x}. Available devices:",
104 device_id
105 );
106
107 for info in &non_matching_adapter_infos {
108 log::warn!(
109 " - {} (device_id={:#06x}, backend={})",
110 info.name,
111 info.device,
112 info.backend
113 );
114 }
115 }
116
117 instance
118 .request_adapter(&wgpu::RequestAdapterOptions {
119 power_preference: wgpu::PowerPreference::None,
120 compatible_surface: None,
121 force_fallback_adapter: false,
122 })
123 .await
124 .map_err(|e| anyhow::anyhow!("Failed to request GPU adapter: {e}"))
125 }
126
127 pub fn supports_dual_source_blending(&self) -> bool {
128 self.dual_source_blending
129 }
130}
131
132fn parse_pci_id(id: &str) -> anyhow::Result<u32> {
133 let mut id = id.trim();
134
135 if id.starts_with("0x") || id.starts_with("0X") {
136 id = &id[2..];
137 }
138 let is_hex_string = id.chars().all(|c| c.is_ascii_hexdigit());
139 let is_4_chars = id.len() == 4;
140 anyhow::ensure!(
141 is_4_chars && is_hex_string,
142 "Expected a 4 digit PCI ID in hexadecimal format"
143 );
144
145 u32::from_str_radix(id, 16).context("parsing PCI ID as hex")
146}
147
148#[cfg(test)]
149mod tests {
150 use super::parse_pci_id;
151
152 #[test]
153 fn test_parse_device_id() {
154 assert!(parse_pci_id("0xABCD").is_ok());
155 assert!(parse_pci_id("ABCD").is_ok());
156 assert!(parse_pci_id("abcd").is_ok());
157 assert!(parse_pci_id("1234").is_ok());
158 assert!(parse_pci_id("123").is_err());
159 assert_eq!(
160 parse_pci_id(&format!("{:x}", 0x1234)).unwrap(),
161 parse_pci_id(&format!("{:X}", 0x1234)).unwrap(),
162 );
163
164 assert_eq!(
165 parse_pci_id(&format!("{:#x}", 0x1234)).unwrap(),
166 parse_pci_id(&format!("{:#X}", 0x1234)).unwrap(),
167 );
168 }
169}