codegen.rs

  1use crate::streaming_diff::{Hunk, StreamingDiff};
  2use ai::completion::{CompletionProvider, CompletionRequest};
  3use anyhow::Result;
  4use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  6use gpui::{Entity, ModelContext, ModelHandle, Task};
  7use language::{Rope, TransactionId};
  8use multi_buffer;
  9use std::{cmp, future, ops::Range, sync::Arc};
 10
 11pub enum Event {
 12    Finished,
 13    Undone,
 14}
 15
 16#[derive(Clone)]
 17pub enum CodegenKind {
 18    Transform { range: Range<Anchor> },
 19    Generate { position: Anchor },
 20}
 21
 22pub struct Codegen {
 23    provider: Arc<dyn CompletionProvider>,
 24    buffer: ModelHandle<MultiBuffer>,
 25    snapshot: MultiBufferSnapshot,
 26    kind: CodegenKind,
 27    last_equal_ranges: Vec<Range<Anchor>>,
 28    transaction_id: Option<TransactionId>,
 29    error: Option<anyhow::Error>,
 30    generation: Task<()>,
 31    idle: bool,
 32    _subscription: gpui::Subscription,
 33}
 34
 35impl Entity for Codegen {
 36    type Event = Event;
 37}
 38
 39impl Codegen {
 40    pub fn new(
 41        buffer: ModelHandle<MultiBuffer>,
 42        kind: CodegenKind,
 43        provider: Arc<dyn CompletionProvider>,
 44        cx: &mut ModelContext<Self>,
 45    ) -> Self {
 46        let snapshot = buffer.read(cx).snapshot(cx);
 47        Self {
 48            provider,
 49            buffer: buffer.clone(),
 50            snapshot,
 51            kind,
 52            last_equal_ranges: Default::default(),
 53            transaction_id: Default::default(),
 54            error: Default::default(),
 55            idle: true,
 56            generation: Task::ready(()),
 57            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 58        }
 59    }
 60
 61    fn handle_buffer_event(
 62        &mut self,
 63        _buffer: ModelHandle<MultiBuffer>,
 64        event: &multi_buffer::Event,
 65        cx: &mut ModelContext<Self>,
 66    ) {
 67        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 68            if self.transaction_id == Some(*transaction_id) {
 69                self.transaction_id = None;
 70                self.generation = Task::ready(());
 71                cx.emit(Event::Undone);
 72            }
 73        }
 74    }
 75
 76    pub fn range(&self) -> Range<Anchor> {
 77        match &self.kind {
 78            CodegenKind::Transform { range } => range.clone(),
 79            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 80        }
 81    }
 82
 83    pub fn kind(&self) -> &CodegenKind {
 84        &self.kind
 85    }
 86
 87    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 88        &self.last_equal_ranges
 89    }
 90
 91    pub fn idle(&self) -> bool {
 92        self.idle
 93    }
 94
 95    pub fn error(&self) -> Option<&anyhow::Error> {
 96        self.error.as_ref()
 97    }
 98
 99    pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
