codegen.rs

  1use crate::streaming_diff::{Hunk, StreamingDiff};
  2use ai::completion::{CompletionProvider, OpenAIRequest};
  3use anyhow::Result;
  4use editor::{
  5    multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
  6};
  7use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  8use gpui::{Entity, ModelContext, ModelHandle, Task};
  9use language::{Rope, TransactionId};
 10use std::{cmp, future, ops::Range, sync::Arc};
 11
 12pub enum Event {
 13    Finished,
 14    Undone,
 15}
 16
 17#[derive(Clone)]
 18pub enum CodegenKind {
 19    Transform { range: Range<Anchor> },
 20    Generate { position: Anchor },
 21}
 22
 23pub struct Codegen {
 24    provider: Arc<dyn CompletionProvider>,
 25    buffer: ModelHandle<MultiBuffer>,
 26    snapshot: MultiBufferSnapshot,
 27    kind: CodegenKind,
 28    last_equal_ranges: Vec<Range<Anchor>>,
 29    transaction_id: Option<TransactionId>,
 30    error: Option<anyhow::Error>,
 31    generation: Task<()>,
 32    idle: bool,
 33    _subscription: gpui::Subscription,
 34}
 35
 36impl Entity for Codegen {
 37    type Event = Event;
 38}
 39
 40impl Codegen {
 41    pub fn new(
 42        buffer: ModelHandle<MultiBuffer>,
 43        mut kind: CodegenKind,
 44        provider: Arc<dyn CompletionProvider>,
 45        cx: &mut ModelContext<Self>,
 46    ) -> Self {
 47        let snapshot = buffer.read(cx).snapshot(cx);
 48        match &mut kind {
 49            CodegenKind::Transform { range } => {
 50                let mut point_range = range.to_point(&snapshot);
 51                point_range.start.column = 0;
 52                if point_range.end.column > 0 || point_range.start.row == point_range.end.row {
 53                    point_range.end.column = snapshot.line_len(point_range.end.row);
 54                }
 55                range.start = snapshot.anchor_before(point_range.start);
 56                range.end = snapshot.anchor_after(point_range.end);
 57            }
 58            CodegenKind::Generate { position } => {
 59                *position = position.bias_right(&snapshot);
 60            }
 61        }
 62
 63        Self {
 64            provider,
 65            buffer: buffer.clone(),
 66            snapshot,
 67            kind,
 68            last_equal_ranges: Default::default(),
 69            transaction_id: Default::default(),
 70            error: Default::default(),
 71            idle: true,
 72            generation: Task::ready(()),
 73            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 74        }
 75    }
 76
 77    fn handle_buffer_event(
 78        &mut self,
 79        _buffer: ModelHandle<MultiBuffer>,
 80        event: &multi_buffer::Event,
 81        cx: &mut ModelContext<Self>,
 82    ) {
 83        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 84            if self.transaction_id == Some(*transaction_id) {
 85                self.transaction_id = None;
 86                self.generation = Task::ready(());
 87                cx.emit(Event::Undone);
 88            }
 89        }
 90    }
 91
 92    pub fn range(&self) -> Range<Anchor> {
 93        match &self.kind {
 94            CodegenKind::Transform { range } => range.clone(),
 95            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 96        }
 97    }
 98
 99    pub fn kind(&self) -> &CodegenKind {
100        &self.kind
101    }
102
103    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
104        &self.last_equal_ranges
105    }
106
107    pub fn idle(&self) -> bool {
108        self.idle
109    }
110
111    pub fn error(&self) -> Option<&anyhow::Error> {
112        self.error.as_ref()
113    }
114
115    pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
116        let range = self.range();
117        let snapshot = self.snapshot.clone();
118        let selected_text = snapshot
119            .text_for_range(range.start..range.end)
120            .collect::<Rope>();
121
122        let selection_start = range.start.to_point(&snapshot);
123        let suggested_line_indent = snapshot
124            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
125            .into_values()
126            .next()
127            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
128
129        let response = self.provider.complete(prompt);
130        self.generation = cx.spawn_weak(|this, mut cx| {
131            async move {
132                let generate = async {
133                    let mut edit_start = range.start.to_offset(&snapshot);
134
135                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
136                    let diff = cx.background().spawn(async move {
137                        let chunks = strip_markdown_codeblock(response.await?);
138                        futures::pin_mut!(chunks);
139                        let mut diff = StreamingDiff::new(selected_text.to_string());
140
141                        let mut new_text = String::new();
142                        let mut base_indent = None;
143                        let mut line_indent = None;
144                        let mut first_line = true;
145
146                        while let Some(chunk) = chunks.next().await {
147                            let chunk = chunk?;
148
149                            let mut lines = chunk.split('\n').peekable();
150                            while let Some(line) = lines.next() {
151                                new_text.push_str(line);
152                                if line_indent.is_none() {
153                                    if let Some(non_whitespace_ch_ix) =
154                                        new_text.find(|ch: char| !ch.is_whitespace())
155                                    {
156                                        line_indent = Some(non_whitespace_ch_ix);
157                                        base_indent = base_indent.or(line_indent);
158
159                                        let line_indent = line_indent.unwrap();
160                                        let base_indent = base_indent.unwrap();
161                                        let indent_delta = line_indent as i32 - base_indent as i32;
162                                        let mut corrected_indent_len = cmp::max(
163                                            0,
164                                            suggested_line_indent.len as i32 + indent_delta,
165                                        )
166                                            as usize;
167                                        if first_line {
168                                            corrected_indent_len = corrected_indent_len
169                                                .saturating_sub(selection_start.column as usize);
170                                        }
171
172                                        let indent_char = suggested_line_indent.char();
173                                        let mut indent_buffer = [0; 4];
174                                        let indent_str =
175                                            indent_char.encode_utf8(&mut indent_buffer);
176                                        new_text.replace_range(
177                                            ..line_indent,
178                                            &indent_str.repeat(corrected_indent_len),
179                                        );
180                                    }
181                                }
182
183                                if line_indent.is_some() {
184                                    hunks_tx.send(diff.push_new(&new_text)).await?;
185                                    new_text.clear();
186                                }
187
188                                if lines.peek().is_some() {
189                                    hunks_tx.send(diff.push_new("\n")).await?;
190                                    line_indent = None;
191                                    first_line = false;
192                                }
193                            }
194                        }
195                        hunks_tx.send(diff.push_new(&new_text)).await?;
196                        hunks_tx.send(diff.finish()).await?;
197
198                        anyhow::Ok(())
199                    });
200
201                    while let Some(hunks) = hunks_rx.next().await {
202                        let this = if let Some(this) = this.upgrade(&cx) {
203                            this
204                        } else {
205                            break;
206                        };
207
208                        this.update(&mut cx, |this, cx| {
209                            this.last_equal_ranges.clear();
210
211                            let transaction = this.buffer.update(cx, |buffer, cx| {
212                                // Avoid grouping assistant edits with user edits.
213                                buffer.finalize_last_transaction(cx);
214
215                                buffer.start_transaction(cx);
216                                buffer.edit(
217                                    hunks.into_iter().filter_map(|hunk| match hunk {
218                                        Hunk::Insert { text } => {
219                                            let edit_start = snapshot.anchor_after(edit_start);
220                                            Some((edit_start..edit_start, text))
221                                        }
222                                        Hunk::Remove { len } => {
223                                            let edit_end = edit_start + len;
224                                            let edit_range = snapshot.anchor_after(edit_start)
225                                                ..snapshot.anchor_before(edit_end);
226                                            edit_start = edit_end;
227                                            Some((edit_range, String::new()))
228                                        }
229                                        Hunk::Keep { len } => {
230                                            let edit_end = edit_start + len;
231                                            let edit_range = snapshot.anchor_after(edit_start)
232                                                ..snapshot.anchor_before(edit_end);
233                                            edit_start = edit_end;
234                                            this.last_equal_ranges.push(edit_range);
235                                            None
236                                        }
237                                    }),
238                                    None,
239                                    cx,
240                                );
241
242                                buffer.end_transaction(cx)
243                            });
244
245                            if let Some(transaction) = transaction {
246                                if let Some(first_transaction) = this.transaction_id {
247                                    // Group all assistant edits into the first transaction.
248                                    this.buffer.update(cx, |buffer, cx| {
249                                        buffer.merge_transactions(
250                                            transaction,
251                                            first_transaction,
252                                            cx,
253                                        )
254                                    });
255                                } else {
256                                    this.transaction_id = Some(transaction);
257                                    this.buffer.update(cx, |buffer, cx| {
258                                        buffer.finalize_last_transaction(cx)
259                                    });
260                                }
261                            }
262
263                            cx.notify();
264                        });
265                    }
266
267                    diff.await?;
268                    anyhow::Ok(())
269                };
270
271                let result = generate.await;
272                if let Some(this) = this.upgrade(&cx) {
273                    this.update(&mut cx, |this, cx| {
274                        this.last_equal_ranges.clear();
275                        this.idle = true;
276                        if let Err(error) = result {
277                            this.error = Some(error);
278                        }
279                        cx.emit(Event::Finished);
280                        cx.notify();
281                    });
282                }
283            }
284        });
285        self.error.take();
286        self.idle = false;
287        cx.notify();
288    }
289
290    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
291        if let Some(transaction_id) = self.transaction_id {
292            self.buffer
293                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
294        }
295    }
296}
297
298fn strip_markdown_codeblock(
299    stream: impl Stream<Item = Result<String>>,
300) -> impl Stream<Item = Result<String>> {
301    let mut first_line = true;
302    let mut buffer = String::new();
303    let mut starts_with_fenced_code_block = false;
304    stream.filter_map(move |chunk| {
305        let chunk = match chunk {
306            Ok(chunk) => chunk,
307            Err(err) => return future::ready(Some(Err(err))),
308        };
309        buffer.push_str(&chunk);
310
311        if first_line {
312            if buffer == "" || buffer == "`" || buffer == "``" {
313                return future::ready(None);
314            } else if buffer.starts_with("```") {
315                starts_with_fenced_code_block = true;
316                if let Some(newline_ix) = buffer.find('\n') {
317                    buffer.replace_range(..newline_ix + 1, "");
318                    first_line = false;
319                } else {
320                    return future::ready(None);
321                }
322            }
323        }
324
325        let text = if starts_with_fenced_code_block {
326            buffer
327                .strip_suffix("\n```\n")
328                .or_else(|| buffer.strip_suffix("\n```"))
329                .or_else(|| buffer.strip_suffix("\n``"))
330                .or_else(|| buffer.strip_suffix("\n`"))
331                .or_else(|| buffer.strip_suffix('\n'))
332                .unwrap_or(&buffer)
333        } else {
334            &buffer
335        };
336
337        if text.contains('\n') {
338            first_line = false;
339        }
340
341        let remainder = buffer.split_off(text.len());
342        let result = if buffer.is_empty() {
343            None
344        } else {
345            Some(Ok(buffer.clone()))
346        };
347        buffer = remainder;
348        future::ready(result)
349    })
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355    use futures::{
356        future::BoxFuture,
357        stream::{self, BoxStream},
358    };
359    use gpui::{executor::Deterministic, TestAppContext};
360    use indoc::indoc;
361    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
362    use parking_lot::Mutex;
363    use rand::prelude::*;
364    use settings::SettingsStore;
365    use smol::future::FutureExt;
366
367    #[gpui::test(iterations = 10)]
368    async fn test_transform_autoindent(
369        cx: &mut TestAppContext,
370        mut rng: StdRng,
371        deterministic: Arc<Deterministic>,
372    ) {
373        cx.set_global(cx.read(SettingsStore::test));
374        cx.update(language_settings::init);
375
376        let text = indoc! {"
377            fn main() {
378                let x = 0;
379                for _ in 0..10 {
380                    x += 1;
381                }
382            }
383        "};
384        let buffer =
385            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
386        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
387        let range = buffer.read_with(cx, |buffer, cx| {
388            let snapshot = buffer.snapshot(cx);
389            snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
390        });
391        let provider = Arc::new(TestCompletionProvider::new());
392        let codegen = cx.add_model(|cx| {
393            Codegen::new(
394                buffer.clone(),
395                CodegenKind::Transform { range },
396                provider.clone(),
397                cx,
398            )
399        });
400        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
401
402        let mut new_text = concat!(
403            "       let mut x = 0;\n",
404            "       while x < 10 {\n",
405            "           x += 1;\n",
406            "       }",
407        );
408        while !new_text.is_empty() {
409            let max_len = cmp::min(new_text.len(), 10);
410            let len = rng.gen_range(1..=max_len);
411            let (chunk, suffix) = new_text.split_at(len);
412            provider.send_completion(chunk);
413            new_text = suffix;
414            deterministic.run_until_parked();
415        }
416        provider.finish_completion();
417        deterministic.run_until_parked();
418
419        assert_eq!(
420            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
421            indoc! {"
422                fn main() {
423                    let mut x = 0;
424                    while x < 10 {
425                        x += 1;
426                    }
427                }
428            "}
429        );
430    }
431
432    #[gpui::test(iterations = 10)]
433    async fn test_autoindent_when_generating_past_indentation(
434        cx: &mut TestAppContext,
435        mut rng: StdRng,
436        deterministic: Arc<Deterministic>,
437    ) {
438        cx.set_global(cx.read(SettingsStore::test));
439        cx.update(language_settings::init);
440
441        let text = indoc! {"
442            fn main() {
443                le
444            }
445        "};
446        let buffer =
447            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
448        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
449        let position = buffer.read_with(cx, |buffer, cx| {
450            let snapshot = buffer.snapshot(cx);
451            snapshot.anchor_before(Point::new(1, 6))
452        });
453        let provider = Arc::new(TestCompletionProvider::new());
454        let codegen = cx.add_model(|cx| {
455            Codegen::new(
456                buffer.clone(),
457                CodegenKind::Generate { position },
458                provider.clone(),
459                cx,
460            )
461        });
462        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
463
464        let mut new_text = concat!(
465            "t mut x = 0;\n",
466            "while x < 10 {\n",
467            "    x += 1;\n",
468            "}", //
469        );
470        while !new_text.is_empty() {
471            let max_len = cmp::min(new_text.len(), 10);
472            let len = rng.gen_range(1..=max_len);
473            let (chunk, suffix) = new_text.split_at(len);
474            provider.send_completion(chunk);
475            new_text = suffix;
476            deterministic.run_until_parked();
477        }
478        provider.finish_completion();
479        deterministic.run_until_parked();
480
481        assert_eq!(
482            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
483            indoc! {"
484                fn main() {
485                    let mut x = 0;
486                    while x < 10 {
487                        x += 1;
488                    }
489                }
490            "}
491        );
492    }
493
494    #[gpui::test(iterations = 10)]
495    async fn test_autoindent_when_generating_before_indentation(
496        cx: &mut TestAppContext,
497        mut rng: StdRng,
498        deterministic: Arc<Deterministic>,
499    ) {
500        cx.set_global(cx.read(SettingsStore::test));
501        cx.update(language_settings::init);
502
503        let text = concat!(
504            "fn main() {\n",
505            "  \n",
506            "}\n" //
507        );
508        let buffer =
509            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
510        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
511        let position = buffer.read_with(cx, |buffer, cx| {
512            let snapshot = buffer.snapshot(cx);
513            snapshot.anchor_before(Point::new(1, 2))
514        });
515        let provider = Arc::new(TestCompletionProvider::new());
516        let codegen = cx.add_model(|cx| {
517            Codegen::new(
518                buffer.clone(),
519                CodegenKind::Generate { position },
520                provider.clone(),
521                cx,
522            )
523        });
524        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
525
526        let mut new_text = concat!(
527            "let mut x = 0;\n",
528            "while x < 10 {\n",
529            "    x += 1;\n",
530            "}", //
531        );
532        while !new_text.is_empty() {
533            let max_len = cmp::min(new_text.len(), 10);
534            let len = rng.gen_range(1..=max_len);
535            let (chunk, suffix) = new_text.split_at(len);
536            provider.send_completion(chunk);
537            new_text = suffix;
538            deterministic.run_until_parked();
539        }
540        provider.finish_completion();
541        deterministic.run_until_parked();
542
543        assert_eq!(
544            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
545            indoc! {"
546                fn main() {
547                    let mut x = 0;
548                    while x < 10 {
549                        x += 1;
550                    }
551                }
552            "}
553        );
554    }
555
556    #[gpui::test]
557    async fn test_strip_markdown_codeblock() {
558        assert_eq!(
559            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
560                .map(|chunk| chunk.unwrap())
561                .collect::<String>()
562                .await,
563            "Lorem ipsum dolor"
564        );
565        assert_eq!(
566            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
567                .map(|chunk| chunk.unwrap())
568                .collect::<String>()
569                .await,
570            "Lorem ipsum dolor"
571        );
572        assert_eq!(
573            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
574                .map(|chunk| chunk.unwrap())
575                .collect::<String>()
576                .await,
577            "Lorem ipsum dolor"
578        );
579        assert_eq!(
580            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
581                .map(|chunk| chunk.unwrap())
582                .collect::<String>()
583                .await,
584            "Lorem ipsum dolor"
585        );
586        assert_eq!(
587            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
588                .map(|chunk| chunk.unwrap())
589                .collect::<String>()
590                .await,
591            "```js\nLorem ipsum dolor\n```"
592        );
593        assert_eq!(
594            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
595                .map(|chunk| chunk.unwrap())
596                .collect::<String>()
597                .await,
598            "``\nLorem ipsum dolor\n```"
599        );
600
601        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
602            stream::iter(
603                text.chars()
604                    .collect::<Vec<_>>()
605                    .chunks(size)
606                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
607                    .collect::<Vec<_>>(),
608            )
609        }
610    }
611
612    struct TestCompletionProvider {
613        last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
614    }
615
616    impl TestCompletionProvider {
617        fn new() -> Self {
618            Self {
619                last_completion_tx: Mutex::new(None),
620            }
621        }
622
623        fn send_completion(&self, completion: impl Into<String>) {
624            let mut tx = self.last_completion_tx.lock();
625            tx.as_mut().unwrap().try_send(completion.into()).unwrap();
626        }
627
628        fn finish_completion(&self) {
629            self.last_completion_tx.lock().take().unwrap();
630        }
631    }
632
633    impl CompletionProvider for TestCompletionProvider {
634        fn complete(
635            &self,
636            _prompt: OpenAIRequest,
637        ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
638            let (tx, rx) = mpsc::channel(1);
639            *self.last_completion_tx.lock() = Some(tx);
640            async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
641        }
642    }
643
644    fn rust_lang() -> Language {
645        Language::new(
646            LanguageConfig {
647                name: "Rust".into(),
648                path_suffixes: vec!["rs".to_string()],
649                ..Default::default()
650            },
651            Some(tree_sitter_rust::language()),
652        )
653        .with_indents_query(
654            r#"
655            (call_expression) @indent
656            (field_expression) @indent
657            (_ "(" ")" @end) @indent
658            (_ "{" "}" @end) @indent
659            "#,
660        )
661        .unwrap()
662    }
663}