1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
2use agent_settings::AgentSettings;
3use anyhow::{Context as _, Result};
4use client::telemetry::Telemetry;
5use cloud_llm_client::CompletionIntent;
6use collections::HashSet;
7use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
8use feature_flags::{FeatureFlagAppExt as _, InlineAssistantV2FeatureFlag};
9use futures::{
10 SinkExt, Stream, StreamExt, TryStreamExt as _,
11 channel::mpsc,
12 future::{LocalBoxFuture, Shared},
13 join,
14};
15use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
16use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
17use language_model::{
18 LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
19 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
20 report_assistant_event,
21};
22use multi_buffer::MultiBufferRow;
23use parking_lot::Mutex;
24use prompt_store::PromptBuilder;
25use rope::Rope;
26use schemars::JsonSchema;
27use serde::{Deserialize, Serialize};
28use smol::future::FutureExt;
29use std::{
30 cmp,
31 future::Future,
32 iter,
33 ops::{Range, RangeInclusive},
34 pin::Pin,
35 sync::Arc,
36 task::{self, Poll},
37 time::Instant,
38};
39use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
40use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
41use ui::SharedString;
42
43/// Use this tool to provide a message to the user when you're unable to complete a task.
44#[derive(Debug, Serialize, Deserialize, JsonSchema)]
45pub struct FailureMessageInput {
46 /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
47 ///
48 /// The message may use markdown formatting if you wish.
49 pub message: String,
50}
51
52/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
53#[derive(Debug, Serialize, Deserialize, JsonSchema)]
54pub struct RewriteSectionInput {
55 /// A brief description of the edit you have made.
56 ///
57 /// The description may use markdown formatting if you wish.
58 /// This is optional - if the edit is simple or obvious, you should leave it empty.
59 pub description: String,
60
61 /// The text to replace the section with.
62 pub replacement_text: String,
63}
64
65pub struct BufferCodegen {
66 alternatives: Vec<Entity<CodegenAlternative>>,
67 pub active_alternative: usize,
68 seen_alternatives: HashSet<usize>,
69 subscriptions: Vec<Subscription>,
70 buffer: Entity<MultiBuffer>,
71 range: Range<Anchor>,
72 initial_transaction_id: Option<TransactionId>,
73 telemetry: Arc<Telemetry>,
74 builder: Arc<PromptBuilder>,
75 pub is_insertion: bool,
76}
77
78impl BufferCodegen {
79 pub fn new(
80 buffer: Entity<MultiBuffer>,
81 range: Range<Anchor>,
82 initial_transaction_id: Option<TransactionId>,
83 telemetry: Arc<Telemetry>,
84 builder: Arc<PromptBuilder>,
85 cx: &mut Context<Self>,
86 ) -> Self {
87 let codegen = cx.new(|cx| {
88 CodegenAlternative::new(
89 buffer.clone(),
90 range.clone(),
91 false,
92 Some(telemetry.clone()),
93 builder.clone(),
94 cx,
95 )
96 });
97 let mut this = Self {
98 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
99 alternatives: vec![codegen],
100 active_alternative: 0,
101 seen_alternatives: HashSet::default(),
102 subscriptions: Vec::new(),
103 buffer,
104 range,
105 initial_transaction_id,
106 telemetry,
107 builder,
108 };
109 this.activate(0, cx);
110 this
111 }
112
113 fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
114 let codegen = self.active_alternative().clone();
115 self.subscriptions.clear();
116 self.subscriptions
117 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
118 self.subscriptions
119 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
120 }
121
122 pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
123 &self.alternatives[self.active_alternative]
124 }
125
126 pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
127 &self.active_alternative().read(cx).status
128 }
129
130 pub fn alternative_count(&self, cx: &App) -> usize {
131 LanguageModelRegistry::read_global(cx)
132 .inline_alternative_models()
133 .len()
134 + 1
135 }
136
137 pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
138 let next_active_ix = if self.active_alternative == 0 {
139 self.alternatives.len() - 1
140 } else {
141 self.active_alternative - 1
142 };
143 self.activate(next_active_ix, cx);
144 }
145
146 pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
147 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
148 self.activate(next_active_ix, cx);
149 }
150
151 fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
152 self.active_alternative()
153 .update(cx, |codegen, cx| codegen.set_active(false, cx));
154 self.seen_alternatives.insert(index);
155 self.active_alternative = index;
156 self.active_alternative()
157 .update(cx, |codegen, cx| codegen.set_active(true, cx));
158 self.subscribe_to_alternative(cx);
159 cx.notify();
160 }
161
162 pub fn start(
163 &mut self,
164 primary_model: Arc<dyn LanguageModel>,
165 user_prompt: String,
166 context_task: Shared<Task<Option<LoadedContext>>>,
167 cx: &mut Context<Self>,
168 ) -> Result<()> {
169 let alternative_models = LanguageModelRegistry::read_global(cx)
170 .inline_alternative_models()
171 .to_vec();
172
173 self.active_alternative()
174 .update(cx, |alternative, cx| alternative.undo(cx));
175 self.activate(0, cx);
176 self.alternatives.truncate(1);
177
178 for _ in 0..alternative_models.len() {
179 self.alternatives.push(cx.new(|cx| {
180 CodegenAlternative::new(
181 self.buffer.clone(),
182 self.range.clone(),
183 false,
184 Some(self.telemetry.clone()),
185 self.builder.clone(),
186 cx,
187 )
188 }));
189 }
190
191 for (model, alternative) in iter::once(primary_model)
192 .chain(alternative_models)
193 .zip(&self.alternatives)
194 {
195 alternative.update(cx, |alternative, cx| {
196 alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
197 })?;
198 }
199
200 Ok(())
201 }
202
203 pub fn stop(&mut self, cx: &mut Context<Self>) {
204 for codegen in &self.alternatives {
205 codegen.update(cx, |codegen, cx| codegen.stop(cx));
206 }
207 }
208
209 pub fn undo(&mut self, cx: &mut Context<Self>) {
210 self.active_alternative()
211 .update(cx, |codegen, cx| codegen.undo(cx));
212
213 self.buffer.update(cx, |buffer, cx| {
214 if let Some(transaction_id) = self.initial_transaction_id.take() {
215 buffer.undo_transaction(transaction_id, cx);
216 buffer.refresh_preview(cx);
217 }
218 });
219 }
220
221 pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
222 self.active_alternative().read(cx).buffer.clone()
223 }
224
225 pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
226 self.active_alternative().read(cx).old_buffer.clone()
227 }
228
229 pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
230 self.active_alternative().read(cx).snapshot.clone()
231 }
232
233 pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
234 self.active_alternative().read(cx).edit_position
235 }
236
237 pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
238 &self.active_alternative().read(cx).diff
239 }
240
241 pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
242 self.active_alternative().read(cx).last_equal_ranges()
243 }
244}
245
246impl EventEmitter<CodegenEvent> for BufferCodegen {}
247
248pub struct CodegenAlternative {
249 buffer: Entity<MultiBuffer>,
250 old_buffer: Entity<Buffer>,
251 snapshot: MultiBufferSnapshot,
252 edit_position: Option<Anchor>,
253 range: Range<Anchor>,
254 last_equal_ranges: Vec<Range<Anchor>>,
255 transformation_transaction_id: Option<TransactionId>,
256 status: CodegenStatus,
257 generation: Task<()>,
258 diff: Diff,
259 telemetry: Option<Arc<Telemetry>>,
260 _subscription: gpui::Subscription,
261 builder: Arc<PromptBuilder>,
262 active: bool,
263 edits: Vec<(Range<Anchor>, String)>,
264 line_operations: Vec<LineOperation>,
265 elapsed_time: Option<f64>,
266 completion: Option<String>,
267 pub message_id: Option<String>,
268 pub model_explanation: Option<SharedString>,
269}
270
271impl EventEmitter<CodegenEvent> for CodegenAlternative {}
272
273impl CodegenAlternative {
274 pub fn new(
275 buffer: Entity<MultiBuffer>,
276 range: Range<Anchor>,
277 active: bool,
278 telemetry: Option<Arc<Telemetry>>,
279 builder: Arc<PromptBuilder>,
280 cx: &mut Context<Self>,
281 ) -> Self {
282 let snapshot = buffer.read(cx).snapshot(cx);
283
284 let (old_buffer, _, _) = snapshot
285 .range_to_buffer_ranges(range.clone())
286 .pop()
287 .unwrap();
288 let old_buffer = cx.new(|cx| {
289 let text = old_buffer.as_rope().clone();
290 let line_ending = old_buffer.line_ending();
291 let language = old_buffer.language().cloned();
292 let language_registry = buffer
293 .read(cx)
294 .buffer(old_buffer.remote_id())
295 .unwrap()
296 .read(cx)
297 .language_registry();
298
299 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
300 buffer.set_language(language, cx);
301 if let Some(language_registry) = language_registry {
302 buffer.set_language_registry(language_registry);
303 }
304 buffer
305 });
306
307 Self {
308 buffer: buffer.clone(),
309 old_buffer,
310 edit_position: None,
311 message_id: None,
312 snapshot,
313 last_equal_ranges: Default::default(),
314 transformation_transaction_id: None,
315 status: CodegenStatus::Idle,
316 generation: Task::ready(()),
317 diff: Diff::default(),
318 telemetry,
319 builder,
320 active: active,
321 edits: Vec::new(),
322 line_operations: Vec::new(),
323 range,
324 elapsed_time: None,
325 completion: None,
326 model_explanation: None,
327 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
328 }
329 }
330
331 pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
332 if active != self.active {
333 self.active = active;
334
335 if self.active {
336 let edits = self.edits.clone();
337 self.apply_edits(edits, cx);
338 if matches!(self.status, CodegenStatus::Pending) {
339 let line_operations = self.line_operations.clone();
340 self.reapply_line_based_diff(line_operations, cx);
341 } else {
342 self.reapply_batch_diff(cx).detach();
343 }
344 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
345 self.buffer.update(cx, |buffer, cx| {
346 buffer.undo_transaction(transaction_id, cx);
347 buffer.forget_transaction(transaction_id, cx);
348 });
349 }
350 }
351 }
352
353 fn handle_buffer_event(
354 &mut self,
355 _buffer: Entity<MultiBuffer>,
356 event: &multi_buffer::Event,
357 cx: &mut Context<Self>,
358 ) {
359 if let multi_buffer::Event::TransactionUndone { transaction_id } = event
360 && self.transformation_transaction_id == Some(*transaction_id)
361 {
362 self.transformation_transaction_id = None;
363 self.generation = Task::ready(());
364 cx.emit(CodegenEvent::Undone);
365 }
366 }
367
368 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
369 &self.last_equal_ranges
370 }
371
372 pub fn start(
373 &mut self,
374 user_prompt: String,
375 context_task: Shared<Task<Option<LoadedContext>>>,
376 model: Arc<dyn LanguageModel>,
377 cx: &mut Context<Self>,
378 ) -> Result<()> {
379 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
380 self.buffer.update(cx, |buffer, cx| {
381 buffer.undo_transaction(transformation_transaction_id, cx);
382 });
383 }
384
385 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
386
387 let api_key = model.api_key(cx);
388 let telemetry_id = model.telemetry_id();
389 let provider_id = model.provider_id();
390
391 if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
392 let request = self.build_request(&model, user_prompt, context_task, cx)?;
393 let tool_use =
394 cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
395 self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
396 } else {
397 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
398 if user_prompt.trim().to_lowercase() == "delete" {
399 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
400 } else {
401 let request = self.build_request(&model, user_prompt, context_task, cx)?;
402 cx.spawn(async move |_, cx| {
403 Ok(model.stream_completion_text(request.await, cx).await?)
404 })
405 .boxed_local()
406 };
407 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
408 }
409
410 Ok(())
411 }
412
413 fn build_request_v2(
414 &self,
415 model: &Arc<dyn LanguageModel>,
416 user_prompt: String,
417 context_task: Shared<Task<Option<LoadedContext>>>,
418 cx: &mut App,
419 ) -> Result<Task<LanguageModelRequest>> {
420 let buffer = self.buffer.read(cx).snapshot(cx);
421 let language = buffer.language_at(self.range.start);
422 let language_name = if let Some(language) = language.as_ref() {
423 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
424 None
425 } else {
426 Some(language.name())
427 }
428 } else {
429 None
430 };
431
432 let language_name = language_name.as_ref();
433 let start = buffer.point_to_buffer_offset(self.range.start);
434 let end = buffer.point_to_buffer_offset(self.range.end);
435 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
436 let (start_buffer, start_buffer_offset) = start;
437 let (end_buffer, end_buffer_offset) = end;
438 if start_buffer.remote_id() == end_buffer.remote_id() {
439 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
440 } else {
441 anyhow::bail!("invalid transformation range");
442 }
443 } else {
444 anyhow::bail!("invalid transformation range");
445 };
446
447 let system_prompt = self
448 .builder
449 .generate_inline_transformation_prompt_v2(
450 language_name,
451 buffer,
452 range.start.0..range.end.0,
453 )
454 .context("generating content prompt")?;
455
456 let temperature = AgentSettings::temperature_for_model(model, cx);
457
458 let tool_input_format = model.tool_input_format();
459
460 Ok(cx.spawn(async move |_cx| {
461 let mut messages = vec![LanguageModelRequestMessage {
462 role: Role::System,
463 content: vec![system_prompt.into()],
464 cache: false,
465 reasoning_details: None,
466 }];
467
468 let mut user_message = LanguageModelRequestMessage {
469 role: Role::User,
470 content: Vec::new(),
471 cache: false,
472 reasoning_details: None,
473 };
474
475 if let Some(context) = context_task.await {
476 context.add_to_request_message(&mut user_message);
477 }
478
479 user_message.content.push(user_prompt.into());
480 messages.push(user_message);
481
482 let tools = vec![
483 LanguageModelRequestTool {
484 name: "rewrite_section".to_string(),
485 description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
486 input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
487 },
488 LanguageModelRequestTool {
489 name: "failure_message".to_string(),
490 description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
491 input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
492 },
493 ];
494
495 LanguageModelRequest {
496 thread_id: None,
497 prompt_id: None,
498 intent: Some(CompletionIntent::InlineAssist),
499 mode: None,
500 tools,
501 tool_choice: None,
502 stop: Vec::new(),
503 temperature,
504 messages,
505 thinking_allowed: false,
506 }
507 }))
508 }
509
510 fn build_request(
511 &self,
512 model: &Arc<dyn LanguageModel>,
513 user_prompt: String,
514 context_task: Shared<Task<Option<LoadedContext>>>,
515 cx: &mut App,
516 ) -> Result<Task<LanguageModelRequest>> {
517 if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
518 return self.build_request_v2(model, user_prompt, context_task, cx);
519 }
520
521 let buffer = self.buffer.read(cx).snapshot(cx);
522 let language = buffer.language_at(self.range.start);
523 let language_name = if let Some(language) = language.as_ref() {
524 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
525 None
526 } else {
527 Some(language.name())
528 }
529 } else {
530 None
531 };
532
533 let language_name = language_name.as_ref();
534 let start = buffer.point_to_buffer_offset(self.range.start);
535 let end = buffer.point_to_buffer_offset(self.range.end);
536 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
537 let (start_buffer, start_buffer_offset) = start;
538 let (end_buffer, end_buffer_offset) = end;
539 if start_buffer.remote_id() == end_buffer.remote_id() {
540 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
541 } else {
542 anyhow::bail!("invalid transformation range");
543 }
544 } else {
545 anyhow::bail!("invalid transformation range");
546 };
547
548 let prompt = self
549 .builder
550 .generate_inline_transformation_prompt(
551 user_prompt,
552 language_name,
553 buffer,
554 range.start.0..range.end.0,
555 )
556 .context("generating content prompt")?;
557
558 let temperature = AgentSettings::temperature_for_model(model, cx);
559
560 Ok(cx.spawn(async move |_cx| {
561 let mut request_message = LanguageModelRequestMessage {
562 role: Role::User,
563 content: Vec::new(),
564 cache: false,
565 reasoning_details: None,
566 };
567
568 if let Some(context) = context_task.await {
569 context.add_to_request_message(&mut request_message);
570 }
571
572 request_message.content.push(prompt.into());
573
574 LanguageModelRequest {
575 thread_id: None,
576 prompt_id: None,
577 intent: Some(CompletionIntent::InlineAssist),
578 mode: None,
579 tools: Vec::new(),
580 tool_choice: None,
581 stop: Vec::new(),
582 temperature,
583 messages: vec![request_message],
584 thinking_allowed: false,
585 }
586 }))
587 }
588
589 pub fn handle_stream(
590 &mut self,
591 model_telemetry_id: String,
592 model_provider_id: String,
593 model_api_key: Option<String>,
594 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
595 cx: &mut Context<Self>,
596 ) {
597 let start_time = Instant::now();
598
599 // Make a new snapshot and re-resolve anchor in case the document was modified.
600 // This can happen often if the editor loses focus and is saved + reformatted,
601 // as in https://github.com/zed-industries/zed/issues/39088
602 self.snapshot = self.buffer.read(cx).snapshot(cx);
603 self.range = self.snapshot.anchor_after(self.range.start)
604 ..self.snapshot.anchor_after(self.range.end);
605
606 let snapshot = self.snapshot.clone();
607 let selected_text = snapshot
608 .text_for_range(self.range.start..self.range.end)
609 .collect::<Rope>();
610
611 let selection_start = self.range.start.to_point(&snapshot);
612
613 // Start with the indentation of the first line in the selection
614 let mut suggested_line_indent = snapshot
615 .suggested_indents(selection_start.row..=selection_start.row, cx)
616 .into_values()
617 .next()
618 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
619
620 // If the first line in the selection does not have indentation, check the following lines
621 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
622 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
623 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
624 // Prefer tabs if a line in the selection uses tabs as indentation
625 if line_indent.kind == IndentKind::Tab {
626 suggested_line_indent.kind = IndentKind::Tab;
627 break;
628 }
629 }
630 }
631
632 let http_client = cx.http_client();
633 let telemetry = self.telemetry.clone();
634 let language_name = {
635 let multibuffer = self.buffer.read(cx);
636 let snapshot = multibuffer.snapshot(cx);
637 let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
638 ranges
639 .first()
640 .and_then(|(buffer, _, _)| buffer.language())
641 .map(|language| language.name())
642 };
643
644 self.diff = Diff::default();
645 self.status = CodegenStatus::Pending;
646 let mut edit_start = self.range.start.to_offset(&snapshot);
647 let completion = Arc::new(Mutex::new(String::new()));
648 let completion_clone = completion.clone();
649
650 self.generation = cx.spawn(async move |codegen, cx| {
651 let stream = stream.await;
652
653 let token_usage = stream
654 .as_ref()
655 .ok()
656 .map(|stream| stream.last_token_usage.clone());
657 let message_id = stream
658 .as_ref()
659 .ok()
660 .and_then(|stream| stream.message_id.clone());
661 let generate = async {
662 let model_telemetry_id = model_telemetry_id.clone();
663 let model_provider_id = model_provider_id.clone();
664 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
665 let executor = cx.background_executor().clone();
666 let message_id = message_id.clone();
667 let line_based_stream_diff: Task<anyhow::Result<()>> =
668 cx.background_spawn(async move {
669 let mut response_latency = None;
670 let request_start = Instant::now();
671 let diff = async {
672 let chunks = StripInvalidSpans::new(
673 stream?.stream.map_err(|error| error.into()),
674 );
675 futures::pin_mut!(chunks);
676 let mut diff = StreamingDiff::new(selected_text.to_string());
677 let mut line_diff = LineDiff::default();
678
679 let mut new_text = String::new();
680 let mut base_indent = None;
681 let mut line_indent = None;
682 let mut first_line = true;
683
684 while let Some(chunk) = chunks.next().await {
685 if response_latency.is_none() {
686 response_latency = Some(request_start.elapsed());
687 }
688 let chunk = chunk?;
689 completion_clone.lock().push_str(&chunk);
690
691 let mut lines = chunk.split('\n').peekable();
692 while let Some(line) = lines.next() {
693 new_text.push_str(line);
694 if line_indent.is_none()
695 && let Some(non_whitespace_ch_ix) =
696 new_text.find(|ch: char| !ch.is_whitespace())
697 {
698 line_indent = Some(non_whitespace_ch_ix);
699 base_indent = base_indent.or(line_indent);
700
701 let line_indent = line_indent.unwrap();
702 let base_indent = base_indent.unwrap();
703 let indent_delta = line_indent as i32 - base_indent as i32;
704 let mut corrected_indent_len = cmp::max(
705 0,
706 suggested_line_indent.len as i32 + indent_delta,
707 )
708 as usize;
709 if first_line {
710 corrected_indent_len = corrected_indent_len
711 .saturating_sub(selection_start.column as usize);
712 }
713
714 let indent_char = suggested_line_indent.char();
715 let mut indent_buffer = [0; 4];
716 let indent_str =
717 indent_char.encode_utf8(&mut indent_buffer);
718 new_text.replace_range(
719 ..line_indent,
720 &indent_str.repeat(corrected_indent_len),
721 );
722 }
723
724 if line_indent.is_some() {
725 let char_ops = diff.push_new(&new_text);
726 line_diff.push_char_operations(&char_ops, &selected_text);
727 diff_tx
728 .send((char_ops, line_diff.line_operations()))
729 .await?;
730 new_text.clear();
731 }
732
733 if lines.peek().is_some() {
734 let char_ops = diff.push_new("\n");
735 line_diff.push_char_operations(&char_ops, &selected_text);
736 diff_tx
737 .send((char_ops, line_diff.line_operations()))
738 .await?;
739 if line_indent.is_none() {
740 // Don't write out the leading indentation in empty lines on the next line
741 // This is the case where the above if statement didn't clear the buffer
742 new_text.clear();
743 }
744 line_indent = None;
745 first_line = false;
746 }
747 }
748 }
749
750 let mut char_ops = diff.push_new(&new_text);
751 char_ops.extend(diff.finish());
752 line_diff.push_char_operations(&char_ops, &selected_text);
753 line_diff.finish(&selected_text);
754 diff_tx
755 .send((char_ops, line_diff.line_operations()))
756 .await?;
757
758 anyhow::Ok(())
759 };
760
761 let result = diff.await;
762
763 let error_message = result.as_ref().err().map(|error| error.to_string());
764 report_assistant_event(
765 AssistantEventData {
766 conversation_id: None,
767 message_id,
768 kind: AssistantKind::Inline,
769 phase: AssistantPhase::Response,
770 model: model_telemetry_id,
771 model_provider: model_provider_id,
772 response_latency,
773 error_message,
774 language_name: language_name.map(|name| name.to_proto()),
775 },
776 telemetry,
777 http_client,
778 model_api_key,
779 &executor,
780 );
781
782 result?;
783 Ok(())
784 });
785
786 while let Some((char_ops, line_ops)) = diff_rx.next().await {
787 codegen.update(cx, |codegen, cx| {
788 codegen.last_equal_ranges.clear();
789
790 let edits = char_ops
791 .into_iter()
792 .filter_map(|operation| match operation {
793 CharOperation::Insert { text } => {
794 let edit_start = snapshot.anchor_after(edit_start);
795 Some((edit_start..edit_start, text))
796 }
797 CharOperation::Delete { bytes } => {
798 let edit_end = edit_start + bytes;
799 let edit_range = snapshot.anchor_after(edit_start)
800 ..snapshot.anchor_before(edit_end);
801 edit_start = edit_end;
802 Some((edit_range, String::new()))
803 }
804 CharOperation::Keep { bytes } => {
805 let edit_end = edit_start + bytes;
806 let edit_range = snapshot.anchor_after(edit_start)
807 ..snapshot.anchor_before(edit_end);
808 edit_start = edit_end;
809 codegen.last_equal_ranges.push(edit_range);
810 None
811 }
812 })
813 .collect::<Vec<_>>();
814
815 if codegen.active {
816 codegen.apply_edits(edits.iter().cloned(), cx);
817 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
818 }
819 codegen.edits.extend(edits);
820 codegen.line_operations = line_ops;
821 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
822
823 cx.notify();
824 })?;
825 }
826
827 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
828 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
829 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
830 let batch_diff_task =
831 codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
832 let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
833 line_based_stream_diff?;
834
835 anyhow::Ok(())
836 };
837
838 let result = generate.await;
839 let elapsed_time = start_time.elapsed().as_secs_f64();
840
841 codegen
842 .update(cx, |this, cx| {
843 this.message_id = message_id;
844 this.last_equal_ranges.clear();
845 if let Err(error) = result {
846 this.status = CodegenStatus::Error(error);
847 } else {
848 this.status = CodegenStatus::Done;
849 }
850 this.elapsed_time = Some(elapsed_time);
851 this.completion = Some(completion.lock().clone());
852 if let Some(usage) = token_usage {
853 let usage = usage.lock();
854 telemetry::event!(
855 "Inline Assistant Completion",
856 model = model_telemetry_id,
857 model_provider = model_provider_id,
858 input_tokens = usage.input_tokens,
859 output_tokens = usage.output_tokens,
860 )
861 }
862
863 cx.emit(CodegenEvent::Finished);
864 cx.notify();
865 })
866 .ok();
867 });
868 cx.notify();
869 }
870
871 pub fn stop(&mut self, cx: &mut Context<Self>) {
872 self.last_equal_ranges.clear();
873 if self.diff.is_empty() {
874 self.status = CodegenStatus::Idle;
875 } else {
876 self.status = CodegenStatus::Done;
877 }
878 self.generation = Task::ready(());
879 cx.emit(CodegenEvent::Finished);
880 cx.notify();
881 }
882
883 pub fn undo(&mut self, cx: &mut Context<Self>) {
884 self.buffer.update(cx, |buffer, cx| {
885 if let Some(transaction_id) = self.transformation_transaction_id.take() {
886 buffer.undo_transaction(transaction_id, cx);
887 buffer.refresh_preview(cx);
888 }
889 });
890 }
891
892 fn apply_edits(
893 &mut self,
894 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
895 cx: &mut Context<CodegenAlternative>,
896 ) {
897 let transaction = self.buffer.update(cx, |buffer, cx| {
898 // Avoid grouping agent edits with user edits.
899 buffer.finalize_last_transaction(cx);
900 buffer.start_transaction(cx);
901 buffer.edit(edits, None, cx);
902 buffer.end_transaction(cx)
903 });
904
905 if let Some(transaction) = transaction {
906 if let Some(first_transaction) = self.transformation_transaction_id {
907 // Group all agent edits into the first transaction.
908 self.buffer.update(cx, |buffer, cx| {
909 buffer.merge_transactions(transaction, first_transaction, cx)
910 });
911 } else {
912 self.transformation_transaction_id = Some(transaction);
913 self.buffer
914 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
915 }
916 }
917 }
918
919 fn reapply_line_based_diff(
920 &mut self,
921 line_operations: impl IntoIterator<Item = LineOperation>,
922 cx: &mut Context<Self>,
923 ) {
924 let old_snapshot = self.snapshot.clone();
925 let old_range = self.range.to_point(&old_snapshot);
926 let new_snapshot = self.buffer.read(cx).snapshot(cx);
927 let new_range = self.range.to_point(&new_snapshot);
928
929 let mut old_row = old_range.start.row;
930 let mut new_row = new_range.start.row;
931
932 self.diff.deleted_row_ranges.clear();
933 self.diff.inserted_row_ranges.clear();
934 for operation in line_operations {
935 match operation {
936 LineOperation::Keep { lines } => {
937 old_row += lines;
938 new_row += lines;
939 }
940 LineOperation::Delete { lines } => {
941 let old_end_row = old_row + lines - 1;
942 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
943
944 if let Some((_, last_deleted_row_range)) =
945 self.diff.deleted_row_ranges.last_mut()
946 {
947 if *last_deleted_row_range.end() + 1 == old_row {
948 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
949 } else {
950 self.diff
951 .deleted_row_ranges
952 .push((new_row, old_row..=old_end_row));
953 }
954 } else {
955 self.diff
956 .deleted_row_ranges
957 .push((new_row, old_row..=old_end_row));
958 }
959
960 old_row += lines;
961 }
962 LineOperation::Insert { lines } => {
963 let new_end_row = new_row + lines - 1;
964 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
965 let end = new_snapshot.anchor_before(Point::new(
966 new_end_row,
967 new_snapshot.line_len(MultiBufferRow(new_end_row)),
968 ));
969 self.diff.inserted_row_ranges.push(start..end);
970 new_row += lines;
971 }
972 }
973
974 cx.notify();
975 }
976 }
977
978 fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
979 let old_snapshot = self.snapshot.clone();
980 let old_range = self.range.to_point(&old_snapshot);
981 let new_snapshot = self.buffer.read(cx).snapshot(cx);
982 let new_range = self.range.to_point(&new_snapshot);
983
984 cx.spawn(async move |codegen, cx| {
985 let (deleted_row_ranges, inserted_row_ranges) = cx
986 .background_spawn(async move {
987 let old_text = old_snapshot
988 .text_for_range(
989 Point::new(old_range.start.row, 0)
990 ..Point::new(
991 old_range.end.row,
992 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
993 ),
994 )
995 .collect::<String>();
996 let new_text = new_snapshot
997 .text_for_range(
998 Point::new(new_range.start.row, 0)
999 ..Point::new(
1000 new_range.end.row,
1001 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1002 ),
1003 )
1004 .collect::<String>();
1005
1006 let old_start_row = old_range.start.row;
1007 let new_start_row = new_range.start.row;
1008 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1009 let mut inserted_row_ranges = Vec::new();
1010 for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1011 let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1012 let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1013 if !old_rows.is_empty() {
1014 deleted_row_ranges.push((
1015 new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1016 old_rows.start..=old_rows.end - 1,
1017 ));
1018 }
1019 if !new_rows.is_empty() {
1020 let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1021 let new_end_row = new_rows.end - 1;
1022 let end = new_snapshot.anchor_before(Point::new(
1023 new_end_row,
1024 new_snapshot.line_len(MultiBufferRow(new_end_row)),
1025 ));
1026 inserted_row_ranges.push(start..end);
1027 }
1028 }
1029 (deleted_row_ranges, inserted_row_ranges)
1030 })
1031 .await;
1032
1033 codegen
1034 .update(cx, |codegen, cx| {
1035 codegen.diff.deleted_row_ranges = deleted_row_ranges;
1036 codegen.diff.inserted_row_ranges = inserted_row_ranges;
1037 cx.notify();
1038 })
1039 .ok();
1040 })
1041 }
1042
1043 fn handle_tool_use(
1044 &mut self,
1045 _telemetry_id: String,
1046 _provider_id: String,
1047 _api_key: Option<String>,
1048 tool_use: impl 'static
1049 + Future<
1050 Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
1051 >,
1052 cx: &mut Context<Self>,
1053 ) {
1054 self.diff = Diff::default();
1055 self.status = CodegenStatus::Pending;
1056
1057 self.generation = cx.spawn(async move |codegen, cx| {
1058 let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1059 let _ = codegen.update(cx, |this, cx| {
1060 this.status = status;
1061 cx.emit(CodegenEvent::Finished);
1062 cx.notify();
1063 });
1064 };
1065
1066 let tool_use = tool_use.await;
1067
1068 match tool_use {
1069 Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
1070 // Parse the input JSON into RewriteSectionInput
1071 match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
1072 Ok(input) => {
1073 // Store the description if non-empty
1074 let description = if !input.description.trim().is_empty() {
1075 Some(input.description.clone())
1076 } else {
1077 None
1078 };
1079
1080 // Apply the replacement text to the buffer and compute diff
1081 let batch_diff_task = codegen
1082 .update(cx, |this, cx| {
1083 this.model_explanation = description.map(Into::into);
1084 let range = this.range.clone();
1085 this.apply_edits(
1086 std::iter::once((range, input.replacement_text)),
1087 cx,
1088 );
1089 this.reapply_batch_diff(cx)
1090 })
1091 .ok();
1092
1093 // Wait for the diff computation to complete
1094 if let Some(diff_task) = batch_diff_task {
1095 diff_task.await;
1096 }
1097
1098 finish_with_status(CodegenStatus::Done, cx);
1099 return;
1100 }
1101 Err(e) => {
1102 finish_with_status(CodegenStatus::Error(e.into()), cx);
1103 return;
1104 }
1105 }
1106 }
1107 Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
1108 // Handle failure message tool use
1109 match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
1110 Ok(input) => {
1111 let _ = codegen.update(cx, |this, _cx| {
1112 // Store the failure message as the tool description
1113 this.model_explanation = Some(input.message.into());
1114 });
1115 finish_with_status(CodegenStatus::Done, cx);
1116 return;
1117 }
1118 Err(e) => {
1119 finish_with_status(CodegenStatus::Error(e.into()), cx);
1120 return;
1121 }
1122 }
1123 }
1124 Ok(_tool_use) => {
1125 // Unexpected tool.
1126 finish_with_status(CodegenStatus::Done, cx);
1127 return;
1128 }
1129 Err(e) => {
1130 finish_with_status(CodegenStatus::Error(e.into()), cx);
1131 return;
1132 }
1133 }
1134 });
1135 cx.notify();
1136 }
1137}
1138
1139#[derive(Copy, Clone, Debug)]
1140pub enum CodegenEvent {
1141 Finished,
1142 Undone,
1143}
1144
1145struct StripInvalidSpans<T> {
1146 stream: T,
1147 stream_done: bool,
1148 buffer: String,
1149 first_line: bool,
1150 line_end: bool,
1151 starts_with_code_block: bool,
1152}
1153
1154impl<T> StripInvalidSpans<T>
1155where
1156 T: Stream<Item = Result<String>>,
1157{
1158 fn new(stream: T) -> Self {
1159 Self {
1160 stream,
1161 stream_done: false,
1162 buffer: String::new(),
1163 first_line: true,
1164 line_end: false,
1165 starts_with_code_block: false,
1166 }
1167 }
1168}
1169
1170impl<T> Stream for StripInvalidSpans<T>
1171where
1172 T: Stream<Item = Result<String>>,
1173{
1174 type Item = Result<String>;
1175
1176 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1177 const CODE_BLOCK_DELIMITER: &str = "```";
1178 const CURSOR_SPAN: &str = "<|CURSOR|>";
1179
1180 let this = unsafe { self.get_unchecked_mut() };
1181 loop {
1182 if !this.stream_done {
1183 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1184 match stream.as_mut().poll_next(cx) {
1185 Poll::Ready(Some(Ok(chunk))) => {
1186 this.buffer.push_str(&chunk);
1187 }
1188 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1189 Poll::Ready(None) => {
1190 this.stream_done = true;
1191 }
1192 Poll::Pending => return Poll::Pending,
1193 }
1194 }
1195
1196 let mut chunk = String::new();
1197 let mut consumed = 0;
1198 if !this.buffer.is_empty() {
1199 let mut lines = this.buffer.split('\n').enumerate().peekable();
1200 while let Some((line_ix, line)) = lines.next() {
1201 if line_ix > 0 {
1202 this.first_line = false;
1203 }
1204
1205 if this.first_line {
1206 let trimmed_line = line.trim();
1207 if lines.peek().is_some() {
1208 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1209 consumed += line.len() + 1;
1210 this.starts_with_code_block = true;
1211 continue;
1212 }
1213 } else if trimmed_line.is_empty()
1214 || prefixes(CODE_BLOCK_DELIMITER)
1215 .any(|prefix| trimmed_line.starts_with(prefix))
1216 {
1217 break;
1218 }
1219 }
1220
1221 let line_without_cursor = line.replace(CURSOR_SPAN, "");
1222 if lines.peek().is_some() {
1223 if this.line_end {
1224 chunk.push('\n');
1225 }
1226
1227 chunk.push_str(&line_without_cursor);
1228 this.line_end = true;
1229 consumed += line.len() + 1;
1230 } else if this.stream_done {
1231 if !this.starts_with_code_block
1232 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1233 {
1234 if this.line_end {
1235 chunk.push('\n');
1236 }
1237
1238 chunk.push_str(line);
1239 }
1240
1241 consumed += line.len();
1242 } else {
1243 let trimmed_line = line.trim();
1244 if trimmed_line.is_empty()
1245 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1246 || prefixes(CODE_BLOCK_DELIMITER)
1247 .any(|prefix| trimmed_line.ends_with(prefix))
1248 {
1249 break;
1250 } else {
1251 if this.line_end {
1252 chunk.push('\n');
1253 this.line_end = false;
1254 }
1255
1256 chunk.push_str(&line_without_cursor);
1257 consumed += line.len();
1258 }
1259 }
1260 }
1261 }
1262
1263 this.buffer = this.buffer.split_off(consumed);
1264 if !chunk.is_empty() {
1265 return Poll::Ready(Some(Ok(chunk)));
1266 } else if this.stream_done {
1267 return Poll::Ready(None);
1268 }
1269 }
1270 }
1271}
1272
1273fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1274 (0..text.len() - 1).map(|ix| &text[..ix + 1])
1275}
1276
1277#[derive(Default)]
1278pub struct Diff {
1279 pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1280 pub inserted_row_ranges: Vec<Range<Anchor>>,
1281}
1282
1283impl Diff {
1284 fn is_empty(&self) -> bool {
1285 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1286 }
1287}
1288
1289#[cfg(test)]
1290mod tests {
1291 use super::*;
1292 use futures::{
1293 Stream,
1294 stream::{self},
1295 };
1296 use gpui::TestAppContext;
1297 use indoc::indoc;
1298 use language::{Buffer, Language, LanguageConfig, LanguageMatcher, Point, tree_sitter_rust};
1299 use language_model::{LanguageModelRegistry, TokenUsage};
1300 use rand::prelude::*;
1301 use settings::SettingsStore;
1302 use std::{future, sync::Arc};
1303
1304 #[gpui::test(iterations = 10)]
1305 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1306 init_test(cx);
1307
1308 let text = indoc! {"
1309 fn main() {
1310 let x = 0;
1311 for _ in 0..10 {
1312 x += 1;
1313 }
1314 }
1315 "};
1316 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1317 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1318 let range = buffer.read_with(cx, |buffer, cx| {
1319 let snapshot = buffer.snapshot(cx);
1320 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1321 });
1322 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1323 let codegen = cx.new(|cx| {
1324 CodegenAlternative::new(
1325 buffer.clone(),
1326 range.clone(),
1327 true,
1328 None,
1329 prompt_builder,
1330 cx,
1331 )
1332 });
1333
1334 let chunks_tx = simulate_response_stream(&codegen, cx);
1335
1336 let mut new_text = concat!(
1337 " let mut x = 0;\n",
1338 " while x < 10 {\n",
1339 " x += 1;\n",
1340 " }",
1341 );
1342 while !new_text.is_empty() {
1343 let max_len = cmp::min(new_text.len(), 10);
1344 let len = rng.random_range(1..=max_len);
1345 let (chunk, suffix) = new_text.split_at(len);
1346 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1347 new_text = suffix;
1348 cx.background_executor.run_until_parked();
1349 }
1350 drop(chunks_tx);
1351 cx.background_executor.run_until_parked();
1352
1353 assert_eq!(
1354 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1355 indoc! {"
1356 fn main() {
1357 let mut x = 0;
1358 while x < 10 {
1359 x += 1;
1360 }
1361 }
1362 "}
1363 );
1364 }
1365
1366 #[gpui::test(iterations = 10)]
1367 async fn test_autoindent_when_generating_past_indentation(
1368 cx: &mut TestAppContext,
1369 mut rng: StdRng,
1370 ) {
1371 init_test(cx);
1372
1373 let text = indoc! {"
1374 fn main() {
1375 le
1376 }
1377 "};
1378 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1379 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1380 let range = buffer.read_with(cx, |buffer, cx| {
1381 let snapshot = buffer.snapshot(cx);
1382 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1383 });
1384 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1385 let codegen = cx.new(|cx| {
1386 CodegenAlternative::new(
1387 buffer.clone(),
1388 range.clone(),
1389 true,
1390 None,
1391 prompt_builder,
1392 cx,
1393 )
1394 });
1395
1396 let chunks_tx = simulate_response_stream(&codegen, cx);
1397
1398 cx.background_executor.run_until_parked();
1399
1400 let mut new_text = concat!(
1401 "t mut x = 0;\n",
1402 "while x < 10 {\n",
1403 " x += 1;\n",
1404 "}", //
1405 );
1406 while !new_text.is_empty() {
1407 let max_len = cmp::min(new_text.len(), 10);
1408 let len = rng.random_range(1..=max_len);
1409 let (chunk, suffix) = new_text.split_at(len);
1410 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1411 new_text = suffix;
1412 cx.background_executor.run_until_parked();
1413 }
1414 drop(chunks_tx);
1415 cx.background_executor.run_until_parked();
1416
1417 assert_eq!(
1418 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1419 indoc! {"
1420 fn main() {
1421 let mut x = 0;
1422 while x < 10 {
1423 x += 1;
1424 }
1425 }
1426 "}
1427 );
1428 }
1429
1430 #[gpui::test(iterations = 10)]
1431 async fn test_autoindent_when_generating_before_indentation(
1432 cx: &mut TestAppContext,
1433 mut rng: StdRng,
1434 ) {
1435 init_test(cx);
1436
1437 let text = concat!(
1438 "fn main() {\n",
1439 " \n",
1440 "}\n" //
1441 );
1442 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1443 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1444 let range = buffer.read_with(cx, |buffer, cx| {
1445 let snapshot = buffer.snapshot(cx);
1446 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1447 });
1448 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1449 let codegen = cx.new(|cx| {
1450 CodegenAlternative::new(
1451 buffer.clone(),
1452 range.clone(),
1453 true,
1454 None,
1455 prompt_builder,
1456 cx,
1457 )
1458 });
1459
1460 let chunks_tx = simulate_response_stream(&codegen, cx);
1461
1462 cx.background_executor.run_until_parked();
1463
1464 let mut new_text = concat!(
1465 "let mut x = 0;\n",
1466 "while x < 10 {\n",
1467 " x += 1;\n",
1468 "}", //
1469 );
1470 while !new_text.is_empty() {
1471 let max_len = cmp::min(new_text.len(), 10);
1472 let len = rng.random_range(1..=max_len);
1473 let (chunk, suffix) = new_text.split_at(len);
1474 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1475 new_text = suffix;
1476 cx.background_executor.run_until_parked();
1477 }
1478 drop(chunks_tx);
1479 cx.background_executor.run_until_parked();
1480
1481 assert_eq!(
1482 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1483 indoc! {"
1484 fn main() {
1485 let mut x = 0;
1486 while x < 10 {
1487 x += 1;
1488 }
1489 }
1490 "}
1491 );
1492 }
1493
1494 #[gpui::test(iterations = 10)]
1495 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1496 init_test(cx);
1497
1498 let text = indoc! {"
1499 func main() {
1500 \tx := 0
1501 \tfor i := 0; i < 10; i++ {
1502 \t\tx++
1503 \t}
1504 }
1505 "};
1506 let buffer = cx.new(|cx| Buffer::local(text, cx));
1507 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1508 let range = buffer.read_with(cx, |buffer, cx| {
1509 let snapshot = buffer.snapshot(cx);
1510 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1511 });
1512 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1513 let codegen = cx.new(|cx| {
1514 CodegenAlternative::new(
1515 buffer.clone(),
1516 range.clone(),
1517 true,
1518 None,
1519 prompt_builder,
1520 cx,
1521 )
1522 });
1523
1524 let chunks_tx = simulate_response_stream(&codegen, cx);
1525 let new_text = concat!(
1526 "func main() {\n",
1527 "\tx := 0\n",
1528 "\tfor x < 10 {\n",
1529 "\t\tx++\n",
1530 "\t}", //
1531 );
1532 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1533 drop(chunks_tx);
1534 cx.background_executor.run_until_parked();
1535
1536 assert_eq!(
1537 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1538 indoc! {"
1539 func main() {
1540 \tx := 0
1541 \tfor x < 10 {
1542 \t\tx++
1543 \t}
1544 }
1545 "}
1546 );
1547 }
1548
1549 #[gpui::test]
1550 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1551 init_test(cx);
1552
1553 let text = indoc! {"
1554 fn main() {
1555 let x = 0;
1556 }
1557 "};
1558 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
1559 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1560 let range = buffer.read_with(cx, |buffer, cx| {
1561 let snapshot = buffer.snapshot(cx);
1562 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1563 });
1564 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1565 let codegen = cx.new(|cx| {
1566 CodegenAlternative::new(
1567 buffer.clone(),
1568 range.clone(),
1569 false,
1570 None,
1571 prompt_builder,
1572 cx,
1573 )
1574 });
1575
1576 let chunks_tx = simulate_response_stream(&codegen, cx);
1577 chunks_tx
1578 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1579 .unwrap();
1580 drop(chunks_tx);
1581 cx.run_until_parked();
1582
1583 // The codegen is inactive, so the buffer doesn't get modified.
1584 assert_eq!(
1585 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1586 text
1587 );
1588
1589 // Activating the codegen applies the changes.
1590 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1591 assert_eq!(
1592 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1593 indoc! {"
1594 fn main() {
1595 let mut x = 0;
1596 x += 1;
1597 }
1598 "}
1599 );
1600
1601 // Deactivating the codegen undoes the changes.
1602 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1603 cx.run_until_parked();
1604 assert_eq!(
1605 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1606 text
1607 );
1608 }
1609
1610 #[gpui::test]
1611 async fn test_strip_invalid_spans_from_codeblock() {
1612 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1613 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1614 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1615 assert_chunks(
1616 "```html\n```js\nLorem ipsum dolor\n```\n```",
1617 "```js\nLorem ipsum dolor\n```",
1618 )
1619 .await;
1620 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1621 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1622 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1623 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1624
1625 async fn assert_chunks(text: &str, expected_text: &str) {
1626 for chunk_size in 1..=text.len() {
1627 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1628 .map(|chunk| chunk.unwrap())
1629 .collect::<String>()
1630 .await;
1631 assert_eq!(
1632 actual_text, expected_text,
1633 "failed to strip invalid spans, chunk size: {}",
1634 chunk_size
1635 );
1636 }
1637 }
1638
1639 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1640 stream::iter(
1641 text.chars()
1642 .collect::<Vec<_>>()
1643 .chunks(size)
1644 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1645 .collect::<Vec<_>>(),
1646 )
1647 }
1648 }
1649
1650 fn init_test(cx: &mut TestAppContext) {
1651 cx.update(LanguageModelRegistry::test);
1652 cx.set_global(cx.update(SettingsStore::test));
1653 }
1654
1655 fn simulate_response_stream(
1656 codegen: &Entity<CodegenAlternative>,
1657 cx: &mut TestAppContext,
1658 ) -> mpsc::UnboundedSender<String> {
1659 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1660 codegen.update(cx, |codegen, cx| {
1661 codegen.handle_stream(
1662 String::new(),
1663 String::new(),
1664 None,
1665 future::ready(Ok(LanguageModelTextStream {
1666 message_id: None,
1667 stream: chunks_rx.map(Ok).boxed(),
1668 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1669 })),
1670 cx,
1671 );
1672 });
1673 chunks_tx
1674 }
1675
1676 fn rust_lang() -> Language {
1677 Language::new(
1678 LanguageConfig {
1679 name: "Rust".into(),
1680 matcher: LanguageMatcher {
1681 path_suffixes: vec!["rs".to_string()],
1682 ..Default::default()
1683 },
1684 ..Default::default()
1685 },
1686 Some(tree_sitter_rust::LANGUAGE.into()),
1687 )
1688 .with_indents_query(
1689 r#"
1690 (call_expression) @indent
1691 (field_expression) @indent
1692 (_ "(" ")" @end) @indent
1693 (_ "{" "}" @end) @indent
1694 "#,
1695 )
1696 .unwrap()
1697 }
1698}