@@ -15,10 +15,14 @@ use axum::{
};
use chrono::{DateTime, Duration, Utc};
use db::{ActiveUserCount, LlmDatabase};
-use futures::StreamExt as _;
+use futures::{Stream, StreamExt as _};
use http_client::IsahcHttpClient;
use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
-use std::sync::Arc;
+use std::{
+ pin::Pin,
+ sync::Arc,
+ task::{Context, Poll},
+};
use tokio::sync::RwLock;
use util::ResultExt;
@@ -155,7 +159,7 @@ async fn perform_completion(
check_usage_limit(&state, params.provider, &model, &claims).await?;
- match params.provider {
+ let stream = match params.provider {
LanguageModelProvider::Anthropic => {
let api_key = state
.config
@@ -185,37 +189,27 @@ async fn perform_completion(
)
.await?;
- let mut recorder = UsageRecorder {
- db: state.db.clone(),
- executor: state.executor.clone(),
- user_id,
- provider: params.provider,
- model,
- token_count: 0,
- };
-
- let stream = chunks.map(move |event| {
- let mut buffer = Vec::new();
- event.map(|chunk| {
- match &chunk {
+ chunks
+ .map(move |event| {
+ let chunk = event?;
+ let (input_tokens, output_tokens) = match &chunk {
anthropic::Event::MessageStart {
message: anthropic::Response { usage, .. },
}
- | anthropic::Event::MessageDelta { usage, .. } => {
- recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
- recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
- }
- _ => {}
- }
-
- buffer.clear();
- serde_json::to_writer(&mut buffer, &chunk).unwrap();
- buffer.push(b'\n');
- buffer
+ | anthropic::Event::MessageDelta { usage, .. } => (
+ usage.input_tokens.unwrap_or(0) as usize,
+ usage.output_tokens.unwrap_or(0) as usize,
+ ),
+ _ => (0, 0),
+ };
+
+ anyhow::Ok((
+ serde_json::to_vec(&chunk).unwrap(),
+ input_tokens,
+ output_tokens,
+ ))
})
- });
-
- Ok(Response::new(Body::wrap_stream(stream)))
+ .boxed()
}
LanguageModelProvider::OpenAi => {
let api_key = state
@@ -232,17 +226,21 @@ async fn perform_completion(
)
.await?;
- let stream = chunks.map(|event| {
- let mut buffer = Vec::new();
- event.map(|chunk| {
- buffer.clear();
- serde_json::to_writer(&mut buffer, &chunk).unwrap();
- buffer.push(b'\n');
- buffer
+ chunks
+ .map(|event| {
+ event.map(|chunk| {
+ let input_tokens =
+ chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
+ let output_tokens =
+ chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
+ (
+ serde_json::to_vec(&chunk).unwrap(),
+ input_tokens,
+ output_tokens,
+ )
+ })
})
- });
-
- Ok(Response::new(Body::wrap_stream(stream)))
+ .boxed()
}
LanguageModelProvider::Google => {
let api_key = state
@@ -258,17 +256,20 @@ async fn perform_completion(
)
.await?;
- let stream = chunks.map(|event| {
- let mut buffer = Vec::new();
- event.map(|chunk| {
- buffer.clear();
- serde_json::to_writer(&mut buffer, &chunk).unwrap();
- buffer.push(b'\n');
- buffer
+ chunks
+ .map(|event| {
+ event.map(|chunk| {
+ // TODO - implement token counting for Google AI
+ let input_tokens = 0;
+ let output_tokens = 0;
+ (
+ serde_json::to_vec(&chunk).unwrap(),
+ input_tokens,
+ output_tokens,
+ )
+ })
})
- });
-
- Ok(Response::new(Body::wrap_stream(stream)))
+ .boxed()
}
LanguageModelProvider::Zed => {
let api_key = state
@@ -290,19 +291,34 @@ async fn perform_completion(
)
.await?;
- let stream = chunks.map(|event| {
- let mut buffer = Vec::new();
- event.map(|chunk| {
- buffer.clear();
- serde_json::to_writer(&mut buffer, &chunk).unwrap();
- buffer.push(b'\n');
- buffer
+ chunks
+ .map(|event| {
+ event.map(|chunk| {
+ let input_tokens =
+ chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
+ let output_tokens =
+ chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
+ (
+ serde_json::to_vec(&chunk).unwrap(),
+ input_tokens,
+ output_tokens,
+ )
+ })
})
- });
-
- Ok(Response::new(Body::wrap_stream(stream)))
+ .boxed()
}
- }
+ };
+
+ Ok(Response::new(Body::wrap_stream(TokenCountingStream {
+ db: state.db.clone(),
+ executor: state.executor.clone(),
+ user_id,
+ provider: params.provider,
+ model,
+ input_tokens: 0,
+ output_tokens: 0,
+ inner_stream: stream,
+ })))
}
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
@@ -377,22 +393,46 @@ async fn check_usage_limit(
Ok(())
}
-struct UsageRecorder {
+
+struct TokenCountingStream<S> {
db: Arc<LlmDatabase>,
executor: Executor,
user_id: i32,
provider: LanguageModelProvider,
model: String,
- token_count: usize,
+ input_tokens: usize,
+ output_tokens: usize,
+ inner_stream: S,
+}
+
+impl<S> Stream for TokenCountingStream<S>
+where
+ S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
+{
+ type Item = Result<Vec<u8>, anyhow::Error>;
+
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ match Pin::new(&mut self.inner_stream).poll_next(cx) {
+ Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
+ bytes.push(b'\n');
+ self.input_tokens += input_tokens;
+ self.output_tokens += output_tokens;
+ Poll::Ready(Some(Ok(bytes)))
+ }
+ Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
+ Poll::Ready(None) => Poll::Ready(None),
+ Poll::Pending => Poll::Pending,
+ }
+ }
}
-impl Drop for UsageRecorder {
+impl<S> Drop for TokenCountingStream<S> {
fn drop(&mut self) {
let db = self.db.clone();
let user_id = self.user_id;
let provider = self.provider;
let model = std::mem::take(&mut self.model);
- let token_count = self.token_count;
+ let token_count = self.input_tokens + self.output_tokens;
self.executor.spawn_detached(async move {
db.record_usage(user_id, provider, &model, token_count, Utc::now())
.await