@@ -116,6 +116,7 @@ dependencies = [
"serde_json",
"settings",
"theme",
+ "tiktoken-rs",
"util",
"workspace",
]
@@ -745,6 +746,21 @@ dependencies = [
"which",
]
+[[package]]
+name = "bit-set"
+version = "0.5.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
+dependencies = [
+ "bit-vec",
+]
+
+[[package]]
+name = "bit-vec"
+version = "0.6.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
+
[[package]]
name = "bitflags"
version = "1.3.2"
@@ -870,6 +886,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3d4260bcc2e8fc9df1eac4919a720effeb63a3f0952f5bf4944adfa18897f09"
dependencies = [
"memchr",
+ "once_cell",
+ "regex-automata",
"serde",
]
@@ -2220,6 +2238,16 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
+[[package]]
+name = "fancy-regex"
+version = "0.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2"
+dependencies = [
+ "bit-set",
+ "regex",
+]
+
[[package]]
name = "fastrand"
version = "1.9.0"
@@ -6969,6 +6997,21 @@ dependencies = [
"weezl",
]
+[[package]]
+name = "tiktoken-rs"
+version = "0.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8ba161c549e2c0686f35f5d920e63fad5cafba2c28ad2caceaf07e5d9fa6e8c4"
+dependencies = [
+ "anyhow",
+ "base64 0.21.0",
+ "bstr",
+ "fancy-regex",
+ "lazy_static",
+ "parking_lot 0.12.1",
+ "rustc-hash",
+]
+
[[package]]
name = "time"
version = "0.1.45"
@@ -16,7 +16,8 @@ use gpui::{
use isahc::{http::StatusCode, Request, RequestExt};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
use settings::SettingsStore;
-use std::{cell::Cell, io, rc::Rc, sync::Arc};
+use std::{cell::Cell, io, rc::Rc, sync::Arc, time::Duration};
+use tiktoken_rs::model::get_context_size;
use util::{post_inc, ResultExt, TryFutureExt};
use workspace::{
dock::{DockPosition, Panel},
@@ -398,7 +399,12 @@ struct Assistant {
completion_count: usize,
pending_completions: Vec<PendingCompletion>,
languages: Arc<LanguageRegistry>,
+ model: String,
+ token_count: Option<usize>,
+ max_token_count: usize,
+ pending_token_count: Task<Option<()>>,
api_key: Rc<Cell<Option<String>>>,
+ _subscriptions: Vec<Subscription>,
}
impl Entity for Assistant {
@@ -411,19 +417,78 @@ impl Assistant {
language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>,
) -> Self {
+ let model = "gpt-3.5-turbo";
+ let buffer = cx.add_model(|_| MultiBuffer::new(0));
let mut this = Self {
- buffer: cx.add_model(|_| MultiBuffer::new(0)),
messages: Default::default(),
messages_by_id: Default::default(),
completion_count: Default::default(),
pending_completions: Default::default(),
languages: language_registry,
+ token_count: None,
+ max_token_count: get_context_size(model),
+ pending_token_count: Task::ready(None),
+ model: model.into(),
+ _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
api_key,
+ buffer,
};
this.push_message(Role::User, cx);
+ this.count_remaining_tokens(cx);
this
}
+ fn handle_buffer_event(
+ &mut self,
+ _: ModelHandle<MultiBuffer>,
+ event: &editor::multi_buffer::Event,
+ cx: &mut ModelContext<Self>,
+ ) {
+ match event {
+ editor::multi_buffer::Event::ExcerptsAdded { .. }
+ | editor::multi_buffer::Event::ExcerptsRemoved { .. }
+ | editor::multi_buffer::Event::Edited => self.count_remaining_tokens(cx),
+ _ => {}
+ }
+ }
+
+ fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
+ let messages = self
+ .messages
+ .iter()
+ .map(|message| tiktoken_rs::ChatCompletionRequestMessage {
+ role: match message.role {
+ Role::User => "user".into(),
+ Role::Assistant => "assistant".into(),
+ Role::System => "system".into(),
+ },
+ content: message.content.read(cx).text(),
+ name: None,
+ })
+ .collect::<Vec<_>>();
+ let model = self.model.clone();
+ self.pending_token_count = cx.spawn(|this, mut cx| {
+ async move {
+ cx.background().timer(Duration::from_millis(200)).await;
+ let token_count = cx
+ .background()
+ .spawn(async move { tiktoken_rs::num_tokens_from_messages(&model, &messages) })
+ .await?;
+
+ this.update(&mut cx, |this, cx| {
+ this.token_count = Some(token_count);
+ cx.notify()
+ });
+ anyhow::Ok(())
+ }
+ .log_err()
+ });
+ }
+
+ fn remaining_tokens(&self) -> Option<isize> {
+ Some(self.max_token_count as isize - self.token_count? as isize)
+ }
+
fn assist(&mut self, cx: &mut ModelContext<Self>) {
let messages = self
.messages
@@ -434,7 +499,7 @@ impl Assistant {
})
.collect();
let request = OpenAIRequest {
- model: "gpt-3.5-turbo".into(),
+ model: self.model.clone(),
messages,
stream: true,
};
@@ -530,6 +595,7 @@ struct PendingCompletion {
struct AssistantEditor {
assistant: ModelHandle<Assistant>,
editor: ViewHandle<Editor>,
+ _subscriptions: Vec<Subscription>,
}
impl AssistantEditor {
@@ -590,7 +656,11 @@ impl AssistantEditor {
);
editor
});
- Self { assistant, editor }
+ Self {
+ _subscriptions: vec![cx.observe(&assistant, |_, _, cx| cx.notify())],
+ assistant,
+ editor,
+ }
}
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
@@ -684,10 +754,34 @@ impl View for AssistantEditor {
fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
let theme = &theme::current(cx).assistant;
+ let remaining_tokens = self
+ .assistant
+ .read(cx)
+ .remaining_tokens()
+ .map(|remaining_tokens| {
+ let remaining_tokens_style = if remaining_tokens <= 0 {
+ &theme.no_remaining_tokens
+ } else {
+ &theme.remaining_tokens
+ };
+ Label::new(
+ remaining_tokens.to_string(),
+ remaining_tokens_style.text.clone(),
+ )
+ .contained()
+ .with_style(remaining_tokens_style.container)
+ .aligned()
+ .top()
+ .right()
+ });
- ChildView::new(&self.editor, cx)
- .contained()
- .with_style(theme.container)
+ Stack::new()
+ .with_child(
+ ChildView::new(&self.editor, cx)
+ .contained()
+ .with_style(theme.container),
+ )
+ .with_children(remaining_tokens)
.into_any()
}
@@ -23,6 +23,20 @@ export default function assistant(colorScheme: ColorScheme) {
margin: { top: 2, left: 8 },
...text(layer, "sans", "default", { size: "2xs" }),
},
+ remaining_tokens: {
+ padding: 4,
+ margin: { right: 16, top: 4 },
+ background: background(layer, "on"),
+ cornerRadius: 4,
+ ...text(layer, "sans", "positive", { size: "xs" }),
+ },
+ no_remaining_tokens: {
+ padding: 4,
+ margin: { right: 16, top: 4 },
+ background: background(layer, "on"),
+ cornerRadius: 4,
+ ...text(layer, "sans", "negative", { size: "xs" }),
+ },
apiKeyEditor: {
background: background(layer, "on"),
cornerRadius: 6,