codegen.rs

  1use crate::streaming_diff::{Hunk, StreamingDiff};
  2use ai::completion::{CompletionProvider, CompletionRequest};
  3use anyhow::Result;
  4use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  6use gpui::{Entity, ModelContext, ModelHandle, Task};
  7use language::{Rope, TransactionId};
  8use multi_buffer;
  9use std::{cmp, future, ops::Range, sync::Arc};
 10
 11pub enum Event {
 12    Finished,
 13    Undone,
 14}
 15
 16#[derive(Clone)]
 17pub enum CodegenKind {
 18    Transform { range: Range<Anchor> },
 19    Generate { position: Anchor },
 20}
 21
 22pub struct Codegen {
 23    provider: Arc<dyn CompletionProvider>,
 24    buffer: ModelHandle<MultiBuffer>,
 25    snapshot: MultiBufferSnapshot,
 26    kind: CodegenKind,
 27    last_equal_ranges: Vec<Range<Anchor>>,
 28    transaction_id: Option<TransactionId>,
 29    error: Option<anyhow::Error>,
 30    generation: Task<()>,
 31    idle: bool,
 32    _subscription: gpui::Subscription,
 33}
 34
 35impl Entity for Codegen {
 36    type Event = Event;
 37}
 38
 39impl Codegen {
 40    pub fn new(
 41        buffer: ModelHandle<MultiBuffer>,
 42        kind: CodegenKind,
 43        provider: Arc<dyn CompletionProvider>,
 44        cx: &mut ModelContext<Self>,
 45    ) -> Self {
 46        let snapshot = buffer.read(cx).snapshot(cx);
 47        Self {
 48            provider,
 49            buffer: buffer.clone(),
 50            snapshot,
 51            kind,
 52            last_equal_ranges: Default::default(),
 53            transaction_id: Default::default(),
 54            error: Default::default(),
 55            idle: true,
 56            generation: Task::ready(()),
 57            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 58        }
 59    }
 60
 61    fn handle_buffer_event(
 62        &mut self,
 63        _buffer: ModelHandle<MultiBuffer>,
 64        event: &multi_buffer::Event,
 65        cx: &mut ModelContext<Self>,
 66    ) {
 67        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 68            if self.transaction_id == Some(*transaction_id) {
 69                self.transaction_id = None;
 70                self.generation = Task::ready(());
 71                cx.emit(Event::Undone);
 72            }
 73        }
 74    }
 75
 76    pub fn range(&self) -> Range<Anchor> {
 77        match &self.kind {
 78            CodegenKind::Transform { range } => range.clone(),
 79            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 80        }
 81    }
 82
 83    pub fn kind(&self) -> &CodegenKind {
 84        &self.kind
 85    }
 86
 87    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 88        &self.last_equal_ranges
 89    }
 90
 91    pub fn idle(&self) -> bool {
 92        self.idle
 93    }
 94
 95    pub fn error(&self) -> Option<&anyhow::Error> {
 96        self.error.as_ref()
 97    }
 98
 99    pub fn start(&mut self, prompt: Box<dyn CompletionRequest>, cx: &mut ModelContext<Self>) {
100        let range = self.range();
101        let snapshot = self.snapshot.clone();
102        let selected_text = snapshot
103            .text_for_range(range.start..range.end)
104            .collect::<Rope>();
105
106        let selection_start = range.start.to_point(&snapshot);
107        let suggested_line_indent = snapshot
108            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
109            .into_values()
110            .next()
111            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
112
113        let response = self.provider.complete(prompt);
114        self.generation = cx.spawn_weak(|this, mut cx| {
115            async move {
116                let generate = async {
117                    let mut edit_start = range.start.to_offset(&snapshot);
118
119                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
120                    let diff = cx.background().spawn(async move {
121                        let chunks = strip_markdown_codeblock(response.await?);
122                        futures::pin_mut!(chunks);
123                        let mut diff = StreamingDiff::new(selected_text.to_string());
124
125                        let mut new_text = String::new();
126                        let mut base_indent = None;
127                        let mut line_indent = None;
128                        let mut first_line = true;
129
130                        while let Some(chunk) = chunks.next().await {
131                            let chunk = chunk?;
132
133                            let mut lines = chunk.split('\n').peekable();
134                            while let Some(line) = lines.next() {
135                                new_text.push_str(line);
136                                if line_indent.is_none() {
137                                    if let Some(non_whitespace_ch_ix) =
138                                        new_text.find(|ch: char| !ch.is_whitespace())
139                                    {
140                                        line_indent = Some(non_whitespace_ch_ix);
141                                        base_indent = base_indent.or(line_indent);
142
143                                        let line_indent = line_indent.unwrap();
144                                        let base_indent = base_indent.unwrap();
145                                        let indent_delta = line_indent as i32 - base_indent as i32;
146                                        let mut corrected_indent_len = cmp::max(
147                                            0,
148                                            suggested_line_indent.len as i32 + indent_delta,
149                                        )
150                                            as usize;
151                                        if first_line {
152                                            corrected_indent_len = corrected_indent_len
153                                                .saturating_sub(selection_start.column as usize);
154                                        }
155
156                                        let indent_char = suggested_line_indent.char();
157                                        let mut indent_buffer = [0; 4];
158                                        let indent_str =
159                                            indent_char.encode_utf8(&mut indent_buffer);
160                                        new_text.replace_range(
161                                            ..line_indent,
162                                            &indent_str.repeat(corrected_indent_len),
163                                        );
164                                    }
165                                }
166
167                                if line_indent.is_some() {
168                                    hunks_tx.send(diff.push_new(&new_text)).await?;
169                                    new_text.clear();
170                                }
171
172                                if lines.peek().is_some() {
173                                    hunks_tx.send(diff.push_new("\n")).await?;
174                                    line_indent = None;
175                                    first_line = false;
176                                }
177                            }
178                        }
179                        hunks_tx.send(diff.push_new(&new_text)).await?;
180                        hunks_tx.send(diff.finish()).await?;
181
182                        anyhow::Ok(())
183                    });
184
185                    while let Some(hunks) = hunks_rx.next().await {
186                        let this = if let Some(this) = this.upgrade(&cx) {
187                            this
188                        } else {
189                            break;
190                        };
191
192                        this.update(&mut cx, |this, cx| {
193                            this.last_equal_ranges.clear();
194
195                            let transaction = this.buffer.update(cx, |buffer, cx| {
196                                // Avoid grouping assistant edits with user edits.
197                                buffer.finalize_last_transaction(cx);
198
199                                buffer.start_transaction(cx);
200                                buffer.edit(
201                                    hunks.into_iter().filter_map(|hunk| match hunk {
202                                        Hunk::Insert { text } => {
203                                            let edit_start = snapshot.anchor_after(edit_start);
204                                            Some((edit_start..edit_start, text))
205                                        }
206                                        Hunk::Remove { len } => {
207                                            let edit_end = edit_start + len;
208                                            let edit_range = snapshot.anchor_after(edit_start)
209                                                ..snapshot.anchor_before(edit_end);
210                                            edit_start = edit_end;
211                                            Some((edit_range, String::new()))
212                                        }
213                                        Hunk::Keep { len } => {
214                                            let edit_end = edit_start + len;
215                                            let edit_range = snapshot.anchor_after(edit_start)
216                                                ..snapshot.anchor_before(edit_end);
217                                            edit_start = edit_end;
218                                            this.last_equal_ranges.push(edit_range);
219                                            None
220                                        }
221                                    }),
222                                    None,
223                                    cx,
224                                );
225
226                                buffer.end_transaction(cx)
227                            });
228
229                            if let Some(transaction) = transaction {
230                                if let Some(first_transaction) = this.transaction_id {
231                                    // Group all assistant edits into the first transaction.
232                                    this.buffer.update(cx, |buffer, cx| {
233                                        buffer.merge_transactions(
234                                            transaction,
235                                            first_transaction,
236                                            cx,
237                                        )
238                                    });
239                                } else {
240                                    this.transaction_id = Some(transaction);
241                                    this.buffer.update(cx, |buffer, cx| {
242                                        buffer.finalize_last_transaction(cx)
243                                    });
244                                }
245                            }
246
247                            cx.notify();
248                        });
249                    }
250
251                    diff.await?;
252                    anyhow::Ok(())
253                };
254
255                let result = generate.await;
256                if let Some(this) = this.upgrade(&cx) {
257                    this.update(&mut cx, |this, cx| {
258                        this.last_equal_ranges.clear();
259                        this.idle = true;
260                        if let Err(error) = result {
261                            this.error = Some(error);
262                        }
263                        cx.emit(Event::Finished);
264                        cx.notify();
265                    });
266                }
267            }
268        });
269        self.error.take();
270        self.idle = false;
271        cx.notify();
272    }
273
274    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
275        if let Some(transaction_id) = self.transaction_id {
276            self.buffer
277                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
278        }
279    }
280}
281
282fn strip_markdown_codeblock(
283    stream: impl Stream<Item = Result<String>>,
284) -> impl Stream<Item = Result<String>> {
285    let mut first_line = true;
286    let mut buffer = String::new();
287    let mut starts_with_fenced_code_block = false;
288    stream.filter_map(move |chunk| {
289        let chunk = match chunk {
290            Ok(chunk) => chunk,
291            Err(err) => return future::ready(Some(Err(err))),
292        };
293        buffer.push_str(&chunk);
294
295        if first_line {
296            if buffer == "" || buffer == "`" || buffer == "``" {
297                return future::ready(None);
298            } else if buffer.starts_with("```") {
299                starts_with_fenced_code_block = true;
300                if let Some(newline_ix) = buffer.find('\n') {
301                    buffer.replace_range(..newline_ix + 1, "");
302                    first_line = false;
303                } else {
304                    return future::ready(None);
305                }
306            }
307        }
308
309        let text = if starts_with_fenced_code_block {
310            buffer
311                .strip_suffix("\n```\n")
312                .or_else(|| buffer.strip_suffix("\n```"))
313                .or_else(|| buffer.strip_suffix("\n``"))
314                .or_else(|| buffer.strip_suffix("\n`"))
315                .or_else(|| buffer.strip_suffix('\n'))
316                .unwrap_or(&buffer)
317        } else {
318            &buffer
319        };
320
321        if text.contains('\n') {
322            first_line = false;
323        }
324
325        let remainder = buffer.split_off(text.len());
326        let result = if buffer.is_empty() {
327            None
328        } else {
329            Some(Ok(buffer.clone()))
330        };
331        buffer = remainder;
332        future::ready(result)
333    })
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use ai::test::FakeCompletionProvider;
340    use futures::stream::{self};
341    use gpui::{executor::Deterministic, TestAppContext};
342    use indoc::indoc;
343    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
344    use rand::prelude::*;
345    use serde::Serialize;
346    use settings::SettingsStore;
347
348    #[derive(Serialize)]
349    pub struct DummyCompletionRequest {
350        pub name: String,
351    }
352
353    impl CompletionRequest for DummyCompletionRequest {
354        fn data(&self) -> serde_json::Result<String> {
355            serde_json::to_string(self)
356        }
357    }
358
359    #[gpui::test(iterations = 10)]
360    async fn test_transform_autoindent(
361        cx: &mut TestAppContext,
362        mut rng: StdRng,
363        deterministic: Arc<Deterministic>,
364    ) {
365        cx.set_global(cx.read(SettingsStore::test));
366        cx.update(language_settings::init);
367
368        let text = indoc! {"
369            fn main() {
370                let x = 0;
371                for _ in 0..10 {
372                    x += 1;
373                }
374            }
375        "};
376        let buffer =
377            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
378        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
379        let range = buffer.read_with(cx, |buffer, cx| {
380            let snapshot = buffer.snapshot(cx);
381            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
382        });
383        let provider = Arc::new(FakeCompletionProvider::new());
384        let codegen = cx.add_model(|cx| {
385            Codegen::new(
386                buffer.clone(),
387                CodegenKind::Transform { range },
388                provider.clone(),
389                cx,
390            )
391        });
392
393        let request = Box::new(DummyCompletionRequest {
394            name: "test".to_string(),
395        });
396        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
397
398        let mut new_text = concat!(
399            "       let mut x = 0;\n",
400            "       while x < 10 {\n",
401            "           x += 1;\n",
402            "       }",
403        );
404        while !new_text.is_empty() {
405            let max_len = cmp::min(new_text.len(), 10);
406            let len = rng.gen_range(1..=max_len);
407            let (chunk, suffix) = new_text.split_at(len);
408            provider.send_completion(chunk);
409            new_text = suffix;
410            deterministic.run_until_parked();
411        }
412        provider.finish_completion();
413        deterministic.run_until_parked();
414
415        assert_eq!(
416            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
417            indoc! {"
418                fn main() {
419                    let mut x = 0;
420                    while x < 10 {
421                        x += 1;
422                    }
423                }
424            "}
425        );
426    }
427
428    #[gpui::test(iterations = 10)]
429    async fn test_autoindent_when_generating_past_indentation(
430        cx: &mut TestAppContext,
431        mut rng: StdRng,
432        deterministic: Arc<Deterministic>,
433    ) {
434        cx.set_global(cx.read(SettingsStore::test));
435        cx.update(language_settings::init);
436
437        let text = indoc! {"
438            fn main() {
439                le
440            }
441        "};
442        let buffer =
443            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
444        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
445        let position = buffer.read_with(cx, |buffer, cx| {
446            let snapshot = buffer.snapshot(cx);
447            snapshot.anchor_before(Point::new(1, 6))
448        });
449        let provider = Arc::new(FakeCompletionProvider::new());
450        let codegen = cx.add_model(|cx| {
451            Codegen::new(
452                buffer.clone(),
453                CodegenKind::Generate { position },
454                provider.clone(),
455                cx,
456            )
457        });
458
459        let request = Box::new(DummyCompletionRequest {
460            name: "test".to_string(),
461        });
462        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
463
464        let mut new_text = concat!(
465            "t mut x = 0;\n",
466            "while x < 10 {\n",
467            "    x += 1;\n",
468            "}", //
469        );
470        while !new_text.is_empty() {
471            let max_len = cmp::min(new_text.len(), 10);
472            let len = rng.gen_range(1..=max_len);
473            let (chunk, suffix) = new_text.split_at(len);
474            provider.send_completion(chunk);
475            new_text = suffix;
476            deterministic.run_until_parked();
477        }
478        provider.finish_completion();
479        deterministic.run_until_parked();
480
481        assert_eq!(
482            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
483            indoc! {"
484                fn main() {
485                    let mut x = 0;
486                    while x < 10 {
487                        x += 1;
488                    }
489                }
490            "}
491        );
492    }
493
494    #[gpui::test(iterations = 10)]
495    async fn test_autoindent_when_generating_before_indentation(
496        cx: &mut TestAppContext,
497        mut rng: StdRng,
498        deterministic: Arc<Deterministic>,
499    ) {
500        cx.set_global(cx.read(SettingsStore::test));
501        cx.update(language_settings::init);
502
503        let text = concat!(
504            "fn main() {\n",
505            "  \n",
506            "}\n" //
507        );
508        let buffer =
509            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
510        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
511        let position = buffer.read_with(cx, |buffer, cx| {
512            let snapshot = buffer.snapshot(cx);
513            snapshot.anchor_before(Point::new(1, 2))
514        });
515        let provider = Arc::new(FakeCompletionProvider::new());
516        let codegen = cx.add_model(|cx| {
517            Codegen::new(
518                buffer.clone(),
519                CodegenKind::Generate { position },
520                provider.clone(),
521                cx,
522            )
523        });
524
525        let request = Box::new(DummyCompletionRequest {
526            name: "test".to_string(),
527        });
528        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
529
530        let mut new_text = concat!(
531            "let mut x = 0;\n",
532            "while x < 10 {\n",
533            "    x += 1;\n",
534            "}", //
535        );
536        while !new_text.is_empty() {
537            let max_len = cmp::min(new_text.len(), 10);
538            let len = rng.gen_range(1..=max_len);
539            let (chunk, suffix) = new_text.split_at(len);
540            provider.send_completion(chunk);
541            new_text = suffix;
542            deterministic.run_until_parked();
543        }
544        provider.finish_completion();
545        deterministic.run_until_parked();
546
547        assert_eq!(
548            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
549            indoc! {"
550                fn main() {
551                    let mut x = 0;
552                    while x < 10 {
553                        x += 1;
554                    }
555                }
556            "}
557        );
558    }
559
560    #[gpui::test]
561    async fn test_strip_markdown_codeblock() {
562        assert_eq!(
563            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
564                .map(|chunk| chunk.unwrap())
565                .collect::<String>()
566                .await,
567            "Lorem ipsum dolor"
568        );
569        assert_eq!(
570            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
571                .map(|chunk| chunk.unwrap())
572                .collect::<String>()
573                .await,
574            "Lorem ipsum dolor"
575        );
576        assert_eq!(
577            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
578                .map(|chunk| chunk.unwrap())
579                .collect::<String>()
580                .await,
581            "Lorem ipsum dolor"
582        );
583        assert_eq!(
584            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
585                .map(|chunk| chunk.unwrap())
586                .collect::<String>()
587                .await,
588            "Lorem ipsum dolor"
589        );
590        assert_eq!(
591            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
592                .map(|chunk| chunk.unwrap())
593                .collect::<String>()
594                .await,
595            "```js\nLorem ipsum dolor\n```"
596        );
597        assert_eq!(
598            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
599                .map(|chunk| chunk.unwrap())
600                .collect::<String>()
601                .await,
602            "``\nLorem ipsum dolor\n```"
603        );
604
605        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
606            stream::iter(
607                text.chars()
608                    .collect::<Vec<_>>()
609                    .chunks(size)
610                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
611                    .collect::<Vec<_>>(),
612            )
613        }
614    }
615
616    fn rust_lang() -> Language {
617        Language::new(
618            LanguageConfig {
619                name: "Rust".into(),
620                path_suffixes: vec!["rs".to_string()],
621                ..Default::default()
622            },
623            Some(tree_sitter_rust::language()),
624        )
625        .with_indents_query(
626            r#"
627            (call_expression) @indent
628            (field_expression) @indent
629            (_ "(" ")" @end) @indent
630            (_ "{" "}" @end) @indent
631            "#,
632        )
633        .unwrap()
634    }
635}