@@ -6,10 +6,10 @@ use sea_orm::prelude::DateTimeUtc;
use std::sync::Arc;
use util::ResultExt;
-pub trait RateLimit: 'static {
- fn capacity() -> usize;
- fn refill_duration() -> Duration;
- fn db_name() -> &'static str;
+pub trait RateLimit: Send + Sync {
+ fn capacity(&self) -> usize;
+ fn refill_duration(&self) -> Duration;
+ fn db_name(&self) -> &'static str;
}
/// Used to enforce per-user rate limits
@@ -42,18 +42,23 @@ impl RateLimiter {
/// Returns an error if the user has exceeded the specified `RateLimit`.
/// Attempts to read the from the database if no cached RateBucket currently exists.
- pub async fn check<T: RateLimit>(&self, user_id: UserId) -> Result<()> {
- self.check_internal::<T>(user_id, Utc::now()).await
+ pub async fn check(&self, limit: &dyn RateLimit, user_id: UserId) -> Result<()> {
+ self.check_internal(limit, user_id, Utc::now()).await
}
- async fn check_internal<T: RateLimit>(&self, user_id: UserId, now: DateTimeUtc) -> Result<()> {
- let bucket_key = (user_id, T::db_name().to_string());
+ async fn check_internal(
+ &self,
+ limit: &dyn RateLimit,
+ user_id: UserId,
+ now: DateTimeUtc,
+ ) -> Result<()> {
+ let bucket_key = (user_id, limit.db_name().to_string());
// Attempt to fetch the bucket from the database if it hasn't been cached.
// For now, we keep buckets in memory for the lifetime of the process rather than expiring them,
// but this enforces limits across restarts so long as the database is reachable.
if !self.buckets.contains_key(&bucket_key) {
- if let Some(bucket) = self.load_bucket::<T>(user_id).await.log_err().flatten() {
+ if let Some(bucket) = self.load_bucket(limit, user_id).await.log_err().flatten() {
self.buckets.insert(bucket_key.clone(), bucket);
self.dirty_buckets.insert(bucket_key.clone());
}
@@ -62,7 +67,7 @@ impl RateLimiter {
let mut bucket = self
.buckets
.entry(bucket_key.clone())
- .or_insert_with(|| RateBucket::new::<T>(now));
+ .or_insert_with(|| RateBucket::new(limit, now));
if bucket.value_mut().allow(now) {
self.dirty_buckets.insert(bucket_key);
@@ -72,16 +77,18 @@ impl RateLimiter {
}
}
- async fn load_bucket<T: RateLimit>(
+ async fn load_bucket(
&self,
+ limit: &dyn RateLimit,
user_id: UserId,
) -> Result<Option<RateBucket>, Error> {
Ok(self
.db
- .get_rate_bucket(user_id, T::db_name())
+ .get_rate_bucket(user_id, limit.db_name())
.await?
.map(|saved_bucket| {
- RateBucket::from_db::<T>(
+ RateBucket::from_db(
+ limit,
saved_bucket.token_count as usize,
DateTime::from_naive_utc_and_offset(saved_bucket.last_refill, Utc),
)
@@ -124,20 +131,20 @@ struct RateBucket {
}
impl RateBucket {
- fn new<T: RateLimit>(now: DateTimeUtc) -> Self {
+ fn new(limit: &dyn RateLimit, now: DateTimeUtc) -> Self {
Self {
- capacity: T::capacity(),
- token_count: T::capacity(),
- refill_time_per_token: T::refill_duration() / T::capacity() as i32,
+ capacity: limit.capacity(),
+ token_count: limit.capacity(),
+ refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
last_refill: now,
}
}
- fn from_db<T: RateLimit>(token_count: usize, last_refill: DateTimeUtc) -> Self {
+ fn from_db(limit: &dyn RateLimit, token_count: usize, last_refill: DateTimeUtc) -> Self {
Self {
- capacity: T::capacity(),
+ capacity: limit.capacity(),
token_count,
- refill_time_per_token: T::refill_duration() / T::capacity() as i32,
+ refill_time_per_token: limit.refill_duration() / limit.capacity() as i32,
last_refill,
}
}
@@ -205,50 +212,52 @@ mod tests {
let mut now = Utc::now();
let rate_limiter = RateLimiter::new(db.clone());
+ let rate_limit_a = Box::new(RateLimitA);
+ let rate_limit_b = Box::new(RateLimitB);
// User 1 can access resource A two times before being rate-limited.
rate_limiter
- .check_internal::<RateLimitA>(user_1, now)
+ .check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
- .check_internal::<RateLimitA>(user_1, now)
+ .check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
- .check_internal::<RateLimitA>(user_1, now)
+ .check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
// User 2 can access resource A and user 1 can access resource B.
rate_limiter
- .check_internal::<RateLimitB>(user_2, now)
+ .check_internal(&*rate_limit_b, user_2, now)
.await
.unwrap();
rate_limiter
- .check_internal::<RateLimitB>(user_1, now)
+ .check_internal(&*rate_limit_b, user_1, now)
.await
.unwrap();
// After 1.5s, user 1 can make another request before being rate-limited again.
now += Duration::milliseconds(1500);
rate_limiter
- .check_internal::<RateLimitA>(user_1, now)
+ .check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
- .check_internal::<RateLimitA>(user_1, now)
+ .check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
// After 500ms, user 1 can make another request before being rate-limited again.
now += Duration::milliseconds(500);
rate_limiter
- .check_internal::<RateLimitA>(user_1, now)
+ .check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
- .check_internal::<RateLimitA>(user_1, now)
+ .check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
@@ -258,18 +267,18 @@ mod tests {
// for resource A.
let rate_limiter = RateLimiter::new(db.clone());
rate_limiter
- .check_internal::<RateLimitA>(user_1, now)
+ .check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
// After 1s, user 1 can make another request before being rate-limited again.
now += Duration::seconds(1);
rate_limiter
- .check_internal::<RateLimitA>(user_1, now)
+ .check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap();
rate_limiter
- .check_internal::<RateLimitA>(user_1, now)
+ .check_internal(&*rate_limit_a, user_1, now)
.await
.unwrap_err();
}
@@ -277,15 +286,15 @@ mod tests {
struct RateLimitA;
impl RateLimit for RateLimitA {
- fn capacity() -> usize {
+ fn capacity(&self) -> usize {
2
}
- fn refill_duration() -> Duration {
+ fn refill_duration(&self) -> Duration {
Duration::seconds(2)
}
- fn db_name() -> &'static str {
+ fn db_name(&self) -> &'static str {
"rate-limit-a"
}
}
@@ -293,15 +302,15 @@ mod tests {
struct RateLimitB;
impl RateLimit for RateLimitB {
- fn capacity() -> usize {
+ fn capacity(&self) -> usize {
10
}
- fn refill_duration() -> Duration {
+ fn refill_duration(&self) -> Duration {
Duration::seconds(3)
}
- fn db_name() -> &'static str {
+ fn db_name(&self) -> &'static str {
"rate-limit-b"
}
}
@@ -199,6 +199,23 @@ impl Session {
}
}
+ pub async fn current_plan(&self) -> anyhow::Result<proto::Plan> {
+ if self.is_staff() {
+ return Ok(proto::Plan::ZedPro);
+ }
+
+ let Some(user_id) = self.user_id() else {
+ return Ok(proto::Plan::Free);
+ };
+
+ let db = self.db().await;
+ if db.has_active_billing_subscription(user_id).await? {
+ Ok(proto::Plan::ZedPro)
+ } else {
+ Ok(proto::Plan::Free)
+ }
+ }
+
fn dev_server_id(&self) -> Option<DevServerId> {
match &self.principal {
Principal::User(_) | Principal::Impersonated { .. } => None,
@@ -3537,15 +3554,8 @@ fn should_auto_subscribe_to_channels(version: ZedVersion) -> bool {
version.0.minor() < 139
}
-async fn update_user_plan(user_id: UserId, session: &Session) -> Result<()> {
- let db = session.db().await;
- let active_subscriptions = db.get_active_billing_subscriptions(user_id).await?;
-
- let plan = if session.is_staff() || !active_subscriptions.is_empty() {
- proto::Plan::ZedPro
- } else {
- proto::Plan::Free
- };
+async fn update_user_plan(_user_id: UserId, session: &Session) -> Result<()> {
+ let plan = session.current_plan().await?;
session
.peer
@@ -4532,22 +4542,41 @@ async fn acknowledge_buffer_version(
Ok(())
}
-struct CompleteWithLanguageModelRateLimit;
+struct ZedProCompleteWithLanguageModelRateLimit;
-impl RateLimit for CompleteWithLanguageModelRateLimit {
- fn capacity() -> usize {
+impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
+ fn capacity(&self) -> usize {
std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(120) // Picked arbitrarily
}
- fn refill_duration() -> chrono::Duration {
+ fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
- fn db_name() -> &'static str {
- "complete-with-language-model"
+ fn db_name(&self) -> &'static str {
+ "zed-pro:complete-with-language-model"
+ }
+}
+
+struct FreeCompleteWithLanguageModelRateLimit;
+
+impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
+ fn capacity(&self) -> usize {
+ std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
+ .ok()
+ .and_then(|v| v.parse().ok())
+ .unwrap_or(120 / 10) // Picked arbitrarily
+ }
+
+ fn refill_duration(&self) -> chrono::Duration {
+ chrono::Duration::hours(1)
+ }
+
+ fn db_name(&self) -> &'static str {
+ "free:complete-with-language-model"
}
}
@@ -4562,9 +4591,14 @@ async fn complete_with_language_model(
};
authorize_access_to_language_models(&session).await?;
+ let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+ proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
+ proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
+ };
+
session
.rate_limiter
- .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
+ .check(&*rate_limit, session.user_id())
.await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
@@ -4602,9 +4636,14 @@ async fn stream_complete_with_language_model(
};
authorize_access_to_language_models(&session).await?;
+ let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+ proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
+ proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
+ };
+
session
.rate_limiter
- .check::<CompleteWithLanguageModelRateLimit>(session.user_id())
+ .check(&*rate_limit, session.user_id())
.await?;
match proto::LanguageModelProvider::from_i32(request.provider) {
@@ -4684,9 +4723,14 @@ async fn count_language_model_tokens(
};
authorize_access_to_language_models(&session).await?;
+ let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+ proto::Plan::ZedPro => Box::new(ZedProCountLanguageModelTokensRateLimit),
+ proto::Plan::Free => Box::new(FreeCountLanguageModelTokensRateLimit),
+ };
+
session
.rate_limiter
- .check::<CountLanguageModelTokensRateLimit>(session.user_id())
+ .check(&*rate_limit, session.user_id())
.await?;
let result = match proto::LanguageModelProvider::from_i32(request.provider) {
@@ -4713,41 +4757,79 @@ async fn count_language_model_tokens(
Ok(())
}
-struct CountLanguageModelTokensRateLimit;
+struct ZedProCountLanguageModelTokensRateLimit;
-impl RateLimit for CountLanguageModelTokensRateLimit {
- fn capacity() -> usize {
+impl RateLimit for ZedProCountLanguageModelTokensRateLimit {
+ fn capacity(&self) -> usize {
std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(600) // Picked arbitrarily
}
- fn refill_duration() -> chrono::Duration {
+ fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
- fn db_name() -> &'static str {
- "count-language-model-tokens"
+ fn db_name(&self) -> &'static str {
+ "zed-pro:count-language-model-tokens"
}
}
-struct ComputeEmbeddingsRateLimit;
+struct FreeCountLanguageModelTokensRateLimit;
-impl RateLimit for ComputeEmbeddingsRateLimit {
- fn capacity() -> usize {
+impl RateLimit for FreeCountLanguageModelTokensRateLimit {
+ fn capacity(&self) -> usize {
+ std::env::var("COUNT_LANGUAGE_MODEL_TOKENS_RATE_LIMIT_PER_HOUR_FREE")
+ .ok()
+ .and_then(|v| v.parse().ok())
+ .unwrap_or(600 / 10) // Picked arbitrarily
+ }
+
+ fn refill_duration(&self) -> chrono::Duration {
+ chrono::Duration::hours(1)
+ }
+
+ fn db_name(&self) -> &'static str {
+ "free:count-language-model-tokens"
+ }
+}
+
+struct ZedProComputeEmbeddingsRateLimit;
+
+impl RateLimit for ZedProComputeEmbeddingsRateLimit {
+ fn capacity(&self) -> usize {
std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5000) // Picked arbitrarily
}
- fn refill_duration() -> chrono::Duration {
+ fn refill_duration(&self) -> chrono::Duration {
+ chrono::Duration::hours(1)
+ }
+
+ fn db_name(&self) -> &'static str {
+ "zed-pro:compute-embeddings"
+ }
+}
+
+struct FreeComputeEmbeddingsRateLimit;
+
+impl RateLimit for FreeComputeEmbeddingsRateLimit {
+ fn capacity(&self) -> usize {
+ std::env::var("EMBED_TEXTS_RATE_LIMIT_PER_HOUR_FREE")
+ .ok()
+ .and_then(|v| v.parse().ok())
+ .unwrap_or(5000 / 10) // Picked arbitrarily
+ }
+
+ fn refill_duration(&self) -> chrono::Duration {
chrono::Duration::hours(1)
}
- fn db_name() -> &'static str {
- "compute-embeddings"
+ fn db_name(&self) -> &'static str {
+ "free:compute-embeddings"
}
}
@@ -4760,9 +4842,14 @@ async fn compute_embeddings(
let api_key = api_key.context("no OpenAI API key configured on the server")?;
authorize_access_to_language_models(&session).await?;
+ let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
+ proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
+ proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
+ };
+
session
.rate_limiter
- .check::<ComputeEmbeddingsRateLimit>(session.user_id())
+ .check(&*rate_limit, session.user_id())
.await?;
let embeddings = match request.model.as_str() {
@@ -4834,10 +4921,10 @@ async fn authorize_access_to_language_models(session: &UserSession) -> Result<()
let db = session.db().await;
let flags = db.get_user_flags(session.user_id()).await?;
if flags.iter().any(|flag| flag == "language-models") {
- Ok(())
- } else {
- Err(anyhow!("permission denied"))?
+ return Ok(());
}
+
+ Err(anyhow!("permission denied"))?
}
/// Get a Supermaven API key for the user