codegen.rs

  1use crate::streaming_diff::{Hunk, StreamingDiff};
  2use ai::completion::{CompletionProvider, CompletionRequest};
  3use anyhow::Result;
  4use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  6use gpui::{Entity, ModelContext, ModelHandle, Task};
  7use language::{Rope, TransactionId};
  8use multi_buffer;
  9use std::{cmp, future, ops::Range, sync::Arc};
 10
 11pub enum Event {
 12    Finished,
 13    Undone,
 14}
 15
 16#[derive(Clone)]
 17pub enum CodegenKind {
 18    Transform { range: Range<Anchor> },
 19    Generate { position: Anchor },
 20}
 21
 22pub struct Codegen {
 23    provider: Arc<dyn CompletionProvider>,
 24    buffer: ModelHandle<MultiBuffer>,
 25    snapshot: MultiBufferSnapshot,
 26    kind: CodegenKind,
 27    last_equal_ranges: Vec<Range<Anchor>>,
 28    transaction_id: Option<TransactionId>,
 29    error: Option<anyhow::Error>,
 30    generation: Task<()>,
 31    idle: bool,
 32    _subscription: gpui::Subscription,
 33}
 34
 35impl Entity for Codegen {
 36    type Event = Event;
 37}
 38
 39impl Codegen {
 40    pub fn new(
 41        buffer: ModelHandle<MultiBuffer>,
 42        kind: CodegenKind,
 43        provider: Arc<dyn CompletionProvider>,
 44        cx: &mut ModelContext<Self>,
 45    ) -> Self {
 46        let snapshot = buffer.read(cx).snapshot(cx);
 47        Self {
 48            provider,
 49            buffer: buffer.clone(),
 50            snapshot,
 51            kind,
 52            last_equal_ranges: Default::default(),
 53            transaction_id: Default::default(),
 54            error: Default::default(),
 55            idle: true,
 56            generation: Task::ready(()),
 57            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 58        }
 59    }
 60
 61    fn handle_buffer_event(
 62        &mut self,
 63        _buffer: ModelHandle<MultiBuffer>,
 64        event: &multi_buffer::Event,
 65        cx: &mut ModelContext<Self>,
 66    ) {
 67        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 68            if self.transaction_id == Some(*transaction_id) {
 69                self.transaction_id = None;
 70                self.generation = Task::ready(());
 71                cx.emit(Event::Undone);
 72            }
 73        }
 74    }
 75
 76    pub fn range(&self) -> Range<Anchor> {
 77        match &self.kind {
 78            CodegenKind::Transform { range } => range.clone(),
 79            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 80        }
 81    }
 82
 83    pub fn kind(&self) -> &CodegenKind {
 84        &self.kind
 85    }
 86
 87    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 88        &self.last_equal_ranges
 89    }
 90
 91    pub fn idle(&self) -> bool {
 92        self.idle
 93    }
 94
 95    pub fn error(&self) -> Option<&anyhow::Error> {
 96        self.error.as_ref()
 97    }
 98
 99    pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
