1use crate::{
2 prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
3 LanguageModelCompletionProvider, MessageId, MessageStatus,
4};
5use anyhow::{anyhow, Context as _, Result};
6use assistant_slash_command::{
7 SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
8};
9use client::{self, proto, telemetry::Telemetry};
10use clock::ReplicaId;
11use collections::{HashMap, HashSet};
12use fs::{Fs, RemoveOptions};
13use futures::{
14 future::{self, Shared},
15 FutureExt, StreamExt,
16};
17use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscription, Task};
18use language::{
19 AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
20};
21use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
22use open_ai::Model as OpenAiModel;
23use paths::contexts_dir;
24use project::Project;
25use schemars::JsonSchema;
26use serde::{Deserialize, Serialize};
27use std::{
28 cmp,
29 fmt::Debug,
30 iter, mem,
31 ops::Range,
32 path::{Path, PathBuf},
33 sync::Arc,
34 time::{Duration, Instant},
35};
36use telemetry_events::AssistantKind;
37use ui::SharedString;
38use util::{post_inc, ResultExt, TryFutureExt};
39use uuid::Uuid;
40
41#[derive(Clone, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
42pub struct ContextId(String);
43
44impl ContextId {
45 pub fn new() -> Self {
46 Self(Uuid::new_v4().to_string())
47 }
48
49 pub fn from_proto(id: String) -> Self {
50 Self(id)
51 }
52
53 pub fn to_proto(&self) -> String {
54 self.0.clone()
55 }
56}
57
58#[derive(Clone, Debug)]
59pub enum ContextOperation {
60 InsertMessage {
61 anchor: MessageAnchor,
62 metadata: MessageMetadata,
63 version: clock::Global,
64 },
65 UpdateMessage {
66 message_id: MessageId,
67 metadata: MessageMetadata,
68 version: clock::Global,
69 },
70 UpdateSummary {
71 summary: ContextSummary,
72 version: clock::Global,
73 },
74 SlashCommandFinished {
75 id: SlashCommandId,
76 output_range: Range<language::Anchor>,
77 sections: Vec<SlashCommandOutputSection<language::Anchor>>,
78 version: clock::Global,
79 },
80 BufferOperation(language::Operation),
81}
82
83impl ContextOperation {
84 pub fn from_proto(op: proto::ContextOperation) -> Result<Self> {
85 match op.variant.context("invalid variant")? {
86 proto::context_operation::Variant::InsertMessage(insert) => {
87 let message = insert.message.context("invalid message")?;
88 let id = MessageId(language::proto::deserialize_timestamp(
89 message.id.context("invalid id")?,
90 ));
91 Ok(Self::InsertMessage {
92 anchor: MessageAnchor {
93 id,
94 start: language::proto::deserialize_anchor(
95 message.start.context("invalid anchor")?,
96 )
97 .context("invalid anchor")?,
98 },
99 metadata: MessageMetadata {
100 role: Role::from_proto(message.role),
101 status: MessageStatus::from_proto(
102 message.status.context("invalid status")?,
103 ),
104 timestamp: id.0,
105 },
106 version: language::proto::deserialize_version(&insert.version),
107 })
108 }
109 proto::context_operation::Variant::UpdateMessage(update) => Ok(Self::UpdateMessage {
110 message_id: MessageId(language::proto::deserialize_timestamp(
111 update.message_id.context("invalid message id")?,
112 )),
113 metadata: MessageMetadata {
114 role: Role::from_proto(update.role),
115 status: MessageStatus::from_proto(update.status.context("invalid status")?),
116 timestamp: language::proto::deserialize_timestamp(
117 update.timestamp.context("invalid timestamp")?,
118 ),
119 },
120 version: language::proto::deserialize_version(&update.version),
121 }),
122 proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary {
123 summary: ContextSummary {
124 text: update.summary,
125 done: update.done,
126 timestamp: language::proto::deserialize_timestamp(
127 update.timestamp.context("invalid timestamp")?,
128 ),
129 },
130 version: language::proto::deserialize_version(&update.version),
131 }),
132 proto::context_operation::Variant::SlashCommandFinished(finished) => {
133 Ok(Self::SlashCommandFinished {
134 id: SlashCommandId(language::proto::deserialize_timestamp(
135 finished.id.context("invalid id")?,
136 )),
137 output_range: language::proto::deserialize_anchor_range(
138 finished.output_range.context("invalid range")?,
139 )?,
140 sections: finished
141 .sections
142 .into_iter()
143 .map(|section| {
144 Ok(SlashCommandOutputSection {
145 range: language::proto::deserialize_anchor_range(
146 section.range.context("invalid range")?,
147 )?,
148 icon: section.icon_name.parse()?,
149 label: section.label.into(),
150 })
151 })
152 .collect::<Result<Vec<_>>>()?,
153 version: language::proto::deserialize_version(&finished.version),
154 })
155 }
156 proto::context_operation::Variant::BufferOperation(op) => Ok(Self::BufferOperation(
157 language::proto::deserialize_operation(
158 op.operation.context("invalid buffer operation")?,
159 )?,
160 )),
161 }
162 }
163
164 pub fn to_proto(&self) -> proto::ContextOperation {
165 match self {
166 Self::InsertMessage {
167 anchor,
168 metadata,
169 version,
170 } => proto::ContextOperation {
171 variant: Some(proto::context_operation::Variant::InsertMessage(
172 proto::context_operation::InsertMessage {
173 message: Some(proto::ContextMessage {
174 id: Some(language::proto::serialize_timestamp(anchor.id.0)),
175 start: Some(language::proto::serialize_anchor(&anchor.start)),
176 role: metadata.role.to_proto() as i32,
177 status: Some(metadata.status.to_proto()),
178 }),
179 version: language::proto::serialize_version(version),
180 },
181 )),
182 },
183 Self::UpdateMessage {
184 message_id,
185 metadata,
186 version,
187 } => proto::ContextOperation {
188 variant: Some(proto::context_operation::Variant::UpdateMessage(
189 proto::context_operation::UpdateMessage {
190 message_id: Some(language::proto::serialize_timestamp(message_id.0)),
191 role: metadata.role.to_proto() as i32,
192 status: Some(metadata.status.to_proto()),
193 timestamp: Some(language::proto::serialize_timestamp(metadata.timestamp)),
194 version: language::proto::serialize_version(version),
195 },
196 )),
197 },
198 Self::UpdateSummary { summary, version } => proto::ContextOperation {
199 variant: Some(proto::context_operation::Variant::UpdateSummary(
200 proto::context_operation::UpdateSummary {
201 summary: summary.text.clone(),
202 done: summary.done,
203 timestamp: Some(language::proto::serialize_timestamp(summary.timestamp)),
204 version: language::proto::serialize_version(version),
205 },
206 )),
207 },
208 Self::SlashCommandFinished {
209 id,
210 output_range,
211 sections,
212 version,
213 } => proto::ContextOperation {
214 variant: Some(proto::context_operation::Variant::SlashCommandFinished(
215 proto::context_operation::SlashCommandFinished {
216 id: Some(language::proto::serialize_timestamp(id.0)),
217 output_range: Some(language::proto::serialize_anchor_range(
218 output_range.clone(),
219 )),
220 sections: sections
221 .iter()
222 .map(|section| {
223 let icon_name: &'static str = section.icon.into();
224 proto::SlashCommandOutputSection {
225 range: Some(language::proto::serialize_anchor_range(
226 section.range.clone(),
227 )),
228 icon_name: icon_name.to_string(),
229 label: section.label.to_string(),
230 }
231 })
232 .collect(),
233 version: language::proto::serialize_version(version),
234 },
235 )),
236 },
237 Self::BufferOperation(operation) => proto::ContextOperation {
238 variant: Some(proto::context_operation::Variant::BufferOperation(
239 proto::context_operation::BufferOperation {
240 operation: Some(language::proto::serialize_operation(operation)),
241 },
242 )),
243 },
244 }
245 }
246
247 fn timestamp(&self) -> clock::Lamport {
248 match self {
249 Self::InsertMessage { anchor, .. } => anchor.id.0,
250 Self::UpdateMessage { metadata, .. } => metadata.timestamp,
251 Self::UpdateSummary { summary, .. } => summary.timestamp,
252 Self::SlashCommandFinished { id, .. } => id.0,
253 Self::BufferOperation(_) => {
254 panic!("reading the timestamp of a buffer operation is not supported")
255 }
256 }
257 }
258
259 /// Returns the current version of the context operation.
260 pub fn version(&self) -> &clock::Global {
261 match self {
262 Self::InsertMessage { version, .. }
263 | Self::UpdateMessage { version, .. }
264 | Self::UpdateSummary { version, .. }
265 | Self::SlashCommandFinished { version, .. } => version,
266 Self::BufferOperation(_) => {
267 panic!("reading the version of a buffer operation is not supported")
268 }
269 }
270 }
271}
272
273#[derive(Clone)]
274pub enum ContextEvent {
275 MessagesEdited,
276 SummaryChanged,
277 EditStepsChanged,
278 StreamedCompletion,
279 PendingSlashCommandsUpdated {
280 removed: Vec<Range<language::Anchor>>,
281 updated: Vec<PendingSlashCommand>,
282 },
283 SlashCommandFinished {
284 output_range: Range<language::Anchor>,
285 sections: Vec<SlashCommandOutputSection<language::Anchor>>,
286 run_commands_in_output: bool,
287 },
288 Operation(ContextOperation),
289}
290
291#[derive(Clone, Default, Debug)]
292pub struct ContextSummary {
293 pub text: String,
294 done: bool,
295 timestamp: clock::Lamport,
296}
297
298#[derive(Clone, Debug, Eq, PartialEq)]
299pub struct MessageAnchor {
300 pub id: MessageId,
301 pub start: language::Anchor,
302}
303
304#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
305pub struct MessageMetadata {
306 pub role: Role,
307 status: MessageStatus,
308 timestamp: clock::Lamport,
309}
310
311#[derive(Clone, Debug, PartialEq, Eq)]
312pub struct Message {
313 pub offset_range: Range<usize>,
314 pub index_range: Range<usize>,
315 pub id: MessageId,
316 pub anchor: language::Anchor,
317 pub role: Role,
318 pub status: MessageStatus,
319}
320
321impl Message {
322 fn to_request_message(&self, buffer: &Buffer) -> LanguageModelRequestMessage {
323 LanguageModelRequestMessage {
324 role: self.role,
325 content: buffer.text_for_range(self.offset_range.clone()).collect(),
326 }
327 }
328}
329
330struct PendingCompletion {
331 id: usize,
332 _task: Task<()>,
333}
334
335#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
336pub struct SlashCommandId(clock::Lamport);
337
338#[derive(Debug)]
339pub struct EditStep {
340 pub source_range: Range<language::Anchor>,
341 pub operations: Option<EditStepOperations>,
342}
343
344#[derive(Debug)]
345pub struct EditSuggestionGroup {
346 pub context_range: Range<language::Anchor>,
347 pub suggestions: Vec<EditSuggestion>,
348}
349
350#[derive(Debug)]
351pub struct EditSuggestion {
352 pub range: Range<language::Anchor>,
353 /// If None, assume this is a suggestion to delete the range rather than transform it.
354 pub description: Option<String>,
355 pub initial_insertion: Option<InitialInsertion>,
356}
357
358impl EditStep {
359 pub fn edit_suggestions(
360 &self,
361 project: &Model<Project>,
362 cx: &AppContext,
363 ) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
364 let Some(EditStepOperations::Ready(operations)) = &self.operations else {
365 return Task::ready(HashMap::default());
366 };
367
368 let suggestion_tasks: Vec<_> = operations
369 .iter()
370 .map(|operation| operation.edit_suggestion(project.clone(), cx))
371 .collect();
372
373 cx.spawn(|mut cx| async move {
374 let suggestions = future::join_all(suggestion_tasks)
375 .await
376 .into_iter()
377 .filter_map(|task| task.log_err())
378 .collect::<Vec<_>>();
379
380 let mut suggestions_by_buffer = HashMap::default();
381 for (buffer, suggestion) in suggestions {
382 suggestions_by_buffer
383 .entry(buffer)
384 .or_insert_with(Vec::new)
385 .push(suggestion);
386 }
387
388 let mut suggestion_groups_by_buffer = HashMap::default();
389 for (buffer, mut suggestions) in suggestions_by_buffer {
390 let mut suggestion_groups = Vec::<EditSuggestionGroup>::new();
391 buffer
392 .update(&mut cx, |buffer, _cx| {
393 // Sort suggestions by their range
394 suggestions.sort_by(|a, b| a.range.cmp(&b.range, buffer));
395
396 // Dedup overlapping suggestions
397 suggestions.dedup_by(|a, b| {
398 let a_range = a.range.to_offset(buffer);
399 let b_range = b.range.to_offset(buffer);
400 if a_range.start <= b_range.end && b_range.start <= a_range.end {
401 if b_range.start < a_range.start {
402 a.range.start = b.range.start;
403 }
404 if b_range.end > a_range.end {
405 a.range.end = b.range.end;
406 }
407
408 if let (Some(a_desc), Some(b_desc)) =
409 (a.description.as_mut(), b.description.as_mut())
410 {
411 b_desc.push('\n');
412 b_desc.push_str(a_desc);
413 } else if a.description.is_some() {
414 b.description = a.description.take();
415 }
416
417 true
418 } else {
419 false
420 }
421 });
422
423 // Create context ranges for each suggestion
424 for suggestion in suggestions {
425 let context_range = {
426 let suggestion_point_range = suggestion.range.to_point(buffer);
427 let start_row = suggestion_point_range.start.row.saturating_sub(5);
428 let end_row = cmp::min(
429 suggestion_point_range.end.row + 5,
430 buffer.max_point().row,
431 );
432 let start = buffer.anchor_before(Point::new(start_row, 0));
433 let end = buffer
434 .anchor_after(Point::new(end_row, buffer.line_len(end_row)));
435 start..end
436 };
437
438 if let Some(last_group) = suggestion_groups.last_mut() {
439 if last_group
440 .context_range
441 .end
442 .cmp(&context_range.start, buffer)
443 .is_ge()
444 {
445 // Merge with the previous group if context ranges overlap
446 last_group.context_range.end = context_range.end;
447 last_group.suggestions.push(suggestion);
448 } else {
449 // Create a new group
450 suggestion_groups.push(EditSuggestionGroup {
451 context_range,
452 suggestions: vec![suggestion],
453 });
454 }
455 } else {
456 // Create the first group
457 suggestion_groups.push(EditSuggestionGroup {
458 context_range,
459 suggestions: vec![suggestion],
460 });
461 }
462 }
463 })
464 .ok();
465 suggestion_groups_by_buffer.insert(buffer, suggestion_groups);
466 }
467
468 suggestion_groups_by_buffer
469 })
470 }
471}
472
473pub enum EditStepOperations {
474 Pending(Task<Option<()>>),
475 Ready(Vec<EditOperation>),
476}
477
478impl Debug for EditStepOperations {
479 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480 match self {
481 EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
482 EditStepOperations::Ready(operations) => f
483 .debug_struct("EditStepOperations::Parsed")
484 .field("operations", operations)
485 .finish(),
486 }
487 }
488}
489
490/// A description of an operation to apply to one location in the codebase.
491#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
492pub struct EditOperation {
493 /// The path to the file containing the relevant operation
494 pub path: String,
495 #[serde(flatten)]
496 pub kind: EditOperationKind,
497}
498
499impl EditOperation {
500 fn edit_suggestion(
501 &self,
502 project: Model<Project>,
503 cx: &AppContext,
504 ) -> Task<Result<(Model<language::Buffer>, EditSuggestion)>> {
505 let path = self.path.clone();
506 let kind = self.kind.clone();
507 cx.spawn(move |mut cx| async move {
508 let buffer = project
509 .update(&mut cx, |project, cx| {
510 let project_path = project
511 .project_path_for_full_path(Path::new(&path), cx)
512 .with_context(|| format!("worktree not found for {:?}", path))?;
513 anyhow::Ok(project.open_buffer(project_path, cx))
514 })??
515 .await?;
516
517 let mut parse_status = buffer.read_with(&cx, |buffer, _cx| buffer.parse_status())?;
518 while *parse_status.borrow() != ParseStatus::Idle {
519 parse_status.changed().await?;
520 }
521
522 let initial_insertion = kind.initial_insertion();
523 let suggestion_range = if let Some(symbol) = kind.symbol() {
524 let outline = buffer
525 .update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
526 .context("no outline for buffer")?;
527 let candidate = outline
528 .path_candidates
529 .iter()
530 .find(|item| item.string == symbol)
531 .with_context(|| {
532 format!(
533 "symbol {:?} not found in path {:?}.\ncandidates: {:?}.\nparse status: {:?}. text:\n{}",
534 symbol,
535 path,
536 outline
537 .path_candidates
538 .iter()
539 .map(|candidate| &candidate.string)
540 .collect::<Vec<_>>(),
541 *parse_status.borrow(),
542 buffer.read_with(&cx, |buffer, _| buffer.text()).unwrap_or_else(|_| "error".to_string())
543 )
544 })?;
545
546 buffer.update(&mut cx, |buffer, _| {
547 let outline_item = &outline.items[candidate.id];
548 let symbol_range = outline_item.range.to_point(buffer);
549 let body_range = outline_item
550 .body_range
551 .as_ref()
552 .map(|range| range.to_point(buffer))
553 .unwrap_or(symbol_range.clone());
554
555 match kind {
556 EditOperationKind::PrependChild { .. } => {
557 let position = buffer.anchor_after(body_range.start);
558 position..position
559 }
560 EditOperationKind::AppendChild { .. } => {
561 let position = buffer.anchor_before(body_range.end);
562 position..position
563 }
564 EditOperationKind::InsertSiblingBefore { .. } => {
565 let position = buffer.anchor_before(symbol_range.start);
566 position..position
567 }
568 EditOperationKind::InsertSiblingAfter { .. } => {
569 let position = buffer.anchor_after(symbol_range.end);
570 position..position
571 }
572 EditOperationKind::Update { .. } | EditOperationKind::Delete { .. } => {
573 let start = Point::new(symbol_range.start.row, 0);
574 let end = Point::new(
575 symbol_range.end.row,
576 buffer.line_len(symbol_range.end.row),
577 );
578 buffer.anchor_before(start)..buffer.anchor_after(end)
579 }
580 EditOperationKind::Create { .. } => unreachable!(),
581 }
582 })?
583 } else {
584 match kind {
585 EditOperationKind::PrependChild { .. } => {
586 language::Anchor::MIN..language::Anchor::MIN
587 }
588 EditOperationKind::AppendChild { .. } | EditOperationKind::Create { .. } => {
589 language::Anchor::MAX..language::Anchor::MAX
590 }
591 _ => unreachable!("All other operations should have a symbol"),
592 }
593 };
594
595 Ok((
596 buffer,
597 EditSuggestion {
598 range: suggestion_range,
599 description: kind.description().map(ToString::to_string),
600 initial_insertion,
601 },
602 ))
603 })
604 }
605}
606
607#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
608#[serde(tag = "kind")]
609pub enum EditOperationKind {
610 /// Rewrite the specified symbol in its entirely based on the given description.
611 Update {
612 /// A full path to the symbol to be rewritten from the provided list.
613 symbol: String,
614 /// A brief one-line description of the change that should be applied.
615 description: String,
616 },
617 /// Create a new file with the given path based on the given description.
618 Create {
619 /// A brief one-line description of the change that should be applied.
620 description: String,
621 },
622 /// Insert a new symbol based on the given description before the specified symbol.
623 InsertSiblingBefore {
624 /// A full path to the symbol to be rewritten from the provided list.
625 symbol: String,
626 /// A brief one-line description of the change that should be applied.
627 description: String,
628 },
629 /// Insert a new symbol based on the given description after the specified symbol.
630 InsertSiblingAfter {
631 /// A full path to the symbol to be rewritten from the provided list.
632 symbol: String,
633 /// A brief one-line description of the change that should be applied.
634 description: String,
635 },
636 /// Insert a new symbol as a child of the specified symbol at the start.
637 PrependChild {
638 /// An optional full path to the symbol to be rewritten from the provided list.
639 /// If not provided, the edit should be applied at the top of the file.
640 symbol: Option<String>,
641 /// A brief one-line description of the change that should be applied.
642 description: String,
643 },
644 /// Insert a new symbol as a child of the specified symbol at the end.
645 AppendChild {
646 /// An optional full path to the symbol to be rewritten from the provided list.
647 /// If not provided, the edit should be applied at the top of the file.
648 symbol: Option<String>,
649 /// A brief one-line description of the change that should be applied.
650 description: String,
651 },
652 /// Delete the specified symbol.
653 Delete {
654 /// A full path to the symbol to be rewritten from the provided list.
655 symbol: String,
656 },
657}
658
659impl EditOperationKind {
660 pub fn symbol(&self) -> Option<&str> {
661 match self {
662 Self::Update { symbol, .. } => Some(symbol),
663 Self::InsertSiblingBefore { symbol, .. } => Some(symbol),
664 Self::InsertSiblingAfter { symbol, .. } => Some(symbol),
665 Self::PrependChild { symbol, .. } => symbol.as_deref(),
666 Self::AppendChild { symbol, .. } => symbol.as_deref(),
667 Self::Delete { symbol } => Some(symbol),
668 Self::Create { .. } => None,
669 }
670 }
671
672 pub fn description(&self) -> Option<&str> {
673 match self {
674 Self::Update { description, .. } => Some(description),
675 Self::Create { description } => Some(description),
676 Self::InsertSiblingBefore { description, .. } => Some(description),
677 Self::InsertSiblingAfter { description, .. } => Some(description),
678 Self::PrependChild { description, .. } => Some(description),
679 Self::AppendChild { description, .. } => Some(description),
680 Self::Delete { .. } => None,
681 }
682 }
683
684 pub fn initial_insertion(&self) -> Option<InitialInsertion> {
685 match self {
686 EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
687 EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
688 EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
689 EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
690 _ => None,
691 }
692 }
693}
694
695pub struct Context {
696 id: ContextId,
697 timestamp: clock::Lamport,
698 version: clock::Global,
699 pending_ops: Vec<ContextOperation>,
700 operations: Vec<ContextOperation>,
701 buffer: Model<Buffer>,
702 pending_slash_commands: Vec<PendingSlashCommand>,
703 edits_since_last_slash_command_parse: language::Subscription,
704 finished_slash_commands: HashSet<SlashCommandId>,
705 slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
706 message_anchors: Vec<MessageAnchor>,
707 messages_metadata: HashMap<MessageId, MessageMetadata>,
708 summary: Option<ContextSummary>,
709 pending_summary: Task<Option<()>>,
710 completion_count: usize,
711 pending_completions: Vec<PendingCompletion>,
712 token_count: Option<usize>,
713 pending_token_count: Task<Option<()>>,
714 pending_save: Task<Result<()>>,
715 path: Option<PathBuf>,
716 _subscriptions: Vec<Subscription>,
717 telemetry: Option<Arc<Telemetry>>,
718 language_registry: Arc<LanguageRegistry>,
719 edit_steps: Vec<EditStep>,
720}
721
722impl EventEmitter<ContextEvent> for Context {}
723
724impl Context {
725 pub fn local(
726 language_registry: Arc<LanguageRegistry>,
727 telemetry: Option<Arc<Telemetry>>,
728 cx: &mut ModelContext<Self>,
729 ) -> Self {
730 Self::new(
731 ContextId::new(),
732 ReplicaId::default(),
733 language::Capability::ReadWrite,
734 language_registry,
735 telemetry,
736 cx,
737 )
738 }
739
740 pub fn new(
741 id: ContextId,
742 replica_id: ReplicaId,
743 capability: language::Capability,
744 language_registry: Arc<LanguageRegistry>,
745 telemetry: Option<Arc<Telemetry>>,
746 cx: &mut ModelContext<Self>,
747 ) -> Self {
748 let buffer = cx.new_model(|_cx| {
749 let mut buffer = Buffer::remote(
750 language::BufferId::new(1).unwrap(),
751 replica_id,
752 capability,
753 "",
754 );
755 buffer.set_language_registry(language_registry.clone());
756 buffer
757 });
758 let edits_since_last_slash_command_parse =
759 buffer.update(cx, |buffer, _| buffer.subscribe());
760 let mut this = Self {
761 id,
762 timestamp: clock::Lamport::new(replica_id),
763 version: clock::Global::new(),
764 pending_ops: Vec::new(),
765 operations: Vec::new(),
766 message_anchors: Default::default(),
767 messages_metadata: Default::default(),
768 pending_slash_commands: Vec::new(),
769 finished_slash_commands: HashSet::default(),
770 slash_command_output_sections: Vec::new(),
771 edits_since_last_slash_command_parse,
772 summary: None,
773 pending_summary: Task::ready(None),
774 completion_count: Default::default(),
775 pending_completions: Default::default(),
776 token_count: None,
777 pending_token_count: Task::ready(None),
778 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
779 pending_save: Task::ready(Ok(())),
780 path: None,
781 buffer,
782 telemetry,
783 language_registry,
784 edit_steps: Vec::new(),
785 };
786
787 let first_message_id = MessageId(clock::Lamport {
788 replica_id: 0,
789 value: 0,
790 });
791 let message = MessageAnchor {
792 id: first_message_id,
793 start: language::Anchor::MIN,
794 };
795 this.messages_metadata.insert(
796 first_message_id,
797 MessageMetadata {
798 role: Role::User,
799 status: MessageStatus::Done,
800 timestamp: first_message_id.0,
801 },
802 );
803 this.message_anchors.push(message);
804
805 this.set_language(cx);
806 this.count_remaining_tokens(cx);
807 this
808 }
809
810 fn serialize(&self, cx: &AppContext) -> SavedContext {
811 let buffer = self.buffer.read(cx);
812 SavedContext {
813 id: Some(self.id.clone()),
814 zed: "context".into(),
815 version: SavedContext::VERSION.into(),
816 text: buffer.text(),
817 messages: self
818 .messages(cx)
819 .map(|message| SavedMessage {
820 id: message.id,
821 start: message.offset_range.start,
822 metadata: self.messages_metadata[&message.id].clone(),
823 })
824 .collect(),
825 summary: self
826 .summary
827 .as_ref()
828 .map(|summary| summary.text.clone())
829 .unwrap_or_default(),
830 slash_command_output_sections: self
831 .slash_command_output_sections
832 .iter()
833 .filter_map(|section| {
834 let range = section.range.to_offset(buffer);
835 if section.range.start.is_valid(buffer) && !range.is_empty() {
836 Some(assistant_slash_command::SlashCommandOutputSection {
837 range,
838 icon: section.icon,
839 label: section.label.clone(),
840 })
841 } else {
842 None
843 }
844 })
845 .collect(),
846 }
847 }
848
849 #[allow(clippy::too_many_arguments)]
850 pub fn deserialize(
851 saved_context: SavedContext,
852 path: PathBuf,
853 language_registry: Arc<LanguageRegistry>,
854 telemetry: Option<Arc<Telemetry>>,
855 cx: &mut ModelContext<Self>,
856 ) -> Self {
857 let id = saved_context.id.clone().unwrap_or_else(|| ContextId::new());
858 let mut this = Self::new(
859 id,
860 ReplicaId::default(),
861 language::Capability::ReadWrite,
862 language_registry,
863 telemetry,
864 cx,
865 );
866 this.path = Some(path);
867 this.buffer.update(cx, |buffer, cx| {
868 buffer.set_text(saved_context.text.as_str(), cx)
869 });
870 let operations = saved_context.into_ops(&this.buffer, cx);
871 this.apply_ops(operations, cx).unwrap();
872 this
873 }
874
875 pub fn id(&self) -> &ContextId {
876 &self.id
877 }
878
879 pub fn replica_id(&self) -> ReplicaId {
880 self.timestamp.replica_id
881 }
882
883 pub fn version(&self, cx: &AppContext) -> ContextVersion {
884 ContextVersion {
885 context: self.version.clone(),
886 buffer: self.buffer.read(cx).version(),
887 }
888 }
889
890 pub fn set_capability(
891 &mut self,
892 capability: language::Capability,
893 cx: &mut ModelContext<Self>,
894 ) {
895 self.buffer
896 .update(cx, |buffer, cx| buffer.set_capability(capability, cx));
897 }
898
899 fn next_timestamp(&mut self) -> clock::Lamport {
900 let timestamp = self.timestamp.tick();
901 self.version.observe(timestamp);
902 timestamp
903 }
904
905 pub fn serialize_ops(
906 &self,
907 since: &ContextVersion,
908 cx: &AppContext,
909 ) -> Task<Vec<proto::ContextOperation>> {
910 let buffer_ops = self
911 .buffer
912 .read(cx)
913 .serialize_ops(Some(since.buffer.clone()), cx);
914
915 let mut context_ops = self
916 .operations
917 .iter()
918 .filter(|op| !since.context.observed(op.timestamp()))
919 .cloned()
920 .collect::<Vec<_>>();
921 context_ops.extend(self.pending_ops.iter().cloned());
922
923 cx.background_executor().spawn(async move {
924 let buffer_ops = buffer_ops.await;
925 context_ops.sort_unstable_by_key(|op| op.timestamp());
926 buffer_ops
927 .into_iter()
928 .map(|op| proto::ContextOperation {
929 variant: Some(proto::context_operation::Variant::BufferOperation(
930 proto::context_operation::BufferOperation {
931 operation: Some(op),
932 },
933 )),
934 })
935 .chain(context_ops.into_iter().map(|op| op.to_proto()))
936 .collect()
937 })
938 }
939
940 pub fn apply_ops(
941 &mut self,
942 ops: impl IntoIterator<Item = ContextOperation>,
943 cx: &mut ModelContext<Self>,
944 ) -> Result<()> {
945 let mut buffer_ops = Vec::new();
946 for op in ops {
947 match op {
948 ContextOperation::BufferOperation(buffer_op) => buffer_ops.push(buffer_op),
949 op @ _ => self.pending_ops.push(op),
950 }
951 }
952 self.buffer
953 .update(cx, |buffer, cx| buffer.apply_ops(buffer_ops, cx))?;
954 self.flush_ops(cx);
955
956 Ok(())
957 }
958
959 fn flush_ops(&mut self, cx: &mut ModelContext<Context>) {
960 let mut messages_changed = false;
961 let mut summary_changed = false;
962
963 self.pending_ops.sort_unstable_by_key(|op| op.timestamp());
964 for op in mem::take(&mut self.pending_ops) {
965 if !self.can_apply_op(&op, cx) {
966 self.pending_ops.push(op);
967 continue;
968 }
969
970 let timestamp = op.timestamp();
971 match op.clone() {
972 ContextOperation::InsertMessage {
973 anchor, metadata, ..
974 } => {
975 if self.messages_metadata.contains_key(&anchor.id) {
976 // We already applied this operation.
977 } else {
978 self.insert_message(anchor, metadata, cx);
979 messages_changed = true;
980 }
981 }
982 ContextOperation::UpdateMessage {
983 message_id,
984 metadata: new_metadata,
985 ..
986 } => {
987 let metadata = self.messages_metadata.get_mut(&message_id).unwrap();
988 if new_metadata.timestamp > metadata.timestamp {
989 *metadata = new_metadata;
990 messages_changed = true;
991 }
992 }
993 ContextOperation::UpdateSummary {
994 summary: new_summary,
995 ..
996 } => {
997 if self
998 .summary
999 .as_ref()
1000 .map_or(true, |summary| new_summary.timestamp > summary.timestamp)
1001 {
1002 self.summary = Some(new_summary);
1003 summary_changed = true;
1004 }
1005 }
1006 ContextOperation::SlashCommandFinished {
1007 id,
1008 output_range,
1009 sections,
1010 ..
1011 } => {
1012 if self.finished_slash_commands.insert(id) {
1013 let buffer = self.buffer.read(cx);
1014 self.slash_command_output_sections
1015 .extend(sections.iter().cloned());
1016 self.slash_command_output_sections
1017 .sort_by(|a, b| a.range.cmp(&b.range, buffer));
1018 cx.emit(ContextEvent::SlashCommandFinished {
1019 output_range,
1020 sections,
1021 run_commands_in_output: false,
1022 });
1023 }
1024 }
1025 ContextOperation::BufferOperation(_) => unreachable!(),
1026 }
1027
1028 self.version.observe(timestamp);
1029 self.timestamp.observe(timestamp);
1030 self.operations.push(op);
1031 }
1032
1033 if messages_changed {
1034 cx.emit(ContextEvent::MessagesEdited);
1035 cx.notify();
1036 }
1037
1038 if summary_changed {
1039 cx.emit(ContextEvent::SummaryChanged);
1040 cx.notify();
1041 }
1042 }
1043
1044 fn can_apply_op(&self, op: &ContextOperation, cx: &AppContext) -> bool {
1045 if !self.version.observed_all(op.version()) {
1046 return false;
1047 }
1048
1049 match op {
1050 ContextOperation::InsertMessage { anchor, .. } => self
1051 .buffer
1052 .read(cx)
1053 .version
1054 .observed(anchor.start.timestamp),
1055 ContextOperation::UpdateMessage { message_id, .. } => {
1056 self.messages_metadata.contains_key(message_id)
1057 }
1058 ContextOperation::UpdateSummary { .. } => true,
1059 ContextOperation::SlashCommandFinished {
1060 output_range,
1061 sections,
1062 ..
1063 } => {
1064 let version = &self.buffer.read(cx).version;
1065 sections
1066 .iter()
1067 .map(|section| §ion.range)
1068 .chain([output_range])
1069 .all(|range| {
1070 let observed_start = range.start == language::Anchor::MIN
1071 || range.start == language::Anchor::MAX
1072 || version.observed(range.start.timestamp);
1073 let observed_end = range.end == language::Anchor::MIN
1074 || range.end == language::Anchor::MAX
1075 || version.observed(range.end.timestamp);
1076 observed_start && observed_end
1077 })
1078 }
1079 ContextOperation::BufferOperation(_) => {
1080 panic!("buffer operations should always be applied")
1081 }
1082 }
1083 }
1084
1085 fn push_op(&mut self, op: ContextOperation, cx: &mut ModelContext<Self>) {
1086 self.operations.push(op.clone());
1087 cx.emit(ContextEvent::Operation(op));
1088 }
1089
1090 pub fn buffer(&self) -> &Model<Buffer> {
1091 &self.buffer
1092 }
1093
1094 pub fn path(&self) -> Option<&Path> {
1095 self.path.as_deref()
1096 }
1097
1098 pub fn summary(&self) -> Option<&ContextSummary> {
1099 self.summary.as_ref()
1100 }
1101
1102 pub fn edit_steps(&self) -> &[EditStep] {
1103 &self.edit_steps
1104 }
1105
1106 pub fn pending_slash_commands(&self) -> &[PendingSlashCommand] {
1107 &self.pending_slash_commands
1108 }
1109
1110 pub fn slash_command_output_sections(&self) -> &[SlashCommandOutputSection<language::Anchor>] {
1111 &self.slash_command_output_sections
1112 }
1113
1114 fn set_language(&mut self, cx: &mut ModelContext<Self>) {
1115 let markdown = self.language_registry.language_for_name("Markdown");
1116 cx.spawn(|this, mut cx| async move {
1117 let markdown = markdown.await?;
1118 this.update(&mut cx, |this, cx| {
1119 this.buffer
1120 .update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx));
1121 })
1122 })
1123 .detach_and_log_err(cx);
1124 }
1125
1126 fn handle_buffer_event(
1127 &mut self,
1128 _: Model<Buffer>,
1129 event: &language::Event,
1130 cx: &mut ModelContext<Self>,
1131 ) {
1132 match event {
1133 language::Event::Operation(operation) => cx.emit(ContextEvent::Operation(
1134 ContextOperation::BufferOperation(operation.clone()),
1135 )),
1136 language::Event::Edited => {
1137 self.count_remaining_tokens(cx);
1138 self.reparse_slash_commands(cx);
1139 self.prune_invalid_edit_steps(cx);
1140 cx.emit(ContextEvent::MessagesEdited);
1141 }
1142 _ => {}
1143 }
1144 }
1145
1146 pub(crate) fn token_count(&self) -> Option<usize> {
1147 self.token_count
1148 }
1149
1150 pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1151 let request = self.to_completion_request(cx);
1152 self.pending_token_count = cx.spawn(|this, mut cx| {
1153 async move {
1154 cx.background_executor()
1155 .timer(Duration::from_millis(200))
1156 .await;
1157
1158 let token_count = cx
1159 .update(|cx| {
1160 LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
1161 })?
1162 .await?;
1163 this.update(&mut cx, |this, cx| {
1164 this.token_count = Some(token_count);
1165 cx.notify()
1166 })
1167 }
1168 .log_err()
1169 });
1170 }
1171
1172 pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
1173 let buffer = self.buffer.read(cx);
1174 let mut row_ranges = self
1175 .edits_since_last_slash_command_parse
1176 .consume()
1177 .into_iter()
1178 .map(|edit| {
1179 let start_row = buffer.offset_to_point(edit.new.start).row;
1180 let end_row = buffer.offset_to_point(edit.new.end).row + 1;
1181 start_row..end_row
1182 })
1183 .peekable();
1184
1185 let mut removed = Vec::new();
1186 let mut updated = Vec::new();
1187 while let Some(mut row_range) = row_ranges.next() {
1188 while let Some(next_row_range) = row_ranges.peek() {
1189 if row_range.end >= next_row_range.start {
1190 row_range.end = next_row_range.end;
1191 row_ranges.next();
1192 } else {
1193 break;
1194 }
1195 }
1196
1197 let start = buffer.anchor_before(Point::new(row_range.start, 0));
1198 let end = buffer.anchor_after(Point::new(
1199 row_range.end - 1,
1200 buffer.line_len(row_range.end - 1),
1201 ));
1202
1203 let old_range = self.pending_command_indices_for_range(start..end, cx);
1204
1205 let mut new_commands = Vec::new();
1206 let mut lines = buffer.text_for_range(start..end).lines();
1207 let mut offset = lines.offset();
1208 while let Some(line) = lines.next() {
1209 if let Some(command_line) = SlashCommandLine::parse(line) {
1210 let name = &line[command_line.name.clone()];
1211 let argument = command_line.argument.as_ref().and_then(|argument| {
1212 (!argument.is_empty()).then_some(&line[argument.clone()])
1213 });
1214 if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
1215 if !command.requires_argument() || argument.is_some() {
1216 let start_ix = offset + command_line.name.start - 1;
1217 let end_ix = offset
1218 + command_line
1219 .argument
1220 .map_or(command_line.name.end, |argument| argument.end);
1221 let source_range =
1222 buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix);
1223 let pending_command = PendingSlashCommand {
1224 name: name.to_string(),
1225 argument: argument.map(ToString::to_string),
1226 source_range,
1227 status: PendingSlashCommandStatus::Idle,
1228 };
1229 updated.push(pending_command.clone());
1230 new_commands.push(pending_command);
1231 }
1232 }
1233 }
1234
1235 offset = lines.offset();
1236 }
1237
1238 let removed_commands = self.pending_slash_commands.splice(old_range, new_commands);
1239 removed.extend(removed_commands.map(|command| command.source_range));
1240 }
1241
1242 if !updated.is_empty() || !removed.is_empty() {
1243 cx.emit(ContextEvent::PendingSlashCommandsUpdated { removed, updated });
1244 }
1245 }
1246
1247 fn prune_invalid_edit_steps(&mut self, cx: &mut ModelContext<Self>) {
1248 let buffer = self.buffer.read(cx);
1249 let prev_len = self.edit_steps.len();
1250 self.edit_steps.retain(|step| {
1251 step.source_range.start.is_valid(buffer) && step.source_range.end.is_valid(buffer)
1252 });
1253 if self.edit_steps.len() != prev_len {
1254 cx.emit(ContextEvent::EditStepsChanged);
1255 cx.notify();
1256 }
1257 }
1258
1259 fn parse_edit_steps_in_range(&mut self, range: Range<usize>, cx: &mut ModelContext<Self>) {
1260 let mut new_edit_steps = Vec::new();
1261
1262 self.buffer.update(cx, |buffer, _cx| {
1263 let mut message_lines = buffer.as_rope().chunks_in_range(range).lines();
1264 let mut in_step = false;
1265 let mut step_start = 0;
1266 let mut line_start_offset = message_lines.offset();
1267
1268 while let Some(line) = message_lines.next() {
1269 if let Some(step_start_index) = line.find("<step>") {
1270 if !in_step {
1271 in_step = true;
1272 step_start = line_start_offset + step_start_index;
1273 }
1274 }
1275
1276 if let Some(step_end_index) = line.find("</step>") {
1277 if in_step {
1278 let start_anchor = buffer.anchor_after(step_start);
1279 let end_anchor = buffer
1280 .anchor_before(line_start_offset + step_end_index + "</step>".len());
1281 let source_range = start_anchor..end_anchor;
1282
1283 // Check if a step with the same range already exists
1284 let existing_step_index = self.edit_steps.binary_search_by(|probe| {
1285 probe.source_range.cmp(&source_range, buffer)
1286 });
1287
1288 if let Err(ix) = existing_step_index {
1289 // Step doesn't exist, so add it
1290 new_edit_steps.push((
1291 ix,
1292 EditStep {
1293 source_range,
1294 operations: None,
1295 },
1296 ));
1297 }
1298
1299 in_step = false;
1300 }
1301 }
1302
1303 line_start_offset = message_lines.offset();
1304 }
1305 });
1306
1307 // Insert new steps and generate their corresponding tasks
1308 for (index, mut step) in new_edit_steps.into_iter().rev() {
1309 let task = self.generate_edit_step_operations(&step, cx);
1310 step.operations = Some(EditStepOperations::Pending(task));
1311 self.edit_steps.insert(index, step);
1312 }
1313
1314 cx.emit(ContextEvent::EditStepsChanged);
1315 cx.notify();
1316 }
1317
1318 fn generate_edit_step_operations(
1319 &self,
1320 edit_step: &EditStep,
1321 cx: &mut ModelContext<Self>,
1322 ) -> Task<Option<()>> {
1323 #[derive(Debug, Deserialize, JsonSchema)]
1324 struct EditTool {
1325 /// A sequence of operations to apply to the codebase.
1326 /// When multiple operations are required for a step, be sure to include multiple operations in this list.
1327 operations: Vec<EditOperation>,
1328 }
1329
1330 impl LanguageModelTool for EditTool {
1331 fn name() -> String {
1332 "edit".into()
1333 }
1334
1335 fn description() -> String {
1336 "suggest edits to one or more locations in the codebase".into()
1337 }
1338 }
1339
1340 let mut request = self.to_completion_request(cx);
1341 let edit_step_range = edit_step.source_range.clone();
1342 let step_text = self
1343 .buffer
1344 .read(cx)
1345 .text_for_range(edit_step_range.clone())
1346 .collect::<String>();
1347
1348 cx.spawn(|this, mut cx| {
1349 async move {
1350 let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
1351
1352 let mut prompt = prompt_store.operations_prompt();
1353 prompt.push_str(&step_text);
1354
1355 request.messages.push(LanguageModelRequestMessage {
1356 role: Role::User,
1357 content: prompt,
1358 });
1359
1360 let tool_use = cx
1361 .update(|cx| {
1362 LanguageModelCompletionProvider::read_global(cx)
1363 .use_tool::<EditTool>(request, cx)
1364 })?
1365 .await?;
1366
1367 this.update(&mut cx, |this, cx| {
1368 let step_index = this
1369 .edit_steps
1370 .binary_search_by(|step| {
1371 step.source_range
1372 .cmp(&edit_step_range, this.buffer.read(cx))
1373 })
1374 .map_err(|_| anyhow!("edit step not found"))?;
1375 if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
1376 edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
1377 cx.emit(ContextEvent::EditStepsChanged);
1378 }
1379 anyhow::Ok(())
1380 })?
1381 }
1382 .log_err()
1383 })
1384 }
1385
1386 pub fn pending_command_for_position(
1387 &mut self,
1388 position: language::Anchor,
1389 cx: &mut ModelContext<Self>,
1390 ) -> Option<&mut PendingSlashCommand> {
1391 let buffer = self.buffer.read(cx);
1392 match self
1393 .pending_slash_commands
1394 .binary_search_by(|probe| probe.source_range.end.cmp(&position, buffer))
1395 {
1396 Ok(ix) => Some(&mut self.pending_slash_commands[ix]),
1397 Err(ix) => {
1398 let cmd = self.pending_slash_commands.get_mut(ix)?;
1399 if position.cmp(&cmd.source_range.start, buffer).is_ge()
1400 && position.cmp(&cmd.source_range.end, buffer).is_le()
1401 {
1402 Some(cmd)
1403 } else {
1404 None
1405 }
1406 }
1407 }
1408 }
1409
1410 pub fn pending_commands_for_range(
1411 &self,
1412 range: Range<language::Anchor>,
1413 cx: &AppContext,
1414 ) -> &[PendingSlashCommand] {
1415 let range = self.pending_command_indices_for_range(range, cx);
1416 &self.pending_slash_commands[range]
1417 }
1418
1419 fn pending_command_indices_for_range(
1420 &self,
1421 range: Range<language::Anchor>,
1422 cx: &AppContext,
1423 ) -> Range<usize> {
1424 let buffer = self.buffer.read(cx);
1425 let start_ix = match self
1426 .pending_slash_commands
1427 .binary_search_by(|probe| probe.source_range.end.cmp(&range.start, &buffer))
1428 {
1429 Ok(ix) | Err(ix) => ix,
1430 };
1431 let end_ix = match self
1432 .pending_slash_commands
1433 .binary_search_by(|probe| probe.source_range.start.cmp(&range.end, &buffer))
1434 {
1435 Ok(ix) => ix + 1,
1436 Err(ix) => ix,
1437 };
1438 start_ix..end_ix
1439 }
1440
1441 pub fn insert_command_output(
1442 &mut self,
1443 command_range: Range<language::Anchor>,
1444 output: Task<Result<SlashCommandOutput>>,
1445 insert_trailing_newline: bool,
1446 cx: &mut ModelContext<Self>,
1447 ) {
1448 self.reparse_slash_commands(cx);
1449
1450 let insert_output_task = cx.spawn(|this, mut cx| {
1451 let command_range = command_range.clone();
1452 async move {
1453 let output = output.await;
1454 this.update(&mut cx, |this, cx| match output {
1455 Ok(mut output) => {
1456 if insert_trailing_newline {
1457 output.text.push('\n');
1458 }
1459
1460 let version = this.version.clone();
1461 let command_id = SlashCommandId(this.next_timestamp());
1462 let (operation, event) = this.buffer.update(cx, |buffer, cx| {
1463 let start = command_range.start.to_offset(buffer);
1464 let old_end = command_range.end.to_offset(buffer);
1465 let new_end = start + output.text.len();
1466 buffer.edit([(start..old_end, output.text)], None, cx);
1467
1468 let mut sections = output
1469 .sections
1470 .into_iter()
1471 .map(|section| SlashCommandOutputSection {
1472 range: buffer.anchor_after(start + section.range.start)
1473 ..buffer.anchor_before(start + section.range.end),
1474 icon: section.icon,
1475 label: section.label,
1476 })
1477 .collect::<Vec<_>>();
1478 sections.sort_by(|a, b| a.range.cmp(&b.range, buffer));
1479
1480 this.slash_command_output_sections
1481 .extend(sections.iter().cloned());
1482 this.slash_command_output_sections
1483 .sort_by(|a, b| a.range.cmp(&b.range, buffer));
1484
1485 let output_range =
1486 buffer.anchor_after(start)..buffer.anchor_before(new_end);
1487 this.finished_slash_commands.insert(command_id);
1488
1489 (
1490 ContextOperation::SlashCommandFinished {
1491 id: command_id,
1492 output_range: output_range.clone(),
1493 sections: sections.clone(),
1494 version,
1495 },
1496 ContextEvent::SlashCommandFinished {
1497 output_range,
1498 sections,
1499 run_commands_in_output: output.run_commands_in_text,
1500 },
1501 )
1502 });
1503
1504 this.push_op(operation, cx);
1505 cx.emit(event);
1506 }
1507 Err(error) => {
1508 if let Some(pending_command) =
1509 this.pending_command_for_position(command_range.start, cx)
1510 {
1511 pending_command.status =
1512 PendingSlashCommandStatus::Error(error.to_string());
1513 cx.emit(ContextEvent::PendingSlashCommandsUpdated {
1514 removed: vec![pending_command.source_range.clone()],
1515 updated: vec![pending_command.clone()],
1516 });
1517 }
1518 }
1519 })
1520 .ok();
1521 }
1522 });
1523
1524 if let Some(pending_command) = self.pending_command_for_position(command_range.start, cx) {
1525 pending_command.status = PendingSlashCommandStatus::Running {
1526 _task: insert_output_task.shared(),
1527 };
1528 cx.emit(ContextEvent::PendingSlashCommandsUpdated {
1529 removed: vec![pending_command.source_range.clone()],
1530 updated: vec![pending_command.clone()],
1531 });
1532 }
1533 }
1534
1535 pub fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
1536 self.count_remaining_tokens(cx);
1537 }
1538
1539 pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
1540 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1541 message
1542 .start
1543 .is_valid(self.buffer.read(cx))
1544 .then_some(message.id)
1545 })?;
1546
1547 if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
1548 log::info!("completion provider has no credentials");
1549 return None;
1550 }
1551
1552 let request = self.to_completion_request(cx);
1553 let stream =
1554 LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
1555 let assistant_message = self
1556 .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
1557 .unwrap();
1558
1559 // Queue up the user's next reply.
1560 let user_message = self
1561 .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx)
1562 .unwrap();
1563
1564 let task = cx.spawn({
1565 |this, mut cx| async move {
1566 let assistant_message_id = assistant_message.id;
1567 let mut response_latency = None;
1568 let stream_completion = async {
1569 let request_start = Instant::now();
1570 let mut chunks = stream.await?;
1571
1572 while let Some(chunk) = chunks.next().await {
1573 if response_latency.is_none() {
1574 response_latency = Some(request_start.elapsed());
1575 }
1576 let chunk = chunk?;
1577
1578 this.update(&mut cx, |this, cx| {
1579 let message_ix = this
1580 .message_anchors
1581 .iter()
1582 .position(|message| message.id == assistant_message_id)?;
1583 let message_range = this.buffer.update(cx, |buffer, cx| {
1584 let message_start_offset =
1585 this.message_anchors[message_ix].start.to_offset(buffer);
1586 let message_old_end_offset = this.message_anchors[message_ix + 1..]
1587 .iter()
1588 .find(|message| message.start.is_valid(buffer))
1589 .map_or(buffer.len(), |message| {
1590 message.start.to_offset(buffer).saturating_sub(1)
1591 });
1592 let message_new_end_offset = message_old_end_offset + chunk.len();
1593 buffer.edit(
1594 [(message_old_end_offset..message_old_end_offset, chunk)],
1595 None,
1596 cx,
1597 );
1598 message_start_offset..message_new_end_offset
1599 });
1600 this.parse_edit_steps_in_range(message_range, cx);
1601 cx.emit(ContextEvent::StreamedCompletion);
1602
1603 Some(())
1604 })?;
1605 smol::future::yield_now().await;
1606 }
1607
1608 this.update(&mut cx, |this, cx| {
1609 this.pending_completions
1610 .retain(|completion| completion.id != this.completion_count);
1611 this.summarize(false, cx);
1612 })?;
1613
1614 anyhow::Ok(())
1615 };
1616
1617 let result = stream_completion.await;
1618
1619 this.update(&mut cx, |this, cx| {
1620 let error_message = result
1621 .err()
1622 .map(|error| error.to_string().trim().to_string());
1623
1624 this.update_metadata(assistant_message_id, cx, |metadata| {
1625 if let Some(error_message) = error_message.as_ref() {
1626 metadata.status =
1627 MessageStatus::Error(SharedString::from(error_message.clone()));
1628 } else {
1629 metadata.status = MessageStatus::Done;
1630 }
1631 });
1632
1633 if let Some(telemetry) = this.telemetry.as_ref() {
1634 let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
1635 .active_model()
1636 .map(|m| m.telemetry_id())
1637 .unwrap_or_default();
1638 telemetry.report_assistant_event(
1639 Some(this.id.0.clone()),
1640 AssistantKind::Panel,
1641 model_telemetry_id,
1642 response_latency,
1643 error_message,
1644 );
1645 }
1646 })
1647 .ok();
1648 }
1649 });
1650
1651 self.pending_completions.push(PendingCompletion {
1652 id: post_inc(&mut self.completion_count),
1653 _task: task,
1654 });
1655
1656 Some(user_message)
1657 }
1658
1659 pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest {
1660 let messages = self
1661 .messages(cx)
1662 .filter(|message| matches!(message.status, MessageStatus::Done))
1663 .map(|message| message.to_request_message(self.buffer.read(cx)));
1664
1665 LanguageModelRequest {
1666 messages: messages.collect(),
1667 stop: vec![],
1668 temperature: 1.0,
1669 }
1670 }
1671
1672 pub fn cancel_last_assist(&mut self) -> bool {
1673 self.pending_completions.pop().is_some()
1674 }
1675
1676 pub fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1677 for id in ids {
1678 if let Some(metadata) = self.messages_metadata.get(&id) {
1679 let role = metadata.role.cycle();
1680 self.update_metadata(id, cx, |metadata| metadata.role = role);
1681 }
1682 }
1683 }
1684
1685 pub fn update_metadata(
1686 &mut self,
1687 id: MessageId,
1688 cx: &mut ModelContext<Self>,
1689 f: impl FnOnce(&mut MessageMetadata),
1690 ) {
1691 let version = self.version.clone();
1692 let timestamp = self.next_timestamp();
1693 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1694 f(metadata);
1695 metadata.timestamp = timestamp;
1696 let operation = ContextOperation::UpdateMessage {
1697 message_id: id,
1698 metadata: metadata.clone(),
1699 version,
1700 };
1701 self.push_op(operation, cx);
1702 cx.emit(ContextEvent::MessagesEdited);
1703 cx.notify();
1704 }
1705 }
1706
1707 fn insert_message_after(
1708 &mut self,
1709 message_id: MessageId,
1710 role: Role,
1711 status: MessageStatus,
1712 cx: &mut ModelContext<Self>,
1713 ) -> Option<MessageAnchor> {
1714 if let Some(prev_message_ix) = self
1715 .message_anchors
1716 .iter()
1717 .position(|message| message.id == message_id)
1718 {
1719 // Find the next valid message after the one we were given.
1720 let mut next_message_ix = prev_message_ix + 1;
1721 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1722 if next_message.start.is_valid(self.buffer.read(cx)) {
1723 break;
1724 }
1725 next_message_ix += 1;
1726 }
1727
1728 let start = self.buffer.update(cx, |buffer, cx| {
1729 let offset = self
1730 .message_anchors
1731 .get(next_message_ix)
1732 .map_or(buffer.len(), |message| {
1733 buffer.clip_offset(message.start.to_offset(buffer) - 1, Bias::Left)
1734 });
1735 buffer.edit([(offset..offset, "\n")], None, cx);
1736 buffer.anchor_before(offset + 1)
1737 });
1738
1739 let version = self.version.clone();
1740 let anchor = MessageAnchor {
1741 id: MessageId(self.next_timestamp()),
1742 start,
1743 };
1744 let metadata = MessageMetadata {
1745 role,
1746 status,
1747 timestamp: anchor.id.0,
1748 };
1749 self.insert_message(anchor.clone(), metadata.clone(), cx);
1750 self.push_op(
1751 ContextOperation::InsertMessage {
1752 anchor: anchor.clone(),
1753 metadata,
1754 version,
1755 },
1756 cx,
1757 );
1758 Some(anchor)
1759 } else {
1760 None
1761 }
1762 }
1763
1764 pub fn split_message(
1765 &mut self,
1766 range: Range<usize>,
1767 cx: &mut ModelContext<Self>,
1768 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1769 let start_message = self.message_for_offset(range.start, cx);
1770 let end_message = self.message_for_offset(range.end, cx);
1771 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1772 // Prevent splitting when range spans multiple messages.
1773 if start_message.id != end_message.id {
1774 return (None, None);
1775 }
1776
1777 let message = start_message;
1778 let role = message.role;
1779 let mut edited_buffer = false;
1780
1781 let mut suffix_start = None;
1782 if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1783 {
1784 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1785 suffix_start = Some(range.end + 1);
1786 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1787 suffix_start = Some(range.end);
1788 }
1789 }
1790
1791 let version = self.version.clone();
1792 let suffix = if let Some(suffix_start) = suffix_start {
1793 MessageAnchor {
1794 id: MessageId(self.next_timestamp()),
1795 start: self.buffer.read(cx).anchor_before(suffix_start),
1796 }
1797 } else {
1798 self.buffer.update(cx, |buffer, cx| {
1799 buffer.edit([(range.end..range.end, "\n")], None, cx);
1800 });
1801 edited_buffer = true;
1802 MessageAnchor {
1803 id: MessageId(self.next_timestamp()),
1804 start: self.buffer.read(cx).anchor_before(range.end + 1),
1805 }
1806 };
1807
1808 let suffix_metadata = MessageMetadata {
1809 role,
1810 status: MessageStatus::Done,
1811 timestamp: suffix.id.0,
1812 };
1813 self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
1814 self.push_op(
1815 ContextOperation::InsertMessage {
1816 anchor: suffix.clone(),
1817 metadata: suffix_metadata,
1818 version,
1819 },
1820 cx,
1821 );
1822
1823 let new_messages =
1824 if range.start == range.end || range.start == message.offset_range.start {
1825 (None, Some(suffix))
1826 } else {
1827 let mut prefix_end = None;
1828 if range.start > message.offset_range.start
1829 && range.end < message.offset_range.end - 1
1830 {
1831 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1832 prefix_end = Some(range.start + 1);
1833 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1834 == Some('\n')
1835 {
1836 prefix_end = Some(range.start);
1837 }
1838 }
1839
1840 let version = self.version.clone();
1841 let selection = if let Some(prefix_end) = prefix_end {
1842 MessageAnchor {
1843 id: MessageId(self.next_timestamp()),
1844 start: self.buffer.read(cx).anchor_before(prefix_end),
1845 }
1846 } else {
1847 self.buffer.update(cx, |buffer, cx| {
1848 buffer.edit([(range.start..range.start, "\n")], None, cx)
1849 });
1850 edited_buffer = true;
1851 MessageAnchor {
1852 id: MessageId(self.next_timestamp()),
1853 start: self.buffer.read(cx).anchor_before(range.end + 1),
1854 }
1855 };
1856
1857 let selection_metadata = MessageMetadata {
1858 role,
1859 status: MessageStatus::Done,
1860 timestamp: selection.id.0,
1861 };
1862 self.insert_message(selection.clone(), selection_metadata.clone(), cx);
1863 self.push_op(
1864 ContextOperation::InsertMessage {
1865 anchor: selection.clone(),
1866 metadata: selection_metadata,
1867 version,
1868 },
1869 cx,
1870 );
1871
1872 (Some(selection), Some(suffix))
1873 };
1874
1875 if !edited_buffer {
1876 cx.emit(ContextEvent::MessagesEdited);
1877 }
1878 new_messages
1879 } else {
1880 (None, None)
1881 }
1882 }
1883
1884 fn insert_message(
1885 &mut self,
1886 new_anchor: MessageAnchor,
1887 new_metadata: MessageMetadata,
1888 cx: &mut ModelContext<Self>,
1889 ) {
1890 cx.emit(ContextEvent::MessagesEdited);
1891
1892 self.messages_metadata.insert(new_anchor.id, new_metadata);
1893
1894 let buffer = self.buffer.read(cx);
1895 let insertion_ix = self
1896 .message_anchors
1897 .iter()
1898 .position(|anchor| {
1899 let comparison = new_anchor.start.cmp(&anchor.start, buffer);
1900 comparison.is_lt() || (comparison.is_eq() && new_anchor.id > anchor.id)
1901 })
1902 .unwrap_or(self.message_anchors.len());
1903 self.message_anchors.insert(insertion_ix, new_anchor);
1904 }
1905
1906 pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
1907 if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
1908 if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
1909 return;
1910 }
1911
1912 let messages = self
1913 .messages(cx)
1914 .map(|message| message.to_request_message(self.buffer.read(cx)))
1915 .chain(Some(LanguageModelRequestMessage {
1916 role: Role::User,
1917 content: "Summarize the context into a short title without punctuation.".into(),
1918 }));
1919 let request = LanguageModelRequest {
1920 messages: messages.collect(),
1921 stop: vec![],
1922 temperature: 1.0,
1923 };
1924
1925 let stream =
1926 LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
1927 self.pending_summary = cx.spawn(|this, mut cx| {
1928 async move {
1929 let mut messages = stream.await?;
1930
1931 let mut replaced = !replace_old;
1932 while let Some(message) = messages.next().await {
1933 let text = message?;
1934 let mut lines = text.lines();
1935 this.update(&mut cx, |this, cx| {
1936 let version = this.version.clone();
1937 let timestamp = this.next_timestamp();
1938 let summary = this.summary.get_or_insert(ContextSummary::default());
1939 if !replaced && replace_old {
1940 summary.text.clear();
1941 replaced = true;
1942 }
1943 summary.text.extend(lines.next());
1944 summary.timestamp = timestamp;
1945 let operation = ContextOperation::UpdateSummary {
1946 summary: summary.clone(),
1947 version,
1948 };
1949 this.push_op(operation, cx);
1950 cx.emit(ContextEvent::SummaryChanged);
1951 })?;
1952
1953 // Stop if the LLM generated multiple lines.
1954 if lines.next().is_some() {
1955 break;
1956 }
1957 }
1958
1959 this.update(&mut cx, |this, cx| {
1960 let version = this.version.clone();
1961 let timestamp = this.next_timestamp();
1962 if let Some(summary) = this.summary.as_mut() {
1963 summary.done = true;
1964 summary.timestamp = timestamp;
1965 let operation = ContextOperation::UpdateSummary {
1966 summary: summary.clone(),
1967 version,
1968 };
1969 this.push_op(operation, cx);
1970 cx.emit(ContextEvent::SummaryChanged);
1971 }
1972 })?;
1973
1974 anyhow::Ok(())
1975 }
1976 .log_err()
1977 });
1978 }
1979 }
1980
1981 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1982 self.messages_for_offsets([offset], cx).pop()
1983 }
1984
1985 pub fn messages_for_offsets(
1986 &self,
1987 offsets: impl IntoIterator<Item = usize>,
1988 cx: &AppContext,
1989 ) -> Vec<Message> {
1990 let mut result = Vec::new();
1991
1992 let mut messages = self.messages(cx).peekable();
1993 let mut offsets = offsets.into_iter().peekable();
1994 let mut current_message = messages.next();
1995 while let Some(offset) = offsets.next() {
1996 // Locate the message that contains the offset.
1997 while current_message.as_ref().map_or(false, |message| {
1998 !message.offset_range.contains(&offset) && messages.peek().is_some()
1999 }) {
2000 current_message = messages.next();
2001 }
2002 let Some(message) = current_message.as_ref() else {
2003 break;
2004 };
2005
2006 // Skip offsets that are in the same message.
2007 while offsets.peek().map_or(false, |offset| {
2008 message.offset_range.contains(offset) || messages.peek().is_none()
2009 }) {
2010 offsets.next();
2011 }
2012
2013 result.push(message.clone());
2014 }
2015 result
2016 }
2017
2018 pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
2019 let buffer = self.buffer.read(cx);
2020 let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
2021 iter::from_fn(move || {
2022 if let Some((start_ix, message_anchor)) = message_anchors.next() {
2023 let metadata = self.messages_metadata.get(&message_anchor.id)?;
2024 let message_start = message_anchor.start.to_offset(buffer);
2025 let mut message_end = None;
2026 let mut end_ix = start_ix;
2027 while let Some((_, next_message)) = message_anchors.peek() {
2028 if next_message.start.is_valid(buffer) {
2029 message_end = Some(next_message.start);
2030 break;
2031 } else {
2032 end_ix += 1;
2033 message_anchors.next();
2034 }
2035 }
2036 let message_end = message_end
2037 .unwrap_or(language::Anchor::MAX)
2038 .to_offset(buffer);
2039
2040 return Some(Message {
2041 index_range: start_ix..end_ix,
2042 offset_range: message_start..message_end,
2043 id: message_anchor.id,
2044 anchor: message_anchor.start,
2045 role: metadata.role,
2046 status: metadata.status.clone(),
2047 });
2048 }
2049 None
2050 })
2051 }
2052
2053 pub fn save(
2054 &mut self,
2055 debounce: Option<Duration>,
2056 fs: Arc<dyn Fs>,
2057 cx: &mut ModelContext<Context>,
2058 ) {
2059 if self.replica_id() != ReplicaId::default() {
2060 // Prevent saving a remote context for now.
2061 return;
2062 }
2063
2064 self.pending_save = cx.spawn(|this, mut cx| async move {
2065 if let Some(debounce) = debounce {
2066 cx.background_executor().timer(debounce).await;
2067 }
2068
2069 let (old_path, summary) = this.read_with(&cx, |this, _| {
2070 let path = this.path.clone();
2071 let summary = if let Some(summary) = this.summary.as_ref() {
2072 if summary.done {
2073 Some(summary.text.clone())
2074 } else {
2075 None
2076 }
2077 } else {
2078 None
2079 };
2080 (path, summary)
2081 })?;
2082
2083 if let Some(summary) = summary {
2084 let context = this.read_with(&cx, |this, cx| this.serialize(cx))?;
2085 let mut discriminant = 1;
2086 let mut new_path;
2087 loop {
2088 new_path = contexts_dir().join(&format!(
2089 "{} - {}.zed.json",
2090 summary.trim(),
2091 discriminant
2092 ));
2093 if fs.is_file(&new_path).await {
2094 discriminant += 1;
2095 } else {
2096 break;
2097 }
2098 }
2099
2100 fs.create_dir(contexts_dir().as_ref()).await?;
2101 fs.atomic_write(new_path.clone(), serde_json::to_string(&context).unwrap())
2102 .await?;
2103 if let Some(old_path) = old_path {
2104 if new_path != old_path {
2105 fs.remove_file(
2106 &old_path,
2107 RemoveOptions {
2108 recursive: false,
2109 ignore_if_not_exists: true,
2110 },
2111 )
2112 .await?;
2113 }
2114 }
2115
2116 this.update(&mut cx, |this, _| this.path = Some(new_path))?;
2117 }
2118
2119 Ok(())
2120 });
2121 }
2122
2123 pub(crate) fn custom_summary(&mut self, custom_summary: String, cx: &mut ModelContext<Self>) {
2124 let timestamp = self.next_timestamp();
2125 let summary = self.summary.get_or_insert(ContextSummary::default());
2126 summary.timestamp = timestamp;
2127 summary.done = true;
2128 summary.text = custom_summary;
2129 cx.emit(ContextEvent::SummaryChanged);
2130 }
2131}
2132
2133#[derive(Debug, Default)]
2134pub struct ContextVersion {
2135 context: clock::Global,
2136 buffer: clock::Global,
2137}
2138
2139impl ContextVersion {
2140 pub fn from_proto(proto: &proto::ContextVersion) -> Self {
2141 Self {
2142 context: language::proto::deserialize_version(&proto.context_version),
2143 buffer: language::proto::deserialize_version(&proto.buffer_version),
2144 }
2145 }
2146
2147 pub fn to_proto(&self, context_id: ContextId) -> proto::ContextVersion {
2148 proto::ContextVersion {
2149 context_id: context_id.to_proto(),
2150 context_version: language::proto::serialize_version(&self.context),
2151 buffer_version: language::proto::serialize_version(&self.buffer),
2152 }
2153 }
2154}
2155
2156#[derive(Clone)]
2157pub struct PendingSlashCommand {
2158 pub name: String,
2159 pub argument: Option<String>,
2160 pub status: PendingSlashCommandStatus,
2161 pub source_range: Range<language::Anchor>,
2162}
2163
2164#[derive(Clone)]
2165pub enum PendingSlashCommandStatus {
2166 Idle,
2167 Running { _task: Shared<Task<()>> },
2168 Error(String),
2169}
2170
2171#[derive(Serialize, Deserialize)]
2172pub struct SavedMessage {
2173 pub id: MessageId,
2174 pub start: usize,
2175 pub metadata: MessageMetadata,
2176}
2177
2178#[derive(Serialize, Deserialize)]
2179pub struct SavedContext {
2180 pub id: Option<ContextId>,
2181 pub zed: String,
2182 pub version: String,
2183 pub text: String,
2184 pub messages: Vec<SavedMessage>,
2185 pub summary: String,
2186 pub slash_command_output_sections:
2187 Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
2188}
2189
2190impl SavedContext {
2191 pub const VERSION: &'static str = "0.4.0";
2192
2193 pub fn from_json(json: &str) -> Result<Self> {
2194 let saved_context_json = serde_json::from_str::<serde_json::Value>(json)?;
2195 match saved_context_json
2196 .get("version")
2197 .ok_or_else(|| anyhow!("version not found"))?
2198 {
2199 serde_json::Value::String(version) => match version.as_str() {
2200 SavedContext::VERSION => {
2201 Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
2202 }
2203 SavedContextV0_3_0::VERSION => {
2204 let saved_context =
2205 serde_json::from_value::<SavedContextV0_3_0>(saved_context_json)?;
2206 Ok(saved_context.upgrade())
2207 }
2208 SavedContextV0_2_0::VERSION => {
2209 let saved_context =
2210 serde_json::from_value::<SavedContextV0_2_0>(saved_context_json)?;
2211 Ok(saved_context.upgrade())
2212 }
2213 SavedContextV0_1_0::VERSION => {
2214 let saved_context =
2215 serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?;
2216 Ok(saved_context.upgrade())
2217 }
2218 _ => Err(anyhow!("unrecognized saved context version: {}", version)),
2219 },
2220 _ => Err(anyhow!("version not found on saved context")),
2221 }
2222 }
2223
2224 fn into_ops(
2225 self,
2226 buffer: &Model<Buffer>,
2227 cx: &mut ModelContext<Context>,
2228 ) -> Vec<ContextOperation> {
2229 let mut operations = Vec::new();
2230 let mut version = clock::Global::new();
2231 let mut next_timestamp = clock::Lamport::new(ReplicaId::default());
2232
2233 let mut first_message_metadata = None;
2234 for message in self.messages {
2235 if message.id == MessageId(clock::Lamport::default()) {
2236 first_message_metadata = Some(message.metadata);
2237 } else {
2238 operations.push(ContextOperation::InsertMessage {
2239 anchor: MessageAnchor {
2240 id: message.id,
2241 start: buffer.read(cx).anchor_before(message.start),
2242 },
2243 metadata: MessageMetadata {
2244 role: message.metadata.role,
2245 status: message.metadata.status,
2246 timestamp: message.metadata.timestamp,
2247 },
2248 version: version.clone(),
2249 });
2250 version.observe(message.id.0);
2251 next_timestamp.observe(message.id.0);
2252 }
2253 }
2254
2255 if let Some(metadata) = first_message_metadata {
2256 let timestamp = next_timestamp.tick();
2257 operations.push(ContextOperation::UpdateMessage {
2258 message_id: MessageId(clock::Lamport::default()),
2259 metadata: MessageMetadata {
2260 role: metadata.role,
2261 status: metadata.status,
2262 timestamp,
2263 },
2264 version: version.clone(),
2265 });
2266 version.observe(timestamp);
2267 }
2268
2269 let timestamp = next_timestamp.tick();
2270 operations.push(ContextOperation::SlashCommandFinished {
2271 id: SlashCommandId(timestamp),
2272 output_range: language::Anchor::MIN..language::Anchor::MAX,
2273 sections: self
2274 .slash_command_output_sections
2275 .into_iter()
2276 .map(|section| {
2277 let buffer = buffer.read(cx);
2278 SlashCommandOutputSection {
2279 range: buffer.anchor_after(section.range.start)
2280 ..buffer.anchor_before(section.range.end),
2281 icon: section.icon,
2282 label: section.label,
2283 }
2284 })
2285 .collect(),
2286 version: version.clone(),
2287 });
2288 version.observe(timestamp);
2289
2290 let timestamp = next_timestamp.tick();
2291 operations.push(ContextOperation::UpdateSummary {
2292 summary: ContextSummary {
2293 text: self.summary,
2294 done: true,
2295 timestamp,
2296 },
2297 version: version.clone(),
2298 });
2299 version.observe(timestamp);
2300
2301 operations
2302 }
2303}
2304
2305#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
2306struct SavedMessageIdPreV0_4_0(usize);
2307
2308#[derive(Serialize, Deserialize)]
2309struct SavedMessagePreV0_4_0 {
2310 id: SavedMessageIdPreV0_4_0,
2311 start: usize,
2312}
2313
2314#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
2315struct SavedMessageMetadataPreV0_4_0 {
2316 role: Role,
2317 status: MessageStatus,
2318}
2319
2320#[derive(Serialize, Deserialize)]
2321struct SavedContextV0_3_0 {
2322 id: Option<ContextId>,
2323 zed: String,
2324 version: String,
2325 text: String,
2326 messages: Vec<SavedMessagePreV0_4_0>,
2327 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2328 summary: String,
2329 slash_command_output_sections: Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
2330}
2331
2332impl SavedContextV0_3_0 {
2333 const VERSION: &'static str = "0.3.0";
2334
2335 fn upgrade(self) -> SavedContext {
2336 SavedContext {
2337 id: self.id,
2338 zed: self.zed,
2339 version: SavedContext::VERSION.into(),
2340 text: self.text,
2341 messages: self
2342 .messages
2343 .into_iter()
2344 .filter_map(|message| {
2345 let metadata = self.message_metadata.get(&message.id)?;
2346 let timestamp = clock::Lamport {
2347 replica_id: ReplicaId::default(),
2348 value: message.id.0 as u32,
2349 };
2350 Some(SavedMessage {
2351 id: MessageId(timestamp),
2352 start: message.start,
2353 metadata: MessageMetadata {
2354 role: metadata.role,
2355 status: metadata.status.clone(),
2356 timestamp,
2357 },
2358 })
2359 })
2360 .collect(),
2361 summary: self.summary,
2362 slash_command_output_sections: self.slash_command_output_sections,
2363 }
2364 }
2365}
2366
2367#[derive(Serialize, Deserialize)]
2368struct SavedContextV0_2_0 {
2369 id: Option<ContextId>,
2370 zed: String,
2371 version: String,
2372 text: String,
2373 messages: Vec<SavedMessagePreV0_4_0>,
2374 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2375 summary: String,
2376}
2377
2378impl SavedContextV0_2_0 {
2379 const VERSION: &'static str = "0.2.0";
2380
2381 fn upgrade(self) -> SavedContext {
2382 SavedContextV0_3_0 {
2383 id: self.id,
2384 zed: self.zed,
2385 version: SavedContextV0_3_0::VERSION.to_string(),
2386 text: self.text,
2387 messages: self.messages,
2388 message_metadata: self.message_metadata,
2389 summary: self.summary,
2390 slash_command_output_sections: Vec::new(),
2391 }
2392 .upgrade()
2393 }
2394}
2395
2396#[derive(Serialize, Deserialize)]
2397struct SavedContextV0_1_0 {
2398 id: Option<ContextId>,
2399 zed: String,
2400 version: String,
2401 text: String,
2402 messages: Vec<SavedMessagePreV0_4_0>,
2403 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2404 summary: String,
2405 api_url: Option<String>,
2406 model: OpenAiModel,
2407}
2408
2409impl SavedContextV0_1_0 {
2410 const VERSION: &'static str = "0.1.0";
2411
2412 fn upgrade(self) -> SavedContext {
2413 SavedContextV0_2_0 {
2414 id: self.id,
2415 zed: self.zed,
2416 version: SavedContextV0_2_0::VERSION.to_string(),
2417 text: self.text,
2418 messages: self.messages,
2419 message_metadata: self.message_metadata,
2420 summary: self.summary,
2421 }
2422 .upgrade()
2423 }
2424}
2425
2426#[derive(Clone)]
2427pub struct SavedContextMetadata {
2428 pub title: String,
2429 pub path: PathBuf,
2430 pub mtime: chrono::DateTime<chrono::Local>,
2431}
2432
2433#[cfg(test)]
2434mod tests {
2435 use super::*;
2436 use crate::{
2437 assistant_panel, prompt_library,
2438 slash_command::{active_command, file_command},
2439 MessageId,
2440 };
2441 use assistant_slash_command::{ArgumentCompletion, SlashCommand};
2442 use fs::FakeFs;
2443 use gpui::{AppContext, TestAppContext, WeakView};
2444 use indoc::indoc;
2445 use language::LspAdapterDelegate;
2446 use parking_lot::Mutex;
2447 use project::Project;
2448 use rand::prelude::*;
2449 use serde_json::json;
2450 use settings::SettingsStore;
2451 use std::{cell::RefCell, env, rc::Rc, sync::atomic::AtomicBool};
2452 use text::{network::Network, ToPoint};
2453 use ui::WindowContext;
2454 use unindent::Unindent;
2455 use util::{test::marked_text_ranges, RandomCharIter};
2456 use workspace::Workspace;
2457
2458 #[gpui::test]
2459 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2460 let settings_store = SettingsStore::test(cx);
2461 language_model::LanguageModelRegistry::test(cx);
2462 completion::LanguageModelCompletionProvider::test(cx);
2463 cx.set_global(settings_store);
2464 assistant_panel::init(cx);
2465 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2466
2467 let context = cx.new_model(|cx| Context::local(registry, None, cx));
2468 let buffer = context.read(cx).buffer.clone();
2469
2470 let message_1 = context.read(cx).message_anchors[0].clone();
2471 assert_eq!(
2472 messages(&context, cx),
2473 vec![(message_1.id, Role::User, 0..0)]
2474 );
2475
2476 let message_2 = context.update(cx, |context, cx| {
2477 context
2478 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2479 .unwrap()
2480 });
2481 assert_eq!(
2482 messages(&context, cx),
2483 vec![
2484 (message_1.id, Role::User, 0..1),
2485 (message_2.id, Role::Assistant, 1..1)
2486 ]
2487 );
2488
2489 buffer.update(cx, |buffer, cx| {
2490 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2491 });
2492 assert_eq!(
2493 messages(&context, cx),
2494 vec![
2495 (message_1.id, Role::User, 0..2),
2496 (message_2.id, Role::Assistant, 2..3)
2497 ]
2498 );
2499
2500 let message_3 = context.update(cx, |context, cx| {
2501 context
2502 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2503 .unwrap()
2504 });
2505 assert_eq!(
2506 messages(&context, cx),
2507 vec![
2508 (message_1.id, Role::User, 0..2),
2509 (message_2.id, Role::Assistant, 2..4),
2510 (message_3.id, Role::User, 4..4)
2511 ]
2512 );
2513
2514 let message_4 = context.update(cx, |context, cx| {
2515 context
2516 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2517 .unwrap()
2518 });
2519 assert_eq!(
2520 messages(&context, cx),
2521 vec![
2522 (message_1.id, Role::User, 0..2),
2523 (message_2.id, Role::Assistant, 2..4),
2524 (message_4.id, Role::User, 4..5),
2525 (message_3.id, Role::User, 5..5),
2526 ]
2527 );
2528
2529 buffer.update(cx, |buffer, cx| {
2530 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2531 });
2532 assert_eq!(
2533 messages(&context, cx),
2534 vec![
2535 (message_1.id, Role::User, 0..2),
2536 (message_2.id, Role::Assistant, 2..4),
2537 (message_4.id, Role::User, 4..6),
2538 (message_3.id, Role::User, 6..7),
2539 ]
2540 );
2541
2542 // Deleting across message boundaries merges the messages.
2543 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2544 assert_eq!(
2545 messages(&context, cx),
2546 vec![
2547 (message_1.id, Role::User, 0..3),
2548 (message_3.id, Role::User, 3..4),
2549 ]
2550 );
2551
2552 // Undoing the deletion should also undo the merge.
2553 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2554 assert_eq!(
2555 messages(&context, cx),
2556 vec![
2557 (message_1.id, Role::User, 0..2),
2558 (message_2.id, Role::Assistant, 2..4),
2559 (message_4.id, Role::User, 4..6),
2560 (message_3.id, Role::User, 6..7),
2561 ]
2562 );
2563
2564 // Redoing the deletion should also redo the merge.
2565 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2566 assert_eq!(
2567 messages(&context, cx),
2568 vec![
2569 (message_1.id, Role::User, 0..3),
2570 (message_3.id, Role::User, 3..4),
2571 ]
2572 );
2573
2574 // Ensure we can still insert after a merged message.
2575 let message_5 = context.update(cx, |context, cx| {
2576 context
2577 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2578 .unwrap()
2579 });
2580 assert_eq!(
2581 messages(&context, cx),
2582 vec![
2583 (message_1.id, Role::User, 0..3),
2584 (message_5.id, Role::System, 3..4),
2585 (message_3.id, Role::User, 4..5)
2586 ]
2587 );
2588 }
2589
2590 #[gpui::test]
2591 fn test_message_splitting(cx: &mut AppContext) {
2592 let settings_store = SettingsStore::test(cx);
2593 cx.set_global(settings_store);
2594 language_model::LanguageModelRegistry::test(cx);
2595 completion::LanguageModelCompletionProvider::test(cx);
2596 assistant_panel::init(cx);
2597 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2598
2599 let context = cx.new_model(|cx| Context::local(registry, None, cx));
2600 let buffer = context.read(cx).buffer.clone();
2601
2602 let message_1 = context.read(cx).message_anchors[0].clone();
2603 assert_eq!(
2604 messages(&context, cx),
2605 vec![(message_1.id, Role::User, 0..0)]
2606 );
2607
2608 buffer.update(cx, |buffer, cx| {
2609 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2610 });
2611
2612 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
2613 let message_2 = message_2.unwrap();
2614
2615 // We recycle newlines in the middle of a split message
2616 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2617 assert_eq!(
2618 messages(&context, cx),
2619 vec![
2620 (message_1.id, Role::User, 0..4),
2621 (message_2.id, Role::User, 4..16),
2622 ]
2623 );
2624
2625 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
2626 let message_3 = message_3.unwrap();
2627
2628 // We don't recycle newlines at the end of a split message
2629 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2630 assert_eq!(
2631 messages(&context, cx),
2632 vec![
2633 (message_1.id, Role::User, 0..4),
2634 (message_3.id, Role::User, 4..5),
2635 (message_2.id, Role::User, 5..17),
2636 ]
2637 );
2638
2639 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
2640 let message_4 = message_4.unwrap();
2641 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2642 assert_eq!(
2643 messages(&context, cx),
2644 vec![
2645 (message_1.id, Role::User, 0..4),
2646 (message_3.id, Role::User, 4..5),
2647 (message_2.id, Role::User, 5..9),
2648 (message_4.id, Role::User, 9..17),
2649 ]
2650 );
2651
2652 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
2653 let message_5 = message_5.unwrap();
2654 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2655 assert_eq!(
2656 messages(&context, cx),
2657 vec![
2658 (message_1.id, Role::User, 0..4),
2659 (message_3.id, Role::User, 4..5),
2660 (message_2.id, Role::User, 5..9),
2661 (message_4.id, Role::User, 9..10),
2662 (message_5.id, Role::User, 10..18),
2663 ]
2664 );
2665
2666 let (message_6, message_7) =
2667 context.update(cx, |context, cx| context.split_message(14..16, cx));
2668 let message_6 = message_6.unwrap();
2669 let message_7 = message_7.unwrap();
2670 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2671 assert_eq!(
2672 messages(&context, cx),
2673 vec![
2674 (message_1.id, Role::User, 0..4),
2675 (message_3.id, Role::User, 4..5),
2676 (message_2.id, Role::User, 5..9),
2677 (message_4.id, Role::User, 9..10),
2678 (message_5.id, Role::User, 10..14),
2679 (message_6.id, Role::User, 14..17),
2680 (message_7.id, Role::User, 17..19),
2681 ]
2682 );
2683 }
2684
2685 #[gpui::test]
2686 fn test_messages_for_offsets(cx: &mut AppContext) {
2687 let settings_store = SettingsStore::test(cx);
2688 language_model::LanguageModelRegistry::test(cx);
2689 completion::LanguageModelCompletionProvider::test(cx);
2690 cx.set_global(settings_store);
2691 assistant_panel::init(cx);
2692 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2693 let context = cx.new_model(|cx| Context::local(registry, None, cx));
2694 let buffer = context.read(cx).buffer.clone();
2695
2696 let message_1 = context.read(cx).message_anchors[0].clone();
2697 assert_eq!(
2698 messages(&context, cx),
2699 vec![(message_1.id, Role::User, 0..0)]
2700 );
2701
2702 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2703 let message_2 = context
2704 .update(cx, |context, cx| {
2705 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2706 })
2707 .unwrap();
2708 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2709
2710 let message_3 = context
2711 .update(cx, |context, cx| {
2712 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2713 })
2714 .unwrap();
2715 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2716
2717 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2718 assert_eq!(
2719 messages(&context, cx),
2720 vec![
2721 (message_1.id, Role::User, 0..4),
2722 (message_2.id, Role::User, 4..8),
2723 (message_3.id, Role::User, 8..11)
2724 ]
2725 );
2726
2727 assert_eq!(
2728 message_ids_for_offsets(&context, &[0, 4, 9], cx),
2729 [message_1.id, message_2.id, message_3.id]
2730 );
2731 assert_eq!(
2732 message_ids_for_offsets(&context, &[0, 1, 11], cx),
2733 [message_1.id, message_3.id]
2734 );
2735
2736 let message_4 = context
2737 .update(cx, |context, cx| {
2738 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2739 })
2740 .unwrap();
2741 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2742 assert_eq!(
2743 messages(&context, cx),
2744 vec![
2745 (message_1.id, Role::User, 0..4),
2746 (message_2.id, Role::User, 4..8),
2747 (message_3.id, Role::User, 8..12),
2748 (message_4.id, Role::User, 12..12)
2749 ]
2750 );
2751 assert_eq!(
2752 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
2753 [message_1.id, message_2.id, message_3.id, message_4.id]
2754 );
2755
2756 fn message_ids_for_offsets(
2757 context: &Model<Context>,
2758 offsets: &[usize],
2759 cx: &AppContext,
2760 ) -> Vec<MessageId> {
2761 context
2762 .read(cx)
2763 .messages_for_offsets(offsets.iter().copied(), cx)
2764 .into_iter()
2765 .map(|message| message.id)
2766 .collect()
2767 }
2768 }
2769
2770 #[gpui::test]
2771 async fn test_slash_commands(cx: &mut TestAppContext) {
2772 let settings_store = cx.update(SettingsStore::test);
2773 cx.set_global(settings_store);
2774 cx.update(language_model::LanguageModelRegistry::test);
2775 cx.update(completion::LanguageModelCompletionProvider::test);
2776 cx.update(Project::init_settings);
2777 cx.update(assistant_panel::init);
2778 let fs = FakeFs::new(cx.background_executor.clone());
2779
2780 fs.insert_tree(
2781 "/test",
2782 json!({
2783 "src": {
2784 "lib.rs": "fn one() -> usize { 1 }",
2785 "main.rs": "
2786 use crate::one;
2787 fn main() { one(); }
2788 ".unindent(),
2789 }
2790 }),
2791 )
2792 .await;
2793
2794 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
2795 slash_command_registry.register_command(file_command::FileSlashCommand, false);
2796 slash_command_registry.register_command(active_command::ActiveSlashCommand, false);
2797
2798 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
2799 let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
2800
2801 let output_ranges = Rc::new(RefCell::new(HashSet::default()));
2802 context.update(cx, |_, cx| {
2803 cx.subscribe(&context, {
2804 let ranges = output_ranges.clone();
2805 move |_, _, event, _| match event {
2806 ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
2807 for range in removed {
2808 ranges.borrow_mut().remove(range);
2809 }
2810 for command in updated {
2811 ranges.borrow_mut().insert(command.source_range.clone());
2812 }
2813 }
2814 _ => {}
2815 }
2816 })
2817 .detach();
2818 });
2819
2820 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
2821
2822 // Insert a slash command
2823 buffer.update(cx, |buffer, cx| {
2824 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
2825 });
2826 assert_text_and_output_ranges(
2827 &buffer,
2828 &output_ranges.borrow(),
2829 "
2830 «/file src/lib.rs»
2831 "
2832 .unindent()
2833 .trim_end(),
2834 cx,
2835 );
2836
2837 // Edit the argument of the slash command.
2838 buffer.update(cx, |buffer, cx| {
2839 let edit_offset = buffer.text().find("lib.rs").unwrap();
2840 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
2841 });
2842 assert_text_and_output_ranges(
2843 &buffer,
2844 &output_ranges.borrow(),
2845 "
2846 «/file src/main.rs»
2847 "
2848 .unindent()
2849 .trim_end(),
2850 cx,
2851 );
2852
2853 // Edit the name of the slash command, using one that doesn't exist.
2854 buffer.update(cx, |buffer, cx| {
2855 let edit_offset = buffer.text().find("/file").unwrap();
2856 buffer.edit(
2857 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
2858 None,
2859 cx,
2860 );
2861 });
2862 assert_text_and_output_ranges(
2863 &buffer,
2864 &output_ranges.borrow(),
2865 "
2866 /unknown src/main.rs
2867 "
2868 .unindent()
2869 .trim_end(),
2870 cx,
2871 );
2872
2873 #[track_caller]
2874 fn assert_text_and_output_ranges(
2875 buffer: &Model<Buffer>,
2876 ranges: &HashSet<Range<language::Anchor>>,
2877 expected_marked_text: &str,
2878 cx: &mut TestAppContext,
2879 ) {
2880 let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
2881 let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
2882 let mut ranges = ranges
2883 .iter()
2884 .map(|range| range.to_offset(buffer))
2885 .collect::<Vec<_>>();
2886 ranges.sort_by_key(|a| a.start);
2887 (buffer.text(), ranges)
2888 });
2889
2890 assert_eq!(actual_text, expected_text);
2891 assert_eq!(actual_ranges, expected_ranges);
2892 }
2893 }
2894
2895 #[gpui::test]
2896 async fn test_edit_step_parsing(cx: &mut TestAppContext) {
2897 cx.update(prompt_library::init);
2898 let settings_store = cx.update(SettingsStore::test);
2899 cx.set_global(settings_store);
2900
2901 let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
2902 cx.update(completion::LanguageModelCompletionProvider::test);
2903
2904 let fake_model = fake_provider.test_model();
2905 cx.update(assistant_panel::init);
2906 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
2907
2908 // Create a new context
2909 let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
2910 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
2911
2912 // Simulate user input
2913 let user_message = indoc! {r#"
2914 Please refactor this code:
2915
2916 fn main() {
2917 println!("Hello, World!");
2918 }
2919 "#};
2920 buffer.update(cx, |buffer, cx| {
2921 buffer.edit([(0..0, user_message)], None, cx);
2922 });
2923
2924 // Simulate LLM response with edit steps
2925 let llm_response = indoc! {r#"
2926 Sure, I can help you refactor that code. Here's a step-by-step process:
2927
2928 <step>
2929 First, let's extract the greeting into a separate function:
2930
2931 ```rust
2932 fn greet() {
2933 println!("Hello, World!");
2934 }
2935
2936 fn main() {
2937 greet();
2938 }
2939 ```
2940 </step>
2941
2942 <step>
2943 Now, let's make the greeting customizable:
2944
2945 ```rust
2946 fn greet(name: &str) {
2947 println!("Hello, {}!", name);
2948 }
2949
2950 fn main() {
2951 greet("World");
2952 }
2953 ```
2954 </step>
2955
2956 These changes make the code more modular and flexible.
2957 "#};
2958
2959 // Simulate the assist method to trigger the LLM response
2960 context.update(cx, |context, cx| context.assist(cx));
2961 cx.run_until_parked();
2962
2963 // Retrieve the assistant response message's start from the context
2964 let response_start_row = context.read_with(cx, |context, cx| {
2965 let buffer = context.buffer.read(cx);
2966 context.message_anchors[1].start.to_point(buffer).row
2967 });
2968
2969 // Simulate the LLM completion
2970 fake_model.send_last_completion_chunk(llm_response.to_string());
2971 fake_model.finish_last_completion();
2972
2973 // Wait for the completion to be processed
2974 cx.run_until_parked();
2975
2976 // Verify that the edit steps were parsed correctly
2977 context.read_with(cx, |context, cx| {
2978 assert_eq!(
2979 edit_steps(context, cx),
2980 vec![
2981 Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 14, 7),
2982 Point::new(response_start_row + 16, 0)..Point::new(response_start_row + 28, 7),
2983 ]
2984 );
2985 });
2986
2987 fn edit_steps(context: &Context, cx: &AppContext) -> Vec<Range<Point>> {
2988 context
2989 .edit_steps
2990 .iter()
2991 .map(|step| {
2992 let buffer = context.buffer.read(cx);
2993 step.source_range.to_point(buffer)
2994 })
2995 .collect()
2996 }
2997 }
2998
2999 #[gpui::test]
3000 async fn test_serialization(cx: &mut TestAppContext) {
3001 let settings_store = cx.update(SettingsStore::test);
3002 cx.set_global(settings_store);
3003 cx.update(language_model::LanguageModelRegistry::test);
3004 cx.update(completion::LanguageModelCompletionProvider::test);
3005 cx.update(assistant_panel::init);
3006 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
3007 let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
3008 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
3009 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
3010 let message_1 = context.update(cx, |context, cx| {
3011 context
3012 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
3013 .unwrap()
3014 });
3015 let message_2 = context.update(cx, |context, cx| {
3016 context
3017 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
3018 .unwrap()
3019 });
3020 buffer.update(cx, |buffer, cx| {
3021 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
3022 buffer.finalize_last_transaction();
3023 });
3024 let _message_3 = context.update(cx, |context, cx| {
3025 context
3026 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
3027 .unwrap()
3028 });
3029 buffer.update(cx, |buffer, cx| buffer.undo(cx));
3030 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
3031 assert_eq!(
3032 cx.read(|cx| messages(&context, cx)),
3033 [
3034 (message_0, Role::User, 0..2),
3035 (message_1.id, Role::Assistant, 2..6),
3036 (message_2.id, Role::System, 6..6),
3037 ]
3038 );
3039
3040 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
3041 let deserialized_context = cx.new_model(|cx| {
3042 Context::deserialize(
3043 serialized_context,
3044 Default::default(),
3045 registry.clone(),
3046 None,
3047 cx,
3048 )
3049 });
3050 let deserialized_buffer =
3051 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
3052 assert_eq!(
3053 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
3054 "a\nb\nc\n"
3055 );
3056 assert_eq!(
3057 cx.read(|cx| messages(&deserialized_context, cx)),
3058 [
3059 (message_0, Role::User, 0..2),
3060 (message_1.id, Role::Assistant, 2..6),
3061 (message_2.id, Role::System, 6..6),
3062 ]
3063 );
3064 }
3065
3066 #[gpui::test(iterations = 100)]
3067 async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
3068 let min_peers = env::var("MIN_PEERS")
3069 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
3070 .unwrap_or(2);
3071 let max_peers = env::var("MAX_PEERS")
3072 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
3073 .unwrap_or(5);
3074 let operations = env::var("OPERATIONS")
3075 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
3076 .unwrap_or(50);
3077
3078 let settings_store = cx.update(SettingsStore::test);
3079 cx.set_global(settings_store);
3080 cx.update(language_model::LanguageModelRegistry::test);
3081 cx.update(completion::LanguageModelCompletionProvider::test);
3082
3083 cx.update(assistant_panel::init);
3084 let slash_commands = cx.update(SlashCommandRegistry::default_global);
3085 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
3086 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
3087 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
3088
3089 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
3090 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
3091 let mut contexts = Vec::new();
3092
3093 let num_peers = rng.gen_range(min_peers..=max_peers);
3094 let context_id = ContextId::new();
3095 for i in 0..num_peers {
3096 let context = cx.new_model(|cx| {
3097 Context::new(
3098 context_id.clone(),
3099 i as ReplicaId,
3100 language::Capability::ReadWrite,
3101 registry.clone(),
3102 None,
3103 cx,
3104 )
3105 });
3106
3107 cx.update(|cx| {
3108 cx.subscribe(&context, {
3109 let network = network.clone();
3110 move |_, event, _| {
3111 if let ContextEvent::Operation(op) = event {
3112 network
3113 .lock()
3114 .broadcast(i as ReplicaId, vec![op.to_proto()]);
3115 }
3116 }
3117 })
3118 .detach();
3119 });
3120
3121 contexts.push(context);
3122 network.lock().add_peer(i as ReplicaId);
3123 }
3124
3125 let mut mutation_count = operations;
3126
3127 while mutation_count > 0
3128 || !network.lock().is_idle()
3129 || network.lock().contains_disconnected_peers()
3130 {
3131 let context_index = rng.gen_range(0..contexts.len());
3132 let context = &contexts[context_index];
3133
3134 match rng.gen_range(0..100) {
3135 0..=29 if mutation_count > 0 => {
3136 log::info!("Context {}: edit buffer", context_index);
3137 context.update(cx, |context, cx| {
3138 context
3139 .buffer
3140 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
3141 });
3142 mutation_count -= 1;
3143 }
3144 30..=44 if mutation_count > 0 => {
3145 context.update(cx, |context, cx| {
3146 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
3147 log::info!("Context {}: split message at {:?}", context_index, range);
3148 context.split_message(range, cx);
3149 });
3150 mutation_count -= 1;
3151 }
3152 45..=59 if mutation_count > 0 => {
3153 context.update(cx, |context, cx| {
3154 if let Some(message) = context.messages(cx).choose(&mut rng) {
3155 let role = *[Role::User, Role::Assistant, Role::System]
3156 .choose(&mut rng)
3157 .unwrap();
3158 log::info!(
3159 "Context {}: insert message after {:?} with {:?}",
3160 context_index,
3161 message.id,
3162 role
3163 );
3164 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
3165 }
3166 });
3167 mutation_count -= 1;
3168 }
3169 60..=74 if mutation_count > 0 => {
3170 context.update(cx, |context, cx| {
3171 let command_text = "/".to_string()
3172 + slash_commands
3173 .command_names()
3174 .choose(&mut rng)
3175 .unwrap()
3176 .clone()
3177 .as_ref();
3178
3179 let command_range = context.buffer.update(cx, |buffer, cx| {
3180 let offset = buffer.random_byte_range(0, &mut rng).start;
3181 buffer.edit(
3182 [(offset..offset, format!("\n{}\n", command_text))],
3183 None,
3184 cx,
3185 );
3186 offset + 1..offset + 1 + command_text.len()
3187 });
3188
3189 let output_len = rng.gen_range(1..=10);
3190 let output_text = RandomCharIter::new(&mut rng)
3191 .filter(|c| *c != '\r')
3192 .take(output_len)
3193 .collect::<String>();
3194
3195 let num_sections = rng.gen_range(0..=3);
3196 let mut sections = Vec::with_capacity(num_sections);
3197 for _ in 0..num_sections {
3198 let section_start = rng.gen_range(0..output_len);
3199 let section_end = rng.gen_range(section_start..=output_len);
3200 sections.push(SlashCommandOutputSection {
3201 range: section_start..section_end,
3202 icon: ui::IconName::Ai,
3203 label: "section".into(),
3204 });
3205 }
3206
3207 log::info!(
3208 "Context {}: insert slash command output at {:?} with {:?}",
3209 context_index,
3210 command_range,
3211 sections
3212 );
3213
3214 let command_range =
3215 context.buffer.read(cx).anchor_after(command_range.start)
3216 ..context.buffer.read(cx).anchor_after(command_range.end);
3217 context.insert_command_output(
3218 command_range,
3219 Task::ready(Ok(SlashCommandOutput {
3220 text: output_text,
3221 sections,
3222 run_commands_in_text: false,
3223 })),
3224 true,
3225 cx,
3226 );
3227 });
3228 cx.run_until_parked();
3229 mutation_count -= 1;
3230 }
3231 75..=84 if mutation_count > 0 => {
3232 context.update(cx, |context, cx| {
3233 if let Some(message) = context.messages(cx).choose(&mut rng) {
3234 let new_status = match rng.gen_range(0..3) {
3235 0 => MessageStatus::Done,
3236 1 => MessageStatus::Pending,
3237 _ => MessageStatus::Error(SharedString::from("Random error")),
3238 };
3239 log::info!(
3240 "Context {}: update message {:?} status to {:?}",
3241 context_index,
3242 message.id,
3243 new_status
3244 );
3245 context.update_metadata(message.id, cx, |metadata| {
3246 metadata.status = new_status;
3247 });
3248 }
3249 });
3250 mutation_count -= 1;
3251 }
3252 _ => {
3253 let replica_id = context_index as ReplicaId;
3254 if network.lock().is_disconnected(replica_id) {
3255 network.lock().reconnect_peer(replica_id, 0);
3256
3257 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
3258 let host_context = &contexts[0].read(cx);
3259 let guest_context = context.read(cx);
3260 (
3261 guest_context.serialize_ops(&host_context.version(cx), cx),
3262 host_context.serialize_ops(&guest_context.version(cx), cx),
3263 )
3264 });
3265 let ops_to_send = ops_to_send.await;
3266 let ops_to_receive = ops_to_receive
3267 .await
3268 .into_iter()
3269 .map(ContextOperation::from_proto)
3270 .collect::<Result<Vec<_>>>()
3271 .unwrap();
3272 log::info!(
3273 "Context {}: reconnecting. Sent {} operations, received {} operations",
3274 context_index,
3275 ops_to_send.len(),
3276 ops_to_receive.len()
3277 );
3278
3279 network.lock().broadcast(replica_id, ops_to_send);
3280 context
3281 .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx))
3282 .unwrap();
3283 } else if rng.gen_bool(0.1) && replica_id != 0 {
3284 log::info!("Context {}: disconnecting", context_index);
3285 network.lock().disconnect_peer(replica_id);
3286 } else if network.lock().has_unreceived(replica_id) {
3287 log::info!("Context {}: applying operations", context_index);
3288 let ops = network.lock().receive(replica_id);
3289 let ops = ops
3290 .into_iter()
3291 .map(ContextOperation::from_proto)
3292 .collect::<Result<Vec<_>>>()
3293 .unwrap();
3294 context
3295 .update(cx, |context, cx| context.apply_ops(ops, cx))
3296 .unwrap();
3297 }
3298 }
3299 }
3300 }
3301
3302 cx.read(|cx| {
3303 let first_context = contexts[0].read(cx);
3304 for context in &contexts[1..] {
3305 let context = context.read(cx);
3306 assert!(context.pending_ops.is_empty());
3307 assert_eq!(
3308 context.buffer.read(cx).text(),
3309 first_context.buffer.read(cx).text(),
3310 "Context {} text != Context 0 text",
3311 context.buffer.read(cx).replica_id()
3312 );
3313 assert_eq!(
3314 context.message_anchors,
3315 first_context.message_anchors,
3316 "Context {} messages != Context 0 messages",
3317 context.buffer.read(cx).replica_id()
3318 );
3319 assert_eq!(
3320 context.messages_metadata,
3321 first_context.messages_metadata,
3322 "Context {} message metadata != Context 0 message metadata",
3323 context.buffer.read(cx).replica_id()
3324 );
3325 assert_eq!(
3326 context.slash_command_output_sections,
3327 first_context.slash_command_output_sections,
3328 "Context {} slash command output sections != Context 0 slash command output sections",
3329 context.buffer.read(cx).replica_id()
3330 );
3331 }
3332 });
3333 }
3334
3335 fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
3336 context
3337 .read(cx)
3338 .messages(cx)
3339 .map(|message| (message.id, message.role, message.offset_range))
3340 .collect()
3341 }
3342
3343 #[derive(Clone)]
3344 struct FakeSlashCommand(String);
3345
3346 impl SlashCommand for FakeSlashCommand {
3347 fn name(&self) -> String {
3348 self.0.clone()
3349 }
3350
3351 fn description(&self) -> String {
3352 format!("Fake slash command: {}", self.0)
3353 }
3354
3355 fn menu_text(&self) -> String {
3356 format!("Run fake command: {}", self.0)
3357 }
3358
3359 fn complete_argument(
3360 self: Arc<Self>,
3361 _query: String,
3362 _cancel: Arc<AtomicBool>,
3363 _workspace: Option<WeakView<Workspace>>,
3364 _cx: &mut AppContext,
3365 ) -> Task<Result<Vec<ArgumentCompletion>>> {
3366 Task::ready(Ok(vec![]))
3367 }
3368
3369 fn requires_argument(&self) -> bool {
3370 false
3371 }
3372
3373 fn run(
3374 self: Arc<Self>,
3375 _argument: Option<&str>,
3376 _workspace: WeakView<Workspace>,
3377 _delegate: Arc<dyn LspAdapterDelegate>,
3378 _cx: &mut WindowContext,
3379 ) -> Task<Result<SlashCommandOutput>> {
3380 Task::ready(Ok(SlashCommandOutput {
3381 text: format!("Executed fake command: {}", self.0),
3382 sections: vec![],
3383 run_commands_in_text: false,
3384 }))
3385 }
3386 }
3387}