Add token usage to `LanguageModelTextStream` (#27490)

Thomas Mickley-Doyle and Michael Sloan created

Release Notes:

- N/A

---------

Co-authored-by: Michael Sloan <michael@zed.dev>

Change summary

crates/assistant/src/inline_assistant.rs    |  3 +
crates/assistant2/src/buffer_codegen.rs     | 21 +++++++++++-
crates/language_model/src/language_model.rs | 38 ++++++++++++++++------
3 files changed, 49 insertions(+), 13 deletions(-)

Detailed changes

crates/assistant/src/inline_assistant.rs 🔗

@@ -3712,7 +3712,7 @@ mod tests {
         language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
         Point,
     };
-    use language_model::LanguageModelRegistry;
+    use language_model::{LanguageModelRegistry, TokenUsage};
     use rand::prelude::*;
     use serde::Serialize;
     use settings::SettingsStore;
@@ -4091,6 +4091,7 @@ mod tests {
                 future::ready(Ok(LanguageModelTextStream {
                     message_id: None,
                     stream: chunks_rx.map(Ok).boxed(),
+                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
                 })),
                 cx,
             );

crates/assistant2/src/buffer_codegen.rs 🔗

@@ -482,11 +482,17 @@ impl CodegenAlternative {
 
         self.generation = cx.spawn(async move |codegen, cx| {
             let stream = stream.await;
+            let token_usage = stream
+                .as_ref()
+                .ok()
+                .map(|stream| stream.last_token_usage.clone());
             let message_id = stream
                 .as_ref()
                 .ok()
                 .and_then(|stream| stream.message_id.clone());
             let generate = async {
+                let model_telemetry_id = model_telemetry_id.clone();
+                let model_provider_id = model_provider_id.clone();
                 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
                 let executor = cx.background_executor().clone();
                 let message_id = message_id.clone();
@@ -596,7 +602,7 @@ impl CodegenAlternative {
                                 kind: AssistantKind::Inline,
                                 phase: AssistantPhase::Response,
                                 model: model_telemetry_id,
-                                model_provider: model_provider_id.to_string(),
+                                model_provider: model_provider_id,
                                 response_latency,
                                 error_message,
                                 language_name: language_name.map(|name| name.to_proto()),
@@ -677,6 +683,16 @@ impl CodegenAlternative {
                     }
                     this.elapsed_time = Some(elapsed_time);
                     this.completion = Some(completion.lock().clone());
+                    if let Some(usage) = token_usage {
+                        let usage = usage.lock();
+                        telemetry::event!(
+                            "Inline Assistant Completion",
+                            model = model_telemetry_id,
+                            model_provider = model_provider_id,
+                            input_tokens = usage.input_tokens,
+                            output_tokens = usage.output_tokens,
+                        )
+                    }
                     cx.emit(CodegenEvent::Finished);
                     cx.notify();
                 })
@@ -1021,7 +1037,7 @@ mod tests {
         language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
         Point,
     };
-    use language_model::LanguageModelRegistry;
+    use language_model::{LanguageModelRegistry, TokenUsage};
     use rand::prelude::*;
     use serde::Serialize;
     use settings::SettingsStore;
@@ -1405,6 +1421,7 @@ mod tests {
                 future::ready(Ok(LanguageModelTextStream {
                     message_id: None,
                     stream: chunks_rx.map(Ok).boxed(),
+                    last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
                 })),
                 cx,
             );

crates/language_model/src/language_model.rs 🔗

@@ -14,6 +14,7 @@ use futures::FutureExt;
 use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt as _};
 use gpui::{AnyElement, AnyView, App, AsyncApp, SharedString, Task, Window};
 use icons::IconName;
+use parking_lot::Mutex;
 use proto::Plan;
 use schemars::JsonSchema;
 use serde::{de::DeserializeOwned, Deserialize, Serialize};
@@ -141,6 +142,8 @@ pub struct LanguageModelToolUse {
 pub struct LanguageModelTextStream {
     pub message_id: Option<String>,
     pub stream: BoxStream<'static, Result<String>>,
+    // Has complete token usage after the stream has finished
+    pub last_token_usage: Arc<Mutex<TokenUsage>>,
 }
 
 impl Default for LanguageModelTextStream {
@@ -148,6 +151,7 @@ impl Default for LanguageModelTextStream {
         Self {
             message_id: None,
             stream: Box::pin(futures::stream::empty()),
+            last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
         }
     }
 }
@@ -200,6 +204,7 @@ pub trait LanguageModel: Send + Sync {
             let mut events = events.await?.fuse();
             let mut message_id = None;
             let mut first_item_text = None;
+            let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
 
             if let Some(first_event) = events.next().await {
                 match first_event {
@@ -214,20 +219,33 @@ pub trait LanguageModel: Send + Sync {
             }
 
             let stream = futures::stream::iter(first_item_text.map(Ok))
-                .chain(events.filter_map(|result| async move {
-                    match result {
-                        Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
-                        Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
-                        Ok(LanguageModelCompletionEvent::Thinking(_)) => None,
-                        Ok(LanguageModelCompletionEvent::Stop(_)) => None,
-                        Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
-                        Ok(LanguageModelCompletionEvent::UsageUpdate(_)) => None,
-                        Err(err) => Some(Err(err)),
+                .chain(events.filter_map({
+                    let last_token_usage = last_token_usage.clone();
+                    move |result| {
+                        let last_token_usage = last_token_usage.clone();
+                        async move {
+                            match result {
+                                Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
+                                Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
+                                Ok(LanguageModelCompletionEvent::Thinking(_)) => None,
+                                Ok(LanguageModelCompletionEvent::Stop(_)) => None,
+                                Ok(LanguageModelCompletionEvent::ToolUse(_)) => None,
+                                Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
+                                    *last_token_usage.lock() = token_usage;
+                                    None
+                                }
+                                Err(err) => Some(Err(err)),
+                            }
+                        }
                     }
                 }))
                 .boxed();
 
-            Ok(LanguageModelTextStream { message_id, stream })
+            Ok(LanguageModelTextStream {
+                message_id,
+                stream,
+                last_token_usage,
+            })
         }
         .boxed()
     }