Detailed changes
@@ -422,6 +422,15 @@ CREATE TABLE dev_server_projects (
paths TEXT NOT NULL
);
+CREATE TABLE IF NOT EXISTS billing_preferences (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
+ user_id INTEGER NOT NULL REFERENCES users(id),
+ max_monthly_llm_usage_spending_in_cents INTEGER NOT NULL
+);
+
+CREATE UNIQUE INDEX "uix_billing_preferences_on_user_id" ON billing_preferences (user_id);
+
CREATE TABLE IF NOT EXISTS billing_customers (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
@@ -0,0 +1,8 @@
+create table if not exists billing_preferences (
+ id serial primary key,
+ created_at timestamp without time zone not null default now(),
+ user_id integer not null references users(id) on delete cascade,
+ max_monthly_llm_usage_spending_in_cents integer not null
+);
+
+create unique index "uix_billing_preferences_on_user_id" on billing_preferences (user_id);
@@ -26,15 +26,19 @@ use crate::db::billing_subscription::{self, StripeSubscriptionStatus};
use crate::db::{
billing_customer, BillingSubscriptionId, CreateBillingCustomerParams,
CreateBillingSubscriptionParams, CreateProcessedStripeEventParams, UpdateBillingCustomerParams,
- UpdateBillingSubscriptionParams,
+ UpdateBillingPreferencesParams, UpdateBillingSubscriptionParams,
};
use crate::llm::db::LlmDatabase;
-use crate::llm::FREE_TIER_MONTHLY_SPENDING_LIMIT;
+use crate::llm::{DEFAULT_MAX_MONTHLY_SPEND, FREE_TIER_MONTHLY_SPENDING_LIMIT};
use crate::rpc::ResultExt as _;
use crate::{AppState, Error, Result};
pub fn router() -> Router {
Router::new()
+ .route(
+ "/billing/preferences",
+ get(get_billing_preferences).put(update_billing_preferences),
+ )
.route(
"/billing/subscriptions",
get(list_billing_subscriptions).post(create_billing_subscription),
@@ -45,6 +49,82 @@ pub fn router() -> Router {
)
}
+#[derive(Debug, Deserialize)]
+struct GetBillingPreferencesParams {
+ github_user_id: i32,
+}
+
+#[derive(Debug, Serialize)]
+struct BillingPreferencesResponse {
+ max_monthly_llm_usage_spending_in_cents: i32,
+}
+
+async fn get_billing_preferences(
+ Extension(app): Extension<Arc<AppState>>,
+ Query(params): Query<GetBillingPreferencesParams>,
+) -> Result<Json<BillingPreferencesResponse>> {
+ let user = app
+ .db
+ .get_user_by_github_user_id(params.github_user_id)
+ .await?
+ .ok_or_else(|| anyhow!("user not found"))?;
+
+ let preferences = app.db.get_billing_preferences(user.id).await?;
+
+ Ok(Json(BillingPreferencesResponse {
+ max_monthly_llm_usage_spending_in_cents: preferences
+ .map_or(DEFAULT_MAX_MONTHLY_SPEND.0 as i32, |preferences| {
+ preferences.max_monthly_llm_usage_spending_in_cents
+ }),
+ }))
+}
+
+#[derive(Debug, Deserialize)]
+struct UpdateBillingPreferencesBody {
+ github_user_id: i32,
+ max_monthly_llm_usage_spending_in_cents: i32,
+}
+
+async fn update_billing_preferences(
+ Extension(app): Extension<Arc<AppState>>,
+ extract::Json(body): extract::Json<UpdateBillingPreferencesBody>,
+) -> Result<Json<BillingPreferencesResponse>> {
+ let user = app
+ .db
+ .get_user_by_github_user_id(body.github_user_id)
+ .await?
+ .ok_or_else(|| anyhow!("user not found"))?;
+
+ let billing_preferences =
+ if let Some(_billing_preferences) = app.db.get_billing_preferences(user.id).await? {
+ app.db
+ .update_billing_preferences(
+ user.id,
+ &UpdateBillingPreferencesParams {
+ max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
+ body.max_monthly_llm_usage_spending_in_cents,
+ ),
+ },
+ )
+ .await?
+ } else {
+ app.db
+ .create_billing_preferences(
+ user.id,
+ &crate::db::CreateBillingPreferencesParams {
+ max_monthly_llm_usage_spending_in_cents: body
+ .max_monthly_llm_usage_spending_in_cents,
+ },
+ )
+ .await?
+ };
+
+ Ok(Json(BillingPreferencesResponse {
+ max_monthly_llm_usage_spending_in_cents: billing_preferences
+ .max_monthly_llm_usage_spending_in_cents,
+ }))
+}
+
#[derive(Debug, Deserialize)]
struct ListBillingSubscriptionsParams {
github_user_id: i32,
@@ -42,6 +42,9 @@ pub use tests::TestDb;
pub use ids::*;
pub use queries::billing_customers::{CreateBillingCustomerParams, UpdateBillingCustomerParams};
+pub use queries::billing_preferences::{
+ CreateBillingPreferencesParams, UpdateBillingPreferencesParams,
+};
pub use queries::billing_subscriptions::{
CreateBillingSubscriptionParams, UpdateBillingSubscriptionParams,
};
@@ -72,6 +72,7 @@ macro_rules! id_type {
id_type!(AccessTokenId);
id_type!(BillingCustomerId);
id_type!(BillingSubscriptionId);
+id_type!(BillingPreferencesId);
id_type!(BufferId);
id_type!(ChannelBufferCollaboratorId);
id_type!(ChannelChatParticipantId);
@@ -2,6 +2,7 @@ use super::*;
pub mod access_tokens;
pub mod billing_customers;
+pub mod billing_preferences;
pub mod billing_subscriptions;
pub mod buffers;
pub mod channels;
@@ -0,0 +1,75 @@
+use super::*;
+
+#[derive(Debug)]
+pub struct CreateBillingPreferencesParams {
+ pub max_monthly_llm_usage_spending_in_cents: i32,
+}
+
+#[derive(Debug, Default)]
+pub struct UpdateBillingPreferencesParams {
+ pub max_monthly_llm_usage_spending_in_cents: ActiveValue<i32>,
+}
+
+impl Database {
+ /// Returns the billing preferences for the given user, if they exist.
+ pub async fn get_billing_preferences(
+ &self,
+ user_id: UserId,
+ ) -> Result<Option<billing_preference::Model>> {
+ self.transaction(|tx| async move {
+ Ok(billing_preference::Entity::find()
+ .filter(billing_preference::Column::UserId.eq(user_id))
+ .one(&*tx)
+ .await?)
+ })
+ .await
+ }
+
+ /// Creates new billing preferences for the given user.
+ pub async fn create_billing_preferences(
+ &self,
+ user_id: UserId,
+ params: &CreateBillingPreferencesParams,
+ ) -> Result<billing_preference::Model> {
+ self.transaction(|tx| async move {
+ let preferences = billing_preference::Entity::insert(billing_preference::ActiveModel {
+ user_id: ActiveValue::set(user_id),
+ max_monthly_llm_usage_spending_in_cents: ActiveValue::set(
+ params.max_monthly_llm_usage_spending_in_cents,
+ ),
+ ..Default::default()
+ })
+ .exec_with_returning(&*tx)
+ .await?;
+
+ Ok(preferences)
+ })
+ .await
+ }
+
+ /// Updates the billing preferences for the given user.
+ pub async fn update_billing_preferences(
+ &self,
+ user_id: UserId,
+ params: &UpdateBillingPreferencesParams,
+ ) -> Result<billing_preference::Model> {
+ self.transaction(|tx| async move {
+ let preferences = billing_preference::Entity::update_many()
+ .set(billing_preference::ActiveModel {
+ max_monthly_llm_usage_spending_in_cents: params
+ .max_monthly_llm_usage_spending_in_cents
+ .clone(),
+ ..Default::default()
+ })
+ .filter(billing_preference::Column::UserId.eq(user_id))
+ .exec_with_returning(&*tx)
+ .await?;
+
+ Ok(preferences
+ .into_iter()
+ .next()
+ .ok_or_else(|| anyhow!("billing preferences not found"))?)
+ })
+ .await
+ }
+}
@@ -1,5 +1,6 @@
pub mod access_token;
pub mod billing_customer;
+pub mod billing_preference;
pub mod billing_subscription;
pub mod buffer;
pub mod buffer_operation;
@@ -0,0 +1,30 @@
+use crate::db::{BillingPreferencesId, UserId};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "billing_preferences")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: BillingPreferencesId,
+ pub created_at: DateTime,
+ pub user_id: UserId,
+ pub max_monthly_llm_usage_spending_in_cents: i32,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::user::Entity",
+ from = "Column::UserId",
+ to = "super::user::Column::Id"
+ )]
+ User,
+}
+
+impl Related<super::user::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::User.def()
+ }
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -442,6 +442,12 @@ fn normalize_model_name(known_models: Vec<String>, name: String) -> String {
/// before they have to pay.
pub const FREE_TIER_MONTHLY_SPENDING_LIMIT: Cents = Cents::from_dollars(5);
+/// The default value to use for maximum spend per month if the user did not
+/// explicitly set a maximum spend.
+///
+/// Used to prevent surprise bills.
+pub const DEFAULT_MAX_MONTHLY_SPEND: Cents = Cents::from_dollars(10);
+
/// The maximum lifetime spending an individual user can reach before being cut off.
const LIFETIME_SPENDING_LIMIT: Cents = Cents::from_dollars(1_000);