codegen.rs

  1use crate::{
  2    stream_completion,
  3    streaming_diff::{Hunk, StreamingDiff},
  4    OpenAIRequest,
  5};
  6use anyhow::Result;
  7use editor::{
  8    multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
  9};
 10use futures::{
 11    channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
 12};
 13use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
 14use language::{Rope, TransactionId};
 15use std::{cmp, future, ops::Range, sync::Arc};
 16
 17pub trait CompletionProvider {
 18    fn complete(
 19        &self,
 20        prompt: OpenAIRequest,
 21    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 22}
 23
 24pub struct OpenAICompletionProvider {
 25    api_key: String,
 26    executor: Arc<Background>,
 27}
 28
 29impl OpenAICompletionProvider {
 30    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
 31        Self { api_key, executor }
 32    }
 33}
 34
 35impl CompletionProvider for OpenAICompletionProvider {
 36    fn complete(
 37        &self,
 38        prompt: OpenAIRequest,
 39    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
 40        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
 41        async move {
 42            let response = request.await?;
 43            let stream = response
 44                .filter_map(|response| async move {
 45                    match response {
 46                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
 47                        Err(error) => Some(Err(error)),
 48                    }
 49                })
 50                .boxed();
 51            Ok(stream)
 52        }
 53        .boxed()
 54    }
 55}
 56
 57pub enum Event {
 58    Finished,
 59    Undone,
 60}
 61
 62#[derive(Clone)]
 63pub enum CodegenKind {
 64    Transform { range: Range<Anchor> },
 65    Generate { position: Anchor },
 66}
 67
 68pub struct Codegen {
 69    provider: Arc<dyn CompletionProvider>,
 70    buffer: ModelHandle<MultiBuffer>,
 71    snapshot: MultiBufferSnapshot,
 72    kind: CodegenKind,
 73    last_equal_ranges: Vec<Range<Anchor>>,
 74    transaction_id: Option<TransactionId>,
 75    error: Option<anyhow::Error>,
 76    generation: Task<()>,
 77    idle: bool,
 78    _subscription: gpui::Subscription,
 79}
 80
 81impl Entity for Codegen {
 82    type Event = Event;
 83}
 84
 85impl Codegen {
 86    pub fn new(
 87        buffer: ModelHandle<MultiBuffer>,
 88        mut kind: CodegenKind,
 89        provider: Arc<dyn CompletionProvider>,
 90        cx: &mut ModelContext<Self>,
 91    ) -> Self {
 92        let snapshot = buffer.read(cx).snapshot(cx);
 93        match &mut kind {
 94            CodegenKind::Transform { range } => {
 95                let mut point_range = range.to_point(&snapshot);
 96                point_range.start.column = 0;
 97                if point_range.end.column > 0 || point_range.start.row == point_range.end.row {
 98                    point_range.end.column = snapshot.line_len(point_range.end.row);
 99                }
100                range.start = snapshot.anchor_before(point_range.start);
101                range.end = snapshot.anchor_after(point_range.end);
102            }
103            CodegenKind::Generate { position } => {
104                *position = position.bias_right(&snapshot);
105            }
106        }
107
108        Self {
109            provider,
110            buffer: buffer.clone(),
111            snapshot,
112            kind,
113            last_equal_ranges: Default::default(),
114            transaction_id: Default::default(),
115            error: Default::default(),
116            idle: true,
117            generation: Task::ready(()),
118            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
119        }
120    }
121
122    fn handle_buffer_event(
123        &mut self,
124        _buffer: ModelHandle<MultiBuffer>,
125        event: &multi_buffer::Event,
126        cx: &mut ModelContext<Self>,
127    ) {
128        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
129            if self.transaction_id == Some(*transaction_id) {
130                self.transaction_id = None;
131                self.generation = Task::ready(());
132                cx.emit(Event::Undone);
133            }
134        }
135    }
136
137    pub fn range(&self) -> Range<Anchor> {
138        match &self.kind {
139            CodegenKind::Transform { range } => range.clone(),
140            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
141        }
142    }
143
144    pub fn kind(&self) -> &CodegenKind {
145        &self.kind
146    }
147
148    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
149        &self.last_equal_ranges
150    }
151
152    pub fn idle(&self) -> bool {
153        self.idle
154    }
155
156    pub fn error(&self) -> Option<&anyhow::Error> {
157        self.error.as_ref()
158    }
159
160    pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
161        let range = self.range();
162        let snapshot = self.snapshot.clone();
163        let selected_text = snapshot
164            .text_for_range(range.start..range.end)
165            .collect::<Rope>();
166
167        let selection_start = range.start.to_point(&snapshot);
168        let suggested_line_indent = snapshot
169            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
170            .into_values()
171            .next()
172            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
173
174        let response = self.provider.complete(prompt);
175        self.generation = cx.spawn_weak(|this, mut cx| {
176            async move {
177                let generate = async {
178                    let mut edit_start = range.start.to_offset(&snapshot);
179
180                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
181                    let diff = cx.background().spawn(async move {
182                        let chunks = strip_markdown_codeblock(response.await?);
183                        futures::pin_mut!(chunks);
184                        let mut diff = StreamingDiff::new(selected_text.to_string());
185
186                        let mut new_text = String::new();
187                        let mut base_indent = None;
188                        let mut line_indent = None;
189                        let mut first_line = true;
190
191                        while let Some(chunk) = chunks.next().await {
192                            let chunk = chunk?;
193
194                            let mut lines = chunk.split('\n').peekable();
195                            while let Some(line) = lines.next() {
196                                new_text.push_str(line);
197                                if line_indent.is_none() {
198                                    if let Some(non_whitespace_ch_ix) =
199                                        new_text.find(|ch: char| !ch.is_whitespace())
200                                    {
201                                        line_indent = Some(non_whitespace_ch_ix);
202                                        base_indent = base_indent.or(line_indent);
203
204                                        let line_indent = line_indent.unwrap();
205                                        let base_indent = base_indent.unwrap();
206                                        let indent_delta = line_indent as i32 - base_indent as i32;
207                                        let mut corrected_indent_len = cmp::max(
208                                            0,
209                                            suggested_line_indent.len as i32 + indent_delta,
210                                        )
211                                            as usize;
212                                        if first_line {
213                                            corrected_indent_len = corrected_indent_len
214                                                .saturating_sub(selection_start.column as usize);
215                                        }
216
217                                        let indent_char = suggested_line_indent.char();
218                                        let mut indent_buffer = [0; 4];
219                                        let indent_str =
220                                            indent_char.encode_utf8(&mut indent_buffer);
221                                        new_text.replace_range(
222                                            ..line_indent,
223                                            &indent_str.repeat(corrected_indent_len),
224                                        );
225                                    }
226                                }
227
228                                if line_indent.is_some() {
229                                    hunks_tx.send(diff.push_new(&new_text)).await?;
230                                    new_text.clear();
231                                }
232
233                                if lines.peek().is_some() {
234                                    hunks_tx.send(diff.push_new("\n")).await?;
235                                    line_indent = None;
236                                    first_line = false;
237                                }
238                            }
239                        }
240                        hunks_tx.send(diff.push_new(&new_text)).await?;
241                        hunks_tx.send(diff.finish()).await?;
242
243                        anyhow::Ok(())
244                    });
245
246                    while let Some(hunks) = hunks_rx.next().await {
247                        let this = if let Some(this) = this.upgrade(&cx) {
248                            this
249                        } else {
250                            break;
251                        };
252
253                        this.update(&mut cx, |this, cx| {
254                            this.last_equal_ranges.clear();
255
256                            let transaction = this.buffer.update(cx, |buffer, cx| {
257                                // Avoid grouping assistant edits with user edits.
258                                buffer.finalize_last_transaction(cx);
259
260                                buffer.start_transaction(cx);
261                                buffer.edit(
262                                    hunks.into_iter().filter_map(|hunk| match hunk {
263                                        Hunk::Insert { text } => {
264                                            let edit_start = snapshot.anchor_after(edit_start);
265                                            Some((edit_start..edit_start, text))
266                                        }
267                                        Hunk::Remove { len } => {
268                                            let edit_end = edit_start + len;
269                                            let edit_range = snapshot.anchor_after(edit_start)
270                                                ..snapshot.anchor_before(edit_end);
271                                            edit_start = edit_end;
272                                            Some((edit_range, String::new()))
273                                        }
274                                        Hunk::Keep { len } => {
275                                            let edit_end = edit_start + len;
276                                            let edit_range = snapshot.anchor_after(edit_start)
277                                                ..snapshot.anchor_before(edit_end);
278                                            edit_start = edit_end;
279                                            this.last_equal_ranges.push(edit_range);
280                                            None
281                                        }
282                                    }),
283                                    None,
284                                    cx,
285                                );
286
287                                buffer.end_transaction(cx)
288                            });
289
290                            if let Some(transaction) = transaction {
291                                if let Some(first_transaction) = this.transaction_id {
292                                    // Group all assistant edits into the first transaction.
293                                    this.buffer.update(cx, |buffer, cx| {
294                                        buffer.merge_transactions(
295                                            transaction,
296                                            first_transaction,
297                                            cx,
298                                        )
299                                    });
300                                } else {
301                                    this.transaction_id = Some(transaction);
302                                    this.buffer.update(cx, |buffer, cx| {
303                                        buffer.finalize_last_transaction(cx)
304                                    });
305                                }
306                            }
307
308                            cx.notify();
309                        });
310                    }
311
312                    diff.await?;
313                    anyhow::Ok(())
314                };
315
316                let result = generate.await;
317                if let Some(this) = this.upgrade(&cx) {
318                    this.update(&mut cx, |this, cx| {
319                        this.last_equal_ranges.clear();
320                        this.idle = true;
321                        if let Err(error) = result {
322                            this.error = Some(error);
323                        }
324                        cx.emit(Event::Finished);
325                        cx.notify();
326                    });
327                }
328            }
329        });
330        self.error.take();
331        self.idle = false;
332        cx.notify();
333    }
334
335    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
336        if let Some(transaction_id) = self.transaction_id {
337            self.buffer
338                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
339        }
340    }
341}
342
343fn strip_markdown_codeblock(
344    stream: impl Stream<Item = Result<String>>,
345) -> impl Stream<Item = Result<String>> {
346    let mut first_line = true;
347    let mut buffer = String::new();
348    let mut starts_with_fenced_code_block = false;
349    stream.filter_map(move |chunk| {
350        let chunk = match chunk {
351            Ok(chunk) => chunk,
352            Err(err) => return future::ready(Some(Err(err))),
353        };
354        buffer.push_str(&chunk);
355
356        if first_line {
357            if buffer == "" || buffer == "`" || buffer == "``" {
358                return future::ready(None);
359            } else if buffer.starts_with("```") {
360                starts_with_fenced_code_block = true;
361                if let Some(newline_ix) = buffer.find('\n') {
362                    buffer.replace_range(..newline_ix + 1, "");
363                    first_line = false;
364                } else {
365                    return future::ready(None);
366                }
367            }
368        }
369
370        let text = if starts_with_fenced_code_block {
371            buffer
372                .strip_suffix("\n```\n")
373                .or_else(|| buffer.strip_suffix("\n```"))
374                .or_else(|| buffer.strip_suffix("\n``"))
375                .or_else(|| buffer.strip_suffix("\n`"))
376                .or_else(|| buffer.strip_suffix('\n'))
377                .unwrap_or(&buffer)
378        } else {
379            &buffer
380        };
381
382        if text.contains('\n') {
383            first_line = false;
384        }
385
386        let remainder = buffer.split_off(text.len());
387        let result = if buffer.is_empty() {
388            None
389        } else {
390            Some(Ok(buffer.clone()))
391        };
392        buffer = remainder;
393        future::ready(result)
394    })
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use futures::stream;
401    use gpui::{executor::Deterministic, TestAppContext};
402    use indoc::indoc;
403    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
404    use parking_lot::Mutex;
405    use rand::prelude::*;
406    use settings::SettingsStore;
407
408    #[gpui::test(iterations = 10)]
409    async fn test_transform_autoindent(
410        cx: &mut TestAppContext,
411        mut rng: StdRng,
412        deterministic: Arc<Deterministic>,
413    ) {
414        cx.set_global(cx.read(SettingsStore::test));
415        cx.update(language_settings::init);
416
417        let text = indoc! {"
418            fn main() {
419                let x = 0;
420                for _ in 0..10 {
421                    x += 1;
422                }
423            }
424        "};
425        let buffer =
426            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
427        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
428        let range = buffer.read_with(cx, |buffer, cx| {
429            let snapshot = buffer.snapshot(cx);
430            snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
431        });
432        let provider = Arc::new(TestCompletionProvider::new());
433        let codegen = cx.add_model(|cx| {
434            Codegen::new(
435                buffer.clone(),
436                CodegenKind::Transform { range },
437                provider.clone(),
438                cx,
439            )
440        });
441        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
442
443        let mut new_text = concat!(
444            "       let mut x = 0;\n",
445            "       while x < 10 {\n",
446            "           x += 1;\n",
447            "       }",
448        );
449        while !new_text.is_empty() {
450            let max_len = cmp::min(new_text.len(), 10);
451            let len = rng.gen_range(1..=max_len);
452            let (chunk, suffix) = new_text.split_at(len);
453            provider.send_completion(chunk);
454            new_text = suffix;
455            deterministic.run_until_parked();
456        }
457        provider.finish_completion();
458        deterministic.run_until_parked();
459
460        assert_eq!(
461            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
462            indoc! {"
463                fn main() {
464                    let mut x = 0;
465                    while x < 10 {
466                        x += 1;
467                    }
468                }
469            "}
470        );
471    }
472
473    #[gpui::test(iterations = 10)]
474    async fn test_autoindent_when_generating_past_indentation(
475        cx: &mut TestAppContext,
476        mut rng: StdRng,
477        deterministic: Arc<Deterministic>,
478    ) {
479        cx.set_global(cx.read(SettingsStore::test));
480        cx.update(language_settings::init);
481
482        let text = indoc! {"
483            fn main() {
484                le
485            }
486        "};
487        let buffer =
488            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
489        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
490        let position = buffer.read_with(cx, |buffer, cx| {
491            let snapshot = buffer.snapshot(cx);
492            snapshot.anchor_before(Point::new(1, 6))
493        });
494        let provider = Arc::new(TestCompletionProvider::new());
495        let codegen = cx.add_model(|cx| {
496            Codegen::new(
497                buffer.clone(),
498                CodegenKind::Generate { position },
499                provider.clone(),
500                cx,
501            )
502        });
503        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
504
505        let mut new_text = concat!(
506            "t mut x = 0;\n",
507            "while x < 10 {\n",
508            "    x += 1;\n",
509            "}", //
510        );
511        while !new_text.is_empty() {
512            let max_len = cmp::min(new_text.len(), 10);
513            let len = rng.gen_range(1..=max_len);
514            let (chunk, suffix) = new_text.split_at(len);
515            provider.send_completion(chunk);
516            new_text = suffix;
517            deterministic.run_until_parked();
518        }
519        provider.finish_completion();
520        deterministic.run_until_parked();
521
522        assert_eq!(
523            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
524            indoc! {"
525                fn main() {
526                    let mut x = 0;
527                    while x < 10 {
528                        x += 1;
529                    }
530                }
531            "}
532        );
533    }
534
535    #[gpui::test(iterations = 10)]
536    async fn test_autoindent_when_generating_before_indentation(
537        cx: &mut TestAppContext,
538        mut rng: StdRng,
539        deterministic: Arc<Deterministic>,
540    ) {
541        cx.set_global(cx.read(SettingsStore::test));
542        cx.update(language_settings::init);
543
544        let text = concat!(
545            "fn main() {\n",
546            "  \n",
547            "}\n" //
548        );
549        let buffer =
550            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
551        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
552        let position = buffer.read_with(cx, |buffer, cx| {
553            let snapshot = buffer.snapshot(cx);
554            snapshot.anchor_before(Point::new(1, 2))
555        });
556        let provider = Arc::new(TestCompletionProvider::new());
557        let codegen = cx.add_model(|cx| {
558            Codegen::new(
559                buffer.clone(),
560                CodegenKind::Generate { position },
561                provider.clone(),
562                cx,
563            )
564        });
565        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
566
567        let mut new_text = concat!(
568            "let mut x = 0;\n",
569            "while x < 10 {\n",
570            "    x += 1;\n",
571            "}", //
572        );
573        while !new_text.is_empty() {
574            let max_len = cmp::min(new_text.len(), 10);
575            let len = rng.gen_range(1..=max_len);
576            let (chunk, suffix) = new_text.split_at(len);
577            provider.send_completion(chunk);
578            new_text = suffix;
579            deterministic.run_until_parked();
580        }
581        provider.finish_completion();
582        deterministic.run_until_parked();
583
584        assert_eq!(
585            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
586            indoc! {"
587                fn main() {
588                    let mut x = 0;
589                    while x < 10 {
590                        x += 1;
591                    }
592                }
593            "}
594        );
595    }
596
597    #[gpui::test]
598    async fn test_strip_markdown_codeblock() {
599        assert_eq!(
600            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
601                .map(|chunk| chunk.unwrap())
602                .collect::<String>()
603                .await,
604            "Lorem ipsum dolor"
605        );
606        assert_eq!(
607            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
608                .map(|chunk| chunk.unwrap())
609                .collect::<String>()
610                .await,
611            "Lorem ipsum dolor"
612        );
613        assert_eq!(
614            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
615                .map(|chunk| chunk.unwrap())
616                .collect::<String>()
617                .await,
618            "Lorem ipsum dolor"
619        );
620        assert_eq!(
621            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
622                .map(|chunk| chunk.unwrap())
623                .collect::<String>()
624                .await,
625            "Lorem ipsum dolor"
626        );
627        assert_eq!(
628            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
629                .map(|chunk| chunk.unwrap())
630                .collect::<String>()
631                .await,
632            "```js\nLorem ipsum dolor\n```"
633        );
634        assert_eq!(
635            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
636                .map(|chunk| chunk.unwrap())
637                .collect::<String>()
638                .await,
639            "``\nLorem ipsum dolor\n```"
640        );
641
642        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
643            stream::iter(
644                text.chars()
645                    .collect::<Vec<_>>()
646                    .chunks(size)
647                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
648                    .collect::<Vec<_>>(),
649            )
650        }
651    }
652
653    struct TestCompletionProvider {
654        last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
655    }
656
657    impl TestCompletionProvider {
658        fn new() -> Self {
659            Self {
660                last_completion_tx: Mutex::new(None),
661            }
662        }
663
664        fn send_completion(&self, completion: impl Into<String>) {
665            let mut tx = self.last_completion_tx.lock();
666            tx.as_mut().unwrap().try_send(completion.into()).unwrap();
667        }
668
669        fn finish_completion(&self) {
670            self.last_completion_tx.lock().take().unwrap();
671        }
672    }
673
674    impl CompletionProvider for TestCompletionProvider {
675        fn complete(
676            &self,
677            _prompt: OpenAIRequest,
678        ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
679            let (tx, rx) = mpsc::channel(1);
680            *self.last_completion_tx.lock() = Some(tx);
681            async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
682        }
683    }
684
685    fn rust_lang() -> Language {
686        Language::new(
687            LanguageConfig {
688                name: "Rust".into(),
689                path_suffixes: vec!["rs".to_string()],
690                ..Default::default()
691            },
692            Some(tree_sitter_rust::language()),
693        )
694        .with_indents_query(
695            r#"
696            (call_expression) @indent
697            (field_expression) @indent
698            (_ "(" ")" @end) @indent
699            (_ "{" "}" @end) @indent
700            "#,
701        )
702        .unwrap()
703    }
704}