@@ -11,8 +11,7 @@ pub struct BladeContext {
impl BladeContext {
pub fn new() -> anyhow::Result<Self> {
let device_id_forced = match std::env::var("ZED_DEVICE_ID") {
- Ok(val) => val
- .parse()
+ Ok(val) => parse_pci_id(&val)
.context("Failed to parse device ID from `ZED_DEVICE_ID` environment variable")
.log_err(),
Err(std::env::VarError::NotPresent) => None,
@@ -36,3 +35,47 @@ impl BladeContext {
Ok(Self { gpu })
}
}
+
+fn parse_pci_id(id: &str) -> anyhow::Result<u32> {
+ let mut id = id.trim();
+
+ if id.starts_with("0x") || id.starts_with("0X") {
+ id = &id[2..];
+ }
+ let is_hex_string = id.chars().all(|c| c.is_ascii_hexdigit());
+ let is_4_chars = id.len() == 4;
+ anyhow::ensure!(
+ is_4_chars && is_hex_string,
+ "Expected a 4 digit PCI ID in hexadecimal format"
+ );
+
+ return u32::from_str_radix(id, 16)
+ .map_err(|_| anyhow::anyhow!("Failed to parse PCI ID as hex"));
+}
+
+#[cfg(test)]
+mod tests {
+ use super::parse_pci_id;
+
+ #[test]
+ fn test_parse_device_id() {
+ assert!(parse_pci_id("0xABCD").is_ok());
+ assert!(parse_pci_id("ABCD").is_ok());
+ assert!(parse_pci_id("abcd").is_ok());
+ assert!(parse_pci_id("1234").is_ok());
+ assert!(parse_pci_id("123").is_err());
+ assert_eq!(
+ parse_pci_id(&format!("{:x}", 0x1234)).unwrap(),
+ parse_pci_id(&format!("{:X}", 0x1234)).unwrap(),
+ );
+
+ assert_eq!(
+ parse_pci_id(&format!("{:#x}", 0x1234)).unwrap(),
+ parse_pci_id(&format!("{:#X}", 0x1234)).unwrap(),
+ );
+ assert_eq!(
+ parse_pci_id(&format!("{:#x}", 0x1234)).unwrap(),
+ parse_pci_id(&format!("{:#X}", 0x1234)).unwrap(),
+ );
+ }
+}