rate_limiter.rs

  1use anyhow::Result;
  2use futures::Stream;
  3use smol::lock::{Semaphore, SemaphoreGuardArc};
  4use std::{
  5    future::Future,
  6    pin::Pin,
  7    sync::Arc,
  8    task::{Context, Poll},
  9};
 10
 11use crate::RequestUsage;
 12
 13#[derive(Clone)]
 14pub struct RateLimiter {
 15    semaphore: Arc<Semaphore>,
 16}
 17
 18pub struct RateLimitGuard<T> {
 19    inner: T,
 20    _guard: SemaphoreGuardArc,
 21}
 22
 23impl<T> Stream for RateLimitGuard<T>
 24where
 25    T: Stream,
 26{
 27    type Item = T::Item;
 28
 29    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
 30        unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
 31    }
 32}
 33
 34impl RateLimiter {
 35    pub fn new(limit: usize) -> Self {
 36        Self {
 37            semaphore: Arc::new(Semaphore::new(limit)),
 38        }
 39    }
 40
 41    pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
 42    where
 43        Fut: 'a + Future<Output = Result<T>>,
 44    {
 45        let guard = self.semaphore.acquire_arc();
 46        async move {
 47            let guard = guard.await;
 48            let result = future.await?;
 49            drop(guard);
 50            Ok(result)
 51        }
 52    }
 53
 54    pub fn stream<'a, Fut, T>(
 55        &self,
 56        future: Fut,
 57    ) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item> + use<Fut, T>>>
 58    where
 59        Fut: 'a + Future<Output = Result<T>>,
 60        T: Stream,
 61    {
 62        let guard = self.semaphore.acquire_arc();
 63        async move {
 64            let guard = guard.await;
 65            let inner = future.await?;
 66            Ok(RateLimitGuard {
 67                inner,
 68                _guard: guard,
 69            })
 70        }
 71    }
 72
 73    pub fn stream_with_usage<'a, Fut, T>(
 74        &self,
 75        future: Fut,
 76    ) -> impl 'a
 77    + Future<
 78        Output = Result<(
 79            impl Stream<Item = T::Item> + use<Fut, T>,
 80            Option<RequestUsage>,
 81        )>,
 82    >
 83    where
 84        Fut: 'a + Future<Output = Result<(T, Option<RequestUsage>)>>,
 85        T: Stream,
 86    {
 87        let guard = self.semaphore.acquire_arc();
 88        async move {
 89            let guard = guard.await;
 90            let (inner, usage) = future.await?;
 91            Ok((
 92                RateLimitGuard {
 93                    inner,
 94                    _guard: guard,
 95                },
 96                usage,
 97            ))
 98        }
 99    }
100}