1use crate::{
2 streaming_diff::{Hunk, StreamingDiff},
3 CompletionProvider, LanguageModelRequest,
4};
5use anyhow::Result;
6use client::telemetry::Telemetry;
7use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
8use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
9use gpui::{EventEmitter, Model, ModelContext, Task};
10use language::{Rope, TransactionId};
11use multi_buffer::MultiBufferRow;
12use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
13
14pub enum Event {
15 Finished,
16 Undone,
17}
18
19#[derive(Clone)]
20pub enum CodegenKind {
21 Transform { range: Range<Anchor> },
22 Generate { position: Anchor },
23}
24
25pub struct Codegen {
26 buffer: Model<MultiBuffer>,
27 snapshot: MultiBufferSnapshot,
28 kind: CodegenKind,
29 last_equal_ranges: Vec<Range<Anchor>>,
30 transaction_id: Option<TransactionId>,
31 error: Option<anyhow::Error>,
32 generation: Task<()>,
33 idle: bool,
34 telemetry: Option<Arc<Telemetry>>,
35 _subscription: gpui::Subscription,
36}
37
38impl EventEmitter<Event> for Codegen {}
39
40impl Codegen {
41 pub fn new(
42 buffer: Model<MultiBuffer>,
43 kind: CodegenKind,
44 telemetry: Option<Arc<Telemetry>>,
45 cx: &mut ModelContext<Self>,
46 ) -> Self {
47 let snapshot = buffer.read(cx).snapshot(cx);
48 Self {
49 buffer: buffer.clone(),
50 snapshot,
51 kind,
52 last_equal_ranges: Default::default(),
53 transaction_id: Default::default(),
54 error: Default::default(),
55 idle: true,
56 generation: Task::ready(()),
57 telemetry,
58 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
59 }
60 }
61
62 fn handle_buffer_event(
63 &mut self,
64 _buffer: Model<MultiBuffer>,
65 event: &multi_buffer::Event,
66 cx: &mut ModelContext<Self>,
67 ) {
68 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
69 if self.transaction_id == Some(*transaction_id) {
70 self.transaction_id = None;
71 self.generation = Task::ready(());
72 cx.emit(Event::Undone);
73 }
74 }
75 }
76
77 pub fn range(&self) -> Range<Anchor> {
78 match &self.kind {
79 CodegenKind::Transform { range } => range.clone(),
80 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
81 }
82 }
83
84 pub fn kind(&self) -> &CodegenKind {
85 &self.kind
86 }
87
88 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
89 &self.last_equal_ranges
90 }
91
92 pub fn idle(&self) -> bool {
93 self.idle
94 }
95
96 pub fn error(&self) -> Option<&anyhow::Error> {
97 self.error.as_ref()
98 }
99
100 pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
101 let range = self.range();
102 let snapshot = self.snapshot.clone();
103 let selected_text = snapshot
104 .text_for_range(range.start..range.end)
105 .collect::<Rope>();
106
107 let selection_start = range.start.to_point(&snapshot);
108 let suggested_line_indent = snapshot
109 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
110 .into_values()
111 .next()
112 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
113
114 let model_telemetry_id = prompt.model.telemetry_id();
115 let response = CompletionProvider::global(cx).complete(prompt);
116 let telemetry = self.telemetry.clone();
117 self.generation = cx.spawn(|this, mut cx| {
118 async move {
119 let generate = async {
120 let mut edit_start = range.start.to_offset(&snapshot);
121
122 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
123 let diff = cx.background_executor().spawn(async move {
124 let mut response_latency = None;
125 let request_start = Instant::now();
126 let diff = async {
127 let chunks = strip_invalid_spans_from_codeblock(response.await?);
128 futures::pin_mut!(chunks);
129 let mut diff = StreamingDiff::new(selected_text.to_string());
130
131 let mut new_text = String::new();
132 let mut base_indent = None;
133 let mut line_indent = None;
134 let mut first_line = true;
135
136 while let Some(chunk) = chunks.next().await {
137 if response_latency.is_none() {
138 response_latency = Some(request_start.elapsed());
139 }
140 let chunk = chunk?;
141
142 let mut lines = chunk.split('\n').peekable();
143 while let Some(line) = lines.next() {
144 new_text.push_str(line);
145 if line_indent.is_none() {
146 if let Some(non_whitespace_ch_ix) =
147 new_text.find(|ch: char| !ch.is_whitespace())
148 {
149 line_indent = Some(non_whitespace_ch_ix);
150 base_indent = base_indent.or(line_indent);
151
152 let line_indent = line_indent.unwrap();
153 let base_indent = base_indent.unwrap();
154 let indent_delta =
155 line_indent as i32 - base_indent as i32;
156 let mut corrected_indent_len = cmp::max(
157 0,
158 suggested_line_indent.len as i32 + indent_delta,
159 )
160 as usize;
161 if first_line {
162 corrected_indent_len = corrected_indent_len
163 .saturating_sub(
164 selection_start.column as usize,
165 );
166 }
167
168 let indent_char = suggested_line_indent.char();
169 let mut indent_buffer = [0; 4];
170 let indent_str =
171 indent_char.encode_utf8(&mut indent_buffer);
172 new_text.replace_range(
173 ..line_indent,
174 &indent_str.repeat(corrected_indent_len),
175 );
176 }
177 }
178
179 if line_indent.is_some() {
180 hunks_tx.send(diff.push_new(&new_text)).await?;
181 new_text.clear();
182 }
183
184 if lines.peek().is_some() {
185 hunks_tx.send(diff.push_new("\n")).await?;
186 line_indent = None;
187 first_line = false;
188 }
189 }
190 }
191 hunks_tx.send(diff.push_new(&new_text)).await?;
192 hunks_tx.send(diff.finish()).await?;
193
194 anyhow::Ok(())
195 };
196
197 let error_message = diff.await.err().map(|error| error.to_string());
198 if let Some(telemetry) = telemetry {
199 telemetry.report_assistant_event(
200 None,
201 telemetry_events::AssistantKind::Inline,
202 model_telemetry_id,
203 response_latency,
204 error_message,
205 );
206 }
207 });
208
209 while let Some(hunks) = hunks_rx.next().await {
210 this.update(&mut cx, |this, cx| {
211 this.last_equal_ranges.clear();
212
213 let transaction = this.buffer.update(cx, |buffer, cx| {
214 // Avoid grouping assistant edits with user edits.
215 buffer.finalize_last_transaction(cx);
216
217 buffer.start_transaction(cx);
218 buffer.edit(
219 hunks.into_iter().filter_map(|hunk| match hunk {
220 Hunk::Insert { text } => {
221 let edit_start = snapshot.anchor_after(edit_start);
222 Some((edit_start..edit_start, text))
223 }
224 Hunk::Remove { len } => {
225 let edit_end = edit_start + len;
226 let edit_range = snapshot.anchor_after(edit_start)
227 ..snapshot.anchor_before(edit_end);
228 edit_start = edit_end;
229 Some((edit_range, String::new()))
230 }
231 Hunk::Keep { len } => {
232 let edit_end = edit_start + len;
233 let edit_range = snapshot.anchor_after(edit_start)
234 ..snapshot.anchor_before(edit_end);
235 edit_start = edit_end;
236 this.last_equal_ranges.push(edit_range);
237 None
238 }
239 }),
240 None,
241 cx,
242 );
243
244 buffer.end_transaction(cx)
245 });
246
247 if let Some(transaction) = transaction {
248 if let Some(first_transaction) = this.transaction_id {
249 // Group all assistant edits into the first transaction.
250 this.buffer.update(cx, |buffer, cx| {
251 buffer.merge_transactions(
252 transaction,
253 first_transaction,
254 cx,
255 )
256 });
257 } else {
258 this.transaction_id = Some(transaction);
259 this.buffer.update(cx, |buffer, cx| {
260 buffer.finalize_last_transaction(cx)
261 });
262 }
263 }
264
265 cx.notify();
266 })?;
267 }
268
269 diff.await;
270
271 anyhow::Ok(())
272 };
273
274 let result = generate.await;
275 this.update(&mut cx, |this, cx| {
276 this.last_equal_ranges.clear();
277 this.idle = true;
278 if let Err(error) = result {
279 this.error = Some(error);
280 }
281 cx.emit(Event::Finished);
282 cx.notify();
283 })
284 .ok();
285 }
286 });
287 self.error.take();
288 self.idle = false;
289 cx.notify();
290 }
291
292 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
293 if let Some(transaction_id) = self.transaction_id {
294 self.buffer
295 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
296 }
297 }
298}
299
300fn strip_invalid_spans_from_codeblock(
301 stream: impl Stream<Item = Result<String>>,
302) -> impl Stream<Item = Result<String>> {
303 let mut first_line = true;
304 let mut buffer = String::new();
305 let mut starts_with_markdown_codeblock = false;
306 let mut includes_start_or_end_span = false;
307 stream.filter_map(move |chunk| {
308 let chunk = match chunk {
309 Ok(chunk) => chunk,
310 Err(err) => return future::ready(Some(Err(err))),
311 };
312 buffer.push_str(&chunk);
313
314 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
315 includes_start_or_end_span = true;
316
317 buffer = buffer
318 .strip_prefix("<|S|>")
319 .or_else(|| buffer.strip_prefix("<|S|"))
320 .unwrap_or(&buffer)
321 .to_string();
322 } else if buffer.ends_with("|E|>") {
323 includes_start_or_end_span = true;
324 } else if buffer.starts_with("<|")
325 || buffer.starts_with("<|S")
326 || buffer.starts_with("<|S|")
327 || buffer.ends_with('|')
328 || buffer.ends_with("|E")
329 || buffer.ends_with("|E|")
330 {
331 return future::ready(None);
332 }
333
334 if first_line {
335 if buffer.is_empty() || buffer == "`" || buffer == "``" {
336 return future::ready(None);
337 } else if buffer.starts_with("```") {
338 starts_with_markdown_codeblock = true;
339 if let Some(newline_ix) = buffer.find('\n') {
340 buffer.replace_range(..newline_ix + 1, "");
341 first_line = false;
342 } else {
343 return future::ready(None);
344 }
345 }
346 }
347
348 let mut text = buffer.to_string();
349 if starts_with_markdown_codeblock {
350 text = text
351 .strip_suffix("\n```\n")
352 .or_else(|| text.strip_suffix("\n```"))
353 .or_else(|| text.strip_suffix("\n``"))
354 .or_else(|| text.strip_suffix("\n`"))
355 .or_else(|| text.strip_suffix('\n'))
356 .unwrap_or(&text)
357 .to_string();
358 }
359
360 if includes_start_or_end_span {
361 text = text
362 .strip_suffix("|E|>")
363 .or_else(|| text.strip_suffix("E|>"))
364 .or_else(|| text.strip_prefix("|>"))
365 .or_else(|| text.strip_prefix('>'))
366 .unwrap_or(&text)
367 .to_string();
368 };
369
370 if text.contains('\n') {
371 first_line = false;
372 }
373
374 let remainder = buffer.split_off(text.len());
375 let result = if buffer.is_empty() {
376 None
377 } else {
378 Some(Ok(buffer.clone()))
379 };
380
381 buffer = remainder;
382 future::ready(result)
383 })
384}
385
386#[cfg(test)]
387mod tests {
388 use std::sync::Arc;
389
390 use crate::FakeCompletionProvider;
391
392 use super::*;
393 use futures::stream::{self};
394 use gpui::{Context, TestAppContext};
395 use indoc::indoc;
396 use language::{
397 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
398 Point,
399 };
400 use rand::prelude::*;
401 use serde::Serialize;
402 use settings::SettingsStore;
403
404 #[derive(Serialize)]
405 pub struct DummyCompletionRequest {
406 pub name: String,
407 }
408
409 #[gpui::test(iterations = 10)]
410 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
411 let provider = FakeCompletionProvider::default();
412 cx.set_global(cx.update(SettingsStore::test));
413 cx.set_global(CompletionProvider::Fake(provider.clone()));
414 cx.update(language_settings::init);
415
416 let text = indoc! {"
417 fn main() {
418 let x = 0;
419 for _ in 0..10 {
420 x += 1;
421 }
422 }
423 "};
424 let buffer =
425 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
426 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
427 let range = buffer.read_with(cx, |buffer, cx| {
428 let snapshot = buffer.snapshot(cx);
429 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
430 });
431 let codegen = cx.new_model(|cx| {
432 Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
433 });
434
435 let request = LanguageModelRequest::default();
436 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
437
438 let mut new_text = concat!(
439 " let mut x = 0;\n",
440 " while x < 10 {\n",
441 " x += 1;\n",
442 " }",
443 );
444 while !new_text.is_empty() {
445 let max_len = cmp::min(new_text.len(), 10);
446 let len = rng.gen_range(1..=max_len);
447 let (chunk, suffix) = new_text.split_at(len);
448 provider.send_completion(chunk.into());
449 new_text = suffix;
450 cx.background_executor.run_until_parked();
451 }
452 provider.finish_completion();
453 cx.background_executor.run_until_parked();
454
455 assert_eq!(
456 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
457 indoc! {"
458 fn main() {
459 let mut x = 0;
460 while x < 10 {
461 x += 1;
462 }
463 }
464 "}
465 );
466 }
467
468 #[gpui::test(iterations = 10)]
469 async fn test_autoindent_when_generating_past_indentation(
470 cx: &mut TestAppContext,
471 mut rng: StdRng,
472 ) {
473 let provider = FakeCompletionProvider::default();
474 cx.set_global(CompletionProvider::Fake(provider.clone()));
475 cx.set_global(cx.update(SettingsStore::test));
476 cx.update(language_settings::init);
477
478 let text = indoc! {"
479 fn main() {
480 le
481 }
482 "};
483 let buffer =
484 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
485 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
486 let position = buffer.read_with(cx, |buffer, cx| {
487 let snapshot = buffer.snapshot(cx);
488 snapshot.anchor_before(Point::new(1, 6))
489 });
490 let codegen = cx.new_model(|cx| {
491 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
492 });
493
494 let request = LanguageModelRequest::default();
495 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
496
497 let mut new_text = concat!(
498 "t mut x = 0;\n",
499 "while x < 10 {\n",
500 " x += 1;\n",
501 "}", //
502 );
503 while !new_text.is_empty() {
504 let max_len = cmp::min(new_text.len(), 10);
505 let len = rng.gen_range(1..=max_len);
506 let (chunk, suffix) = new_text.split_at(len);
507 provider.send_completion(chunk.into());
508 new_text = suffix;
509 cx.background_executor.run_until_parked();
510 }
511 provider.finish_completion();
512 cx.background_executor.run_until_parked();
513
514 assert_eq!(
515 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
516 indoc! {"
517 fn main() {
518 let mut x = 0;
519 while x < 10 {
520 x += 1;
521 }
522 }
523 "}
524 );
525 }
526
527 #[gpui::test(iterations = 10)]
528 async fn test_autoindent_when_generating_before_indentation(
529 cx: &mut TestAppContext,
530 mut rng: StdRng,
531 ) {
532 let provider = FakeCompletionProvider::default();
533 cx.set_global(CompletionProvider::Fake(provider.clone()));
534 cx.set_global(cx.update(SettingsStore::test));
535 cx.update(language_settings::init);
536
537 let text = concat!(
538 "fn main() {\n",
539 " \n",
540 "}\n" //
541 );
542 let buffer =
543 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
544 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
545 let position = buffer.read_with(cx, |buffer, cx| {
546 let snapshot = buffer.snapshot(cx);
547 snapshot.anchor_before(Point::new(1, 2))
548 });
549 let codegen = cx.new_model(|cx| {
550 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
551 });
552
553 let request = LanguageModelRequest::default();
554 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
555
556 let mut new_text = concat!(
557 "let mut x = 0;\n",
558 "while x < 10 {\n",
559 " x += 1;\n",
560 "}", //
561 );
562 while !new_text.is_empty() {
563 let max_len = cmp::min(new_text.len(), 10);
564 let len = rng.gen_range(1..=max_len);
565 let (chunk, suffix) = new_text.split_at(len);
566 provider.send_completion(chunk.into());
567 new_text = suffix;
568 cx.background_executor.run_until_parked();
569 }
570 provider.finish_completion();
571 cx.background_executor.run_until_parked();
572
573 assert_eq!(
574 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
575 indoc! {"
576 fn main() {
577 let mut x = 0;
578 while x < 10 {
579 x += 1;
580 }
581 }
582 "}
583 );
584 }
585
586 #[gpui::test]
587 async fn test_strip_invalid_spans_from_codeblock() {
588 assert_eq!(
589 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
590 .map(|chunk| chunk.unwrap())
591 .collect::<String>()
592 .await,
593 "Lorem ipsum dolor"
594 );
595 assert_eq!(
596 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
597 .map(|chunk| chunk.unwrap())
598 .collect::<String>()
599 .await,
600 "Lorem ipsum dolor"
601 );
602 assert_eq!(
603 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
604 .map(|chunk| chunk.unwrap())
605 .collect::<String>()
606 .await,
607 "Lorem ipsum dolor"
608 );
609 assert_eq!(
610 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
611 .map(|chunk| chunk.unwrap())
612 .collect::<String>()
613 .await,
614 "Lorem ipsum dolor"
615 );
616 assert_eq!(
617 strip_invalid_spans_from_codeblock(chunks(
618 "```html\n```js\nLorem ipsum dolor\n```\n```",
619 2
620 ))
621 .map(|chunk| chunk.unwrap())
622 .collect::<String>()
623 .await,
624 "```js\nLorem ipsum dolor\n```"
625 );
626 assert_eq!(
627 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
628 .map(|chunk| chunk.unwrap())
629 .collect::<String>()
630 .await,
631 "``\nLorem ipsum dolor\n```"
632 );
633 assert_eq!(
634 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
635 .map(|chunk| chunk.unwrap())
636 .collect::<String>()
637 .await,
638 "Lorem ipsum"
639 );
640
641 assert_eq!(
642 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
643 .map(|chunk| chunk.unwrap())
644 .collect::<String>()
645 .await,
646 "Lorem ipsum"
647 );
648
649 assert_eq!(
650 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
651 .map(|chunk| chunk.unwrap())
652 .collect::<String>()
653 .await,
654 "Lorem ipsum"
655 );
656 assert_eq!(
657 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
658 .map(|chunk| chunk.unwrap())
659 .collect::<String>()
660 .await,
661 "Lorem ipsum"
662 );
663 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
664 stream::iter(
665 text.chars()
666 .collect::<Vec<_>>()
667 .chunks(size)
668 .map(|chunk| Ok(chunk.iter().collect::<String>()))
669 .collect::<Vec<_>>(),
670 )
671 }
672 }
673
674 fn rust_lang() -> Language {
675 Language::new(
676 LanguageConfig {
677 name: "Rust".into(),
678 matcher: LanguageMatcher {
679 path_suffixes: vec!["rs".to_string()],
680 ..Default::default()
681 },
682 ..Default::default()
683 },
684 Some(tree_sitter_rust::language()),
685 )
686 .with_indents_query(
687 r#"
688 (call_expression) @indent
689 (field_expression) @indent
690 (_ "(" ")" @end) @indent
691 (_ "{" "}" @end) @indent
692 "#,
693 )
694 .unwrap()
695 }
696}