1use crate::{
2 streaming_diff::{Hunk, StreamingDiff},
3 CompletionProvider, LanguageModelRequest,
4};
5use anyhow::Result;
6use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
7use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
8use gpui::{EventEmitter, Model, ModelContext, Task};
9use language::{Rope, TransactionId};
10use multi_buffer::MultiBufferRow;
11use std::{cmp, future, ops::Range};
12
13pub enum Event {
14 Finished,
15 Undone,
16}
17
18#[derive(Clone)]
19pub enum CodegenKind {
20 Transform { range: Range<Anchor> },
21 Generate { position: Anchor },
22}
23
24pub struct Codegen {
25 buffer: Model<MultiBuffer>,
26 snapshot: MultiBufferSnapshot,
27 kind: CodegenKind,
28 last_equal_ranges: Vec<Range<Anchor>>,
29 transaction_id: Option<TransactionId>,
30 error: Option<anyhow::Error>,
31 generation: Task<()>,
32 idle: bool,
33 _subscription: gpui::Subscription,
34}
35
36impl EventEmitter<Event> for Codegen {}
37
38impl Codegen {
39 pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
40 let snapshot = buffer.read(cx).snapshot(cx);
41 Self {
42 buffer: buffer.clone(),
43 snapshot,
44 kind,
45 last_equal_ranges: Default::default(),
46 transaction_id: Default::default(),
47 error: Default::default(),
48 idle: true,
49 generation: Task::ready(()),
50 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
51 }
52 }
53
54 fn handle_buffer_event(
55 &mut self,
56 _buffer: Model<MultiBuffer>,
57 event: &multi_buffer::Event,
58 cx: &mut ModelContext<Self>,
59 ) {
60 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
61 if self.transaction_id == Some(*transaction_id) {
62 self.transaction_id = None;
63 self.generation = Task::ready(());
64 cx.emit(Event::Undone);
65 }
66 }
67 }
68
69 pub fn range(&self) -> Range<Anchor> {
70 match &self.kind {
71 CodegenKind::Transform { range } => range.clone(),
72 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
73 }
74 }
75
76 pub fn kind(&self) -> &CodegenKind {
77 &self.kind
78 }
79
80 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
81 &self.last_equal_ranges
82 }
83
84 pub fn idle(&self) -> bool {
85 self.idle
86 }
87
88 pub fn error(&self) -> Option<&anyhow::Error> {
89 self.error.as_ref()
90 }
91
92 pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
93 let range = self.range();
94 let snapshot = self.snapshot.clone();
95 let selected_text = snapshot
96 .text_for_range(range.start..range.end)
97 .collect::<Rope>();
98
99 let selection_start = range.start.to_point(&snapshot);
100 let suggested_line_indent = snapshot
101 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
102 .into_values()
103 .next()
104 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
105
106 let response = CompletionProvider::global(cx).complete(prompt);
107 self.generation = cx.spawn(|this, mut cx| {
108 async move {
109 let generate = async {
110 let mut edit_start = range.start.to_offset(&snapshot);
111
112 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
113 let diff = cx.background_executor().spawn(async move {
114 let chunks = strip_invalid_spans_from_codeblock(response.await?);
115 futures::pin_mut!(chunks);
116 let mut diff = StreamingDiff::new(selected_text.to_string());
117
118 let mut new_text = String::new();
119 let mut base_indent = None;
120 let mut line_indent = None;
121 let mut first_line = true;
122
123 while let Some(chunk) = chunks.next().await {
124 let chunk = chunk?;
125
126 let mut lines = chunk.split('\n').peekable();
127 while let Some(line) = lines.next() {
128 new_text.push_str(line);
129 if line_indent.is_none() {
130 if let Some(non_whitespace_ch_ix) =
131 new_text.find(|ch: char| !ch.is_whitespace())
132 {
133 line_indent = Some(non_whitespace_ch_ix);
134 base_indent = base_indent.or(line_indent);
135
136 let line_indent = line_indent.unwrap();
137 let base_indent = base_indent.unwrap();
138 let indent_delta = line_indent as i32 - base_indent as i32;
139 let mut corrected_indent_len = cmp::max(
140 0,
141 suggested_line_indent.len as i32 + indent_delta,
142 )
143 as usize;
144 if first_line {
145 corrected_indent_len = corrected_indent_len
146 .saturating_sub(selection_start.column as usize);
147 }
148
149 let indent_char = suggested_line_indent.char();
150 let mut indent_buffer = [0; 4];
151 let indent_str =
152 indent_char.encode_utf8(&mut indent_buffer);
153 new_text.replace_range(
154 ..line_indent,
155 &indent_str.repeat(corrected_indent_len),
156 );
157 }
158 }
159
160 if line_indent.is_some() {
161 hunks_tx.send(diff.push_new(&new_text)).await?;
162 new_text.clear();
163 }
164
165 if lines.peek().is_some() {
166 hunks_tx.send(diff.push_new("\n")).await?;
167 line_indent = None;
168 first_line = false;
169 }
170 }
171 }
172 hunks_tx.send(diff.push_new(&new_text)).await?;
173 hunks_tx.send(diff.finish()).await?;
174
175 anyhow::Ok(())
176 });
177
178 while let Some(hunks) = hunks_rx.next().await {
179 this.update(&mut cx, |this, cx| {
180 this.last_equal_ranges.clear();
181
182 let transaction = this.buffer.update(cx, |buffer, cx| {
183 // Avoid grouping assistant edits with user edits.
184 buffer.finalize_last_transaction(cx);
185
186 buffer.start_transaction(cx);
187 buffer.edit(
188 hunks.into_iter().filter_map(|hunk| match hunk {
189 Hunk::Insert { text } => {
190 let edit_start = snapshot.anchor_after(edit_start);
191 Some((edit_start..edit_start, text))
192 }
193 Hunk::Remove { len } => {
194 let edit_end = edit_start + len;
195 let edit_range = snapshot.anchor_after(edit_start)
196 ..snapshot.anchor_before(edit_end);
197 edit_start = edit_end;
198 Some((edit_range, String::new()))
199 }
200 Hunk::Keep { len } => {
201 let edit_end = edit_start + len;
202 let edit_range = snapshot.anchor_after(edit_start)
203 ..snapshot.anchor_before(edit_end);
204 edit_start = edit_end;
205 this.last_equal_ranges.push(edit_range);
206 None
207 }
208 }),
209 None,
210 cx,
211 );
212
213 buffer.end_transaction(cx)
214 });
215
216 if let Some(transaction) = transaction {
217 if let Some(first_transaction) = this.transaction_id {
218 // Group all assistant edits into the first transaction.
219 this.buffer.update(cx, |buffer, cx| {
220 buffer.merge_transactions(
221 transaction,
222 first_transaction,
223 cx,
224 )
225 });
226 } else {
227 this.transaction_id = Some(transaction);
228 this.buffer.update(cx, |buffer, cx| {
229 buffer.finalize_last_transaction(cx)
230 });
231 }
232 }
233
234 cx.notify();
235 })?;
236 }
237
238 diff.await?;
239 anyhow::Ok(())
240 };
241
242 let result = generate.await;
243 this.update(&mut cx, |this, cx| {
244 this.last_equal_ranges.clear();
245 this.idle = true;
246 if let Err(error) = result {
247 this.error = Some(error);
248 }
249 cx.emit(Event::Finished);
250 cx.notify();
251 })
252 .ok();
253 }
254 });
255 self.error.take();
256 self.idle = false;
257 cx.notify();
258 }
259
260 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
261 if let Some(transaction_id) = self.transaction_id {
262 self.buffer
263 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
264 }
265 }
266}
267
268fn strip_invalid_spans_from_codeblock(
269 stream: impl Stream<Item = Result<String>>,
270) -> impl Stream<Item = Result<String>> {
271 let mut first_line = true;
272 let mut buffer = String::new();
273 let mut starts_with_markdown_codeblock = false;
274 let mut includes_start_or_end_span = false;
275 stream.filter_map(move |chunk| {
276 let chunk = match chunk {
277 Ok(chunk) => chunk,
278 Err(err) => return future::ready(Some(Err(err))),
279 };
280 buffer.push_str(&chunk);
281
282 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
283 includes_start_or_end_span = true;
284
285 buffer = buffer
286 .strip_prefix("<|S|>")
287 .or_else(|| buffer.strip_prefix("<|S|"))
288 .unwrap_or(&buffer)
289 .to_string();
290 } else if buffer.ends_with("|E|>") {
291 includes_start_or_end_span = true;
292 } else if buffer.starts_with("<|")
293 || buffer.starts_with("<|S")
294 || buffer.starts_with("<|S|")
295 || buffer.ends_with('|')
296 || buffer.ends_with("|E")
297 || buffer.ends_with("|E|")
298 {
299 return future::ready(None);
300 }
301
302 if first_line {
303 if buffer.is_empty() || buffer == "`" || buffer == "``" {
304 return future::ready(None);
305 } else if buffer.starts_with("```") {
306 starts_with_markdown_codeblock = true;
307 if let Some(newline_ix) = buffer.find('\n') {
308 buffer.replace_range(..newline_ix + 1, "");
309 first_line = false;
310 } else {
311 return future::ready(None);
312 }
313 }
314 }
315
316 let mut text = buffer.to_string();
317 if starts_with_markdown_codeblock {
318 text = text
319 .strip_suffix("\n```\n")
320 .or_else(|| text.strip_suffix("\n```"))
321 .or_else(|| text.strip_suffix("\n``"))
322 .or_else(|| text.strip_suffix("\n`"))
323 .or_else(|| text.strip_suffix('\n'))
324 .unwrap_or(&text)
325 .to_string();
326 }
327
328 if includes_start_or_end_span {
329 text = text
330 .strip_suffix("|E|>")
331 .or_else(|| text.strip_suffix("E|>"))
332 .or_else(|| text.strip_prefix("|>"))
333 .or_else(|| text.strip_prefix('>'))
334 .unwrap_or(&text)
335 .to_string();
336 };
337
338 if text.contains('\n') {
339 first_line = false;
340 }
341
342 let remainder = buffer.split_off(text.len());
343 let result = if buffer.is_empty() {
344 None
345 } else {
346 Some(Ok(buffer.clone()))
347 };
348
349 buffer = remainder;
350 future::ready(result)
351 })
352}
353
354#[cfg(test)]
355mod tests {
356 use std::sync::Arc;
357
358 use crate::FakeCompletionProvider;
359
360 use super::*;
361 use futures::stream::{self};
362 use gpui::{Context, TestAppContext};
363 use indoc::indoc;
364 use language::{
365 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
366 Point,
367 };
368 use rand::prelude::*;
369 use serde::Serialize;
370 use settings::SettingsStore;
371
372 #[derive(Serialize)]
373 pub struct DummyCompletionRequest {
374 pub name: String,
375 }
376
377 #[gpui::test(iterations = 10)]
378 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
379 let provider = FakeCompletionProvider::default();
380 cx.set_global(cx.update(SettingsStore::test));
381 cx.set_global(CompletionProvider::Fake(provider.clone()));
382 cx.update(language_settings::init);
383
384 let text = indoc! {"
385 fn main() {
386 let x = 0;
387 for _ in 0..10 {
388 x += 1;
389 }
390 }
391 "};
392 let buffer =
393 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
394 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
395 let range = buffer.read_with(cx, |buffer, cx| {
396 let snapshot = buffer.snapshot(cx);
397 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
398 });
399 let codegen =
400 cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
401
402 let request = LanguageModelRequest::default();
403 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
404
405 let mut new_text = concat!(
406 " let mut x = 0;\n",
407 " while x < 10 {\n",
408 " x += 1;\n",
409 " }",
410 );
411 while !new_text.is_empty() {
412 let max_len = cmp::min(new_text.len(), 10);
413 let len = rng.gen_range(1..=max_len);
414 let (chunk, suffix) = new_text.split_at(len);
415 provider.send_completion(chunk.into());
416 new_text = suffix;
417 cx.background_executor.run_until_parked();
418 }
419 provider.finish_completion();
420 cx.background_executor.run_until_parked();
421
422 assert_eq!(
423 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
424 indoc! {"
425 fn main() {
426 let mut x = 0;
427 while x < 10 {
428 x += 1;
429 }
430 }
431 "}
432 );
433 }
434
435 #[gpui::test(iterations = 10)]
436 async fn test_autoindent_when_generating_past_indentation(
437 cx: &mut TestAppContext,
438 mut rng: StdRng,
439 ) {
440 let provider = FakeCompletionProvider::default();
441 cx.set_global(CompletionProvider::Fake(provider.clone()));
442 cx.set_global(cx.update(SettingsStore::test));
443 cx.update(language_settings::init);
444
445 let text = indoc! {"
446 fn main() {
447 le
448 }
449 "};
450 let buffer =
451 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
452 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
453 let position = buffer.read_with(cx, |buffer, cx| {
454 let snapshot = buffer.snapshot(cx);
455 snapshot.anchor_before(Point::new(1, 6))
456 });
457 let codegen =
458 cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
459
460 let request = LanguageModelRequest::default();
461 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
462
463 let mut new_text = concat!(
464 "t mut x = 0;\n",
465 "while x < 10 {\n",
466 " x += 1;\n",
467 "}", //
468 );
469 while !new_text.is_empty() {
470 let max_len = cmp::min(new_text.len(), 10);
471 let len = rng.gen_range(1..=max_len);
472 let (chunk, suffix) = new_text.split_at(len);
473 provider.send_completion(chunk.into());
474 new_text = suffix;
475 cx.background_executor.run_until_parked();
476 }
477 provider.finish_completion();
478 cx.background_executor.run_until_parked();
479
480 assert_eq!(
481 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
482 indoc! {"
483 fn main() {
484 let mut x = 0;
485 while x < 10 {
486 x += 1;
487 }
488 }
489 "}
490 );
491 }
492
493 #[gpui::test(iterations = 10)]
494 async fn test_autoindent_when_generating_before_indentation(
495 cx: &mut TestAppContext,
496 mut rng: StdRng,
497 ) {
498 let provider = FakeCompletionProvider::default();
499 cx.set_global(CompletionProvider::Fake(provider.clone()));
500 cx.set_global(cx.update(SettingsStore::test));
501 cx.update(language_settings::init);
502
503 let text = concat!(
504 "fn main() {\n",
505 " \n",
506 "}\n" //
507 );
508 let buffer =
509 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
510 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
511 let position = buffer.read_with(cx, |buffer, cx| {
512 let snapshot = buffer.snapshot(cx);
513 snapshot.anchor_before(Point::new(1, 2))
514 });
515 let codegen =
516 cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
517
518 let request = LanguageModelRequest::default();
519 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
520
521 let mut new_text = concat!(
522 "let mut x = 0;\n",
523 "while x < 10 {\n",
524 " x += 1;\n",
525 "}", //
526 );
527 while !new_text.is_empty() {
528 let max_len = cmp::min(new_text.len(), 10);
529 let len = rng.gen_range(1..=max_len);
530 let (chunk, suffix) = new_text.split_at(len);
531 provider.send_completion(chunk.into());
532 new_text = suffix;
533 cx.background_executor.run_until_parked();
534 }
535 provider.finish_completion();
536 cx.background_executor.run_until_parked();
537
538 assert_eq!(
539 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
540 indoc! {"
541 fn main() {
542 let mut x = 0;
543 while x < 10 {
544 x += 1;
545 }
546 }
547 "}
548 );
549 }
550
551 #[gpui::test]
552 async fn test_strip_invalid_spans_from_codeblock() {
553 assert_eq!(
554 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
555 .map(|chunk| chunk.unwrap())
556 .collect::<String>()
557 .await,
558 "Lorem ipsum dolor"
559 );
560 assert_eq!(
561 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
562 .map(|chunk| chunk.unwrap())
563 .collect::<String>()
564 .await,
565 "Lorem ipsum dolor"
566 );
567 assert_eq!(
568 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
569 .map(|chunk| chunk.unwrap())
570 .collect::<String>()
571 .await,
572 "Lorem ipsum dolor"
573 );
574 assert_eq!(
575 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
576 .map(|chunk| chunk.unwrap())
577 .collect::<String>()
578 .await,
579 "Lorem ipsum dolor"
580 );
581 assert_eq!(
582 strip_invalid_spans_from_codeblock(chunks(
583 "```html\n```js\nLorem ipsum dolor\n```\n```",
584 2
585 ))
586 .map(|chunk| chunk.unwrap())
587 .collect::<String>()
588 .await,
589 "```js\nLorem ipsum dolor\n```"
590 );
591 assert_eq!(
592 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
593 .map(|chunk| chunk.unwrap())
594 .collect::<String>()
595 .await,
596 "``\nLorem ipsum dolor\n```"
597 );
598 assert_eq!(
599 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
600 .map(|chunk| chunk.unwrap())
601 .collect::<String>()
602 .await,
603 "Lorem ipsum"
604 );
605
606 assert_eq!(
607 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
608 .map(|chunk| chunk.unwrap())
609 .collect::<String>()
610 .await,
611 "Lorem ipsum"
612 );
613
614 assert_eq!(
615 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
616 .map(|chunk| chunk.unwrap())
617 .collect::<String>()
618 .await,
619 "Lorem ipsum"
620 );
621 assert_eq!(
622 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
623 .map(|chunk| chunk.unwrap())
624 .collect::<String>()
625 .await,
626 "Lorem ipsum"
627 );
628 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
629 stream::iter(
630 text.chars()
631 .collect::<Vec<_>>()
632 .chunks(size)
633 .map(|chunk| Ok(chunk.iter().collect::<String>()))
634 .collect::<Vec<_>>(),
635 )
636 }
637 }
638
639 fn rust_lang() -> Language {
640 Language::new(
641 LanguageConfig {
642 name: "Rust".into(),
643 matcher: LanguageMatcher {
644 path_suffixes: vec!["rs".to_string()],
645 ..Default::default()
646 },
647 ..Default::default()
648 },
649 Some(tree_sitter_rust::language()),
650 )
651 .with_indents_query(
652 r#"
653 (call_expression) @indent
654 (field_expression) @indent
655 (_ "(" ")" @end) @indent
656 (_ "{" "}" @end) @indent
657 "#,
658 )
659 .unwrap()
660 }
661}