Detailed changes
@@ -29,6 +29,7 @@ use crate::db::billing_subscription::{
use crate::llm::db::subscription_usage_meter::CompletionMode;
use crate::llm::{AGENT_EXTENDED_TRIAL_FEATURE_FLAG, DEFAULT_MAX_MONTHLY_SPEND};
use crate::rpc::{ResultExt as _, Server};
+use crate::stripe_client::{StripeCustomerId, StripeSubscriptionId};
use crate::{AppState, Error, Result};
use crate::{db::UserId, llm::db::LlmDatabase};
use crate::{
@@ -1545,14 +1546,10 @@ async fn sync_model_request_usage_with_stripe(
);
};
- let stripe_customer_id = billing_customer
- .stripe_customer_id
- .parse::<stripe::CustomerId>()
- .context("failed to parse Stripe customer ID from database")?;
- let stripe_subscription_id = billing_subscription
- .stripe_subscription_id
- .parse::<stripe::SubscriptionId>()
- .context("failed to parse Stripe subscription ID from database")?;
+ let stripe_customer_id =
+ StripeCustomerId(billing_customer.stripe_customer_id.clone().into());
+ let stripe_subscription_id =
+ StripeSubscriptionId(billing_subscription.stripe_subscription_id.clone().into());
let model = llm_db.model_by_id(usage_meter.model_id)?;
@@ -1578,7 +1575,7 @@ async fn sync_model_request_usage_with_stripe(
};
stripe_billing
- .subscribe_to_price(&stripe_subscription_id.into(), price)
+ .subscribe_to_price(&stripe_subscription_id, price)
.await?;
stripe_billing
.bill_model_request_usage(
@@ -3,7 +3,6 @@ use std::sync::Arc;
use anyhow::{Context as _, anyhow};
use chrono::Utc;
use collections::HashMap;
-use serde::{Deserialize, Serialize};
use stripe::SubscriptionStatus;
use tokio::sync::RwLock;
use uuid::Uuid;
@@ -12,8 +11,9 @@ use crate::Result;
use crate::db::billing_subscription::SubscriptionKind;
use crate::llm::AGENT_EXTENDED_TRIAL_FEATURE_FLAG;
use crate::stripe_client::{
- RealStripeClient, StripeClient, StripeCustomerId, StripeMeter, StripePrice, StripePriceId,
- StripeSubscription, StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
+ RealStripeClient, StripeClient, StripeCreateMeterEventParams, StripeCreateMeterEventPayload,
+ StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripeSubscription,
+ StripeSubscriptionId, UpdateSubscriptionItems, UpdateSubscriptionParams,
UpdateSubscriptionTrialSettings, UpdateSubscriptionTrialSettingsEndBehavior,
UpdateSubscriptionTrialSettingsEndBehaviorMissingPaymentMethod,
};
@@ -204,16 +204,15 @@ impl StripeBilling {
pub async fn bill_model_request_usage(
&self,
- customer_id: &stripe::CustomerId,
+ customer_id: &StripeCustomerId,
event_name: &str,
requests: i32,
) -> Result<()> {
let timestamp = Utc::now().timestamp();
let idempotency_key = Uuid::new_v4();
- StripeMeterEvent::create(
- &self.real_client,
- StripeCreateMeterEventParams {
+ self.client
+ .create_meter_event(StripeCreateMeterEventParams {
identifier: &format!("model_requests/{}", idempotency_key),
event_name,
payload: StripeCreateMeterEventPayload {
@@ -221,9 +220,8 @@ impl StripeBilling {
stripe_customer_id: customer_id,
},
timestamp: Some(timestamp),
- },
- )
- .await?;
+ })
+ .await?;
Ok(())
}
@@ -371,52 +369,6 @@ impl StripeBilling {
}
}
-#[derive(Deserialize)]
-struct StripeMeterEvent {
- identifier: String,
-}
-
-impl StripeMeterEvent {
- pub async fn create(
- client: &stripe::Client,
- params: StripeCreateMeterEventParams<'_>,
- ) -> Result<Self, stripe::StripeError> {
- let identifier = params.identifier;
- match client.post_form("/billing/meter_events", params).await {
- Ok(event) => Ok(event),
- Err(stripe::StripeError::Stripe(error)) => {
- if error.http_status == 400
- && error
- .message
- .as_ref()
- .map_or(false, |message| message.contains(identifier))
- {
- Ok(Self {
- identifier: identifier.to_string(),
- })
- } else {
- Err(stripe::StripeError::Stripe(error))
- }
- }
- Err(error) => Err(error),
- }
- }
-}
-
-#[derive(Serialize)]
-struct StripeCreateMeterEventParams<'a> {
- identifier: &'a str,
- event_name: &'a str,
- payload: StripeCreateMeterEventPayload<'a>,
- timestamp: Option<i64>,
-}
-
-#[derive(Serialize)]
-struct StripeCreateMeterEventPayload<'a> {
- value: u64,
- stripe_customer_id: &'a stripe::CustomerId,
-}
-
fn subscription_contains_price(
subscription: &StripeSubscription,
price_id: &StripePriceId,
@@ -10,9 +10,9 @@ use async_trait::async_trait;
#[cfg(test)]
pub use fake_stripe_client::*;
pub use real_stripe_client::*;
-use serde::Deserialize;
+use serde::{Deserialize, Serialize};
-#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display)]
+#[derive(Debug, PartialEq, Eq, Hash, Clone, derive_more::Display, Serialize)]
pub struct StripeCustomerId(pub Arc<str>);
#[derive(Debug, Clone)]
@@ -97,6 +97,20 @@ pub struct StripeMeter {
pub event_name: String,
}
+#[derive(Debug, Serialize)]
+pub struct StripeCreateMeterEventParams<'a> {
+ pub identifier: &'a str,
+ pub event_name: &'a str,
+ pub payload: StripeCreateMeterEventPayload<'a>,
+ pub timestamp: Option<i64>,
+}
+
+#[derive(Debug, Serialize)]
+pub struct StripeCreateMeterEventPayload<'a> {
+ pub value: u64,
+ pub stripe_customer_id: &'a StripeCustomerId,
+}
+
#[async_trait]
pub trait StripeClient: Send + Sync {
async fn list_customers_by_email(&self, email: &str) -> Result<Vec<StripeCustomer>>;
@@ -117,4 +131,6 @@ pub trait StripeClient: Send + Sync {
async fn list_prices(&self) -> Result<Vec<StripePrice>>;
async fn list_meters(&self) -> Result<Vec<StripeMeter>>;
+
+ async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()>;
}
@@ -7,11 +7,20 @@ use parking_lot::Mutex;
use uuid::Uuid;
use crate::stripe_client::{
- CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter,
- StripeMeterId, StripePrice, StripePriceId, StripeSubscription, StripeSubscriptionId,
- UpdateSubscriptionParams,
+ CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
+ StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripeSubscription,
+ StripeSubscriptionId, UpdateSubscriptionParams,
};
+#[derive(Debug, Clone)]
+pub struct StripeCreateMeterEventCall {
+ pub identifier: Arc<str>,
+ pub event_name: Arc<str>,
+ pub value: u64,
+ pub stripe_customer_id: StripeCustomerId,
+ pub timestamp: Option<i64>,
+}
+
pub struct FakeStripeClient {
pub customers: Arc<Mutex<HashMap<StripeCustomerId, StripeCustomer>>>,
pub subscriptions: Arc<Mutex<HashMap<StripeSubscriptionId, StripeSubscription>>>,
@@ -19,6 +28,7 @@ pub struct FakeStripeClient {
Arc<Mutex<Vec<(StripeSubscriptionId, UpdateSubscriptionParams)>>>,
pub prices: Arc<Mutex<HashMap<StripePriceId, StripePrice>>>,
pub meters: Arc<Mutex<HashMap<StripeMeterId, StripeMeter>>>,
+ pub create_meter_event_calls: Arc<Mutex<Vec<StripeCreateMeterEventCall>>>,
}
impl FakeStripeClient {
@@ -29,6 +39,7 @@ impl FakeStripeClient {
update_subscription_calls: Arc::new(Mutex::new(Vec::new())),
prices: Arc::new(Mutex::new(HashMap::default())),
meters: Arc::new(Mutex::new(HashMap::default())),
+ create_meter_event_calls: Arc::new(Mutex::new(Vec::new())),
}
}
}
@@ -94,4 +105,18 @@ impl StripeClient for FakeStripeClient {
Ok(meters)
}
+
+ async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
+ self.create_meter_event_calls
+ .lock()
+ .push(StripeCreateMeterEventCall {
+ identifier: params.identifier.into(),
+ event_name: params.event_name.into(),
+ value: params.payload.value,
+ stripe_customer_id: params.payload.stripe_customer_id.clone(),
+ timestamp: params.timestamp,
+ });
+
+ Ok(())
+ }
}
@@ -1,7 +1,7 @@
use std::str::FromStr as _;
use std::sync::Arc;
-use anyhow::{Context as _, Result};
+use anyhow::{Context as _, Result, anyhow};
use async_trait::async_trait;
use serde::Serialize;
use stripe::{
@@ -12,9 +12,10 @@ use stripe::{
};
use crate::stripe_client::{
- CreateCustomerParams, StripeClient, StripeCustomer, StripeCustomerId, StripeMeter, StripePrice,
- StripePriceId, StripePriceRecurring, StripeSubscription, StripeSubscriptionId,
- StripeSubscriptionItem, StripeSubscriptionItemId, UpdateSubscriptionParams,
+ CreateCustomerParams, StripeClient, StripeCreateMeterEventParams, StripeCustomer,
+ StripeCustomerId, StripeMeter, StripePrice, StripePriceId, StripePriceRecurring,
+ StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
+ UpdateSubscriptionParams,
};
pub struct RealStripeClient {
@@ -129,6 +130,26 @@ impl StripeClient for RealStripeClient {
Ok(response.data)
}
+
+ async fn create_meter_event(&self, params: StripeCreateMeterEventParams<'_>) -> Result<()> {
+ let identifier = params.identifier;
+ match self.client.post_form("/billing/meter_events", params).await {
+ Ok(event) => Ok(event),
+ Err(stripe::StripeError::Stripe(error)) => {
+ if error.http_status == 400
+ && error
+ .message
+ .as_ref()
+ .map_or(false, |message| message.contains(identifier))
+ {
+ Ok(())
+ } else {
+ Err(anyhow!(stripe::StripeError::Stripe(error)))
+ }
+ }
+ Err(error) => Err(anyhow!(error)),
+ }
+ }
}
impl From<CustomerId> for StripeCustomerId {
@@ -4,9 +4,9 @@ use pretty_assertions::assert_eq;
use crate::stripe_billing::StripeBilling;
use crate::stripe_client::{
- FakeStripeClient, StripeMeter, StripeMeterId, StripePrice, StripePriceId, StripePriceRecurring,
- StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem, StripeSubscriptionItemId,
- UpdateSubscriptionItems,
+ FakeStripeClient, StripeCustomerId, StripeMeter, StripeMeterId, StripePrice, StripePriceId,
+ StripePriceRecurring, StripeSubscription, StripeSubscriptionId, StripeSubscriptionItem,
+ StripeSubscriptionItemId, UpdateSubscriptionItems,
};
fn make_stripe_billing() -> (StripeBilling, Arc<FakeStripeClient>) {
@@ -210,3 +210,34 @@ async fn test_subscribe_to_price() {
assert_eq!(stripe_client.update_subscription_calls.lock().len(), 1);
}
}
+
+#[gpui::test]
+async fn test_bill_model_request_usage() {
+ let (stripe_billing, stripe_client) = make_stripe_billing();
+
+ let customer_id = StripeCustomerId("cus_test".into());
+
+ stripe_billing
+ .bill_model_request_usage(&customer_id, "some_model/requests", 73)
+ .await
+ .unwrap();
+
+ let create_meter_event_calls = stripe_client
+ .create_meter_event_calls
+ .lock()
+ .iter()
+ .cloned()
+ .collect::<Vec<_>>();
+ assert_eq!(create_meter_event_calls.len(), 1);
+ assert!(
+ create_meter_event_calls[0]
+ .identifier
+ .starts_with("model_requests/")
+ );
+ assert_eq!(create_meter_event_calls[0].stripe_customer_id, customer_id);
+ assert_eq!(
+ create_meter_event_calls[0].event_name.as_ref(),
+ "some_model/requests"
+ );
+ assert_eq!(create_meter_event_calls[0].value, 73);
+}