@@ -27,7 +27,7 @@ pub use token::*;
pub struct LlmState {
pub config: Config,
pub executor: Executor,
- pub db: Option<Arc<LlmDatabase>>,
+ pub db: Arc<LlmDatabase>,
pub http_client: IsahcHttpClient,
active_user_count: RwLock<Option<(DateTime<Utc>, ActiveUserCount)>>,
}
@@ -36,25 +36,20 @@ const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30);
impl LlmState {
pub async fn new(config: Config, executor: Executor) -> Result<Arc<Self>> {
- // TODO: This is temporary until we have the LLM database stood up.
- let db = if config.is_development() {
- let database_url = config
- .llm_database_url
- .as_ref()
- .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
- let max_connections = config
- .llm_database_max_connections
- .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
-
- let mut db_options = db::ConnectOptions::new(database_url);
- db_options.max_connections(max_connections);
- let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
- db.initialize().await?;
-
- Some(Arc::new(db))
- } else {
- None
- };
+ let database_url = config
+ .llm_database_url
+ .as_ref()
+ .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?;
+ let max_connections = config
+ .llm_database_max_connections
+ .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?;
+
+ let mut db_options = db::ConnectOptions::new(database_url);
+ db_options.max_connections(max_connections);
+ let mut db = LlmDatabase::new(db_options, executor.clone()).await?;
+ db.initialize().await?;
+
+ let db = Arc::new(db);
let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION"));
let http_client = IsahcHttpClient::builder()
@@ -62,11 +57,8 @@ impl LlmState {
.build()
.context("failed to construct http client")?;
- let initial_active_user_count = if let Some(db) = &db {
- Some((Utc::now(), db.get_active_user_count(Utc::now()).await?))
- } else {
- None
- };
+ let initial_active_user_count =
+ Some((Utc::now(), db.get_active_user_count(Utc::now()).await?));
let this = Self {
config,
@@ -88,14 +80,10 @@ impl LlmState {
}
}
- if let Some(db) = &self.db {
- let mut cache = self.active_user_count.write().await;
- let new_count = db.get_active_user_count(now).await?;
- *cache = Some((now, new_count));
- Ok(new_count)
- } else {
- Ok(ActiveUserCount::default())
- }
+ let mut cache = self.active_user_count.write().await;
+ let new_count = self.db.get_active_user_count(now).await?;
+ *cache = Some((now, new_count));
+ Ok(new_count)
}
}
@@ -165,9 +153,7 @@ async fn perform_completion(
let user_id = claims.user_id as i32;
- if state.db.is_some() {
- check_usage_limit(&state, params.provider, &model, &claims).await?;
- }
+ check_usage_limit(&state, params.provider, &model, &claims).await?;
match params.provider {
LanguageModelProvider::Anthropic => {
@@ -199,14 +185,14 @@ async fn perform_completion(
)
.await?;
- let mut recorder = state.db.clone().map(|db| UsageRecorder {
- db,
+ let mut recorder = UsageRecorder {
+ db: state.db.clone(),
executor: state.executor.clone(),
user_id,
provider: params.provider,
model,
token_count: 0,
- });
+ };
let stream = chunks.map(move |event| {
let mut buffer = Vec::new();
@@ -216,10 +202,8 @@ async fn perform_completion(
message: anthropic::Response { usage, .. },
}
| anthropic::Event::MessageDelta { usage, .. } => {
- if let Some(recorder) = &mut recorder {
- recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
- recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
- }
+ recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
+ recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
}
_ => {}
}
@@ -349,12 +333,9 @@ async fn check_usage_limit(
model_name: &str,
claims: &LlmTokenClaims,
) -> Result<()> {
- let db = state
+ let model = state.db.model(provider, model_name)?;
+ let usage = state
.db
- .as_ref()
- .ok_or_else(|| anyhow!("LLM database not configured"))?;
- let model = db.model(provider, model_name)?;
- let usage = db
.get_usage(claims.user_id as i32, provider, model_name, Utc::now())
.await?;
@@ -248,11 +248,6 @@ async fn setup_app_database(config: &Config) -> Result<()> {
}
async fn setup_llm_database(config: &Config) -> Result<()> {
- // TODO: This is temporary until we have the LLM database stood up.
- if !config.is_development() {
- return Ok(());
- }
-
let database_url = config
.llm_database_url
.as_ref()
@@ -298,7 +293,12 @@ async fn handle_liveness_probe(
state.db.get_all_users(0, 1).await?;
}
- if let Some(_llm_state) = llm_state {}
+ if let Some(llm_state) = llm_state {
+ llm_state
+ .db
+ .get_active_user_count(chrono::Utc::now())
+ .await?;
+ }
Ok("ok".to_string())
}