codegen.rs

  1use crate::streaming_diff::{Hunk, StreamingDiff};
  2use ai::completion::{CompletionProvider, OpenAIRequest};
  3use anyhow::Result;
  4use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  5use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  6use gpui::{Entity, ModelContext, ModelHandle, Task};
  7use language::{Rope, TransactionId};
  8use multi_buffer;
  9use std::{cmp, future, ops::Range, sync::Arc};
 10
 11pub enum Event {
 12    Finished,
 13    Undone,
 14}
 15
 16#[derive(Clone)]
 17pub enum CodegenKind {
 18    Transform { range: Range<Anchor> },
 19    Generate { position: Anchor },
 20}
 21
 22pub struct Codegen {
 23    provider: Arc<dyn CompletionProvider>,
 24    buffer: ModelHandle<MultiBuffer>,
 25    snapshot: MultiBufferSnapshot,
 26    kind: CodegenKind,
 27    last_equal_ranges: Vec<Range<Anchor>>,
 28    transaction_id: Option<TransactionId>,
 29    error: Option<anyhow::Error>,
 30    generation: Task<()>,
 31    idle: bool,
 32    _subscription: gpui::Subscription,
 33}
 34
 35impl Entity for Codegen {
 36    type Event = Event;
 37}
 38
 39impl Codegen {
 40    pub fn new(
 41        buffer: ModelHandle<MultiBuffer>,
 42        kind: CodegenKind,
 43        provider: Arc<dyn CompletionProvider>,
 44        cx: &mut ModelContext<Self>,
 45    ) -> Self {
 46        let snapshot = buffer.read(cx).snapshot(cx);
 47        Self {
 48            provider,
 49            buffer: buffer.clone(),
 50            snapshot,
 51            kind,
 52            last_equal_ranges: Default::default(),
 53            transaction_id: Default::default(),
 54            error: Default::default(),
 55            idle: true,
 56            generation: Task::ready(()),
 57            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 58        }
 59    }
 60
 61    fn handle_buffer_event(
 62        &mut self,
 63        _buffer: ModelHandle<MultiBuffer>,
 64        event: &multi_buffer::Event,
 65        cx: &mut ModelContext<Self>,
 66    ) {
 67        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 68            if self.transaction_id == Some(*transaction_id) {
 69                self.transaction_id = None;
 70                self.generation = Task::ready(());
 71                cx.emit(Event::Undone);
 72            }
 73        }
 74    }
 75
 76    pub fn range(&self) -> Range<Anchor> {
 77        match &self.kind {
 78            CodegenKind::Transform { range } => range.clone(),
 79            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 80        }
 81    }
 82
 83    pub fn kind(&self) -> &CodegenKind {
 84        &self.kind
 85    }
 86
 87    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 88        &self.last_equal_ranges
 89    }
 90
 91    pub fn idle(&self) -> bool {
 92        self.idle
 93    }
 94
 95    pub fn error(&self) -> Option<&anyhow::Error> {
 96        self.error.as_ref()
 97    }
 98
 99    pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
