codegen.rs

  1use crate::{
  2    stream_completion,
  3    streaming_diff::{Hunk, StreamingDiff},
  4    OpenAIRequest,
  5};
  6use anyhow::Result;
  7use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint};
  8use futures::{
  9    channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
 10};
 11use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
 12use language::{IndentSize, Point, Rope, TransactionId};
 13use std::{cmp, future, ops::Range, sync::Arc};
 14
 15pub trait CompletionProvider {
 16    fn complete(
 17        &self,
 18        prompt: OpenAIRequest,
 19    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 20}
 21
 22pub struct OpenAICompletionProvider {
 23    api_key: String,
 24    executor: Arc<Background>,
 25}
 26
 27impl OpenAICompletionProvider {
 28    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
 29        Self { api_key, executor }
 30    }
 31}
 32
 33impl CompletionProvider for OpenAICompletionProvider {
 34    fn complete(
 35        &self,
 36        prompt: OpenAIRequest,
 37    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
 38        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
 39        async move {
 40            let response = request.await?;
 41            let stream = response
 42                .filter_map(|response| async move {
 43                    match response {
 44                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
 45                        Err(error) => Some(Err(error)),
 46                    }
 47                })
 48                .boxed();
 49            Ok(stream)
 50        }
 51        .boxed()
 52    }
 53}
 54
 55pub enum Event {
 56    Finished,
 57    Undone,
 58}
 59
 60pub struct Codegen {
 61    provider: Arc<dyn CompletionProvider>,
 62    buffer: ModelHandle<MultiBuffer>,
 63    range: Range<Anchor>,
 64    last_equal_ranges: Vec<Range<Anchor>>,
 65    transaction_id: Option<TransactionId>,
 66    error: Option<anyhow::Error>,
 67    generation: Task<()>,
 68    idle: bool,
 69    _subscription: gpui::Subscription,
 70}
 71
 72impl Entity for Codegen {
 73    type Event = Event;
 74}
 75
 76impl Codegen {
 77    pub fn new(
 78        buffer: ModelHandle<MultiBuffer>,
 79        range: Range<Anchor>,
 80        provider: Arc<dyn CompletionProvider>,
 81        cx: &mut ModelContext<Self>,
 82    ) -> Self {
 83        Self {
 84            provider,
 85            buffer: buffer.clone(),
 86            range,
 87            last_equal_ranges: Default::default(),
 88            transaction_id: Default::default(),
 89            error: Default::default(),
 90            idle: true,
 91            generation: Task::ready(()),
 92            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 93        }
 94    }
 95
 96    fn handle_buffer_event(
 97        &mut self,
 98        _buffer: ModelHandle<MultiBuffer>,
 99        event: &multi_buffer::Event,
100        cx: &mut ModelContext<Self>,
101    ) {
102        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
103            if self.transaction_id == Some(*transaction_id) {
104                self.transaction_id = None;
105                self.generation = Task::ready(());
106                cx.emit(Event::Undone);
107            }
108        }
109    }
110
111    pub fn range(&self) -> Range<Anchor> {
112        self.range.clone()
113    }
114
115    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
116        &self.last_equal_ranges
117    }
118
119    pub fn idle(&self) -> bool {
120        self.idle
121    }
122
123    pub fn error(&self) -> Option<&anyhow::Error> {
124        self.error.as_ref()
125    }
126
127    pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
128        let range = self.range.clone();
129        let snapshot = self.buffer.read(cx).snapshot(cx);
130        let selected_text = snapshot
131            .text_for_range(range.start..range.end)
132            .collect::<Rope>();
133
134        let selection_start = range.start.to_point(&snapshot);
135        let selection_end = range.end.to_point(&snapshot);
136
137        let mut base_indent: Option<IndentSize> = None;
138        let mut start_row = selection_start.row;
139        if snapshot.is_line_blank(start_row) {
140            if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
141                start_row = prev_non_blank_row;
142            }
143        }
144        for row in start_row..=selection_end.row {
145            if snapshot.is_line_blank(row) {
146                continue;
147            }
148
149            let line_indent = snapshot.indent_size_for_line(row);
150            if let Some(base_indent) = base_indent.as_mut() {
151                if line_indent.len < base_indent.len {
152                    *base_indent = line_indent;
153                }
154            } else {
155                base_indent = Some(line_indent);
156            }
157        }
158
159        let mut normalized_selected_text = selected_text.clone();
160        if let Some(base_indent) = base_indent {
161            for row in selection_start.row..=selection_end.row {
162                let selection_row = row - selection_start.row;
163                let line_start =
164                    normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
165                let indent_len = if row == selection_start.row {
166                    base_indent.len.saturating_sub(selection_start.column)
167                } else {
168                    let line_len = normalized_selected_text.line_len(selection_row);
169                    cmp::min(line_len, base_indent.len)
170                };
171                let indent_end = cmp::min(
172                    line_start + indent_len as usize,
173                    normalized_selected_text.len(),
174                );
175                normalized_selected_text.replace(line_start..indent_end, "");
176            }
177        }
178
179        let response = self.provider.complete(prompt);
180        self.generation = cx.spawn_weak(|this, mut cx| {
181            async move {
182                let generate = async {
183                    let mut edit_start = range.start.to_offset(&snapshot);
184
185                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
186                    let diff = cx.background().spawn(async move {
187                        let chunks = strip_markdown_codeblock(response.await?);
188                        futures::pin_mut!(chunks);
189                        let mut diff = StreamingDiff::new(selected_text.to_string());
190
191                        let mut indent_len;
192                        let indent_text;
193                        if let Some(base_indent) = base_indent {
194                            indent_len = base_indent.len;
195                            indent_text = match base_indent.kind {
196                                language::IndentKind::Space => " ",
197                                language::IndentKind::Tab => "\t",
198                            };
199                        } else {
200                            indent_len = 0;
201                            indent_text = "";
202                        };
203
204                        let mut first_line_len = 0;
205                        let mut first_line_non_whitespace_char_ix = None;
206                        let mut first_line = true;
207                        let mut new_text = String::new();
208
209                        while let Some(chunk) = chunks.next().await {
210                            let chunk = chunk?;
211
212                            let mut lines = chunk.split('\n');
213                            if let Some(mut line) = lines.next() {
214                                if first_line {
215                                    if first_line_non_whitespace_char_ix.is_none() {
216                                        if let Some(mut char_ix) =
217                                            line.find(|ch: char| !ch.is_whitespace())
218                                        {
219                                            line = &line[char_ix..];
220                                            char_ix += first_line_len;
221                                            first_line_non_whitespace_char_ix = Some(char_ix);
222                                            let first_line_indent = char_ix
223                                                .saturating_sub(selection_start.column as usize)
224                                                as usize;
225                                            new_text
226                                                .push_str(&indent_text.repeat(first_line_indent));
227                                            indent_len = indent_len.saturating_sub(char_ix as u32);
228                                        }
229                                    }
230                                    first_line_len += line.len();
231                                }
232
233                                if first_line_non_whitespace_char_ix.is_some() {
234                                    new_text.push_str(line);
235                                }
236                            }
237
238                            for line in lines {
239                                first_line = false;
240                                new_text.push('\n');
241                                if !line.is_empty() {
242                                    new_text.push_str(&indent_text.repeat(indent_len as usize));
243                                }
244                                new_text.push_str(line);
245                            }
246
247                            let hunks = diff.push_new(&new_text);
248                            hunks_tx.send(hunks).await?;
249                            new_text.clear();
250                        }
251                        hunks_tx.send(diff.finish()).await?;
252
253                        anyhow::Ok(())
254                    });
255
256                    while let Some(hunks) = hunks_rx.next().await {
257                        let this = if let Some(this) = this.upgrade(&cx) {
258                            this
259                        } else {
260                            break;
261                        };
262
263                        this.update(&mut cx, |this, cx| {
264                            this.last_equal_ranges.clear();
265
266                            let transaction = this.buffer.update(cx, |buffer, cx| {
267                                // Avoid grouping assistant edits with user edits.
268                                buffer.finalize_last_transaction(cx);
269
270                                buffer.start_transaction(cx);
271                                buffer.edit(
272                                    hunks.into_iter().filter_map(|hunk| match hunk {
273                                        Hunk::Insert { text } => {
274                                            let edit_start = snapshot.anchor_after(edit_start);
275                                            Some((edit_start..edit_start, text))
276                                        }
277                                        Hunk::Remove { len } => {
278                                            let edit_end = edit_start + len;
279                                            let edit_range = snapshot.anchor_after(edit_start)
280                                                ..snapshot.anchor_before(edit_end);
281                                            edit_start = edit_end;
282                                            Some((edit_range, String::new()))
283                                        }
284                                        Hunk::Keep { len } => {
285                                            let edit_end = edit_start + len;
286                                            let edit_range = snapshot.anchor_after(edit_start)
287                                                ..snapshot.anchor_before(edit_end);
288                                            edit_start += len;
289                                            this.last_equal_ranges.push(edit_range);
290                                            None
291                                        }
292                                    }),
293                                    None,
294                                    cx,
295                                );
296
297                                buffer.end_transaction(cx)
298                            });
299
300                            if let Some(transaction) = transaction {
301                                if let Some(first_transaction) = this.transaction_id {
302                                    // Group all assistant edits into the first transaction.
303                                    this.buffer.update(cx, |buffer, cx| {
304                                        buffer.merge_transactions(
305                                            transaction,
306                                            first_transaction,
307                                            cx,
308                                        )
309                                    });
310                                } else {
311                                    this.transaction_id = Some(transaction);
312                                    this.buffer.update(cx, |buffer, cx| {
313                                        buffer.finalize_last_transaction(cx)
314                                    });
315                                }
316                            }
317
318                            cx.notify();
319                        });
320                    }
321
322                    diff.await?;
323                    anyhow::Ok(())
324                };
325
326                let result = generate.await;
327                if let Some(this) = this.upgrade(&cx) {
328                    this.update(&mut cx, |this, cx| {
329                        this.last_equal_ranges.clear();
330                        this.idle = true;
331                        if let Err(error) = result {
332                            this.error = Some(error);
333                        }
334                        cx.emit(Event::Finished);
335                        cx.notify();
336                    });
337                }
338            }
339        });
340        self.error.take();
341        self.idle = false;
342        cx.notify();
343    }
344
345    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
346        if let Some(transaction_id) = self.transaction_id {
347            self.buffer
348                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
349        }
350    }
351}
352
353fn strip_markdown_codeblock(
354    stream: impl Stream<Item = Result<String>>,
355) -> impl Stream<Item = Result<String>> {
356    let mut first_line = true;
357    let mut buffer = String::new();
358    let mut starts_with_fenced_code_block = false;
359    stream.filter_map(move |chunk| {
360        let chunk = match chunk {
361            Ok(chunk) => chunk,
362            Err(err) => return future::ready(Some(Err(err))),
363        };
364        buffer.push_str(&chunk);
365
366        if first_line {
367            if buffer == "" || buffer == "`" || buffer == "``" {
368                return future::ready(None);
369            } else if buffer.starts_with("```") {
370                starts_with_fenced_code_block = true;
371                if let Some(newline_ix) = buffer.find('\n') {
372                    buffer.replace_range(..newline_ix + 1, "");
373                    first_line = false;
374                } else {
375                    return future::ready(None);
376                }
377            }
378        }
379
380        let text = if starts_with_fenced_code_block {
381            buffer
382                .strip_suffix("\n```\n")
383                .or_else(|| buffer.strip_suffix("\n```"))
384                .or_else(|| buffer.strip_suffix("\n``"))
385                .or_else(|| buffer.strip_suffix("\n`"))
386                .or_else(|| buffer.strip_suffix('\n'))
387                .unwrap_or(&buffer)
388        } else {
389            &buffer
390        };
391
392        if text.contains('\n') {
393            first_line = false;
394        }
395
396        let remainder = buffer.split_off(text.len());
397        let result = if buffer.is_empty() {
398            None
399        } else {
400            Some(Ok(buffer.clone()))
401        };
402        buffer = remainder;
403        future::ready(result)
404    })
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use futures::stream;
411    use gpui::{executor::Deterministic, TestAppContext};
412    use indoc::indoc;
413    use language::{tree_sitter_rust, Buffer, Language, LanguageConfig};
414    use parking_lot::Mutex;
415    use rand::prelude::*;
416
417    #[gpui::test(iterations = 10)]
418    async fn test_autoindent(
419        cx: &mut TestAppContext,
420        mut rng: StdRng,
421        deterministic: Arc<Deterministic>,
422    ) {
423        let text = indoc! {"
424            fn main() {
425                let x = 0;
426                for _ in 0..10 {
427                    x += 1;
428                }
429            }
430        "};
431        let buffer =
432            cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
433        let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
434        let range = buffer.read_with(cx, |buffer, cx| {
435            let snapshot = buffer.snapshot(cx);
436            snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
437        });
438        let provider = Arc::new(TestCompletionProvider::new());
439        let codegen = cx.add_model(|cx| Codegen::new(buffer.clone(), range, provider.clone(), cx));
440        codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
441
442        let mut new_text = indoc! {"
443                   let mut x = 0;
444            while x < 10 {
445                           x += 1;
446               }
447        "};
448        while !new_text.is_empty() {
449            let max_len = cmp::min(new_text.len(), 10);
450            let len = rng.gen_range(1..=max_len);
451            let (chunk, suffix) = new_text.split_at(len);
452            provider.send_completion(chunk);
453            new_text = suffix;
454            deterministic.run_until_parked();
455        }
456        provider.finish_completion();
457        deterministic.run_until_parked();
458
459        assert_eq!(
460            buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
461            indoc! {"
462                fn main() {
463                    let mut x = 0;
464                    while x < 10 {
465                        x += 1;
466                    }
467                }
468            "}
469        );
470    }
471
472    #[gpui::test]
473    async fn test_strip_markdown_codeblock() {
474        assert_eq!(
475            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
476                .map(|chunk| chunk.unwrap())
477                .collect::<String>()
478                .await,
479            "Lorem ipsum dolor"
480        );
481        assert_eq!(
482            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
483                .map(|chunk| chunk.unwrap())
484                .collect::<String>()
485                .await,
486            "Lorem ipsum dolor"
487        );
488        assert_eq!(
489            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
490                .map(|chunk| chunk.unwrap())
491                .collect::<String>()
492                .await,
493            "Lorem ipsum dolor"
494        );
495        assert_eq!(
496            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
497                .map(|chunk| chunk.unwrap())
498                .collect::<String>()
499                .await,
500            "Lorem ipsum dolor"
501        );
502        assert_eq!(
503            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
504                .map(|chunk| chunk.unwrap())
505                .collect::<String>()
506                .await,
507            "```js\nLorem ipsum dolor\n```"
508        );
509        assert_eq!(
510            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
511                .map(|chunk| chunk.unwrap())
512                .collect::<String>()
513                .await,
514            "``\nLorem ipsum dolor\n```"
515        );
516
517        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
518            stream::iter(
519                text.chars()
520                    .collect::<Vec<_>>()
521                    .chunks(size)
522                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
523                    .collect::<Vec<_>>(),
524            )
525        }
526    }
527
528    struct TestCompletionProvider {
529        last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
530    }
531
532    impl TestCompletionProvider {
533        fn new() -> Self {
534            Self {
535                last_completion_tx: Mutex::new(None),
536            }
537        }
538
539        fn send_completion(&self, completion: impl Into<String>) {
540            let mut tx = self.last_completion_tx.lock();
541            tx.as_mut().unwrap().try_send(completion.into()).unwrap();
542        }
543
544        fn finish_completion(&self) {
545            self.last_completion_tx.lock().take().unwrap();
546        }
547    }
548
549    impl CompletionProvider for TestCompletionProvider {
550        fn complete(
551            &self,
552            _prompt: OpenAIRequest,
553        ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
554            let (tx, rx) = mpsc::channel(1);
555            *self.last_completion_tx.lock() = Some(tx);
556            async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
557        }
558    }
559
560    fn rust_lang() -> Language {
561        Language::new(
562            LanguageConfig {
563                name: "Rust".into(),
564                path_suffixes: vec!["rs".to_string()],
565                ..Default::default()
566            },
567            Some(tree_sitter_rust::language()),
568        )
569        .with_indents_query(
570            r#"
571            (call_expression) @indent
572            (field_expression) @indent
573            (_ "(" ")" @end) @indent
574            (_ "{" "}" @end) @indent
575            "#,
576        )
577        .unwrap()
578    }
579}