1use crate::context::ContextLoadResult;
2use crate::inline_prompt_editor::CodegenStatus;
3use crate::{context::load_context, context_store::ContextStore};
4use anyhow::Result;
5use client::telemetry::Telemetry;
6use collections::HashSet;
7use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
8use futures::{
9 SinkExt, Stream, StreamExt, TryStreamExt as _, channel::mpsc, future::LocalBoxFuture, join,
10};
11use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Subscription, Task, WeakEntity};
12use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
13use language_model::{
14 LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
15 LanguageModelTextStream, Role, report_assistant_event,
16};
17use multi_buffer::MultiBufferRow;
18use parking_lot::Mutex;
19use project::Project;
20use prompt_store::PromptBuilder;
21use prompt_store::PromptStore;
22use rope::Rope;
23use smol::future::FutureExt;
24use std::{
25 cmp,
26 future::Future,
27 iter,
28 ops::{Range, RangeInclusive},
29 pin::Pin,
30 sync::Arc,
31 task::{self, Poll},
32 time::Instant,
33};
34use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
35use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
36
37pub struct BufferCodegen {
38 alternatives: Vec<Entity<CodegenAlternative>>,
39 pub active_alternative: usize,
40 seen_alternatives: HashSet<usize>,
41 subscriptions: Vec<Subscription>,
42 buffer: Entity<MultiBuffer>,
43 range: Range<Anchor>,
44 initial_transaction_id: Option<TransactionId>,
45 context_store: Entity<ContextStore>,
46 project: WeakEntity<Project>,
47 prompt_store: Option<Entity<PromptStore>>,
48 telemetry: Arc<Telemetry>,
49 builder: Arc<PromptBuilder>,
50 pub is_insertion: bool,
51}
52
53impl BufferCodegen {
54 pub fn new(
55 buffer: Entity<MultiBuffer>,
56 range: Range<Anchor>,
57 initial_transaction_id: Option<TransactionId>,
58 context_store: Entity<ContextStore>,
59 project: WeakEntity<Project>,
60 prompt_store: Option<Entity<PromptStore>>,
61 telemetry: Arc<Telemetry>,
62 builder: Arc<PromptBuilder>,
63 cx: &mut Context<Self>,
64 ) -> Self {
65 let codegen = cx.new(|cx| {
66 CodegenAlternative::new(
67 buffer.clone(),
68 range.clone(),
69 false,
70 Some(context_store.clone()),
71 project.clone(),
72 prompt_store.clone(),
73 Some(telemetry.clone()),
74 builder.clone(),
75 cx,
76 )
77 });
78 let mut this = Self {
79 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
80 alternatives: vec![codegen],
81 active_alternative: 0,
82 seen_alternatives: HashSet::default(),
83 subscriptions: Vec::new(),
84 buffer,
85 range,
86 initial_transaction_id,
87 context_store,
88 project,
89 prompt_store,
90 telemetry,
91 builder,
92 };
93 this.activate(0, cx);
94 this
95 }
96
97 fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
98 let codegen = self.active_alternative().clone();
99 self.subscriptions.clear();
100 self.subscriptions
101 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
102 self.subscriptions
103 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
104 }
105
106 pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
107 &self.alternatives[self.active_alternative]
108 }
109
110 pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
111 &self.active_alternative().read(cx).status
112 }
113
114 pub fn alternative_count(&self, cx: &App) -> usize {
115 LanguageModelRegistry::read_global(cx)
116 .inline_alternative_models()
117 .len()
118 + 1
119 }
120
121 pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
122 let next_active_ix = if self.active_alternative == 0 {
123 self.alternatives.len() - 1
124 } else {
125 self.active_alternative - 1
126 };
127 self.activate(next_active_ix, cx);
128 }
129
130 pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
131 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
132 self.activate(next_active_ix, cx);
133 }
134
135 fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
136 self.active_alternative()
137 .update(cx, |codegen, cx| codegen.set_active(false, cx));
138 self.seen_alternatives.insert(index);
139 self.active_alternative = index;
140 self.active_alternative()
141 .update(cx, |codegen, cx| codegen.set_active(true, cx));
142 self.subscribe_to_alternative(cx);
143 cx.notify();
144 }
145
146 pub fn start(
147 &mut self,
148 primary_model: Arc<dyn LanguageModel>,
149 user_prompt: String,
150 cx: &mut Context<Self>,
151 ) -> Result<()> {
152 let alternative_models = LanguageModelRegistry::read_global(cx)
153 .inline_alternative_models()
154 .to_vec();
155
156 self.active_alternative()
157 .update(cx, |alternative, cx| alternative.undo(cx));
158 self.activate(0, cx);
159 self.alternatives.truncate(1);
160
161 for _ in 0..alternative_models.len() {
162 self.alternatives.push(cx.new(|cx| {
163 CodegenAlternative::new(
164 self.buffer.clone(),
165 self.range.clone(),
166 false,
167 Some(self.context_store.clone()),
168 self.project.clone(),
169 self.prompt_store.clone(),
170 Some(self.telemetry.clone()),
171 self.builder.clone(),
172 cx,
173 )
174 }));
175 }
176
177 for (model, alternative) in iter::once(primary_model)
178 .chain(alternative_models)
179 .zip(&self.alternatives)
180 {
181 alternative.update(cx, |alternative, cx| {
182 alternative.start(user_prompt.clone(), model.clone(), cx)
183 })?;
184 }
185
186 Ok(())
187 }
188
189 pub fn stop(&mut self, cx: &mut Context<Self>) {
190 for codegen in &self.alternatives {
191 codegen.update(cx, |codegen, cx| codegen.stop(cx));
192 }
193 }
194
195 pub fn undo(&mut self, cx: &mut Context<Self>) {
196 self.active_alternative()
197 .update(cx, |codegen, cx| codegen.undo(cx));
198
199 self.buffer.update(cx, |buffer, cx| {
200 if let Some(transaction_id) = self.initial_transaction_id.take() {
201 buffer.undo_transaction(transaction_id, cx);
202 buffer.refresh_preview(cx);
203 }
204 });
205 }
206
207 pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
208 self.active_alternative().read(cx).buffer.clone()
209 }
210
211 pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
212 self.active_alternative().read(cx).old_buffer.clone()
213 }
214
215 pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
216 self.active_alternative().read(cx).snapshot.clone()
217 }
218
219 pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
220 self.active_alternative().read(cx).edit_position
221 }
222
223 pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
224 &self.active_alternative().read(cx).diff
225 }
226
227 pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
228 self.active_alternative().read(cx).last_equal_ranges()
229 }
230}
231
232impl EventEmitter<CodegenEvent> for BufferCodegen {}
233
234pub struct CodegenAlternative {
235 buffer: Entity<MultiBuffer>,
236 old_buffer: Entity<Buffer>,
237 snapshot: MultiBufferSnapshot,
238 edit_position: Option<Anchor>,
239 range: Range<Anchor>,
240 last_equal_ranges: Vec<Range<Anchor>>,
241 transformation_transaction_id: Option<TransactionId>,
242 status: CodegenStatus,
243 generation: Task<()>,
244 diff: Diff,
245 context_store: Option<Entity<ContextStore>>,
246 project: WeakEntity<Project>,
247 prompt_store: Option<Entity<PromptStore>>,
248 telemetry: Option<Arc<Telemetry>>,
249 _subscription: gpui::Subscription,
250 builder: Arc<PromptBuilder>,
251 active: bool,
252 edits: Vec<(Range<Anchor>, String)>,
253 line_operations: Vec<LineOperation>,
254 elapsed_time: Option<f64>,
255 completion: Option<String>,
256 pub message_id: Option<String>,
257}
258
259impl EventEmitter<CodegenEvent> for CodegenAlternative {}
260
261impl CodegenAlternative {
262 pub fn new(
263 buffer: Entity<MultiBuffer>,
264 range: Range<Anchor>,
265 active: bool,
266 context_store: Option<Entity<ContextStore>>,
267 project: WeakEntity<Project>,
268 prompt_store: Option<Entity<PromptStore>>,
269 telemetry: Option<Arc<Telemetry>>,
270 builder: Arc<PromptBuilder>,
271 cx: &mut Context<Self>,
272 ) -> Self {
273 let snapshot = buffer.read(cx).snapshot(cx);
274
275 let (old_buffer, _, _) = snapshot
276 .range_to_buffer_ranges(range.clone())
277 .pop()
278 .unwrap();
279 let old_buffer = cx.new(|cx| {
280 let text = old_buffer.as_rope().clone();
281 let line_ending = old_buffer.line_ending();
282 let language = old_buffer.language().cloned();
283 let language_registry = buffer
284 .read(cx)
285 .buffer(old_buffer.remote_id())
286 .unwrap()
287 .read(cx)
288 .language_registry();
289
290 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
291 buffer.set_language(language, cx);
292 if let Some(language_registry) = language_registry {
293 buffer.set_language_registry(language_registry)
294 }
295 buffer
296 });
297
298 Self {
299 buffer: buffer.clone(),
300 old_buffer,
301 edit_position: None,
302 message_id: None,
303 snapshot,
304 last_equal_ranges: Default::default(),
305 transformation_transaction_id: None,
306 status: CodegenStatus::Idle,
307 generation: Task::ready(()),
308 diff: Diff::default(),
309 context_store,
310 project,
311 prompt_store,
312 telemetry,
313 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
314 builder,
315 active,
316 edits: Vec::new(),
317 line_operations: Vec::new(),
318 range,
319 elapsed_time: None,
320 completion: None,
321 }
322 }
323
324 pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
325 if active != self.active {
326 self.active = active;
327
328 if self.active {
329 let edits = self.edits.clone();
330 self.apply_edits(edits, cx);
331 if matches!(self.status, CodegenStatus::Pending) {
332 let line_operations = self.line_operations.clone();
333 self.reapply_line_based_diff(line_operations, cx);
334 } else {
335 self.reapply_batch_diff(cx).detach();
336 }
337 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
338 self.buffer.update(cx, |buffer, cx| {
339 buffer.undo_transaction(transaction_id, cx);
340 buffer.forget_transaction(transaction_id, cx);
341 });
342 }
343 }
344 }
345
346 fn handle_buffer_event(
347 &mut self,
348 _buffer: Entity<MultiBuffer>,
349 event: &multi_buffer::Event,
350 cx: &mut Context<Self>,
351 ) {
352 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
353 if self.transformation_transaction_id == Some(*transaction_id) {
354 self.transformation_transaction_id = None;
355 self.generation = Task::ready(());
356 cx.emit(CodegenEvent::Undone);
357 }
358 }
359 }
360
361 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
362 &self.last_equal_ranges
363 }
364
365 pub fn start(
366 &mut self,
367 user_prompt: String,
368 model: Arc<dyn LanguageModel>,
369 cx: &mut Context<Self>,
370 ) -> Result<()> {
371 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
372 self.buffer.update(cx, |buffer, cx| {
373 buffer.undo_transaction(transformation_transaction_id, cx);
374 });
375 }
376
377 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
378
379 let api_key = model.api_key(cx);
380 let telemetry_id = model.telemetry_id();
381 let provider_id = model.provider_id();
382 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
383 if user_prompt.trim().to_lowercase() == "delete" {
384 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
385 } else {
386 let request = self.build_request(user_prompt, cx)?;
387 cx.spawn(async move |_, cx| model.stream_completion_text(request.await, &cx).await)
388 .boxed_local()
389 };
390 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
391 Ok(())
392 }
393
394 fn build_request(
395 &self,
396 user_prompt: String,
397 cx: &mut App,
398 ) -> Result<Task<LanguageModelRequest>> {
399 let buffer = self.buffer.read(cx).snapshot(cx);
400 let language = buffer.language_at(self.range.start);
401 let language_name = if let Some(language) = language.as_ref() {
402 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
403 None
404 } else {
405 Some(language.name())
406 }
407 } else {
408 None
409 };
410
411 let language_name = language_name.as_ref();
412 let start = buffer.point_to_buffer_offset(self.range.start);
413 let end = buffer.point_to_buffer_offset(self.range.end);
414 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
415 let (start_buffer, start_buffer_offset) = start;
416 let (end_buffer, end_buffer_offset) = end;
417 if start_buffer.remote_id() == end_buffer.remote_id() {
418 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
419 } else {
420 return Err(anyhow::anyhow!("invalid transformation range"));
421 }
422 } else {
423 return Err(anyhow::anyhow!("invalid transformation range"));
424 };
425
426 let prompt = self
427 .builder
428 .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
429 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
430
431 let context_task = self.context_store.as_ref().map(|context_store| {
432 if let Some(project) = self.project.upgrade() {
433 let context = context_store
434 .read(cx)
435 .context()
436 .cloned()
437 .collect::<Vec<_>>();
438 load_context(context, &project, &self.prompt_store, cx)
439 } else {
440 Task::ready(ContextLoadResult::default())
441 }
442 });
443
444 Ok(cx.spawn(async move |_cx| {
445 let mut request_message = LanguageModelRequestMessage {
446 role: Role::User,
447 content: Vec::new(),
448 cache: false,
449 };
450
451 if let Some(context_task) = context_task {
452 context_task
453 .await
454 .loaded_context
455 .add_to_request_message(&mut request_message);
456 }
457
458 request_message.content.push(prompt.into());
459
460 LanguageModelRequest {
461 thread_id: None,
462 prompt_id: None,
463 tools: Vec::new(),
464 stop: Vec::new(),
465 temperature: None,
466 messages: vec![request_message],
467 }
468 }))
469 }
470
471 pub fn handle_stream(
472 &mut self,
473 model_telemetry_id: String,
474 model_provider_id: String,
475 model_api_key: Option<String>,
476 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
477 cx: &mut Context<Self>,
478 ) {
479 let start_time = Instant::now();
480 let snapshot = self.snapshot.clone();
481 let selected_text = snapshot
482 .text_for_range(self.range.start..self.range.end)
483 .collect::<Rope>();
484
485 let selection_start = self.range.start.to_point(&snapshot);
486
487 // Start with the indentation of the first line in the selection
488 let mut suggested_line_indent = snapshot
489 .suggested_indents(selection_start.row..=selection_start.row, cx)
490 .into_values()
491 .next()
492 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
493
494 // If the first line in the selection does not have indentation, check the following lines
495 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
496 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
497 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
498 // Prefer tabs if a line in the selection uses tabs as indentation
499 if line_indent.kind == IndentKind::Tab {
500 suggested_line_indent.kind = IndentKind::Tab;
501 break;
502 }
503 }
504 }
505
506 let http_client = cx.http_client().clone();
507 let telemetry = self.telemetry.clone();
508 let language_name = {
509 let multibuffer = self.buffer.read(cx);
510 let snapshot = multibuffer.snapshot(cx);
511 let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
512 ranges
513 .first()
514 .and_then(|(buffer, _, _)| buffer.language())
515 .map(|language| language.name())
516 };
517
518 self.diff = Diff::default();
519 self.status = CodegenStatus::Pending;
520 let mut edit_start = self.range.start.to_offset(&snapshot);
521 let completion = Arc::new(Mutex::new(String::new()));
522 let completion_clone = completion.clone();
523
524 self.generation = cx.spawn(async move |codegen, cx| {
525 let stream = stream.await;
526 let token_usage = stream
527 .as_ref()
528 .ok()
529 .map(|stream| stream.last_token_usage.clone());
530 let message_id = stream
531 .as_ref()
532 .ok()
533 .and_then(|stream| stream.message_id.clone());
534 let generate = async {
535 let model_telemetry_id = model_telemetry_id.clone();
536 let model_provider_id = model_provider_id.clone();
537 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
538 let executor = cx.background_executor().clone();
539 let message_id = message_id.clone();
540 let line_based_stream_diff: Task<anyhow::Result<()>> =
541 cx.background_spawn(async move {
542 let mut response_latency = None;
543 let request_start = Instant::now();
544 let diff = async {
545 let chunks = StripInvalidSpans::new(
546 stream?.stream.map_err(|error| error.into()),
547 );
548 futures::pin_mut!(chunks);
549 let mut diff = StreamingDiff::new(selected_text.to_string());
550 let mut line_diff = LineDiff::default();
551
552 let mut new_text = String::new();
553 let mut base_indent = None;
554 let mut line_indent = None;
555 let mut first_line = true;
556
557 while let Some(chunk) = chunks.next().await {
558 if response_latency.is_none() {
559 response_latency = Some(request_start.elapsed());
560 }
561 let chunk = chunk?;
562 completion_clone.lock().push_str(&chunk);
563
564 let mut lines = chunk.split('\n').peekable();
565 while let Some(line) = lines.next() {
566 new_text.push_str(line);
567 if line_indent.is_none() {
568 if let Some(non_whitespace_ch_ix) =
569 new_text.find(|ch: char| !ch.is_whitespace())
570 {
571 line_indent = Some(non_whitespace_ch_ix);
572 base_indent = base_indent.or(line_indent);
573
574 let line_indent = line_indent.unwrap();
575 let base_indent = base_indent.unwrap();
576 let indent_delta =
577 line_indent as i32 - base_indent as i32;
578 let mut corrected_indent_len = cmp::max(
579 0,
580 suggested_line_indent.len as i32 + indent_delta,
581 )
582 as usize;
583 if first_line {
584 corrected_indent_len = corrected_indent_len
585 .saturating_sub(
586 selection_start.column as usize,
587 );
588 }
589
590 let indent_char = suggested_line_indent.char();
591 let mut indent_buffer = [0; 4];
592 let indent_str =
593 indent_char.encode_utf8(&mut indent_buffer);
594 new_text.replace_range(
595 ..line_indent,
596 &indent_str.repeat(corrected_indent_len),
597 );
598 }
599 }
600
601 if line_indent.is_some() {
602 let char_ops = diff.push_new(&new_text);
603 line_diff.push_char_operations(&char_ops, &selected_text);
604 diff_tx
605 .send((char_ops, line_diff.line_operations()))
606 .await?;
607 new_text.clear();
608 }
609
610 if lines.peek().is_some() {
611 let char_ops = diff.push_new("\n");
612 line_diff.push_char_operations(&char_ops, &selected_text);
613 diff_tx
614 .send((char_ops, line_diff.line_operations()))
615 .await?;
616 if line_indent.is_none() {
617 // Don't write out the leading indentation in empty lines on the next line
618 // This is the case where the above if statement didn't clear the buffer
619 new_text.clear();
620 }
621 line_indent = None;
622 first_line = false;
623 }
624 }
625 }
626
627 let mut char_ops = diff.push_new(&new_text);
628 char_ops.extend(diff.finish());
629 line_diff.push_char_operations(&char_ops, &selected_text);
630 line_diff.finish(&selected_text);
631 diff_tx
632 .send((char_ops, line_diff.line_operations()))
633 .await?;
634
635 anyhow::Ok(())
636 };
637
638 let result = diff.await;
639
640 let error_message = result.as_ref().err().map(|error| error.to_string());
641 report_assistant_event(
642 AssistantEventData {
643 conversation_id: None,
644 message_id,
645 kind: AssistantKind::Inline,
646 phase: AssistantPhase::Response,
647 model: model_telemetry_id,
648 model_provider: model_provider_id,
649 response_latency,
650 error_message,
651 language_name: language_name.map(|name| name.to_proto()),
652 },
653 telemetry,
654 http_client,
655 model_api_key,
656 &executor,
657 );
658
659 result?;
660 Ok(())
661 });
662
663 while let Some((char_ops, line_ops)) = diff_rx.next().await {
664 codegen.update(cx, |codegen, cx| {
665 codegen.last_equal_ranges.clear();
666
667 let edits = char_ops
668 .into_iter()
669 .filter_map(|operation| match operation {
670 CharOperation::Insert { text } => {
671 let edit_start = snapshot.anchor_after(edit_start);
672 Some((edit_start..edit_start, text))
673 }
674 CharOperation::Delete { bytes } => {
675 let edit_end = edit_start + bytes;
676 let edit_range = snapshot.anchor_after(edit_start)
677 ..snapshot.anchor_before(edit_end);
678 edit_start = edit_end;
679 Some((edit_range, String::new()))
680 }
681 CharOperation::Keep { bytes } => {
682 let edit_end = edit_start + bytes;
683 let edit_range = snapshot.anchor_after(edit_start)
684 ..snapshot.anchor_before(edit_end);
685 edit_start = edit_end;
686 codegen.last_equal_ranges.push(edit_range);
687 None
688 }
689 })
690 .collect::<Vec<_>>();
691
692 if codegen.active {
693 codegen.apply_edits(edits.iter().cloned(), cx);
694 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
695 }
696 codegen.edits.extend(edits);
697 codegen.line_operations = line_ops;
698 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
699
700 cx.notify();
701 })?;
702 }
703
704 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
705 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
706 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
707 let batch_diff_task =
708 codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
709 let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
710 line_based_stream_diff?;
711
712 anyhow::Ok(())
713 };
714
715 let result = generate.await;
716 let elapsed_time = start_time.elapsed().as_secs_f64();
717
718 codegen
719 .update(cx, |this, cx| {
720 this.message_id = message_id;
721 this.last_equal_ranges.clear();
722 if let Err(error) = result {
723 this.status = CodegenStatus::Error(error);
724 } else {
725 this.status = CodegenStatus::Done;
726 }
727 this.elapsed_time = Some(elapsed_time);
728 this.completion = Some(completion.lock().clone());
729 if let Some(usage) = token_usage {
730 let usage = usage.lock();
731 telemetry::event!(
732 "Inline Assistant Completion",
733 model = model_telemetry_id,
734 model_provider = model_provider_id,
735 input_tokens = usage.input_tokens,
736 output_tokens = usage.output_tokens,
737 )
738 }
739 cx.emit(CodegenEvent::Finished);
740 cx.notify();
741 })
742 .ok();
743 });
744 cx.notify();
745 }
746
747 pub fn stop(&mut self, cx: &mut Context<Self>) {
748 self.last_equal_ranges.clear();
749 if self.diff.is_empty() {
750 self.status = CodegenStatus::Idle;
751 } else {
752 self.status = CodegenStatus::Done;
753 }
754 self.generation = Task::ready(());
755 cx.emit(CodegenEvent::Finished);
756 cx.notify();
757 }
758
759 pub fn undo(&mut self, cx: &mut Context<Self>) {
760 self.buffer.update(cx, |buffer, cx| {
761 if let Some(transaction_id) = self.transformation_transaction_id.take() {
762 buffer.undo_transaction(transaction_id, cx);
763 buffer.refresh_preview(cx);
764 }
765 });
766 }
767
768 fn apply_edits(
769 &mut self,
770 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
771 cx: &mut Context<CodegenAlternative>,
772 ) {
773 let transaction = self.buffer.update(cx, |buffer, cx| {
774 // Avoid grouping assistant edits with user edits.
775 buffer.finalize_last_transaction(cx);
776 buffer.start_transaction(cx);
777 buffer.edit(edits, None, cx);
778 buffer.end_transaction(cx)
779 });
780
781 if let Some(transaction) = transaction {
782 if let Some(first_transaction) = self.transformation_transaction_id {
783 // Group all assistant edits into the first transaction.
784 self.buffer.update(cx, |buffer, cx| {
785 buffer.merge_transactions(transaction, first_transaction, cx)
786 });
787 } else {
788 self.transformation_transaction_id = Some(transaction);
789 self.buffer
790 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
791 }
792 }
793 }
794
795 fn reapply_line_based_diff(
796 &mut self,
797 line_operations: impl IntoIterator<Item = LineOperation>,
798 cx: &mut Context<Self>,
799 ) {
800 let old_snapshot = self.snapshot.clone();
801 let old_range = self.range.to_point(&old_snapshot);
802 let new_snapshot = self.buffer.read(cx).snapshot(cx);
803 let new_range = self.range.to_point(&new_snapshot);
804
805 let mut old_row = old_range.start.row;
806 let mut new_row = new_range.start.row;
807
808 self.diff.deleted_row_ranges.clear();
809 self.diff.inserted_row_ranges.clear();
810 for operation in line_operations {
811 match operation {
812 LineOperation::Keep { lines } => {
813 old_row += lines;
814 new_row += lines;
815 }
816 LineOperation::Delete { lines } => {
817 let old_end_row = old_row + lines - 1;
818 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
819
820 if let Some((_, last_deleted_row_range)) =
821 self.diff.deleted_row_ranges.last_mut()
822 {
823 if *last_deleted_row_range.end() + 1 == old_row {
824 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
825 } else {
826 self.diff
827 .deleted_row_ranges
828 .push((new_row, old_row..=old_end_row));
829 }
830 } else {
831 self.diff
832 .deleted_row_ranges
833 .push((new_row, old_row..=old_end_row));
834 }
835
836 old_row += lines;
837 }
838 LineOperation::Insert { lines } => {
839 let new_end_row = new_row + lines - 1;
840 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
841 let end = new_snapshot.anchor_before(Point::new(
842 new_end_row,
843 new_snapshot.line_len(MultiBufferRow(new_end_row)),
844 ));
845 self.diff.inserted_row_ranges.push(start..end);
846 new_row += lines;
847 }
848 }
849
850 cx.notify();
851 }
852 }
853
854 fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
855 let old_snapshot = self.snapshot.clone();
856 let old_range = self.range.to_point(&old_snapshot);
857 let new_snapshot = self.buffer.read(cx).snapshot(cx);
858 let new_range = self.range.to_point(&new_snapshot);
859
860 cx.spawn(async move |codegen, cx| {
861 let (deleted_row_ranges, inserted_row_ranges) = cx
862 .background_spawn(async move {
863 let old_text = old_snapshot
864 .text_for_range(
865 Point::new(old_range.start.row, 0)
866 ..Point::new(
867 old_range.end.row,
868 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
869 ),
870 )
871 .collect::<String>();
872 let new_text = new_snapshot
873 .text_for_range(
874 Point::new(new_range.start.row, 0)
875 ..Point::new(
876 new_range.end.row,
877 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
878 ),
879 )
880 .collect::<String>();
881
882 let old_start_row = old_range.start.row;
883 let new_start_row = new_range.start.row;
884 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
885 let mut inserted_row_ranges = Vec::new();
886 for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
887 let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
888 let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
889 if !old_rows.is_empty() {
890 deleted_row_ranges.push((
891 new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
892 old_rows.start..=old_rows.end - 1,
893 ));
894 }
895 if !new_rows.is_empty() {
896 let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
897 let new_end_row = new_rows.end - 1;
898 let end = new_snapshot.anchor_before(Point::new(
899 new_end_row,
900 new_snapshot.line_len(MultiBufferRow(new_end_row)),
901 ));
902 inserted_row_ranges.push(start..end);
903 }
904 }
905 (deleted_row_ranges, inserted_row_ranges)
906 })
907 .await;
908
909 codegen
910 .update(cx, |codegen, cx| {
911 codegen.diff.deleted_row_ranges = deleted_row_ranges;
912 codegen.diff.inserted_row_ranges = inserted_row_ranges;
913 cx.notify();
914 })
915 .ok();
916 })
917 }
918}
919
920#[derive(Copy, Clone, Debug)]
921pub enum CodegenEvent {
922 Finished,
923 Undone,
924}
925
926struct StripInvalidSpans<T> {
927 stream: T,
928 stream_done: bool,
929 buffer: String,
930 first_line: bool,
931 line_end: bool,
932 starts_with_code_block: bool,
933}
934
935impl<T> StripInvalidSpans<T>
936where
937 T: Stream<Item = Result<String>>,
938{
939 fn new(stream: T) -> Self {
940 Self {
941 stream,
942 stream_done: false,
943 buffer: String::new(),
944 first_line: true,
945 line_end: false,
946 starts_with_code_block: false,
947 }
948 }
949}
950
951impl<T> Stream for StripInvalidSpans<T>
952where
953 T: Stream<Item = Result<String>>,
954{
955 type Item = Result<String>;
956
957 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
958 const CODE_BLOCK_DELIMITER: &str = "```";
959 const CURSOR_SPAN: &str = "<|CURSOR|>";
960
961 let this = unsafe { self.get_unchecked_mut() };
962 loop {
963 if !this.stream_done {
964 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
965 match stream.as_mut().poll_next(cx) {
966 Poll::Ready(Some(Ok(chunk))) => {
967 this.buffer.push_str(&chunk);
968 }
969 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
970 Poll::Ready(None) => {
971 this.stream_done = true;
972 }
973 Poll::Pending => return Poll::Pending,
974 }
975 }
976
977 let mut chunk = String::new();
978 let mut consumed = 0;
979 if !this.buffer.is_empty() {
980 let mut lines = this.buffer.split('\n').enumerate().peekable();
981 while let Some((line_ix, line)) = lines.next() {
982 if line_ix > 0 {
983 this.first_line = false;
984 }
985
986 if this.first_line {
987 let trimmed_line = line.trim();
988 if lines.peek().is_some() {
989 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
990 consumed += line.len() + 1;
991 this.starts_with_code_block = true;
992 continue;
993 }
994 } else if trimmed_line.is_empty()
995 || prefixes(CODE_BLOCK_DELIMITER)
996 .any(|prefix| trimmed_line.starts_with(prefix))
997 {
998 break;
999 }
1000 }
1001
1002 let line_without_cursor = line.replace(CURSOR_SPAN, "");
1003 if lines.peek().is_some() {
1004 if this.line_end {
1005 chunk.push('\n');
1006 }
1007
1008 chunk.push_str(&line_without_cursor);
1009 this.line_end = true;
1010 consumed += line.len() + 1;
1011 } else if this.stream_done {
1012 if !this.starts_with_code_block
1013 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1014 {
1015 if this.line_end {
1016 chunk.push('\n');
1017 }
1018
1019 chunk.push_str(&line);
1020 }
1021
1022 consumed += line.len();
1023 } else {
1024 let trimmed_line = line.trim();
1025 if trimmed_line.is_empty()
1026 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1027 || prefixes(CODE_BLOCK_DELIMITER)
1028 .any(|prefix| trimmed_line.ends_with(prefix))
1029 {
1030 break;
1031 } else {
1032 if this.line_end {
1033 chunk.push('\n');
1034 this.line_end = false;
1035 }
1036
1037 chunk.push_str(&line_without_cursor);
1038 consumed += line.len();
1039 }
1040 }
1041 }
1042 }
1043
1044 this.buffer = this.buffer.split_off(consumed);
1045 if !chunk.is_empty() {
1046 return Poll::Ready(Some(Ok(chunk)));
1047 } else if this.stream_done {
1048 return Poll::Ready(None);
1049 }
1050 }
1051 }
1052}
1053
1054fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1055 (0..text.len() - 1).map(|ix| &text[..ix + 1])
1056}
1057
1058#[derive(Default)]
1059pub struct Diff {
1060 pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1061 pub inserted_row_ranges: Vec<Range<Anchor>>,
1062}
1063
1064impl Diff {
1065 fn is_empty(&self) -> bool {
1066 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1067 }
1068}
1069
1070#[cfg(test)]
1071mod tests {
1072 use super::*;
1073 use fs::FakeFs;
1074 use futures::{
1075 Stream,
1076 stream::{self},
1077 };
1078 use gpui::TestAppContext;
1079 use indoc::indoc;
1080 use language::{
1081 Buffer, Language, LanguageConfig, LanguageMatcher, Point, language_settings,
1082 tree_sitter_rust,
1083 };
1084 use language_model::{LanguageModelRegistry, TokenUsage};
1085 use rand::prelude::*;
1086 use serde::Serialize;
1087 use settings::SettingsStore;
1088 use std::{future, sync::Arc};
1089
1090 #[derive(Serialize)]
1091 pub struct DummyCompletionRequest {
1092 pub name: String,
1093 }
1094
1095 #[gpui::test(iterations = 10)]
1096 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1097 cx.set_global(cx.update(SettingsStore::test));
1098 cx.update(language_model::LanguageModelRegistry::test);
1099 cx.update(language_settings::init);
1100
1101 let text = indoc! {"
1102 fn main() {
1103 let x = 0;
1104 for _ in 0..10 {
1105 x += 1;
1106 }
1107 }
1108 "};
1109 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1110 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1111 let range = buffer.read_with(cx, |buffer, cx| {
1112 let snapshot = buffer.snapshot(cx);
1113 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1114 });
1115 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1116 let fs = FakeFs::new(cx.executor());
1117 let project = Project::test(fs, vec![], cx).await;
1118 let codegen = cx.new(|cx| {
1119 CodegenAlternative::new(
1120 buffer.clone(),
1121 range.clone(),
1122 true,
1123 None,
1124 project.downgrade(),
1125 None,
1126 None,
1127 prompt_builder,
1128 cx,
1129 )
1130 });
1131
1132 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1133
1134 let mut new_text = concat!(
1135 " let mut x = 0;\n",
1136 " while x < 10 {\n",
1137 " x += 1;\n",
1138 " }",
1139 );
1140 while !new_text.is_empty() {
1141 let max_len = cmp::min(new_text.len(), 10);
1142 let len = rng.gen_range(1..=max_len);
1143 let (chunk, suffix) = new_text.split_at(len);
1144 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1145 new_text = suffix;
1146 cx.background_executor.run_until_parked();
1147 }
1148 drop(chunks_tx);
1149 cx.background_executor.run_until_parked();
1150
1151 assert_eq!(
1152 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1153 indoc! {"
1154 fn main() {
1155 let mut x = 0;
1156 while x < 10 {
1157 x += 1;
1158 }
1159 }
1160 "}
1161 );
1162 }
1163
1164 #[gpui::test(iterations = 10)]
1165 async fn test_autoindent_when_generating_past_indentation(
1166 cx: &mut TestAppContext,
1167 mut rng: StdRng,
1168 ) {
1169 cx.set_global(cx.update(SettingsStore::test));
1170 cx.update(language_settings::init);
1171
1172 let text = indoc! {"
1173 fn main() {
1174 le
1175 }
1176 "};
1177 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1178 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1179 let range = buffer.read_with(cx, |buffer, cx| {
1180 let snapshot = buffer.snapshot(cx);
1181 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1182 });
1183 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1184 let fs = FakeFs::new(cx.executor());
1185 let project = Project::test(fs, vec![], cx).await;
1186 let codegen = cx.new(|cx| {
1187 CodegenAlternative::new(
1188 buffer.clone(),
1189 range.clone(),
1190 true,
1191 None,
1192 project.downgrade(),
1193 None,
1194 None,
1195 prompt_builder,
1196 cx,
1197 )
1198 });
1199
1200 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1201
1202 cx.background_executor.run_until_parked();
1203
1204 let mut new_text = concat!(
1205 "t mut x = 0;\n",
1206 "while x < 10 {\n",
1207 " x += 1;\n",
1208 "}", //
1209 );
1210 while !new_text.is_empty() {
1211 let max_len = cmp::min(new_text.len(), 10);
1212 let len = rng.gen_range(1..=max_len);
1213 let (chunk, suffix) = new_text.split_at(len);
1214 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1215 new_text = suffix;
1216 cx.background_executor.run_until_parked();
1217 }
1218 drop(chunks_tx);
1219 cx.background_executor.run_until_parked();
1220
1221 assert_eq!(
1222 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1223 indoc! {"
1224 fn main() {
1225 let mut x = 0;
1226 while x < 10 {
1227 x += 1;
1228 }
1229 }
1230 "}
1231 );
1232 }
1233
1234 #[gpui::test(iterations = 10)]
1235 async fn test_autoindent_when_generating_before_indentation(
1236 cx: &mut TestAppContext,
1237 mut rng: StdRng,
1238 ) {
1239 cx.update(LanguageModelRegistry::test);
1240 cx.set_global(cx.update(SettingsStore::test));
1241 cx.update(language_settings::init);
1242
1243 let text = concat!(
1244 "fn main() {\n",
1245 " \n",
1246 "}\n" //
1247 );
1248 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1249 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1250 let range = buffer.read_with(cx, |buffer, cx| {
1251 let snapshot = buffer.snapshot(cx);
1252 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1253 });
1254 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1255 let fs = FakeFs::new(cx.executor());
1256 let project = Project::test(fs, vec![], cx).await;
1257 let codegen = cx.new(|cx| {
1258 CodegenAlternative::new(
1259 buffer.clone(),
1260 range.clone(),
1261 true,
1262 None,
1263 project.downgrade(),
1264 None,
1265 None,
1266 prompt_builder,
1267 cx,
1268 )
1269 });
1270
1271 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1272
1273 cx.background_executor.run_until_parked();
1274
1275 let mut new_text = concat!(
1276 "let mut x = 0;\n",
1277 "while x < 10 {\n",
1278 " x += 1;\n",
1279 "}", //
1280 );
1281 while !new_text.is_empty() {
1282 let max_len = cmp::min(new_text.len(), 10);
1283 let len = rng.gen_range(1..=max_len);
1284 let (chunk, suffix) = new_text.split_at(len);
1285 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1286 new_text = suffix;
1287 cx.background_executor.run_until_parked();
1288 }
1289 drop(chunks_tx);
1290 cx.background_executor.run_until_parked();
1291
1292 assert_eq!(
1293 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1294 indoc! {"
1295 fn main() {
1296 let mut x = 0;
1297 while x < 10 {
1298 x += 1;
1299 }
1300 }
1301 "}
1302 );
1303 }
1304
1305 #[gpui::test(iterations = 10)]
1306 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1307 cx.update(LanguageModelRegistry::test);
1308 cx.set_global(cx.update(SettingsStore::test));
1309 cx.update(language_settings::init);
1310
1311 let text = indoc! {"
1312 func main() {
1313 \tx := 0
1314 \tfor i := 0; i < 10; i++ {
1315 \t\tx++
1316 \t}
1317 }
1318 "};
1319 let buffer = cx.new(|cx| Buffer::local(text, cx));
1320 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1321 let range = buffer.read_with(cx, |buffer, cx| {
1322 let snapshot = buffer.snapshot(cx);
1323 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1324 });
1325 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1326 let fs = FakeFs::new(cx.executor());
1327 let project = Project::test(fs, vec![], cx).await;
1328 let codegen = cx.new(|cx| {
1329 CodegenAlternative::new(
1330 buffer.clone(),
1331 range.clone(),
1332 true,
1333 None,
1334 project.downgrade(),
1335 None,
1336 None,
1337 prompt_builder,
1338 cx,
1339 )
1340 });
1341
1342 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1343 let new_text = concat!(
1344 "func main() {\n",
1345 "\tx := 0\n",
1346 "\tfor x < 10 {\n",
1347 "\t\tx++\n",
1348 "\t}", //
1349 );
1350 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1351 drop(chunks_tx);
1352 cx.background_executor.run_until_parked();
1353
1354 assert_eq!(
1355 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1356 indoc! {"
1357 func main() {
1358 \tx := 0
1359 \tfor x < 10 {
1360 \t\tx++
1361 \t}
1362 }
1363 "}
1364 );
1365 }
1366
1367 #[gpui::test]
1368 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1369 cx.update(LanguageModelRegistry::test);
1370 cx.set_global(cx.update(SettingsStore::test));
1371 cx.update(language_settings::init);
1372
1373 let text = indoc! {"
1374 fn main() {
1375 let x = 0;
1376 }
1377 "};
1378 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1379 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1380 let range = buffer.read_with(cx, |buffer, cx| {
1381 let snapshot = buffer.snapshot(cx);
1382 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1383 });
1384 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1385 let fs = FakeFs::new(cx.executor());
1386 let project = Project::test(fs, vec![], cx).await;
1387 let codegen = cx.new(|cx| {
1388 CodegenAlternative::new(
1389 buffer.clone(),
1390 range.clone(),
1391 false,
1392 None,
1393 project.downgrade(),
1394 None,
1395 None,
1396 prompt_builder,
1397 cx,
1398 )
1399 });
1400
1401 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1402 chunks_tx
1403 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1404 .unwrap();
1405 drop(chunks_tx);
1406 cx.run_until_parked();
1407
1408 // The codegen is inactive, so the buffer doesn't get modified.
1409 assert_eq!(
1410 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1411 text
1412 );
1413
1414 // Activating the codegen applies the changes.
1415 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1416 assert_eq!(
1417 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1418 indoc! {"
1419 fn main() {
1420 let mut x = 0;
1421 x += 1;
1422 }
1423 "}
1424 );
1425
1426 // Deactivating the codegen undoes the changes.
1427 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1428 cx.run_until_parked();
1429 assert_eq!(
1430 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1431 text
1432 );
1433 }
1434
1435 #[gpui::test]
1436 async fn test_strip_invalid_spans_from_codeblock() {
1437 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1438 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1439 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1440 assert_chunks(
1441 "```html\n```js\nLorem ipsum dolor\n```\n```",
1442 "```js\nLorem ipsum dolor\n```",
1443 )
1444 .await;
1445 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1446 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1447 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1448 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1449
1450 async fn assert_chunks(text: &str, expected_text: &str) {
1451 for chunk_size in 1..=text.len() {
1452 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1453 .map(|chunk| chunk.unwrap())
1454 .collect::<String>()
1455 .await;
1456 assert_eq!(
1457 actual_text, expected_text,
1458 "failed to strip invalid spans, chunk size: {}",
1459 chunk_size
1460 );
1461 }
1462 }
1463
1464 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1465 stream::iter(
1466 text.chars()
1467 .collect::<Vec<_>>()
1468 .chunks(size)
1469 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1470 .collect::<Vec<_>>(),
1471 )
1472 }
1473 }
1474
1475 fn simulate_response_stream(
1476 codegen: Entity<CodegenAlternative>,
1477 cx: &mut TestAppContext,
1478 ) -> mpsc::UnboundedSender<String> {
1479 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1480 codegen.update(cx, |codegen, cx| {
1481 codegen.handle_stream(
1482 String::new(),
1483 String::new(),
1484 None,
1485 future::ready(Ok(LanguageModelTextStream {
1486 message_id: None,
1487 stream: chunks_rx.map(Ok).boxed(),
1488 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1489 })),
1490 cx,
1491 );
1492 });
1493 chunks_tx
1494 }
1495
1496 fn rust_lang() -> Language {
1497 Language::new(
1498 LanguageConfig {
1499 name: "Rust".into(),
1500 matcher: LanguageMatcher {
1501 path_suffixes: vec!["rs".to_string()],
1502 ..Default::default()
1503 },
1504 ..Default::default()
1505 },
1506 Some(tree_sitter_rust::LANGUAGE.into()),
1507 )
1508 .with_indents_query(
1509 r#"
1510 (call_expression) @indent
1511 (field_expression) @indent
1512 (_ "(" ")" @end) @indent
1513 (_ "{" "}" @end) @indent
1514 "#,
1515 )
1516 .unwrap()
1517 }
1518}