1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
2use agent_settings::AgentSettings;
3use anyhow::{Context as _, Result};
4use client::telemetry::Telemetry;
5use cloud_llm_client::CompletionIntent;
6use collections::HashSet;
7use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
8use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag};
9use futures::{
10 SinkExt, Stream, StreamExt, TryStreamExt as _,
11 channel::mpsc,
12 future::{LocalBoxFuture, Shared},
13 join,
14};
15use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
16use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
17use language_model::{
18 LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
19 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
20 report_assistant_event,
21};
22use multi_buffer::MultiBufferRow;
23use parking_lot::Mutex;
24use prompt_store::PromptBuilder;
25use rope::Rope;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use smol::future::FutureExt;
29use std::{
30 cmp,
31 future::Future,
32 iter,
33 ops::{Range, RangeInclusive},
34 pin::Pin,
35 sync::Arc,
36 task::{self, Poll},
37 time::Instant,
38};
39use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
40use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
41use ui::SharedString;
42
43/// Use this tool to provide a message to the user when you're unable to complete a task.
44#[derive(Debug, Serialize, Deserialize, JsonSchema)]
45pub struct FailureMessageInput {
46 /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
47 ///
48 /// The message may use markdown formatting if you wish.
49 pub message: String,
50}
51
52/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
53#[derive(Debug, Serialize, Deserialize, JsonSchema)]
54pub struct RewriteSectionInput {
55 /// A brief description of the edit you have made.
56 ///
57 /// The description may use markdown formatting if you wish.
58 /// This is optional - if the edit is simple or obvious, you should leave it empty.
59 pub description: String,
60
61 /// The text to replace the section with.
62 pub replacement_text: String,
63}
64
65pub struct BufferCodegen {
66 alternatives: Vec<Entity<CodegenAlternative>>,
67 pub active_alternative: usize,
68 seen_alternatives: HashSet<usize>,
69 subscriptions: Vec<Subscription>,
70 buffer: Entity<MultiBuffer>,
71 range: Range<Anchor>,
72 initial_transaction_id: Option<TransactionId>,
73 telemetry: Arc<Telemetry>,
74 builder: Arc<PromptBuilder>,
75 pub is_insertion: bool,
76}
77
78impl BufferCodegen {
79 pub fn new(
80 buffer: Entity<MultiBuffer>,
81 range: Range<Anchor>,
82 initial_transaction_id: Option<TransactionId>,
83 telemetry: Arc<Telemetry>,
84 builder: Arc<PromptBuilder>,
85 cx: &mut Context<Self>,
86 ) -> Self {
87 let codegen = cx.new(|cx| {
88 CodegenAlternative::new(
89 buffer.clone(),
90 range.clone(),
91 false,
92 Some(telemetry.clone()),
93 builder.clone(),
94 cx,
95 )
96 });
97 let mut this = Self {
98 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
99 alternatives: vec![codegen],
100 active_alternative: 0,
101 seen_alternatives: HashSet::default(),
102 subscriptions: Vec::new(),
103 buffer,
104 range,
105 initial_transaction_id,
106 telemetry,
107 builder,
108 };
109 this.activate(0, cx);
110 this
111 }
112
113 fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
114 let codegen = self.active_alternative().clone();
115 self.subscriptions.clear();
116 self.subscriptions
117 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
118 self.subscriptions
119 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
120 }
121
122 pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
123 &self.alternatives[self.active_alternative]
124 }
125
126 pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
127 &self.active_alternative().read(cx).status
128 }
129
130 pub fn alternative_count(&self, cx: &App) -> usize {
131 LanguageModelRegistry::read_global(cx)
132 .inline_alternative_models()
133 .len()
134 + 1
135 }
136
137 pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
138 let next_active_ix = if self.active_alternative == 0 {
139 self.alternatives.len() - 1
140 } else {
141 self.active_alternative - 1
142 };
143 self.activate(next_active_ix, cx);
144 }
145
146 pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
147 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
148 self.activate(next_active_ix, cx);
149 }
150
151 fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
152 self.active_alternative()
153 .update(cx, |codegen, cx| codegen.set_active(false, cx));
154 self.seen_alternatives.insert(index);
155 self.active_alternative = index;
156 self.active_alternative()
157 .update(cx, |codegen, cx| codegen.set_active(true, cx));
158 self.subscribe_to_alternative(cx);
159 cx.notify();
160 }
161
162 pub fn start(
163 &mut self,
164 primary_model: Arc<dyn LanguageModel>,
165 user_prompt: String,
166 context_task: Shared<Task<Option<LoadedContext>>>,
167 cx: &mut Context<Self>,
168 ) -> Result<()> {
169 let alternative_models = LanguageModelRegistry::read_global(cx)
170 .inline_alternative_models()
171 .to_vec();
172
173 self.active_alternative()
174 .update(cx, |alternative, cx| alternative.undo(cx));
175 self.activate(0, cx);
176 self.alternatives.truncate(1);
177
178 for _ in 0..alternative_models.len() {
179 self.alternatives.push(cx.new(|cx| {
180 CodegenAlternative::new(
181 self.buffer.clone(),
182 self.range.clone(),
183 false,
184 Some(self.telemetry.clone()),
185 self.builder.clone(),
186 cx,
187 )
188 }));
189 }
190
191 for (model, alternative) in iter::once(primary_model)
192 .chain(alternative_models)
193 .zip(&self.alternatives)
194 {
195 alternative.update(cx, |alternative, cx| {
196 alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
197 })?;
198 }
199
200 Ok(())
201 }
202
203 pub fn stop(&mut self, cx: &mut Context<Self>) {
204 for codegen in &self.alternatives {
205 codegen.update(cx, |codegen, cx| codegen.stop(cx));
206 }
207 }
208
209 pub fn undo(&mut self, cx: &mut Context<Self>) {
210 self.active_alternative()
211 .update(cx, |codegen, cx| codegen.undo(cx));
212
213 self.buffer.update(cx, |buffer, cx| {
214 if let Some(transaction_id) = self.initial_transaction_id.take() {
215 buffer.undo_transaction(transaction_id, cx);
216 buffer.refresh_preview(cx);
217 }
218 });
219 }
220
221 pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
222 self.active_alternative().read(cx).buffer.clone()
223 }
224
225 pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
226 self.active_alternative().read(cx).old_buffer.clone()
227 }
228
229 pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
230 self.active_alternative().read(cx).snapshot.clone()
231 }
232
233 pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
234 self.active_alternative().read(cx).edit_position
235 }
236
237 pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
238 &self.active_alternative().read(cx).diff
239 }
240
241 pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
242 self.active_alternative().read(cx).last_equal_ranges()
243 }
244}
245
246impl EventEmitter<CodegenEvent> for BufferCodegen {}
247
248pub struct CodegenAlternative {
249 buffer: Entity<MultiBuffer>,
250 old_buffer: Entity<Buffer>,
251 snapshot: MultiBufferSnapshot,
252 edit_position: Option<Anchor>,
253 range: Range<Anchor>,
254 last_equal_ranges: Vec<Range<Anchor>>,
255 transformation_transaction_id: Option<TransactionId>,
256 status: CodegenStatus,
257 generation: Task<()>,
258 diff: Diff,
259 telemetry: Option<Arc<Telemetry>>,
260 _subscription: gpui::Subscription,
261 builder: Arc<PromptBuilder>,
262 active: bool,
263 edits: Vec<(Range<Anchor>, String)>,
264 line_operations: Vec<LineOperation>,
265 elapsed_time: Option<f64>,
266 completion: Option<String>,
267 pub message_id: Option<String>,
268 pub model_explanation: Option<SharedString>,
269}
270
271impl EventEmitter<CodegenEvent> for CodegenAlternative {}
272
273impl CodegenAlternative {
274 pub fn new(
275 buffer: Entity<MultiBuffer>,
276 range: Range<Anchor>,
277 active: bool,
278 telemetry: Option<Arc<Telemetry>>,
279 builder: Arc<PromptBuilder>,
280 cx: &mut Context<Self>,
281 ) -> Self {
282 let snapshot = buffer.read(cx).snapshot(cx);
283
284 let (old_buffer, _, _) = snapshot
285 .range_to_buffer_ranges(range.clone())
286 .pop()
287 .unwrap();
288 let old_buffer = cx.new(|cx| {
289 let text = old_buffer.as_rope().clone();
290 let line_ending = old_buffer.line_ending();
291 let language = old_buffer.language().cloned();
292 let language_registry = buffer
293 .read(cx)
294 .buffer(old_buffer.remote_id())
295 .unwrap()
296 .read(cx)
297 .language_registry();
298
299 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
300 buffer.set_language(language, cx);
301 if let Some(language_registry) = language_registry {
302 buffer.set_language_registry(language_registry);
303 }
304 buffer
305 });
306
307 Self {
308 buffer: buffer.clone(),
309 old_buffer,
310 edit_position: None,
311 message_id: None,
312 snapshot,
313 last_equal_ranges: Default::default(),
314 transformation_transaction_id: None,
315 status: CodegenStatus::Idle,
316 generation: Task::ready(()),
317 diff: Diff::default(),
318 telemetry,
319 builder,
320 active: active,
321 edits: Vec::new(),
322 line_operations: Vec::new(),
323 range,
324 elapsed_time: None,
325 completion: None,
326 model_explanation: None,
327 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
328 }
329 }
330
331 pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
332 if active != self.active {
333 self.active = active;
334
335 if self.active {
336 let edits = self.edits.clone();
337 self.apply_edits(edits, cx);
338 if matches!(self.status, CodegenStatus::Pending) {
339 let line_operations = self.line_operations.clone();
340 self.reapply_line_based_diff(line_operations, cx);
341 } else {
342 self.reapply_batch_diff(cx).detach();
343 }
344 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
345 self.buffer.update(cx, |buffer, cx| {
346 buffer.undo_transaction(transaction_id, cx);
347 buffer.forget_transaction(transaction_id, cx);
348 });
349 }
350 }
351 }
352
353 fn handle_buffer_event(
354 &mut self,
355 _buffer: Entity<MultiBuffer>,
356 event: &multi_buffer::Event,
357 cx: &mut Context<Self>,
358 ) {
359 if let multi_buffer::Event::TransactionUndone { transaction_id } = event
360 && self.transformation_transaction_id == Some(*transaction_id)
361 {
362 self.transformation_transaction_id = None;
363 self.generation = Task::ready(());
364 cx.emit(CodegenEvent::Undone);
365 }
366 }
367
368 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
369 &self.last_equal_ranges
370 }
371
372 pub fn start(
373 &mut self,
374 user_prompt: String,
375 context_task: Shared<Task<Option<LoadedContext>>>,
376 model: Arc<dyn LanguageModel>,
377 cx: &mut Context<Self>,
378 ) -> Result<()> {
379 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
380 self.buffer.update(cx, |buffer, cx| {
381 buffer.undo_transaction(transformation_transaction_id, cx);
382 });
383 }
384
385 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
386
387 let api_key = model.api_key(cx);
388 let telemetry_id = model.telemetry_id();
389 let provider_id = model.provider_id();
390
391 if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
392 let request = self.build_request(&model, user_prompt, context_task, cx)?;
393 let tool_use =
394 cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
395 self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
396 } else {
397 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
398 if user_prompt.trim().to_lowercase() == "delete" {
399 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
400 } else {
401 let request = self.build_request(&model, user_prompt, context_task, cx)?;
402 cx.spawn(async move |_, cx| {
403 Ok(model.stream_completion_text(request.await, cx).await?)
404 })
405 .boxed_local()
406 };
407 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
408 }
409
410 Ok(())
411 }
412
413 fn build_request_v2(
414 &self,
415 model: &Arc<dyn LanguageModel>,
416 user_prompt: String,
417 context_task: Shared<Task<Option<LoadedContext>>>,
418 cx: &mut App,
419 ) -> Result<Task<LanguageModelRequest>> {
420 let buffer = self.buffer.read(cx).snapshot(cx);
421 let language = buffer.language_at(self.range.start);
422 let language_name = if let Some(language) = language.as_ref() {
423 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
424 None
425 } else {
426 Some(language.name())
427 }
428 } else {
429 None
430 };
431
432 let language_name = language_name.as_ref();
433 let start = buffer.point_to_buffer_offset(self.range.start);
434 let end = buffer.point_to_buffer_offset(self.range.end);
435 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
436 let (start_buffer, start_buffer_offset) = start;
437 let (end_buffer, end_buffer_offset) = end;
438 if start_buffer.remote_id() == end_buffer.remote_id() {
439 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
440 } else {
441 anyhow::bail!("invalid transformation range");
442 }
443 } else {
444 anyhow::bail!("invalid transformation range");
445 };
446
447 let system_prompt = self
448 .builder
449 .generate_inline_transformation_prompt_v2(
450 language_name,
451 buffer,
452 range.start.0..range.end.0,
453 )
454 .context("generating content prompt")?;
455
456 let temperature = AgentSettings::temperature_for_model(model, cx);
457
458 let tool_input_format = model.tool_input_format();
459
460 Ok(cx.spawn(async move |_cx| {
461 let mut messages = vec![LanguageModelRequestMessage {
462 role: Role::System,
463 content: vec![system_prompt.into()],
464 cache: false,
465 reasoning_details: None,
466 }];
467
468 let mut user_message = LanguageModelRequestMessage {
469 role: Role::User,
470 content: Vec::new(),
471 cache: false,
472 reasoning_details: None,
473 };
474
475 if let Some(context) = context_task.await {
476 context.add_to_request_message(&mut user_message);
477 }
478
479 user_message.content.push(user_prompt.into());
480 messages.push(user_message);
481
482 let tools = vec![
483 LanguageModelRequestTool {
484 name: "rewrite_section".to_string(),
485 description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
486 input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
487 },
488 LanguageModelRequestTool {
489 name: "failure_message".to_string(),
490 description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
491 input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
492 },
493 ];
494
495 LanguageModelRequest {
496 thread_id: None,
497 prompt_id: None,
498 intent: Some(CompletionIntent::InlineAssist),
499 mode: None,
500 tools,
501 tool_choice: None,
502 stop: Vec::new(),
503 temperature,
504 messages,
505 thinking_allowed: false,
506 }
507 }))
508 }
509
510 fn build_request(
511 &self,
512 model: &Arc<dyn LanguageModel>,
513 user_prompt: String,
514 context_task: Shared<Task<Option<LoadedContext>>>,
515 cx: &mut App,
516 ) -> Result<Task<LanguageModelRequest>> {
517 if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
518 return self.build_request_v2(model, user_prompt, context_task, cx);
519 }
520
521 let buffer = self.buffer.read(cx).snapshot(cx);
522 let language = buffer.language_at(self.range.start);
523 let language_name = if let Some(language) = language.as_ref() {
524 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
525 None
526 } else {
527 Some(language.name())
528 }
529 } else {
530 None
531 };
532
533 let language_name = language_name.as_ref();
534 let start = buffer.point_to_buffer_offset(self.range.start);
535 let end = buffer.point_to_buffer_offset(self.range.end);
536 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
537 let (start_buffer, start_buffer_offset) = start;
538 let (end_buffer, end_buffer_offset) = end;
539 if start_buffer.remote_id() == end_buffer.remote_id() {
540 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
541 } else {
542 anyhow::bail!("invalid transformation range");
543 }
544 } else {
545 anyhow::bail!("invalid transformation range");
546 };
547
548 let prompt = self
549 .builder
550 .generate_inline_transformation_prompt(
551 user_prompt,
552 language_name,
553 buffer,
554 range.start.0..range.end.0,
555 )
556 .context("generating content prompt")?;
557
558 let temperature = AgentSettings::temperature_for_model(model, cx);
559
560 Ok(cx.spawn(async move |_cx| {
561 let mut request_message = LanguageModelRequestMessage {
562 role: Role::User,
563 content: Vec::new(),
564 cache: false,
565 reasoning_details: None,
566 };
567
568 if let Some(context) = context_task.await {
569 context.add_to_request_message(&mut request_message);
570 }
571
572 request_message.content.push(prompt.into());
573
574 LanguageModelRequest {
575 thread_id: None,
576 prompt_id: None,
577 intent: Some(CompletionIntent::InlineAssist),
578 mode: None,
579 tools: Vec::new(),
580 tool_choice: None,
581 stop: Vec::new(),
582 temperature,
583 messages: vec![request_message],
584 thinking_allowed: false,
585 }
586 }))
587 }
588
589 pub fn handle_stream(
590 &mut self,
591 model_telemetry_id: String,
592 model_provider_id: String,
593 model_api_key: Option<String>,
594 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
595 cx: &mut Context<Self>,
596 ) {
597 let start_time = Instant::now();
598
599 // Make a new snapshot and re-resolve anchor in case the document was modified.
600 // This can happen often if the editor loses focus and is saved + reformatted,
601 // as in https://github.com/zed-industries/zed/issues/39088
602 self.snapshot = self.buffer.read(cx).snapshot(cx);
603 self.range = self.snapshot.anchor_after(self.range.start)
604 ..self.snapshot.anchor_after(self.range.end);
605
606 let snapshot = self.snapshot.clone();
607 let selected_text = snapshot
608 .text_for_range(self.range.start..self.range.end)
609 .collect::<Rope>();
610
611 let selection_start = self.range.start.to_point(&snapshot);
612
613 // Start with the indentation of the first line in the selection
614 let mut suggested_line_indent = snapshot
615 .suggested_indents(selection_start.row..=selection_start.row, cx)
616 .into_values()
617 .next()
618 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
619
620 // If the first line in the selection does not have indentation, check the following lines
621 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
622 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
623 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
624 // Prefer tabs if a line in the selection uses tabs as indentation
625 if line_indent.kind == IndentKind::Tab {
626 suggested_line_indent.kind = IndentKind::Tab;
627 break;
628 }
629 }
630 }
631
632 let http_client = cx.http_client();
633 let telemetry = self.telemetry.clone();
634 let language_name = {
635 let multibuffer = self.buffer.read(cx);
636 let snapshot = multibuffer.snapshot(cx);
637 let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
638 ranges
639 .first()
640 .and_then(|(buffer, _, _)| buffer.language())
641 .map(|language| language.name())
642 };
643
644 self.diff = Diff::default();
645 self.status = CodegenStatus::Pending;
646 let mut edit_start = self.range.start.to_offset(&snapshot);
647 let completion = Arc::new(Mutex::new(String::new()));
648 let completion_clone = completion.clone();
649
650 self.generation = cx.spawn(async move |codegen, cx| {
651 let stream = stream.await;
652
653 let token_usage = stream
654 .as_ref()
655 .ok()
656 .map(|stream| stream.last_token_usage.clone());
657 let message_id = stream
658 .as_ref()
659 .ok()
660 .and_then(|stream| stream.message_id.clone());
661 let generate = async {
662 let model_telemetry_id = model_telemetry_id.clone();
663 let model_provider_id = model_provider_id.clone();
664 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
665 let executor = cx.background_executor().clone();
666 let message_id = message_id.clone();
667 let line_based_stream_diff: Task<anyhow::Result<()>> =
668 cx.background_spawn(async move {
669 let mut response_latency = None;
670 let request_start = Instant::now();
671 let diff = async {
672 let chunks = StripInvalidSpans::new(
673 stream?.stream.map_err(|error| error.into()),
674 );
675 futures::pin_mut!(chunks);
676 let mut diff = StreamingDiff::new(selected_text.to_string());
677 let mut line_diff = LineDiff::default();
678
679 let mut new_text = String::new();
680 let mut base_indent = None;
681 let mut line_indent = None;
682 let mut first_line = true;
683
684 while let Some(chunk) = chunks.next().await {
685 if response_latency.is_none() {
686 response_latency = Some(request_start.elapsed());
687 }
688 let chunk = chunk?;
689 completion_clone.lock().push_str(&chunk);
690
691 let mut lines = chunk.split('\n').peekable();
692 while let Some(line) = lines.next() {
693 new_text.push_str(line);
694 if line_indent.is_none()
695 && let Some(non_whitespace_ch_ix) =
696 new_text.find(|ch: char| !ch.is_whitespace())
697 {
698 line_indent = Some(non_whitespace_ch_ix);
699 base_indent = base_indent.or(line_indent);
700
701 let line_indent = line_indent.unwrap();
702 let base_indent = base_indent.unwrap();
703 let indent_delta = line_indent as i32 - base_indent as i32;
704 let mut corrected_indent_len = cmp::max(
705 0,
706 suggested_line_indent.len as i32 + indent_delta,
707 )
708 as usize;
709 if first_line {
710 corrected_indent_len = corrected_indent_len
711 .saturating_sub(selection_start.column as usize);
712 }
713
714 let indent_char = suggested_line_indent.char();
715 let mut indent_buffer = [0; 4];
716 let indent_str =
717 indent_char.encode_utf8(&mut indent_buffer);
718 new_text.replace_range(
719 ..line_indent,
720 &indent_str.repeat(corrected_indent_len),
721 );
722 }
723
724 if line_indent.is_some() {
725 let char_ops = diff.push_new(&new_text);
726 line_diff.push_char_operations(&char_ops, &selected_text);
727 diff_tx
728 .send((char_ops, line_diff.line_operations()))
729 .await?;
730 new_text.clear();
731 }
732
733 if lines.peek().is_some() {
734 let char_ops = diff.push_new("\n");
735 line_diff.push_char_operations(&char_ops, &selected_text);
736 diff_tx
737 .send((char_ops, line_diff.line_operations()))
738 .await?;
739 if line_indent.is_none() {
740 // Don't write out the leading indentation in empty lines on the next line
741 // This is the case where the above if statement didn't clear the buffer
742 new_text.clear();
743 }
744 line_indent = None;
745 first_line = false;
746 }
747 }
748 }
749
750 let mut char_ops = diff.push_new(&new_text);
751 char_ops.extend(diff.finish());
752 line_diff.push_char_operations(&char_ops, &selected_text);
753 line_diff.finish(&selected_text);
754 diff_tx
755 .send((char_ops, line_diff.line_operations()))
756 .await?;
757
758 anyhow::Ok(())
759 };
760
761 let result = diff.await;
762
763 let error_message = result.as_ref().err().map(|error| error.to_string());
764 report_assistant_event(
765 AssistantEventData {
766 conversation_id: None,
767 message_id,
768 kind: AssistantKind::Inline,
769 phase: AssistantPhase::Response,
770 model: model_telemetry_id,
771 model_provider: model_provider_id,
772 response_latency,
773 error_message,
774 language_name: language_name.map(|name| name.to_proto()),
775 },
776 telemetry,
777 http_client,
778 model_api_key,
779 &executor,
780 );
781
782 result?;
783 Ok(())
784 });
785
786 while let Some((char_ops, line_ops)) = diff_rx.next().await {
787 codegen.update(cx, |codegen, cx| {
788 codegen.last_equal_ranges.clear();
789
790 let edits = char_ops
791 .into_iter()
792 .filter_map(|operation| match operation {
793 CharOperation::Insert { text } => {
794 let edit_start = snapshot.anchor_after(edit_start);
795 Some((edit_start..edit_start, text))
796 }
797 CharOperation::Delete { bytes } => {
798 let edit_end = edit_start + bytes;
799 let edit_range = snapshot.anchor_after(edit_start)
800 ..snapshot.anchor_before(edit_end);
801 edit_start = edit_end;
802 Some((edit_range, String::new()))
803 }
804 CharOperation::Keep { bytes } => {
805 let edit_end = edit_start + bytes;
806 let edit_range = snapshot.anchor_after(edit_start)
807 ..snapshot.anchor_before(edit_end);
808 edit_start = edit_end;
809 codegen.last_equal_ranges.push(edit_range);
810 None
811 }
812 })
813 .collect::<Vec<_>>();
814
815 if codegen.active {
816 codegen.apply_edits(edits.iter().cloned(), cx);
817 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
818 }
819 codegen.edits.extend(edits);
820 codegen.line_operations = line_ops;
821 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
822
823 cx.notify();
824 })?;
825 }
826
827 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
828 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
829 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
830 let batch_diff_task =
831 codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
832 let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
833 line_based_stream_diff?;
834
835 anyhow::Ok(())
836 };
837
838 let result = generate.await;
839 let elapsed_time = start_time.elapsed().as_secs_f64();
840
841 codegen
842 .update(cx, |this, cx| {
843 this.message_id = message_id;
844 this.last_equal_ranges.clear();
845 if let Err(error) = result {
846 this.status = CodegenStatus::Error(error);
847 } else {
848 this.status = CodegenStatus::Done;
849 }
850 this.elapsed_time = Some(elapsed_time);
851 this.completion = Some(completion.lock().clone());
852 if let Some(usage) = token_usage {
853 let usage = usage.lock();
854 telemetry::event!(
855 "Inline Assistant Completion",
856 model = model_telemetry_id,
857 model_provider = model_provider_id,
858 input_tokens = usage.input_tokens,
859 output_tokens = usage.output_tokens,
860 )
861 }
862
863 cx.emit(CodegenEvent::Finished);
864 cx.notify();
865 })
866 .ok();
867 });
868 cx.notify();
869 }
870
871 pub fn stop(&mut self, cx: &mut Context<Self>) {
872 self.last_equal_ranges.clear();
873 if self.diff.is_empty() {
874 self.status = CodegenStatus::Idle;
875 } else {
876 self.status = CodegenStatus::Done;
877 }
878 self.generation = Task::ready(());
879 cx.emit(CodegenEvent::Finished);
880 cx.notify();
881 }
882
883 pub fn undo(&mut self, cx: &mut Context<Self>) {
884 self.buffer.update(cx, |buffer, cx| {
885 if let Some(transaction_id) = self.transformation_transaction_id.take() {
886 buffer.undo_transaction(transaction_id, cx);
887 buffer.refresh_preview(cx);
888 }
889 });
890 }
891
892 fn apply_edits(
893 &mut self,
894 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
895 cx: &mut Context<CodegenAlternative>,
896 ) {
897 let transaction = self.buffer.update(cx, |buffer, cx| {
898 // Avoid grouping agent edits with user edits.
899 buffer.finalize_last_transaction(cx);
900 buffer.start_transaction(cx);
901 buffer.edit(edits, None, cx);
902 buffer.end_transaction(cx)
903 });
904
905 if let Some(transaction) = transaction {
906 if let Some(first_transaction) = self.transformation_transaction_id {
907 // Group all agent edits into the first transaction.
908 self.buffer.update(cx, |buffer, cx| {
909 buffer.merge_transactions(transaction, first_transaction, cx)
910 });
911 } else {
912 self.transformation_transaction_id = Some(transaction);
913 self.buffer
914 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
915 }
916 }
917 }
918
919 fn reapply_line_based_diff(
920 &mut self,
921 line_operations: impl IntoIterator<Item = LineOperation>,
922 cx: &mut Context<Self>,
923 ) {
924 let old_snapshot = self.snapshot.clone();
925 let old_range = self.range.to_point(&old_snapshot);
926 let new_snapshot = self.buffer.read(cx).snapshot(cx);
927 let new_range = self.range.to_point(&new_snapshot);
928
929 let mut old_row = old_range.start.row;
930 let mut new_row = new_range.start.row;
931
932 self.diff.deleted_row_ranges.clear();
933 self.diff.inserted_row_ranges.clear();
934 for operation in line_operations {
935 match operation {
936 LineOperation::Keep { lines } => {
937 old_row += lines;
938 new_row += lines;
939 }
940 LineOperation::Delete { lines } => {
941 let old_end_row = old_row + lines - 1;
942 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
943
944 if let Some((_, last_deleted_row_range)) =
945 self.diff.deleted_row_ranges.last_mut()
946 {
947 if *last_deleted_row_range.end() + 1 == old_row {
948 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
949 } else {
950 self.diff
951 .deleted_row_ranges
952 .push((new_row, old_row..=old_end_row));
953 }
954 } else {
955 self.diff
956 .deleted_row_ranges
957 .push((new_row, old_row..=old_end_row));
958 }
959
960 old_row += lines;
961 }
962 LineOperation::Insert { lines } => {
963 let new_end_row = new_row + lines - 1;
964 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
965 let end = new_snapshot.anchor_before(Point::new(
966 new_end_row,
967 new_snapshot.line_len(MultiBufferRow(new_end_row)),
968 ));
969 self.diff.inserted_row_ranges.push(start..end);
970 new_row += lines;
971 }
972 }
973
974 cx.notify();
975 }
976 }
977
978 fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
979 let old_snapshot = self.snapshot.clone();
980 let old_range = self.range.to_point(&old_snapshot);
981 let new_snapshot = self.buffer.read(cx).snapshot(cx);
982 let new_range = self.range.to_point(&new_snapshot);
983
984 cx.spawn(async move |codegen, cx| {
985 let (deleted_row_ranges, inserted_row_ranges) = cx
986 .background_spawn(async move {
987 let old_text = old_snapshot
988 .text_for_range(
989 Point::new(old_range.start.row, 0)
990 ..Point::new(
991 old_range.end.row,
992 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
993 ),
994 )
995 .collect::<String>();
996 let new_text = new_snapshot
997 .text_for_range(
998 Point::new(new_range.start.row, 0)
999 ..Point::new(
1000 new_range.end.row,
1001 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1002 ),
1003 )
1004 .collect::<String>();
1005
1006 let old_start_row = old_range.start.row;
1007 let new_start_row = new_range.start.row;
1008 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1009 let mut inserted_row_ranges = Vec::new();
1010 for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1011 let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1012 let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1013 if !old_rows.is_empty() {
1014 deleted_row_ranges.push((
1015 new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1016 old_rows.start..=old_rows.end - 1,
1017 ));
1018 }
1019 if !new_rows.is_empty() {
1020 let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1021 let new_end_row = new_rows.end - 1;
1022 let end = new_snapshot.anchor_before(Point::new(
1023 new_end_row,
1024 new_snapshot.line_len(MultiBufferRow(new_end_row)),
1025 ));
1026 inserted_row_ranges.push(start..end);
1027 }
1028 }
1029 (deleted_row_ranges, inserted_row_ranges)
1030 })
1031 .await;
1032
1033 codegen
1034 .update(cx, |codegen, cx| {
1035 codegen.diff.deleted_row_ranges = deleted_row_ranges;
1036 codegen.diff.inserted_row_ranges = inserted_row_ranges;
1037 cx.notify();
1038 })
1039 .ok();
1040 })
1041 }
1042
1043 fn handle_tool_use(
1044 &mut self,
1045 _telemetry_id: String,
1046 _provider_id: String,
1047 _api_key: Option<String>,
1048 tool_use: impl 'static
1049 + Future<
1050 Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
1051 >,
1052 cx: &mut Context<Self>,
1053 ) {
1054 self.diff = Diff::default();
1055 self.status = CodegenStatus::Pending;
1056
1057 self.generation = cx.spawn(async move |codegen, cx| {
1058 let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1059 let _ = codegen.update(cx, |this, cx| {
1060 this.status = status;
1061 cx.emit(CodegenEvent::Finished);
1062 cx.notify();
1063 });
1064 };
1065
1066 let tool_use = tool_use.await;
1067
1068 match tool_use {
1069 Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
1070 // Parse the input JSON into RewriteSectionInput
1071 match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
1072 Ok(input) => {
1073 // Store the description if non-empty
1074 let description = if !input.description.trim().is_empty() {
1075 Some(input.description.clone())
1076 } else {
1077 None
1078 };
1079
1080 // Apply the replacement text to the buffer and compute diff
1081 let batch_diff_task = codegen
1082 .update(cx, |this, cx| {
1083 this.model_explanation = description.map(Into::into);
1084 let range = this.range.clone();
1085 this.apply_edits(
1086 std::iter::once((range, input.replacement_text)),
1087 cx,
1088 );
1089 this.reapply_batch_diff(cx)
1090 })
1091 .ok();
1092
1093 // Wait for the diff computation to complete
1094 if let Some(diff_task) = batch_diff_task {
1095 diff_task.await;
1096 }
1097
1098 finish_with_status(CodegenStatus::Done, cx);
1099 return;
1100 }
1101 Err(e) => {
1102 finish_with_status(CodegenStatus::Error(e.into()), cx);
1103 return;
1104 }
1105 }
1106 }
1107 Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
1108 // Handle failure message tool use
1109 match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
1110 Ok(input) => {
1111 let _ = codegen.update(cx, |this, _cx| {
1112 // Store the failure message as the tool description
1113 this.model_explanation = Some(input.message.into());
1114 });
1115 finish_with_status(CodegenStatus::Done, cx);
1116 return;
1117 }
1118 Err(e) => {
1119 finish_with_status(CodegenStatus::Error(e.into()), cx);
1120 return;
1121 }
1122 }
1123 }
1124 Ok(_tool_use) => {
1125 // Unexpected tool.
1126 finish_with_status(CodegenStatus::Done, cx);
1127 return;
1128 }
1129 Err(e) => {
1130 finish_with_status(CodegenStatus::Error(e.into()), cx);
1131 return;
1132 }
1133 }
1134 });
1135 cx.notify();
1136 }
1137}
1138
1139#[derive(Copy, Clone, Debug)]
1140pub enum CodegenEvent {
1141 Finished,
1142 Undone,
1143}
1144
1145struct StripInvalidSpans<T> {
1146 stream: T,
1147 stream_done: bool,
1148 buffer: String,
1149 first_line: bool,
1150 line_end: bool,
1151 starts_with_code_block: bool,
1152}
1153
1154impl<T> StripInvalidSpans<T>
1155where
1156 T: Stream<Item = Result<String>>,
1157{
1158 fn new(stream: T) -> Self {
1159 Self {
1160 stream,
1161 stream_done: false,
1162 buffer: String::new(),
1163 first_line: true,
1164 line_end: false,
1165 starts_with_code_block: false,
1166 }
1167 }
1168}
1169
1170impl<T> Stream for StripInvalidSpans<T>
1171where
1172 T: Stream<Item = Result<String>>,
1173{
1174 type Item = Result<String>;
1175
1176 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1177 const CODE_BLOCK_DELIMITER: &str = "```";
1178 const CURSOR_SPAN: &str = "<|CURSOR|>";
1179
1180 let this = unsafe { self.get_unchecked_mut() };
1181 loop {
1182 if !this.stream_done {
1183 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1184 match stream.as_mut().poll_next(cx) {
1185 Poll::Ready(Some(Ok(chunk))) => {
1186 this.buffer.push_str(&chunk);
1187 }
1188 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1189 Poll::Ready(None) => {
1190 this.stream_done = true;
1191 }
1192 Poll::Pending => return Poll::Pending,
1193 }
1194 }
1195
1196 let mut chunk = String::new();
1197 let mut consumed = 0;
1198 if !this.buffer.is_empty() {
1199 let mut lines = this.buffer.split('\n').enumerate().peekable();
1200 while let Some((line_ix, line)) = lines.next() {
1201 if line_ix > 0 {
1202 this.first_line = false;
1203 }
1204
1205 if this.first_line {
1206 let trimmed_line = line.trim();
1207 if lines.peek().is_some() {
1208 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1209 consumed += line.len() + 1;
1210 this.starts_with_code_block = true;
1211 continue;
1212 }
1213 } else if trimmed_line.is_empty()
1214 || prefixes(CODE_BLOCK_DELIMITER)
1215 .any(|prefix| trimmed_line.starts_with(prefix))
1216 {
1217 break;
1218 }
1219 }
1220
1221 let line_without_cursor = line.replace(CURSOR_SPAN, "");
1222 if lines.peek().is_some() {
1223 if this.line_end {
1224 chunk.push('\n');
1225 }
1226
1227 chunk.push_str(&line_without_cursor);
1228 this.line_end = true;
1229 consumed += line.len() + 1;
1230 } else if this.stream_done {
1231 if !this.starts_with_code_block
1232 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1233 {
1234 if this.line_end {
1235 chunk.push('\n');
1236 }
1237
1238 chunk.push_str(line);
1239 }
1240
1241 consumed += line.len();
1242 } else {
1243 let trimmed_line = line.trim();
1244 if trimmed_line.is_empty()
1245 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1246 || prefixes(CODE_BLOCK_DELIMITER)
1247 .any(|prefix| trimmed_line.ends_with(prefix))
1248 {
1249 break;
1250 } else {
1251 if this.line_end {
1252 chunk.push('\n');
1253 this.line_end = false;
1254 }
1255
1256 chunk.push_str(&line_without_cursor);
1257 consumed += line.len();
1258 }
1259 }
1260 }
1261 }
1262
1263 this.buffer = this.buffer.split_off(consumed);
1264 if !chunk.is_empty() {
1265 return Poll::Ready(Some(Ok(chunk)));
1266 } else if this.stream_done {
1267 return Poll::Ready(None);
1268 }
1269 }
1270 }
1271}
1272
1273fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1274 (0..text.len() - 1).map(|ix| &text[..ix + 1])
1275}
1276
1277#[derive(Default)]
1278pub struct Diff {
1279 pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1280 pub inserted_row_ranges: Vec<Range<Anchor>>,
1281}
1282
1283impl Diff {
1284 fn is_empty(&self) -> bool {
1285 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1286 }
1287}
1288
1289#[cfg(test)]
1290mod tests {
1291 use super::*;
1292 use futures::{
1293 Stream,
1294 stream::{self},
1295 };
1296 use gpui::TestAppContext;
1297 use indoc::indoc;
1298 use language::{Buffer, Point};
1299 use language_model::{LanguageModelRegistry, TokenUsage};
1300 use languages::rust_lang;
1301 use rand::prelude::*;
1302 use settings::SettingsStore;
1303 use std::{future, sync::Arc};
1304
1305 #[gpui::test(iterations = 10)]
1306 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1307 init_test(cx);
1308
1309 let text = indoc! {"
1310 fn main() {
1311 let x = 0;
1312 for _ in 0..10 {
1313 x += 1;
1314 }
1315 }
1316 "};
1317 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1318 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1319 let range = buffer.read_with(cx, |buffer, cx| {
1320 let snapshot = buffer.snapshot(cx);
1321 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1322 });
1323 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1324 let codegen = cx.new(|cx| {
1325 CodegenAlternative::new(
1326 buffer.clone(),
1327 range.clone(),
1328 true,
1329 None,
1330 prompt_builder,
1331 cx,
1332 )
1333 });
1334
1335 let chunks_tx = simulate_response_stream(&codegen, cx);
1336
1337 let mut new_text = concat!(
1338 " let mut x = 0;\n",
1339 " while x < 10 {\n",
1340 " x += 1;\n",
1341 " }",
1342 );
1343 while !new_text.is_empty() {
1344 let max_len = cmp::min(new_text.len(), 10);
1345 let len = rng.random_range(1..=max_len);
1346 let (chunk, suffix) = new_text.split_at(len);
1347 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1348 new_text = suffix;
1349 cx.background_executor.run_until_parked();
1350 }
1351 drop(chunks_tx);
1352 cx.background_executor.run_until_parked();
1353
1354 assert_eq!(
1355 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1356 indoc! {"
1357 fn main() {
1358 let mut x = 0;
1359 while x < 10 {
1360 x += 1;
1361 }
1362 }
1363 "}
1364 );
1365 }
1366
1367 #[gpui::test(iterations = 10)]
1368 async fn test_autoindent_when_generating_past_indentation(
1369 cx: &mut TestAppContext,
1370 mut rng: StdRng,
1371 ) {
1372 init_test(cx);
1373
1374 let text = indoc! {"
1375 fn main() {
1376 le
1377 }
1378 "};
1379 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1380 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1381 let range = buffer.read_with(cx, |buffer, cx| {
1382 let snapshot = buffer.snapshot(cx);
1383 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1384 });
1385 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1386 let codegen = cx.new(|cx| {
1387 CodegenAlternative::new(
1388 buffer.clone(),
1389 range.clone(),
1390 true,
1391 None,
1392 prompt_builder,
1393 cx,
1394 )
1395 });
1396
1397 let chunks_tx = simulate_response_stream(&codegen, cx);
1398
1399 cx.background_executor.run_until_parked();
1400
1401 let mut new_text = concat!(
1402 "t mut x = 0;\n",
1403 "while x < 10 {\n",
1404 " x += 1;\n",
1405 "}", //
1406 );
1407 while !new_text.is_empty() {
1408 let max_len = cmp::min(new_text.len(), 10);
1409 let len = rng.random_range(1..=max_len);
1410 let (chunk, suffix) = new_text.split_at(len);
1411 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1412 new_text = suffix;
1413 cx.background_executor.run_until_parked();
1414 }
1415 drop(chunks_tx);
1416 cx.background_executor.run_until_parked();
1417
1418 assert_eq!(
1419 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1420 indoc! {"
1421 fn main() {
1422 let mut x = 0;
1423 while x < 10 {
1424 x += 1;
1425 }
1426 }
1427 "}
1428 );
1429 }
1430
1431 #[gpui::test(iterations = 10)]
1432 async fn test_autoindent_when_generating_before_indentation(
1433 cx: &mut TestAppContext,
1434 mut rng: StdRng,
1435 ) {
1436 init_test(cx);
1437
1438 let text = concat!(
1439 "fn main() {\n",
1440 " \n",
1441 "}\n" //
1442 );
1443 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1444 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1445 let range = buffer.read_with(cx, |buffer, cx| {
1446 let snapshot = buffer.snapshot(cx);
1447 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1448 });
1449 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1450 let codegen = cx.new(|cx| {
1451 CodegenAlternative::new(
1452 buffer.clone(),
1453 range.clone(),
1454 true,
1455 None,
1456 prompt_builder,
1457 cx,
1458 )
1459 });
1460
1461 let chunks_tx = simulate_response_stream(&codegen, cx);
1462
1463 cx.background_executor.run_until_parked();
1464
1465 let mut new_text = concat!(
1466 "let mut x = 0;\n",
1467 "while x < 10 {\n",
1468 " x += 1;\n",
1469 "}", //
1470 );
1471 while !new_text.is_empty() {
1472 let max_len = cmp::min(new_text.len(), 10);
1473 let len = rng.random_range(1..=max_len);
1474 let (chunk, suffix) = new_text.split_at(len);
1475 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1476 new_text = suffix;
1477 cx.background_executor.run_until_parked();
1478 }
1479 drop(chunks_tx);
1480 cx.background_executor.run_until_parked();
1481
1482 assert_eq!(
1483 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1484 indoc! {"
1485 fn main() {
1486 let mut x = 0;
1487 while x < 10 {
1488 x += 1;
1489 }
1490 }
1491 "}
1492 );
1493 }
1494
1495 #[gpui::test(iterations = 10)]
1496 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1497 init_test(cx);
1498
1499 let text = indoc! {"
1500 func main() {
1501 \tx := 0
1502 \tfor i := 0; i < 10; i++ {
1503 \t\tx++
1504 \t}
1505 }
1506 "};
1507 let buffer = cx.new(|cx| Buffer::local(text, cx));
1508 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1509 let range = buffer.read_with(cx, |buffer, cx| {
1510 let snapshot = buffer.snapshot(cx);
1511 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1512 });
1513 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1514 let codegen = cx.new(|cx| {
1515 CodegenAlternative::new(
1516 buffer.clone(),
1517 range.clone(),
1518 true,
1519 None,
1520 prompt_builder,
1521 cx,
1522 )
1523 });
1524
1525 let chunks_tx = simulate_response_stream(&codegen, cx);
1526 let new_text = concat!(
1527 "func main() {\n",
1528 "\tx := 0\n",
1529 "\tfor x < 10 {\n",
1530 "\t\tx++\n",
1531 "\t}", //
1532 );
1533 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1534 drop(chunks_tx);
1535 cx.background_executor.run_until_parked();
1536
1537 assert_eq!(
1538 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1539 indoc! {"
1540 func main() {
1541 \tx := 0
1542 \tfor x < 10 {
1543 \t\tx++
1544 \t}
1545 }
1546 "}
1547 );
1548 }
1549
1550 #[gpui::test]
1551 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1552 init_test(cx);
1553
1554 let text = indoc! {"
1555 fn main() {
1556 let x = 0;
1557 }
1558 "};
1559 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1560 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1561 let range = buffer.read_with(cx, |buffer, cx| {
1562 let snapshot = buffer.snapshot(cx);
1563 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1564 });
1565 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1566 let codegen = cx.new(|cx| {
1567 CodegenAlternative::new(
1568 buffer.clone(),
1569 range.clone(),
1570 false,
1571 None,
1572 prompt_builder,
1573 cx,
1574 )
1575 });
1576
1577 let chunks_tx = simulate_response_stream(&codegen, cx);
1578 chunks_tx
1579 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1580 .unwrap();
1581 drop(chunks_tx);
1582 cx.run_until_parked();
1583
1584 // The codegen is inactive, so the buffer doesn't get modified.
1585 assert_eq!(
1586 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1587 text
1588 );
1589
1590 // Activating the codegen applies the changes.
1591 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1592 assert_eq!(
1593 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1594 indoc! {"
1595 fn main() {
1596 let mut x = 0;
1597 x += 1;
1598 }
1599 "}
1600 );
1601
1602 // Deactivating the codegen undoes the changes.
1603 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1604 cx.run_until_parked();
1605 assert_eq!(
1606 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1607 text
1608 );
1609 }
1610
1611 #[gpui::test]
1612 async fn test_strip_invalid_spans_from_codeblock() {
1613 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1614 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1615 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1616 assert_chunks(
1617 "```html\n```js\nLorem ipsum dolor\n```\n```",
1618 "```js\nLorem ipsum dolor\n```",
1619 )
1620 .await;
1621 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1622 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1623 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1624 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1625
1626 async fn assert_chunks(text: &str, expected_text: &str) {
1627 for chunk_size in 1..=text.len() {
1628 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1629 .map(|chunk| chunk.unwrap())
1630 .collect::<String>()
1631 .await;
1632 assert_eq!(
1633 actual_text, expected_text,
1634 "failed to strip invalid spans, chunk size: {}",
1635 chunk_size
1636 );
1637 }
1638 }
1639
1640 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1641 stream::iter(
1642 text.chars()
1643 .collect::<Vec<_>>()
1644 .chunks(size)
1645 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1646 .collect::<Vec<_>>(),
1647 )
1648 }
1649 }
1650
1651 fn init_test(cx: &mut TestAppContext) {
1652 cx.update(LanguageModelRegistry::test);
1653 cx.set_global(cx.update(SettingsStore::test));
1654 }
1655
1656 fn simulate_response_stream(
1657 codegen: &Entity<CodegenAlternative>,
1658 cx: &mut TestAppContext,
1659 ) -> mpsc::UnboundedSender<String> {
1660 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1661 codegen.update(cx, |codegen, cx| {
1662 codegen.handle_stream(
1663 String::new(),
1664 String::new(),
1665 None,
1666 future::ready(Ok(LanguageModelTextStream {
1667 message_id: None,
1668 stream: chunks_rx.map(Ok).boxed(),
1669 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1670 })),
1671 cx,
1672 );
1673 });
1674 chunks_tx
1675 }
1676}