1use crate::{
2 stream_completion,
3 streaming_diff::{Hunk, StreamingDiff},
4 OpenAIRequest,
5};
6use anyhow::Result;
7use editor::{
8 multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
9};
10use futures::{
11 channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
12};
13use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
14use language::{Rope, TransactionId};
15use std::{cmp, future, ops::Range, sync::Arc};
16
17pub trait CompletionProvider {
18 fn complete(
19 &self,
20 prompt: OpenAIRequest,
21 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
22}
23
24pub struct OpenAICompletionProvider {
25 api_key: String,
26 executor: Arc<Background>,
27}
28
29impl OpenAICompletionProvider {
30 pub fn new(api_key: String, executor: Arc<Background>) -> Self {
31 Self { api_key, executor }
32 }
33}
34
35impl CompletionProvider for OpenAICompletionProvider {
36 fn complete(
37 &self,
38 prompt: OpenAIRequest,
39 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
40 let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
41 async move {
42 let response = request.await?;
43 let stream = response
44 .filter_map(|response| async move {
45 match response {
46 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
47 Err(error) => Some(Err(error)),
48 }
49 })
50 .boxed();
51 Ok(stream)
52 }
53 .boxed()
54 }
55}
56
57pub enum Event {
58 Finished,
59 Undone,
60}
61
62#[derive(Clone)]
63pub enum CodegenKind {
64 Transform { range: Range<Anchor> },
65 Generate { position: Anchor },
66}
67
68pub struct Codegen {
69 provider: Arc<dyn CompletionProvider>,
70 buffer: ModelHandle<MultiBuffer>,
71 snapshot: MultiBufferSnapshot,
72 kind: CodegenKind,
73 last_equal_ranges: Vec<Range<Anchor>>,
74 transaction_id: Option<TransactionId>,
75 error: Option<anyhow::Error>,
76 generation: Task<()>,
77 idle: bool,
78 _subscription: gpui::Subscription,
79}
80
81impl Entity for Codegen {
82 type Event = Event;
83}
84
85impl Codegen {
86 pub fn new(
87 buffer: ModelHandle<MultiBuffer>,
88 mut kind: CodegenKind,
89 provider: Arc<dyn CompletionProvider>,
90 cx: &mut ModelContext<Self>,
91 ) -> Self {
92 let snapshot = buffer.read(cx).snapshot(cx);
93 match &mut kind {
94 CodegenKind::Transform { range } => {
95 let mut point_range = range.to_point(&snapshot);
96 point_range.start.column = 0;
97 if point_range.end.column > 0 || point_range.start.row == point_range.end.row {
98 point_range.end.column = snapshot.line_len(point_range.end.row);
99 }
100 range.start = snapshot.anchor_before(point_range.start);
101 range.end = snapshot.anchor_after(point_range.end);
102 }
103 CodegenKind::Generate { position } => {
104 *position = position.bias_right(&snapshot);
105 }
106 }
107
108 Self {
109 provider,
110 buffer: buffer.clone(),
111 snapshot,
112 kind,
113 last_equal_ranges: Default::default(),
114 transaction_id: Default::default(),
115 error: Default::default(),
116 idle: true,
117 generation: Task::ready(()),
118 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
119 }
120 }
121
122 fn handle_buffer_event(
123 &mut self,
124 _buffer: ModelHandle<MultiBuffer>,
125 event: &multi_buffer::Event,
126 cx: &mut ModelContext<Self>,
127 ) {
128 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
129 if self.transaction_id == Some(*transaction_id) {
130 self.transaction_id = None;
131 self.generation = Task::ready(());
132 cx.emit(Event::Undone);
133 }
134 }
135 }
136
137 pub fn range(&self) -> Range<Anchor> {
138 match &self.kind {
139 CodegenKind::Transform { range } => range.clone(),
140 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
141 }
142 }
143
144 pub fn kind(&self) -> &CodegenKind {
145 &self.kind
146 }
147
148 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
149 &self.last_equal_ranges
150 }
151
152 pub fn idle(&self) -> bool {
153 self.idle
154 }
155
156 pub fn error(&self) -> Option<&anyhow::Error> {
157 self.error.as_ref()
158 }
159
160 pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
161 let range = self.range();
162 let snapshot = self.snapshot.clone();
163 let selected_text = snapshot
164 .text_for_range(range.start..range.end)
165 .collect::<Rope>();
166
167 let selection_start = range.start.to_point(&snapshot);
168 let suggested_line_indent = snapshot
169 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
170 .into_values()
171 .next()
172 .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
173
174 let response = self.provider.complete(prompt);
175 self.generation = cx.spawn_weak(|this, mut cx| {
176 async move {
177 let generate = async {
178 let mut edit_start = range.start.to_offset(&snapshot);
179
180 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
181 let diff = cx.background().spawn(async move {
182 let chunks = strip_markdown_codeblock(response.await?);
183 futures::pin_mut!(chunks);
184 let mut diff = StreamingDiff::new(selected_text.to_string());
185
186 let mut new_text = String::new();
187 let mut base_indent = None;
188 let mut line_indent = None;
189 let mut first_line = true;
190
191 while let Some(chunk) = chunks.next().await {
192 let chunk = chunk?;
193
194 let mut lines = chunk.split('\n').peekable();
195 while let Some(line) = lines.next() {
196 new_text.push_str(line);
197 if line_indent.is_none() {
198 if let Some(non_whitespace_ch_ix) =
199 new_text.find(|ch: char| !ch.is_whitespace())
200 {
201 line_indent = Some(non_whitespace_ch_ix);
202 base_indent = base_indent.or(line_indent);
203
204 let line_indent = line_indent.unwrap();
205 let base_indent = base_indent.unwrap();
206 let indent_delta = line_indent as i32 - base_indent as i32;
207 let mut corrected_indent_len = cmp::max(
208 0,
209 suggested_line_indent.len as i32 + indent_delta,
210 )
211 as usize;
212 if first_line {
213 corrected_indent_len = corrected_indent_len
214 .saturating_sub(selection_start.column as usize);
215 }
216
217 let indent_char = suggested_line_indent.char();
218 let mut indent_buffer = [0; 4];
219 let indent_str =
220 indent_char.encode_utf8(&mut indent_buffer);
221 new_text.replace_range(
222 ..line_indent,
223 &indent_str.repeat(corrected_indent_len),
224 );
225 }
226 }
227
228 if lines.peek().is_some() {
229 hunks_tx.send(diff.push_new(&new_text)).await?;
230 hunks_tx.send(diff.push_new("\n")).await?;
231 new_text.clear();
232 line_indent = None;
233 first_line = false;
234 }
235 }
236 }
237 hunks_tx.send(diff.push_new(&new_text)).await?;
238 hunks_tx.send(diff.finish()).await?;
239
240 anyhow::Ok(())
241 });
242
243 while let Some(hunks) = hunks_rx.next().await {
244 let this = if let Some(this) = this.upgrade(&cx) {
245 this
246 } else {
247 break;
248 };
249
250 this.update(&mut cx, |this, cx| {
251 this.last_equal_ranges.clear();
252
253 let transaction = this.buffer.update(cx, |buffer, cx| {
254 // Avoid grouping assistant edits with user edits.
255 buffer.finalize_last_transaction(cx);
256
257 buffer.start_transaction(cx);
258 buffer.edit(
259 hunks.into_iter().filter_map(|hunk| match hunk {
260 Hunk::Insert { text } => {
261 let edit_start = snapshot.anchor_after(edit_start);
262 Some((edit_start..edit_start, text))
263 }
264 Hunk::Remove { len } => {
265 let edit_end = edit_start + len;
266 let edit_range = snapshot.anchor_after(edit_start)
267 ..snapshot.anchor_before(edit_end);
268 edit_start = edit_end;
269 Some((edit_range, String::new()))
270 }
271 Hunk::Keep { len } => {
272 let edit_end = edit_start + len;
273 let edit_range = snapshot.anchor_after(edit_start)
274 ..snapshot.anchor_before(edit_end);
275 edit_start = edit_end;
276 this.last_equal_ranges.push(edit_range);
277 None
278 }
279 }),
280 None,
281 cx,
282 );
283
284 buffer.end_transaction(cx)
285 });
286
287 if let Some(transaction) = transaction {
288 if let Some(first_transaction) = this.transaction_id {
289 // Group all assistant edits into the first transaction.
290 this.buffer.update(cx, |buffer, cx| {
291 buffer.merge_transactions(
292 transaction,
293 first_transaction,
294 cx,
295 )
296 });
297 } else {
298 this.transaction_id = Some(transaction);
299 this.buffer.update(cx, |buffer, cx| {
300 buffer.finalize_last_transaction(cx)
301 });
302 }
303 }
304
305 cx.notify();
306 });
307 }
308
309 diff.await?;
310 anyhow::Ok(())
311 };
312
313 let result = generate.await;
314 if let Some(this) = this.upgrade(&cx) {
315 this.update(&mut cx, |this, cx| {
316 this.last_equal_ranges.clear();
317 this.idle = true;
318 if let Err(error) = result {
319 this.error = Some(error);
320 }
321 cx.emit(Event::Finished);
322 cx.notify();
323 });
324 }
325 }
326 });
327 self.error.take();
328 self.idle = false;
329 cx.notify();
330 }
331
332 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
333 if let Some(transaction_id) = self.transaction_id {
334 self.buffer
335 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
336 }
337 }
338}
339
340fn strip_markdown_codeblock(
341 stream: impl Stream<Item = Result<String>>,
342) -> impl Stream<Item = Result<String>> {
343 let mut first_line = true;
344 let mut buffer = String::new();
345 let mut starts_with_fenced_code_block = false;
346 stream.filter_map(move |chunk| {
347 let chunk = match chunk {
348 Ok(chunk) => chunk,
349 Err(err) => return future::ready(Some(Err(err))),
350 };
351 buffer.push_str(&chunk);
352
353 if first_line {
354 if buffer == "" || buffer == "`" || buffer == "``" {
355 return future::ready(None);
356 } else if buffer.starts_with("```") {
357 starts_with_fenced_code_block = true;
358 if let Some(newline_ix) = buffer.find('\n') {
359 buffer.replace_range(..newline_ix + 1, "");
360 first_line = false;
361 } else {
362 return future::ready(None);
363 }
364 }
365 }
366
367 let text = if starts_with_fenced_code_block {
368 buffer
369 .strip_suffix("\n```\n")
370 .or_else(|| buffer.strip_suffix("\n```"))
371 .or_else(|| buffer.strip_suffix("\n``"))
372 .or_else(|| buffer.strip_suffix("\n`"))
373 .or_else(|| buffer.strip_suffix('\n'))
374 .unwrap_or(&buffer)
375 } else {
376 &buffer
377 };
378
379 if text.contains('\n') {
380 first_line = false;
381 }
382
383 let remainder = buffer.split_off(text.len());
384 let result = if buffer.is_empty() {
385 None
386 } else {
387 Some(Ok(buffer.clone()))
388 };
389 buffer = remainder;
390 future::ready(result)
391 })
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use futures::stream;
398 use gpui::{executor::Deterministic, TestAppContext};
399 use indoc::indoc;
400 use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
401 use parking_lot::Mutex;
402 use rand::prelude::*;
403 use settings::SettingsStore;
404
405 #[gpui::test(iterations = 10)]
406 async fn test_transform_autoindent(
407 cx: &mut TestAppContext,
408 mut rng: StdRng,
409 deterministic: Arc<Deterministic>,
410 ) {
411 cx.set_global(cx.read(SettingsStore::test));
412 cx.update(language_settings::init);
413
414 let text = indoc! {"
415 fn main() {
416 let x = 0;
417 for _ in 0..10 {
418 x += 1;
419 }
420 }
421 "};
422 let buffer =
423 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
424 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
425 let range = buffer.read_with(cx, |buffer, cx| {
426 let snapshot = buffer.snapshot(cx);
427 snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
428 });
429 let provider = Arc::new(TestCompletionProvider::new());
430 let codegen = cx.add_model(|cx| {
431 Codegen::new(
432 buffer.clone(),
433 CodegenKind::Transform { range },
434 provider.clone(),
435 cx,
436 )
437 });
438 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
439
440 let mut new_text = concat!(
441 " let mut x = 0;\n",
442 " while x < 10 {\n",
443 " x += 1;\n",
444 " }",
445 );
446 while !new_text.is_empty() {
447 let max_len = cmp::min(new_text.len(), 10);
448 let len = rng.gen_range(1..=max_len);
449 let (chunk, suffix) = new_text.split_at(len);
450 provider.send_completion(chunk);
451 new_text = suffix;
452 deterministic.run_until_parked();
453 }
454 provider.finish_completion();
455 deterministic.run_until_parked();
456
457 assert_eq!(
458 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
459 indoc! {"
460 fn main() {
461 let mut x = 0;
462 while x < 10 {
463 x += 1;
464 }
465 }
466 "}
467 );
468 }
469
470 #[gpui::test(iterations = 10)]
471 async fn test_autoindent_when_generating_past_indentation(
472 cx: &mut TestAppContext,
473 mut rng: StdRng,
474 deterministic: Arc<Deterministic>,
475 ) {
476 cx.set_global(cx.read(SettingsStore::test));
477 cx.update(language_settings::init);
478
479 let text = indoc! {"
480 fn main() {
481 le
482 }
483 "};
484 let buffer =
485 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
486 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
487 let position = buffer.read_with(cx, |buffer, cx| {
488 let snapshot = buffer.snapshot(cx);
489 snapshot.anchor_before(Point::new(1, 6))
490 });
491 let provider = Arc::new(TestCompletionProvider::new());
492 let codegen = cx.add_model(|cx| {
493 Codegen::new(
494 buffer.clone(),
495 CodegenKind::Generate { position },
496 provider.clone(),
497 cx,
498 )
499 });
500 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
501
502 let mut new_text = concat!(
503 "t mut x = 0;\n",
504 "while x < 10 {\n",
505 " x += 1;\n",
506 "}", //
507 );
508 while !new_text.is_empty() {
509 let max_len = cmp::min(new_text.len(), 10);
510 let len = rng.gen_range(1..=max_len);
511 let (chunk, suffix) = new_text.split_at(len);
512 provider.send_completion(chunk);
513 new_text = suffix;
514 deterministic.run_until_parked();
515 }
516 provider.finish_completion();
517 deterministic.run_until_parked();
518
519 assert_eq!(
520 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
521 indoc! {"
522 fn main() {
523 let mut x = 0;
524 while x < 10 {
525 x += 1;
526 }
527 }
528 "}
529 );
530 }
531
532 #[gpui::test(iterations = 10)]
533 async fn test_autoindent_when_generating_before_indentation(
534 cx: &mut TestAppContext,
535 mut rng: StdRng,
536 deterministic: Arc<Deterministic>,
537 ) {
538 cx.set_global(cx.read(SettingsStore::test));
539 cx.update(language_settings::init);
540
541 let text = concat!(
542 "fn main() {\n",
543 " \n",
544 "}\n" //
545 );
546 let buffer =
547 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
548 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
549 let position = buffer.read_with(cx, |buffer, cx| {
550 let snapshot = buffer.snapshot(cx);
551 snapshot.anchor_before(Point::new(1, 2))
552 });
553 let provider = Arc::new(TestCompletionProvider::new());
554 let codegen = cx.add_model(|cx| {
555 Codegen::new(
556 buffer.clone(),
557 CodegenKind::Generate { position },
558 provider.clone(),
559 cx,
560 )
561 });
562 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
563
564 let mut new_text = concat!(
565 "let mut x = 0;\n",
566 "while x < 10 {\n",
567 " x += 1;\n",
568 "}", //
569 );
570 while !new_text.is_empty() {
571 let max_len = cmp::min(new_text.len(), 10);
572 let len = rng.gen_range(1..=max_len);
573 let (chunk, suffix) = new_text.split_at(len);
574 provider.send_completion(chunk);
575 new_text = suffix;
576 deterministic.run_until_parked();
577 }
578 provider.finish_completion();
579 deterministic.run_until_parked();
580
581 assert_eq!(
582 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
583 indoc! {"
584 fn main() {
585 let mut x = 0;
586 while x < 10 {
587 x += 1;
588 }
589 }
590 "}
591 );
592 }
593
594 #[gpui::test]
595 async fn test_strip_markdown_codeblock() {
596 assert_eq!(
597 strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
598 .map(|chunk| chunk.unwrap())
599 .collect::<String>()
600 .await,
601 "Lorem ipsum dolor"
602 );
603 assert_eq!(
604 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
605 .map(|chunk| chunk.unwrap())
606 .collect::<String>()
607 .await,
608 "Lorem ipsum dolor"
609 );
610 assert_eq!(
611 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
612 .map(|chunk| chunk.unwrap())
613 .collect::<String>()
614 .await,
615 "Lorem ipsum dolor"
616 );
617 assert_eq!(
618 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
619 .map(|chunk| chunk.unwrap())
620 .collect::<String>()
621 .await,
622 "Lorem ipsum dolor"
623 );
624 assert_eq!(
625 strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
626 .map(|chunk| chunk.unwrap())
627 .collect::<String>()
628 .await,
629 "```js\nLorem ipsum dolor\n```"
630 );
631 assert_eq!(
632 strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
633 .map(|chunk| chunk.unwrap())
634 .collect::<String>()
635 .await,
636 "``\nLorem ipsum dolor\n```"
637 );
638
639 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
640 stream::iter(
641 text.chars()
642 .collect::<Vec<_>>()
643 .chunks(size)
644 .map(|chunk| Ok(chunk.iter().collect::<String>()))
645 .collect::<Vec<_>>(),
646 )
647 }
648 }
649
650 struct TestCompletionProvider {
651 last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
652 }
653
654 impl TestCompletionProvider {
655 fn new() -> Self {
656 Self {
657 last_completion_tx: Mutex::new(None),
658 }
659 }
660
661 fn send_completion(&self, completion: impl Into<String>) {
662 let mut tx = self.last_completion_tx.lock();
663 tx.as_mut().unwrap().try_send(completion.into()).unwrap();
664 }
665
666 fn finish_completion(&self) {
667 self.last_completion_tx.lock().take().unwrap();
668 }
669 }
670
671 impl CompletionProvider for TestCompletionProvider {
672 fn complete(
673 &self,
674 _prompt: OpenAIRequest,
675 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
676 let (tx, rx) = mpsc::channel(1);
677 *self.last_completion_tx.lock() = Some(tx);
678 async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
679 }
680 }
681
682 fn rust_lang() -> Language {
683 Language::new(
684 LanguageConfig {
685 name: "Rust".into(),
686 path_suffixes: vec!["rs".to_string()],
687 ..Default::default()
688 },
689 Some(tree_sitter_rust::language()),
690 )
691 .with_indents_query(
692 r#"
693 (call_expression) @indent
694 (field_expression) @indent
695 (_ "(" ")" @end) @indent
696 (_ "{" "}" @end) @indent
697 "#,
698 )
699 .unwrap()
700 }
701}