codegen.rs

  1use crate::{
  2    streaming_diff::{Hunk, StreamingDiff},
  3    CompletionProvider, LanguageModelRequest,
  4};
  5use anyhow::Result;
  6use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  7use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  8use gpui::{EventEmitter, Model, ModelContext, Task};
  9use language::{Rope, TransactionId};
 10use std::{cmp, future, ops::Range};
 11
 12pub enum Event {
 13    Finished,
 14    Undone,
 15}
 16
 17#[derive(Clone)]
 18pub enum CodegenKind {
 19    Transform { range: Range<Anchor> },
 20    Generate { position: Anchor },
 21}
 22
 23pub struct Codegen {
 24    buffer: Model<MultiBuffer>,
 25    snapshot: MultiBufferSnapshot,
 26    kind: CodegenKind,
 27    last_equal_ranges: Vec<Range<Anchor>>,
 28    transaction_id: Option<TransactionId>,
 29    error: Option<anyhow::Error>,
 30    generation: Task<()>,
 31    idle: bool,
 32    _subscription: gpui::Subscription,
 33}
 34
 35impl EventEmitter<Event> for Codegen {}
 36
 37impl Codegen {
 38    pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
 39        let snapshot = buffer.read(cx).snapshot(cx);
 40        Self {
 41            buffer: buffer.clone(),
 42            snapshot,
 43            kind,
 44            last_equal_ranges: Default::default(),
 45            transaction_id: Default::default(),
 46            error: Default::default(),
 47            idle: true,
 48            generation: Task::ready(()),
 49            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 50        }
 51    }
 52
 53    fn handle_buffer_event(
 54        &mut self,
 55        _buffer: Model<MultiBuffer>,
 56        event: &multi_buffer::Event,
 57        cx: &mut ModelContext<Self>,
 58    ) {
 59        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 60            if self.transaction_id == Some(*transaction_id) {
 61                self.transaction_id = None;
 62                self.generation = Task::ready(());
 63                cx.emit(Event::Undone);
 64            }
 65        }
 66    }
 67
 68    pub fn range(&self) -> Range<Anchor> {
 69        match &self.kind {
 70            CodegenKind::Transform { range } => range.clone(),
 71            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 72        }
 73    }
 74
 75    pub fn kind(&self) -> &CodegenKind {
 76        &self.kind
 77    }
 78
 79    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 80        &self.last_equal_ranges
 81    }
 82
 83    pub fn idle(&self) -> bool {
 84        self.idle
 85    }
 86
 87    pub fn error(&self) -> Option<&anyhow::Error> {
 88        self.error.as_ref()
 89    }
 90
 91    pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
 92        let range = self.range();
 93        let snapshot = self.snapshot.clone();
 94        let selected_text = snapshot
 95            .text_for_range(range.start..range.end)
 96            .collect::<Rope>();
 97
 98        let selection_start = range.start.to_point(&snapshot);
 99        let suggested_line_indent = snapshot
