codegen.rs

  1use crate::streaming_diff::{Hunk, StreamingDiff};
  2use ai::completion::{CompletionProvider, CompletionRequest};
  3use anyhow::Result;
  4use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  6use gpui::{EventEmitter, Model, ModelContext, Task};
  7use language::{Rope, TransactionId};
  8use multi_buffer;
  9use std::{cmp, future, ops::Range, sync::Arc};
 10
 11pub enum Event {
 12    Finished,
 13    Undone,
 14}
 15
 16#[derive(Clone)]
 17pub enum CodegenKind {
 18    Transform { range: Range<Anchor> },
 19    Generate { position: Anchor },
 20}
 21
 22pub struct Codegen {
 23    provider: Arc<dyn CompletionProvider>,
 24    buffer: Model<MultiBuffer>,
 25    snapshot: MultiBufferSnapshot,
 26    kind: CodegenKind,
 27    last_equal_ranges: Vec<Range<Anchor>>,
 28    transaction_id: Option<TransactionId>,
 29    error: Option<anyhow::Error>,
 30    generation: Task<()>,
 31    idle: bool,
 32    _subscription: gpui::Subscription,
 33}
 34
 35impl EventEmitter<Event> for Codegen {}
 36
 37impl Codegen {
 38    pub fn new(
 39        buffer: Model<MultiBuffer>,
 40        kind: CodegenKind,
 41        provider: Arc<dyn CompletionProvider>,
 42        cx: &mut ModelContext<Self>,
 43    ) -> Self {
 44        let snapshot = buffer.read(cx).snapshot(cx);
 45        Self {
 46            provider,
 47            buffer: buffer.clone(),
 48            snapshot,
 49            kind,
 50            last_equal_ranges: Default::default(),
 51            transaction_id: Default::default(),
 52            error: Default::default(),
 53            idle: true,
 54            generation: Task::ready(()),
 55            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 56        }
 57    }
 58
 59    fn handle_buffer_event(
 60        &mut self,
 61        _buffer: Model<MultiBuffer>,
 62        event: &multi_buffer::Event,
 63        cx: &mut ModelContext<Self>,
 64    ) {
 65        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 66            if self.transaction_id == Some(*transaction_id) {
 67                self.transaction_id = None;
 68                self.generation = Task::ready(());
 69                cx.emit(Event::Undone);
 70            }
 71        }
 72    }
 73
 74    pub fn range(&self) -> Range<Anchor> {
 75        match &self.kind {
 76            CodegenKind::Transform { range } => range.clone(),
 77            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 78        }
 79    }
 80
 81    pub fn kind(&self) -> &CodegenKind {
 82        &self.kind
 83    }
 84
 85    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 86        &self.last_equal_ranges
 87    }
 88
 89    pub fn idle(&self) -> bool {
 90        self.idle
 91    }
 92
 93    pub fn error(&self) -> Option<&anyhow::Error> {
 94        self.error.as_ref()
 95    }
 96
 97    pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
 98        let range = self.range();
 99        let snapshot = self.snapshot.clone();
