Detailed changes
@@ -34,6 +34,7 @@ use std::{
ops::{Not, Range},
path::PathBuf,
sync::Arc,
+ time::{Duration, Instant},
};
use util::ResultExt as _;
use workspace::{
@@ -130,6 +131,7 @@ pub struct ProjectSearchView {
struct SemanticState {
index_status: SemanticIndexStatus,
+ maintain_rate_limit: Option<Task<()>>,
_subscription: Subscription,
}
@@ -319,11 +321,28 @@ impl View for ProjectSearchView {
let status = semantic.index_status;
match status {
SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
- SemanticIndexStatus::Indexing { remaining_files } => {
+ SemanticIndexStatus::Indexing {
+ remaining_files,
+ rate_limit_expiry,
+ } => {
if remaining_files == 0 {
Some(format!("Indexing..."))
} else {
- Some(format!("Remaining files to index: {}", remaining_files))
+ if let Some(rate_limit_expiry) = rate_limit_expiry {
+ let remaining_seconds =
+ rate_limit_expiry.duration_since(Instant::now());
+ if remaining_seconds > Duration::from_secs(0) {
+ Some(format!(
+ "Remaining files to index (rate limit resets in {}s): {}",
+ remaining_seconds.as_secs(),
+ remaining_files
+ ))
+ } else {
+ Some(format!("Remaining files to index: {}", remaining_files))
+ }
+ } else {
+ Some(format!("Remaining files to index: {}", remaining_files))
+ }
}
}
SemanticIndexStatus::NotIndexed => None,
@@ -651,9 +670,10 @@ impl ProjectSearchView {
self.semantic_state = Some(SemanticState {
index_status: semantic_index.read(cx).status(&project),
+ maintain_rate_limit: None,
_subscription: cx.observe(&semantic_index, Self::semantic_index_changed),
});
- cx.notify();
+ self.semantic_index_changed(semantic_index, cx);
}
}
@@ -664,8 +684,25 @@ impl ProjectSearchView {
) {
let project = self.model.read(cx).project.clone();
if let Some(semantic_state) = self.semantic_state.as_mut() {
- semantic_state.index_status = semantic_index.read(cx).status(&project);
cx.notify();
+ semantic_state.index_status = semantic_index.read(cx).status(&project);
+ if let SemanticIndexStatus::Indexing {
+ rate_limit_expiry: Some(_),
+ ..
+ } = &semantic_state.index_status
+ {
+ if semantic_state.maintain_rate_limit.is_none() {
+ semantic_state.maintain_rate_limit =
+ Some(cx.spawn(|this, mut cx| async move {
+ loop {
+ cx.background().timer(Duration::from_secs(1)).await;
+ this.update(&mut cx, |_, cx| cx.notify()).log_err();
+ }
+ }));
+ return;
+ }
+ }
+ semantic_state.maintain_rate_limit = None;
}
}
@@ -7,13 +7,16 @@ use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
use lazy_static::lazy_static;
+use parking_lot::Mutex;
use parse_duration::parse;
+use postage::watch;
use rusqlite::types::{FromSql, FromSqlResult, ToSqlOutput, ValueRef};
use rusqlite::ToSql;
use serde::{Deserialize, Serialize};
use std::env;
+use std::ops::Add;
use std::sync::Arc;
-use std::time::Duration;
+use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
@@ -82,6 +85,8 @@ impl ToSql for Embedding {
pub struct OpenAIEmbeddings {
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
+ rate_limit_count_rx: watch::Receiver<Option<Instant>>,
+ rate_limit_count_tx: Arc<Mutex<watch::Sender<Option<Instant>>>>,
}
#[derive(Serialize)]
@@ -114,12 +119,16 @@ pub trait EmbeddingProvider: Sync + Send {
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>>;
fn max_tokens_per_batch(&self) -> usize;
fn truncate(&self, span: &str) -> (String, usize);
+ fn rate_limit_expiration(&self) -> Option<Instant>;
}
pub struct DummyEmbeddings {}
#[async_trait]
impl EmbeddingProvider for DummyEmbeddings {
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ None
+ }
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
// 1024 is the OpenAI Embeddings size for ada models.
// the model we will likely be starting with.
@@ -149,6 +158,50 @@ impl EmbeddingProvider for DummyEmbeddings {
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
+ pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+ let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
+ let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
+
+ OpenAIEmbeddings {
+ client,
+ executor,
+ rate_limit_count_rx,
+ rate_limit_count_tx,
+ }
+ }
+
+ fn resolve_rate_limit(&self) {
+ let reset_time = *self.rate_limit_count_tx.lock().borrow();
+
+ if let Some(reset_time) = reset_time {
+ if Instant::now() >= reset_time {
+ *self.rate_limit_count_tx.lock().borrow_mut() = None
+ }
+ }
+
+ log::trace!(
+ "resolving reset time: {:?}",
+ *self.rate_limit_count_tx.lock().borrow()
+ );
+ }
+
+ fn update_reset_time(&self, reset_time: Instant) {
+ let original_time = *self.rate_limit_count_tx.lock().borrow();
+
+ let updated_time = if let Some(original_time) = original_time {
+ if reset_time < original_time {
+ Some(reset_time)
+ } else {
+ Some(original_time)
+ }
+ } else {
+ Some(reset_time)
+ };
+
+ log::trace!("updating rate limit time: {:?}", updated_time);
+
+ *self.rate_limit_count_tx.lock().borrow_mut() = updated_time;
+ }
async fn send_request(
&self,
api_key: &str,
@@ -179,6 +232,9 @@ impl EmbeddingProvider for OpenAIEmbeddings {
50000
}
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ *self.rate_limit_count_rx.borrow()
+ }
fn truncate(&self, span: &str) -> (String, usize) {
let mut tokens = OPENAI_BPE_TOKENIZER.encode_with_special_tokens(span);
let output = if tokens.len() > OPENAI_INPUT_LIMIT {
@@ -203,6 +259,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
.ok_or_else(|| anyhow!("no api key"))?;
let mut request_number = 0;
+ let mut rate_limiting = false;
let mut request_timeout: u64 = 15;
let mut response: Response<AsyncBody>;
while request_number < MAX_RETRIES {
@@ -229,6 +286,12 @@ impl EmbeddingProvider for OpenAIEmbeddings {
response.usage.total_tokens
);
+ // If we complete a request successfully that was previously rate_limited
+ // resolve the rate limit
+ if rate_limiting {
+ self.resolve_rate_limit()
+ }
+
return Ok(response
.data
.into_iter()
@@ -236,6 +299,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
.collect());
}
StatusCode::TOO_MANY_REQUESTS => {
+ rate_limiting = true;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
@@ -254,6 +318,10 @@ impl EmbeddingProvider for OpenAIEmbeddings {
}
};
+ // If we've previously rate limited, increment the duration but not the count
+ let reset_time = Instant::now().add(delay_duration);
+ self.update_reset_time(reset_time);
+
log::trace!(
"openai rate limiting: waiting {:?} until lifted",
&delay_duration
@@ -91,10 +91,7 @@ pub fn init(
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
- Arc::new(OpenAIEmbeddings {
- client: http_client,
- executor: cx.background(),
- }),
+ Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
language_registry,
cx.clone(),
)
@@ -113,7 +110,10 @@ pub fn init(
pub enum SemanticIndexStatus {
NotIndexed,
Indexed,
- Indexing { remaining_files: usize },
+ Indexing {
+ remaining_files: usize,
+ rate_limit_expiry: Option<Instant>,
+ },
}
pub struct SemanticIndex {
@@ -293,6 +293,7 @@ impl SemanticIndex {
} else {
SemanticIndexStatus::Indexing {
remaining_files: project_state.pending_file_count_rx.borrow().clone(),
+ rate_limit_expiry: self.embedding_provider.rate_limit_expiration(),
}
}
} else {
@@ -21,7 +21,7 @@ use std::{
atomic::{self, AtomicUsize},
Arc,
},
- time::SystemTime,
+ time::{Instant, SystemTime},
};
use unindent::Unindent;
use util::RandomCharIter;
@@ -1275,6 +1275,10 @@ impl EmbeddingProvider for FakeEmbeddingProvider {
200
}
+ fn rate_limit_expiration(&self) -> Option<Instant> {
+ None
+ }
+
async fn embed_batch(&self, spans: Vec<String>) -> Result<Vec<Embedding>> {
self.embedding_count
.fetch_add(spans.len(), atomic::Ordering::SeqCst);