1use crate::streaming_diff::{Hunk, StreamingDiff};
2use ai::completion::{CompletionProvider, OpenAIRequest};
3use anyhow::Result;
4use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
6use gpui::{Entity, ModelContext, ModelHandle, Task};
7use language::{Rope, TransactionId};
8use std::{cmp, future, ops::Range, sync::Arc};
9
10pub enum Event {
11 Finished,
12 Undone,
13}
14
15#[derive(Clone)]
16pub enum CodegenKind {
17 Transform { range: Range<Anchor> },
18 Generate { position: Anchor },
19}
20
21pub struct Codegen {
22 provider: Arc<dyn CompletionProvider>,
23 buffer: ModelHandle<MultiBuffer>,
24 snapshot: MultiBufferSnapshot,
25 kind: CodegenKind,
26 last_equal_ranges: Vec<Range<Anchor>>,
27 transaction_id: Option<TransactionId>,
28 error: Option<anyhow::Error>,
29 generation: Task<()>,
30 idle: bool,
31 _subscription: gpui::Subscription,
32}
33
34impl Entity for Codegen {
35 type Event = Event;
36}
37
38impl Codegen {
39 pub fn new(
40 buffer: ModelHandle<MultiBuffer>,
41 kind: CodegenKind,
42 provider: Arc<dyn CompletionProvider>,
43 cx: &mut ModelContext<Self>,
44 ) -> Self {
45 let snapshot = buffer.read(cx).snapshot(cx);
46 Self {
47 provider,
48 buffer: buffer.clone(),
49 snapshot,
50 kind,
51 last_equal_ranges: Default::default(),
52 transaction_id: Default::default(),
53 error: Default::default(),
54 idle: true,
55 generation: Task::ready(()),
56 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
57 }
58 }
59
60 fn handle_buffer_event(
61 &mut self,
62 _buffer: ModelHandle<MultiBuffer>,
63 event: &multi_buffer::Event,
64 cx: &mut ModelContext<Self>,
65 ) {
66 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
67 if self.transaction_id == Some(*transaction_id) {
68 self.transaction_id = None;
69 self.generation = Task::ready(());
70 cx.emit(Event::Undone);
71 }
72 }
73 }
74
75 pub fn range(&self) -> Range<Anchor> {
76 match &self.kind {
77 CodegenKind::Transform { range } => range.clone(),
78 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
79 }
80 }
81
82 pub fn kind(&self) -> &CodegenKind {
83 &self.kind
84 }
85
86 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
87 &self.last_equal_ranges
88 }
89
90 pub fn idle(&self) -> bool {
91 self.idle
92 }
93
94 pub fn error(&self) -> Option<&anyhow::Error> {
95 self.error.as_ref()
96 }
97
98 pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
99 let range = self.range();
100 let snapshot = self.snapshot.clone();
101 let selected_text = snapshot
102 .text_for_range(range.start..range.end)
103 .collect::<Rope>();
104
105 let selection_start = range.start.to_point(&snapshot);
106 let suggested_line_indent = snapshot
107 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
108 .into_values()
109 .next()
110 .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
111
112 let response = self.provider.complete(prompt);
113 self.generation = cx.spawn_weak(|this, mut cx| {
114 async move {
115 let generate = async {
116 let mut edit_start = range.start.to_offset(&snapshot);
117
118 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
119 let diff = cx.background().spawn(async move {
120 let chunks = strip_invalid_spans_from_codeblock(response.await?);
121 futures::pin_mut!(chunks);
122 let mut diff = StreamingDiff::new(selected_text.to_string());
123
124 let mut new_text = String::new();
125 let mut base_indent = None;
126 let mut line_indent = None;
127 let mut first_line = true;
128
129 while let Some(chunk) = chunks.next().await {
130 let chunk = chunk?;
131
132 let mut lines = chunk.split('\n').peekable();
133 while let Some(line) = lines.next() {
134 new_text.push_str(line);
135 if line_indent.is_none() {
136 if let Some(non_whitespace_ch_ix) =
137 new_text.find(|ch: char| !ch.is_whitespace())
138 {
139 line_indent = Some(non_whitespace_ch_ix);
140 base_indent = base_indent.or(line_indent);
141
142 let line_indent = line_indent.unwrap();
143 let base_indent = base_indent.unwrap();
144 let indent_delta = line_indent as i32 - base_indent as i32;
145 let mut corrected_indent_len = cmp::max(
146 0,
147 suggested_line_indent.len as i32 + indent_delta,
148 )
149 as usize;
150 if first_line {
151 corrected_indent_len = corrected_indent_len
152 .saturating_sub(selection_start.column as usize);
153 }
154
155 let indent_char = suggested_line_indent.char();
156 let mut indent_buffer = [0; 4];
157 let indent_str =
158 indent_char.encode_utf8(&mut indent_buffer);
159 new_text.replace_range(
160 ..line_indent,
161 &indent_str.repeat(corrected_indent_len),
162 );
163 }
164 }
165
166 if line_indent.is_some() {
167 hunks_tx.send(diff.push_new(&new_text)).await?;
168 new_text.clear();
169 }
170
171 if lines.peek().is_some() {
172 hunks_tx.send(diff.push_new("\n")).await?;
173 line_indent = None;
174 first_line = false;
175 }
176 }
177 }
178 hunks_tx.send(diff.push_new(&new_text)).await?;
179 hunks_tx.send(diff.finish()).await?;
180
181 anyhow::Ok(())
182 });
183
184 while let Some(hunks) = hunks_rx.next().await {
185 let this = if let Some(this) = this.upgrade(&cx) {
186 this
187 } else {
188 break;
189 };
190
191 this.update(&mut cx, |this, cx| {
192 this.last_equal_ranges.clear();
193
194 let transaction = this.buffer.update(cx, |buffer, cx| {
195 // Avoid grouping assistant edits with user edits.
196 buffer.finalize_last_transaction(cx);
197
198 buffer.start_transaction(cx);
199 buffer.edit(
200 hunks.into_iter().filter_map(|hunk| match hunk {
201 Hunk::Insert { text } => {
202 let edit_start = snapshot.anchor_after(edit_start);
203 Some((edit_start..edit_start, text))
204 }
205 Hunk::Remove { len } => {
206 let edit_end = edit_start + len;
207 let edit_range = snapshot.anchor_after(edit_start)
208 ..snapshot.anchor_before(edit_end);
209 edit_start = edit_end;
210 Some((edit_range, String::new()))
211 }
212 Hunk::Keep { len } => {
213 let edit_end = edit_start + len;
214 let edit_range = snapshot.anchor_after(edit_start)
215 ..snapshot.anchor_before(edit_end);
216 edit_start = edit_end;
217 this.last_equal_ranges.push(edit_range);
218 None
219 }
220 }),
221 None,
222 cx,
223 );
224
225 buffer.end_transaction(cx)
226 });
227
228 if let Some(transaction) = transaction {
229 if let Some(first_transaction) = this.transaction_id {
230 // Group all assistant edits into the first transaction.
231 this.buffer.update(cx, |buffer, cx| {
232 buffer.merge_transactions(
233 transaction,
234 first_transaction,
235 cx,
236 )
237 });
238 } else {
239 this.transaction_id = Some(transaction);
240 this.buffer.update(cx, |buffer, cx| {
241 buffer.finalize_last_transaction(cx)
242 });
243 }
244 }
245
246 cx.notify();
247 });
248 }
249
250 diff.await?;
251 anyhow::Ok(())
252 };
253
254 let result = generate.await;
255 if let Some(this) = this.upgrade(&cx) {
256 this.update(&mut cx, |this, cx| {
257 this.last_equal_ranges.clear();
258 this.idle = true;
259 if let Err(error) = result {
260 this.error = Some(error);
261 }
262 cx.emit(Event::Finished);
263 cx.notify();
264 });
265 }
266 }
267 });
268 self.error.take();
269 self.idle = false;
270 cx.notify();
271 }
272
273 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
274 if let Some(transaction_id) = self.transaction_id {
275 self.buffer
276 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
277 }
278 }
279}
280
281fn strip_invalid_spans_from_codeblock(
282 stream: impl Stream<Item = Result<String>>,
283) -> impl Stream<Item = Result<String>> {
284 let mut first_line = true;
285 let mut buffer = String::new();
286 let mut starts_with_markdown_codeblock = false;
287 let mut includes_start_or_end_span = false;
288 stream.filter_map(move |chunk| {
289 let chunk = match chunk {
290 Ok(chunk) => chunk,
291 Err(err) => return future::ready(Some(Err(err))),
292 };
293 buffer.push_str(&chunk);
294
295 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
296 includes_start_or_end_span = true;
297
298 buffer = buffer
299 .strip_prefix("<|S|>")
300 .or_else(|| buffer.strip_prefix("<|S|"))
301 .unwrap_or(&buffer)
302 .to_string();
303 } else if buffer.ends_with("|E|>") {
304 includes_start_or_end_span = true;
305 } else if buffer.starts_with("<|")
306 || buffer.starts_with("<|S")
307 || buffer.starts_with("<|S|")
308 || buffer.ends_with("|")
309 || buffer.ends_with("|E")
310 || buffer.ends_with("|E|")
311 {
312 return future::ready(None);
313 }
314
315 if first_line {
316 if buffer == "" || buffer == "`" || buffer == "``" {
317 return future::ready(None);
318 } else if buffer.starts_with("```") {
319 starts_with_markdown_codeblock = true;
320 if let Some(newline_ix) = buffer.find('\n') {
321 buffer.replace_range(..newline_ix + 1, "");
322 first_line = false;
323 } else {
324 return future::ready(None);
325 }
326 }
327 }
328
329 let mut text = buffer.to_string();
330 if starts_with_markdown_codeblock {
331 text = text
332 .strip_suffix("\n```\n")
333 .or_else(|| text.strip_suffix("\n```"))
334 .or_else(|| text.strip_suffix("\n``"))
335 .or_else(|| text.strip_suffix("\n`"))
336 .or_else(|| text.strip_suffix('\n'))
337 .unwrap_or(&text)
338 .to_string();
339 }
340
341 if includes_start_or_end_span {
342 text = text
343 .strip_suffix("|E|>")
344 .or_else(|| text.strip_suffix("E|>"))
345 .or_else(|| text.strip_prefix("|>"))
346 .or_else(|| text.strip_prefix(">"))
347 .unwrap_or(&text)
348 .to_string();
349 };
350
351 if text.contains('\n') {
352 first_line = false;
353 }
354
355 let remainder = buffer.split_off(text.len());
356 let result = if buffer.is_empty() {
357 None
358 } else {
359 Some(Ok(buffer.clone()))
360 };
361
362 buffer = remainder;
363 future::ready(result)
364 })
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370 use futures::{
371 future::BoxFuture,
372 stream::{self, BoxStream},
373 };
374 use gpui::{executor::Deterministic, TestAppContext};
375 use indoc::indoc;
376 use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
377 use parking_lot::Mutex;
378 use rand::prelude::*;
379 use settings::SettingsStore;
380 use smol::future::FutureExt;
381
382 #[gpui::test(iterations = 10)]
383 async fn test_transform_autoindent(
384 cx: &mut TestAppContext,
385 mut rng: StdRng,
386 deterministic: Arc<Deterministic>,
387 ) {
388 cx.set_global(cx.read(SettingsStore::test));
389 cx.update(language_settings::init);
390
391 let text = indoc! {"
392 fn main() {
393 let x = 0;
394 for _ in 0..10 {
395 x += 1;
396 }
397 }
398 "};
399 let buffer =
400 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
401 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
402 let range = buffer.read_with(cx, |buffer, cx| {
403 let snapshot = buffer.snapshot(cx);
404 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
405 });
406 let provider = Arc::new(TestCompletionProvider::new());
407 let codegen = cx.add_model(|cx| {
408 Codegen::new(
409 buffer.clone(),
410 CodegenKind::Transform { range },
411 provider.clone(),
412 cx,
413 )
414 });
415 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
416
417 let mut new_text = concat!(
418 " let mut x = 0;\n",
419 " while x < 10 {\n",
420 " x += 1;\n",
421 " }",
422 );
423 while !new_text.is_empty() {
424 let max_len = cmp::min(new_text.len(), 10);
425 let len = rng.gen_range(1..=max_len);
426 let (chunk, suffix) = new_text.split_at(len);
427 provider.send_completion(chunk);
428 new_text = suffix;
429 deterministic.run_until_parked();
430 }
431 provider.finish_completion();
432 deterministic.run_until_parked();
433
434 assert_eq!(
435 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
436 indoc! {"
437 fn main() {
438 let mut x = 0;
439 while x < 10 {
440 x += 1;
441 }
442 }
443 "}
444 );
445 }
446
447 #[gpui::test(iterations = 10)]
448 async fn test_autoindent_when_generating_past_indentation(
449 cx: &mut TestAppContext,
450 mut rng: StdRng,
451 deterministic: Arc<Deterministic>,
452 ) {
453 cx.set_global(cx.read(SettingsStore::test));
454 cx.update(language_settings::init);
455
456 let text = indoc! {"
457 fn main() {
458 le
459 }
460 "};
461 let buffer =
462 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
463 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
464 let position = buffer.read_with(cx, |buffer, cx| {
465 let snapshot = buffer.snapshot(cx);
466 snapshot.anchor_before(Point::new(1, 6))
467 });
468 let provider = Arc::new(TestCompletionProvider::new());
469 let codegen = cx.add_model(|cx| {
470 Codegen::new(
471 buffer.clone(),
472 CodegenKind::Generate { position },
473 provider.clone(),
474 cx,
475 )
476 });
477 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
478
479 let mut new_text = concat!(
480 "t mut x = 0;\n",
481 "while x < 10 {\n",
482 " x += 1;\n",
483 "}", //
484 );
485 while !new_text.is_empty() {
486 let max_len = cmp::min(new_text.len(), 10);
487 let len = rng.gen_range(1..=max_len);
488 let (chunk, suffix) = new_text.split_at(len);
489 provider.send_completion(chunk);
490 new_text = suffix;
491 deterministic.run_until_parked();
492 }
493 provider.finish_completion();
494 deterministic.run_until_parked();
495
496 assert_eq!(
497 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
498 indoc! {"
499 fn main() {
500 let mut x = 0;
501 while x < 10 {
502 x += 1;
503 }
504 }
505 "}
506 );
507 }
508
509 #[gpui::test(iterations = 10)]
510 async fn test_autoindent_when_generating_before_indentation(
511 cx: &mut TestAppContext,
512 mut rng: StdRng,
513 deterministic: Arc<Deterministic>,
514 ) {
515 cx.set_global(cx.read(SettingsStore::test));
516 cx.update(language_settings::init);
517
518 let text = concat!(
519 "fn main() {\n",
520 " \n",
521 "}\n" //
522 );
523 let buffer =
524 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
525 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
526 let position = buffer.read_with(cx, |buffer, cx| {
527 let snapshot = buffer.snapshot(cx);
528 snapshot.anchor_before(Point::new(1, 2))
529 });
530 let provider = Arc::new(TestCompletionProvider::new());
531 let codegen = cx.add_model(|cx| {
532 Codegen::new(
533 buffer.clone(),
534 CodegenKind::Generate { position },
535 provider.clone(),
536 cx,
537 )
538 });
539 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
540
541 let mut new_text = concat!(
542 "let mut x = 0;\n",
543 "while x < 10 {\n",
544 " x += 1;\n",
545 "}", //
546 );
547 while !new_text.is_empty() {
548 let max_len = cmp::min(new_text.len(), 10);
549 let len = rng.gen_range(1..=max_len);
550 let (chunk, suffix) = new_text.split_at(len);
551 provider.send_completion(chunk);
552 new_text = suffix;
553 deterministic.run_until_parked();
554 }
555 provider.finish_completion();
556 deterministic.run_until_parked();
557
558 assert_eq!(
559 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
560 indoc! {"
561 fn main() {
562 let mut x = 0;
563 while x < 10 {
564 x += 1;
565 }
566 }
567 "}
568 );
569 }
570
571 #[gpui::test]
572 async fn test_strip_invalid_spans_from_codeblock() {
573 assert_eq!(
574 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
575 .map(|chunk| chunk.unwrap())
576 .collect::<String>()
577 .await,
578 "Lorem ipsum dolor"
579 );
580 assert_eq!(
581 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
582 .map(|chunk| chunk.unwrap())
583 .collect::<String>()
584 .await,
585 "Lorem ipsum dolor"
586 );
587 assert_eq!(
588 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
589 .map(|chunk| chunk.unwrap())
590 .collect::<String>()
591 .await,
592 "Lorem ipsum dolor"
593 );
594 assert_eq!(
595 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
596 .map(|chunk| chunk.unwrap())
597 .collect::<String>()
598 .await,
599 "Lorem ipsum dolor"
600 );
601 assert_eq!(
602 strip_invalid_spans_from_codeblock(chunks(
603 "```html\n```js\nLorem ipsum dolor\n```\n```",
604 2
605 ))
606 .map(|chunk| chunk.unwrap())
607 .collect::<String>()
608 .await,
609 "```js\nLorem ipsum dolor\n```"
610 );
611 assert_eq!(
612 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
613 .map(|chunk| chunk.unwrap())
614 .collect::<String>()
615 .await,
616 "``\nLorem ipsum dolor\n```"
617 );
618 assert_eq!(
619 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
620 .map(|chunk| chunk.unwrap())
621 .collect::<String>()
622 .await,
623 "Lorem ipsum"
624 );
625
626 assert_eq!(
627 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
628 .map(|chunk| chunk.unwrap())
629 .collect::<String>()
630 .await,
631 "Lorem ipsum"
632 );
633
634 assert_eq!(
635 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
636 .map(|chunk| chunk.unwrap())
637 .collect::<String>()
638 .await,
639 "Lorem ipsum"
640 );
641 assert_eq!(
642 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
643 .map(|chunk| chunk.unwrap())
644 .collect::<String>()
645 .await,
646 "Lorem ipsum"
647 );
648 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
649 stream::iter(
650 text.chars()
651 .collect::<Vec<_>>()
652 .chunks(size)
653 .map(|chunk| Ok(chunk.iter().collect::<String>()))
654 .collect::<Vec<_>>(),
655 )
656 }
657 }
658
659 struct TestCompletionProvider {
660 last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
661 }
662
663 impl TestCompletionProvider {
664 fn new() -> Self {
665 Self {
666 last_completion_tx: Mutex::new(None),
667 }
668 }
669
670 fn send_completion(&self, completion: impl Into<String>) {
671 let mut tx = self.last_completion_tx.lock();
672 tx.as_mut().unwrap().try_send(completion.into()).unwrap();
673 }
674
675 fn finish_completion(&self) {
676 self.last_completion_tx.lock().take().unwrap();
677 }
678 }
679
680 impl CompletionProvider for TestCompletionProvider {
681 fn complete(
682 &self,
683 _prompt: OpenAIRequest,
684 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
685 let (tx, rx) = mpsc::channel(1);
686 *self.last_completion_tx.lock() = Some(tx);
687 async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
688 }
689 }
690
691 fn rust_lang() -> Language {
692 Language::new(
693 LanguageConfig {
694 name: "Rust".into(),
695 path_suffixes: vec!["rs".to_string()],
696 ..Default::default()
697 },
698 Some(tree_sitter_rust::language()),
699 )
700 .with_indents_query(
701 r#"
702 (call_expression) @indent
703 (field_expression) @indent
704 (_ "(" ")" @end) @indent
705 (_ "{" "}" @end) @indent
706 "#,
707 )
708 .unwrap()
709 }
710}