1use crate::streaming_diff::{Hunk, StreamingDiff};
2use ai::completion::{CompletionProvider, CompletionRequest};
3use anyhow::Result;
4use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
6use gpui::{EventEmitter, Model, ModelContext, Task};
7use language::{Rope, TransactionId};
8use multi_buffer;
9use std::{cmp, future, ops::Range, sync::Arc};
10
11pub enum Event {
12 Finished,
13 Undone,
14}
15
16#[derive(Clone)]
17pub enum CodegenKind {
18 Transform { range: Range<Anchor> },
19 Generate { position: Anchor },
20}
21
22pub struct Codegen {
23 provider: Arc<dyn CompletionProvider>,
24 buffer: Model<MultiBuffer>,
25 snapshot: MultiBufferSnapshot,
26 kind: CodegenKind,
27 last_equal_ranges: Vec<Range<Anchor>>,
28 transaction_id: Option<TransactionId>,
29 error: Option<anyhow::Error>,
30 generation: Task<()>,
31 idle: bool,
32 _subscription: gpui::Subscription,
33}
34
35impl EventEmitter<Event> for Codegen {}
36
37impl Codegen {
38 pub fn new(
39 buffer: Model<MultiBuffer>,
40 kind: CodegenKind,
41 provider: Arc<dyn CompletionProvider>,
42 cx: &mut ModelContext<Self>,
43 ) -> Self {
44 let snapshot = buffer.read(cx).snapshot(cx);
45 Self {
46 provider,
47 buffer: buffer.clone(),
48 snapshot,
49 kind,
50 last_equal_ranges: Default::default(),
51 transaction_id: Default::default(),
52 error: Default::default(),
53 idle: true,
54 generation: Task::ready(()),
55 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
56 }
57 }
58
59 fn handle_buffer_event(
60 &mut self,
61 _buffer: Model<MultiBuffer>,
62 event: &multi_buffer::Event,
63 cx: &mut ModelContext<Self>,
64 ) {
65 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
66 if self.transaction_id == Some(*transaction_id) {
67 self.transaction_id = None;
68 self.generation = Task::ready(());
69 cx.emit(Event::Undone);
70 }
71 }
72 }
73
74 pub fn range(&self) -> Range<Anchor> {
75 match &self.kind {
76 CodegenKind::Transform { range } => range.clone(),
77 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
78 }
79 }
80
81 pub fn kind(&self) -> &CodegenKind {
82 &self.kind
83 }
84
85 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
86 &self.last_equal_ranges
87 }
88
89 pub fn idle(&self) -> bool {
90 self.idle
91 }
92
93 pub fn error(&self) -> Option<&anyhow::Error> {
94 self.error.as_ref()
95 }
96
97 pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
98 let range = self.range();
99 let snapshot = self.snapshot.clone();
100 let selected_text = snapshot
101 .text_for_range(range.start..range.end)
102 .collect::<Rope>();
103
104 let selection_start = range.start.to_point(&snapshot);
105 let suggested_line_indent = snapshot
106 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
107 .into_values()
108 .next()
109 .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
110
111 let response = self.provider.complete(prompt);
112 self.generation = cx.spawn(|this, mut cx| {
113 async move {
114 let generate = async {
115 let mut edit_start = range.start.to_offset(&snapshot);
116
117 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
118 let diff = cx.background_executor().spawn(async move {
119 let chunks = strip_invalid_spans_from_codeblock(response.await?);
120 futures::pin_mut!(chunks);
121 let mut diff = StreamingDiff::new(selected_text.to_string());
122
123 let mut new_text = String::new();
124 let mut base_indent = None;
125 let mut line_indent = None;
126 let mut first_line = true;
127
128 while let Some(chunk) = chunks.next().await {
129 let chunk = chunk?;
130
131 let mut lines = chunk.split('\n').peekable();
132 while let Some(line) = lines.next() {
133 new_text.push_str(line);
134 if line_indent.is_none() {
135 if let Some(non_whitespace_ch_ix) =
136 new_text.find(|ch: char| !ch.is_whitespace())
137 {
138 line_indent = Some(non_whitespace_ch_ix);
139 base_indent = base_indent.or(line_indent);
140
141 let line_indent = line_indent.unwrap();
142 let base_indent = base_indent.unwrap();
143 let indent_delta = line_indent as i32 - base_indent as i32;
144 let mut corrected_indent_len = cmp::max(
145 0,
146 suggested_line_indent.len as i32 + indent_delta,
147 )
148 as usize;
149 if first_line {
150 corrected_indent_len = corrected_indent_len
151 .saturating_sub(selection_start.column as usize);
152 }
153
154 let indent_char = suggested_line_indent.char();
155 let mut indent_buffer = [0; 4];
156 let indent_str =
157 indent_char.encode_utf8(&mut indent_buffer);
158 new_text.replace_range(
159 ..line_indent,
160 &indent_str.repeat(corrected_indent_len),
161 );
162 }
163 }
164
165 if line_indent.is_some() {
166 hunks_tx.send(diff.push_new(&new_text)).await?;
167 new_text.clear();
168 }
169
170 if lines.peek().is_some() {
171 hunks_tx.send(diff.push_new("\n")).await?;
172 line_indent = None;
173 first_line = false;
174 }
175 }
176 }
177 hunks_tx.send(diff.push_new(&new_text)).await?;
178 hunks_tx.send(diff.finish()).await?;
179
180 anyhow::Ok(())
181 });
182
183 while let Some(hunks) = hunks_rx.next().await {
184 this.update(&mut cx, |this, cx| {
185 this.last_equal_ranges.clear();
186
187 let transaction = this.buffer.update(cx, |buffer, cx| {
188 // Avoid grouping assistant edits with user edits.
189 buffer.finalize_last_transaction(cx);
190
191 buffer.start_transaction(cx);
192 buffer.edit(
193 hunks.into_iter().filter_map(|hunk| match hunk {
194 Hunk::Insert { text } => {
195 let edit_start = snapshot.anchor_after(edit_start);
196 Some((edit_start..edit_start, text))
197 }
198 Hunk::Remove { len } => {
199 let edit_end = edit_start + len;
200 let edit_range = snapshot.anchor_after(edit_start)
201 ..snapshot.anchor_before(edit_end);
202 edit_start = edit_end;
203 Some((edit_range, String::new()))
204 }
205 Hunk::Keep { len } => {
206 let edit_end = edit_start + len;
207 let edit_range = snapshot.anchor_after(edit_start)
208 ..snapshot.anchor_before(edit_end);
209 edit_start = edit_end;
210 this.last_equal_ranges.push(edit_range);
211 None
212 }
213 }),
214 None,
215 cx,
216 );
217
218 buffer.end_transaction(cx)
219 });
220
221 if let Some(transaction) = transaction {
222 if let Some(first_transaction) = this.transaction_id {
223 // Group all assistant edits into the first transaction.
224 this.buffer.update(cx, |buffer, cx| {
225 buffer.merge_transactions(
226 transaction,
227 first_transaction,
228 cx,
229 )
230 });
231 } else {
232 this.transaction_id = Some(transaction);
233 this.buffer.update(cx, |buffer, cx| {
234 buffer.finalize_last_transaction(cx)
235 });
236 }
237 }
238
239 cx.notify();
240 })?;
241 }
242
243 diff.await?;
244 anyhow::Ok(())
245 };
246
247 let result = generate.await;
248 this.update(&mut cx, |this, cx| {
249 this.last_equal_ranges.clear();
250 this.idle = true;
251 if let Err(error) = result {
252 this.error = Some(error);
253 }
254 cx.emit(Event::Finished);
255 cx.notify();
256 })
257 .ok();
258 }
259 });
260 self.error.take();
261 self.idle = false;
262 cx.notify();
263 }
264
265 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
266 if let Some(transaction_id) = self.transaction_id {
267 self.buffer
268 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
269 }
270 }
271}
272
273fn strip_invalid_spans_from_codeblock(
274 stream: impl Stream<Item = Result<String>>,
275) -> impl Stream<Item = Result<String>> {
276 let mut first_line = true;
277 let mut buffer = String::new();
278 let mut starts_with_markdown_codeblock = false;
279 let mut includes_start_or_end_span = false;
280 stream.filter_map(move |chunk| {
281 let chunk = match chunk {
282 Ok(chunk) => chunk,
283 Err(err) => return future::ready(Some(Err(err))),
284 };
285 buffer.push_str(&chunk);
286
287 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
288 includes_start_or_end_span = true;
289
290 buffer = buffer
291 .strip_prefix("<|S|>")
292 .or_else(|| buffer.strip_prefix("<|S|"))
293 .unwrap_or(&buffer)
294 .to_string();
295 } else if buffer.ends_with("|E|>") {
296 includes_start_or_end_span = true;
297 } else if buffer.starts_with("<|")
298 || buffer.starts_with("<|S")
299 || buffer.starts_with("<|S|")
300 || buffer.ends_with("|")
301 || buffer.ends_with("|E")
302 || buffer.ends_with("|E|")
303 {
304 return future::ready(None);
305 }
306
307 if first_line {
308 if buffer == "" || buffer == "`" || buffer == "``" {
309 return future::ready(None);
310 } else if buffer.starts_with("```") {
311 starts_with_markdown_codeblock = true;
312 if let Some(newline_ix) = buffer.find('\n') {
313 buffer.replace_range(..newline_ix + 1, "");
314 first_line = false;
315 } else {
316 return future::ready(None);
317 }
318 }
319 }
320
321 let mut text = buffer.to_string();
322 if starts_with_markdown_codeblock {
323 text = text
324 .strip_suffix("\n```\n")
325 .or_else(|| text.strip_suffix("\n```"))
326 .or_else(|| text.strip_suffix("\n``"))
327 .or_else(|| text.strip_suffix("\n`"))
328 .or_else(|| text.strip_suffix('\n'))
329 .unwrap_or(&text)
330 .to_string();
331 }
332
333 if includes_start_or_end_span {
334 text = text
335 .strip_suffix("|E|>")
336 .or_else(|| text.strip_suffix("E|>"))
337 .or_else(|| text.strip_prefix("|>"))
338 .or_else(|| text.strip_prefix(">"))
339 .unwrap_or(&text)
340 .to_string();
341 };
342
343 if text.contains('\n') {
344 first_line = false;
345 }
346
347 let remainder = buffer.split_off(text.len());
348 let result = if buffer.is_empty() {
349 None
350 } else {
351 Some(Ok(buffer.clone()))
352 };
353
354 buffer = remainder;
355 future::ready(result)
356 })
357}
358
359#[cfg(test)]
360mod tests {
361 use std::sync::Arc;
362
363 use super::*;
364 use ai::test::FakeCompletionProvider;
365 use futures::stream::{self};
366 use gpui::{Context, TestAppContext};
367 use indoc::indoc;
368 use language::{
369 language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig, Point,
370 };
371 use rand::prelude::*;
372 use serde::Serialize;
373 use settings::SettingsStore;
374
375 #[derive(Serialize)]
376 pub struct DummyCompletionRequest {
377 pub name: String,
378 }
379
380 impl CompletionRequest for DummyCompletionRequest {
381 fn data(&self) -> serde_json::Result<String> {
382 serde_json::to_string(self)
383 }
384 }
385
386 #[gpui::test(iterations = 10)]
387 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
388 cx.set_global(cx.update(SettingsStore::test));
389 cx.update(language_settings::init);
390
391 let text = indoc! {"
392 fn main() {
393 let x = 0;
394 for _ in 0..10 {
395 x += 1;
396 }
397 }
398 "};
399 let buffer = cx.new_model(|cx| {
400 Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
401 });
402 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
403 let range = buffer.read_with(cx, |buffer, cx| {
404 let snapshot = buffer.snapshot(cx);
405 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
406 });
407 let provider = Arc::new(FakeCompletionProvider::new());
408 let codegen = cx.new_model(|cx| {
409 Codegen::new(
410 buffer.clone(),
411 CodegenKind::Transform { range },
412 provider.clone(),
413 cx,
414 )
415 });
416
417 let request = Box::new(DummyCompletionRequest {
418 name: "test".to_string(),
419 });
420 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
421
422 let mut new_text = concat!(
423 " let mut x = 0;\n",
424 " while x < 10 {\n",
425 " x += 1;\n",
426 " }",
427 );
428 while !new_text.is_empty() {
429 let max_len = cmp::min(new_text.len(), 10);
430 let len = rng.gen_range(1..=max_len);
431 let (chunk, suffix) = new_text.split_at(len);
432 println!("CHUNK: {:?}", &chunk);
433 provider.send_completion(chunk);
434 new_text = suffix;
435 cx.background_executor.run_until_parked();
436 }
437 provider.finish_completion();
438 cx.background_executor.run_until_parked();
439
440 assert_eq!(
441 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
442 indoc! {"
443 fn main() {
444 let mut x = 0;
445 while x < 10 {
446 x += 1;
447 }
448 }
449 "}
450 );
451 }
452
453 #[gpui::test(iterations = 10)]
454 async fn test_autoindent_when_generating_past_indentation(
455 cx: &mut TestAppContext,
456 mut rng: StdRng,
457 ) {
458 cx.set_global(cx.update(SettingsStore::test));
459 cx.update(language_settings::init);
460
461 let text = indoc! {"
462 fn main() {
463 le
464 }
465 "};
466 let buffer = cx.new_model(|cx| {
467 Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
468 });
469 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
470 let position = buffer.read_with(cx, |buffer, cx| {
471 let snapshot = buffer.snapshot(cx);
472 snapshot.anchor_before(Point::new(1, 6))
473 });
474 let provider = Arc::new(FakeCompletionProvider::new());
475 let codegen = cx.new_model(|cx| {
476 Codegen::new(
477 buffer.clone(),
478 CodegenKind::Generate { position },
479 provider.clone(),
480 cx,
481 )
482 });
483
484 let request = Box::new(DummyCompletionRequest {
485 name: "test".to_string(),
486 });
487 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
488
489 let mut new_text = concat!(
490 "t mut x = 0;\n",
491 "while x < 10 {\n",
492 " x += 1;\n",
493 "}", //
494 );
495 while !new_text.is_empty() {
496 let max_len = cmp::min(new_text.len(), 10);
497 let len = rng.gen_range(1..=max_len);
498 let (chunk, suffix) = new_text.split_at(len);
499 provider.send_completion(chunk);
500 new_text = suffix;
501 cx.background_executor.run_until_parked();
502 }
503 provider.finish_completion();
504 cx.background_executor.run_until_parked();
505
506 assert_eq!(
507 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
508 indoc! {"
509 fn main() {
510 let mut x = 0;
511 while x < 10 {
512 x += 1;
513 }
514 }
515 "}
516 );
517 }
518
519 #[gpui::test(iterations = 10)]
520 async fn test_autoindent_when_generating_before_indentation(
521 cx: &mut TestAppContext,
522 mut rng: StdRng,
523 ) {
524 cx.set_global(cx.update(SettingsStore::test));
525 cx.update(language_settings::init);
526
527 let text = concat!(
528 "fn main() {\n",
529 " \n",
530 "}\n" //
531 );
532 let buffer = cx.new_model(|cx| {
533 Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
534 });
535 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
536 let position = buffer.read_with(cx, |buffer, cx| {
537 let snapshot = buffer.snapshot(cx);
538 snapshot.anchor_before(Point::new(1, 2))
539 });
540 let provider = Arc::new(FakeCompletionProvider::new());
541 let codegen = cx.new_model(|cx| {
542 Codegen::new(
543 buffer.clone(),
544 CodegenKind::Generate { position },
545 provider.clone(),
546 cx,
547 )
548 });
549
550 let request = Box::new(DummyCompletionRequest {
551 name: "test".to_string(),
552 });
553 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
554
555 let mut new_text = concat!(
556 "let mut x = 0;\n",
557 "while x < 10 {\n",
558 " x += 1;\n",
559 "}", //
560 );
561 while !new_text.is_empty() {
562 let max_len = cmp::min(new_text.len(), 10);
563 let len = rng.gen_range(1..=max_len);
564 let (chunk, suffix) = new_text.split_at(len);
565 println!("{:?}", &chunk);
566 provider.send_completion(chunk);
567 new_text = suffix;
568 cx.background_executor.run_until_parked();
569 }
570 provider.finish_completion();
571 cx.background_executor.run_until_parked();
572
573 assert_eq!(
574 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
575 indoc! {"
576 fn main() {
577 let mut x = 0;
578 while x < 10 {
579 x += 1;
580 }
581 }
582 "}
583 );
584 }
585
586 #[gpui::test]
587 async fn test_strip_invalid_spans_from_codeblock() {
588 assert_eq!(
589 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
590 .map(|chunk| chunk.unwrap())
591 .collect::<String>()
592 .await,
593 "Lorem ipsum dolor"
594 );
595 assert_eq!(
596 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
597 .map(|chunk| chunk.unwrap())
598 .collect::<String>()
599 .await,
600 "Lorem ipsum dolor"
601 );
602 assert_eq!(
603 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
604 .map(|chunk| chunk.unwrap())
605 .collect::<String>()
606 .await,
607 "Lorem ipsum dolor"
608 );
609 assert_eq!(
610 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
611 .map(|chunk| chunk.unwrap())
612 .collect::<String>()
613 .await,
614 "Lorem ipsum dolor"
615 );
616 assert_eq!(
617 strip_invalid_spans_from_codeblock(chunks(
618 "```html\n```js\nLorem ipsum dolor\n```\n```",
619 2
620 ))
621 .map(|chunk| chunk.unwrap())
622 .collect::<String>()
623 .await,
624 "```js\nLorem ipsum dolor\n```"
625 );
626 assert_eq!(
627 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
628 .map(|chunk| chunk.unwrap())
629 .collect::<String>()
630 .await,
631 "``\nLorem ipsum dolor\n```"
632 );
633 assert_eq!(
634 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
635 .map(|chunk| chunk.unwrap())
636 .collect::<String>()
637 .await,
638 "Lorem ipsum"
639 );
640
641 assert_eq!(
642 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
643 .map(|chunk| chunk.unwrap())
644 .collect::<String>()
645 .await,
646 "Lorem ipsum"
647 );
648
649 assert_eq!(
650 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
651 .map(|chunk| chunk.unwrap())
652 .collect::<String>()
653 .await,
654 "Lorem ipsum"
655 );
656 assert_eq!(
657 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
658 .map(|chunk| chunk.unwrap())
659 .collect::<String>()
660 .await,
661 "Lorem ipsum"
662 );
663 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
664 stream::iter(
665 text.chars()
666 .collect::<Vec<_>>()
667 .chunks(size)
668 .map(|chunk| Ok(chunk.iter().collect::<String>()))
669 .collect::<Vec<_>>(),
670 )
671 }
672 }
673
674 fn rust_lang() -> Language {
675 Language::new(
676 LanguageConfig {
677 name: "Rust".into(),
678 path_suffixes: vec!["rs".to_string()],
679 ..Default::default()
680 },
681 Some(tree_sitter_rust::language()),
682 )
683 .with_indents_query(
684 r#"
685 (call_expression) @indent
686 (field_expression) @indent
687 (_ "(" ")" @end) @indent
688 (_ "{" "}" @end) @indent
689 "#,
690 )
691 .unwrap()
692 }
693}