1use anyhow::Result;
2use futures::Stream;
3use smol::lock::{Semaphore, SemaphoreGuardArc};
4use std::{
5 future::Future,
6 pin::Pin,
7 sync::Arc,
8 task::{Context, Poll},
9};
10
11use crate::RequestUsage;
12
13#[derive(Clone)]
14pub struct RateLimiter {
15 semaphore: Arc<Semaphore>,
16}
17
18pub struct RateLimitGuard<T> {
19 inner: T,
20 _guard: SemaphoreGuardArc,
21}
22
23impl<T> Stream for RateLimitGuard<T>
24where
25 T: Stream,
26{
27 type Item = T::Item;
28
29 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
30 unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
31 }
32}
33
34impl RateLimiter {
35 pub fn new(limit: usize) -> Self {
36 Self {
37 semaphore: Arc::new(Semaphore::new(limit)),
38 }
39 }
40
41 pub fn run<'a, Fut, T>(&self, future: Fut) -> impl 'a + Future<Output = Result<T>>
42 where
43 Fut: 'a + Future<Output = Result<T>>,
44 {
45 let guard = self.semaphore.acquire_arc();
46 async move {
47 let guard = guard.await;
48 let result = future.await?;
49 drop(guard);
50 Ok(result)
51 }
52 }
53
54 pub fn stream<'a, Fut, T>(
55 &self,
56 future: Fut,
57 ) -> impl 'a + Future<Output = Result<impl Stream<Item = T::Item> + use<Fut, T>>>
58 where
59 Fut: 'a + Future<Output = Result<T>>,
60 T: Stream,
61 {
62 let guard = self.semaphore.acquire_arc();
63 async move {
64 let guard = guard.await;
65 let inner = future.await?;
66 Ok(RateLimitGuard {
67 inner,
68 _guard: guard,
69 })
70 }
71 }
72
73 pub fn stream_with_usage<'a, Fut, T>(
74 &self,
75 future: Fut,
76 ) -> impl 'a
77 + Future<
78 Output = Result<(
79 impl Stream<Item = T::Item> + use<Fut, T>,
80 Option<RequestUsage>,
81 )>,
82 >
83 where
84 Fut: 'a + Future<Output = Result<(T, Option<RequestUsage>)>>,
85 T: Stream,
86 {
87 let guard = self.semaphore.acquire_arc();
88 async move {
89 let guard = guard.await;
90 let (inner, usage) = future.await?;
91 Ok((
92 RateLimitGuard {
93 inner,
94 _guard: guard,
95 },
96 usage,
97 ))
98 }
99 }
100}