1use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
2use agent_settings::AgentSettings;
3use anyhow::{Context as _, Result};
4use uuid::Uuid;
5
6use cloud_llm_client::CompletionIntent;
7use collections::HashSet;
8use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
9use futures::{
10 SinkExt, Stream, StreamExt, TryStreamExt as _,
11 channel::mpsc,
12 future::{LocalBoxFuture, Shared},
13 join,
14 stream::BoxStream,
15};
16use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
17use language::{Buffer, IndentKind, LanguageName, Point, TransactionId, line_diff};
18use language_model::{
19 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
20 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
21 LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolChoice,
22 LanguageModelToolUse, Role, TokenUsage,
23};
24use multi_buffer::MultiBufferRow;
25use parking_lot::Mutex;
26use prompt_store::PromptBuilder;
27use rope::Rope;
28use schemars::JsonSchema;
29use serde::{Deserialize, Serialize};
30use settings::Settings as _;
31use smol::future::FutureExt;
32use std::{
33 cmp,
34 future::Future,
35 iter,
36 ops::{Range, RangeInclusive},
37 pin::Pin,
38 sync::Arc,
39 task::{self, Poll},
40 time::Instant,
41};
42use streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff};
43
44/// Use this tool when you cannot or should not make a rewrite. This includes:
45/// - The user's request is unclear, ambiguous, or nonsensical
46/// - The requested change cannot be made by only editing the <rewrite_this> section
47#[derive(Debug, Serialize, Deserialize, JsonSchema)]
48pub struct FailureMessageInput {
49 /// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
50 #[serde(default)]
51 pub message: String,
52}
53
54/// Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.
55/// Only use this tool when you are confident you understand the user's request and can fulfill it
56/// by editing the marked section.
57#[derive(Debug, Serialize, Deserialize, JsonSchema)]
58pub struct RewriteSectionInput {
59 /// The text to replace the section with.
60 #[serde(default)]
61 pub replacement_text: String,
62}
63
64pub struct BufferCodegen {
65 alternatives: Vec<Entity<CodegenAlternative>>,
66 pub active_alternative: usize,
67 seen_alternatives: HashSet<usize>,
68 subscriptions: Vec<Subscription>,
69 buffer: Entity<MultiBuffer>,
70 range: Range<Anchor>,
71 initial_transaction_id: Option<TransactionId>,
72 builder: Arc<PromptBuilder>,
73 pub is_insertion: bool,
74 session_id: Uuid,
75}
76
77pub const REWRITE_SECTION_TOOL_NAME: &str = "rewrite_section";
78pub const FAILURE_MESSAGE_TOOL_NAME: &str = "failure_message";
79
80impl BufferCodegen {
81 pub fn new(
82 buffer: Entity<MultiBuffer>,
83 range: Range<Anchor>,
84 initial_transaction_id: Option<TransactionId>,
85 session_id: Uuid,
86 builder: Arc<PromptBuilder>,
87 cx: &mut Context<Self>,
88 ) -> Self {
89 let codegen = cx.new(|cx| {
90 CodegenAlternative::new(
91 buffer.clone(),
92 range.clone(),
93 false,
94 builder.clone(),
95 session_id,
96 cx,
97 )
98 });
99 let mut this = Self {
100 is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
101 alternatives: vec![codegen],
102 active_alternative: 0,
103 seen_alternatives: HashSet::default(),
104 subscriptions: Vec::new(),
105 buffer,
106 range,
107 initial_transaction_id,
108 builder,
109 session_id,
110 };
111 this.activate(0, cx);
112 this
113 }
114
115 fn subscribe_to_alternative(&mut self, cx: &mut Context<Self>) {
116 let codegen = self.active_alternative().clone();
117 self.subscriptions.clear();
118 self.subscriptions
119 .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
120 self.subscriptions
121 .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
122 }
123
124 pub fn active_completion(&self, cx: &App) -> Option<String> {
125 self.active_alternative().read(cx).current_completion()
126 }
127
128 pub fn active_alternative(&self) -> &Entity<CodegenAlternative> {
129 &self.alternatives[self.active_alternative]
130 }
131
132 pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
133 self.active_alternative().read(cx).language_name(cx)
134 }
135
136 pub fn status<'a>(&self, cx: &'a App) -> &'a CodegenStatus {
137 &self.active_alternative().read(cx).status
138 }
139
140 pub fn alternative_count(&self, cx: &App) -> usize {
141 LanguageModelRegistry::read_global(cx)
142 .inline_alternative_models()
143 .len()
144 + 1
145 }
146
147 pub fn cycle_prev(&mut self, cx: &mut Context<Self>) {
148 let next_active_ix = if self.active_alternative == 0 {
149 self.alternatives.len() - 1
150 } else {
151 self.active_alternative - 1
152 };
153 self.activate(next_active_ix, cx);
154 }
155
156 pub fn cycle_next(&mut self, cx: &mut Context<Self>) {
157 let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
158 self.activate(next_active_ix, cx);
159 }
160
161 fn activate(&mut self, index: usize, cx: &mut Context<Self>) {
162 self.active_alternative()
163 .update(cx, |codegen, cx| codegen.set_active(false, cx));
164 self.seen_alternatives.insert(index);
165 self.active_alternative = index;
166 self.active_alternative()
167 .update(cx, |codegen, cx| codegen.set_active(true, cx));
168 self.subscribe_to_alternative(cx);
169 cx.notify();
170 }
171
172 pub fn start(
173 &mut self,
174 primary_model: Arc<dyn LanguageModel>,
175 user_prompt: String,
176 context_task: Shared<Task<Option<LoadedContext>>>,
177 cx: &mut Context<Self>,
178 ) -> Result<()> {
179 let alternative_models = LanguageModelRegistry::read_global(cx)
180 .inline_alternative_models()
181 .to_vec();
182
183 self.active_alternative()
184 .update(cx, |alternative, cx| alternative.undo(cx));
185 self.activate(0, cx);
186 self.alternatives.truncate(1);
187
188 for _ in 0..alternative_models.len() {
189 self.alternatives.push(cx.new(|cx| {
190 CodegenAlternative::new(
191 self.buffer.clone(),
192 self.range.clone(),
193 false,
194 self.builder.clone(),
195 self.session_id,
196 cx,
197 )
198 }));
199 }
200
201 for (model, alternative) in iter::once(primary_model)
202 .chain(alternative_models)
203 .zip(&self.alternatives)
204 {
205 alternative.update(cx, |alternative, cx| {
206 alternative.start(user_prompt.clone(), context_task.clone(), model.clone(), cx)
207 })?;
208 }
209
210 Ok(())
211 }
212
213 pub fn stop(&mut self, cx: &mut Context<Self>) {
214 for codegen in &self.alternatives {
215 codegen.update(cx, |codegen, cx| codegen.stop(cx));
216 }
217 }
218
219 pub fn undo(&mut self, cx: &mut Context<Self>) {
220 self.active_alternative()
221 .update(cx, |codegen, cx| codegen.undo(cx));
222
223 self.buffer.update(cx, |buffer, cx| {
224 if let Some(transaction_id) = self.initial_transaction_id.take() {
225 buffer.undo_transaction(transaction_id, cx);
226 buffer.refresh_preview(cx);
227 }
228 });
229 }
230
231 pub fn buffer(&self, cx: &App) -> Entity<MultiBuffer> {
232 self.active_alternative().read(cx).buffer.clone()
233 }
234
235 pub fn old_buffer(&self, cx: &App) -> Entity<Buffer> {
236 self.active_alternative().read(cx).old_buffer.clone()
237 }
238
239 pub fn snapshot(&self, cx: &App) -> MultiBufferSnapshot {
240 self.active_alternative().read(cx).snapshot.clone()
241 }
242
243 pub fn edit_position(&self, cx: &App) -> Option<Anchor> {
244 self.active_alternative().read(cx).edit_position
245 }
246
247 pub fn diff<'a>(&self, cx: &'a App) -> &'a Diff {
248 &self.active_alternative().read(cx).diff
249 }
250
251 pub fn last_equal_ranges<'a>(&self, cx: &'a App) -> &'a [Range<Anchor>] {
252 self.active_alternative().read(cx).last_equal_ranges()
253 }
254
255 pub fn selected_text<'a>(&self, cx: &'a App) -> Option<&'a str> {
256 self.active_alternative().read(cx).selected_text()
257 }
258
259 pub fn session_id(&self) -> Uuid {
260 self.session_id
261 }
262}
263
264impl EventEmitter<CodegenEvent> for BufferCodegen {}
265
266pub struct CodegenAlternative {
267 buffer: Entity<MultiBuffer>,
268 old_buffer: Entity<Buffer>,
269 snapshot: MultiBufferSnapshot,
270 edit_position: Option<Anchor>,
271 range: Range<Anchor>,
272 last_equal_ranges: Vec<Range<Anchor>>,
273 transformation_transaction_id: Option<TransactionId>,
274 status: CodegenStatus,
275 generation: Task<()>,
276 diff: Diff,
277 _subscription: gpui::Subscription,
278 builder: Arc<PromptBuilder>,
279 active: bool,
280 edits: Vec<(Range<Anchor>, String)>,
281 line_operations: Vec<LineOperation>,
282 elapsed_time: Option<f64>,
283 completion: Option<String>,
284 selected_text: Option<String>,
285 pub message_id: Option<String>,
286 session_id: Uuid,
287 pub description: Option<String>,
288 pub failure: Option<String>,
289}
290
291impl EventEmitter<CodegenEvent> for CodegenAlternative {}
292
293impl CodegenAlternative {
294 pub fn new(
295 buffer: Entity<MultiBuffer>,
296 range: Range<Anchor>,
297 active: bool,
298 builder: Arc<PromptBuilder>,
299 session_id: Uuid,
300 cx: &mut Context<Self>,
301 ) -> Self {
302 let snapshot = buffer.read(cx).snapshot(cx);
303
304 let (old_buffer, _, _) = snapshot
305 .range_to_buffer_ranges(range.start..=range.end)
306 .pop()
307 .unwrap();
308 let old_buffer = cx.new(|cx| {
309 let text = old_buffer.as_rope().clone();
310 let line_ending = old_buffer.line_ending();
311 let language = old_buffer.language().cloned();
312 let language_registry = buffer
313 .read(cx)
314 .buffer(old_buffer.remote_id())
315 .unwrap()
316 .read(cx)
317 .language_registry();
318
319 let mut buffer = Buffer::local_normalized(text, line_ending, cx);
320 buffer.set_language(language, cx);
321 if let Some(language_registry) = language_registry {
322 buffer.set_language_registry(language_registry);
323 }
324 buffer
325 });
326
327 Self {
328 buffer: buffer.clone(),
329 old_buffer,
330 edit_position: None,
331 message_id: None,
332 snapshot,
333 last_equal_ranges: Default::default(),
334 transformation_transaction_id: None,
335 status: CodegenStatus::Idle,
336 generation: Task::ready(()),
337 diff: Diff::default(),
338 builder,
339 active: active,
340 edits: Vec::new(),
341 line_operations: Vec::new(),
342 range,
343 elapsed_time: None,
344 completion: None,
345 selected_text: None,
346 session_id,
347 description: None,
348 failure: None,
349 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
350 }
351 }
352
353 pub fn language_name(&self, cx: &App) -> Option<LanguageName> {
354 self.old_buffer
355 .read(cx)
356 .language()
357 .map(|language| language.name())
358 }
359
360 pub fn set_active(&mut self, active: bool, cx: &mut Context<Self>) {
361 if active != self.active {
362 self.active = active;
363
364 if self.active {
365 let edits = self.edits.clone();
366 self.apply_edits(edits, cx);
367 if matches!(self.status, CodegenStatus::Pending) {
368 let line_operations = self.line_operations.clone();
369 self.reapply_line_based_diff(line_operations, cx);
370 } else {
371 self.reapply_batch_diff(cx).detach();
372 }
373 } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
374 self.buffer.update(cx, |buffer, cx| {
375 buffer.undo_transaction(transaction_id, cx);
376 buffer.forget_transaction(transaction_id, cx);
377 });
378 }
379 }
380 }
381
382 fn handle_buffer_event(
383 &mut self,
384 _buffer: Entity<MultiBuffer>,
385 event: &multi_buffer::Event,
386 cx: &mut Context<Self>,
387 ) {
388 if let multi_buffer::Event::TransactionUndone { transaction_id } = event
389 && self.transformation_transaction_id == Some(*transaction_id)
390 {
391 self.transformation_transaction_id = None;
392 self.generation = Task::ready(());
393 cx.emit(CodegenEvent::Undone);
394 }
395 }
396
397 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
398 &self.last_equal_ranges
399 }
400
401 pub fn use_streaming_tools(model: &dyn LanguageModel, cx: &App) -> bool {
402 model.supports_streaming_tools()
403 && AgentSettings::get_global(cx).inline_assistant_use_streaming_tools
404 }
405
406 pub fn start(
407 &mut self,
408 user_prompt: String,
409 context_task: Shared<Task<Option<LoadedContext>>>,
410 model: Arc<dyn LanguageModel>,
411 cx: &mut Context<Self>,
412 ) -> Result<()> {
413 // Clear the model explanation since the user has started a new generation.
414 self.description = None;
415
416 if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
417 self.buffer.update(cx, |buffer, cx| {
418 buffer.undo_transaction(transformation_transaction_id, cx);
419 });
420 }
421
422 self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
423
424 if Self::use_streaming_tools(model.as_ref(), cx) {
425 let request = self.build_request(&model, user_prompt, context_task, cx)?;
426 let completion_events = cx.spawn({
427 let model = model.clone();
428 async move |_, cx| model.stream_completion(request.await, cx).await
429 });
430 self.generation = self.handle_completion(model, completion_events, cx);
431 } else {
432 let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
433 if user_prompt.trim().to_lowercase() == "delete" {
434 async { Ok(LanguageModelTextStream::default()) }.boxed_local()
435 } else {
436 let request = self.build_request(&model, user_prompt, context_task, cx)?;
437 cx.spawn({
438 let model = model.clone();
439 async move |_, cx| {
440 Ok(model.stream_completion_text(request.await, cx).await?)
441 }
442 })
443 .boxed_local()
444 };
445 self.generation =
446 self.handle_stream(model, /* strip_invalid_spans: */ true, stream, cx);
447 }
448
449 Ok(())
450 }
451
452 fn build_request_tools(
453 &self,
454 model: &Arc<dyn LanguageModel>,
455 user_prompt: String,
456 context_task: Shared<Task<Option<LoadedContext>>>,
457 cx: &mut App,
458 ) -> Result<Task<LanguageModelRequest>> {
459 let buffer = self.buffer.read(cx).snapshot(cx);
460 let language = buffer.language_at(self.range.start);
461 let language_name = if let Some(language) = language.as_ref() {
462 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
463 None
464 } else {
465 Some(language.name())
466 }
467 } else {
468 None
469 };
470
471 let language_name = language_name.as_ref();
472 let start = buffer.point_to_buffer_offset(self.range.start);
473 let end = buffer.point_to_buffer_offset(self.range.end);
474 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
475 let (start_buffer, start_buffer_offset) = start;
476 let (end_buffer, end_buffer_offset) = end;
477 if start_buffer.remote_id() == end_buffer.remote_id() {
478 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
479 } else {
480 anyhow::bail!("invalid transformation range");
481 }
482 } else {
483 anyhow::bail!("invalid transformation range");
484 };
485
486 let system_prompt = self
487 .builder
488 .generate_inline_transformation_prompt_tools(
489 language_name,
490 buffer,
491 range.start.0..range.end.0,
492 )
493 .context("generating content prompt")?;
494
495 let temperature = AgentSettings::temperature_for_model(model, cx);
496
497 let tool_input_format = model.tool_input_format();
498 let tool_choice = model
499 .supports_tool_choice(LanguageModelToolChoice::Any)
500 .then_some(LanguageModelToolChoice::Any);
501
502 Ok(cx.spawn(async move |_cx| {
503 let mut messages = vec![LanguageModelRequestMessage {
504 role: Role::System,
505 content: vec![system_prompt.into()],
506 cache: false,
507 reasoning_details: None,
508 }];
509
510 let mut user_message = LanguageModelRequestMessage {
511 role: Role::User,
512 content: Vec::new(),
513 cache: false,
514 reasoning_details: None,
515 };
516
517 if let Some(context) = context_task.await {
518 context.add_to_request_message(&mut user_message);
519 }
520
521 user_message.content.push(user_prompt.into());
522 messages.push(user_message);
523
524 let tools = vec![
525 LanguageModelRequestTool {
526 name: REWRITE_SECTION_TOOL_NAME.to_string(),
527 description: "Replaces text in <rewrite_this></rewrite_this> tags with your replacement_text.".to_string(),
528 input_schema: language_model::tool_schema::root_schema_for::<RewriteSectionInput>(tool_input_format).to_value(),
529 use_input_streaming: false,
530 },
531 LanguageModelRequestTool {
532 name: FAILURE_MESSAGE_TOOL_NAME.to_string(),
533 description: "Use this tool to provide a message to the user when you're unable to complete a task.".to_string(),
534 input_schema: language_model::tool_schema::root_schema_for::<FailureMessageInput>(tool_input_format).to_value(),
535 use_input_streaming: false,
536 },
537 ];
538
539 LanguageModelRequest {
540 thread_id: None,
541 prompt_id: None,
542 intent: Some(CompletionIntent::InlineAssist),
543 tools,
544 tool_choice,
545 stop: Vec::new(),
546 temperature,
547 messages,
548 thinking_allowed: false,
549 thinking_effort: None,
550 speed: None,
551 }
552 }))
553 }
554
555 fn build_request(
556 &self,
557 model: &Arc<dyn LanguageModel>,
558 user_prompt: String,
559 context_task: Shared<Task<Option<LoadedContext>>>,
560 cx: &mut App,
561 ) -> Result<Task<LanguageModelRequest>> {
562 if Self::use_streaming_tools(model.as_ref(), cx) {
563 return self.build_request_tools(model, user_prompt, context_task, cx);
564 }
565
566 let buffer = self.buffer.read(cx).snapshot(cx);
567 let language = buffer.language_at(self.range.start);
568 let language_name = if let Some(language) = language.as_ref() {
569 if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
570 None
571 } else {
572 Some(language.name())
573 }
574 } else {
575 None
576 };
577
578 let language_name = language_name.as_ref();
579 let start = buffer.point_to_buffer_offset(self.range.start);
580 let end = buffer.point_to_buffer_offset(self.range.end);
581 let (buffer, range) = if let Some((start, end)) = start.zip(end) {
582 let (start_buffer, start_buffer_offset) = start;
583 let (end_buffer, end_buffer_offset) = end;
584 if start_buffer.remote_id() == end_buffer.remote_id() {
585 (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
586 } else {
587 anyhow::bail!("invalid transformation range");
588 }
589 } else {
590 anyhow::bail!("invalid transformation range");
591 };
592
593 let prompt = self
594 .builder
595 .generate_inline_transformation_prompt(
596 user_prompt,
597 language_name,
598 buffer,
599 range.start.0..range.end.0,
600 )
601 .context("generating content prompt")?;
602
603 let temperature = AgentSettings::temperature_for_model(model, cx);
604
605 Ok(cx.spawn(async move |_cx| {
606 let mut request_message = LanguageModelRequestMessage {
607 role: Role::User,
608 content: Vec::new(),
609 cache: false,
610 reasoning_details: None,
611 };
612
613 if let Some(context) = context_task.await {
614 context.add_to_request_message(&mut request_message);
615 }
616
617 request_message.content.push(prompt.into());
618
619 LanguageModelRequest {
620 thread_id: None,
621 prompt_id: None,
622 intent: Some(CompletionIntent::InlineAssist),
623 tools: Vec::new(),
624 tool_choice: None,
625 stop: Vec::new(),
626 temperature,
627 messages: vec![request_message],
628 thinking_allowed: false,
629 thinking_effort: None,
630 speed: None,
631 }
632 }))
633 }
634
635 pub fn handle_stream(
636 &mut self,
637 model: Arc<dyn LanguageModel>,
638 strip_invalid_spans: bool,
639 stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
640 cx: &mut Context<Self>,
641 ) -> Task<()> {
642 let anthropic_reporter = language_model::AnthropicEventReporter::new(&model, cx);
643 let session_id = self.session_id;
644 let model_telemetry_id = model.telemetry_id();
645 let model_provider_id = model.provider_id().to_string();
646 let start_time = Instant::now();
647
648 // Make a new snapshot and re-resolve anchor in case the document was modified.
649 // This can happen often if the editor loses focus and is saved + reformatted,
650 // as in https://github.com/zed-industries/zed/issues/39088
651 self.snapshot = self.buffer.read(cx).snapshot(cx);
652 self.range = self.snapshot.anchor_after(self.range.start)
653 ..self.snapshot.anchor_after(self.range.end);
654
655 let snapshot = self.snapshot.clone();
656 let selected_text = snapshot
657 .text_for_range(self.range.start..self.range.end)
658 .collect::<Rope>();
659
660 self.selected_text = Some(selected_text.to_string());
661
662 let selection_start = self.range.start.to_point(&snapshot);
663
664 // Start with the indentation of the first line in the selection
665 let mut suggested_line_indent = snapshot
666 .suggested_indents(selection_start.row..=selection_start.row, cx)
667 .into_values()
668 .next()
669 .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
670
671 // If the first line in the selection does not have indentation, check the following lines
672 if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
673 for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
674 let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
675 // Prefer tabs if a line in the selection uses tabs as indentation
676 if line_indent.kind == IndentKind::Tab {
677 suggested_line_indent.kind = IndentKind::Tab;
678 break;
679 }
680 }
681 }
682
683 let language_name = {
684 let multibuffer = self.buffer.read(cx);
685 let snapshot = multibuffer.snapshot(cx);
686 let ranges = snapshot.range_to_buffer_ranges(self.range.start..=self.range.end);
687 ranges
688 .first()
689 .and_then(|(buffer, _, _)| buffer.language())
690 .map(|language| language.name())
691 };
692
693 self.diff = Diff::default();
694 self.status = CodegenStatus::Pending;
695 let mut edit_start = self.range.start.to_offset(&snapshot);
696 let completion = Arc::new(Mutex::new(String::new()));
697 let completion_clone = completion.clone();
698
699 cx.notify();
700 cx.spawn(async move |codegen, cx| {
701 let stream = stream.await;
702
703 let token_usage = stream
704 .as_ref()
705 .ok()
706 .map(|stream| stream.last_token_usage.clone());
707 let message_id = stream
708 .as_ref()
709 .ok()
710 .and_then(|stream| stream.message_id.clone());
711 let generate = async {
712 let model_telemetry_id = model_telemetry_id.clone();
713 let model_provider_id = model_provider_id.clone();
714 let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
715 let message_id = message_id.clone();
716 let line_based_stream_diff: Task<anyhow::Result<()>> = cx.background_spawn({
717 let anthropic_reporter = anthropic_reporter.clone();
718 let language_name = language_name.clone();
719 async move {
720 let mut response_latency = None;
721 let request_start = Instant::now();
722 let diff = async {
723 let raw_stream = stream?.stream.map_err(|error| error.into());
724
725 let stripped;
726 let mut chunks: Pin<Box<dyn Stream<Item = Result<String>> + Send>> =
727 if strip_invalid_spans {
728 stripped = StripInvalidSpans::new(raw_stream);
729 Box::pin(stripped)
730 } else {
731 Box::pin(raw_stream)
732 };
733
734 let mut diff = StreamingDiff::new(selected_text.to_string());
735 let mut line_diff = LineDiff::default();
736
737 let mut new_text = String::new();
738 let mut base_indent = None;
739 let mut line_indent = None;
740 let mut first_line = true;
741
742 while let Some(chunk) = chunks.next().await {
743 if response_latency.is_none() {
744 response_latency = Some(request_start.elapsed());
745 }
746 let chunk = chunk?;
747 completion_clone.lock().push_str(&chunk);
748
749 let mut lines = chunk.split('\n').peekable();
750 while let Some(line) = lines.next() {
751 new_text.push_str(line);
752 if line_indent.is_none()
753 && let Some(non_whitespace_ch_ix) =
754 new_text.find(|ch: char| !ch.is_whitespace())
755 {
756 line_indent = Some(non_whitespace_ch_ix);
757 base_indent = base_indent.or(line_indent);
758
759 let line_indent = line_indent.unwrap();
760 let base_indent = base_indent.unwrap();
761 let indent_delta = line_indent as i32 - base_indent as i32;
762 let mut corrected_indent_len = cmp::max(
763 0,
764 suggested_line_indent.len as i32 + indent_delta,
765 )
766 as usize;
767 if first_line {
768 corrected_indent_len = corrected_indent_len
769 .saturating_sub(selection_start.column as usize);
770 }
771
772 let indent_char = suggested_line_indent.char();
773 let mut indent_buffer = [0; 4];
774 let indent_str =
775 indent_char.encode_utf8(&mut indent_buffer);
776 new_text.replace_range(
777 ..line_indent,
778 &indent_str.repeat(corrected_indent_len),
779 );
780 }
781
782 if line_indent.is_some() {
783 let char_ops = diff.push_new(&new_text);
784 line_diff.push_char_operations(&char_ops, &selected_text);
785 diff_tx
786 .send((char_ops, line_diff.line_operations()))
787 .await?;
788 new_text.clear();
789 }
790
791 if lines.peek().is_some() {
792 let char_ops = diff.push_new("\n");
793 line_diff.push_char_operations(&char_ops, &selected_text);
794 diff_tx
795 .send((char_ops, line_diff.line_operations()))
796 .await?;
797 if line_indent.is_none() {
798 // Don't write out the leading indentation in empty lines on the next line
799 // This is the case where the above if statement didn't clear the buffer
800 new_text.clear();
801 }
802 line_indent = None;
803 first_line = false;
804 }
805 }
806 }
807
808 let mut char_ops = diff.push_new(&new_text);
809 char_ops.extend(diff.finish());
810 line_diff.push_char_operations(&char_ops, &selected_text);
811 line_diff.finish(&selected_text);
812 diff_tx
813 .send((char_ops, line_diff.line_operations()))
814 .await?;
815
816 anyhow::Ok(())
817 };
818
819 let result = diff.await;
820
821 let error_message = result.as_ref().err().map(|error| error.to_string());
822 telemetry::event!(
823 "Assistant Responded",
824 kind = "inline",
825 phase = "response",
826 session_id = session_id.to_string(),
827 model = model_telemetry_id,
828 model_provider = model_provider_id,
829 language_name = language_name.as_ref().map(|n| n.to_string()),
830 message_id = message_id.as_deref(),
831 response_latency = response_latency,
832 error_message = error_message.as_deref(),
833 );
834
835 anthropic_reporter.report(language_model::AnthropicEventData {
836 completion_type: language_model::AnthropicCompletionType::Editor,
837 event: language_model::AnthropicEventType::Response,
838 language_name: language_name.map(|n| n.to_string()),
839 message_id,
840 });
841
842 result?;
843 Ok(())
844 }
845 });
846
847 while let Some((char_ops, line_ops)) = diff_rx.next().await {
848 codegen.update(cx, |codegen, cx| {
849 codegen.last_equal_ranges.clear();
850
851 let edits = char_ops
852 .into_iter()
853 .filter_map(|operation| match operation {
854 CharOperation::Insert { text } => {
855 let edit_start = snapshot.anchor_after(edit_start);
856 Some((edit_start..edit_start, text))
857 }
858 CharOperation::Delete { bytes } => {
859 let edit_end = edit_start + bytes;
860 let edit_range = snapshot.anchor_after(edit_start)
861 ..snapshot.anchor_before(edit_end);
862 edit_start = edit_end;
863 Some((edit_range, String::new()))
864 }
865 CharOperation::Keep { bytes } => {
866 let edit_end = edit_start + bytes;
867 let edit_range = snapshot.anchor_after(edit_start)
868 ..snapshot.anchor_before(edit_end);
869 edit_start = edit_end;
870 codegen.last_equal_ranges.push(edit_range);
871 None
872 }
873 })
874 .collect::<Vec<_>>();
875
876 if codegen.active {
877 codegen.apply_edits(edits.iter().cloned(), cx);
878 codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
879 }
880 codegen.edits.extend(edits);
881 codegen.line_operations = line_ops;
882 codegen.edit_position = Some(snapshot.anchor_after(edit_start));
883
884 cx.notify();
885 })?;
886 }
887
888 // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
889 // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
890 // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
891 let batch_diff_task =
892 codegen.update(cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
893 let (line_based_stream_diff, ()) = join!(line_based_stream_diff, batch_diff_task);
894 line_based_stream_diff?;
895
896 anyhow::Ok(())
897 };
898
899 let result = generate.await;
900 let elapsed_time = start_time.elapsed().as_secs_f64();
901
902 codegen
903 .update(cx, |this, cx| {
904 this.message_id = message_id;
905 this.last_equal_ranges.clear();
906 if let Err(error) = result {
907 this.status = CodegenStatus::Error(error);
908 } else {
909 this.status = CodegenStatus::Done;
910 }
911 this.elapsed_time = Some(elapsed_time);
912 this.completion = Some(completion.lock().clone());
913 if let Some(usage) = token_usage {
914 let usage = usage.lock();
915 telemetry::event!(
916 "Inline Assistant Completion",
917 model = model_telemetry_id,
918 model_provider = model_provider_id,
919 input_tokens = usage.input_tokens,
920 output_tokens = usage.output_tokens,
921 )
922 }
923
924 cx.emit(CodegenEvent::Finished);
925 cx.notify();
926 })
927 .ok();
928 })
929 }
930
931 pub fn current_completion(&self) -> Option<String> {
932 self.completion.clone()
933 }
934
935 #[cfg(any(test, feature = "test-support"))]
936 pub fn current_description(&self) -> Option<String> {
937 self.description.clone()
938 }
939
940 #[cfg(any(test, feature = "test-support"))]
941 pub fn current_failure(&self) -> Option<String> {
942 self.failure.clone()
943 }
944
945 pub fn selected_text(&self) -> Option<&str> {
946 self.selected_text.as_deref()
947 }
948
949 pub fn stop(&mut self, cx: &mut Context<Self>) {
950 self.last_equal_ranges.clear();
951 if self.diff.is_empty() {
952 self.status = CodegenStatus::Idle;
953 } else {
954 self.status = CodegenStatus::Done;
955 }
956 self.generation = Task::ready(());
957 cx.emit(CodegenEvent::Finished);
958 cx.notify();
959 }
960
961 pub fn undo(&mut self, cx: &mut Context<Self>) {
962 self.buffer.update(cx, |buffer, cx| {
963 if let Some(transaction_id) = self.transformation_transaction_id.take() {
964 buffer.undo_transaction(transaction_id, cx);
965 buffer.refresh_preview(cx);
966 }
967 });
968 }
969
970 fn apply_edits(
971 &mut self,
972 edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
973 cx: &mut Context<CodegenAlternative>,
974 ) {
975 let transaction = self.buffer.update(cx, |buffer, cx| {
976 // Avoid grouping agent edits with user edits.
977 buffer.finalize_last_transaction(cx);
978 buffer.start_transaction(cx);
979 buffer.edit(edits, None, cx);
980 buffer.end_transaction(cx)
981 });
982
983 if let Some(transaction) = transaction {
984 if let Some(first_transaction) = self.transformation_transaction_id {
985 // Group all agent edits into the first transaction.
986 self.buffer.update(cx, |buffer, cx| {
987 buffer.merge_transactions(transaction, first_transaction, cx)
988 });
989 } else {
990 self.transformation_transaction_id = Some(transaction);
991 self.buffer
992 .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
993 }
994 }
995 }
996
997 fn reapply_line_based_diff(
998 &mut self,
999 line_operations: impl IntoIterator<Item = LineOperation>,
1000 cx: &mut Context<Self>,
1001 ) {
1002 let old_snapshot = self.snapshot.clone();
1003 let old_range = self.range.to_point(&old_snapshot);
1004 let new_snapshot = self.buffer.read(cx).snapshot(cx);
1005 let new_range = self.range.to_point(&new_snapshot);
1006
1007 let mut old_row = old_range.start.row;
1008 let mut new_row = new_range.start.row;
1009
1010 self.diff.deleted_row_ranges.clear();
1011 self.diff.inserted_row_ranges.clear();
1012 for operation in line_operations {
1013 match operation {
1014 LineOperation::Keep { lines } => {
1015 old_row += lines;
1016 new_row += lines;
1017 }
1018 LineOperation::Delete { lines } => {
1019 let old_end_row = old_row + lines - 1;
1020 let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
1021
1022 if let Some((_, last_deleted_row_range)) =
1023 self.diff.deleted_row_ranges.last_mut()
1024 {
1025 if *last_deleted_row_range.end() + 1 == old_row {
1026 *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
1027 } else {
1028 self.diff
1029 .deleted_row_ranges
1030 .push((new_row, old_row..=old_end_row));
1031 }
1032 } else {
1033 self.diff
1034 .deleted_row_ranges
1035 .push((new_row, old_row..=old_end_row));
1036 }
1037
1038 old_row += lines;
1039 }
1040 LineOperation::Insert { lines } => {
1041 let new_end_row = new_row + lines - 1;
1042 let start = new_snapshot.anchor_before(Point::new(new_row, 0));
1043 let end = new_snapshot.anchor_before(Point::new(
1044 new_end_row,
1045 new_snapshot.line_len(MultiBufferRow(new_end_row)),
1046 ));
1047 self.diff.inserted_row_ranges.push(start..end);
1048 new_row += lines;
1049 }
1050 }
1051
1052 cx.notify();
1053 }
1054 }
1055
1056 fn reapply_batch_diff(&mut self, cx: &mut Context<Self>) -> Task<()> {
1057 let old_snapshot = self.snapshot.clone();
1058 let old_range = self.range.to_point(&old_snapshot);
1059 let new_snapshot = self.buffer.read(cx).snapshot(cx);
1060 let new_range = self.range.to_point(&new_snapshot);
1061
1062 cx.spawn(async move |codegen, cx| {
1063 let (deleted_row_ranges, inserted_row_ranges) = cx
1064 .background_spawn(async move {
1065 let old_text = old_snapshot
1066 .text_for_range(
1067 Point::new(old_range.start.row, 0)
1068 ..Point::new(
1069 old_range.end.row,
1070 old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
1071 ),
1072 )
1073 .collect::<String>();
1074 let new_text = new_snapshot
1075 .text_for_range(
1076 Point::new(new_range.start.row, 0)
1077 ..Point::new(
1078 new_range.end.row,
1079 new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
1080 ),
1081 )
1082 .collect::<String>();
1083
1084 let old_start_row = old_range.start.row;
1085 let new_start_row = new_range.start.row;
1086 let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
1087 let mut inserted_row_ranges = Vec::new();
1088 for (old_rows, new_rows) in line_diff(&old_text, &new_text) {
1089 let old_rows = old_start_row + old_rows.start..old_start_row + old_rows.end;
1090 let new_rows = new_start_row + new_rows.start..new_start_row + new_rows.end;
1091 if !old_rows.is_empty() {
1092 deleted_row_ranges.push((
1093 new_snapshot.anchor_before(Point::new(new_rows.start, 0)),
1094 old_rows.start..=old_rows.end - 1,
1095 ));
1096 }
1097 if !new_rows.is_empty() {
1098 let start = new_snapshot.anchor_before(Point::new(new_rows.start, 0));
1099 let new_end_row = new_rows.end - 1;
1100 let end = new_snapshot.anchor_before(Point::new(
1101 new_end_row,
1102 new_snapshot.line_len(MultiBufferRow(new_end_row)),
1103 ));
1104 inserted_row_ranges.push(start..end);
1105 }
1106 }
1107 (deleted_row_ranges, inserted_row_ranges)
1108 })
1109 .await;
1110
1111 codegen
1112 .update(cx, |codegen, cx| {
1113 codegen.diff.deleted_row_ranges = deleted_row_ranges;
1114 codegen.diff.inserted_row_ranges = inserted_row_ranges;
1115 cx.notify();
1116 })
1117 .ok();
1118 })
1119 }
1120
1121 fn handle_completion(
1122 &mut self,
1123 model: Arc<dyn LanguageModel>,
1124 completion_stream: Task<
1125 Result<
1126 BoxStream<
1127 'static,
1128 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1129 >,
1130 LanguageModelCompletionError,
1131 >,
1132 >,
1133 cx: &mut Context<Self>,
1134 ) -> Task<()> {
1135 self.diff = Diff::default();
1136 self.status = CodegenStatus::Pending;
1137
1138 cx.notify();
1139 // Leaving this in generation so that STOP equivalent events are respected even
1140 // while we're still pre-processing the completion event
1141 cx.spawn(async move |codegen, cx| {
1142 let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
1143 let _ = codegen.update(cx, |this, cx| {
1144 this.status = status;
1145 cx.emit(CodegenEvent::Finished);
1146 cx.notify();
1147 });
1148 };
1149
1150 let mut completion_events = match completion_stream.await {
1151 Ok(events) => events,
1152 Err(err) => {
1153 finish_with_status(CodegenStatus::Error(err.into()), cx);
1154 return;
1155 }
1156 };
1157
1158 enum ToolUseOutput {
1159 Rewrite {
1160 text: String,
1161 description: Option<String>,
1162 },
1163 Failure(String),
1164 }
1165
1166 enum ModelUpdate {
1167 Description(String),
1168 Failure(String),
1169 }
1170
1171 let chars_read_so_far = Arc::new(Mutex::new(0usize));
1172 let process_tool_use = move |tool_use: LanguageModelToolUse| -> Option<ToolUseOutput> {
1173 let mut chars_read_so_far = chars_read_so_far.lock();
1174 match tool_use.name.as_ref() {
1175 REWRITE_SECTION_TOOL_NAME => {
1176 let Ok(input) =
1177 serde_json::from_value::<RewriteSectionInput>(tool_use.input)
1178 else {
1179 return None;
1180 };
1181 let text = input.replacement_text[*chars_read_so_far..].to_string();
1182 *chars_read_so_far = input.replacement_text.len();
1183 Some(ToolUseOutput::Rewrite {
1184 text,
1185 description: None,
1186 })
1187 }
1188 FAILURE_MESSAGE_TOOL_NAME => {
1189 let Ok(mut input) =
1190 serde_json::from_value::<FailureMessageInput>(tool_use.input)
1191 else {
1192 return None;
1193 };
1194 Some(ToolUseOutput::Failure(std::mem::take(&mut input.message)))
1195 }
1196 _ => None,
1197 }
1198 };
1199
1200 let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded::<ModelUpdate>();
1201
1202 cx.spawn({
1203 let codegen = codegen.clone();
1204 async move |cx| {
1205 while let Some(update) = message_rx.next().await {
1206 let _ = codegen.update(cx, |this, _cx| match update {
1207 ModelUpdate::Description(d) => this.description = Some(d),
1208 ModelUpdate::Failure(f) => this.failure = Some(f),
1209 });
1210 }
1211 }
1212 })
1213 .detach();
1214
1215 let mut message_id = None;
1216 let mut first_text = None;
1217 let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
1218 let total_text = Arc::new(Mutex::new(String::new()));
1219
1220 loop {
1221 if let Some(first_event) = completion_events.next().await {
1222 match first_event {
1223 Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
1224 message_id = Some(id);
1225 }
1226 Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1227 if let Some(output) = process_tool_use(tool_use) {
1228 let (text, update) = match output {
1229 ToolUseOutput::Rewrite { text, description } => {
1230 (Some(text), description.map(ModelUpdate::Description))
1231 }
1232 ToolUseOutput::Failure(message) => {
1233 (None, Some(ModelUpdate::Failure(message)))
1234 }
1235 };
1236 if let Some(update) = update {
1237 let _ = message_tx.unbounded_send(update);
1238 }
1239 first_text = text;
1240 if first_text.is_some() {
1241 break;
1242 }
1243 }
1244 }
1245 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1246 *last_token_usage.lock() = token_usage;
1247 }
1248 Ok(LanguageModelCompletionEvent::Text(text)) => {
1249 let mut lock = total_text.lock();
1250 lock.push_str(&text);
1251 }
1252 Ok(e) => {
1253 log::warn!("Unexpected event: {:?}", e);
1254 break;
1255 }
1256 Err(e) => {
1257 finish_with_status(CodegenStatus::Error(e.into()), cx);
1258 break;
1259 }
1260 }
1261 }
1262 }
1263
1264 let Some(first_text) = first_text else {
1265 finish_with_status(CodegenStatus::Done, cx);
1266 return;
1267 };
1268
1269 let move_last_token_usage = last_token_usage.clone();
1270
1271 let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
1272 completion_events.filter_map(move |e| {
1273 let process_tool_use = process_tool_use.clone();
1274 let last_token_usage = move_last_token_usage.clone();
1275 let total_text = total_text.clone();
1276 let mut message_tx = message_tx.clone();
1277 async move {
1278 match e {
1279 Ok(LanguageModelCompletionEvent::ToolUse(tool_use)) => {
1280 let Some(output) = process_tool_use(tool_use) else {
1281 return None;
1282 };
1283 let (text, update) = match output {
1284 ToolUseOutput::Rewrite { text, description } => {
1285 (Some(text), description.map(ModelUpdate::Description))
1286 }
1287 ToolUseOutput::Failure(message) => {
1288 (None, Some(ModelUpdate::Failure(message)))
1289 }
1290 };
1291 if let Some(update) = update {
1292 let _ = message_tx.send(update).await;
1293 }
1294 text.map(Ok)
1295 }
1296 Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
1297 *last_token_usage.lock() = token_usage;
1298 None
1299 }
1300 Ok(LanguageModelCompletionEvent::Text(text)) => {
1301 let mut lock = total_text.lock();
1302 lock.push_str(&text);
1303 None
1304 }
1305 Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
1306 e => {
1307 log::error!("UNEXPECTED EVENT {:?}", e);
1308 None
1309 }
1310 }
1311 }
1312 }),
1313 ));
1314
1315 let language_model_text_stream = LanguageModelTextStream {
1316 message_id: message_id,
1317 stream: text_stream,
1318 last_token_usage,
1319 };
1320
1321 let Some(task) = codegen
1322 .update(cx, move |codegen, cx| {
1323 codegen.handle_stream(
1324 model,
1325 /* strip_invalid_spans: */ false,
1326 async { Ok(language_model_text_stream) },
1327 cx,
1328 )
1329 })
1330 .ok()
1331 else {
1332 return;
1333 };
1334
1335 task.await;
1336 })
1337 }
1338}
1339
1340#[derive(Copy, Clone, Debug)]
1341pub enum CodegenEvent {
1342 Finished,
1343 Undone,
1344}
1345
1346struct StripInvalidSpans<T> {
1347 stream: T,
1348 stream_done: bool,
1349 buffer: String,
1350 first_line: bool,
1351 line_end: bool,
1352 starts_with_code_block: bool,
1353}
1354
1355impl<T> StripInvalidSpans<T>
1356where
1357 T: Stream<Item = Result<String>>,
1358{
1359 fn new(stream: T) -> Self {
1360 Self {
1361 stream,
1362 stream_done: false,
1363 buffer: String::new(),
1364 first_line: true,
1365 line_end: false,
1366 starts_with_code_block: false,
1367 }
1368 }
1369}
1370
1371impl<T> Stream for StripInvalidSpans<T>
1372where
1373 T: Stream<Item = Result<String>>,
1374{
1375 type Item = Result<String>;
1376
1377 fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
1378 const CODE_BLOCK_DELIMITER: &str = "```";
1379 const CURSOR_SPAN: &str = "<|CURSOR|>";
1380
1381 let this = unsafe { self.get_unchecked_mut() };
1382 loop {
1383 if !this.stream_done {
1384 let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
1385 match stream.as_mut().poll_next(cx) {
1386 Poll::Ready(Some(Ok(chunk))) => {
1387 this.buffer.push_str(&chunk);
1388 }
1389 Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
1390 Poll::Ready(None) => {
1391 this.stream_done = true;
1392 }
1393 Poll::Pending => return Poll::Pending,
1394 }
1395 }
1396
1397 let mut chunk = String::new();
1398 let mut consumed = 0;
1399 if !this.buffer.is_empty() {
1400 let mut lines = this.buffer.split('\n').enumerate().peekable();
1401 while let Some((line_ix, line)) = lines.next() {
1402 if line_ix > 0 {
1403 this.first_line = false;
1404 }
1405
1406 if this.first_line {
1407 let trimmed_line = line.trim();
1408 if lines.peek().is_some() {
1409 if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
1410 consumed += line.len() + 1;
1411 this.starts_with_code_block = true;
1412 continue;
1413 }
1414 } else if trimmed_line.is_empty()
1415 || prefixes(CODE_BLOCK_DELIMITER)
1416 .any(|prefix| trimmed_line.starts_with(prefix))
1417 {
1418 break;
1419 }
1420 }
1421
1422 let line_without_cursor = line.replace(CURSOR_SPAN, "");
1423 if lines.peek().is_some() {
1424 if this.line_end {
1425 chunk.push('\n');
1426 }
1427
1428 chunk.push_str(&line_without_cursor);
1429 this.line_end = true;
1430 consumed += line.len() + 1;
1431 } else if this.stream_done {
1432 if !this.starts_with_code_block
1433 || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
1434 {
1435 if this.line_end {
1436 chunk.push('\n');
1437 }
1438
1439 chunk.push_str(line);
1440 }
1441
1442 consumed += line.len();
1443 } else {
1444 let trimmed_line = line.trim();
1445 if trimmed_line.is_empty()
1446 || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
1447 || prefixes(CODE_BLOCK_DELIMITER)
1448 .any(|prefix| trimmed_line.ends_with(prefix))
1449 {
1450 break;
1451 } else {
1452 if this.line_end {
1453 chunk.push('\n');
1454 this.line_end = false;
1455 }
1456
1457 chunk.push_str(&line_without_cursor);
1458 consumed += line.len();
1459 }
1460 }
1461 }
1462 }
1463
1464 this.buffer = this.buffer.split_off(consumed);
1465 if !chunk.is_empty() {
1466 return Poll::Ready(Some(Ok(chunk)));
1467 } else if this.stream_done {
1468 return Poll::Ready(None);
1469 }
1470 }
1471 }
1472}
1473
1474fn prefixes(text: &str) -> impl Iterator<Item = &str> {
1475 (0..text.len() - 1).map(|ix| &text[..ix + 1])
1476}
1477
1478#[derive(Default)]
1479pub struct Diff {
1480 pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
1481 pub inserted_row_ranges: Vec<Range<Anchor>>,
1482}
1483
1484impl Diff {
1485 fn is_empty(&self) -> bool {
1486 self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
1487 }
1488}
1489
1490#[cfg(test)]
1491mod tests {
1492 use super::*;
1493 use futures::{
1494 Stream,
1495 stream::{self},
1496 };
1497 use gpui::TestAppContext;
1498 use indoc::indoc;
1499 use language::{Buffer, Point};
1500 use language_model::fake_provider::FakeLanguageModel;
1501 use language_model::{
1502 LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelRegistry,
1503 LanguageModelToolUse, StopReason, TokenUsage,
1504 };
1505 use languages::rust_lang;
1506 use rand::prelude::*;
1507 use settings::SettingsStore;
1508 use std::{future, sync::Arc};
1509
1510 #[gpui::test(iterations = 10)]
1511 async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
1512 init_test(cx);
1513
1514 let text = indoc! {"
1515 fn main() {
1516 let x = 0;
1517 for _ in 0..10 {
1518 x += 1;
1519 }
1520 }
1521 "};
1522 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1523 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1524 let range = buffer.read_with(cx, |buffer, cx| {
1525 let snapshot = buffer.snapshot(cx);
1526 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
1527 });
1528 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1529 let codegen = cx.new(|cx| {
1530 CodegenAlternative::new(
1531 buffer.clone(),
1532 range.clone(),
1533 true,
1534 prompt_builder,
1535 Uuid::new_v4(),
1536 cx,
1537 )
1538 });
1539
1540 let chunks_tx = simulate_response_stream(&codegen, cx);
1541
1542 let mut new_text = concat!(
1543 " let mut x = 0;\n",
1544 " while x < 10 {\n",
1545 " x += 1;\n",
1546 " }",
1547 );
1548 while !new_text.is_empty() {
1549 let max_len = cmp::min(new_text.len(), 10);
1550 let len = rng.random_range(1..=max_len);
1551 let (chunk, suffix) = new_text.split_at(len);
1552 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1553 new_text = suffix;
1554 cx.background_executor.run_until_parked();
1555 }
1556 drop(chunks_tx);
1557 cx.background_executor.run_until_parked();
1558
1559 assert_eq!(
1560 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1561 indoc! {"
1562 fn main() {
1563 let mut x = 0;
1564 while x < 10 {
1565 x += 1;
1566 }
1567 }
1568 "}
1569 );
1570 }
1571
1572 #[gpui::test(iterations = 10)]
1573 async fn test_autoindent_when_generating_past_indentation(
1574 cx: &mut TestAppContext,
1575 mut rng: StdRng,
1576 ) {
1577 init_test(cx);
1578
1579 let text = indoc! {"
1580 fn main() {
1581 le
1582 }
1583 "};
1584 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1585 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1586 let range = buffer.read_with(cx, |buffer, cx| {
1587 let snapshot = buffer.snapshot(cx);
1588 snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
1589 });
1590 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1591 let codegen = cx.new(|cx| {
1592 CodegenAlternative::new(
1593 buffer.clone(),
1594 range.clone(),
1595 true,
1596 prompt_builder,
1597 Uuid::new_v4(),
1598 cx,
1599 )
1600 });
1601
1602 let chunks_tx = simulate_response_stream(&codegen, cx);
1603
1604 cx.background_executor.run_until_parked();
1605
1606 let mut new_text = concat!(
1607 "t mut x = 0;\n",
1608 "while x < 10 {\n",
1609 " x += 1;\n",
1610 "}", //
1611 );
1612 while !new_text.is_empty() {
1613 let max_len = cmp::min(new_text.len(), 10);
1614 let len = rng.random_range(1..=max_len);
1615 let (chunk, suffix) = new_text.split_at(len);
1616 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1617 new_text = suffix;
1618 cx.background_executor.run_until_parked();
1619 }
1620 drop(chunks_tx);
1621 cx.background_executor.run_until_parked();
1622
1623 assert_eq!(
1624 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1625 indoc! {"
1626 fn main() {
1627 let mut x = 0;
1628 while x < 10 {
1629 x += 1;
1630 }
1631 }
1632 "}
1633 );
1634 }
1635
1636 #[gpui::test(iterations = 10)]
1637 async fn test_autoindent_when_generating_before_indentation(
1638 cx: &mut TestAppContext,
1639 mut rng: StdRng,
1640 ) {
1641 init_test(cx);
1642
1643 let text = concat!(
1644 "fn main() {\n",
1645 " \n",
1646 "}\n" //
1647 );
1648 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1649 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1650 let range = buffer.read_with(cx, |buffer, cx| {
1651 let snapshot = buffer.snapshot(cx);
1652 snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
1653 });
1654 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1655 let codegen = cx.new(|cx| {
1656 CodegenAlternative::new(
1657 buffer.clone(),
1658 range.clone(),
1659 true,
1660 prompt_builder,
1661 Uuid::new_v4(),
1662 cx,
1663 )
1664 });
1665
1666 let chunks_tx = simulate_response_stream(&codegen, cx);
1667
1668 cx.background_executor.run_until_parked();
1669
1670 let mut new_text = concat!(
1671 "let mut x = 0;\n",
1672 "while x < 10 {\n",
1673 " x += 1;\n",
1674 "}", //
1675 );
1676 while !new_text.is_empty() {
1677 let max_len = cmp::min(new_text.len(), 10);
1678 let len = rng.random_range(1..=max_len);
1679 let (chunk, suffix) = new_text.split_at(len);
1680 chunks_tx.unbounded_send(chunk.to_string()).unwrap();
1681 new_text = suffix;
1682 cx.background_executor.run_until_parked();
1683 }
1684 drop(chunks_tx);
1685 cx.background_executor.run_until_parked();
1686
1687 assert_eq!(
1688 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1689 indoc! {"
1690 fn main() {
1691 let mut x = 0;
1692 while x < 10 {
1693 x += 1;
1694 }
1695 }
1696 "}
1697 );
1698 }
1699
1700 #[gpui::test(iterations = 10)]
1701 async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
1702 init_test(cx);
1703
1704 let text = indoc! {"
1705 func main() {
1706 \tx := 0
1707 \tfor i := 0; i < 10; i++ {
1708 \t\tx++
1709 \t}
1710 }
1711 "};
1712 let buffer = cx.new(|cx| Buffer::local(text, cx));
1713 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1714 let range = buffer.read_with(cx, |buffer, cx| {
1715 let snapshot = buffer.snapshot(cx);
1716 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
1717 });
1718 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1719 let codegen = cx.new(|cx| {
1720 CodegenAlternative::new(
1721 buffer.clone(),
1722 range.clone(),
1723 true,
1724 prompt_builder,
1725 Uuid::new_v4(),
1726 cx,
1727 )
1728 });
1729
1730 let chunks_tx = simulate_response_stream(&codegen, cx);
1731 let new_text = concat!(
1732 "func main() {\n",
1733 "\tx := 0\n",
1734 "\tfor x < 10 {\n",
1735 "\t\tx++\n",
1736 "\t}", //
1737 );
1738 chunks_tx.unbounded_send(new_text.to_string()).unwrap();
1739 drop(chunks_tx);
1740 cx.background_executor.run_until_parked();
1741
1742 assert_eq!(
1743 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1744 indoc! {"
1745 func main() {
1746 \tx := 0
1747 \tfor x < 10 {
1748 \t\tx++
1749 \t}
1750 }
1751 "}
1752 );
1753 }
1754
1755 #[gpui::test]
1756 async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
1757 init_test(cx);
1758
1759 let text = indoc! {"
1760 fn main() {
1761 let x = 0;
1762 }
1763 "};
1764 let buffer = cx.new(|cx| Buffer::local(text, cx).with_language(rust_lang(), cx));
1765 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1766 let range = buffer.read_with(cx, |buffer, cx| {
1767 let snapshot = buffer.snapshot(cx);
1768 snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
1769 });
1770 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1771 let codegen = cx.new(|cx| {
1772 CodegenAlternative::new(
1773 buffer.clone(),
1774 range.clone(),
1775 false,
1776 prompt_builder,
1777 Uuid::new_v4(),
1778 cx,
1779 )
1780 });
1781
1782 let chunks_tx = simulate_response_stream(&codegen, cx);
1783 chunks_tx
1784 .unbounded_send("let mut x = 0;\nx += 1;".to_string())
1785 .unwrap();
1786 drop(chunks_tx);
1787 cx.run_until_parked();
1788
1789 // The codegen is inactive, so the buffer doesn't get modified.
1790 assert_eq!(
1791 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1792 text
1793 );
1794
1795 // Activating the codegen applies the changes.
1796 codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
1797 assert_eq!(
1798 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1799 indoc! {"
1800 fn main() {
1801 let mut x = 0;
1802 x += 1;
1803 }
1804 "}
1805 );
1806
1807 // Deactivating the codegen undoes the changes.
1808 codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
1809 cx.run_until_parked();
1810 assert_eq!(
1811 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1812 text
1813 );
1814 }
1815
1816 // When not streaming tool calls, we strip backticks as part of parsing the model's
1817 // plain text response. This is a regression test for a bug where we stripped
1818 // backticks incorrectly.
1819 #[gpui::test]
1820 async fn test_allows_model_to_output_backticks(cx: &mut TestAppContext) {
1821 init_test(cx);
1822 let text = "- Improved; `cmd+click` behavior. Now requires `cmd` to be pressed before the click starts or it doesn't run. ([#44579](https://github.com/zed-industries/zed/pull/44579); thanks [Zachiah](https://github.com/Zachiah))";
1823 let buffer = cx.new(|cx| Buffer::local("", cx));
1824 let buffer = cx.new(|cx| MultiBuffer::singleton(buffer, cx));
1825 let range = buffer.read_with(cx, |buffer, cx| {
1826 let snapshot = buffer.snapshot(cx);
1827 snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(0, 0))
1828 });
1829 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
1830 let codegen = cx.new(|cx| {
1831 CodegenAlternative::new(
1832 buffer.clone(),
1833 range.clone(),
1834 true,
1835 prompt_builder,
1836 Uuid::new_v4(),
1837 cx,
1838 )
1839 });
1840
1841 let events_tx = simulate_tool_based_completion(&codegen, cx);
1842 let chunk_len = text.find('`').unwrap();
1843 events_tx
1844 .unbounded_send(rewrite_tool_use("tool_1", &text[..chunk_len], false))
1845 .unwrap();
1846 events_tx
1847 .unbounded_send(rewrite_tool_use("tool_2", &text, true))
1848 .unwrap();
1849 events_tx
1850 .unbounded_send(LanguageModelCompletionEvent::Stop(StopReason::EndTurn))
1851 .unwrap();
1852 drop(events_tx);
1853 cx.run_until_parked();
1854
1855 assert_eq!(
1856 buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
1857 text
1858 );
1859 }
1860
1861 #[gpui::test]
1862 async fn test_strip_invalid_spans_from_codeblock() {
1863 assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
1864 assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
1865 assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
1866 assert_chunks(
1867 "```html\n```js\nLorem ipsum dolor\n```\n```",
1868 "```js\nLorem ipsum dolor\n```",
1869 )
1870 .await;
1871 assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
1872 assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
1873 assert_chunks("Lorem ipsum", "Lorem ipsum").await;
1874 assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
1875
1876 async fn assert_chunks(text: &str, expected_text: &str) {
1877 for chunk_size in 1..=text.len() {
1878 let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
1879 .map(|chunk| chunk.unwrap())
1880 .collect::<String>()
1881 .await;
1882 assert_eq!(
1883 actual_text, expected_text,
1884 "failed to strip invalid spans, chunk size: {}",
1885 chunk_size
1886 );
1887 }
1888 }
1889
1890 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
1891 stream::iter(
1892 text.chars()
1893 .collect::<Vec<_>>()
1894 .chunks(size)
1895 .map(|chunk| Ok(chunk.iter().collect::<String>()))
1896 .collect::<Vec<_>>(),
1897 )
1898 }
1899 }
1900
1901 fn init_test(cx: &mut TestAppContext) {
1902 cx.update(LanguageModelRegistry::test);
1903 cx.set_global(cx.update(SettingsStore::test));
1904 }
1905
1906 fn simulate_response_stream(
1907 codegen: &Entity<CodegenAlternative>,
1908 cx: &mut TestAppContext,
1909 ) -> mpsc::UnboundedSender<String> {
1910 let (chunks_tx, chunks_rx) = mpsc::unbounded();
1911 let model = Arc::new(FakeLanguageModel::default());
1912 codegen.update(cx, |codegen, cx| {
1913 codegen.generation = codegen.handle_stream(
1914 model,
1915 /* strip_invalid_spans: */ false,
1916 future::ready(Ok(LanguageModelTextStream {
1917 message_id: None,
1918 stream: chunks_rx.map(Ok).boxed(),
1919 last_token_usage: Arc::new(Mutex::new(TokenUsage::default())),
1920 })),
1921 cx,
1922 );
1923 });
1924 chunks_tx
1925 }
1926
1927 fn simulate_tool_based_completion(
1928 codegen: &Entity<CodegenAlternative>,
1929 cx: &mut TestAppContext,
1930 ) -> mpsc::UnboundedSender<LanguageModelCompletionEvent> {
1931 let (events_tx, events_rx) = mpsc::unbounded();
1932 let model = Arc::new(FakeLanguageModel::default());
1933 codegen.update(cx, |codegen, cx| {
1934 let completion_stream = Task::ready(Ok(events_rx.map(Ok).boxed()
1935 as BoxStream<
1936 'static,
1937 Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
1938 >));
1939 codegen.generation = codegen.handle_completion(model, completion_stream, cx);
1940 });
1941 events_tx
1942 }
1943
1944 fn rewrite_tool_use(
1945 id: &str,
1946 replacement_text: &str,
1947 is_complete: bool,
1948 ) -> LanguageModelCompletionEvent {
1949 let input = RewriteSectionInput {
1950 replacement_text: replacement_text.into(),
1951 };
1952 LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
1953 id: id.into(),
1954 name: REWRITE_SECTION_TOOL_NAME.into(),
1955 raw_input: serde_json::to_string(&input).unwrap(),
1956 input: serde_json::to_value(&input).unwrap(),
1957 is_input_complete: is_complete,
1958 thought_signature: None,
1959 })
1960 }
1961}