1use futures::Stream;
2use smol::lock::{Semaphore, SemaphoreGuardArc};
3use std::{
4 future::Future,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use crate::LanguageModelCompletionError;
11
12#[derive(Clone)]
13pub struct RateLimiter {
14 semaphore: Arc<Semaphore>,
15}
16
17pub struct RateLimitGuard<T> {
18 inner: T,
19 _guard: SemaphoreGuardArc,
20}
21
22impl<T> Stream for RateLimitGuard<T>
23where
24 T: Stream,
25{
26 type Item = T::Item;
27
28 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
29 unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
30 }
31}
32
33impl RateLimiter {
34 pub fn new(limit: usize) -> Self {
35 Self {
36 semaphore: Arc::new(Semaphore::new(limit)),
37 }
38 }
39
40 pub fn run<'a, Fut, T>(
41 &self,
42 future: Fut,
43 ) -> impl 'a + Future<Output = Result<T, LanguageModelCompletionError>>
44 where
45 Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
46 {
47 let guard = self.semaphore.acquire_arc();
48 async move {
49 let guard = guard.await;
50 let result = future.await?;
51 drop(guard);
52 Ok(result)
53 }
54 }
55
56 pub fn stream<'a, Fut, T>(
57 &self,
58 future: Fut,
59 ) -> impl 'a
60 + Future<
61 Output = Result<impl Stream<Item = T::Item> + use<Fut, T>, LanguageModelCompletionError>,
62 >
63 where
64 Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
65 T: Stream,
66 {
67 let guard = self.semaphore.acquire_arc();
68 async move {
69 let guard = guard.await;
70 let inner = future.await?;
71 Ok(RateLimitGuard {
72 inner,
73 _guard: guard,
74 })
75 }
76 }
77}