initial outline for rate limiting status updates

KCaverly created

Change summary

crates/search/src/project_search.rs               | 16 +++
crates/semantic_index/src/embedding.rs            | 75 +++++++++++++++++
crates/semantic_index/src/semantic_index.rs       | 17 ++-
crates/semantic_index/src/semantic_index_tests.rs |  6 +
4 files changed, 106 insertions(+), 8 deletions(-)

Detailed changes

crates/search/src/project_search.rs 🔗

@@ -34,6 +34,7 @@ use std::{
     ops::{Not, Range},
     path::PathBuf,
     sync::Arc,
+    time::Duration,
 };
 use util::ResultExt as _;
 use workspace::{
@@ -319,11 +320,22 @@ impl View for ProjectSearchView {
                 let status = semantic.index_status;
                 match status {
                     SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
-                    SemanticIndexStatus::Indexing { remaining_files } => {
+                    SemanticIndexStatus::Indexing {
+                        remaining_files,
+                        rate_limiting,
+                    } => {
                         if remaining_files == 0 {
                             Some(format!("Indexing..."))
                         } else {
-                            Some(format!("Remaining files to index: {}", remaining_files))
+                            if rate_limiting > Duration::ZERO {
+                                Some(format!(
+                                    "Remaining files to index (rate limit resets in {}s): {}",
+                                    rate_limiting.as_secs(),
+                                    remaining_files
+                                ))
+                            } else {
+                                Some(format!("Remaining files to index: {}", remaining_files))
+                            }
                         }
                     }
                     SemanticIndexStatus::NotIndexed => None,

crates/semantic_index/src/embedding.rs 🔗

@@ -7,7 +7,9 @@ use isahc::http::StatusCode;
 use isahc::prelude::Configurable;
 use isahc::{AsyncBody, Response};
 use lazy_static::lazy_static;
+use parking_lot::Mutex;
 use parse_duration::parse;
+use postage::watch;
 use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
 use rusqlite::ToSql;
 use serde::{Deserialize, Serialize};
@@ -82,6 +84,8 @@ impl ToSql for Embedding {
 pub struct OpenAIEmbeddings {
     pub client: Arc<dyn HttpClient>,
     pub executor: Arc<Background>,
+    rate_limit_count_rx: watch::Receiver<(Duration, usize)>,
+    rate_limit_count_tx: Arc<Mutex<watch::Sender<(Duration, usize)>>>,
 }
 
 #[derive(Serialize)]
@@ -114,12 +118,16 @@ pub trait EmbeddingProvider: Sync + Send {
     async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
     fn max_tokens_per_batch(&self) -> usize;
     fn truncate(&self, span: &str) -> (String, usize);
+    fn rate_limit_expiration(&self) -> Duration;
 }
 
 pub struct DummyEmbeddings {}
 
 #[async_trait]
 impl EmbeddingProvider for DummyEmbeddings {
+    fn rate_limit_expiration(&self) -> Duration {
+        Duration::ZERO
+    }
     async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         // 1024 is the OpenAI Embeddings size for ada models.
         // the model we will likely be starting with.
@@ -149,6 +157,53 @@ impl EmbeddingProvider for DummyEmbeddings {
 const OPENAI_INPUT_LIMIT: usize = 8190;
 
 impl OpenAIEmbeddings {
+    pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+        let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with((Duration::ZERO, 0));
+        let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
+
+        OpenAIEmbeddings {
+            client,
+            executor,
+            rate_limit_count_rx,
+            rate_limit_count_tx,
+        }
+    }
+
+    fn resolve_rate_limit(&self) {
+        let (current_delay, delay_count) = *self.rate_limit_count_tx.lock().borrow();
+        let updated_count = delay_count - 1;
+        let updated_duration = if updated_count == 0 {
+            Duration::ZERO
+        } else {
+            current_delay
+        };
+
+        log::trace!(
+            "resolving rate limit: Count: {:?} Duration: {:?}",
+            updated_count,
+            updated_duration
+        );
+
+        *self.rate_limit_count_tx.lock().borrow_mut() = (updated_duration, updated_count);
+    }
+
+    fn update_rate_limit(&self, delay_duration: Duration, count_increase: usize) {
+        let (current_delay, delay_count) = *self.rate_limit_count_tx.lock().borrow();
+        let updated_count = delay_count + count_increase;
+        let updated_duration = if current_delay < delay_duration {
+            delay_duration
+        } else {
+            current_delay
+        };
+
+        log::trace!(
+            "updating rate limit: Count: {:?} Duration: {:?}",
+            updated_count,
+            updated_duration
+        );
+
+        *self.rate_limit_count_tx.lock().borrow_mut() = (updated_duration, updated_count);
+    }
     async fn send_request(
         &self,
         api_key: &str,
@@ -179,6 +234,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
         50000
     }
 
+    fn rate_limit_expiration(&self) -> Duration {
+        let (duration, _) = *self.rate_limit_count_rx.borrow();
+        duration
+    }
     fn truncate(&self, span: &str) -> (String, usize) {
         let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
         let output = if tokens.len() > OPENAI_INPUT_LIMIT {
@@ -203,6 +262,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
             .ok_or_else(|| anyhow!("no api key"))?;
 
         let mut request_number = 0;
+        let mut rate_limiting = false;
         let mut request_timeout: u64 = 15;
         let mut response: Response<AsyncBody>;
         while request_number < MAX_RETRIES {
@@ -229,6 +289,12 @@ impl EmbeddingProvider for OpenAIEmbeddings {
                         response.usage.total_tokens
                     );
 
+                    // If we complete a request successfully that was previously rate_limited
+                    // resolve the rate limit
+                    if rate_limiting {
+                        self.resolve_rate_limit()
+                    }
+
                     return Ok(response
                         .data
                         .into_iter()
@@ -254,6 +320,15 @@ impl EmbeddingProvider for OpenAIEmbeddings {
                         }
                     };
 
+                    // If we've previously rate limited, increment the duration but not the count
+                    if rate_limiting {
+                        self.update_rate_limit(delay_duration, 0);
+                    } else {
+                        self.update_rate_limit(delay_duration, 1);
+                    }
+
+                    rate_limiting = true;
+
                     log::trace!(
                         "openai rate limiting: waiting {:?} until lifted",
                         &delay_duration

crates/semantic_index/src/semantic_index.rs 🔗

@@ -91,10 +91,7 @@ pub fn init(
         let semantic_index = SemanticIndex::new(
             fs,
             db_file_path,
-            Arc::new(OpenAIEmbeddings {
-                client: http_client,
-                executor: cx.background(),
-            }),
+            Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
             language_registry,
             cx.clone(),
         )
@@ -113,7 +110,10 @@ pub fn init(
 pub enum SemanticIndexStatus {
     NotIndexed,
     Indexed,
-    Indexing { remaining_files: usize },
+    Indexing {
+        remaining_files: usize,
+        rate_limiting: Duration,
+    },
 }
 
 pub struct SemanticIndex {
@@ -132,6 +132,8 @@ struct ProjectState {
     pending_file_count_rx: watch::Receiver<usize>,
     pending_file_count_tx: Arc<Mutex<watch::Sender<usize>>>,
     pending_index: usize,
+    rate_limiting_count_rx: watch::Receiver<usize>,
+    rate_limiting_count_tx: Arc<Mutex<watch::Sender<usize>>>,
     _subscription: gpui::Subscription,
     _observe_pending_file_count: Task<()>,
 }
@@ -223,11 +225,15 @@ impl ProjectState {
     fn new(subscription: gpui::Subscription, cx: &mut ModelContext<SemanticIndex>) -> Self {
         let (pending_file_count_tx, pending_file_count_rx) = watch::channel_with(0);
         let pending_file_count_tx = Arc::new(Mutex::new(pending_file_count_tx));
+        let (rate_limiting_count_tx, rate_limiting_count_rx) = watch::channel_with(0);
+        let rate_limiting_count_tx = Arc::new(Mutex::new(rate_limiting_count_tx));
         Self {
             worktrees: Default::default(),
             pending_file_count_rx: pending_file_count_rx.clone(),
             pending_file_count_tx,
             pending_index: 0,
+            rate_limiting_count_rx: rate_limiting_count_rx.clone(),
+            rate_limiting_count_tx,
             _subscription: subscription,
             _observe_pending_file_count: cx.spawn_weak({
                 let mut pending_file_count_rx = pending_file_count_rx.clone();
@@ -293,6 +299,7 @@ impl SemanticIndex {
             } else {
                 SemanticIndexStatus::Indexing {
                     remaining_files: project_state.pending_file_count_rx.borrow().clone(),
+                    rate_limiting: self.embedding_provider.rate_limit_expiration(),
                 }
             }
         } else {

crates/semantic_index/src/semantic_index_tests.rs 🔗

@@ -21,7 +21,7 @@ use std::{
         atomic::{self, AtomicUsize},
         Arc,
     },
-    time::SystemTime,
+    time::{Duration, SystemTime},
 };
 use unindent::Unindent;
 use util::RandomCharIter;
@@ -1275,6 +1275,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
         200
     }
 
+    fn rate_limit_expiration(&self) -> Duration {
+        Duration::ZERO
+    }
+
     async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
         self.embedding_count
             .fetch_add(spans.len(), atomic::Ordering::SeqCst);