introduce `set_pipeline_state`

Junkui Zhang created

Change summary

crates/gpui/src/platform/windows/directx_renderer.rs | 86 +++++++++----
1 file changed, 61 insertions(+), 25 deletions(-)

Detailed changes

crates/gpui/src/platform/windows/directx_renderer.rs 🔗

@@ -616,16 +616,16 @@ impl<T> PipelineState<T> {
         global_params: &[Option<ID3D11Buffer>],
         instance_count: u32,
     ) -> Result<()> {
+        set_pipeline_state(
+            device_context,
+            &self.view,
+            D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP,
+            viewport,
+            &self.vertex,
+            &self.fragment,
+            global_params,
+        );
         unsafe {
-            device_context.VSSetShaderResources(1, Some(&self.view));
-            device_context.PSSetShaderResources(1, Some(&self.view));
-            device_context.IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP);
-            device_context.RSSetViewports(Some(viewport));
-            device_context.VSSetShader(&self.vertex, None);
-            device_context.PSSetShader(&self.fragment, None);
-            device_context.VSSetConstantBuffers(0, Some(global_params));
-            device_context.PSSetConstantBuffers(0, Some(global_params));
-
             device_context.DrawInstanced(4, instance_count, 0, 0);
         }
         Ok(())
@@ -640,15 +640,16 @@ impl<T> PipelineState<T> {
         sampler: &[Option<ID3D11SamplerState>],
         instance_count: u32,
     ) -> Result<()> {
+        set_pipeline_state(
+            device_context,
+            &self.view,
+            D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP,
+            viewport,
+            &self.vertex,
+            &self.fragment,
+            global_params,
+        );
         unsafe {
-            device_context.IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLESTRIP);
-            device_context.RSSetViewports(Some(viewport));
-            device_context.VSSetShader(&self.vertex, None);
-            device_context.PSSetShader(&self.fragment, None);
-            device_context.VSSetConstantBuffers(0, Some(global_params));
-            device_context.PSSetConstantBuffers(0, Some(global_params));
-            device_context.VSSetShaderResources(1, Some(&self.view));
-            device_context.PSSetShaderResources(1, Some(&self.view));
             device_context.PSSetSamplers(0, Some(sampler));
             device_context.VSSetShaderResources(0, Some(texture));
             device_context.PSSetShaderResources(0, Some(texture));
@@ -811,15 +812,16 @@ impl PathsPipelineState {
         viewport: &[D3D11_VIEWPORT],
         global_params: &[Option<ID3D11Buffer>],
     ) -> Result<()> {
+        set_pipeline_state(
+            device_context,
+            &self.view,
+            D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST,
+            viewport,
+            &self.vertex,
+            &self.fragment,
+            global_params,
+        );
         unsafe {
-            device_context.VSSetShaderResources(1, Some(&self.view));
-            device_context.PSSetShaderResources(1, Some(&self.view));
-            device_context.IASetPrimitiveTopology(D3D_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
-            device_context.RSSetViewports(Some(viewport));
-            device_context.VSSetShader(&self.vertex, None);
-            device_context.PSSetShader(&self.fragment, None);
-            device_context.VSSetConstantBuffers(0, Some(global_params));
-            device_context.PSSetConstantBuffers(0, Some(global_params));
             const STRIDE: u32 = std::mem::size_of::<PathVertex<ScaledPixels>>() as u32;
             device_context.IASetVertexBuffers(
                 0,
@@ -849,6 +851,7 @@ struct PathSprite {
     color: Background,
 }
 
+#[inline]
 fn get_dxgi_factory() -> Result<IDXGIFactory6> {
     #[cfg(debug_assertions)]
     let factory_flag = DXGI_CREATE_FACTORY_DEBUG;
@@ -970,6 +973,7 @@ fn create_swap_chain_default(
     Ok(swap_chain)
 }
 
+#[inline]
 fn set_render_target_view(
     swap_chain: &IDXGISwapChain1,
     device: &ID3D11Device,
@@ -987,6 +991,7 @@ fn set_render_target_view(
     Ok(back_buffer)
 }
 
+#[inline]
 fn set_viewport(
     device_context: &ID3D11DeviceContext,
     width: f32,
@@ -1004,6 +1009,7 @@ fn set_viewport(
     viewport
 }
 
+#[inline]
 fn set_rasterizer_state(device: &ID3D11Device, device_context: &ID3D11DeviceContext) -> Result<()> {
     let desc = D3D11_RASTERIZER_DESC {
         FillMode: D3D11_FILL_SOLID,
@@ -1028,6 +1034,7 @@ fn set_rasterizer_state(device: &ID3D11Device, device_context: &ID3D11DeviceCont
 }
 
 // https://learn.microsoft.com/en-us/windows/win32/api/d3d11/ns-d3d11-d3d11_blend_desc
+#[inline]
 fn create_blend_state(device: &ID3D11Device) -> Result<ID3D11BlendState> {
     // If the feature level is set to greater than D3D_FEATURE_LEVEL_9_3, the display
     // device performs the blend in linear space, which is ideal.
@@ -1047,6 +1054,7 @@ fn create_blend_state(device: &ID3D11Device) -> Result<ID3D11BlendState> {
     }
 }
 
+#[inline]
 fn create_vertex_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11VertexShader> {
     unsafe {
         let mut shader = None;
@@ -1055,6 +1063,7 @@ fn create_vertex_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11Ver
     }
 }
 
+#[inline]
 fn create_fragment_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11PixelShader> {
     unsafe {
         let mut shader = None;
@@ -1063,6 +1072,7 @@ fn create_fragment_shader(device: &ID3D11Device, bytes: &[u8]) -> Result<ID3D11P
     }
 }
 
+#[inline]
 fn create_buffer(
     device: &ID3D11Device,
     element_size: usize,
@@ -1081,6 +1091,7 @@ fn create_buffer(
     Ok(buffer.unwrap())
 }
 
+#[inline]
 fn create_buffer_view(
     device: &ID3D11Device,
     buffer: &ID3D11Buffer,
@@ -1090,6 +1101,7 @@ fn create_buffer_view(
     Ok([view])
 }
 
+#[inline]
 fn create_indirect_draw_buffer(device: &ID3D11Device, buffer_size: usize) -> Result<ID3D11Buffer> {
     let desc = D3D11_BUFFER_DESC {
         ByteWidth: (std::mem::size_of::<DrawInstancedIndirectArgs>() * buffer_size) as u32,
@@ -1104,6 +1116,7 @@ fn create_indirect_draw_buffer(device: &ID3D11Device, buffer_size: usize) -> Res
     Ok(buffer.unwrap())
 }
 
+#[inline]
 fn pre_draw(
     device_context: &ID3D11DeviceContext,
     global_params_buffer: &[Option<ID3D11Buffer>; 1],
@@ -1130,6 +1143,7 @@ fn pre_draw(
     Ok(())
 }
 
+#[inline]
 fn update_buffer<T>(
     device_context: &ID3D11DeviceContext,
     buffer: &ID3D11Buffer,
@@ -1144,6 +1158,28 @@ fn update_buffer<T>(
     Ok(())
 }
 
+#[inline]
+fn set_pipeline_state(
+    device_context: &ID3D11DeviceContext,
+    buffer_view: &[Option<ID3D11ShaderResourceView>],
+    topology: D3D_PRIMITIVE_TOPOLOGY,
+    viewport: &[D3D11_VIEWPORT],
+    vertex_shader: &ID3D11VertexShader,
+    fragment_shader: &ID3D11PixelShader,
+    global_params: &[Option<ID3D11Buffer>],
+) {
+    unsafe {
+        device_context.VSSetShaderResources(1, Some(buffer_view));
+        device_context.PSSetShaderResources(1, Some(buffer_view));
+        device_context.IASetPrimitiveTopology(topology);
+        device_context.RSSetViewports(Some(viewport));
+        device_context.VSSetShader(vertex_shader, None);
+        device_context.PSSetShader(fragment_shader, None);
+        device_context.VSSetConstantBuffers(0, Some(global_params));
+        device_context.PSSetConstantBuffers(0, Some(global_params));
+    }
+}
+
 const BUFFER_COUNT: usize = 3;
 
 mod shader_resources {