1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
2use agent_settings::AgentSettings;
3use anyhow::{Context as _, Result};
4
5use client::telemetry::Telemetry;
6use cloud_llm_client::CompletionIntent;
7use collections::HashSet;
8use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
9use feature_flags::{FeatureFlagAppExt as _, InlineAssistantUseToolFeatureFlag};
10use futures::{
11 SinkExt, Stream, StreamExt, TryStreamExt as _,
12 channel::mpsc,
13 future::{LocalBoxFuture, Shared},
14 join,
15 stream::BoxStream,
16};
17use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
18use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
19use language_model::{
20 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
21 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
22 LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
23 LanguageModelToolUse, Role, TokenUsage, report_assistant_event,
24};
25use multi_buffer::MultiBufferRow;
26use parking_lot::Mutex;
27use prompt_store::PromptBuilder;
28use rope::Rope;
29use schemars::JsonSchema;
30use serde::{Deserialize, Serialize};
31use settings::Settings as _;
32use smol::future::FutureExt;
33use std::{
34 cmp,
35 future::Future,
36 iter,
37 ops::{Range, RangeInclusive},
38 pin::Pin,
39 sync::Arc,
40 task::{self, Poll},
41 time::Instant,
42};
43use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
44use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
45use ui::SharedString;
46
47/// Use this tool to provide a message to the user when you're unable to complete a task.
48#[derive(Debug, Serialize, Deserialize, JsonSchema)]
49pub struct FailureMessageInput {
50 /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
51 ///
52 /// The message may use markdown formatting if you wish.
53 #[serde(default)]
54 pub message: String,
55}
56
57/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
58#[derive(Debug, Serialize, Deserialize, JsonSchema)]
59pub struct RewriteSectionInput {
60 /// A brief description of the edit you have made.
61 ///
62 /// The description may use markdown formatting if you wish.
63 /// This is optional - if the edit is simple or obvious, you should leave it empty.
64 #[serde(default)]
65 pub description: String,
66
67 /// The text to replace the section with.
68 #[serde(default)]
69 pub replacement_text: String,
70}
71
72pub struct BufferCodegen {
73 alternatives: Vec<Entity<CodegenAlternative>>,
74 pub active_alternative: usize,
75 seen_alternatives: HashSet<usize>,
76 subscriptions: Vec<Subscription>,
77 buffer: Entity<MultiBuffer>,
78 range: Range<Anchor>,
79 initial_transaction_id: Option<TransactionId>,
80 telemetry: Arc<Telemetry>,
81 builder: Arc<PromptBuilder>,
82 pub is_insertion: bool,
83}
84
85impl BufferCodegen {
86 pub fn new(
87 buffer: Entity<MultiBuffer>,
88 range: Range<Anchor>,
89 initial_transaction_id: Option<TransactionId>,
90 telemetry: Arc<Telemetry>,
91 builder: Arc<PromptBuilder>,
92 cx: &mut Context<Self>,
93 ) -> Self {
94 let codegen = cx.new(|cx| {
95 CodegenAlternative::new(
96 buffer.clone(),
97 range.clone(),
98 false,
99 Some(telemetry.clone()),
100 builder.clone(),
101 cx,
102 )
103 });
104 let mut this = Self {
105 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
106 alternatives: vec![codegen],
107 active_alternative: 0,
108 seen_alternatives: HashSet::default(),
109 subscriptions: Vec::new(),
110 buffer,
111 range,
112 initial_transaction_id,
113 telemetry,
114 builder,
115 };
116 this.activate(0, cx);
117 this
118 }
119
120 fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
121 let codegen = self.active_alternative().clone();
122 self.subscriptions.clear();
123 self.subscriptions
124 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
125 self.subscriptions
126 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
127 }
128
129 pub fn active_completion(&self, cx: &App) -> Option<String> {
130 self.active_alternative().read(cx).current_completion()
131 }
132
133 pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
134 &self.alternatives[self.active_alternative]
135 }
136
137 pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
138 &self.active_alternative().read(cx).status
139 }
140
141 pub fn alternative_count(&self, cx: &App) -> usize {
142 LanguageModelRegistry::read_global(cx)
143 .inline_alternative_models()
144 .len()
145 + 1
146 }
147
148 pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
149 let next_active_ix = if self.active_alternative == 0 {
150 self.alternatives.len() - 1
151 } else {
152 self.active_alternative - 1
153 };
154 self.activate(next_active_ix, cx);
155 }
156
157 pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
158 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
159 self.activate(next_active_ix, cx);
160 }
161
162 fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
163 self.active_alternative()
164 .update(cx, |codegen, cx| codegen.set_active(false, cx));
165 self.seen_alternatives.insert(index);
166 self.active_alternative = index;
167 self.active_alternative()
168 .update(cx, |codegen, cx| codegen.set_active(true, cx));
169 self.subscribe_to_alternative(cx);
170 cx.notify();
171 }
172
173 pub fn start(
174 &mut self,
175 primary_model: Arc<dyn LanguageModel>,
176 user_prompt: String,
177 context_task: Shared<Task<Option<LoadedContext>>>,
178 cx: &mut Context<Self>,
179 ) -> Result<()> {
180 let alternative_models = LanguageModelRegistry::read_global(cx)
181 .inline_alternative_models()
182 .to_vec();
183
184 self.active_alternative()
185 .update(cx, |alternative, cx| alternative.undo(cx));
186 self.activate(0, cx);
187 self.alternatives.truncate(1);
188
189 for _ in 0..alternative_models.len() {
190 self.alternatives.push(cx.new(|cx| {
191 CodegenAlternative::new(
192 self.buffer.clone(),
193 self.range.clone(),
194 false,
195 Some(self.telemetry.clone()),
196 self.builder.clone(),
197 cx,
198 )
199 }));
200 }
201
202 for (model, alternative) in iter::once(primary_model)
203 .chain(alternative_models)
204 .zip(&self.alternatives)
205 {
206 alternative.update(cx, |alternative, cx| {
207 alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
208 })?;
209 }
210
211 Ok(())
212 }
213
214 pub fn stop(&mut self, cx: &mut Context<Self>) {
215 for codegen in &self.alternatives {
216 codegen.update(cx, |codegen, cx| codegen.stop(cx));
217 }
218 }
219
220 pub fn undo(&mut self, cx: &mut Context<Self>) {
221 self.active_alternative()
222 .update(cx, |codegen, cx| codegen.undo(cx));
223
224 self.buffer.update(cx, |buffer, cx| {
225 if let Some(transaction_id) = self.initial_transaction_id.take() {
226 buffer.undo_transaction(transaction_id, cx);
227 buffer.refresh_preview(cx);
228 }
229 });
230 }
231
232 pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
233 self.active_alternative().read(cx).buffer.clone()
234 }
235
236 pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
237 self.active_alternative().read(cx).old_buffer.clone()
238 }
239
240 pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
241 self.active_alternative().read(cx).snapshot.clone()
242 }
243
244 pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
245 self.active_alternative().read(cx).edit_position
246 }
247
248 pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
249 &self.active_alternative().read(cx).diff
250 }
251
252 pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
253 self.active_alternative().read(cx).last_equal_ranges()
254 }
255
256 pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
257 self.active_alternative().read(cx).selected_text()
258 }
259}
260
261impl EventEmitter<CodegenEvent> for BufferCodegen {}
262
263pub struct CodegenAlternative {
264 buffer: Entity<MultiBuffer>,
265 old_buffer: Entity<Buffer>,
266 snapshot: MultiBufferSnapshot,
267 edit_position: Option<Anchor>,
268 range: Range<Anchor>,
269 last_equal_ranges: Vec<Range<Anchor>>,
270 transformation_transaction_id: Option<TransactionId>,
271 status: CodegenStatus,
272 generation: Task<()>,
273 diff: Diff,
274 telemetry: Option<Arc<Telemetry>>,
275 _subscription: gpui::Subscription,
276 builder: Arc<PromptBuilder>,
277 active: bool,
278 edits: Vec<(Range<Anchor>, String)>,
279 line_operations: Vec<LineOperation>,
280 elapsed_time: Option<f64>,
281 completion: Option<String>,
282 selected_text: Option<String>,
283 pub message_id: Option<String>,
284 pub model_explanation: Option<SharedString>,
285}
286
287impl EventEmitter<CodegenEvent> for CodegenAlternative {}
288
289impl CodegenAlternative {
290 pub fn new(
291 buffer: Entity<MultiBuffer>,
292 range: Range<Anchor>,
293 active: bool,
294 telemetry: Option<Arc<Telemetry>>,
295 builder: Arc<PromptBuilder>,
296 cx: &mut Context<Self>,
297 ) -> Self {
298 let snapshot = buffer.read(cx).snapshot(cx);
299
300 let (old_buffer, _, _) = snapshot
301 .range_to_buffer_ranges(range.clone())
302 .pop()
303 .unwrap();
304 let old_buffer = cx.new(|cx| {
305 let text = old_buffer.as_rope().clone();
306 let line_ending = old_buffer.line_ending();
307 let language = old_buffer.language().cloned();
308 let language_registry = buffer
309 .read(cx)
310 .buffer(old_buffer.remote_id())
311 .unwrap()
312 .read(cx)
313 .language_registry();
314
315 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
316 buffer.set_language(language, cx);
317 if let Some(language_registry) = language_registry {
318 buffer.set_language_registry(language_registry);
319 }
320 buffer
321 });
322
323 Self {
324 buffer: buffer.clone(),
325 old_buffer,
326 edit_position: None,
327 message_id: None,
328 snapshot,
329 last_equal_ranges: Default::default(),
330 transformation_transaction_id: None,
331 status: CodegenStatus::Idle,
332 generation: Task::ready(()),
333 diff: Diff::default(),
334 telemetry,
335 builder,
336 active: active,
337 edits: Vec::new(),
338 line_operations: Vec::new(),
339 range,
340 elapsed_time: None,
341 completion: None,
342 selected_text: None,
343 model_explanation: None,
344 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
345 }
346 }
347
348 pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
349 if active != self.active {
350 self.active = active;
351
352 if self.active {
353 let edits = self.edits.clone();
354 self.apply_edits(edits, cx);
355 if matches!(self.status, CodegenStatus::Pending) {
356 let line_operations = self.line_operations.clone();
357 self.reapply_line_based_diff(line_operations, cx);
358 } else {
359 self.reapply_batch_diff(cx).detach();
360 }
361 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
362 self.buffer.update(cx, |buffer, cx| {
363 buffer.undo_transaction(transaction_id, cx);
364 buffer.forget_transaction(transaction_id, cx);
365 });
366 }
367 }
368 }
369
370 fn handle_buffer_event(
371 &mut self,
372 _buffer: Entity<MultiBuffer>,
373 event: &multi_buffer::Event,
374 cx: &mut Context<Self>,
375 ) {
376 if let multi_buffer::Event::TransactionUndone { transaction_id } = event
377 && self.transformation_transaction_id == Some(*transaction_id)
378 {
379 self.transformation_transaction_id = None;
380 self.generation = Task::ready(());
381 cx.emit(CodegenEvent::Undone);
382 }
383 }
384
385 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
386 &self.last_equal_ranges
387 }
388
389 fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
390 model.supports_streaming_tools()
391 && cx.has_flag::<InlineAssistantUseToolFeatureFlag>()
392 && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
393 }
394
395 pub fn start(
396 &mut self,
397 user_prompt: String,
398 context_task: Shared<Task<Option<LoadedContext>>>,
399 model: Arc<dyn LanguageModel>,
400 cx: &mut Context<Self>,
401 ) -> Result<()> {
402 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
403 self.buffer.update(cx, |buffer, cx| {
404 buffer.undo_transaction(transformation_transaction_id, cx);
405 });
406 }
407
408 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
409
410 let api_key = model.api_key(cx);
411 let telemetry_id = model.telemetry_id();
412 let provider_id = model.provider_id();
413
414 if Self::use_streaming_tools(model.as_ref(), cx) {
415 let request = self.build_request(&model, user_prompt, context_task, cx)?;
416 let completion_events =
417 cx.spawn(async move |_, cx| model.stream_completion(request.await, cx).await);
418 self.generation = self.handle_completion(
419 telemetry_id,
420 provider_id.to_string(),
421 api_key,
422 completion_events,
423 cx,
424 );
425 } else {
426 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
427 if user_prompt.trim().to_lowercase() == "delete" {
428 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
429 } else {
430 let request = self.build_request(&model, user_prompt, context_task, cx)?;
431 cx.spawn(async move |_, cx| {
432 Ok(model.stream_completion_text(request.await, cx).await?)
433 })
434 .boxed_local()
435 };
436 self.generation =
437 self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
438 }
439
440 Ok(())
441 }
442
443 fn build_request_tools(
444 &self,
445 model: &Arc<dyn LanguageModel>,
446 user_prompt: String,
447 context_task: Shared<Task<Option<LoadedContext>>>,
448 cx: &mut App,
449 ) -> Result<Task<LanguageModelRequest>> {
450 let buffer = self.buffer.read(cx).snapshot(cx);
451 let language = buffer.language_at(self.range.start);
452 let language_name = if let Some(language) = language.as_ref() {
453 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
454 None
455 } else {
456 Some(language.name())
457 }
458 } else {
459 None
460 };
461
462 let language_name = language_name.as_ref();
463 let start = buffer.point_to_buffer_offset(self.range.start);
464 let end = buffer.point_to_buffer_offset(self.range.end);
465 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
466 let (start_buffer, start_buffer_offset) = start;
467 let (end_buffer, end_buffer_offset) = end;
468 if start_buffer.remote_id() == end_buffer.remote_id() {
469 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
470 } else {
471 anyhow::bail!("invalid transformation range");
472 }
473 } else {
474 anyhow::bail!("invalid transformation range");
475 };
476
477 let system_prompt = self
478 .builder
479 .generate_inline_transformation_prompt_tools(
480 language_name,
481 buffer,
482 range.start.0..range.end.0,
483 )
484 .context("generating content prompt")?;
485
486 let temperature = AgentSettings::temperature_for_model(model, cx);
487
488 let tool_input_format = model.tool_input_format();
489 let tool_choice = model
490 .supports_tool_choice(LanguageModelToolChoice::Any)
491 .then_some(LanguageModelToolChoice::Any);
492
493 Ok(cx.spawn(async move |_cx| {
494 let mut messages = vec![LanguageModelRequestMessage {
495 role: Role::System,
496 content: vec![system_prompt.into()],
497 cache: false,
498 reasoning_details: None,
499 }];
500
501 let mut user_message = LanguageModelRequestMessage {
502 role: Role::User,
503 content: Vec::new(),
504 cache: false,
505 reasoning_details: None,
506 };
507
508 if let Some(context) = context_task.await {
509 context.add_to_request_message(&mut user_message);
510 }
511
512 user_message.content.push(user_prompt.into());
513 messages.push(user_message);
514
515 let tools = vec![
516 LanguageModelRequestTool {
517 name: "rewrite_section".to_string(),
518 description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
519 input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
520 },
521 LanguageModelRequestTool {
522 name: "failure_message".to_string(),
523 description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
524 input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
525 },
526 ];
527
528 LanguageModelRequest {
529 thread_id: None,
530 prompt_id: None,
531 intent: Some(CompletionIntent::InlineAssist),
532 mode: None,
533 tools,
534 tool_choice,
535 stop: Vec::new(),
536 temperature,
537 messages,
538 thinking_allowed: false,
539 }
540 }))
541 }
542
543 fn build_request(
544 &self,
545 model: &Arc<dyn LanguageModel>,
546 user_prompt: String,
547 context_task: Shared<Task<Option<LoadedContext>>>,
548 cx: &mut App,
549 ) -> Result<Task<LanguageModelRequest>> {
550 if Self::use_streaming_tools(model.as_ref(), cx) {
551 return self.build_request_tools(model, user_prompt, context_task, cx);
552 }
553
554 let buffer = self.buffer.read(cx).snapshot(cx);
555 let language = buffer.language_at(self.range.start);
556 let language_name = if let Some(language) = language.as_ref() {
557 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
558 None
559 } else {
560 Some(language.name())
561 }
562 } else {
563 None
564 };
565
566 let language_name = language_name.as_ref();
567 let start = buffer.point_to_buffer_offset(self.range.start);
568 let end = buffer.point_to_buffer_offset(self.range.end);
569 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
570 let (start_buffer, start_buffer_offset) = start;
571 let (end_buffer, end_buffer_offset) = end;
572 if start_buffer.remote_id() == end_buffer.remote_id() {
573 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
574 } else {
575 anyhow::bail!("invalid transformation range");
576 }
577 } else {
578 anyhow::bail!("invalid transformation range");
579 };
580
581 let prompt = self
582 .builder
583 .generate_inline_transformation_prompt(
584 user_prompt,
585 language_name,
586 buffer,
587 range.start.0..range.end.0,
588 )
589 .context("generating content prompt")?;
590
591 let temperature = AgentSettings::temperature_for_model(model, cx);
592
593 Ok(cx.spawn(async move |_cx| {
594 let mut request_message = LanguageModelRequestMessage {
595 role: Role::User,
596 content: Vec::new(),
597 cache: false,
598 reasoning_details: None,
599 };
600
601 if let Some(context) = context_task.await {
602 context.add_to_request_message(&mut request_message);
603 }
604
605 request_message.content.push(prompt.into());
606
607 LanguageModelRequest {
608 thread_id: None,
609 prompt_id: None,
610 intent: Some(CompletionIntent::InlineAssist),
611 mode: None,
612 tools: Vec::new(),
613 tool_choice: None,
614 stop: Vec::new(),
615 temperature,
616 messages: vec![request_message],
617 thinking_allowed: false,
618 }
619 }))
620 }
621
622 pub fn handle_stream(
623 &mut self,
624 model_telemetry_id: String,
625 model_provider_id: String,
626 model_api_key: Option<String>,
627 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
628 cx: &mut Context<Self>,
629 ) -> Task<()> {
630 let start_time = Instant::now();
631
632 // Make a new snapshot and re-resolve anchor in case the document was modified.
633 // This can happen often if the editor loses focus and is saved + reformatted,
634 // as in https://github.com/zed-industries/zed/issues/39088
635 self.snapshot = self.buffer.read(cx).snapshot(cx);
636 self.range = self.snapshot.anchor_after(self.range.start)
637 ..self.snapshot.anchor_after(self.range.end);
638
639 let snapshot = self.snapshot.clone();
640 let selected_text = snapshot
641 .text_for_range(self.range.start..self.range.end)
642 .collect::<Rope>();
643
644 self.selected_text = Some(selected_text.to_string());
645
646 let selection_start = self.range.start.to_point(&snapshot);
647
648 // Start with the indentation of the first line in the selection
649 let mut suggested_line_indent = snapshot
650 .suggested_indents(selection_start.row..=selection_start.row, cx)
651 .into_values()
652 .next()
653 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
654
655 // If the first line in the selection does not have indentation, check the following lines
656 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
657 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
658 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
659 // Prefer tabs if a line in the selection uses tabs as indentation
660 if line_indent.kind == IndentKind::Tab {
661 suggested_line_indent.kind = IndentKind::Tab;
662 break;
663 }
664 }
665 }
666
667 let http_client = cx.http_client();
668 let telemetry = self.telemetry.clone();
669 let language_name = {
670 let multibuffer = self.buffer.read(cx);
671 let snapshot = multibuffer.snapshot(cx);
672 let ranges = snapshot.range_to_buffer_ranges(self.range.clone());
673 ranges
674 .first()
675 .and_then(|(buffer, _, _)| buffer.language())
676 .map(|language| language.name())
677 };
678
679 self.diff = Diff::default();
680 self.status = CodegenStatus::Pending;
681 let mut edit_start = self.range.start.to_offset(&snapshot);
682 let completion = Arc::new(Mutex::new(String::new()));
683 let completion_clone = completion.clone();
684
685 cx.notify();
686 cx.spawn(async move |codegen, cx| {
687 let stream = stream.await;
688
689 let token_usage = stream
690 .as_ref()
691 .ok()
692 .map(|stream| stream.last_token_usage.clone());
693 let message_id = stream
694 .as_ref()
695 .ok()
696 .and_then(|stream| stream.message_id.clone());
697 let generate = async {
698 let model_telemetry_id = model_telemetry_id.clone();
699 let model_provider_id = model_provider_id.clone();
700 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
701 let executor = cx.background_executor().clone();
702 let message_id = message_id.clone();
703 let line_based_stream_diff: Task<anyhow::Result<()>> =
704 cx.background_spawn(async move {
705 let mut response_latency = None;
706 let request_start = Instant::now();
707 let diff = async {
708 let chunks = StripInvalidSpans::new(
709 stream?.stream.map_err(|error| error.into()),
710 );
711 futures::pin_mut!(chunks);
712
713 let mut diff = StreamingDiff::new(selected_text.to_string());
714 let mut line_diff = LineDiff::default();
715
716 let mut new_text = String::new();
717 let mut base_indent = None;
718 let mut line_indent = None;
719 let mut first_line = true;
720
721 while let Some(chunk) = chunks.next().await {
722 if response_latency.is_none() {
723 response_latency = Some(request_start.elapsed());
724 }
725 let chunk = chunk?;
726 completion_clone.lock().push_str(&chunk);
727
728 let mut lines = chunk.split('\n').peekable();
729 while let Some(line) = lines.next() {
730 new_text.push_str(line);
731 if line_indent.is_none()
732 && let Some(non_whitespace_ch_ix) =
733 new_text.find(|ch: char| !ch.is_whitespace())
734 {
735 line_indent = Some(non_whitespace_ch_ix);
736 base_indent = base_indent.or(line_indent);
737
738 let line_indent = line_indent.unwrap();
739 let base_indent = base_indent.unwrap();
740 let indent_delta = line_indent as i32 - base_indent as i32;
741 let mut corrected_indent_len = cmp::max(
742 0,
743 suggested_line_indent.len as i32 + indent_delta,
744 )
745 as usize;
746 if first_line {
747 corrected_indent_len = corrected_indent_len
748 .saturating_sub(selection_start.column as usize);
749 }
750
751 let indent_char = suggested_line_indent.char();
752 let mut indent_buffer = [0; 4];
753 let indent_str =
754 indent_char.encode_utf8(&mut indent_buffer);
755 new_text.replace_range(
756 ..line_indent,
757 &indent_str.repeat(corrected_indent_len),
758 );
759 }
760
761 if line_indent.is_some() {
762 let char_ops = diff.push_new(&new_text);
763 line_diff.push_char_operations(&char_ops, &selected_text);
764 diff_tx
765 .send((char_ops, line_diff.line_operations()))
766 .await?;
767 new_text.clear();
768 }
769
770 if lines.peek().is_some() {
771 let char_ops = diff.push_new("\n");
772 line_diff.push_char_operations(&char_ops, &selected_text);
773 diff_tx
774 .send((char_ops, line_diff.line_operations()))
775 .await?;
776 if line_indent.is_none() {
777 // Don't write out the leading indentation in empty lines on the next line
778 // This is the case where the above if statement didn't clear the buffer
779 new_text.clear();
780 }
781 line_indent = None;
782 first_line = false;
783 }
784 }
785 }
786
787 let mut char_ops = diff.push_new(&new_text);
788 char_ops.extend(diff.finish());
789 line_diff.push_char_operations(&char_ops, &selected_text);
790 line_diff.finish(&selected_text);
791 diff_tx
792 .send((char_ops, line_diff.line_operations()))
793 .await?;
794
795 anyhow::Ok(())
796 };
797
798 let result = diff.await;
799
800 let error_message = result.as_ref().err().map(|error| error.to_string());
801 report_assistant_event(
802 AssistantEventData {
803 conversation_id: None,
804 message_id,
805 kind: AssistantKind::Inline,
806 phase: AssistantPhase::Response,
807 model: model_telemetry_id,
808 model_provider: model_provider_id,
809 response_latency,
810 error_message,
811 language_name: language_name.map(|name| name.to_proto()),
812 },
813 telemetry,
814 http_client,
815 model_api_key,
816 &executor,
817 );
818
819 result?;
820 Ok(())
821 });
822
823 while let Some((char_ops, line_ops)) = diff_rx.next().await {
824 codegen.update(cx, |codegen, cx| {
825 codegen.last_equal_ranges.clear();
826
827 let edits = char_ops
828 .into_iter()
829 .filter_map(|operation| match operation {
830 CharOperation::Insert { text } => {
831 let edit_start = snapshot.anchor_after(edit_start);
832 Some((edit_start..edit_start, text))
833 }
834 CharOperation::Delete { bytes } => {
835 let edit_end = edit_start + bytes;
836 let edit_range = snapshot.anchor_after(edit_start)
837 ..snapshot.anchor_before(edit_end);
838 edit_start = edit_end;
839 Some((edit_range, String::new()))
840 }
841 CharOperation::Keep { bytes } => {
842 let edit_end = edit_start + bytes;
843 let edit_range = snapshot.anchor_after(edit_start)
844 ..snapshot.anchor_before(edit_end);
845 edit_start = edit_end;
846 codegen.last_equal_ranges.push(edit_range);
847 None
848 }
849 })
850 .collect::<Vec<_>>();
851
852 if codegen.active {
853 codegen.apply_edits(edits.iter().cloned(), cx);
854 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
855 }
856 codegen.edits.extend(edits);
857 codegen.line_operations = line_ops;
858 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
859
860 cx.notify();
861 })?;
862 }
863
864 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
865 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
866 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
867 let batch_diff_task =
868 codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
869 let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
870 line_based_stream_diff?;
871
872 anyhow::Ok(())
873 };
874
875 let result = generate.await;
876 let elapsed_time = start_time.elapsed().as_secs_f64();
877
878 codegen
879 .update(cx, |this, cx| {
880 this.message_id = message_id;
881 this.last_equal_ranges.clear();
882 if let Err(error) = result {
883 this.status = CodegenStatus::Error(error);
884 } else {
885 this.status = CodegenStatus::Done;
886 }
887 this.elapsed_time = Some(elapsed_time);
888 this.completion = Some(completion.lock().clone());
889 if let Some(usage) = token_usage {
890 let usage = usage.lock();
891 telemetry::event!(
892 "Inline Assistant Completion",
893 model = model_telemetry_id,
894 model_provider = model_provider_id,
895 input_tokens = usage.input_tokens,
896 output_tokens = usage.output_tokens,
897 )
898 }
899
900 cx.emit(CodegenEvent::Finished);
901 cx.notify();
902 })
903 .ok();
904 })
905 }
906
907 pub fn current_completion(&self) -> Option<String> {
908 self.completion.clone()
909 }
910
911 pub fn selected_text(&self) -> Option<&str> {
912 self.selected_text.as_deref()
913 }
914
915 pub fn stop(&mut self, cx: &mut Context<Self>) {
916 self.last_equal_ranges.clear();
917 if self.diff.is_empty() {
918 self.status = CodegenStatus::Idle;
919 } else {
920 self.status = CodegenStatus::Done;
921 }
922 self.generation = Task::ready(());
923 cx.emit(CodegenEvent::Finished);
924 cx.notify();
925 }
926
927 pub fn undo(&mut self, cx: &mut Context<Self>) {
928 self.buffer.update(cx, |buffer, cx| {
929 if let Some(transaction_id) = self.transformation_transaction_id.take() {
930 buffer.undo_transaction(transaction_id, cx);
931 buffer.refresh_preview(cx);
932 }
933 });
934 }
935
936 fn apply_edits(
937 &mut self,
938 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
939 cx: &mut Context<CodegenAlternative>,
940 ) {
941 let transaction = self.buffer.update(cx, |buffer, cx| {
942 // Avoid grouping agent edits with user edits.
943 buffer.finalize_last_transaction(cx);
944 buffer.start_transaction(cx);
945 buffer.edit(edits, None, cx);
946 buffer.end_transaction(cx)
947 });
948
949 if let Some(transaction) = transaction {
950 if let Some(first_transaction) = self.transformation_transaction_id {
951 // Group all agent edits into the first transaction.
952 self.buffer.update(cx, |buffer, cx| {
953 buffer.merge_transactions(transaction, first_transaction, cx)
954 });
955 } else {
956 self.transformation_transaction_id = Some(transaction);
957 self.buffer
958 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
959 }
960 }
961 }
962
963 fn reapply_line_based_diff(
964 &mut self,
965 line_operations: impl IntoIterator<Item = LineOperation>,
966 cx: &mut Context<Self>,
967 ) {
968 let old_snapshot = self.snapshot.clone();
969 let old_range = self.range.to_point(&old_snapshot);
970 let new_snapshot = self.buffer.read(cx).snapshot(cx);
971 let new_range = self.range.to_point(&new_snapshot);
972
973 let mut old_row = old_range.start.row;
974 let mut new_row = new_range.start.row;
975
976 self.diff.deleted_row_ranges.clear();
977 self.diff.inserted_row_ranges.clear();
978 for operation in line_operations {
979 match operation {
980 LineOperation::Keep { lines } => {
981 old_row += lines;
982 new_row += lines;
983 }
984 LineOperation::Delete { lines } => {
985 let old_end_row = old_row + lines - 1;
986 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
987
988 if let Some((_, last_deleted_row_range)) =
989 self.diff.deleted_row_ranges.last_mut()
990 {
991 if *last_deleted_row_range.end() + 1 == old_row {
992 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
993 } else {
994 self.diff
995 .deleted_row_ranges
996 .push((new_row, old_row..=old_end_row));
997 }
998 } else {
999 self.diff
1000 .deleted_row_ranges
1001 .push((new_row, old_row..=old_end_row));
1002 }
1003
1004 old_row += lines;
1005 }
1006 LineOperation::Insert { lines } => {
1007 let new_end_row = new_row + lines - 1;
1008 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1009 let end = new_snapshot.anchor_before(Point::new(
1010 new_end_row,
1011 new_snapshot.line_len(MultiBufferRow(new_end_row)),
1012 ));
1013 self.diff.inserted_row_ranges.push(start..end);
1014 new_row += lines;
1015 }
1016 }
1017
1018 cx.notify();
1019 }
1020 }
1021
1022 fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
1023 let old_snapshot = self.snapshot.clone();
1024 let old_range = self.range.to_point(&old_snapshot);
1025 let new_snapshot = self.buffer.read(cx).snapshot(cx);
1026 let new_range = self.range.to_point(&new_snapshot);
1027
1028 cx.spawn(async move |codegen, cx| {
1029 let (deleted_row_ranges, inserted_row_ranges) = cx
1030 .background_spawn(async move {
1031 let old_text = old_snapshot
1032 .text_for_range(
1033 Point::new(old_range.start.row, 0)
1034 ..Point::new(
1035 old_range.end.row,
1036 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1037 ),
1038 )
1039 .collect::<String>();
1040 let new_text = new_snapshot
1041 .text_for_range(
1042 Point::new(new_range.start.row, 0)
1043 ..Point::new(
1044 new_range.end.row,
1045 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1046 ),
1047 )
1048 .collect::<String>();
1049
1050 let old_start_row = old_range.start.row;
1051 let new_start_row = new_range.start.row;
1052 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1053 let mut inserted_row_ranges = Vec::new();
1054 for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1055 let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1056 let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1057 if !old_rows.is_empty() {
1058 deleted_row_ranges.push((
1059 new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1060 old_rows.start..=old_rows.end - 1,
1061 ));
1062 }
1063 if !new_rows.is_empty() {
1064 let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1065 let new_end_row = new_rows.end - 1;
1066 let end = new_snapshot.anchor_before(Point::new(
1067 new_end_row,
1068 new_snapshot.line_len(MultiBufferRow(new_end_row)),
1069 ));
1070 inserted_row_ranges.push(start..end);
1071 }
1072 }
1073 (deleted_row_ranges, inserted_row_ranges)
1074 })
1075 .await;
1076
1077 codegen
1078 .update(cx, |codegen, cx| {
1079 codegen.diff.deleted_row_ranges = deleted_row_ranges;
1080 codegen.diff.inserted_row_ranges = inserted_row_ranges;
1081 cx.notify();
1082 })
1083 .ok();
1084 })
1085 }
1086
1087 fn handle_completion(
1088 &mut self,
1089 telemetry_id: String,
1090 provider_id: String,
1091 api_key: Option<String>,
1092 completion_stream: Task<
1093 Result<
1094 BoxStream<
1095 'static,
1096 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1097 >,
1098 LanguageModelCompletionError,
1099 >,
1100 >,
1101 cx: &mut Context<Self>,
1102 ) -> Task<()> {
1103 self.diff = Diff::default();
1104 self.status = CodegenStatus::Pending;
1105
1106 cx.notify();
1107 // Leaving this in generation so that STOP equivalent events are respected even
1108 // while we're still pre-processing the completion event
1109 cx.spawn(async move |codegen, cx| {
1110 let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1111 let _ = codegen.update(cx, |this, cx| {
1112 this.status = status;
1113 cx.emit(CodegenEvent::Finished);
1114 cx.notify();
1115 });
1116 };
1117
1118 let mut completion_events = match completion_stream.await {
1119 Ok(events) => events,
1120 Err(err) => {
1121 finish_with_status(CodegenStatus::Error(err.into()), cx);
1122 return;
1123 }
1124 };
1125
1126 let chars_read_so_far = Arc::new(Mutex::new(0usize));
1127 let tool_to_text_and_message =
1128 move |tool_use: LanguageModelToolUse| -> (Option<String>, Option<String>) {
1129 let mut chars_read_so_far = chars_read_so_far.lock();
1130 match tool_use.name.as_ref() {
1131 "rewrite_section" => {
1132 let Ok(mut input) =
1133 serde_json::from_value::<RewriteSectionInput>(tool_use.input)
1134 else {
1135 return (None, None);
1136 };
1137 let value = input.replacement_text[*chars_read_so_far..].to_string();
1138 *chars_read_so_far = input.replacement_text.len();
1139 (Some(value), Some(std::mem::take(&mut input.description)))
1140 }
1141 "failure_message" => {
1142 let Ok(mut input) =
1143 serde_json::from_value::<FailureMessageInput>(tool_use.input)
1144 else {
1145 return (None, None);
1146 };
1147 (None, Some(std::mem::take(&mut input.message)))
1148 }
1149 _ => (None, None),
1150 }
1151 };
1152
1153 let mut message_id = None;
1154 let mut first_text = None;
1155 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
1156 let total_text = Arc::new(Mutex::new(String::new()));
1157
1158 loop {
1159 if let Some(first_event) = completion_events.next().await {
1160 match first_event {
1161 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
1162 message_id = Some(id);
1163 }
1164 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
1165 if matches!(
1166 tool_use.name.as_ref(),
1167 "rewrite_section" | "failure_message"
1168 ) =>
1169 {
1170 let is_complete = tool_use.is_input_complete;
1171 let (text, message) = tool_to_text_and_message(tool_use);
1172 // Only update the model explanation if the tool use is complete.
1173 // Otherwise the UI element bounces around as it's updated.
1174 if is_complete {
1175 let _ = codegen.update(cx, |this, _cx| {
1176 this.model_explanation = message.map(Into::into);
1177 });
1178 }
1179 first_text = text;
1180 if first_text.is_some() {
1181 break;
1182 }
1183 }
1184 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1185 *last_token_usage.lock() = token_usage;
1186 }
1187 Ok(LanguageModelCompletionEvent::Text(text)) => {
1188 let mut lock = total_text.lock();
1189 lock.push_str(&text);
1190 }
1191 Ok(e) => {
1192 log::warn!("Unexpected event: {:?}", e);
1193 break;
1194 }
1195 Err(e) => {
1196 finish_with_status(CodegenStatus::Error(e.into()), cx);
1197 break;
1198 }
1199 }
1200 }
1201 }
1202
1203 let Some(first_text) = first_text else {
1204 finish_with_status(CodegenStatus::Done, cx);
1205 return;
1206 };
1207
1208 let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded();
1209
1210 cx.spawn({
1211 let codegen = codegen.clone();
1212 async move |cx| {
1213 while let Some(message) = message_rx.next().await {
1214 let _ = codegen.update(cx, |this, _cx| {
1215 this.model_explanation = message;
1216 });
1217 }
1218 }
1219 })
1220 .detach();
1221
1222 let move_last_token_usage = last_token_usage.clone();
1223
1224 let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
1225 completion_events.filter_map(move |e| {
1226 let tool_to_text_and_message = tool_to_text_and_message.clone();
1227 let last_token_usage = move_last_token_usage.clone();
1228 let total_text = total_text.clone();
1229 let mut message_tx = message_tx.clone();
1230 async move {
1231 match e {
1232 Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
1233 if matches!(
1234 tool_use.name.as_ref(),
1235 "rewrite_section" | "failure_message"
1236 ) =>
1237 {
1238 let is_complete = tool_use.is_input_complete;
1239 let (text, message) = tool_to_text_and_message(tool_use);
1240 if is_complete {
1241 // Again only send the message when complete to not get a bouncing UI element.
1242 let _ = message_tx.send(message.map(Into::into)).await;
1243 }
1244 text.map(Ok)
1245 }
1246 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1247 *last_token_usage.lock() = token_usage;
1248 None
1249 }
1250 Ok(LanguageModelCompletionEvent::Text(text)) => {
1251 let mut lock = total_text.lock();
1252 lock.push_str(&text);
1253 None
1254 }
1255 Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
1256 e => {
1257 log::error!("UNEXPECTED EVENT {:?}", e);
1258 None
1259 }
1260 }
1261 }
1262 }),
1263 ));
1264
1265 let language_model_text_stream = LanguageModelTextStream {
1266 message_id: message_id,
1267 stream: text_stream,
1268 last_token_usage,
1269 };
1270
1271 let Some(task) = codegen
1272 .update(cx, move |codegen, cx| {
1273 codegen.handle_stream(
1274 telemetry_id,
1275 provider_id,
1276 api_key,
1277 async { Ok(language_model_text_stream) },
1278 cx,
1279 )
1280 })
1281 .ok()
1282 else {
1283 return;
1284 };
1285
1286 task.await;
1287 })
1288 }
1289}
1290
1291#[derive(Copy, Clone, Debug)]
1292pub enum CodegenEvent {
1293 Finished,
1294 Undone,
1295}
1296
1297struct StripInvalidSpans<T> {
1298 stream: T,
1299 stream_done: bool,
1300 buffer: String,
1301 first_line: bool,
1302 line_end: bool,
1303 starts_with_code_block: bool,
1304}
1305
1306impl<T> StripInvalidSpans<T>
1307where
1308 T: Stream<Item = Result<String>>,
1309{
1310 fn new(stream: T) -> Self {
1311 Self {
1312 stream,
1313 stream_done: false,
1314 buffer: String::new(),
1315 first_line: true,
1316 line_end: false,
1317 starts_with_code_block: false,
1318 }
1319 }
1320}
1321
1322impl<T> Stream for StripInvalidSpans<T>
1323where
1324 T: Stream<Item = Result<String>>,
1325{
1326 type Item = Result<String>;
1327
1328 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1329 const CODE_BLOCK_DELIMITER: &str = "```";
1330 const CURSOR_SPAN: &str = "<|CURSOR|>";
1331
1332 let this = unsafe { self.get_unchecked_mut() };
1333 loop {
1334 if !this.stream_done {
1335 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1336 match stream.as_mut().poll_next(cx) {
1337 Poll::Ready(Some(Ok(chunk))) => {
1338 this.buffer.push_str(&chunk);
1339 }
1340 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1341 Poll::Ready(None) => {
1342 this.stream_done = true;
1343 }
1344 Poll::Pending => return Poll::Pending,
1345 }
1346 }
1347
1348 let mut chunk = String::new();
1349 let mut consumed = 0;
1350 if !this.buffer.is_empty() {
1351 let mut lines = this.buffer.split('\n').enumerate().peekable();
1352 while let Some((line_ix, line)) = lines.next() {
1353 if line_ix > 0 {
1354 this.first_line = false;
1355 }
1356
1357 if this.first_line {
1358 let trimmed_line = line.trim();
1359 if lines.peek().is_some() {
1360 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1361 consumed += line.len() + 1;
1362 this.starts_with_code_block = true;
1363 continue;
1364 }
1365 } else if trimmed_line.is_empty()
1366 || prefixes(CODE_BLOCK_DELIMITER)
1367 .any(|prefix| trimmed_line.starts_with(prefix))
1368 {
1369 break;
1370 }
1371 }
1372
1373 let line_without_cursor = line.replace(CURSOR_SPAN, "");
1374 if lines.peek().is_some() {
1375 if this.line_end {
1376 chunk.push('\n');
1377 }
1378
1379 chunk.push_str(&line_without_cursor);
1380 this.line_end = true;
1381 consumed += line.len() + 1;
1382 } else if this.stream_done {
1383 if !this.starts_with_code_block
1384 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1385 {
1386 if this.line_end {
1387 chunk.push('\n');
1388 }
1389
1390 chunk.push_str(line);
1391 }
1392
1393 consumed += line.len();
1394 } else {
1395 let trimmed_line = line.trim();
1396 if trimmed_line.is_empty()
1397 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1398 || prefixes(CODE_BLOCK_DELIMITER)
1399 .any(|prefix| trimmed_line.ends_with(prefix))
1400 {
1401 break;
1402 } else {
1403 if this.line_end {
1404 chunk.push('\n');
1405 this.line_end = false;
1406 }
1407
1408 chunk.push_str(&line_without_cursor);
1409 consumed += line.len();
1410 }
1411 }
1412 }
1413 }
1414
1415 this.buffer = this.buffer.split_off(consumed);
1416 if !chunk.is_empty() {
1417 return Poll::Ready(Some(Ok(chunk)));
1418 } else if this.stream_done {
1419 return Poll::Ready(None);
1420 }
1421 }
1422 }
1423}
1424
1425fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1426 (0..text.len() - 1).map(|ix| &text[..ix + 1])
1427}
1428
1429#[derive(Default)]
1430pub struct Diff {
1431 pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1432 pub inserted_row_ranges: Vec<Range<Anchor>>,
1433}
1434
1435impl Diff {
1436 fn is_empty(&self) -> bool {
1437 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1438 }
1439}
1440
1441#[cfg(test)]
1442mod tests {
1443 use super::*;
1444 use futures::{
1445 Stream,
1446 stream::{self},
1447 };
1448 use gpui::TestAppContext;
1449 use indoc::indoc;
1450 use language::{Buffer, Point};
1451 use language_model::{LanguageModelRegistry, TokenUsage};
1452 use languages::rust_lang;
1453 use rand::prelude::*;
1454 use settings::SettingsStore;
1455 use std::{future, sync::Arc};
1456
1457 #[gpui::test(iterations = 10)]
1458 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1459 init_test(cx);
1460
1461 let text = indoc! {"
1462 fn main() {
1463 let x = 0;
1464 for _ in 0..10 {
1465 x += 1;
1466 }
1467 }
1468 "};
1469 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1470 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1471 let range = buffer.read_with(cx, |buffer, cx| {
1472 let snapshot = buffer.snapshot(cx);
1473 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1474 });
1475 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1476 let codegen = cx.new(|cx| {
1477 CodegenAlternative::new(
1478 buffer.clone(),
1479 range.clone(),
1480 true,
1481 None,
1482 prompt_builder,
1483 cx,
1484 )
1485 });
1486
1487 let chunks_tx = simulate_response_stream(&codegen, cx);
1488
1489 let mut new_text = concat!(
1490 " let mut x = 0;\n",
1491 " while x < 10 {\n",
1492 " x += 1;\n",
1493 " }",
1494 );
1495 while !new_text.is_empty() {
1496 let max_len = cmp::min(new_text.len(), 10);
1497 let len = rng.random_range(1..=max_len);
1498 let (chunk, suffix) = new_text.split_at(len);
1499 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1500 new_text = suffix;
1501 cx.background_executor.run_until_parked();
1502 }
1503 drop(chunks_tx);
1504 cx.background_executor.run_until_parked();
1505
1506 assert_eq!(
1507 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1508 indoc! {"
1509 fn main() {
1510 let mut x = 0;
1511 while x < 10 {
1512 x += 1;
1513 }
1514 }
1515 "}
1516 );
1517 }
1518
1519 #[gpui::test(iterations = 10)]
1520 async fn test_autoindent_when_generating_past_indentation(
1521 cx: &mut TestAppContext,
1522 mut rng: StdRng,
1523 ) {
1524 init_test(cx);
1525
1526 let text = indoc! {"
1527 fn main() {
1528 le
1529 }
1530 "};
1531 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1532 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1533 let range = buffer.read_with(cx, |buffer, cx| {
1534 let snapshot = buffer.snapshot(cx);
1535 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1536 });
1537 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1538 let codegen = cx.new(|cx| {
1539 CodegenAlternative::new(
1540 buffer.clone(),
1541 range.clone(),
1542 true,
1543 None,
1544 prompt_builder,
1545 cx,
1546 )
1547 });
1548
1549 let chunks_tx = simulate_response_stream(&codegen, cx);
1550
1551 cx.background_executor.run_until_parked();
1552
1553 let mut new_text = concat!(
1554 "t mut x = 0;\n",
1555 "while x < 10 {\n",
1556 " x += 1;\n",
1557 "}", //
1558 );
1559 while !new_text.is_empty() {
1560 let max_len = cmp::min(new_text.len(), 10);
1561 let len = rng.random_range(1..=max_len);
1562 let (chunk, suffix) = new_text.split_at(len);
1563 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1564 new_text = suffix;
1565 cx.background_executor.run_until_parked();
1566 }
1567 drop(chunks_tx);
1568 cx.background_executor.run_until_parked();
1569
1570 assert_eq!(
1571 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1572 indoc! {"
1573 fn main() {
1574 let mut x = 0;
1575 while x < 10 {
1576 x += 1;
1577 }
1578 }
1579 "}
1580 );
1581 }
1582
1583 #[gpui::test(iterations = 10)]
1584 async fn test_autoindent_when_generating_before_indentation(
1585 cx: &mut TestAppContext,
1586 mut rng: StdRng,
1587 ) {
1588 init_test(cx);
1589
1590 let text = concat!(
1591 "fn main() {\n",
1592 " \n",
1593 "}\n" //
1594 );
1595 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1596 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1597 let range = buffer.read_with(cx, |buffer, cx| {
1598 let snapshot = buffer.snapshot(cx);
1599 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1600 });
1601 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1602 let codegen = cx.new(|cx| {
1603 CodegenAlternative::new(
1604 buffer.clone(),
1605 range.clone(),
1606 true,
1607 None,
1608 prompt_builder,
1609 cx,
1610 )
1611 });
1612
1613 let chunks_tx = simulate_response_stream(&codegen, cx);
1614
1615 cx.background_executor.run_until_parked();
1616
1617 let mut new_text = concat!(
1618 "let mut x = 0;\n",
1619 "while x < 10 {\n",
1620 " x += 1;\n",
1621 "}", //
1622 );
1623 while !new_text.is_empty() {
1624 let max_len = cmp::min(new_text.len(), 10);
1625 let len = rng.random_range(1..=max_len);
1626 let (chunk, suffix) = new_text.split_at(len);
1627 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1628 new_text = suffix;
1629 cx.background_executor.run_until_parked();
1630 }
1631 drop(chunks_tx);
1632 cx.background_executor.run_until_parked();
1633
1634 assert_eq!(
1635 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1636 indoc! {"
1637 fn main() {
1638 let mut x = 0;
1639 while x < 10 {
1640 x += 1;
1641 }
1642 }
1643 "}
1644 );
1645 }
1646
1647 #[gpui::test(iterations = 10)]
1648 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1649 init_test(cx);
1650
1651 let text = indoc! {"
1652 func main() {
1653 \tx := 0
1654 \tfor i := 0; i < 10; i++ {
1655 \t\tx++
1656 \t}
1657 }
1658 "};
1659 let buffer = cx.new(|cx| Buffer::local(text, cx));
1660 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1661 let range = buffer.read_with(cx, |buffer, cx| {
1662 let snapshot = buffer.snapshot(cx);
1663 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1664 });
1665 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1666 let codegen = cx.new(|cx| {
1667 CodegenAlternative::new(
1668 buffer.clone(),
1669 range.clone(),
1670 true,
1671 None,
1672 prompt_builder,
1673 cx,
1674 )
1675 });
1676
1677 let chunks_tx = simulate_response_stream(&codegen, cx);
1678 let new_text = concat!(
1679 "func main() {\n",
1680 "\tx := 0\n",
1681 "\tfor x < 10 {\n",
1682 "\t\tx++\n",
1683 "\t}", //
1684 );
1685 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1686 drop(chunks_tx);
1687 cx.background_executor.run_until_parked();
1688
1689 assert_eq!(
1690 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1691 indoc! {"
1692 func main() {
1693 \tx := 0
1694 \tfor x < 10 {
1695 \t\tx++
1696 \t}
1697 }
1698 "}
1699 );
1700 }
1701
1702 #[gpui::test]
1703 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1704 init_test(cx);
1705
1706 let text = indoc! {"
1707 fn main() {
1708 let x = 0;
1709 }
1710 "};
1711 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1712 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1713 let range = buffer.read_with(cx, |buffer, cx| {
1714 let snapshot = buffer.snapshot(cx);
1715 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1716 });
1717 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1718 let codegen = cx.new(|cx| {
1719 CodegenAlternative::new(
1720 buffer.clone(),
1721 range.clone(),
1722 false,
1723 None,
1724 prompt_builder,
1725 cx,
1726 )
1727 });
1728
1729 let chunks_tx = simulate_response_stream(&codegen, cx);
1730 chunks_tx
1731 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1732 .unwrap();
1733 drop(chunks_tx);
1734 cx.run_until_parked();
1735
1736 // The codegen is inactive, so the buffer doesn't get modified.
1737 assert_eq!(
1738 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1739 text
1740 );
1741
1742 // Activating the codegen applies the changes.
1743 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1744 assert_eq!(
1745 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1746 indoc! {"
1747 fn main() {
1748 let mut x = 0;
1749 x += 1;
1750 }
1751 "}
1752 );
1753
1754 // Deactivating the codegen undoes the changes.
1755 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1756 cx.run_until_parked();
1757 assert_eq!(
1758 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1759 text
1760 );
1761 }
1762
1763 #[gpui::test]
1764 async fn test_strip_invalid_spans_from_codeblock() {
1765 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1766 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1767 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1768 assert_chunks(
1769 "```html\n```js\nLorem ipsum dolor\n```\n```",
1770 "```js\nLorem ipsum dolor\n```",
1771 )
1772 .await;
1773 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1774 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1775 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1776 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1777
1778 async fn assert_chunks(text: &str, expected_text: &str) {
1779 for chunk_size in 1..=text.len() {
1780 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1781 .map(|chunk| chunk.unwrap())
1782 .collect::<String>()
1783 .await;
1784 assert_eq!(
1785 actual_text, expected_text,
1786 "failed to strip invalid spans, chunk size: {}",
1787 chunk_size
1788 );
1789 }
1790 }
1791
1792 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1793 stream::iter(
1794 text.chars()
1795 .collect::<Vec<_>>()
1796 .chunks(size)
1797 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1798 .collect::<Vec<_>>(),
1799 )
1800 }
1801 }
1802
1803 fn init_test(cx: &mut TestAppContext) {
1804 cx.update(LanguageModelRegistry::test);
1805 cx.set_global(cx.update(SettingsStore::test));
1806 }
1807
1808 fn simulate_response_stream(
1809 codegen: &Entity<CodegenAlternative>,
1810 cx: &mut TestAppContext,
1811 ) -> mpsc::UnboundedSender<String> {
1812 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1813 codegen.update(cx, |codegen, cx| {
1814 codegen.generation = codegen.handle_stream(
1815 String::new(),
1816 String::new(),
1817 None,
1818 future::ready(Ok(LanguageModelTextStream {
1819 message_id: None,
1820 stream: chunks_rx.map(Ok).boxed(),
1821 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1822 })),
1823 cx,
1824 );
1825 });
1826 chunks_tx
1827 }
1828}