100        let range = self.range();
101        let snapshot = self.snapshot.clone();
102        let selected_text = snapshot
103            .text_for_range(range.start..range.end)
104            .collect::<Rope>();
105
106        let selection_start = range.start.to_point(&snapshot);
107        let suggested_line_indent = snapshot
108            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
109            .into_values()
110            .next()
111            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
112
113        let response = self.provider.complete(prompt);
114        self.generation = cx.spawn_weak(|this, mut cx| {
115            async move {
116                let generate = async {
117                    let mut edit_start = range.start.to_offset(&snapshot);
118
119                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
120                    let diff = cx.background().spawn(async move {
121                        let chunks = strip_invalid_spans_from_codeblock(response.await?);
122                        futures::pin_mut!(chunks);
123                        let mut diff = StreamingDiff::new(selected_text.to_string());
124
125                        let mut new_text = String::new();
126                        let mut base_indent = None;
127                        let mut line_indent = None;
128                        let mut first_line = true;
129
130                        while let Some(chunk) = chunks.next().await {
131                            let chunk = chunk?;
132
133                            let mut lines = chunk.split('\n').peekable();
134                            while let Some(line) = lines.next() {
135                                new_text.push_str(line);
136                                if line_indent.is_none() {
137                                    if let Some(non_whitespace_ch_ix) =
138                                        new_text.find(|ch: char| !ch.is_whitespace())
139                                    {
140                                        line_indent = Some(non_whitespace_ch_ix);
141                                        base_indent = base_indent.or(line_indent);
142
143                                        let line_indent = line_indent.unwrap();
144                                        let base_indent = base_indent.unwrap();
145                                        let indent_delta = line_indent as i32 - base_indent as i32;
146                                        let mut corrected_indent_len = cmp::max(
147                                            0,
148                                            suggested_line_indent.len as i32 + indent_delta,
149                                        )
150                                            as usize;
151                                        if first_line {
152                                            corrected_indent_len = corrected_indent_len
153                                                .saturating_sub(selection_start.column as usize);
154                                        }
155
156                                        let indent_char = suggested_line_indent.char();
157                                        let mut indent_buffer = [0; 4];
158                                        let indent_str =
159                                            indent_char.encode_utf8(&mut indent_buffer);
160                                        new_text.replace_range(
161                                            ..line_indent,
162                                            &indent_str.repeat(corrected_indent_len),
163                                        );
164                                    }
165                                }
166
167                                if line_indent.is_some() {
168                                    hunks_tx.send(diff.push_new(&new_text)).await?;
169                                    new_text.clear();
170                                }
171
172                                if lines.peek().is_some() {
173                                    hunks_tx.send(diff.push_new("\n")).await?;
174                                    line_indent = None;
175                                    first_line = false;
176                                }
177                            }
178                        }
179                        hunks_tx.send(diff.push_new(&new_text)).await?;
180                        hunks_tx.send(diff.finish()).await?;
181
182                        anyhow::Ok(())
183                    });
184
185                    while let Some(hunks) = hunks_rx.next().await {
186                        let this = if let Some(this) = this.upgrade(&cx) {
187                            this
188                        } else {
189                            break;
190                        };
191
192                        this.update(&mut cx, |this, cx| {
193                            this.last_equal_ranges.clear();
194
195                            let transaction = this.buffer.update(cx, |buffer, cx| {
196                                // Avoid grouping assistant edits with user edits.
197                                buffer.finalize_last_transaction(cx);
198
199                                buffer.start_transaction(cx);
200                                buffer.edit(
201                                    hunks.into_iter().filter_map(|hunk| match hunk {
202                                        Hunk::Insert { text } => {
203                                            let edit_start = snapshot.anchor_after(edit_start);
204                                            Some((edit_start..edit_start, text))
205                                        }
206                                        Hunk::Remove { len } => {
207                                            let edit_end = edit_start + len;
208                                            let edit_range = snapshot.anchor_after(edit_start)
209                                                ..snapshot.anchor_before(edit_end);
210                                            edit_start = edit_end;
211                                            Some((edit_range, String::new()))
212                                        }
213                                        Hunk::Keep { len } => {
214                                            let edit_end = edit_start + len;
215                                            let edit_range = snapshot.anchor_after(edit_start)
216                                                ..snapshot.anchor_before(edit_end);
217                                            edit_start = edit_end;
218                                            this.last_equal_ranges.push(edit_range);
219                                            None
220                                        }
221                                    }),
222                                    None,
223                                    cx,
224                                );
225
226                                buffer.end_transaction(cx)
227                            });
228
229                            if let Some(transaction) = transaction {
230                                if let Some(first_transaction) = this.transaction_id {
231                                    // Group all assistant edits into the first transaction.
232                                    this.buffer.update(cx, |buffer, cx| {
233                                        buffer.merge_transactions(
234                                            transaction,
235                                            first_transaction,
236                                            cx,
237                                        )
238                                    });
239                                } else {
240                                    this.transaction_id = Some(transaction);
241                                    this.buffer.update(cx, |buffer, cx| {
242                                        buffer.finalize_last_transaction(cx)
243                                    });
244                                }
245                            }
246
247                            cx.notify();
248                        });
249                    }
250
251                    diff.await?;
252                    anyhow::Ok(())
253                };
254
255                let result = generate.await;
256                if let Some(this) = this.upgrade(&cx) {
257                    this.update(&mut cx, |this, cx| {
258                        this.last_equal_ranges.clear();
259                        this.idle = true;
260                        if let Err(error) = result {
261                            this.error = Some(error);
262                        }
263                        cx.emit(Event::Finished);
264                        cx.notify();
265                    });
266                }
267            }
268        });
269        self.error.take();
270        self.idle = false;
271        cx.notify();
272    }
273
274    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
275        if let Some(transaction_id) = self.transaction_id {
276            self.buffer
277                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
278        }
279    }
280}
281
282fn strip_invalid_spans_from_codeblock(
283    stream: impl Stream<Item = Result<String>>,
284) -> impl Stream<Item = Result<String>> {
285    let mut first_line = true;
286    let mut buffer = String::new();
287    let mut starts_with_markdown_codeblock = false;
288    let mut includes_start_or_end_span = false;
289    stream.filter_map(move |chunk| {
290        let chunk = match chunk {
291            Ok(chunk) => chunk,
292            Err(err) => return future::ready(Some(Err(err))),
293        };
294        buffer.push_str(&chunk);
295
296        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
297            includes_start_or_end_span = true;
298
299            buffer = buffer
300                .strip_prefix("<|S|>")
301                .or_else(|| buffer.strip_prefix("<|S|"))
302                .unwrap_or(&buffer)
303                .to_string();
304        } else if buffer.ends_with("|E|>") {
305            includes_start_or_end_span = true;
306        } else if buffer.starts_with("<|")
307            || buffer.starts_with("<|S")
308            || buffer.starts_with("<|S|")
309            || buffer.ends_with("|")
310            || buffer.ends_with("|E")
311            || buffer.ends_with("|E|")
312        {
313            return future::ready(None);
314        }
315
316        if first_line {
317            if buffer == "" || buffer == "`" || buffer == "``" {
318                return future::ready(None);
319            } else if buffer.starts_with("```") {
320                starts_with_markdown_codeblock = true;
321                if let Some(newline_ix) = buffer.find('\n') {
322                    buffer.replace_range(..newline_ix + 1, "");
323                    first_line = false;
324                } else {
325                    return future::ready(None);
326                }
327            }
328        }
329
330        let mut text = buffer.to_string();
331        if starts_with_markdown_codeblock {
332            text = text
333                .strip_suffix("\n```\n")
334                .or_else(|| text.strip_suffix("\n```"))
335                .or_else(|| text.strip_suffix("\n``"))
336                .or_else(|| text.strip_suffix("\n`"))
337                .or_else(|| text.strip_suffix('\n'))
338                .unwrap_or(&text)
339                .to_string();
340        }
341
342        if includes_start_or_end_span {
343            text = text
344                .strip_suffix("|E|>")
345                .or_else(|| text.strip_suffix("E|>"))
346                .or_else(|| text.strip_prefix("|>"))
347                .or_else(|| text.strip_prefix(">"))
348                .unwrap_or(&text)
349                .to_string();
350        };
351
352        if text.contains('\n') {
353            first_line = false;
354        }
355
356        let remainder = buffer.split_off(text.len());
357        let result = if buffer.is_empty() {
358            None
359        } else {
360            Some(Ok(buffer.clone()))
361        };
362
363        buffer = remainder;
364        future::ready(result)
365    })
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use ai::test::FakeCompletionProvider;
372    use futures::stream::{self};
373    use gpui::{executor::Deterministic, TestAppContext};
374    use indoc::indoc;
375    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
376    use rand::prelude::*;
377    use serde::Serialize;
378    use settings::SettingsStore;
379
380    #[derive(Serialize)]
381    pub struct DummyCompletionRequest {
382        pub name: String,
383    }
384
385    impl CompletionRequest for DummyCompletionRequest {
386        fn data(&self) -> serde_json::Result<String> {
387            serde_json::to_string(self)
388        }
389    }
390
391    #[gpui::test(iterations = 10)]
392    async fn test_transform_autoindent(
393        cx: &mut TestAppContext,
394        mut rng: StdRng,
395        deterministic: Arc<Deterministic>,
396    ) {
397        cx.set_global(cx.read(SettingsStore::test));
398        cx.update(language_settings::init);
399
400        let text = indoc! {"
401            fn main() {
402                let x = 0;
403                for _ in 0..10 {
404                    x += 1;
405                }
406            }
407        "};
408        let buffer =
409            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
410        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
411        let range = buffer.read_with(cx, |buffer, cx| {
412            let snapshot = buffer.snapshot(cx);
413            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
414        });
415        let provider = Arc::new(FakeCompletionProvider::new());
416        let codegen = cx.add_model(|cx| {
417            Codegen::new(
418                buffer.clone(),
419                CodegenKind::Transform { range },
420                provider.clone(),
421                cx,
422            )
423        });
424
425        let request = Box::new(DummyCompletionRequest {
426            name: "test".to_string(),
427        });
428        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
429
430        let mut new_text = concat!(
431            "       let mut x = 0;\n",
432            "       while x < 10 {\n",
433            "           x += 1;\n",
434            "       }",
435        );
436        while !new_text.is_empty() {
437            let max_len = cmp::min(new_text.len(), 10);
438            let len = rng.gen_range(1..=max_len);
439            let (chunk, suffix) = new_text.split_at(len);
440            provider.send_completion(chunk);
441            new_text = suffix;
442            deterministic.run_until_parked();
443        }
444        provider.finish_completion();
445        deterministic.run_until_parked();
446
447        assert_eq!(
448            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
449            indoc! {"
450                fn main() {
451                    let mut x = 0;
452                    while x < 10 {
453                        x += 1;
454                    }
455                }
456            "}
457        );
458    }
459
460    #[gpui::test(iterations = 10)]
461    async fn test_autoindent_when_generating_past_indentation(
462        cx: &mut TestAppContext,
463        mut rng: StdRng,
464        deterministic: Arc<Deterministic>,
465    ) {
466        cx.set_global(cx.read(SettingsStore::test));
467        cx.update(language_settings::init);
468
469        let text = indoc! {"
470            fn main() {
471                le
472            }
473        "};
474        let buffer =
475            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
476        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
477        let position = buffer.read_with(cx, |buffer, cx| {
478            let snapshot = buffer.snapshot(cx);
479            snapshot.anchor_before(Point::new(1, 6))
480        });
481        let provider = Arc::new(FakeCompletionProvider::new());
482        let codegen = cx.add_model(|cx| {
483            Codegen::new(
484                buffer.clone(),
485                CodegenKind::Generate { position },
486                provider.clone(),
487                cx,
488            )
489        });
490
491        let request = Box::new(DummyCompletionRequest {
492            name: "test".to_string(),
493        });
494        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
495
496        let mut new_text = concat!(
497            "t mut x = 0;\n",
498            "while x < 10 {\n",
499            "    x += 1;\n",
500            "}", //
501        );
502        while !new_text.is_empty() {
503            let max_len = cmp::min(new_text.len(), 10);
504            let len = rng.gen_range(1..=max_len);
505            let (chunk, suffix) = new_text.split_at(len);
506            provider.send_completion(chunk);
507            new_text = suffix;
508            deterministic.run_until_parked();
509        }
510        provider.finish_completion();
511        deterministic.run_until_parked();
512
513        assert_eq!(
514            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
515            indoc! {"
516                fn main() {
517                    let mut x = 0;
518                    while x < 10 {
519                        x += 1;
520                    }
521                }
522            "}
523        );
524    }
525
526    #[gpui::test(iterations = 10)]
527    async fn test_autoindent_when_generating_before_indentation(
528        cx: &mut TestAppContext,
529        mut rng: StdRng,
530        deterministic: Arc<Deterministic>,
531    ) {
532        cx.set_global(cx.read(SettingsStore::test));
533        cx.update(language_settings::init);
534
535        let text = concat!(
536            "fn main() {\n",
537            "  \n",
538            "}\n" //
539        );
540        let buffer =
541            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
542        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
543        let position = buffer.read_with(cx, |buffer, cx| {
544            let snapshot = buffer.snapshot(cx);
545            snapshot.anchor_before(Point::new(1, 2))
546        });
547        let provider = Arc::new(FakeCompletionProvider::new());
548        let codegen = cx.add_model(|cx| {
549            Codegen::new(
550                buffer.clone(),
551                CodegenKind::Generate { position },
552                provider.clone(),
553                cx,
554            )
555        });
556
557        let request = Box::new(DummyCompletionRequest {
558            name: "test".to_string(),
559        });
560        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
561
562        let mut new_text = concat!(
563            "let mut x = 0;\n",
564            "while x < 10 {\n",
565            "    x += 1;\n",
566            "}", //
567        );
568        while !new_text.is_empty() {
569            let max_len = cmp::min(new_text.len(), 10);
570            let len = rng.gen_range(1..=max_len);
571            let (chunk, suffix) = new_text.split_at(len);
572            provider.send_completion(chunk);
573            new_text = suffix;
574            deterministic.run_until_parked();
575        }
576        provider.finish_completion();
577        deterministic.run_until_parked();
578
579        assert_eq!(
580            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
581            indoc! {"
582                fn main() {
583                    let mut x = 0;
584                    while x < 10 {
585                        x += 1;
586                    }
587                }
588            "}
589        );
590    }
591
592    #[gpui::test]
593    async fn test_strip_invalid_spans_from_codeblock() {
594        assert_eq!(
595            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
596                .map(|chunk| chunk.unwrap())
597                .collect::<String>()
598                .await,
599            "Lorem ipsum dolor"
600        );
601        assert_eq!(
602            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
603                .map(|chunk| chunk.unwrap())
604                .collect::<String>()
605                .await,
606            "Lorem ipsum dolor"
607        );
608        assert_eq!(
609            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
610                .map(|chunk| chunk.unwrap())
611                .collect::<String>()
612                .await,
613            "Lorem ipsum dolor"
614        );
615        assert_eq!(
616            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
617                .map(|chunk| chunk.unwrap())
618                .collect::<String>()
619                .await,
620            "Lorem ipsum dolor"
621        );
622        assert_eq!(
623            strip_invalid_spans_from_codeblock(chunks(
624                "```html\n```js\nLorem ipsum dolor\n```\n```",
625                2
626            ))
627            .map(|chunk| chunk.unwrap())
628            .collect::<String>()
629            .await,
630            "```js\nLorem ipsum dolor\n```"
631        );
632        assert_eq!(
633            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
634                .map(|chunk| chunk.unwrap())
635                .collect::<String>()
636                .await,
637            "``\nLorem ipsum dolor\n```"
638        );
639        assert_eq!(
640            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
641                .map(|chunk| chunk.unwrap())
642                .collect::<String>()
643                .await,
644            "Lorem ipsum"
645        );
646
647        assert_eq!(
648            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
649                .map(|chunk| chunk.unwrap())
650                .collect::<String>()
651                .await,
652            "Lorem ipsum"
653        );
654
655        assert_eq!(
656            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
657                .map(|chunk| chunk.unwrap())
658                .collect::<String>()
659                .await,
660            "Lorem ipsum"
661        );
662        assert_eq!(
663            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
664                .map(|chunk| chunk.unwrap())
665                .collect::<String>()
666                .await,
667            "Lorem ipsum"
668        );
669        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
670            stream::iter(
671                text.chars()
672                    .collect::<Vec<_>>()
673                    .chunks(size)
674                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
675                    .collect::<Vec<_>>(),
676            )
677        }
678    }
679
680    fn rust_lang() -> Language {
681        Language::new(
682            LanguageConfig {
683                name: "Rust".into(),
684                path_suffixes: vec!["rs".to_string()],
685                ..Default::default()
686            },
687            Some(tree_sitter_rust::language()),
688        )
689        .with_indents_query(
690            r#"
691            (call_expression) @indent
692            (field_expression) @indent
693            (_ "(" ")" @end) @indent
694            (_ "{" "}" @end) @indent
695            "#,
696        )
697        .unwrap()
698    }
699}