codegen.rs

  1use crate::{
  2    streaming_diff::{Hunk, StreamingDiff},
  3    CompletionProvider, LanguageModelRequest,
  4};
  5use anyhow::Result;
  6use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  7use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  8use gpui::{EventEmitter, Model, ModelContext, Task};
  9use language::{Rope, TransactionId};
 10use std::{cmp, future, ops::Range};
 11
 12pub enum Event {
 13    Finished,
 14    Undone,
 15}
 16
 17#[derive(Clone)]
 18pub enum CodegenKind {
 19    Transform { range: Range<Anchor> },
 20    Generate { position: Anchor },
 21}
 22
 23pub struct Codegen {
 24    buffer: Model<MultiBuffer>,
 25    snapshot: MultiBufferSnapshot,
 26    kind: CodegenKind,
 27    last_equal_ranges: Vec<Range<Anchor>>,
 28    transaction_id: Option<TransactionId>,
 29    error: Option<anyhow::Error>,
 30    generation: Task<()>,
 31    idle: bool,
 32    _subscription: gpui::Subscription,
 33}
 34
 35impl EventEmitter<Event> for Codegen {}
 36
 37impl Codegen {
 38    pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
 39        let snapshot = buffer.read(cx).snapshot(cx);
 40        Self {
 41            buffer: buffer.clone(),
 42            snapshot,
 43            kind,
 44            last_equal_ranges: Default::default(),
 45            transaction_id: Default::default(),
 46            error: Default::default(),
 47            idle: true,
 48            generation: Task::ready(()),
 49            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 50        }
 51    }
 52
 53    fn handle_buffer_event(
 54        &mut self,
 55        _buffer: Model<MultiBuffer>,
 56        event: &multi_buffer::Event,
 57        cx: &mut ModelContext<Self>,
 58    ) {
 59        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 60            if self.transaction_id == Some(*transaction_id) {
 61                self.transaction_id = None;
 62                self.generation = Task::ready(());
 63                cx.emit(Event::Undone);
 64            }
 65        }
 66    }
 67
 68    pub fn range(&self) -> Range<Anchor> {
 69        match &self.kind {
 70            CodegenKind::Transform { range } => range.clone(),
 71            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 72        }
 73    }
 74
 75    pub fn kind(&self) -> &CodegenKind {
 76        &self.kind
 77    }
 78
 79    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 80        &self.last_equal_ranges
 81    }
 82
 83    pub fn idle(&self) -> bool {
 84        self.idle
 85    }
 86
 87    pub fn error(&self) -> Option<&anyhow::Error> {
 88        self.error.as_ref()
 89    }
 90
 91    pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
 92        let range = self.range();
 93        let snapshot = self.snapshot.clone();
 94        let selected_text = snapshot
 95            .text_for_range(range.start..range.end)
 96            .collect::<Rope>();
 97
 98        let selection_start = range.start.to_point(&snapshot);
 99        let suggested_line_indent = snapshot
