codegen.rs

  1use crate::{
  2    streaming_diff::{Hunk, StreamingDiff},
  3    CompletionProvider, LanguageModelRequest,
  4};
  5use anyhow::Result;
  6use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  7use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  8use gpui::{EventEmitter, Model, ModelContext, Task};
  9use language::{Rope, TransactionId};
 10use multi_buffer::MultiBufferRow;
 11use std::{cmp, future, ops::Range};
 12
 13pub enum Event {
 14    Finished,
 15    Undone,
 16}
 17
 18#[derive(Clone)]
 19pub enum CodegenKind {
 20    Transform { range: Range<Anchor> },
 21    Generate { position: Anchor },
 22}
 23
 24pub struct Codegen {
 25    buffer: Model<MultiBuffer>,
 26    snapshot: MultiBufferSnapshot,
 27    kind: CodegenKind,
 28    last_equal_ranges: Vec<Range<Anchor>>,
 29    transaction_id: Option<TransactionId>,
 30    error: Option<anyhow::Error>,
 31    generation: Task<()>,
 32    idle: bool,
 33    _subscription: gpui::Subscription,
 34}
 35
 36impl EventEmitter<Event> for Codegen {}
 37
 38impl Codegen {
 39    pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
 40        let snapshot = buffer.read(cx).snapshot(cx);
 41        Self {
 42            buffer: buffer.clone(),
 43            snapshot,
 44            kind,
 45            last_equal_ranges: Default::default(),
 46            transaction_id: Default::default(),
 47            error: Default::default(),
 48            idle: true,
 49            generation: Task::ready(()),
 50            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 51        }
 52    }
 53
 54    fn handle_buffer_event(
 55        &mut self,
 56        _buffer: Model<MultiBuffer>,
 57        event: &multi_buffer::Event,
 58        cx: &mut ModelContext<Self>,
 59    ) {
 60        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 61            if self.transaction_id == Some(*transaction_id) {
 62                self.transaction_id = None;
 63                self.generation = Task::ready(());
 64                cx.emit(Event::Undone);
 65            }
 66        }
 67    }
 68
 69    pub fn range(&self) -> Range<Anchor> {
 70        match &self.kind {
 71            CodegenKind::Transform { range } => range.clone(),
 72            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 73        }
 74    }
 75
 76    pub fn kind(&self) -> &CodegenKind {
 77        &self.kind
 78    }
 79
 80    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 81        &self.last_equal_ranges
 82    }
 83
 84    pub fn idle(&self) -> bool {
 85        self.idle
 86    }
 87
 88    pub fn error(&self) -> Option<&anyhow::Error> {
 89        self.error.as_ref()
 90    }
 91
 92    pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
 93        let range = self.range();
 94        let snapshot = self.snapshot.clone();
 95        let selected_text = snapshot
 96            .text_for_range(range.start..range.end)
 97            .collect::<Rope>();
 98
 99        let selection_start = range.start.to_point(&snapshot);
