Fix usage recording in llm service (#16044)

Max Brunsfeld and Marshall created

Release Notes:

- N/A

Co-authored-by: Marshall <marshall@zed.dev>

Change summary

crates/collab/src/llm.rs | 170 +++++++++++++++++++++++++----------------
1 file changed, 105 insertions(+), 65 deletions(-)

Detailed changes

crates/collab/src/llm.rs 🔗

@@ -15,10 +15,14 @@ use axum::{
 };
 use chrono::{DateTime, Duration, Utc};
 use db::{ActiveUserCount, LlmDatabase};
-use futures::StreamExt as _;
+use futures::{Stream, StreamExt as _};
 use http_client::IsahcHttpClient;
 use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
-use std::sync::Arc;
+use std::{
+    pin::Pin,
+    sync::Arc,
+    task::{Context, Poll},
+};
 use tokio::sync::RwLock;
 use util::ResultExt;
 
@@ -155,7 +159,7 @@ async fn perform_completion(
 
     check_usage_limit(&state, params.provider, &model, &claims).await?;
 
-    match params.provider {
+    let stream = match params.provider {
         LanguageModelProvider::Anthropic => {
             let api_key = state
                 .config
@@ -185,37 +189,27 @@ async fn perform_completion(
             )
             .await?;
 
-            let mut recorder = UsageRecorder {
-                db: state.db.clone(),
-                executor: state.executor.clone(),
-                user_id,
-                provider: params.provider,
-                model,
-                token_count: 0,
-            };
-
-            let stream = chunks.map(move |event| {
-                let mut buffer = Vec::new();
-                event.map(|chunk| {
-                    match &chunk {
+            chunks
+                .map(move |event| {
+                    let chunk = event?;
+                    let (input_tokens, output_tokens) = match &chunk {
                         anthropic::Event::MessageStart {
                             message: anthropic::Response { usage, .. },
                         }
-                        | anthropic::Event::MessageDelta { usage, .. } => {
-                            recorder.token_count += usage.input_tokens.unwrap_or(0) as usize;
-                            recorder.token_count += usage.output_tokens.unwrap_or(0) as usize;
-                        }
-                        _ => {}
-                    }
-
-                    buffer.clear();
-                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
-                    buffer.push(b'\n');
-                    buffer
+                        | anthropic::Event::MessageDelta { usage, .. } => (
+                            usage.input_tokens.unwrap_or(0) as usize,
+                            usage.output_tokens.unwrap_or(0) as usize,
+                        ),
+                        _ => (0, 0),
+                    };
+
+                    anyhow::Ok((
+                        serde_json::to_vec(&chunk).unwrap(),
+                        input_tokens,
+                        output_tokens,
+                    ))
                 })
-            });
-
-            Ok(Response::new(Body::wrap_stream(stream)))
+                .boxed()
         }
         LanguageModelProvider::OpenAi => {
             let api_key = state
@@ -232,17 +226,21 @@ async fn perform_completion(
             )
             .await?;
 
-            let stream = chunks.map(|event| {
-                let mut buffer = Vec::new();
-                event.map(|chunk| {
-                    buffer.clear();
-                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
-                    buffer.push(b'\n');
-                    buffer
+            chunks
+                .map(|event| {
+                    event.map(|chunk| {
+                        let input_tokens =
+                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
+                        let output_tokens =
+                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
+                        (
+                            serde_json::to_vec(&chunk).unwrap(),
+                            input_tokens,
+                            output_tokens,
+                        )
+                    })
                 })
-            });
-
-            Ok(Response::new(Body::wrap_stream(stream)))
+                .boxed()
         }
         LanguageModelProvider::Google => {
             let api_key = state
@@ -258,17 +256,20 @@ async fn perform_completion(
             )
             .await?;
 
-            let stream = chunks.map(|event| {
-                let mut buffer = Vec::new();
-                event.map(|chunk| {
-                    buffer.clear();
-                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
-                    buffer.push(b'\n');
-                    buffer
+            chunks
+                .map(|event| {
+                    event.map(|chunk| {
+                        // TODO - implement token counting for Google AI
+                        let input_tokens = 0;
+                        let output_tokens = 0;
+                        (
+                            serde_json::to_vec(&chunk).unwrap(),
+                            input_tokens,
+                            output_tokens,
+                        )
+                    })
                 })
-            });
-
-            Ok(Response::new(Body::wrap_stream(stream)))
+                .boxed()
         }
         LanguageModelProvider::Zed => {
             let api_key = state
@@ -290,19 +291,34 @@ async fn perform_completion(
             )
             .await?;
 
-            let stream = chunks.map(|event| {
-                let mut buffer = Vec::new();
-                event.map(|chunk| {
-                    buffer.clear();
-                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
-                    buffer.push(b'\n');
-                    buffer
+            chunks
+                .map(|event| {
+                    event.map(|chunk| {
+                        let input_tokens =
+                            chunk.usage.as_ref().map_or(0, |u| u.prompt_tokens) as usize;
+                        let output_tokens =
+                            chunk.usage.as_ref().map_or(0, |u| u.completion_tokens) as usize;
+                        (
+                            serde_json::to_vec(&chunk).unwrap(),
+                            input_tokens,
+                            output_tokens,
+                        )
+                    })
                 })
-            });
-
-            Ok(Response::new(Body::wrap_stream(stream)))
+                .boxed()
         }
-    }
+    };
+
+    Ok(Response::new(Body::wrap_stream(TokenCountingStream {
+        db: state.db.clone(),
+        executor: state.executor.clone(),
+        user_id,
+        provider: params.provider,
+        model,
+        input_tokens: 0,
+        output_tokens: 0,
+        inner_stream: stream,
+    })))
 }
 
 fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
@@ -377,22 +393,46 @@ async fn check_usage_limit(
 
     Ok(())
 }
-struct UsageRecorder {
+
+struct TokenCountingStream<S> {
     db: Arc<LlmDatabase>,
     executor: Executor,
     user_id: i32,
     provider: LanguageModelProvider,
     model: String,
-    token_count: usize,
+    input_tokens: usize,
+    output_tokens: usize,
+    inner_stream: S,
+}
+
+impl<S> Stream for TokenCountingStream<S>
+where
+    S: Stream<Item = Result<(Vec<u8>, usize, usize), anyhow::Error>> + Unpin,
+{
+    type Item = Result<Vec<u8>, anyhow::Error>;
+
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+        match Pin::new(&mut self.inner_stream).poll_next(cx) {
+            Poll::Ready(Some(Ok((mut bytes, input_tokens, output_tokens)))) => {
+                bytes.push(b'\n');
+                self.input_tokens += input_tokens;
+                self.output_tokens += output_tokens;
+                Poll::Ready(Some(Ok(bytes)))
+            }
+            Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
+            Poll::Ready(None) => Poll::Ready(None),
+            Poll::Pending => Poll::Pending,
+        }
+    }
 }
 
-impl Drop for UsageRecorder {
+impl<S> Drop for TokenCountingStream<S> {
     fn drop(&mut self) {
         let db = self.db.clone();
         let user_id = self.user_id;
         let provider = self.provider;
         let model = std::mem::take(&mut self.model);
-        let token_count = self.token_count;
+        let token_count = self.input_tokens + self.output_tokens;
         self.executor.spawn_detached(async move {
             db.record_usage(user_id, provider, &model, token_count, Utc::now())
                 .await