1use anyhow::Context as _;
2use std::sync::Arc;
3use util::ResultExt;
4
5pub struct WgpuContext {
6 pub instance: wgpu::Instance,
7 pub adapter: wgpu::Adapter,
8 pub device: Arc<wgpu::Device>,
9 pub queue: Arc<wgpu::Queue>,
10 dual_source_blending: bool,
11}
12
13impl WgpuContext {
14 pub fn new(instance: wgpu::Instance, surface: &wgpu::Surface<'_>) -> anyhow::Result<Self> {
15 let device_id_filter = match std::env::var("ZED_DEVICE_ID") {
16 Ok(val) => parse_pci_id(&val)
17 .context("Failed to parse device ID from `ZED_DEVICE_ID` environment variable")
18 .log_err(),
19 Err(std::env::VarError::NotPresent) => None,
20 err => {
21 err.context("Failed to read value of `ZED_DEVICE_ID` environment variable")
22 .log_err();
23 None
24 }
25 };
26
27 let adapter = smol::block_on(Self::select_adapter(
28 &instance,
29 device_id_filter,
30 Some(surface),
31 ))?;
32
33 let caps = surface.get_capabilities(&adapter);
34 if caps.formats.is_empty() {
35 let info = adapter.get_info();
36 anyhow::bail!(
37 "No adapter compatible with the display surface could be found. \
38 Best candidate {:?} (backend={:?}, device={:#06x}) reports no \
39 supported surface formats.",
40 info.name,
41 info.backend,
42 info.device,
43 );
44 }
45
46 log::info!(
47 "Selected GPU adapter: {:?} ({:?})",
48 adapter.get_info().name,
49 adapter.get_info().backend
50 );
51
52 let (device, queue, dual_source_blending) = Self::create_device(&adapter)?;
53
54 Ok(Self {
55 instance,
56 adapter,
57 device: Arc::new(device),
58 queue: Arc::new(queue),
59 dual_source_blending,
60 })
61 }
62
63 pub fn instance() -> wgpu::Instance {
64 wgpu::Instance::new(&wgpu::InstanceDescriptor {
65 backends: wgpu::Backends::VULKAN | wgpu::Backends::GL,
66 flags: wgpu::InstanceFlags::default(),
67 backend_options: wgpu::BackendOptions::default(),
68 memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
69 })
70 }
71
72 pub fn check_compatible_with_surface(&self, surface: &wgpu::Surface<'_>) -> anyhow::Result<()> {
73 let caps = surface.get_capabilities(&self.adapter);
74 if caps.formats.is_empty() {
75 let info = self.adapter.get_info();
76 anyhow::bail!(
77 "Adapter {:?} (backend={:?}, device={:#06x}) is not compatible with the \
78 display surface for this window.",
79 info.name,
80 info.backend,
81 info.device,
82 );
83 }
84 Ok(())
85 }
86
87 fn create_device(adapter: &wgpu::Adapter) -> anyhow::Result<(wgpu::Device, wgpu::Queue, bool)> {
88 let dual_source_blending_available = adapter
89 .features()
90 .contains(wgpu::Features::DUAL_SOURCE_BLENDING);
91
92 let mut required_features = wgpu::Features::empty();
93 if dual_source_blending_available {
94 required_features |= wgpu::Features::DUAL_SOURCE_BLENDING;
95 } else {
96 log::warn!(
97 "Dual-source blending not available on this GPU. \
98 Subpixel text antialiasing will be disabled."
99 );
100 }
101
102 let (device, queue) = smol::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
103 label: Some("gpui_device"),
104 required_features,
105 required_limits: wgpu::Limits::default(),
106 memory_hints: wgpu::MemoryHints::MemoryUsage,
107 trace: wgpu::Trace::Off,
108 experimental_features: wgpu::ExperimentalFeatures::disabled(),
109 }))
110 .map_err(|e| anyhow::anyhow!("Failed to create wgpu device: {e}"))?;
111
112 Ok((device, queue, dual_source_blending_available))
113 }
114
115 async fn select_adapter(
116 instance: &wgpu::Instance,
117 device_id_filter: Option<u32>,
118 compatible_surface: Option<&wgpu::Surface<'_>>,
119 ) -> anyhow::Result<wgpu::Adapter> {
120 if let Some(device_id) = device_id_filter {
121 let adapters: Vec<_> = instance.enumerate_adapters(wgpu::Backends::all()).await;
122
123 if adapters.is_empty() {
124 anyhow::bail!("No GPU adapters found");
125 }
126
127 let mut non_matching_adapter_infos: Vec<wgpu::AdapterInfo> = Vec::new();
128
129 for adapter in adapters.into_iter() {
130 let info = adapter.get_info();
131 if info.device == device_id {
132 if let Some(surface) = compatible_surface {
133 let caps = surface.get_capabilities(&adapter);
134 if caps.formats.is_empty() {
135 log::warn!(
136 "GPU matching ZED_DEVICE_ID={:#06x} ({}) is not compatible \
137 with the display surface. Falling back to auto-selection.",
138 device_id,
139 info.name,
140 );
141 break;
142 }
143 }
144 log::info!(
145 "Found GPU matching ZED_DEVICE_ID={:#06x}: {}",
146 device_id,
147 info.name
148 );
149 return Ok(adapter);
150 } else {
151 non_matching_adapter_infos.push(info);
152 }
153 }
154
155 log::warn!(
156 "No compatible GPU found matching ZED_DEVICE_ID={:#06x}. Available devices:",
157 device_id
158 );
159
160 for info in &non_matching_adapter_infos {
161 log::warn!(
162 " - {} (device_id={:#06x}, backend={})",
163 info.name,
164 info.device,
165 info.backend
166 );
167 }
168 }
169
170 instance
171 .request_adapter(&wgpu::RequestAdapterOptions {
172 power_preference: wgpu::PowerPreference::None,
173 compatible_surface,
174 force_fallback_adapter: false,
175 })
176 .await
177 .map_err(|e| anyhow::anyhow!("Failed to request GPU adapter: {e}"))
178 }
179
180 pub fn supports_dual_source_blending(&self) -> bool {
181 self.dual_source_blending
182 }
183}
184
185fn parse_pci_id(id: &str) -> anyhow::Result<u32> {
186 let mut id = id.trim();
187
188 if id.starts_with("0x") || id.starts_with("0X") {
189 id = &id[2..];
190 }
191 let is_hex_string = id.chars().all(|c| c.is_ascii_hexdigit());
192 let is_4_chars = id.len() == 4;
193 anyhow::ensure!(
194 is_4_chars && is_hex_string,
195 "Expected a 4 digit PCI ID in hexadecimal format"
196 );
197
198 u32::from_str_radix(id, 16).context("parsing PCI ID as hex")
199}
200
201#[cfg(test)]
202mod tests {
203 use super::parse_pci_id;
204
205 #[test]
206 fn test_parse_device_id() {
207 assert!(parse_pci_id("0xABCD").is_ok());
208 assert!(parse_pci_id("ABCD").is_ok());
209 assert!(parse_pci_id("abcd").is_ok());
210 assert!(parse_pci_id("1234").is_ok());
211 assert!(parse_pci_id("123").is_err());
212 assert_eq!(
213 parse_pci_id(&format!("{:x}", 0x1234)).unwrap(),
214 parse_pci_id(&format!("{:X}", 0x1234)).unwrap(),
215 );
216
217 assert_eq!(
218 parse_pci_id(&format!("{:#x}", 0x1234)).unwrap(),
219 parse_pci_id(&format!("{:#X}", 0x1234)).unwrap(),
220 );
221 }
222}