codegen.rs

  1use crate::{
  2    stream_completion,
  3    streaming_diff::{Hunk, StreamingDiff},
  4    OpenAIRequest,
  5};
  6use anyhow::Result;
  7use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint};
  8use futures::{
  9    channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
 10};
 11use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
 12use language::{IndentSize, Point, Rope, TransactionId};
 13use std::{cmp, future, ops::Range, sync::Arc};
 14
 15pub trait CompletionProvider {
 16    fn complete(
 17        &self,
 18        prompt: OpenAIRequest,
 19    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
 20}
 21
 22pub struct OpenAICompletionProvider {
 23    api_key: String,
 24    executor: Arc<Background>,
 25}
 26
 27impl OpenAICompletionProvider {
 28    pub fn new(api_key: String, executor: Arc<Background>) -> Self {
 29        Self { api_key, executor }
 30    }
 31}
 32
 33impl CompletionProvider for OpenAICompletionProvider {
 34    fn complete(
 35        &self,
 36        prompt: OpenAIRequest,
 37    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
 38        let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
 39        async move {
 40            let response = request.await?;
 41            let stream = response
 42                .filter_map(|response| async move {
 43                    match response {
 44                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
 45                        Err(error) => Some(Err(error)),
 46                    }
 47                })
 48                .boxed();
 49            Ok(stream)
 50        }
 51        .boxed()
 52    }
 53}
 54
 55pub enum Event {
 56    Finished,
 57    Undone,
 58}
 59
 60pub struct Codegen {
 61    provider: Arc<dyn CompletionProvider>,
 62    buffer: ModelHandle<MultiBuffer>,
 63    range: Range<Anchor>,
 64    last_equal_ranges: Vec<Range<Anchor>>,
 65    transaction_id: Option<TransactionId>,
 66    error: Option<anyhow::Error>,
 67    generation: Task<()>,
 68    idle: bool,
 69    _subscription: gpui::Subscription,
 70}
 71
 72impl Entity for Codegen {
 73    type Event = Event;
 74}
 75
 76impl Codegen {
 77    pub fn new(
 78        buffer: ModelHandle<MultiBuffer>,
 79        range: Range<Anchor>,
 80        provider: Arc<dyn CompletionProvider>,
 81        cx: &mut ModelContext<Self>,
 82    ) -> Self {
 83        Self {
 84            provider,
 85            buffer: buffer.clone(),
 86            range,
 87            last_equal_ranges: Default::default(),
 88            transaction_id: Default::default(),
 89            error: Default::default(),
 90            idle: true,
 91            generation: Task::ready(()),
 92            _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
 93        }
 94    }
 95
 96    fn handle_buffer_event(
 97        &mut self,
 98        _buffer: ModelHandle<MultiBuffer>,
 99        event: &multi_buffer::Event,
100        cx: &mut ModelContext<Self>,
101    ) {
102        if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
103            if self.transaction_id == Some(*transaction_id) {
104                self.transaction_id = None;
105                self.generation = Task::ready(());
106                cx.emit(Event::Undone);
107            }
108        }
109    }
110
111    pub fn range(&self) -> Range<Anchor> {
112        self.range.clone()
113    }
114
115    pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
116        &self.last_equal_ranges
117    }
118
119    pub fn idle(&self) -> bool {
120        self.idle
121    }
122
123    pub fn error(&self) -> Option<&anyhow::Error> {
124        self.error.as_ref()
125    }
126
127    pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
128        let range = self.range.clone();
129        let snapshot = self.buffer.read(cx).snapshot(cx);
130        let selected_text = snapshot
131            .text_for_range(range.start..range.end)
132            .collect::<Rope>();
133
134        let selection_start = range.start.to_point(&snapshot);
135        let selection_end = range.end.to_point(&snapshot);
136
137        let mut base_indent: Option<IndentSize> = None;
138        let mut start_row = selection_start.row;
139        if snapshot.is_line_blank(start_row) {
140            if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
141                start_row = prev_non_blank_row;
142            }
143        }
144        for row in start_row..=selection_end.row {
145            if snapshot.is_line_blank(row) {
146                continue;
147            }
148
149            let line_indent = snapshot.indent_size_for_line(row);
150            if let Some(base_indent) = base_indent.as_mut() {
151                if line_indent.len < base_indent.len {
152                    *base_indent = line_indent;
153                }
154            } else {
155                base_indent = Some(line_indent);
156            }
157        }
158
159        let mut normalized_selected_text = selected_text.clone();
160        if let Some(base_indent) = base_indent {
161            for row in selection_start.row..=selection_end.row {
162                let selection_row = row - selection_start.row;
163                let line_start =
164                    normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
165                let indent_len = if row == selection_start.row {
166                    base_indent.len.saturating_sub(selection_start.column)
167                } else {
168                    let line_len = normalized_selected_text.line_len(selection_row);
169                    cmp::min(line_len, base_indent.len)
170                };
171                let indent_end = cmp::min(
172                    line_start + indent_len as usize,
173                    normalized_selected_text.len(),
174                );
175                normalized_selected_text.replace(line_start..indent_end, "");
176            }
177        }
178
179        let response = self.provider.complete(prompt);
180        self.generation = cx.spawn_weak(|this, mut cx| {
181            async move {
182                let generate = async {
183                    let mut edit_start = range.start.to_offset(&snapshot);
184
185                    let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
186                    let diff = cx.background().spawn(async move {
187                        let chunks = strip_markdown_codeblock(response.await?);
188                        futures::pin_mut!(chunks);
189                        let mut diff = StreamingDiff::new(selected_text.to_string());
190
191                        let mut indent_len;
192                        let indent_text;
193                        if let Some(base_indent) = base_indent {
194                            indent_len = base_indent.len;
195                            indent_text = match base_indent.kind {
196                                language::IndentKind::Space => " ",
197                                language::IndentKind::Tab => "\t",
198                            };
199                        } else {
200                            indent_len = 0;
201                            indent_text = "";
202                        };
203
204                        let mut first_line_len = 0;
205                        let mut first_line_non_whitespace_char_ix = None;
206                        let mut first_line = true;
207                        let mut new_text = String::new();
208
209                        while let Some(chunk) = chunks.next().await {
210                            let chunk = chunk?;
211
212                            let mut lines = chunk.split('\n');
213                            if let Some(mut line) = lines.next() {
214                                if first_line {
215                                    if first_line_non_whitespace_char_ix.is_none() {
216                                        if let Some(mut char_ix) =
217                                            line.find(|ch: char| !ch.is_whitespace())
218                                        {
219                                            line = &line[char_ix..];
220                                            char_ix += first_line_len;
221                                            first_line_non_whitespace_char_ix = Some(char_ix);
222                                            let first_line_indent = char_ix
223                                                .saturating_sub(selection_start.column as usize)
224                                                as usize;
225                                            new_text
226                                                .push_str(&indent_text.repeat(first_line_indent));
227                                            indent_len = indent_len.saturating_sub(char_ix as u32);
228                                        }
229                                    }
230                                    first_line_len += line.len();
231                                }
232
233                                if first_line_non_whitespace_char_ix.is_some() {
234                                    new_text.push_str(line);
235                                }
236                            }
237
238                            for line in lines {
239                                first_line = false;
240                                new_text.push('\n');
241                                if !line.is_empty() {
242                                    new_text.push_str(&indent_text.repeat(indent_len as usize));
243                                }
244                                new_text.push_str(line);
245                            }
246
247                            let hunks = diff.push_new(&new_text);
248                            hunks_tx.send(hunks).await?;
249                            new_text.clear();
250                        }
251                        hunks_tx.send(diff.finish()).await?;
252
253                        anyhow::Ok(())
254                    });
255
256                    while let Some(hunks) = hunks_rx.next().await {
257                        let this = if let Some(this) = this.upgrade(&cx) {
258                            this
259                        } else {
260                            break;
261                        };
262
263                        this.update(&mut cx, |this, cx| {
264                            this.last_equal_ranges.clear();
265
266                            let transaction = this.buffer.update(cx, |buffer, cx| {
267                                // Avoid grouping assistant edits with user edits.
268                                buffer.finalize_last_transaction(cx);
269
270                                buffer.start_transaction(cx);
271                                buffer.edit(
272                                    hunks.into_iter().filter_map(|hunk| match hunk {
273                                        Hunk::Insert { text } => {
274                                            let edit_start = snapshot.anchor_after(edit_start);
275                                            Some((edit_start..edit_start, text))
276                                        }
277                                        Hunk::Remove { len } => {
278                                            let edit_end = edit_start + len;
279                                            let edit_range = snapshot.anchor_after(edit_start)
280                                                ..snapshot.anchor_before(edit_end);
281                                            edit_start = edit_end;
282                                            Some((edit_range, String::new()))
283                                        }
284                                        Hunk::Keep { len } => {
285                                            let edit_end = edit_start + len;
286                                            let edit_range = snapshot.anchor_after(edit_start)
287                                                ..snapshot.anchor_before(edit_end);
288                                            edit_start += len;
289                                            this.last_equal_ranges.push(edit_range);
290                                            None
291                                        }
292                                    }),
293                                    None,
294                                    cx,
295                                );
296
297                                buffer.end_transaction(cx)
298                            });
299
300                            if let Some(transaction) = transaction {
301                                if let Some(first_transaction) = this.transaction_id {
302                                    // Group all assistant edits into the first transaction.
303                                    this.buffer.update(cx, |buffer, cx| {
304                                        buffer.merge_transactions(
305                                            transaction,
306                                            first_transaction,
307                                            cx,
308                                        )
309                                    });
310                                } else {
311                                    this.transaction_id = Some(transaction);
312                                    this.buffer.update(cx, |buffer, cx| {
313                                        buffer.finalize_last_transaction(cx)
314                                    });
315                                }
316                            }
317
318                            cx.notify();
319                        });
320                    }
321
322                    diff.await?;
323                    anyhow::Ok(())
324                };
325
326                let result = generate.await;
327                if let Some(this) = this.upgrade(&cx) {
328                    this.update(&mut cx, |this, cx| {
329                        this.last_equal_ranges.clear();
330                        this.idle = true;
331                        if let Err(error) = result {
332                            this.error = Some(error);
333                        }
334                        cx.emit(Event::Finished);
335                        cx.notify();
336                    });
337                }
338            }
339        });
340        self.error.take();
341        self.idle = false;
342        cx.notify();
343    }
344
345    pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
346        if let Some(transaction_id) = self.transaction_id {
347            self.buffer
348                .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
349        }
350    }
351}
352
353fn strip_markdown_codeblock(
354    stream: impl Stream<Item = Result<String>>,
355) -> impl Stream<Item = Result<String>> {
356    let mut first_line = true;
357    let mut buffer = String::new();
358    let mut starts_with_fenced_code_block = false;
359    stream.filter_map(move |chunk| {
360        let chunk = match chunk {
361            Ok(chunk) => chunk,
362            Err(err) => return future::ready(Some(Err(err))),
363        };
364        buffer.push_str(&chunk);
365
366        if first_line {
367            if buffer == "" || buffer == "`" || buffer == "``" {
368                return future::ready(None);
369            } else if buffer.starts_with("```") {
370                starts_with_fenced_code_block = true;
371                if let Some(newline_ix) = buffer.find('\n') {
372                    buffer.replace_range(..newline_ix + 1, "");
373                    first_line = false;
374                } else {
375                    return future::ready(None);
376                }
377            }
378        }
379
380        let text = if starts_with_fenced_code_block {
381            buffer
382                .strip_suffix("\n```\n")
383                .or_else(|| buffer.strip_suffix("\n```"))
384                .or_else(|| buffer.strip_suffix("\n``"))
385                .or_else(|| buffer.strip_suffix("\n`"))
386                .or_else(|| buffer.strip_suffix('\n'))
387                .unwrap_or(&buffer)
388        } else {
389            &buffer
390        };
391
392        if text.contains('\n') {
393            first_line = false;
394        }
395
396        let remainder = buffer.split_off(text.len());
397        let result = if buffer.is_empty() {
398            None
399        } else {
400            Some(Ok(buffer.clone()))
401        };
402        buffer = remainder;
403        future::ready(result)
404    })
405}
406
407#[cfg(test)]
408mod tests {
409    use futures::stream;
410
411    use super::*;
412
413    #[gpui::test]
414    async fn test_strip_markdown_codeblock() {
415        assert_eq!(
416            strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
417                .map(|chunk| chunk.unwrap())
418                .collect::<String>()
419                .await,
420            "Lorem ipsum dolor"
421        );
422        assert_eq!(
423            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
424                .map(|chunk| chunk.unwrap())
425                .collect::<String>()
426                .await,
427            "Lorem ipsum dolor"
428        );
429        assert_eq!(
430            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
431                .map(|chunk| chunk.unwrap())
432                .collect::<String>()
433                .await,
434            "Lorem ipsum dolor"
435        );
436        assert_eq!(
437            strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
438                .map(|chunk| chunk.unwrap())
439                .collect::<String>()
440                .await,
441            "Lorem ipsum dolor"
442        );
443        assert_eq!(
444            strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
445                .map(|chunk| chunk.unwrap())
446                .collect::<String>()
447                .await,
448            "```js\nLorem ipsum dolor\n```"
449        );
450        assert_eq!(
451            strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
452                .map(|chunk| chunk.unwrap())
453                .collect::<String>()
454                .await,
455            "``\nLorem ipsum dolor\n```"
456        );
457
458        fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
459            stream::iter(
460                text.chars()
461                    .collect::<Vec<_>>()
462                    .chunks(size)
463                    .map(|chunk| Ok(chunk.iter().collect::<String>()))
464                    .collect::<Vec<_>>(),
465            )
466        }
467    }
468}