codegen.rs

  1use crate::{
  2    stream_completion,
  3    streaming_diff::{Hunk, StreamingDiff},
  4    OpenAIRequest,
  5};
  6use anyhow::Result;
  7use editor::{
  8    multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
  9};
 10use futures::{
 11    channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
 12};
 13use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
 14use language::{Rope, TransactionId};
 15use std::{cmp, future, ops::Range, sync::Arc};
 16
 17pub trait CompletionProvider {
 18    fn complete(
 19        &self,
 20        prompt: OpenAIRequest,
 21    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 22}
 23
 24pub struct OpenAICompletionProvider {
 25    api_key: String,
 26    executor: Arc<Background>,
 27}
 28
 29impl OpenAICompletionProvider {
 30    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
 31        Self { api_key, executor }
 32    }
 33}
 34
 35impl CompletionProvider for OpenAICompletionProvider {
 36    fn complete(
 37        &self,
 38        prompt: OpenAIRequest,
 39    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
 40        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
 41        async move {
 42            let response = request.await?;
 43            let stream = response
 44                .filter_map(|response| async move {
 45                    match response {
 46                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
 47                        Err(error) => Some(Err(error)),
 48                    }
 49                })
 50                .boxed();
 51            Ok(stream)
 52        }
 53        .boxed()
 54    }
 55}
 56
 57pub enum Event {
 58    Finished,
 59    Undone,
 60}
 61
 62#[derive(Clone)]
 63pub enum CodegenKind {
 64    Transform { range: Range<Anchor> },
 65    Generate { position: Anchor },
 66}
 67
 68pub struct Codegen {
 69    provider: Arc<dyn CompletionProvider>,
 70    buffer: ModelHandle<MultiBuffer>,
 71    snapshot: MultiBufferSnapshot,
 72    kind: CodegenKind,
 73    last_equal_ranges: Vec<Range<Anchor>>,
 74    transaction_id: Option<TransactionId>,
 75    error: Option<anyhow::Error>,
 76    generation: Task<()>,
 77    idle: bool,
 78    _subscription: gpui::Subscription,
 79}
 80
 81impl Entity for Codegen {
 82    type Event = Event;
 83}
 84
 85impl Codegen {
 86    pub fn new(
 87        buffer: ModelHandle<MultiBuffer>,
 88        mut kind: CodegenKind,
 89        provider: Arc<dyn CompletionProvider>,
 90        cx: &mut ModelContext<Self>,
 91    ) -> Self {
 92        let snapshot = buffer.read(cx).snapshot(cx);
 93        match &mut kind {
 94            CodegenKind::Transform { range } => {
 95                let mut point_range = range.to_point(&snapshot);
 96                point_range.start.column = 0;
 97                if point_range.end.column > 0 || point_range.start.row == point_range.end.row {
 98                    point_range.end.column = snapshot.line_len(point_range.end.row);
 99                }
100                range.start = snapshot.anchor_before(point_range.start);
101                range.end = snapshot.anchor_after(point_range.end);
102            }
103            CodegenKind::Generate { position } => {
104                *position = position.bias_right(&snapshot);
105            }
106        }
107
108        Self {
109            provider,
110            buffer: buffer.clone(),
111            snapshot,
112            kind,
113            last_equal_ranges: Default::default(),
114            transaction_id: Default::default(),
115            error: Default::default(),
116            idle: true,
117            generation: Task::ready(()),
118            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
119        }
120    }
121
122    fn handle_buffer_event(
123        &mut self,
124        _buffer: ModelHandle<MultiBuffer>,
125        event: &multi_buffer::Event,
126        cx: &mut ModelContext<Self>,
127    ) {
128        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
129            if self.transaction_id == Some(*transaction_id) {
130                self.transaction_id = None;
131                self.generation = Task::ready(());
132                cx.emit(Event::Undone);
133            }
134        }
135    }
136
137    pub fn range(&self) -> Range<Anchor> {
138        match &self.kind {
139            CodegenKind::Transform { range } => range.clone(),
140            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
141        }
142    }
143
144    pub fn kind(&self) -> &CodegenKind {
145        &self.kind
146    }
147
148    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
149        &self.last_equal_ranges
150    }
151
152    pub fn idle(&self) -> bool {
153        self.idle
154    }
155
156    pub fn error(&self) -> Option<&anyhow::Error> {
157        self.error.as_ref()
158    }
159
160    pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
161        let range = self.range();
162        let snapshot = self.snapshot.clone();
163        let selected_text = snapshot
164            .text_for_range(range.start..range.end)
165            .collect::<Rope>();
166
167        let selection_start = range.start.to_point(&snapshot);
168        let suggested_line_indent = snapshot
169            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
170            .into_values()
171            .next()
172            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
173
174        let response = self.provider.complete(prompt);
175        self.generation = cx.spawn_weak(|this, mut cx| {
176            async move {
177                let generate = async {
178                    let mut edit_start = range.start.to_offset(&snapshot);
179
180                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
181                    let diff = cx.background().spawn(async move {
182                        let chunks = strip_markdown_codeblock(response.await?);
183                        futures::pin_mut!(chunks);
184                        let mut diff = StreamingDiff::new(selected_text.to_string());
185
186                        let mut new_text = String::new();
187                        let mut base_indent = None;
188                        let mut line_indent = None;
189                        let mut first_line = true;
190
191                        while let Some(chunk) = chunks.next().await {
192                            let chunk = chunk?;
193
194                            let mut lines = chunk.split('\n').peekable();
195                            while let Some(line) = lines.next() {
196                                new_text.push_str(line);
197                                if line_indent.is_none() {
198                                    if let Some(non_whitespace_ch_ix) =
199                                        new_text.find(|ch: char| !ch.is_whitespace())
200                                    {
201                                        line_indent = Some(non_whitespace_ch_ix);
202                                        base_indent = base_indent.or(line_indent);
203
204                                        let line_indent = line_indent.unwrap();
205                                        let base_indent = base_indent.unwrap();
206                                        let indent_delta = line_indent as i32 - base_indent as i32;
207                                        let mut corrected_indent_len = cmp::max(
208                                            0,
209                                            suggested_line_indent.len as i32 + indent_delta,
210                                        )
211                                            as usize;
212                                        if first_line {
213                                            corrected_indent_len = corrected_indent_len
214                                                .saturating_sub(selection_start.column as usize);
215                                        }
216
217                                        let indent_char = suggested_line_indent.char();
218                                        let mut indent_buffer = [0; 4];
219                                        let indent_str =
220                                            indent_char.encode_utf8(&mut indent_buffer);
221                                        new_text.replace_range(
222                                            ..line_indent,
223                                            &indent_str.repeat(corrected_indent_len),
224                                        );
225                                    }
226                                }
227
228                                if lines.peek().is_some() {
229                                    hunks_tx.send(diff.push_new(&new_text)).await?;
230                                    hunks_tx.send(diff.push_new("\n")).await?;
231                                    new_text.clear();
232                                    line_indent = None;
233                                    first_line = false;
234                                }
235                            }
236                        }
237                        hunks_tx.send(diff.push_new(&new_text)).await?;
238                        hunks_tx.send(diff.finish()).await?;
239
240                        anyhow::Ok(())
241                    });
242
243                    while let Some(hunks) = hunks_rx.next().await {
244                        let this = if let Some(this) = this.upgrade(&cx) {
245                            this
246                        } else {
247                            break;
248                        };
249
250                        this.update(&mut cx, |this, cx| {
251                            this.last_equal_ranges.clear();
252
253                            let transaction = this.buffer.update(cx, |buffer, cx| {
254                                // Avoid grouping assistant edits with user edits.
255                                buffer.finalize_last_transaction(cx);
256
257                                buffer.start_transaction(cx);
258                                buffer.edit(
259                                    hunks.into_iter().filter_map(|hunk| match hunk {
260                                        Hunk::Insert { text } => {
261                                            let edit_start = snapshot.anchor_after(edit_start);
262                                            Some((edit_start..edit_start, text))
263                                        }
264                                        Hunk::Remove { len } => {
265                                            let edit_end = edit_start + len;
266                                            let edit_range = snapshot.anchor_after(edit_start)
267                                                ..snapshot.anchor_before(edit_end);
268                                            edit_start = edit_end;
269                                            Some((edit_range, String::new()))
270                                        }
271                                        Hunk::Keep { len } => {
272                                            let edit_end = edit_start + len;
273                                            let edit_range = snapshot.anchor_after(edit_start)
274                                                ..snapshot.anchor_before(edit_end);
275                                            edit_start = edit_end;
276                                            this.last_equal_ranges.push(edit_range);
277                                            None
278                                        }
279                                    }),
280                                    None,
281                                    cx,
282                                );
283
284                                buffer.end_transaction(cx)
285                            });
286
287                            if let Some(transaction) = transaction {
288                                if let Some(first_transaction) = this.transaction_id {
289                                    // Group all assistant edits into the first transaction.
290                                    this.buffer.update(cx, |buffer, cx| {
291                                        buffer.merge_transactions(
292                                            transaction,
293                                            first_transaction,
294                                            cx,
295                                        )
296                                    });
297                                } else {
298                                    this.transaction_id = Some(transaction);
299                                    this.buffer.update(cx, |buffer, cx| {
300                                        buffer.finalize_last_transaction(cx)
301                                    });
302                                }
303                            }
304
305                            cx.notify();
306                        });
307                    }
308
309                    diff.await?;
310                    anyhow::Ok(())
311                };
312
313                let result = generate.await;
314                if let Some(this) = this.upgrade(&cx) {
315                    this.update(&mut cx, |this, cx| {
316                        this.last_equal_ranges.clear();
317                        this.idle = true;
318                        if let Err(error) = result {
319                            this.error = Some(error);
320                        }
321                        cx.emit(Event::Finished);
322                        cx.notify();
323                    });
324                }
325            }
326        });
327        self.error.take();
328        self.idle = false;
329        cx.notify();
330    }
331
332    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
333        if let Some(transaction_id) = self.transaction_id {
334            self.buffer
335                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
336        }
337    }
338}
339
340fn strip_markdown_codeblock(
341    stream: impl Stream<Item = Result<String>>,
342) -> impl Stream<Item = Result<String>> {
343    let mut first_line = true;
344    let mut buffer = String::new();
345    let mut starts_with_fenced_code_block = false;
346    stream.filter_map(move |chunk| {
347        let chunk = match chunk {
348            Ok(chunk) => chunk,
349            Err(err) => return future::ready(Some(Err(err))),
350        };
351        buffer.push_str(&chunk);
352
353        if first_line {
354            if buffer == "" || buffer == "`" || buffer == "``" {
355                return future::ready(None);
356            } else if buffer.starts_with("```") {
357                starts_with_fenced_code_block = true;
358                if let Some(newline_ix) = buffer.find('\n') {
359                    buffer.replace_range(..newline_ix + 1, "");
360                    first_line = false;
361                } else {
362                    return future::ready(None);
363                }
364            }
365        }
366
367        let text = if starts_with_fenced_code_block {
368            buffer
369                .strip_suffix("\n```\n")
370                .or_else(|| buffer.strip_suffix("\n```"))
371                .or_else(|| buffer.strip_suffix("\n``"))
372                .or_else(|| buffer.strip_suffix("\n`"))
373                .or_else(|| buffer.strip_suffix('\n'))
374                .unwrap_or(&buffer)
375        } else {
376            &buffer
377        };
378
379        if text.contains('\n') {
380            first_line = false;
381        }
382
383        let remainder = buffer.split_off(text.len());
384        let result = if buffer.is_empty() {
385            None
386        } else {
387            Some(Ok(buffer.clone()))
388        };
389        buffer = remainder;
390        future::ready(result)
391    })
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use futures::stream;
398    use gpui::{executor::Deterministic, TestAppContext};
399    use indoc::indoc;
400    use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
401    use parking_lot::Mutex;
402    use rand::prelude::*;
403    use settings::SettingsStore;
404
405    #[gpui::test(iterations = 10)]
406    async fn test_transform_autoindent(
407        cx: &mut TestAppContext,
408        mut rng: StdRng,
409        deterministic: Arc<Deterministic>,
410    ) {
411        cx.set_global(cx.read(SettingsStore::test));
412        cx.update(language_settings::init);
413
414        let text = indoc! {"
415            fn main() {
416                let x = 0;
417                for _ in 0..10 {
418                    x += 1;
419                }
420            }
421        "};
422        let buffer =
423            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
424        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
425        let range = buffer.read_with(cx, |buffer, cx| {
426            let snapshot = buffer.snapshot(cx);
427            snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
428        });
429        let provider = Arc::new(TestCompletionProvider::new());
430        let codegen = cx.add_model(|cx| {
431            Codegen::new(
432                buffer.clone(),
433                CodegenKind::Transform { range },
434                provider.clone(),
435                cx,
436            )
437        });
438        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
439
440        let mut new_text = concat!(
441            "       let mut x = 0;\n",
442            "       while x < 10 {\n",
443            "           x += 1;\n",
444            "       }",
445        );
446        while !new_text.is_empty() {
447            let max_len = cmp::min(new_text.len(), 10);
448            let len = rng.gen_range(1..=max_len);
449            let (chunk, suffix) = new_text.split_at(len);
450            provider.send_completion(chunk);
451            new_text = suffix;
452            deterministic.run_until_parked();
453        }
454        provider.finish_completion();
455        deterministic.run_until_parked();
456
457        assert_eq!(
458            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
459            indoc! {"
460                fn main() {
461                    let mut x = 0;
462                    while x < 10 {
463                        x += 1;
464                    }
465                }
466            "}
467        );
468    }
469
470    #[gpui::test(iterations = 10)]
471    async fn test_autoindent_when_generating_past_indentation(
472        cx: &mut TestAppContext,
473        mut rng: StdRng,
474        deterministic: Arc<Deterministic>,
475    ) {
476        cx.set_global(cx.read(SettingsStore::test));
477        cx.update(language_settings::init);
478
479        let text = indoc! {"
480            fn main() {
481                le
482            }
483        "};
484        let buffer =
485            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
486        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
487        let position = buffer.read_with(cx, |buffer, cx| {
488            let snapshot = buffer.snapshot(cx);
489            snapshot.anchor_before(Point::new(1, 6))
490        });
491        let provider = Arc::new(TestCompletionProvider::new());
492        let codegen = cx.add_model(|cx| {
493            Codegen::new(
494                buffer.clone(),
495                CodegenKind::Generate { position },
496                provider.clone(),
497                cx,
498            )
499        });
500        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
501
502        let mut new_text = concat!(
503            "t mut x = 0;\n",
504            "while x < 10 {\n",
505            "    x += 1;\n",
506            "}", //
507        );
508        while !new_text.is_empty() {
509            let max_len = cmp::min(new_text.len(), 10);
510            let len = rng.gen_range(1..=max_len);
511            let (chunk, suffix) = new_text.split_at(len);
512            provider.send_completion(chunk);
513            new_text = suffix;
514            deterministic.run_until_parked();
515        }
516        provider.finish_completion();
517        deterministic.run_until_parked();
518
519        assert_eq!(
520            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
521            indoc! {"
522                fn main() {
523                    let mut x = 0;
524                    while x < 10 {
525                        x += 1;
526                    }
527                }
528            "}
529        );
530    }
531
532    #[gpui::test(iterations = 10)]
533    async fn test_autoindent_when_generating_before_indentation(
534        cx: &mut TestAppContext,
535        mut rng: StdRng,
536        deterministic: Arc<Deterministic>,
537    ) {
538        cx.set_global(cx.read(SettingsStore::test));
539        cx.update(language_settings::init);
540
541        let text = concat!(
542            "fn main() {\n",
543            "  \n",
544            "}\n" //
545        );
546        let buffer =
547            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
548        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
549        let position = buffer.read_with(cx, |buffer, cx| {
550            let snapshot = buffer.snapshot(cx);
551            snapshot.anchor_before(Point::new(1, 2))
552        });
553        let provider = Arc::new(TestCompletionProvider::new());
554        let codegen = cx.add_model(|cx| {
555            Codegen::new(
556                buffer.clone(),
557                CodegenKind::Generate { position },
558                provider.clone(),
559                cx,
560            )
561        });
562        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
563
564        let mut new_text = concat!(
565            "let mut x = 0;\n",
566            "while x < 10 {\n",
567            "    x += 1;\n",
568            "}", //
569        );
570        while !new_text.is_empty() {
571            let max_len = cmp::min(new_text.len(), 10);
572            let len = rng.gen_range(1..=max_len);
573            let (chunk, suffix) = new_text.split_at(len);
574            provider.send_completion(chunk);
575            new_text = suffix;
576            deterministic.run_until_parked();
577        }
578        provider.finish_completion();
579        deterministic.run_until_parked();
580
581        assert_eq!(
582            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
583            indoc! {"
584                fn main() {
585                    let mut x = 0;
586                    while x < 10 {
587                        x += 1;
588                    }
589                }
590            "}
591        );
592    }
593
594    #[gpui::test]
595    async fn test_strip_markdown_codeblock() {
596        assert_eq!(
597            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
598                .map(|chunk| chunk.unwrap())
599                .collect::<String>()
600                .await,
601            "Lorem ipsum dolor"
602        );
603        assert_eq!(
604            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
605                .map(|chunk| chunk.unwrap())
606                .collect::<String>()
607                .await,
608            "Lorem ipsum dolor"
609        );
610        assert_eq!(
611            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
612                .map(|chunk| chunk.unwrap())
613                .collect::<String>()
614                .await,
615            "Lorem ipsum dolor"
616        );
617        assert_eq!(
618            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
619                .map(|chunk| chunk.unwrap())
620                .collect::<String>()
621                .await,
622            "Lorem ipsum dolor"
623        );
624        assert_eq!(
625            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
626                .map(|chunk| chunk.unwrap())
627                .collect::<String>()
628                .await,
629            "```js\nLorem ipsum dolor\n```"
630        );
631        assert_eq!(
632            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
633                .map(|chunk| chunk.unwrap())
634                .collect::<String>()
635                .await,
636            "``\nLorem ipsum dolor\n```"
637        );
638
639        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
640            stream::iter(
641                text.chars()
642                    .collect::<Vec<_>>()
643                    .chunks(size)
644                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
645                    .collect::<Vec<_>>(),
646            )
647        }
648    }
649
650    struct TestCompletionProvider {
651        last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
652    }
653
654    impl TestCompletionProvider {
655        fn new() -> Self {
656            Self {
657                last_completion_tx: Mutex::new(None),
658            }
659        }
660
661        fn send_completion(&self, completion: impl Into<String>) {
662            let mut tx = self.last_completion_tx.lock();
663            tx.as_mut().unwrap().try_send(completion.into()).unwrap();
664        }
665
666        fn finish_completion(&self) {
667            self.last_completion_tx.lock().take().unwrap();
668        }
669    }
670
671    impl CompletionProvider for TestCompletionProvider {
672        fn complete(
673            &self,
674            _prompt: OpenAIRequest,
675        ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
676            let (tx, rx) = mpsc::channel(1);
677            *self.last_completion_tx.lock() = Some(tx);
678            async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
679        }
680    }
681
682    fn rust_lang() -> Language {
683        Language::new(
684            LanguageConfig {
685                name: "Rust".into(),
686                path_suffixes: vec!["rs".to_string()],
687                ..Default::default()
688            },
689            Some(tree_sitter_rust::language()),
690        )
691        .with_indents_query(
692            r#"
693            (call_expression) @indent
694            (field_expression) @indent
695            (_ "(" ")" @end) @indent
696            (_ "{" "}" @end) @indent
697            "#,
698        )
699        .unwrap()
700    }
701}