Show remaining tokens

Antonio Scandurra created

Change summary

Cargo.lock                        |  43 +++++++++++++
crates/ai/Cargo.toml              |   1 
crates/ai/src/assistant.rs        | 108 ++++++++++++++++++++++++++++++--
crates/editor/src/editor.rs       |   2 
crates/theme/src/theme.rs         |   2 
styles/src/styleTree/assistant.ts |  14 ++++
6 files changed, 162 insertions(+), 8 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -116,6 +116,7 @@ dependencies = [
  "serde_json",
  "settings",
  "theme",
+ "tiktoken-rs",
  "util",
  "workspace",
 ]
@@ -745,6 +746,21 @@ dependencies = [
  "which",
 ]
 
+[[package]]
+name = "bit-set"
+version = "0.5.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
+dependencies = [
+ "bit-vec",
+]
+
+[[package]]
+name = "bit-vec"
+version = "0.6.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
+
 [[package]]
 name = "bitflags"
 version = "1.3.2"
@@ -870,6 +886,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "c3d4260bcc2e8fc9df1eac4919a720effeb63a3f0952f5bf4944adfa18897f09"
 dependencies = [
  "memchr",
+ "once_cell",
+ "regex-automata",
  "serde",
 ]
 
@@ -2220,6 +2238,16 @@ version = "0.2.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
 
+[[package]]
+name = "fancy-regex"
+version = "0.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2"
+dependencies = [
+ "bit-set",
+ "regex",
+]
+
 [[package]]
 name = "fastrand"
 version = "1.9.0"
@@ -6969,6 +6997,21 @@ dependencies = [
  "weezl",
 ]
 
+[[package]]
+name = "tiktoken-rs"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8ba161c549e2c0686f35f5d920e63fad5cafba2c28ad2caceaf07e5d9fa6e8c4"
+dependencies = [
+ "anyhow",
+ "base64 0.21.0",
+ "bstr",
+ "fancy-regex",
+ "lazy_static",
+ "parking_lot 0.12.1",
+ "rustc-hash",
+]
+
 [[package]]
 name = "time"
 version = "0.1.45"

crates/ai/Cargo.toml 🔗

@@ -29,6 +29,7 @@ isahc.workspace = true
 schemars.workspace = true
 serde.workspace = true
 serde_json.workspace = true
+tiktoken-rs = "0.4"
 
 [dev-dependencies]
 editor = { path = "../editor", features = ["test-support"] }

crates/ai/src/assistant.rs 🔗

@@ -16,7 +16,8 @@ use gpui::{
 use isahc::{http::StatusCode, Request, RequestExt};
 use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
 use settings::SettingsStore;
-use std::{cell::Cell, io, rc::Rc, sync::Arc};
+use std::{cell::Cell, io, rc::Rc, sync::Arc, time::Duration};
+use tiktoken_rs::model::get_context_size;
 use util::{post_inc, ResultExt, TryFutureExt};
 use workspace::{
     dock::{DockPosition, Panel},
@@ -398,7 +399,12 @@ struct Assistant {
     completion_count: usize,
     pending_completions: Vec<PendingCompletion>,
     languages: Arc<LanguageRegistry>,
+    model: String,
+    token_count: Option<usize>,
+    max_token_count: usize,
+    pending_token_count: Task<Option<()>>,
     api_key: Rc<Cell<Option<String>>>,
+    _subscriptions: Vec<Subscription>,
 }
 
 impl Entity for Assistant {
@@ -411,19 +417,78 @@ impl Assistant {
         language_registry: Arc<LanguageRegistry>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
+        let model = "gpt-3.5-turbo";
+        let buffer = cx.add_model(|_| MultiBuffer::new(0));
         let mut this = Self {
-            buffer: cx.add_model(|_| MultiBuffer::new(0)),
             messages: Default::default(),
             messages_by_id: Default::default(),
             completion_count: Default::default(),
             pending_completions: Default::default(),
             languages: language_registry,
+            token_count: None,
+            max_token_count: get_context_size(model),
+            pending_token_count: Task::ready(None),
+            model: model.into(),
+            _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
             api_key,
+            buffer,
         };
         this.push_message(Role::User, cx);
+        this.count_remaining_tokens(cx);
         this
     }
 
+    fn handle_buffer_event(
+        &mut self,
+        _: ModelHandle<MultiBuffer>,
+        event: &editor::multi_buffer::Event,
+        cx: &mut ModelContext<Self>,
+    ) {
+        match event {
+            editor::multi_buffer::Event::ExcerptsAdded { .. }
+            | editor::multi_buffer::Event::ExcerptsRemoved { .. }
+            | editor::multi_buffer::Event::Edited => self.count_remaining_tokens(cx),
+            _ => {}
+        }
+    }
+
+    fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
+        let messages = self
+            .messages
+            .iter()
+            .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+                role: match message.role {
+                    Role::User => "user".into(),
+                    Role::Assistant => "assistant".into(),
+                    Role::System => "system".into(),
+                },
+                content: message.content.read(cx).text(),
+                name: None,
+            })
+            .collect::<Vec<_>>();
+        let model = self.model.clone();
+        self.pending_token_count = cx.spawn(|this, mut cx| {
+            async move {
+                cx.background().timer(Duration::from_millis(200)).await;
+                let token_count = cx
+                    .background()
+                    .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
+                    .await?;
+
+                this.update(&mut cx, |this, cx| {
+                    this.token_count = Some(token_count);
+                    cx.notify()
+                });
+                anyhow::Ok(())
+            }
+            .log_err()
+        });
+    }
+
+    fn remaining_tokens(&self) -> Option<isize> {
+        Some(self.max_token_count as isize - self.token_count? as isize)
+    }
+
     fn assist(&mut self, cx: &mut ModelContext<Self>) {
         let messages = self
             .messages
@@ -434,7 +499,7 @@ impl Assistant {
             })
             .collect();
         let request = OpenAIRequest {
-            model: "gpt-3.5-turbo".into(),
+            model: self.model.clone(),
             messages,
             stream: true,
         };
