Rate limiting status (#2955)

Kyle Caverly created

Add a rate limit remaining status to Project Search Semantic Search
minor text

Release Notes (Preview-Only):

- Added tracking functionality within EmbeddingProvider, to track rate
limit expiry
- Update minor text within Project Search to show countdown remaining
before rate limit expiry

Change summary

crates/search/src/project_search.rs               | 45 +++++++++
crates/semantic_index/src/embedding.rs            | 70 ++++++++++++++++
crates/semantic_index/src/semantic_index.rs       | 11 +-
crates/semantic_index/src/semantic_index_tests.rs |  6 +
4 files changed, 121 insertions(+), 11 deletions(-)

Detailed changes

crates/search/src/project_search.rs 🔗

@@ -34,6 +34,7 @@ use std::{
     ops::{Not, Range},
     path::PathBuf,
     sync::Arc,
+    time::{Duration, Instant},
 };
 use util::ResultExt as _;
 use workspace::{
@@ -130,6 +131,7 @@ pub struct ProjectSearchView {
 
 struct SemanticState {
     index_status: SemanticIndexStatus,
+    maintain_rate_limit: Option<Task<()>>,
     _subscription: Subscription,
 }
 
@@ -319,11 +321,28 @@ impl View for ProjectSearchView {
                 let status = semantic.index_status;
                 match status {
                     SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
-                    SemanticIndexStatus::Indexing { remaining_files } => {
+                    SemanticIndexStatus::Indexing {
+                        remaining_files,
+                        rate_limit_expiry,
+                    } => {
                         if remaining_files == 0 {
                             Some(format!("Indexing..."))
                         } else {
-                            Some(format!("Remaining files to index: {}", remaining_files))
+                            if let Some(rate_limit_expiry) = rate_limit_expiry {
+                                let remaining_seconds =
+                                    rate_limit_expiry.duration_since(Instant::now());
+                                if remaining_seconds > Duration::from_secs(0) {
+                                    Some(format!(
+                                        "Remaining files to index (rate limit resets in {}s): {}",
+                                        remaining_seconds.as_secs(),
+                                        remaining_files
+                                    ))
+                                } else {
+                                    Some(format!("Remaining files to index: {}", remaining_files))
+                                }
+                            } else {
+                                Some(format!("Remaining files to index: {}", remaining_files))
+                            }
                         }
                     }
                     SemanticIndexStatus::NotIndexed => None,
@@ -651,9 +670,10 @@ impl ProjectSearchView {
 
             self.semantic_state = Some(SemanticState {
                 index_status: semantic_index.read(cx).status(&project),
+                maintain_rate_limit: None,
                 _subscription: cx.observe(&semantic_index, Self::semantic_index_changed),
             });
-            cx.notify();
+            self.semantic_index_changed(semantic_index, cx);
         }
     }
 
@@ -664,8 +684,25 @@ impl ProjectSearchView {
     ) {
         let project = self.model.read(cx).project.clone();
         if let Some(semantic_state) = self.semantic_state.as_mut() {
-            semantic_state.index_status = semantic_index.read(cx).status(&project);
             cx.notify();
+            semantic_state.index_status = semantic_index.read(cx).status(&project);
+            if let SemanticIndexStatus::Indexing {
+                rate_limit_expiry: Some(_),
+                ..
+            } = &semantic_state.index_status
+            {
+                if semantic_state.maintain_rate_limit.is_none() {
+                    semantic_state.maintain_rate_limit =
+                        Some(cx.spawn(|this, mut cx| async move {
+                            loop {
+                                cx.background().timer(Duration::from_secs(1)).await;
+                                this.update(&mut cx, |_, cx| cx.notify()).log_err();
+                            }
+                        }));
+                    return;
+                }
+            }
+            semantic_state.maintain_rate_limit = None;
         }
     }
 

crates/semantic_index/src/embedding.rs 🔗

@@ -7,13 +7,16 @@ use isahc::http::StatusCode;
 use isahc::prelude::Configurable;
 use isahc::{AsyncBody, Response};
 use lazy_static::lazy_static;
+use parking_lot::Mutex;
 use parse_duration::parse;
+use postage::watch;
 use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 use rusqlite::ToSql;
 use serde::{Deserialize, Serialize};
 use std::env;
+use std::ops::Add;
 use std::sync::Arc;
-use std::time::Duration;
+use std::time::{Duration, Instant};
 use tiktoken_rs::{cl100k_base, CoreBPE};
 use util::http::{HttpClient, Request};
 
@@ -82,6 +85,8 @@ impl ToSql for Embedding {
 pub struct OpenAIEmbeddings {
     pub client: Arc<dyn HttpClient>,
     pub executor: Arc<Background>,
+    rate_limit_count_rx: watch::Receiver<Option<Instant>>,
+    rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
 }
 
 #[derive(Serialize)]
@@ -114,12 +119,16 @@ pub trait EmbeddingProvider: Sync + Send {
     async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
     fn max_tokens_per_batch(&self) -> usize;
     fn truncate(&self, span: &str) -> (String, usize);
+    fn rate_limit_expiration(&self) -> Option<Instant>;
 }
 
 pub struct DummyEmbeddings {}
 
 #[async_trait]
 impl EmbeddingProvider for DummyEmbeddings {
+    fn rate_limit_expiration(&self) -> Option<Instant> {
+        None
+    }
     async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         // 1024 is the OpenAI Embeddings size for ada models.
         // the model we will likely be starting with.
@@ -149,6 +158,50 @@ impl EmbeddingProvider for DummyEmbeddings {
 const OPENAI_INPUT_LIMIT: usize = 8190;
 
 impl OpenAIEmbeddings {
+    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
+        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
+
+        OpenAIEmbeddings {
+            client,
+            executor,
+            rate_limit_count_rx,
+            rate_limit_count_tx,
+        }
+    }
+
+    fn resolve_rate_limit(&self) {
+        let reset_time = *self.rate_limit_count_tx.lock().borrow();
+
+        if let Some(reset_time) = reset_time {
+            if Instant::now() >= reset_time {
+                *self.rate_limit_count_tx.lock().borrow_mut() = None
+            }
+        }
+
+        log::trace!(
+            "resolving reset time: {:?}",
+            *self.rate_limit_count_tx.lock().borrow()
+        );
+    }
+
+    fn update_reset_time(&self, reset_time: Instant) {
+        let original_time = *self.rate_limit_count_tx.lock().borrow();
+
+        let updated_time = if let Some(original_time) = original_time {
+            if reset_time < original_time {
+                Some(reset_time)
+            } else {
+                Some(original_time)
+            }
+        } else {
+            Some(reset_time)
+        };
+
+        log::trace!("updating rate limit time: {:?}", updated_time);
+
+        *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
+    }
     async fn send_request(
         &self,
         api_key: &str,
@@ -179,6 +232,9 @@ impl EmbeddingProvider for OpenAIEmbeddings {
         50000
     }
 
+    fn rate_limit_expiration(&self) -> Option<Instant> {
+        *self.rate_limit_count_rx.borrow()
+    }
     fn truncate(&self, span: &str) -> (String, usize) {
         let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
         let output = if tokens.len() > OPENAI_INPUT_LIMIT {
@@ -203,6 +259,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
             .ok_or_else(|| anyhow!("no api key"))?;
 
         let mut request_number = 0;
+        let mut rate_limiting = false;
         let mut request_timeout: u64 = 15;
         let mut response: Response<AsyncBody>;
         while request_number < MAX_RETRIES {
@@ -229,6 +286,12 @@ impl EmbeddingProvider for OpenAIEmbeddings {
                         response.usage.total_tokens
                     );
 
+                    // If we complete a request successfully that was previously rate_limited
+                    // resolve the rate limit
+                    if rate_limiting {
+                        self.resolve_rate_limit()
+                    }
+
                     return Ok(response
                         .data
                         .into_iter()
@@ -236,6 +299,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
                         .collect());
                 }
                 StatusCode::TOO_MANY_REQUESTS => {
+                    rate_limiting = true;
                     let mut body = String::new();
                     response.body_mut().read_to_string(&mut body).await?;
 
@@ -254,6 +318,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
                         }
                     };
 
+                    // If we've previously rate limited, increment the duration but not the count
+                    let reset_time = Instant::now().add(delay_duration);
+                    self.update_reset_time(reset_time);
+
                     log::trace!(
                         "openai rate limiting: waiting {:?} until lifted",
                         &delay_duration

crates/semantic_index/src/semantic_index.rs 🔗

@@ -91,10 +91,7 @@ pub fn init(
         let semantic_index = SemanticIndex::new(
             fs,
             db_file_path,
-            Arc::new(OpenAIEmbeddings {
-                client: http_client,
-                executor: cx.background(),
-            }),
+            Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
             language_registry,
             cx.clone(),
         )
@@ -113,7 +110,10 @@ pub fn init(
 pub enum SemanticIndexStatus {
     NotIndexed,
     Indexed,
-    Indexing { remaining_files: usize },
+    Indexing {
+        remaining_files: usize,
+        rate_limit_expiry: Option<Instant>,
+    },
 }
 
 pub struct SemanticIndex {
@@ -293,6 +293,7 @@ impl SemanticIndex {
             } else {
                 SemanticIndexStatus::Indexing {
                     remaining_files: project_state.pending_file_count_rx.borrow().clone(),
+                    rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
                 }
             }
         } else {

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -21,7 +21,7 @@ use std::{
         atomic::{self, AtomicUsize},
         Arc,
     },
-    time::SystemTime,
+    time::{Instant, SystemTime},
 };
 use unindent::Unindent;
 use util::RandomCharIter;
@@ -1275,6 +1275,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
         200
     }
 
+    fn rate_limit_expiration(&self) -> Option<Instant> {
+        None
+    }
+
     async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         self.embedding_count
             .fetch_add(spans.len(), atomic::Ordering::SeqCst);