1use futures::Stream;
2use smol::lock::{Semaphore, SemaphoreGuardArc};
3use std::{
4 future::Future,
5 pin::Pin,
6 sync::Arc,
7 task::{Context, Poll},
8};
9
10use crate::LanguageModelCompletionError;
11
12#[derive(Clone)]
13pub struct RateLimiter {
14 semaphore: Arc<Semaphore>,
15}
16
17pub struct RateLimitGuard<T> {
18 inner: T,
19 _guard: Option<SemaphoreGuardArc>,
20}
21
22impl<T> Stream for RateLimitGuard<T>
23where
24 T: Stream,
25{
26 type Item = T::Item;
27
28 fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
29 unsafe { Pin::map_unchecked_mut(self, |this| &mut this.inner).poll_next(cx) }
30 }
31}
32
33impl RateLimiter {
34 pub fn new(limit: usize) -> Self {
35 Self {
36 semaphore: Arc::new(Semaphore::new(limit)),
37 }
38 }
39
40 pub fn run<'a, Fut, T>(
41 &self,
42 future: Fut,
43 ) -> impl 'a + Future<Output = Result<T, LanguageModelCompletionError>>
44 where
45 Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
46 {
47 let guard = self.semaphore.acquire_arc();
48 async move {
49 let guard = guard.await;
50 let result = future.await?;
51 drop(guard);
52 Ok(result)
53 }
54 }
55
56 pub fn stream<'a, Fut, T>(
57 &self,
58 future: Fut,
59 ) -> impl 'a
60 + Future<
61 Output = Result<impl Stream<Item = T::Item> + use<Fut, T>, LanguageModelCompletionError>,
62 >
63 where
64 Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
65 T: Stream,
66 {
67 let guard = self.semaphore.acquire_arc();
68 async move {
69 let guard = guard.await;
70 let inner = future.await?;
71 Ok(RateLimitGuard {
72 inner,
73 _guard: Some(guard),
74 })
75 }
76 }
77
78 /// Like `stream`, but conditionally bypasses the rate limiter based on the flag.
79 /// Used for nested requests (like edit agent requests) that are already "part of"
80 /// a rate-limited request to avoid deadlocks.
81 pub fn stream_with_bypass<'a, Fut, T>(
82 &self,
83 future: Fut,
84 bypass: bool,
85 ) -> impl 'a
86 + Future<
87 Output = Result<impl Stream<Item = T::Item> + use<Fut, T>, LanguageModelCompletionError>,
88 >
89 where
90 Fut: 'a + Future<Output = Result<T, LanguageModelCompletionError>>,
91 T: Stream,
92 {
93 let semaphore = self.semaphore.clone();
94 async move {
95 let guard = if bypass {
96 None
97 } else {
98 Some(semaphore.acquire_arc().await)
99 };
100 let inner = future.await?;
101 Ok(RateLimitGuard {
102 inner,
103 _guard: guard,
104 })
105 }
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use futures::stream;
113 use smol::lock::Barrier;
114 use std::sync::Arc;
115 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
116 use std::time::{Duration, Instant};
117
118 /// Tests that nested requests without bypass_rate_limit cause deadlock,
119 /// while requests with bypass_rate_limit complete successfully.
120 ///
121 /// This test simulates the scenario where multiple "parent" requests each
122 /// try to spawn a "nested" request (like edit_file tool spawning an edit agent).
123 /// With a rate limit of 2 and 2 parent requests, without bypass the nested
124 /// requests would block forever waiting for permits that the parents hold.
125 #[test]
126 fn test_nested_requests_bypass_prevents_deadlock() {
127 smol::block_on(async {
128 // Use only 2 permits so we can guarantee deadlock conditions
129 let rate_limiter = RateLimiter::new(2);
130 let completed = Arc::new(AtomicUsize::new(0));
131 // Barrier ensures all parents acquire permits before any tries nested request
132 let barrier = Arc::new(Barrier::new(2));
133
134 // Spawn 2 "parent" requests that each try to make a "nested" request
135 let mut handles = Vec::new();
136 for _ in 0..2 {
137 let limiter = rate_limiter.clone();
138 let completed = completed.clone();
139 let barrier = barrier.clone();
140
141 let handle = smol::spawn(async move {
142 // Parent request acquires a permit via stream_with_bypass (bypass=false)
143 let parent_stream = limiter
144 .stream_with_bypass(
145 async {
146 // Wait for all parents to acquire permits
147 barrier.wait().await;
148
149 // While holding the parent permit, make a nested request
150 // WITH bypass=true (simulating EditAgent behavior)
151 let nested_stream = limiter
152 .stream_with_bypass(
153 async { Ok(stream::iter(vec![1, 2, 3])) },
154 true, // bypass - this is the key!
155 )
156 .await?;
157
158 // Consume the nested stream
159 use futures::StreamExt;
160 let _: Vec<_> = nested_stream.collect().await;
161
162 Ok(stream::iter(vec!["done"]))
163 },
164 false, // parent does NOT bypass
165 )
166 .await
167 .unwrap();
168
169 // Consume parent stream
170 use futures::StreamExt;
171 let _: Vec<_> = parent_stream.collect().await;
172
173 completed.fetch_add(1, Ordering::SeqCst);
174 });
175 handles.push(handle);
176 }
177
178 // With bypass=true for nested requests, this should complete quickly
179 let timed_out = Arc::new(AtomicBool::new(false));
180 let timed_out_clone = timed_out.clone();
181
182 // Spawn a watchdog that sets timed_out after 2 seconds
183 let watchdog = smol::spawn(async move {
184 let start = Instant::now();
185 while start.elapsed() < Duration::from_secs(2) {
186 smol::future::yield_now().await;
187 }
188 timed_out_clone.store(true, Ordering::SeqCst);
189 });
190
191 // Wait for all handles to complete
192 for handle in handles {
193 handle.await;
194 }
195
196 // Cancel the watchdog
197 drop(watchdog);
198
199 if timed_out.load(Ordering::SeqCst) {
200 panic!(
201 "Test timed out - deadlock detected! This means bypass_rate_limit is not working."
202 );
203 }
204 assert_eq!(completed.load(Ordering::SeqCst), 2);
205 });
206 }
207
208 /// Tests that without bypass, nested requests DO cause deadlock.
209 /// This test verifies the problem exists when bypass is not used.
210 #[test]
211 fn test_nested_requests_without_bypass_deadlocks() {
212 smol::block_on(async {
213 // Use only 2 permits so we can guarantee deadlock conditions
214 let rate_limiter = RateLimiter::new(2);
215 let completed = Arc::new(AtomicUsize::new(0));
216 // Barrier ensures all parents acquire permits before any tries nested request
217 let barrier = Arc::new(Barrier::new(2));
218
219 // Spawn 2 "parent" requests that each try to make a "nested" request
220 let mut handles = Vec::new();
221 for _ in 0..2 {
222 let limiter = rate_limiter.clone();
223 let completed = completed.clone();
224 let barrier = barrier.clone();
225
226 let handle = smol::spawn(async move {
227 // Parent request acquires a permit
228 let parent_stream = limiter
229 .stream_with_bypass(
230 async {
231 // Wait for all parents to acquire permits - this guarantees
232 // that all 2 permits are held before any nested request starts
233 barrier.wait().await;
234
235 // Nested request WITHOUT bypass - this will deadlock!
236 // Both parents hold permits, so no permits available
237 let nested_stream = limiter
238 .stream_with_bypass(
239 async { Ok(stream::iter(vec![1, 2, 3])) },
240 false, // NO bypass - will try to acquire permit
241 )
242 .await?;
243
244 use futures::StreamExt;
245 let _: Vec<_> = nested_stream.collect().await;
246
247 Ok(stream::iter(vec!["done"]))
248 },
249 false,
250 )
251 .await
252 .unwrap();
253
254 use futures::StreamExt;
255 let _: Vec<_> = parent_stream.collect().await;
256
257 completed.fetch_add(1, Ordering::SeqCst);
258 });
259 handles.push(handle);
260 }
261
262 // This SHOULD timeout because of deadlock (both parents hold permits,
263 // both nested requests wait for permits)
264 let timed_out = Arc::new(AtomicBool::new(false));
265 let timed_out_clone = timed_out.clone();
266
267 // Spawn a watchdog that sets timed_out after 100ms
268 let watchdog = smol::spawn(async move {
269 let start = Instant::now();
270 while start.elapsed() < Duration::from_millis(100) {
271 smol::future::yield_now().await;
272 }
273 timed_out_clone.store(true, Ordering::SeqCst);
274 });
275
276 // Poll briefly to let everything run
277 let start = Instant::now();
278 while start.elapsed() < Duration::from_millis(100) {
279 smol::future::yield_now().await;
280 }
281
282 // Cancel the watchdog
283 drop(watchdog);
284
285 // Expected - deadlock occurred, which proves the bypass is necessary
286 let count = completed.load(Ordering::SeqCst);
287 assert_eq!(
288 count, 0,
289 "Expected complete deadlock (0 completed) but {} requests completed",
290 count
291 );
292 });
293 }
294}