checkpoint

Junkui Zhang created

Change summary

crates/gpui/build.rs                                 |   1 
crates/gpui/src/platform/windows/directx_renderer.rs | 177 ++++++++-----
2 files changed, 113 insertions(+), 65 deletions(-)

Detailed changes

crates/gpui/build.rs 🔗

@@ -289,6 +289,7 @@ mod windows {
             .unwrap();
     }
 
+    /// You can set the `GPUI_FXC_PATH` environment variable to specify the path to the fxc.exe compiler.
     fn compile_shaders() {
         use std::fs;
         use std::process::{self, Command};

crates/gpui/src/platform/windows/directx_renderer.rs 🔗

@@ -13,7 +13,12 @@ use windows::Win32::{
 #[cfg(not(feature = "enable-renderdoc"))]
 use windows::{Win32::Graphics::DirectComposition::*, core::Interface};
 
-use crate::*;
+use crate::{
+    platform::windows::directx_renderer::shader_resources::{
+        RawShaderBytes, ShaderModule, ShaderTarget,
+    },
+    *,
+};
 
 const RENDER_TARGET_FORMAT: DXGI_FORMAT = DXGI_FORMAT_B8G8R8A8_UNORM;
 // This configuration is used for MSAA rendering, and it's guaranteed to be supported by DirectX 11.
@@ -481,35 +486,22 @@ impl DirectXResources {
 
 impl DirectXRenderPipelines {
     pub fn new(device: &ID3D11Device) -> Result<Self> {
-        let shadow_pipeline = PipelineState::new(
-            device,
-            "shadow_pipeline",
-            "shadow_vertex",
-            "shadow_fragment",
-            4,
-        )?;
-        let quad_pipeline =
-            PipelineState::new(device, "quad_pipeline", "quad_vertex", "quad_fragment", 64)?;
+        let shadow_pipeline =
+            PipelineState::new(device, "shadow_pipeline", ShaderModule::Shadow, 4)?;
+        let quad_pipeline = PipelineState::new(device, "quad_pipeline", ShaderModule::Quad, 64)?;
         let paths_pipeline = PathsPipelineState::new(device)?;
-        let underline_pipeline = PipelineState::new(
-            device,
-            "underline_pipeline",
-            "underline_vertex",
-            "underline_fragment",
-            4,
-        )?;
+        let underline_pipeline =
+            PipelineState::new(device, "underline_pipeline", ShaderModule::Underline, 4)?;
         let mono_sprites = PipelineState::new(
             device,
             "monochrome_sprite_pipeline",
-            "monochrome_sprite_vertex",
-            "monochrome_sprite_fragment",
+            ShaderModule::MonochromeSprite,
             512,
         )?;
         let poly_sprites = PipelineState::new(
             device,
             "polychrome_sprite_pipeline",
-            "polychrome_sprite_vertex",
-            "polychrome_sprite_fragment",
+            ShaderModule::PolychromeSprite,
             16,
         )?;
 
@@ -625,29 +617,16 @@ impl<T> PipelineState<T> {
     fn new(
         device: &ID3D11Device,
         label: &'static str,
-        vertex_entry: &str,
-        fragment_entry: &str,
+        shader_module: ShaderModule,
         buffer_size: usize,
     ) -> Result<Self> {
         let vertex = {
-            let shader_blob = shader_resources::build_shader_blob(vertex_entry, "vs_5_0")?;
-            let bytes = unsafe {
-                std::slice::from_raw_parts(
-                    shader_blob.GetBufferPointer() as *mut u8,
-                    shader_blob.GetBufferSize(),
-                )
-            };
-            create_vertex_shader(device, bytes)?
+            let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Vertex)?;
+            create_vertex_shader(device, raw_shader.as_bytes())?
         };
         let fragment = {
-            let shader_blob = shader_resources::build_shader_blob(fragment_entry, "ps_5_0")?;
-            let bytes = unsafe {
-                std::slice::from_raw_parts(
-                    shader_blob.GetBufferPointer() as *mut u8,
-                    shader_blob.GetBufferSize(),
-                )
-            };
-            create_fragment_shader(device, bytes)?
+            let raw_shader = RawShaderBytes::new(shader_module, ShaderTarget::Fragment)?;
+            create_fragment_shader(device, raw_shader.as_bytes())?
         };
         let buffer = create_buffer(device, std::mem::size_of::<T>(), buffer_size)?;
         let view = create_buffer_view(device, &buffer)?;
@@ -740,24 +719,15 @@ impl<T> PipelineState<T> {
 impl PathsPipelineState {
     fn new(device: &ID3D11Device) -> Result<Self> {
         let (vertex, vertex_shader) = {
-            let shader_blob = shader_resources::build_shader_blob("paths_vertex", "vs_5_0")?;
-            let bytes = unsafe {
-                std::slice::from_raw_parts(
-                    shader_blob.GetBufferPointer() as *mut u8,
-                    shader_blob.GetBufferSize(),
-                )
-            };
-            (create_vertex_shader(device, bytes)?, shader_blob)
+            let raw_vertex_shader = RawShaderBytes::new(ShaderModule::Paths, ShaderTarget::Vertex)?;
+            (
+                create_vertex_shader(device, raw_vertex_shader.as_bytes())?,
+                raw_vertex_shader,
+            )
         };
         let fragment = {
-            let shader_blob = shader_resources::build_shader_blob("paths_fragment", "ps_5_0")?;
-            let bytes = unsafe {
-                std::slice::from_raw_parts(
-                    shader_blob.GetBufferPointer() as *mut u8,
-                    shader_blob.GetBufferSize(),
-                )
-            };
-            create_fragment_shader(device, bytes)?
+            let raw_shader = RawShaderBytes::new(ShaderModule::Paths, ShaderTarget::Fragment)?;
+            create_fragment_shader(device, raw_shader.as_bytes())?
         };
         let buffer = create_buffer(device, std::mem::size_of::<PathSprite>(), 32)?;
         let view = create_buffer_view(device, &buffer)?;
@@ -769,10 +739,6 @@ impl PathsPipelineState {
         let indirect_draw_buffer = create_indirect_draw_buffer(device, 32)?;
         // Create input layout
         let input_layout = unsafe {
-            let shader_bytes = std::slice::from_raw_parts(
-                vertex_shader.GetBufferPointer() as *const u8,
-                vertex_shader.GetBufferSize(),
-            );
             let mut layout = None;
             device.CreateInputLayout(
                 &[
@@ -813,7 +779,7 @@ impl PathsPipelineState {
                         InstanceDataStepRate: 0,
                     },
                 ],
-                shader_bytes,
+                vertex_shader.as_bytes(),
                 Some(&mut layout),
             )?;
             layout.unwrap()
@@ -1316,12 +1282,73 @@ const BUFFER_COUNT: usize = 3;
 
 mod shader_resources {
     use anyhow::Result;
-    use windows::Win32::Graphics::Direct3D::{
-        Fxc::{D3DCOMPILE_DEBUG, D3DCOMPILE_SKIP_OPTIMIZATION, D3DCompileFromFile},
-        ID3DBlob,
-    };
-    use windows_core::{HSTRING, PCSTR};
+    // use windows::Win32::Graphics::Direct3D::{
+    //     Fxc::{D3DCOMPILE_DEBUG, D3DCOMPILE_SKIP_OPTIMIZATION, D3DCompileFromFile},
+    //     ID3DBlob,
+    // };
+    // use windows_core::{HSTRING, PCSTR};
+
+    #[derive(Copy, Clone, Debug, Eq, PartialEq)]
+    pub(super) enum ShaderModule {
+        Quad,
+        Shadow,
+        Underline,
+        Paths,
+        MonochromeSprite,
+        PolychromeSprite,
+    }
+
+    #[derive(Copy, Clone, Debug, Eq, PartialEq)]
+    pub(super) enum ShaderTarget {
+        Vertex,
+        Fragment,
+    }
 
+    pub(super) struct RawShaderBytes<'t> {
+        inner: &'t [u8],
+    }
+
+    impl<'t> RawShaderBytes<'t> {
+        pub(super) fn new(module: ShaderModule, target: ShaderTarget) -> Result<Self> {
+            Ok(Self::from_bytes(module, target))
+        }
+
+        pub(super) fn as_bytes(&self) -> &'t [u8] {
+            self.inner
+        }
+
+        fn from_bytes(module: ShaderModule, target: ShaderTarget) -> Self {
+            let bytes = match module {
+                ShaderModule::Quad => match target {
+                    ShaderTarget::Vertex => QUAD_VERTEX_BYTES,
+                    ShaderTarget::Fragment => QUAD_FRAGMENT_BYTES,
+                },
+                ShaderModule::Shadow => match target {
+                    ShaderTarget::Vertex => SHADOW_VERTEX_BYTES,
+                    ShaderTarget::Fragment => SHADOW_FRAGMENT_BYTES,
+                },
+                ShaderModule::Underline => match target {
+                    ShaderTarget::Vertex => UNDERLINE_VERTEX_BYTES,
+                    ShaderTarget::Fragment => UNDERLINE_FRAGMENT_BYTES,
+                },
+                ShaderModule::Paths => match target {
+                    ShaderTarget::Vertex => PATHS_VERTEX_BYTES,
+                    ShaderTarget::Fragment => PATHS_FRAGMENT_BYTES,
+                },
+                ShaderModule::MonochromeSprite => match target {
+                    ShaderTarget::Vertex => MONOCHROME_SPRITE_VERTEX_BYTES,
+                    ShaderTarget::Fragment => MONOCHROME_SPRITE_FRAGMENT_BYTES,
+                },
+                ShaderModule::PolychromeSprite => match target {
+                    ShaderTarget::Vertex => POLYCHROME_SPRITE_VERTEX_BYTES,
+                    ShaderTarget::Fragment => POLYCHROME_SPRITE_FRAGMENT_BYTES,
+                },
+            };
+            Self { inner: bytes }
+        }
+    }
+
+    #[cfg(not(debug_assertions))]
     pub(super) fn build_shader_blob(entry: &str, target: &str) -> Result<ID3DBlob> {
         unsafe {
             let mut entry = entry.to_owned();
@@ -1368,6 +1395,26 @@ mod shader_resources {
             Ok(compile_blob.unwrap())
         }
     }
+
+    include!(concat!(env!("OUT_DIR"), "/shaders_bytes.rs"));
+
+    impl ShaderModule {
+        pub fn as_str(&self) -> &'static str {
+            match self {
+                ShaderModule::Quad => "quad",
+                ShaderModule::Shadow => "shadow",
+                ShaderModule::Underline => "underline",
+                ShaderModule::Paths => "paths",
+                ShaderModule::MonochromeSprite => "monochrome_sprite",
+                ShaderModule::PolychromeSprite => "polychrome_sprite",
+            }
+        }
+    }
+
+    // pub fn quad_vertex_shader() -> &'static [u8] {
+    //     unsafe {
+    //         std::slice::from_raw_parts(g_quad_vertex.as_ptr() as *const u8, g_quad_vertex.len())
+    //     }
 }
 
 mod nvidia {