100            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
101            .into_values()
102            .next()
103            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
104
105        let response = CompletionProvider::global(cx).complete(prompt);
106        self.generation = cx.spawn(|this, mut cx| {
107            async move {
108                let generate = async {
109                    let mut edit_start = range.start.to_offset(&snapshot);
110
111                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
112                    let diff = cx.background_executor().spawn(async move {
113                        let chunks = strip_invalid_spans_from_codeblock(response.await?);
114                        futures::pin_mut!(chunks);
115                        let mut diff = StreamingDiff::new(selected_text.to_string());
116
117                        let mut new_text = String::new();
118                        let mut base_indent = None;
119                        let mut line_indent = None;
120                        let mut first_line = true;
121
122                        while let Some(chunk) = chunks.next().await {
123                            let chunk = chunk?;
124
125                            let mut lines = chunk.split('\n').peekable();
126                            while let Some(line) = lines.next() {
127                                new_text.push_str(line);
128                                if line_indent.is_none() {
129                                    if let Some(non_whitespace_ch_ix) =
130                                        new_text.find(|ch: char| !ch.is_whitespace())
131                                    {
132                                        line_indent = Some(non_whitespace_ch_ix);
133                                        base_indent = base_indent.or(line_indent);
134
135                                        let line_indent = line_indent.unwrap();
136                                        let base_indent = base_indent.unwrap();
137                                        let indent_delta = line_indent as i32 - base_indent as i32;
138                                        let mut corrected_indent_len = cmp::max(
139                                            0,
140                                            suggested_line_indent.len as i32 + indent_delta,
141                                        )
142                                            as usize;
143                                        if first_line {
144                                            corrected_indent_len = corrected_indent_len
145                                                .saturating_sub(selection_start.column as usize);
146                                        }
147
148                                        let indent_char = suggested_line_indent.char();
149                                        let mut indent_buffer = [0; 4];
150                                        let indent_str =
151                                            indent_char.encode_utf8(&mut indent_buffer);
152                                        new_text.replace_range(
153                                            ..line_indent,
154                                            &indent_str.repeat(corrected_indent_len),
155                                        );
156                                    }
157                                }
158
159                                if line_indent.is_some() {
160                                    hunks_tx.send(diff.push_new(&new_text)).await?;
161                                    new_text.clear();
162                                }
163
164                                if lines.peek().is_some() {
165                                    hunks_tx.send(diff.push_new("\n")).await?;
166                                    line_indent = None;
167                                    first_line = false;
168                                }
169                            }
170                        }
171                        hunks_tx.send(diff.push_new(&new_text)).await?;
172                        hunks_tx.send(diff.finish()).await?;
173
174                        anyhow::Ok(())
175                    });
176
177                    while let Some(hunks) = hunks_rx.next().await {
178                        this.update(&mut cx, |this, cx| {
179                            this.last_equal_ranges.clear();
180
181                            let transaction = this.buffer.update(cx, |buffer, cx| {
182                                // Avoid grouping assistant edits with user edits.
183                                buffer.finalize_last_transaction(cx);
184
185                                buffer.start_transaction(cx);
186                                buffer.edit(
187                                    hunks.into_iter().filter_map(|hunk| match hunk {
188                                        Hunk::Insert { text } => {
189                                            let edit_start = snapshot.anchor_after(edit_start);
190                                            Some((edit_start..edit_start, text))
191                                        }
192                                        Hunk::Remove { len } => {
193                                            let edit_end = edit_start + len;
194                                            let edit_range = snapshot.anchor_after(edit_start)
195                                                ..snapshot.anchor_before(edit_end);
196                                            edit_start = edit_end;
197                                            Some((edit_range, String::new()))
198                                        }
199                                        Hunk::Keep { len } => {
200                                            let edit_end = edit_start + len;
201                                            let edit_range = snapshot.anchor_after(edit_start)
202                                                ..snapshot.anchor_before(edit_end);
203                                            edit_start = edit_end;
204                                            this.last_equal_ranges.push(edit_range);
205                                            None
206                                        }
207                                    }),
208                                    None,
209                                    cx,
210                                );
211
212                                buffer.end_transaction(cx)
213                            });
214
215                            if let Some(transaction) = transaction {
216                                if let Some(first_transaction) = this.transaction_id {
217                                    // Group all assistant edits into the first transaction.
218                                    this.buffer.update(cx, |buffer, cx| {
219                                        buffer.merge_transactions(
220                                            transaction,
221                                            first_transaction,
222                                            cx,
223                                        )
224                                    });
225                                } else {
226                                    this.transaction_id = Some(transaction);
227                                    this.buffer.update(cx, |buffer, cx| {
228                                        buffer.finalize_last_transaction(cx)
229                                    });
230                                }
231                            }
232
233                            cx.notify();
234                        })?;
235                    }
236
237                    diff.await?;
238                    anyhow::Ok(())
239                };
240
241                let result = generate.await;
242                this.update(&mut cx, |this, cx| {
243                    this.last_equal_ranges.clear();
244                    this.idle = true;
245                    if let Err(error) = result {
246                        this.error = Some(error);
247                    }
248                    cx.emit(Event::Finished);
249                    cx.notify();
250                })
251                .ok();
252            }
253        });
254        self.error.take();
255        self.idle = false;
256        cx.notify();
257    }
258
259    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
260        if let Some(transaction_id) = self.transaction_id {
261            self.buffer
262                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
263        }
264    }
265}
266
267fn strip_invalid_spans_from_codeblock(
268    stream: impl Stream<Item = Result<String>>,
269) -> impl Stream<Item = Result<String>> {
270    let mut first_line = true;
271    let mut buffer = String::new();
272    let mut starts_with_markdown_codeblock = false;
273    let mut includes_start_or_end_span = false;
274    stream.filter_map(move |chunk| {
275        let chunk = match chunk {
276            Ok(chunk) => chunk,
277            Err(err) => return future::ready(Some(Err(err))),
278        };
279        buffer.push_str(&chunk);
280
281        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
282            includes_start_or_end_span = true;
283
284            buffer = buffer
285                .strip_prefix("<|S|>")
286                .or_else(|| buffer.strip_prefix("<|S|"))
287                .unwrap_or(&buffer)
288                .to_string();
289        } else if buffer.ends_with("|E|>") {
290            includes_start_or_end_span = true;
291        } else if buffer.starts_with("<|")
292            || buffer.starts_with("<|S")
293            || buffer.starts_with("<|S|")
294            || buffer.ends_with('|')
295            || buffer.ends_with("|E")
296            || buffer.ends_with("|E|")
297        {
298            return future::ready(None);
299        }
300
301        if first_line {
302            if buffer.is_empty() || buffer == "`" || buffer == "``" {
303                return future::ready(None);
304            } else if buffer.starts_with("```") {
305                starts_with_markdown_codeblock = true;
306                if let Some(newline_ix) = buffer.find('\n') {
307                    buffer.replace_range(..newline_ix + 1, "");
308                    first_line = false;
309                } else {
310                    return future::ready(None);
311                }
312            }
313        }
314
315        let mut text = buffer.to_string();
316        if starts_with_markdown_codeblock {
317            text = text
318                .strip_suffix("\n```\n")
319                .or_else(|| text.strip_suffix("\n```"))
320                .or_else(|| text.strip_suffix("\n``"))
321                .or_else(|| text.strip_suffix("\n`"))
322                .or_else(|| text.strip_suffix('\n'))
323                .unwrap_or(&text)
324                .to_string();
325        }
326
327        if includes_start_or_end_span {
328            text = text
329                .strip_suffix("|E|>")
330                .or_else(|| text.strip_suffix("E|>"))
331                .or_else(|| text.strip_prefix("|>"))
332                .or_else(|| text.strip_prefix('>'))
333                .unwrap_or(&text)
334                .to_string();
335        };
336
337        if text.contains('\n') {
338            first_line = false;
339        }
340
341        let remainder = buffer.split_off(text.len());
342        let result = if buffer.is_empty() {
343            None
344        } else {
345            Some(Ok(buffer.clone()))
346        };
347
348        buffer = remainder;
349        future::ready(result)
350    })
351}
352
353#[cfg(test)]
354mod tests {
355    use std::sync::Arc;
356
357    use crate::FakeCompletionProvider;
358
359    use super::*;
360    use futures::stream::{self};
361    use gpui::{Context, TestAppContext};
362    use indoc::indoc;
363    use language::{
364        language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
365        Point,
366    };
367    use rand::prelude::*;
368    use serde::Serialize;
369    use settings::SettingsStore;
370
371    #[derive(Serialize)]
372    pub struct DummyCompletionRequest {
373        pub name: String,
374    }
375
376    #[gpui::test(iterations = 10)]
377    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
378        let provider = FakeCompletionProvider::default();
379        cx.set_global(cx.update(SettingsStore::test));
380        cx.set_global(CompletionProvider::Fake(provider.clone()));
381        cx.update(language_settings::init);
382
383        let text = indoc! {"
384            fn main() {
385                let x = 0;
386                for _ in 0..10 {
387                    x += 1;
388                }
389            }
390        "};
391        let buffer =
392            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
393        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
394        let range = buffer.read_with(cx, |buffer, cx| {
395            let snapshot = buffer.snapshot(cx);
396            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
397        });
398        let codegen =
399            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
400
401        let request = LanguageModelRequest::default();
402        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
403
404        let mut new_text = concat!(
405            "       let mut x = 0;\n",
406            "       while x < 10 {\n",
407            "           x += 1;\n",
408            "       }",
409        );
410        while !new_text.is_empty() {
411            let max_len = cmp::min(new_text.len(), 10);
412            let len = rng.gen_range(1..=max_len);
413            let (chunk, suffix) = new_text.split_at(len);
414            provider.send_completion(chunk.into());
415            new_text = suffix;
416            cx.background_executor.run_until_parked();
417        }
418        provider.finish_completion();
419        cx.background_executor.run_until_parked();
420
421        assert_eq!(
422            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
423            indoc! {"
424                fn main() {
425                    let mut x = 0;
426                    while x < 10 {
427                        x += 1;
428                    }
429                }
430            "}
431        );
432    }
433
434    #[gpui::test(iterations = 10)]
435    async fn test_autoindent_when_generating_past_indentation(
436        cx: &mut TestAppContext,
437        mut rng: StdRng,
438    ) {
439        let provider = FakeCompletionProvider::default();
440        cx.set_global(CompletionProvider::Fake(provider.clone()));
441        cx.set_global(cx.update(SettingsStore::test));
442        cx.update(language_settings::init);
443
444        let text = indoc! {"
445            fn main() {
446                le
447            }
448        "};
449        let buffer =
450            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
451        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
452        let position = buffer.read_with(cx, |buffer, cx| {
453            let snapshot = buffer.snapshot(cx);
454            snapshot.anchor_before(Point::new(1, 6))
455        });
456        let codegen =
457            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
458
459        let request = LanguageModelRequest::default();
460        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
461
462        let mut new_text = concat!(
463            "t mut x = 0;\n",
464            "while x < 10 {\n",
465            "    x += 1;\n",
466            "}", //
467        );
468        while !new_text.is_empty() {
469            let max_len = cmp::min(new_text.len(), 10);
470            let len = rng.gen_range(1..=max_len);
471            let (chunk, suffix) = new_text.split_at(len);
472            provider.send_completion(chunk.into());
473            new_text = suffix;
474            cx.background_executor.run_until_parked();
475        }
476        provider.finish_completion();
477        cx.background_executor.run_until_parked();
478
479        assert_eq!(
480            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
481            indoc! {"
482                fn main() {
483                    let mut x = 0;
484                    while x < 10 {
485                        x += 1;
486                    }
487                }
488            "}
489        );
490    }
491
492    #[gpui::test(iterations = 10)]
493    async fn test_autoindent_when_generating_before_indentation(
494        cx: &mut TestAppContext,
495        mut rng: StdRng,
496    ) {
497        let provider = FakeCompletionProvider::default();
498        cx.set_global(CompletionProvider::Fake(provider.clone()));
499        cx.set_global(cx.update(SettingsStore::test));
500        cx.update(language_settings::init);
501
502        let text = concat!(
503            "fn main() {\n",
504            "  \n",
505            "}\n" //
506        );
507        let buffer =
508            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
509        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
510        let position = buffer.read_with(cx, |buffer, cx| {
511            let snapshot = buffer.snapshot(cx);
512            snapshot.anchor_before(Point::new(1, 2))
513        });
514        let codegen =
515            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
516
517        let request = LanguageModelRequest::default();
518        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
519
520        let mut new_text = concat!(
521            "let mut x = 0;\n",
522            "while x < 10 {\n",
523            "    x += 1;\n",
524            "}", //
525        );
526        while !new_text.is_empty() {
527            let max_len = cmp::min(new_text.len(), 10);
528            let len = rng.gen_range(1..=max_len);
529            let (chunk, suffix) = new_text.split_at(len);
530            provider.send_completion(chunk.into());
531            new_text = suffix;
532            cx.background_executor.run_until_parked();
533        }
534        provider.finish_completion();
535        cx.background_executor.run_until_parked();
536
537        assert_eq!(
538            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
539            indoc! {"
540                fn main() {
541                    let mut x = 0;
542                    while x < 10 {
543                        x += 1;
544                    }
545                }
546            "}
547        );
548    }
549
550    #[gpui::test]
551    async fn test_strip_invalid_spans_from_codeblock() {
552        assert_eq!(
553            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
554                .map(|chunk| chunk.unwrap())
555                .collect::<String>()
556                .await,
557            "Lorem ipsum dolor"
558        );
559        assert_eq!(
560            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
561                .map(|chunk| chunk.unwrap())
562                .collect::<String>()
563                .await,
564            "Lorem ipsum dolor"
565        );
566        assert_eq!(
567            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
568                .map(|chunk| chunk.unwrap())
569                .collect::<String>()
570                .await,
571            "Lorem ipsum dolor"
572        );
573        assert_eq!(
574            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
575                .map(|chunk| chunk.unwrap())
576                .collect::<String>()
577                .await,
578            "Lorem ipsum dolor"
579        );
580        assert_eq!(
581            strip_invalid_spans_from_codeblock(chunks(
582                "```html\n```js\nLorem ipsum dolor\n```\n```",
583                2
584            ))
585            .map(|chunk| chunk.unwrap())
586            .collect::<String>()
587            .await,
588            "```js\nLorem ipsum dolor\n```"
589        );
590        assert_eq!(
591            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
592                .map(|chunk| chunk.unwrap())
593                .collect::<String>()
594                .await,
595            "``\nLorem ipsum dolor\n```"
596        );
597        assert_eq!(
598            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
599                .map(|chunk| chunk.unwrap())
600                .collect::<String>()
601                .await,
602            "Lorem ipsum"
603        );
604
605        assert_eq!(
606            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
607                .map(|chunk| chunk.unwrap())
608                .collect::<String>()
609                .await,
610            "Lorem ipsum"
611        );
612
613        assert_eq!(
614            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
615                .map(|chunk| chunk.unwrap())
616                .collect::<String>()
617                .await,
618            "Lorem ipsum"
619        );
620        assert_eq!(
621            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
622                .map(|chunk| chunk.unwrap())
623                .collect::<String>()
624                .await,
625            "Lorem ipsum"
626        );
627        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
628            stream::iter(
629                text.chars()
630                    .collect::<Vec<_>>()
631                    .chunks(size)
632                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
633                    .collect::<Vec<_>>(),
634            )
635        }
636    }
637
638    fn rust_lang() -> Language {
639        Language::new(
640            LanguageConfig {
641                name: "Rust".into(),
642                matcher: LanguageMatcher {
643                    path_suffixes: vec!["rs".to_string()],
644                    ..Default::default()
645                },
646                ..Default::default()
647            },
648            Some(tree_sitter_rust::language()),
649        )
650        .with_indents_query(
651            r#"
652            (call_expression) @indent
653            (field_expression) @indent
654            (_ "(" ")" @end) @indent
655            (_ "{" "}" @end) @indent
656            "#,
657        )
658        .unwrap()
659    }
660}