1use crate::{
2 streaming_diff::{Hunk, StreamingDiff},
3 CompletionProvider, LanguageModelRequest,
4};
5use anyhow::Result;
6use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
7use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
8use gpui::{EventEmitter, Model, ModelContext, Task};
9use language::{Rope, TransactionId};
10use std::{cmp, future, ops::Range};
11
12pub enum Event {
13 Finished,
14 Undone,
15}
16
17#[derive(Clone)]
18pub enum CodegenKind {
19 Transform { range: Range<Anchor> },
20 Generate { position: Anchor },
21}
22
23pub struct Codegen {
24 buffer: Model<MultiBuffer>,
25 snapshot: MultiBufferSnapshot,
26 kind: CodegenKind,
27 last_equal_ranges: Vec<Range<Anchor>>,
28 transaction_id: Option<TransactionId>,
29 error: Option<anyhow::Error>,
30 generation: Task<()>,
31 idle: bool,
32 _subscription: gpui::Subscription,
33}
34
35impl EventEmitter<Event> for Codegen {}
36
37impl Codegen {
38 pub fn new(buffer: Model<MultiBuffer>, kind: CodegenKind, cx: &mut ModelContext<Self>) -> Self {
39 let snapshot = buffer.read(cx).snapshot(cx);
40 Self {
41 buffer: buffer.clone(),
42 snapshot,
43 kind,
44 last_equal_ranges: Default::default(),
45 transaction_id: Default::default(),
46 error: Default::default(),
47 idle: true,
48 generation: Task::ready(()),
49 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
50 }
51 }
52
53 fn handle_buffer_event(
54 &mut self,
55 _buffer: Model<MultiBuffer>,
56 event: &multi_buffer::Event,
57 cx: &mut ModelContext<Self>,
58 ) {
59 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
60 if self.transaction_id == Some(*transaction_id) {
61 self.transaction_id = None;
62 self.generation = Task::ready(());
63 cx.emit(Event::Undone);
64 }
65 }
66 }
67
68 pub fn range(&self) -> Range<Anchor> {
69 match &self.kind {
70 CodegenKind::Transform { range } => range.clone(),
71 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
72 }
73 }
74
75 pub fn kind(&self) -> &CodegenKind {
76 &self.kind
77 }
78
79 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
80 &self.last_equal_ranges
81 }
82
83 pub fn idle(&self) -> bool {
84 self.idle
85 }
86
87 pub fn error(&self) -> Option<&anyhow::Error> {
88 self.error.as_ref()
89 }
90
91 pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
92 let range = self.range();
93 let snapshot = self.snapshot.clone();
94 let selected_text = snapshot
95 .text_for_range(range.start..range.end)
96 .collect::<Rope>();
97
98 let selection_start = range.start.to_point(&snapshot);
99 let suggested_line_indent = snapshot
100 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
101 .into_values()
102 .next()
103 .unwrap_or_else(|| snapshot.indent_size_for_line(selection_start.row));
104
105 let response = CompletionProvider::global(cx).complete(prompt);
106 self.generation = cx.spawn(|this, mut cx| {
107 async move {
108 let generate = async {
109 let mut edit_start = range.start.to_offset(&snapshot);
110
111 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
112 let diff = cx.background_executor().spawn(async move {
113 let chunks = strip_invalid_spans_from_codeblock(response.await?);
114 futures::pin_mut!(chunks);
115 let mut diff = StreamingDiff::new(selected_text.to_string());
116
117 let mut new_text = String::new();
118 let mut base_indent = None;
119 let mut line_indent = None;
120 let mut first_line = true;
121
122 while let Some(chunk) = chunks.next().await {
123 let chunk = chunk?;
124
125 let mut lines = chunk.split('\n').peekable();
126 while let Some(line) = lines.next() {
127 new_text.push_str(line);
128 if line_indent.is_none() {
129 if let Some(non_whitespace_ch_ix) =
130 new_text.find(|ch: char| !ch.is_whitespace())
131 {
132 line_indent = Some(non_whitespace_ch_ix);
133 base_indent = base_indent.or(line_indent);
134
135 let line_indent = line_indent.unwrap();
136 let base_indent = base_indent.unwrap();
137 let indent_delta = line_indent as i32 - base_indent as i32;
138 let mut corrected_indent_len = cmp::max(
139 0,
140 suggested_line_indent.len as i32 + indent_delta,
141 )
142 as usize;
143 if first_line {
144 corrected_indent_len = corrected_indent_len
145 .saturating_sub(selection_start.column as usize);
146 }
147
148 let indent_char = suggested_line_indent.char();
149 let mut indent_buffer = [0; 4];
150 let indent_str =
151 indent_char.encode_utf8(&mut indent_buffer);
152 new_text.replace_range(
153 ..line_indent,
154 &indent_str.repeat(corrected_indent_len),
155 );
156 }
157 }
158
159 if line_indent.is_some() {
160 hunks_tx.send(diff.push_new(&new_text)).await?;
161 new_text.clear();
162 }
163
164 if lines.peek().is_some() {
165 hunks_tx.send(diff.push_new("\n")).await?;
166 line_indent = None;
167 first_line = false;
168 }
169 }
170 }
171 hunks_tx.send(diff.push_new(&new_text)).await?;
172 hunks_tx.send(diff.finish()).await?;
173
174 anyhow::Ok(())
175 });
176
177 while let Some(hunks) = hunks_rx.next().await {
178 this.update(&mut cx, |this, cx| {
179 this.last_equal_ranges.clear();
180
181 let transaction = this.buffer.update(cx, |buffer, cx| {
182 // Avoid grouping assistant edits with user edits.
183 buffer.finalize_last_transaction(cx);
184
185 buffer.start_transaction(cx);
186 buffer.edit(
187 hunks.into_iter().filter_map(|hunk| match hunk {
188 Hunk::Insert { text } => {
189 let edit_start = snapshot.anchor_after(edit_start);
190 Some((edit_start..edit_start, text))
191 }
192 Hunk::Remove { len } => {
193 let edit_end = edit_start + len;
194 let edit_range = snapshot.anchor_after(edit_start)
195 ..snapshot.anchor_before(edit_end);
196 edit_start = edit_end;
197 Some((edit_range, String::new()))
198 }
199 Hunk::Keep { len } => {
200 let edit_end = edit_start + len;
201 let edit_range = snapshot.anchor_after(edit_start)
202 ..snapshot.anchor_before(edit_end);
203 edit_start = edit_end;
204 this.last_equal_ranges.push(edit_range);
205 None
206 }
207 }),
208 None,
209 cx,
210 );
211
212 buffer.end_transaction(cx)
213 });
214
215 if let Some(transaction) = transaction {
216 if let Some(first_transaction) = this.transaction_id {
217 // Group all assistant edits into the first transaction.
218 this.buffer.update(cx, |buffer, cx| {
219 buffer.merge_transactions(
220 transaction,
221 first_transaction,
222 cx,
223 )
224 });
225 } else {
226 this.transaction_id = Some(transaction);
227 this.buffer.update(cx, |buffer, cx| {
228 buffer.finalize_last_transaction(cx)
229 });
230 }
231 }
232
233 cx.notify();
234 })?;
235 }
236
237 diff.await?;
238 anyhow::Ok(())
239 };
240
241 let result = generate.await;
242 this.update(&mut cx, |this, cx| {
243 this.last_equal_ranges.clear();
244 this.idle = true;
245 if let Err(error) = result {
246 this.error = Some(error);
247 }
248 cx.emit(Event::Finished);
249 cx.notify();
250 })
251 .ok();
252 }
253 });
254 self.error.take();
255 self.idle = false;
256 cx.notify();
257 }
258
259 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
260 if let Some(transaction_id) = self.transaction_id {
261 self.buffer
262 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
263 }
264 }
265}
266
267fn strip_invalid_spans_from_codeblock(
268 stream: impl Stream<Item = Result<String>>,
269) -> impl Stream<Item = Result<String>> {
270 let mut first_line = true;
271 let mut buffer = String::new();
272 let mut starts_with_markdown_codeblock = false;
273 let mut includes_start_or_end_span = false;
274 stream.filter_map(move |chunk| {
275 let chunk = match chunk {
276 Ok(chunk) => chunk,
277 Err(err) => return future::ready(Some(Err(err))),
278 };
279 buffer.push_str(&chunk);
280
281 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
282 includes_start_or_end_span = true;
283
284 buffer = buffer
285 .strip_prefix("<|S|>")
286 .or_else(|| buffer.strip_prefix("<|S|"))
287 .unwrap_or(&buffer)
288 .to_string();
289 } else if buffer.ends_with("|E|>") {
290 includes_start_or_end_span = true;
291 } else if buffer.starts_with("<|")
292 || buffer.starts_with("<|S")
293 || buffer.starts_with("<|S|")
294 || buffer.ends_with('|')
295 || buffer.ends_with("|E")
296 || buffer.ends_with("|E|")
297 {
298 return future::ready(None);
299 }
300
301 if first_line {
302 if buffer.is_empty() || buffer == "`" || buffer == "``" {
303 return future::ready(None);
304 } else if buffer.starts_with("```") {
305 starts_with_markdown_codeblock = true;
306 if let Some(newline_ix) = buffer.find('\n') {
307 buffer.replace_range(..newline_ix + 1, "");
308 first_line = false;
309 } else {
310 return future::ready(None);
311 }
312 }
313 }
314
315 let mut text = buffer.to_string();
316 if starts_with_markdown_codeblock {
317 text = text
318 .strip_suffix("\n```\n")
319 .or_else(|| text.strip_suffix("\n```"))
320 .or_else(|| text.strip_suffix("\n``"))
321 .or_else(|| text.strip_suffix("\n`"))
322 .or_else(|| text.strip_suffix('\n'))
323 .unwrap_or(&text)
324 .to_string();
325 }
326
327 if includes_start_or_end_span {
328 text = text
329 .strip_suffix("|E|>")
330 .or_else(|| text.strip_suffix("E|>"))
331 .or_else(|| text.strip_prefix("|>"))
332 .or_else(|| text.strip_prefix('>'))
333 .unwrap_or(&text)
334 .to_string();
335 };
336
337 if text.contains('\n') {
338 first_line = false;
339 }
340
341 let remainder = buffer.split_off(text.len());
342 let result = if buffer.is_empty() {
343 None
344 } else {
345 Some(Ok(buffer.clone()))
346 };
347
348 buffer = remainder;
349 future::ready(result)
350 })
351}
352
353#[cfg(test)]
354mod tests {
355 use std::sync::Arc;
356
357 use crate::FakeCompletionProvider;
358
359 use super::*;
360 use futures::stream::{self};
361 use gpui::{Context, TestAppContext};
362 use indoc::indoc;
363 use language::{
364 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
365 Point,
366 };
367 use rand::prelude::*;
368 use serde::Serialize;
369 use settings::SettingsStore;
370
371 #[derive(Serialize)]
372 pub struct DummyCompletionRequest {
373 pub name: String,
374 }
375
376 #[gpui::test(iterations = 10)]
377 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
378 let provider = FakeCompletionProvider::default();
379 cx.set_global(cx.update(SettingsStore::test));
380 cx.set_global(CompletionProvider::Fake(provider.clone()));
381 cx.update(language_settings::init);
382
383 let text = indoc! {"
384 fn main() {
385 let x = 0;
386 for _ in 0..10 {
387 x += 1;
388 }
389 }
390 "};
391 let buffer =
392 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
393 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
394 let range = buffer.read_with(cx, |buffer, cx| {
395 let snapshot = buffer.snapshot(cx);
396 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
397 });
398 let codegen =
399 cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Transform { range }, cx));
400
401 let request = LanguageModelRequest::default();
402 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
403
404 let mut new_text = concat!(
405 " let mut x = 0;\n",
406 " while x < 10 {\n",
407 " x += 1;\n",
408 " }",
409 );
410 while !new_text.is_empty() {
411 let max_len = cmp::min(new_text.len(), 10);
412 let len = rng.gen_range(1..=max_len);
413 let (chunk, suffix) = new_text.split_at(len);
414 provider.send_completion(chunk.into());
415 new_text = suffix;
416 cx.background_executor.run_until_parked();
417 }
418 provider.finish_completion();
419 cx.background_executor.run_until_parked();
420
421 assert_eq!(
422 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
423 indoc! {"
424 fn main() {
425 let mut x = 0;
426 while x < 10 {
427 x += 1;
428 }
429 }
430 "}
431 );
432 }
433
434 #[gpui::test(iterations = 10)]
435 async fn test_autoindent_when_generating_past_indentation(
436 cx: &mut TestAppContext,
437 mut rng: StdRng,
438 ) {
439 let provider = FakeCompletionProvider::default();
440 cx.set_global(CompletionProvider::Fake(provider.clone()));
441 cx.set_global(cx.update(SettingsStore::test));
442 cx.update(language_settings::init);
443
444 let text = indoc! {"
445 fn main() {
446 le
447 }
448 "};
449 let buffer =
450 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
451 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
452 let position = buffer.read_with(cx, |buffer, cx| {
453 let snapshot = buffer.snapshot(cx);
454 snapshot.anchor_before(Point::new(1, 6))
455 });
456 let codegen =
457 cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
458
459 let request = LanguageModelRequest::default();
460 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
461
462 let mut new_text = concat!(
463 "t mut x = 0;\n",
464 "while x < 10 {\n",
465 " x += 1;\n",
466 "}", //
467 );
468 while !new_text.is_empty() {
469 let max_len = cmp::min(new_text.len(), 10);
470 let len = rng.gen_range(1..=max_len);
471 let (chunk, suffix) = new_text.split_at(len);
472 provider.send_completion(chunk.into());
473 new_text = suffix;
474 cx.background_executor.run_until_parked();
475 }
476 provider.finish_completion();
477 cx.background_executor.run_until_parked();
478
479 assert_eq!(
480 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
481 indoc! {"
482 fn main() {
483 let mut x = 0;
484 while x < 10 {
485 x += 1;
486 }
487 }
488 "}
489 );
490 }
491
492 #[gpui::test(iterations = 10)]
493 async fn test_autoindent_when_generating_before_indentation(
494 cx: &mut TestAppContext,
495 mut rng: StdRng,
496 ) {
497 let provider = FakeCompletionProvider::default();
498 cx.set_global(CompletionProvider::Fake(provider.clone()));
499 cx.set_global(cx.update(SettingsStore::test));
500 cx.update(language_settings::init);
501
502 let text = concat!(
503 "fn main() {\n",
504 " \n",
505 "}\n" //
506 );
507 let buffer =
508 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
509 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
510 let position = buffer.read_with(cx, |buffer, cx| {
511 let snapshot = buffer.snapshot(cx);
512 snapshot.anchor_before(Point::new(1, 2))
513 });
514 let codegen =
515 cx.new_model(|cx| Codegen::new(buffer.clone(), CodegenKind::Generate { position }, cx));
516
517 let request = LanguageModelRequest::default();
518 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
519
520 let mut new_text = concat!(
521 "let mut x = 0;\n",
522 "while x < 10 {\n",
523 " x += 1;\n",
524 "}", //
525 );
526 while !new_text.is_empty() {
527 let max_len = cmp::min(new_text.len(), 10);
528 let len = rng.gen_range(1..=max_len);
529 let (chunk, suffix) = new_text.split_at(len);
530 provider.send_completion(chunk.into());
531 new_text = suffix;
532 cx.background_executor.run_until_parked();
533 }
534 provider.finish_completion();
535 cx.background_executor.run_until_parked();
536
537 assert_eq!(
538 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
539 indoc! {"
540 fn main() {
541 let mut x = 0;
542 while x < 10 {
543 x += 1;
544 }
545 }
546 "}
547 );
548 }
549
550 #[gpui::test]
551 async fn test_strip_invalid_spans_from_codeblock() {
552 assert_eq!(
553 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
554 .map(|chunk| chunk.unwrap())
555 .collect::<String>()
556 .await,
557 "Lorem ipsum dolor"
558 );
559 assert_eq!(
560 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
561 .map(|chunk| chunk.unwrap())
562 .collect::<String>()
563 .await,
564 "Lorem ipsum dolor"
565 );
566 assert_eq!(
567 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
568 .map(|chunk| chunk.unwrap())
569 .collect::<String>()
570 .await,
571 "Lorem ipsum dolor"
572 );
573 assert_eq!(
574 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
575 .map(|chunk| chunk.unwrap())
576 .collect::<String>()
577 .await,
578 "Lorem ipsum dolor"
579 );
580 assert_eq!(
581 strip_invalid_spans_from_codeblock(chunks(
582 "```html\n```js\nLorem ipsum dolor\n```\n```",
583 2
584 ))
585 .map(|chunk| chunk.unwrap())
586 .collect::<String>()
587 .await,
588 "```js\nLorem ipsum dolor\n```"
589 );
590 assert_eq!(
591 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
592 .map(|chunk| chunk.unwrap())
593 .collect::<String>()
594 .await,
595 "``\nLorem ipsum dolor\n```"
596 );
597 assert_eq!(
598 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
599 .map(|chunk| chunk.unwrap())
600 .collect::<String>()
601 .await,
602 "Lorem ipsum"
603 );
604
605 assert_eq!(
606 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
607 .map(|chunk| chunk.unwrap())
608 .collect::<String>()
609 .await,
610 "Lorem ipsum"
611 );
612
613 assert_eq!(
614 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
615 .map(|chunk| chunk.unwrap())
616 .collect::<String>()
617 .await,
618 "Lorem ipsum"
619 );
620 assert_eq!(
621 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
622 .map(|chunk| chunk.unwrap())
623 .collect::<String>()
624 .await,
625 "Lorem ipsum"
626 );
627 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
628 stream::iter(
629 text.chars()
630 .collect::<Vec<_>>()
631 .chunks(size)
632 .map(|chunk| Ok(chunk.iter().collect::<String>()))
633 .collect::<Vec<_>>(),
634 )
635 }
636 }
637
638 fn rust_lang() -> Language {
639 Language::new(
640 LanguageConfig {
641 name: "Rust".into(),
642 matcher: LanguageMatcher {
643 path_suffixes: vec!["rs".to_string()],
644 ..Default::default()
645 },
646 ..Default::default()
647 },
648 Some(tree_sitter_rust::language()),
649 )
650 .with_indents_query(
651 r#"
652 (call_expression) @indent
653 (field_expression) @indent
654 (_ "(" ")" @end) @indent
655 (_ "{" "}" @end) @indent
656 "#,
657 )
658 .unwrap()
659 }
660}