100        let selected_text = snapshot
101            .text_for_range(range.start..range.end)
102            .collect::<Rope>();
103
104        let selection_start = range.start.to_point(&snapshot);
105        let suggested_line_indent = snapshot
106            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
107            .into_values()
108            .next()
109            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
110
111        let response = self.provider.complete(prompt);
112        self.generation = cx.spawn(|this, mut cx| {
113            async move {
114                let generate = async {
115                    let mut edit_start = range.start.to_offset(&snapshot);
116
117                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
118                    let diff = cx.background_executor().spawn(async move {
119                        let chunks = strip_invalid_spans_from_codeblock(response.await?);
120                        futures::pin_mut!(chunks);
121                        let mut diff = StreamingDiff::new(selected_text.to_string());
122
123                        let mut new_text = String::new();
124                        let mut base_indent = None;
125                        let mut line_indent = None;
126                        let mut first_line = true;
127
128                        while let Some(chunk) = chunks.next().await {
129                            let chunk = chunk?;
130
131                            let mut lines = chunk.split('\n').peekable();
132                            while let Some(line) = lines.next() {
133                                new_text.push_str(line);
134                                if line_indent.is_none() {
135                                    if let Some(non_whitespace_ch_ix) =
136                                        new_text.find(|ch: char| !ch.is_whitespace())
137                                    {
138                                        line_indent = Some(non_whitespace_ch_ix);
139                                        base_indent = base_indent.or(line_indent);
140
141                                        let line_indent = line_indent.unwrap();
142                                        let base_indent = base_indent.unwrap();
143                                        let indent_delta = line_indent as i32 - base_indent as i32;
144                                        let mut corrected_indent_len = cmp::max(
145                                            0,
146                                            suggested_line_indent.len as i32 + indent_delta,
147                                        )
148                                            as usize;
149                                        if first_line {
150                                            corrected_indent_len = corrected_indent_len
151                                                .saturating_sub(selection_start.column as usize);
152                                        }
153
154                                        let indent_char = suggested_line_indent.char();
155                                        let mut indent_buffer = [0; 4];
156                                        let indent_str =
157                                            indent_char.encode_utf8(&mut indent_buffer);
158                                        new_text.replace_range(
159                                            ..line_indent,
160                                            &indent_str.repeat(corrected_indent_len),
161                                        );
162                                    }
163                                }
164
165                                if line_indent.is_some() {
166                                    hunks_tx.send(diff.push_new(&new_text)).await?;
167                                    new_text.clear();
168                                }
169
170                                if lines.peek().is_some() {
171                                    hunks_tx.send(diff.push_new("\n")).await?;
172                                    line_indent = None;
173                                    first_line = false;
174                                }
175                            }
176                        }
177                        hunks_tx.send(diff.push_new(&new_text)).await?;
178                        hunks_tx.send(diff.finish()).await?;
179
180                        anyhow::Ok(())
181                    });
182
183                    while let Some(hunks) = hunks_rx.next().await {
184                        this.update(&mut cx, |this, cx| {
185                            this.last_equal_ranges.clear();
186
187                            let transaction = this.buffer.update(cx, |buffer, cx| {
188                                // Avoid grouping assistant edits with user edits.
189                                buffer.finalize_last_transaction(cx);
190
191                                buffer.start_transaction(cx);
192                                buffer.edit(
193                                    hunks.into_iter().filter_map(|hunk| match hunk {
194                                        Hunk::Insert { text } => {
195                                            let edit_start = snapshot.anchor_after(edit_start);
196                                            Some((edit_start..edit_start, text))
197                                        }
198                                        Hunk::Remove { len } => {
199                                            let edit_end = edit_start + len;
200                                            let edit_range = snapshot.anchor_after(edit_start)
201                                                ..snapshot.anchor_before(edit_end);
202                                            edit_start = edit_end;
203                                            Some((edit_range, String::new()))
204                                        }
205                                        Hunk::Keep { len } => {
206                                            let edit_end = edit_start + len;
207                                            let edit_range = snapshot.anchor_after(edit_start)
208                                                ..snapshot.anchor_before(edit_end);
209                                            edit_start = edit_end;
210                                            this.last_equal_ranges.push(edit_range);
211                                            None
212                                        }
213                                    }),
214                                    None,
215                                    cx,
216                                );
217
218                                buffer.end_transaction(cx)
219                            });
220
221                            if let Some(transaction) = transaction {
222                                if let Some(first_transaction) = this.transaction_id {
223                                    // Group all assistant edits into the first transaction.
224                                    this.buffer.update(cx, |buffer, cx| {
225                                        buffer.merge_transactions(
226                                            transaction,
227                                            first_transaction,
228                                            cx,
229                                        )
230                                    });
231                                } else {
232                                    this.transaction_id = Some(transaction);
233                                    this.buffer.update(cx, |buffer, cx| {
234                                        buffer.finalize_last_transaction(cx)
235                                    });
236                                }
237                            }
238
239                            cx.notify();
240                        })?;
241                    }
242
243                    diff.await?;
244                    anyhow::Ok(())
245                };
246
247                let result = generate.await;
248                this.update(&mut cx, |this, cx| {
249                    this.last_equal_ranges.clear();
250                    this.idle = true;
251                    if let Err(error) = result {
252                        this.error = Some(error);
253                    }
254                    cx.emit(Event::Finished);
255                    cx.notify();
256                })
257                .ok();
258            }
259        });
260        self.error.take();
261        self.idle = false;
262        cx.notify();
263    }
264
265    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
266        if let Some(transaction_id) = self.transaction_id {
267            self.buffer
268                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
269        }
270    }
271}
272
273fn strip_invalid_spans_from_codeblock(
274    stream: impl Stream<Item = Result<String>>,
275) -> impl Stream<Item = Result<String>> {
276    let mut first_line = true;
277    let mut buffer = String::new();
278    let mut starts_with_markdown_codeblock = false;
279    let mut includes_start_or_end_span = false;
280    stream.filter_map(move |chunk| {
281        let chunk = match chunk {
282            Ok(chunk) => chunk,
283            Err(err) => return future::ready(Some(Err(err))),
284        };
285        buffer.push_str(&chunk);
286
287        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
288            includes_start_or_end_span = true;
289
290            buffer = buffer
291                .strip_prefix("<|S|>")
292                .or_else(|| buffer.strip_prefix("<|S|"))
293                .unwrap_or(&buffer)
294                .to_string();
295        } else if buffer.ends_with("|E|>") {
296            includes_start_or_end_span = true;
297        } else if buffer.starts_with("<|")
298            || buffer.starts_with("<|S")
299            || buffer.starts_with("<|S|")
300            || buffer.ends_with("|")
301            || buffer.ends_with("|E")
302            || buffer.ends_with("|E|")
303        {
304            return future::ready(None);
305        }
306
307        if first_line {
308            if buffer == "" || buffer == "`" || buffer == "``" {
309                return future::ready(None);
310            } else if buffer.starts_with("```") {
311                starts_with_markdown_codeblock = true;
312                if let Some(newline_ix) = buffer.find('\n') {
313                    buffer.replace_range(..newline_ix + 1, "");
314                    first_line = false;
315                } else {
316                    return future::ready(None);
317                }
318            }
319        }
320
321        let mut text = buffer.to_string();
322        if starts_with_markdown_codeblock {
323            text = text
324                .strip_suffix("\n```\n")
325                .or_else(|| text.strip_suffix("\n```"))
326                .or_else(|| text.strip_suffix("\n``"))
327                .or_else(|| text.strip_suffix("\n`"))
328                .or_else(|| text.strip_suffix('\n'))
329                .unwrap_or(&text)
330                .to_string();
331        }
332
333        if includes_start_or_end_span {
334            text = text
335                .strip_suffix("|E|>")
336                .or_else(|| text.strip_suffix("E|>"))
337                .or_else(|| text.strip_prefix("|>"))
338                .or_else(|| text.strip_prefix(">"))
339                .unwrap_or(&text)
340                .to_string();
341        };
342
343        if text.contains('\n') {
344            first_line = false;
345        }
346
347        let remainder = buffer.split_off(text.len());
348        let result = if buffer.is_empty() {
349            None
350        } else {
351            Some(Ok(buffer.clone()))
352        };
353
354        buffer = remainder;
355        future::ready(result)
356    })
357}
358
359#[cfg(test)]
360mod tests {
361    use std::sync::Arc;
362
363    use super::*;
364    use ai::test::FakeCompletionProvider;
365    use futures::stream::{self};
366    use gpui::{Context, TestAppContext};
367    use indoc::indoc;
368    use language::{
369        language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig,
370        LanguageMatcher, Point,
371    };
372    use rand::prelude::*;
373    use serde::Serialize;
374    use settings::SettingsStore;
375
376    #[derive(Serialize)]
377    pub struct DummyCompletionRequest {
378        pub name: String,
379    }
380
381    impl CompletionRequest for DummyCompletionRequest {
382        fn data(&self) -> serde_json::Result<String> {
383            serde_json::to_string(self)
384        }
385    }
386
387    #[gpui::test(iterations = 10)]
388    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
389        cx.set_global(cx.update(SettingsStore::test));
390        cx.update(language_settings::init);
391
392        let text = indoc! {"
393            fn main() {
394                let x = 0;
395                for _ in 0..10 {
396                    x += 1;
397                }
398            }
399        "};
400        let buffer = cx.new_model(|cx| {
401            Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
402        });
403        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
404        let range = buffer.read_with(cx, |buffer, cx| {
405            let snapshot = buffer.snapshot(cx);
406            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
407        });
408        let provider = Arc::new(FakeCompletionProvider::new());
409        let codegen = cx.new_model(|cx| {
410            Codegen::new(
411                buffer.clone(),
412                CodegenKind::Transform { range },
413                provider.clone(),
414                cx,
415            )
416        });
417
418        let request = Box::new(DummyCompletionRequest {
419            name: "test".to_string(),
420        });
421        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
422
423        let mut new_text = concat!(
424            "       let mut x = 0;\n",
425            "       while x < 10 {\n",
426            "           x += 1;\n",
427            "       }",
428        );
429        while !new_text.is_empty() {
430            let max_len = cmp::min(new_text.len(), 10);
431            let len = rng.gen_range(1..=max_len);
432            let (chunk, suffix) = new_text.split_at(len);
433            println!("CHUNK: {:?}", &chunk);
434            provider.send_completion(chunk);
435            new_text = suffix;
436            cx.background_executor.run_until_parked();
437        }
438        provider.finish_completion();
439        cx.background_executor.run_until_parked();
440
441        assert_eq!(
442            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
443            indoc! {"
444                fn main() {
445                    let mut x = 0;
446                    while x < 10 {
447                        x += 1;
448                    }
449                }
450            "}
451        );
452    }
453
454    #[gpui::test(iterations = 10)]
455    async fn test_autoindent_when_generating_past_indentation(
456        cx: &mut TestAppContext,
457        mut rng: StdRng,
458    ) {
459        cx.set_global(cx.update(SettingsStore::test));
460        cx.update(language_settings::init);
461
462        let text = indoc! {"
463            fn main() {
464                le
465            }
466        "};
467        let buffer = cx.new_model(|cx| {
468            Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
469        });
470        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
471        let position = buffer.read_with(cx, |buffer, cx| {
472            let snapshot = buffer.snapshot(cx);
473            snapshot.anchor_before(Point::new(1, 6))
474        });
475        let provider = Arc::new(FakeCompletionProvider::new());
476        let codegen = cx.new_model(|cx| {
477            Codegen::new(
478                buffer.clone(),
479                CodegenKind::Generate { position },
480                provider.clone(),
481                cx,
482            )
483        });
484
485        let request = Box::new(DummyCompletionRequest {
486            name: "test".to_string(),
487        });
488        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
489
490        let mut new_text = concat!(
491            "t mut x = 0;\n",
492            "while x < 10 {\n",
493            "    x += 1;\n",
494            "}", //
495        );
496        while !new_text.is_empty() {
497            let max_len = cmp::min(new_text.len(), 10);
498            let len = rng.gen_range(1..=max_len);
499            let (chunk, suffix) = new_text.split_at(len);
500            provider.send_completion(chunk);
501            new_text = suffix;
502            cx.background_executor.run_until_parked();
503        }
504        provider.finish_completion();
505        cx.background_executor.run_until_parked();
506
507        assert_eq!(
508            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
509            indoc! {"
510                fn main() {
511                    let mut x = 0;
512                    while x < 10 {
513                        x += 1;
514                    }
515                }
516            "}
517        );
518    }
519
520    #[gpui::test(iterations = 10)]
521    async fn test_autoindent_when_generating_before_indentation(
522        cx: &mut TestAppContext,
523        mut rng: StdRng,
524    ) {
525        cx.set_global(cx.update(SettingsStore::test));
526        cx.update(language_settings::init);
527
528        let text = concat!(
529            "fn main() {\n",
530            "  \n",
531            "}\n" //
532        );
533        let buffer = cx.new_model(|cx| {
534            Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
535        });
536        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
537        let position = buffer.read_with(cx, |buffer, cx| {
538            let snapshot = buffer.snapshot(cx);
539            snapshot.anchor_before(Point::new(1, 2))
540        });
541        let provider = Arc::new(FakeCompletionProvider::new());
542        let codegen = cx.new_model(|cx| {
543            Codegen::new(
544                buffer.clone(),
545                CodegenKind::Generate { position },
546                provider.clone(),
547                cx,
548            )
549        });
550
551        let request = Box::new(DummyCompletionRequest {
552            name: "test".to_string(),
553        });
554        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
555
556        let mut new_text = concat!(
557            "let mut x = 0;\n",
558            "while x < 10 {\n",
559            "    x += 1;\n",
560            "}", //
561        );
562        while !new_text.is_empty() {
563            let max_len = cmp::min(new_text.len(), 10);
564            let len = rng.gen_range(1..=max_len);
565            let (chunk, suffix) = new_text.split_at(len);
566            println!("{:?}", &chunk);
567            provider.send_completion(chunk);
568            new_text = suffix;
569            cx.background_executor.run_until_parked();
570        }
571        provider.finish_completion();
572        cx.background_executor.run_until_parked();
573
574        assert_eq!(
575            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
576            indoc! {"
577                fn main() {
578                    let mut x = 0;
579                    while x < 10 {
580                        x += 1;
581                    }
582                }
583            "}
584        );
585    }
586
587    #[gpui::test]
588    async fn test_strip_invalid_spans_from_codeblock() {
589        assert_eq!(
590            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
591                .map(|chunk| chunk.unwrap())
592                .collect::<String>()
593                .await,
594            "Lorem ipsum dolor"
595        );
596        assert_eq!(
597            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
598                .map(|chunk| chunk.unwrap())
599                .collect::<String>()
600                .await,
601            "Lorem ipsum dolor"
602        );
603        assert_eq!(
604            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
605                .map(|chunk| chunk.unwrap())
606                .collect::<String>()
607                .await,
608            "Lorem ipsum dolor"
609        );
610        assert_eq!(
611            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
612                .map(|chunk| chunk.unwrap())
613                .collect::<String>()
614                .await,
615            "Lorem ipsum dolor"
616        );
617        assert_eq!(
618            strip_invalid_spans_from_codeblock(chunks(
619                "```html\n```js\nLorem ipsum dolor\n```\n```",
620                2
621            ))
622            .map(|chunk| chunk.unwrap())
623            .collect::<String>()
624            .await,
625            "```js\nLorem ipsum dolor\n```"
626        );
627        assert_eq!(
628            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
629                .map(|chunk| chunk.unwrap())
630                .collect::<String>()
631                .await,
632            "``\nLorem ipsum dolor\n```"
633        );
634        assert_eq!(
635            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
636                .map(|chunk| chunk.unwrap())
637                .collect::<String>()
638                .await,
639            "Lorem ipsum"
640        );
641
642        assert_eq!(
643            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
644                .map(|chunk| chunk.unwrap())
645                .collect::<String>()
646                .await,
647            "Lorem ipsum"
648        );
649
650        assert_eq!(
651            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
652                .map(|chunk| chunk.unwrap())
653                .collect::<String>()
654                .await,
655            "Lorem ipsum"
656        );
657        assert_eq!(
658            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
659                .map(|chunk| chunk.unwrap())
660                .collect::<String>()
661                .await,
662            "Lorem ipsum"
663        );
664        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
665            stream::iter(
666                text.chars()
667                    .collect::<Vec<_>>()
668                    .chunks(size)
669                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
670                    .collect::<Vec<_>>(),
671            )
672        }
673    }
674
675    fn rust_lang() -> Language {
676        Language::new(
677            LanguageConfig {
678                name: "Rust".into(),
679                matcher: LanguageMatcher {
680                    path_suffixes: vec!["rs".to_string()],
681                    ..Default::default()
682                },
683                ..Default::default()
684            },
685            Some(tree_sitter_rust::language()),
686        )
687        .with_indents_query(
688            r#"
689            (call_expression) @indent
690            (field_expression) @indent
691            (_ "(" ")" @end) @indent
692            (_ "{" "}" @end) @indent
693            "#,
694        )
695        .unwrap()
696    }
697}