1use futures::Stream;
 2use smol::lock::{Semaphore, SemaphoreGuardArc};
 3use std::{
 4    future::Future,
 5    pin::Pin,
 6    sync::Arc,
 7    task::{Context, Poll},
 8};
 9
10use crate::LanguageModelCompletionError;
11
12#[derive(Clone)]
13pub struct RateLimiter {
14    semaphore: Arc<Semaphore>,
15}
16
17pub struct RateLimitGuard<T> {
18    inner: T,
19    _guard: SemaphoreGuardArc,
20}
21
22impl<T> Stream for RateLimitGuard<T>
23where
24    T: Stream,
25{
26    type Item = T::Item;
27
28    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
29        unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
30    }
31}
32
33impl RateLimiter {
34    pub fn new(limit: usize) -> Self {
35        Self {
36            semaphore: Arc::new(Semaphore::new(limit)),
37        }
38    }
39
40    pub fn run<'a, Fut, T>(
41        &self,
42        future: Fut,
43    ) -> impl 'a + Future<Output = Result<T, LanguageModelCompletionError>>
44    where
45        Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
46    {
47        let guard = self.semaphore.acquire_arc();
48        async move {
49            let guard = guard.await;
50            let result = future.await?;
51            drop(guard);
52            Ok(result)
53        }
54    }
55
56    pub fn stream<'a, Fut, T>(
57        &self,
58        future: Fut,
59    ) -> impl 'a
60    + Future<
61        Output = Result<impl Stream<Item = T::Item> + use<Fut, T>, LanguageModelCompletionError>,
62    >
63    where
64        Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
65        T: Stream,
66    {
67        let guard = self.semaphore.acquire_arc();
68        async move {
69            let guard = guard.await;
70            let inner = future.await?;
71            Ok(RateLimitGuard {
72                inner,
73                _guard: guard,
74            })
75        }
76    }
77}