1use crate::{
2 stream_completion,
3 streaming_diff::{Hunk, StreamingDiff},
4 OpenAIRequest,
5};
6use anyhow::Result;
7use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint};
8use futures::{
9 channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
10};
11use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
12use language::{IndentSize, Point, Rope, TransactionId};
13use std::{cmp, future, ops::Range, sync::Arc};
14
15pub trait CompletionProvider {
16 fn complete(
17 &self,
18 prompt: OpenAIRequest,
19 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
20}
21
22pub struct OpenAICompletionProvider {
23 api_key: String,
24 executor: Arc<Background>,
25}
26
27impl OpenAICompletionProvider {
28 pub fn new(api_key: String, executor: Arc<Background>) -> Self {
29 Self { api_key, executor }
30 }
31}
32
33impl CompletionProvider for OpenAICompletionProvider {
34 fn complete(
35 &self,
36 prompt: OpenAIRequest,
37 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
38 let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
39 async move {
40 let response = request.await?;
41 let stream = response
42 .filter_map(|response| async move {
43 match response {
44 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
45 Err(error) => Some(Err(error)),
46 }
47 })
48 .boxed();
49 Ok(stream)
50 }
51 .boxed()
52 }
53}
54
55pub enum Event {
56 Finished,
57 Undone,
58}
59
60pub struct Codegen {
61 provider: Arc<dyn CompletionProvider>,
62 buffer: ModelHandle<MultiBuffer>,
63 range: Range<Anchor>,
64 last_equal_ranges: Vec<Range<Anchor>>,
65 transaction_id: Option<TransactionId>,
66 error: Option<anyhow::Error>,
67 generation: Task<()>,
68 idle: bool,
69 _subscription: gpui::Subscription,
70}
71
72impl Entity for Codegen {
73 type Event = Event;
74}
75
76impl Codegen {
77 pub fn new(
78 buffer: ModelHandle<MultiBuffer>,
79 range: Range<Anchor>,
80 provider: Arc<dyn CompletionProvider>,
81 cx: &mut ModelContext<Self>,
82 ) -> Self {
83 Self {
84 provider,
85 buffer: buffer.clone(),
86 range,
87 last_equal_ranges: Default::default(),
88 transaction_id: Default::default(),
89 error: Default::default(),
90 idle: true,
91 generation: Task::ready(()),
92 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
93 }
94 }
95
96 fn handle_buffer_event(
97 &mut self,
98 _buffer: ModelHandle<MultiBuffer>,
99 event: &multi_buffer::Event,
100 cx: &mut ModelContext<Self>,
101 ) {
102 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
103 if self.transaction_id == Some(*transaction_id) {
104 self.transaction_id = None;
105 self.generation = Task::ready(());
106 cx.emit(Event::Undone);
107 }
108 }
109 }
110
111 pub fn range(&self) -> Range<Anchor> {
112 self.range.clone()
113 }
114
115 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
116 &self.last_equal_ranges
117 }
118
119 pub fn idle(&self) -> bool {
120 self.idle
121 }
122
123 pub fn error(&self) -> Option<&anyhow::Error> {
124 self.error.as_ref()
125 }
126
127 pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
128 let range = self.range.clone();
129 let snapshot = self.buffer.read(cx).snapshot(cx);
130 let selected_text = snapshot
131 .text_for_range(range.start..range.end)
132 .collect::<Rope>();
133
134 let selection_start = range.start.to_point(&snapshot);
135 let selection_end = range.end.to_point(&snapshot);
136
137 let mut base_indent: Option<IndentSize> = None;
138 let mut start_row = selection_start.row;
139 if snapshot.is_line_blank(start_row) {
140 if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
141 start_row = prev_non_blank_row;
142 }
143 }
144 for row in start_row..=selection_end.row {
145 if snapshot.is_line_blank(row) {
146 continue;
147 }
148
149 let line_indent = snapshot.indent_size_for_line(row);
150 if let Some(base_indent) = base_indent.as_mut() {
151 if line_indent.len < base_indent.len {
152 *base_indent = line_indent;
153 }
154 } else {
155 base_indent = Some(line_indent);
156 }
157 }
158
159 let mut normalized_selected_text = selected_text.clone();
160 if let Some(base_indent) = base_indent {
161 for row in selection_start.row..=selection_end.row {
162 let selection_row = row - selection_start.row;
163 let line_start =
164 normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
165 let indent_len = if row == selection_start.row {
166 base_indent.len.saturating_sub(selection_start.column)
167 } else {
168 let line_len = normalized_selected_text.line_len(selection_row);
169 cmp::min(line_len, base_indent.len)
170 };
171 let indent_end = cmp::min(
172 line_start + indent_len as usize,
173 normalized_selected_text.len(),
174 );
175 normalized_selected_text.replace(line_start..indent_end, "");
176 }
177 }
178
179 let response = self.provider.complete(prompt);
180 self.generation = cx.spawn_weak(|this, mut cx| {
181 async move {
182 let generate = async {
183 let mut edit_start = range.start.to_offset(&snapshot);
184
185 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
186 let diff = cx.background().spawn(async move {
187 let chunks = strip_markdown_codeblock(response.await?);
188 futures::pin_mut!(chunks);
189 let mut diff = StreamingDiff::new(selected_text.to_string());
190
191 let mut indent_len;
192 let indent_text;
193 if let Some(base_indent) = base_indent {
194 indent_len = base_indent.len;
195 indent_text = match base_indent.kind {
196 language::IndentKind::Space => " ",
197 language::IndentKind::Tab => "\t",
198 };
199 } else {
200 indent_len = 0;
201 indent_text = "";
202 };
203
204 let mut first_line_len = 0;
205 let mut first_line_non_whitespace_char_ix = None;
206 let mut first_line = true;
207 let mut new_text = String::new();
208
209 while let Some(chunk) = chunks.next().await {
210 let chunk = chunk?;
211
212 let mut lines = chunk.split('\n');
213 if let Some(mut line) = lines.next() {
214 if first_line {
215 if first_line_non_whitespace_char_ix.is_none() {
216 if let Some(mut char_ix) =
217 line.find(|ch: char| !ch.is_whitespace())
218 {
219 line = &line[char_ix..];
220 char_ix += first_line_len;
221 first_line_non_whitespace_char_ix = Some(char_ix);
222 let first_line_indent = char_ix
223 .saturating_sub(selection_start.column as usize)
224 as usize;
225 new_text
226 .push_str(&indent_text.repeat(first_line_indent));
227 indent_len = indent_len.saturating_sub(char_ix as u32);
228 }
229 }
230 first_line_len += line.len();
231 }
232
233 if first_line_non_whitespace_char_ix.is_some() {
234 new_text.push_str(line);
235 }
236 }
237
238 for line in lines {
239 first_line = false;
240 new_text.push('\n');
241 if !line.is_empty() {
242 new_text.push_str(&indent_text.repeat(indent_len as usize));
243 }
244 new_text.push_str(line);
245 }
246
247 let hunks = diff.push_new(&new_text);
248 hunks_tx.send(hunks).await?;
249 new_text.clear();
250 }
251 hunks_tx.send(diff.finish()).await?;
252
253 anyhow::Ok(())
254 });
255
256 while let Some(hunks) = hunks_rx.next().await {
257 let this = if let Some(this) = this.upgrade(&cx) {
258 this
259 } else {
260 break;
261 };
262
263 this.update(&mut cx, |this, cx| {
264 this.last_equal_ranges.clear();
265
266 let transaction = this.buffer.update(cx, |buffer, cx| {
267 // Avoid grouping assistant edits with user edits.
268 buffer.finalize_last_transaction(cx);
269
270 buffer.start_transaction(cx);
271 buffer.edit(
272 hunks.into_iter().filter_map(|hunk| match hunk {
273 Hunk::Insert { text } => {
274 let edit_start = snapshot.anchor_after(edit_start);
275 Some((edit_start..edit_start, text))
276 }
277 Hunk::Remove { len } => {
278 let edit_end = edit_start + len;
279 let edit_range = snapshot.anchor_after(edit_start)
280 ..snapshot.anchor_before(edit_end);
281 edit_start = edit_end;
282 Some((edit_range, String::new()))
283 }
284 Hunk::Keep { len } => {
285 let edit_end = edit_start + len;
286 let edit_range = snapshot.anchor_after(edit_start)
287 ..snapshot.anchor_before(edit_end);
288 edit_start += len;
289 this.last_equal_ranges.push(edit_range);
290 None
291 }
292 }),
293 None,
294 cx,
295 );
296
297 buffer.end_transaction(cx)
298 });
299
300 if let Some(transaction) = transaction {
301 if let Some(first_transaction) = this.transaction_id {
302 // Group all assistant edits into the first transaction.
303 this.buffer.update(cx, |buffer, cx| {
304 buffer.merge_transactions(
305 transaction,
306 first_transaction,
307 cx,
308 )
309 });
310 } else {
311 this.transaction_id = Some(transaction);
312 this.buffer.update(cx, |buffer, cx| {
313 buffer.finalize_last_transaction(cx)
314 });
315 }
316 }
317
318 cx.notify();
319 });
320 }
321
322 diff.await?;
323 anyhow::Ok(())
324 };
325
326 let result = generate.await;
327 if let Some(this) = this.upgrade(&cx) {
328 this.update(&mut cx, |this, cx| {
329 this.last_equal_ranges.clear();
330 this.idle = true;
331 if let Err(error) = result {
332 this.error = Some(error);
333 }
334 cx.emit(Event::Finished);
335 cx.notify();
336 });
337 }
338 }
339 });
340 self.error.take();
341 self.idle = false;
342 cx.notify();
343 }
344
345 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
346 if let Some(transaction_id) = self.transaction_id {
347 self.buffer
348 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
349 }
350 }
351}
352
353fn strip_markdown_codeblock(
354 stream: impl Stream<Item = Result<String>>,
355) -> impl Stream<Item = Result<String>> {
356 let mut first_line = true;
357 let mut buffer = String::new();
358 let mut starts_with_fenced_code_block = false;
359 stream.filter_map(move |chunk| {
360 let chunk = match chunk {
361 Ok(chunk) => chunk,
362 Err(err) => return future::ready(Some(Err(err))),
363 };
364 buffer.push_str(&chunk);
365
366 if first_line {
367 if buffer == "" || buffer == "`" || buffer == "``" {
368 return future::ready(None);
369 } else if buffer.starts_with("```") {
370 starts_with_fenced_code_block = true;
371 if let Some(newline_ix) = buffer.find('\n') {
372 buffer.replace_range(..newline_ix + 1, "");
373 first_line = false;
374 } else {
375 return future::ready(None);
376 }
377 }
378 }
379
380 let text = if starts_with_fenced_code_block {
381 buffer
382 .strip_suffix("\n```\n")
383 .or_else(|| buffer.strip_suffix("\n```"))
384 .or_else(|| buffer.strip_suffix("\n``"))
385 .or_else(|| buffer.strip_suffix("\n`"))
386 .or_else(|| buffer.strip_suffix('\n'))
387 .unwrap_or(&buffer)
388 } else {
389 &buffer
390 };
391
392 if text.contains('\n') {
393 first_line = false;
394 }
395
396 let remainder = buffer.split_off(text.len());
397 let result = if buffer.is_empty() {
398 None
399 } else {
400 Some(Ok(buffer.clone()))
401 };
402 buffer = remainder;
403 future::ready(result)
404 })
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410 use futures::stream;
411 use gpui::{executor::Deterministic, TestAppContext};
412 use indoc::indoc;
413 use language::{tree_sitter_rust, Buffer, Language, LanguageConfig};
414 use parking_lot::Mutex;
415 use rand::prelude::*;
416
417 #[gpui::test(iterations = 10)]
418 async fn test_autoindent(
419 cx: &mut TestAppContext,
420 mut rng: StdRng,
421 deterministic: Arc<Deterministic>,
422 ) {
423 let text = indoc! {"
424 fn main() {
425 let x = 0;
426 for _ in 0..10 {
427 x += 1;
428 }
429 }
430 "};
431 let buffer =
432 cx.add_model(|cx| Buffer::new(0, 0, text).with_language(Arc::new(rust_lang()), cx));
433 let buffer = cx.add_model(|cx| MultiBuffer::singleton(buffer, cx));
434 let range = buffer.read_with(cx, |buffer, cx| {
435 let snapshot = buffer.snapshot(cx);
436 snapshot.anchor_before(Point::new(1, 4))..snapshot.anchor_after(Point::new(4, 4))
437 });
438 let provider = Arc::new(TestCompletionProvider::new());
439 let codegen = cx.add_model(|cx| Codegen::new(buffer.clone(), range, provider.clone(), cx));
440 codegen.update(cx, |codegen, cx| codegen.start(Default::default(), cx));
441
442 let mut new_text = indoc! {"
443 let mut x = 0;
444 while x < 10 {
445 x += 1;
446 }
447 "};
448 while !new_text.is_empty() {
449 let max_len = cmp::min(new_text.len(), 10);
450 let len = rng.gen_range(1..=max_len);
451 let (chunk, suffix) = new_text.split_at(len);
452 provider.send_completion(chunk);
453 new_text = suffix;
454 deterministic.run_until_parked();
455 }
456 provider.finish_completion();
457 deterministic.run_until_parked();
458
459 assert_eq!(
460 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
461 indoc! {"
462 fn main() {
463 let mut x = 0;
464 while x < 10 {
465 x += 1;
466 }
467 }
468 "}
469 );
470 }
471
472 #[gpui::test]
473 async fn test_strip_markdown_codeblock() {
474 assert_eq!(
475 strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
476 .map(|chunk| chunk.unwrap())
477 .collect::<String>()
478 .await,
479 "Lorem ipsum dolor"
480 );
481 assert_eq!(
482 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
483 .map(|chunk| chunk.unwrap())
484 .collect::<String>()
485 .await,
486 "Lorem ipsum dolor"
487 );
488 assert_eq!(
489 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
490 .map(|chunk| chunk.unwrap())
491 .collect::<String>()
492 .await,
493 "Lorem ipsum dolor"
494 );
495 assert_eq!(
496 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
497 .map(|chunk| chunk.unwrap())
498 .collect::<String>()
499 .await,
500 "Lorem ipsum dolor"
501 );
502 assert_eq!(
503 strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
504 .map(|chunk| chunk.unwrap())
505 .collect::<String>()
506 .await,
507 "```js\nLorem ipsum dolor\n```"
508 );
509 assert_eq!(
510 strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
511 .map(|chunk| chunk.unwrap())
512 .collect::<String>()
513 .await,
514 "``\nLorem ipsum dolor\n```"
515 );
516
517 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
518 stream::iter(
519 text.chars()
520 .collect::<Vec<_>>()
521 .chunks(size)
522 .map(|chunk| Ok(chunk.iter().collect::<String>()))
523 .collect::<Vec<_>>(),
524 )
525 }
526 }
527
528 struct TestCompletionProvider {
529 last_completion_tx: Mutex<Option<mpsc::Sender<String>>>,
530 }
531
532 impl TestCompletionProvider {
533 fn new() -> Self {
534 Self {
535 last_completion_tx: Mutex::new(None),
536 }
537 }
538
539 fn send_completion(&self, completion: impl Into<String>) {
540 let mut tx = self.last_completion_tx.lock();
541 tx.as_mut().unwrap().try_send(completion.into()).unwrap();
542 }
543
544 fn finish_completion(&self) {
545 self.last_completion_tx.lock().take().unwrap();
546 }
547 }
548
549 impl CompletionProvider for TestCompletionProvider {
550 fn complete(
551 &self,
552 _prompt: OpenAIRequest,
553 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
554 let (tx, rx) = mpsc::channel(1);
555 *self.last_completion_tx.lock() = Some(tx);
556 async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed()
557 }
558 }
559
560 fn rust_lang() -> Language {
561 Language::new(
562 LanguageConfig {
563 name: "Rust".into(),
564 path_suffixes: vec!["rs".to_string()],
565 ..Default::default()
566 },
567 Some(tree_sitter_rust::language()),
568 )
569 .with_indents_query(
570 r#"
571 (call_expression) @indent
572 (field_expression) @indent
573 (_ "(" ")" @end) @indent
574 (_ "{" "}" @end) @indent
575 "#,
576 )
577 .unwrap()
578 }
579}