codegen.rs

  1use crate::{
  2    streaming_diff::{Hunk, StreamingDiff},
  3    CompletionProvider, LanguageModelRequest,
  4};
  5use anyhow::Result;
  6use client::telemetry::Telemetry;
  7use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
  8use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
  9use gpui::{EventEmitter, Model, ModelContext, Task};
 10use language::{Rope, TransactionId};
 11use multi_buffer::MultiBufferRow;
 12use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
 13
 14pub enum Event {
 15    Finished,
 16    Undone,
 17}
 18
 19#[derive(Clone)]
 20pub enum CodegenKind {
 21    Transform { range: Range<Anchor> },
 22    Generate { position: Anchor },
 23}
 24
 25pub struct Codegen {
 26    buffer: Model<MultiBuffer>,
 27    snapshot: MultiBufferSnapshot,
 28    kind: CodegenKind,
 29    last_equal_ranges: Vec<Range<Anchor>>,
 30    transaction_id: Option<TransactionId>,
 31    error: Option<anyhow::Error>,
 32    generation: Task<()>,
 33    idle: bool,
 34    telemetry: Option<Arc<Telemetry>>,
 35    _subscription: gpui::Subscription,
 36}
 37
 38impl EventEmitter<Event> for Codegen {}
 39
 40impl Codegen {
 41    pub fn new(
 42        buffer: Model<MultiBuffer>,
 43        kind: CodegenKind,
 44        telemetry: Option<Arc<Telemetry>>,
 45        cx: &mut ModelContext<Self>,
 46    ) -> Self {
 47        let snapshot = buffer.read(cx).snapshot(cx);
 48        Self {
 49            buffer: buffer.clone(),
 50            snapshot,
 51            kind,
 52            last_equal_ranges: Default::default(),
 53            transaction_id: Default::default(),
 54            error: Default::default(),
 55            idle: true,
 56            generation: Task::ready(()),
 57            telemetry,
 58            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 59        }
 60    }
 61
 62    fn handle_buffer_event(
 63        &mut self,
 64        _buffer: Model<MultiBuffer>,
 65        event: &multi_buffer::Event,
 66        cx: &mut ModelContext<Self>,
 67    ) {
 68        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
 69            if self.transaction_id == Some(*transaction_id) {
 70                self.transaction_id = None;
 71                self.generation = Task::ready(());
 72                cx.emit(Event::Undone);
 73            }
 74        }
 75    }
 76
 77    pub fn range(&self) -> Range<Anchor> {
 78        match &self.kind {
 79            CodegenKind::Transform { range } => range.clone(),
 80            CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
 81        }
 82    }
 83
 84    pub fn kind(&self) -> &CodegenKind {
 85        &self.kind
 86    }
 87
 88    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
 89        &self.last_equal_ranges
 90    }
 91
 92    pub fn idle(&self) -> bool {
 93        self.idle
 94    }
 95
 96    pub fn error(&self) -> Option<&anyhow::Error> {
 97        self.error.as_ref()
 98    }
 99
