1use anyhow::Context as _;
2use blade_graphics as gpu;
3use std::sync::Arc;
4use util::ResultExt;
5
6#[cfg_attr(target_os = "macos", derive(Clone))]
7pub struct BladeContext {
8 pub(super) gpu: Arc<gpu::Context>,
9}
10
11impl BladeContext {
12 pub fn new() -> anyhow::Result<Self> {
13 let device_id_forced = match std::env::var("ZED_DEVICE_ID") {
14 Ok(val) => parse_pci_id(&val)
15 .context("Failed to parse device ID from `ZED_DEVICE_ID` environment variable")
16 .log_err(),
17 Err(std::env::VarError::NotPresent) => None,
18 err => {
19 err.context("Failed to read value of `ZED_DEVICE_ID` environment variable")
20 .log_err();
21 None
22 }
23 };
24 let gpu = Arc::new(
25 unsafe {
26 gpu::Context::init(gpu::ContextDesc {
27 presentation: true,
28 validation: false,
29 device_id: device_id_forced.unwrap_or(0),
30 ..Default::default()
31 })
32 }
33 .map_err(|e| anyhow::anyhow!("{:?}", e))?,
34 );
35 Ok(Self { gpu })
36 }
37}
38
39fn parse_pci_id(id: &str) -> anyhow::Result<u32> {
40 let mut id = id.trim();
41
42 if id.starts_with("0x") || id.starts_with("0X") {
43 id = &id[2..];
44 }
45 let is_hex_string = id.chars().all(|c| c.is_ascii_hexdigit());
46 let is_4_chars = id.len() == 4;
47 anyhow::ensure!(
48 is_4_chars && is_hex_string,
49 "Expected a 4 digit PCI ID in hexadecimal format"
50 );
51
52 return u32::from_str_radix(id, 16)
53 .map_err(|_| anyhow::anyhow!("Failed to parse PCI ID as hex"));
54}
55
56#[cfg(test)]
57mod tests {
58 use super::parse_pci_id;
59
60 #[test]
61 fn test_parse_device_id() {
62 assert!(parse_pci_id("0xABCD").is_ok());
63 assert!(parse_pci_id("ABCD").is_ok());
64 assert!(parse_pci_id("abcd").is_ok());
65 assert!(parse_pci_id("1234").is_ok());
66 assert!(parse_pci_id("123").is_err());
67 assert_eq!(
68 parse_pci_id(&format!("{:x}", 0x1234)).unwrap(),
69 parse_pci_id(&format!("{:X}", 0x1234)).unwrap(),
70 );
71
72 assert_eq!(
73 parse_pci_id(&format!("{:#x}", 0x1234)).unwrap(),
74 parse_pci_id(&format!("{:#X}", 0x1234)).unwrap(),
75 );
76 assert_eq!(
77 parse_pci_id(&format!("{:#x}", 0x1234)).unwrap(),
78 parse_pci_id(&format!("{:#X}", 0x1234)).unwrap(),
79 );
80 }
81}