1use crate::streaming_diff::{Hunk, StreamingDiff};
2use ai::completion::{CompletionProvider, OpenAIRequest};
3use anyhow::Result;
4use editor::{
5 multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
6};
7use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
8use gpui::{Entity, ModelContext, ModelHandle, Task};
9use language::{Rope, TransactionId};
10use std::{cmp, future, ops::Range, sync::Arc};
11
12pub enum Event {
13 Finished,
14 Undone,
15}
16
17#[derive(Clone)]
18pub enum CodegenKind {
19 Transform { range: Range<Anchor> },
20 Generate { position: Anchor },
21}
22
23pub struct Codegen {
24 provider: Arc<dyn CompletionProvider>,
25 buffer: ModelHandle<MultiBuffer>,
26 snapshot: MultiBufferSnapshot,
27 kind: CodegenKind,
28 last_equal_ranges: Vec<Range<Anchor>>,
29 transaction_id: Option<TransactionId>,
30 error: Option<anyhow::Error>,
31 generation: Task<()>,
32 idle: bool,
33 _subscription: gpui::Subscription,
34}
35
36impl Entity for Codegen {
37 type Event = Event;
38}
39
40impl Codegen {
41 pub fn new(
42 buffer: ModelHandle<MultiBuffer>,
43 mut kind: CodegenKind,
44 provider: Arc<dyn CompletionProvider>,
45 cx: &mut ModelContext<Self>,
46 ) -> Self {
47 let snapshot = buffer.read(cx).snapshot(cx);
48 match &mut kind {
49 CodegenKind::Transform { range } => {
50 let mut point_range = range.to_point(&snapshot);
51 point_range.start.column = 0;
52 if point_range.end.column > 0 || point_range.start.row == point_range.end.row {
53 point_range.end.column = snapshot.line_len(point_range.end.row);
54 }
55 range.start = snapshot.anchor_before(point_range.start);
56 range.end = snapshot.anchor_after(point_range.end);
57 }
58 CodegenKind::Generate { position } => {
59 *position = position.bias_right(&snapshot);
60 }
61 }
62
63 Self {
64 provider,
65 buffer: buffer.clone(),
66 snapshot,
67 kind,
68 last_equal_ranges: Default::default(),
69 transaction_id: Default::default(),
70 error: Default::default(),
71 idle: true,
72 generation: Task::ready(()),
73 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
74 }
75 }
76
77 fn handle_buffer_event(
78 &mut self,
79 _buffer: ModelHandle<MultiBuffer>,
80 event: &multi_buffer::Event,
81 cx: &mut ModelContext<Self>,
82 ) {
83 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
84 if self.transaction_id == Some(*transaction_id) {
85 self.transaction_id = None;
86 self.generation = Task::ready(());
87 cx.emit(Event::Undone);
88 }
89 }
90 }
91
92 pub fn range(&self) -> Range<Anchor> {
93 match &self.kind {
94 CodegenKind::Transform { range } => range.clone(),
95 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
96 }
97 }
98
99 pub fn kind(&self) -> &CodegenKind {
100 &self.kind
101 }
102
103 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
104 &self.last_equal_ranges
105 }
106
107 pub fn idle(&self) -> bool {
108 self.idle
109 }
110
111 pub fn error(&self) -> Option<&anyhow::Error> {
112 self.error.as_ref()
113 }
114
115 pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
116 let range = self.range();
117 let snapshot = self.snapshot.clone();
118 let selected_text = snapshot
119 .text_for_range(range.start..range.end)
120 .collect::<Rope>();
121
122 let selection_start = range.start.to_point(&snapshot);
123 let suggested_line_indent = snapshot
124 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
125 .into_values()
126 .next()
127 .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
128
129 let response = self.provider.complete(prompt);
130 self.generation = cx.spawn_weak(|this, mut cx| {
131 async move {
132 let generate = async {
133 let mut edit_start = range.start.to_offset(&snapshot);
134
135 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
136 let diff = cx.background().spawn(async move {
137 let chunks = strip_markdown_codeblock(response.await?);
138 futures::pin_mut!(chunks);
139 let mut diff = StreamingDiff::new(selected_text.to_string());
140
141 let mut new_text = String::new();
142 let mut base_indent = None;
143 let mut line_indent = None;
144 let mut first_line = true;
145
146 while let Some(chunk) = chunks.next().await {
147 let chunk = chunk?;
148
149 let mut lines = chunk.split('\n').peekable();
150 while let Some(line) = lines.next() {
151 new_text.push_str(line);
152 if line_indent.is_none() {
153 if let Some(non_whitespace_ch_ix) =
154 new_text.find(|ch: char| !ch.is_whitespace())
155 {
156 line_indent = Some(non_whitespace_ch_ix);
157 base_indent = base_indent.or(line_indent);
158
159 let line_indent = line_indent.unwrap();
160 let base_indent = base_indent.unwrap();
161 let indent_delta = line_indent as i32 - base_indent as i32;
162 let mut corrected_indent_len = cmp::max(
163 0,
164 suggested_line_indent.len as i32 + indent_delta,
165 )
166 as usize;
167 if first_line {
168 corrected_indent_len = corrected_indent_len
169 .saturating_sub(selection_start.column as usize);
170 }
171
172 let indent_char = suggested_line_indent.char();
173 let mut indent_buffer = [0; 4];
174 let indent_str =
175 indent_char.encode_utf8(&mut indent_buffer);
176 new_text.replace_range(
177 ..line_indent,
178 &indent_str.repeat(corrected_indent_len),
179 );
180 }
181 }
182
183 if line_indent.is_some() {
184 hunks_tx.send(diff.push_new(&new_text)).await?;
185 new_text.clear();
186 }
187
188 if lines.peek().is_some() {
189 hunks_tx.send(diff.push_new("\n")).await?;
190 line_indent = None;
191 first_line = false;
192 }
193 }
194 }
195 hunks_tx.send(diff.push_new(&new_text)).await?;
196 hunks_tx.send(diff.finish()).await?;
197
198 anyhow::Ok(())
199 });
200
201 while let Some(hunks) = hunks_rx.next().await {
202 let this = if let Some(this) = this.upgrade(&cx) {
203 this
204 } else {
205 break;
206 };
207
208 this.update(&mut cx, |this, cx| {
209 this.last_equal_ranges.clear();
210
211 let transaction = this.buffer.update(cx, |buffer, cx| {
212 // Avoid grouping assistant edits with user edits.
213 buffer.finalize_last_transaction(cx);
214
215 buffer.start_transaction(cx);
216 buffer.edit(
217 hunks.into_iter().filter_map(|hunk| match hunk {
218 Hunk::Insert { text } => {
219 let edit_start = snapshot.anchor_after(edit_start);
220 Some((edit_start..edit_start, text))
221 }
222 Hunk::Remove { len } => {
223 let edit_end = edit_start + len;
224 let edit_range = snapshot.anchor_after(edit_start)
225 ..snapshot.anchor_before(edit_end);
226 edit_start = edit_end;
227 Some((edit_range, String::new()))
228 }
229 Hunk::Keep { len } => {
230 let edit_end = edit_start + len;
231 let edit_range = snapshot.anchor_after(edit_start)
232 ..snapshot.anchor_before(edit_end);
233 edit_start = edit_end;
234 this.last_equal_ranges.push(edit_range);
235 None
236 }
237 }),
238 None,
239 cx,
240 );
241
242 buffer.end_transaction(cx)
243 });
244
245 if let Some(transaction) = transaction {
246 if let Some(first_transaction) = this.transaction_id {
247 // Group all assistant edits into the first transaction.
248 this.buffer.update(cx, |buffer, cx| {
249 buffer.merge_transactions(
250 transaction,
251 first_transaction,
252 cx,
253 )
254 });
255 } else {
256 this.transaction_id = Some(transaction);
257 this.buffer.update(cx, |buffer, cx| {
258 buffer.finalize_last_transaction(cx)
259 });
260 }
261 }
262
263 cx.notify();
264 });
265 }
266
267 diff.await?;
268 anyhow::Ok(())
269 };
270
271 let result = generate.await;
272 if let Some(this) = this.upgrade(&cx) {
273 this.update(&mut cx, |this, cx| {
274 this.last_equal_ranges.clear();
275 this.idle = true;
276 if let Err(error) = result {
277 this.error = Some(error);
278 }
279 cx.emit(Event::Finished);
280 cx.notify();
281 });
282 }
283 }
284 });
285 self.error.take();
286 self.idle = false;
287 cx.notify();
288 }
289
290 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
291 if let Some(transaction_id) = self.transaction_id {
292 self.buffer
293 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
294 }
295 }
296}
297
298fn strip_markdown_codeblock(
299 stream: impl Stream<Item = Result<String>>,
300) -> impl Stream<Item = Result<String>> {
301 let mut first_line = true;
302 let mut buffer = String::new();
303 let mut starts_with_fenced_code_block = false;
304 stream.filter_map(move |chunk| {
305 let chunk = match chunk {
306 Ok(chunk) => chunk,
307 Err(err) => return future::ready(Some(Err(err))),
308 };
309 buffer.push_str(&chunk);
310
311 if first_line {
312 if buffer == "" || buffer == "`" || buffer == "``" {
313 return future::ready(None);
314 } else if buffer.starts_with("```") {
315 starts_with_fenced_code_block = true;
316 if let Some(newline_ix) = buffer.find('\n') {
317 buffer.replace_range(..newline_ix + 1, "");
318 first_line = false;
319 } else {
320 return future::ready(None);
321 }
322 }
323 }
324
325 let text = if starts_with_fenced_code_block {
326 buffer
327 .strip_suffix("\n```\n")
328 .or_else(|| buffer.strip_suffix("\n```"))
329 .or_else(|| buffer.strip_suffix("\n``"))
330 .or_else(|| buffer.strip_suffix("\n`"))
331 .or_else(|| buffer.strip_suffix('\n'))
332 .unwrap_or(&buffer)
333 } else {
334 &buffer
335 };
336
337 if text.contains('\n') {
338 first_line = false;
339 }
340
341 let remainder = buffer.split_off(text.len());
342 let result = if buffer.is_empty() {
343 None
344 } else {
345 Some(Ok(buffer.clone()))
346 };
347 buffer = remainder;
348 future::ready(result)
349 })
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use futures::{
356 future::BoxFuture,
357 stream::{self, BoxStream},
358 };
359 use gpui::{executor::Deterministic, TestAppContext};
360 use indoc::indoc;
361 use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
362 use parking_lot::Mutex;
363 use rand::prelude::*;
364 use settings::SettingsStore;
365 use smol::future::FutureExt;
366
367 #[gpui::test(iterations = 10)]
368 async fn test_transform_autoindent(
369 cx: &mut TestAppContext,
370 mut rng: StdRng,
371 deterministic: Arc<Deterministic>,
372 ) {
373 cx.set_global(cx.read(SettingsStore::test));
374 cx.update(language_settings::init);
375
376 let text = indoc! {"
377 fn main() {
378 let x = 0;
379 for _ in 0..10 {
380 x += 1;
381 }
382 }
383 "};
384 let buffer =
385 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
386 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
387 let range = buffer.read_with(cx, |buffer, cx| {
388 let snapshot = buffer.snapshot(cx);
389 snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
390 });
391 let provider = Arc::new(TestCompletionProvider::new());
392 let codegen = cx.add_model(|cx| {
393 Codegen::new(
394 buffer.clone(),
395 CodegenKind::Transform { range },
396 provider.clone(),
397 cx,
398 )
399 });
400 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
401
402 let mut new_text = concat!(
403 " let mut x = 0;\n",
404 " while x < 10 {\n",
405 " x += 1;\n",
406 " }",
407 );
408 while !new_text.is_empty() {
409 let max_len = cmp::min(new_text.len(), 10);
410 let len = rng.gen_range(1..=max_len);
411 let (chunk, suffix) = new_text.split_at(len);
412 provider.send_completion(chunk);
413 new_text = suffix;
414 deterministic.run_until_parked();
415 }
416 provider.finish_completion();
417 deterministic.run_until_parked();
418
419 assert_eq!(
420 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
421 indoc! {"
422 fn main() {
423 let mut x = 0;
424 while x < 10 {
425 x += 1;
426 }
427 }
428 "}
429 );
430 }
431
432 #[gpui::test(iterations = 10)]
433 async fn test_autoindent_when_generating_past_indentation(
434 cx: &mut TestAppContext,
435 mut rng: StdRng,
436 deterministic: Arc<Deterministic>,
437 ) {
438 cx.set_global(cx.read(SettingsStore::test));
439 cx.update(language_settings::init);
440
441 let text = indoc! {"
442 fn main() {
443 le
444 }
445 "};
446 let buffer =
447 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
448 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
449 let position = buffer.read_with(cx, |buffer, cx| {
450 let snapshot = buffer.snapshot(cx);
451 snapshot.anchor_before(Point::new(1, 6))
452 });
453 let provider = Arc::new(TestCompletionProvider::new());
454 let codegen = cx.add_model(|cx| {
455 Codegen::new(
456 buffer.clone(),
457 CodegenKind::Generate { position },
458 provider.clone(),
459 cx,
460 )
461 });
462 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
463
464 let mut new_text = concat!(
465 "t mut x = 0;\n",
466 "while x < 10 {\n",
467 " x += 1;\n",
468 "}", //
469 );
470 while !new_text.is_empty() {
471 let max_len = cmp::min(new_text.len(), 10);
472 let len = rng.gen_range(1..=max_len);
473 let (chunk, suffix) = new_text.split_at(len);
474 provider.send_completion(chunk);
475 new_text = suffix;
476 deterministic.run_until_parked();
477 }
478 provider.finish_completion();
479 deterministic.run_until_parked();
480
481 assert_eq!(
482 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
483 indoc! {"
484 fn main() {
485 let mut x = 0;
486 while x < 10 {
487 x += 1;
488 }
489 }
490 "}
491 );
492 }
493
494 #[gpui::test(iterations = 10)]
495 async fn test_autoindent_when_generating_before_indentation(
496 cx: &mut TestAppContext,
497 mut rng: StdRng,
498 deterministic: Arc<Deterministic>,
499 ) {
500 cx.set_global(cx.read(SettingsStore::test));
501 cx.update(language_settings::init);
502
503 let text = concat!(
504 "fn main() {\n",
505 " \n",
506 "}\n" //
507 );
508 let buffer =
509 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
510 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
511 let position = buffer.read_with(cx, |buffer, cx| {
512 let snapshot = buffer.snapshot(cx);
513 snapshot.anchor_before(Point::new(1, 2))
514 });
515 let provider = Arc::new(TestCompletionProvider::new());
516 let codegen = cx.add_model(|cx| {
517 Codegen::new(
518 buffer.clone(),
519 CodegenKind::Generate { position },
520 provider.clone(),
521 cx,
522 )
523 });
524 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
525
526 let mut new_text = concat!(
527 "let mut x = 0;\n",
528 "while x < 10 {\n",
529 " x += 1;\n",
530 "}", //
531 );
532 while !new_text.is_empty() {
533 let max_len = cmp::min(new_text.len(), 10);
534 let len = rng.gen_range(1..=max_len);
535 let (chunk, suffix) = new_text.split_at(len);
536 provider.send_completion(chunk);
537 new_text = suffix;
538 deterministic.run_until_parked();
539 }
540 provider.finish_completion();
541 deterministic.run_until_parked();
542
543 assert_eq!(
544 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
545 indoc! {"
546 fn main() {
547 let mut x = 0;
548 while x < 10 {
549 x += 1;
550 }
551 }
552 "}
553 );
554 }
555
556 #[gpui::test]
557 async fn test_strip_markdown_codeblock() {
558 assert_eq!(
559 strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
560 .map(|chunk| chunk.unwrap())
561 .collect::<String>()
562 .await,
563 "Lorem ipsum dolor"
564 );
565 assert_eq!(
566 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
567 .map(|chunk| chunk.unwrap())
568 .collect::<String>()
569 .await,
570 "Lorem ipsum dolor"
571 );
572 assert_eq!(
573 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
574 .map(|chunk| chunk.unwrap())
575 .collect::<String>()
576 .await,
577 "Lorem ipsum dolor"
578 );
579 assert_eq!(
580 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
581 .map(|chunk| chunk.unwrap())
582 .collect::<String>()
583 .await,
584 "Lorem ipsum dolor"
585 );
586 assert_eq!(
587 strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
588 .map(|chunk| chunk.unwrap())
589 .collect::<String>()
590 .await,
591 "```js\nLorem ipsum dolor\n```"
592 );
593 assert_eq!(
594 strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
595 .map(|chunk| chunk.unwrap())
596 .collect::<String>()
597 .await,
598 "``\nLorem ipsum dolor\n```"
599 );
600
601 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
602 stream::iter(
603 text.chars()
604 .collect::<Vec<_>>()
605 .chunks(size)
606 .map(|chunk| Ok(chunk.iter().collect::<String>()))
607 .collect::<Vec<_>>(),
608 )
609 }
610 }
611
612 struct TestCompletionProvider {
613 last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
614 }
615
616 impl TestCompletionProvider {
617 fn new() -> Self {
618 Self {
619 last_completion_tx: Mutex::new(None),
620 }
621 }
622
623 fn send_completion(&self, completion: impl Into<String>) {
624 let mut tx = self.last_completion_tx.lock();
625 tx.as_mut().unwrap().try_send(completion.into()).unwrap();
626 }
627
628 fn finish_completion(&self) {
629 self.last_completion_tx.lock().take().unwrap();
630 }
631 }
632
633 impl CompletionProvider for TestCompletionProvider {
634 fn complete(
635 &self,
636 _prompt: OpenAIRequest,
637 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
638 let (tx, rx) = mpsc::channel(1);
639 *self.last_completion_tx.lock() = Some(tx);
640 async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
641 }
642 }
643
644 fn rust_lang() -> Language {
645 Language::new(
646 LanguageConfig {
647 name: "Rust".into(),
648 path_suffixes: vec!["rs".to_string()],
649 ..Default::default()
650 },
651 Some(tree_sitter_rust::language()),
652 )
653 .with_indents_query(
654 r#"
655 (call_expression) @indent
656 (field_expression) @indent
657 (_ "(" ")" @end) @indent
658 (_ "{" "}" @end) @indent
659 "#,
660 )
661 .unwrap()
662 }
663}