100        let range = self.range();
101        let snapshot = self.snapshot.clone();
102        let selected_text = snapshot
103            .text_for_range(range.start..range.end)
104            .collect::<Rope>();
105
106        let selection_start = range.start.to_point(&snapshot);
107        let suggested_line_indent = snapshot
108            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
109            .into_values()
110            .next()
111            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
112
113        let response = self.provider.complete(prompt);
114        self.generation = cx.spawn_weak(|this, mut cx| {
115            async move {
116                let generate = async {
117                    let mut edit_start = range.start.to_offset(&snapshot);
118
119                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
120                    let diff = cx.background().spawn(async move {
121                        let chunks = strip_invalid_spans_from_codeblock(response.await?);
122                        futures::pin_mut!(chunks);
123                        let mut diff = StreamingDiff::new(selected_text.to_string());
124
125                        let mut new_text = String::new();
126                        let mut base_indent = None;
127                        let mut line_indent = None;
128                        let mut first_line = true;
129
130                        while let Some(chunk) = chunks.next().await {
131                            let chunk = chunk?;
132
133                            let mut lines = chunk.split('\n').peekable();
134                            while let Some(line) = lines.next() {
135                                new_text.push_str(line);
136                                if line_indent.is_none() {
137                                    if let Some(non_whitespace_ch_ix) =
138                                        new_text.find(|ch: char| !ch.is_whitespace())
139                                    {
140                                        line_indent = Some(non_whitespace_ch_ix);
141                                        base_indent = base_indent.or(line_indent);
142
143                                        let line_indent = line_indent.unwrap();
144                                        let base_indent = base_indent.unwrap();
145                                        let indent_delta = line_indent as i32 - base_indent as i32;
146                                        let mut corrected_indent_len = cmp::max(
147                                            0,
148                                            suggested_line_indent.len as i32 + indent_delta,
149                                        )
150                                            as usize;
151                                        if first_line {
152                                            corrected_indent_len = corrected_indent_len
153                                                .saturating_sub(selection_start.column as usize);
154                                        }
155
156                                        let indent_char = suggested_line_indent.char();
157                                        let mut indent_buffer = [0; 4];
158                                        let indent_str =
159                                            indent_char.encode_utf8(&mut indent_buffer);
160                                        new_text.replace_range(
161                                            ..line_indent,
162                                            &indent_str.repeat(corrected_indent_len),
163                                        );
164                                    }
165                                }
166
167                                if line_indent.is_some() {
168                                    hunks_tx.send(diff.push_new(&new_text)).await?;
169                                    new_text.clear();
170                                }
171
172                                if lines.peek().is_some() {
173                                    hunks_tx.send(diff.push_new("\n")).await?;
174                                    line_indent = None;
175                                    first_line = false;
176                                }
177                            }
178                        }
179                        hunks_tx.send(diff.push_new(&new_text)).await?;
180                        hunks_tx.send(diff.finish()).await?;
181
182                        anyhow::Ok(())
183                    });
184
185                    while let Some(hunks) = hunks_rx.next().await {
186                        let this = if let Some(this) = this.upgrade(&cx) {
187                            this
188                        } else {
189                            break;
190                        };
191
192                        this.update(&mut cx, |this, cx| {
193                            this.last_equal_ranges.clear();
194
195                            let transaction = this.buffer.update(cx, |buffer, cx| {
196                                // Avoid grouping assistant edits with user edits.
197                                buffer.finalize_last_transaction(cx);
198
199                                buffer.start_transaction(cx);
200                                buffer.edit(
201                                    hunks.into_iter().filter_map(|hunk| match hunk {
202                                        Hunk::Insert { text } => {
203                                            let edit_start = snapshot.anchor_after(edit_start);
204                                            Some((edit_start..edit_start, text))
205                                        }
206                                        Hunk::Remove { len } => {
207                                            let edit_end = edit_start + len;
208                                            let edit_range = snapshot.anchor_after(edit_start)
209                                                ..snapshot.anchor_before(edit_end);
210                                            edit_start = edit_end;
211                                            Some((edit_range, String::new()))
212                                        }
213                                        Hunk::Keep { len } => {
214                                            let edit_end = edit_start + len;
215                                            let edit_range = snapshot.anchor_after(edit_start)
216                                                ..snapshot.anchor_before(edit_end);
217                                            edit_start = edit_end;
218                                            this.last_equal_ranges.push(edit_range);
219                                            None
220                                        }
221                                    }),
222                                    None,
223                                    cx,
224                                );
225
226                                buffer.end_transaction(cx)
227                            });
228
229                            if let Some(transaction) = transaction {
230                                if let Some(first_transaction) = this.transaction_id {
231                                    // Group all assistant edits into the first transaction.
232                                    this.buffer.update(cx, |buffer, cx| {
233                                        buffer.merge_transactions(
234                                            transaction,
235                                            first_transaction,
236                                            cx,
237                                        )
238                                    });
239                                } else {
240                                    this.transaction_id = Some(transaction);
241                                    this.buffer.update(cx, |buffer, cx| {
242                                        buffer.finalize_last_transaction(cx)
243                                    });
244                                }
245                            }
246
247                            cx.notify();
248                        });
249                    }
250
251                    diff.await?;
252                    anyhow::Ok(())
253                };
254
255                let result = generate.await;
256                if let Some(this) = this.upgrade(&cx) {
257                    this.update(&mut cx, |this, cx| {
258                        this.last_equal_ranges.clear();
259                        this.idle = true;
260                        if let Err(error) = result {
261                            this.error = Some(error);
262                        }
263                        cx.emit(Event::Finished);
264                        cx.notify();
265                    });
266                }
267            }
268        });
269        self.error.take();
270        self.idle = false;
271        cx.notify();
272    }
273
274    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
275        if let Some(transaction_id) = self.transaction_id {
276            self.buffer
277                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
278        }
279    }
280}
281
282fn strip_invalid_spans_from_codeblock(
283    stream: impl Stream<Item = Result<String>>,
284) -> impl Stream<Item = Result<String>> {
285    let mut first_line = true;
286    let mut buffer = String::new();
287    let mut starts_with_markdown_codeblock = false;
288    let mut includes_start_or_end_span = false;
289    stream.filter_map(move |chunk| {
290        let chunk = match chunk {
291            Ok(chunk) => chunk,
292            Err(err) => return future::ready(Some(Err(err))),
293        };
294        buffer.push_str(&chunk);
295
296        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
297            includes_start_or_end_span = true;
298
299            buffer = buffer
300                .strip_prefix("<|S|>")
301                .or_else(|| buffer.strip_prefix("<|S|"))
302                .unwrap_or(&buffer)
303                .to_string();
304        } else if buffer.ends_with("|E|>") {
305            includes_start_or_end_span = true;
306        } else if buffer.starts_with("<|")
307            || buffer.starts_with("<|S")
308            || buffer.starts_with("<|S|")
309            || buffer.ends_with("|")
310            || buffer.ends_with("|E")
311            || buffer.ends_with("|E|")
312        {
313            return future::ready(None);
314        }
315
316        if first_line {
317            if buffer == "" || buffer == "`" || buffer == "``" {
318                return future::ready(None);
319            } else if buffer.starts_with("```") {
320                starts_with_markdown_codeblock = true;
321                if let Some(newline_ix) = buffer.find('\n') {
322                    buffer.replace_range(..newline_ix + 1, "");
323                    first_line = false;
324                } else {
325                    return future::ready(None);
326                }
327            }
328        }
329
330        let mut text = buffer.to_string();
331        if starts_with_markdown_codeblock {
332            text = text
333                .strip_suffix("\n```\n")
334                .or_else(|| text.strip_suffix("\n```"))
335                .or_else(|| text.strip_suffix("\n``"))
336                .or_else(|| text.strip_suffix("\n`"))
337                .or_else(|| text.strip_suffix('\n'))
338                .unwrap_or(&text)
339                .to_string();
340        }
341
342        if includes_start_or_end_span {
343            text = text
344                .strip_suffix("|E|>")
345                .or_else(|| text.strip_suffix("E|>"))
346                .or_else(|| text.strip_prefix("|>"))
347                .or_else(|| text.strip_prefix(">"))
348                .unwrap_or(&text)
349                .to_string();
350        };
351
352        if text.contains('\n') {
353            first_line = false;
354        }
355
356        let remainder = buffer.split_off(text.len());
357        let result = if buffer.is_empty() {
358            None
359        } else {
360            Some(Ok(buffer.clone()))
361        };
362
363        buffer = remainder;
364        future::ready(result)
365    })
366}
367
368#[cfg(test)]
369mod tests {
370    use std::sync::Arc;
371
372    use super::*;
373    use ai::test::FakeCompletionProvider;
374    use futures::stream::{self};
375    use gpui::{executor::Deterministic, TestAppContext};
376    use indoc::indoc;
377    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
378    use rand::prelude::*;
379    use serde::Serialize;
380    use settings::SettingsStore;
381
382    #[derive(Serialize)]
383    pub struct DummyCompletionRequest {
384        pub name: String,
385    }
386
387    impl CompletionRequest for DummyCompletionRequest {
388        fn data(&self) -> serde_json::Result<String> {
389            serde_json::to_string(self)
390        }
391    }
392
393    #[gpui::test(iterations = 10)]
394    async fn test_transform_autoindent(
395        cx: &mut TestAppContext,
396        mut rng: StdRng,
397        deterministic: Arc<Deterministic>,
398    ) {
399        cx.set_global(cx.read(SettingsStore::test));
400        cx.update(language_settings::init);
401
402        let text = indoc! {"
403            fn main() {
404                let x = 0;
405                for _ in 0..10 {
406                    x += 1;
407                }
408            }
409        "};
410        let buffer =
411            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
412        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
413        let range = buffer.read_with(cx, |buffer, cx| {
414            let snapshot = buffer.snapshot(cx);
415            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
416        });
417        let provider = Arc::new(FakeCompletionProvider::new());
418        let codegen = cx.add_model(|cx| {
419            Codegen::new(
420                buffer.clone(),
421                CodegenKind::Transform { range },
422                provider.clone(),
423                cx,
424            )
425        });
426
427        let request = Box::new(DummyCompletionRequest {
428            name: "test".to_string(),
429        });
430        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
431
432        let mut new_text = concat!(
433            "       let mut x = 0;\n",
434            "       while x < 10 {\n",
435            "           x += 1;\n",
436            "       }",
437        );
438        while !new_text.is_empty() {
439            let max_len = cmp::min(new_text.len(), 10);
440            let len = rng.gen_range(1..=max_len);
441            let (chunk, suffix) = new_text.split_at(len);
442            println!("CHUNK: {:?}", &chunk);
443            provider.send_completion(chunk);
444            new_text = suffix;
445            deterministic.run_until_parked();
446        }
447        provider.finish_completion();
448        deterministic.run_until_parked();
449
450        assert_eq!(
451            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
452            indoc! {"
453                fn main() {
454                    let mut x = 0;
455                    while x < 10 {
456                        x += 1;
457                    }
458                }
459            "}
460        );
461    }
462
463    #[gpui::test(iterations = 10)]
464    async fn test_autoindent_when_generating_past_indentation(
465        cx: &mut TestAppContext,
466        mut rng: StdRng,
467        deterministic: Arc<Deterministic>,
468    ) {
469        cx.set_global(cx.read(SettingsStore::test));
470        cx.update(language_settings::init);
471
472        let text = indoc! {"
473            fn main() {
474                le
475            }
476        "};
477        let buffer =
478            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
479        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
480        let position = buffer.read_with(cx, |buffer, cx| {
481            let snapshot = buffer.snapshot(cx);
482            snapshot.anchor_before(Point::new(1, 6))
483        });
484        let provider = Arc::new(FakeCompletionProvider::new());
485        let codegen = cx.add_model(|cx| {
486            Codegen::new(
487                buffer.clone(),
488                CodegenKind::Generate { position },
489                provider.clone(),
490                cx,
491            )
492        });
493
494        let request = Box::new(DummyCompletionRequest {
495            name: "test".to_string(),
496        });
497        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
498
499        let mut new_text = concat!(
500            "t mut x = 0;\n",
501            "while x < 10 {\n",
502            "    x += 1;\n",
503            "}", //
504        );
505        while !new_text.is_empty() {
506            let max_len = cmp::min(new_text.len(), 10);
507            let len = rng.gen_range(1..=max_len);
508            let (chunk, suffix) = new_text.split_at(len);
509            provider.send_completion(chunk);
510            new_text = suffix;
511            deterministic.run_until_parked();
512        }
513        provider.finish_completion();
514        deterministic.run_until_parked();
515
516        assert_eq!(
517            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
518            indoc! {"
519                fn main() {
520                    let mut x = 0;
521                    while x < 10 {
522                        x += 1;
523                    }
524                }
525            "}
526        );
527    }
528
529    #[gpui::test(iterations = 10)]
530    async fn test_autoindent_when_generating_before_indentation(
531        cx: &mut TestAppContext,
532        mut rng: StdRng,
533        deterministic: Arc<Deterministic>,
534    ) {
535        cx.set_global(cx.read(SettingsStore::test));
536        cx.update(language_settings::init);
537
538        let text = concat!(
539            "fn main() {\n",
540            "  \n",
541            "}\n" //
542        );
543        let buffer =
544            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
545        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
546        let position = buffer.read_with(cx, |buffer, cx| {
547            let snapshot = buffer.snapshot(cx);
548            snapshot.anchor_before(Point::new(1, 2))
549        });
550        let provider = Arc::new(FakeCompletionProvider::new());
551        let codegen = cx.add_model(|cx| {
552            Codegen::new(
553                buffer.clone(),
554                CodegenKind::Generate { position },
555                provider.clone(),
556                cx,
557            )
558        });
559
560        let request = Box::new(DummyCompletionRequest {
561            name: "test".to_string(),
562        });
563        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
564
565        let mut new_text = concat!(
566            "let mut x = 0;\n",
567            "while x < 10 {\n",
568            "    x += 1;\n",
569            "}", //
570        );
571        while !new_text.is_empty() {
572            let max_len = cmp::min(new_text.len(), 10);
573            let len = rng.gen_range(1..=max_len);
574            let (chunk, suffix) = new_text.split_at(len);
575            println!("{:?}", &chunk);
576            provider.send_completion(chunk);
577            new_text = suffix;
578            deterministic.run_until_parked();
579        }
580        provider.finish_completion();
581        deterministic.run_until_parked();
582
583        assert_eq!(
584            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
585            indoc! {"
586                fn main() {
587                    let mut x = 0;
588                    while x < 10 {
589                        x += 1;
590                    }
591                }
592            "}
593        );
594    }
595
596    #[gpui::test]
597    async fn test_strip_invalid_spans_from_codeblock() {
598        assert_eq!(
599            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
600                .map(|chunk| chunk.unwrap())
601                .collect::<String>()
602                .await,
603            "Lorem ipsum dolor"
604        );
605        assert_eq!(
606            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
607                .map(|chunk| chunk.unwrap())
608                .collect::<String>()
609                .await,
610            "Lorem ipsum dolor"
611        );
612        assert_eq!(
613            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
614                .map(|chunk| chunk.unwrap())
615                .collect::<String>()
616                .await,
617            "Lorem ipsum dolor"
618        );
619        assert_eq!(
620            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
621                .map(|chunk| chunk.unwrap())
622                .collect::<String>()
623                .await,
624            "Lorem ipsum dolor"
625        );
626        assert_eq!(
627            strip_invalid_spans_from_codeblock(chunks(
628                "```html\n```js\nLorem ipsum dolor\n```\n```",
629                2
630            ))
631            .map(|chunk| chunk.unwrap())
632            .collect::<String>()
633            .await,
634            "```js\nLorem ipsum dolor\n```"
635        );
636        assert_eq!(
637            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
638                .map(|chunk| chunk.unwrap())
639                .collect::<String>()
640                .await,
641            "``\nLorem ipsum dolor\n```"
642        );
643        assert_eq!(
644            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
645                .map(|chunk| chunk.unwrap())
646                .collect::<String>()
647                .await,
648            "Lorem ipsum"
649        );
650
651        assert_eq!(
652            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
653                .map(|chunk| chunk.unwrap())
654                .collect::<String>()
655                .await,
656            "Lorem ipsum"
657        );
658
659        assert_eq!(
660            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
661                .map(|chunk| chunk.unwrap())
662                .collect::<String>()
663                .await,
664            "Lorem ipsum"
665        );
666        assert_eq!(
667            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
668                .map(|chunk| chunk.unwrap())
669                .collect::<String>()
670                .await,
671            "Lorem ipsum"
672        );
673        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
674            stream::iter(
675                text.chars()
676                    .collect::<Vec<_>>()
677                    .chunks(size)
678                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
679                    .collect::<Vec<_>>(),
680            )
681        }
682    }
683
684    fn rust_lang() -> Language {
685        Language::new(
686            LanguageConfig {
687                name: "Rust".into(),
688                path_suffixes: vec!["rs".to_string()],
689                ..Default::default()
690            },
691            Some(tree_sitter_rust::language()),
692        )
693        .with_indents_query(
694            r#"
695            (call_expression) @indent
696            (field_expression) @indent
697            (_ "(" ")" @end) @indent
698            (_ "{" "}" @end) @indent
699            "#,
700        )
701        .unwrap()
702    }
703}