codegen.rs

  1use crate::streaming_diff::{Hunk, StreamingDiff};
  2use ai::completion::{CompletionProvider, CompletionRequest};
  3use anyhow::Result;
  4use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  6use gpui::{Entity, ModelContext, ModelHandle, Task};
  7use language::{Rope, TransactionId};
  8use multi_buffer;
  9use std::{cmp, future, ops::Range};
 10
 11pub enum Event {
 12    Finished,
 13    Undone,
 14}
 15
 16#[derive(Clone)]
 17pub enum CodegenKind {
 18    Transform { range: Range<Anchor> },
 19    Generate { position: Anchor },
 20}
 21
 22pub struct Codegen {
 23    provider: Box<dyn CompletionProvider>,
 24    buffer: ModelHandle<MultiBuffer>,
 25    snapshot: MultiBufferSnapshot,
 26    kind: CodegenKind,
 27    last_equal_ranges: Vec<Range<Anchor>>,
 28    transaction_id: Option<TransactionId>,
 29    error: Option<anyhow::Error>,
 30    generation: Task<()>,
 31    idle: bool,
 32    _subscription: gpui::Subscription,
 33}
 34
 35impl Entity for Codegen {
 36    type Event = Event;
 37}
 38
 39impl Codegen {
 40    pub fn new(
 41        buffer: ModelHandle<MultiBuffer>,
 42        kind: CodegenKind,
 43        provider: Box<dyn CompletionProvider>,
 44        cx: &mut ModelContext<Self>,
 45    ) -> Self {
 46        let snapshot = buffer.read(cx).snapshot(cx);
 47        Self {
 48            provider,
 49            buffer: buffer.clone(),
 50            snapshot,
 51            kind,
 52            last_equal_ranges: Default::default(),
 53            transaction_id: Default::default(),
 54            error: Default::default(),
 55            idle: true,
 56            generation: Task::ready(()),
 57            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 58        }
 59    }
 60
 61    fn handle_buffer_event(
 62        &mut self,
 63        _buffer: ModelHandle<MultiBuffer>,
 64        event: &multi_buffer::Event,
 65        cx: &mut ModelContext<Self>,
 66    ) {
 67        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 68            if self.transaction_id == Some(*transaction_id) {
 69                self.transaction_id = None;
 70                self.generation = Task::ready(());
 71                cx.emit(Event::Undone);
 72            }
 73        }
 74    }
 75
 76    pub fn range(&self) -> Range<Anchor> {
 77        match &self.kind {
 78            CodegenKind::Transform { range } => range.clone(),
 79            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 80        }
 81    }
 82
 83    pub fn kind(&self) -> &CodegenKind {
 84        &self.kind
 85    }
 86
 87    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 88        &self.last_equal_ranges
 89    }
 90
 91    pub fn idle(&self) -> bool {
 92        self.idle
 93    }
 94
 95    pub fn error(&self) -> Option<&anyhow::Error> {
 96        self.error.as_ref()
 97    }
 98
 99    pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
