rate_limiter.rs

 1use anyhow::Result;
 2use futures::Stream;
 3use smol::lock::{Semaphore, SemaphoreGuardArc};
 4use std::{
 5    future::Future,
 6    pin::Pin,
 7    sync::Arc,
 8    task::{Context, Poll},
 9};
10
11#[derive(Clone)]
12pub struct RateLimiter {
13    semaphore: Arc<Semaphore>,
14}
15
16pub struct RateLimitGuard<T> {
17    inner: T,
18    _guard: SemaphoreGuardArc,
19}
20
21impl<T> Stream for RateLimitGuard<T>
22where
23    T: Stream,
24{
25    type Item = T::Item;
26
27    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
28        unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
29    }
30}
31
32impl RateLimiter {
33    pub fn new(limit: usize) -> Self {
34        Self {
35            semaphore: Arc::new(Semaphore::new(limit)),
36        }
37    }
38
39    pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
40    where
41        Fut: 'a + Future<Output = Result<T>>,
42    {
43        let guard = self.semaphore.acquire_arc();
44        async move {
45            let guard = guard.await;
46            let result = future.await?;
47            drop(guard);
48            Ok(result)
49        }
50    }
51
52    pub fn stream<'a, Fut, T>(
53        &self,
54        future: Fut,
55    ) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item> + use<Fut, T>>>
56    where
57        Fut: 'a + Future<Output = Result<T>>,
58        T: Stream,
59    {
60        let guard = self.semaphore.acquire_arc();
61        async move {
62            let guard = guard.await;
63            let inner = future.await?;
64            Ok(RateLimitGuard {
65                inner,
66                _guard: guard,
67            })
68        }
69    }
70}