@@ -530,6 +595,7 @@ struct PendingCompletion {
 struct AssistantEditor {
     assistant: ModelHandle<Assistant>,
     editor: ViewHandle<Editor>,
+    _subscriptions: Vec<Subscription>,
 }
 
 impl AssistantEditor {
@@ -590,7 +656,11 @@ impl AssistantEditor {
             );
             editor
         });
-        Self { assistant, editor }
+        Self {
+            _subscriptions: vec![cx.observe(&assistant, |_, _, cx| cx.notify())],
+            assistant,
+            editor,
+        }
     }
 
     fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
@@ -684,10 +754,34 @@ impl View for AssistantEditor {
 
     fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
         let theme = &theme::current(cx).assistant;
+        let remaining_tokens = self
+            .assistant
+            .read(cx)
+            .remaining_tokens()
+            .map(|remaining_tokens| {
+                let remaining_tokens_style = if remaining_tokens <= 0 {
+                    &theme.no_remaining_tokens
+                } else {
+                    &theme.remaining_tokens
+                };
+                Label::new(
+                    remaining_tokens.to_string(),
+                    remaining_tokens_style.text.clone(),
+                )
+                .contained()
+                .with_style(remaining_tokens_style.container)
+                .aligned()
+                .top()
+                .right()
+            });
 
-        ChildView::new(&self.editor, cx)
-            .contained()
-            .with_style(theme.container)
+        Stack::new()
+            .with_child(
+                ChildView::new(&self.editor, cx)
+                    .contained()
+                    .with_style(theme.container),
+            )
+            .with_children(remaining_tokens)
             .into_any()
     }
 

crates/editor/src/editor.rs 🔗

@@ -10,7 +10,7 @@ pub mod items;
 mod link_go_to_definition;
 mod mouse_context_menu;
 pub mod movement;
-mod multi_buffer;
+pub mod multi_buffer;
 mod persistence;
 pub mod scroll;
 pub mod selections_collection;

crates/theme/src/theme.rs 🔗

@@ -976,6 +976,8 @@ pub struct AssistantStyle {
     pub sent_at: ContainedText,
     pub user_sender: ContainedText,
     pub assistant_sender: ContainedText,
+    pub remaining_tokens: ContainedText,
+    pub no_remaining_tokens: ContainedText,
     pub api_key_editor: FieldEditor,
     pub api_key_prompt: ContainedText,
 }

styles/src/styleTree/assistant.ts 🔗

@@ -23,6 +23,20 @@ export default function assistant(colorScheme: ColorScheme) {
         margin: { top: 2, left: 8 },
         ...text(layer, "sans", "default", { size: "2xs" }),
       },
+      remaining_tokens: {
+        padding: 4,
+        margin: { right: 16, top: 4 },
+        background: background(layer, "on"),
+        cornerRadius: 4,
+        ...text(layer, "sans", "positive", { size: "xs" }),
+      },
+      no_remaining_tokens: {
+        padding: 4,
+        margin: { right: 16, top: 4 },
+        background: background(layer, "on"),
+        cornerRadius: 4,
+        ...text(layer, "sans", "negative", { size: "xs" }),
+      },
       apiKeyEditor: {
           background: background(layer, "on"),
           cornerRadius: 6,