shaders.metal

  1#include <metal_stdlib>
  2#include <simd/simd.h>
  3
  4using namespace metal;
  5
  6float4 hsla_to_rgba(Hsla hsla);
  7float4 to_device_position(float2 unit_vertex, Bounds_ScaledPixels bounds,
  8                          Bounds_ScaledPixels clip_bounds,
  9                          constant Size_DevicePixels *viewport_size);
 10float2 to_tile_position(float2 unit_vertex, AtlasTile tile,
 11                        constant Size_DevicePixels *atlas_size);
 12float quad_sdf(float2 point, Bounds_ScaledPixels bounds,
 13               Corners_ScaledPixels corner_radii);
 14
 15struct QuadVertexOutput {
 16  float4 position [[position]];
 17  float4 background_color [[flat]];
 18  float4 border_color [[flat]];
 19  uint quad_id [[flat]];
 20};
 21
 22vertex QuadVertexOutput quad_vertex(uint unit_vertex_id [[vertex_id]],
 23                                    uint quad_id [[instance_id]],
 24                                    constant float2 *unit_vertices
 25                                    [[buffer(QuadInputIndex_Vertices)]],
 26                                    constant Quad *quads
 27                                    [[buffer(QuadInputIndex_Quads)]],
 28                                    constant Size_DevicePixels *viewport_size
 29                                    [[buffer(QuadInputIndex_ViewportSize)]]) {
 30  float2 unit_vertex = unit_vertices[unit_vertex_id];
 31  Quad quad = quads[quad_id];
 32  float4 device_position = to_device_position(unit_vertex, quad.bounds,
 33                                              quad.clip_bounds, viewport_size);
 34  float4 background_color = hsla_to_rgba(quad.background);
 35  float4 border_color = hsla_to_rgba(quad.border_color);
 36  return QuadVertexOutput{device_position, background_color, border_color,
 37                          quad_id};
 38}
 39
 40fragment float4 quad_fragment(QuadVertexOutput input [[stage_in]],
 41                              constant Quad *quads
 42                              [[buffer(QuadInputIndex_Quads)]]) {
 43  Quad quad = quads[input.quad_id];
 44  float2 half_size =
 45      float2(quad.bounds.size.width, quad.bounds.size.height) / 2.;
 46  float2 center =
 47      float2(quad.bounds.origin.x, quad.bounds.origin.y) + half_size;
 48  float2 center_to_point = input.position.xy - center;
 49  float corner_radius;
 50  if (center_to_point.x < 0.) {
 51    if (center_to_point.y < 0.) {
 52      corner_radius = quad.corner_radii.top_left;
 53    } else {
 54      corner_radius = quad.corner_radii.bottom_left;
 55    }
 56  } else {
 57    if (center_to_point.y < 0.) {
 58      corner_radius = quad.corner_radii.top_right;
 59    } else {
 60      corner_radius = quad.corner_radii.bottom_right;
 61    }
 62  }
 63
 64  float2 rounded_edge_to_point =
 65      fabs(center_to_point) - half_size + corner_radius;
 66  float distance =
 67      length(max(0., rounded_edge_to_point)) +
 68      min(0., max(rounded_edge_to_point.x, rounded_edge_to_point.y)) -
 69      corner_radius;
 70
 71  float vertical_border = center_to_point.x <= 0. ? quad.border_widths.left
 72                                                  : quad.border_widths.right;
 73  float horizontal_border = center_to_point.y <= 0. ? quad.border_widths.top
 74                                                    : quad.border_widths.bottom;
 75  float2 inset_size =
 76      half_size - corner_radius - float2(vertical_border, horizontal_border);
 77  float2 point_to_inset_corner = fabs(center_to_point) - inset_size;
 78  float border_width;
 79  if (point_to_inset_corner.x < 0. && point_to_inset_corner.y < 0.) {
 80    border_width = 0.;
 81  } else if (point_to_inset_corner.y > point_to_inset_corner.x) {
 82    border_width = horizontal_border;
 83  } else {
 84    border_width = vertical_border;
 85  }
 86
 87  float4 color;
 88  if (border_width == 0.) {
 89    color = input.background_color;
 90  } else {
 91    float inset_distance = distance + border_width;
 92
 93    // Decrease border's opacity as we move inside the background.
 94    input.border_color.a *= 1. - saturate(0.5 - inset_distance);
 95
 96    // Alpha-blend the border and the background.
 97    float output_alpha =
 98        quad.border_color.a + quad.background.a * (1. - quad.border_color.a);
 99    float3 premultiplied_border_rgb =
100        input.border_color.rgb * quad.border_color.a;
101    float3 premultiplied_background_rgb =
102        input.background_color.rgb * input.background_color.a;
103    float3 premultiplied_output_rgb =
104        premultiplied_border_rgb +
105        premultiplied_background_rgb * (1. - input.border_color.a);
106    color = float4(premultiplied_output_rgb, output_alpha);
107  }
108
109  float clip_distance =
110      quad_sdf(input.position.xy, quad.clip_bounds, quad.clip_corner_radii);
111  return color *
112         float4(1., 1., 1.,
113                saturate(0.5 - distance) * saturate(0.5 - clip_distance));
114}
115
116struct MonochromeSpriteVertexOutput {
117  float4 position [[position]];
118  float2 tile_position;
119  float4 color [[flat]];
120  uint sprite_id [[flat]];
121};
122
123vertex MonochromeSpriteVertexOutput monochrome_sprite_vertex(
124    uint unit_vertex_id [[vertex_id]], uint sprite_id [[instance_id]],
125    constant float2 *unit_vertices [[buffer(SpriteInputIndex_Vertices)]],
126    constant MonochromeSprite *sprites [[buffer(SpriteInputIndex_Sprites)]],
127    constant Size_DevicePixels *viewport_size
128    [[buffer(SpriteInputIndex_ViewportSize)]],
129    constant Size_DevicePixels *atlas_size
130    [[buffer(SpriteInputIndex_AtlasTextureSize)]]) {
131
132  float2 unit_vertex = unit_vertices[unit_vertex_id];
133  MonochromeSprite sprite = sprites[sprite_id];
134  float4 device_position = to_device_position(
135      unit_vertex, sprite.bounds, sprite.content_mask.bounds, viewport_size);
136  float2 tile_position = to_tile_position(unit_vertex, sprite.tile, atlas_size);
137  float4 color = hsla_to_rgba(sprite.color);
138  return MonochromeSpriteVertexOutput{device_position, tile_position, color,
139                                      sprite_id};
140}
141
142fragment float4 monochrome_sprite_fragment(
143    MonochromeSpriteVertexOutput input [[stage_in]],
144    constant MonochromeSprite *sprites [[buffer(SpriteInputIndex_Sprites)]],
145    texture2d<float> atlas_texture [[texture(SpriteInputIndex_AtlasTexture)]]) {
146  MonochromeSprite sprite = sprites[input.sprite_id];
147  constexpr sampler atlas_texture_sampler(mag_filter::linear,
148                                          min_filter::linear);
149  float4 sample =
150      atlas_texture.sample(atlas_texture_sampler, input.tile_position);
151  float clip_distance = quad_sdf(input.position.xy, sprite.content_mask.bounds,
152                                 sprite.content_mask.corner_radii);
153  float4 color = input.color;
154  color.a *= sample.a * saturate(0.5 - clip_distance);
155  return color;
156}
157
158struct PolychromeSpriteVertexOutput {
159  float4 position [[position]];
160  float2 tile_position;
161  uint sprite_id [[flat]];
162};
163
164vertex PolychromeSpriteVertexOutput polychrome_sprite_vertex(
165    uint unit_vertex_id [[vertex_id]], uint sprite_id [[instance_id]],
166    constant float2 *unit_vertices [[buffer(SpriteInputIndex_Vertices)]],
167    constant PolychromeSprite *sprites [[buffer(SpriteInputIndex_Sprites)]],
168    constant Size_DevicePixels *viewport_size
169    [[buffer(SpriteInputIndex_ViewportSize)]],
170    constant Size_DevicePixels *atlas_size
171    [[buffer(SpriteInputIndex_AtlasTextureSize)]]) {
172
173  float2 unit_vertex = unit_vertices[unit_vertex_id];
174  PolychromeSprite sprite = sprites[sprite_id];
175  float4 device_position = to_device_position(
176      unit_vertex, sprite.bounds, sprite.content_mask.bounds, viewport_size);
177  float2 tile_position = to_tile_position(unit_vertex, sprite.tile, atlas_size);
178  return PolychromeSpriteVertexOutput{device_position, tile_position,
179                                      sprite_id};
180}
181
182fragment float4 polychrome_sprite_fragment(
183    PolychromeSpriteVertexOutput input [[stage_in]],
184    constant PolychromeSprite *sprites [[buffer(SpriteInputIndex_Sprites)]],
185    texture2d<float> atlas_texture [[texture(SpriteInputIndex_AtlasTexture)]]) {
186  PolychromeSprite sprite = sprites[input.sprite_id];
187  constexpr sampler atlas_texture_sampler(mag_filter::linear,
188                                          min_filter::linear);
189  float4 sample =
190      atlas_texture.sample(atlas_texture_sampler, input.tile_position);
191  float clip_distance = quad_sdf(input.position.xy, sprite.content_mask.bounds,
192                                 sprite.content_mask.corner_radii);
193  float4 color = sample;
194  if (sprite.grayscale) {
195    float grayscale = 0.2126 * color.r + 0.7152 * color.g + 0.0722 * color.b;
196    color.r = grayscale;
197    color.g = grayscale;
198    color.b = grayscale;
199  }
200  color.a *= saturate(0.5 - clip_distance);
201  return color;
202}
203
204float4 hsla_to_rgba(Hsla hsla) {
205  float h = hsla.h * 6.0; // Now, it's an angle but scaled in [0, 6) range
206  float s = hsla.s;
207  float l = hsla.l;
208  float a = hsla.a;
209
210  float c = (1.0 - fabs(2.0 * l - 1.0)) * s;
211  float x = c * (1.0 - fabs(fmod(h, 2.0) - 1.0));
212  float m = l - c / 2.0;
213
214  float r = 0.0;
215  float g = 0.0;
216  float b = 0.0;
217
218  if (h >= 0.0 && h < 1.0) {
219    r = c;
220    g = x;
221    b = 0.0;
222  } else if (h >= 1.0 && h < 2.0) {
223    r = x;
224    g = c;
225    b = 0.0;
226  } else if (h >= 2.0 && h < 3.0) {
227    r = 0.0;
228    g = c;
229    b = x;
230  } else if (h >= 3.0 && h < 4.0) {
231    r = 0.0;
232    g = x;
233    b = c;
234  } else if (h >= 4.0 && h < 5.0) {
235    r = x;
236    g = 0.0;
237    b = c;
238  } else {
239    r = c;
240    g = 0.0;
241    b = x;
242  }
243
244  float4 rgba;
245  rgba.x = (r + m);
246  rgba.y = (g + m);
247  rgba.z = (b + m);
248  rgba.w = a;
249  return rgba;
250}
251
252float4 to_device_position(float2 unit_vertex, Bounds_ScaledPixels bounds,
253                          Bounds_ScaledPixels clip_bounds,
254                          constant Size_DevicePixels *input_viewport_size) {
255  float2 position =
256      unit_vertex * float2(bounds.size.width, bounds.size.height) +
257      float2(bounds.origin.x, bounds.origin.y);
258  position.x = max(clip_bounds.origin.x, position.x);
259  position.x = min(clip_bounds.origin.x + clip_bounds.size.width, position.x);
260  position.y = max(clip_bounds.origin.y, position.y);
261  position.y = min(clip_bounds.origin.y + clip_bounds.size.height, position.y);
262
263  float2 viewport_size = float2((float)input_viewport_size->width,
264                                (float)input_viewport_size->height);
265  float2 device_position =
266      position / viewport_size * float2(2., -2.) + float2(-1., 1.);
267  return float4(device_position, 0., 1.);
268}
269
270float2 to_tile_position(float2 unit_vertex, AtlasTile tile,
271                        constant Size_DevicePixels *atlas_size) {
272  float2 tile_origin = float2(tile.bounds.origin.x, tile.bounds.origin.y);
273  float2 tile_size = float2(tile.bounds.size.width, tile.bounds.size.height);
274  return (tile_origin + unit_vertex * tile_size) /
275         float2((float)atlas_size->width, (float)atlas_size->height);
276}
277
278float quad_sdf(float2 point, Bounds_ScaledPixels bounds,
279               Corners_ScaledPixels corner_radii) {
280  float2 half_size = float2(bounds.size.width, bounds.size.height) / 2.;
281  float2 center = float2(bounds.origin.x, bounds.origin.y) + half_size;
282  float2 center_to_point = point - center;
283  float corner_radius;
284  if (center_to_point.x < 0.) {
285    if (center_to_point.y < 0.) {
286      corner_radius = corner_radii.top_left;
287    } else {
288      corner_radius = corner_radii.bottom_left;
289    }
290  } else {
291    if (center_to_point.y < 0.) {
292      corner_radius = corner_radii.top_right;
293    } else {
294      corner_radius = corner_radii.bottom_right;
295    }
296  }
297
298  float2 rounded_edge_to_point =
299      abs(center_to_point) - half_size + corner_radius;
300  float distance =
301      length(max(0., rounded_edge_to_point)) +
302      min(0., max(rounded_edge_to_point.x, rounded_edge_to_point.y)) -
303      corner_radius;
304
305  return distance;
306}