1use crate::{
2 prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
3 LanguageModelCompletionProvider, MessageId, MessageStatus,
4};
5use anyhow::{anyhow, Context as _, Result};
6use assistant_slash_command::{
7 SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
8};
9use client::{self, proto, telemetry::Telemetry};
10use clock::ReplicaId;
11use collections::{HashMap, HashSet};
12use fs::{Fs, RemoveOptions};
13use futures::{
14 future::{self, Shared},
15 FutureExt, StreamExt,
16};
17use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscription, Task};
18use language::{
19 AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
20};
21use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
22use open_ai::Model as OpenAiModel;
23use paths::contexts_dir;
24use project::Project;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use std::{
28 cmp,
29 fmt::Debug,
30 iter, mem,
31 ops::Range,
32 path::{Path, PathBuf},
33 sync::Arc,
34 time::{Duration, Instant},
35};
36use telemetry_events::AssistantKind;
37use ui::SharedString;
38use util::{post_inc, ResultExt, TryFutureExt};
39use uuid::Uuid;
40
41#[derive(Clone, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
42pub struct ContextId(String);
43
44impl ContextId {
45 pub fn new() -> Self {
46 Self(Uuid::new_v4().to_string())
47 }
48
49 pub fn from_proto(id: String) -> Self {
50 Self(id)
51 }
52
53 pub fn to_proto(&self) -> String {
54 self.0.clone()
55 }
56}
57
58#[derive(Clone, Debug)]
59pub enum ContextOperation {
60 InsertMessage {
61 anchor: MessageAnchor,
62 metadata: MessageMetadata,
63 version: clock::Global,
64 },
65 UpdateMessage {
66 message_id: MessageId,
67 metadata: MessageMetadata,
68 version: clock::Global,
69 },
70 UpdateSummary {
71 summary: ContextSummary,
72 version: clock::Global,
73 },
74 SlashCommandFinished {
75 id: SlashCommandId,
76 output_range: Range<language::Anchor>,
77 sections: Vec<SlashCommandOutputSection<language::Anchor>>,
78 version: clock::Global,
79 },
80 BufferOperation(language::Operation),
81}
82
83impl ContextOperation {
84 pub fn from_proto(op: proto::ContextOperation) -> Result<Self> {
85 match op.variant.context("invalid variant")? {
86 proto::context_operation::Variant::InsertMessage(insert) => {
87 let message = insert.message.context("invalid message")?;
88 let id = MessageId(language::proto::deserialize_timestamp(
89 message.id.context("invalid id")?,
90 ));
91 Ok(Self::InsertMessage {
92 anchor: MessageAnchor {
93 id,
94 start: language::proto::deserialize_anchor(
95 message.start.context("invalid anchor")?,
96 )
97 .context("invalid anchor")?,
98 },
99 metadata: MessageMetadata {
100 role: Role::from_proto(message.role),
101 status: MessageStatus::from_proto(
102 message.status.context("invalid status")?,
103 ),
104 timestamp: id.0,
105 },
106 version: language::proto::deserialize_version(&insert.version),
107 })
108 }
109 proto::context_operation::Variant::UpdateMessage(update) => Ok(Self::UpdateMessage {
110 message_id: MessageId(language::proto::deserialize_timestamp(
111 update.message_id.context("invalid message id")?,
112 )),
113 metadata: MessageMetadata {
114 role: Role::from_proto(update.role),
115 status: MessageStatus::from_proto(update.status.context("invalid status")?),
116 timestamp: language::proto::deserialize_timestamp(
117 update.timestamp.context("invalid timestamp")?,
118 ),
119 },
120 version: language::proto::deserialize_version(&update.version),
121 }),
122 proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary {
123 summary: ContextSummary {
124 text: update.summary,
125 done: update.done,
126 timestamp: language::proto::deserialize_timestamp(
127 update.timestamp.context("invalid timestamp")?,
128 ),
129 },
130 version: language::proto::deserialize_version(&update.version),
131 }),
132 proto::context_operation::Variant::SlashCommandFinished(finished) => {
133 Ok(Self::SlashCommandFinished {
134 id: SlashCommandId(language::proto::deserialize_timestamp(
135 finished.id.context("invalid id")?,
136 )),
137 output_range: language::proto::deserialize_anchor_range(
138 finished.output_range.context("invalid range")?,
139 )?,
140 sections: finished
141 .sections
142 .into_iter()
143 .map(|section| {
144 Ok(SlashCommandOutputSection {
145 range: language::proto::deserialize_anchor_range(
146 section.range.context("invalid range")?,
147 )?,
148 icon: section.icon_name.parse()?,
149 label: section.label.into(),
150 })
151 })
152 .collect::<Result<Vec<_>>>()?,
153 version: language::proto::deserialize_version(&finished.version),
154 })
155 }
156 proto::context_operation::Variant::BufferOperation(op) => Ok(Self::BufferOperation(
157 language::proto::deserialize_operation(
158 op.operation.context("invalid buffer operation")?,
159 )?,
160 )),
161 }
162 }
163
164 pub fn to_proto(&self) -> proto::ContextOperation {
165 match self {
166 Self::InsertMessage {
167 anchor,
168 metadata,
169 version,
170 } => proto::ContextOperation {
171 variant: Some(proto::context_operation::Variant::InsertMessage(
172 proto::context_operation::InsertMessage {
173 message: Some(proto::ContextMessage {
174 id: Some(language::proto::serialize_timestamp(anchor.id.0)),
175 start: Some(language::proto::serialize_anchor(&anchor.start)),
176 role: metadata.role.to_proto() as i32,
177 status: Some(metadata.status.to_proto()),
178 }),
179 version: language::proto::serialize_version(version),
180 },
181 )),
182 },
183 Self::UpdateMessage {
184 message_id,
185 metadata,
186 version,
187 } => proto::ContextOperation {
188 variant: Some(proto::context_operation::Variant::UpdateMessage(
189 proto::context_operation::UpdateMessage {
190 message_id: Some(language::proto::serialize_timestamp(message_id.0)),
191 role: metadata.role.to_proto() as i32,
192 status: Some(metadata.status.to_proto()),
193 timestamp: Some(language::proto::serialize_timestamp(metadata.timestamp)),
194 version: language::proto::serialize_version(version),
195 },
196 )),
197 },
198 Self::UpdateSummary { summary, version } => proto::ContextOperation {
199 variant: Some(proto::context_operation::Variant::UpdateSummary(
200 proto::context_operation::UpdateSummary {
201 summary: summary.text.clone(),
202 done: summary.done,
203 timestamp: Some(language::proto::serialize_timestamp(summary.timestamp)),
204 version: language::proto::serialize_version(version),
205 },
206 )),
207 },
208 Self::SlashCommandFinished {
209 id,
210 output_range,
211 sections,
212 version,
213 } => proto::ContextOperation {
214 variant: Some(proto::context_operation::Variant::SlashCommandFinished(
215 proto::context_operation::SlashCommandFinished {
216 id: Some(language::proto::serialize_timestamp(id.0)),
217 output_range: Some(language::proto::serialize_anchor_range(
218 output_range.clone(),
219 )),
220 sections: sections
221 .iter()
222 .map(|section| {
223 let icon_name: &'static str = section.icon.into();
224 proto::SlashCommandOutputSection {
225 range: Some(language::proto::serialize_anchor_range(
226 section.range.clone(),
227 )),
228 icon_name: icon_name.to_string(),
229 label: section.label.to_string(),
230 }
231 })
232 .collect(),
233 version: language::proto::serialize_version(version),
234 },
235 )),
236 },
237 Self::BufferOperation(operation) => proto::ContextOperation {
238 variant: Some(proto::context_operation::Variant::BufferOperation(
239 proto::context_operation::BufferOperation {
240 operation: Some(language::proto::serialize_operation(operation)),
241 },
242 )),
243 },
244 }
245 }
246
247 fn timestamp(&self) -> clock::Lamport {
248 match self {
249 Self::InsertMessage { anchor, .. } => anchor.id.0,
250 Self::UpdateMessage { metadata, .. } => metadata.timestamp,
251 Self::UpdateSummary { summary, .. } => summary.timestamp,
252 Self::SlashCommandFinished { id, .. } => id.0,
253 Self::BufferOperation(_) => {
254 panic!("reading the timestamp of a buffer operation is not supported")
255 }
256 }
257 }
258
259 /// Returns the current version of the context operation.
260 pub fn version(&self) -> &clock::Global {
261 match self {
262 Self::InsertMessage { version, .. }
263 | Self::UpdateMessage { version, .. }
264 | Self::UpdateSummary { version, .. }
265 | Self::SlashCommandFinished { version, .. } => version,
266 Self::BufferOperation(_) => {
267 panic!("reading the version of a buffer operation is not supported")
268 }
269 }
270 }
271}
272
273#[derive(Clone)]
274pub enum ContextEvent {
275 MessagesEdited,
276 SummaryChanged,
277 EditStepsChanged,
278 StreamedCompletion,
279 PendingSlashCommandsUpdated {
280 removed: Vec<Range<language::Anchor>>,
281 updated: Vec<PendingSlashCommand>,
282 },
283 SlashCommandFinished {
284 output_range: Range<language::Anchor>,
285 sections: Vec<SlashCommandOutputSection<language::Anchor>>,
286 run_commands_in_output: bool,
287 },
288 Operation(ContextOperation),
289}
290
291#[derive(Clone, Default, Debug)]
292pub struct ContextSummary {
293 pub text: String,
294 done: bool,
295 timestamp: clock::Lamport,
296}
297
298#[derive(Clone, Debug, Eq, PartialEq)]
299pub struct MessageAnchor {
300 pub id: MessageId,
301 pub start: language::Anchor,
302}
303
304#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
305pub struct MessageMetadata {
306 pub role: Role,
307 status: MessageStatus,
308 timestamp: clock::Lamport,
309}
310
311#[derive(Clone, Debug, PartialEq, Eq)]
312pub struct Message {
313 pub offset_range: Range<usize>,
314 pub index_range: Range<usize>,
315 pub id: MessageId,
316 pub anchor: language::Anchor,
317 pub role: Role,
318 pub status: MessageStatus,
319}
320
321impl Message {
322 fn to_request_message(&self, buffer: &Buffer) -> LanguageModelRequestMessage {
323 LanguageModelRequestMessage {
324 role: self.role,
325 content: buffer.text_for_range(self.offset_range.clone()).collect(),
326 }
327 }
328}
329
330struct PendingCompletion {
331 id: usize,
332 _task: Task<()>,
333}
334
335#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
336pub struct SlashCommandId(clock::Lamport);
337
338#[derive(Debug)]
339pub struct EditStep {
340 pub source_range: Range<language::Anchor>,
341 pub operations: Option<EditStepOperations>,
342}
343
344#[derive(Debug)]
345pub struct EditSuggestionGroup {
346 pub context_range: Range<language::Anchor>,
347 pub suggestions: Vec<EditSuggestion>,
348}
349
350#[derive(Debug)]
351pub struct EditSuggestion {
352 pub range: Range<language::Anchor>,
353 /// If None, assume this is a suggestion to delete the range rather than transform it.
354 pub description: Option<String>,
355 pub initial_insertion: Option<InitialInsertion>,
356}
357
358impl EditStep {
359 pub fn edit_suggestions(
360 &self,
361 project: &Model<Project>,
362 cx: &AppContext,
363 ) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
364 let Some(EditStepOperations::Ready(operations)) = &self.operations else {
365 return Task::ready(HashMap::default());
366 };
367
368 let suggestion_tasks: Vec<_> = operations
369 .iter()
370 .map(|operation| operation.edit_suggestion(project.clone(), cx))
371 .collect();
372
373 cx.spawn(|mut cx| async move {
374 let suggestions = future::join_all(suggestion_tasks)
375 .await
376 .into_iter()
377 .filter_map(|task| task.log_err())
378 .collect::<Vec<_>>();
379
380 let mut suggestions_by_buffer = HashMap::default();
381 for (buffer, suggestion) in suggestions {
382 suggestions_by_buffer
383 .entry(buffer)
384 .or_insert_with(Vec::new)
385 .push(suggestion);
386 }
387
388 let mut suggestion_groups_by_buffer = HashMap::default();
389 for (buffer, mut suggestions) in suggestions_by_buffer {
390 let mut suggestion_groups = Vec::<EditSuggestionGroup>::new();
391 buffer
392 .update(&mut cx, |buffer, _cx| {
393 // Sort suggestions by their range
394 suggestions.sort_by(|a, b| a.range.cmp(&b.range, buffer));
395
396 // Dedup overlapping suggestions
397 suggestions.dedup_by(|a, b| {
398 let a_range = a.range.to_offset(buffer);
399 let b_range = b.range.to_offset(buffer);
400 if a_range.start <= b_range.end && b_range.start <= a_range.end {
401 if b_range.start < a_range.start {
402 a.range.start = b.range.start;
403 }
404 if b_range.end > a_range.end {
405 a.range.end = b.range.end;
406 }
407
408 if let (Some(a_desc), Some(b_desc)) =
409 (a.description.as_mut(), b.description.as_mut())
410 {
411 b_desc.push('\n');
412 b_desc.push_str(a_desc);
413 } else if a.description.is_some() {
414 b.description = a.description.take();
415 }
416
417 true
418 } else {
419 false
420 }
421 });
422
423 // Create context ranges for each suggestion
424 for suggestion in suggestions {
425 let context_range = {
426 let suggestion_point_range = suggestion.range.to_point(buffer);
427 let start_row = suggestion_point_range.start.row.saturating_sub(5);
428 let end_row = cmp::min(
429 suggestion_point_range.end.row + 5,
430 buffer.max_point().row,
431 );
432 let start = buffer.anchor_before(Point::new(start_row, 0));
433 let end = buffer
434 .anchor_after(Point::new(end_row, buffer.line_len(end_row)));
435 start..end
436 };
437
438 if let Some(last_group) = suggestion_groups.last_mut() {
439 if last_group
440 .context_range
441 .end
442 .cmp(&context_range.start, buffer)
443 .is_ge()
444 {
445 // Merge with the previous group if context ranges overlap
446 last_group.context_range.end = context_range.end;
447 last_group.suggestions.push(suggestion);
448 } else {
449 // Create a new group
450 suggestion_groups.push(EditSuggestionGroup {
451 context_range,
452 suggestions: vec![suggestion],
453 });
454 }
455 } else {
456 // Create the first group
457 suggestion_groups.push(EditSuggestionGroup {
458 context_range,
459 suggestions: vec![suggestion],
460 });
461 }
462 }
463 })
464 .ok();
465 suggestion_groups_by_buffer.insert(buffer, suggestion_groups);
466 }
467
468 suggestion_groups_by_buffer
469 })
470 }
471}
472
473pub enum EditStepOperations {
474 Pending(Task<Option<()>>),
475 Ready(Vec<EditOperation>),
476}
477
478impl Debug for EditStepOperations {
479 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480 match self {
481 EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
482 EditStepOperations::Ready(operations) => f
483 .debug_struct("EditStepOperations::Parsed")
484 .field("operations", operations)
485 .finish(),
486 }
487 }
488}
489
490/// A description of an operation to apply to one location in the codebase.
491///
492/// This object represents a single edit operation that can be performed on a specific file
493/// in the codebase. It encapsulates both the location (file path) and the nature of the
494/// edit to be made.
495///
496/// # Fields
497///
498/// * `path`: A string representing the file path where the edit operation should be applied.
499/// This path is relative to the root of the project or repository.
500///
501/// * `kind`: An enum representing the specific type of edit operation to be performed.
502///
503/// # Usage
504///
505/// `EditOperation` is used within a code editor to represent and apply
506/// programmatic changes to source code. It provides a structured way to describe
507/// edits for features like refactoring tools or AI-assisted coding suggestions.
508#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
509pub struct EditOperation {
510 /// The path to the file containing the relevant operation
511 pub path: String,
512 #[serde(flatten)]
513 pub kind: EditOperationKind,
514}
515
516impl EditOperation {
517 fn edit_suggestion(
518 &self,
519 project: Model<Project>,
520 cx: &AppContext,
521 ) -> Task<Result<(Model<language::Buffer>, EditSuggestion)>> {
522 let path = self.path.clone();
523 let kind = self.kind.clone();
524 cx.spawn(move |mut cx| async move {
525 let buffer = project
526 .update(&mut cx, |project, cx| {
527 let project_path = project
528 .project_path_for_full_path(Path::new(&path), cx)
529 .with_context(|| format!("worktree not found for {:?}", path))?;
530 anyhow::Ok(project.open_buffer(project_path, cx))
531 })??
532 .await?;
533
534 let mut parse_status = buffer.read_with(&cx, |buffer, _cx| buffer.parse_status())?;
535 while *parse_status.borrow() != ParseStatus::Idle {
536 parse_status.changed().await?;
537 }
538
539 let initial_insertion = kind.initial_insertion();
540 let suggestion_range = if let Some(symbol) = kind.symbol() {
541 let outline = buffer
542 .update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
543 .context("no outline for buffer")?;
544 let candidate = outline
545 .path_candidates
546 .iter()
547 .max_by(|a, b| {
548 strsim::jaro_winkler(&a.string, symbol)
549 .total_cmp(&strsim::jaro_winkler(&b.string, symbol))
550 })
551 .with_context(|| {
552 format!(
553 "symbol {:?} not found in path {:?}.\ncandidates: {:?}.\nparse status: {:?}. text:\n{}",
554 symbol,
555 path,
556 outline
557 .path_candidates
558 .iter()
559 .map(|candidate| &candidate.string)
560 .collect::<Vec<_>>(),
561 *parse_status.borrow(),
562 buffer.read_with(&cx, |buffer, _| buffer.text()).unwrap_or_else(|_| "error".to_string())
563 )
564 })?;
565
566 buffer.update(&mut cx, |buffer, _| {
567 let outline_item = &outline.items[candidate.id];
568 let symbol_range = outline_item.range.to_point(buffer);
569 let body_range = outline_item
570 .body_range
571 .as_ref()
572 .map(|range| range.to_point(buffer))
573 .unwrap_or(symbol_range.clone());
574
575 match kind {
576 EditOperationKind::PrependChild { .. } => {
577 let position = buffer.anchor_after(body_range.start);
578 position..position
579 }
580 EditOperationKind::AppendChild { .. } => {
581 let position = buffer.anchor_before(body_range.end);
582 position..position
583 }
584 EditOperationKind::InsertSiblingBefore { .. } => {
585 let position = buffer.anchor_before(symbol_range.start);
586 position..position
587 }
588 EditOperationKind::InsertSiblingAfter { .. } => {
589 let position = buffer.anchor_after(symbol_range.end);
590 position..position
591 }
592 EditOperationKind::Update { .. } | EditOperationKind::Delete { .. } => {
593 let start = Point::new(symbol_range.start.row, 0);
594 let end = Point::new(
595 symbol_range.end.row,
596 buffer.line_len(symbol_range.end.row),
597 );
598 buffer.anchor_before(start)..buffer.anchor_after(end)
599 }
600 EditOperationKind::Create { .. } => unreachable!(),
601 }
602 })?
603 } else {
604 match kind {
605 EditOperationKind::PrependChild { .. } => {
606 language::Anchor::MIN..language::Anchor::MIN
607 }
608 EditOperationKind::AppendChild { .. } | EditOperationKind::Create { .. } => {
609 language::Anchor::MAX..language::Anchor::MAX
610 }
611 _ => unreachable!("All other operations should have a symbol"),
612 }
613 };
614
615 Ok((
616 buffer,
617 EditSuggestion {
618 range: suggestion_range,
619 description: kind.description().map(ToString::to_string),
620 initial_insertion,
621 },
622 ))
623 })
624 }
625}
626
627#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
628#[serde(tag = "kind")]
629pub enum EditOperationKind {
630 /// Rewrites the specified symbol entirely based on the given description.
631 /// This operation completely replaces the existing symbol with new content.
632 Update {
633 /// A fully-qualified reference to the symbol, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
634 /// The path should uniquely identify the symbol within the containing file.
635 symbol: String,
636 /// A brief description of the transformation to apply to the symbol.
637 description: String,
638 },
639 /// Creates a new file with the given path based on the provided description.
640 /// This operation adds a new file to the codebase.
641 Create {
642 /// A brief description of the file to be created.
643 description: String,
644 },
645 /// Inserts a new symbol based on the given description before the specified symbol.
646 /// This operation adds new content immediately preceding an existing symbol.
647 InsertSiblingBefore {
648 /// A fully-qualified reference to the symbol, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
649 /// The new content will be inserted immediately before this symbol.
650 symbol: String,
651 /// A brief description of the new symbol to be inserted.
652 description: String,
653 },
654 /// Inserts a new symbol based on the given description after the specified symbol.
655 /// This operation adds new content immediately following an existing symbol.
656 InsertSiblingAfter {
657 /// A fully-qualified reference to the symbol, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
658 /// The new content will be inserted immediately after this symbol.
659 symbol: String,
660 /// A brief description of the new symbol to be inserted.
661 description: String,
662 },
663 /// Inserts a new symbol as a child of the specified symbol at the start.
664 /// This operation adds new content as the first child of an existing symbol (or file if no symbol is provided).
665 PrependChild {
666 /// An optional fully-qualified reference to the symbol after the code you want to insert, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
667 /// If provided, the new content will be inserted as the first child of this symbol.
668 /// If not provided, the new content will be inserted at the top of the file.
669 symbol: Option<String>,
670 /// A brief description of the new symbol to be inserted.
671 description: String,
672 },
673 /// Inserts a new symbol as a child of the specified symbol at the end.
674 /// This operation adds new content as the last child of an existing symbol (or file if no symbol is provided).
675 AppendChild {
676 /// An optional fully-qualified reference to the symbol before the code you want to insert, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
677 /// If provided, the new content will be inserted as the last child of this symbol.
678 /// If not provided, the new content will be applied at the bottom of the file.
679 symbol: Option<String>,
680 /// A brief description of the new symbol to be inserted.
681 description: String,
682 },
683 /// Deletes the specified symbol from the containing file.
684 Delete {
685 /// An fully-qualified reference to the symbol to be deleted, e.g. `mod foo impl Bar pub fn baz` instead of just `fn baz`.
686 symbol: String,
687 },
688}
689
690impl EditOperationKind {
691 pub fn symbol(&self) -> Option<&str> {
692 match self {
693 Self::Update { symbol, .. } => Some(symbol),
694 Self::InsertSiblingBefore { symbol, .. } => Some(symbol),
695 Self::InsertSiblingAfter { symbol, .. } => Some(symbol),
696 Self::PrependChild { symbol, .. } => symbol.as_deref(),
697 Self::AppendChild { symbol, .. } => symbol.as_deref(),
698 Self::Delete { symbol } => Some(symbol),
699 Self::Create { .. } => None,
700 }
701 }
702
703 pub fn description(&self) -> Option<&str> {
704 match self {
705 Self::Update { description, .. } => Some(description),
706 Self::Create { description } => Some(description),
707 Self::InsertSiblingBefore { description, .. } => Some(description),
708 Self::InsertSiblingAfter { description, .. } => Some(description),
709 Self::PrependChild { description, .. } => Some(description),
710 Self::AppendChild { description, .. } => Some(description),
711 Self::Delete { .. } => None,
712 }
713 }
714
715 pub fn initial_insertion(&self) -> Option<InitialInsertion> {
716 match self {
717 EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
718 EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
719 EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
720 EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
721 _ => None,
722 }
723 }
724}
725
726pub struct Context {
727 id: ContextId,
728 timestamp: clock::Lamport,
729 version: clock::Global,
730 pending_ops: Vec<ContextOperation>,
731 operations: Vec<ContextOperation>,
732 buffer: Model<Buffer>,
733 pending_slash_commands: Vec<PendingSlashCommand>,
734 edits_since_last_slash_command_parse: language::Subscription,
735 finished_slash_commands: HashSet<SlashCommandId>,
736 slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
737 message_anchors: Vec<MessageAnchor>,
738 messages_metadata: HashMap<MessageId, MessageMetadata>,
739 summary: Option<ContextSummary>,
740 pending_summary: Task<Option<()>>,
741 completion_count: usize,
742 pending_completions: Vec<PendingCompletion>,
743 token_count: Option<usize>,
744 pending_token_count: Task<Option<()>>,
745 pending_save: Task<Result<()>>,
746 path: Option<PathBuf>,
747 _subscriptions: Vec<Subscription>,
748 telemetry: Option<Arc<Telemetry>>,
749 language_registry: Arc<LanguageRegistry>,
750 edit_steps: Vec<EditStep>,
751}
752
753impl EventEmitter<ContextEvent> for Context {}
754
755impl Context {
756 pub fn local(
757 language_registry: Arc<LanguageRegistry>,
758 telemetry: Option<Arc<Telemetry>>,
759 cx: &mut ModelContext<Self>,
760 ) -> Self {
761 Self::new(
762 ContextId::new(),
763 ReplicaId::default(),
764 language::Capability::ReadWrite,
765 language_registry,
766 telemetry,
767 cx,
768 )
769 }
770
771 pub fn new(
772 id: ContextId,
773 replica_id: ReplicaId,
774 capability: language::Capability,
775 language_registry: Arc<LanguageRegistry>,
776 telemetry: Option<Arc<Telemetry>>,
777 cx: &mut ModelContext<Self>,
778 ) -> Self {
779 let buffer = cx.new_model(|_cx| {
780 let mut buffer = Buffer::remote(
781 language::BufferId::new(1).unwrap(),
782 replica_id,
783 capability,
784 "",
785 );
786 buffer.set_language_registry(language_registry.clone());
787 buffer
788 });
789 let edits_since_last_slash_command_parse =
790 buffer.update(cx, |buffer, _| buffer.subscribe());
791 let mut this = Self {
792 id,
793 timestamp: clock::Lamport::new(replica_id),
794 version: clock::Global::new(),
795 pending_ops: Vec::new(),
796 operations: Vec::new(),
797 message_anchors: Default::default(),
798 messages_metadata: Default::default(),
799 pending_slash_commands: Vec::new(),
800 finished_slash_commands: HashSet::default(),
801 slash_command_output_sections: Vec::new(),
802 edits_since_last_slash_command_parse,
803 summary: None,
804 pending_summary: Task::ready(None),
805 completion_count: Default::default(),
806 pending_completions: Default::default(),
807 token_count: None,
808 pending_token_count: Task::ready(None),
809 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
810 pending_save: Task::ready(Ok(())),
811 path: None,
812 buffer,
813 telemetry,
814 language_registry,
815 edit_steps: Vec::new(),
816 };
817
818 let first_message_id = MessageId(clock::Lamport {
819 replica_id: 0,
820 value: 0,
821 });
822 let message = MessageAnchor {
823 id: first_message_id,
824 start: language::Anchor::MIN,
825 };
826 this.messages_metadata.insert(
827 first_message_id,
828 MessageMetadata {
829 role: Role::User,
830 status: MessageStatus::Done,
831 timestamp: first_message_id.0,
832 },
833 );
834 this.message_anchors.push(message);
835
836 this.set_language(cx);
837 this.count_remaining_tokens(cx);
838 this
839 }
840
841 fn serialize(&self, cx: &AppContext) -> SavedContext {
842 let buffer = self.buffer.read(cx);
843 SavedContext {
844 id: Some(self.id.clone()),
845 zed: "context".into(),
846 version: SavedContext::VERSION.into(),
847 text: buffer.text(),
848 messages: self
849 .messages(cx)
850 .map(|message| SavedMessage {
851 id: message.id,
852 start: message.offset_range.start,
853 metadata: self.messages_metadata[&message.id].clone(),
854 })
855 .collect(),
856 summary: self
857 .summary
858 .as_ref()
859 .map(|summary| summary.text.clone())
860 .unwrap_or_default(),
861 slash_command_output_sections: self
862 .slash_command_output_sections
863 .iter()
864 .filter_map(|section| {
865 let range = section.range.to_offset(buffer);
866 if section.range.start.is_valid(buffer) && !range.is_empty() {
867 Some(assistant_slash_command::SlashCommandOutputSection {
868 range,
869 icon: section.icon,
870 label: section.label.clone(),
871 })
872 } else {
873 None
874 }
875 })
876 .collect(),
877 }
878 }
879
880 #[allow(clippy::too_many_arguments)]
881 pub fn deserialize(
882 saved_context: SavedContext,
883 path: PathBuf,
884 language_registry: Arc<LanguageRegistry>,
885 telemetry: Option<Arc<Telemetry>>,
886 cx: &mut ModelContext<Self>,
887 ) -> Self {
888 let id = saved_context.id.clone().unwrap_or_else(|| ContextId::new());
889 let mut this = Self::new(
890 id,
891 ReplicaId::default(),
892 language::Capability::ReadWrite,
893 language_registry,
894 telemetry,
895 cx,
896 );
897 this.path = Some(path);
898 this.buffer.update(cx, |buffer, cx| {
899 buffer.set_text(saved_context.text.as_str(), cx)
900 });
901 let operations = saved_context.into_ops(&this.buffer, cx);
902 this.apply_ops(operations, cx).unwrap();
903 this
904 }
905
906 pub fn id(&self) -> &ContextId {
907 &self.id
908 }
909
910 pub fn replica_id(&self) -> ReplicaId {
911 self.timestamp.replica_id
912 }
913
914 pub fn version(&self, cx: &AppContext) -> ContextVersion {
915 ContextVersion {
916 context: self.version.clone(),
917 buffer: self.buffer.read(cx).version(),
918 }
919 }
920
921 pub fn set_capability(
922 &mut self,
923 capability: language::Capability,
924 cx: &mut ModelContext<Self>,
925 ) {
926 self.buffer
927 .update(cx, |buffer, cx| buffer.set_capability(capability, cx));
928 }
929
930 fn next_timestamp(&mut self) -> clock::Lamport {
931 let timestamp = self.timestamp.tick();
932 self.version.observe(timestamp);
933 timestamp
934 }
935
936 pub fn serialize_ops(
937 &self,
938 since: &ContextVersion,
939 cx: &AppContext,
940 ) -> Task<Vec<proto::ContextOperation>> {
941 let buffer_ops = self
942 .buffer
943 .read(cx)
944 .serialize_ops(Some(since.buffer.clone()), cx);
945
946 let mut context_ops = self
947 .operations
948 .iter()
949 .filter(|op| !since.context.observed(op.timestamp()))
950 .cloned()
951 .collect::<Vec<_>>();
952 context_ops.extend(self.pending_ops.iter().cloned());
953
954 cx.background_executor().spawn(async move {
955 let buffer_ops = buffer_ops.await;
956 context_ops.sort_unstable_by_key(|op| op.timestamp());
957 buffer_ops
958 .into_iter()
959 .map(|op| proto::ContextOperation {
960 variant: Some(proto::context_operation::Variant::BufferOperation(
961 proto::context_operation::BufferOperation {
962 operation: Some(op),
963 },
964 )),
965 })
966 .chain(context_ops.into_iter().map(|op| op.to_proto()))
967 .collect()
968 })
969 }
970
971 pub fn apply_ops(
972 &mut self,
973 ops: impl IntoIterator<Item = ContextOperation>,
974 cx: &mut ModelContext<Self>,
975 ) -> Result<()> {
976 let mut buffer_ops = Vec::new();
977 for op in ops {
978 match op {
979 ContextOperation::BufferOperation(buffer_op) => buffer_ops.push(buffer_op),
980 op @ _ => self.pending_ops.push(op),
981 }
982 }
983 self.buffer
984 .update(cx, |buffer, cx| buffer.apply_ops(buffer_ops, cx))?;
985 self.flush_ops(cx);
986
987 Ok(())
988 }
989
990 fn flush_ops(&mut self, cx: &mut ModelContext<Context>) {
991 let mut messages_changed = false;
992 let mut summary_changed = false;
993
994 self.pending_ops.sort_unstable_by_key(|op| op.timestamp());
995 for op in mem::take(&mut self.pending_ops) {
996 if !self.can_apply_op(&op, cx) {
997 self.pending_ops.push(op);
998 continue;
999 }
1000
1001 let timestamp = op.timestamp();
1002 match op.clone() {
1003 ContextOperation::InsertMessage {
1004 anchor, metadata, ..
1005 } => {
1006 if self.messages_metadata.contains_key(&anchor.id) {
1007 // We already applied this operation.
1008 } else {
1009 self.insert_message(anchor, metadata, cx);
1010 messages_changed = true;
1011 }
1012 }
1013 ContextOperation::UpdateMessage {
1014 message_id,
1015 metadata: new_metadata,
1016 ..
1017 } => {
1018 let metadata = self.messages_metadata.get_mut(&message_id).unwrap();
1019 if new_metadata.timestamp > metadata.timestamp {
1020 *metadata = new_metadata;
1021 messages_changed = true;
1022 }
1023 }
1024 ContextOperation::UpdateSummary {
1025 summary: new_summary,
1026 ..
1027 } => {
1028 if self
1029 .summary
1030 .as_ref()
1031 .map_or(true, |summary| new_summary.timestamp > summary.timestamp)
1032 {
1033 self.summary = Some(new_summary);
1034 summary_changed = true;
1035 }
1036 }
1037 ContextOperation::SlashCommandFinished {
1038 id,
1039 output_range,
1040 sections,
1041 ..
1042 } => {
1043 if self.finished_slash_commands.insert(id) {
1044 let buffer = self.buffer.read(cx);
1045 self.slash_command_output_sections
1046 .extend(sections.iter().cloned());
1047 self.slash_command_output_sections
1048 .sort_by(|a, b| a.range.cmp(&b.range, buffer));
1049 cx.emit(ContextEvent::SlashCommandFinished {
1050 output_range,
1051 sections,
1052 run_commands_in_output: false,
1053 });
1054 }
1055 }
1056 ContextOperation::BufferOperation(_) => unreachable!(),
1057 }
1058
1059 self.version.observe(timestamp);
1060 self.timestamp.observe(timestamp);
1061 self.operations.push(op);
1062 }
1063
1064 if messages_changed {
1065 cx.emit(ContextEvent::MessagesEdited);
1066 cx.notify();
1067 }
1068
1069 if summary_changed {
1070 cx.emit(ContextEvent::SummaryChanged);
1071 cx.notify();
1072 }
1073 }
1074
1075 fn can_apply_op(&self, op: &ContextOperation, cx: &AppContext) -> bool {
1076 if !self.version.observed_all(op.version()) {
1077 return false;
1078 }
1079
1080 match op {
1081 ContextOperation::InsertMessage { anchor, .. } => self
1082 .buffer
1083 .read(cx)
1084 .version
1085 .observed(anchor.start.timestamp),
1086 ContextOperation::UpdateMessage { message_id, .. } => {
1087 self.messages_metadata.contains_key(message_id)
1088 }
1089 ContextOperation::UpdateSummary { .. } => true,
1090 ContextOperation::SlashCommandFinished {
1091 output_range,
1092 sections,
1093 ..
1094 } => {
1095 let version = &self.buffer.read(cx).version;
1096 sections
1097 .iter()
1098 .map(|section| §ion.range)
1099 .chain([output_range])
1100 .all(|range| {
1101 let observed_start = range.start == language::Anchor::MIN
1102 || range.start == language::Anchor::MAX
1103 || version.observed(range.start.timestamp);
1104 let observed_end = range.end == language::Anchor::MIN
1105 || range.end == language::Anchor::MAX
1106 || version.observed(range.end.timestamp);
1107 observed_start && observed_end
1108 })
1109 }
1110 ContextOperation::BufferOperation(_) => {
1111 panic!("buffer operations should always be applied")
1112 }
1113 }
1114 }
1115
1116 fn push_op(&mut self, op: ContextOperation, cx: &mut ModelContext<Self>) {
1117 self.operations.push(op.clone());
1118 cx.emit(ContextEvent::Operation(op));
1119 }
1120
1121 pub fn buffer(&self) -> &Model<Buffer> {
1122 &self.buffer
1123 }
1124
1125 pub fn path(&self) -> Option<&Path> {
1126 self.path.as_deref()
1127 }
1128
1129 pub fn summary(&self) -> Option<&ContextSummary> {
1130 self.summary.as_ref()
1131 }
1132
1133 pub fn edit_steps(&self) -> &[EditStep] {
1134 &self.edit_steps
1135 }
1136
1137 pub fn pending_slash_commands(&self) -> &[PendingSlashCommand] {
1138 &self.pending_slash_commands
1139 }
1140
1141 pub fn slash_command_output_sections(&self) -> &[SlashCommandOutputSection<language::Anchor>] {
1142 &self.slash_command_output_sections
1143 }
1144
1145 fn set_language(&mut self, cx: &mut ModelContext<Self>) {
1146 let markdown = self.language_registry.language_for_name("Markdown");
1147 cx.spawn(|this, mut cx| async move {
1148 let markdown = markdown.await?;
1149 this.update(&mut cx, |this, cx| {
1150 this.buffer
1151 .update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx));
1152 })
1153 })
1154 .detach_and_log_err(cx);
1155 }
1156
1157 fn handle_buffer_event(
1158 &mut self,
1159 _: Model<Buffer>,
1160 event: &language::Event,
1161 cx: &mut ModelContext<Self>,
1162 ) {
1163 match event {
1164 language::Event::Operation(operation) => cx.emit(ContextEvent::Operation(
1165 ContextOperation::BufferOperation(operation.clone()),
1166 )),
1167 language::Event::Edited => {
1168 self.count_remaining_tokens(cx);
1169 self.reparse_slash_commands(cx);
1170 self.prune_invalid_edit_steps(cx);
1171 cx.emit(ContextEvent::MessagesEdited);
1172 }
1173 _ => {}
1174 }
1175 }
1176
1177 pub(crate) fn token_count(&self) -> Option<usize> {
1178 self.token_count
1179 }
1180
1181 pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1182 let request = self.to_completion_request(cx);
1183 self.pending_token_count = cx.spawn(|this, mut cx| {
1184 async move {
1185 cx.background_executor()
1186 .timer(Duration::from_millis(200))
1187 .await;
1188
1189 let token_count = cx
1190 .update(|cx| {
1191 LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
1192 })?
1193 .await?;
1194 this.update(&mut cx, |this, cx| {
1195 this.token_count = Some(token_count);
1196 cx.notify()
1197 })
1198 }
1199 .log_err()
1200 });
1201 }
1202
1203 pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
1204 let buffer = self.buffer.read(cx);
1205 let mut row_ranges = self
1206 .edits_since_last_slash_command_parse
1207 .consume()
1208 .into_iter()
1209 .map(|edit| {
1210 let start_row = buffer.offset_to_point(edit.new.start).row;
1211 let end_row = buffer.offset_to_point(edit.new.end).row + 1;
1212 start_row..end_row
1213 })
1214 .peekable();
1215
1216 let mut removed = Vec::new();
1217 let mut updated = Vec::new();
1218 while let Some(mut row_range) = row_ranges.next() {
1219 while let Some(next_row_range) = row_ranges.peek() {
1220 if row_range.end >= next_row_range.start {
1221 row_range.end = next_row_range.end;
1222 row_ranges.next();
1223 } else {
1224 break;
1225 }
1226 }
1227
1228 let start = buffer.anchor_before(Point::new(row_range.start, 0));
1229 let end = buffer.anchor_after(Point::new(
1230 row_range.end - 1,
1231 buffer.line_len(row_range.end - 1),
1232 ));
1233
1234 let old_range = self.pending_command_indices_for_range(start..end, cx);
1235
1236 let mut new_commands = Vec::new();
1237 let mut lines = buffer.text_for_range(start..end).lines();
1238 let mut offset = lines.offset();
1239 while let Some(line) = lines.next() {
1240 if let Some(command_line) = SlashCommandLine::parse(line) {
1241 let name = &line[command_line.name.clone()];
1242 let argument = command_line.argument.as_ref().and_then(|argument| {
1243 (!argument.is_empty()).then_some(&line[argument.clone()])
1244 });
1245 if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
1246 if !command.requires_argument() || argument.is_some() {
1247 let start_ix = offset + command_line.name.start - 1;
1248 let end_ix = offset
1249 + command_line
1250 .argument
1251 .map_or(command_line.name.end, |argument| argument.end);
1252 let source_range =
1253 buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix);
1254 let pending_command = PendingSlashCommand {
1255 name: name.to_string(),
1256 argument: argument.map(ToString::to_string),
1257 source_range,
1258 status: PendingSlashCommandStatus::Idle,
1259 };
1260 updated.push(pending_command.clone());
1261 new_commands.push(pending_command);
1262 }
1263 }
1264 }
1265
1266 offset = lines.offset();
1267 }
1268
1269 let removed_commands = self.pending_slash_commands.splice(old_range, new_commands);
1270 removed.extend(removed_commands.map(|command| command.source_range));
1271 }
1272
1273 if !updated.is_empty() || !removed.is_empty() {
1274 cx.emit(ContextEvent::PendingSlashCommandsUpdated { removed, updated });
1275 }
1276 }
1277
1278 fn prune_invalid_edit_steps(&mut self, cx: &mut ModelContext<Self>) {
1279 let buffer = self.buffer.read(cx);
1280 let prev_len = self.edit_steps.len();
1281 self.edit_steps.retain(|step| {
1282 step.source_range.start.is_valid(buffer) && step.source_range.end.is_valid(buffer)
1283 });
1284 if self.edit_steps.len() != prev_len {
1285 cx.emit(ContextEvent::EditStepsChanged);
1286 cx.notify();
1287 }
1288 }
1289
1290 fn parse_edit_steps_in_range(&mut self, range: Range<usize>, cx: &mut ModelContext<Self>) {
1291 let mut new_edit_steps = Vec::new();
1292
1293 self.buffer.update(cx, |buffer, _cx| {
1294 let mut message_lines = buffer.as_rope().chunks_in_range(range).lines();
1295 let mut in_step = false;
1296 let mut step_start = 0;
1297 let mut line_start_offset = message_lines.offset();
1298
1299 while let Some(line) = message_lines.next() {
1300 if let Some(step_start_index) = line.find("<step>") {
1301 if !in_step {
1302 in_step = true;
1303 step_start = line_start_offset + step_start_index;
1304 }
1305 }
1306
1307 if let Some(step_end_index) = line.find("</step>") {
1308 if in_step {
1309 let start_anchor = buffer.anchor_after(step_start);
1310 let end_anchor = buffer
1311 .anchor_before(line_start_offset + step_end_index + "</step>".len());
1312 let source_range = start_anchor..end_anchor;
1313
1314 // Check if a step with the same range already exists
1315 let existing_step_index = self.edit_steps.binary_search_by(|probe| {
1316 probe.source_range.cmp(&source_range, buffer)
1317 });
1318
1319 if let Err(ix) = existing_step_index {
1320 // Step doesn't exist, so add it
1321 new_edit_steps.push((
1322 ix,
1323 EditStep {
1324 source_range,
1325 operations: None,
1326 },
1327 ));
1328 }
1329
1330 in_step = false;
1331 }
1332 }
1333
1334 line_start_offset = message_lines.offset();
1335 }
1336 });
1337
1338 // Insert new steps and generate their corresponding tasks
1339 for (index, mut step) in new_edit_steps.into_iter().rev() {
1340 let task = self.generate_edit_step_operations(&step, cx);
1341 step.operations = Some(EditStepOperations::Pending(task));
1342 self.edit_steps.insert(index, step);
1343 }
1344
1345 cx.emit(ContextEvent::EditStepsChanged);
1346 cx.notify();
1347 }
1348
1349 fn generate_edit_step_operations(
1350 &self,
1351 edit_step: &EditStep,
1352 cx: &mut ModelContext<Self>,
1353 ) -> Task<Option<()>> {
1354 #[derive(Debug, Deserialize, JsonSchema)]
1355 struct EditTool {
1356 /// A sequence of operations to apply to the codebase.
1357 /// When multiple operations are required for a step, be sure to include multiple operations in this list.
1358 operations: Vec<EditOperation>,
1359 }
1360
1361 impl LanguageModelTool for EditTool {
1362 fn name() -> String {
1363 "edit".into()
1364 }
1365
1366 fn description() -> String {
1367 "suggest edits to one or more locations in the codebase".into()
1368 }
1369 }
1370
1371 let mut request = self.to_completion_request(cx);
1372 let edit_step_range = edit_step.source_range.clone();
1373 let step_text = self
1374 .buffer
1375 .read(cx)
1376 .text_for_range(edit_step_range.clone())
1377 .collect::<String>();
1378
1379 cx.spawn(|this, mut cx| {
1380 async move {
1381 let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
1382
1383 let mut prompt = prompt_store.operations_prompt();
1384 prompt.push_str(&step_text);
1385
1386 request.messages.push(LanguageModelRequestMessage {
1387 role: Role::User,
1388 content: prompt,
1389 });
1390
1391 let tool_use = cx
1392 .update(|cx| {
1393 LanguageModelCompletionProvider::read_global(cx)
1394 .use_tool::<EditTool>(request, cx)
1395 })?
1396 .await?;
1397
1398 this.update(&mut cx, |this, cx| {
1399 let step_index = this
1400 .edit_steps
1401 .binary_search_by(|step| {
1402 step.source_range
1403 .cmp(&edit_step_range, this.buffer.read(cx))
1404 })
1405 .map_err(|_| anyhow!("edit step not found"))?;
1406 if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
1407 edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
1408 cx.emit(ContextEvent::EditStepsChanged);
1409 }
1410 anyhow::Ok(())
1411 })?
1412 }
1413 .log_err()
1414 })
1415 }
1416
1417 pub fn pending_command_for_position(
1418 &mut self,
1419 position: language::Anchor,
1420 cx: &mut ModelContext<Self>,
1421 ) -> Option<&mut PendingSlashCommand> {
1422 let buffer = self.buffer.read(cx);
1423 match self
1424 .pending_slash_commands
1425 .binary_search_by(|probe| probe.source_range.end.cmp(&position, buffer))
1426 {
1427 Ok(ix) => Some(&mut self.pending_slash_commands[ix]),
1428 Err(ix) => {
1429 let cmd = self.pending_slash_commands.get_mut(ix)?;
1430 if position.cmp(&cmd.source_range.start, buffer).is_ge()
1431 && position.cmp(&cmd.source_range.end, buffer).is_le()
1432 {
1433 Some(cmd)
1434 } else {
1435 None
1436 }
1437 }
1438 }
1439 }
1440
1441 pub fn pending_commands_for_range(
1442 &self,
1443 range: Range<language::Anchor>,
1444 cx: &AppContext,
1445 ) -> &[PendingSlashCommand] {
1446 let range = self.pending_command_indices_for_range(range, cx);
1447 &self.pending_slash_commands[range]
1448 }
1449
1450 fn pending_command_indices_for_range(
1451 &self,
1452 range: Range<language::Anchor>,
1453 cx: &AppContext,
1454 ) -> Range<usize> {
1455 let buffer = self.buffer.read(cx);
1456 let start_ix = match self
1457 .pending_slash_commands
1458 .binary_search_by(|probe| probe.source_range.end.cmp(&range.start, &buffer))
1459 {
1460 Ok(ix) | Err(ix) => ix,
1461 };
1462 let end_ix = match self
1463 .pending_slash_commands
1464 .binary_search_by(|probe| probe.source_range.start.cmp(&range.end, &buffer))
1465 {
1466 Ok(ix) => ix + 1,
1467 Err(ix) => ix,
1468 };
1469 start_ix..end_ix
1470 }
1471
1472 pub fn insert_command_output(
1473 &mut self,
1474 command_range: Range<language::Anchor>,
1475 output: Task<Result<SlashCommandOutput>>,
1476 insert_trailing_newline: bool,
1477 cx: &mut ModelContext<Self>,
1478 ) {
1479 self.reparse_slash_commands(cx);
1480
1481 let insert_output_task = cx.spawn(|this, mut cx| {
1482 let command_range = command_range.clone();
1483 async move {
1484 let output = output.await;
1485 this.update(&mut cx, |this, cx| match output {
1486 Ok(mut output) => {
1487 if insert_trailing_newline {
1488 output.text.push('\n');
1489 }
1490
1491 let version = this.version.clone();
1492 let command_id = SlashCommandId(this.next_timestamp());
1493 let (operation, event) = this.buffer.update(cx, |buffer, cx| {
1494 let start = command_range.start.to_offset(buffer);
1495 let old_end = command_range.end.to_offset(buffer);
1496 let new_end = start + output.text.len();
1497 buffer.edit([(start..old_end, output.text)], None, cx);
1498
1499 let mut sections = output
1500 .sections
1501 .into_iter()
1502 .map(|section| SlashCommandOutputSection {
1503 range: buffer.anchor_after(start + section.range.start)
1504 ..buffer.anchor_before(start + section.range.end),
1505 icon: section.icon,
1506 label: section.label,
1507 })
1508 .collect::<Vec<_>>();
1509 sections.sort_by(|a, b| a.range.cmp(&b.range, buffer));
1510
1511 this.slash_command_output_sections
1512 .extend(sections.iter().cloned());
1513 this.slash_command_output_sections
1514 .sort_by(|a, b| a.range.cmp(&b.range, buffer));
1515
1516 let output_range =
1517 buffer.anchor_after(start)..buffer.anchor_before(new_end);
1518 this.finished_slash_commands.insert(command_id);
1519
1520 (
1521 ContextOperation::SlashCommandFinished {
1522 id: command_id,
1523 output_range: output_range.clone(),
1524 sections: sections.clone(),
1525 version,
1526 },
1527 ContextEvent::SlashCommandFinished {
1528 output_range,
1529 sections,
1530 run_commands_in_output: output.run_commands_in_text,
1531 },
1532 )
1533 });
1534
1535 this.push_op(operation, cx);
1536 cx.emit(event);
1537 }
1538 Err(error) => {
1539 if let Some(pending_command) =
1540 this.pending_command_for_position(command_range.start, cx)
1541 {
1542 pending_command.status =
1543 PendingSlashCommandStatus::Error(error.to_string());
1544 cx.emit(ContextEvent::PendingSlashCommandsUpdated {
1545 removed: vec![pending_command.source_range.clone()],
1546 updated: vec![pending_command.clone()],
1547 });
1548 }
1549 }
1550 })
1551 .ok();
1552 }
1553 });
1554
1555 if let Some(pending_command) = self.pending_command_for_position(command_range.start, cx) {
1556 pending_command.status = PendingSlashCommandStatus::Running {
1557 _task: insert_output_task.shared(),
1558 };
1559 cx.emit(ContextEvent::PendingSlashCommandsUpdated {
1560 removed: vec![pending_command.source_range.clone()],
1561 updated: vec![pending_command.clone()],
1562 });
1563 }
1564 }
1565
1566 pub fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
1567 self.count_remaining_tokens(cx);
1568 }
1569
1570 pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
1571 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1572 message
1573 .start
1574 .is_valid(self.buffer.read(cx))
1575 .then_some(message.id)
1576 })?;
1577
1578 if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
1579 log::info!("completion provider has no credentials");
1580 return None;
1581 }
1582
1583 let request = self.to_completion_request(cx);
1584 let stream =
1585 LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
1586 let assistant_message = self
1587 .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
1588 .unwrap();
1589
1590 // Queue up the user's next reply.
1591 let user_message = self
1592 .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx)
1593 .unwrap();
1594
1595 let task = cx.spawn({
1596 |this, mut cx| async move {
1597 let assistant_message_id = assistant_message.id;
1598 let mut response_latency = None;
1599 let stream_completion = async {
1600 let request_start = Instant::now();
1601 let mut chunks = stream.await?;
1602
1603 while let Some(chunk) = chunks.next().await {
1604 if response_latency.is_none() {
1605 response_latency = Some(request_start.elapsed());
1606 }
1607 let chunk = chunk?;
1608
1609 this.update(&mut cx, |this, cx| {
1610 let message_ix = this
1611 .message_anchors
1612 .iter()
1613 .position(|message| message.id == assistant_message_id)?;
1614 let message_range = this.buffer.update(cx, |buffer, cx| {
1615 let message_start_offset =
1616 this.message_anchors[message_ix].start.to_offset(buffer);
1617 let message_old_end_offset = this.message_anchors[message_ix + 1..]
1618 .iter()
1619 .find(|message| message.start.is_valid(buffer))
1620 .map_or(buffer.len(), |message| {
1621 message.start.to_offset(buffer).saturating_sub(1)
1622 });
1623 let message_new_end_offset = message_old_end_offset + chunk.len();
1624 buffer.edit(
1625 [(message_old_end_offset..message_old_end_offset, chunk)],
1626 None,
1627 cx,
1628 );
1629 message_start_offset..message_new_end_offset
1630 });
1631 this.parse_edit_steps_in_range(message_range, cx);
1632 cx.emit(ContextEvent::StreamedCompletion);
1633
1634 Some(())
1635 })?;
1636 smol::future::yield_now().await;
1637 }
1638
1639 this.update(&mut cx, |this, cx| {
1640 this.pending_completions
1641 .retain(|completion| completion.id != this.completion_count);
1642 this.summarize(false, cx);
1643 })?;
1644
1645 anyhow::Ok(())
1646 };
1647
1648 let result = stream_completion.await;
1649
1650 this.update(&mut cx, |this, cx| {
1651 let error_message = result
1652 .err()
1653 .map(|error| error.to_string().trim().to_string());
1654
1655 this.update_metadata(assistant_message_id, cx, |metadata| {
1656 if let Some(error_message) = error_message.as_ref() {
1657 metadata.status =
1658 MessageStatus::Error(SharedString::from(error_message.clone()));
1659 } else {
1660 metadata.status = MessageStatus::Done;
1661 }
1662 });
1663
1664 if let Some(telemetry) = this.telemetry.as_ref() {
1665 let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
1666 .active_model()
1667 .map(|m| m.telemetry_id())
1668 .unwrap_or_default();
1669 telemetry.report_assistant_event(
1670 Some(this.id.0.clone()),
1671 AssistantKind::Panel,
1672 model_telemetry_id,
1673 response_latency,
1674 error_message,
1675 );
1676 }
1677 })
1678 .ok();
1679 }
1680 });
1681
1682 self.pending_completions.push(PendingCompletion {
1683 id: post_inc(&mut self.completion_count),
1684 _task: task,
1685 });
1686
1687 Some(user_message)
1688 }
1689
1690 pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest {
1691 let messages = self
1692 .messages(cx)
1693 .filter(|message| matches!(message.status, MessageStatus::Done))
1694 .map(|message| message.to_request_message(self.buffer.read(cx)));
1695
1696 LanguageModelRequest {
1697 messages: messages.collect(),
1698 stop: vec![],
1699 temperature: 1.0,
1700 }
1701 }
1702
1703 pub fn cancel_last_assist(&mut self) -> bool {
1704 self.pending_completions.pop().is_some()
1705 }
1706
1707 pub fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1708 for id in ids {
1709 if let Some(metadata) = self.messages_metadata.get(&id) {
1710 let role = metadata.role.cycle();
1711 self.update_metadata(id, cx, |metadata| metadata.role = role);
1712 }
1713 }
1714 }
1715
1716 pub fn update_metadata(
1717 &mut self,
1718 id: MessageId,
1719 cx: &mut ModelContext<Self>,
1720 f: impl FnOnce(&mut MessageMetadata),
1721 ) {
1722 let version = self.version.clone();
1723 let timestamp = self.next_timestamp();
1724 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1725 f(metadata);
1726 metadata.timestamp = timestamp;
1727 let operation = ContextOperation::UpdateMessage {
1728 message_id: id,
1729 metadata: metadata.clone(),
1730 version,
1731 };
1732 self.push_op(operation, cx);
1733 cx.emit(ContextEvent::MessagesEdited);
1734 cx.notify();
1735 }
1736 }
1737
1738 fn insert_message_after(
1739 &mut self,
1740 message_id: MessageId,
1741 role: Role,
1742 status: MessageStatus,
1743 cx: &mut ModelContext<Self>,
1744 ) -> Option<MessageAnchor> {
1745 if let Some(prev_message_ix) = self
1746 .message_anchors
1747 .iter()
1748 .position(|message| message.id == message_id)
1749 {
1750 // Find the next valid message after the one we were given.
1751 let mut next_message_ix = prev_message_ix + 1;
1752 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1753 if next_message.start.is_valid(self.buffer.read(cx)) {
1754 break;
1755 }
1756 next_message_ix += 1;
1757 }
1758
1759 let start = self.buffer.update(cx, |buffer, cx| {
1760 let offset = self
1761 .message_anchors
1762 .get(next_message_ix)
1763 .map_or(buffer.len(), |message| {
1764 buffer.clip_offset(message.start.to_offset(buffer) - 1, Bias::Left)
1765 });
1766 buffer.edit([(offset..offset, "\n")], None, cx);
1767 buffer.anchor_before(offset + 1)
1768 });
1769
1770 let version = self.version.clone();
1771 let anchor = MessageAnchor {
1772 id: MessageId(self.next_timestamp()),
1773 start,
1774 };
1775 let metadata = MessageMetadata {
1776 role,
1777 status,
1778 timestamp: anchor.id.0,
1779 };
1780 self.insert_message(anchor.clone(), metadata.clone(), cx);
1781 self.push_op(
1782 ContextOperation::InsertMessage {
1783 anchor: anchor.clone(),
1784 metadata,
1785 version,
1786 },
1787 cx,
1788 );
1789 Some(anchor)
1790 } else {
1791 None
1792 }
1793 }
1794
1795 pub fn split_message(
1796 &mut self,
1797 range: Range<usize>,
1798 cx: &mut ModelContext<Self>,
1799 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1800 let start_message = self.message_for_offset(range.start, cx);
1801 let end_message = self.message_for_offset(range.end, cx);
1802 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1803 // Prevent splitting when range spans multiple messages.
1804 if start_message.id != end_message.id {
1805 return (None, None);
1806 }
1807
1808 let message = start_message;
1809 let role = message.role;
1810 let mut edited_buffer = false;
1811
1812 let mut suffix_start = None;
1813 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1814 {
1815 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1816 suffix_start = Some(range.end + 1);
1817 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1818 suffix_start = Some(range.end);
1819 }
1820 }
1821
1822 let version = self.version.clone();
1823 let suffix = if let Some(suffix_start) = suffix_start {
1824 MessageAnchor {
1825 id: MessageId(self.next_timestamp()),
1826 start: self.buffer.read(cx).anchor_before(suffix_start),
1827 }
1828 } else {
1829 self.buffer.update(cx, |buffer, cx| {
1830 buffer.edit([(range.end..range.end, "\n")], None, cx);
1831 });
1832 edited_buffer = true;
1833 MessageAnchor {
1834 id: MessageId(self.next_timestamp()),
1835 start: self.buffer.read(cx).anchor_before(range.end + 1),
1836 }
1837 };
1838
1839 let suffix_metadata = MessageMetadata {
1840 role,
1841 status: MessageStatus::Done,
1842 timestamp: suffix.id.0,
1843 };
1844 self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
1845 self.push_op(
1846 ContextOperation::InsertMessage {
1847 anchor: suffix.clone(),
1848 metadata: suffix_metadata,
1849 version,
1850 },
1851 cx,
1852 );
1853
1854 let new_messages =
1855 if range.start == range.end || range.start == message.offset_range.start {
1856 (None, Some(suffix))
1857 } else {
1858 let mut prefix_end = None;
1859 if range.start > message.offset_range.start
1860 && range.end < message.offset_range.end - 1
1861 {
1862 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1863 prefix_end = Some(range.start + 1);
1864 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1865 == Some('\n')
1866 {
1867 prefix_end = Some(range.start);
1868 }
1869 }
1870
1871 let version = self.version.clone();
1872 let selection = if let Some(prefix_end) = prefix_end {
1873 MessageAnchor {
1874 id: MessageId(self.next_timestamp()),
1875 start: self.buffer.read(cx).anchor_before(prefix_end),
1876 }
1877 } else {
1878 self.buffer.update(cx, |buffer, cx| {
1879 buffer.edit([(range.start..range.start, "\n")], None, cx)
1880 });
1881 edited_buffer = true;
1882 MessageAnchor {
1883 id: MessageId(self.next_timestamp()),
1884 start: self.buffer.read(cx).anchor_before(range.end + 1),
1885 }
1886 };
1887
1888 let selection_metadata = MessageMetadata {
1889 role,
1890 status: MessageStatus::Done,
1891 timestamp: selection.id.0,
1892 };
1893 self.insert_message(selection.clone(), selection_metadata.clone(), cx);
1894 self.push_op(
1895 ContextOperation::InsertMessage {
1896 anchor: selection.clone(),
1897 metadata: selection_metadata,
1898 version,
1899 },
1900 cx,
1901 );
1902
1903 (Some(selection), Some(suffix))
1904 };
1905
1906 if !edited_buffer {
1907 cx.emit(ContextEvent::MessagesEdited);
1908 }
1909 new_messages
1910 } else {
1911 (None, None)
1912 }
1913 }
1914
1915 fn insert_message(
1916 &mut self,
1917 new_anchor: MessageAnchor,
1918 new_metadata: MessageMetadata,
1919 cx: &mut ModelContext<Self>,
1920 ) {
1921 cx.emit(ContextEvent::MessagesEdited);
1922
1923 self.messages_metadata.insert(new_anchor.id, new_metadata);
1924
1925 let buffer = self.buffer.read(cx);
1926 let insertion_ix = self
1927 .message_anchors
1928 .iter()
1929 .position(|anchor| {
1930 let comparison = new_anchor.start.cmp(&anchor.start, buffer);
1931 comparison.is_lt() || (comparison.is_eq() && new_anchor.id > anchor.id)
1932 })
1933 .unwrap_or(self.message_anchors.len());
1934 self.message_anchors.insert(insertion_ix, new_anchor);
1935 }
1936
1937 pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
1938 if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
1939 if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
1940 return;
1941 }
1942
1943 let messages = self
1944 .messages(cx)
1945 .map(|message| message.to_request_message(self.buffer.read(cx)))
1946 .chain(Some(LanguageModelRequestMessage {
1947 role: Role::User,
1948 content: "Summarize the context into a short title without punctuation.".into(),
1949 }));
1950 let request = LanguageModelRequest {
1951 messages: messages.collect(),
1952 stop: vec![],
1953 temperature: 1.0,
1954 };
1955
1956 let stream =
1957 LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
1958 self.pending_summary = cx.spawn(|this, mut cx| {
1959 async move {
1960 let mut messages = stream.await?;
1961
1962 let mut replaced = !replace_old;
1963 while let Some(message) = messages.next().await {
1964 let text = message?;
1965 let mut lines = text.lines();
1966 this.update(&mut cx, |this, cx| {
1967 let version = this.version.clone();
1968 let timestamp = this.next_timestamp();
1969 let summary = this.summary.get_or_insert(ContextSummary::default());
1970 if !replaced && replace_old {
1971 summary.text.clear();
1972 replaced = true;
1973 }
1974 summary.text.extend(lines.next());
1975 summary.timestamp = timestamp;
1976 let operation = ContextOperation::UpdateSummary {
1977 summary: summary.clone(),
1978 version,
1979 };
1980 this.push_op(operation, cx);
1981 cx.emit(ContextEvent::SummaryChanged);
1982 })?;
1983
1984 // Stop if the LLM generated multiple lines.
1985 if lines.next().is_some() {
1986 break;
1987 }
1988 }
1989
1990 this.update(&mut cx, |this, cx| {
1991 let version = this.version.clone();
1992 let timestamp = this.next_timestamp();
1993 if let Some(summary) = this.summary.as_mut() {
1994 summary.done = true;
1995 summary.timestamp = timestamp;
1996 let operation = ContextOperation::UpdateSummary {
1997 summary: summary.clone(),
1998 version,
1999 };
2000 this.push_op(operation, cx);
2001 cx.emit(ContextEvent::SummaryChanged);
2002 }
2003 })?;
2004
2005 anyhow::Ok(())
2006 }
2007 .log_err()
2008 });
2009 }
2010 }
2011
2012 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
2013 self.messages_for_offsets([offset], cx).pop()
2014 }
2015
2016 pub fn messages_for_offsets(
2017 &self,
2018 offsets: impl IntoIterator<Item = usize>,
2019 cx: &AppContext,
2020 ) -> Vec<Message> {
2021 let mut result = Vec::new();
2022
2023 let mut messages = self.messages(cx).peekable();
2024 let mut offsets = offsets.into_iter().peekable();
2025 let mut current_message = messages.next();
2026 while let Some(offset) = offsets.next() {
2027 // Locate the message that contains the offset.
2028 while current_message.as_ref().map_or(false, |message| {
2029 !message.offset_range.contains(&offset) && messages.peek().is_some()
2030 }) {
2031 current_message = messages.next();
2032 }
2033 let Some(message) = current_message.as_ref() else {
2034 break;
2035 };
2036
2037 // Skip offsets that are in the same message.
2038 while offsets.peek().map_or(false, |offset| {
2039 message.offset_range.contains(offset) || messages.peek().is_none()
2040 }) {
2041 offsets.next();
2042 }
2043
2044 result.push(message.clone());
2045 }
2046 result
2047 }
2048
2049 pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
2050 let buffer = self.buffer.read(cx);
2051 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
2052 iter::from_fn(move || {
2053 if let Some((start_ix, message_anchor)) = message_anchors.next() {
2054 let metadata = self.messages_metadata.get(&message_anchor.id)?;
2055 let message_start = message_anchor.start.to_offset(buffer);
2056 let mut message_end = None;
2057 let mut end_ix = start_ix;
2058 while let Some((_, next_message)) = message_anchors.peek() {
2059 if next_message.start.is_valid(buffer) {
2060 message_end = Some(next_message.start);
2061 break;
2062 } else {
2063 end_ix += 1;
2064 message_anchors.next();
2065 }
2066 }
2067 let message_end = message_end
2068 .unwrap_or(language::Anchor::MAX)
2069 .to_offset(buffer);
2070
2071 return Some(Message {
2072 index_range: start_ix..end_ix,
2073 offset_range: message_start..message_end,
2074 id: message_anchor.id,
2075 anchor: message_anchor.start,
2076 role: metadata.role,
2077 status: metadata.status.clone(),
2078 });
2079 }
2080 None
2081 })
2082 }
2083
2084 pub fn save(
2085 &mut self,
2086 debounce: Option<Duration>,
2087 fs: Arc<dyn Fs>,
2088 cx: &mut ModelContext<Context>,
2089 ) {
2090 if self.replica_id() != ReplicaId::default() {
2091 // Prevent saving a remote context for now.
2092 return;
2093 }
2094
2095 self.pending_save = cx.spawn(|this, mut cx| async move {
2096 if let Some(debounce) = debounce {
2097 cx.background_executor().timer(debounce).await;
2098 }
2099
2100 let (old_path, summary) = this.read_with(&cx, |this, _| {
2101 let path = this.path.clone();
2102 let summary = if let Some(summary) = this.summary.as_ref() {
2103 if summary.done {
2104 Some(summary.text.clone())
2105 } else {
2106 None
2107 }
2108 } else {
2109 None
2110 };
2111 (path, summary)
2112 })?;
2113
2114 if let Some(summary) = summary {
2115 let context = this.read_with(&cx, |this, cx| this.serialize(cx))?;
2116 let mut discriminant = 1;
2117 let mut new_path;
2118 loop {
2119 new_path = contexts_dir().join(&format!(
2120 "{} - {}.zed.json",
2121 summary.trim(),
2122 discriminant
2123 ));
2124 if fs.is_file(&new_path).await {
2125 discriminant += 1;
2126 } else {
2127 break;
2128 }
2129 }
2130
2131 fs.create_dir(contexts_dir().as_ref()).await?;
2132 fs.atomic_write(new_path.clone(), serde_json::to_string(&context).unwrap())
2133 .await?;
2134 if let Some(old_path) = old_path {
2135 if new_path != old_path {
2136 fs.remove_file(
2137 &old_path,
2138 RemoveOptions {
2139 recursive: false,
2140 ignore_if_not_exists: true,
2141 },
2142 )
2143 .await?;
2144 }
2145 }
2146
2147 this.update(&mut cx, |this, _| this.path = Some(new_path))?;
2148 }
2149
2150 Ok(())
2151 });
2152 }
2153
2154 pub(crate) fn custom_summary(&mut self, custom_summary: String, cx: &mut ModelContext<Self>) {
2155 let timestamp = self.next_timestamp();
2156 let summary = self.summary.get_or_insert(ContextSummary::default());
2157 summary.timestamp = timestamp;
2158 summary.done = true;
2159 summary.text = custom_summary;
2160 cx.emit(ContextEvent::SummaryChanged);
2161 }
2162}
2163
2164#[derive(Debug, Default)]
2165pub struct ContextVersion {
2166 context: clock::Global,
2167 buffer: clock::Global,
2168}
2169
2170impl ContextVersion {
2171 pub fn from_proto(proto: &proto::ContextVersion) -> Self {
2172 Self {
2173 context: language::proto::deserialize_version(&proto.context_version),
2174 buffer: language::proto::deserialize_version(&proto.buffer_version),
2175 }
2176 }
2177
2178 pub fn to_proto(&self, context_id: ContextId) -> proto::ContextVersion {
2179 proto::ContextVersion {
2180 context_id: context_id.to_proto(),
2181 context_version: language::proto::serialize_version(&self.context),
2182 buffer_version: language::proto::serialize_version(&self.buffer),
2183 }
2184 }
2185}
2186
2187#[derive(Clone)]
2188pub struct PendingSlashCommand {
2189 pub name: String,
2190 pub argument: Option<String>,
2191 pub status: PendingSlashCommandStatus,
2192 pub source_range: Range<language::Anchor>,
2193}
2194
2195#[derive(Clone)]
2196pub enum PendingSlashCommandStatus {
2197 Idle,
2198 Running { _task: Shared<Task<()>> },
2199 Error(String),
2200}
2201
2202#[derive(Serialize, Deserialize)]
2203pub struct SavedMessage {
2204 pub id: MessageId,
2205 pub start: usize,
2206 pub metadata: MessageMetadata,
2207}
2208
2209#[derive(Serialize, Deserialize)]
2210pub struct SavedContext {
2211 pub id: Option<ContextId>,
2212 pub zed: String,
2213 pub version: String,
2214 pub text: String,
2215 pub messages: Vec<SavedMessage>,
2216 pub summary: String,
2217 pub slash_command_output_sections:
2218 Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
2219}
2220
2221impl SavedContext {
2222 pub const VERSION: &'static str = "0.4.0";
2223
2224 pub fn from_json(json: &str) -> Result<Self> {
2225 let saved_context_json = serde_json::from_str::<serde_json::Value>(json)?;
2226 match saved_context_json
2227 .get("version")
2228 .ok_or_else(|| anyhow!("version not found"))?
2229 {
2230 serde_json::Value::String(version) => match version.as_str() {
2231 SavedContext::VERSION => {
2232 Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
2233 }
2234 SavedContextV0_3_0::VERSION => {
2235 let saved_context =
2236 serde_json::from_value::<SavedContextV0_3_0>(saved_context_json)?;
2237 Ok(saved_context.upgrade())
2238 }
2239 SavedContextV0_2_0::VERSION => {
2240 let saved_context =
2241 serde_json::from_value::<SavedContextV0_2_0>(saved_context_json)?;
2242 Ok(saved_context.upgrade())
2243 }
2244 SavedContextV0_1_0::VERSION => {
2245 let saved_context =
2246 serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?;
2247 Ok(saved_context.upgrade())
2248 }
2249 _ => Err(anyhow!("unrecognized saved context version: {}", version)),
2250 },
2251 _ => Err(anyhow!("version not found on saved context")),
2252 }
2253 }
2254
2255 fn into_ops(
2256 self,
2257 buffer: &Model<Buffer>,
2258 cx: &mut ModelContext<Context>,
2259 ) -> Vec<ContextOperation> {
2260 let mut operations = Vec::new();
2261 let mut version = clock::Global::new();
2262 let mut next_timestamp = clock::Lamport::new(ReplicaId::default());
2263
2264 let mut first_message_metadata = None;
2265 for message in self.messages {
2266 if message.id == MessageId(clock::Lamport::default()) {
2267 first_message_metadata = Some(message.metadata);
2268 } else {
2269 operations.push(ContextOperation::InsertMessage {
2270 anchor: MessageAnchor {
2271 id: message.id,
2272 start: buffer.read(cx).anchor_before(message.start),
2273 },
2274 metadata: MessageMetadata {
2275 role: message.metadata.role,
2276 status: message.metadata.status,
2277 timestamp: message.metadata.timestamp,
2278 },
2279 version: version.clone(),
2280 });
2281 version.observe(message.id.0);
2282 next_timestamp.observe(message.id.0);
2283 }
2284 }
2285
2286 if let Some(metadata) = first_message_metadata {
2287 let timestamp = next_timestamp.tick();
2288 operations.push(ContextOperation::UpdateMessage {
2289 message_id: MessageId(clock::Lamport::default()),
2290 metadata: MessageMetadata {
2291 role: metadata.role,
2292 status: metadata.status,
2293 timestamp,
2294 },
2295 version: version.clone(),
2296 });
2297 version.observe(timestamp);
2298 }
2299
2300 let timestamp = next_timestamp.tick();
2301 operations.push(ContextOperation::SlashCommandFinished {
2302 id: SlashCommandId(timestamp),
2303 output_range: language::Anchor::MIN..language::Anchor::MAX,
2304 sections: self
2305 .slash_command_output_sections
2306 .into_iter()
2307 .map(|section| {
2308 let buffer = buffer.read(cx);
2309 SlashCommandOutputSection {
2310 range: buffer.anchor_after(section.range.start)
2311 ..buffer.anchor_before(section.range.end),
2312 icon: section.icon,
2313 label: section.label,
2314 }
2315 })
2316 .collect(),
2317 version: version.clone(),
2318 });
2319 version.observe(timestamp);
2320
2321 let timestamp = next_timestamp.tick();
2322 operations.push(ContextOperation::UpdateSummary {
2323 summary: ContextSummary {
2324 text: self.summary,
2325 done: true,
2326 timestamp,
2327 },
2328 version: version.clone(),
2329 });
2330 version.observe(timestamp);
2331
2332 operations
2333 }
2334}
2335
2336#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
2337struct SavedMessageIdPreV0_4_0(usize);
2338
2339#[derive(Serialize, Deserialize)]
2340struct SavedMessagePreV0_4_0 {
2341 id: SavedMessageIdPreV0_4_0,
2342 start: usize,
2343}
2344
2345#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
2346struct SavedMessageMetadataPreV0_4_0 {
2347 role: Role,
2348 status: MessageStatus,
2349}
2350
2351#[derive(Serialize, Deserialize)]
2352struct SavedContextV0_3_0 {
2353 id: Option<ContextId>,
2354 zed: String,
2355 version: String,
2356 text: String,
2357 messages: Vec<SavedMessagePreV0_4_0>,
2358 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2359 summary: String,
2360 slash_command_output_sections: Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
2361}
2362
2363impl SavedContextV0_3_0 {
2364 const VERSION: &'static str = "0.3.0";
2365
2366 fn upgrade(self) -> SavedContext {
2367 SavedContext {
2368 id: self.id,
2369 zed: self.zed,
2370 version: SavedContext::VERSION.into(),
2371 text: self.text,
2372 messages: self
2373 .messages
2374 .into_iter()
2375 .filter_map(|message| {
2376 let metadata = self.message_metadata.get(&message.id)?;
2377 let timestamp = clock::Lamport {
2378 replica_id: ReplicaId::default(),
2379 value: message.id.0 as u32,
2380 };
2381 Some(SavedMessage {
2382 id: MessageId(timestamp),
2383 start: message.start,
2384 metadata: MessageMetadata {
2385 role: metadata.role,
2386 status: metadata.status.clone(),
2387 timestamp,
2388 },
2389 })
2390 })
2391 .collect(),
2392 summary: self.summary,
2393 slash_command_output_sections: self.slash_command_output_sections,
2394 }
2395 }
2396}
2397
2398#[derive(Serialize, Deserialize)]
2399struct SavedContextV0_2_0 {
2400 id: Option<ContextId>,
2401 zed: String,
2402 version: String,
2403 text: String,
2404 messages: Vec<SavedMessagePreV0_4_0>,
2405 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2406 summary: String,
2407}
2408
2409impl SavedContextV0_2_0 {
2410 const VERSION: &'static str = "0.2.0";
2411
2412 fn upgrade(self) -> SavedContext {
2413 SavedContextV0_3_0 {
2414 id: self.id,
2415 zed: self.zed,
2416 version: SavedContextV0_3_0::VERSION.to_string(),
2417 text: self.text,
2418 messages: self.messages,
2419 message_metadata: self.message_metadata,
2420 summary: self.summary,
2421 slash_command_output_sections: Vec::new(),
2422 }
2423 .upgrade()
2424 }
2425}
2426
2427#[derive(Serialize, Deserialize)]
2428struct SavedContextV0_1_0 {
2429 id: Option<ContextId>,
2430 zed: String,
2431 version: String,
2432 text: String,
2433 messages: Vec<SavedMessagePreV0_4_0>,
2434 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2435 summary: String,
2436 api_url: Option<String>,
2437 model: OpenAiModel,
2438}
2439
2440impl SavedContextV0_1_0 {
2441 const VERSION: &'static str = "0.1.0";
2442
2443 fn upgrade(self) -> SavedContext {
2444 SavedContextV0_2_0 {
2445 id: self.id,
2446 zed: self.zed,
2447 version: SavedContextV0_2_0::VERSION.to_string(),
2448 text: self.text,
2449 messages: self.messages,
2450 message_metadata: self.message_metadata,
2451 summary: self.summary,
2452 }
2453 .upgrade()
2454 }
2455}
2456
2457#[derive(Clone)]
2458pub struct SavedContextMetadata {
2459 pub title: String,
2460 pub path: PathBuf,
2461 pub mtime: chrono::DateTime<chrono::Local>,
2462}
2463
2464#[cfg(test)]
2465mod tests {
2466 use super::*;
2467 use crate::{
2468 assistant_panel, prompt_library,
2469 slash_command::{active_command, file_command},
2470 MessageId,
2471 };
2472 use assistant_slash_command::{ArgumentCompletion, SlashCommand};
2473 use fs::FakeFs;
2474 use gpui::{AppContext, TestAppContext, WeakView};
2475 use indoc::indoc;
2476 use language::LspAdapterDelegate;
2477 use parking_lot::Mutex;
2478 use project::Project;
2479 use rand::prelude::*;
2480 use serde_json::json;
2481 use settings::SettingsStore;
2482 use std::{cell::RefCell, env, rc::Rc, sync::atomic::AtomicBool};
2483 use text::{network::Network, ToPoint};
2484 use ui::WindowContext;
2485 use unindent::Unindent;
2486 use util::{test::marked_text_ranges, RandomCharIter};
2487 use workspace::Workspace;
2488
2489 #[gpui::test]
2490 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2491 let settings_store = SettingsStore::test(cx);
2492 language_model::LanguageModelRegistry::test(cx);
2493 completion::LanguageModelCompletionProvider::test(cx);
2494 cx.set_global(settings_store);
2495 assistant_panel::init(cx);
2496 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2497
2498 let context = cx.new_model(|cx| Context::local(registry, None, cx));
2499 let buffer = context.read(cx).buffer.clone();
2500
2501 let message_1 = context.read(cx).message_anchors[0].clone();
2502 assert_eq!(
2503 messages(&context, cx),
2504 vec![(message_1.id, Role::User, 0..0)]
2505 );
2506
2507 let message_2 = context.update(cx, |context, cx| {
2508 context
2509 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2510 .unwrap()
2511 });
2512 assert_eq!(
2513 messages(&context, cx),
2514 vec![
2515 (message_1.id, Role::User, 0..1),
2516 (message_2.id, Role::Assistant, 1..1)
2517 ]
2518 );
2519
2520 buffer.update(cx, |buffer, cx| {
2521 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2522 });
2523 assert_eq!(
2524 messages(&context, cx),
2525 vec![
2526 (message_1.id, Role::User, 0..2),
2527 (message_2.id, Role::Assistant, 2..3)
2528 ]
2529 );
2530
2531 let message_3 = context.update(cx, |context, cx| {
2532 context
2533 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2534 .unwrap()
2535 });
2536 assert_eq!(
2537 messages(&context, cx),
2538 vec![
2539 (message_1.id, Role::User, 0..2),
2540 (message_2.id, Role::Assistant, 2..4),
2541 (message_3.id, Role::User, 4..4)
2542 ]
2543 );
2544
2545 let message_4 = context.update(cx, |context, cx| {
2546 context
2547 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2548 .unwrap()
2549 });
2550 assert_eq!(
2551 messages(&context, cx),
2552 vec![
2553 (message_1.id, Role::User, 0..2),
2554 (message_2.id, Role::Assistant, 2..4),
2555 (message_4.id, Role::User, 4..5),
2556 (message_3.id, Role::User, 5..5),
2557 ]
2558 );
2559
2560 buffer.update(cx, |buffer, cx| {
2561 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2562 });
2563 assert_eq!(
2564 messages(&context, cx),
2565 vec![
2566 (message_1.id, Role::User, 0..2),
2567 (message_2.id, Role::Assistant, 2..4),
2568 (message_4.id, Role::User, 4..6),
2569 (message_3.id, Role::User, 6..7),
2570 ]
2571 );
2572
2573 // Deleting across message boundaries merges the messages.
2574 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2575 assert_eq!(
2576 messages(&context, cx),
2577 vec![
2578 (message_1.id, Role::User, 0..3),
2579 (message_3.id, Role::User, 3..4),
2580 ]
2581 );
2582
2583 // Undoing the deletion should also undo the merge.
2584 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2585 assert_eq!(
2586 messages(&context, cx),
2587 vec![
2588 (message_1.id, Role::User, 0..2),
2589 (message_2.id, Role::Assistant, 2..4),
2590 (message_4.id, Role::User, 4..6),
2591 (message_3.id, Role::User, 6..7),
2592 ]
2593 );
2594
2595 // Redoing the deletion should also redo the merge.
2596 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2597 assert_eq!(
2598 messages(&context, cx),
2599 vec![
2600 (message_1.id, Role::User, 0..3),
2601 (message_3.id, Role::User, 3..4),
2602 ]
2603 );
2604
2605 // Ensure we can still insert after a merged message.
2606 let message_5 = context.update(cx, |context, cx| {
2607 context
2608 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2609 .unwrap()
2610 });
2611 assert_eq!(
2612 messages(&context, cx),
2613 vec![
2614 (message_1.id, Role::User, 0..3),
2615 (message_5.id, Role::System, 3..4),
2616 (message_3.id, Role::User, 4..5)
2617 ]
2618 );
2619 }
2620
2621 #[gpui::test]
2622 fn test_message_splitting(cx: &mut AppContext) {
2623 let settings_store = SettingsStore::test(cx);
2624 cx.set_global(settings_store);
2625 language_model::LanguageModelRegistry::test(cx);
2626 completion::LanguageModelCompletionProvider::test(cx);
2627 assistant_panel::init(cx);
2628 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2629
2630 let context = cx.new_model(|cx| Context::local(registry, None, cx));
2631 let buffer = context.read(cx).buffer.clone();
2632
2633 let message_1 = context.read(cx).message_anchors[0].clone();
2634 assert_eq!(
2635 messages(&context, cx),
2636 vec![(message_1.id, Role::User, 0..0)]
2637 );
2638
2639 buffer.update(cx, |buffer, cx| {
2640 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2641 });
2642
2643 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
2644 let message_2 = message_2.unwrap();
2645
2646 // We recycle newlines in the middle of a split message
2647 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2648 assert_eq!(
2649 messages(&context, cx),
2650 vec![
2651 (message_1.id, Role::User, 0..4),
2652 (message_2.id, Role::User, 4..16),
2653 ]
2654 );
2655
2656 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
2657 let message_3 = message_3.unwrap();
2658
2659 // We don't recycle newlines at the end of a split message
2660 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2661 assert_eq!(
2662 messages(&context, cx),
2663 vec![
2664 (message_1.id, Role::User, 0..4),
2665 (message_3.id, Role::User, 4..5),
2666 (message_2.id, Role::User, 5..17),
2667 ]
2668 );
2669
2670 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
2671 let message_4 = message_4.unwrap();
2672 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2673 assert_eq!(
2674 messages(&context, cx),
2675 vec![
2676 (message_1.id, Role::User, 0..4),
2677 (message_3.id, Role::User, 4..5),
2678 (message_2.id, Role::User, 5..9),
2679 (message_4.id, Role::User, 9..17),
2680 ]
2681 );
2682
2683 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
2684 let message_5 = message_5.unwrap();
2685 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2686 assert_eq!(
2687 messages(&context, cx),
2688 vec![
2689 (message_1.id, Role::User, 0..4),
2690 (message_3.id, Role::User, 4..5),
2691 (message_2.id, Role::User, 5..9),
2692 (message_4.id, Role::User, 9..10),
2693 (message_5.id, Role::User, 10..18),
2694 ]
2695 );
2696
2697 let (message_6, message_7) =
2698 context.update(cx, |context, cx| context.split_message(14..16, cx));
2699 let message_6 = message_6.unwrap();
2700 let message_7 = message_7.unwrap();
2701 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2702 assert_eq!(
2703 messages(&context, cx),
2704 vec![
2705 (message_1.id, Role::User, 0..4),
2706 (message_3.id, Role::User, 4..5),
2707 (message_2.id, Role::User, 5..9),
2708 (message_4.id, Role::User, 9..10),
2709 (message_5.id, Role::User, 10..14),
2710 (message_6.id, Role::User, 14..17),
2711 (message_7.id, Role::User, 17..19),
2712 ]
2713 );
2714 }
2715
2716 #[gpui::test]
2717 fn test_messages_for_offsets(cx: &mut AppContext) {
2718 let settings_store = SettingsStore::test(cx);
2719 language_model::LanguageModelRegistry::test(cx);
2720 completion::LanguageModelCompletionProvider::test(cx);
2721 cx.set_global(settings_store);
2722 assistant_panel::init(cx);
2723 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2724 let context = cx.new_model(|cx| Context::local(registry, None, cx));
2725 let buffer = context.read(cx).buffer.clone();
2726
2727 let message_1 = context.read(cx).message_anchors[0].clone();
2728 assert_eq!(
2729 messages(&context, cx),
2730 vec![(message_1.id, Role::User, 0..0)]
2731 );
2732
2733 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2734 let message_2 = context
2735 .update(cx, |context, cx| {
2736 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2737 })
2738 .unwrap();
2739 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2740
2741 let message_3 = context
2742 .update(cx, |context, cx| {
2743 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2744 })
2745 .unwrap();
2746 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2747
2748 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2749 assert_eq!(
2750 messages(&context, cx),
2751 vec![
2752 (message_1.id, Role::User, 0..4),
2753 (message_2.id, Role::User, 4..8),
2754 (message_3.id, Role::User, 8..11)
2755 ]
2756 );
2757
2758 assert_eq!(
2759 message_ids_for_offsets(&context, &[0, 4, 9], cx),
2760 [message_1.id, message_2.id, message_3.id]
2761 );
2762 assert_eq!(
2763 message_ids_for_offsets(&context, &[0, 1, 11], cx),
2764 [message_1.id, message_3.id]
2765 );
2766
2767 let message_4 = context
2768 .update(cx, |context, cx| {
2769 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2770 })
2771 .unwrap();
2772 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2773 assert_eq!(
2774 messages(&context, cx),
2775 vec![
2776 (message_1.id, Role::User, 0..4),
2777 (message_2.id, Role::User, 4..8),
2778 (message_3.id, Role::User, 8..12),
2779 (message_4.id, Role::User, 12..12)
2780 ]
2781 );
2782 assert_eq!(
2783 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
2784 [message_1.id, message_2.id, message_3.id, message_4.id]
2785 );
2786
2787 fn message_ids_for_offsets(
2788 context: &Model<Context>,
2789 offsets: &[usize],
2790 cx: &AppContext,
2791 ) -> Vec<MessageId> {
2792 context
2793 .read(cx)
2794 .messages_for_offsets(offsets.iter().copied(), cx)
2795 .into_iter()
2796 .map(|message| message.id)
2797 .collect()
2798 }
2799 }
2800
2801 #[gpui::test]
2802 async fn test_slash_commands(cx: &mut TestAppContext) {
2803 let settings_store = cx.update(SettingsStore::test);
2804 cx.set_global(settings_store);
2805 cx.update(language_model::LanguageModelRegistry::test);
2806 cx.update(completion::LanguageModelCompletionProvider::test);
2807 cx.update(Project::init_settings);
2808 cx.update(assistant_panel::init);
2809 let fs = FakeFs::new(cx.background_executor.clone());
2810
2811 fs.insert_tree(
2812 "/test",
2813 json!({
2814 "src": {
2815 "lib.rs": "fn one() -> usize { 1 }",
2816 "main.rs": "
2817 use crate::one;
2818 fn main() { one(); }
2819 ".unindent(),
2820 }
2821 }),
2822 )
2823 .await;
2824
2825 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
2826 slash_command_registry.register_command(file_command::FileSlashCommand, false);
2827 slash_command_registry.register_command(active_command::ActiveSlashCommand, false);
2828
2829 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
2830 let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
2831
2832 let output_ranges = Rc::new(RefCell::new(HashSet::default()));
2833 context.update(cx, |_, cx| {
2834 cx.subscribe(&context, {
2835 let ranges = output_ranges.clone();
2836 move |_, _, event, _| match event {
2837 ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
2838 for range in removed {
2839 ranges.borrow_mut().remove(range);
2840 }
2841 for command in updated {
2842 ranges.borrow_mut().insert(command.source_range.clone());
2843 }
2844 }
2845 _ => {}
2846 }
2847 })
2848 .detach();
2849 });
2850
2851 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
2852
2853 // Insert a slash command
2854 buffer.update(cx, |buffer, cx| {
2855 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
2856 });
2857 assert_text_and_output_ranges(
2858 &buffer,
2859 &output_ranges.borrow(),
2860 "
2861 «/file src/lib.rs»
2862 "
2863 .unindent()
2864 .trim_end(),
2865 cx,
2866 );
2867
2868 // Edit the argument of the slash command.
2869 buffer.update(cx, |buffer, cx| {
2870 let edit_offset = buffer.text().find("lib.rs").unwrap();
2871 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
2872 });
2873 assert_text_and_output_ranges(
2874 &buffer,
2875 &output_ranges.borrow(),
2876 "
2877 «/file src/main.rs»
2878 "
2879 .unindent()
2880 .trim_end(),
2881 cx,
2882 );
2883
2884 // Edit the name of the slash command, using one that doesn't exist.
2885 buffer.update(cx, |buffer, cx| {
2886 let edit_offset = buffer.text().find("/file").unwrap();
2887 buffer.edit(
2888 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
2889 None,
2890 cx,
2891 );
2892 });
2893 assert_text_and_output_ranges(
2894 &buffer,
2895 &output_ranges.borrow(),
2896 "
2897 /unknown src/main.rs
2898 "
2899 .unindent()
2900 .trim_end(),
2901 cx,
2902 );
2903
2904 #[track_caller]
2905 fn assert_text_and_output_ranges(
2906 buffer: &Model<Buffer>,
2907 ranges: &HashSet<Range<language::Anchor>>,
2908 expected_marked_text: &str,
2909 cx: &mut TestAppContext,
2910 ) {
2911 let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
2912 let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
2913 let mut ranges = ranges
2914 .iter()
2915 .map(|range| range.to_offset(buffer))
2916 .collect::<Vec<_>>();
2917 ranges.sort_by_key(|a| a.start);
2918 (buffer.text(), ranges)
2919 });
2920
2921 assert_eq!(actual_text, expected_text);
2922 assert_eq!(actual_ranges, expected_ranges);
2923 }
2924 }
2925
2926 #[gpui::test]
2927 async fn test_edit_step_parsing(cx: &mut TestAppContext) {
2928 cx.update(prompt_library::init);
2929 let settings_store = cx.update(SettingsStore::test);
2930 cx.set_global(settings_store);
2931
2932 let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
2933 cx.update(completion::LanguageModelCompletionProvider::test);
2934
2935 let fake_model = fake_provider.test_model();
2936 cx.update(assistant_panel::init);
2937 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
2938
2939 // Create a new context
2940 let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
2941 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
2942
2943 // Simulate user input
2944 let user_message = indoc! {r#"
2945 Please refactor this code:
2946
2947 fn main() {
2948 println!("Hello, World!");
2949 }
2950 "#};
2951 buffer.update(cx, |buffer, cx| {
2952 buffer.edit([(0..0, user_message)], None, cx);
2953 });
2954
2955 // Simulate LLM response with edit steps
2956 let llm_response = indoc! {r#"
2957 Sure, I can help you refactor that code. Here's a step-by-step process:
2958
2959 <step>
2960 First, let's extract the greeting into a separate function:
2961
2962 ```rust
2963 fn greet() {
2964 println!("Hello, World!");
2965 }
2966
2967 fn main() {
2968 greet();
2969 }
2970 ```
2971 </step>
2972
2973 <step>
2974 Now, let's make the greeting customizable:
2975
2976 ```rust
2977 fn greet(name: &str) {
2978 println!("Hello, {}!", name);
2979 }
2980
2981 fn main() {
2982 greet("World");
2983 }
2984 ```
2985 </step>
2986
2987 These changes make the code more modular and flexible.
2988 "#};
2989
2990 // Simulate the assist method to trigger the LLM response
2991 context.update(cx, |context, cx| context.assist(cx));
2992 cx.run_until_parked();
2993
2994 // Retrieve the assistant response message's start from the context
2995 let response_start_row = context.read_with(cx, |context, cx| {
2996 let buffer = context.buffer.read(cx);
2997 context.message_anchors[1].start.to_point(buffer).row
2998 });
2999
3000 // Simulate the LLM completion
3001 fake_model.send_last_completion_chunk(llm_response.to_string());
3002 fake_model.finish_last_completion();
3003
3004 // Wait for the completion to be processed
3005 cx.run_until_parked();
3006
3007 // Verify that the edit steps were parsed correctly
3008 context.read_with(cx, |context, cx| {
3009 assert_eq!(
3010 edit_steps(context, cx),
3011 vec![
3012 Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 14, 7),
3013 Point::new(response_start_row + 16, 0)..Point::new(response_start_row + 28, 7),
3014 ]
3015 );
3016 });
3017
3018 fn edit_steps(context: &Context, cx: &AppContext) -> Vec<Range<Point>> {
3019 context
3020 .edit_steps
3021 .iter()
3022 .map(|step| {
3023 let buffer = context.buffer.read(cx);
3024 step.source_range.to_point(buffer)
3025 })
3026 .collect()
3027 }
3028 }
3029
3030 #[gpui::test]
3031 async fn test_serialization(cx: &mut TestAppContext) {
3032 let settings_store = cx.update(SettingsStore::test);
3033 cx.set_global(settings_store);
3034 cx.update(language_model::LanguageModelRegistry::test);
3035 cx.update(completion::LanguageModelCompletionProvider::test);
3036 cx.update(assistant_panel::init);
3037 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
3038 let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
3039 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
3040 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
3041 let message_1 = context.update(cx, |context, cx| {
3042 context
3043 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
3044 .unwrap()
3045 });
3046 let message_2 = context.update(cx, |context, cx| {
3047 context
3048 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
3049 .unwrap()
3050 });
3051 buffer.update(cx, |buffer, cx| {
3052 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
3053 buffer.finalize_last_transaction();
3054 });
3055 let _message_3 = context.update(cx, |context, cx| {
3056 context
3057 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
3058 .unwrap()
3059 });
3060 buffer.update(cx, |buffer, cx| buffer.undo(cx));
3061 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
3062 assert_eq!(
3063 cx.read(|cx| messages(&context, cx)),
3064 [
3065 (message_0, Role::User, 0..2),
3066 (message_1.id, Role::Assistant, 2..6),
3067 (message_2.id, Role::System, 6..6),
3068 ]
3069 );
3070
3071 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
3072 let deserialized_context = cx.new_model(|cx| {
3073 Context::deserialize(
3074 serialized_context,
3075 Default::default(),
3076 registry.clone(),
3077 None,
3078 cx,
3079 )
3080 });
3081 let deserialized_buffer =
3082 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
3083 assert_eq!(
3084 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
3085 "a\nb\nc\n"
3086 );
3087 assert_eq!(
3088 cx.read(|cx| messages(&deserialized_context, cx)),
3089 [
3090 (message_0, Role::User, 0..2),
3091 (message_1.id, Role::Assistant, 2..6),
3092 (message_2.id, Role::System, 6..6),
3093 ]
3094 );
3095 }
3096
3097 #[gpui::test(iterations = 100)]
3098 async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
3099 let min_peers = env::var("MIN_PEERS")
3100 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
3101 .unwrap_or(2);
3102 let max_peers = env::var("MAX_PEERS")
3103 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
3104 .unwrap_or(5);
3105 let operations = env::var("OPERATIONS")
3106 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
3107 .unwrap_or(50);
3108
3109 let settings_store = cx.update(SettingsStore::test);
3110 cx.set_global(settings_store);
3111 cx.update(language_model::LanguageModelRegistry::test);
3112 cx.update(completion::LanguageModelCompletionProvider::test);
3113
3114 cx.update(assistant_panel::init);
3115 let slash_commands = cx.update(SlashCommandRegistry::default_global);
3116 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
3117 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
3118 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
3119
3120 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
3121 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
3122 let mut contexts = Vec::new();
3123
3124 let num_peers = rng.gen_range(min_peers..=max_peers);
3125 let context_id = ContextId::new();
3126 for i in 0..num_peers {
3127 let context = cx.new_model(|cx| {
3128 Context::new(
3129 context_id.clone(),
3130 i as ReplicaId,
3131 language::Capability::ReadWrite,
3132 registry.clone(),
3133 None,
3134 cx,
3135 )
3136 });
3137
3138 cx.update(|cx| {
3139 cx.subscribe(&context, {
3140 let network = network.clone();
3141 move |_, event, _| {
3142 if let ContextEvent::Operation(op) = event {
3143 network
3144 .lock()
3145 .broadcast(i as ReplicaId, vec![op.to_proto()]);
3146 }
3147 }
3148 })
3149 .detach();
3150 });
3151
3152 contexts.push(context);
3153 network.lock().add_peer(i as ReplicaId);
3154 }
3155
3156 let mut mutation_count = operations;
3157
3158 while mutation_count > 0
3159 || !network.lock().is_idle()
3160 || network.lock().contains_disconnected_peers()
3161 {
3162 let context_index = rng.gen_range(0..contexts.len());
3163 let context = &contexts[context_index];
3164
3165 match rng.gen_range(0..100) {
3166 0..=29 if mutation_count > 0 => {
3167 log::info!("Context {}: edit buffer", context_index);
3168 context.update(cx, |context, cx| {
3169 context
3170 .buffer
3171 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
3172 });
3173 mutation_count -= 1;
3174 }
3175 30..=44 if mutation_count > 0 => {
3176 context.update(cx, |context, cx| {
3177 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
3178 log::info!("Context {}: split message at {:?}", context_index, range);
3179 context.split_message(range, cx);
3180 });
3181 mutation_count -= 1;
3182 }
3183 45..=59 if mutation_count > 0 => {
3184 context.update(cx, |context, cx| {
3185 if let Some(message) = context.messages(cx).choose(&mut rng) {
3186 let role = *[Role::User, Role::Assistant, Role::System]
3187 .choose(&mut rng)
3188 .unwrap();
3189 log::info!(
3190 "Context {}: insert message after {:?} with {:?}",
3191 context_index,
3192 message.id,
3193 role
3194 );
3195 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
3196 }
3197 });
3198 mutation_count -= 1;
3199 }
3200 60..=74 if mutation_count > 0 => {
3201 context.update(cx, |context, cx| {
3202 let command_text = "/".to_string()
3203 + slash_commands
3204 .command_names()
3205 .choose(&mut rng)
3206 .unwrap()
3207 .clone()
3208 .as_ref();
3209
3210 let command_range = context.buffer.update(cx, |buffer, cx| {
3211 let offset = buffer.random_byte_range(0, &mut rng).start;
3212 buffer.edit(
3213 [(offset..offset, format!("\n{}\n", command_text))],
3214 None,
3215 cx,
3216 );
3217 offset + 1..offset + 1 + command_text.len()
3218 });
3219
3220 let output_len = rng.gen_range(1..=10);
3221 let output_text = RandomCharIter::new(&mut rng)
3222 .filter(|c| *c != '\r')
3223 .take(output_len)
3224 .collect::<String>();
3225
3226 let num_sections = rng.gen_range(0..=3);
3227 let mut sections = Vec::with_capacity(num_sections);
3228 for _ in 0..num_sections {
3229 let section_start = rng.gen_range(0..output_len);
3230 let section_end = rng.gen_range(section_start..=output_len);
3231 sections.push(SlashCommandOutputSection {
3232 range: section_start..section_end,
3233 icon: ui::IconName::Ai,
3234 label: "section".into(),
3235 });
3236 }
3237
3238 log::info!(
3239 "Context {}: insert slash command output at {:?} with {:?}",
3240 context_index,
3241 command_range,
3242 sections
3243 );
3244
3245 let command_range =
3246 context.buffer.read(cx).anchor_after(command_range.start)
3247 ..context.buffer.read(cx).anchor_after(command_range.end);
3248 context.insert_command_output(
3249 command_range,
3250 Task::ready(Ok(SlashCommandOutput {
3251 text: output_text,
3252 sections,
3253 run_commands_in_text: false,
3254 })),
3255 true,
3256 cx,
3257 );
3258 });
3259 cx.run_until_parked();
3260 mutation_count -= 1;
3261 }
3262 75..=84 if mutation_count > 0 => {
3263 context.update(cx, |context, cx| {
3264 if let Some(message) = context.messages(cx).choose(&mut rng) {
3265 let new_status = match rng.gen_range(0..3) {
3266 0 => MessageStatus::Done,
3267 1 => MessageStatus::Pending,
3268 _ => MessageStatus::Error(SharedString::from("Random error")),
3269 };
3270 log::info!(
3271 "Context {}: update message {:?} status to {:?}",
3272 context_index,
3273 message.id,
3274 new_status
3275 );
3276 context.update_metadata(message.id, cx, |metadata| {
3277 metadata.status = new_status;
3278 });
3279 }
3280 });
3281 mutation_count -= 1;
3282 }
3283 _ => {
3284 let replica_id = context_index as ReplicaId;
3285 if network.lock().is_disconnected(replica_id) {
3286 network.lock().reconnect_peer(replica_id, 0);
3287
3288 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
3289 let host_context = &contexts[0].read(cx);
3290 let guest_context = context.read(cx);
3291 (
3292 guest_context.serialize_ops(&host_context.version(cx), cx),
3293 host_context.serialize_ops(&guest_context.version(cx), cx),
3294 )
3295 });
3296 let ops_to_send = ops_to_send.await;
3297 let ops_to_receive = ops_to_receive
3298 .await
3299 .into_iter()
3300 .map(ContextOperation::from_proto)
3301 .collect::<Result<Vec<_>>>()
3302 .unwrap();
3303 log::info!(
3304 "Context {}: reconnecting. Sent {} operations, received {} operations",
3305 context_index,
3306 ops_to_send.len(),
3307 ops_to_receive.len()
3308 );
3309
3310 network.lock().broadcast(replica_id, ops_to_send);
3311 context
3312 .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx))
3313 .unwrap();
3314 } else if rng.gen_bool(0.1) && replica_id != 0 {
3315 log::info!("Context {}: disconnecting", context_index);
3316 network.lock().disconnect_peer(replica_id);
3317 } else if network.lock().has_unreceived(replica_id) {
3318 log::info!("Context {}: applying operations", context_index);
3319 let ops = network.lock().receive(replica_id);
3320 let ops = ops
3321 .into_iter()
3322 .map(ContextOperation::from_proto)
3323 .collect::<Result<Vec<_>>>()
3324 .unwrap();
3325 context
3326 .update(cx, |context, cx| context.apply_ops(ops, cx))
3327 .unwrap();
3328 }
3329 }
3330 }
3331 }
3332
3333 cx.read(|cx| {
3334 let first_context = contexts[0].read(cx);
3335 for context in &contexts[1..] {
3336 let context = context.read(cx);
3337 assert!(context.pending_ops.is_empty());
3338 assert_eq!(
3339 context.buffer.read(cx).text(),
3340 first_context.buffer.read(cx).text(),
3341 "Context {} text != Context 0 text",
3342 context.buffer.read(cx).replica_id()
3343 );
3344 assert_eq!(
3345 context.message_anchors,
3346 first_context.message_anchors,
3347 "Context {} messages != Context 0 messages",
3348 context.buffer.read(cx).replica_id()
3349 );
3350 assert_eq!(
3351 context.messages_metadata,
3352 first_context.messages_metadata,
3353 "Context {} message metadata != Context 0 message metadata",
3354 context.buffer.read(cx).replica_id()
3355 );
3356 assert_eq!(
3357 context.slash_command_output_sections,
3358 first_context.slash_command_output_sections,
3359 "Context {} slash command output sections != Context 0 slash command output sections",
3360 context.buffer.read(cx).replica_id()
3361 );
3362 }
3363 });
3364 }
3365
3366 fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
3367 context
3368 .read(cx)
3369 .messages(cx)
3370 .map(|message| (message.id, message.role, message.offset_range))
3371 .collect()
3372 }
3373
3374 #[derive(Clone)]
3375 struct FakeSlashCommand(String);
3376
3377 impl SlashCommand for FakeSlashCommand {
3378 fn name(&self) -> String {
3379 self.0.clone()
3380 }
3381
3382 fn description(&self) -> String {
3383 format!("Fake slash command: {}", self.0)
3384 }
3385
3386 fn menu_text(&self) -> String {
3387 format!("Run fake command: {}", self.0)
3388 }
3389
3390 fn complete_argument(
3391 self: Arc<Self>,
3392 _query: String,
3393 _cancel: Arc<AtomicBool>,
3394 _workspace: Option<WeakView<Workspace>>,
3395 _cx: &mut AppContext,
3396 ) -> Task<Result<Vec<ArgumentCompletion>>> {
3397 Task::ready(Ok(vec![]))
3398 }
3399
3400 fn requires_argument(&self) -> bool {
3401 false
3402 }
3403
3404 fn run(
3405 self: Arc<Self>,
3406 _argument: Option<&str>,
3407 _workspace: WeakView<Workspace>,
3408 _delegate: Arc<dyn LspAdapterDelegate>,
3409 _cx: &mut WindowContext,
3410 ) -> Task<Result<SlashCommandOutput>> {
3411 Task::ready(Ok(SlashCommandOutput {
3412 text: format!("Executed fake command: {}", self.0),
3413 sections: vec![],
3414 run_commands_in_text: false,
3415 }))
3416 }
3417 }
3418}