codegen.rs

  1use crate::streaming_diff::{Hunk, StreamingDiff};
  2use ai::completion::{CompletionProvider, CompletionRequest};
  3use anyhow::Result;
  4use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  6use gpui::{EventEmitter, Model, ModelContext, Task};
  7use language::{Rope, TransactionId};
  8use multi_buffer;
  9use std::{cmp, future, ops::Range, sync::Arc};
 10
 11pub enum Event {
 12    Finished,
 13    Undone,
 14}
 15
 16#[derive(Clone)]
 17pub enum CodegenKind {
 18    Transform { range: Range<Anchor> },
 19    Generate { position: Anchor },
 20}
 21
 22pub struct Codegen {
 23    provider: Arc<dyn CompletionProvider>,
 24    buffer: Model<MultiBuffer>,
 25    snapshot: MultiBufferSnapshot,
 26    kind: CodegenKind,
 27    last_equal_ranges: Vec<Range<Anchor>>,
 28    transaction_id: Option<TransactionId>,
 29    error: Option<anyhow::Error>,
 30    generation: Task<()>,
 31    idle: bool,
 32    _subscription: gpui::Subscription,
 33}
 34
 35impl EventEmitter<Event> for Codegen {}
 36
 37impl Codegen {
 38    pub fn new(
 39        buffer: Model<MultiBuffer>,
 40        kind: CodegenKind,
 41        provider: Arc<dyn CompletionProvider>,
 42        cx: &mut ModelContext<Self>,
 43    ) -> Self {
 44        let snapshot = buffer.read(cx).snapshot(cx);
 45        Self {
 46            provider,
 47            buffer: buffer.clone(),
 48            snapshot,
 49            kind,
 50            last_equal_ranges: Default::default(),
 51            transaction_id: Default::default(),
 52            error: Default::default(),
 53            idle: true,
 54            generation: Task::ready(()),
 55            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 56        }
 57    }
 58
 59    fn handle_buffer_event(
 60        &mut self,
 61        _buffer: Model<MultiBuffer>,
 62        event: &multi_buffer::Event,
 63        cx: &mut ModelContext<Self>,
 64    ) {
 65        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 66            if self.transaction_id == Some(*transaction_id) {
 67                self.transaction_id = None;
 68                self.generation = Task::ready(());
 69                cx.emit(Event::Undone);
 70            }
 71        }
 72    }
 73
 74    pub fn range(&self) -> Range<Anchor> {
 75        match &self.kind {
 76            CodegenKind::Transform { range } => range.clone(),
 77            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 78        }
 79    }
 80
 81    pub fn kind(&self) -> &CodegenKind {
 82        &self.kind
 83    }
 84
 85    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 86        &self.last_equal_ranges
 87    }
 88
 89    pub fn idle(&self) -> bool {
 90        self.idle
 91    }
 92
 93    pub fn error(&self) -> Option<&anyhow::Error> {
 94        self.error.as_ref()
 95    }
 96
 97    pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
 98        let range = self.range();
 99        let snapshot = self.snapshot.clone();
