wgpu_context.rs

  1use anyhow::Context as _;
  2use std::sync::Arc;
  3use util::ResultExt;
  4
  5pub struct WgpuContext {
  6    pub instance: wgpu::Instance,
  7    pub adapter: wgpu::Adapter,
  8    pub device: Arc<wgpu::Device>,
  9    pub queue: Arc<wgpu::Queue>,
 10    dual_source_blending: bool,
 11}
 12
 13impl WgpuContext {
 14    pub fn new(instance: wgpu::Instance, surface: &wgpu::Surface<'_>) -> anyhow::Result<Self> {
 15        let device_id_filter = match std::env::var("ZED_DEVICE_ID") {
 16            Ok(val) => parse_pci_id(&val)
 17                .context("Failed to parse device ID from `ZED_DEVICE_ID` environment variable")
 18                .log_err(),
 19            Err(std::env::VarError::NotPresent) => None,
 20            err => {
 21                err.context("Failed to read value of `ZED_DEVICE_ID` environment variable")
 22                    .log_err();
 23                None
 24            }
 25        };
 26
 27        let adapter = smol::block_on(Self::select_adapter(
 28            &instance,
 29            device_id_filter,
 30            Some(surface),
 31        ))?;
 32
 33        let caps = surface.get_capabilities(&adapter);
 34        if caps.formats.is_empty() {
 35            let info = adapter.get_info();
 36            anyhow::bail!(
 37                "No adapter compatible with the display surface could be found. \
 38                 Best candidate {:?} (backend={:?}, device={:#06x}) reports no \
 39                 supported surface formats.",
 40                info.name,
 41                info.backend,
 42                info.device,
 43            );
 44        }
 45
 46        log::info!(
 47            "Selected GPU adapter: {:?} ({:?})",
 48            adapter.get_info().name,
 49            adapter.get_info().backend
 50        );
 51
 52        let (device, queue, dual_source_blending) = Self::create_device(&adapter)?;
 53
 54        Ok(Self {
 55            instance,
 56            adapter,
 57            device: Arc::new(device),
 58            queue: Arc::new(queue),
 59            dual_source_blending,
 60        })
 61    }
 62
 63    pub fn instance() -> wgpu::Instance {
 64        wgpu::Instance::new(&wgpu::InstanceDescriptor {
 65            backends: wgpu::Backends::VULKAN | wgpu::Backends::GL,
 66            flags: wgpu::InstanceFlags::default(),
 67            backend_options: wgpu::BackendOptions::default(),
 68            memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
 69        })
 70    }
 71
 72    pub fn check_compatible_with_surface(&self, surface: &wgpu::Surface<'_>) -> anyhow::Result<()> {
 73        let caps = surface.get_capabilities(&self.adapter);
 74        if caps.formats.is_empty() {
 75            let info = self.adapter.get_info();
 76            anyhow::bail!(
 77                "Adapter {:?} (backend={:?}, device={:#06x}) is not compatible with the \
 78                 display surface for this window.",
 79                info.name,
 80                info.backend,
 81                info.device,
 82            );
 83        }
 84        Ok(())
 85    }
 86
 87    fn create_device(adapter: &wgpu::Adapter) -> anyhow::Result<(wgpu::Device, wgpu::Queue, bool)> {
 88        let dual_source_blending_available = adapter
 89            .features()
 90            .contains(wgpu::Features::DUAL_SOURCE_BLENDING);
 91
 92        let mut required_features = wgpu::Features::empty();
 93        if dual_source_blending_available {
 94            required_features |= wgpu::Features::DUAL_SOURCE_BLENDING;
 95        } else {
 96            log::warn!(
 97                "Dual-source blending not available on this GPU. \
 98                Subpixel text antialiasing will be disabled."
 99            );
100        }
101
102        let (device, queue) = smol::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
103            label: Some("gpui_device"),
104            required_features,
105            required_limits: wgpu::Limits::default(),
106            memory_hints: wgpu::MemoryHints::MemoryUsage,
107            trace: wgpu::Trace::Off,
108            experimental_features: wgpu::ExperimentalFeatures::disabled(),
109        }))
110        .map_err(|e| anyhow::anyhow!("Failed to create wgpu device: {e}"))?;
111
112        Ok((device, queue, dual_source_blending_available))
113    }
114
115    async fn select_adapter(
116        instance: &wgpu::Instance,
117        device_id_filter: Option<u32>,
118        compatible_surface: Option<&wgpu::Surface<'_>>,
119    ) -> anyhow::Result<wgpu::Adapter> {
120        if let Some(device_id) = device_id_filter {
121            let adapters: Vec<_> = instance.enumerate_adapters(wgpu::Backends::all()).await;
122
123            if adapters.is_empty() {
124                anyhow::bail!("No GPU adapters found");
125            }
126
127            let mut non_matching_adapter_infos: Vec<wgpu::AdapterInfo> = Vec::new();
128
129            for adapter in adapters.into_iter() {
130                let info = adapter.get_info();
131                if info.device == device_id {
132                    if let Some(surface) = compatible_surface {
133                        let caps = surface.get_capabilities(&adapter);
134                        if caps.formats.is_empty() {
135                            log::warn!(
136                                "GPU matching ZED_DEVICE_ID={:#06x} ({}) is not compatible \
137                                 with the display surface. Falling back to auto-selection.",
138                                device_id,
139                                info.name,
140                            );
141                            break;
142                        }
143                    }
144                    log::info!(
145                        "Found GPU matching ZED_DEVICE_ID={:#06x}: {}",
146                        device_id,
147                        info.name
148                    );
149                    return Ok(adapter);
150                } else {
151                    non_matching_adapter_infos.push(info);
152                }
153            }
154
155            log::warn!(
156                "No compatible GPU found matching ZED_DEVICE_ID={:#06x}. Available devices:",
157                device_id
158            );
159
160            for info in &non_matching_adapter_infos {
161                log::warn!(
162                    "  - {} (device_id={:#06x}, backend={})",
163                    info.name,
164                    info.device,
165                    info.backend
166                );
167            }
168        }
169
170        instance
171            .request_adapter(&wgpu::RequestAdapterOptions {
172                power_preference: wgpu::PowerPreference::None,
173                compatible_surface,
174                force_fallback_adapter: false,
175            })
176            .await
177            .map_err(|e| anyhow::anyhow!("Failed to request GPU adapter: {e}"))
178    }
179
180    pub fn supports_dual_source_blending(&self) -> bool {
181        self.dual_source_blending
182    }
183}
184
185fn parse_pci_id(id: &str) -> anyhow::Result<u32> {
186    let mut id = id.trim();
187
188    if id.starts_with("0x") || id.starts_with("0X") {
189        id = &id[2..];
190    }
191    let is_hex_string = id.chars().all(|c| c.is_ascii_hexdigit());
192    let is_4_chars = id.len() == 4;
193    anyhow::ensure!(
194        is_4_chars && is_hex_string,
195        "Expected a 4 digit PCI ID in hexadecimal format"
196    );
197
198    u32::from_str_radix(id, 16).context("parsing PCI ID as hex")
199}
200
201#[cfg(test)]
202mod tests {
203    use super::parse_pci_id;
204
205    #[test]
206    fn test_parse_device_id() {
207        assert!(parse_pci_id("0xABCD").is_ok());
208        assert!(parse_pci_id("ABCD").is_ok());
209        assert!(parse_pci_id("abcd").is_ok());
210        assert!(parse_pci_id("1234").is_ok());
211        assert!(parse_pci_id("123").is_err());
212        assert_eq!(
213            parse_pci_id(&format!("{:x}", 0x1234)).unwrap(),
214            parse_pci_id(&format!("{:X}", 0x1234)).unwrap(),
215        );
216
217        assert_eq!(
218            parse_pci_id(&format!("{:#x}", 0x1234)).unwrap(),
219            parse_pci_id(&format!("{:#X}", 0x1234)).unwrap(),
220        );
221    }
222}