100        let range = self.range();
101        let snapshot = self.snapshot.clone();
102        let selected_text = snapshot
103            .text_for_range(range.start..range.end)
104            .collect::<Rope>();
105
106        let selection_start = range.start.to_point(&snapshot);
107        let suggested_line_indent = snapshot
108            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
109            .into_values()
110            .next()
111            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
112
113        let response = self.provider.complete(prompt);
114        self.generation = cx.spawn_weak(|this, mut cx| {
115            async move {
116                let generate = async {
117                    let mut edit_start = range.start.to_offset(&snapshot);
118
119                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
120                    let diff = cx.background().spawn(async move {
121                        let chunks = strip_markdown_codeblock(response.await?);
122                        futures::pin_mut!(chunks);
123                        let mut diff = StreamingDiff::new(selected_text.to_string());
124
125                        let mut new_text = String::new();
126                        let mut base_indent = None;
127                        let mut line_indent = None;
128                        let mut first_line = true;
129
130                        while let Some(chunk) = chunks.next().await {
131                            let chunk = chunk?;
132
133                            let mut lines = chunk.split('\n').peekable();
134                            while let Some(line) = lines.next() {
135                                new_text.push_str(line);
136                                if line_indent.is_none() {
137                                    if let Some(non_whitespace_ch_ix) =
138                                        new_text.find(|ch: char| !ch.is_whitespace())
139                                    {
140                                        line_indent = Some(non_whitespace_ch_ix);
141                                        base_indent = base_indent.or(line_indent);
142
143                                        let line_indent = line_indent.unwrap();
144                                        let base_indent = base_indent.unwrap();
145                                        let indent_delta = line_indent as i32 - base_indent as i32;
146                                        let mut corrected_indent_len = cmp::max(
147                                            0,
148                                            suggested_line_indent.len as i32 + indent_delta,
149                                        )
150                                            as usize;
151                                        if first_line {
152                                            corrected_indent_len = corrected_indent_len
153                                                .saturating_sub(selection_start.column as usize);
154                                        }
155
156                                        let indent_char = suggested_line_indent.char();
157                                        let mut indent_buffer = [0; 4];
158                                        let indent_str =
159                                            indent_char.encode_utf8(&mut indent_buffer);
160                                        new_text.replace_range(
161                                            ..line_indent,
162                                            &indent_str.repeat(corrected_indent_len),
163                                        );
164                                    }
165                                }
166
167                                if line_indent.is_some() {
168                                    hunks_tx.send(diff.push_new(&new_text)).await?;
169                                    new_text.clear();
170                                }
171
172                                if lines.peek().is_some() {
173                                    hunks_tx.send(diff.push_new("\n")).await?;
174                                    line_indent = None;
175                                    first_line = false;
176                                }
177                            }
178                        }
179                        hunks_tx.send(diff.push_new(&new_text)).await?;
180                        hunks_tx.send(diff.finish()).await?;
181
182                        anyhow::Ok(())
183                    });
184
185                    while let Some(hunks) = hunks_rx.next().await {
186                        let this = if let Some(this) = this.upgrade(&cx) {
187                            this
188                        } else {
189                            break;
190                        };
191
192                        this.update(&mut cx, |this, cx| {
193                            this.last_equal_ranges.clear();
194
195                            let transaction = this.buffer.update(cx, |buffer, cx| {
196                                // Avoid grouping assistant edits with user edits.
197                                buffer.finalize_last_transaction(cx);
198
199                                buffer.start_transaction(cx);
200                                buffer.edit(
201                                    hunks.into_iter().filter_map(|hunk| match hunk {
202                                        Hunk::Insert { text } => {
203                                            let edit_start = snapshot.anchor_after(edit_start);
204                                            Some((edit_start..edit_start, text))
205                                        }
206                                        Hunk::Remove { len } => {
207                                            let edit_end = edit_start + len;
208                                            let edit_range = snapshot.anchor_after(edit_start)
209                                                ..snapshot.anchor_before(edit_end);
210                                            edit_start = edit_end;
211                                            Some((edit_range, String::new()))
212                                        }
213                                        Hunk::Keep { len } => {
214                                            let edit_end = edit_start + len;
215                                            let edit_range = snapshot.anchor_after(edit_start)
216                                                ..snapshot.anchor_before(edit_end);
217                                            edit_start = edit_end;
218                                            this.last_equal_ranges.push(edit_range);
219                                            None
220                                        }
221                                    }),
222                                    None,
223                                    cx,
224                                );
225
226                                buffer.end_transaction(cx)
227                            });
228
229                            if let Some(transaction) = transaction {
230                                if let Some(first_transaction) = this.transaction_id {
231                                    // Group all assistant edits into the first transaction.
232                                    this.buffer.update(cx, |buffer, cx| {
233                                        buffer.merge_transactions(
234                                            transaction,
235                                            first_transaction,
236                                            cx,
237                                        )
238                                    });
239                                } else {
240                                    this.transaction_id = Some(transaction);
241                                    this.buffer.update(cx, |buffer, cx| {
242                                        buffer.finalize_last_transaction(cx)
243                                    });
244                                }
245                            }
246
247                            cx.notify();
248                        });
249                    }
250
251                    diff.await?;
252                    anyhow::Ok(())
253                };
254
255                let result = generate.await;
256                if let Some(this) = this.upgrade(&cx) {
257                    this.update(&mut cx, |this, cx| {
258                        this.last_equal_ranges.clear();
259                        this.idle = true;
260                        if let Err(error) = result {
261                            this.error = Some(error);
262                        }
263                        cx.emit(Event::Finished);
264                        cx.notify();
265                    });
266                }
267            }
268        });
269        self.error.take();
270        self.idle = false;
271        cx.notify();
272    }
273
274    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
275        if let Some(transaction_id) = self.transaction_id {
276            self.buffer
277                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
278        }
279    }
280}
281
282fn strip_markdown_codeblock(
283    stream: impl Stream<Item = Result<String>>,
284) -> impl Stream<Item = Result<String>> {
285    let mut first_line = true;
286    let mut buffer = String::new();
287    let mut starts_with_fenced_code_block = false;
288    stream.filter_map(move |chunk| {
289        let chunk = match chunk {
290            Ok(chunk) => chunk,
291            Err(err) => return future::ready(Some(Err(err))),
292        };
293        buffer.push_str(&chunk);
294
295        if first_line {
296            if buffer == "" || buffer == "`" || buffer == "``" {
297                return future::ready(None);
298            } else if buffer.starts_with("```") {
299                starts_with_fenced_code_block = true;
300                if let Some(newline_ix) = buffer.find('\n') {
301                    buffer.replace_range(..newline_ix + 1, "");
302                    first_line = false;
303                } else {
304                    return future::ready(None);
305                }
306            }
307        }
308
309        let text = if starts_with_fenced_code_block {
310            buffer
311                .strip_suffix("\n```\n")
312                .or_else(|| buffer.strip_suffix("\n```"))
313                .or_else(|| buffer.strip_suffix("\n``"))
314                .or_else(|| buffer.strip_suffix("\n`"))
315                .or_else(|| buffer.strip_suffix('\n'))
316                .unwrap_or(&buffer)
317        } else {
318            &buffer
319        };
320
321        if text.contains('\n') {
322            first_line = false;
323        }
324
325        let remainder = buffer.split_off(text.len());
326        let result = if buffer.is_empty() {
327            None
328        } else {
329            Some(Ok(buffer.clone()))
330        };
331        buffer = remainder;
332        future::ready(result)
333    })
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use futures::{
340        future::BoxFuture,
341        stream::{self, BoxStream},
342    };
343    use gpui::{executor::Deterministic, TestAppContext};
344    use indoc::indoc;
345    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
346    use parking_lot::Mutex;
347    use rand::prelude::*;
348    use settings::SettingsStore;
349    use smol::future::FutureExt;
350
351    #[gpui::test(iterations = 10)]
352    async fn test_transform_autoindent(
353        cx: &mut TestAppContext,
354        mut rng: StdRng,
355        deterministic: Arc<Deterministic>,
356    ) {
357        cx.set_global(cx.read(SettingsStore::test));
358        cx.update(language_settings::init);
359
360        let text = indoc! {"
361            fn main() {
362                let x = 0;
363                for _ in 0..10 {
364                    x += 1;
365                }
366            }
367        "};
368        let buffer =
369            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
370        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
371        let range = buffer.read_with(cx, |buffer, cx| {
372            let snapshot = buffer.snapshot(cx);
373            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
374        });
375        let provider = Arc::new(TestCompletionProvider::new());
376        let codegen = cx.add_model(|cx| {
377            Codegen::new(
378                buffer.clone(),
379                CodegenKind::Transform { range },
380                provider.clone(),
381                cx,
382            )
383        });
384        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
385
386        let mut new_text = concat!(
387            "       let mut x = 0;\n",
388            "       while x < 10 {\n",
389            "           x += 1;\n",
390            "       }",
391        );
392        while !new_text.is_empty() {
393            let max_len = cmp::min(new_text.len(), 10);
394            let len = rng.gen_range(1..=max_len);
395            let (chunk, suffix) = new_text.split_at(len);
396            provider.send_completion(chunk);
397            new_text = suffix;
398            deterministic.run_until_parked();
399        }
400        provider.finish_completion();
401        deterministic.run_until_parked();
402
403        assert_eq!(
404            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
405            indoc! {"
406                fn main() {
407                    let mut x = 0;
408                    while x < 10 {
409                        x += 1;
410                    }
411                }
412            "}
413        );
414    }
415
416    #[gpui::test(iterations = 10)]
417    async fn test_autoindent_when_generating_past_indentation(
418        cx: &mut TestAppContext,
419        mut rng: StdRng,
420        deterministic: Arc<Deterministic>,
421    ) {
422        cx.set_global(cx.read(SettingsStore::test));
423        cx.update(language_settings::init);
424
425        let text = indoc! {"
426            fn main() {
427                le
428            }
429        "};
430        let buffer =
431            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
432        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
433        let position = buffer.read_with(cx, |buffer, cx| {
434            let snapshot = buffer.snapshot(cx);
435            snapshot.anchor_before(Point::new(1, 6))
436        });
437        let provider = Arc::new(TestCompletionProvider::new());
438        let codegen = cx.add_model(|cx| {
439            Codegen::new(
440                buffer.clone(),
441                CodegenKind::Generate { position },
442                provider.clone(),
443                cx,
444            )
445        });
446        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
447
448        let mut new_text = concat!(
449            "t mut x = 0;\n",
450            "while x < 10 {\n",
451            "    x += 1;\n",
452            "}", //
453        );
454        while !new_text.is_empty() {
455            let max_len = cmp::min(new_text.len(), 10);
456            let len = rng.gen_range(1..=max_len);
457            let (chunk, suffix) = new_text.split_at(len);
458            provider.send_completion(chunk);
459            new_text = suffix;
460            deterministic.run_until_parked();
461        }
462        provider.finish_completion();
463        deterministic.run_until_parked();
464
465        assert_eq!(
466            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
467            indoc! {"
468                fn main() {
469                    let mut x = 0;
470                    while x < 10 {
471                        x += 1;
472                    }
473                }
474            "}
475        );
476    }
477
478    #[gpui::test(iterations = 10)]
479    async fn test_autoindent_when_generating_before_indentation(
480        cx: &mut TestAppContext,
481        mut rng: StdRng,
482        deterministic: Arc<Deterministic>,
483    ) {
484        cx.set_global(cx.read(SettingsStore::test));
485        cx.update(language_settings::init);
486
487        let text = concat!(
488            "fn main() {\n",
489            "  \n",
490            "}\n" //
491        );
492        let buffer =
493            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
494        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
495        let position = buffer.read_with(cx, |buffer, cx| {
496            let snapshot = buffer.snapshot(cx);
497            snapshot.anchor_before(Point::new(1, 2))
498        });
499        let provider = Arc::new(TestCompletionProvider::new());
500        let codegen = cx.add_model(|cx| {
501            Codegen::new(
502                buffer.clone(),
503                CodegenKind::Generate { position },
504                provider.clone(),
505                cx,
506            )
507        });
508        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
509
510        let mut new_text = concat!(
511            "let mut x = 0;\n",
512            "while x < 10 {\n",
513            "    x += 1;\n",
514            "}", //
515        );
516        while !new_text.is_empty() {
517            let max_len = cmp::min(new_text.len(), 10);
518            let len = rng.gen_range(1..=max_len);
519            let (chunk, suffix) = new_text.split_at(len);
520            provider.send_completion(chunk);
521            new_text = suffix;
522            deterministic.run_until_parked();
523        }
524        provider.finish_completion();
525        deterministic.run_until_parked();
526
527        assert_eq!(
528            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
529            indoc! {"
530                fn main() {
531                    let mut x = 0;
532                    while x < 10 {
533                        x += 1;
534                    }
535                }
536            "}
537        );
538    }
539
540    #[gpui::test]
541    async fn test_strip_markdown_codeblock() {
542        assert_eq!(
543            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
544                .map(|chunk| chunk.unwrap())
545                .collect::<String>()
546                .await,
547            "Lorem ipsum dolor"
548        );
549        assert_eq!(
550            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
551                .map(|chunk| chunk.unwrap())
552                .collect::<String>()
553                .await,
554            "Lorem ipsum dolor"
555        );
556        assert_eq!(
557            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
558                .map(|chunk| chunk.unwrap())
559                .collect::<String>()
560                .await,
561            "Lorem ipsum dolor"
562        );
563        assert_eq!(
564            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
565                .map(|chunk| chunk.unwrap())
566                .collect::<String>()
567                .await,
568            "Lorem ipsum dolor"
569        );
570        assert_eq!(
571            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
572                .map(|chunk| chunk.unwrap())
573                .collect::<String>()
574                .await,
575            "```js\nLorem ipsum dolor\n```"
576        );
577        assert_eq!(
578            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
579                .map(|chunk| chunk.unwrap())
580                .collect::<String>()
581                .await,
582            "``\nLorem ipsum dolor\n```"
583        );
584
585        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
586            stream::iter(
587                text.chars()
588                    .collect::<Vec<_>>()
589                    .chunks(size)
590                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
591                    .collect::<Vec<_>>(),
592            )
593        }
594    }
595
596    struct TestCompletionProvider {
597        last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
598    }
599
600    impl TestCompletionProvider {
601        fn new() -> Self {
602            Self {
603                last_completion_tx: Mutex::new(None),
604            }
605        }
606
607        fn send_completion(&self, completion: impl Into<String>) {
608            let mut tx = self.last_completion_tx.lock();
609            tx.as_mut().unwrap().try_send(completion.into()).unwrap();
610        }
611
612        fn finish_completion(&self) {
613            self.last_completion_tx.lock().take().unwrap();
614        }
615    }
616
617    impl CompletionProvider for TestCompletionProvider {
618        fn complete(
619            &self,
620            _prompt: OpenAIRequest,
621        ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
622            let (tx, rx) = mpsc::channel(1);
623            *self.last_completion_tx.lock() = Some(tx);
624            async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
625        }
626    }
627
628    fn rust_lang() -> Language {
629        Language::new(
630            LanguageConfig {
631                name: "Rust".into(),
632                path_suffixes: vec!["rs".to_string()],
633                ..Default::default()
634            },
635            Some(tree_sitter_rust::language()),
636        )
637        .with_indents_query(
638            r#"
639            (call_expression) @indent
640            (field_expression) @indent
641            (_ "(" ")" @end) @indent
642            (_ "{" "}" @end) @indent
643            "#,
644        )
645        .unwrap()
646    }
647}