100        let selected_text = snapshot
101            .text_for_range(range.start..range.end)
102            .collect::<Rope>();
103
104        let selection_start = range.start.to_point(&snapshot);
105        let suggested_line_indent = snapshot
106            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
107            .into_values()
108            .next()
109            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
110
111        let response = self.provider.complete(prompt);
112        self.generation = cx.spawn_weak(|this, mut cx| {
113            async move {
114                let generate = async {
115                    let mut edit_start = range.start.to_offset(&snapshot);
116
117                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
118                    let diff = cx.background().spawn(async move {
119                        let chunks = strip_invalid_spans_from_codeblock(response.await?);
120                        futures::pin_mut!(chunks);
121                        let mut diff = StreamingDiff::new(selected_text.to_string());
122
123                        let mut new_text = String::new();
124                        let mut base_indent = None;
125                        let mut line_indent = None;
126                        let mut first_line = true;
127
128                        while let Some(chunk) = chunks.next().await {
129                            let chunk = chunk?;
130
131                            let mut lines = chunk.split('\n').peekable();
132                            while let Some(line) = lines.next() {
133                                new_text.push_str(line);
134                                if line_indent.is_none() {
135                                    if let Some(non_whitespace_ch_ix) =
136                                        new_text.find(|ch: char| !ch.is_whitespace())
137                                    {
138                                        line_indent = Some(non_whitespace_ch_ix);
139                                        base_indent = base_indent.or(line_indent);
140
141                                        let line_indent = line_indent.unwrap();
142                                        let base_indent = base_indent.unwrap();
143                                        let indent_delta = line_indent as i32 - base_indent as i32;
144                                        let mut corrected_indent_len = cmp::max(
145                                            0,
146                                            suggested_line_indent.len as i32 + indent_delta,
147                                        )
148                                            as usize;
149                                        if first_line {
150                                            corrected_indent_len = corrected_indent_len
151                                                .saturating_sub(selection_start.column as usize);
152                                        }
153
154                                        let indent_char = suggested_line_indent.char();
155                                        let mut indent_buffer = [0; 4];
156                                        let indent_str =
157                                            indent_char.encode_utf8(&mut indent_buffer);
158                                        new_text.replace_range(
159                                            ..line_indent,
160                                            &indent_str.repeat(corrected_indent_len),
161                                        );
162                                    }
163                                }
164
165                                if line_indent.is_some() {
166                                    hunks_tx.send(diff.push_new(&new_text)).await?;
167                                    new_text.clear();
168                                }
169
170                                if lines.peek().is_some() {
171                                    hunks_tx.send(diff.push_new("\n")).await?;
172                                    line_indent = None;
173                                    first_line = false;
174                                }
175                            }
176                        }
177                        hunks_tx.send(diff.push_new(&new_text)).await?;
178                        hunks_tx.send(diff.finish()).await?;
179
180                        anyhow::Ok(())
181                    });
182
183                    while let Some(hunks) = hunks_rx.next().await {
184                        let this = if let Some(this) = this.upgrade(&cx) {
185                            this
186                        } else {
187                            break;
188                        };
189
190                        this.update(&mut cx, |this, cx| {
191                            this.last_equal_ranges.clear();
192
193                            let transaction = this.buffer.update(cx, |buffer, cx| {
194                                // Avoid grouping assistant edits with user edits.
195                                buffer.finalize_last_transaction(cx);
196
197                                buffer.start_transaction(cx);
198                                buffer.edit(
199                                    hunks.into_iter().filter_map(|hunk| match hunk {
200                                        Hunk::Insert { text } => {
201                                            let edit_start = snapshot.anchor_after(edit_start);
202                                            Some((edit_start..edit_start, text))
203                                        }
204                                        Hunk::Remove { len } => {
205                                            let edit_end = edit_start + len;
206                                            let edit_range = snapshot.anchor_after(edit_start)
207                                                ..snapshot.anchor_before(edit_end);
208                                            edit_start = edit_end;
209                                            Some((edit_range, String::new()))
210                                        }
211                                        Hunk::Keep { len } => {
212                                            let edit_end = edit_start + len;
213                                            let edit_range = snapshot.anchor_after(edit_start)
214                                                ..snapshot.anchor_before(edit_end);
215                                            edit_start = edit_end;
216                                            this.last_equal_ranges.push(edit_range);
217                                            None
218                                        }
219                                    }),
220                                    None,
221                                    cx,
222                                );
223
224                                buffer.end_transaction(cx)
225                            });
226
227                            if let Some(transaction) = transaction {
228                                if let Some(first_transaction) = this.transaction_id {
229                                    // Group all assistant edits into the first transaction.
230                                    this.buffer.update(cx, |buffer, cx| {
231                                        buffer.merge_transactions(
232                                            transaction,
233                                            first_transaction,
234                                            cx,
235                                        )
236                                    });
237                                } else {
238                                    this.transaction_id = Some(transaction);
239                                    this.buffer.update(cx, |buffer, cx| {
240                                        buffer.finalize_last_transaction(cx)
241                                    });
242                                }
243                            }
244
245                            cx.notify();
246                        });
247                    }
248
249                    diff.await?;
250                    anyhow::Ok(())
251                };
252
253                let result = generate.await;
254                if let Some(this) = this.upgrade(&cx) {
255                    this.update(&mut cx, |this, cx| {
256                        this.last_equal_ranges.clear();
257                        this.idle = true;
258                        if let Err(error) = result {
259                            this.error = Some(error);
260                        }
261                        cx.emit(Event::Finished);
262                        cx.notify();
263                    });
264                }
265            }
266        });
267        self.error.take();
268        self.idle = false;
269        cx.notify();
270    }
271
272    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
273        if let Some(transaction_id) = self.transaction_id {
274            self.buffer
275                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
276        }
277    }
278}
279
280fn strip_invalid_spans_from_codeblock(
281    stream: impl Stream<Item = Result<String>>,
282) -> impl Stream<Item = Result<String>> {
283    let mut first_line = true;
284    let mut buffer = String::new();
285    let mut starts_with_markdown_codeblock = false;
286    let mut includes_start_or_end_span = false;
287    stream.filter_map(move |chunk| {
288        let chunk = match chunk {
289            Ok(chunk) => chunk,
290            Err(err) => return future::ready(Some(Err(err))),
291        };
292        buffer.push_str(&chunk);
293
294        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
295            includes_start_or_end_span = true;
296
297            buffer = buffer
298                .strip_prefix("<|S|>")
299                .or_else(|| buffer.strip_prefix("<|S|"))
300                .unwrap_or(&buffer)
301                .to_string();
302        } else if buffer.ends_with("|E|>") {
303            includes_start_or_end_span = true;
304        } else if buffer.starts_with("<|")
305            || buffer.starts_with("<|S")
306            || buffer.starts_with("<|S|")
307            || buffer.ends_with("|")
308            || buffer.ends_with("|E")
309            || buffer.ends_with("|E|")
310        {
311            return future::ready(None);
312        }
313
314        if first_line {
315            if buffer == "" || buffer == "`" || buffer == "``" {
316                return future::ready(None);
317            } else if buffer.starts_with("```") {
318                starts_with_markdown_codeblock = true;
319                if let Some(newline_ix) = buffer.find('\n') {
320                    buffer.replace_range(..newline_ix + 1, "");
321                    first_line = false;
322                } else {
323                    return future::ready(None);
324                }
325            }
326        }
327
328        let mut text = buffer.to_string();
329        if starts_with_markdown_codeblock {
330            text = text
331                .strip_suffix("\n```\n")
332                .or_else(|| text.strip_suffix("\n```"))
333                .or_else(|| text.strip_suffix("\n``"))
334                .or_else(|| text.strip_suffix("\n`"))
335                .or_else(|| text.strip_suffix('\n'))
336                .unwrap_or(&text)
337                .to_string();
338        }
339
340        if includes_start_or_end_span {
341            text = text
342                .strip_suffix("|E|>")
343                .or_else(|| text.strip_suffix("E|>"))
344                .or_else(|| text.strip_prefix("|>"))
345                .or_else(|| text.strip_prefix(">"))
346                .unwrap_or(&text)
347                .to_string();
348        };
349
350        if text.contains('\n') {
351            first_line = false;
352        }
353
354        let remainder = buffer.split_off(text.len());
355        let result = if buffer.is_empty() {
356            None
357        } else {
358            Some(Ok(buffer.clone()))
359        };
360
361        buffer = remainder;
362        future::ready(result)
363    })
364}
365
366#[cfg(test)]
367mod tests {
368    use std::sync::Arc;
369
370    use super::*;
371    use ai::test::FakeCompletionProvider;
372    use futures::stream::{self};
373    use gpui::TestAppContext;
374    use indoc::indoc;
375    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
376    use rand::prelude::*;
377    use serde::Serialize;
378    use settings::SettingsStore;
379
380    #[derive(Serialize)]
381    pub struct DummyCompletionRequest {
382        pub name: String,
383    }
384
385    impl CompletionRequest for DummyCompletionRequest {
386        fn data(&self) -> serde_json::Result<String> {
387            serde_json::to_string(self)
388        }
389    }
390
391    #[gpui::test(iterations = 10)]
392    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
393        cx.set_global(cx.read(SettingsStore::test));
394        cx.update(language_settings::init);
395
396        let text = indoc! {"
397            fn main() {
398                let x = 0;
399                for _ in 0..10 {
400                    x += 1;
401                }
402            }
403        "};
404        let buffer =
405            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
406        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
407        let range = buffer.read_with(cx, |buffer, cx| {
408            let snapshot = buffer.snapshot(cx);
409            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
410        });
411        let provider = Arc::new(FakeCompletionProvider::new());
412        let codegen = cx.add_model(|cx| {
413            Codegen::new(
414                buffer.clone(),
415                CodegenKind::Transform { range },
416                provider.clone(),
417                cx,
418            )
419        });
420
421        let request = Box::new(DummyCompletionRequest {
422            name: "test".to_string(),
423        });
424        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
425
426        let mut new_text = concat!(
427            "       let mut x = 0;\n",
428            "       while x < 10 {\n",
429            "           x += 1;\n",
430            "       }",
431        );
432        while !new_text.is_empty() {
433            let max_len = cmp::min(new_text.len(), 10);
434            let len = rng.gen_range(1..=max_len);
435            let (chunk, suffix) = new_text.split_at(len);
436            println!("CHUNK: {:?}", &chunk);
437            provider.send_completion(chunk);
438            new_text = suffix;
439            cx.background_executor.run_until_parked();
440        }
441        provider.finish_completion();
442        cx.background_executor.run_until_parked();
443
444        assert_eq!(
445            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
446            indoc! {"
447                fn main() {
448                    let mut x = 0;
449                    while x < 10 {
450                        x += 1;
451                    }
452                }
453            "}
454        );
455    }
456
457    #[gpui::test(iterations = 10)]
458    async fn test_autoindent_when_generating_past_indentation(
459        cx: &mut TestAppContext,
460        mut rng: StdRng,
461    ) {
462        cx.set_global(cx.read(SettingsStore::test));
463        cx.update(language_settings::init);
464
465        let text = indoc! {"
466            fn main() {
467                le
468            }
469        "};
470        let buffer =
471            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
472        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
473        let position = buffer.read_with(cx, |buffer, cx| {
474            let snapshot = buffer.snapshot(cx);
475            snapshot.anchor_before(Point::new(1, 6))
476        });
477        let provider = Arc::new(FakeCompletionProvider::new());
478        let codegen = cx.add_model(|cx| {
479            Codegen::new(
480                buffer.clone(),
481                CodegenKind::Generate { position },
482                provider.clone(),
483                cx,
484            )
485        });
486
487        let request = Box::new(DummyCompletionRequest {
488            name: "test".to_string(),
489        });
490        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
491
492        let mut new_text = concat!(
493            "t mut x = 0;\n",
494            "while x < 10 {\n",
495            "    x += 1;\n",
496            "}", //
497        );
498        while !new_text.is_empty() {
499            let max_len = cmp::min(new_text.len(), 10);
500            let len = rng.gen_range(1..=max_len);
501            let (chunk, suffix) = new_text.split_at(len);
502            provider.send_completion(chunk);
503            new_text = suffix;
504            cx.background_executor.run_until_parked();
505        }
506        provider.finish_completion();
507        cx.background_executor.run_until_parked();
508
509        assert_eq!(
510            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
511            indoc! {"
512                fn main() {
513                    let mut x = 0;
514                    while x < 10 {
515                        x += 1;
516                    }
517                }
518            "}
519        );
520    }
521
522    #[gpui::test(iterations = 10)]
523    async fn test_autoindent_when_generating_before_indentation(
524        cx: &mut TestAppContext,
525        mut rng: StdRng,
526    ) {
527        cx.set_global(cx.read(SettingsStore::test));
528        cx.update(language_settings::init);
529
530        let text = concat!(
531            "fn main() {\n",
532            "  \n",
533            "}\n" //
534        );
535        let buffer =
536            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
537        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
538        let position = buffer.read_with(cx, |buffer, cx| {
539            let snapshot = buffer.snapshot(cx);
540            snapshot.anchor_before(Point::new(1, 2))
541        });
542        let provider = Arc::new(FakeCompletionProvider::new());
543        let codegen = cx.add_model(|cx| {
544            Codegen::new(
545                buffer.clone(),
546                CodegenKind::Generate { position },
547                provider.clone(),
548                cx,
549            )
550        });
551
552        let request = Box::new(DummyCompletionRequest {
553            name: "test".to_string(),
554        });
555        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
556
557        let mut new_text = concat!(
558            "let mut x = 0;\n",
559            "while x < 10 {\n",
560            "    x += 1;\n",
561            "}", //
562        );
563        while !new_text.is_empty() {
564            let max_len = cmp::min(new_text.len(), 10);
565            let len = rng.gen_range(1..=max_len);
566            let (chunk, suffix) = new_text.split_at(len);
567            println!("{:?}", &chunk);
568            provider.send_completion(chunk);
569            new_text = suffix;
570            cx.background_executor.run_until_parked();
571        }
572        provider.finish_completion();
573        cx.background_executor.run_until_parked();
574
575        assert_eq!(
576            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
577            indoc! {"
578                fn main() {
579                    let mut x = 0;
580                    while x < 10 {
581                        x += 1;
582                    }
583                }
584            "}
585        );
586    }
587
588    #[gpui::test]
589    async fn test_strip_invalid_spans_from_codeblock() {
590        assert_eq!(
591            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
592                .map(|chunk| chunk.unwrap())
593                .collect::<String>()
594                .await,
595            "Lorem ipsum dolor"
596        );
597        assert_eq!(
598            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
599                .map(|chunk| chunk.unwrap())
600                .collect::<String>()
601                .await,
602            "Lorem ipsum dolor"
603        );
604        assert_eq!(
605            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
606                .map(|chunk| chunk.unwrap())
607                .collect::<String>()
608                .await,
609            "Lorem ipsum dolor"
610        );
611        assert_eq!(
612            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
613                .map(|chunk| chunk.unwrap())
614                .collect::<String>()
615                .await,
616            "Lorem ipsum dolor"
617        );
618        assert_eq!(
619            strip_invalid_spans_from_codeblock(chunks(
620                "```html\n```js\nLorem ipsum dolor\n```\n```",
621                2
622            ))
623            .map(|chunk| chunk.unwrap())
624            .collect::<String>()
625            .await,
626            "```js\nLorem ipsum dolor\n```"
627        );
628        assert_eq!(
629            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
630                .map(|chunk| chunk.unwrap())
631                .collect::<String>()
632                .await,
633            "``\nLorem ipsum dolor\n```"
634        );
635        assert_eq!(
636            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
637                .map(|chunk| chunk.unwrap())
638                .collect::<String>()
639                .await,
640            "Lorem ipsum"
641        );
642
643        assert_eq!(
644            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
645                .map(|chunk| chunk.unwrap())
646                .collect::<String>()
647                .await,
648            "Lorem ipsum"
649        );
650
651        assert_eq!(
652            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
653                .map(|chunk| chunk.unwrap())
654                .collect::<String>()
655                .await,
656            "Lorem ipsum"
657        );
658        assert_eq!(
659            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
660                .map(|chunk| chunk.unwrap())
661                .collect::<String>()
662                .await,
663            "Lorem ipsum"
664        );
665        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
666            stream::iter(
667                text.chars()
668                    .collect::<Vec<_>>()
669                    .chunks(size)
670                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
671                    .collect::<Vec<_>>(),
672            )
673        }
674    }
675
676    fn rust_lang() -> Language {
677        Language::new(
678            LanguageConfig {
679                name: "Rust".into(),
680                path_suffixes: vec!["rs".to_string()],
681                ..Default::default()
682            },
683            Some(tree_sitter_rust::language()),
684        )
685        .with_indents_query(
686            r#"
687            (call_expression) @indent
688            (field_expression) @indent
689            (_ "(" ")" @end) @indent
690            (_ "{" "}" @end) @indent
691            "#,
692        )
693        .unwrap()
694    }
695}