Detailed changes
@@ -108,7 +108,7 @@ dependencies = [
"rusqlite",
"serde",
"serde_json",
- "tiktoken-rs 0.5.4",
+ "tiktoken-rs",
"util",
]
@@ -327,7 +327,7 @@ dependencies = [
"settings",
"smol",
"theme",
- "tiktoken-rs 0.4.5",
+ "tiktoken-rs",
"util",
"uuid 1.4.1",
"workspace",
@@ -6798,7 +6798,7 @@ dependencies = [
"smol",
"tempdir",
"theme",
- "tiktoken-rs 0.5.4",
+ "tiktoken-rs",
"tree-sitter",
"tree-sitter-cpp",
"tree-sitter-elixir",
@@ -7875,21 +7875,6 @@ dependencies = [
"weezl",
]
-[[package]]
-name = "tiktoken-rs"
-version = "0.4.5"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "52aacc1cff93ba9d5f198c62c49c77fa0355025c729eed3326beaf7f33bc8614"
-dependencies = [
- "anyhow",
- "base64 0.21.4",
- "bstr",
- "fancy-regex",
- "lazy_static",
- "parking_lot 0.12.1",
- "rustc-hash",
-]
-
[[package]]
name = "tiktoken-rs"
version = "0.5.4"
@@ -38,7 +38,7 @@ schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
smol.workspace = true
-tiktoken-rs = "0.4"
+tiktoken-rs = "0.5"
[dev-dependencies]
editor = { path = "../editor", features = ["test-support"] }
@@ -437,8 +437,15 @@ impl AssistantPanel {
InlineAssistantEvent::Confirmed {
prompt,
include_conversation,
+ retrieve_context,
} => {
- self.confirm_inline_assist(assist_id, prompt, *include_conversation, cx);
+ self.confirm_inline_assist(
+ assist_id,
+ prompt,
+ *include_conversation,
+ cx,
+ *retrieve_context,
+ );
}
InlineAssistantEvent::Canceled => {
self.finish_inline_assist(assist_id, true, cx);
@@ -532,6 +539,7 @@ impl AssistantPanel {
user_prompt: &str,
include_conversation: bool,
cx: &mut ViewContext<Self>,
+ retrieve_context: bool,
) {
let conversation = if include_conversation {
self.active_editor()
@@ -593,42 +601,49 @@ impl AssistantPanel {
let codegen_kind = codegen.read(cx).kind().clone();
let user_prompt = user_prompt.to_string();
- let project = if let Some(workspace) = self.workspace.upgrade(cx) {
- workspace.read(cx).project()
- } else {
- return;
- };
+ let snippets = if retrieve_context {
+ let project = if let Some(workspace) = self.workspace.upgrade(cx) {
+ workspace.read(cx).project()
+ } else {
+ return;
+ };
- let project = project.to_owned();
- let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
- let search_results = semantic_index.update(cx, |this, cx| {
- this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
- });
+ let project = project.to_owned();
+ let search_results = if let Some(semantic_index) = self.semantic_index.clone() {
+ let search_results = semantic_index.update(cx, |this, cx| {
+ this.search_project(project, user_prompt.to_string(), 10, vec![], vec![], cx)
+ });
- cx.background()
- .spawn(async move { search_results.await.unwrap_or_default() })
+ cx.background()
+ .spawn(async move { search_results.await.unwrap_or_default() })
+ } else {
+ Task::ready(Vec::new())
+ };
+
+ let snippets = cx.spawn(|_, cx| async move {
+ let mut snippets = Vec::new();
+ for result in search_results.await {
+ snippets.push(result.buffer.read_with(&cx, |buffer, _| {
+ buffer
+ .snapshot()
+ .text_for_range(result.range)
+ .collect::<String>()
+ }));
+ }
+ snippets
+ });
+ snippets
} else {
Task::ready(Vec::new())
};
- let snippets = cx.spawn(|_, cx| async move {
- let mut snippets = Vec::new();
- for result in search_results.await {
- snippets.push(result.buffer.read_with(&cx, |buffer, _| {
- buffer
- .snapshot()
- .text_for_range(result.range)
- .collect::<String>()
- }));
- }
- snippets
- });
+ let mut model = settings::get::<AssistantSettings>(cx)
+ .default_open_ai_model
+ .clone();
+ let model_name = model.full_name();
let prompt = cx.background().spawn(async move {
let snippets = snippets.await;
- for snippet in &snippets {
- println!("SNIPPET: \n{:?}", snippet);
- }
let language_name = language_name.as_deref();
generate_content_prompt(
@@ -638,13 +653,11 @@ impl AssistantPanel {
range,
codegen_kind,
snippets,
+ model_name,
)
});
let mut messages = Vec::new();
- let mut model = settings::get::<AssistantSettings>(cx)
- .default_open_ai_model
- .clone();
if let Some(conversation) = conversation {
let conversation = conversation.read(cx);
let buffer = conversation.buffer.read(cx);
@@ -1557,12 +1570,14 @@ impl Conversation {
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
- content: self
- .buffer
- .read(cx)
- .text_for_range(message.offset_range)
- .collect(),
+ content: Some(
+ self.buffer
+ .read(cx)
+ .text_for_range(message.offset_range)
+ .collect(),
+ ),
name: None,
+ function_call: None,
})
})
.collect::<Vec<_>>();
@@ -2681,6 +2696,7 @@ enum InlineAssistantEvent {
Confirmed {
prompt: String,
include_conversation: bool,
+ retrieve_context: bool,
},
Canceled,
Dismissed,
@@ -2922,6 +2938,7 @@ impl InlineAssistant {
cx.emit(InlineAssistantEvent::Confirmed {
prompt,
include_conversation: self.include_conversation,
+ retrieve_context: self.retrieve_context,
});
self.confirmed = true;
cx.notify();
@@ -1,8 +1,10 @@
use crate::codegen::CodegenKind;
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
use std::cmp;
+use std::fmt::Write;
+use std::iter;
use std::ops::Range;
-use std::{fmt::Write, iter};
+use tiktoken_rs::ChatCompletionRequestMessage;
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
#[derive(Debug)]
@@ -122,69 +124,103 @@ pub fn generate_content_prompt(
range: Range<impl ToOffset>,
kind: CodegenKind,
search_results: Vec<String>,
+ model: &str,
) -> String {
- let mut prompt = String::new();
+ const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+
+ let mut prompts = Vec::new();
// General Preamble
if let Some(language_name) = language_name {
- writeln!(prompt, "You're an expert {language_name} engineer.\n").unwrap();
+ prompts.push(format!("You're an expert {language_name} engineer.\n"));
} else {
- writeln!(prompt, "You're an expert engineer.\n").unwrap();
+ prompts.push("You're an expert engineer.\n".to_string());
}
+ // Snippets
+ let mut snippet_position = prompts.len() - 1;
+
let outline = summarize(buffer, range);
- writeln!(
- prompt,
- "The file you are currently working on has the following outline:"
- )
- .unwrap();
+ prompts.push("The file you are currently working on has the following outline:".to_string());
if let Some(language_name) = language_name {
let language_name = language_name.to_lowercase();
- writeln!(prompt, "```{language_name}\n{outline}\n```").unwrap();
+ prompts.push(format!("```{language_name}\n{outline}\n```"));
} else {
- writeln!(prompt, "```\n{outline}\n```").unwrap();
+ prompts.push(format!("```\n{outline}\n```"));
}
match kind {
CodegenKind::Generate { position: _ } => {
- writeln!(prompt, "In particular, the user's cursor is current on the '<|START|>' span in the above outline, with no text selected.").unwrap();
- writeln!(
- prompt,
- "Assume the cursor is located where the `<|START|` marker is."
- )
- .unwrap();
- writeln!(
- prompt,
+ prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
+ prompts
+ .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
+ prompts.push(
"Text can't be replaced, so assume your answer will be inserted at the cursor."
- )
- .unwrap();
- writeln!(
- prompt,
+ .to_string(),
+ );
+ prompts.push(format!(
"Generate text based on the users prompt: {user_prompt}"
- )
- .unwrap();
+ ));
}
CodegenKind::Transform { range: _ } => {
- writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
- writeln!(
- prompt,
+ prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
+ prompts.push(format!(
"Modify the users code selected text based upon the users prompt: {user_prompt}"
- )
- .unwrap();
- writeln!(
- prompt,
- "You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file."
- )
- .unwrap();
+ ));
+ prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
}
}
if let Some(language_name) = language_name {
- writeln!(prompt, "Your answer MUST always be valid {language_name}").unwrap();
+ prompts.push(format!("Your answer MUST always be valid {language_name}"));
+ }
+ prompts.push("Always wrap your response in a Markdown codeblock".to_string());
+ prompts.push("Never make remarks about the output.".to_string());
+
+ let current_messages = [ChatCompletionRequestMessage {
+ role: "user".to_string(),
+ content: Some(prompts.join("\n")),
+ function_call: None,
+ name: None,
+ }];
+
+ let remaining_token_count = if let Ok(current_token_count) =
+ tiktoken_rs::num_tokens_from_messages(model, ¤t_messages)
+ {
+ let max_token_count = tiktoken_rs::model::get_context_size(model);
+ max_token_count - current_token_count
+ } else {
+ // If tiktoken fails to count token count, assume we have no space remaining.
+ 0
+ };
+
+ // TODO:
+ // - add repository name to snippet
+ // - add file path
+ // - add language
+ if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
+ let template = "You are working inside a large repository, here are a few code snippets that may be useful";
+
+ for search_result in search_results {
+ let mut snippet_prompt = template.to_string();
+ writeln!(snippet_prompt, "```\n{search_result}\n```").unwrap();
+
+ let token_count = encoding
+ .encode_with_special_tokens(snippet_prompt.as_str())
+ .len();
+ if token_count <= remaining_token_count {
+ if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
+ prompts.insert(snippet_position, snippet_prompt);
+ snippet_position += 1;
+ }
+ } else {
+ break;
+ }
+ }
}
- writeln!(prompt, "Always wrap your response in a Markdown codeblock").unwrap();
- writeln!(prompt, "Never make remarks about the output.").unwrap();
+ let prompt = prompts.join("\n");
+ println!("PROMPT: {:?}", prompt);
prompt
}