1use crate::context::ContextLoadResult;
2use crate::inline_prompt_editor::CodegenStatus;
3use crate::{context::load_context, context_store::ContextStore};
4use anyhow::Result;
5use client::telemetry::Telemetry;
6use collections::HashSet;
7use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
8use futures::{
9 SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
10};
11use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
12use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
13use language_model::{
14 LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
15 LanguageModelTextStream, Role, report_assistant_event,
16};
17use multi_buffer::MultiBufferRow;
18use parking_lot::Mutex;
19use project::Project;
20use prompt_store::PromptBuilder;
21use prompt_store::PromptStore;
22use rope::Rope;
23use smol::future::FutureExt;
24use std::{
25 cmp,
26 future::Future,
27 iter,
28 ops::{Range, RangeInclusive},
29 pin::Pin,
30 sync::Arc,
31 task::{self, Poll},
32 time::Instant,
33};
34use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
35use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
36
37pub struct BufferCodegen {
38 alternatives: Vec<Entity<CodegenAlternative>>,
39 pub active_alternative: usize,
40 seen_alternatives: HashSet<usize>,
41 subscriptions: Vec<Subscription>,
42 buffer: Entity<MultiBuffer>,
43 range: Range<Anchor>,
44 initial_transaction_id: Option<TransactionId>,
45 context_store: Entity<ContextStore>,
46 project: WeakEntity<Project>,
47 prompt_store: Option<Entity<PromptStore>>,
48 telemetry: Arc<Telemetry>,
49 builder: Arc<PromptBuilder>,
50 pub is_insertion: bool,
51}
52
53impl BufferCodegen {
54 pub fn new(
55 buffer: Entity<MultiBuffer>,
56 range: Range<Anchor>,
57 initial_transaction_id: Option<TransactionId>,
58 context_store: Entity<ContextStore>,
59 project: WeakEntity<Project>,
60 prompt_store: Option<Entity<PromptStore>>,
61 telemetry: Arc<Telemetry>,
62 builder: Arc<PromptBuilder>,
63 cx: &mut Context<Self>,
64 ) -> Self {
65 let codegen = cx.new(|cx| {
66 CodegenAlternative::new(
67 buffer.clone(),
68 range.clone(),
69 false,
70 Some(context_store.clone()),
71 project.clone(),
72 prompt_store.clone(),
73 Some(telemetry.clone()),
74 builder.clone(),
75 cx,
76 )
77 });
78 let mut this = Self {
79 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
80 alternatives: vec![codegen],
81 active_alternative: 0,
82 seen_alternatives: HashSet::default(),
83 subscriptions: Vec::new(),
84 buffer,
85 range,
86 initial_transaction_id,
87 context_store,
88 project,
89 prompt_store,
90 telemetry,
91 builder,
92 };
93 this.activate(0, cx);
94 this
95 }
96
97 fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
98 let codegen = self.active_alternative().clone();
99 self.subscriptions.clear();
100 self.subscriptions
101 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
102 self.subscriptions
103 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
104 }
105
106 pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
107 &self.alternatives[self.active_alternative]
108 }
109
110 pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
111 &self.active_alternative().read(cx).status
112 }
113
114 pub fn alternative_count(&self, cx: &App) -> usize {
115 LanguageModelRegistry::read_global(cx)
116 .inline_alternative_models()
117 .len()
118 + 1
119 }
120
121 pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
122 let next_active_ix = if self.active_alternative == 0 {
123 self.alternatives.len() - 1
124 } else {
125 self.active_alternative - 1
126 };
127 self.activate(next_active_ix, cx);
128 }
129
130 pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
131 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
132 self.activate(next_active_ix, cx);
133 }
134
135 fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
136 self.active_alternative()
137 .update(cx, |codegen, cx| codegen.set_active(false, cx));
138 self.seen_alternatives.insert(index);
139 self.active_alternative = index;
140 self.active_alternative()
141 .update(cx, |codegen, cx| codegen.set_active(true, cx));
142 self.subscribe_to_alternative(cx);
143 cx.notify();
144 }
145
146 pub fn start(
147 &mut self,
148 primary_model: Arc<dyn LanguageModel>,
149 user_prompt: String,
150 cx: &mut Context<Self>,
151 ) -> Result<()> {
152 let alternative_models = LanguageModelRegistry::read_global(cx)
153 .inline_alternative_models()
154 .to_vec();
155
156 self.active_alternative()
157 .update(cx, |alternative, cx| alternative.undo(cx));
158 self.activate(0, cx);
159 self.alternatives.truncate(1);
160
161 for _ in 0..alternative_models.len() {
162 self.alternatives.push(cx.new(|cx| {
163 CodegenAlternative::new(
164 self.buffer.clone(),
165 self.range.clone(),
166 false,
167 Some(self.context_store.clone()),
168 self.project.clone(),
169 self.prompt_store.clone(),
170 Some(self.telemetry.clone()),
171 self.builder.clone(),
172 cx,
173 )
174 }));
175 }
176
177 for (model, alternative) in iter::once(primary_model)
178 .chain(alternative_models)
179 .zip(&self.alternatives)
180 {
181 alternative.update(cx, |alternative, cx| {
182 alternative.start(user_prompt.clone(), model.clone(), cx)
183 })?;
184 }
185
186 Ok(())
187 }
188
189 pub fn stop(&mut self, cx: &mut Context<Self>) {
190 for codegen in &self.alternatives {
191 codegen.update(cx, |codegen, cx| codegen.stop(cx));
192 }
193 }
194
195 pub fn undo(&mut self, cx: &mut Context<Self>) {
196 self.active_alternative()
197 .update(cx, |codegen, cx| codegen.undo(cx));
198
199 self.buffer.update(cx, |buffer, cx| {
200 if let Some(transaction_id) = self.initial_transaction_id.take() {
201 buffer.undo_transaction(transaction_id, cx);
202 buffer.refresh_preview(cx);
203 }
204 });
205 }
206
207 pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
208 self.active_alternative().read(cx).buffer.clone()
209 }
210
211 pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
212 self.active_alternative().read(cx).old_buffer.clone()
213 }
214
215 pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
216 self.active_alternative().read(cx).snapshot.clone()
217 }
218
219 pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
220 self.active_alternative().read(cx).edit_position
221 }
222
223 pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
224 &self.active_alternative().read(cx).diff
225 }
226
227 pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
228 self.active_alternative().read(cx).last_equal_ranges()
229 }
230}
231
232impl EventEmitter<CodegenEvent> for BufferCodegen {}
233
234pub struct CodegenAlternative {
235 buffer: Entity<MultiBuffer>,
236 old_buffer: Entity<Buffer>,
237 snapshot: MultiBufferSnapshot,
238 edit_position: Option<Anchor>,
239 range: Range<Anchor>,
240 last_equal_ranges: Vec<Range<Anchor>>,
241 transformation_transaction_id: Option<TransactionId>,
242 status: CodegenStatus,
243 generation: Task<()>,
244 diff: Diff,
245 context_store: Option<Entity<ContextStore>>,
246 project: WeakEntity<Project>,
247 prompt_store: Option<Entity<PromptStore>>,
248 telemetry: Option<Arc<Telemetry>>,
249 _subscription: gpui::Subscription,
250 builder: Arc<PromptBuilder>,
251 active: bool,
252 edits: Vec<(Range<Anchor>, String)>,
253 line_operations: Vec<LineOperation>,
254 elapsed_time: Option<f64>,
255 completion: Option<String>,
256 pub message_id: Option<String>,
257}
258
259impl EventEmitter<CodegenEvent> for CodegenAlternative {}
260
261impl CodegenAlternative {
262 pub fn new(
263 buffer: Entity<MultiBuffer>,
264 range: Range<Anchor>,
265 active: bool,
266 context_store: Option<Entity<ContextStore>>,
267 project: WeakEntity<Project>,
268 prompt_store: Option<Entity<PromptStore>>,
269 telemetry: Option<Arc<Telemetry>>,
270 builder: Arc<PromptBuilder>,
271 cx: &mut Context<Self>,
272 ) -> Self {
273 let snapshot = buffer.read(cx).snapshot(cx);
274
275 let (old_buffer, _, _) = snapshot
276 .range_to_buffer_ranges(range.clone())
277 .pop()
278 .unwrap();
279 let old_buffer = cx.new(|cx| {
280 let text = old_buffer.as_rope().clone();
281 let line_ending = old_buffer.line_ending();
282 let language = old_buffer.language().cloned();
283 let language_registry = buffer
284 .read(cx)
285 .buffer(old_buffer.remote_id())
286 .unwrap()
287 .read(cx)
288 .language_registry();
289
290 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
291 buffer.set_language(language, cx);
292 if let Some(language_registry) = language_registry {
293 buffer.set_language_registry(language_registry)
294 }
295 buffer
296 });
297
298 Self {
299 buffer: buffer.clone(),
300 old_buffer,
301 edit_position: None,
302 message_id: None,
303 snapshot,
304 last_equal_ranges: Default::default(),
305 transformation_transaction_id: None,
306 status: CodegenStatus::Idle,
307 generation: Task::ready(()),
308 diff: Diff::default(),
309 context_store,
310 project,
311 prompt_store,
312 telemetry,
313 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
314 builder,
315 active,
316 edits: Vec::new(),
317 line_operations: Vec::new(),
318 range,
319 elapsed_time: None,
320 completion: None,
321 }
322 }
323
324 pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
325 if active != self.active {
326 self.active = active;
327
328 if self.active {
329 let edits = self.edits.clone();
330 self.apply_edits(edits, cx);
331 if matches!(self.status, CodegenStatus::Pending) {
332 let line_operations = self.line_operations.clone();
333 self.reapply_line_based_diff(line_operations, cx);
334 } else {
335 self.reapply_batch_diff(cx).detach();
336 }
337 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
338 self.buffer.update(cx, |buffer, cx| {
339 buffer.undo_transaction(transaction_id, cx);
340 buffer.forget_transaction(transaction_id, cx);
341 });
342 }
343 }
344 }
345
346 fn handle_buffer_event(
347 &mut self,
348 _buffer: Entity<MultiBuffer>,
349 event: &multi_buffer::Event,
350 cx: &mut Context<Self>,
351 ) {
352 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
353 if self.transformation_transaction_id == Some(*transaction_id) {
354 self.transformation_transaction_id = None;
355 self.generation = Task::ready(());
356 cx.emit(CodegenEvent::Undone);
357 }
358 }
359 }
360
361 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
362 &self.last_equal_ranges
363 }
364
365 pub fn start(
366 &mut self,
367 user_prompt: String,
368 model: Arc<dyn LanguageModel>,
369 cx: &mut Context<Self>,
370 ) -> Result<()> {
371 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
372 self.buffer.update(cx, |buffer, cx| {
373 buffer.undo_transaction(transformation_transaction_id, cx);
374 });
375 }
376
377 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
378
379 let api_key = model.api_key(cx);
380 let telemetry_id = model.telemetry_id();
381 let provider_id = model.provider_id();
382 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
383 if user_prompt.trim().to_lowercase() == "delete" {
384 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
385 } else {
386 let request = self.build_request(user_prompt, cx)?;
387 cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
388 .boxed_local()
389 };
390 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
391 Ok(())
392 }
393
394 fn build_request(
395 &self,
396 user_prompt: String,
397 cx: &mut App,
398 ) -> Result<Task<LanguageModelRequest>> {
399 let buffer = self.buffer.read(cx).snapshot(cx);
400 let language = buffer.language_at(self.range.start);
401 let language_name = if let Some(language) = language.as_ref() {
402 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
403 None
404 } else {
405 Some(language.name())
406 }
407 } else {
408 None
409 };
410
411 let language_name = language_name.as_ref();
412 let start = buffer.point_to_buffer_offset(self.range.start);
413 let end = buffer.point_to_buffer_offset(self.range.end);
414 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
415 let (start_buffer, start_buffer_offset) = start;
416 let (end_buffer, end_buffer_offset) = end;
417 if start_buffer.remote_id() == end_buffer.remote_id() {
418 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
419 } else {
420 return Err(anyhow::anyhow!("invalid transformation range"));
421 }
422 } else {
423 return Err(anyhow::anyhow!("invalid transformation range"));
424 };
425
426 let prompt = self
427 .builder
428 .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
429 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
430
431 let context_task = self.context_store.as_ref().map(|context_store| {
432 if let Some(project) = self.project.upgrade() {
433 let context = context_store
434 .read(cx)
435 .context()
436 .cloned()
437 .collect::<Vec<_>>();
438 load_context(context, &project, &self.prompt_store, cx)
439 } else {
440 Task::ready(ContextLoadResult::default())
441 }
442 });
443
444 Ok(cx.spawn(async move |_cx| {
445 let mut request_message = LanguageModelRequestMessage {
446 role: Role::User,
447 content: Vec::new(),
448 cache: false,
449 };
450
451 if let Some(context_task) = context_task {
452 context_task
453 .await
454 .loaded_context
455 .add_to_request_message(&mut request_message);
456 }
457
458 request_message.content.push(prompt.into());
459
460 LanguageModelRequest {
461 thread_id: None,
462 prompt_id: None,
463 mode: None,
464 tools: Vec::new(),
465 stop: Vec::new(),
466 temperature: None,
467 messages: vec![request_message],
468 }
469 }))
470 }
471
472 pub fn handle_stream(
473 &mut self,
474 model_telemetry_id: String,
475 model_provider_id: String,
476 model_api_key: Option<String>,
477 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
478 cx: &mut Context<Self>,
479 ) {
480 let start_time = Instant::now();
481 let snapshot = self.snapshot.clone();
482 let selected_text = snapshot
483 .text_for_range(self.range.start..self.range.end)
484 .collect::<Rope>();
485
486 let selection_start = self.range.start.to_point(&snapshot);
487
488 // Start with the indentation of the first line in the selection
489 let mut suggested_line_indent = snapshot
490 .suggested_indents(selection_start.row..=selection_start.row, cx)
491 .into_values()
492 .next()
493 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
494
495 // If the first line in the selection does not have indentation, check the following lines
496 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
497 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
498 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
499 // Prefer tabs if a line in the selection uses tabs as indentation
500 if line_indent.kind == IndentKind::Tab {
501 suggested_line_indent.kind = IndentKind::Tab;
502 break;
503 }
504 }
505 }
506
507 let http_client = cx.http_client();
508 let telemetry = self.telemetry.clone();
509 let language_name = {
510 let multibuffer = self.buffer.read(cx);
511 let snapshot = multibuffer.snapshot(cx);
512 let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
513 ranges
514 .first()
515 .and_then(|(buffer, _, _)| buffer.language())
516 .map(|language| language.name())
517 };
518
519 self.diff = Diff::default();
520 self.status = CodegenStatus::Pending;
521 let mut edit_start = self.range.start.to_offset(&snapshot);
522 let completion = Arc::new(Mutex::new(String::new()));
523 let completion_clone = completion.clone();
524
525 self.generation = cx.spawn(async move |codegen, cx| {
526 let stream = stream.await;
527 let token_usage = stream
528 .as_ref()
529 .ok()
530 .map(|stream| stream.last_token_usage.clone());
531 let message_id = stream
532 .as_ref()
533 .ok()
534 .and_then(|stream| stream.message_id.clone());
535 let generate = async {
536 let model_telemetry_id = model_telemetry_id.clone();
537 let model_provider_id = model_provider_id.clone();
538 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
539 let executor = cx.background_executor().clone();
540 let message_id = message_id.clone();
541 let line_based_stream_diff: Task<anyhow::Result<()>> =
542 cx.background_spawn(async move {
543 let mut response_latency = None;
544 let request_start = Instant::now();
545 let diff = async {
546 let chunks = StripInvalidSpans::new(
547 stream?.stream.map_err(|error| error.into()),
548 );
549 futures::pin_mut!(chunks);
550 let mut diff = StreamingDiff::new(selected_text.to_string());
551 let mut line_diff = LineDiff::default();
552
553 let mut new_text = String::new();
554 let mut base_indent = None;
555 let mut line_indent = None;
556 let mut first_line = true;
557
558 while let Some(chunk) = chunks.next().await {
559 if response_latency.is_none() {
560 response_latency = Some(request_start.elapsed());
561 }
562 let chunk = chunk?;
563 completion_clone.lock().push_str(&chunk);
564
565 let mut lines = chunk.split('\n').peekable();
566 while let Some(line) = lines.next() {
567 new_text.push_str(line);
568 if line_indent.is_none() {
569 if let Some(non_whitespace_ch_ix) =
570 new_text.find(|ch: char| !ch.is_whitespace())
571 {
572 line_indent = Some(non_whitespace_ch_ix);
573 base_indent = base_indent.or(line_indent);
574
575 let line_indent = line_indent.unwrap();
576 let base_indent = base_indent.unwrap();
577 let indent_delta =
578 line_indent as i32 - base_indent as i32;
579 let mut corrected_indent_len = cmp::max(
580 0,
581 suggested_line_indent.len as i32 + indent_delta,
582 )
583 as usize;
584 if first_line {
585 corrected_indent_len = corrected_indent_len
586 .saturating_sub(
587 selection_start.column as usize,
588 );
589 }
590
591 let indent_char = suggested_line_indent.char();
592 let mut indent_buffer = [0; 4];
593 let indent_str =
594 indent_char.encode_utf8(&mut indent_buffer);
595 new_text.replace_range(
596 ..line_indent,
597 &indent_str.repeat(corrected_indent_len),
598 );
599 }
600 }
601
602 if line_indent.is_some() {
603 let char_ops = diff.push_new(&new_text);
604 line_diff.push_char_operations(&char_ops, &selected_text);
605 diff_tx
606 .send((char_ops, line_diff.line_operations()))
607 .await?;
608 new_text.clear();
609 }
610
611 if lines.peek().is_some() {
612 let char_ops = diff.push_new("\n");
613 line_diff.push_char_operations(&char_ops, &selected_text);
614 diff_tx
615 .send((char_ops, line_diff.line_operations()))
616 .await?;
617 if line_indent.is_none() {
618 // Don't write out the leading indentation in empty lines on the next line
619 // This is the case where the above if statement didn't clear the buffer
620 new_text.clear();
621 }
622 line_indent = None;
623 first_line = false;
624 }
625 }
626 }
627
628 let mut char_ops = diff.push_new(&new_text);
629 char_ops.extend(diff.finish());
630 line_diff.push_char_operations(&char_ops, &selected_text);
631 line_diff.finish(&selected_text);
632 diff_tx
633 .send((char_ops, line_diff.line_operations()))
634 .await?;
635
636 anyhow::Ok(())
637 };
638
639 let result = diff.await;
640
641 let error_message = result.as_ref().err().map(|error| error.to_string());
642 report_assistant_event(
643 AssistantEventData {
644 conversation_id: None,
645 message_id,
646 kind: AssistantKind::Inline,
647 phase: AssistantPhase::Response,
648 model: model_telemetry_id,
649 model_provider: model_provider_id,
650 response_latency,
651 error_message,
652 language_name: language_name.map(|name| name.to_proto()),
653 },
654 telemetry,
655 http_client,
656 model_api_key,
657 &executor,
658 );
659
660 result?;
661 Ok(())
662 });
663
664 while let Some((char_ops, line_ops)) = diff_rx.next().await {
665 codegen.update(cx, |codegen, cx| {
666 codegen.last_equal_ranges.clear();
667
668 let edits = char_ops
669 .into_iter()
670 .filter_map(|operation| match operation {
671 CharOperation::Insert { text } => {
672 let edit_start = snapshot.anchor_after(edit_start);
673 Some((edit_start..edit_start, text))
674 }
675 CharOperation::Delete { bytes } => {
676 let edit_end = edit_start + bytes;
677 let edit_range = snapshot.anchor_after(edit_start)
678 ..snapshot.anchor_before(edit_end);
679 edit_start = edit_end;
680 Some((edit_range, String::new()))
681 }
682 CharOperation::Keep { bytes } => {
683 let edit_end = edit_start + bytes;
684 let edit_range = snapshot.anchor_after(edit_start)
685 ..snapshot.anchor_before(edit_end);
686 edit_start = edit_end;
687 codegen.last_equal_ranges.push(edit_range);
688 None
689 }
690 })
691 .collect::<Vec<_>>();
692
693 if codegen.active {
694 codegen.apply_edits(edits.iter().cloned(), cx);
695 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
696 }
697 codegen.edits.extend(edits);
698 codegen.line_operations = line_ops;
699 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
700
701 cx.notify();
702 })?;
703 }
704
705 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
706 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
707 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
708 let batch_diff_task =
709 codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
710 let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
711 line_based_stream_diff?;
712
713 anyhow::Ok(())
714 };
715
716 let result = generate.await;
717 let elapsed_time = start_time.elapsed().as_secs_f64();
718
719 codegen
720 .update(cx, |this, cx| {
721 this.message_id = message_id;
722 this.last_equal_ranges.clear();
723 if let Err(error) = result {
724 this.status = CodegenStatus::Error(error);
725 } else {
726 this.status = CodegenStatus::Done;
727 }
728 this.elapsed_time = Some(elapsed_time);
729 this.completion = Some(completion.lock().clone());
730 if let Some(usage) = token_usage {
731 let usage = usage.lock();
732 telemetry::event!(
733 "Inline Assistant Completion",
734 model = model_telemetry_id,
735 model_provider = model_provider_id,
736 input_tokens = usage.input_tokens,
737 output_tokens = usage.output_tokens,
738 )
739 }
740 cx.emit(CodegenEvent::Finished);
741 cx.notify();
742 })
743 .ok();
744 });
745 cx.notify();
746 }
747
748 pub fn stop(&mut self, cx: &mut Context<Self>) {
749 self.last_equal_ranges.clear();
750 if self.diff.is_empty() {
751 self.status = CodegenStatus::Idle;
752 } else {
753 self.status = CodegenStatus::Done;
754 }
755 self.generation = Task::ready(());
756 cx.emit(CodegenEvent::Finished);
757 cx.notify();
758 }
759
760 pub fn undo(&mut self, cx: &mut Context<Self>) {
761 self.buffer.update(cx, |buffer, cx| {
762 if let Some(transaction_id) = self.transformation_transaction_id.take() {
763 buffer.undo_transaction(transaction_id, cx);
764 buffer.refresh_preview(cx);
765 }
766 });
767 }
768
769 fn apply_edits(
770 &mut self,
771 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
772 cx: &mut Context<CodegenAlternative>,
773 ) {
774 let transaction = self.buffer.update(cx, |buffer, cx| {
775 // Avoid grouping assistant edits with user edits.
776 buffer.finalize_last_transaction(cx);
777 buffer.start_transaction(cx);
778 buffer.edit(edits, None, cx);
779 buffer.end_transaction(cx)
780 });
781
782 if let Some(transaction) = transaction {
783 if let Some(first_transaction) = self.transformation_transaction_id {
784 // Group all assistant edits into the first transaction.
785 self.buffer.update(cx, |buffer, cx| {
786 buffer.merge_transactions(transaction, first_transaction, cx)
787 });
788 } else {
789 self.transformation_transaction_id = Some(transaction);
790 self.buffer
791 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
792 }
793 }
794 }
795
796 fn reapply_line_based_diff(
797 &mut self,
798 line_operations: impl IntoIterator<Item = LineOperation>,
799 cx: &mut Context<Self>,
800 ) {
801 let old_snapshot = self.snapshot.clone();
802 let old_range = self.range.to_point(&old_snapshot);
803 let new_snapshot = self.buffer.read(cx).snapshot(cx);
804 let new_range = self.range.to_point(&new_snapshot);
805
806 let mut old_row = old_range.start.row;
807 let mut new_row = new_range.start.row;
808
809 self.diff.deleted_row_ranges.clear();
810 self.diff.inserted_row_ranges.clear();
811 for operation in line_operations {
812 match operation {
813 LineOperation::Keep { lines } => {
814 old_row += lines;
815 new_row += lines;
816 }
817 LineOperation::Delete { lines } => {
818 let old_end_row = old_row + lines - 1;
819 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
820
821 if let Some((_, last_deleted_row_range)) =
822 self.diff.deleted_row_ranges.last_mut()
823 {
824 if *last_deleted_row_range.end() + 1 == old_row {
825 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
826 } else {
827 self.diff
828 .deleted_row_ranges
829 .push((new_row, old_row..=old_end_row));
830 }
831 } else {
832 self.diff
833 .deleted_row_ranges
834 .push((new_row, old_row..=old_end_row));
835 }
836
837 old_row += lines;
838 }
839 LineOperation::Insert { lines } => {
840 let new_end_row = new_row + lines - 1;
841 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
842 let end = new_snapshot.anchor_before(Point::new(
843 new_end_row,
844 new_snapshot.line_len(MultiBufferRow(new_end_row)),
845 ));
846 self.diff.inserted_row_ranges.push(start..end);
847 new_row += lines;
848 }
849 }
850
851 cx.notify();
852 }
853 }
854
855 fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
856 let old_snapshot = self.snapshot.clone();
857 let old_range = self.range.to_point(&old_snapshot);
858 let new_snapshot = self.buffer.read(cx).snapshot(cx);
859 let new_range = self.range.to_point(&new_snapshot);
860
861 cx.spawn(async move |codegen, cx| {
862 let (deleted_row_ranges, inserted_row_ranges) = cx
863 .background_spawn(async move {
864 let old_text = old_snapshot
865 .text_for_range(
866 Point::new(old_range.start.row, 0)
867 ..Point::new(
868 old_range.end.row,
869 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
870 ),
871 )
872 .collect::<String>();
873 let new_text = new_snapshot
874 .text_for_range(
875 Point::new(new_range.start.row, 0)
876 ..Point::new(
877 new_range.end.row,
878 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
879 ),
880 )
881 .collect::<String>();
882
883 let old_start_row = old_range.start.row;
884 let new_start_row = new_range.start.row;
885 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
886 let mut inserted_row_ranges = Vec::new();
887 for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
888 let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
889 let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
890 if !old_rows.is_empty() {
891 deleted_row_ranges.push((
892 new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
893 old_rows.start..=old_rows.end - 1,
894 ));
895 }
896 if !new_rows.is_empty() {
897 let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
898 let new_end_row = new_rows.end - 1;
899 let end = new_snapshot.anchor_before(Point::new(
900 new_end_row,
901 new_snapshot.line_len(MultiBufferRow(new_end_row)),
902 ));
903 inserted_row_ranges.push(start..end);
904 }
905 }
906 (deleted_row_ranges, inserted_row_ranges)
907 })
908 .await;
909
910 codegen
911 .update(cx, |codegen, cx| {
912 codegen.diff.deleted_row_ranges = deleted_row_ranges;
913 codegen.diff.inserted_row_ranges = inserted_row_ranges;
914 cx.notify();
915 })
916 .ok();
917 })
918 }
919}
920
921#[derive(Copy, Clone, Debug)]
922pub enum CodegenEvent {
923 Finished,
924 Undone,
925}
926
927struct StripInvalidSpans<T> {
928 stream: T,
929 stream_done: bool,
930 buffer: String,
931 first_line: bool,
932 line_end: bool,
933 starts_with_code_block: bool,
934}
935
936impl<T> StripInvalidSpans<T>
937where
938 T: Stream<Item = Result<String>>,
939{
940 fn new(stream: T) -> Self {
941 Self {
942 stream,
943 stream_done: false,
944 buffer: String::new(),
945 first_line: true,
946 line_end: false,
947 starts_with_code_block: false,
948 }
949 }
950}
951
952impl<T> Stream for StripInvalidSpans<T>
953where
954 T: Stream<Item = Result<String>>,
955{
956 type Item = Result<String>;
957
958 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
959 const CODE_BLOCK_DELIMITER: &str = "```";
960 const CURSOR_SPAN: &str = "<|CURSOR|>";
961
962 let this = unsafe { self.get_unchecked_mut() };
963 loop {
964 if !this.stream_done {
965 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
966 match stream.as_mut().poll_next(cx) {
967 Poll::Ready(Some(Ok(chunk))) => {
968 this.buffer.push_str(&chunk);
969 }
970 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
971 Poll::Ready(None) => {
972 this.stream_done = true;
973 }
974 Poll::Pending => return Poll::Pending,
975 }
976 }
977
978 let mut chunk = String::new();
979 let mut consumed = 0;
980 if !this.buffer.is_empty() {
981 let mut lines = this.buffer.split('\n').enumerate().peekable();
982 while let Some((line_ix, line)) = lines.next() {
983 if line_ix > 0 {
984 this.first_line = false;
985 }
986
987 if this.first_line {
988 let trimmed_line = line.trim();
989 if lines.peek().is_some() {
990 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
991 consumed += line.len() + 1;
992 this.starts_with_code_block = true;
993 continue;
994 }
995 } else if trimmed_line.is_empty()
996 || prefixes(CODE_BLOCK_DELIMITER)
997 .any(|prefix| trimmed_line.starts_with(prefix))
998 {
999 break;
1000 }
1001 }
1002
1003 let line_without_cursor = line.replace(CURSOR_SPAN, "");
1004 if lines.peek().is_some() {
1005 if this.line_end {
1006 chunk.push('\n');
1007 }
1008
1009 chunk.push_str(&line_without_cursor);
1010 this.line_end = true;
1011 consumed += line.len() + 1;
1012 } else if this.stream_done {
1013 if !this.starts_with_code_block
1014 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1015 {
1016 if this.line_end {
1017 chunk.push('\n');
1018 }
1019
1020 chunk.push_str(&line);
1021 }
1022
1023 consumed += line.len();
1024 } else {
1025 let trimmed_line = line.trim();
1026 if trimmed_line.is_empty()
1027 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1028 || prefixes(CODE_BLOCK_DELIMITER)
1029 .any(|prefix| trimmed_line.ends_with(prefix))
1030 {
1031 break;
1032 } else {
1033 if this.line_end {
1034 chunk.push('\n');
1035 this.line_end = false;
1036 }
1037
1038 chunk.push_str(&line_without_cursor);
1039 consumed += line.len();
1040 }
1041 }
1042 }
1043 }
1044
1045 this.buffer = this.buffer.split_off(consumed);
1046 if !chunk.is_empty() {
1047 return Poll::Ready(Some(Ok(chunk)));
1048 } else if this.stream_done {
1049 return Poll::Ready(None);
1050 }
1051 }
1052 }
1053}
1054
1055fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1056 (0..text.len() - 1).map(|ix| &text[..ix + 1])
1057}
1058
1059#[derive(Default)]
1060pub struct Diff {
1061 pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1062 pub inserted_row_ranges: Vec<Range<Anchor>>,
1063}
1064
1065impl Diff {
1066 fn is_empty(&self) -> bool {
1067 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1068 }
1069}
1070
1071#[cfg(test)]
1072mod tests {
1073 use super::*;
1074 use fs::FakeFs;
1075 use futures::{
1076 Stream,
1077 stream::{self},
1078 };
1079 use gpui::TestAppContext;
1080 use indoc::indoc;
1081 use language::{
1082 Buffer, Language, LanguageConfig, LanguageMatcher, Point, language_settings,
1083 tree_sitter_rust,
1084 };
1085 use language_model::{LanguageModelRegistry, TokenUsage};
1086 use rand::prelude::*;
1087 use serde::Serialize;
1088 use settings::SettingsStore;
1089 use std::{future, sync::Arc};
1090
1091 #[derive(Serialize)]
1092 pub struct DummyCompletionRequest {
1093 pub name: String,
1094 }
1095
1096 #[gpui::test(iterations = 10)]
1097 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1098 cx.set_global(cx.update(SettingsStore::test));
1099 cx.update(language_model::LanguageModelRegistry::test);
1100 cx.update(language_settings::init);
1101
1102 let text = indoc! {"
1103 fn main() {
1104 let x = 0;
1105 for _ in 0..10 {
1106 x += 1;
1107 }
1108 }
1109 "};
1110 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1111 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1112 let range = buffer.read_with(cx, |buffer, cx| {
1113 let snapshot = buffer.snapshot(cx);
1114 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1115 });
1116 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1117 let fs = FakeFs::new(cx.executor());
1118 let project = Project::test(fs, vec![], cx).await;
1119 let codegen = cx.new(|cx| {
1120 CodegenAlternative::new(
1121 buffer.clone(),
1122 range.clone(),
1123 true,
1124 None,
1125 project.downgrade(),
1126 None,
1127 None,
1128 prompt_builder,
1129 cx,
1130 )
1131 });
1132
1133 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1134
1135 let mut new_text = concat!(
1136 " let mut x = 0;\n",
1137 " while x < 10 {\n",
1138 " x += 1;\n",
1139 " }",
1140 );
1141 while !new_text.is_empty() {
1142 let max_len = cmp::min(new_text.len(), 10);
1143 let len = rng.gen_range(1..=max_len);
1144 let (chunk, suffix) = new_text.split_at(len);
1145 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1146 new_text = suffix;
1147 cx.background_executor.run_until_parked();
1148 }
1149 drop(chunks_tx);
1150 cx.background_executor.run_until_parked();
1151
1152 assert_eq!(
1153 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1154 indoc! {"
1155 fn main() {
1156 let mut x = 0;
1157 while x < 10 {
1158 x += 1;
1159 }
1160 }
1161 "}
1162 );
1163 }
1164
1165 #[gpui::test(iterations = 10)]
1166 async fn test_autoindent_when_generating_past_indentation(
1167 cx: &mut TestAppContext,
1168 mut rng: StdRng,
1169 ) {
1170 cx.set_global(cx.update(SettingsStore::test));
1171 cx.update(language_settings::init);
1172
1173 let text = indoc! {"
1174 fn main() {
1175 le
1176 }
1177 "};
1178 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1179 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1180 let range = buffer.read_with(cx, |buffer, cx| {
1181 let snapshot = buffer.snapshot(cx);
1182 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1183 });
1184 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1185 let fs = FakeFs::new(cx.executor());
1186 let project = Project::test(fs, vec![], cx).await;
1187 let codegen = cx.new(|cx| {
1188 CodegenAlternative::new(
1189 buffer.clone(),
1190 range.clone(),
1191 true,
1192 None,
1193 project.downgrade(),
1194 None,
1195 None,
1196 prompt_builder,
1197 cx,
1198 )
1199 });
1200
1201 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1202
1203 cx.background_executor.run_until_parked();
1204
1205 let mut new_text = concat!(
1206 "t mut x = 0;\n",
1207 "while x < 10 {\n",
1208 " x += 1;\n",
1209 "}", //
1210 );
1211 while !new_text.is_empty() {
1212 let max_len = cmp::min(new_text.len(), 10);
1213 let len = rng.gen_range(1..=max_len);
1214 let (chunk, suffix) = new_text.split_at(len);
1215 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1216 new_text = suffix;
1217 cx.background_executor.run_until_parked();
1218 }
1219 drop(chunks_tx);
1220 cx.background_executor.run_until_parked();
1221
1222 assert_eq!(
1223 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1224 indoc! {"
1225 fn main() {
1226 let mut x = 0;
1227 while x < 10 {
1228 x += 1;
1229 }
1230 }
1231 "}
1232 );
1233 }
1234
1235 #[gpui::test(iterations = 10)]
1236 async fn test_autoindent_when_generating_before_indentation(
1237 cx: &mut TestAppContext,
1238 mut rng: StdRng,
1239 ) {
1240 cx.update(LanguageModelRegistry::test);
1241 cx.set_global(cx.update(SettingsStore::test));
1242 cx.update(language_settings::init);
1243
1244 let text = concat!(
1245 "fn main() {\n",
1246 " \n",
1247 "}\n" //
1248 );
1249 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1250 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1251 let range = buffer.read_with(cx, |buffer, cx| {
1252 let snapshot = buffer.snapshot(cx);
1253 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1254 });
1255 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1256 let fs = FakeFs::new(cx.executor());
1257 let project = Project::test(fs, vec![], cx).await;
1258 let codegen = cx.new(|cx| {
1259 CodegenAlternative::new(
1260 buffer.clone(),
1261 range.clone(),
1262 true,
1263 None,
1264 project.downgrade(),
1265 None,
1266 None,
1267 prompt_builder,
1268 cx,
1269 )
1270 });
1271
1272 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1273
1274 cx.background_executor.run_until_parked();
1275
1276 let mut new_text = concat!(
1277 "let mut x = 0;\n",
1278 "while x < 10 {\n",
1279 " x += 1;\n",
1280 "}", //
1281 );
1282 while !new_text.is_empty() {
1283 let max_len = cmp::min(new_text.len(), 10);
1284 let len = rng.gen_range(1..=max_len);
1285 let (chunk, suffix) = new_text.split_at(len);
1286 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1287 new_text = suffix;
1288 cx.background_executor.run_until_parked();
1289 }
1290 drop(chunks_tx);
1291 cx.background_executor.run_until_parked();
1292
1293 assert_eq!(
1294 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1295 indoc! {"
1296 fn main() {
1297 let mut x = 0;
1298 while x < 10 {
1299 x += 1;
1300 }
1301 }
1302 "}
1303 );
1304 }
1305
1306 #[gpui::test(iterations = 10)]
1307 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1308 cx.update(LanguageModelRegistry::test);
1309 cx.set_global(cx.update(SettingsStore::test));
1310 cx.update(language_settings::init);
1311
1312 let text = indoc! {"
1313 func main() {
1314 \tx := 0
1315 \tfor i := 0; i < 10; i++ {
1316 \t\tx++
1317 \t}
1318 }
1319 "};
1320 let buffer = cx.new(|cx| Buffer::local(text, cx));
1321 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1322 let range = buffer.read_with(cx, |buffer, cx| {
1323 let snapshot = buffer.snapshot(cx);
1324 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1325 });
1326 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1327 let fs = FakeFs::new(cx.executor());
1328 let project = Project::test(fs, vec![], cx).await;
1329 let codegen = cx.new(|cx| {
1330 CodegenAlternative::new(
1331 buffer.clone(),
1332 range.clone(),
1333 true,
1334 None,
1335 project.downgrade(),
1336 None,
1337 None,
1338 prompt_builder,
1339 cx,
1340 )
1341 });
1342
1343 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1344 let new_text = concat!(
1345 "func main() {\n",
1346 "\tx := 0\n",
1347 "\tfor x < 10 {\n",
1348 "\t\tx++\n",
1349 "\t}", //
1350 );
1351 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1352 drop(chunks_tx);
1353 cx.background_executor.run_until_parked();
1354
1355 assert_eq!(
1356 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1357 indoc! {"
1358 func main() {
1359 \tx := 0
1360 \tfor x < 10 {
1361 \t\tx++
1362 \t}
1363 }
1364 "}
1365 );
1366 }
1367
1368 #[gpui::test]
1369 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1370 cx.update(LanguageModelRegistry::test);
1371 cx.set_global(cx.update(SettingsStore::test));
1372 cx.update(language_settings::init);
1373
1374 let text = indoc! {"
1375 fn main() {
1376 let x = 0;
1377 }
1378 "};
1379 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1380 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1381 let range = buffer.read_with(cx, |buffer, cx| {
1382 let snapshot = buffer.snapshot(cx);
1383 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1384 });
1385 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1386 let fs = FakeFs::new(cx.executor());
1387 let project = Project::test(fs, vec![], cx).await;
1388 let codegen = cx.new(|cx| {
1389 CodegenAlternative::new(
1390 buffer.clone(),
1391 range.clone(),
1392 false,
1393 None,
1394 project.downgrade(),
1395 None,
1396 None,
1397 prompt_builder,
1398 cx,
1399 )
1400 });
1401
1402 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1403 chunks_tx
1404 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1405 .unwrap();
1406 drop(chunks_tx);
1407 cx.run_until_parked();
1408
1409 // The codegen is inactive, so the buffer doesn't get modified.
1410 assert_eq!(
1411 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1412 text
1413 );
1414
1415 // Activating the codegen applies the changes.
1416 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1417 assert_eq!(
1418 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1419 indoc! {"
1420 fn main() {
1421 let mut x = 0;
1422 x += 1;
1423 }
1424 "}
1425 );
1426
1427 // Deactivating the codegen undoes the changes.
1428 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1429 cx.run_until_parked();
1430 assert_eq!(
1431 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1432 text
1433 );
1434 }
1435
1436 #[gpui::test]
1437 async fn test_strip_invalid_spans_from_codeblock() {
1438 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1439 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1440 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1441 assert_chunks(
1442 "```html\n```js\nLorem ipsum dolor\n```\n```",
1443 "```js\nLorem ipsum dolor\n```",
1444 )
1445 .await;
1446 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1447 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1448 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1449 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1450
1451 async fn assert_chunks(text: &str, expected_text: &str) {
1452 for chunk_size in 1..=text.len() {
1453 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1454 .map(|chunk| chunk.unwrap())
1455 .collect::<String>()
1456 .await;
1457 assert_eq!(
1458 actual_text, expected_text,
1459 "failed to strip invalid spans, chunk size: {}",
1460 chunk_size
1461 );
1462 }
1463 }
1464
1465 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1466 stream::iter(
1467 text.chars()
1468 .collect::<Vec<_>>()
1469 .chunks(size)
1470 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1471 .collect::<Vec<_>>(),
1472 )
1473 }
1474 }
1475
1476 fn simulate_response_stream(
1477 codegen: Entity<CodegenAlternative>,
1478 cx: &mut TestAppContext,
1479 ) -> mpsc::UnboundedSender<String> {
1480 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1481 codegen.update(cx, |codegen, cx| {
1482 codegen.handle_stream(
1483 String::new(),
1484 String::new(),
1485 None,
1486 future::ready(Ok(LanguageModelTextStream {
1487 message_id: None,
1488 stream: chunks_rx.map(Ok).boxed(),
1489 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1490 })),
1491 cx,
1492 );
1493 });
1494 chunks_tx
1495 }
1496
1497 fn rust_lang() -> Language {
1498 Language::new(
1499 LanguageConfig {
1500 name: "Rust".into(),
1501 matcher: LanguageMatcher {
1502 path_suffixes: vec!["rs".to_string()],
1503 ..Default::default()
1504 },
1505 ..Default::default()
1506 },
1507 Some(tree_sitter_rust::LANGUAGE.into()),
1508 )
1509 .with_indents_query(
1510 r#"
1511 (call_expression) @indent
1512 (field_expression) @indent
1513 (_ "(" ")" @end) @indent
1514 (_ "{" "}" @end) @indent
1515 "#,
1516 )
1517 .unwrap()
1518 }
1519}