wgpu_context.rs

  1use anyhow::Context as _;
  2use std::sync::Arc;
  3use util::ResultExt;
  4
  5pub struct WgpuContext {
  6    pub instance: wgpu::Instance,
  7    pub adapter: wgpu::Adapter,
  8    pub device: Arc<wgpu::Device>,
  9    pub queue: Arc<wgpu::Queue>,
 10    dual_source_blending: bool,
 11}
 12
 13impl WgpuContext {
 14    pub fn new() -> anyhow::Result<Self> {
 15        let device_id_filter = match std::env::var("ZED_DEVICE_ID") {
 16            Ok(val) => parse_pci_id(&val)
 17                .context("Failed to parse device ID from `ZED_DEVICE_ID` environment variable")
 18                .log_err(),
 19            Err(std::env::VarError::NotPresent) => None,
 20            err => {
 21                err.context("Failed to read value of `ZED_DEVICE_ID` environment variable")
 22                    .log_err();
 23                None
 24            }
 25        };
 26
 27        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
 28            backends: wgpu::Backends::VULKAN | wgpu::Backends::GL,
 29            flags: wgpu::InstanceFlags::default(),
 30            backend_options: wgpu::BackendOptions::default(),
 31            memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
 32        });
 33
 34        let adapter = smol::block_on(Self::select_adapter(&instance, device_id_filter))?;
 35
 36        log::info!(
 37            "Selected GPU adapter: {:?} ({:?})",
 38            adapter.get_info().name,
 39            adapter.get_info().backend
 40        );
 41
 42        let dual_source_blending_available = adapter
 43            .features()
 44            .contains(wgpu::Features::DUAL_SOURCE_BLENDING);
 45
 46        let mut required_features = wgpu::Features::empty();
 47        if dual_source_blending_available {
 48            required_features |= wgpu::Features::DUAL_SOURCE_BLENDING;
 49        } else {
 50            log::warn!(
 51                "Dual-source blending not available on this GPU. \
 52                Subpixel text antialiasing will be disabled."
 53            );
 54        }
 55
 56        let (device, queue) = smol::block_on(adapter.request_device(&wgpu::DeviceDescriptor {
 57            label: Some("gpui_device"),
 58            required_features,
 59            required_limits: wgpu::Limits::default(),
 60            memory_hints: wgpu::MemoryHints::MemoryUsage,
 61            trace: wgpu::Trace::Off,
 62            experimental_features: wgpu::ExperimentalFeatures::disabled(),
 63        }))
 64        .map_err(|e| anyhow::anyhow!("Failed to create wgpu device: {e}"))?;
 65
 66        Ok(Self {
 67            instance,
 68            adapter,
 69            device: Arc::new(device),
 70            queue: Arc::new(queue),
 71            dual_source_blending: dual_source_blending_available,
 72        })
 73    }
 74
 75    async fn select_adapter(
 76        instance: &wgpu::Instance,
 77        device_id_filter: Option<u32>,
 78    ) -> anyhow::Result<wgpu::Adapter> {
 79        if let Some(device_id) = device_id_filter {
 80            let adapters: Vec<_> = instance.enumerate_adapters(wgpu::Backends::all()).await;
 81
 82            if adapters.is_empty() {
 83                anyhow::bail!("No GPU adapters found");
 84            }
 85
 86            let mut non_matching_adapter_infos: Vec<wgpu::AdapterInfo> = Vec::new();
 87
 88            for adapter in adapters.into_iter() {
 89                let info = adapter.get_info();
 90                if info.device == device_id {
 91                    log::info!(
 92                        "Found GPU matching ZED_DEVICE_ID={:#06x}: {}",
 93                        device_id,
 94                        info.name
 95                    );
 96                    return Ok(adapter);
 97                } else {
 98                    non_matching_adapter_infos.push(info);
 99                }
100            }
101
102            log::warn!(
103                "No GPU found matching ZED_DEVICE_ID={:#06x}. Available devices:",
104                device_id
105            );
106
107            for info in &non_matching_adapter_infos {
108                log::warn!(
109                    "  - {} (device_id={:#06x}, backend={})",
110                    info.name,
111                    info.device,
112                    info.backend
113                );
114            }
115        }
116
117        instance
118            .request_adapter(&wgpu::RequestAdapterOptions {
119                power_preference: wgpu::PowerPreference::None,
120                compatible_surface: None,
121                force_fallback_adapter: false,
122            })
123            .await
124            .map_err(|e| anyhow::anyhow!("Failed to request GPU adapter: {e}"))
125    }
126
127    pub fn supports_dual_source_blending(&self) -> bool {
128        self.dual_source_blending
129    }
130}
131
132fn parse_pci_id(id: &str) -> anyhow::Result<u32> {
133    let mut id = id.trim();
134
135    if id.starts_with("0x") || id.starts_with("0X") {
136        id = &id[2..];
137    }
138    let is_hex_string = id.chars().all(|c| c.is_ascii_hexdigit());
139    let is_4_chars = id.len() == 4;
140    anyhow::ensure!(
141        is_4_chars && is_hex_string,
142        "Expected a 4 digit PCI ID in hexadecimal format"
143    );
144
145    u32::from_str_radix(id, 16).context("parsing PCI ID as hex")
146}
147
148#[cfg(test)]
149mod tests {
150    use super::parse_pci_id;
151
152    #[test]
153    fn test_parse_device_id() {
154        assert!(parse_pci_id("0xABCD").is_ok());
155        assert!(parse_pci_id("ABCD").is_ok());
156        assert!(parse_pci_id("abcd").is_ok());
157        assert!(parse_pci_id("1234").is_ok());
158        assert!(parse_pci_id("123").is_err());
159        assert_eq!(
160            parse_pci_id(&format!("{:x}", 0x1234)).unwrap(),
161            parse_pci_id(&format!("{:X}", 0x1234)).unwrap(),
162        );
163
164        assert_eq!(
165            parse_pci_id(&format!("{:#x}", 0x1234)).unwrap(),
166            parse_pci_id(&format!("{:#X}", 0x1234)).unwrap(),
167        );
168    }
169}