@@ -28,7 +28,7 @@ use gpui::{
FontWeight, Global, HighlightStyle, Model, ModelContext, Subscription, Task, TextStyle,
UpdateGlobal, View, ViewContext, WeakView, WindowContext,
};
-use language::{Buffer, IndentKind, Point, TransactionId};
+use language::{Buffer, IndentKind, Point, Selection, TransactionId};
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
@@ -38,6 +38,7 @@ use rope::Rope;
use settings::Settings;
use smol::future::FutureExt;
use std::{
+ cmp,
future::{self, Future},
mem,
ops::{Range, RangeInclusive},
@@ -46,7 +47,6 @@ use std::{
task::{self, Poll},
time::{Duration, Instant},
};
-use text::OffsetRangeExt as _;
use theme::ThemeSettings;
use ui::{prelude::*, CheckboxWithLabel, IconButtonShape, Popover, Tooltip};
use util::{RangeExt, ResultExt};
@@ -140,81 +140,66 @@ impl InlineAssistant {
cx: &mut WindowContext,
) {
let snapshot = editor.read(cx).buffer().read(cx).snapshot(cx);
- struct CodegenRange {
- transform_range: Range<Point>,
- selection_ranges: Vec<Range<Point>>,
- focus_assist: bool,
- }
- let newest_selection_range = editor.read(cx).selections.newest::<Point>(cx).range();
- let mut codegen_ranges: Vec<CodegenRange> = Vec::new();
-
- let selection_ranges = snapshot
- .split_ranges(editor.read(cx).selections.disjoint_anchor_ranges())
- .map(|range| range.to_point(&snapshot))
- .collect::<Vec<Range<Point>>>();
-
- for selection_range in selection_ranges {
- let selection_is_newest = newest_selection_range.contains_inclusive(&selection_range);
- let mut transform_range = selection_range.start..selection_range.end;
-
- // Expand the transform range to start/end of lines.
- // If a non-empty selection ends at the start of the last line, clip at the end of the penultimate line.
- transform_range.start.column = 0;
- if transform_range.end.column == 0 && transform_range.end > transform_range.start {
- transform_range.end.row -= 1;
- }
- transform_range.end.column = snapshot.line_len(MultiBufferRow(transform_range.end.row));
- let selection_range =
- selection_range.start..selection_range.end.min(transform_range.end);
-
- // If we intersect the previous transform range,
- if let Some(CodegenRange {
- transform_range: prev_transform_range,
- selection_ranges,
- focus_assist,
- }) = codegen_ranges.last_mut()
- {
- if transform_range.start <= prev_transform_range.end {
- prev_transform_range.end = transform_range.end;
- selection_ranges.push(selection_range);
- *focus_assist |= selection_is_newest;
+ let mut selections = Vec::<Selection<Point>>::new();
+ let mut newest_selection = None;
+ for mut selection in editor.read(cx).selections.all::<Point>(cx) {
+ if selection.end > selection.start {
+ selection.start.column = 0;
+ // If the selection ends at the start of the line, we don't want to include it.
+ if selection.end.column == 0 {
+ selection.end.row -= 1;
+ }
+ selection.end.column = snapshot.line_len(MultiBufferRow(selection.end.row));
+ }
+
+ if let Some(prev_selection) = selections.last_mut() {
+ if selection.start <= prev_selection.end {
+ prev_selection.end = selection.end;
continue;
}
}
- codegen_ranges.push(CodegenRange {
- transform_range,
- selection_ranges: vec![selection_range],
- focus_assist: selection_is_newest,
- })
+ let latest_selection = newest_selection.get_or_insert_with(|| selection.clone());
+ if selection.id > latest_selection.id {
+ *latest_selection = selection.clone();
+ }
+ selections.push(selection);
+ }
+ let newest_selection = newest_selection.unwrap();
+
+ let mut codegen_ranges = Vec::new();
+ for (excerpt_id, buffer, buffer_range) in
+ snapshot.excerpts_in_ranges(selections.iter().map(|selection| {
+ snapshot.anchor_before(selection.start)..snapshot.anchor_after(selection.end)
+ }))
+ {
+ let start = Anchor {
+ buffer_id: Some(buffer.remote_id()),
+ excerpt_id,
+ text_anchor: buffer.anchor_before(buffer_range.start),
+ };
+ let end = Anchor {
+ buffer_id: Some(buffer.remote_id()),
+ excerpt_id,
+ text_anchor: buffer.anchor_after(buffer_range.end),
+ };
+ codegen_ranges.push(start..end);
}
let assist_group_id = self.next_assist_group_id.post_inc();
let prompt_buffer =
cx.new_model(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx));
let prompt_buffer = cx.new_model(|cx| MultiBuffer::singleton(prompt_buffer, cx));
+
let mut assists = Vec::new();
let mut assist_to_focus = None;
-
- for CodegenRange {
- transform_range,
- selection_ranges,
- focus_assist,
- } in codegen_ranges
- {
- let transform_range = snapshot.anchor_before(transform_range.start)
- ..snapshot.anchor_after(transform_range.end);
- let selection_ranges = selection_ranges
- .iter()
- .map(|range| snapshot.anchor_before(range.start)..snapshot.anchor_after(range.end))
- .collect::<Vec<_>>();
-
+ for range in codegen_ranges {
+ let assist_id = self.next_assist_id.post_inc();
let codegen = cx.new_model(|cx| {
Codegen::new(
editor.read(cx).buffer().clone(),
- transform_range.clone(),
- selection_ranges,
+ range.clone(),
None,
self.telemetry.clone(),
self.prompt_builder.clone(),
@@ -222,7 +207,6 @@ impl InlineAssistant {
)
});
- let assist_id = self.next_assist_id.post_inc();
let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
let prompt_editor = cx.new_view(|cx| {
PromptEditor::new(
@@ -239,16 +223,23 @@ impl InlineAssistant {
)
});
- if focus_assist {
- assist_to_focus = Some(assist_id);
+ if assist_to_focus.is_none() {
+ let focus_assist = if newest_selection.reversed {
+ range.start.to_point(&snapshot) == newest_selection.start
+ } else {
+ range.end.to_point(&snapshot) == newest_selection.end
+ };
+ if focus_assist {
+ assist_to_focus = Some(assist_id);
+ }
}
let [prompt_block_id, end_block_id] =
- self.insert_assist_blocks(editor, &transform_range, &prompt_editor, cx);
+ self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
assists.push((
assist_id,
- transform_range,
+ range,
prompt_editor,
prompt_block_id,
end_block_id,
@@ -315,7 +306,6 @@ impl InlineAssistant {
Codegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
- vec![range.clone()],
initial_transaction_id,
self.telemetry.clone(),
self.prompt_builder.clone(),
@@ -925,7 +915,12 @@ impl InlineAssistant {
assist
.codegen
.update(cx, |codegen, cx| {
- codegen.start(user_prompt, assistant_panel_context, cx)
+ codegen.start(
+ assist.range.clone(),
+ user_prompt,
+ assistant_panel_context,
+ cx,
+ )
})
.log_err();
@@ -2120,9 +2115,12 @@ impl InlineAssist {
return future::ready(Err(anyhow!("no user prompt"))).boxed();
};
let assistant_panel_context = self.assistant_panel_context(cx);
- self.codegen
- .read(cx)
- .count_tokens(user_prompt, assistant_panel_context, cx)
+ self.codegen.read(cx).count_tokens(
+ self.range.clone(),
+ user_prompt,
+ assistant_panel_context,
+ cx,
+ )
}
}
@@ -2143,8 +2141,6 @@ pub struct Codegen {
buffer: Model<MultiBuffer>,
old_buffer: Model<Buffer>,
snapshot: MultiBufferSnapshot,
- transform_range: Range<Anchor>,
- selected_ranges: Vec<Range<Anchor>>,
edit_position: Option<Anchor>,
last_equal_ranges: Vec<Range<Anchor>>,
initial_transaction_id: Option<TransactionId>,
@@ -2154,7 +2150,7 @@ pub struct Codegen {
diff: Diff,
telemetry: Option<Arc<Telemetry>>,
_subscription: gpui::Subscription,
- prompt_builder: Arc<PromptBuilder>,
+ builder: Arc<PromptBuilder>,
}
enum CodegenStatus {
@@ -2181,8 +2177,7 @@ impl EventEmitter<CodegenEvent> for Codegen {}
impl Codegen {
pub fn new(
buffer: Model<MultiBuffer>,
- transform_range: Range<Anchor>,
- selected_ranges: Vec<Range<Anchor>>,
+ range: Range<Anchor>,
initial_transaction_id: Option<TransactionId>,
telemetry: Option<Arc<Telemetry>>,
builder: Arc<PromptBuilder>,
@@ -2192,7 +2187,7 @@ impl Codegen {
let (old_buffer, _, _) = buffer
.read(cx)
- .range_to_buffer_ranges(transform_range.clone(), cx)
+ .range_to_buffer_ranges(range.clone(), cx)
.pop()
.unwrap();
let old_buffer = cx.new_model(|cx| {
@@ -2223,9 +2218,7 @@ impl Codegen {
telemetry,
_subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
initial_transaction_id,
- prompt_builder: builder,
- transform_range,
- selected_ranges,
+ builder,
}
}
@@ -2250,12 +2243,14 @@ impl Codegen {
pub fn count_tokens(
&self,
+ edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
cx: &AppContext,
) -> BoxFuture<'static, Result<TokenCounts>> {
if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
- let request = self.build_request(user_prompt, assistant_panel_context.clone(), cx);
+ let request =
+ self.build_request(user_prompt, assistant_panel_context.clone(), edit_range, cx);
match request {
Ok(request) => {
let total_count = model.count_tokens(request.clone(), cx);
@@ -2280,6 +2275,7 @@ impl Codegen {
pub fn start(
&mut self,
+ edit_range: Range<Anchor>,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
cx: &mut ModelContext<Self>,
@@ -2294,20 +2290,24 @@ impl Codegen {
});
}
- self.edit_position = Some(self.transform_range.start.bias_right(&self.snapshot));
+ self.edit_position = Some(edit_range.start.bias_right(&self.snapshot));
let telemetry_id = model.telemetry_id();
- let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> =
- if user_prompt.trim().to_lowercase() == "delete" {
- async { Ok(stream::empty().boxed()) }.boxed_local()
- } else {
- let request = self.build_request(user_prompt, assistant_panel_context, cx)?;
+ let chunks: LocalBoxFuture<Result<BoxStream<Result<String>>>> = if user_prompt
+ .trim()
+ .to_lowercase()
+ == "delete"
+ {
+ async { Ok(stream::empty().boxed()) }.boxed_local()
+ } else {
+ let request =
+ self.build_request(user_prompt, assistant_panel_context, edit_range.clone(), cx)?;
- let chunks =
- cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
- async move { Ok(chunks.await?.boxed()) }.boxed_local()
- };
- self.handle_stream(telemetry_id, self.transform_range.clone(), chunks, cx);
+ let chunks =
+ cx.spawn(|_, cx| async move { model.stream_completion(request, &cx).await });
+ async move { Ok(chunks.await?.boxed()) }.boxed_local()
+ };
+ self.handle_stream(telemetry_id, edit_range, chunks, cx);
Ok(())
}
@@ -2315,10 +2315,11 @@ impl Codegen {
&self,
user_prompt: String,
assistant_panel_context: Option<LanguageModelRequest>,
+ edit_range: Range<Anchor>,
cx: &AppContext,
) -> Result<LanguageModelRequest> {
let buffer = self.buffer.read(cx).snapshot(cx);
- let language = buffer.language_at(self.transform_range.start);
+ let language = buffer.language_at(edit_range.start);
let language_name = if let Some(language) = language.as_ref() {
if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
None
@@ -2343,9 +2344,9 @@ impl Codegen {
};
let language_name = language_name.as_deref();
- let start = buffer.point_to_buffer_offset(self.transform_range.start);
- let end = buffer.point_to_buffer_offset(self.transform_range.end);
- let (transform_buffer, transform_range) = if let Some((start, end)) = start.zip(end) {
+ let start = buffer.point_to_buffer_offset(edit_range.start);
+ let end = buffer.point_to_buffer_offset(edit_range.end);
+ let (buffer, range) = if let Some((start, end)) = start.zip(end) {
let (start_buffer, start_buffer_offset) = start;
let (end_buffer, end_buffer_offset) = end;
if start_buffer.remote_id() == end_buffer.remote_id() {
@@ -2357,39 +2358,9 @@ impl Codegen {
return Err(anyhow::anyhow!("invalid transformation range"));
};
- let mut transform_context_range = transform_range.to_point(&transform_buffer);
- transform_context_range.start.row = transform_context_range.start.row.saturating_sub(3);
- transform_context_range.start.column = 0;
- transform_context_range.end =
- (transform_context_range.end + Point::new(3, 0)).min(transform_buffer.max_point());
- transform_context_range.end.column =
- transform_buffer.line_len(transform_context_range.end.row);
- let transform_context_range = transform_context_range.to_offset(&transform_buffer);
-
- let selected_ranges = self
- .selected_ranges
- .iter()
- .filter_map(|selected_range| {
- let start = buffer
- .point_to_buffer_offset(selected_range.start)
- .map(|(_, offset)| offset)?;
- let end = buffer
- .point_to_buffer_offset(selected_range.end)
- .map(|(_, offset)| offset)?;
- Some(start..end)
- })
- .collect::<Vec<_>>();
-
let prompt = self
- .prompt_builder
- .generate_content_prompt(
- user_prompt,
- language_name,
- transform_buffer,
- transform_range,
- selected_ranges,
- transform_context_range,
- )
+ .builder
+ .generate_content_prompt(user_prompt, language_name, buffer, range)
.map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
let mut messages = Vec::new();
@@ -2462,19 +2433,84 @@ impl Codegen {
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
+ let mut new_text = String::new();
+ let mut base_indent = None;
+ let mut line_indent = None;
+ let mut first_line = true;
+
while let Some(chunk) = chunks.next().await {
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
let chunk = chunk?;
- let char_ops = diff.push_new(&chunk);
- line_diff.push_char_operations(&char_ops, &selected_text);
- diff_tx
- .send((char_ops, line_diff.line_operations()))
- .await?;
+
+ let mut lines = chunk.split('\n').peekable();
+ while let Some(line) = lines.next() {
+ new_text.push_str(line);
+ if line_indent.is_none() {
+ if let Some(non_whitespace_ch_ix) =
+ new_text.find(|ch: char| !ch.is_whitespace())
+ {
+ line_indent = Some(non_whitespace_ch_ix);
+ base_indent = base_indent.or(line_indent);
+
+ let line_indent = line_indent.unwrap();
+ let base_indent = base_indent.unwrap();
+ let indent_delta =
+ line_indent as i32 - base_indent as i32;
+ let mut corrected_indent_len = cmp::max(
+ 0,
+ suggested_line_indent.len as i32 + indent_delta,
+ )
+ as usize;
+ if first_line {
+ corrected_indent_len = corrected_indent_len
+ .saturating_sub(
+ selection_start.column as usize,
+ );
+ }
+
+ let indent_char = suggested_line_indent.char();
+ let mut indent_buffer = [0; 4];
+ let indent_str =
+ indent_char.encode_utf8(&mut indent_buffer);
+ new_text.replace_range(
+ ..line_indent,
+ &indent_str.repeat(corrected_indent_len),
+ );
+ }
+ }
+
+ if line_indent.is_some() {
+ let char_ops = diff.push_new(&new_text);
+ line_diff
+ .push_char_operations(&char_ops, &selected_text);
+ diff_tx
+ .send((char_ops, line_diff.line_operations()))
+ .await?;
+ new_text.clear();
+ }
+
+ if lines.peek().is_some() {
+ let char_ops = diff.push_new("\n");
+ line_diff
+ .push_char_operations(&char_ops, &selected_text);
+ diff_tx
+ .send((char_ops, line_diff.line_operations()))
+ .await?;
+ if line_indent.is_none() {
+ // Don't write out the leading indentation in empty lines on the next line
+ // This is the case where the above if statement didn't clear the buffer
+ new_text.clear();
+ }
+ line_indent = None;
+ first_line = false;
+ }
+ }
}
- let char_ops = diff.finish();
+ let mut char_ops = diff.push_new(&new_text);
+ char_ops.extend(diff.finish());
line_diff.push_char_operations(&char_ops, &selected_text);
line_diff.finish(&selected_text);
diff_tx
@@ -2938,13 +2974,311 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
mod tests {
use super::*;
use futures::stream::{self};
+ use gpui::{Context, TestAppContext};
+ use indoc::indoc;
+ use language::{
+ language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
+ Point,
+ };
+ use language_model::LanguageModelRegistry;
+ use rand::prelude::*;
use serde::Serialize;
+ use settings::SettingsStore;
+ use std::{future, sync::Arc};
#[derive(Serialize)]
pub struct DummyCompletionRequest {
pub name: String,
}
+ #[gpui::test(iterations = 10)]
+ async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_model::LanguageModelRegistry::test);
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ fn main() {
+ let x = 0;
+ for _ in 0..10 {
+ x += 1;
+ }
+ }
+ "};
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range,
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ let mut new_text = concat!(
+ " let mut x = 0;\n",
+ " while x < 10 {\n",
+ " x += 1;\n",
+ " }",
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_when_generating_past_indentation(
+ cx: &mut TestAppContext,
+ mut rng: StdRng,
+ ) {
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ fn main() {
+ le
+ }
+ "};
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range.clone(),
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ cx.background_executor.run_until_parked();
+
+ let mut new_text = concat!(
+ "t mut x = 0;\n",
+ "while x < 10 {\n",
+ " x += 1;\n",
+ "}", //
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_when_generating_before_indentation(
+ cx: &mut TestAppContext,
+ mut rng: StdRng,
+ ) {
+ cx.update(LanguageModelRegistry::test);
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = concat!(
+ "fn main() {\n",
+ " \n",
+ "}\n" //
+ );
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range.clone(),
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ cx.background_executor.run_until_parked();
+
+ let mut new_text = concat!(
+ "let mut x = 0;\n",
+ "while x < 10 {\n",
+ " x += 1;\n",
+ "}", //
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
+ cx.update(LanguageModelRegistry::test);
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ func main() {
+ \tx := 0
+ \tfor i := 0; i < 10; i++ {
+ \t\tx++
+ \t}
+ }
+ "};
+ let buffer = cx.new_model(|cx| Buffer::local(text, cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ Codegen::new(
+ buffer.clone(),
+ range.clone(),
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ range.clone(),
+ future::ready(Ok(chunks_rx.map(|chunk| Ok(chunk)).boxed())),
+ cx,
+ )
+ });
+
+ let new_text = concat!(
+ "func main() {\n",
+ "\tx := 0\n",
+ "\tfor x < 10 {\n",
+ "\t\tx++\n",
+ "\t}", //
+ );
+ chunks_tx.unbounded_send(new_text.to_string()).unwrap();
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ func main() {
+ \tx := 0
+ \tfor x < 10 {
+ \t\tx++
+ \t}
+ }
+ "}
+ );
+ }
+
#[gpui::test]
async fn test_strip_invalid_spans_from_codeblock() {
assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
@@ -2984,4 +3318,27 @@ mod tests {
)
}
}
+
+ fn rust_lang() -> Language {
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::language()),
+ )
+ .with_indents_query(
+ r#"
+ (call_expression) @indent
+ (field_expression) @indent
+ (_ "(" ")" @end) @indent
+ (_ "{" "}" @end) @indent
+ "#,
+ )
+ .unwrap()
+ }
}
@@ -12,15 +12,11 @@ use util::ResultExt;
pub struct ContentPromptContext {
pub content_type: String,
pub language_name: Option<String>,
+ pub is_insert: bool,
pub is_truncated: bool,
pub document_content: String,
pub user_prompt: String,
- pub rewrite_section: String,
- pub rewrite_section_prefix: String,
- pub rewrite_section_suffix: String,
- pub rewrite_section_with_edits: String,
- pub has_insertion: bool,
- pub has_replacement: bool,
+ pub rewrite_section: Option<String>,
}
#[derive(Serialize)]
@@ -46,54 +42,41 @@ pub struct PromptBuilder {
handlebars: Arc<Mutex<Handlebars<'static>>>,
}
-pub struct PromptOverrideContext<'a> {
- pub dev_mode: bool,
- pub fs: Arc<dyn Fs>,
- pub cx: &'a mut gpui::AppContext,
-}
-
impl PromptBuilder {
- pub fn new(override_cx: Option<PromptOverrideContext>) -> Result<Self, Box<TemplateError>> {
+ pub fn new(
+ fs_and_cx: Option<(Arc<dyn Fs>, &gpui::AppContext)>,
+ ) -> Result<Self, Box<TemplateError>> {
let mut handlebars = Handlebars::new();
Self::register_templates(&mut handlebars)?;
let handlebars = Arc::new(Mutex::new(handlebars));
- if let Some(override_cx) = override_cx {
- Self::watch_fs_for_template_overrides(override_cx, handlebars.clone());
+ if let Some((fs, cx)) = fs_and_cx {
+ Self::watch_fs_for_template_overrides(fs, cx, handlebars.clone());
}
Ok(Self { handlebars })
}
fn watch_fs_for_template_overrides(
- PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext,
+ fs: Arc<dyn Fs>,
+ cx: &gpui::AppContext,
handlebars: Arc<Mutex<Handlebars<'static>>>,
) {
+ let templates_dir = paths::prompt_overrides_dir();
+
cx.background_executor()
.spawn(async move {
- let templates_dir = if dev_mode {
- std::env::current_dir()
- .ok()
- .and_then(|pwd| {
- let pwd_assets_prompts = pwd.join("assets").join("prompts");
- pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
- })
- .unwrap_or_else(|| paths::prompt_overrides_dir().clone())
- } else {
- paths::prompt_overrides_dir().clone()
- };
-
// Create the prompt templates directory if it doesn't exist
- if !fs.is_dir(&templates_dir).await {
- if let Err(e) = fs.create_dir(&templates_dir).await {
+ if !fs.is_dir(templates_dir).await {
+ if let Err(e) = fs.create_dir(templates_dir).await {
log::error!("Failed to create prompt templates directory: {}", e);
return;
}
}
// Initial scan of the prompts directory
- if let Ok(mut entries) = fs.read_dir(&templates_dir).await {
+ if let Ok(mut entries) = fs.read_dir(templates_dir).await {
while let Some(Ok(file_path)) = entries.next().await {
if file_path.to_string_lossy().ends_with(".hbs") {
if let Ok(content) = fs.load(&file_path).await {
@@ -121,7 +104,7 @@ impl PromptBuilder {
}
// Watch for changes
- let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await;
+ let (mut changes, watcher) = fs.watch(templates_dir, Duration::from_secs(1)).await;
while let Some(changed_paths) = changes.next().await {
for changed_path in changed_paths {
if changed_path.extension().map_or(false, |ext| ext == "hbs") {
@@ -173,9 +156,7 @@ impl PromptBuilder {
user_prompt: String,
language_name: Option<&str>,
buffer: BufferSnapshot,
- transform_range: Range<usize>,
- selected_ranges: Vec<Range<usize>>,
- transform_context_range: Range<usize>,
+ range: Range<usize>,
) -> Result<String, RenderError> {
let content_type = match language_name {
None | Some("Markdown" | "Plain Text") => "text",
@@ -183,20 +164,21 @@ impl PromptBuilder {
};
const MAX_CTX: usize = 50000;
+ let is_insert = range.is_empty();
let mut is_truncated = false;
- let before_range = 0..transform_range.start;
+ let before_range = 0..range.start;
let truncated_before = if before_range.len() > MAX_CTX {
is_truncated = true;
- transform_range.start - MAX_CTX..transform_range.start
+ range.start - MAX_CTX..range.start
} else {
before_range
};
- let after_range = transform_range.end..buffer.len();
+ let after_range = range.end..buffer.len();
let truncated_after = if after_range.len() > MAX_CTX {
is_truncated = true;
- transform_range.end..transform_range.end + MAX_CTX
+ range.end..range.end + MAX_CTX
} else {
after_range
};
@@ -205,74 +187,37 @@ impl PromptBuilder {
for chunk in buffer.text_for_range(truncated_before) {
document_content.push_str(chunk);
}
-
- document_content.push_str("<rewrite_this>\n");
- for chunk in buffer.text_for_range(transform_range.clone()) {
- document_content.push_str(chunk);
+ if is_insert {
+ document_content.push_str("<insert_here></insert_here>");
+ } else {
+ document_content.push_str("<rewrite_this>\n");
+ for chunk in buffer.text_for_range(range.clone()) {
+ document_content.push_str(chunk);
+ }
+ document_content.push_str("\n</rewrite_this>");
}
- document_content.push_str("\n</rewrite_this>");
-
for chunk in buffer.text_for_range(truncated_after) {
document_content.push_str(chunk);
}
- let mut rewrite_section = String::new();
- for chunk in buffer.text_for_range(transform_range.clone()) {
- rewrite_section.push_str(chunk);
- }
-
- let mut rewrite_section_prefix = String::new();
- for chunk in buffer.text_for_range(transform_context_range.start..transform_range.start) {
- rewrite_section_prefix.push_str(chunk);
- }
-
- let mut rewrite_section_suffix = String::new();
- for chunk in buffer.text_for_range(transform_range.end..transform_context_range.end) {
- rewrite_section_suffix.push_str(chunk);
- }
-
- let rewrite_section_with_edits = {
- let mut section_with_selections = String::new();
- let mut last_end = 0;
- for selected_range in &selected_ranges {
- if selected_range.start > last_end {
- section_with_selections.push_str(
- &rewrite_section[last_end..selected_range.start - transform_range.start],
- );
- }
- if selected_range.start == selected_range.end {
- section_with_selections.push_str("<insert_here></insert_here>");
- } else {
- section_with_selections.push_str("<edit_here>");
- section_with_selections.push_str(
- &rewrite_section[selected_range.start - transform_range.start
- ..selected_range.end - transform_range.start],
- );
- section_with_selections.push_str("</edit_here>");
- }
- last_end = selected_range.end - transform_range.start;
- }
- if last_end < rewrite_section.len() {
- section_with_selections.push_str(&rewrite_section[last_end..]);
+ let rewrite_section = if !is_insert {
+ let mut section = String::new();
+ for chunk in buffer.text_for_range(range.clone()) {
+ section.push_str(chunk);
}
- section_with_selections
+ Some(section)
+ } else {
+ None
};
- let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
- let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
-
let context = ContentPromptContext {
content_type: content_type.to_string(),
language_name: language_name.map(|s| s.to_string()),
+ is_insert,
is_truncated,
document_content,
user_prompt,
rewrite_section,
- rewrite_section_prefix,
- rewrite_section_suffix,
- rewrite_section_with_edits,
- has_insertion,
- has_replacement,
};
self.handlebars.lock().render("content_prompt", &context)