1use crate::context::attach_context_to_message;
2use crate::context_store::ContextStore;
3use crate::inline_prompt_editor::CodegenStatus;
4use crate::{
5 prompts::PromptBuilder,
6 streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff},
7};
8use anyhow::{Context as _, Result};
9use client::telemetry::Telemetry;
10use collections::HashSet;
11use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
12use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, StreamExt};
13use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscription, Task};
14use language::{Buffer, IndentKind, Point, TransactionId};
15use language_model::{
16 LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
17 LanguageModelTextStream, Role,
18};
19use language_models::report_assistant_event;
20use multi_buffer::MultiBufferRow;
21use parking_lot::Mutex;
22use rope::Rope;
23use smol::future::FutureExt;
24use std::{
25 cmp,
26 future::Future,
27 iter,
28 ops::{Range, RangeInclusive},
29 pin::Pin,
30 sync::Arc,
31 task::{self, Poll},
32 time::Instant,
33};
34use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
35
36pub struct BufferCodegen {
37 alternatives: Vec<Model<CodegenAlternative>>,
38 pub active_alternative: usize,
39 seen_alternatives: HashSet<usize>,
40 subscriptions: Vec<Subscription>,
41 buffer: Model<MultiBuffer>,
42 range: Range<Anchor>,
43 initial_transaction_id: Option<TransactionId>,
44 context_store: Model<ContextStore>,
45 telemetry: Arc<Telemetry>,
46 builder: Arc<PromptBuilder>,
47 pub is_insertion: bool,
48}
49
50impl BufferCodegen {
51 pub fn new(
52 buffer: Model<MultiBuffer>,
53 range: Range<Anchor>,
54 initial_transaction_id: Option<TransactionId>,
55 context_store: Model<ContextStore>,
56 telemetry: Arc<Telemetry>,
57 builder: Arc<PromptBuilder>,
58 cx: &mut ModelContext<Self>,
59 ) -> Self {
60 let codegen = cx.new_model(|cx| {
61 CodegenAlternative::new(
62 buffer.clone(),
63 range.clone(),
64 false,
65 Some(context_store.clone()),
66 Some(telemetry.clone()),
67 builder.clone(),
68 cx,
69 )
70 });
71 let mut this = Self {
72 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
73 alternatives: vec![codegen],
74 active_alternative: 0,
75 seen_alternatives: HashSet::default(),
76 subscriptions: Vec::new(),
77 buffer,
78 range,
79 initial_transaction_id,
80 context_store,
81 telemetry,
82 builder,
83 };
84 this.activate(0, cx);
85 this
86 }
87
88 fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
89 let codegen = self.active_alternative().clone();
90 self.subscriptions.clear();
91 self.subscriptions
92 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
93 self.subscriptions
94 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
95 }
96
97 pub fn active_alternative(&self) -> &Model<CodegenAlternative> {
98 &self.alternatives[self.active_alternative]
99 }
100
101 pub fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
102 &self.active_alternative().read(cx).status
103 }
104
105 pub fn alternative_count(&self, cx: &AppContext) -> usize {
106 LanguageModelRegistry::read_global(cx)
107 .inline_alternative_models()
108 .len()
109 + 1
110 }
111
112 pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
113 let next_active_ix = if self.active_alternative == 0 {
114 self.alternatives.len() - 1
115 } else {
116 self.active_alternative - 1
117 };
118 self.activate(next_active_ix, cx);
119 }
120
121 pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
122 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
123 self.activate(next_active_ix, cx);
124 }
125
126 fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
127 self.active_alternative()
128 .update(cx, |codegen, cx| codegen.set_active(false, cx));
129 self.seen_alternatives.insert(index);
130 self.active_alternative = index;
131 self.active_alternative()
132 .update(cx, |codegen, cx| codegen.set_active(true, cx));
133 self.subscribe_to_alternative(cx);
134 cx.notify();
135 }
136
137 pub fn start(&mut self, user_prompt: String, cx: &mut ModelContext<Self>) -> Result<()> {
138 let alternative_models = LanguageModelRegistry::read_global(cx)
139 .inline_alternative_models()
140 .to_vec();
141
142 self.active_alternative()
143 .update(cx, |alternative, cx| alternative.undo(cx));
144 self.activate(0, cx);
145 self.alternatives.truncate(1);
146
147 for _ in 0..alternative_models.len() {
148 self.alternatives.push(cx.new_model(|cx| {
149 CodegenAlternative::new(
150 self.buffer.clone(),
151 self.range.clone(),
152 false,
153 Some(self.context_store.clone()),
154 Some(self.telemetry.clone()),
155 self.builder.clone(),
156 cx,
157 )
158 }));
159 }
160
161 let primary_model = LanguageModelRegistry::read_global(cx)
162 .active_model()
163 .context("no active model")?;
164
165 for (model, alternative) in iter::once(primary_model)
166 .chain(alternative_models)
167 .zip(&self.alternatives)
168 {
169 alternative.update(cx, |alternative, cx| {
170 alternative.start(user_prompt.clone(), model.clone(), cx)
171 })?;
172 }
173
174 Ok(())
175 }
176
177 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
178 for codegen in &self.alternatives {
179 codegen.update(cx, |codegen, cx| codegen.stop(cx));
180 }
181 }
182
183 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
184 self.active_alternative()
185 .update(cx, |codegen, cx| codegen.undo(cx));
186
187 self.buffer.update(cx, |buffer, cx| {
188 if let Some(transaction_id) = self.initial_transaction_id.take() {
189 buffer.undo_transaction(transaction_id, cx);
190 buffer.refresh_preview(cx);
191 }
192 });
193 }
194
195 pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
196 self.active_alternative().read(cx).buffer.clone()
197 }
198
199 pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
200 self.active_alternative().read(cx).old_buffer.clone()
201 }
202
203 pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
204 self.active_alternative().read(cx).snapshot.clone()
205 }
206
207 pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
208 self.active_alternative().read(cx).edit_position
209 }
210
211 pub fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
212 &self.active_alternative().read(cx).diff
213 }
214
215 pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
216 self.active_alternative().read(cx).last_equal_ranges()
217 }
218}
219
220impl EventEmitter<CodegenEvent> for BufferCodegen {}
221
222pub struct CodegenAlternative {
223 buffer: Model<MultiBuffer>,
224 old_buffer: Model<Buffer>,
225 snapshot: MultiBufferSnapshot,
226 edit_position: Option<Anchor>,
227 range: Range<Anchor>,
228 last_equal_ranges: Vec<Range<Anchor>>,
229 transformation_transaction_id: Option<TransactionId>,
230 status: CodegenStatus,
231 generation: Task<()>,
232 diff: Diff,
233 context_store: Option<Model<ContextStore>>,
234 telemetry: Option<Arc<Telemetry>>,
235 _subscription: gpui::Subscription,
236 builder: Arc<PromptBuilder>,
237 active: bool,
238 edits: Vec<(Range<Anchor>, String)>,
239 line_operations: Vec<LineOperation>,
240 request: Option<LanguageModelRequest>,
241 elapsed_time: Option<f64>,
242 completion: Option<String>,
243 pub message_id: Option<String>,
244}
245
246impl EventEmitter<CodegenEvent> for CodegenAlternative {}
247
248impl CodegenAlternative {
249 pub fn new(
250 buffer: Model<MultiBuffer>,
251 range: Range<Anchor>,
252 active: bool,
253 context_store: Option<Model<ContextStore>>,
254 telemetry: Option<Arc<Telemetry>>,
255 builder: Arc<PromptBuilder>,
256 cx: &mut ModelContext<Self>,
257 ) -> Self {
258 let snapshot = buffer.read(cx).snapshot(cx);
259
260 let (old_buffer, _, _) = buffer
261 .read(cx)
262 .range_to_buffer_ranges(range.clone(), cx)
263 .pop()
264 .unwrap();
265 let old_buffer = cx.new_model(|cx| {
266 let old_buffer = old_buffer.read(cx);
267 let text = old_buffer.as_rope().clone();
268 let line_ending = old_buffer.line_ending();
269 let language = old_buffer.language().cloned();
270 let language_registry = old_buffer.language_registry();
271
272 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
273 buffer.set_language(language, cx);
274 if let Some(language_registry) = language_registry {
275 buffer.set_language_registry(language_registry)
276 }
277 buffer
278 });
279
280 Self {
281 buffer: buffer.clone(),
282 old_buffer,
283 edit_position: None,
284 message_id: None,
285 snapshot,
286 last_equal_ranges: Default::default(),
287 transformation_transaction_id: None,
288 status: CodegenStatus::Idle,
289 generation: Task::ready(()),
290 diff: Diff::default(),
291 context_store,
292 telemetry,
293 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
294 builder,
295 active,
296 edits: Vec::new(),
297 line_operations: Vec::new(),
298 range,
299 request: None,
300 elapsed_time: None,
301 completion: None,
302 }
303 }
304
305 pub fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
306 if active != self.active {
307 self.active = active;
308
309 if self.active {
310 let edits = self.edits.clone();
311 self.apply_edits(edits, cx);
312 if matches!(self.status, CodegenStatus::Pending) {
313 let line_operations = self.line_operations.clone();
314 self.reapply_line_based_diff(line_operations, cx);
315 } else {
316 self.reapply_batch_diff(cx).detach();
317 }
318 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
319 self.buffer.update(cx, |buffer, cx| {
320 buffer.undo_transaction(transaction_id, cx);
321 buffer.forget_transaction(transaction_id, cx);
322 });
323 }
324 }
325 }
326
327 fn handle_buffer_event(
328 &mut self,
329 _buffer: Model<MultiBuffer>,
330 event: &multi_buffer::Event,
331 cx: &mut ModelContext<Self>,
332 ) {
333 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
334 if self.transformation_transaction_id == Some(*transaction_id) {
335 self.transformation_transaction_id = None;
336 self.generation = Task::ready(());
337 cx.emit(CodegenEvent::Undone);
338 }
339 }
340 }
341
342 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
343 &self.last_equal_ranges
344 }
345
346 pub fn start(
347 &mut self,
348 user_prompt: String,
349 model: Arc<dyn LanguageModel>,
350 cx: &mut ModelContext<Self>,
351 ) -> Result<()> {
352 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
353 self.buffer.update(cx, |buffer, cx| {
354 buffer.undo_transaction(transformation_transaction_id, cx);
355 });
356 }
357
358 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
359
360 let api_key = model.api_key(cx);
361 let telemetry_id = model.telemetry_id();
362 let provider_id = model.provider_id();
363 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
364 if user_prompt.trim().to_lowercase() == "delete" {
365 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
366 } else {
367 let request = self.build_request(user_prompt, cx)?;
368 self.request = Some(request.clone());
369
370 cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
371 .boxed_local()
372 };
373 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
374 Ok(())
375 }
376
377 fn build_request(
378 &self,
379 user_prompt: String,
380 cx: &mut AppContext,
381 ) -> Result<LanguageModelRequest> {
382 let buffer = self.buffer.read(cx).snapshot(cx);
383 let language = buffer.language_at(self.range.start);
384 let language_name = if let Some(language) = language.as_ref() {
385 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
386 None
387 } else {
388 Some(language.name())
389 }
390 } else {
391 None
392 };
393
394 let language_name = language_name.as_ref();
395 let start = buffer.point_to_buffer_offset(self.range.start);
396 let end = buffer.point_to_buffer_offset(self.range.end);
397 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
398 let (start_buffer, start_buffer_offset) = start;
399 let (end_buffer, end_buffer_offset) = end;
400 if start_buffer.remote_id() == end_buffer.remote_id() {
401 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
402 } else {
403 return Err(anyhow::anyhow!("invalid transformation range"));
404 }
405 } else {
406 return Err(anyhow::anyhow!("invalid transformation range"));
407 };
408
409 let prompt = self
410 .builder
411 .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
412 .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
413
414 let mut request_message = LanguageModelRequestMessage {
415 role: Role::User,
416 content: Vec::new(),
417 cache: false,
418 };
419
420 if let Some(context_store) = &self.context_store {
421 let context = context_store.update(cx, |this, _cx| this.context().clone());
422 attach_context_to_message(&mut request_message, context);
423 }
424
425 request_message.content.push(prompt.into());
426
427 Ok(LanguageModelRequest {
428 tools: Vec::new(),
429 stop: Vec::new(),
430 temperature: None,
431 messages: vec![request_message],
432 })
433 }
434
435 pub fn handle_stream(
436 &mut self,
437 model_telemetry_id: String,
438 model_provider_id: String,
439 model_api_key: Option<String>,
440 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
441 cx: &mut ModelContext<Self>,
442 ) {
443 let start_time = Instant::now();
444 let snapshot = self.snapshot.clone();
445 let selected_text = snapshot
446 .text_for_range(self.range.start..self.range.end)
447 .collect::<Rope>();
448
449 let selection_start = self.range.start.to_point(&snapshot);
450
451 // Start with the indentation of the first line in the selection
452 let mut suggested_line_indent = snapshot
453 .suggested_indents(selection_start.row..=selection_start.row, cx)
454 .into_values()
455 .next()
456 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
457
458 // If the first line in the selection does not have indentation, check the following lines
459 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
460 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
461 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
462 // Prefer tabs if a line in the selection uses tabs as indentation
463 if line_indent.kind == IndentKind::Tab {
464 suggested_line_indent.kind = IndentKind::Tab;
465 break;
466 }
467 }
468 }
469
470 let http_client = cx.http_client().clone();
471 let telemetry = self.telemetry.clone();
472 let language_name = {
473 let multibuffer = self.buffer.read(cx);
474 let ranges = multibuffer.range_to_buffer_ranges(self.range.clone(), cx);
475 ranges
476 .first()
477 .and_then(|(buffer, _, _)| buffer.read(cx).language())
478 .map(|language| language.name())
479 };
480
481 self.diff = Diff::default();
482 self.status = CodegenStatus::Pending;
483 let mut edit_start = self.range.start.to_offset(&snapshot);
484 let completion = Arc::new(Mutex::new(String::new()));
485 let completion_clone = completion.clone();
486
487 self.generation = cx.spawn(|codegen, mut cx| {
488 async move {
489 let stream = stream.await;
490 let message_id = stream
491 .as_ref()
492 .ok()
493 .and_then(|stream| stream.message_id.clone());
494 let generate = async {
495 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
496 let executor = cx.background_executor().clone();
497 let message_id = message_id.clone();
498 let line_based_stream_diff: Task<anyhow::Result<()>> =
499 cx.background_executor().spawn(async move {
500 let mut response_latency = None;
501 let request_start = Instant::now();
502 let diff = async {
503 let chunks = StripInvalidSpans::new(stream?.stream);
504 futures::pin_mut!(chunks);
505 let mut diff = StreamingDiff::new(selected_text.to_string());
506 let mut line_diff = LineDiff::default();
507
508 let mut new_text = String::new();
509 let mut base_indent = None;
510 let mut line_indent = None;
511 let mut first_line = true;
512
513 while let Some(chunk) = chunks.next().await {
514 if response_latency.is_none() {
515 response_latency = Some(request_start.elapsed());
516 }
517 let chunk = chunk?;
518 completion_clone.lock().push_str(&chunk);
519
520 let mut lines = chunk.split('\n').peekable();
521 while let Some(line) = lines.next() {
522 new_text.push_str(line);
523 if line_indent.is_none() {
524 if let Some(non_whitespace_ch_ix) =
525 new_text.find(|ch: char| !ch.is_whitespace())
526 {
527 line_indent = Some(non_whitespace_ch_ix);
528 base_indent = base_indent.or(line_indent);
529
530 let line_indent = line_indent.unwrap();
531 let base_indent = base_indent.unwrap();
532 let indent_delta =
533 line_indent as i32 - base_indent as i32;
534 let mut corrected_indent_len = cmp::max(
535 0,
536 suggested_line_indent.len as i32 + indent_delta,
537 )
538 as usize;
539 if first_line {
540 corrected_indent_len = corrected_indent_len
541 .saturating_sub(
542 selection_start.column as usize,
543 );
544 }
545
546 let indent_char = suggested_line_indent.char();
547 let mut indent_buffer = [0; 4];
548 let indent_str =
549 indent_char.encode_utf8(&mut indent_buffer);
550 new_text.replace_range(
551 ..line_indent,
552 &indent_str.repeat(corrected_indent_len),
553 );
554 }
555 }
556
557 if line_indent.is_some() {
558 let char_ops = diff.push_new(&new_text);
559 line_diff
560 .push_char_operations(&char_ops, &selected_text);
561 diff_tx
562 .send((char_ops, line_diff.line_operations()))
563 .await?;
564 new_text.clear();
565 }
566
567 if lines.peek().is_some() {
568 let char_ops = diff.push_new("\n");
569 line_diff
570 .push_char_operations(&char_ops, &selected_text);
571 diff_tx
572 .send((char_ops, line_diff.line_operations()))
573 .await?;
574 if line_indent.is_none() {
575 // Don't write out the leading indentation in empty lines on the next line
576 // This is the case where the above if statement didn't clear the buffer
577 new_text.clear();
578 }
579 line_indent = None;
580 first_line = false;
581 }
582 }
583 }
584
585 let mut char_ops = diff.push_new(&new_text);
586 char_ops.extend(diff.finish());
587 line_diff.push_char_operations(&char_ops, &selected_text);
588 line_diff.finish(&selected_text);
589 diff_tx
590 .send((char_ops, line_diff.line_operations()))
591 .await?;
592
593 anyhow::Ok(())
594 };
595
596 let result = diff.await;
597
598 let error_message =
599 result.as_ref().err().map(|error| error.to_string());
600 report_assistant_event(
601 AssistantEvent {
602 conversation_id: None,
603 message_id,
604 kind: AssistantKind::Inline,
605 phase: AssistantPhase::Response,
606 model: model_telemetry_id,
607 model_provider: model_provider_id.to_string(),
608 response_latency,
609 error_message,
610 language_name: language_name.map(|name| name.to_proto()),
611 },
612 telemetry,
613 http_client,
614 model_api_key,
615 &executor,
616 );
617
618 result?;
619 Ok(())
620 });
621
622 while let Some((char_ops, line_ops)) = diff_rx.next().await {
623 codegen.update(&mut cx, |codegen, cx| {
624 codegen.last_equal_ranges.clear();
625
626 let edits = char_ops
627 .into_iter()
628 .filter_map(|operation| match operation {
629 CharOperation::Insert { text } => {
630 let edit_start = snapshot.anchor_after(edit_start);
631 Some((edit_start..edit_start, text))
632 }
633 CharOperation::Delete { bytes } => {
634 let edit_end = edit_start + bytes;
635 let edit_range = snapshot.anchor_after(edit_start)
636 ..snapshot.anchor_before(edit_end);
637 edit_start = edit_end;
638 Some((edit_range, String::new()))
639 }
640 CharOperation::Keep { bytes } => {
641 let edit_end = edit_start + bytes;
642 let edit_range = snapshot.anchor_after(edit_start)
643 ..snapshot.anchor_before(edit_end);
644 edit_start = edit_end;
645 codegen.last_equal_ranges.push(edit_range);
646 None
647 }
648 })
649 .collect::<Vec<_>>();
650
651 if codegen.active {
652 codegen.apply_edits(edits.iter().cloned(), cx);
653 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
654 }
655 codegen.edits.extend(edits);
656 codegen.line_operations = line_ops;
657 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
658
659 cx.notify();
660 })?;
661 }
662
663 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
664 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
665 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
666 let batch_diff_task =
667 codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
668 let (line_based_stream_diff, ()) =
669 join!(line_based_stream_diff, batch_diff_task);
670 line_based_stream_diff?;
671
672 anyhow::Ok(())
673 };
674
675 let result = generate.await;
676 let elapsed_time = start_time.elapsed().as_secs_f64();
677
678 codegen
679 .update(&mut cx, |this, cx| {
680 this.message_id = message_id;
681 this.last_equal_ranges.clear();
682 if let Err(error) = result {
683 this.status = CodegenStatus::Error(error);
684 } else {
685 this.status = CodegenStatus::Done;
686 }
687 this.elapsed_time = Some(elapsed_time);
688 this.completion = Some(completion.lock().clone());
689 cx.emit(CodegenEvent::Finished);
690 cx.notify();
691 })
692 .ok();
693 }
694 });
695 cx.notify();
696 }
697
698 pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
699 self.last_equal_ranges.clear();
700 if self.diff.is_empty() {
701 self.status = CodegenStatus::Idle;
702 } else {
703 self.status = CodegenStatus::Done;
704 }
705 self.generation = Task::ready(());
706 cx.emit(CodegenEvent::Finished);
707 cx.notify();
708 }
709
710 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
711 self.buffer.update(cx, |buffer, cx| {
712 if let Some(transaction_id) = self.transformation_transaction_id.take() {
713 buffer.undo_transaction(transaction_id, cx);
714 buffer.refresh_preview(cx);
715 }
716 });
717 }
718
719 fn apply_edits(
720 &mut self,
721 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
722 cx: &mut ModelContext<CodegenAlternative>,
723 ) {
724 let transaction = self.buffer.update(cx, |buffer, cx| {
725 // Avoid grouping assistant edits with user edits.
726 buffer.finalize_last_transaction(cx);
727 buffer.start_transaction(cx);
728 buffer.edit(edits, None, cx);
729 buffer.end_transaction(cx)
730 });
731
732 if let Some(transaction) = transaction {
733 if let Some(first_transaction) = self.transformation_transaction_id {
734 // Group all assistant edits into the first transaction.
735 self.buffer.update(cx, |buffer, cx| {
736 buffer.merge_transactions(transaction, first_transaction, cx)
737 });
738 } else {
739 self.transformation_transaction_id = Some(transaction);
740 self.buffer
741 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
742 }
743 }
744 }
745
746 fn reapply_line_based_diff(
747 &mut self,
748 line_operations: impl IntoIterator<Item = LineOperation>,
749 cx: &mut ModelContext<Self>,
750 ) {
751 let old_snapshot = self.snapshot.clone();
752 let old_range = self.range.to_point(&old_snapshot);
753 let new_snapshot = self.buffer.read(cx).snapshot(cx);
754 let new_range = self.range.to_point(&new_snapshot);
755
756 let mut old_row = old_range.start.row;
757 let mut new_row = new_range.start.row;
758
759 self.diff.deleted_row_ranges.clear();
760 self.diff.inserted_row_ranges.clear();
761 for operation in line_operations {
762 match operation {
763 LineOperation::Keep { lines } => {
764 old_row += lines;
765 new_row += lines;
766 }
767 LineOperation::Delete { lines } => {
768 let old_end_row = old_row + lines - 1;
769 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
770
771 if let Some((_, last_deleted_row_range)) =
772 self.diff.deleted_row_ranges.last_mut()
773 {
774 if *last_deleted_row_range.end() + 1 == old_row {
775 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
776 } else {
777 self.diff
778 .deleted_row_ranges
779 .push((new_row, old_row..=old_end_row));
780 }
781 } else {
782 self.diff
783 .deleted_row_ranges
784 .push((new_row, old_row..=old_end_row));
785 }
786
787 old_row += lines;
788 }
789 LineOperation::Insert { lines } => {
790 let new_end_row = new_row + lines - 1;
791 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
792 let end = new_snapshot.anchor_before(Point::new(
793 new_end_row,
794 new_snapshot.line_len(MultiBufferRow(new_end_row)),
795 ));
796 self.diff.inserted_row_ranges.push(start..end);
797 new_row += lines;
798 }
799 }
800
801 cx.notify();
802 }
803 }
804
805 fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
806 let old_snapshot = self.snapshot.clone();
807 let old_range = self.range.to_point(&old_snapshot);
808 let new_snapshot = self.buffer.read(cx).snapshot(cx);
809 let new_range = self.range.to_point(&new_snapshot);
810
811 cx.spawn(|codegen, mut cx| async move {
812 let (deleted_row_ranges, inserted_row_ranges) = cx
813 .background_executor()
814 .spawn(async move {
815 let old_text = old_snapshot
816 .text_for_range(
817 Point::new(old_range.start.row, 0)
818 ..Point::new(
819 old_range.end.row,
820 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
821 ),
822 )
823 .collect::<String>();
824 let new_text = new_snapshot
825 .text_for_range(
826 Point::new(new_range.start.row, 0)
827 ..Point::new(
828 new_range.end.row,
829 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
830 ),
831 )
832 .collect::<String>();
833
834 let mut old_row = old_range.start.row;
835 let mut new_row = new_range.start.row;
836 let batch_diff =
837 similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
838
839 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
840 let mut inserted_row_ranges = Vec::new();
841 for change in batch_diff.iter_all_changes() {
842 let line_count = change.value().lines().count() as u32;
843 match change.tag() {
844 similar::ChangeTag::Equal => {
845 old_row += line_count;
846 new_row += line_count;
847 }
848 similar::ChangeTag::Delete => {
849 let old_end_row = old_row + line_count - 1;
850 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
851
852 if let Some((_, last_deleted_row_range)) =
853 deleted_row_ranges.last_mut()
854 {
855 if *last_deleted_row_range.end() + 1 == old_row {
856 *last_deleted_row_range =
857 *last_deleted_row_range.start()..=old_end_row;
858 } else {
859 deleted_row_ranges.push((new_row, old_row..=old_end_row));
860 }
861 } else {
862 deleted_row_ranges.push((new_row, old_row..=old_end_row));
863 }
864
865 old_row += line_count;
866 }
867 similar::ChangeTag::Insert => {
868 let new_end_row = new_row + line_count - 1;
869 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
870 let end = new_snapshot.anchor_before(Point::new(
871 new_end_row,
872 new_snapshot.line_len(MultiBufferRow(new_end_row)),
873 ));
874 inserted_row_ranges.push(start..end);
875 new_row += line_count;
876 }
877 }
878 }
879
880 (deleted_row_ranges, inserted_row_ranges)
881 })
882 .await;
883
884 codegen
885 .update(&mut cx, |codegen, cx| {
886 codegen.diff.deleted_row_ranges = deleted_row_ranges;
887 codegen.diff.inserted_row_ranges = inserted_row_ranges;
888 cx.notify();
889 })
890 .ok();
891 })
892 }
893}
894
895#[derive(Copy, Clone, Debug)]
896pub enum CodegenEvent {
897 Finished,
898 Undone,
899}
900
901struct StripInvalidSpans<T> {
902 stream: T,
903 stream_done: bool,
904 buffer: String,
905 first_line: bool,
906 line_end: bool,
907 starts_with_code_block: bool,
908}
909
910impl<T> StripInvalidSpans<T>
911where
912 T: Stream<Item = Result<String>>,
913{
914 fn new(stream: T) -> Self {
915 Self {
916 stream,
917 stream_done: false,
918 buffer: String::new(),
919 first_line: true,
920 line_end: false,
921 starts_with_code_block: false,
922 }
923 }
924}
925
926impl<T> Stream for StripInvalidSpans<T>
927where
928 T: Stream<Item = Result<String>>,
929{
930 type Item = Result<String>;
931
932 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
933 const CODE_BLOCK_DELIMITER: &str = "```";
934 const CURSOR_SPAN: &str = "<|CURSOR|>";
935
936 let this = unsafe { self.get_unchecked_mut() };
937 loop {
938 if !this.stream_done {
939 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
940 match stream.as_mut().poll_next(cx) {
941 Poll::Ready(Some(Ok(chunk))) => {
942 this.buffer.push_str(&chunk);
943 }
944 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
945 Poll::Ready(None) => {
946 this.stream_done = true;
947 }
948 Poll::Pending => return Poll::Pending,
949 }
950 }
951
952 let mut chunk = String::new();
953 let mut consumed = 0;
954 if !this.buffer.is_empty() {
955 let mut lines = this.buffer.split('\n').enumerate().peekable();
956 while let Some((line_ix, line)) = lines.next() {
957 if line_ix > 0 {
958 this.first_line = false;
959 }
960
961 if this.first_line {
962 let trimmed_line = line.trim();
963 if lines.peek().is_some() {
964 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
965 consumed += line.len() + 1;
966 this.starts_with_code_block = true;
967 continue;
968 }
969 } else if trimmed_line.is_empty()
970 || prefixes(CODE_BLOCK_DELIMITER)
971 .any(|prefix| trimmed_line.starts_with(prefix))
972 {
973 break;
974 }
975 }
976
977 let line_without_cursor = line.replace(CURSOR_SPAN, "");
978 if lines.peek().is_some() {
979 if this.line_end {
980 chunk.push('\n');
981 }
982
983 chunk.push_str(&line_without_cursor);
984 this.line_end = true;
985 consumed += line.len() + 1;
986 } else if this.stream_done {
987 if !this.starts_with_code_block
988 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
989 {
990 if this.line_end {
991 chunk.push('\n');
992 }
993
994 chunk.push_str(&line);
995 }
996
997 consumed += line.len();
998 } else {
999 let trimmed_line = line.trim();
1000 if trimmed_line.is_empty()
1001 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1002 || prefixes(CODE_BLOCK_DELIMITER)
1003 .any(|prefix| trimmed_line.ends_with(prefix))
1004 {
1005 break;
1006 } else {
1007 if this.line_end {
1008 chunk.push('\n');
1009 this.line_end = false;
1010 }
1011
1012 chunk.push_str(&line_without_cursor);
1013 consumed += line.len();
1014 }
1015 }
1016 }
1017 }
1018
1019 this.buffer = this.buffer.split_off(consumed);
1020 if !chunk.is_empty() {
1021 return Poll::Ready(Some(Ok(chunk)));
1022 } else if this.stream_done {
1023 return Poll::Ready(None);
1024 }
1025 }
1026 }
1027}
1028
1029fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1030 (0..text.len() - 1).map(|ix| &text[..ix + 1])
1031}
1032
1033#[derive(Default)]
1034pub struct Diff {
1035 pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1036 pub inserted_row_ranges: Vec<Range<Anchor>>,
1037}
1038
1039impl Diff {
1040 fn is_empty(&self) -> bool {
1041 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1042 }
1043}
1044
1045#[cfg(test)]
1046mod tests {
1047 use super::*;
1048 use futures::{
1049 stream::{self},
1050 Stream,
1051 };
1052 use gpui::{Context, TestAppContext};
1053 use indoc::indoc;
1054 use language::{
1055 language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
1056 Point,
1057 };
1058 use language_model::LanguageModelRegistry;
1059 use rand::prelude::*;
1060 use serde::Serialize;
1061 use settings::SettingsStore;
1062 use std::{future, sync::Arc};
1063
1064 #[derive(Serialize)]
1065 pub struct DummyCompletionRequest {
1066 pub name: String,
1067 }
1068
1069 #[gpui::test(iterations = 10)]
1070 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1071 cx.set_global(cx.update(SettingsStore::test));
1072 cx.update(language_model::LanguageModelRegistry::test);
1073 cx.update(language_settings::init);
1074
1075 let text = indoc! {"
1076 fn main() {
1077 let x = 0;
1078 for _ in 0..10 {
1079 x += 1;
1080 }
1081 }
1082 "};
1083 let buffer =
1084 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1085 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1086 let range = buffer.read_with(cx, |buffer, cx| {
1087 let snapshot = buffer.snapshot(cx);
1088 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1089 });
1090 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1091 let codegen = cx.new_model(|cx| {
1092 CodegenAlternative::new(
1093 buffer.clone(),
1094 range.clone(),
1095 true,
1096 None,
1097 None,
1098 prompt_builder,
1099 cx,
1100 )
1101 });
1102
1103 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1104
1105 let mut new_text = concat!(
1106 " let mut x = 0;\n",
1107 " while x < 10 {\n",
1108 " x += 1;\n",
1109 " }",
1110 );
1111 while !new_text.is_empty() {
1112 let max_len = cmp::min(new_text.len(), 10);
1113 let len = rng.gen_range(1..=max_len);
1114 let (chunk, suffix) = new_text.split_at(len);
1115 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1116 new_text = suffix;
1117 cx.background_executor.run_until_parked();
1118 }
1119 drop(chunks_tx);
1120 cx.background_executor.run_until_parked();
1121
1122 assert_eq!(
1123 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1124 indoc! {"
1125 fn main() {
1126 let mut x = 0;
1127 while x < 10 {
1128 x += 1;
1129 }
1130 }
1131 "}
1132 );
1133 }
1134
1135 #[gpui::test(iterations = 10)]
1136 async fn test_autoindent_when_generating_past_indentation(
1137 cx: &mut TestAppContext,
1138 mut rng: StdRng,
1139 ) {
1140 cx.set_global(cx.update(SettingsStore::test));
1141 cx.update(language_settings::init);
1142
1143 let text = indoc! {"
1144 fn main() {
1145 le
1146 }
1147 "};
1148 let buffer =
1149 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1150 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1151 let range = buffer.read_with(cx, |buffer, cx| {
1152 let snapshot = buffer.snapshot(cx);
1153 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1154 });
1155 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1156 let codegen = cx.new_model(|cx| {
1157 CodegenAlternative::new(
1158 buffer.clone(),
1159 range.clone(),
1160 true,
1161 None,
1162 None,
1163 prompt_builder,
1164 cx,
1165 )
1166 });
1167
1168 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1169
1170 cx.background_executor.run_until_parked();
1171
1172 let mut new_text = concat!(
1173 "t mut x = 0;\n",
1174 "while x < 10 {\n",
1175 " x += 1;\n",
1176 "}", //
1177 );
1178 while !new_text.is_empty() {
1179 let max_len = cmp::min(new_text.len(), 10);
1180 let len = rng.gen_range(1..=max_len);
1181 let (chunk, suffix) = new_text.split_at(len);
1182 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1183 new_text = suffix;
1184 cx.background_executor.run_until_parked();
1185 }
1186 drop(chunks_tx);
1187 cx.background_executor.run_until_parked();
1188
1189 assert_eq!(
1190 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1191 indoc! {"
1192 fn main() {
1193 let mut x = 0;
1194 while x < 10 {
1195 x += 1;
1196 }
1197 }
1198 "}
1199 );
1200 }
1201
1202 #[gpui::test(iterations = 10)]
1203 async fn test_autoindent_when_generating_before_indentation(
1204 cx: &mut TestAppContext,
1205 mut rng: StdRng,
1206 ) {
1207 cx.update(LanguageModelRegistry::test);
1208 cx.set_global(cx.update(SettingsStore::test));
1209 cx.update(language_settings::init);
1210
1211 let text = concat!(
1212 "fn main() {\n",
1213 " \n",
1214 "}\n" //
1215 );
1216 let buffer =
1217 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1218 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1219 let range = buffer.read_with(cx, |buffer, cx| {
1220 let snapshot = buffer.snapshot(cx);
1221 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1222 });
1223 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1224 let codegen = cx.new_model(|cx| {
1225 CodegenAlternative::new(
1226 buffer.clone(),
1227 range.clone(),
1228 true,
1229 None,
1230 None,
1231 prompt_builder,
1232 cx,
1233 )
1234 });
1235
1236 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1237
1238 cx.background_executor.run_until_parked();
1239
1240 let mut new_text = concat!(
1241 "let mut x = 0;\n",
1242 "while x < 10 {\n",
1243 " x += 1;\n",
1244 "}", //
1245 );
1246 while !new_text.is_empty() {
1247 let max_len = cmp::min(new_text.len(), 10);
1248 let len = rng.gen_range(1..=max_len);
1249 let (chunk, suffix) = new_text.split_at(len);
1250 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1251 new_text = suffix;
1252 cx.background_executor.run_until_parked();
1253 }
1254 drop(chunks_tx);
1255 cx.background_executor.run_until_parked();
1256
1257 assert_eq!(
1258 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1259 indoc! {"
1260 fn main() {
1261 let mut x = 0;
1262 while x < 10 {
1263 x += 1;
1264 }
1265 }
1266 "}
1267 );
1268 }
1269
1270 #[gpui::test(iterations = 10)]
1271 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1272 cx.update(LanguageModelRegistry::test);
1273 cx.set_global(cx.update(SettingsStore::test));
1274 cx.update(language_settings::init);
1275
1276 let text = indoc! {"
1277 func main() {
1278 \tx := 0
1279 \tfor i := 0; i < 10; i++ {
1280 \t\tx++
1281 \t}
1282 }
1283 "};
1284 let buffer = cx.new_model(|cx| Buffer::local(text, cx));
1285 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1286 let range = buffer.read_with(cx, |buffer, cx| {
1287 let snapshot = buffer.snapshot(cx);
1288 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1289 });
1290 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1291 let codegen = cx.new_model(|cx| {
1292 CodegenAlternative::new(
1293 buffer.clone(),
1294 range.clone(),
1295 true,
1296 None,
1297 None,
1298 prompt_builder,
1299 cx,
1300 )
1301 });
1302
1303 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1304 let new_text = concat!(
1305 "func main() {\n",
1306 "\tx := 0\n",
1307 "\tfor x < 10 {\n",
1308 "\t\tx++\n",
1309 "\t}", //
1310 );
1311 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1312 drop(chunks_tx);
1313 cx.background_executor.run_until_parked();
1314
1315 assert_eq!(
1316 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1317 indoc! {"
1318 func main() {
1319 \tx := 0
1320 \tfor x < 10 {
1321 \t\tx++
1322 \t}
1323 }
1324 "}
1325 );
1326 }
1327
1328 #[gpui::test]
1329 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1330 cx.update(LanguageModelRegistry::test);
1331 cx.set_global(cx.update(SettingsStore::test));
1332 cx.update(language_settings::init);
1333
1334 let text = indoc! {"
1335 fn main() {
1336 let x = 0;
1337 }
1338 "};
1339 let buffer =
1340 cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1341 let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
1342 let range = buffer.read_with(cx, |buffer, cx| {
1343 let snapshot = buffer.snapshot(cx);
1344 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1345 });
1346 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1347 let codegen = cx.new_model(|cx| {
1348 CodegenAlternative::new(
1349 buffer.clone(),
1350 range.clone(),
1351 false,
1352 None,
1353 None,
1354 prompt_builder,
1355 cx,
1356 )
1357 });
1358
1359 let chunks_tx = simulate_response_stream(codegen.clone(), cx);
1360 chunks_tx
1361 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1362 .unwrap();
1363 drop(chunks_tx);
1364 cx.run_until_parked();
1365
1366 // The codegen is inactive, so the buffer doesn't get modified.
1367 assert_eq!(
1368 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1369 text
1370 );
1371
1372 // Activating the codegen applies the changes.
1373 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1374 assert_eq!(
1375 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1376 indoc! {"
1377 fn main() {
1378 let mut x = 0;
1379 x += 1;
1380 }
1381 "}
1382 );
1383
1384 // Deactivating the codegen undoes the changes.
1385 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1386 cx.run_until_parked();
1387 assert_eq!(
1388 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1389 text
1390 );
1391 }
1392
1393 #[gpui::test]
1394 async fn test_strip_invalid_spans_from_codeblock() {
1395 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1396 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1397 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1398 assert_chunks(
1399 "```html\n```js\nLorem ipsum dolor\n```\n```",
1400 "```js\nLorem ipsum dolor\n```",
1401 )
1402 .await;
1403 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1404 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1405 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1406 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1407
1408 async fn assert_chunks(text: &str, expected_text: &str) {
1409 for chunk_size in 1..=text.len() {
1410 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1411 .map(|chunk| chunk.unwrap())
1412 .collect::<String>()
1413 .await;
1414 assert_eq!(
1415 actual_text, expected_text,
1416 "failed to strip invalid spans, chunk size: {}",
1417 chunk_size
1418 );
1419 }
1420 }
1421
1422 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1423 stream::iter(
1424 text.chars()
1425 .collect::<Vec<_>>()
1426 .chunks(size)
1427 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1428 .collect::<Vec<_>>(),
1429 )
1430 }
1431 }
1432
1433 fn simulate_response_stream(
1434 codegen: Model<CodegenAlternative>,
1435 cx: &mut TestAppContext,
1436 ) -> mpsc::UnboundedSender<String> {
1437 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1438 codegen.update(cx, |codegen, cx| {
1439 codegen.handle_stream(
1440 String::new(),
1441 String::new(),
1442 None,
1443 future::ready(Ok(LanguageModelTextStream {
1444 message_id: None,
1445 stream: chunks_rx.map(Ok).boxed(),
1446 })),
1447 cx,
1448 );
1449 });
1450 chunks_tx
1451 }
1452
1453 fn rust_lang() -> Language {
1454 Language::new(
1455 LanguageConfig {
1456 name: "Rust".into(),
1457 matcher: LanguageMatcher {
1458 path_suffixes: vec!["rs".to_string()],
1459 ..Default::default()
1460 },
1461 ..Default::default()
1462 },
1463 Some(tree_sitter_rust::LANGUAGE.into()),
1464 )
1465 .with_indents_query(
1466 r#"
1467 (call_expression) @indent
1468 (field_expression) @indent
1469 (_ "(" ")" @end) @indent
1470 (_ "{" "}" @end) @indent
1471 "#,
1472 )
1473 .unwrap()
1474 }
1475}