100        let range = self.range();
101        let snapshot = self.snapshot.clone();
102        let selected_text = snapshot
103            .text_for_range(range.start..range.end)
104            .collect::<Rope>();
105
106        let selection_start = range.start.to_point(&snapshot);
107        let suggested_line_indent = snapshot
108            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
109            .into_values()
110            .next()
111            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
112
113        let response = self.provider.complete(prompt);
114        self.generation = cx.spawn_weak(|this, mut cx| {
115            async move {
116                let generate = async {
117                    let mut edit_start = range.start.to_offset(&snapshot);
118
119                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
120                    let diff = cx.background().spawn(async move {
121                        let chunks = strip_invalid_spans_from_codeblock(response.await?);
122                        futures::pin_mut!(chunks);
123                        let mut diff = StreamingDiff::new(selected_text.to_string());
124
125                        let mut new_text = String::new();
126                        let mut base_indent = None;
127                        let mut line_indent = None;
128                        let mut first_line = true;
129
130                        while let Some(chunk) = chunks.next().await {
131                            let chunk = chunk?;
132
133                            let mut lines = chunk.split('\n').peekable();
134                            while let Some(line) = lines.next() {
135                                new_text.push_str(line);
136                                if line_indent.is_none() {
137                                    if let Some(non_whitespace_ch_ix) =
138                                        new_text.find(|ch: char| !ch.is_whitespace())
139                                    {
140                                        line_indent = Some(non_whitespace_ch_ix);
141                                        base_indent = base_indent.or(line_indent);
142
143                                        let line_indent = line_indent.unwrap();
144                                        let base_indent = base_indent.unwrap();
145                                        let indent_delta = line_indent as i32 - base_indent as i32;
146                                        let mut corrected_indent_len = cmp::max(
147                                            0,
148                                            suggested_line_indent.len as i32 + indent_delta,
149                                        )
150                                            as usize;
151                                        if first_line {
152                                            corrected_indent_len = corrected_indent_len
153                                                .saturating_sub(selection_start.column as usize);
154                                        }
155
156                                        let indent_char = suggested_line_indent.char();
157                                        let mut indent_buffer = [0; 4];
158                                        let indent_str =
159                                            indent_char.encode_utf8(&mut indent_buffer);
160                                        new_text.replace_range(
161                                            ..line_indent,
162                                            &indent_str.repeat(corrected_indent_len),
163                                        );
164                                    }
165                                }
166
167                                if line_indent.is_some() {
168                                    hunks_tx.send(diff.push_new(&new_text)).await?;
169                                    new_text.clear();
170                                }
171
172                                if lines.peek().is_some() {
173                                    hunks_tx.send(diff.push_new("\n")).await?;
174                                    line_indent = None;
175                                    first_line = false;
176                                }
177                            }
178                        }
179                        hunks_tx.send(diff.push_new(&new_text)).await?;
180                        hunks_tx.send(diff.finish()).await?;
181
182                        anyhow::Ok(())
183                    });
184
185                    while let Some(hunks) = hunks_rx.next().await {
186                        let this = if let Some(this) = this.upgrade(&cx) {
187                            this
188                        } else {
189                            break;
190                        };
191
192                        this.update(&mut cx, |this, cx| {
193                            this.last_equal_ranges.clear();
194
195                            let transaction = this.buffer.update(cx, |buffer, cx| {
196                                // Avoid grouping assistant edits with user edits.
197                                buffer.finalize_last_transaction(cx);
198
199                                buffer.start_transaction(cx);
200                                buffer.edit(
201                                    hunks.into_iter().filter_map(|hunk| match hunk {
202                                        Hunk::Insert { text } => {
203                                            let edit_start = snapshot.anchor_after(edit_start);
204                                            Some((edit_start..edit_start, text))
205                                        }
206                                        Hunk::Remove { len } => {
207                                            let edit_end = edit_start + len;
208                                            let edit_range = snapshot.anchor_after(edit_start)
209                                                ..snapshot.anchor_before(edit_end);
210                                            edit_start = edit_end;
211                                            Some((edit_range, String::new()))
212                                        }
213                                        Hunk::Keep { len } => {
214                                            let edit_end = edit_start + len;
215                                            let edit_range = snapshot.anchor_after(edit_start)
216                                                ..snapshot.anchor_before(edit_end);
217                                            edit_start = edit_end;
218                                            this.last_equal_ranges.push(edit_range);
219                                            None
220                                        }
221                                    }),
222                                    None,
223                                    cx,
224                                );
225
226                                buffer.end_transaction(cx)
227                            });
228
229                            if let Some(transaction) = transaction {
230                                if let Some(first_transaction) = this.transaction_id {
231                                    // Group all assistant edits into the first transaction.
232                                    this.buffer.update(cx, |buffer, cx| {
233                                        buffer.merge_transactions(
234                                            transaction,
235                                            first_transaction,
236                                            cx,
237                                        )
238                                    });
239                                } else {
240                                    this.transaction_id = Some(transaction);
241                                    this.buffer.update(cx, |buffer, cx| {
242                                        buffer.finalize_last_transaction(cx)
243                                    });
244                                }
245                            }
246
247                            cx.notify();
248                        });
249                    }
250
251                    diff.await?;
252                    anyhow::Ok(())
253                };
254
255                let result = generate.await;
256                if let Some(this) = this.upgrade(&cx) {
257                    this.update(&mut cx, |this, cx| {
258                        this.last_equal_ranges.clear();
259                        this.idle = true;
260                        if let Err(error) = result {
261                            this.error = Some(error);
262                        }
263                        cx.emit(Event::Finished);
264                        cx.notify();
265                    });
266                }
267            }
268        });
269        self.error.take();
270        self.idle = false;
271        cx.notify();
272    }
273
274    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
275        if let Some(transaction_id) = self.transaction_id {
276            self.buffer
277                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
278        }
279    }
280}
281
282fn strip_invalid_spans_from_codeblock(
283    stream: impl Stream<Item = Result<String>>,
284) -> impl Stream<Item = Result<String>> {
285    let mut first_line = true;
286    let mut buffer = String::new();
287    let mut starts_with_markdown_codeblock = false;
288    let mut includes_start_or_end_span = false;
289    stream.filter_map(move |chunk| {
290        let chunk = match chunk {
291            Ok(chunk) => chunk,
292            Err(err) => return future::ready(Some(Err(err))),
293        };
294        buffer.push_str(&chunk);
295
296        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
297            includes_start_or_end_span = true;
298
299            buffer = buffer
300                .strip_prefix("<|S|>")
301                .or_else(|| buffer.strip_prefix("<|S|"))
302                .unwrap_or(&buffer)
303                .to_string();
304        } else if buffer.ends_with("|E|>") {
305            includes_start_or_end_span = true;
306        } else if buffer.starts_with("<|")
307            || buffer.starts_with("<|S")
308            || buffer.starts_with("<|S|")
309            || buffer.ends_with("|")
310            || buffer.ends_with("|E")
311            || buffer.ends_with("|E|")
312        {
313            return future::ready(None);
314        }
315
316        if first_line {
317            if buffer == "" || buffer == "`" || buffer == "``" {
318                return future::ready(None);
319            } else if buffer.starts_with("```") {
320                starts_with_markdown_codeblock = true;
321                if let Some(newline_ix) = buffer.find('\n') {
322                    buffer.replace_range(..newline_ix + 1, "");
323                    first_line = false;
324                } else {
325                    return future::ready(None);
326                }
327            }
328        }
329
330        let mut text = buffer.to_string();
331        if starts_with_markdown_codeblock {
332            text = text
333                .strip_suffix("\n```\n")
334                .or_else(|| text.strip_suffix("\n```"))
335                .or_else(|| text.strip_suffix("\n``"))
336                .or_else(|| text.strip_suffix("\n`"))
337                .or_else(|| text.strip_suffix('\n'))
338                .unwrap_or(&text)
339                .to_string();
340        }
341
342        if includes_start_or_end_span {
343            text = text
344                .strip_suffix("|E|>")
345                .or_else(|| text.strip_suffix("E|>"))
346                .or_else(|| text.strip_prefix("|>"))
347                .or_else(|| text.strip_prefix(">"))
348                .unwrap_or(&text)
349                .to_string();
350        };
351
352        if text.contains('\n') {
353            first_line = false;
354        }
355
356        let remainder = buffer.split_off(text.len());
357        let result = if buffer.is_empty() {
358            None
359        } else {
360            Some(Ok(buffer.clone()))
361        };
362
363        buffer = remainder;
364        future::ready(result)
365    })
366}
367
368#[cfg(test)]
369mod tests {
370    use std::sync::Arc;
371
372    use super::*;
373    use ai::test::FakeCompletionProvider;
374    use futures::stream::{self};
375    use gpui::{executor::Deterministic, TestAppContext};
376    use indoc::indoc;
377    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
378    use rand::prelude::*;
379    use serde::Serialize;
380    use settings::SettingsStore;
381
382    #[derive(Serialize)]
383    pub struct DummyCompletionRequest {
384        pub name: String,
385    }
386
387    impl CompletionRequest for DummyCompletionRequest {
388        fn data(&self) -> serde_json::Result<String> {
389            serde_json::to_string(self)
390        }
391    }
392
393    #[gpui::test(iterations = 10)]
394    async fn test_transform_autoindent(
395        cx: &mut TestAppContext,
396        mut rng: StdRng,
397        deterministic: Arc<Deterministic>,
398    ) {
399        cx.set_global(cx.read(SettingsStore::test));
400        cx.update(language_settings::init);
401
402        let text = indoc! {"
403            fn main() {
404                let x = 0;
405                for _ in 0..10 {
406                    x += 1;
407                }
408            }
409        "};
410        let buffer =
411            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
412        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
413        let range = buffer.read_with(cx, |buffer, cx| {
414            let snapshot = buffer.snapshot(cx);
415            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
416        });
417        let provider = Box::new(FakeCompletionProvider::new());
418        let codegen = cx.add_model(|cx| {
419            Codegen::new(
420                buffer.clone(),
421                CodegenKind::Transform { range },
422                provider.clone(),
423                cx,
424            )
425        });
426
427        let request = Box::new(DummyCompletionRequest {
428            name: "test".to_string(),
429        });
430        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
431
432        let mut new_text = concat!(
433            "       let mut x = 0;\n",
434            "       while x < 10 {\n",
435            "           x += 1;\n",
436            "       }",
437        );
438        while !new_text.is_empty() {
439            let max_len = cmp::min(new_text.len(), 10);
440            let len = rng.gen_range(1..=max_len);
441            let (chunk, suffix) = new_text.split_at(len);
442            provider.send_completion(chunk);
443            new_text = suffix;
444            deterministic.run_until_parked();
445        }
446        provider.finish_completion();
447        deterministic.run_until_parked();
448
449        assert_eq!(
450            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
451            indoc! {"
452                fn main() {
453                    let mut x = 0;
454                    while x < 10 {
455                        x += 1;
456                    }
457                }
458            "}
459        );
460    }
461
462    #[gpui::test(iterations = 10)]
463    async fn test_autoindent_when_generating_past_indentation(
464        cx: &mut TestAppContext,
465        mut rng: StdRng,
466        deterministic: Arc<Deterministic>,
467    ) {
468        cx.set_global(cx.read(SettingsStore::test));
469        cx.update(language_settings::init);
470
471        let text = indoc! {"
472            fn main() {
473                le
474            }
475        "};
476        let buffer =
477            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
478        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
479        let position = buffer.read_with(cx, |buffer, cx| {
480            let snapshot = buffer.snapshot(cx);
481            snapshot.anchor_before(Point::new(1, 6))
482        });
483        let provider = Box::new(FakeCompletionProvider::new());
484        let codegen = cx.add_model(|cx| {
485            Codegen::new(
486                buffer.clone(),
487                CodegenKind::Generate { position },
488                provider.clone(),
489                cx,
490            )
491        });
492
493        let request = Box::new(DummyCompletionRequest {
494            name: "test".to_string(),
495        });
496        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
497
498        let mut new_text = concat!(
499            "t mut x = 0;\n",
500            "while x < 10 {\n",
501            "    x += 1;\n",
502            "}", //
503        );
504        while !new_text.is_empty() {
505            let max_len = cmp::min(new_text.len(), 10);
506            let len = rng.gen_range(1..=max_len);
507            let (chunk, suffix) = new_text.split_at(len);
508            provider.send_completion(chunk);
509            new_text = suffix;
510            deterministic.run_until_parked();
511        }
512        provider.finish_completion();
513        deterministic.run_until_parked();
514
515        assert_eq!(
516            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
517            indoc! {"
518                fn main() {
519                    let mut x = 0;
520                    while x < 10 {
521                        x += 1;
522                    }
523                }
524            "}
525        );
526    }
527
528    #[gpui::test(iterations = 10)]
529    async fn test_autoindent_when_generating_before_indentation(
530        cx: &mut TestAppContext,
531        mut rng: StdRng,
532        deterministic: Arc<Deterministic>,
533    ) {
534        cx.set_global(cx.read(SettingsStore::test));
535        cx.update(language_settings::init);
536
537        let text = concat!(
538            "fn main() {\n",
539            "  \n",
540            "}\n" //
541        );
542        let buffer =
543            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
544        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
545        let position = buffer.read_with(cx, |buffer, cx| {
546            let snapshot = buffer.snapshot(cx);
547            snapshot.anchor_before(Point::new(1, 2))
548        });
549        let provider = Box::new(FakeCompletionProvider::new());
550        let codegen = cx.add_model(|cx| {
551            Codegen::new(
552                buffer.clone(),
553                CodegenKind::Generate { position },
554                provider.clone(),
555                cx,
556            )
557        });
558
559        let request = Box::new(DummyCompletionRequest {
560            name: "test".to_string(),
561        });
562        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
563
564        let mut new_text = concat!(
565            "let mut x = 0;\n",
566            "while x < 10 {\n",
567            "    x += 1;\n",
568            "}", //
569        );
570        while !new_text.is_empty() {
571            let max_len = cmp::min(new_text.len(), 10);
572            let len = rng.gen_range(1..=max_len);
573            let (chunk, suffix) = new_text.split_at(len);
574            provider.send_completion(chunk);
575            new_text = suffix;
576            deterministic.run_until_parked();
577        }
578        provider.finish_completion();
579        deterministic.run_until_parked();
580
581        assert_eq!(
582            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
583            indoc! {"
584                fn main() {
585                    let mut x = 0;
586                    while x < 10 {
587                        x += 1;
588                    }
589                }
590            "}
591        );
592    }
593
594    #[gpui::test]
595    async fn test_strip_invalid_spans_from_codeblock() {
596        assert_eq!(
597            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
598                .map(|chunk| chunk.unwrap())
599                .collect::<String>()
600                .await,
601            "Lorem ipsum dolor"
602        );
603        assert_eq!(
604            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
605                .map(|chunk| chunk.unwrap())
606                .collect::<String>()
607                .await,
608            "Lorem ipsum dolor"
609        );
610        assert_eq!(
611            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
612                .map(|chunk| chunk.unwrap())
613                .collect::<String>()
614                .await,
615            "Lorem ipsum dolor"
616        );
617        assert_eq!(
618            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
619                .map(|chunk| chunk.unwrap())
620                .collect::<String>()
621                .await,
622            "Lorem ipsum dolor"
623        );
624        assert_eq!(
625            strip_invalid_spans_from_codeblock(chunks(
626                "```html\n```js\nLorem ipsum dolor\n```\n```",
627                2
628            ))
629            .map(|chunk| chunk.unwrap())
630            .collect::<String>()
631            .await,
632            "```js\nLorem ipsum dolor\n```"
633        );
634        assert_eq!(
635            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
636                .map(|chunk| chunk.unwrap())
637                .collect::<String>()
638                .await,
639            "``\nLorem ipsum dolor\n```"
640        );
641        assert_eq!(
642            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
643                .map(|chunk| chunk.unwrap())
644                .collect::<String>()
645                .await,
646            "Lorem ipsum"
647        );
648
649        assert_eq!(
650            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
651                .map(|chunk| chunk.unwrap())
652                .collect::<String>()
653                .await,
654            "Lorem ipsum"
655        );
656
657        assert_eq!(
658            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
659                .map(|chunk| chunk.unwrap())
660                .collect::<String>()
661                .await,
662            "Lorem ipsum"
663        );
664        assert_eq!(
665            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
666                .map(|chunk| chunk.unwrap())
667                .collect::<String>()
668                .await,
669            "Lorem ipsum"
670        );
671        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
672            stream::iter(
673                text.chars()
674                    .collect::<Vec<_>>()
675                    .chunks(size)
676                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
677                    .collect::<Vec<_>>(),
678            )
679        }
680    }
681
682    fn rust_lang() -> Language {
683        Language::new(
684            LanguageConfig {
685                name: "Rust".into(),
686                path_suffixes: vec!["rs".to_string()],
687                ..Default::default()
688            },
689            Some(tree_sitter_rust::language()),
690        )
691        .with_indents_query(
692            r#"
693            (call_expression) @indent
694            (field_expression) @indent
695            (_ "(" ")" @end) @indent
696            (_ "{" "}" @end) @indent
697            "#,
698        )
699        .unwrap()
700    }
701}