100        let suggested_line_indent = snapshot
101            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
102            .into_values()
103            .next()
104            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
105
106        let response = CompletionProvider::global(cx).complete(prompt);
107        self.generation = cx.spawn(|this, mut cx| {
108            async move {
109                let generate = async {
110                    let mut edit_start = range.start.to_offset(&snapshot);
111
112                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
113                    let diff = cx.background_executor().spawn(async move {
114                        let chunks = strip_invalid_spans_from_codeblock(response.await?);
115                        futures::pin_mut!(chunks);
116                        let mut diff = StreamingDiff::new(selected_text.to_string());
117
118                        let mut new_text = String::new();
119                        let mut base_indent = None;
120                        let mut line_indent = None;
121                        let mut first_line = true;
122
123                        while let Some(chunk) = chunks.next().await {
124                            let chunk = chunk?;
125
126                            let mut lines = chunk.split('\n').peekable();
127                            while let Some(line) = lines.next() {
128                                new_text.push_str(line);
129                                if line_indent.is_none() {
130                                    if let Some(non_whitespace_ch_ix) =
131                                        new_text.find(|ch: char| !ch.is_whitespace())
132                                    {
133                                        line_indent = Some(non_whitespace_ch_ix);
134                                        base_indent = base_indent.or(line_indent);
135
136                                        let line_indent = line_indent.unwrap();
137                                        let base_indent = base_indent.unwrap();
138                                        let indent_delta = line_indent as i32 - base_indent as i32;
139                                        let mut corrected_indent_len = cmp::max(
140                                            0,
141                                            suggested_line_indent.len as i32 + indent_delta,
142                                        )
143                                            as usize;
144                                        if first_line {
145                                            corrected_indent_len = corrected_indent_len
146                                                .saturating_sub(selection_start.column as usize);
147                                        }
148
149                                        let indent_char = suggested_line_indent.char();
150                                        let mut indent_buffer = [0; 4];
151                                        let indent_str =
152                                            indent_char.encode_utf8(&mut indent_buffer);
153                                        new_text.replace_range(
154                                            ..line_indent,
155                                            &indent_str.repeat(corrected_indent_len),
156                                        );
157                                    }
158                                }
159
160                                if line_indent.is_some() {
161                                    hunks_tx.send(diff.push_new(&new_text)).await?;
162                                    new_text.clear();
163                                }
164
165                                if lines.peek().is_some() {
166                                    hunks_tx.send(diff.push_new("\n")).await?;
167                                    line_indent = None;
168                                    first_line = false;
169                                }
170                            }
171                        }
172                        hunks_tx.send(diff.push_new(&new_text)).await?;
173                        hunks_tx.send(diff.finish()).await?;
174
175                        anyhow::Ok(())
176                    });
177
178                    while let Some(hunks) = hunks_rx.next().await {
179                        this.update(&mut cx, |this, cx| {
180                            this.last_equal_ranges.clear();
181
182                            let transaction = this.buffer.update(cx, |buffer, cx| {
183                                // Avoid grouping assistant edits with user edits.
184                                buffer.finalize_last_transaction(cx);
185
186                                buffer.start_transaction(cx);
187                                buffer.edit(
188                                    hunks.into_iter().filter_map(|hunk| match hunk {
189                                        Hunk::Insert { text } => {
190                                            let edit_start = snapshot.anchor_after(edit_start);
191                                            Some((edit_start..edit_start, text))
192                                        }
193                                        Hunk::Remove { len } => {
194                                            let edit_end = edit_start + len;
195                                            let edit_range = snapshot.anchor_after(edit_start)
196                                                ..snapshot.anchor_before(edit_end);
197                                            edit_start = edit_end;
198                                            Some((edit_range, String::new()))
199                                        }
200                                        Hunk::Keep { len } => {
201                                            let edit_end = edit_start + len;
202                                            let edit_range = snapshot.anchor_after(edit_start)
203                                                ..snapshot.anchor_before(edit_end);
204                                            edit_start = edit_end;
205                                            this.last_equal_ranges.push(edit_range);
206                                            None
207                                        }
208                                    }),
209                                    None,
210                                    cx,
211                                );
212
213                                buffer.end_transaction(cx)
214                            });
215
216                            if let Some(transaction) = transaction {
217                                if let Some(first_transaction) = this.transaction_id {
218                                    // Group all assistant edits into the first transaction.
219                                    this.buffer.update(cx, |buffer, cx| {
220                                        buffer.merge_transactions(
221                                            transaction,
222                                            first_transaction,
223                                            cx,
224                                        )
225                                    });
226                                } else {
227                                    this.transaction_id = Some(transaction);
228                                    this.buffer.update(cx, |buffer, cx| {
229                                        buffer.finalize_last_transaction(cx)
230                                    });
231                                }
232                            }
233
234                            cx.notify();
235                        })?;
236                    }
237
238                    diff.await?;
239                    anyhow::Ok(())
240                };
241
242                let result = generate.await;
243                this.update(&mut cx, |this, cx| {
244                    this.last_equal_ranges.clear();
245                    this.idle = true;
246                    if let Err(error) = result {
247                        this.error = Some(error);
248                    }
249                    cx.emit(Event::Finished);
250                    cx.notify();
251                })
252                .ok();
253            }
254        });
255        self.error.take();
256        self.idle = false;
257        cx.notify();
258    }
259
260    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
261        if let Some(transaction_id) = self.transaction_id {
262            self.buffer
263                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
264        }
265    }
266}
267
268fn strip_invalid_spans_from_codeblock(
269    stream: impl Stream<Item = Result<String>>,
270) -> impl Stream<Item = Result<String>> {
271    let mut first_line = true;
272    let mut buffer = String::new();
273    let mut starts_with_markdown_codeblock = false;
274    let mut includes_start_or_end_span = false;
275    stream.filter_map(move |chunk| {
276        let chunk = match chunk {
277            Ok(chunk) => chunk,
278            Err(err) => return future::ready(Some(Err(err))),
279        };
280        buffer.push_str(&chunk);
281
282        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
283            includes_start_or_end_span = true;
284
285            buffer = buffer
286                .strip_prefix("<|S|>")
287                .or_else(|| buffer.strip_prefix("<|S|"))
288                .unwrap_or(&buffer)
289                .to_string();
290        } else if buffer.ends_with("|E|>") {
291            includes_start_or_end_span = true;
292        } else if buffer.starts_with("<|")
293            || buffer.starts_with("<|S")
294            || buffer.starts_with("<|S|")
295            || buffer.ends_with('|')
296            || buffer.ends_with("|E")
297            || buffer.ends_with("|E|")
298        {
299            return future::ready(None);
300        }
301
302        if first_line {
303            if buffer.is_empty() || buffer == "`" || buffer == "``" {
304                return future::ready(None);
305            } else if buffer.starts_with("```") {
306                starts_with_markdown_codeblock = true;
307                if let Some(newline_ix) = buffer.find('\n') {
308                    buffer.replace_range(..newline_ix + 1, "");
309                    first_line = false;
310                } else {
311                    return future::ready(None);
312                }
313            }
314        }
315
316        let mut text = buffer.to_string();
317        if starts_with_markdown_codeblock {
318            text = text
319                .strip_suffix("\n```\n")
320                .or_else(|| text.strip_suffix("\n```"))
321                .or_else(|| text.strip_suffix("\n``"))
322                .or_else(|| text.strip_suffix("\n`"))
323                .or_else(|| text.strip_suffix('\n'))
324                .unwrap_or(&text)
325                .to_string();
326        }
327
328        if includes_start_or_end_span {
329            text = text
330                .strip_suffix("|E|>")
331                .or_else(|| text.strip_suffix("E|>"))
332                .or_else(|| text.strip_prefix("|>"))
333                .or_else(|| text.strip_prefix('>'))
334                .unwrap_or(&text)
335                .to_string();
336        };
337
338        if text.contains('\n') {
339            first_line = false;
340        }
341
342        let remainder = buffer.split_off(text.len());
343        let result = if buffer.is_empty() {
344            None
345        } else {
346            Some(Ok(buffer.clone()))
347        };
348
349        buffer = remainder;
350        future::ready(result)
351    })
352}
353
354#[cfg(test)]
355mod tests {
356    use std::sync::Arc;
357
358    use crate::FakeCompletionProvider;
359
360    use super::*;
361    use futures::stream::{self};
362    use gpui::{Context, TestAppContext};
363    use indoc::indoc;
364    use language::{
365        language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
366        Point,
367    };
368    use rand::prelude::*;
369    use serde::Serialize;
370    use settings::SettingsStore;
371
372    #[derive(Serialize)]
373    pub struct DummyCompletionRequest {
374        pub name: String,
375    }
376
377    #[gpui::test(iterations = 10)]
378    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
379        let provider = FakeCompletionProvider::default();
380        cx.set_global(cx.update(SettingsStore::test));
381        cx.set_global(CompletionProvider::Fake(provider.clone()));
382        cx.update(language_settings::init);
383
384        let text = indoc! {"
385            fn main() {
386                let x = 0;
387                for _ in 0..10 {
388                    x += 1;
389                }
390            }
391        "};
392        let buffer =
393            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
394        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
395        let range = buffer.read_with(cx, |buffer, cx| {
396            let snapshot = buffer.snapshot(cx);
397            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
398        });
399        let codegen =
400            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
401
402        let request = LanguageModelRequest::default();
403        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
404
405        let mut new_text = concat!(
406            "       let mut x = 0;\n",
407            "       while x < 10 {\n",
408            "           x += 1;\n",
409            "       }",
410        );
411        while !new_text.is_empty() {
412            let max_len = cmp::min(new_text.len(), 10);
413            let len = rng.gen_range(1..=max_len);
414            let (chunk, suffix) = new_text.split_at(len);
415            provider.send_completion(chunk.into());
416            new_text = suffix;
417            cx.background_executor.run_until_parked();
418        }
419        provider.finish_completion();
420        cx.background_executor.run_until_parked();
421
422        assert_eq!(
423            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
424            indoc! {"
425                fn main() {
426                    let mut x = 0;
427                    while x < 10 {
428                        x += 1;
429                    }
430                }
431            "}
432        );
433    }
434
435    #[gpui::test(iterations = 10)]
436    async fn test_autoindent_when_generating_past_indentation(
437        cx: &mut TestAppContext,
438        mut rng: StdRng,
439    ) {
440        let provider = FakeCompletionProvider::default();
441        cx.set_global(CompletionProvider::Fake(provider.clone()));
442        cx.set_global(cx.update(SettingsStore::test));
443        cx.update(language_settings::init);
444
445        let text = indoc! {"
446            fn main() {
447                le
448            }
449        "};
450        let buffer =
451            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
452        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
453        let position = buffer.read_with(cx, |buffer, cx| {
454            let snapshot = buffer.snapshot(cx);
455            snapshot.anchor_before(Point::new(1, 6))
456        });
457        let codegen =
458            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
459
460        let request = LanguageModelRequest::default();
461        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
462
463        let mut new_text = concat!(
464            "t mut x = 0;\n",
465            "while x < 10 {\n",
466            "    x += 1;\n",
467            "}", //
468        );
469        while !new_text.is_empty() {
470            let max_len = cmp::min(new_text.len(), 10);
471            let len = rng.gen_range(1..=max_len);
472            let (chunk, suffix) = new_text.split_at(len);
473            provider.send_completion(chunk.into());
474            new_text = suffix;
475            cx.background_executor.run_until_parked();
476        }
477        provider.finish_completion();
478        cx.background_executor.run_until_parked();
479
480        assert_eq!(
481            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
482            indoc! {"
483                fn main() {
484                    let mut x = 0;
485                    while x < 10 {
486                        x += 1;
487                    }
488                }
489            "}
490        );
491    }
492
493    #[gpui::test(iterations = 10)]
494    async fn test_autoindent_when_generating_before_indentation(
495        cx: &mut TestAppContext,
496        mut rng: StdRng,
497    ) {
498        let provider = FakeCompletionProvider::default();
499        cx.set_global(CompletionProvider::Fake(provider.clone()));
500        cx.set_global(cx.update(SettingsStore::test));
501        cx.update(language_settings::init);
502
503        let text = concat!(
504            "fn main() {\n",
505            "  \n",
506            "}\n" //
507        );
508        let buffer =
509            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
510        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
511        let position = buffer.read_with(cx, |buffer, cx| {
512            let snapshot = buffer.snapshot(cx);
513            snapshot.anchor_before(Point::new(1, 2))
514        });
515        let codegen =
516            cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
517
518        let request = LanguageModelRequest::default();
519        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
520
521        let mut new_text = concat!(
522            "let mut x = 0;\n",
523            "while x < 10 {\n",
524            "    x += 1;\n",
525            "}", //
526        );
527        while !new_text.is_empty() {
528            let max_len = cmp::min(new_text.len(), 10);
529            let len = rng.gen_range(1..=max_len);
530            let (chunk, suffix) = new_text.split_at(len);
531            provider.send_completion(chunk.into());
532            new_text = suffix;
533            cx.background_executor.run_until_parked();
534        }
535        provider.finish_completion();
536        cx.background_executor.run_until_parked();
537
538        assert_eq!(
539            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
540            indoc! {"
541                fn main() {
542                    let mut x = 0;
543                    while x < 10 {
544                        x += 1;
545                    }
546                }
547            "}
548        );
549    }
550
551    #[gpui::test]
552    async fn test_strip_invalid_spans_from_codeblock() {
553        assert_eq!(
554            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
555                .map(|chunk| chunk.unwrap())
556                .collect::<String>()
557                .await,
558            "Lorem ipsum dolor"
559        );
560        assert_eq!(
561            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
562                .map(|chunk| chunk.unwrap())
563                .collect::<String>()
564                .await,
565            "Lorem ipsum dolor"
566        );
567        assert_eq!(
568            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
569                .map(|chunk| chunk.unwrap())
570                .collect::<String>()
571                .await,
572            "Lorem ipsum dolor"
573        );
574        assert_eq!(
575            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
576                .map(|chunk| chunk.unwrap())
577                .collect::<String>()
578                .await,
579            "Lorem ipsum dolor"
580        );
581        assert_eq!(
582            strip_invalid_spans_from_codeblock(chunks(
583                "```html\n```js\nLorem ipsum dolor\n```\n```",
584                2
585            ))
586            .map(|chunk| chunk.unwrap())
587            .collect::<String>()
588            .await,
589            "```js\nLorem ipsum dolor\n```"
590        );
591        assert_eq!(
592            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
593                .map(|chunk| chunk.unwrap())
594                .collect::<String>()
595                .await,
596            "``\nLorem ipsum dolor\n```"
597        );
598        assert_eq!(
599            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
600                .map(|chunk| chunk.unwrap())
601                .collect::<String>()
602                .await,
603            "Lorem ipsum"
604        );
605
606        assert_eq!(
607            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
608                .map(|chunk| chunk.unwrap())
609                .collect::<String>()
610                .await,
611            "Lorem ipsum"
612        );
613
614        assert_eq!(
615            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
616                .map(|chunk| chunk.unwrap())
617                .collect::<String>()
618                .await,
619            "Lorem ipsum"
620        );
621        assert_eq!(
622            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
623                .map(|chunk| chunk.unwrap())
624                .collect::<String>()
625                .await,
626            "Lorem ipsum"
627        );
628        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
629            stream::iter(
630                text.chars()
631                    .collect::<Vec<_>>()
632                    .chunks(size)
633                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
634                    .collect::<Vec<_>>(),
635            )
636        }
637    }
638
639    fn rust_lang() -> Language {
640        Language::new(
641            LanguageConfig {
642                name: "Rust".into(),
643                matcher: LanguageMatcher {
644                    path_suffixes: vec!["rs".to_string()],
645                    ..Default::default()
646                },
647                ..Default::default()
648            },
649            Some(tree_sitter_rust::language()),
650        )
651        .with_indents_query(
652            r#"
653            (call_expression) @indent
654            (field_expression) @indent
655            (_ "(" ")" @end) @indent
656            (_ "{" "}" @end) @indent
657            "#,
658        )
659        .unwrap()
660    }
661}