codegen.rs

  1use crate::streaming_diff::{Hunk, StreamingDiff};
  2use ai::completion::{CompletionProvider, CompletionRequest};
  3use anyhow::Result;
  4use editor::{multi_buffer, Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  6use gpui::{Entity, ModelContext, ModelHandle, Task};
  7use language::{Rope, TransactionId};
  8use std::{cmp, future, ops::Range, sync::Arc};
  9
 10pub enum Event {
 11    Finished,
 12    Undone,
 13}
 14
 15#[derive(Clone)]
 16pub enum CodegenKind {
 17    Transform { range: Range<Anchor> },
 18    Generate { position: Anchor },
 19}
 20
 21pub struct Codegen {
 22    provider: Arc<dyn CompletionProvider>,
 23    buffer: ModelHandle<MultiBuffer>,
 24    snapshot: MultiBufferSnapshot,
 25    kind: CodegenKind,
 26    last_equal_ranges: Vec<Range<Anchor>>,
 27    transaction_id: Option<TransactionId>,
 28    error: Option<anyhow::Error>,
 29    generation: Task<()>,
 30    idle: bool,
 31    _subscription: gpui::Subscription,
 32}
 33
 34impl Entity for Codegen {
 35    type Event = Event;
 36}
 37
 38impl Codegen {
 39    pub fn new(
 40        buffer: ModelHandle<MultiBuffer>,
 41        kind: CodegenKind,
 42        provider: Arc<dyn CompletionProvider>,
 43        cx: &mut ModelContext<Self>,
 44    ) -> Self {
 45        let snapshot = buffer.read(cx).snapshot(cx);
 46        Self {
 47            provider,
 48            buffer: buffer.clone(),
 49            snapshot,
 50            kind,
 51            last_equal_ranges: Default::default(),
 52            transaction_id: Default::default(),
 53            error: Default::default(),
 54            idle: true,
 55            generation: Task::ready(()),
 56            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 57        }
 58    }
 59
 60    fn handle_buffer_event(
 61        &mut self,
 62        _buffer: ModelHandle<MultiBuffer>,
 63        event: &multi_buffer::Event,
 64        cx: &mut ModelContext<Self>,
 65    ) {
 66        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 67            if self.transaction_id == Some(*transaction_id) {
 68                self.transaction_id = None;
 69                self.generation = Task::ready(());
 70                cx.emit(Event::Undone);
 71            }
 72        }
 73    }
 74
 75    pub fn range(&self) -> Range<Anchor> {
 76        match &self.kind {
 77            CodegenKind::Transform { range } => range.clone(),
 78            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 79        }
 80    }
 81
 82    pub fn kind(&self) -> &CodegenKind {
 83        &self.kind
 84    }
 85
 86    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 87        &self.last_equal_ranges
 88    }
 89
 90    pub fn idle(&self) -> bool {
 91        self.idle
 92    }
 93
 94    pub fn error(&self) -> Option<&anyhow::Error> {
 95        self.error.as_ref()
 96    }
 97
 98    pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
 99        let range = self.range();
