1use anyhow::Result;
2use futures::Stream;
3use smol::lock::{Semaphore, SemaphoreGuardArc};
4use std::{
5 future::Future,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9};
10
11#[derive(Clone)]
12pub struct RateLimiter {
13 semaphore: Arc<Semaphore>,
14}
15
16pub struct RateLimitGuard<T> {
17 inner: T,
18 _guard: SemaphoreGuardArc,
19}
20
21impl<T> Stream for RateLimitGuard<T>
22where
23 T: Stream,
24{
25 type Item = T::Item;
26
27 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
28 unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
29 }
30}
31
32impl RateLimiter {
33 pub fn new(limit: usize) -> Self {
34 Self {
35 semaphore: Arc::new(Semaphore::new(limit)),
36 }
37 }
38
39 pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
40 where
41 Fut: 'a + Future<Output = Result<T>>,
42 {
43 let guard = self.semaphore.acquire_arc();
44 async move {
45 let guard = guard.await;
46 let result = future.await?;
47 drop(guard);
48 Ok(result)
49 }
50 }
51
52 pub fn stream<'a, Fut, T>(
53 &self,
54 future: Fut,
55 ) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item> + use<Fut, T>>>
56 where
57 Fut: 'a + Future<Output = Result<T>>,
58 T: Stream,
59 {
60 let guard = self.semaphore.acquire_arc();
61 async move {
62 let guard = guard.await;
63 let inner = future.await?;
64 Ok(RateLimitGuard {
65 inner,
66 _guard: guard,
67 })
68 }
69 }
70}