100    pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
101        let range = self.range();
102        let snapshot = self.snapshot.clone();
103        let selected_text = snapshot
104            .text_for_range(range.start..range.end)
105            .collect::<Rope>();
106
107        let selection_start = range.start.to_point(&snapshot);
108        let suggested_line_indent = snapshot
109            .suggested_indents(selection_start.row..selection_start.row + 1, cx)
110            .into_values()
111            .next()
112            .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
113
114        let model_telemetry_id = prompt.model.telemetry_id();
115        let response = CompletionProvider::global(cx).complete(prompt);
116        let telemetry = self.telemetry.clone();
117        self.generation = cx.spawn(|this, mut cx| {
118            async move {
119                let generate = async {
120                    let mut edit_start = range.start.to_offset(&snapshot);
121
122                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
123                    let diff = cx.background_executor().spawn(async move {
124                        let mut response_latency = None;
125                        let request_start = Instant::now();
126                        let diff = async {
127                            let chunks = strip_invalid_spans_from_codeblock(response.await?);
128                            futures::pin_mut!(chunks);
129                            let mut diff = StreamingDiff::new(selected_text.to_string());
130
131                            let mut new_text = String::new();
132                            let mut base_indent = None;
133                            let mut line_indent = None;
134                            let mut first_line = true;
135
136                            while let Some(chunk) = chunks.next().await {
137                                if response_latency.is_none() {
138                                    response_latency = Some(request_start.elapsed());
139                                }
140                                let chunk = chunk?;
141
142                                let mut lines = chunk.split('\n').peekable();
143                                while let Some(line) = lines.next() {
144                                    new_text.push_str(line);
145                                    if line_indent.is_none() {
146                                        if let Some(non_whitespace_ch_ix) =
147                                            new_text.find(|ch: char| !ch.is_whitespace())
148                                        {
149                                            line_indent = Some(non_whitespace_ch_ix);
150                                            base_indent = base_indent.or(line_indent);
151
152                                            let line_indent = line_indent.unwrap();
153                                            let base_indent = base_indent.unwrap();
154                                            let indent_delta =
155                                                line_indent as i32 - base_indent as i32;
156                                            let mut corrected_indent_len = cmp::max(
157                                                0,
158                                                suggested_line_indent.len as i32 + indent_delta,
159                                            )
160                                                as usize;
161                                            if first_line {
162                                                corrected_indent_len = corrected_indent_len
163                                                    .saturating_sub(
164                                                        selection_start.column as usize,
165                                                    );
166                                            }
167
168                                            let indent_char = suggested_line_indent.char();
169                                            let mut indent_buffer = [0; 4];
170                                            let indent_str =
171                                                indent_char.encode_utf8(&mut indent_buffer);
172                                            new_text.replace_range(
173                                                ..line_indent,
174                                                &indent_str.repeat(corrected_indent_len),
175                                            );
176                                        }
177                                    }
178
179                                    if line_indent.is_some() {
180                                        hunks_tx.send(diff.push_new(&new_text)).await?;
181                                        new_text.clear();
182                                    }
183
184                                    if lines.peek().is_some() {
185                                        hunks_tx.send(diff.push_new("\n")).await?;
186                                        line_indent = None;
187                                        first_line = false;
188                                    }
189                                }
190                            }
191                            hunks_tx.send(diff.push_new(&new_text)).await?;
192                            hunks_tx.send(diff.finish()).await?;
193
194                            anyhow::Ok(())
195                        };
196
197                        let error_message = diff.await.err().map(|error| error.to_string());
198                        if let Some(telemetry) = telemetry {
199                            telemetry.report_assistant_event(
200                                None,
201                                telemetry_events::AssistantKind::Inline,
202                                model_telemetry_id,
203                                response_latency,
204                                error_message,
205                            );
206                        }
207                    });
208
209                    while let Some(hunks) = hunks_rx.next().await {
210                        this.update(&mut cx, |this, cx| {
211                            this.last_equal_ranges.clear();
212
213                            let transaction = this.buffer.update(cx, |buffer, cx| {
214                                // Avoid grouping assistant edits with user edits.
215                                buffer.finalize_last_transaction(cx);
216
217                                buffer.start_transaction(cx);
218                                buffer.edit(
219                                    hunks.into_iter().filter_map(|hunk| match hunk {
220                                        Hunk::Insert { text } => {
221                                            let edit_start = snapshot.anchor_after(edit_start);
222                                            Some((edit_start..edit_start, text))
223                                        }
224                                        Hunk::Remove { len } => {
225                                            let edit_end = edit_start + len;
226                                            let edit_range = snapshot.anchor_after(edit_start)
227                                                ..snapshot.anchor_before(edit_end);
228                                            edit_start = edit_end;
229                                            Some((edit_range, String::new()))
230                                        }
231                                        Hunk::Keep { len } => {
232                                            let edit_end = edit_start + len;
233                                            let edit_range = snapshot.anchor_after(edit_start)
234                                                ..snapshot.anchor_before(edit_end);
235                                            edit_start = edit_end;
236                                            this.last_equal_ranges.push(edit_range);
237                                            None
238                                        }
239                                    }),
240                                    None,
241                                    cx,
242                                );
243
244                                buffer.end_transaction(cx)
245                            });
246
247                            if let Some(transaction) = transaction {
248                                if let Some(first_transaction) = this.transaction_id {
249                                    // Group all assistant edits into the first transaction.
250                                    this.buffer.update(cx, |buffer, cx| {
251                                        buffer.merge_transactions(
252                                            transaction,
253                                            first_transaction,
254                                            cx,
255                                        )
256                                    });
257                                } else {
258                                    this.transaction_id = Some(transaction);
259                                    this.buffer.update(cx, |buffer, cx| {
260                                        buffer.finalize_last_transaction(cx)
261                                    });
262                                }
263                            }
264
265                            cx.notify();
266                        })?;
267                    }
268
269                    diff.await;
270
271                    anyhow::Ok(())
272                };
273
274                let result = generate.await;
275                this.update(&mut cx, |this, cx| {
276                    this.last_equal_ranges.clear();
277                    this.idle = true;
278                    if let Err(error) = result {
279                        this.error = Some(error);
280                    }
281                    cx.emit(Event::Finished);
282                    cx.notify();
283                })
284                .ok();
285            }
286        });
287        self.error.take();
288        self.idle = false;
289        cx.notify();
290    }
291
292    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
293        if let Some(transaction_id) = self.transaction_id {
294            self.buffer
295                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
296        }
297    }
298}
299
300fn strip_invalid_spans_from_codeblock(
301    stream: impl Stream<Item = Result<String>>,
302) -> impl Stream<Item = Result<String>> {
303    let mut first_line = true;
304    let mut buffer = String::new();
305    let mut starts_with_markdown_codeblock = false;
306    let mut includes_start_or_end_span = false;
307    stream.filter_map(move |chunk| {
308        let chunk = match chunk {
309            Ok(chunk) => chunk,
310            Err(err) => return future::ready(Some(Err(err))),
311        };
312        buffer.push_str(&chunk);
313
314        if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
315            includes_start_or_end_span = true;
316
317            buffer = buffer
318                .strip_prefix("<|S|>")
319                .or_else(|| buffer.strip_prefix("<|S|"))
320                .unwrap_or(&buffer)
321                .to_string();
322        } else if buffer.ends_with("|E|>") {
323            includes_start_or_end_span = true;
324        } else if buffer.starts_with("<|")
325            || buffer.starts_with("<|S")
326            || buffer.starts_with("<|S|")
327            || buffer.ends_with('|')
328            || buffer.ends_with("|E")
329            || buffer.ends_with("|E|")
330        {
331            return future::ready(None);
332        }
333
334        if first_line {
335            if buffer.is_empty() || buffer == "`" || buffer == "``" {
336                return future::ready(None);
337            } else if buffer.starts_with("```") {
338                starts_with_markdown_codeblock = true;
339                if let Some(newline_ix) = buffer.find('\n') {
340                    buffer.replace_range(..newline_ix + 1, "");
341                    first_line = false;
342                } else {
343                    return future::ready(None);
344                }
345            }
346        }
347
348        let mut text = buffer.to_string();
349        if starts_with_markdown_codeblock {
350            text = text
351                .strip_suffix("\n```\n")
352                .or_else(|| text.strip_suffix("\n```"))
353                .or_else(|| text.strip_suffix("\n``"))
354                .or_else(|| text.strip_suffix("\n`"))
355                .or_else(|| text.strip_suffix('\n'))
356                .unwrap_or(&text)
357                .to_string();
358        }
359
360        if includes_start_or_end_span {
361            text = text
362                .strip_suffix("|E|>")
363                .or_else(|| text.strip_suffix("E|>"))
364                .or_else(|| text.strip_prefix("|>"))
365                .or_else(|| text.strip_prefix('>'))
366                .unwrap_or(&text)
367                .to_string();
368        };
369
370        if text.contains('\n') {
371            first_line = false;
372        }
373
374        let remainder = buffer.split_off(text.len());
375        let result = if buffer.is_empty() {
376            None
377        } else {
378            Some(Ok(buffer.clone()))
379        };
380
381        buffer = remainder;
382        future::ready(result)
383    })
384}
385
386#[cfg(test)]
387mod tests {
388    use std::sync::Arc;
389
390    use crate::FakeCompletionProvider;
391
392    use super::*;
393    use futures::stream::{self};
394    use gpui::{Context, TestAppContext};
395    use indoc::indoc;
396    use language::{
397        language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
398        Point,
399    };
400    use rand::prelude::*;
401    use serde::Serialize;
402    use settings::SettingsStore;
403
404    #[derive(Serialize)]
405    pub struct DummyCompletionRequest {
406        pub name: String,
407    }
408
409    #[gpui::test(iterations = 10)]
410    async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
411        let provider = FakeCompletionProvider::default();
412        cx.set_global(cx.update(SettingsStore::test));
413        cx.set_global(CompletionProvider::Fake(provider.clone()));
414        cx.update(language_settings::init);
415
416        let text = indoc! {"
417            fn main() {
418                let x = 0;
419                for _ in 0..10 {
420                    x += 1;
421                }
422            }
423        "};
424        let buffer =
425            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
426        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
427        let range = buffer.read_with(cx, |buffer, cx| {
428            let snapshot = buffer.snapshot(cx);
429            snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
430        });
431        let codegen = cx.new_model(|cx| {
432            Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
433        });
434
435        let request = LanguageModelRequest::default();
436        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
437
438        let mut new_text = concat!(
439            "       let mut x = 0;\n",
440            "       while x < 10 {\n",
441            "           x += 1;\n",
442            "       }",
443        );
444        while !new_text.is_empty() {
445            let max_len = cmp::min(new_text.len(), 10);
446            let len = rng.gen_range(1..=max_len);
447            let (chunk, suffix) = new_text.split_at(len);
448            provider.send_completion(chunk.into());
449            new_text = suffix;
450            cx.background_executor.run_until_parked();
451        }
452        provider.finish_completion();
453        cx.background_executor.run_until_parked();
454
455        assert_eq!(
456            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
457            indoc! {"
458                fn main() {
459                    let mut x = 0;
460                    while x < 10 {
461                        x += 1;
462                    }
463                }
464            "}
465        );
466    }
467
468    #[gpui::test(iterations = 10)]
469    async fn test_autoindent_when_generating_past_indentation(
470        cx: &mut TestAppContext,
471        mut rng: StdRng,
472    ) {
473        let provider = FakeCompletionProvider::default();
474        cx.set_global(CompletionProvider::Fake(provider.clone()));
475        cx.set_global(cx.update(SettingsStore::test));
476        cx.update(language_settings::init);
477
478        let text = indoc! {"
479            fn main() {
480                le
481            }
482        "};
483        let buffer =
484            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
485        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
486        let position = buffer.read_with(cx, |buffer, cx| {
487            let snapshot = buffer.snapshot(cx);
488            snapshot.anchor_before(Point::new(1, 6))
489        });
490        let codegen = cx.new_model(|cx| {
491            Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
492        });
493
494        let request = LanguageModelRequest::default();
495        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
496
497        let mut new_text = concat!(
498            "t mut x = 0;\n",
499            "while x < 10 {\n",
500            "    x += 1;\n",
501            "}", //
502        );
503        while !new_text.is_empty() {
504            let max_len = cmp::min(new_text.len(), 10);
505            let len = rng.gen_range(1..=max_len);
506            let (chunk, suffix) = new_text.split_at(len);
507            provider.send_completion(chunk.into());
508            new_text = suffix;
509            cx.background_executor.run_until_parked();
510        }
511        provider.finish_completion();
512        cx.background_executor.run_until_parked();
513
514        assert_eq!(
515            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
516            indoc! {"
517                fn main() {
518                    let mut x = 0;
519                    while x < 10 {
520                        x += 1;
521                    }
522                }
523            "}
524        );
525    }
526
527    #[gpui::test(iterations = 10)]
528    async fn test_autoindent_when_generating_before_indentation(
529        cx: &mut TestAppContext,
530        mut rng: StdRng,
531    ) {
532        let provider = FakeCompletionProvider::default();
533        cx.set_global(CompletionProvider::Fake(provider.clone()));
534        cx.set_global(cx.update(SettingsStore::test));
535        cx.update(language_settings::init);
536
537        let text = concat!(
538            "fn main() {\n",
539            "  \n",
540            "}\n" //
541        );
542        let buffer =
543            cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
544        let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
545        let position = buffer.read_with(cx, |buffer, cx| {
546            let snapshot = buffer.snapshot(cx);
547            snapshot.anchor_before(Point::new(1, 2))
548        });
549        let codegen = cx.new_model(|cx| {
550            Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
551        });
552
553        let request = LanguageModelRequest::default();
554        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
555
556        let mut new_text = concat!(
557            "let mut x = 0;\n",
558            "while x < 10 {\n",
559            "    x += 1;\n",
560            "}", //
561        );
562        while !new_text.is_empty() {
563            let max_len = cmp::min(new_text.len(), 10);
564            let len = rng.gen_range(1..=max_len);
565            let (chunk, suffix) = new_text.split_at(len);
566            provider.send_completion(chunk.into());
567            new_text = suffix;
568            cx.background_executor.run_until_parked();
569        }
570        provider.finish_completion();
571        cx.background_executor.run_until_parked();
572
573        assert_eq!(
574            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
575            indoc! {"
576                fn main() {
577                    let mut x = 0;
578                    while x < 10 {
579                        x += 1;
580                    }
581                }
582            "}
583        );
584    }
585
586    #[gpui::test]
587    async fn test_strip_invalid_spans_from_codeblock() {
588        assert_eq!(
589            strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
590                .map(|chunk| chunk.unwrap())
591                .collect::<String>()
592                .await,
593            "Lorem ipsum dolor"
594        );
595        assert_eq!(
596            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
597                .map(|chunk| chunk.unwrap())
598                .collect::<String>()
599                .await,
600            "Lorem ipsum dolor"
601        );
602        assert_eq!(
603            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
604                .map(|chunk| chunk.unwrap())
605                .collect::<String>()
606                .await,
607            "Lorem ipsum dolor"
608        );
609        assert_eq!(
610            strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
611                .map(|chunk| chunk.unwrap())
612                .collect::<String>()
613                .await,
614            "Lorem ipsum dolor"
615        );
616        assert_eq!(
617            strip_invalid_spans_from_codeblock(chunks(
618                "```html\n```js\nLorem ipsum dolor\n```\n```",
619                2
620            ))
621            .map(|chunk| chunk.unwrap())
622            .collect::<String>()
623            .await,
624            "```js\nLorem ipsum dolor\n```"
625        );
626        assert_eq!(
627            strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
628                .map(|chunk| chunk.unwrap())
629                .collect::<String>()
630                .await,
631            "``\nLorem ipsum dolor\n```"
632        );
633        assert_eq!(
634            strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
635                .map(|chunk| chunk.unwrap())
636                .collect::<String>()
637                .await,
638            "Lorem ipsum"
639        );
640
641        assert_eq!(
642            strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
643                .map(|chunk| chunk.unwrap())
644                .collect::<String>()
645                .await,
646            "Lorem ipsum"
647        );
648
649        assert_eq!(
650            strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
651                .map(|chunk| chunk.unwrap())
652                .collect::<String>()
653                .await,
654            "Lorem ipsum"
655        );
656        assert_eq!(
657            strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
658                .map(|chunk| chunk.unwrap())
659                .collect::<String>()
660                .await,
661            "Lorem ipsum"
662        );
663        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
664            stream::iter(
665                text.chars()
666                    .collect::<Vec<_>>()
667                    .chunks(size)
668                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
669                    .collect::<Vec<_>>(),
670            )
671        }
672    }
673
674    fn rust_lang() -> Language {
675        Language::new(
676            LanguageConfig {
677                name: "Rust".into(),
678                matcher: LanguageMatcher {
679                    path_suffixes: vec!["rs".to_string()],
680                    ..Default::default()
681                },
682                ..Default::default()
683            },
684            Some(tree_sitter_rust::language()),
685        )
686        .with_indents_query(
687            r#"
688            (call_expression) @indent
689            (field_expression) @indent
690            (_ "(" ")" @end) @indent
691            (_ "{" "}" @end) @indent
692            "#,
693        )
694        .unwrap()
695    }
696}