100            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
101            .into_values()
102            .next()
103            .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
104
105        let response = CompletionProvider::global(cx).complete(prompt);
106        self.generation = cx.spawn(|this, mut cx| {
107            async move {
108                let generate = async {
109                    let mut edit_start = range.start.to_offset(&snapshot);
110
111                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
112                    let diff = cx.background_executor().spawn(async move {
113                        let chunks = strip_invalid_spans_from_codeblock(response.await?);
114                        futures::pin_mut!(chunks);
115                        let mut diff = StreamingDiff::new(selected_text.to_string());
116
117                        let mut new_text = String::new();
118                        let mut base_indent = None;
119                        let mut line_indent = None;
120                        let mut first_line = true;
121
122                        while let Some(chunk) = chunks.next().await {
123                            let chunk = chunk?;
124
125                            let mut lines = chunk.split('\n').peekable();
126                            while let Some(line) = lines.next() {
127                                new_text.push_str(line);
128                                if line_indent.is_none() {
129                                    if let Some(non_whitespace_ch_ix) =
130                                        new_text.find(|ch: char| !ch.is_whitespace())
131                                    {
132                                        line_indent = Some(non_whitespace_ch_ix);
133                                        base_indent = base_indent.or(line_indent);
134
135                                        let line_indent = line_indent.unwrap();
136                                        let base_indent = base_indent.unwrap();
137                                        let indent_delta = line_indent as i32 - base_indent as i32;
138                                        let mut corrected_indent_len = cmp::max(
139                                            0,
140                                            suggested_line_indent.len as i32 + indent_delta,
141                                        )
142                                            as usize;
143                                        if first_line {
144                                            corrected_indent_len = corrected_indent_len
145                                                .saturating_sub(selection_start.column as usize);
146                                        }
147
148                                        let indent_char = suggested_line_indent.char();
149                                        let mut indent_buffer = [0; 4];
150                                        let indent_str =
151                                            indent_char.encode_utf8(&mut indent_buffer);
152                                        new_text.replace_range(
153                                            ..line_indent,
154                                            &indent_str.repeat(corrected_indent_len),
155                                        );
156                                    }
157                                }
158
159                                if line_indent.is_some() {
160                                    hunks_tx.send(diff.push_new(&new_text)).await?;
161                                    new_text.clear();
162                                }
163
164                                if lines.peek().is_some() {
165                                    hunks_tx.send(diff.push_new("\n")).await?;
166                                    line_indent = None;
167                                    first_line = false;
168                                }
169                            }
170                        }
171                        hunks_tx.send(diff.push_new(&new_text)).await?;
172                        hunks_tx.send(diff.finish()).await?;
173
174                        anyhow::Ok(())
175                    });
176
177                    while let Some(hunks) = hunks_rx.next().await {
178                        this.update(&mut cx, |this, cx| {
179                            this.last_equal_ranges.clear();
180
181                            let transaction = this.buffer.update(cx, |buffer, cx| {
182                                // Avoid grouping assistant edits with user edits.
183                                buffer.finalize_last_transaction(cx);
184
185                                buffer.start_transaction(cx);
186                                buffer.edit(
187                                    hunks.into_iter().filter_map(|hunk| match hunk {
188                                        Hunk::Insert { text } => {
189                                            let edit_start = snapshot.anchor_after(edit_start);
190                                            Some((edit_start..edit_start, text))
191                                        }
192                                        Hunk::Remove { len } => {
193                                            let edit_end = edit_start + len;
194                                            let edit_range = snapshot.anchor_after(edit_start)
195                                                ..snapshot.anchor_before(edit_end);
196                                            edit_start = edit_end;
197                                            Some((edit_range, String::new()))
198                                        }
199                                        Hunk::Keep { len } => {
200                                            let edit_end = edit_start + len;
201                                            let edit_range = snapshot.anchor_after(edit_start)
202                                                ..snapshot.anchor_before(edit_end);
203                                            edit_start = edit_end;
204                                            this.last_equal_ranges.push(edit_range);
205                                            None
206                                        }
207                                    }),
208                                    None,
209                                    cx,
210                                );
211
212                                buffer.end_transaction(cx)
213                            });
214
215                            if let Some(transaction) = transaction {
216                                if let Some(first_transaction) = this.transaction_id {
217                                    // Group all assistant edits into the first transaction.
218                                    this.buffer.update(cx, |buffer, cx| {
219                                        buffer.merge_transactions(
220                                            transaction,
221                                            first_transaction,
222                                            cx,
223                                        )
224                                    });
225                                } else {
226                                    this.transaction_id = Some(transaction);
227                                    this.buffer.update(cx, |buffer, cx| {
228                                        buffer.finalize_last_transaction(cx)
229                                    });
230                                }
231                            }
232
233                            cx.notify();
234                        })?;
235                    }
236
237                    diff.await?;
238                    anyhow::Ok(())
239                };
240
241                let result = generate.await;
242                this.update(&mut cx, |this, cx| {
243                    this.last_equal_ranges.clear();
244                    this.idle = true;
245                    if let Err(error) = result {
246                        this.error = Some(error);
247                    }
248                    cx.emit(Event::Finished);
249                    cx.notify();
250                })
251                .ok();
252            }
253        });
254        self.error.take();
255        self.idle = false;
256        cx.notify();
257    }
258
259    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
260        if let Some(transaction_id) = self.transaction_id {
261            self.buffer
262                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
263        }
264    }
265}
266
267fn strip_invalid_spans_from_codeblock(
268    stream: impl Stream<Item = Result<String>>,
269) -> impl Stream<Item = Result<String>> {
270    let mut first_line = true;
271    let mut buffer = String::new();
272    let mut starts_with_markdown_codeblock = false;
273    let mut includes_start_or_end_span = false;
274    stream.filter_map(move |chunk| {
275        let chunk = match chunk {
276            Ok(chunk) => chunk,
277            Err(err) => return future::ready(Some(Err(err))),
278        };
279        buffer.push_str(&chunk);
280
281        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
282            includes_start_or_end_span = true;
283
284            buffer = buffer
285                .strip_prefix("<|S|>")
286                .or_else(|| buffer.strip_prefix("<|S|"))
287                .unwrap_or(&buffer)
288                .to_string();
289        } else if buffer.ends_with("|E|>") {
290            includes_start_or_end_span = true;
291        } else if buffer.starts_with("<|")
292            || buffer.starts_with("<|S")
293            || buffer.starts_with("<|S|")
294            || buffer.ends_with('|')
295            || buffer.ends_with("|E")
296            || buffer.ends_with("|E|")
297        {
298            return future::ready(None);
299        }
300
301        if first_line {
302            if buffer.is_empty() || buffer == "`" || buffer == "``" {
303                return future::ready(None);
304            } else if buffer.starts_with("```") {
305                starts_with_markdown_codeblock = true;
306                if let Some(newline_ix) = buffer.find('\n') {
307                    buffer.replace_range(..newline_ix + 1, "");
308                    first_line = false;
309                } else {
310                    return future::ready(None);
311                }
312            }
313        }
314
315        let mut text = buffer.to_string();
316        if starts_with_markdown_codeblock {
317            text = text
318                .strip_suffix("\n```\n")
319                .or_else(|| text.strip_suffix("\n```"))
320                .or_else(|| text.strip_suffix("\n``"))
321                .or_else(|| text.strip_suffix("\n`"))
322                .or_else(|| text.strip_suffix('\n'))
323                .unwrap_or(&text)
324                .to_string();
325        }
326
327        if includes_start_or_end_span {
328            text = text
329                .strip_suffix("|E|>")
330                .or_else(|| text.strip_suffix("E|>"))
331                .or_else(|| text.strip_prefix("|>"))
332                .or_else(|| text.strip_prefix('>'))
333                .unwrap_or(&text)
334                .to_string();
335        };
336
337        if text.contains('\n') {
338            first_line = false;
339        }
340
341        let remainder = buffer.split_off(text.len());
342        let result = if buffer.is_empty() {
343            None
344        } else {
345            Some(Ok(buffer.clone()))
346        };
347
348        buffer = remainder;
349        future::ready(result)
350    })
351}
352
353#[cfg(test)]
354mod tests {
355    use std::sync::Arc;
356
357    use crate::FakeCompletionProvider;
358
359    use super::*;
360    use futures::stream::{self};
361    use gpui::{Context, TestAppContext};
362    use indoc::indoc;
363    use language::{
364        language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig,
365        LanguageMatcher, Point,
366    };
367    use rand::prelude::*;
368    use serde::Serialize;
369    use settings::SettingsStore;
370
371    #[derive(Serialize)]
372    pub struct DummyCompletionRequest {
373        pub name: String,
374    }
375
376    #[gpui::test(iterations = 10)]
377    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
378        let provider = FakeCompletionProvider::default();
379        cx.set_global(cx.update(SettingsStore::test));
380        cx.set_global(CompletionProvider::Fake(provider.clone()));
381        cx.update(language_settings::init);
382
383        let text = indoc! {"
384            fn main() {
385                let x = 0;
386                for _ in 0..10 {
387                    x += 1;
388                }
389            }
390        "};
391        let buffer = cx.new_model(|cx| {
392            Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
393        });
394        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
395        let range = buffer.read_with(cx, |buffer, cx| {
396            let snapshot = buffer.snapshot(cx);
397            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
398        });
399        let codegen =
400            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
401
402        let request = LanguageModelRequest::default();
403        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
404
405        let mut new_text = concat!(
406            "       let mut x = 0;\n",
407            "       while x < 10 {\n",
408            "           x += 1;\n",
409            "       }",
410        );
411        while !new_text.is_empty() {
412            let max_len = cmp::min(new_text.len(), 10);
413            let len = rng.gen_range(1..=max_len);
414            let (chunk, suffix) = new_text.split_at(len);
415            provider.send_completion(chunk.into());
416            new_text = suffix;
417            cx.background_executor.run_until_parked();
418        }
419        provider.finish_completion();
420        cx.background_executor.run_until_parked();
421
422        assert_eq!(
423            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
424            indoc! {"
425                fn main() {
426                    let mut x = 0;
427                    while x < 10 {
428                        x += 1;
429                    }
430                }
431            "}
432        );
433    }
434
435    #[gpui::test(iterations = 10)]
436    async fn test_autoindent_when_generating_past_indentation(
437        cx: &mut TestAppContext,
438        mut rng: StdRng,
439    ) {
440        let provider = FakeCompletionProvider::default();
441        cx.set_global(CompletionProvider::Fake(provider.clone()));
442        cx.set_global(cx.update(SettingsStore::test));
443        cx.update(language_settings::init);
444
445        let text = indoc! {"
446            fn main() {
447                le
448            }
449        "};
450        let buffer = cx.new_model(|cx| {
451            Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
452        });
453        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
454        let position = buffer.read_with(cx, |buffer, cx| {
455            let snapshot = buffer.snapshot(cx);
456            snapshot.anchor_before(Point::new(1, 6))
457        });
458        let codegen =
459            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
460
461        let request = LanguageModelRequest::default();
462        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
463
464        let mut new_text = concat!(
465            "t mut x = 0;\n",
466            "while x < 10 {\n",
467            "    x += 1;\n",
468            "}", //
469        );
470        while !new_text.is_empty() {
471            let max_len = cmp::min(new_text.len(), 10);
472            let len = rng.gen_range(1..=max_len);
473            let (chunk, suffix) = new_text.split_at(len);
474            provider.send_completion(chunk.into());
475            new_text = suffix;
476            cx.background_executor.run_until_parked();
477        }
478        provider.finish_completion();
479        cx.background_executor.run_until_parked();
480
481        assert_eq!(
482            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
483            indoc! {"
484                fn main() {
485                    let mut x = 0;
486                    while x < 10 {
487                        x += 1;
488                    }
489                }
490            "}
491        );
492    }
493
494    #[gpui::test(iterations = 10)]
495    async fn test_autoindent_when_generating_before_indentation(
496        cx: &mut TestAppContext,
497        mut rng: StdRng,
498    ) {
499        let provider = FakeCompletionProvider::default();
500        cx.set_global(CompletionProvider::Fake(provider.clone()));
501        cx.set_global(cx.update(SettingsStore::test));
502        cx.update(language_settings::init);
503
504        let text = concat!(
505            "fn main() {\n",
506            "  \n",
507            "}\n" //
508        );
509        let buffer = cx.new_model(|cx| {
510            Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
511        });
512        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
513        let position = buffer.read_with(cx, |buffer, cx| {
514            let snapshot = buffer.snapshot(cx);
515            snapshot.anchor_before(Point::new(1, 2))
516        });
517        let codegen =
518            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
519
520        let request = LanguageModelRequest::default();
521        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
522
523        let mut new_text = concat!(
524            "let mut x = 0;\n",
525            "while x < 10 {\n",
526            "    x += 1;\n",
527            "}", //
528        );
529        while !new_text.is_empty() {
530            let max_len = cmp::min(new_text.len(), 10);
531            let len = rng.gen_range(1..=max_len);
532            let (chunk, suffix) = new_text.split_at(len);
533            provider.send_completion(chunk.into());
534            new_text = suffix;
535            cx.background_executor.run_until_parked();
536        }
537        provider.finish_completion();
538        cx.background_executor.run_until_parked();
539
540        assert_eq!(
541            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
542            indoc! {"
543                fn main() {
544                    let mut x = 0;
545                    while x < 10 {
546                        x += 1;
547                    }
548                }
549            "}
550        );
551    }
552
553    #[gpui::test]
554    async fn test_strip_invalid_spans_from_codeblock() {
555        assert_eq!(
556            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
557                .map(|chunk| chunk.unwrap())
558                .collect::<String>()
559                .await,
560            "Lorem ipsum dolor"
561        );
562        assert_eq!(
563            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
564                .map(|chunk| chunk.unwrap())
565                .collect::<String>()
566                .await,
567            "Lorem ipsum dolor"
568        );
569        assert_eq!(
570            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
571                .map(|chunk| chunk.unwrap())
572                .collect::<String>()
573                .await,
574            "Lorem ipsum dolor"
575        );
576        assert_eq!(
577            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
578                .map(|chunk| chunk.unwrap())
579                .collect::<String>()
580                .await,
581            "Lorem ipsum dolor"
582        );
583        assert_eq!(
584            strip_invalid_spans_from_codeblock(chunks(
585                "```html\n```js\nLorem ipsum dolor\n```\n```",
586                2
587            ))
588            .map(|chunk| chunk.unwrap())
589            .collect::<String>()
590            .await,
591            "```js\nLorem ipsum dolor\n```"
592        );
593        assert_eq!(
594            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
595                .map(|chunk| chunk.unwrap())
596                .collect::<String>()
597                .await,
598            "``\nLorem ipsum dolor\n```"
599        );
600        assert_eq!(
601            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
602                .map(|chunk| chunk.unwrap())
603                .collect::<String>()
604                .await,
605            "Lorem ipsum"
606        );
607
608        assert_eq!(
609            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
610                .map(|chunk| chunk.unwrap())
611                .collect::<String>()
612                .await,
613            "Lorem ipsum"
614        );
615
616        assert_eq!(
617            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
618                .map(|chunk| chunk.unwrap())
619                .collect::<String>()
620                .await,
621            "Lorem ipsum"
622        );
623        assert_eq!(
624            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
625                .map(|chunk| chunk.unwrap())
626                .collect::<String>()
627                .await,
628            "Lorem ipsum"
629        );
630        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
631            stream::iter(
632                text.chars()
633                    .collect::<Vec<_>>()
634                    .chunks(size)
635                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
636                    .collect::<Vec<_>>(),
637            )
638        }
639    }
640
641    fn rust_lang() -> Language {
642        Language::new(
643            LanguageConfig {
644                name: "Rust".into(),
645                matcher: LanguageMatcher {
646                    path_suffixes: vec!["rs".to_string()],
647                    ..Default::default()
648                },
649                ..Default::default()
650            },
651            Some(tree_sitter_rust::language()),
652        )
653        .with_indents_query(
654            r#"
655            (call_expression) @indent
656            (field_expression) @indent
657            (_ "(" ")" @end) @indent
658            (_ "{" "}" @end) @indent
659            "#,
660        )
661        .unwrap()
662    }
663}