1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
2use agent_settings::AgentSettings;
3use anyhow::{Context as _, Result};
4use client::telemetry::Telemetry;
5use cloud_llm_client::CompletionIntent;
6use collections::HashSet;
7use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
8use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag};
9use futures::{
10 SinkExt, Stream, StreamExt, TryStreamExt as _,
11 channel::mpsc,
12 future::{LocalBoxFuture, Shared},
13 join,
14};
15use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
16use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
17use language_model::{
18 LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
19 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
20 report_assistant_event,
21};
22use multi_buffer::MultiBufferRow;
23use parking_lot::Mutex;
24use prompt_store::PromptBuilder;
25use rope::Rope;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use smol::future::FutureExt;
29use std::{
30 cmp,
31 future::Future,
32 iter,
33 ops::{Range, RangeInclusive},
34 pin::Pin,
35 sync::Arc,
36 task::{self, Poll},
37 time::Instant,
38};
39use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
40use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
41use ui::SharedString;
42
43/// Use this tool to provide a message to the user when you're unable to complete a task.
44#[derive(Debug, Serialize, Deserialize, JsonSchema)]
45pub struct FailureMessageInput {
46 /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
47 ///
48 /// The message may use markdown formatting if you wish.
49 pub message: String,
50}
51
52/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
53#[derive(Debug, Serialize, Deserialize, JsonSchema)]
54pub struct RewriteSectionInput {
55 /// A brief description of the edit you have made.
56 ///
57 /// The description may use markdown formatting if you wish.
58 /// This is optional - if the edit is simple or obvious, you should leave it empty.
59 pub description: String,
60
61 /// The text to replace the section with.
62 pub replacement_text: String,
63}
64
65pub struct BufferCodegen {
66 alternatives: Vec<Entity<CodegenAlternative>>,
67 pub active_alternative: usize,
68 seen_alternatives: HashSet<usize>,
69 subscriptions: Vec<Subscription>,
70 buffer: Entity<MultiBuffer>,
71 range: Range<Anchor>,
72 initial_transaction_id: Option<TransactionId>,
73 telemetry: Arc<Telemetry>,
74 builder: Arc<PromptBuilder>,
75 pub is_insertion: bool,
76}
77
78impl BufferCodegen {
79 pub fn new(
80 buffer: Entity<MultiBuffer>,
81 range: Range<Anchor>,
82 initial_transaction_id: Option<TransactionId>,
83 telemetry: Arc<Telemetry>,
84 builder: Arc<PromptBuilder>,
85 cx: &mut Context<Self>,
86 ) -> Self {
87 let codegen = cx.new(|cx| {
88 CodegenAlternative::new(
89 buffer.clone(),
90 range.clone(),
91 false,
92 Some(telemetry.clone()),
93 builder.clone(),
94 cx,
95 )
96 });
97 let mut this = Self {
98 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
99 alternatives: vec![codegen],
100 active_alternative: 0,
101 seen_alternatives: HashSet::default(),
102 subscriptions: Vec::new(),
103 buffer,
104 range,
105 initial_transaction_id,
106 telemetry,
107 builder,
108 };
109 this.activate(0, cx);
110 this
111 }
112
113 fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
114 let codegen = self.active_alternative().clone();
115 self.subscriptions.clear();
116 self.subscriptions
117 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
118 self.subscriptions
119 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
120 }
121
122 pub fn active_completion(&self, cx: &App) -> Option<String> {
123 self.active_alternative().read(cx).current_completion()
124 }
125
126 pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
127 &self.alternatives[self.active_alternative]
128 }
129
130 pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
131 &self.active_alternative().read(cx).status
132 }
133
134 pub fn alternative_count(&self, cx: &App) -> usize {
135 LanguageModelRegistry::read_global(cx)
136 .inline_alternative_models()
137 .len()
138 + 1
139 }
140
141 pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
142 let next_active_ix = if self.active_alternative == 0 {
143 self.alternatives.len() - 1
144 } else {
145 self.active_alternative - 1
146 };
147 self.activate(next_active_ix, cx);
148 }
149
150 pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
151 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
152 self.activate(next_active_ix, cx);
153 }
154
155 fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
156 self.active_alternative()
157 .update(cx, |codegen, cx| codegen.set_active(false, cx));
158 self.seen_alternatives.insert(index);
159 self.active_alternative = index;
160 self.active_alternative()
161 .update(cx, |codegen, cx| codegen.set_active(true, cx));
162 self.subscribe_to_alternative(cx);
163 cx.notify();
164 }
165
166 pub fn start(
167 &mut self,
168 primary_model: Arc<dyn LanguageModel>,
169 user_prompt: String,
170 context_task: Shared<Task<Option<LoadedContext>>>,
171 cx: &mut Context<Self>,
172 ) -> Result<()> {
173 let alternative_models = LanguageModelRegistry::read_global(cx)
174 .inline_alternative_models()
175 .to_vec();
176
177 self.active_alternative()
178 .update(cx, |alternative, cx| alternative.undo(cx));
179 self.activate(0, cx);
180 self.alternatives.truncate(1);
181
182 for _ in 0..alternative_models.len() {
183 self.alternatives.push(cx.new(|cx| {
184 CodegenAlternative::new(
185 self.buffer.clone(),
186 self.range.clone(),
187 false,
188 Some(self.telemetry.clone()),
189 self.builder.clone(),
190 cx,
191 )
192 }));
193 }
194
195 for (model, alternative) in iter::once(primary_model)
196 .chain(alternative_models)
197 .zip(&self.alternatives)
198 {
199 alternative.update(cx, |alternative, cx| {
200 alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
201 })?;
202 }
203
204 Ok(())
205 }
206
207 pub fn stop(&mut self, cx: &mut Context<Self>) {
208 for codegen in &self.alternatives {
209 codegen.update(cx, |codegen, cx| codegen.stop(cx));
210 }
211 }
212
213 pub fn undo(&mut self, cx: &mut Context<Self>) {
214 self.active_alternative()
215 .update(cx, |codegen, cx| codegen.undo(cx));
216
217 self.buffer.update(cx, |buffer, cx| {
218 if let Some(transaction_id) = self.initial_transaction_id.take() {
219 buffer.undo_transaction(transaction_id, cx);
220 buffer.refresh_preview(cx);
221 }
222 });
223 }
224
225 pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
226 self.active_alternative().read(cx).buffer.clone()
227 }
228
229 pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
230 self.active_alternative().read(cx).old_buffer.clone()
231 }
232
233 pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
234 self.active_alternative().read(cx).snapshot.clone()
235 }
236
237 pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
238 self.active_alternative().read(cx).edit_position
239 }
240
241 pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
242 &self.active_alternative().read(cx).diff
243 }
244
245 pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
246 self.active_alternative().read(cx).last_equal_ranges()
247 }
248
249 pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
250 self.active_alternative().read(cx).selected_text()
251 }
252}
253
254impl EventEmitter<CodegenEvent> for BufferCodegen {}
255
256pub struct CodegenAlternative {
257 buffer: Entity<MultiBuffer>,
258 old_buffer: Entity<Buffer>,
259 snapshot: MultiBufferSnapshot,
260 edit_position: Option<Anchor>,
261 range: Range<Anchor>,
262 last_equal_ranges: Vec<Range<Anchor>>,
263 transformation_transaction_id: Option<TransactionId>,
264 status: CodegenStatus,
265 generation: Task<()>,
266 diff: Diff,
267 telemetry: Option<Arc<Telemetry>>,
268 _subscription: gpui::Subscription,
269 builder: Arc<PromptBuilder>,
270 active: bool,
271 edits: Vec<(Range<Anchor>, String)>,
272 line_operations: Vec<LineOperation>,
273 elapsed_time: Option<f64>,
274 completion: Option<String>,
275 selected_text: Option<String>,
276 pub message_id: Option<String>,
277 pub model_explanation: Option<SharedString>,
278}
279
280impl EventEmitter<CodegenEvent> for CodegenAlternative {}
281
282impl CodegenAlternative {
283 pub fn new(
284 buffer: Entity<MultiBuffer>,
285 range: Range<Anchor>,
286 active: bool,
287 telemetry: Option<Arc<Telemetry>>,
288 builder: Arc<PromptBuilder>,
289 cx: &mut Context<Self>,
290 ) -> Self {
291 let snapshot = buffer.read(cx).snapshot(cx);
292
293 let (old_buffer, _, _) = snapshot
294 .range_to_buffer_ranges(range.clone())
295 .pop()
296 .unwrap();
297 let old_buffer = cx.new(|cx| {
298 let text = old_buffer.as_rope().clone();
299 let line_ending = old_buffer.line_ending();
300 let language = old_buffer.language().cloned();
301 let language_registry = buffer
302 .read(cx)
303 .buffer(old_buffer.remote_id())
304 .unwrap()
305 .read(cx)
306 .language_registry();
307
308 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
309 buffer.set_language(language, cx);
310 if let Some(language_registry) = language_registry {
311 buffer.set_language_registry(language_registry);
312 }
313 buffer
314 });
315
316 Self {
317 buffer: buffer.clone(),
318 old_buffer,
319 edit_position: None,
320 message_id: None,
321 snapshot,
322 last_equal_ranges: Default::default(),
323 transformation_transaction_id: None,
324 status: CodegenStatus::Idle,
325 generation: Task::ready(()),
326 diff: Diff::default(),
327 telemetry,
328 builder,
329 active: active,
330 edits: Vec::new(),
331 line_operations: Vec::new(),
332 range,
333 elapsed_time: None,
334 completion: None,
335 selected_text: None,
336 model_explanation: None,
337 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
338 }
339 }
340
341 pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
342 if active != self.active {
343 self.active = active;
344
345 if self.active {
346 let edits = self.edits.clone();
347 self.apply_edits(edits, cx);
348 if matches!(self.status, CodegenStatus::Pending) {
349 let line_operations = self.line_operations.clone();
350 self.reapply_line_based_diff(line_operations, cx);
351 } else {
352 self.reapply_batch_diff(cx).detach();
353 }
354 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
355 self.buffer.update(cx, |buffer, cx| {
356 buffer.undo_transaction(transaction_id, cx);
357 buffer.forget_transaction(transaction_id, cx);
358 });
359 }
360 }
361 }
362
363 fn handle_buffer_event(
364 &mut self,
365 _buffer: Entity<MultiBuffer>,
366 event: &multi_buffer::Event,
367 cx: &mut Context<Self>,
368 ) {
369 if let multi_buffer::Event::TransactionUndone { transaction_id } = event
370 && self.transformation_transaction_id == Some(*transaction_id)
371 {
372 self.transformation_transaction_id = None;
373 self.generation = Task::ready(());
374 cx.emit(CodegenEvent::Undone);
375 }
376 }
377
378 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
379 &self.last_equal_ranges
380 }
381
382 pub fn start(
383 &mut self,
384 user_prompt: String,
385 context_task: Shared<Task<Option<LoadedContext>>>,
386 model: Arc<dyn LanguageModel>,
387 cx: &mut Context<Self>,
388 ) -> Result<()> {
389 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
390 self.buffer.update(cx, |buffer, cx| {
391 buffer.undo_transaction(transformation_transaction_id, cx);
392 });
393 }
394
395 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
396
397 let api_key = model.api_key(cx);
398 let telemetry_id = model.telemetry_id();
399 let provider_id = model.provider_id();
400
401 if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
402 let request = self.build_request(&model, user_prompt, context_task, cx)?;
403 let tool_use =
404 cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
405 self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
406 } else {
407 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
408 if user_prompt.trim().to_lowercase() == "delete" {
409 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
410 } else {
411 let request = self.build_request(&model, user_prompt, context_task, cx)?;
412 cx.spawn(async move |_, cx| {
413 Ok(model.stream_completion_text(request.await, cx).await?)
414 })
415 .boxed_local()
416 };
417 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
418 }
419
420 Ok(())
421 }
422
423 fn build_request_v2(
424 &self,
425 model: &Arc<dyn LanguageModel>,
426 user_prompt: String,
427 context_task: Shared<Task<Option<LoadedContext>>>,
428 cx: &mut App,
429 ) -> Result<Task<LanguageModelRequest>> {
430 let buffer = self.buffer.read(cx).snapshot(cx);
431 let language = buffer.language_at(self.range.start);
432 let language_name = if let Some(language) = language.as_ref() {
433 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
434 None
435 } else {
436 Some(language.name())
437 }
438 } else {
439 None
440 };
441
442 let language_name = language_name.as_ref();
443 let start = buffer.point_to_buffer_offset(self.range.start);
444 let end = buffer.point_to_buffer_offset(self.range.end);
445 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
446 let (start_buffer, start_buffer_offset) = start;
447 let (end_buffer, end_buffer_offset) = end;
448 if start_buffer.remote_id() == end_buffer.remote_id() {
449 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
450 } else {
451 anyhow::bail!("invalid transformation range");
452 }
453 } else {
454 anyhow::bail!("invalid transformation range");
455 };
456
457 let system_prompt = self
458 .builder
459 .generate_inline_transformation_prompt_v2(
460 language_name,
461 buffer,
462 range.start.0..range.end.0,
463 )
464 .context("generating content prompt")?;
465
466 let temperature = AgentSettings::temperature_for_model(model, cx);
467
468 let tool_input_format = model.tool_input_format();
469
470 Ok(cx.spawn(async move |_cx| {
471 let mut messages = vec![LanguageModelRequestMessage {
472 role: Role::System,
473 content: vec![system_prompt.into()],
474 cache: false,
475 reasoning_details: None,
476 }];
477
478 let mut user_message = LanguageModelRequestMessage {
479 role: Role::User,
480 content: Vec::new(),
481 cache: false,
482 reasoning_details: None,
483 };
484
485 if let Some(context) = context_task.await {
486 context.add_to_request_message(&mut user_message);
487 }
488
489 user_message.content.push(user_prompt.into());
490 messages.push(user_message);
491
492 let tools = vec![
493 LanguageModelRequestTool {
494 name: "rewrite_section".to_string(),
495 description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
496 input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
497 },
498 LanguageModelRequestTool {
499 name: "failure_message".to_string(),
500 description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
501 input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
502 },
503 ];
504
505 LanguageModelRequest {
506 thread_id: None,
507 prompt_id: None,
508 intent: Some(CompletionIntent::InlineAssist),
509 mode: None,
510 tools,
511 tool_choice: None,
512 stop: Vec::new(),
513 temperature,
514 messages,
515 thinking_allowed: false,
516 }
517 }))
518 }
519
520 fn build_request(
521 &self,
522 model: &Arc<dyn LanguageModel>,
523 user_prompt: String,
524 context_task: Shared<Task<Option<LoadedContext>>>,
525 cx: &mut App,
526 ) -> Result<Task<LanguageModelRequest>> {
527 if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
528 return self.build_request_v2(model, user_prompt, context_task, cx);
529 }
530
531 let buffer = self.buffer.read(cx).snapshot(cx);
532 let language = buffer.language_at(self.range.start);
533 let language_name = if let Some(language) = language.as_ref() {
534 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
535 None
536 } else {
537 Some(language.name())
538 }
539 } else {
540 None
541 };
542
543 let language_name = language_name.as_ref();
544 let start = buffer.point_to_buffer_offset(self.range.start);
545 let end = buffer.point_to_buffer_offset(self.range.end);
546 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
547 let (start_buffer, start_buffer_offset) = start;
548 let (end_buffer, end_buffer_offset) = end;
549 if start_buffer.remote_id() == end_buffer.remote_id() {
550 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
551 } else {
552 anyhow::bail!("invalid transformation range");
553 }
554 } else {
555 anyhow::bail!("invalid transformation range");
556 };
557
558 let prompt = self
559 .builder
560 .generate_inline_transformation_prompt(
561 user_prompt,
562 language_name,
563 buffer,
564 range.start.0..range.end.0,
565 )
566 .context("generating content prompt")?;
567
568 let temperature = AgentSettings::temperature_for_model(model, cx);
569
570 Ok(cx.spawn(async move |_cx| {
571 let mut request_message = LanguageModelRequestMessage {
572 role: Role::User,
573 content: Vec::new(),
574 cache: false,
575 reasoning_details: None,
576 };
577
578 if let Some(context) = context_task.await {
579 context.add_to_request_message(&mut request_message);
580 }
581
582 request_message.content.push(prompt.into());
583
584 LanguageModelRequest {
585 thread_id: None,
586 prompt_id: None,
587 intent: Some(CompletionIntent::InlineAssist),
588 mode: None,
589 tools: Vec::new(),
590 tool_choice: None,
591 stop: Vec::new(),
592 temperature,
593 messages: vec![request_message],
594 thinking_allowed: false,
595 }
596 }))
597 }
598
599 pub fn handle_stream(
600 &mut self,
601 model_telemetry_id: String,
602 model_provider_id: String,
603 model_api_key: Option<String>,
604 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
605 cx: &mut Context<Self>,
606 ) {
607 let start_time = Instant::now();
608
609 // Make a new snapshot and re-resolve anchor in case the document was modified.
610 // This can happen often if the editor loses focus and is saved + reformatted,
611 // as in https://github.com/zed-industries/zed/issues/39088
612 self.snapshot = self.buffer.read(cx).snapshot(cx);
613 self.range = self.snapshot.anchor_after(self.range.start)
614 ..self.snapshot.anchor_after(self.range.end);
615
616 let snapshot = self.snapshot.clone();
617 let selected_text = snapshot
618 .text_for_range(self.range.start..self.range.end)
619 .collect::<Rope>();
620
621 self.selected_text = Some(selected_text.to_string());
622
623 let selection_start = self.range.start.to_point(&snapshot);
624
625 // Start with the indentation of the first line in the selection
626 let mut suggested_line_indent = snapshot
627 .suggested_indents(selection_start.row..=selection_start.row, cx)
628 .into_values()
629 .next()
630 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
631
632 // If the first line in the selection does not have indentation, check the following lines
633 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
634 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
635 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
636 // Prefer tabs if a line in the selection uses tabs as indentation
637 if line_indent.kind == IndentKind::Tab {
638 suggested_line_indent.kind = IndentKind::Tab;
639 break;
640 }
641 }
642 }
643
644 let http_client = cx.http_client();
645 let telemetry = self.telemetry.clone();
646 let language_name = {
647 let multibuffer = self.buffer.read(cx);
648 let snapshot = multibuffer.snapshot(cx);
649 let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
650 ranges
651 .first()
652 .and_then(|(buffer, _, _)| buffer.language())
653 .map(|language| language.name())
654 };
655
656 self.diff = Diff::default();
657 self.status = CodegenStatus::Pending;
658 let mut edit_start = self.range.start.to_offset(&snapshot);
659 let completion = Arc::new(Mutex::new(String::new()));
660 let completion_clone = completion.clone();
661
662 self.generation = cx.spawn(async move |codegen, cx| {
663 let stream = stream.await;
664
665 let token_usage = stream
666 .as_ref()
667 .ok()
668 .map(|stream| stream.last_token_usage.clone());
669 let message_id = stream
670 .as_ref()
671 .ok()
672 .and_then(|stream| stream.message_id.clone());
673 let generate = async {
674 let model_telemetry_id = model_telemetry_id.clone();
675 let model_provider_id = model_provider_id.clone();
676 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
677 let executor = cx.background_executor().clone();
678 let message_id = message_id.clone();
679 let line_based_stream_diff: Task<anyhow::Result<()>> =
680 cx.background_spawn(async move {
681 let mut response_latency = None;
682 let request_start = Instant::now();
683 let diff = async {
684 let chunks = StripInvalidSpans::new(
685 stream?.stream.map_err(|error| error.into()),
686 );
687 futures::pin_mut!(chunks);
688 let mut diff = StreamingDiff::new(selected_text.to_string());
689 let mut line_diff = LineDiff::default();
690
691 let mut new_text = String::new();
692 let mut base_indent = None;
693 let mut line_indent = None;
694 let mut first_line = true;
695
696 while let Some(chunk) = chunks.next().await {
697 if response_latency.is_none() {
698 response_latency = Some(request_start.elapsed());
699 }
700 let chunk = chunk?;
701 completion_clone.lock().push_str(&chunk);
702
703 let mut lines = chunk.split('\n').peekable();
704 while let Some(line) = lines.next() {
705 new_text.push_str(line);
706 if line_indent.is_none()
707 && let Some(non_whitespace_ch_ix) =
708 new_text.find(|ch: char| !ch.is_whitespace())
709 {
710 line_indent = Some(non_whitespace_ch_ix);
711 base_indent = base_indent.or(line_indent);
712
713 let line_indent = line_indent.unwrap();
714 let base_indent = base_indent.unwrap();
715 let indent_delta = line_indent as i32 - base_indent as i32;
716 let mut corrected_indent_len = cmp::max(
717 0,
718 suggested_line_indent.len as i32 + indent_delta,
719 )
720 as usize;
721 if first_line {
722 corrected_indent_len = corrected_indent_len
723 .saturating_sub(selection_start.column as usize);
724 }
725
726 let indent_char = suggested_line_indent.char();
727 let mut indent_buffer = [0; 4];
728 let indent_str =
729 indent_char.encode_utf8(&mut indent_buffer);
730 new_text.replace_range(
731 ..line_indent,
732 &indent_str.repeat(corrected_indent_len),
733 );
734 }
735
736 if line_indent.is_some() {
737 let char_ops = diff.push_new(&new_text);
738 line_diff.push_char_operations(&char_ops, &selected_text);
739 diff_tx
740 .send((char_ops, line_diff.line_operations()))
741 .await?;
742 new_text.clear();
743 }
744
745 if lines.peek().is_some() {
746 let char_ops = diff.push_new("\n");
747 line_diff.push_char_operations(&char_ops, &selected_text);
748 diff_tx
749 .send((char_ops, line_diff.line_operations()))
750 .await?;
751 if line_indent.is_none() {
752 // Don't write out the leading indentation in empty lines on the next line
753 // This is the case where the above if statement didn't clear the buffer
754 new_text.clear();
755 }
756 line_indent = None;
757 first_line = false;
758 }
759 }
760 }
761
762 let mut char_ops = diff.push_new(&new_text);
763 char_ops.extend(diff.finish());
764 line_diff.push_char_operations(&char_ops, &selected_text);
765 line_diff.finish(&selected_text);
766 diff_tx
767 .send((char_ops, line_diff.line_operations()))
768 .await?;
769
770 anyhow::Ok(())
771 };
772
773 let result = diff.await;
774
775 let error_message = result.as_ref().err().map(|error| error.to_string());
776 report_assistant_event(
777 AssistantEventData {
778 conversation_id: None,
779 message_id,
780 kind: AssistantKind::Inline,
781 phase: AssistantPhase::Response,
782 model: model_telemetry_id,
783 model_provider: model_provider_id,
784 response_latency,
785 error_message,
786 language_name: language_name.map(|name| name.to_proto()),
787 },
788 telemetry,
789 http_client,
790 model_api_key,
791 &executor,
792 );
793
794 result?;
795 Ok(())
796 });
797
798 while let Some((char_ops, line_ops)) = diff_rx.next().await {
799 codegen.update(cx, |codegen, cx| {
800 codegen.last_equal_ranges.clear();
801
802 let edits = char_ops
803 .into_iter()
804 .filter_map(|operation| match operation {
805 CharOperation::Insert { text } => {
806 let edit_start = snapshot.anchor_after(edit_start);
807 Some((edit_start..edit_start, text))
808 }
809 CharOperation::Delete { bytes } => {
810 let edit_end = edit_start + bytes;
811 let edit_range = snapshot.anchor_after(edit_start)
812 ..snapshot.anchor_before(edit_end);
813 edit_start = edit_end;
814 Some((edit_range, String::new()))
815 }
816 CharOperation::Keep { bytes } => {
817 let edit_end = edit_start + bytes;
818 let edit_range = snapshot.anchor_after(edit_start)
819 ..snapshot.anchor_before(edit_end);
820 edit_start = edit_end;
821 codegen.last_equal_ranges.push(edit_range);
822 None
823 }
824 })
825 .collect::<Vec<_>>();
826
827 if codegen.active {
828 codegen.apply_edits(edits.iter().cloned(), cx);
829 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
830 }
831 codegen.edits.extend(edits);
832 codegen.line_operations = line_ops;
833 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
834
835 cx.notify();
836 })?;
837 }
838
839 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
840 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
841 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
842 let batch_diff_task =
843 codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
844 let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
845 line_based_stream_diff?;
846
847 anyhow::Ok(())
848 };
849
850 let result = generate.await;
851 let elapsed_time = start_time.elapsed().as_secs_f64();
852
853 codegen
854 .update(cx, |this, cx| {
855 this.message_id = message_id;
856 this.last_equal_ranges.clear();
857 if let Err(error) = result {
858 this.status = CodegenStatus::Error(error);
859 } else {
860 this.status = CodegenStatus::Done;
861 }
862 this.elapsed_time = Some(elapsed_time);
863 this.completion = Some(completion.lock().clone());
864 if let Some(usage) = token_usage {
865 let usage = usage.lock();
866 telemetry::event!(
867 "Inline Assistant Completion",
868 model = model_telemetry_id,
869 model_provider = model_provider_id,
870 input_tokens = usage.input_tokens,
871 output_tokens = usage.output_tokens,
872 )
873 }
874
875 cx.emit(CodegenEvent::Finished);
876 cx.notify();
877 })
878 .ok();
879 });
880 cx.notify();
881 }
882
883 pub fn current_completion(&self) -> Option<String> {
884 self.completion.clone()
885 }
886
887 pub fn selected_text(&self) -> Option<&str> {
888 self.selected_text.as_deref()
889 }
890
891 pub fn stop(&mut self, cx: &mut Context<Self>) {
892 self.last_equal_ranges.clear();
893 if self.diff.is_empty() {
894 self.status = CodegenStatus::Idle;
895 } else {
896 self.status = CodegenStatus::Done;
897 }
898 self.generation = Task::ready(());
899 cx.emit(CodegenEvent::Finished);
900 cx.notify();
901 }
902
903 pub fn undo(&mut self, cx: &mut Context<Self>) {
904 self.buffer.update(cx, |buffer, cx| {
905 if let Some(transaction_id) = self.transformation_transaction_id.take() {
906 buffer.undo_transaction(transaction_id, cx);
907 buffer.refresh_preview(cx);
908 }
909 });
910 }
911
912 fn apply_edits(
913 &mut self,
914 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
915 cx: &mut Context<CodegenAlternative>,
916 ) {
917 let transaction = self.buffer.update(cx, |buffer, cx| {
918 // Avoid grouping agent edits with user edits.
919 buffer.finalize_last_transaction(cx);
920 buffer.start_transaction(cx);
921 buffer.edit(edits, None, cx);
922 buffer.end_transaction(cx)
923 });
924
925 if let Some(transaction) = transaction {
926 if let Some(first_transaction) = self.transformation_transaction_id {
927 // Group all agent edits into the first transaction.
928 self.buffer.update(cx, |buffer, cx| {
929 buffer.merge_transactions(transaction, first_transaction, cx)
930 });
931 } else {
932 self.transformation_transaction_id = Some(transaction);
933 self.buffer
934 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
935 }
936 }
937 }
938
939 fn reapply_line_based_diff(
940 &mut self,
941 line_operations: impl IntoIterator<Item = LineOperation>,
942 cx: &mut Context<Self>,
943 ) {
944 let old_snapshot = self.snapshot.clone();
945 let old_range = self.range.to_point(&old_snapshot);
946 let new_snapshot = self.buffer.read(cx).snapshot(cx);
947 let new_range = self.range.to_point(&new_snapshot);
948
949 let mut old_row = old_range.start.row;
950 let mut new_row = new_range.start.row;
951
952 self.diff.deleted_row_ranges.clear();
953 self.diff.inserted_row_ranges.clear();
954 for operation in line_operations {
955 match operation {
956 LineOperation::Keep { lines } => {
957 old_row += lines;
958 new_row += lines;
959 }
960 LineOperation::Delete { lines } => {
961 let old_end_row = old_row + lines - 1;
962 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
963
964 if let Some((_, last_deleted_row_range)) =
965 self.diff.deleted_row_ranges.last_mut()
966 {
967 if *last_deleted_row_range.end() + 1 == old_row {
968 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
969 } else {
970 self.diff
971 .deleted_row_ranges
972 .push((new_row, old_row..=old_end_row));
973 }
974 } else {
975 self.diff
976 .deleted_row_ranges
977 .push((new_row, old_row..=old_end_row));
978 }
979
980 old_row += lines;
981 }
982 LineOperation::Insert { lines } => {
983 let new_end_row = new_row + lines - 1;
984 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
985 let end = new_snapshot.anchor_before(Point::new(
986 new_end_row,
987 new_snapshot.line_len(MultiBufferRow(new_end_row)),
988 ));
989 self.diff.inserted_row_ranges.push(start..end);
990 new_row += lines;
991 }
992 }
993
994 cx.notify();
995 }
996 }
997
998 fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
999 let old_snapshot = self.snapshot.clone();
1000 let old_range = self.range.to_point(&old_snapshot);
1001 let new_snapshot = self.buffer.read(cx).snapshot(cx);
1002 let new_range = self.range.to_point(&new_snapshot);
1003
1004 cx.spawn(async move |codegen, cx| {
1005 let (deleted_row_ranges, inserted_row_ranges) = cx
1006 .background_spawn(async move {
1007 let old_text = old_snapshot
1008 .text_for_range(
1009 Point::new(old_range.start.row, 0)
1010 ..Point::new(
1011 old_range.end.row,
1012 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1013 ),
1014 )
1015 .collect::<String>();
1016 let new_text = new_snapshot
1017 .text_for_range(
1018 Point::new(new_range.start.row, 0)
1019 ..Point::new(
1020 new_range.end.row,
1021 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1022 ),
1023 )
1024 .collect::<String>();
1025
1026 let old_start_row = old_range.start.row;
1027 let new_start_row = new_range.start.row;
1028 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1029 let mut inserted_row_ranges = Vec::new();
1030 for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1031 let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1032 let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1033 if !old_rows.is_empty() {
1034 deleted_row_ranges.push((
1035 new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1036 old_rows.start..=old_rows.end - 1,
1037 ));
1038 }
1039 if !new_rows.is_empty() {
1040 let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1041 let new_end_row = new_rows.end - 1;
1042 let end = new_snapshot.anchor_before(Point::new(
1043 new_end_row,
1044 new_snapshot.line_len(MultiBufferRow(new_end_row)),
1045 ));
1046 inserted_row_ranges.push(start..end);
1047 }
1048 }
1049 (deleted_row_ranges, inserted_row_ranges)
1050 })
1051 .await;
1052
1053 codegen
1054 .update(cx, |codegen, cx| {
1055 codegen.diff.deleted_row_ranges = deleted_row_ranges;
1056 codegen.diff.inserted_row_ranges = inserted_row_ranges;
1057 cx.notify();
1058 })
1059 .ok();
1060 })
1061 }
1062
1063 fn handle_tool_use(
1064 &mut self,
1065 _telemetry_id: String,
1066 _provider_id: String,
1067 _api_key: Option<String>,
1068 tool_use: impl 'static
1069 + Future<
1070 Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
1071 >,
1072 cx: &mut Context<Self>,
1073 ) {
1074 self.diff = Diff::default();
1075 self.status = CodegenStatus::Pending;
1076
1077 self.generation = cx.spawn(async move |codegen, cx| {
1078 let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1079 let _ = codegen.update(cx, |this, cx| {
1080 this.status = status;
1081 cx.emit(CodegenEvent::Finished);
1082 cx.notify();
1083 });
1084 };
1085
1086 let tool_use = tool_use.await;
1087
1088 match tool_use {
1089 Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
1090 // Parse the input JSON into RewriteSectionInput
1091 match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
1092 Ok(input) => {
1093 // Store the description if non-empty
1094 let description = if !input.description.trim().is_empty() {
1095 Some(input.description.clone())
1096 } else {
1097 None
1098 };
1099
1100 // Apply the replacement text to the buffer and compute diff
1101 let batch_diff_task = codegen
1102 .update(cx, |this, cx| {
1103 this.model_explanation = description.map(Into::into);
1104 let range = this.range.clone();
1105 this.apply_edits(
1106 std::iter::once((range, input.replacement_text)),
1107 cx,
1108 );
1109 this.reapply_batch_diff(cx)
1110 })
1111 .ok();
1112
1113 // Wait for the diff computation to complete
1114 if let Some(diff_task) = batch_diff_task {
1115 diff_task.await;
1116 }
1117
1118 finish_with_status(CodegenStatus::Done, cx);
1119 return;
1120 }
1121 Err(e) => {
1122 finish_with_status(CodegenStatus::Error(e.into()), cx);
1123 return;
1124 }
1125 }
1126 }
1127 Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
1128 // Handle failure message tool use
1129 match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
1130 Ok(input) => {
1131 let _ = codegen.update(cx, |this, _cx| {
1132 // Store the failure message as the tool description
1133 this.model_explanation = Some(input.message.into());
1134 });
1135 finish_with_status(CodegenStatus::Done, cx);
1136 return;
1137 }
1138 Err(e) => {
1139 finish_with_status(CodegenStatus::Error(e.into()), cx);
1140 return;
1141 }
1142 }
1143 }
1144 Ok(_tool_use) => {
1145 // Unexpected tool.
1146 finish_with_status(CodegenStatus::Done, cx);
1147 return;
1148 }
1149 Err(e) => {
1150 finish_with_status(CodegenStatus::Error(e.into()), cx);
1151 return;
1152 }
1153 }
1154 });
1155 cx.notify();
1156 }
1157}
1158
1159#[derive(Copy, Clone, Debug)]
1160pub enum CodegenEvent {
1161 Finished,
1162 Undone,
1163}
1164
1165struct StripInvalidSpans<T> {
1166 stream: T,
1167 stream_done: bool,
1168 buffer: String,
1169 first_line: bool,
1170 line_end: bool,
1171 starts_with_code_block: bool,
1172}
1173
1174impl<T> StripInvalidSpans<T>
1175where
1176 T: Stream<Item = Result<String>>,
1177{
1178 fn new(stream: T) -> Self {
1179 Self {
1180 stream,
1181 stream_done: false,
1182 buffer: String::new(),
1183 first_line: true,
1184 line_end: false,
1185 starts_with_code_block: false,
1186 }
1187 }
1188}
1189
1190impl<T> Stream for StripInvalidSpans<T>
1191where
1192 T: Stream<Item = Result<String>>,
1193{
1194 type Item = Result<String>;
1195
1196 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1197 const CODE_BLOCK_DELIMITER: &str = "```";
1198 const CURSOR_SPAN: &str = "<|CURSOR|>";
1199
1200 let this = unsafe { self.get_unchecked_mut() };
1201 loop {
1202 if !this.stream_done {
1203 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1204 match stream.as_mut().poll_next(cx) {
1205 Poll::Ready(Some(Ok(chunk))) => {
1206 this.buffer.push_str(&chunk);
1207 }
1208 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1209 Poll::Ready(None) => {
1210 this.stream_done = true;
1211 }
1212 Poll::Pending => return Poll::Pending,
1213 }
1214 }
1215
1216 let mut chunk = String::new();
1217 let mut consumed = 0;
1218 if !this.buffer.is_empty() {
1219 let mut lines = this.buffer.split('\n').enumerate().peekable();
1220 while let Some((line_ix, line)) = lines.next() {
1221 if line_ix > 0 {
1222 this.first_line = false;
1223 }
1224
1225 if this.first_line {
1226 let trimmed_line = line.trim();
1227 if lines.peek().is_some() {
1228 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1229 consumed += line.len() + 1;
1230 this.starts_with_code_block = true;
1231 continue;
1232 }
1233 } else if trimmed_line.is_empty()
1234 || prefixes(CODE_BLOCK_DELIMITER)
1235 .any(|prefix| trimmed_line.starts_with(prefix))
1236 {
1237 break;
1238 }
1239 }
1240
1241 let line_without_cursor = line.replace(CURSOR_SPAN, "");
1242 if lines.peek().is_some() {
1243 if this.line_end {
1244 chunk.push('\n');
1245 }
1246
1247 chunk.push_str(&line_without_cursor);
1248 this.line_end = true;
1249 consumed += line.len() + 1;
1250 } else if this.stream_done {
1251 if !this.starts_with_code_block
1252 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1253 {
1254 if this.line_end {
1255 chunk.push('\n');
1256 }
1257
1258 chunk.push_str(line);
1259 }
1260
1261 consumed += line.len();
1262 } else {
1263 let trimmed_line = line.trim();
1264 if trimmed_line.is_empty()
1265 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1266 || prefixes(CODE_BLOCK_DELIMITER)
1267 .any(|prefix| trimmed_line.ends_with(prefix))
1268 {
1269 break;
1270 } else {
1271 if this.line_end {
1272 chunk.push('\n');
1273 this.line_end = false;
1274 }
1275
1276 chunk.push_str(&line_without_cursor);
1277 consumed += line.len();
1278 }
1279 }
1280 }
1281 }
1282
1283 this.buffer = this.buffer.split_off(consumed);
1284 if !chunk.is_empty() {
1285 return Poll::Ready(Some(Ok(chunk)));
1286 } else if this.stream_done {
1287 return Poll::Ready(None);
1288 }
1289 }
1290 }
1291}
1292
1293fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1294 (0..text.len() - 1).map(|ix| &text[..ix + 1])
1295}
1296
1297#[derive(Default)]
1298pub struct Diff {
1299 pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1300 pub inserted_row_ranges: Vec<Range<Anchor>>,
1301}
1302
1303impl Diff {
1304 fn is_empty(&self) -> bool {
1305 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1306 }
1307}
1308
1309#[cfg(test)]
1310mod tests {
1311 use super::*;
1312 use futures::{
1313 Stream,
1314 stream::{self},
1315 };
1316 use gpui::TestAppContext;
1317 use indoc::indoc;
1318 use language::{Buffer, Point};
1319 use language_model::{LanguageModelRegistry, TokenUsage};
1320 use languages::rust_lang;
1321 use rand::prelude::*;
1322 use settings::SettingsStore;
1323 use std::{future, sync::Arc};
1324
1325 #[gpui::test(iterations = 10)]
1326 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1327 init_test(cx);
1328
1329 let text = indoc! {"
1330 fn main() {
1331 let x = 0;
1332 for _ in 0..10 {
1333 x += 1;
1334 }
1335 }
1336 "};
1337 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1338 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1339 let range = buffer.read_with(cx, |buffer, cx| {
1340 let snapshot = buffer.snapshot(cx);
1341 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1342 });
1343 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1344 let codegen = cx.new(|cx| {
1345 CodegenAlternative::new(
1346 buffer.clone(),
1347 range.clone(),
1348 true,
1349 None,
1350 prompt_builder,
1351 cx,
1352 )
1353 });
1354
1355 let chunks_tx = simulate_response_stream(&codegen, cx);
1356
1357 let mut new_text = concat!(
1358 " let mut x = 0;\n",
1359 " while x < 10 {\n",
1360 " x += 1;\n",
1361 " }",
1362 );
1363 while !new_text.is_empty() {
1364 let max_len = cmp::min(new_text.len(), 10);
1365 let len = rng.random_range(1..=max_len);
1366 let (chunk, suffix) = new_text.split_at(len);
1367 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1368 new_text = suffix;
1369 cx.background_executor.run_until_parked();
1370 }
1371 drop(chunks_tx);
1372 cx.background_executor.run_until_parked();
1373
1374 assert_eq!(
1375 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1376 indoc! {"
1377 fn main() {
1378 let mut x = 0;
1379 while x < 10 {
1380 x += 1;
1381 }
1382 }
1383 "}
1384 );
1385 }
1386
1387 #[gpui::test(iterations = 10)]
1388 async fn test_autoindent_when_generating_past_indentation(
1389 cx: &mut TestAppContext,
1390 mut rng: StdRng,
1391 ) {
1392 init_test(cx);
1393
1394 let text = indoc! {"
1395 fn main() {
1396 le
1397 }
1398 "};
1399 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1400 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1401 let range = buffer.read_with(cx, |buffer, cx| {
1402 let snapshot = buffer.snapshot(cx);
1403 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1404 });
1405 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1406 let codegen = cx.new(|cx| {
1407 CodegenAlternative::new(
1408 buffer.clone(),
1409 range.clone(),
1410 true,
1411 None,
1412 prompt_builder,
1413 cx,
1414 )
1415 });
1416
1417 let chunks_tx = simulate_response_stream(&codegen, cx);
1418
1419 cx.background_executor.run_until_parked();
1420
1421 let mut new_text = concat!(
1422 "t mut x = 0;\n",
1423 "while x < 10 {\n",
1424 " x += 1;\n",
1425 "}", //
1426 );
1427 while !new_text.is_empty() {
1428 let max_len = cmp::min(new_text.len(), 10);
1429 let len = rng.random_range(1..=max_len);
1430 let (chunk, suffix) = new_text.split_at(len);
1431 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1432 new_text = suffix;
1433 cx.background_executor.run_until_parked();
1434 }
1435 drop(chunks_tx);
1436 cx.background_executor.run_until_parked();
1437
1438 assert_eq!(
1439 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1440 indoc! {"
1441 fn main() {
1442 let mut x = 0;
1443 while x < 10 {
1444 x += 1;
1445 }
1446 }
1447 "}
1448 );
1449 }
1450
1451 #[gpui::test(iterations = 10)]
1452 async fn test_autoindent_when_generating_before_indentation(
1453 cx: &mut TestAppContext,
1454 mut rng: StdRng,
1455 ) {
1456 init_test(cx);
1457
1458 let text = concat!(
1459 "fn main() {\n",
1460 " \n",
1461 "}\n" //
1462 );
1463 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1464 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1465 let range = buffer.read_with(cx, |buffer, cx| {
1466 let snapshot = buffer.snapshot(cx);
1467 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1468 });
1469 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1470 let codegen = cx.new(|cx| {
1471 CodegenAlternative::new(
1472 buffer.clone(),
1473 range.clone(),
1474 true,
1475 None,
1476 prompt_builder,
1477 cx,
1478 )
1479 });
1480
1481 let chunks_tx = simulate_response_stream(&codegen, cx);
1482
1483 cx.background_executor.run_until_parked();
1484
1485 let mut new_text = concat!(
1486 "let mut x = 0;\n",
1487 "while x < 10 {\n",
1488 " x += 1;\n",
1489 "}", //
1490 );
1491 while !new_text.is_empty() {
1492 let max_len = cmp::min(new_text.len(), 10);
1493 let len = rng.random_range(1..=max_len);
1494 let (chunk, suffix) = new_text.split_at(len);
1495 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1496 new_text = suffix;
1497 cx.background_executor.run_until_parked();
1498 }
1499 drop(chunks_tx);
1500 cx.background_executor.run_until_parked();
1501
1502 assert_eq!(
1503 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1504 indoc! {"
1505 fn main() {
1506 let mut x = 0;
1507 while x < 10 {
1508 x += 1;
1509 }
1510 }
1511 "}
1512 );
1513 }
1514
1515 #[gpui::test(iterations = 10)]
1516 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1517 init_test(cx);
1518
1519 let text = indoc! {"
1520 func main() {
1521 \tx := 0
1522 \tfor i := 0; i < 10; i++ {
1523 \t\tx++
1524 \t}
1525 }
1526 "};
1527 let buffer = cx.new(|cx| Buffer::local(text, cx));
1528 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1529 let range = buffer.read_with(cx, |buffer, cx| {
1530 let snapshot = buffer.snapshot(cx);
1531 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1532 });
1533 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1534 let codegen = cx.new(|cx| {
1535 CodegenAlternative::new(
1536 buffer.clone(),
1537 range.clone(),
1538 true,
1539 None,
1540 prompt_builder,
1541 cx,
1542 )
1543 });
1544
1545 let chunks_tx = simulate_response_stream(&codegen, cx);
1546 let new_text = concat!(
1547 "func main() {\n",
1548 "\tx := 0\n",
1549 "\tfor x < 10 {\n",
1550 "\t\tx++\n",
1551 "\t}", //
1552 );
1553 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1554 drop(chunks_tx);
1555 cx.background_executor.run_until_parked();
1556
1557 assert_eq!(
1558 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1559 indoc! {"
1560 func main() {
1561 \tx := 0
1562 \tfor x < 10 {
1563 \t\tx++
1564 \t}
1565 }
1566 "}
1567 );
1568 }
1569
1570 #[gpui::test]
1571 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1572 init_test(cx);
1573
1574 let text = indoc! {"
1575 fn main() {
1576 let x = 0;
1577 }
1578 "};
1579 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1580 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1581 let range = buffer.read_with(cx, |buffer, cx| {
1582 let snapshot = buffer.snapshot(cx);
1583 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1584 });
1585 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1586 let codegen = cx.new(|cx| {
1587 CodegenAlternative::new(
1588 buffer.clone(),
1589 range.clone(),
1590 false,
1591 None,
1592 prompt_builder,
1593 cx,
1594 )
1595 });
1596
1597 let chunks_tx = simulate_response_stream(&codegen, cx);
1598 chunks_tx
1599 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1600 .unwrap();
1601 drop(chunks_tx);
1602 cx.run_until_parked();
1603
1604 // The codegen is inactive, so the buffer doesn't get modified.
1605 assert_eq!(
1606 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1607 text
1608 );
1609
1610 // Activating the codegen applies the changes.
1611 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1612 assert_eq!(
1613 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1614 indoc! {"
1615 fn main() {
1616 let mut x = 0;
1617 x += 1;
1618 }
1619 "}
1620 );
1621
1622 // Deactivating the codegen undoes the changes.
1623 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1624 cx.run_until_parked();
1625 assert_eq!(
1626 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1627 text
1628 );
1629 }
1630
1631 #[gpui::test]
1632 async fn test_strip_invalid_spans_from_codeblock() {
1633 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1634 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1635 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1636 assert_chunks(
1637 "```html\n```js\nLorem ipsum dolor\n```\n```",
1638 "```js\nLorem ipsum dolor\n```",
1639 )
1640 .await;
1641 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1642 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1643 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1644 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1645
1646 async fn assert_chunks(text: &str, expected_text: &str) {
1647 for chunk_size in 1..=text.len() {
1648 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1649 .map(|chunk| chunk.unwrap())
1650 .collect::<String>()
1651 .await;
1652 assert_eq!(
1653 actual_text, expected_text,
1654 "failed to strip invalid spans, chunk size: {}",
1655 chunk_size
1656 );
1657 }
1658 }
1659
1660 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1661 stream::iter(
1662 text.chars()
1663 .collect::<Vec<_>>()
1664 .chunks(size)
1665 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1666 .collect::<Vec<_>>(),
1667 )
1668 }
1669 }
1670
1671 fn init_test(cx: &mut TestAppContext) {
1672 cx.update(LanguageModelRegistry::test);
1673 cx.set_global(cx.update(SettingsStore::test));
1674 }
1675
1676 fn simulate_response_stream(
1677 codegen: &Entity<CodegenAlternative>,
1678 cx: &mut TestAppContext,
1679 ) -> mpsc::UnboundedSender<String> {
1680 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1681 codegen.update(cx, |codegen, cx| {
1682 codegen.handle_stream(
1683 String::new(),
1684 String::new(),
1685 None,
1686 future::ready(Ok(LanguageModelTextStream {
1687 message_id: None,
1688 stream: chunks_rx.map(Ok).boxed(),
1689 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1690 })),
1691 cx,
1692 );
1693 });
1694 chunks_tx
1695 }
1696}