1use crate::{
2 stream_completion,
3 streaming_diff::{Hunk, StreamingDiff},
4 OpenAIRequest,
5};
6use anyhow::Result;
7use editor::{
8 multi_buffer, Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
9};
10use futures::{
11 channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
12};
13use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
14use language::{Rope, TransactionId};
15use std::{cmp, future, ops::Range, sync::Arc};
16
17pub trait CompletionProvider {
18 fn complete(
19 &self,
20 prompt: OpenAIRequest,
21 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
22}
23
24pub struct OpenAICompletionProvider {
25 api_key: String,
26 executor: Arc<Background>,
27}
28
29impl OpenAICompletionProvider {
30 pub fn new(api_key: String, executor: Arc<Background>) -> Self {
31 Self { api_key, executor }
32 }
33}
34
35impl CompletionProvider for OpenAICompletionProvider {
36 fn complete(
37 &self,
38 prompt: OpenAIRequest,
39 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
40 let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
41 async move {
42 let response = request.await?;
43 let stream = response
44 .filter_map(|response| async move {
45 match response {
46 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
47 Err(error) => Some(Err(error)),
48 }
49 })
50 .boxed();
51 Ok(stream)
52 }
53 .boxed()
54 }
55}
56
57pub enum Event {
58 Finished,
59 Undone,
60}
61
62#[derive(Clone)]
63pub enum CodegenKind {
64 Transform { range: Range<Anchor> },
65 Generate { position: Anchor },
66}
67
68pub struct Codegen {
69 provider: Arc<dyn CompletionProvider>,
70 buffer: ModelHandle<MultiBuffer>,
71 snapshot: MultiBufferSnapshot,
72 kind: CodegenKind,
73 last_equal_ranges: Vec<Range<Anchor>>,
74 transaction_id: Option<TransactionId>,
75 error: Option<anyhow::Error>,
76 generation: Task<()>,
77 idle: bool,
78 _subscription: gpui::Subscription,
79}
80
81impl Entity for Codegen {
82 type Event = Event;
83}
84
85impl Codegen {
86 pub fn new(
87 buffer: ModelHandle<MultiBuffer>,
88 mut kind: CodegenKind,
89 provider: Arc<dyn CompletionProvider>,
90 cx: &mut ModelContext<Self>,
91 ) -> Self {
92 let snapshot = buffer.read(cx).snapshot(cx);
93 match &mut kind {
94 CodegenKind::Transform { range } => {
95 let mut point_range = range.to_point(&snapshot);
96 point_range.start.column = 0;
97 if point_range.end.column > 0 || point_range.start.row == point_range.end.row {
98 point_range.end.column = snapshot.line_len(point_range.end.row);
99 }
100 range.start = snapshot.anchor_before(point_range.start);
101 range.end = snapshot.anchor_after(point_range.end);
102 }
103 CodegenKind::Generate { position } => {
104 *position = position.bias_right(&snapshot);
105 }
106 }
107
108 Self {
109 provider,
110 buffer: buffer.clone(),
111 snapshot,
112 kind,
113 last_equal_ranges: Default::default(),
114 transaction_id: Default::default(),
115 error: Default::default(),
116 idle: true,
117 generation: Task::ready(()),
118 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
119 }
120 }
121
122 fn handle_buffer_event(
123 &mut self,
124 _buffer: ModelHandle<MultiBuffer>,
125 event: &multi_buffer::Event,
126 cx: &mut ModelContext<Self>,
127 ) {
128 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
129 if self.transaction_id == Some(*transaction_id) {
130 self.transaction_id = None;
131 self.generation = Task::ready(());
132 cx.emit(Event::Undone);
133 }
134 }
135 }
136
137 pub fn range(&self) -> Range<Anchor> {
138 match &self.kind {
139 CodegenKind::Transform { range } => range.clone(),
140 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
141 }
142 }
143
144 pub fn kind(&self) -> &CodegenKind {
145 &self.kind
146 }
147
148 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
149 &self.last_equal_ranges
150 }
151
152 pub fn idle(&self) -> bool {
153 self.idle
154 }
155
156 pub fn error(&self) -> Option<&anyhow::Error> {
157 self.error.as_ref()
158 }
159
160 pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
161 let range = self.range();
162 let snapshot = self.snapshot.clone();
163 let selected_text = snapshot
164 .text_for_range(range.start..range.end)
165 .collect::<Rope>();
166
167 let selection_start = range.start.to_point(&snapshot);
168 let suggested_line_indent = snapshot
169 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
170 .into_values()
171 .next()
172 .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
173
174 let response = self.provider.complete(prompt);
175 self.generation = cx.spawn_weak(|this, mut cx| {
176 async move {
177 let generate = async {
178 let mut edit_start = range.start.to_offset(&snapshot);
179
180 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
181 let diff = cx.background().spawn(async move {
182 let chunks = strip_markdown_codeblock(response.await?);
183 futures::pin_mut!(chunks);
184 let mut diff = StreamingDiff::new(selected_text.to_string());
185
186 let mut new_text = String::new();
187 let mut base_indent = None;
188 let mut line_indent = None;
189 let mut first_line = true;
190
191 while let Some(chunk) = chunks.next().await {
192 let chunk = chunk?;
193
194 let mut lines = chunk.split('\n').peekable();
195 while let Some(line) = lines.next() {
196 new_text.push_str(line);
197 if line_indent.is_none() {
198 if let Some(non_whitespace_ch_ix) =
199 new_text.find(|ch: char| !ch.is_whitespace())
200 {
201 line_indent = Some(non_whitespace_ch_ix);
202 base_indent = base_indent.or(line_indent);
203
204 let line_indent = line_indent.unwrap();
205 let base_indent = base_indent.unwrap();
206 let indent_delta = line_indent as i32 - base_indent as i32;
207 let mut corrected_indent_len = cmp::max(
208 0,
209 suggested_line_indent.len as i32 + indent_delta,
210 )
211 as usize;
212 if first_line {
213 corrected_indent_len = corrected_indent_len
214 .saturating_sub(selection_start.column as usize);
215 }
216
217 let indent_char = suggested_line_indent.char();
218 let mut indent_buffer = [0; 4];
219 let indent_str =
220 indent_char.encode_utf8(&mut indent_buffer);
221 new_text.replace_range(
222 ..line_indent,
223 &indent_str.repeat(corrected_indent_len),
224 );
225 }
226 }
227
228 if line_indent.is_some() {
229 hunks_tx.send(diff.push_new(&new_text)).await?;
230 new_text.clear();
231 }
232
233 if lines.peek().is_some() {
234 hunks_tx.send(diff.push_new("\n")).await?;
235 line_indent = None;
236 first_line = false;
237 }
238 }
239 }
240 hunks_tx.send(diff.push_new(&new_text)).await?;
241 hunks_tx.send(diff.finish()).await?;
242
243 anyhow::Ok(())
244 });
245
246 while let Some(hunks) = hunks_rx.next().await {
247 let this = if let Some(this) = this.upgrade(&cx) {
248 this
249 } else {
250 break;
251 };
252
253 this.update(&mut cx, |this, cx| {
254 this.last_equal_ranges.clear();
255
256 let transaction = this.buffer.update(cx, |buffer, cx| {
257 // Avoid grouping assistant edits with user edits.
258 buffer.finalize_last_transaction(cx);
259
260 buffer.start_transaction(cx);
261 buffer.edit(
262 hunks.into_iter().filter_map(|hunk| match hunk {
263 Hunk::Insert { text } => {
264 let edit_start = snapshot.anchor_after(edit_start);
265 Some((edit_start..edit_start, text))
266 }
267 Hunk::Remove { len } => {
268 let edit_end = edit_start + len;
269 let edit_range = snapshot.anchor_after(edit_start)
270 ..snapshot.anchor_before(edit_end);
271 edit_start = edit_end;
272 Some((edit_range, String::new()))
273 }
274 Hunk::Keep { len } => {
275 let edit_end = edit_start + len;
276 let edit_range = snapshot.anchor_after(edit_start)
277 ..snapshot.anchor_before(edit_end);
278 edit_start = edit_end;
279 this.last_equal_ranges.push(edit_range);
280 None
281 }
282 }),
283 None,
284 cx,
285 );
286
287 buffer.end_transaction(cx)
288 });
289
290 if let Some(transaction) = transaction {
291 if let Some(first_transaction) = this.transaction_id {
292 // Group all assistant edits into the first transaction.
293 this.buffer.update(cx, |buffer, cx| {
294 buffer.merge_transactions(
295 transaction,
296 first_transaction,
297 cx,
298 )
299 });
300 } else {
301 this.transaction_id = Some(transaction);
302 this.buffer.update(cx, |buffer, cx| {
303 buffer.finalize_last_transaction(cx)
304 });
305 }
306 }
307
308 cx.notify();
309 });
310 }
311
312 diff.await?;
313 anyhow::Ok(())
314 };
315
316 let result = generate.await;
317 if let Some(this) = this.upgrade(&cx) {
318 this.update(&mut cx, |this, cx| {
319 this.last_equal_ranges.clear();
320 this.idle = true;
321 if let Err(error) = result {
322 this.error = Some(error);
323 }
324 cx.emit(Event::Finished);
325 cx.notify();
326 });
327 }
328 }
329 });
330 self.error.take();
331 self.idle = false;
332 cx.notify();
333 }
334
335 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
336 if let Some(transaction_id) = self.transaction_id {
337 self.buffer
338 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
339 }
340 }
341}
342
343fn strip_markdown_codeblock(
344 stream: impl Stream<Item = Result<String>>,
345) -> impl Stream<Item = Result<String>> {
346 let mut first_line = true;
347 let mut buffer = String::new();
348 let mut starts_with_fenced_code_block = false;
349 stream.filter_map(move |chunk| {
350 let chunk = match chunk {
351 Ok(chunk) => chunk,
352 Err(err) => return future::ready(Some(Err(err))),
353 };
354 buffer.push_str(&chunk);
355
356 if first_line {
357 if buffer == "" || buffer == "`" || buffer == "``" {
358 return future::ready(None);
359 } else if buffer.starts_with("```") {
360 starts_with_fenced_code_block = true;
361 if let Some(newline_ix) = buffer.find('\n') {
362 buffer.replace_range(..newline_ix + 1, "");
363 first_line = false;
364 } else {
365 return future::ready(None);
366 }
367 }
368 }
369
370 let text = if starts_with_fenced_code_block {
371 buffer
372 .strip_suffix("\n```\n")
373 .or_else(|| buffer.strip_suffix("\n```"))
374 .or_else(|| buffer.strip_suffix("\n``"))
375 .or_else(|| buffer.strip_suffix("\n`"))
376 .or_else(|| buffer.strip_suffix('\n'))
377 .unwrap_or(&buffer)
378 } else {
379 &buffer
380 };
381
382 if text.contains('\n') {
383 first_line = false;
384 }
385
386 let remainder = buffer.split_off(text.len());
387 let result = if buffer.is_empty() {
388 None
389 } else {
390 Some(Ok(buffer.clone()))
391 };
392 buffer = remainder;
393 future::ready(result)
394 })
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use futures::stream;
401 use gpui::{executor::Deterministic, TestAppContext};
402 use indoc::indoc;
403 use language::{language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, Point};
404 use parking_lot::Mutex;
405 use rand::prelude::*;
406 use settings::SettingsStore;
407
408 #[gpui::test(iterations = 10)]
409 async fn test_transform_autoindent(
410 cx: &mut TestAppContext,
411 mut rng: StdRng,
412 deterministic: Arc<Deterministic>,
413 ) {
414 cx.set_global(cx.read(SettingsStore::test));
415 cx.update(language_settings::init);
416
417 let text = indoc! {"
418 fn main() {
419 let x = 0;
420 for _ in 0..10 {
421 x += 1;
422 }
423 }
424 "};
425 let buffer =
426 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
427 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
428 let range = buffer.read_with(cx, |buffer, cx| {
429 let snapshot = buffer.snapshot(cx);
430 snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
431 });
432 let provider = Arc::new(TestCompletionProvider::new());
433 let codegen = cx.add_model(|cx| {
434 Codegen::new(
435 buffer.clone(),
436 CodegenKind::Transform { range },
437 provider.clone(),
438 cx,
439 )
440 });
441 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
442
443 let mut new_text = concat!(
444 " let mut x = 0;\n",
445 " while x < 10 {\n",
446 " x += 1;\n",
447 " }",
448 );
449 while !new_text.is_empty() {
450 let max_len = cmp::min(new_text.len(), 10);
451 let len = rng.gen_range(1..=max_len);
452 let (chunk, suffix) = new_text.split_at(len);
453 provider.send_completion(chunk);
454 new_text = suffix;
455 deterministic.run_until_parked();
456 }
457 provider.finish_completion();
458 deterministic.run_until_parked();
459
460 assert_eq!(
461 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
462 indoc! {"
463 fn main() {
464 let mut x = 0;
465 while x < 10 {
466 x += 1;
467 }
468 }
469 "}
470 );
471 }
472
473 #[gpui::test(iterations = 10)]
474 async fn test_autoindent_when_generating_past_indentation(
475 cx: &mut TestAppContext,
476 mut rng: StdRng,
477 deterministic: Arc<Deterministic>,
478 ) {
479 cx.set_global(cx.read(SettingsStore::test));
480 cx.update(language_settings::init);
481
482 let text = indoc! {"
483 fn main() {
484 le
485 }
486 "};
487 let buffer =
488 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
489 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
490 let position = buffer.read_with(cx, |buffer, cx| {
491 let snapshot = buffer.snapshot(cx);
492 snapshot.anchor_before(Point::new(1, 6))
493 });
494 let provider = Arc::new(TestCompletionProvider::new());
495 let codegen = cx.add_model(|cx| {
496 Codegen::new(
497 buffer.clone(),
498 CodegenKind::Generate { position },
499 provider.clone(),
500 cx,
501 )
502 });
503 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
504
505 let mut new_text = concat!(
506 "t mut x = 0;\n",
507 "while x < 10 {\n",
508 " x += 1;\n",
509 "}", //
510 );
511 while !new_text.is_empty() {
512 let max_len = cmp::min(new_text.len(), 10);
513 let len = rng.gen_range(1..=max_len);
514 let (chunk, suffix) = new_text.split_at(len);
515 provider.send_completion(chunk);
516 new_text = suffix;
517 deterministic.run_until_parked();
518 }
519 provider.finish_completion();
520 deterministic.run_until_parked();
521
522 assert_eq!(
523 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
524 indoc! {"
525 fn main() {
526 let mut x = 0;
527 while x < 10 {
528 x += 1;
529 }
530 }
531 "}
532 );
533 }
534
535 #[gpui::test(iterations = 10)]
536 async fn test_autoindent_when_generating_before_indentation(
537 cx: &mut TestAppContext,
538 mut rng: StdRng,
539 deterministic: Arc<Deterministic>,
540 ) {
541 cx.set_global(cx.read(SettingsStore::test));
542 cx.update(language_settings::init);
543
544 let text = concat!(
545 "fn main() {\n",
546 " \n",
547 "}\n" //
548 );
549 let buffer =
550 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
551 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
552 let position = buffer.read_with(cx, |buffer, cx| {
553 let snapshot = buffer.snapshot(cx);
554 snapshot.anchor_before(Point::new(1, 2))
555 });
556 let provider = Arc::new(TestCompletionProvider::new());
557 let codegen = cx.add_model(|cx| {
558 Codegen::new(
559 buffer.clone(),
560 CodegenKind::Generate { position },
561 provider.clone(),
562 cx,
563 )
564 });
565 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
566
567 let mut new_text = concat!(
568 "let mut x = 0;\n",
569 "while x < 10 {\n",
570 " x += 1;\n",
571 "}", //
572 );
573 while !new_text.is_empty() {
574 let max_len = cmp::min(new_text.len(), 10);
575 let len = rng.gen_range(1..=max_len);
576 let (chunk, suffix) = new_text.split_at(len);
577 provider.send_completion(chunk);
578 new_text = suffix;
579 deterministic.run_until_parked();
580 }
581 provider.finish_completion();
582 deterministic.run_until_parked();
583
584 assert_eq!(
585 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
586 indoc! {"
587 fn main() {
588 let mut x = 0;
589 while x < 10 {
590 x += 1;
591 }
592 }
593 "}
594 );
595 }
596
597 #[gpui::test]
598 async fn test_strip_markdown_codeblock() {
599 assert_eq!(
600 strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
601 .map(|chunk| chunk.unwrap())
602 .collect::<String>()
603 .await,
604 "Lorem ipsum dolor"
605 );
606 assert_eq!(
607 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
608 .map(|chunk| chunk.unwrap())
609 .collect::<String>()
610 .await,
611 "Lorem ipsum dolor"
612 );
613 assert_eq!(
614 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
615 .map(|chunk| chunk.unwrap())
616 .collect::<String>()
617 .await,
618 "Lorem ipsum dolor"
619 );
620 assert_eq!(
621 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
622 .map(|chunk| chunk.unwrap())
623 .collect::<String>()
624 .await,
625 "Lorem ipsum dolor"
626 );
627 assert_eq!(
628 strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
629 .map(|chunk| chunk.unwrap())
630 .collect::<String>()
631 .await,
632 "```js\nLorem ipsum dolor\n```"
633 );
634 assert_eq!(
635 strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
636 .map(|chunk| chunk.unwrap())
637 .collect::<String>()
638 .await,
639 "``\nLorem ipsum dolor\n```"
640 );
641
642 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
643 stream::iter(
644 text.chars()
645 .collect::<Vec<_>>()
646 .chunks(size)
647 .map(|chunk| Ok(chunk.iter().collect::<String>()))
648 .collect::<Vec<_>>(),
649 )
650 }
651 }
652
653 struct TestCompletionProvider {
654 last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
655 }
656
657 impl TestCompletionProvider {
658 fn new() -> Self {
659 Self {
660 last_completion_tx: Mutex::new(None),
661 }
662 }
663
664 fn send_completion(&self, completion: impl Into<String>) {
665 let mut tx = self.last_completion_tx.lock();
666 tx.as_mut().unwrap().try_send(completion.into()).unwrap();
667 }
668
669 fn finish_completion(&self) {
670 self.last_completion_tx.lock().take().unwrap();
671 }
672 }
673
674 impl CompletionProvider for TestCompletionProvider {
675 fn complete(
676 &self,
677 _prompt: OpenAIRequest,
678 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
679 let (tx, rx) = mpsc::channel(1);
680 *self.last_completion_tx.lock() = Some(tx);
681 async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
682 }
683 }
684
685 fn rust_lang() -> Language {
686 Language::new(
687 LanguageConfig {
688 name: "Rust".into(),
689 path_suffixes: vec!["rs".to_string()],
690 ..Default::default()
691 },
692 Some(tree_sitter_rust::language()),
693 )
694 .with_indents_query(
695 r#"
696 (call_expression) @indent
697 (field_expression) @indent
698 (_ "(" ")" @end) @indent
699 (_ "{" "}" @end) @indent
700 "#,
701 )
702 .unwrap()
703 }
704}