rate_limiter.rs

  1use futures::Stream;
  2use smol::lock::{Semaphore, SemaphoreGuardArc};
  3use std::{
  4    future::Future,
  5    pin::Pin,
  6    sync::Arc,
  7    task::{Context, Poll},
  8};
  9
 10use crate::LanguageModelCompletionError;
 11
 12#[derive(Clone)]
 13pub struct RateLimiter {
 14    semaphore: Arc<Semaphore>,
 15}
 16
 17pub struct RateLimitGuard<T> {
 18    inner: T,
 19    _guard: Option<SemaphoreGuardArc>,
 20}
 21
 22impl<T> Stream for RateLimitGuard<T>
 23where
 24    T: Stream,
 25{
 26    type Item = T::Item;
 27
 28    fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
 29        unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
 30    }
 31}
 32
 33impl RateLimiter {
 34    pub fn new(limit: usize) -> Self {
 35        Self {
 36            semaphore: Arc::new(Semaphore::new(limit)),
 37        }
 38    }
 39
 40    pub fn run<'a, Fut, T>(
 41        &self,
 42        future: Fut,
 43    ) -> impl 'a + Future<Output = Result<T, LanguageModelCompletionError>>
 44    where
 45        Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
 46    {
 47        let guard = self.semaphore.acquire_arc();
 48        async move {
 49            let guard = guard.await;
 50            let result = future.await?;
 51            drop(guard);
 52            Ok(result)
 53        }
 54    }
 55
 56    pub fn stream<'a, Fut, T>(
 57        &self,
 58        future: Fut,
 59    ) -> impl 'a
 60    + Future<
 61        Output = Result<impl Stream<Item = T::Item> + use<Fut, T>, LanguageModelCompletionError>,
 62    >
 63    where
 64        Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
 65        T: Stream,
 66    {
 67        let guard = self.semaphore.acquire_arc();
 68        async move {
 69            let guard = guard.await;
 70            let inner = future.await?;
 71            Ok(RateLimitGuard {
 72                inner,
 73                _guard: Some(guard),
 74            })
 75        }
 76    }
 77
 78    /// Like `stream`, but conditionally bypasses the rate limiter based on the flag.
 79    /// Used for nested requests (like edit agent requests) that are already "part of"
 80    /// a rate-limited request to avoid deadlocks.
 81    pub fn stream_with_bypass<'a, Fut, T>(
 82        &self,
 83        future: Fut,
 84        bypass: bool,
 85    ) -> impl 'a
 86    + Future<
 87        Output = Result<impl Stream<Item = T::Item> + use<Fut, T>, LanguageModelCompletionError>,
 88    >
 89    where
 90        Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
 91        T: Stream,
 92    {
 93        let semaphore = self.semaphore.clone();
 94        async move {
 95            let guard = if bypass {
 96                None
 97            } else {
 98                Some(semaphore.acquire_arc().await)
 99            };
100            let inner = future.await?;
101            Ok(RateLimitGuard {
102                inner,
103                _guard: guard,
104            })
105        }
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use futures::stream;
113    use smol::lock::Barrier;
114    use std::sync::Arc;
115    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
116    use std::time::{Duration, Instant};
117
118    /// Tests that nested requests without bypass_rate_limit cause deadlock,
119    /// while requests with bypass_rate_limit complete successfully.
120    ///
121    /// This test simulates the scenario where multiple "parent" requests each
122    /// try to spawn a "nested" request (like edit_file tool spawning an edit agent).
123    /// With a rate limit of 2 and 2 parent requests, without bypass the nested
124    /// requests would block forever waiting for permits that the parents hold.
125    #[test]
126    fn test_nested_requests_bypass_prevents_deadlock() {
127        smol::block_on(async {
128            // Use only 2 permits so we can guarantee deadlock conditions
129            let rate_limiter = RateLimiter::new(2);
130            let completed = Arc::new(AtomicUsize::new(0));
131            // Barrier ensures all parents acquire permits before any tries nested request
132            let barrier = Arc::new(Barrier::new(2));
133
134            // Spawn 2 "parent" requests that each try to make a "nested" request
135            let mut handles = Vec::new();
136            for _ in 0..2 {
137                let limiter = rate_limiter.clone();
138                let completed = completed.clone();
139                let barrier = barrier.clone();
140
141                let handle = smol::spawn(async move {
142                    // Parent request acquires a permit via stream_with_bypass (bypass=false)
143                    let parent_stream = limiter
144                        .stream_with_bypass(
145                            async {
146                                // Wait for all parents to acquire permits
147                                barrier.wait().await;
148
149                                // While holding the parent permit, make a nested request
150                                // WITH bypass=true (simulating EditAgent behavior)
151                                let nested_stream = limiter
152                                    .stream_with_bypass(
153                                        async { Ok(stream::iter(vec![1, 2, 3])) },
154                                        true, // bypass - this is the key!
155                                    )
156                                    .await?;
157
158                                // Consume the nested stream
159                                use futures::StreamExt;
160                                let _: Vec<_> = nested_stream.collect().await;
161
162                                Ok(stream::iter(vec!["done"]))
163                            },
164                            false, // parent does NOT bypass
165                        )
166                        .await
167                        .unwrap();
168
169                    // Consume parent stream
170                    use futures::StreamExt;
171                    let _: Vec<_> = parent_stream.collect().await;
172
173                    completed.fetch_add(1, Ordering::SeqCst);
174                });
175                handles.push(handle);
176            }
177
178            // With bypass=true for nested requests, this should complete quickly
179            let timed_out = Arc::new(AtomicBool::new(false));
180            let timed_out_clone = timed_out.clone();
181
182            // Spawn a watchdog that sets timed_out after 2 seconds
183            let watchdog = smol::spawn(async move {
184                let start = Instant::now();
185                while start.elapsed() < Duration::from_secs(2) {
186                    smol::future::yield_now().await;
187                }
188                timed_out_clone.store(true, Ordering::SeqCst);
189            });
190
191            // Wait for all handles to complete
192            for handle in handles {
193                handle.await;
194            }
195
196            // Cancel the watchdog
197            drop(watchdog);
198
199            if timed_out.load(Ordering::SeqCst) {
200                panic!(
201                    "Test timed out - deadlock detected! This means bypass_rate_limit is not working."
202                );
203            }
204            assert_eq!(completed.load(Ordering::SeqCst), 2);
205        });
206    }
207
208    /// Tests that without bypass, nested requests DO cause deadlock.
209    /// This test verifies the problem exists when bypass is not used.
210    #[test]
211    fn test_nested_requests_without_bypass_deadlocks() {
212        smol::block_on(async {
213            // Use only 2 permits so we can guarantee deadlock conditions
214            let rate_limiter = RateLimiter::new(2);
215            let completed = Arc::new(AtomicUsize::new(0));
216            // Barrier ensures all parents acquire permits before any tries nested request
217            let barrier = Arc::new(Barrier::new(2));
218
219            // Spawn 2 "parent" requests that each try to make a "nested" request
220            let mut handles = Vec::new();
221            for _ in 0..2 {
222                let limiter = rate_limiter.clone();
223                let completed = completed.clone();
224                let barrier = barrier.clone();
225
226                let handle = smol::spawn(async move {
227                    // Parent request acquires a permit
228                    let parent_stream = limiter
229                        .stream_with_bypass(
230                            async {
231                                // Wait for all parents to acquire permits - this guarantees
232                                // that all 2 permits are held before any nested request starts
233                                barrier.wait().await;
234
235                                // Nested request WITHOUT bypass - this will deadlock!
236                                // Both parents hold permits, so no permits available
237                                let nested_stream = limiter
238                                    .stream_with_bypass(
239                                        async { Ok(stream::iter(vec![1, 2, 3])) },
240                                        false, // NO bypass - will try to acquire permit
241                                    )
242                                    .await?;
243
244                                use futures::StreamExt;
245                                let _: Vec<_> = nested_stream.collect().await;
246
247                                Ok(stream::iter(vec!["done"]))
248                            },
249                            false,
250                        )
251                        .await
252                        .unwrap();
253
254                    use futures::StreamExt;
255                    let _: Vec<_> = parent_stream.collect().await;
256
257                    completed.fetch_add(1, Ordering::SeqCst);
258                });
259                handles.push(handle);
260            }
261
262            // This SHOULD timeout because of deadlock (both parents hold permits,
263            // both nested requests wait for permits)
264            let timed_out = Arc::new(AtomicBool::new(false));
265            let timed_out_clone = timed_out.clone();
266
267            // Spawn a watchdog that sets timed_out after 100ms
268            let watchdog = smol::spawn(async move {
269                let start = Instant::now();
270                while start.elapsed() < Duration::from_millis(100) {
271                    smol::future::yield_now().await;
272                }
273                timed_out_clone.store(true, Ordering::SeqCst);
274            });
275
276            // Poll briefly to let everything run
277            let start = Instant::now();
278            while start.elapsed() < Duration::from_millis(100) {
279                smol::future::yield_now().await;
280            }
281
282            // Cancel the watchdog
283            drop(watchdog);
284
285            // Expected - deadlock occurred, which proves the bypass is necessary
286            let count = completed.load(Ordering::SeqCst);
287            assert_eq!(
288                count, 0,
289                "Expected complete deadlock (0 completed) but {} requests completed",
290                count
291            );
292        });
293    }
294}