1use crate::streaming_diff::{Hunk, StreamingDiff};
2use ai::completion::{CompletionProvider, CompletionRequest};
3use anyhow::Result;
4use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
6use gpui::{Entity, ModelContext, ModelHandle, Task};
7use language::{Rope, TransactionId};
8use multi_buffer;
9use std::{cmp, future, ops::Range, sync::Arc};
10
11pub enum Event {
12 Finished,
13 Undone,
14}
15
16#[derive(Clone)]
17pub enum CodegenKind {
18 Transform { range: Range<Anchor> },
19 Generate { position: Anchor },
20}
21
22pub struct Codegen {
23 provider: Arc<dyn CompletionProvider>,
24 buffer: ModelHandle<MultiBuffer>,
25 snapshot: MultiBufferSnapshot,
26 kind: CodegenKind,
27 last_equal_ranges: Vec<Range<Anchor>>,
28 transaction_id: Option<TransactionId>,
29 error: Option<anyhow::Error>,
30 generation: Task<()>,
31 idle: bool,
32 _subscription: gpui::Subscription,
33}
34
35impl Entity for Codegen {
36 type Event = Event;
37}
38
39impl Codegen {
40 pub fn new(
41 buffer: ModelHandle<MultiBuffer>,
42 kind: CodegenKind,
43 provider: Arc<dyn CompletionProvider>,
44 cx: &mut ModelContext<Self>,
45 ) -> Self {
46 let snapshot = buffer.read(cx).snapshot(cx);
47 Self {
48 provider,
49 buffer: buffer.clone(),
50 snapshot,
51 kind,
52 last_equal_ranges: Default::default(),
53 transaction_id: Default::default(),
54 error: Default::default(),
55 idle: true,
56 generation: Task::ready(()),
57 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
58 }
59 }
60
61 fn handle_buffer_event(
62 &mut self,
63 _buffer: ModelHandle<MultiBuffer>,
64 event: &multi_buffer::Event,
65 cx: &mut ModelContext<Self>,
66 ) {
67 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
68 if self.transaction_id == Some(*transaction_id) {
69 self.transaction_id = None;
70 self.generation = Task::ready(());
71 cx.emit(Event::Undone);
72 }
73 }
74 }
75
76 pub fn range(&self) -> Range<Anchor> {
77 match &self.kind {
78 CodegenKind::Transform { range } => range.clone(),
79 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
80 }
81 }
82
83 pub fn kind(&self) -> &CodegenKind {
84 &self.kind
85 }
86
87 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
88 &self.last_equal_ranges
89 }
90
91 pub fn idle(&self) -> bool {
92 self.idle
93 }
94
95 pub fn error(&self) -> Option<&anyhow::Error> {
96 self.error.as_ref()
97 }
98
99 pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
100 let range = self.range();
101 let snapshot = self.snapshot.clone();
102 let selected_text = snapshot
103 .text_for_range(range.start..range.end)
104 .collect::<Rope>();
105
106 let selection_start = range.start.to_point(&snapshot);
107 let suggested_line_indent = snapshot
108 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
109 .into_values()
110 .next()
111 .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
112
113 let response = self.provider.complete(prompt);
114 self.generation = cx.spawn_weak(|this, mut cx| {
115 async move {
116 let generate = async {
117 let mut edit_start = range.start.to_offset(&snapshot);
118
119 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
120 let diff = cx.background().spawn(async move {
121 let chunks = strip_invalid_spans_from_codeblock(response.await?);
122 futures::pin_mut!(chunks);
123 let mut diff = StreamingDiff::new(selected_text.to_string());
124
125 let mut new_text = String::new();
126 let mut base_indent = None;
127 let mut line_indent = None;
128 let mut first_line = true;
129
130 while let Some(chunk) = chunks.next().await {
131 let chunk = chunk?;
132
133 let mut lines = chunk.split('\n').peekable();
134 while let Some(line) = lines.next() {
135 new_text.push_str(line);
136 if line_indent.is_none() {
137 if let Some(non_whitespace_ch_ix) =
138 new_text.find(|ch: char| !ch.is_whitespace())
139 {
140 line_indent = Some(non_whitespace_ch_ix);
141 base_indent = base_indent.or(line_indent);
142
143 let line_indent = line_indent.unwrap();
144 let base_indent = base_indent.unwrap();
145 let indent_delta = line_indent as i32 - base_indent as i32;
146 let mut corrected_indent_len = cmp::max(
147 0,
148 suggested_line_indent.len as i32 + indent_delta,
149 )
150 as usize;
151 if first_line {
152 corrected_indent_len = corrected_indent_len
153 .saturating_sub(selection_start.column as usize);
154 }
155
156 let indent_char = suggested_line_indent.char();
157 let mut indent_buffer = [0; 4];
158 let indent_str =
159 indent_char.encode_utf8(&mut indent_buffer);
160 new_text.replace_range(
161 ..line_indent,
162 &indent_str.repeat(corrected_indent_len),
163 );
164 }
165 }
166
167 if line_indent.is_some() {
168 hunks_tx.send(diff.push_new(&new_text)).await?;
169 new_text.clear();
170 }
171
172 if lines.peek().is_some() {
173 hunks_tx.send(diff.push_new("\n")).await?;
174 line_indent = None;
175 first_line = false;
176 }
177 }
178 }
179 hunks_tx.send(diff.push_new(&new_text)).await?;
180 hunks_tx.send(diff.finish()).await?;
181
182 anyhow::Ok(())
183 });
184
185 while let Some(hunks) = hunks_rx.next().await {
186 let this = if let Some(this) = this.upgrade(&cx) {
187 this
188 } else {
189 break;
190 };
191
192 this.update(&mut cx, |this, cx| {
193 this.last_equal_ranges.clear();
194
195 let transaction = this.buffer.update(cx, |buffer, cx| {
196 // Avoid grouping assistant edits with user edits.
197 buffer.finalize_last_transaction(cx);
198
199 buffer.start_transaction(cx);
200 buffer.edit(
201 hunks.into_iter().filter_map(|hunk| match hunk {
202 Hunk::Insert { text } => {
203 let edit_start = snapshot.anchor_after(edit_start);
204 Some((edit_start..edit_start, text))
205 }
206 Hunk::Remove { len } => {
207 let edit_end = edit_start + len;
208 let edit_range = snapshot.anchor_after(edit_start)
209 ..snapshot.anchor_before(edit_end);
210 edit_start = edit_end;
211 Some((edit_range, String::new()))
212 }
213 Hunk::Keep { len } => {
214 let edit_end = edit_start + len;
215 let edit_range = snapshot.anchor_after(edit_start)
216 ..snapshot.anchor_before(edit_end);
217 edit_start = edit_end;
218 this.last_equal_ranges.push(edit_range);
219 None
220 }
221 }),
222 None,
223 cx,
224 );
225
226 buffer.end_transaction(cx)
227 });
228
229 if let Some(transaction) = transaction {
230 if let Some(first_transaction) = this.transaction_id {
231 // Group all assistant edits into the first transaction.
232 this.buffer.update(cx, |buffer, cx| {
233 buffer.merge_transactions(
234 transaction,
235 first_transaction,
236 cx,
237 )
238 });
239 } else {
240 this.transaction_id = Some(transaction);
241 this.buffer.update(cx, |buffer, cx| {
242 buffer.finalize_last_transaction(cx)
243 });
244 }
245 }
246
247 cx.notify();
248 });
249 }
250
251 diff.await?;
252 anyhow::Ok(())
253 };
254
255 let result = generate.await;
256 if let Some(this) = this.upgrade(&cx) {
257 this.update(&mut cx, |this, cx| {
258 this.last_equal_ranges.clear();
259 this.idle = true;
260 if let Err(error) = result {
261 this.error = Some(error);
262 }
263 cx.emit(Event::Finished);
264 cx.notify();
265 });
266 }
267 }
268 });
269 self.error.take();
270 self.idle = false;
271 cx.notify();
272 }
273
274 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
275 if let Some(transaction_id) = self.transaction_id {
276 self.buffer
277 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
278 }
279 }
280}
281
282fn strip_invalid_spans_from_codeblock(
283 stream: impl Stream<Item = Result<String>>,
284) -> impl Stream<Item = Result<String>> {
285 let mut first_line = true;
286 let mut buffer = String::new();
287 let mut starts_with_markdown_codeblock = false;
288 let mut includes_start_or_end_span = false;
289 stream.filter_map(move |chunk| {
290 let chunk = match chunk {
291 Ok(chunk) => chunk,
292 Err(err) => return future::ready(Some(Err(err))),
293 };
294 buffer.push_str(&chunk);
295
296 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
297 includes_start_or_end_span = true;
298
299 buffer = buffer
300 .strip_prefix("<|S|>")
301 .or_else(|| buffer.strip_prefix("<|S|"))
302 .unwrap_or(&buffer)
303 .to_string();
304 } else if buffer.ends_with("|E|>") {
305 includes_start_or_end_span = true;
306 } else if buffer.starts_with("<|")
307 || buffer.starts_with("<|S")
308 || buffer.starts_with("<|S|")
309 || buffer.ends_with("|")
310 || buffer.ends_with("|E")
311 || buffer.ends_with("|E|")
312 {
313 return future::ready(None);
314 }
315
316 if first_line {
317 if buffer == "" || buffer == "`" || buffer == "``" {
318 return future::ready(None);
319 } else if buffer.starts_with("```") {
320 starts_with_markdown_codeblock = true;
321 if let Some(newline_ix) = buffer.find('\n') {
322 buffer.replace_range(..newline_ix + 1, "");
323 first_line = false;
324 } else {
325 return future::ready(None);
326 }
327 }
328 }
329
330 let mut text = buffer.to_string();
331 if starts_with_markdown_codeblock {
332 text = text
333 .strip_suffix("\n```\n")
334 .or_else(|| text.strip_suffix("\n```"))
335 .or_else(|| text.strip_suffix("\n``"))
336 .or_else(|| text.strip_suffix("\n`"))
337 .or_else(|| text.strip_suffix('\n'))
338 .unwrap_or(&text)
339 .to_string();
340 }
341
342 if includes_start_or_end_span {
343 text = text
344 .strip_suffix("|E|>")
345 .or_else(|| text.strip_suffix("E|>"))
346 .or_else(|| text.strip_prefix("|>"))
347 .or_else(|| text.strip_prefix(">"))
348 .unwrap_or(&text)
349 .to_string();
350 };
351
352 if text.contains('\n') {
353 first_line = false;
354 }
355
356 let remainder = buffer.split_off(text.len());
357 let result = if buffer.is_empty() {
358 None
359 } else {
360 Some(Ok(buffer.clone()))
361 };
362
363 buffer = remainder;
364 future::ready(result)
365 })
366}
367
368#[cfg(test)]
369mod tests {
370 use std::sync::Arc;
371
372 use super::*;
373 use ai::test::FakeCompletionProvider;
374 use futures::stream::{self};
375 use gpui::{executor::Deterministic, TestAppContext};
376 use indoc::indoc;
377 use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
378 use rand::prelude::*;
379 use serde::Serialize;
380 use settings::SettingsStore;
381
382 #[derive(Serialize)]
383 pub struct DummyCompletionRequest {
384 pub name: String,
385 }
386
387 impl CompletionRequest for DummyCompletionRequest {
388 fn data(&self) -> serde_json::Result<String> {
389 serde_json::to_string(self)
390 }
391 }
392
393 #[gpui::test(iterations = 10)]
394 async fn test_transform_autoindent(
395 cx: &mut TestAppContext,
396 mut rng: StdRng,
397 deterministic: Arc<Deterministic>,
398 ) {
399 cx.set_global(cx.read(SettingsStore::test));
400 cx.update(language_settings::init);
401
402 let text = indoc! {"
403 fn main() {
404 let x = 0;
405 for _ in 0..10 {
406 x += 1;
407 }
408 }
409 "};
410 let buffer =
411 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
412 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
413 let range = buffer.read_with(cx, |buffer, cx| {
414 let snapshot = buffer.snapshot(cx);
415 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
416 });
417 let provider = Arc::new(FakeCompletionProvider::new());
418 let codegen = cx.add_model(|cx| {
419 Codegen::new(
420 buffer.clone(),
421 CodegenKind::Transform { range },
422 provider.clone(),
423 cx,
424 )
425 });
426
427 let request = Box::new(DummyCompletionRequest {
428 name: "test".to_string(),
429 });
430 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
431
432 let mut new_text = concat!(
433 " let mut x = 0;\n",
434 " while x < 10 {\n",
435 " x += 1;\n",
436 " }",
437 );
438 while !new_text.is_empty() {
439 let max_len = cmp::min(new_text.len(), 10);
440 let len = rng.gen_range(1..=max_len);
441 let (chunk, suffix) = new_text.split_at(len);
442 println!("CHUNK: {:?}", &chunk);
443 provider.send_completion(chunk);
444 new_text = suffix;
445 deterministic.run_until_parked();
446 }
447 provider.finish_completion();
448 deterministic.run_until_parked();
449
450 assert_eq!(
451 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
452 indoc! {"
453 fn main() {
454 let mut x = 0;
455 while x < 10 {
456 x += 1;
457 }
458 }
459 "}
460 );
461 }
462
463 #[gpui::test(iterations = 10)]
464 async fn test_autoindent_when_generating_past_indentation(
465 cx: &mut TestAppContext,
466 mut rng: StdRng,
467 deterministic: Arc<Deterministic>,
468 ) {
469 cx.set_global(cx.read(SettingsStore::test));
470 cx.update(language_settings::init);
471
472 let text = indoc! {"
473 fn main() {
474 le
475 }
476 "};
477 let buffer =
478 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
479 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
480 let position = buffer.read_with(cx, |buffer, cx| {
481 let snapshot = buffer.snapshot(cx);
482 snapshot.anchor_before(Point::new(1, 6))
483 });
484 let provider = Arc::new(FakeCompletionProvider::new());
485 let codegen = cx.add_model(|cx| {
486 Codegen::new(
487 buffer.clone(),
488 CodegenKind::Generate { position },
489 provider.clone(),
490 cx,
491 )
492 });
493
494 let request = Box::new(DummyCompletionRequest {
495 name: "test".to_string(),
496 });
497 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
498
499 let mut new_text = concat!(
500 "t mut x = 0;\n",
501 "while x < 10 {\n",
502 " x += 1;\n",
503 "}", //
504 );
505 while !new_text.is_empty() {
506 let max_len = cmp::min(new_text.len(), 10);
507 let len = rng.gen_range(1..=max_len);
508 let (chunk, suffix) = new_text.split_at(len);
509 provider.send_completion(chunk);
510 new_text = suffix;
511 deterministic.run_until_parked();
512 }
513 provider.finish_completion();
514 deterministic.run_until_parked();
515
516 assert_eq!(
517 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
518 indoc! {"
519 fn main() {
520 let mut x = 0;
521 while x < 10 {
522 x += 1;
523 }
524 }
525 "}
526 );
527 }
528
529 #[gpui::test(iterations = 10)]
530 async fn test_autoindent_when_generating_before_indentation(
531 cx: &mut TestAppContext,
532 mut rng: StdRng,
533 deterministic: Arc<Deterministic>,
534 ) {
535 cx.set_global(cx.read(SettingsStore::test));
536 cx.update(language_settings::init);
537
538 let text = concat!(
539 "fn main() {\n",
540 " \n",
541 "}\n" //
542 );
543 let buffer =
544 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
545 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
546 let position = buffer.read_with(cx, |buffer, cx| {
547 let snapshot = buffer.snapshot(cx);
548 snapshot.anchor_before(Point::new(1, 2))
549 });
550 let provider = Arc::new(FakeCompletionProvider::new());
551 let codegen = cx.add_model(|cx| {
552 Codegen::new(
553 buffer.clone(),
554 CodegenKind::Generate { position },
555 provider.clone(),
556 cx,
557 )
558 });
559
560 let request = Box::new(DummyCompletionRequest {
561 name: "test".to_string(),
562 });
563 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
564
565 let mut new_text = concat!(
566 "let mut x = 0;\n",
567 "while x < 10 {\n",
568 " x += 1;\n",
569 "}", //
570 );
571 while !new_text.is_empty() {
572 let max_len = cmp::min(new_text.len(), 10);
573 let len = rng.gen_range(1..=max_len);
574 let (chunk, suffix) = new_text.split_at(len);
575 println!("{:?}", &chunk);
576 provider.send_completion(chunk);
577 new_text = suffix;
578 deterministic.run_until_parked();
579 }
580 provider.finish_completion();
581 deterministic.run_until_parked();
582
583 assert_eq!(
584 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
585 indoc! {"
586 fn main() {
587 let mut x = 0;
588 while x < 10 {
589 x += 1;
590 }
591 }
592 "}
593 );
594 }
595
596 #[gpui::test]
597 async fn test_strip_invalid_spans_from_codeblock() {
598 assert_eq!(
599 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
600 .map(|chunk| chunk.unwrap())
601 .collect::<String>()
602 .await,
603 "Lorem ipsum dolor"
604 );
605 assert_eq!(
606 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
607 .map(|chunk| chunk.unwrap())
608 .collect::<String>()
609 .await,
610 "Lorem ipsum dolor"
611 );
612 assert_eq!(
613 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
614 .map(|chunk| chunk.unwrap())
615 .collect::<String>()
616 .await,
617 "Lorem ipsum dolor"
618 );
619 assert_eq!(
620 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
621 .map(|chunk| chunk.unwrap())
622 .collect::<String>()
623 .await,
624 "Lorem ipsum dolor"
625 );
626 assert_eq!(
627 strip_invalid_spans_from_codeblock(chunks(
628 "```html\n```js\nLorem ipsum dolor\n```\n```",
629 2
630 ))
631 .map(|chunk| chunk.unwrap())
632 .collect::<String>()
633 .await,
634 "```js\nLorem ipsum dolor\n```"
635 );
636 assert_eq!(
637 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
638 .map(|chunk| chunk.unwrap())
639 .collect::<String>()
640 .await,
641 "``\nLorem ipsum dolor\n```"
642 );
643 assert_eq!(
644 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
645 .map(|chunk| chunk.unwrap())
646 .collect::<String>()
647 .await,
648 "Lorem ipsum"
649 );
650
651 assert_eq!(
652 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
653 .map(|chunk| chunk.unwrap())
654 .collect::<String>()
655 .await,
656 "Lorem ipsum"
657 );
658
659 assert_eq!(
660 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
661 .map(|chunk| chunk.unwrap())
662 .collect::<String>()
663 .await,
664 "Lorem ipsum"
665 );
666 assert_eq!(
667 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
668 .map(|chunk| chunk.unwrap())
669 .collect::<String>()
670 .await,
671 "Lorem ipsum"
672 );
673 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
674 stream::iter(
675 text.chars()
676 .collect::<Vec<_>>()
677 .chunks(size)
678 .map(|chunk| Ok(chunk.iter().collect::<String>()))
679 .collect::<Vec<_>>(),
680 )
681 }
682 }
683
684 fn rust_lang() -> Language {
685 Language::new(
686 LanguageConfig {
687 name: "Rust".into(),
688 path_suffixes: vec!["rs".to_string()],
689 ..Default::default()
690 },
691 Some(tree_sitter_rust::language()),
692 )
693 .with_indents_query(
694 r#"
695 (call_expression) @indent
696 (field_expression) @indent
697 (_ "(" ")" @end) @indent
698 (_ "{" "}" @end) @indent
699 "#,
700 )
701 .unwrap()
702 }
703}