wgpu_atlas.rs

  1use anyhow::{Context as _, Result};
  2use collections::FxHashMap;
  3use etagere::{BucketedAtlasAllocator, size2};
  4use gpui::{
  5    AtlasKey, AtlasTextureId, AtlasTextureKind, AtlasTextureList, AtlasTile, Bounds, DevicePixels,
  6    PlatformAtlas, Point, Size,
  7};
  8use parking_lot::Mutex;
  9use std::{borrow::Cow, ops, sync::Arc};
 10
 11fn device_size_to_etagere(size: Size<DevicePixels>) -> etagere::Size {
 12    size2(size.width.0, size.height.0)
 13}
 14
 15fn etagere_point_to_device(point: etagere::Point) -> Point<DevicePixels> {
 16    Point {
 17        x: DevicePixels(point.x),
 18        y: DevicePixels(point.y),
 19    }
 20}
 21
 22pub struct WgpuAtlas(Mutex<WgpuAtlasState>);
 23
 24struct PendingUpload {
 25    id: AtlasTextureId,
 26    bounds: Bounds<DevicePixels>,
 27    data: Vec<u8>,
 28}
 29
 30struct WgpuAtlasState {
 31    device: Arc<wgpu::Device>,
 32    queue: Arc<wgpu::Queue>,
 33    max_texture_size: u32,
 34    storage: WgpuAtlasStorage,
 35    tiles_by_key: FxHashMap<AtlasKey, AtlasTile>,
 36    pending_uploads: Vec<PendingUpload>,
 37}
 38
 39pub struct WgpuTextureInfo {
 40    pub view: wgpu::TextureView,
 41}
 42
 43impl WgpuAtlas {
 44    pub fn new(device: Arc<wgpu::Device>, queue: Arc<wgpu::Queue>) -> Self {
 45        let max_texture_size = device.limits().max_texture_dimension_2d;
 46        WgpuAtlas(Mutex::new(WgpuAtlasState {
 47            device,
 48            queue,
 49            max_texture_size,
 50            storage: WgpuAtlasStorage::default(),
 51            tiles_by_key: Default::default(),
 52            pending_uploads: Vec::new(),
 53        }))
 54    }
 55
 56    pub fn before_frame(&self) {
 57        let mut lock = self.0.lock();
 58        lock.flush_uploads();
 59    }
 60
 61    pub fn get_texture_info(&self, id: AtlasTextureId) -> WgpuTextureInfo {
 62        let lock = self.0.lock();
 63        let texture = &lock.storage[id];
 64        WgpuTextureInfo {
 65            view: texture.view.clone(),
 66        }
 67    }
 68
 69    /// Handles device lost by clearing all textures and cached tiles.
 70    /// The atlas will lazily recreate textures as needed on subsequent frames.
 71    pub fn handle_device_lost(&self, device: Arc<wgpu::Device>, queue: Arc<wgpu::Queue>) {
 72        let mut lock = self.0.lock();
 73        lock.device = device;
 74        lock.queue = queue;
 75        lock.storage = WgpuAtlasStorage::default();
 76        lock.tiles_by_key.clear();
 77        lock.pending_uploads.clear();
 78    }
 79}
 80
 81impl PlatformAtlas for WgpuAtlas {
 82    fn get_or_insert_with<'a>(
 83        &self,
 84        key: &AtlasKey,
 85        build: &mut dyn FnMut() -> Result<Option<(Size<DevicePixels>, Cow<'a, [u8]>)>>,
 86    ) -> Result<Option<AtlasTile>> {
 87        let mut lock = self.0.lock();
 88        if let Some(tile) = lock.tiles_by_key.get(key) {
 89            Ok(Some(tile.clone()))
 90        } else {
 91            profiling::scope!("new tile");
 92            let Some((size, bytes)) = build()? else {
 93                return Ok(None);
 94            };
 95            let tile = lock
 96                .allocate(size, key.texture_kind())
 97                .context("failed to allocate")?;
 98            lock.upload_texture(tile.texture_id, tile.bounds, &bytes);
 99            lock.tiles_by_key.insert(key.clone(), tile.clone());
100            Ok(Some(tile))
101        }
102    }
103
104    fn remove(&self, key: &AtlasKey) {
105        let mut lock = self.0.lock();
106
107        let Some(id) = lock.tiles_by_key.remove(key).map(|tile| tile.texture_id) else {
108            return;
109        };
110
111        let Some(texture_slot) = lock.storage[id.kind].textures.get_mut(id.index as usize) else {
112            return;
113        };
114
115        if let Some(mut texture) = texture_slot.take() {
116            texture.decrement_ref_count();
117            if texture.is_unreferenced() {
118                lock.pending_uploads
119                    .retain(|upload| upload.id != texture.id);
120                lock.storage[id.kind]
121                    .free_list
122                    .push(texture.id.index as usize);
123            } else {
124                *texture_slot = Some(texture);
125            }
126        }
127    }
128}
129
130impl WgpuAtlasState {
131    fn allocate(
132        &mut self,
133        size: Size<DevicePixels>,
134        texture_kind: AtlasTextureKind,
135    ) -> Option<AtlasTile> {
136        {
137            let textures = &mut self.storage[texture_kind];
138
139            if let Some(tile) = textures
140                .iter_mut()
141                .rev()
142                .find_map(|texture| texture.allocate(size))
143            {
144                return Some(tile);
145            }
146        }
147
148        let texture = self.push_texture(size, texture_kind);
149        texture.allocate(size)
150    }
151
152    fn push_texture(
153        &mut self,
154        min_size: Size<DevicePixels>,
155        kind: AtlasTextureKind,
156    ) -> &mut WgpuAtlasTexture {
157        const DEFAULT_ATLAS_SIZE: Size<DevicePixels> = Size {
158            width: DevicePixels(1024),
159            height: DevicePixels(1024),
160        };
161        let max_texture_size = self.max_texture_size as i32;
162        let max_atlas_size = Size {
163            width: DevicePixels(max_texture_size),
164            height: DevicePixels(max_texture_size),
165        };
166
167        let size = min_size.min(&max_atlas_size).max(&DEFAULT_ATLAS_SIZE);
168        let format = match kind {
169            AtlasTextureKind::Monochrome => wgpu::TextureFormat::R8Unorm,
170            AtlasTextureKind::Subpixel => wgpu::TextureFormat::Bgra8Unorm,
171            AtlasTextureKind::Polychrome => wgpu::TextureFormat::Bgra8Unorm,
172        };
173
174        let texture = self.device.create_texture(&wgpu::TextureDescriptor {
175            label: Some("atlas"),
176            size: wgpu::Extent3d {
177                width: size.width.0 as u32,
178                height: size.height.0 as u32,
179                depth_or_array_layers: 1,
180            },
181            mip_level_count: 1,
182            sample_count: 1,
183            dimension: wgpu::TextureDimension::D2,
184            format,
185            usage: wgpu::TextureUsages::TEXTURE_BINDING | wgpu::TextureUsages::COPY_DST,
186            view_formats: &[],
187        });
188
189        let view = texture.create_view(&wgpu::TextureViewDescriptor::default());
190
191        let texture_list = &mut self.storage[kind];
192        let index = texture_list.free_list.pop();
193
194        let atlas_texture = WgpuAtlasTexture {
195            id: AtlasTextureId {
196                index: index.unwrap_or(texture_list.textures.len()) as u32,
197                kind,
198            },
199            allocator: BucketedAtlasAllocator::new(device_size_to_etagere(size)),
200            format,
201            texture,
202            view,
203            live_atlas_keys: 0,
204        };
205
206        if let Some(ix) = index {
207            texture_list.textures[ix] = Some(atlas_texture);
208            texture_list
209                .textures
210                .get_mut(ix)
211                .and_then(|t| t.as_mut())
212                .expect("texture must exist")
213        } else {
214            texture_list.textures.push(Some(atlas_texture));
215            texture_list
216                .textures
217                .last_mut()
218                .and_then(|t| t.as_mut())
219                .expect("texture must exist")
220        }
221    }
222
223    fn upload_texture(&mut self, id: AtlasTextureId, bounds: Bounds<DevicePixels>, bytes: &[u8]) {
224        self.pending_uploads.push(PendingUpload {
225            id,
226            bounds,
227            data: bytes.to_vec(),
228        });
229    }
230
231    fn flush_uploads(&mut self) {
232        for upload in self.pending_uploads.drain(..) {
233            let Some(texture) = self.storage.get(upload.id) else {
234                continue;
235            };
236            let bytes_per_pixel = texture.bytes_per_pixel();
237
238            self.queue.write_texture(
239                wgpu::TexelCopyTextureInfo {
240                    texture: &texture.texture,
241                    mip_level: 0,
242                    origin: wgpu::Origin3d {
243                        x: upload.bounds.origin.x.0 as u32,
244                        y: upload.bounds.origin.y.0 as u32,
245                        z: 0,
246                    },
247                    aspect: wgpu::TextureAspect::All,
248                },
249                &upload.data,
250                wgpu::TexelCopyBufferLayout {
251                    offset: 0,
252                    bytes_per_row: Some(upload.bounds.size.width.0 as u32 * bytes_per_pixel as u32),
253                    rows_per_image: None,
254                },
255                wgpu::Extent3d {
256                    width: upload.bounds.size.width.0 as u32,
257                    height: upload.bounds.size.height.0 as u32,
258                    depth_or_array_layers: 1,
259                },
260            );
261        }
262    }
263}
264
265#[derive(Default)]
266struct WgpuAtlasStorage {
267    monochrome_textures: AtlasTextureList<WgpuAtlasTexture>,
268    subpixel_textures: AtlasTextureList<WgpuAtlasTexture>,
269    polychrome_textures: AtlasTextureList<WgpuAtlasTexture>,
270}
271
272impl ops::Index<AtlasTextureKind> for WgpuAtlasStorage {
273    type Output = AtlasTextureList<WgpuAtlasTexture>;
274    fn index(&self, kind: AtlasTextureKind) -> &Self::Output {
275        match kind {
276            AtlasTextureKind::Monochrome => &self.monochrome_textures,
277            AtlasTextureKind::Subpixel => &self.subpixel_textures,
278            AtlasTextureKind::Polychrome => &self.polychrome_textures,
279        }
280    }
281}
282
283impl ops::IndexMut<AtlasTextureKind> for WgpuAtlasStorage {
284    fn index_mut(&mut self, kind: AtlasTextureKind) -> &mut Self::Output {
285        match kind {
286            AtlasTextureKind::Monochrome => &mut self.monochrome_textures,
287            AtlasTextureKind::Subpixel => &mut self.subpixel_textures,
288            AtlasTextureKind::Polychrome => &mut self.polychrome_textures,
289        }
290    }
291}
292
293impl WgpuAtlasStorage {
294    fn get(&self, id: AtlasTextureId) -> Option<&WgpuAtlasTexture> {
295        self[id.kind]
296            .textures
297            .get(id.index as usize)
298            .and_then(|t| t.as_ref())
299    }
300}
301
302impl ops::Index<AtlasTextureId> for WgpuAtlasStorage {
303    type Output = WgpuAtlasTexture;
304    fn index(&self, id: AtlasTextureId) -> &Self::Output {
305        let textures = match id.kind {
306            AtlasTextureKind::Monochrome => &self.monochrome_textures,
307            AtlasTextureKind::Subpixel => &self.subpixel_textures,
308            AtlasTextureKind::Polychrome => &self.polychrome_textures,
309        };
310        textures[id.index as usize]
311            .as_ref()
312            .expect("texture must exist")
313    }
314}
315
316struct WgpuAtlasTexture {
317    id: AtlasTextureId,
318    allocator: BucketedAtlasAllocator,
319    texture: wgpu::Texture,
320    view: wgpu::TextureView,
321    format: wgpu::TextureFormat,
322    live_atlas_keys: u32,
323}
324
325impl WgpuAtlasTexture {
326    fn allocate(&mut self, size: Size<DevicePixels>) -> Option<AtlasTile> {
327        let allocation = self.allocator.allocate(device_size_to_etagere(size))?;
328        let tile = AtlasTile {
329            texture_id: self.id,
330            tile_id: allocation.id.into(),
331            padding: 0,
332            bounds: Bounds {
333                origin: etagere_point_to_device(allocation.rectangle.min),
334                size,
335            },
336        };
337        self.live_atlas_keys += 1;
338        Some(tile)
339    }
340
341    fn bytes_per_pixel(&self) -> u8 {
342        match self.format {
343            wgpu::TextureFormat::R8Unorm => 1,
344            wgpu::TextureFormat::Bgra8Unorm => 4,
345            _ => 4,
346        }
347    }
348
349    fn decrement_ref_count(&mut self) {
350        self.live_atlas_keys -= 1;
351    }
352
353    fn is_unreferenced(&self) -> bool {
354        self.live_atlas_keys == 0
355    }
356}
357
358#[cfg(all(test, not(target_family = "wasm")))]
359mod tests {
360    use super::*;
361    use gpui::{ImageId, RenderImageParams};
362    use pollster::block_on;
363    use std::sync::Arc;
364
365    fn test_device_and_queue() -> anyhow::Result<(Arc<wgpu::Device>, Arc<wgpu::Queue>)> {
366        block_on(async {
367            let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
368                backends: wgpu::Backends::all(),
369                flags: wgpu::InstanceFlags::default(),
370                backend_options: wgpu::BackendOptions::default(),
371                memory_budget_thresholds: wgpu::MemoryBudgetThresholds::default(),
372                display: None,
373            });
374            let adapter = instance
375                .request_adapter(&wgpu::RequestAdapterOptions {
376                    power_preference: wgpu::PowerPreference::LowPower,
377                    compatible_surface: None,
378                    force_fallback_adapter: false,
379                })
380                .await
381                .map_err(|error| anyhow::anyhow!("failed to request adapter: {error}"))?;
382            let (device, queue) = adapter
383                .request_device(&wgpu::DeviceDescriptor {
384                    label: Some("wgpu_atlas_test_device"),
385                    required_features: wgpu::Features::empty(),
386                    required_limits: wgpu::Limits::downlevel_defaults()
387                        .using_resolution(adapter.limits())
388                        .using_alignment(adapter.limits()),
389                    memory_hints: wgpu::MemoryHints::MemoryUsage,
390                    trace: wgpu::Trace::Off,
391                    experimental_features: wgpu::ExperimentalFeatures::disabled(),
392                })
393                .await
394                .map_err(|error| anyhow::anyhow!("failed to request device: {error}"))?;
395            Ok((Arc::new(device), Arc::new(queue)))
396        })
397    }
398
399    #[test]
400    fn before_frame_skips_uploads_for_removed_texture() -> anyhow::Result<()> {
401        let (device, queue) = test_device_and_queue()?;
402
403        let atlas = WgpuAtlas::new(device, queue);
404        let key = AtlasKey::Image(RenderImageParams {
405            image_id: ImageId(1),
406            frame_index: 0,
407        });
408        let size = Size {
409            width: DevicePixels(1),
410            height: DevicePixels(1),
411        };
412        let mut build = || Ok(Some((size, Cow::Owned(vec![0, 0, 0, 255]))));
413
414        // Regression test: before the fix, this panicked in flush_uploads
415        atlas
416            .get_or_insert_with(&key, &mut build)?
417            .expect("tile should be created");
418        atlas.remove(&key);
419        atlas.before_frame();
420
421        Ok(())
422    }
423}