1use crate::streaming_diff::{Hunk, StreamingDiff};
2use ai::completion::{CompletionProvider, CompletionRequest};
3use anyhow::Result;
4use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
6use gpui::{EventEmitter, Model, ModelContext, Task};
7use language::{Rope, TransactionId};
8use multi_buffer;
9use std::{cmp, future, ops::Range, sync::Arc};
10
11pub enum Event {
12 Finished,
13 Undone,
14}
15
16#[derive(Clone)]
17pub enum CodegenKind {
18 Transform { range: Range<Anchor> },
19 Generate { position: Anchor },
20}
21
22pub struct Codegen {
23 provider: Arc<dyn CompletionProvider>,
24 buffer: Model<MultiBuffer>,
25 snapshot: MultiBufferSnapshot,
26 kind: CodegenKind,
27 last_equal_ranges: Vec<Range<Anchor>>,
28 transaction_id: Option<TransactionId>,
29 error: Option<anyhow::Error>,
30 generation: Task<()>,
31 idle: bool,
32 _subscription: gpui::Subscription,
33}
34
35impl EventEmitter<Event> for Codegen {}
36
37impl Codegen {
38 pub fn new(
39 buffer: Model<MultiBuffer>,
40 kind: CodegenKind,
41 provider: Arc<dyn CompletionProvider>,
42 cx: &mut ModelContext<Self>,
43 ) -> Self {
44 let snapshot = buffer.read(cx).snapshot(cx);
45 Self {
46 provider,
47 buffer: buffer.clone(),
48 snapshot,
49 kind,
50 last_equal_ranges: Default::default(),
51 transaction_id: Default::default(),
52 error: Default::default(),
53 idle: true,
54 generation: Task::ready(()),
55 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
56 }
57 }
58
59 fn handle_buffer_event(
60 &mut self,
61 _buffer: Model<MultiBuffer>,
62 event: &multi_buffer::Event,
63 cx: &mut ModelContext<Self>,
64 ) {
65 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
66 if self.transaction_id == Some(*transaction_id) {
67 self.transaction_id = None;
68 self.generation = Task::ready(());
69 cx.emit(Event::Undone);
70 }
71 }
72 }
73
74 pub fn range(&self) -> Range<Anchor> {
75 match &self.kind {
76 CodegenKind::Transform { range } => range.clone(),
77 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
78 }
79 }
80
81 pub fn kind(&self) -> &CodegenKind {
82 &self.kind
83 }
84
85 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
86 &self.last_equal_ranges
87 }
88
89 pub fn idle(&self) -> bool {
90 self.idle
91 }
92
93 pub fn error(&self) -> Option<&anyhow::Error> {
94 self.error.as_ref()
95 }
96
97 pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
98 let range = self.range();
99 let snapshot = self.snapshot.clone();
100 let selected_text = snapshot
101 .text_for_range(range.start..range.end)
102 .collect::<Rope>();
103
104 let selection_start = range.start.to_point(&snapshot);
105 let suggested_line_indent = snapshot
106 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
107 .into_values()
108 .next()
109 .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
110
111 let response = self.provider.complete(prompt);
112 self.generation = cx.spawn_weak(|this, mut cx| {
113 async move {
114 let generate = async {
115 let mut edit_start = range.start.to_offset(&snapshot);
116
117 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
118 let diff = cx.background().spawn(async move {
119 let chunks = strip_invalid_spans_from_codeblock(response.await?);
120 futures::pin_mut!(chunks);
121 let mut diff = StreamingDiff::new(selected_text.to_string());
122
123 let mut new_text = String::new();
124 let mut base_indent = None;
125 let mut line_indent = None;
126 let mut first_line = true;
127
128 while let Some(chunk) = chunks.next().await {
129 let chunk = chunk?;
130
131 let mut lines = chunk.split('\n').peekable();
132 while let Some(line) = lines.next() {
133 new_text.push_str(line);
134 if line_indent.is_none() {
135 if let Some(non_whitespace_ch_ix) =
136 new_text.find(|ch: char| !ch.is_whitespace())
137 {
138 line_indent = Some(non_whitespace_ch_ix);
139 base_indent = base_indent.or(line_indent);
140
141 let line_indent = line_indent.unwrap();
142 let base_indent = base_indent.unwrap();
143 let indent_delta = line_indent as i32 - base_indent as i32;
144 let mut corrected_indent_len = cmp::max(
145 0,
146 suggested_line_indent.len as i32 + indent_delta,
147 )
148 as usize;
149 if first_line {
150 corrected_indent_len = corrected_indent_len
151 .saturating_sub(selection_start.column as usize);
152 }
153
154 let indent_char = suggested_line_indent.char();
155 let mut indent_buffer = [0; 4];
156 let indent_str =
157 indent_char.encode_utf8(&mut indent_buffer);
158 new_text.replace_range(
159 ..line_indent,
160 &indent_str.repeat(corrected_indent_len),
161 );
162 }
163 }
164
165 if line_indent.is_some() {
166 hunks_tx.send(diff.push_new(&new_text)).await?;
167 new_text.clear();
168 }
169
170 if lines.peek().is_some() {
171 hunks_tx.send(diff.push_new("\n")).await?;
172 line_indent = None;
173 first_line = false;
174 }
175 }
176 }
177 hunks_tx.send(diff.push_new(&new_text)).await?;
178 hunks_tx.send(diff.finish()).await?;
179
180 anyhow::Ok(())
181 });
182
183 while let Some(hunks) = hunks_rx.next().await {
184 let this = if let Some(this) = this.upgrade(&cx) {
185 this
186 } else {
187 break;
188 };
189
190 this.update(&mut cx, |this, cx| {
191 this.last_equal_ranges.clear();
192
193 let transaction = this.buffer.update(cx, |buffer, cx| {
194 // Avoid grouping assistant edits with user edits.
195 buffer.finalize_last_transaction(cx);
196
197 buffer.start_transaction(cx);
198 buffer.edit(
199 hunks.into_iter().filter_map(|hunk| match hunk {
200 Hunk::Insert { text } => {
201 let edit_start = snapshot.anchor_after(edit_start);
202 Some((edit_start..edit_start, text))
203 }
204 Hunk::Remove { len } => {
205 let edit_end = edit_start + len;
206 let edit_range = snapshot.anchor_after(edit_start)
207 ..snapshot.anchor_before(edit_end);
208 edit_start = edit_end;
209 Some((edit_range, String::new()))
210 }
211 Hunk::Keep { len } => {
212 let edit_end = edit_start + len;
213 let edit_range = snapshot.anchor_after(edit_start)
214 ..snapshot.anchor_before(edit_end);
215 edit_start = edit_end;
216 this.last_equal_ranges.push(edit_range);
217 None
218 }
219 }),
220 None,
221 cx,
222 );
223
224 buffer.end_transaction(cx)
225 });
226
227 if let Some(transaction) = transaction {
228 if let Some(first_transaction) = this.transaction_id {
229 // Group all assistant edits into the first transaction.
230 this.buffer.update(cx, |buffer, cx| {
231 buffer.merge_transactions(
232 transaction,
233 first_transaction,
234 cx,
235 )
236 });
237 } else {
238 this.transaction_id = Some(transaction);
239 this.buffer.update(cx, |buffer, cx| {
240 buffer.finalize_last_transaction(cx)
241 });
242 }
243 }
244
245 cx.notify();
246 });
247 }
248
249 diff.await?;
250 anyhow::Ok(())
251 };
252
253 let result = generate.await;
254 if let Some(this) = this.upgrade(&cx) {
255 this.update(&mut cx, |this, cx| {
256 this.last_equal_ranges.clear();
257 this.idle = true;
258 if let Err(error) = result {
259 this.error = Some(error);
260 }
261 cx.emit(Event::Finished);
262 cx.notify();
263 });
264 }
265 }
266 });
267 self.error.take();
268 self.idle = false;
269 cx.notify();
270 }
271
272 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
273 if let Some(transaction_id) = self.transaction_id {
274 self.buffer
275 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
276 }
277 }
278}
279
280fn strip_invalid_spans_from_codeblock(
281 stream: impl Stream<Item = Result<String>>,
282) -> impl Stream<Item = Result<String>> {
283 let mut first_line = true;
284 let mut buffer = String::new();
285 let mut starts_with_markdown_codeblock = false;
286 let mut includes_start_or_end_span = false;
287 stream.filter_map(move |chunk| {
288 let chunk = match chunk {
289 Ok(chunk) => chunk,
290 Err(err) => return future::ready(Some(Err(err))),
291 };
292 buffer.push_str(&chunk);
293
294 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
295 includes_start_or_end_span = true;
296
297 buffer = buffer
298 .strip_prefix("<|S|>")
299 .or_else(|| buffer.strip_prefix("<|S|"))
300 .unwrap_or(&buffer)
301 .to_string();
302 } else if buffer.ends_with("|E|>") {
303 includes_start_or_end_span = true;
304 } else if buffer.starts_with("<|")
305 || buffer.starts_with("<|S")
306 || buffer.starts_with("<|S|")
307 || buffer.ends_with("|")
308 || buffer.ends_with("|E")
309 || buffer.ends_with("|E|")
310 {
311 return future::ready(None);
312 }
313
314 if first_line {
315 if buffer == "" || buffer == "`" || buffer == "``" {
316 return future::ready(None);
317 } else if buffer.starts_with("```") {
318 starts_with_markdown_codeblock = true;
319 if let Some(newline_ix) = buffer.find('\n') {
320 buffer.replace_range(..newline_ix + 1, "");
321 first_line = false;
322 } else {
323 return future::ready(None);
324 }
325 }
326 }
327
328 let mut text = buffer.to_string();
329 if starts_with_markdown_codeblock {
330 text = text
331 .strip_suffix("\n```\n")
332 .or_else(|| text.strip_suffix("\n```"))
333 .or_else(|| text.strip_suffix("\n``"))
334 .or_else(|| text.strip_suffix("\n`"))
335 .or_else(|| text.strip_suffix('\n'))
336 .unwrap_or(&text)
337 .to_string();
338 }
339
340 if includes_start_or_end_span {
341 text = text
342 .strip_suffix("|E|>")
343 .or_else(|| text.strip_suffix("E|>"))
344 .or_else(|| text.strip_prefix("|>"))
345 .or_else(|| text.strip_prefix(">"))
346 .unwrap_or(&text)
347 .to_string();
348 };
349
350 if text.contains('\n') {
351 first_line = false;
352 }
353
354 let remainder = buffer.split_off(text.len());
355 let result = if buffer.is_empty() {
356 None
357 } else {
358 Some(Ok(buffer.clone()))
359 };
360
361 buffer = remainder;
362 future::ready(result)
363 })
364}
365
366#[cfg(test)]
367mod tests {
368 use std::sync::Arc;
369
370 use super::*;
371 use ai::test::FakeCompletionProvider;
372 use futures::stream::{self};
373 use gpui::TestAppContext;
374 use indoc::indoc;
375 use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
376 use rand::prelude::*;
377 use serde::Serialize;
378 use settings::SettingsStore;
379
380 #[derive(Serialize)]
381 pub struct DummyCompletionRequest {
382 pub name: String,
383 }
384
385 impl CompletionRequest for DummyCompletionRequest {
386 fn data(&self) -> serde_json::Result<String> {
387 serde_json::to_string(self)
388 }
389 }
390
391 #[gpui::test(iterations = 10)]
392 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
393 cx.set_global(cx.read(SettingsStore::test));
394 cx.update(language_settings::init);
395
396 let text = indoc! {"
397 fn main() {
398 let x = 0;
399 for _ in 0..10 {
400 x += 1;
401 }
402 }
403 "};
404 let buffer =
405 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
406 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
407 let range = buffer.read_with(cx, |buffer, cx| {
408 let snapshot = buffer.snapshot(cx);
409 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
410 });
411 let provider = Arc::new(FakeCompletionProvider::new());
412 let codegen = cx.add_model(|cx| {
413 Codegen::new(
414 buffer.clone(),
415 CodegenKind::Transform { range },
416 provider.clone(),
417 cx,
418 )
419 });
420
421 let request = Box::new(DummyCompletionRequest {
422 name: "test".to_string(),
423 });
424 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
425
426 let mut new_text = concat!(
427 " let mut x = 0;\n",
428 " while x < 10 {\n",
429 " x += 1;\n",
430 " }",
431 );
432 while !new_text.is_empty() {
433 let max_len = cmp::min(new_text.len(), 10);
434 let len = rng.gen_range(1..=max_len);
435 let (chunk, suffix) = new_text.split_at(len);
436 println!("CHUNK: {:?}", &chunk);
437 provider.send_completion(chunk);
438 new_text = suffix;
439 cx.background_executor.run_until_parked();
440 }
441 provider.finish_completion();
442 cx.background_executor.run_until_parked();
443
444 assert_eq!(
445 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
446 indoc! {"
447 fn main() {
448 let mut x = 0;
449 while x < 10 {
450 x += 1;
451 }
452 }
453 "}
454 );
455 }
456
457 #[gpui::test(iterations = 10)]
458 async fn test_autoindent_when_generating_past_indentation(
459 cx: &mut TestAppContext,
460 mut rng: StdRng,
461 ) {
462 cx.set_global(cx.read(SettingsStore::test));
463 cx.update(language_settings::init);
464
465 let text = indoc! {"
466 fn main() {
467 le
468 }
469 "};
470 let buffer =
471 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
472 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
473 let position = buffer.read_with(cx, |buffer, cx| {
474 let snapshot = buffer.snapshot(cx);
475 snapshot.anchor_before(Point::new(1, 6))
476 });
477 let provider = Arc::new(FakeCompletionProvider::new());
478 let codegen = cx.add_model(|cx| {
479 Codegen::new(
480 buffer.clone(),
481 CodegenKind::Generate { position },
482 provider.clone(),
483 cx,
484 )
485 });
486
487 let request = Box::new(DummyCompletionRequest {
488 name: "test".to_string(),
489 });
490 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
491
492 let mut new_text = concat!(
493 "t mut x = 0;\n",
494 "while x < 10 {\n",
495 " x += 1;\n",
496 "}", //
497 );
498 while !new_text.is_empty() {
499 let max_len = cmp::min(new_text.len(), 10);
500 let len = rng.gen_range(1..=max_len);
501 let (chunk, suffix) = new_text.split_at(len);
502 provider.send_completion(chunk);
503 new_text = suffix;
504 cx.background_executor.run_until_parked();
505 }
506 provider.finish_completion();
507 cx.background_executor.run_until_parked();
508
509 assert_eq!(
510 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
511 indoc! {"
512 fn main() {
513 let mut x = 0;
514 while x < 10 {
515 x += 1;
516 }
517 }
518 "}
519 );
520 }
521
522 #[gpui::test(iterations = 10)]
523 async fn test_autoindent_when_generating_before_indentation(
524 cx: &mut TestAppContext,
525 mut rng: StdRng,
526 ) {
527 cx.set_global(cx.read(SettingsStore::test));
528 cx.update(language_settings::init);
529
530 let text = concat!(
531 "fn main() {\n",
532 " \n",
533 "}\n" //
534 );
535 let buffer =
536 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
537 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
538 let position = buffer.read_with(cx, |buffer, cx| {
539 let snapshot = buffer.snapshot(cx);
540 snapshot.anchor_before(Point::new(1, 2))
541 });
542 let provider = Arc::new(FakeCompletionProvider::new());
543 let codegen = cx.add_model(|cx| {
544 Codegen::new(
545 buffer.clone(),
546 CodegenKind::Generate { position },
547 provider.clone(),
548 cx,
549 )
550 });
551
552 let request = Box::new(DummyCompletionRequest {
553 name: "test".to_string(),
554 });
555 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
556
557 let mut new_text = concat!(
558 "let mut x = 0;\n",
559 "while x < 10 {\n",
560 " x += 1;\n",
561 "}", //
562 );
563 while !new_text.is_empty() {
564 let max_len = cmp::min(new_text.len(), 10);
565 let len = rng.gen_range(1..=max_len);
566 let (chunk, suffix) = new_text.split_at(len);
567 println!("{:?}", &chunk);
568 provider.send_completion(chunk);
569 new_text = suffix;
570 cx.background_executor.run_until_parked();
571 }
572 provider.finish_completion();
573 cx.background_executor.run_until_parked();
574
575 assert_eq!(
576 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
577 indoc! {"
578 fn main() {
579 let mut x = 0;
580 while x < 10 {
581 x += 1;
582 }
583 }
584 "}
585 );
586 }
587
588 #[gpui::test]
589 async fn test_strip_invalid_spans_from_codeblock() {
590 assert_eq!(
591 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
592 .map(|chunk| chunk.unwrap())
593 .collect::<String>()
594 .await,
595 "Lorem ipsum dolor"
596 );
597 assert_eq!(
598 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
599 .map(|chunk| chunk.unwrap())
600 .collect::<String>()
601 .await,
602 "Lorem ipsum dolor"
603 );
604 assert_eq!(
605 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
606 .map(|chunk| chunk.unwrap())
607 .collect::<String>()
608 .await,
609 "Lorem ipsum dolor"
610 );
611 assert_eq!(
612 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
613 .map(|chunk| chunk.unwrap())
614 .collect::<String>()
615 .await,
616 "Lorem ipsum dolor"
617 );
618 assert_eq!(
619 strip_invalid_spans_from_codeblock(chunks(
620 "```html\n```js\nLorem ipsum dolor\n```\n```",
621 2
622 ))
623 .map(|chunk| chunk.unwrap())
624 .collect::<String>()
625 .await,
626 "```js\nLorem ipsum dolor\n```"
627 );
628 assert_eq!(
629 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
630 .map(|chunk| chunk.unwrap())
631 .collect::<String>()
632 .await,
633 "``\nLorem ipsum dolor\n```"
634 );
635 assert_eq!(
636 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
637 .map(|chunk| chunk.unwrap())
638 .collect::<String>()
639 .await,
640 "Lorem ipsum"
641 );
642
643 assert_eq!(
644 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
645 .map(|chunk| chunk.unwrap())
646 .collect::<String>()
647 .await,
648 "Lorem ipsum"
649 );
650
651 assert_eq!(
652 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
653 .map(|chunk| chunk.unwrap())
654 .collect::<String>()
655 .await,
656 "Lorem ipsum"
657 );
658 assert_eq!(
659 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
660 .map(|chunk| chunk.unwrap())
661 .collect::<String>()
662 .await,
663 "Lorem ipsum"
664 );
665 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
666 stream::iter(
667 text.chars()
668 .collect::<Vec<_>>()
669 .chunks(size)
670 .map(|chunk| Ok(chunk.iter().collect::<String>()))
671 .collect::<Vec<_>>(),
672 )
673 }
674 }
675
676 fn rust_lang() -> Language {
677 Language::new(
678 LanguageConfig {
679 name: "Rust".into(),
680 path_suffixes: vec!["rs".to_string()],
681 ..Default::default()
682 },
683 Some(tree_sitter_rust::language()),
684 )
685 .with_indents_query(
686 r#"
687 (call_expression) @indent
688 (field_expression) @indent
689 (_ "(" ")" @end) @indent
690 (_ "{" "}" @end) @indent
691 "#,
692 )
693 .unwrap()
694 }
695}