1use crate::{
2 streaming_diff::{Hunk, StreamingDiff},
3 CompletionProvider, LanguageModelRequest,
4};
5use anyhow::Result;
6use client::telemetry::Telemetry;
7use editor::{Anchor, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint};
8use futures::{channel::mpsc, SinkExt, Stream, StreamExt};
9use gpui::{EventEmitter, Model, ModelContext, Task};
10use language::{Rope, TransactionId};
11use multi_buffer::MultiBufferRow;
12use std::{cmp, future, ops::Range, sync::Arc, time::Instant};
13
14#[derive(Debug)]
15pub enum Event {
16 Finished,
17 Undone,
18}
19
20#[derive(Clone)]
21pub enum CodegenKind {
22 Transform { range: Range<Anchor> },
23 Generate { position: Anchor },
24}
25
26pub struct Codegen {
27 buffer: Model<MultiBuffer>,
28 snapshot: MultiBufferSnapshot,
29 kind: CodegenKind,
30 last_equal_ranges: Vec<Range<Anchor>>,
31 transaction_id: Option<TransactionId>,
32 error: Option<anyhow::Error>,
33 generation: Task<()>,
34 idle: bool,
35 telemetry: Option<Arc<Telemetry>>,
36 _subscription: gpui::Subscription,
37}
38
39impl EventEmitter<Event> for Codegen {}
40
41impl Codegen {
42 pub fn new(
43 buffer: Model<MultiBuffer>,
44 kind: CodegenKind,
45 telemetry: Option<Arc<Telemetry>>,
46 cx: &mut ModelContext<Self>,
47 ) -> Self {
48 let snapshot = buffer.read(cx).snapshot(cx);
49 Self {
50 buffer: buffer.clone(),
51 snapshot,
52 kind,
53 last_equal_ranges: Default::default(),
54 transaction_id: Default::default(),
55 error: Default::default(),
56 idle: true,
57 generation: Task::ready(()),
58 telemetry,
59 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
60 }
61 }
62
63 fn handle_buffer_event(
64 &mut self,
65 _buffer: Model<MultiBuffer>,
66 event: &multi_buffer::Event,
67 cx: &mut ModelContext<Self>,
68 ) {
69 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
70 if self.transaction_id == Some(*transaction_id) {
71 self.transaction_id = None;
72 self.generation = Task::ready(());
73 cx.emit(Event::Undone);
74 }
75 }
76 }
77
78 pub fn range(&self) -> Range<Anchor> {
79 match &self.kind {
80 CodegenKind::Transform { range } => range.clone(),
81 CodegenKind::Generate { position } => position.bias_left(&self.snapshot)..*position,
82 }
83 }
84
85 pub fn kind(&self) -> &CodegenKind {
86 &self.kind
87 }
88
89 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
90 &self.last_equal_ranges
91 }
92
93 pub fn idle(&self) -> bool {
94 self.idle
95 }
96
97 pub fn error(&self) -> Option<&anyhow::Error> {
98 self.error.as_ref()
99 }
100
101 pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
102 let range = self.range();
103 let snapshot = self.snapshot.clone();
104 let selected_text = snapshot
105 .text_for_range(range.start..range.end)
106 .collect::<Rope>();
107
108 let selection_start = range.start.to_point(&snapshot);
109 let suggested_line_indent = snapshot
110 .suggested_indents(selection_start.row..selection_start.row + 1, cx)
111 .into_values()
112 .next()
113 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
114
115 let model_telemetry_id = prompt.model.telemetry_id();
116 let response = CompletionProvider::global(cx).complete(prompt);
117 let telemetry = self.telemetry.clone();
118 self.generation = cx.spawn(|this, mut cx| {
119 async move {
120 let generate = async {
121 let mut edit_start = range.start.to_offset(&snapshot);
122
123 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
124 let diff: Task<anyhow::Result<()>> =
125 cx.background_executor().spawn(async move {
126 let mut response_latency = None;
127 let request_start = Instant::now();
128 let diff = async {
129 let chunks = strip_invalid_spans_from_codeblock(response.await?);
130 futures::pin_mut!(chunks);
131 let mut diff = StreamingDiff::new(selected_text.to_string());
132
133 let mut new_text = String::new();
134 let mut base_indent = None;
135 let mut line_indent = None;
136 let mut first_line = true;
137
138 while let Some(chunk) = chunks.next().await {
139 if response_latency.is_none() {
140 response_latency = Some(request_start.elapsed());
141 }
142 let chunk = chunk?;
143
144 let mut lines = chunk.split('\n').peekable();
145 while let Some(line) = lines.next() {
146 new_text.push_str(line);
147 if line_indent.is_none() {
148 if let Some(non_whitespace_ch_ix) =
149 new_text.find(|ch: char| !ch.is_whitespace())
150 {
151 line_indent = Some(non_whitespace_ch_ix);
152 base_indent = base_indent.or(line_indent);
153
154 let line_indent = line_indent.unwrap();
155 let base_indent = base_indent.unwrap();
156 let indent_delta =
157 line_indent as i32 - base_indent as i32;
158 let mut corrected_indent_len = cmp::max(
159 0,
160 suggested_line_indent.len as i32 + indent_delta,
161 )
162 as usize;
163 if first_line {
164 corrected_indent_len = corrected_indent_len
165 .saturating_sub(
166 selection_start.column as usize,
167 );
168 }
169
170 let indent_char = suggested_line_indent.char();
171 let mut indent_buffer = [0; 4];
172 let indent_str =
173 indent_char.encode_utf8(&mut indent_buffer);
174 new_text.replace_range(
175 ..line_indent,
176 &indent_str.repeat(corrected_indent_len),
177 );
178 }
179 }
180
181 if line_indent.is_some() {
182 hunks_tx.send(diff.push_new(&new_text)).await?;
183 new_text.clear();
184 }
185
186 if lines.peek().is_some() {
187 hunks_tx.send(diff.push_new("\n")).await?;
188 line_indent = None;
189 first_line = false;
190 }
191 }
192 }
193 hunks_tx.send(diff.push_new(&new_text)).await?;
194 hunks_tx.send(diff.finish()).await?;
195
196 anyhow::Ok(())
197 };
198
199 let result = diff.await;
200
201 let error_message =
202 result.as_ref().err().map(|error| error.to_string());
203 if let Some(telemetry) = telemetry {
204 telemetry.report_assistant_event(
205 None,
206 telemetry_events::AssistantKind::Inline,
207 model_telemetry_id,
208 response_latency,
209 error_message,
210 );
211 }
212
213 result?;
214 Ok(())
215 });
216
217 while let Some(hunks) = hunks_rx.next().await {
218 this.update(&mut cx, |this, cx| {
219 this.last_equal_ranges.clear();
220
221 let transaction = this.buffer.update(cx, |buffer, cx| {
222 // Avoid grouping assistant edits with user edits.
223 buffer.finalize_last_transaction(cx);
224
225 buffer.start_transaction(cx);
226 buffer.edit(
227 hunks.into_iter().filter_map(|hunk| match hunk {
228 Hunk::Insert { text } => {
229 let edit_start = snapshot.anchor_after(edit_start);
230 Some((edit_start..edit_start, text))
231 }
232 Hunk::Remove { len } => {
233 let edit_end = edit_start + len;
234 let edit_range = snapshot.anchor_after(edit_start)
235 ..snapshot.anchor_before(edit_end);
236 edit_start = edit_end;
237 Some((edit_range, String::new()))
238 }
239 Hunk::Keep { len } => {
240 let edit_end = edit_start + len;
241 let edit_range = snapshot.anchor_after(edit_start)
242 ..snapshot.anchor_before(edit_end);
243 edit_start = edit_end;
244 this.last_equal_ranges.push(edit_range);
245 None
246 }
247 }),
248 None,
249 cx,
250 );
251
252 buffer.end_transaction(cx)
253 });
254
255 if let Some(transaction) = transaction {
256 if let Some(first_transaction) = this.transaction_id {
257 // Group all assistant edits into the first transaction.
258 this.buffer.update(cx, |buffer, cx| {
259 buffer.merge_transactions(
260 transaction,
261 first_transaction,
262 cx,
263 )
264 });
265 } else {
266 this.transaction_id = Some(transaction);
267 this.buffer.update(cx, |buffer, cx| {
268 buffer.finalize_last_transaction(cx)
269 });
270 }
271 }
272
273 cx.notify();
274 })?;
275 }
276
277 diff.await?;
278
279 anyhow::Ok(())
280 };
281
282 let result = generate.await;
283 this.update(&mut cx, |this, cx| {
284 this.last_equal_ranges.clear();
285 this.idle = true;
286 if let Err(error) = result {
287 this.error = Some(error);
288 }
289 cx.emit(Event::Finished);
290 cx.notify();
291 })
292 .ok();
293 }
294 });
295 self.error.take();
296 self.idle = false;
297 cx.notify();
298 }
299
300 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
301 if let Some(transaction_id) = self.transaction_id {
302 self.buffer
303 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
304 }
305 }
306}
307
308fn strip_invalid_spans_from_codeblock(
309 stream: impl Stream<Item = Result<String>>,
310) -> impl Stream<Item = Result<String>> {
311 let mut first_line = true;
312 let mut buffer = String::new();
313 let mut starts_with_markdown_codeblock = false;
314 let mut includes_start_or_end_span = false;
315 stream.filter_map(move |chunk| {
316 let chunk = match chunk {
317 Ok(chunk) => chunk,
318 Err(err) => return future::ready(Some(Err(err))),
319 };
320 buffer.push_str(&chunk);
321
322 if buffer.len() > "<|S|".len() && buffer.starts_with("<|S|") {
323 includes_start_or_end_span = true;
324
325 buffer = buffer
326 .strip_prefix("<|S|>")
327 .or_else(|| buffer.strip_prefix("<|S|"))
328 .unwrap_or(&buffer)
329 .to_string();
330 } else if buffer.ends_with("|E|>") {
331 includes_start_or_end_span = true;
332 } else if buffer.starts_with("<|")
333 || buffer.starts_with("<|S")
334 || buffer.starts_with("<|S|")
335 || buffer.ends_with('|')
336 || buffer.ends_with("|E")
337 || buffer.ends_with("|E|")
338 {
339 return future::ready(None);
340 }
341
342 if first_line {
343 if buffer.is_empty() || buffer == "`" || buffer == "``" {
344 return future::ready(None);
345 } else if buffer.starts_with("```") {
346 starts_with_markdown_codeblock = true;
347 if let Some(newline_ix) = buffer.find('\n') {
348 buffer.replace_range(..newline_ix + 1, "");
349 first_line = false;
350 } else {
351 return future::ready(None);
352 }
353 }
354 }
355
356 let mut text = buffer.to_string();
357 if starts_with_markdown_codeblock {
358 text = text
359 .strip_suffix("\n```\n")
360 .or_else(|| text.strip_suffix("\n```"))
361 .or_else(|| text.strip_suffix("\n``"))
362 .or_else(|| text.strip_suffix("\n`"))
363 .or_else(|| text.strip_suffix('\n'))
364 .unwrap_or(&text)
365 .to_string();
366 }
367
368 if includes_start_or_end_span {
369 text = text
370 .strip_suffix("|E|>")
371 .or_else(|| text.strip_suffix("E|>"))
372 .or_else(|| text.strip_prefix("|>"))
373 .or_else(|| text.strip_prefix('>'))
374 .unwrap_or(&text)
375 .to_string();
376 };
377
378 if text.contains('\n') {
379 first_line = false;
380 }
381
382 let remainder = buffer.split_off(text.len());
383 let result = if buffer.is_empty() {
384 None
385 } else {
386 Some(Ok(buffer.clone()))
387 };
388
389 buffer = remainder;
390 future::ready(result)
391 })
392}
393
394#[cfg(test)]
395mod tests {
396 use std::sync::Arc;
397
398 use crate::FakeCompletionProvider;
399
400 use super::*;
401 use futures::stream::{self};
402 use gpui::{Context, TestAppContext};
403 use indoc::indoc;
404 use language::{
405 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
406 Point,
407 };
408 use rand::prelude::*;
409 use serde::Serialize;
410 use settings::SettingsStore;
411
412 #[derive(Serialize)]
413 pub struct DummyCompletionRequest {
414 pub name: String,
415 }
416
417 #[gpui::test(iterations = 10)]
418 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
419 let provider = FakeCompletionProvider::default();
420 cx.set_global(cx.update(SettingsStore::test));
421 cx.set_global(CompletionProvider::Fake(provider.clone()));
422 cx.update(language_settings::init);
423
424 let text = indoc! {"
425 fn main() {
426 let x = 0;
427 for _ in 0..10 {
428 x += 1;
429 }
430 }
431 "};
432 let buffer =
433 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
434 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
435 let range = buffer.read_with(cx, |buffer, cx| {
436 let snapshot = buffer.snapshot(cx);
437 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
438 });
439 let codegen = cx.new_model(|cx| {
440 Codegen::new(buffer.clone(), CodegenKind::Transform { range }, None, cx)
441 });
442
443 let request = LanguageModelRequest::default();
444 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
445
446 let mut new_text = concat!(
447 " let mut x = 0;\n",
448 " while x < 10 {\n",
449 " x += 1;\n",
450 " }",
451 );
452 while !new_text.is_empty() {
453 let max_len = cmp::min(new_text.len(), 10);
454 let len = rng.gen_range(1..=max_len);
455 let (chunk, suffix) = new_text.split_at(len);
456 provider.send_completion(chunk.into());
457 new_text = suffix;
458 cx.background_executor.run_until_parked();
459 }
460 provider.finish_completion();
461 cx.background_executor.run_until_parked();
462
463 assert_eq!(
464 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
465 indoc! {"
466 fn main() {
467 let mut x = 0;
468 while x < 10 {
469 x += 1;
470 }
471 }
472 "}
473 );
474 }
475
476 #[gpui::test(iterations = 10)]
477 async fn test_autoindent_when_generating_past_indentation(
478 cx: &mut TestAppContext,
479 mut rng: StdRng,
480 ) {
481 let provider = FakeCompletionProvider::default();
482 cx.set_global(CompletionProvider::Fake(provider.clone()));
483 cx.set_global(cx.update(SettingsStore::test));
484 cx.update(language_settings::init);
485
486 let text = indoc! {"
487 fn main() {
488 le
489 }
490 "};
491 let buffer =
492 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
493 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
494 let position = buffer.read_with(cx, |buffer, cx| {
495 let snapshot = buffer.snapshot(cx);
496 snapshot.anchor_before(Point::new(1, 6))
497 });
498 let codegen = cx.new_model(|cx| {
499 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
500 });
501
502 let request = LanguageModelRequest::default();
503 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
504
505 let mut new_text = concat!(
506 "t mut x = 0;\n",
507 "while x < 10 {\n",
508 " x += 1;\n",
509 "}", //
510 );
511 while !new_text.is_empty() {
512 let max_len = cmp::min(new_text.len(), 10);
513 let len = rng.gen_range(1..=max_len);
514 let (chunk, suffix) = new_text.split_at(len);
515 provider.send_completion(chunk.into());
516 new_text = suffix;
517 cx.background_executor.run_until_parked();
518 }
519 provider.finish_completion();
520 cx.background_executor.run_until_parked();
521
522 assert_eq!(
523 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
524 indoc! {"
525 fn main() {
526 let mut x = 0;
527 while x < 10 {
528 x += 1;
529 }
530 }
531 "}
532 );
533 }
534
535 #[gpui::test(iterations = 10)]
536 async fn test_autoindent_when_generating_before_indentation(
537 cx: &mut TestAppContext,
538 mut rng: StdRng,
539 ) {
540 let provider = FakeCompletionProvider::default();
541 cx.set_global(CompletionProvider::Fake(provider.clone()));
542 cx.set_global(cx.update(SettingsStore::test));
543 cx.update(language_settings::init);
544
545 let text = concat!(
546 "fn main() {\n",
547 " \n",
548 "}\n" //
549 );
550 let buffer =
551 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
552 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
553 let position = buffer.read_with(cx, |buffer, cx| {
554 let snapshot = buffer.snapshot(cx);
555 snapshot.anchor_before(Point::new(1, 2))
556 });
557 let codegen = cx.new_model(|cx| {
558 Codegen::new(buffer.clone(), CodegenKind::Generate { position }, None, cx)
559 });
560
561 let request = LanguageModelRequest::default();
562 codegen.update(cx, |codegen, cx| codegen.start(request, cx));
563
564 let mut new_text = concat!(
565 "let mut x = 0;\n",
566 "while x < 10 {\n",
567 " x += 1;\n",
568 "}", //
569 );
570 while !new_text.is_empty() {
571 let max_len = cmp::min(new_text.len(), 10);
572 let len = rng.gen_range(1..=max_len);
573 let (chunk, suffix) = new_text.split_at(len);
574 provider.send_completion(chunk.into());
575 new_text = suffix;
576 cx.background_executor.run_until_parked();
577 }
578 provider.finish_completion();
579 cx.background_executor.run_until_parked();
580
581 assert_eq!(
582 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
583 indoc! {"
584 fn main() {
585 let mut x = 0;
586 while x < 10 {
587 x += 1;
588 }
589 }
590 "}
591 );
592 }
593
594 #[gpui::test]
595 async fn test_strip_invalid_spans_from_codeblock() {
596 assert_eq!(
597 strip_invalid_spans_from_codeblock(chunks("Lorem ipsum dolor", 2))
598 .map(|chunk| chunk.unwrap())
599 .collect::<String>()
600 .await,
601 "Lorem ipsum dolor"
602 );
603 assert_eq!(
604 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor", 2))
605 .map(|chunk| chunk.unwrap())
606 .collect::<String>()
607 .await,
608 "Lorem ipsum dolor"
609 );
610 assert_eq!(
611 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
612 .map(|chunk| chunk.unwrap())
613 .collect::<String>()
614 .await,
615 "Lorem ipsum dolor"
616 );
617 assert_eq!(
618 strip_invalid_spans_from_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
619 .map(|chunk| chunk.unwrap())
620 .collect::<String>()
621 .await,
622 "Lorem ipsum dolor"
623 );
624 assert_eq!(
625 strip_invalid_spans_from_codeblock(chunks(
626 "```html\n```js\nLorem ipsum dolor\n```\n```",
627 2
628 ))
629 .map(|chunk| chunk.unwrap())
630 .collect::<String>()
631 .await,
632 "```js\nLorem ipsum dolor\n```"
633 );
634 assert_eq!(
635 strip_invalid_spans_from_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
636 .map(|chunk| chunk.unwrap())
637 .collect::<String>()
638 .await,
639 "``\nLorem ipsum dolor\n```"
640 );
641 assert_eq!(
642 strip_invalid_spans_from_codeblock(chunks("<|S|Lorem ipsum|E|>", 2))
643 .map(|chunk| chunk.unwrap())
644 .collect::<String>()
645 .await,
646 "Lorem ipsum"
647 );
648
649 assert_eq!(
650 strip_invalid_spans_from_codeblock(chunks("<|S|>Lorem ipsum", 2))
651 .map(|chunk| chunk.unwrap())
652 .collect::<String>()
653 .await,
654 "Lorem ipsum"
655 );
656
657 assert_eq!(
658 strip_invalid_spans_from_codeblock(chunks("```\n<|S|>Lorem ipsum\n```", 2))
659 .map(|chunk| chunk.unwrap())
660 .collect::<String>()
661 .await,
662 "Lorem ipsum"
663 );
664 assert_eq!(
665 strip_invalid_spans_from_codeblock(chunks("```\n<|S|Lorem ipsum|E|>\n```", 2))
666 .map(|chunk| chunk.unwrap())
667 .collect::<String>()
668 .await,
669 "Lorem ipsum"
670 );
671 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
672 stream::iter(
673 text.chars()
674 .collect::<Vec<_>>()
675 .chunks(size)
676 .map(|chunk| Ok(chunk.iter().collect::<String>()))
677 .collect::<Vec<_>>(),
678 )
679 }
680 }
681
682 fn rust_lang() -> Language {
683 Language::new(
684 LanguageConfig {
685 name: "Rust".into(),
686 matcher: LanguageMatcher {
687 path_suffixes: vec!["rs".to_string()],
688 ..Default::default()
689 },
690 ..Default::default()
691 },
692 Some(tree_sitter_rust::language()),
693 )
694 .with_indents_query(
695 r#"
696 (call_expression) @indent
697 (field_expression) @indent
698 (_ "(" ")" @end) @indent
699 (_ "{" "}" @end) @indent
700 "#,
701 )
702 .unwrap()
703 }
704}