1use crate::{
2 streaming_diff::{Hunk, StreamingDiff},
3 CompletionProvider, LanguageModelRequest,
4};
5use anyhow::Result;
6use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
7use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
8use gpui::{EventEmitter, Model, ModelContext, Task};
9use language::{Rope, TransactionId};
10use std::{cmp, future, ops::Range};
11
12pub enum Event {
13 Finished,
14 Undone,
15}
16
17#[derive(Clone)]
18pub enum CodegenKind {
19 Transform { range: Range<Anchor> },
20 Generate { position: Anchor },
21}
22
23pub struct Codegen {
24 buffer: Model<MultiBuffer>,
25 snapshot: MultiBufferSnapshot,
26 kind: CodegenKind,
27 last_equal_ranges: Vec<Range<Anchor>>,
28 transaction_id: Option<TransactionId>,
29 error: Option<anyhow::Error>,
30 generation: Task<()>,
31 idle: bool,
32 _subscription: gpui::Subscription,
33}
34
35impl EventEmitter<Event> for Codegen {}
36
37impl Codegen {
38 pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
39 let snapshot = buffer.read(cx).snapshot(cx);
40 Self {
41 buffer: buffer.clone(),
42 snapshot,
43 kind,
44 last_equal_ranges: Default::default(),
45 transaction_id: Default::default(),
46 error: Default::default(),
47 idle: true,
48 generation: Task::ready(()),
49 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
50 }
51 }
52
53 fn handle_buffer_event(
54 &mut self,
55 _buffer: Model<MultiBuffer>,
56 event: &multi_buffer::Event,
57 cx: &mut ModelContext<Self>,
58 ) {
59 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
60 if self.transaction_id == Some(*transaction_id) {
61 self.transaction_id = None;
62 self.generation = Task::ready(());
63 cx.emit(Event::Undone);
64 }
65 }
66 }
67
68 pub fn range(&self) -> Range<Anchor> {
69 match &self.kind {
70 CodegenKind::Transform { range } => range.clone(),
71 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
72 }
73 }
74
75 pub fn kind(&self) -> &CodegenKind {
76 &self.kind
77 }
78
79 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
80 &self.last_equal_ranges
81 }
82
83 pub fn idle(&self) -> bool {
84 self.idle
85 }
86
87 pub fn error(&self) -> Option<&anyhow::Error> {
88 self.error.as_ref()
89 }
90
91 pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
92 let range = self.range();
93 let snapshot = self.snapshot.clone();
94 let selected_text = snapshot
95 .text_for_range(range.start..range.end)
96 .collect::<Rope>();
97
98 let selection_start = range.start.to_point(&snapshot);
99 let suggested_line_indent = snapshot
100 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
101 .into_values()
102 .next()
103 .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
104
105 let response = CompletionProvider::global(cx).complete(prompt);
106 self.generation = cx.spawn(|this, mut cx| {
107 async move {
108 let generate = async {
109 let mut edit_start = range.start.to_offset(&snapshot);
110
111 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
112 let diff = cx.background_executor().spawn(async move {
113 let chunks = strip_invalid_spans_from_codeblock(response.await?);
114 futures::pin_mut!(chunks);
115 let mut diff = StreamingDiff::new(selected_text.to_string());
116
117 let mut new_text = String::new();
118 let mut base_indent = None;
119 let mut line_indent = None;
120 let mut first_line = true;
121
122 while let Some(chunk) = chunks.next().await {
123 let chunk = chunk?;
124
125 let mut lines = chunk.split('\n').peekable();
126 while let Some(line) = lines.next() {
127 new_text.push_str(line);
128 if line_indent.is_none() {
129 if let Some(non_whitespace_ch_ix) =
130 new_text.find(|ch: char| !ch.is_whitespace())
131 {
132 line_indent = Some(non_whitespace_ch_ix);
133 base_indent = base_indent.or(line_indent);
134
135 let line_indent = line_indent.unwrap();
136 let base_indent = base_indent.unwrap();
137 let indent_delta = line_indent as i32 - base_indent as i32;
138 let mut corrected_indent_len = cmp::max(
139 0,
140 suggested_line_indent.len as i32 + indent_delta,
141 )
142 as usize;
143 if first_line {
144 corrected_indent_len = corrected_indent_len
145 .saturating_sub(selection_start.column as usize);
146 }
147
148 let indent_char = suggested_line_indent.char();
149 let mut indent_buffer = [0; 4];
150 let indent_str =
151 indent_char.encode_utf8(&mut indent_buffer);
152 new_text.replace_range(
153 ..line_indent,
154 &indent_str.repeat(corrected_indent_len),
155 );
156 }
157 }
158
159 if line_indent.is_some() {
160 hunks_tx.send(diff.push_new(&new_text)).await?;
161 new_text.clear();
162 }
163
164 if lines.peek().is_some() {
165 hunks_tx.send(diff.push_new("\n")).await?;
166 line_indent = None;
167 first_line = false;
168 }
169 }
170 }
171 hunks_tx.send(diff.push_new(&new_text)).await?;
172 hunks_tx.send(diff.finish()).await?;
173
174 anyhow::Ok(())
175 });
176
177 while let Some(hunks) = hunks_rx.next().await {
178 this.update(&mut cx, |this, cx| {
179 this.last_equal_ranges.clear();
180
181 let transaction = this.buffer.update(cx, |buffer, cx| {
182 // Avoid grouping assistant edits with user edits.
183 buffer.finalize_last_transaction(cx);
184
185 buffer.start_transaction(cx);
186 buffer.edit(
187 hunks.into_iter().filter_map(|hunk| match hunk {
188 Hunk::Insert { text } => {
189 let edit_start = snapshot.anchor_after(edit_start);
190 Some((edit_start..edit_start, text))
191 }
192 Hunk::Remove { len } => {
193 let edit_end = edit_start + len;
194 let edit_range = snapshot.anchor_after(edit_start)
195 ..snapshot.anchor_before(edit_end);
196 edit_start = edit_end;
197 Some((edit_range, String::new()))
198 }
199 Hunk::Keep { len } => {
200 let edit_end = edit_start + len;
201 let edit_range = snapshot.anchor_after(edit_start)
202 ..snapshot.anchor_before(edit_end);
203 edit_start = edit_end;
204 this.last_equal_ranges.push(edit_range);
205 None
206 }
207 }),
208 None,
209 cx,
210 );
211
212 buffer.end_transaction(cx)
213 });
214
215 if let Some(transaction) = transaction {
216 if let Some(first_transaction) = this.transaction_id {
217 // Group all assistant edits into the first transaction.
218 this.buffer.update(cx, |buffer, cx| {
219 buffer.merge_transactions(
220 transaction,
221 first_transaction,
222 cx,
223 )
224 });
225 } else {
226 this.transaction_id = Some(transaction);
227 this.buffer.update(cx, |buffer, cx| {
228 buffer.finalize_last_transaction(cx)
229 });
230 }
231 }
232
233 cx.notify();
234 })?;
235 }
236
237 diff.await?;
238 anyhow::Ok(())
239 };
240
241 let result = generate.await;
242 this.update(&mut cx, |this, cx| {
243 this.last_equal_ranges.clear();
244 this.idle = true;
245 if let Err(error) = result {
246 this.error = Some(error);
247 }
248 cx.emit(Event::Finished);
249 cx.notify();
250 })
251 .ok();
252 }
253 });
254 self.error.take();
255 self.idle = false;
256 cx.notify();
257 }
258
259 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
260 if let Some(transaction_id) = self.transaction_id {
261 self.buffer
262 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
263 }
264 }
265}
266
267fn strip_invalid_spans_from_codeblock(
268 stream: impl Stream<Item = Result<String>>,
269) -> impl Stream<Item = Result<String>> {
270 let mut first_line = true;
271 let mut buffer = String::new();
272 let mut starts_with_markdown_codeblock = false;
273 let mut includes_start_or_end_span = false;
274 stream.filter_map(move |chunk| {
275 let chunk = match chunk {
276 Ok(chunk) => chunk,
277 Err(err) => return future::ready(Some(Err(err))),
278 };
279 buffer.push_str(&chunk);
280
281 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
282 includes_start_or_end_span = true;
283
284 buffer = buffer
285 .strip_prefix("<|S|>")
286 .or_else(|| buffer.strip_prefix("<|S|"))
287 .unwrap_or(&buffer)
288 .to_string();
289 } else if buffer.ends_with("|E|>") {
290 includes_start_or_end_span = true;
291 } else if buffer.starts_with("<|")
292 || buffer.starts_with("<|S")
293 || buffer.starts_with("<|S|")
294 || buffer.ends_with('|')
295 || buffer.ends_with("|E")
296 || buffer.ends_with("|E|")
297 {
298 return future::ready(None);
299 }
300
301 if first_line {
302 if buffer.is_empty() || buffer == "`" || buffer == "``" {
303 return future::ready(None);
304 } else if buffer.starts_with("```") {
305 starts_with_markdown_codeblock = true;
306 if let Some(newline_ix) = buffer.find('\n') {
307 buffer.replace_range(..newline_ix + 1, "");
308 first_line = false;
309 } else {
310 return future::ready(None);
311 }
312 }
313 }
314
315 let mut text = buffer.to_string();
316 if starts_with_markdown_codeblock {
317 text = text
318 .strip_suffix("\n```\n")
319 .or_else(|| text.strip_suffix("\n```"))
320 .or_else(|| text.strip_suffix("\n``"))
321 .or_else(|| text.strip_suffix("\n`"))
322 .or_else(|| text.strip_suffix('\n'))
323 .unwrap_or(&text)
324 .to_string();
325 }
326
327 if includes_start_or_end_span {
328 text = text
329 .strip_suffix("|E|>")
330 .or_else(|| text.strip_suffix("E|>"))
331 .or_else(|| text.strip_prefix("|>"))
332 .or_else(|| text.strip_prefix('>'))
333 .unwrap_or(&text)
334 .to_string();
335 };
336
337 if text.contains('\n') {
338 first_line = false;
339 }
340
341 let remainder = buffer.split_off(text.len());
342 let result = if buffer.is_empty() {
343 None
344 } else {
345 Some(Ok(buffer.clone()))
346 };
347
348 buffer = remainder;
349 future::ready(result)
350 })
351}
352
353#[cfg(test)]
354mod tests {
355 use std::sync::Arc;
356
357 use crate::FakeCompletionProvider;
358
359 use super::*;
360 use futures::stream::{self};
361 use gpui::{Context, TestAppContext};
362 use indoc::indoc;
363 use language::{
364 language_settings, tree_sitter_rust, Buffer, BufferId, Language, LanguageConfig,
365 LanguageMatcher, Point,
366 };
367 use rand::prelude::*;
368 use serde::Serialize;
369 use settings::SettingsStore;
370
371 #[derive(Serialize)]
372 pub struct DummyCompletionRequest {
373 pub name: String,
374 }
375
376 #[gpui::test(iterations = 10)]
377 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
378 let provider = FakeCompletionProvider::default();
379 cx.set_global(cx.update(SettingsStore::test));
380 cx.set_global(CompletionProvider::Fake(provider.clone()));
381 cx.update(language_settings::init);
382
383 let text = indoc! {"
384 fn main() {
385 let x = 0;
386 for _ in 0..10 {
387 x += 1;
388 }
389 }
390 "};
391 let buffer = cx.new_model(|cx| {
392 Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
393 });
394 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
395 let range = buffer.read_with(cx, |buffer, cx| {
396 let snapshot = buffer.snapshot(cx);
397 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
398 });
399 let codegen =
400 cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
401
402 let request = LanguageModelRequest::default();
403 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
404
405 let mut new_text = concat!(
406 " let mut x = 0;\n",
407 " while x < 10 {\n",
408 " x += 1;\n",
409 " }",
410 );
411 while !new_text.is_empty() {
412 let max_len = cmp::min(new_text.len(), 10);
413 let len = rng.gen_range(1..=max_len);
414 let (chunk, suffix) = new_text.split_at(len);
415 provider.send_completion(chunk.into());
416 new_text = suffix;
417 cx.background_executor.run_until_parked();
418 }
419 provider.finish_completion();
420 cx.background_executor.run_until_parked();
421
422 assert_eq!(
423 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
424 indoc! {"
425 fn main() {
426 let mut x = 0;
427 while x < 10 {
428 x += 1;
429 }
430 }
431 "}
432 );
433 }
434
435 #[gpui::test(iterations = 10)]
436 async fn test_autoindent_when_generating_past_indentation(
437 cx: &mut TestAppContext,
438 mut rng: StdRng,
439 ) {
440 let provider = FakeCompletionProvider::default();
441 cx.set_global(CompletionProvider::Fake(provider.clone()));
442 cx.set_global(cx.update(SettingsStore::test));
443 cx.update(language_settings::init);
444
445 let text = indoc! {"
446 fn main() {
447 le
448 }
449 "};
450 let buffer = cx.new_model(|cx| {
451 Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
452 });
453 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
454 let position = buffer.read_with(cx, |buffer, cx| {
455 let snapshot = buffer.snapshot(cx);
456 snapshot.anchor_before(Point::new(1, 6))
457 });
458 let codegen =
459 cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
460
461 let request = LanguageModelRequest::default();
462 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
463
464 let mut new_text = concat!(
465 "t mut x = 0;\n",
466 "while x < 10 {\n",
467 " x += 1;\n",
468 "}", //
469 );
470 while !new_text.is_empty() {
471 let max_len = cmp::min(new_text.len(), 10);
472 let len = rng.gen_range(1..=max_len);
473 let (chunk, suffix) = new_text.split_at(len);
474 provider.send_completion(chunk.into());
475 new_text = suffix;
476 cx.background_executor.run_until_parked();
477 }
478 provider.finish_completion();
479 cx.background_executor.run_until_parked();
480
481 assert_eq!(
482 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
483 indoc! {"
484 fn main() {
485 let mut x = 0;
486 while x < 10 {
487 x += 1;
488 }
489 }
490 "}
491 );
492 }
493
494 #[gpui::test(iterations = 10)]
495 async fn test_autoindent_when_generating_before_indentation(
496 cx: &mut TestAppContext,
497 mut rng: StdRng,
498 ) {
499 let provider = FakeCompletionProvider::default();
500 cx.set_global(CompletionProvider::Fake(provider.clone()));
501 cx.set_global(cx.update(SettingsStore::test));
502 cx.update(language_settings::init);
503
504 let text = concat!(
505 "fn main() {\n",
506 " \n",
507 "}\n" //
508 );
509 let buffer = cx.new_model(|cx| {
510 Buffer::new(0, BufferId::new(1).unwrap(), text).with_language(Arc::new(rust_lang()), cx)
511 });
512 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
513 let position = buffer.read_with(cx, |buffer, cx| {
514 let snapshot = buffer.snapshot(cx);
515 snapshot.anchor_before(Point::new(1, 2))
516 });
517 let codegen =
518 cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
519
520 let request = LanguageModelRequest::default();
521 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
522
523 let mut new_text = concat!(
524 "let mut x = 0;\n",
525 "while x < 10 {\n",
526 " x += 1;\n",
527 "}", //
528 );
529 while !new_text.is_empty() {
530 let max_len = cmp::min(new_text.len(), 10);
531 let len = rng.gen_range(1..=max_len);
532 let (chunk, suffix) = new_text.split_at(len);
533 provider.send_completion(chunk.into());
534 new_text = suffix;
535 cx.background_executor.run_until_parked();
536 }
537 provider.finish_completion();
538 cx.background_executor.run_until_parked();
539
540 assert_eq!(
541 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
542 indoc! {"
543 fn main() {
544 let mut x = 0;
545 while x < 10 {
546 x += 1;
547 }
548 }
549 "}
550 );
551 }
552
553 #[gpui::test]
554 async fn test_strip_invalid_spans_from_codeblock() {
555 assert_eq!(
556 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
557 .map(|chunk| chunk.unwrap())
558 .collect::<String>()
559 .await,
560 "Lorem ipsum dolor"
561 );
562 assert_eq!(
563 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
564 .map(|chunk| chunk.unwrap())
565 .collect::<String>()
566 .await,
567 "Lorem ipsum dolor"
568 );
569 assert_eq!(
570 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
571 .map(|chunk| chunk.unwrap())
572 .collect::<String>()
573 .await,
574 "Lorem ipsum dolor"
575 );
576 assert_eq!(
577 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
578 .map(|chunk| chunk.unwrap())
579 .collect::<String>()
580 .await,
581 "Lorem ipsum dolor"
582 );
583 assert_eq!(
584 strip_invalid_spans_from_codeblock(chunks(
585 "```html\n```js\nLorem ipsum dolor\n```\n```",
586 2
587 ))
588 .map(|chunk| chunk.unwrap())
589 .collect::<String>()
590 .await,
591 "```js\nLorem ipsum dolor\n```"
592 );
593 assert_eq!(
594 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
595 .map(|chunk| chunk.unwrap())
596 .collect::<String>()
597 .await,
598 "``\nLorem ipsum dolor\n```"
599 );
600 assert_eq!(
601 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
602 .map(|chunk| chunk.unwrap())
603 .collect::<String>()
604 .await,
605 "Lorem ipsum"
606 );
607
608 assert_eq!(
609 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
610 .map(|chunk| chunk.unwrap())
611 .collect::<String>()
612 .await,
613 "Lorem ipsum"
614 );
615
616 assert_eq!(
617 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
618 .map(|chunk| chunk.unwrap())
619 .collect::<String>()
620 .await,
621 "Lorem ipsum"
622 );
623 assert_eq!(
624 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
625 .map(|chunk| chunk.unwrap())
626 .collect::<String>()
627 .await,
628 "Lorem ipsum"
629 );
630 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
631 stream::iter(
632 text.chars()
633 .collect::<Vec<_>>()
634 .chunks(size)
635 .map(|chunk| Ok(chunk.iter().collect::<String>()))
636 .collect::<Vec<_>>(),
637 )
638 }
639 }
640
641 fn rust_lang() -> Language {
642 Language::new(
643 LanguageConfig {
644 name: "Rust".into(),
645 matcher: LanguageMatcher {
646 path_suffixes: vec!["rs".to_string()],
647 ..Default::default()
648 },
649 ..Default::default()
650 },
651 Some(tree_sitter_rust::language()),
652 )
653 .with_indents_query(
654 r#"
655 (call_expression) @indent
656 (field_expression) @indent
657 (_ "(" ")" @end) @indent
658 (_ "{" "}" @end) @indent
659 "#,
660 )
661 .unwrap()
662 }
663}