1use crate::context::attach_context_to_message;
2use crate::context_store::ContextStore;
3use crate::inline_prompt_editor::CodegenStatus;
4use crate::{
5 prompts::PromptBuilder,
6 streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff},
7};
8use anyhow::{Context as _, Result};
9use client::telemetry::Telemetry;
10use collections::HashSet;
11use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
12use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, StreamExt};
13use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscription, Task};
14use language::{Buffer, IndentKind, Point, TransactionId};
15use language_model::{
16 LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
17 LanguageModelTextStream, Role,
18};
19use language_models::report_assistant_event;
20use multi_buffer::MultiBufferRow;
21use parking_lot::Mutex;
22use rope::Rope;
23use smol::future::FutureExt;
24use std::{
25 cmp,
26 future::Future,
27 iter,
28 ops::{Range, RangeInclusive},
29 pin::Pin,
30 sync::Arc,
31 task::{self, Poll},
32 time::Instant,
33};
34use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
35
36pub struct BufferCodegen {
37 alternatives: Vec<Model<CodegenAlternative>>,
38 pub active_alternative: usize,
39 seen_alternatives: HashSet<usize>,
40 subscriptions: Vec<Subscription>,
41 buffer: Model<MultiBuffer>,
42 range: Range<Anchor>,
43 initial_transaction_id: Option<TransactionId>,
44 context_store: Model<ContextStore>,
45 telemetry: Arc<Telemetry>,
46 builder: Arc<PromptBuilder>,
47 pub is_insertion: bool,
48}
49
50impl BufferCodegen {
51 pub fn new(
52 buffer: Model<MultiBuffer>,
53 range: Range<Anchor>,
54 initial_transaction_id: Option<TransactionId>,
55 context_store: Model<ContextStore>,
56 telemetry: Arc<Telemetry>,
57 builder: Arc<PromptBuilder>,
58 cx: &mut ModelContext<Self>,
59 ) -> Self {
60 let codegen = cx.new_model(|cx| {
61 CodegenAlternative::new(
62 buffer.clone(),
63 range.clone(),
64 false,
65 Some(context_store.clone()),
66 Some(telemetry.clone()),
67 builder.clone(),
68 cx,
69 )
70 });
71 let mut this = Self {
72 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
73 alternatives: vec![codegen],
74 active_alternative: 0,
75 seen_alternatives: HashSet::default(),
76 subscriptions: Vec::new(),
77 buffer,
78 range,
79 initial_transaction_id,
80 context_store,
81 telemetry,
82 builder,
83 };
84 this.activate(0, cx);
85 this
86 }
87
88 fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
89 let codegen = self.active_alternative().clone();
90 self.subscriptions.clear();
91 self.subscriptions
92 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
93 self.subscriptions
94 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
95 }
96
97 pub fn active_alternative(&self) -> &Model<CodegenAlternative> {
98 &self.alternatives[self.active_alternative]
99 }
100
101 pub fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
102 &self.active_alternative().read(cx).status
103 }
104
105 pub fn alternative_count(&self, cx: &AppContext) -> usize {
106 LanguageModelRegistry::read_global(cx)
107 .inline_alternative_models()
108 .len()
109 + 1
110 }
111
112 pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
113 let next_active_ix = if self.active_alternative == 0 {
114 self.alternatives.len() - 1
115 } else {
116 self.active_alternative - 1
117 };
118 self.activate(next_active_ix, cx);
119 }
120
121 pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
122 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
123 self.activate(next_active_ix, cx);
124 }
125
126 fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
127 self.active_alternative()
128 .update(cx, |codegen, cx| codegen.set_active(false, cx));
129 self.seen_alternatives.insert(index);
130 self.active_alternative = index;
131 self.active_alternative()
132 .update(cx, |codegen, cx| codegen.set_active(true, cx));
133 self.subscribe_to_alternative(cx);
134 cx.notify();
135 }
136
137 pub fn start(&mut self, user_prompt: String, cx: &mut ModelContext<Self>) -> Result<()> {
138 let alternative_models = LanguageModelRegistry::read_global(cx)
139 .inline_alternative_models()
140 .to_vec();
141
142 self.active_alternative()
143 .update(cx, |alternative, cx| alternative.undo(cx));
144 self.activate(0, cx);
145 self.alternatives.truncate(1);
146
147 for _ in 0..alternative_models.len() {
148 self.alternatives.push(cx.new_model(|cx| {
149 CodegenAlternative::new(
150 self.buffer.clone(),
151 self.range.clone(),
152 false,
153 Some(self.context_store.clone()),
154 Some(self.telemetry.clone()),
155 self.builder.clone(),
156 cx,
157 )
158 }));
159 }
160
161 let primary_model = LanguageModelRegistry::read_global(cx)
162 .active_model()
163 .context("no active model")?;
164
165 for (model, alternative) in iter::once(primary_model)
166 .chain(alternative_models)
167 .zip(&self.alternatives)
168 {
169 alternative.update(cx, |alternative, cx| {
170 alternative.start(user_prompt.clone(), model.clone(), cx)
171 })?;
172 }
173
174 Ok(())
175 }
176
177 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
178 for codegen in &self.alternatives {
179 codegen.update(cx, |codegen, cx| codegen.stop(cx));
180 }
181 }
182
183 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
184 self.active_alternative()
185 .update(cx, |codegen, cx| codegen.undo(cx));
186
187 self.buffer.update(cx, |buffer, cx| {
188 if let Some(transaction_id) = self.initial_transaction_id.take() {
189 buffer.undo_transaction(transaction_id, cx);
190 buffer.refresh_preview(cx);
191 }
192 });
193 }
194
195 pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
196 self.active_alternative().read(cx).buffer.clone()
197 }
198
199 pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
200 self.active_alternative().read(cx).old_buffer.clone()
201 }
202
203 pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
204 self.active_alternative().read(cx).snapshot.clone()
205 }
206
207 pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
208 self.active_alternative().read(cx).edit_position
209 }
210
211 pub fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
212 &self.active_alternative().read(cx).diff
213 }
214
215 pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
216 self.active_alternative().read(cx).last_equal_ranges()
217 }
218}
219
220impl EventEmitter<CodegenEvent> for BufferCodegen {}
221
222pub struct CodegenAlternative {
223 buffer: Model<MultiBuffer>,
224 old_buffer: Model<Buffer>,
225 snapshot: MultiBufferSnapshot,
226 edit_position: Option<Anchor>,
227 range: Range<Anchor>,
228 last_equal_ranges: Vec<Range<Anchor>>,
229 transformation_transaction_id: Option<TransactionId>,
230 status: CodegenStatus,
231 generation: Task<()>,
232 diff: Diff,
233 context_store: Option<Model<ContextStore>>,
234 telemetry: Option<Arc<Telemetry>>,
235 _subscription: gpui::Subscription,
236 builder: Arc<PromptBuilder>,
237 active: bool,
238 edits: Vec<(Range<Anchor>, String)>,
239 line_operations: Vec<LineOperation>,
240 request: Option<LanguageModelRequest>,
241 elapsed_time: Option<f64>,
242 completion: Option<String>,
243 pub message_id: Option<String>,
244}
245
246impl EventEmitter<CodegenEvent> for CodegenAlternative {}
247
248impl CodegenAlternative {
249 pub fn new(
250 buffer: Model<MultiBuffer>,
251 range: Range<Anchor>,
252 active: bool,
253 context_store: Option<Model<ContextStore>>,
254 telemetry: Option<Arc<Telemetry>>,
255 builder: Arc<PromptBuilder>,
256 cx: &mut ModelContext<Self>,
257 ) -> Self {
258 let snapshot = buffer.read(cx).snapshot(cx);
259
260 let (old_excerpt, _) = snapshot
261 .range_to_buffer_ranges(range.clone())
262 .pop()
263 .unwrap();
264 let old_buffer = cx.new_model(|cx| {
265 let text = old_excerpt.buffer().as_rope().clone();
266 let line_ending = old_excerpt.buffer().line_ending();
267 let language = old_excerpt.buffer().language().cloned();
268 let language_registry = buffer
269 .read(cx)
270 .buffer(old_excerpt.buffer_id())
271 .unwrap()
272 .read(cx)
273 .language_registry();
274
275 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
276 buffer.set_language(language, cx);
277 if let Some(language_registry) = language_registry {
278 buffer.set_language_registry(language_registry)
279 }
280 buffer
281 });
282
283 Self {
284 buffer: buffer.clone(),
285 old_buffer,
286 edit_position: None,
287 message_id: None,
288 snapshot,
289 last_equal_ranges: Default::default(),
290 transformation_transaction_id: None,
291 status: CodegenStatus::Idle,
292 generation: Task::ready(()),
293 diff: Diff::default(),
294 context_store,
295 telemetry,
296 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
297 builder,
298 active,
299 edits: Vec::new(),
300 line_operations: Vec::new(),
301 range,
302 request: None,
303 elapsed_time: None,
304 completion: None,
305 }
306 }
307
308 pub fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
309 if active != self.active {
310 self.active = active;
311
312 if self.active {
313 let edits = self.edits.clone();
314 self.apply_edits(edits, cx);
315 if matches!(self.status, CodegenStatus::Pending) {
316 let line_operations = self.line_operations.clone();
317 self.reapply_line_based_diff(line_operations, cx);
318 } else {
319 self.reapply_batch_diff(cx).detach();
320 }
321 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
322 self.buffer.update(cx, |buffer, cx| {
323 buffer.undo_transaction(transaction_id, cx);
324 buffer.forget_transaction(transaction_id, cx);
325 });
326 }
327 }
328 }
329
330 fn handle_buffer_event(
331 &mut self,
332 _buffer: Model<MultiBuffer>,
333 event: &multi_buffer::Event,
334 cx: &mut ModelContext<Self>,
335 ) {
336 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
337 if self.transformation_transaction_id == Some(*transaction_id) {
338 self.transformation_transaction_id = None;
339 self.generation = Task::ready(());
340 cx.emit(CodegenEvent::Undone);
341 }
342 }
343 }
344
345 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
346 &self.last_equal_ranges
347 }
348
349 pub fn start(
350 &mut self,
351 user_prompt: String,
352 model: Arc<dyn LanguageModel>,
353 cx: &mut ModelContext<Self>,
354 ) -> Result<()> {
355 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
356 self.buffer.update(cx, |buffer, cx| {
357 buffer.undo_transaction(transformation_transaction_id, cx);
358 });
359 }
360
361 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
362
363 let api_key = model.api_key(cx);
364 let telemetry_id = model.telemetry_id();
365 let provider_id = model.provider_id();
366 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
367 if user_prompt.trim().to_lowercase() == "delete" {
368 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
369 } else {
370 let request = self.build_request(user_prompt, cx)?;
371 self.request = Some(request.clone());
372
373 cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
374 .boxed_local()
375 };
376 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
377 Ok(())
378 }
379
380 fn build_request(
381 &self,
382 user_prompt: String,
383 cx: &mut AppContext,
384 ) -> Result<LanguageModelRequest> {
385 let buffer = self.buffer.read(cx).snapshot(cx);
386 let language = buffer.language_at(self.range.start);
387 let language_name = if let Some(language) = language.as_ref() {
388 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
389 None
390 } else {
391 Some(language.name())
392 }
393 } else {
394 None
395 };
396
397 let language_name = language_name.as_ref();
398 let start = buffer.point_to_buffer_offset(self.range.start);
399 let end = buffer.point_to_buffer_offset(self.range.end);
400 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
401 let (start_buffer, start_buffer_offset) = start;
402 let (end_buffer, end_buffer_offset) = end;
403 if start_buffer.remote_id() == end_buffer.remote_id() {
404 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
405 } else {
406 return Err(anyhow::anyhow!("invalid transformation range"));
407 }
408 } else {
409 return Err(anyhow::anyhow!("invalid transformation range"));
410 };
411
412 let prompt = self
413 .builder
414 .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
415 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
416
417 let mut request_message = LanguageModelRequestMessage {
418 role: Role::User,
419 content: Vec::new(),
420 cache: false,
421 };
422
423 if let Some(context_store) = &self.context_store {
424 attach_context_to_message(&mut request_message, context_store.read(cx).snapshot(cx));
425 }
426
427 request_message.content.push(prompt.into());
428
429 Ok(LanguageModelRequest {
430 tools: Vec::new(),
431 stop: Vec::new(),
432 temperature: None,
433 messages: vec![request_message],
434 })
435 }
436
437 pub fn handle_stream(
438 &mut self,
439 model_telemetry_id: String,
440 model_provider_id: String,
441 model_api_key: Option<String>,
442 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
443 cx: &mut ModelContext<Self>,
444 ) {
445 let start_time = Instant::now();
446 let snapshot = self.snapshot.clone();
447 let selected_text = snapshot
448 .text_for_range(self.range.start..self.range.end)
449 .collect::<Rope>();
450
451 let selection_start = self.range.start.to_point(&snapshot);
452
453 // Start with the indentation of the first line in the selection
454 let mut suggested_line_indent = snapshot
455 .suggested_indents(selection_start.row..=selection_start.row, cx)
456 .into_values()
457 .next()
458 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
459
460 // If the first line in the selection does not have indentation, check the following lines
461 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
462 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
463 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
464 // Prefer tabs if a line in the selection uses tabs as indentation
465 if line_indent.kind == IndentKind::Tab {
466 suggested_line_indent.kind = IndentKind::Tab;
467 break;
468 }
469 }
470 }
471
472 let http_client = cx.http_client().clone();
473 let telemetry = self.telemetry.clone();
474 let language_name = {
475 let multibuffer = self.buffer.read(cx);
476 let snapshot = multibuffer.snapshot(cx);
477 let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
478 ranges
479 .first()
480 .and_then(|(excerpt, _)| excerpt.buffer().language())
481 .map(|language| language.name())
482 };
483
484 self.diff = Diff::default();
485 self.status = CodegenStatus::Pending;
486 let mut edit_start = self.range.start.to_offset(&snapshot);
487 let completion = Arc::new(Mutex::new(String::new()));
488 let completion_clone = completion.clone();
489
490 self.generation = cx.spawn(|codegen, mut cx| {
491 async move {
492 let stream = stream.await;
493 let message_id = stream
494 .as_ref()
495 .ok()
496 .and_then(|stream| stream.message_id.clone());
497 let generate = async {
498 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
499 let executor = cx.background_executor().clone();
500 let message_id = message_id.clone();
501 let line_based_stream_diff: Task<anyhow::Result<()>> =
502 cx.background_executor().spawn(async move {
503 let mut response_latency = None;
504 let request_start = Instant::now();
505 let diff = async {
506 let chunks = StripInvalidSpans::new(stream?.stream);
507 futures::pin_mut!(chunks);
508 let mut diff = StreamingDiff::new(selected_text.to_string());
509 let mut line_diff = LineDiff::default();
510
511 let mut new_text = String::new();
512 let mut base_indent = None;
513 let mut line_indent = None;
514 let mut first_line = true;
515
516 while let Some(chunk) = chunks.next().await {
517 if response_latency.is_none() {
518 response_latency = Some(request_start.elapsed());
519 }
520 let chunk = chunk?;
521 completion_clone.lock().push_str(&chunk);
522
523 let mut lines = chunk.split('\n').peekable();
524 while let Some(line) = lines.next() {
525 new_text.push_str(line);
526 if line_indent.is_none() {
527 if let Some(non_whitespace_ch_ix) =
528 new_text.find(|ch: char| !ch.is_whitespace())
529 {
530 line_indent = Some(non_whitespace_ch_ix);
531 base_indent = base_indent.or(line_indent);
532
533 let line_indent = line_indent.unwrap();
534 let base_indent = base_indent.unwrap();
535 let indent_delta =
536 line_indent as i32 - base_indent as i32;
537 let mut corrected_indent_len = cmp::max(
538 0,
539 suggested_line_indent.len as i32 + indent_delta,
540 )
541 as usize;
542 if first_line {
543 corrected_indent_len = corrected_indent_len
544 .saturating_sub(
545 selection_start.column as usize,
546 );
547 }
548
549 let indent_char = suggested_line_indent.char();
550 let mut indent_buffer = [0; 4];
551 let indent_str =
552 indent_char.encode_utf8(&mut indent_buffer);
553 new_text.replace_range(
554 ..line_indent,
555 &indent_str.repeat(corrected_indent_len),
556 );
557 }
558 }
559
560 if line_indent.is_some() {
561 let char_ops = diff.push_new(&new_text);
562 line_diff
563 .push_char_operations(&char_ops, &selected_text);
564 diff_tx
565 .send((char_ops, line_diff.line_operations()))
566 .await?;
567 new_text.clear();
568 }
569
570 if lines.peek().is_some() {
571 let char_ops = diff.push_new("\n");
572 line_diff
573 .push_char_operations(&char_ops, &selected_text);
574 diff_tx
575 .send((char_ops, line_diff.line_operations()))
576 .await?;
577 if line_indent.is_none() {
578 // Don't write out the leading indentation in empty lines on the next line
579 // This is the case where the above if statement didn't clear the buffer
580 new_text.clear();
581 }
582 line_indent = None;
583 first_line = false;
584 }
585 }
586 }
587
588 let mut char_ops = diff.push_new(&new_text);
589 char_ops.extend(diff.finish());
590 line_diff.push_char_operations(&char_ops, &selected_text);
591 line_diff.finish(&selected_text);
592 diff_tx
593 .send((char_ops, line_diff.line_operations()))
594 .await?;
595
596 anyhow::Ok(())
597 };
598
599 let result = diff.await;
600
601 let error_message =
602 result.as_ref().err().map(|error| error.to_string());
603 report_assistant_event(
604 AssistantEvent {
605 conversation_id: None,
606 message_id,
607 kind: AssistantKind::Inline,
608 phase: AssistantPhase::Response,
609 model: model_telemetry_id,
610 model_provider: model_provider_id.to_string(),
611 response_latency,
612 error_message,
613 language_name: language_name.map(|name| name.to_proto()),
614 },
615 telemetry,
616 http_client,
617 model_api_key,
618 &executor,
619 );
620
621 result?;
622 Ok(())
623 });
624
625 while let Some((char_ops, line_ops)) = diff_rx.next().await {
626 codegen.update(&mut cx, |codegen, cx| {
627 codegen.last_equal_ranges.clear();
628
629 let edits = char_ops
630 .into_iter()
631 .filter_map(|operation| match operation {
632 CharOperation::Insert { text } => {
633 let edit_start = snapshot.anchor_after(edit_start);
634 Some((edit_start..edit_start, text))
635 }
636 CharOperation::Delete { bytes } => {
637 let edit_end = edit_start + bytes;
638 let edit_range = snapshot.anchor_after(edit_start)
639 ..snapshot.anchor_before(edit_end);
640 edit_start = edit_end;
641 Some((edit_range, String::new()))
642 }
643 CharOperation::Keep { bytes } => {
644 let edit_end = edit_start + bytes;
645 let edit_range = snapshot.anchor_after(edit_start)
646 ..snapshot.anchor_before(edit_end);
647 edit_start = edit_end;
648 codegen.last_equal_ranges.push(edit_range);
649 None
650 }
651 })
652 .collect::<Vec<_>>();
653
654 if codegen.active {
655 codegen.apply_edits(edits.iter().cloned(), cx);
656 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
657 }
658 codegen.edits.extend(edits);
659 codegen.line_operations = line_ops;
660 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
661
662 cx.notify();
663 })?;
664 }
665
666 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
667 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
668 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
669 let batch_diff_task =
670 codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
671 let (line_based_stream_diff, ()) =
672 join!(line_based_stream_diff, batch_diff_task);
673 line_based_stream_diff?;
674
675 anyhow::Ok(())
676 };
677
678 let result = generate.await;
679 let elapsed_time = start_time.elapsed().as_secs_f64();
680
681 codegen
682 .update(&mut cx, |this, cx| {
683 this.message_id = message_id;
684 this.last_equal_ranges.clear();
685 if let Err(error) = result {
686 this.status = CodegenStatus::Error(error);
687 } else {
688 this.status = CodegenStatus::Done;
689 }
690 this.elapsed_time = Some(elapsed_time);
691 this.completion = Some(completion.lock().clone());
692 cx.emit(CodegenEvent::Finished);
693 cx.notify();
694 })
695 .ok();
696 }
697 });
698 cx.notify();
699 }
700
701 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
702 self.last_equal_ranges.clear();
703 if self.diff.is_empty() {
704 self.status = CodegenStatus::Idle;
705 } else {
706 self.status = CodegenStatus::Done;
707 }
708 self.generation = Task::ready(());
709 cx.emit(CodegenEvent::Finished);
710 cx.notify();
711 }
712
713 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
714 self.buffer.update(cx, |buffer, cx| {
715 if let Some(transaction_id) = self.transformation_transaction_id.take() {
716 buffer.undo_transaction(transaction_id, cx);
717 buffer.refresh_preview(cx);
718 }
719 });
720 }
721
722 fn apply_edits(
723 &mut self,
724 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
725 cx: &mut ModelContext<CodegenAlternative>,
726 ) {
727 let transaction = self.buffer.update(cx, |buffer, cx| {
728 // Avoid grouping assistant edits with user edits.
729 buffer.finalize_last_transaction(cx);
730 buffer.start_transaction(cx);
731 buffer.edit(edits, None, cx);
732 buffer.end_transaction(cx)
733 });
734
735 if let Some(transaction) = transaction {
736 if let Some(first_transaction) = self.transformation_transaction_id {
737 // Group all assistant edits into the first transaction.
738 self.buffer.update(cx, |buffer, cx| {
739 buffer.merge_transactions(transaction, first_transaction, cx)
740 });
741 } else {
742 self.transformation_transaction_id = Some(transaction);
743 self.buffer
744 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
745 }
746 }
747 }
748
749 fn reapply_line_based_diff(
750 &mut self,
751 line_operations: impl IntoIterator<Item = LineOperation>,
752 cx: &mut ModelContext<Self>,
753 ) {
754 let old_snapshot = self.snapshot.clone();
755 let old_range = self.range.to_point(&old_snapshot);
756 let new_snapshot = self.buffer.read(cx).snapshot(cx);
757 let new_range = self.range.to_point(&new_snapshot);
758
759 let mut old_row = old_range.start.row;
760 let mut new_row = new_range.start.row;
761
762 self.diff.deleted_row_ranges.clear();
763 self.diff.inserted_row_ranges.clear();
764 for operation in line_operations {
765 match operation {
766 LineOperation::Keep { lines } => {
767 old_row += lines;
768 new_row += lines;
769 }
770 LineOperation::Delete { lines } => {
771 let old_end_row = old_row + lines - 1;
772 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
773
774 if let Some((_, last_deleted_row_range)) =
775 self.diff.deleted_row_ranges.last_mut()
776 {
777 if *last_deleted_row_range.end() + 1 == old_row {
778 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
779 } else {
780 self.diff
781 .deleted_row_ranges
782 .push((new_row, old_row..=old_end_row));
783 }
784 } else {
785 self.diff
786 .deleted_row_ranges
787 .push((new_row, old_row..=old_end_row));
788 }
789
790 old_row += lines;
791 }
792 LineOperation::Insert { lines } => {
793 let new_end_row = new_row + lines - 1;
794 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
795 let end = new_snapshot.anchor_before(Point::new(
796 new_end_row,
797 new_snapshot.line_len(MultiBufferRow(new_end_row)),
798 ));
799 self.diff.inserted_row_ranges.push(start..end);
800 new_row += lines;
801 }
802 }
803
804 cx.notify();
805 }
806 }
807
808 fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
809 let old_snapshot = self.snapshot.clone();
810 let old_range = self.range.to_point(&old_snapshot);
811 let new_snapshot = self.buffer.read(cx).snapshot(cx);
812 let new_range = self.range.to_point(&new_snapshot);
813
814 cx.spawn(|codegen, mut cx| async move {
815 let (deleted_row_ranges, inserted_row_ranges) = cx
816 .background_executor()
817 .spawn(async move {
818 let old_text = old_snapshot
819 .text_for_range(
820 Point::new(old_range.start.row, 0)
821 ..Point::new(
822 old_range.end.row,
823 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
824 ),
825 )
826 .collect::<String>();
827 let new_text = new_snapshot
828 .text_for_range(
829 Point::new(new_range.start.row, 0)
830 ..Point::new(
831 new_range.end.row,
832 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
833 ),
834 )
835 .collect::<String>();
836
837 let mut old_row = old_range.start.row;
838 let mut new_row = new_range.start.row;
839 let batch_diff =
840 similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
841
842 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
843 let mut inserted_row_ranges = Vec::new();
844 for change in batch_diff.iter_all_changes() {
845 let line_count = change.value().lines().count() as u32;
846 match change.tag() {
847 similar::ChangeTag::Equal => {
848 old_row += line_count;
849 new_row += line_count;
850 }
851 similar::ChangeTag::Delete => {
852 let old_end_row = old_row + line_count - 1;
853 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
854
855 if let Some((_, last_deleted_row_range)) =
856 deleted_row_ranges.last_mut()
857 {
858 if *last_deleted_row_range.end() + 1 == old_row {
859 *last_deleted_row_range =
860 *last_deleted_row_range.start()..=old_end_row;
861 } else {
862 deleted_row_ranges.push((new_row, old_row..=old_end_row));
863 }
864 } else {
865 deleted_row_ranges.push((new_row, old_row..=old_end_row));
866 }
867
868 old_row += line_count;
869 }
870 similar::ChangeTag::Insert => {
871 let new_end_row = new_row + line_count - 1;
872 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
873 let end = new_snapshot.anchor_before(Point::new(
874 new_end_row,
875 new_snapshot.line_len(MultiBufferRow(new_end_row)),
876 ));
877 inserted_row_ranges.push(start..end);
878 new_row += line_count;
879 }
880 }
881 }
882
883 (deleted_row_ranges, inserted_row_ranges)
884 })
885 .await;
886
887 codegen
888 .update(&mut cx, |codegen, cx| {
889 codegen.diff.deleted_row_ranges = deleted_row_ranges;
890 codegen.diff.inserted_row_ranges = inserted_row_ranges;
891 cx.notify();
892 })
893 .ok();
894 })
895 }
896}
897
898#[derive(Copy, Clone, Debug)]
899pub enum CodegenEvent {
900 Finished,
901 Undone,
902}
903
904struct StripInvalidSpans<T> {
905 stream: T,
906 stream_done: bool,
907 buffer: String,
908 first_line: bool,
909 line_end: bool,
910 starts_with_code_block: bool,
911}
912
913impl<T> StripInvalidSpans<T>
914where
915 T: Stream<Item = Result<String>>,
916{
917 fn new(stream: T) -> Self {
918 Self {
919 stream,
920 stream_done: false,
921 buffer: String::new(),
922 first_line: true,
923 line_end: false,
924 starts_with_code_block: false,
925 }
926 }
927}
928
929impl<T> Stream for StripInvalidSpans<T>
930where
931 T: Stream<Item = Result<String>>,
932{
933 type Item = Result<String>;
934
935 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
936 const CODE_BLOCK_DELIMITER: &str = "```";
937 const CURSOR_SPAN: &str = "<|CURSOR|>";
938
939 let this = unsafe { self.get_unchecked_mut() };
940 loop {
941 if !this.stream_done {
942 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
943 match stream.as_mut().poll_next(cx) {
944 Poll::Ready(Some(Ok(chunk))) => {
945 this.buffer.push_str(&chunk);
946 }
947 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
948 Poll::Ready(None) => {
949 this.stream_done = true;
950 }
951 Poll::Pending => return Poll::Pending,
952 }
953 }
954
955 let mut chunk = String::new();
956 let mut consumed = 0;
957 if !this.buffer.is_empty() {
958 let mut lines = this.buffer.split('\n').enumerate().peekable();
959 while let Some((line_ix, line)) = lines.next() {
960 if line_ix > 0 {
961 this.first_line = false;
962 }
963
964 if this.first_line {
965 let trimmed_line = line.trim();
966 if lines.peek().is_some() {
967 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
968 consumed += line.len() + 1;
969 this.starts_with_code_block = true;
970 continue;
971 }
972 } else if trimmed_line.is_empty()
973 || prefixes(CODE_BLOCK_DELIMITER)
974 .any(|prefix| trimmed_line.starts_with(prefix))
975 {
976 break;
977 }
978 }
979
980 let line_without_cursor = line.replace(CURSOR_SPAN, "");
981 if lines.peek().is_some() {
982 if this.line_end {
983 chunk.push('\n');
984 }
985
986 chunk.push_str(&line_without_cursor);
987 this.line_end = true;
988 consumed += line.len() + 1;
989 } else if this.stream_done {
990 if !this.starts_with_code_block
991 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
992 {
993 if this.line_end {
994 chunk.push('\n');
995 }
996
997 chunk.push_str(&line);
998 }
999
1000 consumed += line.len();
1001 } else {
1002 let trimmed_line = line.trim();
1003 if trimmed_line.is_empty()
1004 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1005 || prefixes(CODE_BLOCK_DELIMITER)
1006 .any(|prefix| trimmed_line.ends_with(prefix))
1007 {
1008 break;
1009 } else {
1010 if this.line_end {
1011 chunk.push('\n');
1012 this.line_end = false;
1013 }
1014
1015 chunk.push_str(&line_without_cursor);
1016 consumed += line.len();
1017 }
1018 }
1019 }
1020 }
1021
1022 this.buffer = this.buffer.split_off(consumed);
1023 if !chunk.is_empty() {
1024 return Poll::Ready(Some(Ok(chunk)));
1025 } else if this.stream_done {
1026 return Poll::Ready(None);
1027 }
1028 }
1029 }
1030}
1031
1032fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1033 (0..text.len() - 1).map(|ix| &text[..ix + 1])
1034}
1035
1036#[derive(Default)]
1037pub struct Diff {
1038 pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1039 pub inserted_row_ranges: Vec<Range<Anchor>>,
1040}
1041
1042impl Diff {
1043 fn is_empty(&self) -> bool {
1044 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1045 }
1046}
1047
1048#[cfg(test)]
1049mod tests {
1050 use super::*;
1051 use futures::{
1052 stream::{self},
1053 Stream,
1054 };
1055 use gpui::TestAppContext;
1056 use indoc::indoc;
1057 use language::{
1058 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
1059 Point,
1060 };
1061 use language_model::LanguageModelRegistry;
1062 use rand::prelude::*;
1063 use serde::Serialize;
1064 use settings::SettingsStore;
1065 use std::{future, sync::Arc};
1066
1067 #[derive(Serialize)]
1068 pub struct DummyCompletionRequest {
1069 pub name: String,
1070 }
1071
1072 #[gpui::test(iterations = 10)]
1073 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1074 cx.set_global(cx.update(SettingsStore::test));
1075 cx.update(language_model::LanguageModelRegistry::test);
1076 cx.update(language_settings::init);
1077
1078 let text = indoc! {"
1079 fn main() {
1080 let x = 0;
1081 for _ in 0..10 {
1082 x += 1;
1083 }
1084 }
1085 "};
1086 let buffer =
1087 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1088 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1089 let range = buffer.read_with(cx, |buffer, cx| {
1090 let snapshot = buffer.snapshot(cx);
1091 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1092 });
1093 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1094 let codegen = cx.new_model(|cx| {
1095 CodegenAlternative::new(
1096 buffer.clone(),
1097 range.clone(),
1098 true,
1099 None,
1100 None,
1101 prompt_builder,
1102 cx,
1103 )
1104 });
1105
1106 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1107
1108 let mut new_text = concat!(
1109 " let mut x = 0;\n",
1110 " while x < 10 {\n",
1111 " x += 1;\n",
1112 " }",
1113 );
1114 while !new_text.is_empty() {
1115 let max_len = cmp::min(new_text.len(), 10);
1116 let len = rng.gen_range(1..=max_len);
1117 let (chunk, suffix) = new_text.split_at(len);
1118 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1119 new_text = suffix;
1120 cx.background_executor.run_until_parked();
1121 }
1122 drop(chunks_tx);
1123 cx.background_executor.run_until_parked();
1124
1125 assert_eq!(
1126 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1127 indoc! {"
1128 fn main() {
1129 let mut x = 0;
1130 while x < 10 {
1131 x += 1;
1132 }
1133 }
1134 "}
1135 );
1136 }
1137
1138 #[gpui::test(iterations = 10)]
1139 async fn test_autoindent_when_generating_past_indentation(
1140 cx: &mut TestAppContext,
1141 mut rng: StdRng,
1142 ) {
1143 cx.set_global(cx.update(SettingsStore::test));
1144 cx.update(language_settings::init);
1145
1146 let text = indoc! {"
1147 fn main() {
1148 le
1149 }
1150 "};
1151 let buffer =
1152 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1153 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1154 let range = buffer.read_with(cx, |buffer, cx| {
1155 let snapshot = buffer.snapshot(cx);
1156 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1157 });
1158 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1159 let codegen = cx.new_model(|cx| {
1160 CodegenAlternative::new(
1161 buffer.clone(),
1162 range.clone(),
1163 true,
1164 None,
1165 None,
1166 prompt_builder,
1167 cx,
1168 )
1169 });
1170
1171 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1172
1173 cx.background_executor.run_until_parked();
1174
1175 let mut new_text = concat!(
1176 "t mut x = 0;\n",
1177 "while x < 10 {\n",
1178 " x += 1;\n",
1179 "}", //
1180 );
1181 while !new_text.is_empty() {
1182 let max_len = cmp::min(new_text.len(), 10);
1183 let len = rng.gen_range(1..=max_len);
1184 let (chunk, suffix) = new_text.split_at(len);
1185 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1186 new_text = suffix;
1187 cx.background_executor.run_until_parked();
1188 }
1189 drop(chunks_tx);
1190 cx.background_executor.run_until_parked();
1191
1192 assert_eq!(
1193 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1194 indoc! {"
1195 fn main() {
1196 let mut x = 0;
1197 while x < 10 {
1198 x += 1;
1199 }
1200 }
1201 "}
1202 );
1203 }
1204
1205 #[gpui::test(iterations = 10)]
1206 async fn test_autoindent_when_generating_before_indentation(
1207 cx: &mut TestAppContext,
1208 mut rng: StdRng,
1209 ) {
1210 cx.update(LanguageModelRegistry::test);
1211 cx.set_global(cx.update(SettingsStore::test));
1212 cx.update(language_settings::init);
1213
1214 let text = concat!(
1215 "fn main() {\n",
1216 " \n",
1217 "}\n" //
1218 );
1219 let buffer =
1220 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1221 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1222 let range = buffer.read_with(cx, |buffer, cx| {
1223 let snapshot = buffer.snapshot(cx);
1224 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1225 });
1226 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1227 let codegen = cx.new_model(|cx| {
1228 CodegenAlternative::new(
1229 buffer.clone(),
1230 range.clone(),
1231 true,
1232 None,
1233 None,
1234 prompt_builder,
1235 cx,
1236 )
1237 });
1238
1239 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1240
1241 cx.background_executor.run_until_parked();
1242
1243 let mut new_text = concat!(
1244 "let mut x = 0;\n",
1245 "while x < 10 {\n",
1246 " x += 1;\n",
1247 "}", //
1248 );
1249 while !new_text.is_empty() {
1250 let max_len = cmp::min(new_text.len(), 10);
1251 let len = rng.gen_range(1..=max_len);
1252 let (chunk, suffix) = new_text.split_at(len);
1253 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1254 new_text = suffix;
1255 cx.background_executor.run_until_parked();
1256 }
1257 drop(chunks_tx);
1258 cx.background_executor.run_until_parked();
1259
1260 assert_eq!(
1261 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1262 indoc! {"
1263 fn main() {
1264 let mut x = 0;
1265 while x < 10 {
1266 x += 1;
1267 }
1268 }
1269 "}
1270 );
1271 }
1272
1273 #[gpui::test(iterations = 10)]
1274 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1275 cx.update(LanguageModelRegistry::test);
1276 cx.set_global(cx.update(SettingsStore::test));
1277 cx.update(language_settings::init);
1278
1279 let text = indoc! {"
1280 func main() {
1281 \tx := 0
1282 \tfor i := 0; i < 10; i++ {
1283 \t\tx++
1284 \t}
1285 }
1286 "};
1287 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
1288 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1289 let range = buffer.read_with(cx, |buffer, cx| {
1290 let snapshot = buffer.snapshot(cx);
1291 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1292 });
1293 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1294 let codegen = cx.new_model(|cx| {
1295 CodegenAlternative::new(
1296 buffer.clone(),
1297 range.clone(),
1298 true,
1299 None,
1300 None,
1301 prompt_builder,
1302 cx,
1303 )
1304 });
1305
1306 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1307 let new_text = concat!(
1308 "func main() {\n",
1309 "\tx := 0\n",
1310 "\tfor x < 10 {\n",
1311 "\t\tx++\n",
1312 "\t}", //
1313 );
1314 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1315 drop(chunks_tx);
1316 cx.background_executor.run_until_parked();
1317
1318 assert_eq!(
1319 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1320 indoc! {"
1321 func main() {
1322 \tx := 0
1323 \tfor x < 10 {
1324 \t\tx++
1325 \t}
1326 }
1327 "}
1328 );
1329 }
1330
1331 #[gpui::test]
1332 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1333 cx.update(LanguageModelRegistry::test);
1334 cx.set_global(cx.update(SettingsStore::test));
1335 cx.update(language_settings::init);
1336
1337 let text = indoc! {"
1338 fn main() {
1339 let x = 0;
1340 }
1341 "};
1342 let buffer =
1343 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1344 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1345 let range = buffer.read_with(cx, |buffer, cx| {
1346 let snapshot = buffer.snapshot(cx);
1347 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1348 });
1349 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1350 let codegen = cx.new_model(|cx| {
1351 CodegenAlternative::new(
1352 buffer.clone(),
1353 range.clone(),
1354 false,
1355 None,
1356 None,
1357 prompt_builder,
1358 cx,
1359 )
1360 });
1361
1362 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1363 chunks_tx
1364 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1365 .unwrap();
1366 drop(chunks_tx);
1367 cx.run_until_parked();
1368
1369 // The codegen is inactive, so the buffer doesn't get modified.
1370 assert_eq!(
1371 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1372 text
1373 );
1374
1375 // Activating the codegen applies the changes.
1376 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1377 assert_eq!(
1378 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1379 indoc! {"
1380 fn main() {
1381 let mut x = 0;
1382 x += 1;
1383 }
1384 "}
1385 );
1386
1387 // Deactivating the codegen undoes the changes.
1388 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1389 cx.run_until_parked();
1390 assert_eq!(
1391 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1392 text
1393 );
1394 }
1395
1396 #[gpui::test]
1397 async fn test_strip_invalid_spans_from_codeblock() {
1398 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1399 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1400 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1401 assert_chunks(
1402 "```html\n```js\nLorem ipsum dolor\n```\n```",
1403 "```js\nLorem ipsum dolor\n```",
1404 )
1405 .await;
1406 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1407 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1408 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1409 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1410
1411 async fn assert_chunks(text: &str, expected_text: &str) {
1412 for chunk_size in 1..=text.len() {
1413 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1414 .map(|chunk| chunk.unwrap())
1415 .collect::<String>()
1416 .await;
1417 assert_eq!(
1418 actual_text, expected_text,
1419 "failed to strip invalid spans, chunk size: {}",
1420 chunk_size
1421 );
1422 }
1423 }
1424
1425 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1426 stream::iter(
1427 text.chars()
1428 .collect::<Vec<_>>()
1429 .chunks(size)
1430 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1431 .collect::<Vec<_>>(),
1432 )
1433 }
1434 }
1435
1436 fn simulate_response_stream(
1437 codegen: Model<CodegenAlternative>,
1438 cx: &mut TestAppContext,
1439 ) -> mpsc::UnboundedSender<String> {
1440 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1441 codegen.update(cx, |codegen, cx| {
1442 codegen.handle_stream(
1443 String::new(),
1444 String::new(),
1445 None,
1446 future::ready(Ok(LanguageModelTextStream {
1447 message_id: None,
1448 stream: chunks_rx.map(Ok).boxed(),
1449 })),
1450 cx,
1451 );
1452 });
1453 chunks_tx
1454 }
1455
1456 fn rust_lang() -> Language {
1457 Language::new(
1458 LanguageConfig {
1459 name: "Rust".into(),
1460 matcher: LanguageMatcher {
1461 path_suffixes: vec!["rs".to_string()],
1462 ..Default::default()
1463 },
1464 ..Default::default()
1465 },
1466 Some(tree_sitter_rust::LANGUAGE.into()),
1467 )
1468 .with_indents_query(
1469 r#"
1470 (call_expression) @indent
1471 (field_expression) @indent
1472 (_ "(" ")" @end) @indent
1473 (_ "{" "}" @end) @indent
1474 "#,
1475 )
1476 .unwrap()
1477 }
1478}