100        let snapshot = self.snapshot.clone();
101        let selected_text = snapshot
102            .text_for_range(range.start..range.end)
103            .collect::<Rope>();
104
105        let selection_start = range.start.to_point(&snapshot);
106        let suggested_line_indent = snapshot
107            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
108            .into_values()
109            .next()
110            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
111
112        let response = self.provider.complete(prompt);
113        self.generation = cx.spawn_weak(|this, mut cx| {
114            async move {
115                let generate = async {
116                    let mut edit_start = range.start.to_offset(&snapshot);
117
118                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
119                    let diff = cx.background().spawn(async move {
120                        let chunks = strip_markdown_codeblock(response.await?);
121                        futures::pin_mut!(chunks);
122                        let mut diff = StreamingDiff::new(selected_text.to_string());
123
124                        let mut new_text = String::new();
125                        let mut base_indent = None;
126                        let mut line_indent = None;
127                        let mut first_line = true;
128
129                        while let Some(chunk) = chunks.next().await {
130                            let chunk = chunk?;
131
132                            let mut lines = chunk.split('\n').peekable();
133                            while let Some(line) = lines.next() {
134                                new_text.push_str(line);
135                                if line_indent.is_none() {
136                                    if let Some(non_whitespace_ch_ix) =
137                                        new_text.find(|ch: char| !ch.is_whitespace())
138                                    {
139                                        line_indent = Some(non_whitespace_ch_ix);
140                                        base_indent = base_indent.or(line_indent);
141
142                                        let line_indent = line_indent.unwrap();
143                                        let base_indent = base_indent.unwrap();
144                                        let indent_delta = line_indent as i32 - base_indent as i32;
145                                        let mut corrected_indent_len = cmp::max(
146                                            0,
147                                            suggested_line_indent.len as i32 + indent_delta,
148                                        )
149                                            as usize;
150                                        if first_line {
151                                            corrected_indent_len = corrected_indent_len
152                                                .saturating_sub(selection_start.column as usize);
153                                        }
154
155                                        let indent_char = suggested_line_indent.char();
156                                        let mut indent_buffer = [0; 4];
157                                        let indent_str =
158                                            indent_char.encode_utf8(&mut indent_buffer);
159                                        new_text.replace_range(
160                                            ..line_indent,
161                                            &indent_str.repeat(corrected_indent_len),
162                                        );
163                                    }
164                                }
165
166                                if line_indent.is_some() {
167                                    hunks_tx.send(diff.push_new(&new_text)).await?;
168                                    new_text.clear();
169                                }
170
171                                if lines.peek().is_some() {
172                                    hunks_tx.send(diff.push_new("\n")).await?;
173                                    line_indent = None;
174                                    first_line = false;
175                                }
176                            }
177                        }
178                        hunks_tx.send(diff.push_new(&new_text)).await?;
179                        hunks_tx.send(diff.finish()).await?;
180
181                        anyhow::Ok(())
182                    });
183
184                    while let Some(hunks) = hunks_rx.next().await {
185                        let this = if let Some(this) = this.upgrade(&cx) {
186                            this
187                        } else {
188                            break;
189                        };
190
191                        this.update(&mut cx, |this, cx| {
192                            this.last_equal_ranges.clear();
193
194                            let transaction = this.buffer.update(cx, |buffer, cx| {
195                                // Avoid grouping assistant edits with user edits.
196                                buffer.finalize_last_transaction(cx);
197
198                                buffer.start_transaction(cx);
199                                buffer.edit(
200                                    hunks.into_iter().filter_map(|hunk| match hunk {
201                                        Hunk::Insert { text } => {
202                                            let edit_start = snapshot.anchor_after(edit_start);
203                                            Some((edit_start..edit_start, text))
204                                        }
205                                        Hunk::Remove { len } => {
206                                            let edit_end = edit_start + len;
207                                            let edit_range = snapshot.anchor_after(edit_start)
208                                                ..snapshot.anchor_before(edit_end);
209                                            edit_start = edit_end;
210                                            Some((edit_range, String::new()))
211                                        }
212                                        Hunk::Keep { len } => {
213                                            let edit_end = edit_start + len;
214                                            let edit_range = snapshot.anchor_after(edit_start)
215                                                ..snapshot.anchor_before(edit_end);
216                                            edit_start = edit_end;
217                                            this.last_equal_ranges.push(edit_range);
218                                            None
219                                        }
220                                    }),
221                                    None,
222                                    cx,
223                                );
224
225                                buffer.end_transaction(cx)
226                            });
227
228                            if let Some(transaction) = transaction {
229                                if let Some(first_transaction) = this.transaction_id {
230                                    // Group all assistant edits into the first transaction.
231                                    this.buffer.update(cx, |buffer, cx| {
232                                        buffer.merge_transactions(
233                                            transaction,
234                                            first_transaction,
235                                            cx,
236                                        )
237                                    });
238                                } else {
239                                    this.transaction_id = Some(transaction);
240                                    this.buffer.update(cx, |buffer, cx| {
241                                        buffer.finalize_last_transaction(cx)
242                                    });
243                                }
244                            }
245
246                            cx.notify();
247                        });
248                    }
249
250                    diff.await?;
251                    anyhow::Ok(())
252                };
253
254                let result = generate.await;
255                if let Some(this) = this.upgrade(&cx) {
256                    this.update(&mut cx, |this, cx| {
257                        this.last_equal_ranges.clear();
258                        this.idle = true;
259                        if let Err(error) = result {
260                            this.error = Some(error);
261                        }
262                        cx.emit(Event::Finished);
263                        cx.notify();
264                    });
265                }
266            }
267        });
268        self.error.take();
269        self.idle = false;
270        cx.notify();
271    }
272
273    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
274        if let Some(transaction_id) = self.transaction_id {
275            self.buffer
276                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
277        }
278    }
279}
280
281fn strip_markdown_codeblock(
282    stream: impl Stream<Item = Result<String>>,
283) -> impl Stream<Item = Result<String>> {
284    let mut first_line = true;
285    let mut buffer = String::new();
286    let mut starts_with_fenced_code_block = false;
287    stream.filter_map(move |chunk| {
288        let chunk = match chunk {
289            Ok(chunk) => chunk,
290            Err(err) => return future::ready(Some(Err(err))),
291        };
292        buffer.push_str(&chunk);
293
294        if first_line {
295            if buffer == "" || buffer == "`" || buffer == "``" {
296                return future::ready(None);
297            } else if buffer.starts_with("```") {
298                starts_with_fenced_code_block = true;
299                if let Some(newline_ix) = buffer.find('\n') {
300                    buffer.replace_range(..newline_ix + 1, "");
301                    first_line = false;
302                } else {
303                    return future::ready(None);
304                }
305            }
306        }
307
308        let text = if starts_with_fenced_code_block {
309            buffer
310                .strip_suffix("\n```\n")
311                .or_else(|| buffer.strip_suffix("\n```"))
312                .or_else(|| buffer.strip_suffix("\n``"))
313                .or_else(|| buffer.strip_suffix("\n`"))
314                .or_else(|| buffer.strip_suffix('\n'))
315                .unwrap_or(&buffer)
316        } else {
317            &buffer
318        };
319
320        if text.contains('\n') {
321            first_line = false;
322        }
323
324        let remainder = buffer.split_off(text.len());
325        let result = if buffer.is_empty() {
326            None
327        } else {
328            Some(Ok(buffer.clone()))
329        };
330        buffer = remainder;
331        future::ready(result)
332    })
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338    use ai::{models::LanguageModel, test::FakeLanguageModel};
339    use futures::{
340        future::BoxFuture,
341        stream::{self, BoxStream},
342    };
343    use gpui::{executor::Deterministic, TestAppContext};
344    use indoc::indoc;
345    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
346    use parking_lot::Mutex;
347    use rand::prelude::*;
348    use serde::Serialize;
349    use settings::SettingsStore;
350    use smol::future::FutureExt;
351
352    #[derive(Serialize)]
353    pub struct DummyCompletionRequest {
354        pub name: String,
355    }
356
357    impl CompletionRequest for DummyCompletionRequest {
358        fn data(&self) -> serde_json::Result<String> {
359            serde_json::to_string(self)
360        }
361    }
362
363    #[gpui::test(iterations = 10)]
364    async fn test_transform_autoindent(
365        cx: &mut TestAppContext,
366        mut rng: StdRng,
367        deterministic: Arc<Deterministic>,
368    ) {
369        cx.set_global(cx.read(SettingsStore::test));
370        cx.update(language_settings::init);
371
372        let text = indoc! {"
373            fn main() {
374                let x = 0;
375                for _ in 0..10 {
376                    x += 1;
377                }
378            }
379        "};
380        let buffer =
381            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
382        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
383        let range = buffer.read_with(cx, |buffer, cx| {
384            let snapshot = buffer.snapshot(cx);
385            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
386        });
387        let provider = Arc::new(TestCompletionProvider::new());
388        let codegen = cx.add_model(|cx| {
389            Codegen::new(
390                buffer.clone(),
391                CodegenKind::Transform { range },
392                provider.clone(),
393                cx,
394            )
395        });
396
397        let request = Box::new(DummyCompletionRequest {
398            name: "test".to_string(),
399        });
400        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
401
402        let mut new_text = concat!(
403            "       let mut x = 0;\n",
404            "       while x < 10 {\n",
405            "           x += 1;\n",
406            "       }",
407        );
408        while !new_text.is_empty() {
409            let max_len = cmp::min(new_text.len(), 10);
410            let len = rng.gen_range(1..=max_len);
411            let (chunk, suffix) = new_text.split_at(len);
412            provider.send_completion(chunk);
413            new_text = suffix;
414            deterministic.run_until_parked();
415        }
416        provider.finish_completion();
417        deterministic.run_until_parked();
418
419        assert_eq!(
420            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
421            indoc! {"
422                fn main() {
423                    let mut x = 0;
424                    while x < 10 {
425                        x += 1;
426                    }
427                }
428            "}
429        );
430    }
431
432    #[gpui::test(iterations = 10)]
433    async fn test_autoindent_when_generating_past_indentation(
434        cx: &mut TestAppContext,
435        mut rng: StdRng,
436        deterministic: Arc<Deterministic>,
437    ) {
438        cx.set_global(cx.read(SettingsStore::test));
439        cx.update(language_settings::init);
440
441        let text = indoc! {"
442            fn main() {
443                le
444            }
445        "};
446        let buffer =
447            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
448        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
449        let position = buffer.read_with(cx, |buffer, cx| {
450            let snapshot = buffer.snapshot(cx);
451            snapshot.anchor_before(Point::new(1, 6))
452        });
453        let provider = Arc::new(TestCompletionProvider::new());
454        let codegen = cx.add_model(|cx| {
455            Codegen::new(
456                buffer.clone(),
457                CodegenKind::Generate { position },
458                provider.clone(),
459                cx,
460            )
461        });
462
463        let request = Box::new(DummyCompletionRequest {
464            name: "test".to_string(),
465        });
466        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
467
468        let mut new_text = concat!(
469            "t mut x = 0;\n",
470            "while x < 10 {\n",
471            "    x += 1;\n",
472            "}", //
473        );
474        while !new_text.is_empty() {
475            let max_len = cmp::min(new_text.len(), 10);
476            let len = rng.gen_range(1..=max_len);
477            let (chunk, suffix) = new_text.split_at(len);
478            provider.send_completion(chunk);
479            new_text = suffix;
480            deterministic.run_until_parked();
481        }
482        provider.finish_completion();
483        deterministic.run_until_parked();
484
485        assert_eq!(
486            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
487            indoc! {"
488                fn main() {
489                    let mut x = 0;
490                    while x < 10 {
491                        x += 1;
492                    }
493                }
494            "}
495        );
496    }
497
498    #[gpui::test(iterations = 10)]
499    async fn test_autoindent_when_generating_before_indentation(
500        cx: &mut TestAppContext,
501        mut rng: StdRng,
502        deterministic: Arc<Deterministic>,
503    ) {
504        cx.set_global(cx.read(SettingsStore::test));
505        cx.update(language_settings::init);
506
507        let text = concat!(
508            "fn main() {\n",
509            "  \n",
510            "}\n" //
511        );
512        let buffer =
513            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
514        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
515        let position = buffer.read_with(cx, |buffer, cx| {
516            let snapshot = buffer.snapshot(cx);
517            snapshot.anchor_before(Point::new(1, 2))
518        });
519        let provider = Arc::new(TestCompletionProvider::new());
520        let codegen = cx.add_model(|cx| {
521            Codegen::new(
522                buffer.clone(),
523                CodegenKind::Generate { position },
524                provider.clone(),
525                cx,
526            )
527        });
528
529        let request = Box::new(DummyCompletionRequest {
530            name: "test".to_string(),
531        });
532        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
533
534        let mut new_text = concat!(
535            "let mut x = 0;\n",
536            "while x < 10 {\n",
537            "    x += 1;\n",
538            "}", //
539        );
540        while !new_text.is_empty() {
541            let max_len = cmp::min(new_text.len(), 10);
542            let len = rng.gen_range(1..=max_len);
543            let (chunk, suffix) = new_text.split_at(len);
544            provider.send_completion(chunk);
545            new_text = suffix;
546            deterministic.run_until_parked();
547        }
548        provider.finish_completion();
549        deterministic.run_until_parked();
550
551        assert_eq!(
552            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
553            indoc! {"
554                fn main() {
555                    let mut x = 0;
556                    while x < 10 {
557                        x += 1;
558                    }
559                }
560            "}
561        );
562    }
563
564    #[gpui::test]
565    async fn test_strip_markdown_codeblock() {
566        assert_eq!(
567            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
568                .map(|chunk| chunk.unwrap())
569                .collect::<String>()
570                .await,
571            "Lorem ipsum dolor"
572        );
573        assert_eq!(
574            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
575                .map(|chunk| chunk.unwrap())
576                .collect::<String>()
577                .await,
578            "Lorem ipsum dolor"
579        );
580        assert_eq!(
581            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
582                .map(|chunk| chunk.unwrap())
583                .collect::<String>()
584                .await,
585            "Lorem ipsum dolor"
586        );
587        assert_eq!(
588            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
589                .map(|chunk| chunk.unwrap())
590                .collect::<String>()
591                .await,
592            "Lorem ipsum dolor"
593        );
594        assert_eq!(
595            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
596                .map(|chunk| chunk.unwrap())
597                .collect::<String>()
598                .await,
599            "```js\nLorem ipsum dolor\n```"
600        );
601        assert_eq!(
602            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
603                .map(|chunk| chunk.unwrap())
604                .collect::<String>()
605                .await,
606            "``\nLorem ipsum dolor\n```"
607        );
608
609        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
610            stream::iter(
611                text.chars()
612                    .collect::<Vec<_>>()
613                    .chunks(size)
614                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
615                    .collect::<Vec<_>>(),
616            )
617        }
618    }
619
620    struct TestCompletionProvider {
621        last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
622    }
623
624    impl TestCompletionProvider {
625        fn new() -> Self {
626            Self {
627                last_completion_tx: Mutex::new(None),
628            }
629        }
630
631        fn send_completion(&self, completion: impl Into<String>) {
632            let mut tx = self.last_completion_tx.lock();
633            tx.as_mut().unwrap().try_send(completion.into()).unwrap();
634        }
635
636        fn finish_completion(&self) {
637            self.last_completion_tx.lock().take().unwrap();
638        }
639    }
640
641    impl CompletionProvider for TestCompletionProvider {
642        fn base_model(&self) -> Box<dyn LanguageModel> {
643            let model: Box<dyn LanguageModel> = Box::new(FakeLanguageModel { capacity: 8190 });
644            model
645        }
646        fn complete(
647            &self,
648            _prompt: Box<dyn CompletionRequest>,
649        ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
650            let (tx, rx) = mpsc::channel(1);
651            *self.last_completion_tx.lock() = Some(tx);
652            async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
653        }
654    }
655
656    fn rust_lang() -> Language {
657        Language::new(
658            LanguageConfig {
659                name: "Rust".into(),
660                path_suffixes: vec!["rs".to_string()],
661                ..Default::default()
662            },
663            Some(tree_sitter_rust::language()),
664        )
665        .with_indents_query(
666            r#"
667            (call_expression) @indent
668            (field_expression) @indent
669            (_ "(" ")" @end) @indent
670            (_ "{" "}" @end) @indent
671            "#,
672        )
673        .unwrap()
674    }
675}