windows: Fix the issue where `ags.dll` couldn’t be replaced during update (#35877)

张小白 and Kirill Bulatov created

Release Notes:

- N/A

---------

Co-authored-by: Kirill Bulatov <kirill@zed.dev>

Change summary

crates/gpui/src/platform/windows/directx_renderer.rs | 66 ++++++++-----
1 file changed, 39 insertions(+), 27 deletions(-)

Detailed changes

crates/gpui/src/platform/windows/directx_renderer.rs 🔗

@@ -4,15 +4,16 @@ use ::util::ResultExt;
 use anyhow::{Context, Result};
 use windows::{
     Win32::{
-        Foundation::{HMODULE, HWND},
+        Foundation::{FreeLibrary, HMODULE, HWND},
         Graphics::{
             Direct3D::*,
             Direct3D11::*,
             DirectComposition::*,
             Dxgi::{Common::*, *},
         },
+        System::LibraryLoader::LoadLibraryA,
     },
-    core::Interface,
+    core::{Interface, PCSTR},
 };
 
 use crate::{
@@ -1618,17 +1619,32 @@ pub(crate) mod shader_resources {
     }
 }
 
+fn with_dll_library<R, F>(dll_name: PCSTR, f: F) -> Result<R>
+where
+    F: FnOnce(HMODULE) -> Result<R>,
+{
+    let library = unsafe {
+        LoadLibraryA(dll_name).with_context(|| format!("Loading dll: {}", dll_name.display()))?
+    };
+    let result = f(library);
+    unsafe {
+        FreeLibrary(library)
+            .with_context(|| format!("Freeing dll: {}", dll_name.display()))
+            .log_err();
+    }
+    result
+}
+
 mod nvidia {
     use std::{
         ffi::CStr,
         os::raw::{c_char, c_int, c_uint},
     };
 
-    use anyhow::{Context, Result};
-    use windows::{
-        Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA},
-        core::s,
-    };
+    use anyhow::Result;
+    use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s};
+
+    use crate::platform::windows::directx_renderer::with_dll_library;
 
     // https://github.com/NVIDIA/nvapi/blob/7cb76fce2f52de818b3da497af646af1ec16ce27/nvapi_lite_common.h#L180
     const NVAPI_SHORT_STRING_MAX: usize = 64;
@@ -1645,13 +1661,12 @@ mod nvidia {
     ) -> c_int;
 
     pub(super) fn get_driver_version() -> Result<String> {
-        unsafe {
-            // Try to load the NVIDIA driver DLL
-            #[cfg(target_pointer_width = "64")]
-            let nvidia_dll = LoadLibraryA(s!("nvapi64.dll")).context("Can't load nvapi64.dll")?;
-            #[cfg(target_pointer_width = "32")]
-            let nvidia_dll = LoadLibraryA(s!("nvapi.dll")).context("Can't load nvapi.dll")?;
+        #[cfg(target_pointer_width = "64")]
+        let nvidia_dll_name = s!("nvapi64.dll");
+        #[cfg(target_pointer_width = "32")]
+        let nvidia_dll_name = s!("nvapi.dll");
 
+        with_dll_library(nvidia_dll_name, |nvidia_dll| unsafe {
             let nvapi_query_addr = GetProcAddress(nvidia_dll, s!("nvapi_QueryInterface"))
                 .ok_or_else(|| anyhow::anyhow!("Failed to get nvapi_QueryInterface address"))?;
             let nvapi_query: extern "C" fn(u32) -> *mut () = std::mem::transmute(nvapi_query_addr);
@@ -1686,18 +1701,17 @@ mod nvidia {
                 minor,
                 branch_string.to_string_lossy()
             ))
-        }
+        })
     }
 }
 
 mod amd {
     use std::os::raw::{c_char, c_int, c_void};
 
-    use anyhow::{Context, Result};
-    use windows::{
-        Win32::System::LibraryLoader::{GetProcAddress, LoadLibraryA},
-        core::s,
-    };
+    use anyhow::Result;
+    use windows::{Win32::System::LibraryLoader::GetProcAddress, core::s};
+
+    use crate::platform::windows::directx_renderer::with_dll_library;
 
     // https://github.com/GPUOpen-LibrariesAndSDKs/AGS_SDK/blob/5d8812d703d0335741b6f7ffc37838eeb8b967f7/ags_lib/inc/amd_ags.h#L145
     const AGS_CURRENT_VERSION: i32 = (6 << 22) | (3 << 12);
@@ -1731,14 +1745,12 @@ mod amd {
     type agsDeInitialize_t = unsafe extern "C" fn(context: *mut AGSContext) -> c_int;
 
     pub(super) fn get_driver_version() -> Result<String> {
-        unsafe {
-            #[cfg(target_pointer_width = "64")]
-            let amd_dll =
-                LoadLibraryA(s!("amd_ags_x64.dll")).context("Failed to load AMD AGS library")?;
-            #[cfg(target_pointer_width = "32")]
-            let amd_dll =
-                LoadLibraryA(s!("amd_ags_x86.dll")).context("Failed to load AMD AGS library")?;
+        #[cfg(target_pointer_width = "64")]
+        let amd_dll_name = s!("amd_ags_x64.dll");
+        #[cfg(target_pointer_width = "32")]
+        let amd_dll_name = s!("amd_ags_x86.dll");
 
+        with_dll_library(amd_dll_name, |amd_dll| unsafe {
             let ags_initialize_addr = GetProcAddress(amd_dll, s!("agsInitialize"))
                 .ok_or_else(|| anyhow::anyhow!("Failed to get agsInitialize address"))?;
             let ags_deinitialize_addr = GetProcAddress(amd_dll, s!("agsDeInitialize"))
@@ -1784,7 +1796,7 @@ mod amd {
 
             ags_deinitialize(context);
             Ok(format!("{} ({})", software_version, driver_version))
-        }
+        })
     }
 }