wgpu_atlas.rs

  1use anyhow::{Context as _, Result};
  2use collections::FxHashMap;
  3use etagere::{BucketedAtlasAllocator, size2};
  4use gpui::{
  5    AtlasKey, AtlasTextureId, AtlasTextureKind, AtlasTextureList, AtlasTile, Bounds, DevicePixels,
  6    PlatformAtlas, Point, Size,
  7};
  8use parking_lot::Mutex;
  9use std::{borrow::Cow, ops, sync::Arc};
 10
 11use crate::WgpuContext;
 12
 13fn device_size_to_etagere(size: Size<DevicePixels>) -> etagere::Size {
 14    size2(size.width.0, size.height.0)
 15}
 16
 17fn etagere_point_to_device(point: etagere::Point) -> Point<DevicePixels> {
 18    Point {
 19        x: DevicePixels(point.x),
 20        y: DevicePixels(point.y),
 21    }
 22}
 23
 24pub struct WgpuAtlas(Mutex<WgpuAtlasState>);
 25
 26struct PendingUpload {
 27    id: AtlasTextureId,
 28    bounds: Bounds<DevicePixels>,
 29    data: Vec<u8>,
 30}
 31
 32struct WgpuAtlasState {
 33    device: Arc<wgpu::Device>,
 34    queue: Arc<wgpu::Queue>,
 35    max_texture_size: u32,
 36    color_texture_format: wgpu::TextureFormat,
 37    storage: WgpuAtlasStorage,
 38    tiles_by_key: FxHashMap<AtlasKey, AtlasTile>,
 39    pending_uploads: Vec<PendingUpload>,
 40}
 41
 42pub struct WgpuTextureInfo {
 43    pub view: wgpu::TextureView,
 44}
 45
 46impl WgpuAtlas {
 47    pub fn new(
 48        device: Arc<wgpu::Device>,
 49        queue: Arc<wgpu::Queue>,
 50        color_texture_format: wgpu::TextureFormat,
 51    ) -> Self {
 52        let max_texture_size = device.limits().max_texture_dimension_2d;
 53        WgpuAtlas(Mutex::new(WgpuAtlasState {
 54            device,
 55            queue,
 56            max_texture_size,
 57            color_texture_format,
 58            storage: WgpuAtlasStorage::default(),
 59            tiles_by_key: Default::default(),
 60            pending_uploads: Vec::new(),
 61        }))
 62    }
 63
 64    pub fn from_context(context: &WgpuContext) -> Self {
 65        Self::new(
 66            context.device.clone(),
 67            context.queue.clone(),
 68            context.color_texture_format(),
 69        )
 70    }
 71
 72    pub fn before_frame(&self) {
 73        let mut lock = self.0.lock();
 74        lock.flush_uploads();
 75    }
 76
 77    pub fn get_texture_info(&self, id: AtlasTextureId) -> WgpuTextureInfo {
 78        let lock = self.0.lock();
 79        let texture = &lock.storage[id];
 80        WgpuTextureInfo {
 81            view: texture.view.clone(),
 82        }
 83    }
 84
 85    /// Clears all cached textures and tiles, forcing them to be recreated.
 86    /// Use this for incremental recovery when the device is still valid.
 87    pub fn clear(&self) {
 88        let mut lock = self.0.lock();
 89        lock.storage = WgpuAtlasStorage::default();
 90        lock.tiles_by_key.clear();
 91        lock.pending_uploads.clear();
 92    }
 93
 94    /// Handles device lost by clearing all textures and cached tiles.
 95    /// The atlas will lazily recreate textures as needed on subsequent frames.
 96    pub fn handle_device_lost(&self, context: &WgpuContext) {
 97        let mut lock = self.0.lock();
 98        lock.device = context.device.clone();
 99        lock.queue = context.queue.clone();
100        lock.color_texture_format = context.color_texture_format();
101        lock.storage = WgpuAtlasStorage::default();
102        lock.tiles_by_key.clear();
103        lock.pending_uploads.clear();
104    }
105}
106
107impl PlatformAtlas for WgpuAtlas {
108    fn get_or_insert_with<'a>(
109        &self,
110        key: &AtlasKey,
111        build: &mut dyn FnMut() -> Result<Option<(Size<DevicePixels>, Cow<'a, [u8]>)>>,
112    ) -> Result<Option<AtlasTile>> {
113        let mut lock = self.0.lock();
114        if let Some(tile) = lock.tiles_by_key.get(key) {
115            Ok(Some(*tile))
116        } else {
117            profiling::scope!("new tile");
118            let Some((size, bytes)) = build()? else {
119                return Ok(None);
120            };
121            let tile = lock
122                .allocate(size, key.texture_kind())
123                .context("failed to allocate")?;
124            lock.upload_texture(tile.texture_id, tile.bounds, &bytes);
125            lock.tiles_by_key.insert(key.clone(), tile);
126            Ok(Some(tile))
127        }
128    }
129
130    fn remove(&self, key: &AtlasKey) {
131        let mut lock = self.0.lock();
132
133        let Some(id) = lock.tiles_by_key.remove(key).map(|tile| tile.texture_id) else {
134            return;
135        };
136
137        let Some(texture_slot) = lock.storage[id.kind].textures.get_mut(id.index as usize) else {
138            return;
139        };
140
141        if let Some(mut texture) = texture_slot.take() {
142            texture.decrement_ref_count();
143            if texture.is_unreferenced() {
144                lock.pending_uploads
145                    .retain(|upload| upload.id != texture.id);
146                lock.storage[id.kind]
147                    .free_list
148                    .push(texture.id.index as usize);
149            } else {
150                *texture_slot = Some(texture);
151            }
152        }
153    }
154}
155
156impl WgpuAtlasState {
157    fn allocate(
158        &mut self,
159        size: Size<DevicePixels>,
160        texture_kind: AtlasTextureKind,
161    ) -> Option<AtlasTile> {
162        {
163            let textures = &mut self.storage[texture_kind];
164
165            if let Some(tile) = textures
166                .iter_mut()
167                .rev()
168                .find_map(|texture| texture.allocate(size))
169            {
170                return Some(tile);
171            }
172        }
173
174        let texture = self.push_texture(size, texture_kind);
175        texture.allocate(size)
176    }
177
178    fn push_texture(
179        &mut self,
180        min_size: Size<DevicePixels>,
181        kind: AtlasTextureKind,
182    ) -> &mut WgpuAtlasTexture {
183        const DEFAULT_ATLAS_SIZE: Size<DevicePixels> = Size {
184            width: DevicePixels(1024),
185            height: DevicePixels(1024),
186        };
187        let max_texture_size = self.max_texture_size as i32;
188        let max_atlas_size = Size {
189            width: DevicePixels(max_texture_size),
190            height: DevicePixels(max_texture_size),
191        };
192
193        let size = min_size.min(&max_atlas_size).max(&DEFAULT_ATLAS_SIZE);
194        let format = match kind {
195            AtlasTextureKind::Monochrome => wgpu::TextureFormat::R8Unorm,
196            AtlasTextureKind::Subpixel | AtlasTextureKind::Polychrome => self.color_texture_format,
197        };
198
199        let texture = self.device.create_texture(&wgpu::TextureDescriptor {
200            label: Some("atlas"),
201            size: wgpu::Extent3d {
202                width: size.width.0 as u32,
203                height: size.height.0 as u32,
204                depth_or_array_layers: 1,
205            },
206            mip_level_count: 1,
207            sample_count: 1,
208            dimension: wgpu::TextureDimension::D2,
209            format,
210            usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST,
211            view_formats: &[],
212        });
213
214        let view = texture.create_view(&wgpu::TextureViewDescriptor::default());
215
216        let texture_list = &mut self.storage[kind];
217        let index = texture_list.free_list.pop();
218
219        let atlas_texture = WgpuAtlasTexture {
220            id: AtlasTextureId {
221                index: index.unwrap_or(texture_list.textures.len()) as u32,
222                kind,
223            },
224            allocator: BucketedAtlasAllocator::new(device_size_to_etagere(size)),
225            format,
226            texture,
227            view,
228            live_atlas_keys: 0,
229        };
230
231        if let Some(ix) = index {
232            texture_list.textures[ix] = Some(atlas_texture);
233            texture_list
234                .textures
235                .get_mut(ix)
236                .and_then(|t| t.as_mut())
237                .expect("texture must exist")
238        } else {
239            texture_list.textures.push(Some(atlas_texture));
240            texture_list
241                .textures
242                .last_mut()
243                .and_then(|t| t.as_mut())
244                .expect("texture must exist")
245        }
246    }
247
248    fn upload_texture(&mut self, id: AtlasTextureId, bounds: Bounds<DevicePixels>, bytes: &[u8]) {
249        let data = self
250            .storage
251            .get(id)
252            .map(|texture| swizzle_upload_data(bytes, texture.format))
253            .unwrap_or_else(|| bytes.to_vec());
254
255        self.pending_uploads
256            .push(PendingUpload { id, bounds, data });
257    }
258
259    fn flush_uploads(&mut self) {
260        for upload in self.pending_uploads.drain(..) {
261            let Some(texture) = self.storage.get(upload.id) else {
262                continue;
263            };
264            let bytes_per_pixel = texture.bytes_per_pixel();
265
266            self.queue.write_texture(
267                wgpu::TexelCopyTextureInfo {
268                    texture: &texture.texture,
269                    mip_level: 0,
270                    origin: wgpu::Origin3d {
271                        x: upload.bounds.origin.x.0 as u32,
272                        y: upload.bounds.origin.y.0 as u32,
273                        z: 0,
274                    },
275                    aspect: wgpu::TextureAspect::All,
276                },
277                &upload.data,
278                wgpu::TexelCopyBufferLayout {
279                    offset: 0,
280                    bytes_per_row: Some(upload.bounds.size.width.0 as u32 * bytes_per_pixel as u32),
281                    rows_per_image: None,
282                },
283                wgpu::Extent3d {
284                    width: upload.bounds.size.width.0 as u32,
285                    height: upload.bounds.size.height.0 as u32,
286                    depth_or_array_layers: 1,
287                },
288            );
289        }
290    }
291}
292
293#[derive(Default)]
294struct WgpuAtlasStorage {
295    monochrome_textures: AtlasTextureList<WgpuAtlasTexture>,
296    subpixel_textures: AtlasTextureList<WgpuAtlasTexture>,
297    polychrome_textures: AtlasTextureList<WgpuAtlasTexture>,
298}
299
300impl ops::Index<AtlasTextureKind> for WgpuAtlasStorage {
301    type Output = AtlasTextureList<WgpuAtlasTexture>;
302    fn index(&self, kind: AtlasTextureKind) -> &Self::Output {
303        match kind {
304            AtlasTextureKind::Monochrome => &self.monochrome_textures,
305            AtlasTextureKind::Subpixel => &self.subpixel_textures,
306            AtlasTextureKind::Polychrome => &self.polychrome_textures,
307        }
308    }
309}
310
311impl ops::IndexMut<AtlasTextureKind> for WgpuAtlasStorage {
312    fn index_mut(&mut self, kind: AtlasTextureKind) -> &mut Self::Output {
313        match kind {
314            AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
315            AtlasTextureKind::Subpixel => &mut self.subpixel_textures,
316            AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
317        }
318    }
319}
320
321impl WgpuAtlasStorage {
322    fn get(&self, id: AtlasTextureId) -> Option<&WgpuAtlasTexture> {
323        self[id.kind]
324            .textures
325            .get(id.index as usize)
326            .and_then(|t| t.as_ref())
327    }
328}
329
330impl ops::Index<AtlasTextureId> for WgpuAtlasStorage {
331    type Output = WgpuAtlasTexture;
332    fn index(&self, id: AtlasTextureId) -> &Self::Output {
333        let textures = match id.kind {
334            AtlasTextureKind::Monochrome => &self.monochrome_textures,
335            AtlasTextureKind::Subpixel => &self.subpixel_textures,
336            AtlasTextureKind::Polychrome => &self.polychrome_textures,
337        };
338        textures[id.index as usize]
339            .as_ref()
340            .expect("texture must exist")
341    }
342}
343
344struct WgpuAtlasTexture {
345    id: AtlasTextureId,
346    allocator: BucketedAtlasAllocator,
347    texture: wgpu::Texture,
348    view: wgpu::TextureView,
349    format: wgpu::TextureFormat,
350    live_atlas_keys: u32,
351}
352
353impl WgpuAtlasTexture {
354    fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> {
355        let allocation = self.allocator.allocate(device_size_to_etagere(size))?;
356        let tile = AtlasTile {
357            texture_id: self.id,
358            tile_id: allocation.id.into(),
359            padding: 0,
360            bounds: Bounds {
361                origin: etagere_point_to_device(allocation.rectangle.min),
362                size,
363            },
364        };
365        self.live_atlas_keys += 1;
366        Some(tile)
367    }
368
369    fn bytes_per_pixel(&self) -> u8 {
370        match self.format {
371            wgpu::TextureFormat::R8Unorm => 1,
372            wgpu::TextureFormat::Bgra8Unorm | wgpu::TextureFormat::Rgba8Unorm => 4,
373            _ => 4,
374        }
375    }
376
377    fn decrement_ref_count(&mut self) {
378        self.live_atlas_keys -= 1;
379    }
380
381    fn is_unreferenced(&self) -> bool {
382        self.live_atlas_keys == 0
383    }
384}
385
386fn swizzle_upload_data(bytes: &[u8], format: wgpu::TextureFormat) -> Vec<u8> {
387    match format {
388        wgpu::TextureFormat::Rgba8Unorm => {
389            let mut data = bytes.to_vec();
390            for pixel in data.chunks_exact_mut(4) {
391                pixel.swap(0, 2);
392            }
393            data
394        }
395        _ => bytes.to_vec(),
396    }
397}
398
399#[cfg(all(test, not(target_family = "wasm")))]
400mod tests {
401    use super::*;
402    use gpui::{ImageId, RenderImageParams};
403    use pollster::block_on;
404    use std::sync::Arc;
405
406    fn test_device_and_queue() -> anyhow::Result<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
407        block_on(async {
408            let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
409                backends: wgpu::Backends::all(),
410                flags: wgpu::InstanceFlags::default(),
411                backend_options: wgpu::BackendOptions::default(),
412                memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
413                display: None,
414            });
415            let adapter = instance
416                .request_adapter(&wgpu::RequestAdapterOptions {
417                    power_preference: wgpu::PowerPreference::LowPower,
418                    compatible_surface: None,
419                    force_fallback_adapter: false,
420                })
421                .await
422                .map_err(|error| anyhow::anyhow!("failed to request adapter: {error}"))?;
423            let (device, queue) = adapter
424                .request_device(&wgpu::DeviceDescriptor {
425                    label: Some("wgpu_atlas_test_device"),
426                    required_features: wgpu::Features::empty(),
427                    required_limits: wgpu::Limits::downlevel_defaults()
428                        .using_resolution(adapter.limits())
429                        .using_alignment(adapter.limits()),
430                    memory_hints: wgpu::MemoryHints::MemoryUsage,
431                    trace: wgpu::Trace::Off,
432                    experimental_features: wgpu::ExperimentalFeatures::disabled(),
433                })
434                .await
435                .map_err(|error| anyhow::anyhow!("failed to request device: {error}"))?;
436            Ok((Arc::new(device), Arc::new(queue)))
437        })
438    }
439
440    #[test]
441    fn before_frame_skips_uploads_for_removed_texture() -> anyhow::Result<()> {
442        let (device, queue) = test_device_and_queue()?;
443
444        let atlas = WgpuAtlas::new(device, queue, wgpu::TextureFormat::Bgra8Unorm);
445        let key = AtlasKey::Image(RenderImageParams {
446            image_id: ImageId(1),
447            frame_index: 0,
448        });
449        let size = Size {
450            width: DevicePixels(1),
451            height: DevicePixels(1),
452        };
453        let mut build = || Ok(Some((size, Cow::Owned(vec![0, 0, 0, 255]))));
454
455        // Regression test: before the fix, this panicked in flush_uploads
456        atlas
457            .get_or_insert_with(&key, &mut build)?
458            .expect("tile should be created");
459        atlas.remove(&key);
460        atlas.before_frame();
461        Ok(())
462    }
463
464    #[test]
465    fn swizzle_upload_data_preserves_bgra_uploads() {
466        let input = vec![0x10, 0x20, 0x30, 0x40];
467        assert_eq!(
468            swizzle_upload_data(&input, wgpu::TextureFormat::Bgra8Unorm),
469            input
470        );
471    }
472
473    #[test]
474    fn swizzle_upload_data_converts_bgra_to_rgba() {
475        let input = vec![0x10, 0x20, 0x30, 0x40, 0xAA, 0xBB, 0xCC, 0xDD];
476        assert_eq!(
477            swizzle_upload_data(&input, wgpu::TextureFormat::Rgba8Unorm),
478            vec![0x30, 0x20, 0x10, 0x40, 0xCC, 0xBB, 0xAA, 0xDD]
479        );
480    }
481}