1use crate::{
2 prompts::PromptBuilder, slash_command::SlashCommandLine, workflow::WorkflowStep, MessageId,
3 MessageStatus,
4};
5use anyhow::{anyhow, Context as _, Result};
6use assistant_slash_command::{
7 SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
8};
9use client::{self, proto, telemetry::Telemetry};
10use clock::ReplicaId;
11use collections::{HashMap, HashSet};
12use fs::{Fs, RemoveOptions};
13use futures::{future::Shared, stream::FuturesUnordered, FutureExt, StreamExt};
14use gpui::{
15 AppContext, Context as _, EventEmitter, Image, Model, ModelContext, RenderImage, SharedString,
16 Subscription, Task,
17};
18
19use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
20use language_model::{
21 LanguageModelImage, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
22 Role,
23};
24use open_ai::Model as OpenAiModel;
25use paths::{context_images_dir, contexts_dir};
26use project::Project;
27use serde::{Deserialize, Serialize};
28use smallvec::SmallVec;
29use std::{
30 cmp::Ordering,
31 collections::hash_map,
32 fmt::Debug,
33 iter, mem,
34 ops::Range,
35 path::{Path, PathBuf},
36 sync::Arc,
37 time::{Duration, Instant},
38};
39use telemetry_events::AssistantKind;
40use util::{post_inc, ResultExt, TryFutureExt};
41use uuid::Uuid;
42
43#[derive(Clone, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
44pub struct ContextId(String);
45
46impl ContextId {
47 pub fn new() -> Self {
48 Self(Uuid::new_v4().to_string())
49 }
50
51 pub fn from_proto(id: String) -> Self {
52 Self(id)
53 }
54
55 pub fn to_proto(&self) -> String {
56 self.0.clone()
57 }
58}
59
60#[derive(Clone, Debug)]
61pub enum ContextOperation {
62 InsertMessage {
63 anchor: MessageAnchor,
64 metadata: MessageMetadata,
65 version: clock::Global,
66 },
67 UpdateMessage {
68 message_id: MessageId,
69 metadata: MessageMetadata,
70 version: clock::Global,
71 },
72 UpdateSummary {
73 summary: ContextSummary,
74 version: clock::Global,
75 },
76 SlashCommandFinished {
77 id: SlashCommandId,
78 output_range: Range<language::Anchor>,
79 sections: Vec<SlashCommandOutputSection<language::Anchor>>,
80 version: clock::Global,
81 },
82 BufferOperation(language::Operation),
83}
84
85impl ContextOperation {
86 pub fn from_proto(op: proto::ContextOperation) -> Result<Self> {
87 match op.variant.context("invalid variant")? {
88 proto::context_operation::Variant::InsertMessage(insert) => {
89 let message = insert.message.context("invalid message")?;
90 let id = MessageId(language::proto::deserialize_timestamp(
91 message.id.context("invalid id")?,
92 ));
93 Ok(Self::InsertMessage {
94 anchor: MessageAnchor {
95 id,
96 start: language::proto::deserialize_anchor(
97 message.start.context("invalid anchor")?,
98 )
99 .context("invalid anchor")?,
100 },
101 metadata: MessageMetadata {
102 role: Role::from_proto(message.role),
103 status: MessageStatus::from_proto(
104 message.status.context("invalid status")?,
105 ),
106 timestamp: id.0,
107 },
108 version: language::proto::deserialize_version(&insert.version),
109 })
110 }
111 proto::context_operation::Variant::UpdateMessage(update) => Ok(Self::UpdateMessage {
112 message_id: MessageId(language::proto::deserialize_timestamp(
113 update.message_id.context("invalid message id")?,
114 )),
115 metadata: MessageMetadata {
116 role: Role::from_proto(update.role),
117 status: MessageStatus::from_proto(update.status.context("invalid status")?),
118 timestamp: language::proto::deserialize_timestamp(
119 update.timestamp.context("invalid timestamp")?,
120 ),
121 },
122 version: language::proto::deserialize_version(&update.version),
123 }),
124 proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary {
125 summary: ContextSummary {
126 text: update.summary,
127 done: update.done,
128 timestamp: language::proto::deserialize_timestamp(
129 update.timestamp.context("invalid timestamp")?,
130 ),
131 },
132 version: language::proto::deserialize_version(&update.version),
133 }),
134 proto::context_operation::Variant::SlashCommandFinished(finished) => {
135 Ok(Self::SlashCommandFinished {
136 id: SlashCommandId(language::proto::deserialize_timestamp(
137 finished.id.context("invalid id")?,
138 )),
139 output_range: language::proto::deserialize_anchor_range(
140 finished.output_range.context("invalid range")?,
141 )?,
142 sections: finished
143 .sections
144 .into_iter()
145 .map(|section| {
146 Ok(SlashCommandOutputSection {
147 range: language::proto::deserialize_anchor_range(
148 section.range.context("invalid range")?,
149 )?,
150 icon: section.icon_name.parse()?,
151 label: section.label.into(),
152 })
153 })
154 .collect::<Result<Vec<_>>>()?,
155 version: language::proto::deserialize_version(&finished.version),
156 })
157 }
158 proto::context_operation::Variant::BufferOperation(op) => Ok(Self::BufferOperation(
159 language::proto::deserialize_operation(
160 op.operation.context("invalid buffer operation")?,
161 )?,
162 )),
163 }
164 }
165
166 pub fn to_proto(&self) -> proto::ContextOperation {
167 match self {
168 Self::InsertMessage {
169 anchor,
170 metadata,
171 version,
172 } => proto::ContextOperation {
173 variant: Some(proto::context_operation::Variant::InsertMessage(
174 proto::context_operation::InsertMessage {
175 message: Some(proto::ContextMessage {
176 id: Some(language::proto::serialize_timestamp(anchor.id.0)),
177 start: Some(language::proto::serialize_anchor(&anchor.start)),
178 role: metadata.role.to_proto() as i32,
179 status: Some(metadata.status.to_proto()),
180 }),
181 version: language::proto::serialize_version(version),
182 },
183 )),
184 },
185 Self::UpdateMessage {
186 message_id,
187 metadata,
188 version,
189 } => proto::ContextOperation {
190 variant: Some(proto::context_operation::Variant::UpdateMessage(
191 proto::context_operation::UpdateMessage {
192 message_id: Some(language::proto::serialize_timestamp(message_id.0)),
193 role: metadata.role.to_proto() as i32,
194 status: Some(metadata.status.to_proto()),
195 timestamp: Some(language::proto::serialize_timestamp(metadata.timestamp)),
196 version: language::proto::serialize_version(version),
197 },
198 )),
199 },
200 Self::UpdateSummary { summary, version } => proto::ContextOperation {
201 variant: Some(proto::context_operation::Variant::UpdateSummary(
202 proto::context_operation::UpdateSummary {
203 summary: summary.text.clone(),
204 done: summary.done,
205 timestamp: Some(language::proto::serialize_timestamp(summary.timestamp)),
206 version: language::proto::serialize_version(version),
207 },
208 )),
209 },
210 Self::SlashCommandFinished {
211 id,
212 output_range,
213 sections,
214 version,
215 } => proto::ContextOperation {
216 variant: Some(proto::context_operation::Variant::SlashCommandFinished(
217 proto::context_operation::SlashCommandFinished {
218 id: Some(language::proto::serialize_timestamp(id.0)),
219 output_range: Some(language::proto::serialize_anchor_range(
220 output_range.clone(),
221 )),
222 sections: sections
223 .iter()
224 .map(|section| {
225 let icon_name: &'static str = section.icon.into();
226 proto::SlashCommandOutputSection {
227 range: Some(language::proto::serialize_anchor_range(
228 section.range.clone(),
229 )),
230 icon_name: icon_name.to_string(),
231 label: section.label.to_string(),
232 }
233 })
234 .collect(),
235 version: language::proto::serialize_version(version),
236 },
237 )),
238 },
239 Self::BufferOperation(operation) => proto::ContextOperation {
240 variant: Some(proto::context_operation::Variant::BufferOperation(
241 proto::context_operation::BufferOperation {
242 operation: Some(language::proto::serialize_operation(operation)),
243 },
244 )),
245 },
246 }
247 }
248
249 fn timestamp(&self) -> clock::Lamport {
250 match self {
251 Self::InsertMessage { anchor, .. } => anchor.id.0,
252 Self::UpdateMessage { metadata, .. } => metadata.timestamp,
253 Self::UpdateSummary { summary, .. } => summary.timestamp,
254 Self::SlashCommandFinished { id, .. } => id.0,
255 Self::BufferOperation(_) => {
256 panic!("reading the timestamp of a buffer operation is not supported")
257 }
258 }
259 }
260
261 /// Returns the current version of the context operation.
262 pub fn version(&self) -> &clock::Global {
263 match self {
264 Self::InsertMessage { version, .. }
265 | Self::UpdateMessage { version, .. }
266 | Self::UpdateSummary { version, .. }
267 | Self::SlashCommandFinished { version, .. } => version,
268 Self::BufferOperation(_) => {
269 panic!("reading the version of a buffer operation is not supported")
270 }
271 }
272 }
273}
274
275#[derive(Debug, Clone)]
276pub enum ContextEvent {
277 ShowAssistError(SharedString),
278 MessagesEdited,
279 SummaryChanged,
280 WorkflowStepsRemoved(Vec<Range<language::Anchor>>),
281 WorkflowStepUpdated(Range<language::Anchor>),
282 StreamedCompletion,
283 PendingSlashCommandsUpdated {
284 removed: Vec<Range<language::Anchor>>,
285 updated: Vec<PendingSlashCommand>,
286 },
287 SlashCommandFinished {
288 output_range: Range<language::Anchor>,
289 sections: Vec<SlashCommandOutputSection<language::Anchor>>,
290 run_commands_in_output: bool,
291 },
292 Operation(ContextOperation),
293}
294
295#[derive(Clone, Default, Debug)]
296pub struct ContextSummary {
297 pub text: String,
298 done: bool,
299 timestamp: clock::Lamport,
300}
301
302#[derive(Clone, Debug, Eq, PartialEq)]
303pub struct MessageAnchor {
304 pub id: MessageId,
305 pub start: language::Anchor,
306}
307
308#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
309pub struct MessageMetadata {
310 pub role: Role,
311 status: MessageStatus,
312 timestamp: clock::Lamport,
313}
314
315#[derive(Clone, Debug)]
316pub struct MessageImage {
317 image_id: u64,
318 image: Shared<Task<Option<LanguageModelImage>>>,
319}
320
321impl PartialEq for MessageImage {
322 fn eq(&self, other: &Self) -> bool {
323 self.image_id == other.image_id
324 }
325}
326
327impl Eq for MessageImage {}
328
329#[derive(Clone, Debug)]
330pub struct Message {
331 pub image_offsets: SmallVec<[(usize, MessageImage); 1]>,
332 pub offset_range: Range<usize>,
333 pub index_range: Range<usize>,
334 pub id: MessageId,
335 pub anchor: language::Anchor,
336 pub role: Role,
337 pub status: MessageStatus,
338}
339
340impl Message {
341 fn to_request_message(&self, buffer: &Buffer) -> LanguageModelRequestMessage {
342 let mut content = Vec::new();
343
344 let mut range_start = self.offset_range.start;
345 for (image_offset, message_image) in self.image_offsets.iter() {
346 if *image_offset != range_start {
347 content.push(
348 buffer
349 .text_for_range(range_start..*image_offset)
350 .collect::<String>()
351 .into(),
352 )
353 }
354
355 if let Some(image) = message_image.image.clone().now_or_never().flatten() {
356 content.push(language_model::MessageContent::Image(image));
357 }
358
359 range_start = *image_offset;
360 }
361 if range_start != self.offset_range.end {
362 content.push(
363 buffer
364 .text_for_range(range_start..self.offset_range.end)
365 .collect::<String>()
366 .into(),
367 )
368 }
369
370 LanguageModelRequestMessage {
371 role: self.role,
372 content,
373 }
374 }
375}
376
377#[derive(Clone, Debug)]
378pub struct ImageAnchor {
379 pub anchor: language::Anchor,
380 pub image_id: u64,
381 pub render_image: Arc<RenderImage>,
382 pub image: Shared<Task<Option<LanguageModelImage>>>,
383}
384
385struct PendingCompletion {
386 id: usize,
387 assistant_message_id: MessageId,
388 _task: Task<()>,
389}
390
391#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
392pub struct SlashCommandId(clock::Lamport);
393
394struct WorkflowStepEntry {
395 range: Range<language::Anchor>,
396 step: Model<WorkflowStep>,
397}
398
399pub struct Context {
400 id: ContextId,
401 timestamp: clock::Lamport,
402 version: clock::Global,
403 pending_ops: Vec<ContextOperation>,
404 operations: Vec<ContextOperation>,
405 buffer: Model<Buffer>,
406 pending_slash_commands: Vec<PendingSlashCommand>,
407 edits_since_last_slash_command_parse: language::Subscription,
408 finished_slash_commands: HashSet<SlashCommandId>,
409 slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
410 message_anchors: Vec<MessageAnchor>,
411 images: HashMap<u64, (Arc<RenderImage>, Shared<Task<Option<LanguageModelImage>>>)>,
412 image_anchors: Vec<ImageAnchor>,
413 messages_metadata: HashMap<MessageId, MessageMetadata>,
414 summary: Option<ContextSummary>,
415 pending_summary: Task<Option<()>>,
416 completion_count: usize,
417 pending_completions: Vec<PendingCompletion>,
418 token_count: Option<usize>,
419 pending_token_count: Task<Option<()>>,
420 pending_save: Task<Result<()>>,
421 path: Option<PathBuf>,
422 _subscriptions: Vec<Subscription>,
423 telemetry: Option<Arc<Telemetry>>,
424 language_registry: Arc<LanguageRegistry>,
425 workflow_steps: Vec<WorkflowStepEntry>,
426 edits_since_last_workflow_step_prune: language::Subscription,
427 project: Option<Model<Project>>,
428 prompt_builder: Arc<PromptBuilder>,
429}
430
431impl EventEmitter<ContextEvent> for Context {}
432
433impl Context {
434 pub fn local(
435 language_registry: Arc<LanguageRegistry>,
436 project: Option<Model<Project>>,
437 telemetry: Option<Arc<Telemetry>>,
438 prompt_builder: Arc<PromptBuilder>,
439 cx: &mut ModelContext<Self>,
440 ) -> Self {
441 Self::new(
442 ContextId::new(),
443 ReplicaId::default(),
444 language::Capability::ReadWrite,
445 language_registry,
446 prompt_builder,
447 project,
448 telemetry,
449 cx,
450 )
451 }
452
453 #[allow(clippy::too_many_arguments)]
454 pub fn new(
455 id: ContextId,
456 replica_id: ReplicaId,
457 capability: language::Capability,
458 language_registry: Arc<LanguageRegistry>,
459 prompt_builder: Arc<PromptBuilder>,
460 project: Option<Model<Project>>,
461 telemetry: Option<Arc<Telemetry>>,
462 cx: &mut ModelContext<Self>,
463 ) -> Self {
464 let buffer = cx.new_model(|_cx| {
465 let mut buffer = Buffer::remote(
466 language::BufferId::new(1).unwrap(),
467 replica_id,
468 capability,
469 "",
470 );
471 buffer.set_language_registry(language_registry.clone());
472 buffer
473 });
474 let edits_since_last_slash_command_parse =
475 buffer.update(cx, |buffer, _| buffer.subscribe());
476 let edits_since_last_workflow_step_prune =
477 buffer.update(cx, |buffer, _| buffer.subscribe());
478 let mut this = Self {
479 id,
480 timestamp: clock::Lamport::new(replica_id),
481 version: clock::Global::new(),
482 pending_ops: Vec::new(),
483 operations: Vec::new(),
484 message_anchors: Default::default(),
485 image_anchors: Default::default(),
486 images: Default::default(),
487 messages_metadata: Default::default(),
488 pending_slash_commands: Vec::new(),
489 finished_slash_commands: HashSet::default(),
490 slash_command_output_sections: Vec::new(),
491 edits_since_last_slash_command_parse,
492 summary: None,
493 pending_summary: Task::ready(None),
494 completion_count: Default::default(),
495 pending_completions: Default::default(),
496 token_count: None,
497 pending_token_count: Task::ready(None),
498 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
499 pending_save: Task::ready(Ok(())),
500 path: None,
501 buffer,
502 telemetry,
503 project,
504 language_registry,
505 workflow_steps: Vec::new(),
506 edits_since_last_workflow_step_prune,
507 prompt_builder,
508 };
509
510 let first_message_id = MessageId(clock::Lamport {
511 replica_id: 0,
512 value: 0,
513 });
514 let message = MessageAnchor {
515 id: first_message_id,
516 start: language::Anchor::MIN,
517 };
518 this.messages_metadata.insert(
519 first_message_id,
520 MessageMetadata {
521 role: Role::User,
522 status: MessageStatus::Done,
523 timestamp: first_message_id.0,
524 },
525 );
526 this.message_anchors.push(message);
527
528 this.set_language(cx);
529 this.count_remaining_tokens(cx);
530 this
531 }
532
533 fn serialize(&self, cx: &AppContext) -> SavedContext {
534 let buffer = self.buffer.read(cx);
535 SavedContext {
536 id: Some(self.id.clone()),
537 zed: "context".into(),
538 version: SavedContext::VERSION.into(),
539 text: buffer.text(),
540 messages: self
541 .messages(cx)
542 .map(|message| SavedMessage {
543 id: message.id,
544 start: message.offset_range.start,
545 metadata: self.messages_metadata[&message.id].clone(),
546 image_offsets: message
547 .image_offsets
548 .iter()
549 .map(|image_offset| (image_offset.0, image_offset.1.image_id))
550 .collect(),
551 })
552 .collect(),
553 summary: self
554 .summary
555 .as_ref()
556 .map(|summary| summary.text.clone())
557 .unwrap_or_default(),
558 slash_command_output_sections: self
559 .slash_command_output_sections
560 .iter()
561 .filter_map(|section| {
562 let range = section.range.to_offset(buffer);
563 if section.range.start.is_valid(buffer) && !range.is_empty() {
564 Some(assistant_slash_command::SlashCommandOutputSection {
565 range,
566 icon: section.icon,
567 label: section.label.clone(),
568 })
569 } else {
570 None
571 }
572 })
573 .collect(),
574 }
575 }
576
577 #[allow(clippy::too_many_arguments)]
578 pub fn deserialize(
579 saved_context: SavedContext,
580 path: PathBuf,
581 language_registry: Arc<LanguageRegistry>,
582 prompt_builder: Arc<PromptBuilder>,
583 project: Option<Model<Project>>,
584 telemetry: Option<Arc<Telemetry>>,
585 cx: &mut ModelContext<Self>,
586 ) -> Self {
587 let id = saved_context.id.clone().unwrap_or_else(|| ContextId::new());
588 let mut this = Self::new(
589 id,
590 ReplicaId::default(),
591 language::Capability::ReadWrite,
592 language_registry,
593 prompt_builder,
594 project,
595 telemetry,
596 cx,
597 );
598 this.path = Some(path);
599 this.buffer.update(cx, |buffer, cx| {
600 buffer.set_text(saved_context.text.as_str(), cx)
601 });
602 let operations = saved_context.into_ops(&this.buffer, cx);
603 this.apply_ops(operations, cx).unwrap();
604 this
605 }
606
607 pub fn id(&self) -> &ContextId {
608 &self.id
609 }
610
611 pub fn replica_id(&self) -> ReplicaId {
612 self.timestamp.replica_id
613 }
614
615 pub fn version(&self, cx: &AppContext) -> ContextVersion {
616 ContextVersion {
617 context: self.version.clone(),
618 buffer: self.buffer.read(cx).version(),
619 }
620 }
621
622 pub fn set_capability(
623 &mut self,
624 capability: language::Capability,
625 cx: &mut ModelContext<Self>,
626 ) {
627 self.buffer
628 .update(cx, |buffer, cx| buffer.set_capability(capability, cx));
629 }
630
631 fn next_timestamp(&mut self) -> clock::Lamport {
632 let timestamp = self.timestamp.tick();
633 self.version.observe(timestamp);
634 timestamp
635 }
636
637 pub fn serialize_ops(
638 &self,
639 since: &ContextVersion,
640 cx: &AppContext,
641 ) -> Task<Vec<proto::ContextOperation>> {
642 let buffer_ops = self
643 .buffer
644 .read(cx)
645 .serialize_ops(Some(since.buffer.clone()), cx);
646
647 let mut context_ops = self
648 .operations
649 .iter()
650 .filter(|op| !since.context.observed(op.timestamp()))
651 .cloned()
652 .collect::<Vec<_>>();
653 context_ops.extend(self.pending_ops.iter().cloned());
654
655 cx.background_executor().spawn(async move {
656 let buffer_ops = buffer_ops.await;
657 context_ops.sort_unstable_by_key(|op| op.timestamp());
658 buffer_ops
659 .into_iter()
660 .map(|op| proto::ContextOperation {
661 variant: Some(proto::context_operation::Variant::BufferOperation(
662 proto::context_operation::BufferOperation {
663 operation: Some(op),
664 },
665 )),
666 })
667 .chain(context_ops.into_iter().map(|op| op.to_proto()))
668 .collect()
669 })
670 }
671
672 pub fn apply_ops(
673 &mut self,
674 ops: impl IntoIterator<Item = ContextOperation>,
675 cx: &mut ModelContext<Self>,
676 ) -> Result<()> {
677 let mut buffer_ops = Vec::new();
678 for op in ops {
679 match op {
680 ContextOperation::BufferOperation(buffer_op) => buffer_ops.push(buffer_op),
681 op @ _ => self.pending_ops.push(op),
682 }
683 }
684 self.buffer
685 .update(cx, |buffer, cx| buffer.apply_ops(buffer_ops, cx))?;
686 self.flush_ops(cx);
687
688 Ok(())
689 }
690
691 fn flush_ops(&mut self, cx: &mut ModelContext<Context>) {
692 let mut messages_changed = false;
693 let mut summary_changed = false;
694
695 self.pending_ops.sort_unstable_by_key(|op| op.timestamp());
696 for op in mem::take(&mut self.pending_ops) {
697 if !self.can_apply_op(&op, cx) {
698 self.pending_ops.push(op);
699 continue;
700 }
701
702 let timestamp = op.timestamp();
703 match op.clone() {
704 ContextOperation::InsertMessage {
705 anchor, metadata, ..
706 } => {
707 if self.messages_metadata.contains_key(&anchor.id) {
708 // We already applied this operation.
709 } else {
710 self.insert_message(anchor, metadata, cx);
711 messages_changed = true;
712 }
713 }
714 ContextOperation::UpdateMessage {
715 message_id,
716 metadata: new_metadata,
717 ..
718 } => {
719 let metadata = self.messages_metadata.get_mut(&message_id).unwrap();
720 if new_metadata.timestamp > metadata.timestamp {
721 *metadata = new_metadata;
722 messages_changed = true;
723 }
724 }
725 ContextOperation::UpdateSummary {
726 summary: new_summary,
727 ..
728 } => {
729 if self
730 .summary
731 .as_ref()
732 .map_or(true, |summary| new_summary.timestamp > summary.timestamp)
733 {
734 self.summary = Some(new_summary);
735 summary_changed = true;
736 }
737 }
738 ContextOperation::SlashCommandFinished {
739 id,
740 output_range,
741 sections,
742 ..
743 } => {
744 if self.finished_slash_commands.insert(id) {
745 let buffer = self.buffer.read(cx);
746 self.slash_command_output_sections
747 .extend(sections.iter().cloned());
748 self.slash_command_output_sections
749 .sort_by(|a, b| a.range.cmp(&b.range, buffer));
750 cx.emit(ContextEvent::SlashCommandFinished {
751 output_range,
752 sections,
753 run_commands_in_output: false,
754 });
755 }
756 }
757 ContextOperation::BufferOperation(_) => unreachable!(),
758 }
759
760 self.version.observe(timestamp);
761 self.timestamp.observe(timestamp);
762 self.operations.push(op);
763 }
764
765 if messages_changed {
766 cx.emit(ContextEvent::MessagesEdited);
767 cx.notify();
768 }
769
770 if summary_changed {
771 cx.emit(ContextEvent::SummaryChanged);
772 cx.notify();
773 }
774 }
775
776 fn can_apply_op(&self, op: &ContextOperation, cx: &AppContext) -> bool {
777 if !self.version.observed_all(op.version()) {
778 return false;
779 }
780
781 match op {
782 ContextOperation::InsertMessage { anchor, .. } => self
783 .buffer
784 .read(cx)
785 .version
786 .observed(anchor.start.timestamp),
787 ContextOperation::UpdateMessage { message_id, .. } => {
788 self.messages_metadata.contains_key(message_id)
789 }
790 ContextOperation::UpdateSummary { .. } => true,
791 ContextOperation::SlashCommandFinished {
792 output_range,
793 sections,
794 ..
795 } => {
796 let version = &self.buffer.read(cx).version;
797 sections
798 .iter()
799 .map(|section| §ion.range)
800 .chain([output_range])
801 .all(|range| {
802 let observed_start = range.start == language::Anchor::MIN
803 || range.start == language::Anchor::MAX
804 || version.observed(range.start.timestamp);
805 let observed_end = range.end == language::Anchor::MIN
806 || range.end == language::Anchor::MAX
807 || version.observed(range.end.timestamp);
808 observed_start && observed_end
809 })
810 }
811 ContextOperation::BufferOperation(_) => {
812 panic!("buffer operations should always be applied")
813 }
814 }
815 }
816
817 fn push_op(&mut self, op: ContextOperation, cx: &mut ModelContext<Self>) {
818 self.operations.push(op.clone());
819 cx.emit(ContextEvent::Operation(op));
820 }
821
822 pub fn buffer(&self) -> &Model<Buffer> {
823 &self.buffer
824 }
825
826 pub fn language_registry(&self) -> Arc<LanguageRegistry> {
827 self.language_registry.clone()
828 }
829
830 pub fn project(&self) -> Option<Model<Project>> {
831 self.project.clone()
832 }
833
834 pub fn prompt_builder(&self) -> Arc<PromptBuilder> {
835 self.prompt_builder.clone()
836 }
837
838 pub fn path(&self) -> Option<&Path> {
839 self.path.as_deref()
840 }
841
842 pub fn summary(&self) -> Option<&ContextSummary> {
843 self.summary.as_ref()
844 }
845
846 pub fn workflow_step_containing(
847 &self,
848 offset: usize,
849 cx: &AppContext,
850 ) -> Option<(Range<language::Anchor>, Model<WorkflowStep>)> {
851 let buffer = self.buffer.read(cx);
852 let index = self
853 .workflow_steps
854 .binary_search_by(|step| {
855 let step_range = step.range.to_offset(&buffer);
856 if offset < step_range.start {
857 Ordering::Greater
858 } else if offset > step_range.end {
859 Ordering::Less
860 } else {
861 Ordering::Equal
862 }
863 })
864 .ok()?;
865 let step = &self.workflow_steps[index];
866 Some((step.range.clone(), step.step.clone()))
867 }
868
869 pub fn workflow_step_for_range(
870 &self,
871 range: Range<language::Anchor>,
872 ) -> Option<Model<WorkflowStep>> {
873 Some(
874 self.workflow_steps
875 .iter()
876 .find(|step| step.range == range)?
877 .step
878 .clone(),
879 )
880 }
881
882 pub fn pending_slash_commands(&self) -> &[PendingSlashCommand] {
883 &self.pending_slash_commands
884 }
885
886 pub fn slash_command_output_sections(&self) -> &[SlashCommandOutputSection<language::Anchor>] {
887 &self.slash_command_output_sections
888 }
889
890 fn set_language(&mut self, cx: &mut ModelContext<Self>) {
891 let markdown = self.language_registry.language_for_name("Markdown");
892 cx.spawn(|this, mut cx| async move {
893 let markdown = markdown.await?;
894 this.update(&mut cx, |this, cx| {
895 this.buffer
896 .update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx));
897 })
898 })
899 .detach_and_log_err(cx);
900 }
901
902 fn handle_buffer_event(
903 &mut self,
904 _: Model<Buffer>,
905 event: &language::Event,
906 cx: &mut ModelContext<Self>,
907 ) {
908 match event {
909 language::Event::Operation(operation) => cx.emit(ContextEvent::Operation(
910 ContextOperation::BufferOperation(operation.clone()),
911 )),
912 language::Event::Edited => {
913 self.count_remaining_tokens(cx);
914 self.reparse_slash_commands(cx);
915 // Use `inclusive = true` to invalidate a step when an edit occurs
916 // at the start/end of a parsed step.
917 self.prune_invalid_workflow_steps(true, cx);
918 cx.emit(ContextEvent::MessagesEdited);
919 }
920 _ => {}
921 }
922 }
923
924 pub(crate) fn token_count(&self) -> Option<usize> {
925 self.token_count
926 }
927
928 pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
929 let request = self.to_completion_request(cx);
930 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
931 return;
932 };
933 self.pending_token_count = cx.spawn(|this, mut cx| {
934 async move {
935 cx.background_executor()
936 .timer(Duration::from_millis(200))
937 .await;
938
939 let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
940 this.update(&mut cx, |this, cx| {
941 this.token_count = Some(token_count);
942 cx.notify()
943 })
944 }
945 .log_err()
946 });
947 }
948
949 pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
950 let buffer = self.buffer.read(cx);
951 let mut row_ranges = self
952 .edits_since_last_slash_command_parse
953 .consume()
954 .into_iter()
955 .map(|edit| {
956 let start_row = buffer.offset_to_point(edit.new.start).row;
957 let end_row = buffer.offset_to_point(edit.new.end).row + 1;
958 start_row..end_row
959 })
960 .peekable();
961
962 let mut removed = Vec::new();
963 let mut updated = Vec::new();
964 while let Some(mut row_range) = row_ranges.next() {
965 while let Some(next_row_range) = row_ranges.peek() {
966 if row_range.end >= next_row_range.start {
967 row_range.end = next_row_range.end;
968 row_ranges.next();
969 } else {
970 break;
971 }
972 }
973
974 let start = buffer.anchor_before(Point::new(row_range.start, 0));
975 let end = buffer.anchor_after(Point::new(
976 row_range.end - 1,
977 buffer.line_len(row_range.end - 1),
978 ));
979
980 let old_range = self.pending_command_indices_for_range(start..end, cx);
981
982 let mut new_commands = Vec::new();
983 let mut lines = buffer.text_for_range(start..end).lines();
984 let mut offset = lines.offset();
985 while let Some(line) = lines.next() {
986 if let Some(command_line) = SlashCommandLine::parse(line) {
987 let name = &line[command_line.name.clone()];
988 let arguments = command_line
989 .arguments
990 .iter()
991 .filter_map(|argument_range| {
992 if argument_range.is_empty() {
993 None
994 } else {
995 line.get(argument_range.clone())
996 }
997 })
998 .map(ToOwned::to_owned)
999 .collect::<SmallVec<_>>();
1000 if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
1001 if !command.requires_argument() || !arguments.is_empty() {
1002 let start_ix = offset + command_line.name.start - 1;
1003 let end_ix = offset
1004 + command_line
1005 .arguments
1006 .last()
1007 .map_or(command_line.name.end, |argument| argument.end);
1008 let source_range =
1009 buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix);
1010 let pending_command = PendingSlashCommand {
1011 name: name.to_string(),
1012 arguments,
1013 source_range,
1014 status: PendingSlashCommandStatus::Idle,
1015 };
1016 updated.push(pending_command.clone());
1017 new_commands.push(pending_command);
1018 }
1019 }
1020 }
1021
1022 offset = lines.offset();
1023 }
1024
1025 let removed_commands = self.pending_slash_commands.splice(old_range, new_commands);
1026 removed.extend(removed_commands.map(|command| command.source_range));
1027 }
1028
1029 if !updated.is_empty() || !removed.is_empty() {
1030 cx.emit(ContextEvent::PendingSlashCommandsUpdated { removed, updated });
1031 }
1032 }
1033
1034 fn prune_invalid_workflow_steps(&mut self, inclusive: bool, cx: &mut ModelContext<Self>) {
1035 let mut removed = Vec::new();
1036
1037 for edit_range in self.edits_since_last_workflow_step_prune.consume() {
1038 let intersecting_range = self.find_intersecting_steps(edit_range.new, inclusive, cx);
1039 removed.extend(
1040 self.workflow_steps
1041 .drain(intersecting_range)
1042 .map(|step| step.range),
1043 );
1044 }
1045
1046 if !removed.is_empty() {
1047 cx.emit(ContextEvent::WorkflowStepsRemoved(removed));
1048 cx.notify();
1049 }
1050 }
1051
1052 fn find_intersecting_steps(
1053 &self,
1054 range: Range<usize>,
1055 inclusive: bool,
1056 cx: &AppContext,
1057 ) -> Range<usize> {
1058 let buffer = self.buffer.read(cx);
1059 let start_ix = match self.workflow_steps.binary_search_by(|probe| {
1060 probe
1061 .range
1062 .end
1063 .to_offset(buffer)
1064 .cmp(&range.start)
1065 .then(if inclusive {
1066 Ordering::Greater
1067 } else {
1068 Ordering::Less
1069 })
1070 }) {
1071 Ok(ix) | Err(ix) => ix,
1072 };
1073 let end_ix = match self.workflow_steps.binary_search_by(|probe| {
1074 probe
1075 .range
1076 .start
1077 .to_offset(buffer)
1078 .cmp(&range.end)
1079 .then(if inclusive {
1080 Ordering::Less
1081 } else {
1082 Ordering::Greater
1083 })
1084 }) {
1085 Ok(ix) | Err(ix) => ix,
1086 };
1087 start_ix..end_ix
1088 }
1089
1090 fn parse_workflow_steps_in_range(&mut self, range: Range<usize>, cx: &mut ModelContext<Self>) {
1091 let weak_self = cx.weak_model();
1092 let mut new_edit_steps = Vec::new();
1093 let mut edits = Vec::new();
1094
1095 let buffer = self.buffer.read(cx).snapshot();
1096 let mut message_lines = buffer.as_rope().chunks_in_range(range).lines();
1097 let mut in_step = false;
1098 let mut step_open_tag_start_ix = 0;
1099 let mut line_start_offset = message_lines.offset();
1100
1101 while let Some(line) = message_lines.next() {
1102 if let Some(step_start_index) = line.find("<step>") {
1103 if !in_step {
1104 in_step = true;
1105 step_open_tag_start_ix = line_start_offset + step_start_index;
1106 }
1107 }
1108
1109 if let Some(step_end_index) = line.find("</step>") {
1110 if in_step {
1111 let mut step_open_tag_end_ix = step_open_tag_start_ix + "<step>".len();
1112 if buffer.chars_at(step_open_tag_end_ix).next() == Some('\n') {
1113 step_open_tag_end_ix += 1;
1114 }
1115 let mut step_end_tag_start_ix = line_start_offset + step_end_index;
1116 let step_end_tag_end_ix = step_end_tag_start_ix + "</step>".len();
1117 if buffer.reversed_chars_at(step_end_tag_start_ix).next() == Some('\n') {
1118 step_end_tag_start_ix -= 1;
1119 }
1120 edits.push((step_open_tag_start_ix..step_open_tag_end_ix, ""));
1121 edits.push((step_end_tag_start_ix..step_end_tag_end_ix, ""));
1122 let tagged_range = buffer.anchor_after(step_open_tag_end_ix)
1123 ..buffer.anchor_before(step_end_tag_start_ix);
1124
1125 // Check if a step with the same range already exists
1126 let existing_step_index = self
1127 .workflow_steps
1128 .binary_search_by(|probe| probe.range.cmp(&tagged_range, &buffer));
1129
1130 if let Err(ix) = existing_step_index {
1131 new_edit_steps.push((
1132 ix,
1133 WorkflowStepEntry {
1134 step: cx.new_model(|_| {
1135 WorkflowStep::new(tagged_range.clone(), weak_self.clone())
1136 }),
1137 range: tagged_range,
1138 },
1139 ));
1140 }
1141
1142 in_step = false;
1143 }
1144 }
1145
1146 line_start_offset = message_lines.offset();
1147 }
1148
1149 let mut updated = Vec::new();
1150 for (index, step) in new_edit_steps.into_iter().rev() {
1151 let step_range = step.range.clone();
1152 updated.push(step_range.clone());
1153 self.workflow_steps.insert(index, step);
1154 self.resolve_workflow_step(step_range, cx);
1155 }
1156
1157 // Delete <step> tags, making sure we don't accidentally invalidate
1158 // the step we just parsed.
1159 self.buffer
1160 .update(cx, |buffer, cx| buffer.edit(edits, None, cx));
1161 self.edits_since_last_workflow_step_prune.consume();
1162 }
1163
1164 pub fn resolve_workflow_step(
1165 &mut self,
1166 tagged_range: Range<language::Anchor>,
1167 cx: &mut ModelContext<Self>,
1168 ) {
1169 let Ok(step_index) = self
1170 .workflow_steps
1171 .binary_search_by(|step| step.range.cmp(&tagged_range, self.buffer.read(cx)))
1172 else {
1173 return;
1174 };
1175
1176 cx.emit(ContextEvent::WorkflowStepUpdated(tagged_range.clone()));
1177 cx.notify();
1178
1179 let resolution = self.workflow_steps[step_index].step.clone();
1180 cx.defer(move |cx| {
1181 resolution.update(cx, |resolution, cx| resolution.resolve(cx));
1182 });
1183 }
1184
1185 pub fn workflow_step_updated(
1186 &mut self,
1187 range: Range<language::Anchor>,
1188 cx: &mut ModelContext<Self>,
1189 ) {
1190 cx.emit(ContextEvent::WorkflowStepUpdated(range));
1191 cx.notify();
1192 }
1193
1194 pub fn pending_command_for_position(
1195 &mut self,
1196 position: language::Anchor,
1197 cx: &mut ModelContext<Self>,
1198 ) -> Option<&mut PendingSlashCommand> {
1199 let buffer = self.buffer.read(cx);
1200 match self
1201 .pending_slash_commands
1202 .binary_search_by(|probe| probe.source_range.end.cmp(&position, buffer))
1203 {
1204 Ok(ix) => Some(&mut self.pending_slash_commands[ix]),
1205 Err(ix) => {
1206 let cmd = self.pending_slash_commands.get_mut(ix)?;
1207 if position.cmp(&cmd.source_range.start, buffer).is_ge()
1208 && position.cmp(&cmd.source_range.end, buffer).is_le()
1209 {
1210 Some(cmd)
1211 } else {
1212 None
1213 }
1214 }
1215 }
1216 }
1217
1218 pub fn pending_commands_for_range(
1219 &self,
1220 range: Range<language::Anchor>,
1221 cx: &AppContext,
1222 ) -> &[PendingSlashCommand] {
1223 let range = self.pending_command_indices_for_range(range, cx);
1224 &self.pending_slash_commands[range]
1225 }
1226
1227 fn pending_command_indices_for_range(
1228 &self,
1229 range: Range<language::Anchor>,
1230 cx: &AppContext,
1231 ) -> Range<usize> {
1232 let buffer = self.buffer.read(cx);
1233 let start_ix = match self
1234 .pending_slash_commands
1235 .binary_search_by(|probe| probe.source_range.end.cmp(&range.start, &buffer))
1236 {
1237 Ok(ix) | Err(ix) => ix,
1238 };
1239 let end_ix = match self
1240 .pending_slash_commands
1241 .binary_search_by(|probe| probe.source_range.start.cmp(&range.end, &buffer))
1242 {
1243 Ok(ix) => ix + 1,
1244 Err(ix) => ix,
1245 };
1246 start_ix..end_ix
1247 }
1248
1249 pub fn insert_command_output(
1250 &mut self,
1251 command_range: Range<language::Anchor>,
1252 output: Task<Result<SlashCommandOutput>>,
1253 insert_trailing_newline: bool,
1254 cx: &mut ModelContext<Self>,
1255 ) {
1256 self.reparse_slash_commands(cx);
1257
1258 let insert_output_task = cx.spawn(|this, mut cx| {
1259 let command_range = command_range.clone();
1260 async move {
1261 let output = output.await;
1262 this.update(&mut cx, |this, cx| match output {
1263 Ok(mut output) => {
1264 if insert_trailing_newline {
1265 output.text.push('\n');
1266 }
1267
1268 let version = this.version.clone();
1269 let command_id = SlashCommandId(this.next_timestamp());
1270 let (operation, event) = this.buffer.update(cx, |buffer, cx| {
1271 let start = command_range.start.to_offset(buffer);
1272 let old_end = command_range.end.to_offset(buffer);
1273 let new_end = start + output.text.len();
1274 buffer.edit([(start..old_end, output.text)], None, cx);
1275
1276 let mut sections = output
1277 .sections
1278 .into_iter()
1279 .map(|section| SlashCommandOutputSection {
1280 range: buffer.anchor_after(start + section.range.start)
1281 ..buffer.anchor_before(start + section.range.end),
1282 icon: section.icon,
1283 label: section.label,
1284 })
1285 .collect::<Vec<_>>();
1286 sections.sort_by(|a, b| a.range.cmp(&b.range, buffer));
1287
1288 this.slash_command_output_sections
1289 .extend(sections.iter().cloned());
1290 this.slash_command_output_sections
1291 .sort_by(|a, b| a.range.cmp(&b.range, buffer));
1292
1293 let output_range =
1294 buffer.anchor_after(start)..buffer.anchor_before(new_end);
1295 this.finished_slash_commands.insert(command_id);
1296
1297 (
1298 ContextOperation::SlashCommandFinished {
1299 id: command_id,
1300 output_range: output_range.clone(),
1301 sections: sections.clone(),
1302 version,
1303 },
1304 ContextEvent::SlashCommandFinished {
1305 output_range,
1306 sections,
1307 run_commands_in_output: output.run_commands_in_text,
1308 },
1309 )
1310 });
1311
1312 this.push_op(operation, cx);
1313 cx.emit(event);
1314 }
1315 Err(error) => {
1316 if let Some(pending_command) =
1317 this.pending_command_for_position(command_range.start, cx)
1318 {
1319 pending_command.status =
1320 PendingSlashCommandStatus::Error(error.to_string());
1321 cx.emit(ContextEvent::PendingSlashCommandsUpdated {
1322 removed: vec![pending_command.source_range.clone()],
1323 updated: vec![pending_command.clone()],
1324 });
1325 }
1326 }
1327 })
1328 .ok();
1329 }
1330 });
1331
1332 if let Some(pending_command) = self.pending_command_for_position(command_range.start, cx) {
1333 pending_command.status = PendingSlashCommandStatus::Running {
1334 _task: insert_output_task.shared(),
1335 };
1336 cx.emit(ContextEvent::PendingSlashCommandsUpdated {
1337 removed: vec![pending_command.source_range.clone()],
1338 updated: vec![pending_command.clone()],
1339 });
1340 }
1341 }
1342
1343 pub fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
1344 self.count_remaining_tokens(cx);
1345 }
1346
1347 pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
1348 let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
1349 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1350 let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1351 message
1352 .start
1353 .is_valid(self.buffer.read(cx))
1354 .then_some(message.id)
1355 })?;
1356
1357 if !provider.is_authenticated(cx) {
1358 log::info!("completion provider has no credentials");
1359 return None;
1360 }
1361
1362 let request = self.to_completion_request(cx);
1363 let assistant_message = self
1364 .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
1365 .unwrap();
1366
1367 // Queue up the user's next reply.
1368 let user_message = self
1369 .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx)
1370 .unwrap();
1371
1372 let pending_completion_id = post_inc(&mut self.completion_count);
1373
1374 let task = cx.spawn({
1375 |this, mut cx| async move {
1376 let stream = model.stream_completion(request, &cx);
1377 let assistant_message_id = assistant_message.id;
1378 let mut response_latency = None;
1379 let stream_completion = async {
1380 let request_start = Instant::now();
1381 let mut chunks = stream.await?;
1382
1383 while let Some(chunk) = chunks.next().await {
1384 if response_latency.is_none() {
1385 response_latency = Some(request_start.elapsed());
1386 }
1387 let chunk = chunk?;
1388
1389 this.update(&mut cx, |this, cx| {
1390 let message_ix = this
1391 .message_anchors
1392 .iter()
1393 .position(|message| message.id == assistant_message_id)?;
1394 let message_range = this.buffer.update(cx, |buffer, cx| {
1395 let message_start_offset =
1396 this.message_anchors[message_ix].start.to_offset(buffer);
1397 let message_old_end_offset = this.message_anchors[message_ix + 1..]
1398 .iter()
1399 .find(|message| message.start.is_valid(buffer))
1400 .map_or(buffer.len(), |message| {
1401 message.start.to_offset(buffer).saturating_sub(1)
1402 });
1403 let message_new_end_offset = message_old_end_offset + chunk.len();
1404 buffer.edit(
1405 [(message_old_end_offset..message_old_end_offset, chunk)],
1406 None,
1407 cx,
1408 );
1409 message_start_offset..message_new_end_offset
1410 });
1411
1412 // Use `inclusive = false` as edits might occur at the end of a parsed step.
1413 this.prune_invalid_workflow_steps(false, cx);
1414 this.parse_workflow_steps_in_range(message_range, cx);
1415 cx.emit(ContextEvent::StreamedCompletion);
1416
1417 Some(())
1418 })?;
1419 smol::future::yield_now().await;
1420 }
1421 this.update(&mut cx, |this, cx| {
1422 this.pending_completions
1423 .retain(|completion| completion.id != pending_completion_id);
1424 this.summarize(false, cx);
1425 })?;
1426
1427 anyhow::Ok(())
1428 };
1429
1430 let result = stream_completion.await;
1431
1432 this.update(&mut cx, |this, cx| {
1433 let error_message = result
1434 .err()
1435 .map(|error| error.to_string().trim().to_string());
1436
1437 if let Some(error_message) = error_message.as_ref() {
1438 cx.emit(ContextEvent::ShowAssistError(SharedString::from(
1439 error_message.clone(),
1440 )));
1441 }
1442
1443 this.update_metadata(assistant_message_id, cx, |metadata| {
1444 if let Some(error_message) = error_message.as_ref() {
1445 metadata.status =
1446 MessageStatus::Error(SharedString::from(error_message.clone()));
1447 } else {
1448 metadata.status = MessageStatus::Done;
1449 }
1450 });
1451
1452 if let Some(telemetry) = this.telemetry.as_ref() {
1453 telemetry.report_assistant_event(
1454 Some(this.id.0.clone()),
1455 AssistantKind::Panel,
1456 model.telemetry_id(),
1457 response_latency,
1458 error_message,
1459 );
1460 }
1461 })
1462 .ok();
1463 }
1464 });
1465
1466 self.pending_completions.push(PendingCompletion {
1467 id: pending_completion_id,
1468 assistant_message_id: assistant_message.id,
1469 _task: task,
1470 });
1471
1472 Some(user_message)
1473 }
1474
1475 pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest {
1476 let buffer = self.buffer.read(cx);
1477 let request_messages = self
1478 .messages(cx)
1479 .filter(|message| message.status == MessageStatus::Done)
1480 .map(|message| message.to_request_message(&buffer))
1481 .collect();
1482
1483 LanguageModelRequest {
1484 messages: request_messages,
1485 stop: vec![],
1486 temperature: 1.0,
1487 }
1488 }
1489
1490 pub fn cancel_last_assist(&mut self, cx: &mut ModelContext<Self>) -> bool {
1491 if let Some(pending_completion) = self.pending_completions.pop() {
1492 self.update_metadata(pending_completion.assistant_message_id, cx, |metadata| {
1493 if metadata.status == MessageStatus::Pending {
1494 metadata.status = MessageStatus::Canceled;
1495 }
1496 });
1497 true
1498 } else {
1499 false
1500 }
1501 }
1502
1503 pub fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1504 for id in ids {
1505 if let Some(metadata) = self.messages_metadata.get(&id) {
1506 let role = metadata.role.cycle();
1507 self.update_metadata(id, cx, |metadata| metadata.role = role);
1508 }
1509 }
1510 }
1511
1512 pub fn update_metadata(
1513 &mut self,
1514 id: MessageId,
1515 cx: &mut ModelContext<Self>,
1516 f: impl FnOnce(&mut MessageMetadata),
1517 ) {
1518 let version = self.version.clone();
1519 let timestamp = self.next_timestamp();
1520 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1521 f(metadata);
1522 metadata.timestamp = timestamp;
1523 let operation = ContextOperation::UpdateMessage {
1524 message_id: id,
1525 metadata: metadata.clone(),
1526 version,
1527 };
1528 self.push_op(operation, cx);
1529 cx.emit(ContextEvent::MessagesEdited);
1530 cx.notify();
1531 }
1532 }
1533
1534 fn insert_message_after(
1535 &mut self,
1536 message_id: MessageId,
1537 role: Role,
1538 status: MessageStatus,
1539 cx: &mut ModelContext<Self>,
1540 ) -> Option<MessageAnchor> {
1541 if let Some(prev_message_ix) = self
1542 .message_anchors
1543 .iter()
1544 .position(|message| message.id == message_id)
1545 {
1546 // Find the next valid message after the one we were given.
1547 let mut next_message_ix = prev_message_ix + 1;
1548 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1549 if next_message.start.is_valid(self.buffer.read(cx)) {
1550 break;
1551 }
1552 next_message_ix += 1;
1553 }
1554
1555 let start = self.buffer.update(cx, |buffer, cx| {
1556 let offset = self
1557 .message_anchors
1558 .get(next_message_ix)
1559 .map_or(buffer.len(), |message| {
1560 buffer.clip_offset(message.start.to_offset(buffer) - 1, Bias::Left)
1561 });
1562 buffer.edit([(offset..offset, "\n")], None, cx);
1563 buffer.anchor_before(offset + 1)
1564 });
1565
1566 let version = self.version.clone();
1567 let anchor = MessageAnchor {
1568 id: MessageId(self.next_timestamp()),
1569 start,
1570 };
1571 let metadata = MessageMetadata {
1572 role,
1573 status,
1574 timestamp: anchor.id.0,
1575 };
1576 self.insert_message(anchor.clone(), metadata.clone(), cx);
1577 self.push_op(
1578 ContextOperation::InsertMessage {
1579 anchor: anchor.clone(),
1580 metadata,
1581 version,
1582 },
1583 cx,
1584 );
1585 Some(anchor)
1586 } else {
1587 None
1588 }
1589 }
1590
1591 pub fn insert_image(&mut self, image: Image, cx: &mut ModelContext<Self>) -> Option<()> {
1592 if let hash_map::Entry::Vacant(entry) = self.images.entry(image.id()) {
1593 entry.insert((
1594 image.to_image_data(cx).log_err()?,
1595 LanguageModelImage::from_image(image, cx).shared(),
1596 ));
1597 }
1598
1599 Some(())
1600 }
1601
1602 pub fn insert_image_anchor(
1603 &mut self,
1604 image_id: u64,
1605 anchor: language::Anchor,
1606 cx: &mut ModelContext<Self>,
1607 ) -> bool {
1608 cx.emit(ContextEvent::MessagesEdited);
1609
1610 let buffer = self.buffer.read(cx);
1611 let insertion_ix = match self
1612 .image_anchors
1613 .binary_search_by(|existing_anchor| anchor.cmp(&existing_anchor.anchor, buffer))
1614 {
1615 Ok(ix) => ix,
1616 Err(ix) => ix,
1617 };
1618
1619 if let Some((render_image, image)) = self.images.get(&image_id) {
1620 self.image_anchors.insert(
1621 insertion_ix,
1622 ImageAnchor {
1623 anchor,
1624 image_id,
1625 image: image.clone(),
1626 render_image: render_image.clone(),
1627 },
1628 );
1629
1630 true
1631 } else {
1632 false
1633 }
1634 }
1635
1636 pub fn images<'a>(&'a self, _cx: &'a AppContext) -> impl 'a + Iterator<Item = ImageAnchor> {
1637 self.image_anchors.iter().cloned()
1638 }
1639
1640 pub fn split_message(
1641 &mut self,
1642 range: Range<usize>,
1643 cx: &mut ModelContext<Self>,
1644 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1645 let start_message = self.message_for_offset(range.start, cx);
1646 let end_message = self.message_for_offset(range.end, cx);
1647 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1648 // Prevent splitting when range spans multiple messages.
1649 if start_message.id != end_message.id {
1650 return (None, None);
1651 }
1652
1653 let message = start_message;
1654 let role = message.role;
1655 let mut edited_buffer = false;
1656
1657 let mut suffix_start = None;
1658
1659 // TODO: why did this start panicking?
1660 if range.start > message.offset_range.start
1661 && range.end < message.offset_range.end.saturating_sub(1)
1662 {
1663 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1664 suffix_start = Some(range.end + 1);
1665 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1666 suffix_start = Some(range.end);
1667 }
1668 }
1669
1670 let version = self.version.clone();
1671 let suffix = if let Some(suffix_start) = suffix_start {
1672 MessageAnchor {
1673 id: MessageId(self.next_timestamp()),
1674 start: self.buffer.read(cx).anchor_before(suffix_start),
1675 }
1676 } else {
1677 self.buffer.update(cx, |buffer, cx| {
1678 buffer.edit([(range.end..range.end, "\n")], None, cx);
1679 });
1680 edited_buffer = true;
1681 MessageAnchor {
1682 id: MessageId(self.next_timestamp()),
1683 start: self.buffer.read(cx).anchor_before(range.end + 1),
1684 }
1685 };
1686
1687 let suffix_metadata = MessageMetadata {
1688 role,
1689 status: MessageStatus::Done,
1690 timestamp: suffix.id.0,
1691 };
1692 self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
1693 self.push_op(
1694 ContextOperation::InsertMessage {
1695 anchor: suffix.clone(),
1696 metadata: suffix_metadata,
1697 version,
1698 },
1699 cx,
1700 );
1701
1702 let new_messages =
1703 if range.start == range.end || range.start == message.offset_range.start {
1704 (None, Some(suffix))
1705 } else {
1706 let mut prefix_end = None;
1707 if range.start > message.offset_range.start
1708 && range.end < message.offset_range.end - 1
1709 {
1710 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1711 prefix_end = Some(range.start + 1);
1712 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1713 == Some('\n')
1714 {
1715 prefix_end = Some(range.start);
1716 }
1717 }
1718
1719 let version = self.version.clone();
1720 let selection = if let Some(prefix_end) = prefix_end {
1721 MessageAnchor {
1722 id: MessageId(self.next_timestamp()),
1723 start: self.buffer.read(cx).anchor_before(prefix_end),
1724 }
1725 } else {
1726 self.buffer.update(cx, |buffer, cx| {
1727 buffer.edit([(range.start..range.start, "\n")], None, cx)
1728 });
1729 edited_buffer = true;
1730 MessageAnchor {
1731 id: MessageId(self.next_timestamp()),
1732 start: self.buffer.read(cx).anchor_before(range.end + 1),
1733 }
1734 };
1735
1736 let selection_metadata = MessageMetadata {
1737 role,
1738 status: MessageStatus::Done,
1739 timestamp: selection.id.0,
1740 };
1741 self.insert_message(selection.clone(), selection_metadata.clone(), cx);
1742 self.push_op(
1743 ContextOperation::InsertMessage {
1744 anchor: selection.clone(),
1745 metadata: selection_metadata,
1746 version,
1747 },
1748 cx,
1749 );
1750
1751 (Some(selection), Some(suffix))
1752 };
1753
1754 if !edited_buffer {
1755 cx.emit(ContextEvent::MessagesEdited);
1756 }
1757 new_messages
1758 } else {
1759 (None, None)
1760 }
1761 }
1762
1763 fn insert_message(
1764 &mut self,
1765 new_anchor: MessageAnchor,
1766 new_metadata: MessageMetadata,
1767 cx: &mut ModelContext<Self>,
1768 ) {
1769 cx.emit(ContextEvent::MessagesEdited);
1770
1771 self.messages_metadata.insert(new_anchor.id, new_metadata);
1772
1773 let buffer = self.buffer.read(cx);
1774 let insertion_ix = self
1775 .message_anchors
1776 .iter()
1777 .position(|anchor| {
1778 let comparison = new_anchor.start.cmp(&anchor.start, buffer);
1779 comparison.is_lt() || (comparison.is_eq() && new_anchor.id > anchor.id)
1780 })
1781 .unwrap_or(self.message_anchors.len());
1782 self.message_anchors.insert(insertion_ix, new_anchor);
1783 }
1784
1785 pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
1786 let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
1787 return;
1788 };
1789 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
1790 return;
1791 };
1792
1793 if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
1794 if !provider.is_authenticated(cx) {
1795 return;
1796 }
1797
1798 let messages = self
1799 .messages(cx)
1800 .map(|message| message.to_request_message(self.buffer.read(cx)))
1801 .chain(Some(LanguageModelRequestMessage {
1802 role: Role::User,
1803 content: vec![
1804 "Summarize the context into a short title without punctuation.".into(),
1805 ],
1806 }));
1807 let request = LanguageModelRequest {
1808 messages: messages.collect(),
1809 stop: vec![],
1810 temperature: 1.0,
1811 };
1812
1813 self.pending_summary = cx.spawn(|this, mut cx| {
1814 async move {
1815 let stream = model.stream_completion(request, &cx);
1816 let mut messages = stream.await?;
1817
1818 let mut replaced = !replace_old;
1819 while let Some(message) = messages.next().await {
1820 let text = message?;
1821 let mut lines = text.lines();
1822 this.update(&mut cx, |this, cx| {
1823 let version = this.version.clone();
1824 let timestamp = this.next_timestamp();
1825 let summary = this.summary.get_or_insert(ContextSummary::default());
1826 if !replaced && replace_old {
1827 summary.text.clear();
1828 replaced = true;
1829 }
1830 summary.text.extend(lines.next());
1831 summary.timestamp = timestamp;
1832 let operation = ContextOperation::UpdateSummary {
1833 summary: summary.clone(),
1834 version,
1835 };
1836 this.push_op(operation, cx);
1837 cx.emit(ContextEvent::SummaryChanged);
1838 })?;
1839
1840 // Stop if the LLM generated multiple lines.
1841 if lines.next().is_some() {
1842 break;
1843 }
1844 }
1845
1846 this.update(&mut cx, |this, cx| {
1847 let version = this.version.clone();
1848 let timestamp = this.next_timestamp();
1849 if let Some(summary) = this.summary.as_mut() {
1850 summary.done = true;
1851 summary.timestamp = timestamp;
1852 let operation = ContextOperation::UpdateSummary {
1853 summary: summary.clone(),
1854 version,
1855 };
1856 this.push_op(operation, cx);
1857 cx.emit(ContextEvent::SummaryChanged);
1858 }
1859 })?;
1860
1861 anyhow::Ok(())
1862 }
1863 .log_err()
1864 });
1865 }
1866 }
1867
1868 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1869 self.messages_for_offsets([offset], cx).pop()
1870 }
1871
1872 pub fn messages_for_offsets(
1873 &self,
1874 offsets: impl IntoIterator<Item = usize>,
1875 cx: &AppContext,
1876 ) -> Vec<Message> {
1877 let mut result = Vec::new();
1878
1879 let mut messages = self.messages(cx).peekable();
1880 let mut offsets = offsets.into_iter().peekable();
1881 let mut current_message = messages.next();
1882 while let Some(offset) = offsets.next() {
1883 // Locate the message that contains the offset.
1884 while current_message.as_ref().map_or(false, |message| {
1885 !message.offset_range.contains(&offset) && messages.peek().is_some()
1886 }) {
1887 current_message = messages.next();
1888 }
1889 let Some(message) = current_message.as_ref() else {
1890 break;
1891 };
1892
1893 // Skip offsets that are in the same message.
1894 while offsets.peek().map_or(false, |offset| {
1895 message.offset_range.contains(offset) || messages.peek().is_none()
1896 }) {
1897 offsets.next();
1898 }
1899
1900 result.push(message.clone());
1901 }
1902 result
1903 }
1904
1905 pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
1906 let buffer = self.buffer.read(cx);
1907 let messages = self.message_anchors.iter().enumerate();
1908 let images = self.image_anchors.iter();
1909
1910 Self::messages_from_iters(buffer, &self.messages_metadata, messages, images)
1911 }
1912
1913 pub fn messages_from_iters<'a>(
1914 buffer: &'a Buffer,
1915 metadata: &'a HashMap<MessageId, MessageMetadata>,
1916 messages: impl Iterator<Item = (usize, &'a MessageAnchor)> + 'a,
1917 images: impl Iterator<Item = &'a ImageAnchor> + 'a,
1918 ) -> impl 'a + Iterator<Item = Message> {
1919 let mut messages = messages.peekable();
1920 let mut images = images.peekable();
1921
1922 iter::from_fn(move || {
1923 if let Some((start_ix, message_anchor)) = messages.next() {
1924 let metadata = metadata.get(&message_anchor.id)?;
1925
1926 let message_start = message_anchor.start.to_offset(buffer);
1927 let mut message_end = None;
1928 let mut end_ix = start_ix;
1929 while let Some((_, next_message)) = messages.peek() {
1930 if next_message.start.is_valid(buffer) {
1931 message_end = Some(next_message.start);
1932 break;
1933 } else {
1934 end_ix += 1;
1935 messages.next();
1936 }
1937 }
1938 let message_end_anchor = message_end.unwrap_or(language::Anchor::MAX);
1939 let message_end = message_end_anchor.to_offset(buffer);
1940
1941 let mut image_offsets = SmallVec::new();
1942 while let Some(image_anchor) = images.peek() {
1943 if image_anchor.anchor.cmp(&message_end_anchor, buffer).is_lt() {
1944 image_offsets.push((
1945 image_anchor.anchor.to_offset(buffer),
1946 MessageImage {
1947 image_id: image_anchor.image_id,
1948 image: image_anchor.image.clone(),
1949 },
1950 ));
1951 images.next();
1952 } else {
1953 break;
1954 }
1955 }
1956
1957 return Some(Message {
1958 index_range: start_ix..end_ix,
1959 offset_range: message_start..message_end,
1960 id: message_anchor.id,
1961 anchor: message_anchor.start,
1962 role: metadata.role,
1963 status: metadata.status.clone(),
1964 image_offsets,
1965 });
1966 }
1967 None
1968 })
1969 }
1970
1971 pub fn save(
1972 &mut self,
1973 debounce: Option<Duration>,
1974 fs: Arc<dyn Fs>,
1975 cx: &mut ModelContext<Context>,
1976 ) {
1977 if self.replica_id() != ReplicaId::default() {
1978 // Prevent saving a remote context for now.
1979 return;
1980 }
1981
1982 self.pending_save = cx.spawn(|this, mut cx| async move {
1983 if let Some(debounce) = debounce {
1984 cx.background_executor().timer(debounce).await;
1985 }
1986
1987 let (old_path, summary) = this.read_with(&cx, |this, _| {
1988 let path = this.path.clone();
1989 let summary = if let Some(summary) = this.summary.as_ref() {
1990 if summary.done {
1991 Some(summary.text.clone())
1992 } else {
1993 None
1994 }
1995 } else {
1996 None
1997 };
1998 (path, summary)
1999 })?;
2000
2001 if let Some(summary) = summary {
2002 this.read_with(&cx, |this, cx| this.serialize_images(fs.clone(), cx))?
2003 .await;
2004
2005 let context = this.read_with(&cx, |this, cx| this.serialize(cx))?;
2006 let mut discriminant = 1;
2007 let mut new_path;
2008 loop {
2009 new_path = contexts_dir().join(&format!(
2010 "{} - {}.zed.json",
2011 summary.trim(),
2012 discriminant
2013 ));
2014 if fs.is_file(&new_path).await {
2015 discriminant += 1;
2016 } else {
2017 break;
2018 }
2019 }
2020
2021 fs.create_dir(contexts_dir().as_ref()).await?;
2022 fs.atomic_write(new_path.clone(), serde_json::to_string(&context).unwrap())
2023 .await?;
2024 if let Some(old_path) = old_path {
2025 if new_path != old_path {
2026 fs.remove_file(
2027 &old_path,
2028 RemoveOptions {
2029 recursive: false,
2030 ignore_if_not_exists: true,
2031 },
2032 )
2033 .await?;
2034 }
2035 }
2036
2037 this.update(&mut cx, |this, _| this.path = Some(new_path))?;
2038 }
2039
2040 Ok(())
2041 });
2042 }
2043
2044 pub fn serialize_images(&self, fs: Arc<dyn Fs>, cx: &AppContext) -> Task<()> {
2045 let mut images_to_save = self
2046 .images
2047 .iter()
2048 .map(|(id, (_, llm_image))| {
2049 let fs = fs.clone();
2050 let llm_image = llm_image.clone();
2051 let id = *id;
2052 async move {
2053 if let Some(llm_image) = llm_image.await {
2054 let path: PathBuf =
2055 context_images_dir().join(&format!("{}.png.base64", id));
2056 if fs
2057 .metadata(path.as_path())
2058 .await
2059 .log_err()
2060 .flatten()
2061 .is_none()
2062 {
2063 fs.atomic_write(path, llm_image.source.to_string())
2064 .await
2065 .log_err();
2066 }
2067 }
2068 }
2069 })
2070 .collect::<FuturesUnordered<_>>();
2071 cx.background_executor().spawn(async move {
2072 if fs
2073 .create_dir(context_images_dir().as_ref())
2074 .await
2075 .log_err()
2076 .is_some()
2077 {
2078 while let Some(_) = images_to_save.next().await {}
2079 }
2080 })
2081 }
2082
2083 pub(crate) fn custom_summary(&mut self, custom_summary: String, cx: &mut ModelContext<Self>) {
2084 let timestamp = self.next_timestamp();
2085 let summary = self.summary.get_or_insert(ContextSummary::default());
2086 summary.timestamp = timestamp;
2087 summary.done = true;
2088 summary.text = custom_summary;
2089 cx.emit(ContextEvent::SummaryChanged);
2090 }
2091}
2092
2093#[derive(Debug, Default)]
2094pub struct ContextVersion {
2095 context: clock::Global,
2096 buffer: clock::Global,
2097}
2098
2099impl ContextVersion {
2100 pub fn from_proto(proto: &proto::ContextVersion) -> Self {
2101 Self {
2102 context: language::proto::deserialize_version(&proto.context_version),
2103 buffer: language::proto::deserialize_version(&proto.buffer_version),
2104 }
2105 }
2106
2107 pub fn to_proto(&self, context_id: ContextId) -> proto::ContextVersion {
2108 proto::ContextVersion {
2109 context_id: context_id.to_proto(),
2110 context_version: language::proto::serialize_version(&self.context),
2111 buffer_version: language::proto::serialize_version(&self.buffer),
2112 }
2113 }
2114}
2115
2116#[derive(Debug, Clone)]
2117pub struct PendingSlashCommand {
2118 pub name: String,
2119 pub arguments: SmallVec<[String; 3]>,
2120 pub status: PendingSlashCommandStatus,
2121 pub source_range: Range<language::Anchor>,
2122}
2123
2124#[derive(Debug, Clone)]
2125pub enum PendingSlashCommandStatus {
2126 Idle,
2127 Running { _task: Shared<Task<()>> },
2128 Error(String),
2129}
2130
2131#[derive(Serialize, Deserialize)]
2132pub struct SavedMessage {
2133 pub id: MessageId,
2134 pub start: usize,
2135 pub metadata: MessageMetadata,
2136 #[serde(default)]
2137 // This is defaulted for backwards compatibility with JSON files created before August 2024. We didn't always have this field.
2138 pub image_offsets: Vec<(usize, u64)>,
2139}
2140
2141#[derive(Serialize, Deserialize)]
2142pub struct SavedContext {
2143 pub id: Option<ContextId>,
2144 pub zed: String,
2145 pub version: String,
2146 pub text: String,
2147 pub messages: Vec<SavedMessage>,
2148 pub summary: String,
2149 pub slash_command_output_sections:
2150 Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
2151}
2152
2153impl SavedContext {
2154 pub const VERSION: &'static str = "0.4.0";
2155
2156 pub fn from_json(json: &str) -> Result<Self> {
2157 let saved_context_json = serde_json::from_str::<serde_json::Value>(json)?;
2158 match saved_context_json
2159 .get("version")
2160 .ok_or_else(|| anyhow!("version not found"))?
2161 {
2162 serde_json::Value::String(version) => match version.as_str() {
2163 SavedContext::VERSION => {
2164 Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
2165 }
2166 SavedContextV0_3_0::VERSION => {
2167 let saved_context =
2168 serde_json::from_value::<SavedContextV0_3_0>(saved_context_json)?;
2169 Ok(saved_context.upgrade())
2170 }
2171 SavedContextV0_2_0::VERSION => {
2172 let saved_context =
2173 serde_json::from_value::<SavedContextV0_2_0>(saved_context_json)?;
2174 Ok(saved_context.upgrade())
2175 }
2176 SavedContextV0_1_0::VERSION => {
2177 let saved_context =
2178 serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?;
2179 Ok(saved_context.upgrade())
2180 }
2181 _ => Err(anyhow!("unrecognized saved context version: {}", version)),
2182 },
2183 _ => Err(anyhow!("version not found on saved context")),
2184 }
2185 }
2186
2187 fn into_ops(
2188 self,
2189 buffer: &Model<Buffer>,
2190 cx: &mut ModelContext<Context>,
2191 ) -> Vec<ContextOperation> {
2192 let mut operations = Vec::new();
2193 let mut version = clock::Global::new();
2194 let mut next_timestamp = clock::Lamport::new(ReplicaId::default());
2195
2196 let mut first_message_metadata = None;
2197 for message in self.messages {
2198 if message.id == MessageId(clock::Lamport::default()) {
2199 first_message_metadata = Some(message.metadata);
2200 } else {
2201 operations.push(ContextOperation::InsertMessage {
2202 anchor: MessageAnchor {
2203 id: message.id,
2204 start: buffer.read(cx).anchor_before(message.start),
2205 },
2206 metadata: MessageMetadata {
2207 role: message.metadata.role,
2208 status: message.metadata.status,
2209 timestamp: message.metadata.timestamp,
2210 },
2211 version: version.clone(),
2212 });
2213 version.observe(message.id.0);
2214 next_timestamp.observe(message.id.0);
2215 }
2216 }
2217
2218 if let Some(metadata) = first_message_metadata {
2219 let timestamp = next_timestamp.tick();
2220 operations.push(ContextOperation::UpdateMessage {
2221 message_id: MessageId(clock::Lamport::default()),
2222 metadata: MessageMetadata {
2223 role: metadata.role,
2224 status: metadata.status,
2225 timestamp,
2226 },
2227 version: version.clone(),
2228 });
2229 version.observe(timestamp);
2230 }
2231
2232 let timestamp = next_timestamp.tick();
2233 operations.push(ContextOperation::SlashCommandFinished {
2234 id: SlashCommandId(timestamp),
2235 output_range: language::Anchor::MIN..language::Anchor::MAX,
2236 sections: self
2237 .slash_command_output_sections
2238 .into_iter()
2239 .map(|section| {
2240 let buffer = buffer.read(cx);
2241 SlashCommandOutputSection {
2242 range: buffer.anchor_after(section.range.start)
2243 ..buffer.anchor_before(section.range.end),
2244 icon: section.icon,
2245 label: section.label,
2246 }
2247 })
2248 .collect(),
2249 version: version.clone(),
2250 });
2251 version.observe(timestamp);
2252
2253 let timestamp = next_timestamp.tick();
2254 operations.push(ContextOperation::UpdateSummary {
2255 summary: ContextSummary {
2256 text: self.summary,
2257 done: true,
2258 timestamp,
2259 },
2260 version: version.clone(),
2261 });
2262 version.observe(timestamp);
2263
2264 operations
2265 }
2266}
2267
2268#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
2269struct SavedMessageIdPreV0_4_0(usize);
2270
2271#[derive(Serialize, Deserialize)]
2272struct SavedMessagePreV0_4_0 {
2273 id: SavedMessageIdPreV0_4_0,
2274 start: usize,
2275}
2276
2277#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
2278struct SavedMessageMetadataPreV0_4_0 {
2279 role: Role,
2280 status: MessageStatus,
2281}
2282
2283#[derive(Serialize, Deserialize)]
2284struct SavedContextV0_3_0 {
2285 id: Option<ContextId>,
2286 zed: String,
2287 version: String,
2288 text: String,
2289 messages: Vec<SavedMessagePreV0_4_0>,
2290 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2291 summary: String,
2292 slash_command_output_sections: Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
2293}
2294
2295impl SavedContextV0_3_0 {
2296 const VERSION: &'static str = "0.3.0";
2297
2298 fn upgrade(self) -> SavedContext {
2299 SavedContext {
2300 id: self.id,
2301 zed: self.zed,
2302 version: SavedContext::VERSION.into(),
2303 text: self.text,
2304 messages: self
2305 .messages
2306 .into_iter()
2307 .filter_map(|message| {
2308 let metadata = self.message_metadata.get(&message.id)?;
2309 let timestamp = clock::Lamport {
2310 replica_id: ReplicaId::default(),
2311 value: message.id.0 as u32,
2312 };
2313 Some(SavedMessage {
2314 id: MessageId(timestamp),
2315 start: message.start,
2316 metadata: MessageMetadata {
2317 role: metadata.role,
2318 status: metadata.status.clone(),
2319 timestamp,
2320 },
2321 image_offsets: Vec::new(),
2322 })
2323 })
2324 .collect(),
2325 summary: self.summary,
2326 slash_command_output_sections: self.slash_command_output_sections,
2327 }
2328 }
2329}
2330
2331#[derive(Serialize, Deserialize)]
2332struct SavedContextV0_2_0 {
2333 id: Option<ContextId>,
2334 zed: String,
2335 version: String,
2336 text: String,
2337 messages: Vec<SavedMessagePreV0_4_0>,
2338 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2339 summary: String,
2340}
2341
2342impl SavedContextV0_2_0 {
2343 const VERSION: &'static str = "0.2.0";
2344
2345 fn upgrade(self) -> SavedContext {
2346 SavedContextV0_3_0 {
2347 id: self.id,
2348 zed: self.zed,
2349 version: SavedContextV0_3_0::VERSION.to_string(),
2350 text: self.text,
2351 messages: self.messages,
2352 message_metadata: self.message_metadata,
2353 summary: self.summary,
2354 slash_command_output_sections: Vec::new(),
2355 }
2356 .upgrade()
2357 }
2358}
2359
2360#[derive(Serialize, Deserialize)]
2361struct SavedContextV0_1_0 {
2362 id: Option<ContextId>,
2363 zed: String,
2364 version: String,
2365 text: String,
2366 messages: Vec<SavedMessagePreV0_4_0>,
2367 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2368 summary: String,
2369 api_url: Option<String>,
2370 model: OpenAiModel,
2371}
2372
2373impl SavedContextV0_1_0 {
2374 const VERSION: &'static str = "0.1.0";
2375
2376 fn upgrade(self) -> SavedContext {
2377 SavedContextV0_2_0 {
2378 id: self.id,
2379 zed: self.zed,
2380 version: SavedContextV0_2_0::VERSION.to_string(),
2381 text: self.text,
2382 messages: self.messages,
2383 message_metadata: self.message_metadata,
2384 summary: self.summary,
2385 }
2386 .upgrade()
2387 }
2388}
2389
2390#[derive(Clone)]
2391pub struct SavedContextMetadata {
2392 pub title: String,
2393 pub path: PathBuf,
2394 pub mtime: chrono::DateTime<chrono::Local>,
2395}
2396
2397#[cfg(test)]
2398mod tests {
2399 use super::*;
2400 use crate::{
2401 assistant_panel, prompt_library, slash_command::file_command, workflow::tool, MessageId,
2402 };
2403 use assistant_slash_command::{ArgumentCompletion, SlashCommand};
2404 use fs::FakeFs;
2405 use gpui::{AppContext, TestAppContext, WeakView};
2406 use indoc::indoc;
2407 use language::LspAdapterDelegate;
2408 use parking_lot::Mutex;
2409 use project::Project;
2410 use rand::prelude::*;
2411 use serde_json::json;
2412 use settings::SettingsStore;
2413 use std::{cell::RefCell, env, rc::Rc, sync::atomic::AtomicBool};
2414 use text::{network::Network, ToPoint};
2415 use ui::WindowContext;
2416 use unindent::Unindent;
2417 use util::{test::marked_text_ranges, RandomCharIter};
2418 use workspace::Workspace;
2419
2420 #[gpui::test]
2421 fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2422 let settings_store = SettingsStore::test(cx);
2423 LanguageModelRegistry::test(cx);
2424 cx.set_global(settings_store);
2425 assistant_panel::init(cx);
2426 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2427 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2428 let context =
2429 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
2430 let buffer = context.read(cx).buffer.clone();
2431
2432 let message_1 = context.read(cx).message_anchors[0].clone();
2433 assert_eq!(
2434 messages(&context, cx),
2435 vec![(message_1.id, Role::User, 0..0)]
2436 );
2437
2438 let message_2 = context.update(cx, |context, cx| {
2439 context
2440 .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2441 .unwrap()
2442 });
2443 assert_eq!(
2444 messages(&context, cx),
2445 vec![
2446 (message_1.id, Role::User, 0..1),
2447 (message_2.id, Role::Assistant, 1..1)
2448 ]
2449 );
2450
2451 buffer.update(cx, |buffer, cx| {
2452 buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2453 });
2454 assert_eq!(
2455 messages(&context, cx),
2456 vec![
2457 (message_1.id, Role::User, 0..2),
2458 (message_2.id, Role::Assistant, 2..3)
2459 ]
2460 );
2461
2462 let message_3 = context.update(cx, |context, cx| {
2463 context
2464 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2465 .unwrap()
2466 });
2467 assert_eq!(
2468 messages(&context, cx),
2469 vec![
2470 (message_1.id, Role::User, 0..2),
2471 (message_2.id, Role::Assistant, 2..4),
2472 (message_3.id, Role::User, 4..4)
2473 ]
2474 );
2475
2476 let message_4 = context.update(cx, |context, cx| {
2477 context
2478 .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2479 .unwrap()
2480 });
2481 assert_eq!(
2482 messages(&context, cx),
2483 vec![
2484 (message_1.id, Role::User, 0..2),
2485 (message_2.id, Role::Assistant, 2..4),
2486 (message_4.id, Role::User, 4..5),
2487 (message_3.id, Role::User, 5..5),
2488 ]
2489 );
2490
2491 buffer.update(cx, |buffer, cx| {
2492 buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2493 });
2494 assert_eq!(
2495 messages(&context, cx),
2496 vec![
2497 (message_1.id, Role::User, 0..2),
2498 (message_2.id, Role::Assistant, 2..4),
2499 (message_4.id, Role::User, 4..6),
2500 (message_3.id, Role::User, 6..7),
2501 ]
2502 );
2503
2504 // Deleting across message boundaries merges the messages.
2505 buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2506 assert_eq!(
2507 messages(&context, cx),
2508 vec![
2509 (message_1.id, Role::User, 0..3),
2510 (message_3.id, Role::User, 3..4),
2511 ]
2512 );
2513
2514 // Undoing the deletion should also undo the merge.
2515 buffer.update(cx, |buffer, cx| buffer.undo(cx));
2516 assert_eq!(
2517 messages(&context, cx),
2518 vec![
2519 (message_1.id, Role::User, 0..2),
2520 (message_2.id, Role::Assistant, 2..4),
2521 (message_4.id, Role::User, 4..6),
2522 (message_3.id, Role::User, 6..7),
2523 ]
2524 );
2525
2526 // Redoing the deletion should also redo the merge.
2527 buffer.update(cx, |buffer, cx| buffer.redo(cx));
2528 assert_eq!(
2529 messages(&context, cx),
2530 vec![
2531 (message_1.id, Role::User, 0..3),
2532 (message_3.id, Role::User, 3..4),
2533 ]
2534 );
2535
2536 // Ensure we can still insert after a merged message.
2537 let message_5 = context.update(cx, |context, cx| {
2538 context
2539 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2540 .unwrap()
2541 });
2542 assert_eq!(
2543 messages(&context, cx),
2544 vec![
2545 (message_1.id, Role::User, 0..3),
2546 (message_5.id, Role::System, 3..4),
2547 (message_3.id, Role::User, 4..5)
2548 ]
2549 );
2550 }
2551
2552 #[gpui::test]
2553 fn test_message_splitting(cx: &mut AppContext) {
2554 let settings_store = SettingsStore::test(cx);
2555 cx.set_global(settings_store);
2556 LanguageModelRegistry::test(cx);
2557 assistant_panel::init(cx);
2558 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2559
2560 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2561 let context =
2562 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
2563 let buffer = context.read(cx).buffer.clone();
2564
2565 let message_1 = context.read(cx).message_anchors[0].clone();
2566 assert_eq!(
2567 messages(&context, cx),
2568 vec![(message_1.id, Role::User, 0..0)]
2569 );
2570
2571 buffer.update(cx, |buffer, cx| {
2572 buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2573 });
2574
2575 let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
2576 let message_2 = message_2.unwrap();
2577
2578 // We recycle newlines in the middle of a split message
2579 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2580 assert_eq!(
2581 messages(&context, cx),
2582 vec![
2583 (message_1.id, Role::User, 0..4),
2584 (message_2.id, Role::User, 4..16),
2585 ]
2586 );
2587
2588 let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
2589 let message_3 = message_3.unwrap();
2590
2591 // We don't recycle newlines at the end of a split message
2592 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2593 assert_eq!(
2594 messages(&context, cx),
2595 vec![
2596 (message_1.id, Role::User, 0..4),
2597 (message_3.id, Role::User, 4..5),
2598 (message_2.id, Role::User, 5..17),
2599 ]
2600 );
2601
2602 let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
2603 let message_4 = message_4.unwrap();
2604 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2605 assert_eq!(
2606 messages(&context, cx),
2607 vec![
2608 (message_1.id, Role::User, 0..4),
2609 (message_3.id, Role::User, 4..5),
2610 (message_2.id, Role::User, 5..9),
2611 (message_4.id, Role::User, 9..17),
2612 ]
2613 );
2614
2615 let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
2616 let message_5 = message_5.unwrap();
2617 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2618 assert_eq!(
2619 messages(&context, cx),
2620 vec![
2621 (message_1.id, Role::User, 0..4),
2622 (message_3.id, Role::User, 4..5),
2623 (message_2.id, Role::User, 5..9),
2624 (message_4.id, Role::User, 9..10),
2625 (message_5.id, Role::User, 10..18),
2626 ]
2627 );
2628
2629 let (message_6, message_7) =
2630 context.update(cx, |context, cx| context.split_message(14..16, cx));
2631 let message_6 = message_6.unwrap();
2632 let message_7 = message_7.unwrap();
2633 assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2634 assert_eq!(
2635 messages(&context, cx),
2636 vec![
2637 (message_1.id, Role::User, 0..4),
2638 (message_3.id, Role::User, 4..5),
2639 (message_2.id, Role::User, 5..9),
2640 (message_4.id, Role::User, 9..10),
2641 (message_5.id, Role::User, 10..14),
2642 (message_6.id, Role::User, 14..17),
2643 (message_7.id, Role::User, 17..19),
2644 ]
2645 );
2646 }
2647
2648 #[gpui::test]
2649 fn test_messages_for_offsets(cx: &mut AppContext) {
2650 let settings_store = SettingsStore::test(cx);
2651 LanguageModelRegistry::test(cx);
2652 cx.set_global(settings_store);
2653 assistant_panel::init(cx);
2654 let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2655 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2656 let context =
2657 cx.new_model(|cx| Context::local(registry, None, None, prompt_builder.clone(), cx));
2658 let buffer = context.read(cx).buffer.clone();
2659
2660 let message_1 = context.read(cx).message_anchors[0].clone();
2661 assert_eq!(
2662 messages(&context, cx),
2663 vec![(message_1.id, Role::User, 0..0)]
2664 );
2665
2666 buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2667 let message_2 = context
2668 .update(cx, |context, cx| {
2669 context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2670 })
2671 .unwrap();
2672 buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2673
2674 let message_3 = context
2675 .update(cx, |context, cx| {
2676 context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2677 })
2678 .unwrap();
2679 buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2680
2681 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2682 assert_eq!(
2683 messages(&context, cx),
2684 vec![
2685 (message_1.id, Role::User, 0..4),
2686 (message_2.id, Role::User, 4..8),
2687 (message_3.id, Role::User, 8..11)
2688 ]
2689 );
2690
2691 assert_eq!(
2692 message_ids_for_offsets(&context, &[0, 4, 9], cx),
2693 [message_1.id, message_2.id, message_3.id]
2694 );
2695 assert_eq!(
2696 message_ids_for_offsets(&context, &[0, 1, 11], cx),
2697 [message_1.id, message_3.id]
2698 );
2699
2700 let message_4 = context
2701 .update(cx, |context, cx| {
2702 context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2703 })
2704 .unwrap();
2705 assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2706 assert_eq!(
2707 messages(&context, cx),
2708 vec![
2709 (message_1.id, Role::User, 0..4),
2710 (message_2.id, Role::User, 4..8),
2711 (message_3.id, Role::User, 8..12),
2712 (message_4.id, Role::User, 12..12)
2713 ]
2714 );
2715 assert_eq!(
2716 message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
2717 [message_1.id, message_2.id, message_3.id, message_4.id]
2718 );
2719
2720 fn message_ids_for_offsets(
2721 context: &Model<Context>,
2722 offsets: &[usize],
2723 cx: &AppContext,
2724 ) -> Vec<MessageId> {
2725 context
2726 .read(cx)
2727 .messages_for_offsets(offsets.iter().copied(), cx)
2728 .into_iter()
2729 .map(|message| message.id)
2730 .collect()
2731 }
2732 }
2733
2734 #[gpui::test]
2735 async fn test_slash_commands(cx: &mut TestAppContext) {
2736 let settings_store = cx.update(SettingsStore::test);
2737 cx.set_global(settings_store);
2738 cx.update(LanguageModelRegistry::test);
2739 cx.update(Project::init_settings);
2740 cx.update(assistant_panel::init);
2741 let fs = FakeFs::new(cx.background_executor.clone());
2742
2743 fs.insert_tree(
2744 "/test",
2745 json!({
2746 "src": {
2747 "lib.rs": "fn one() -> usize { 1 }",
2748 "main.rs": "
2749 use crate::one;
2750 fn main() { one(); }
2751 ".unindent(),
2752 }
2753 }),
2754 )
2755 .await;
2756
2757 let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
2758 slash_command_registry.register_command(file_command::FileSlashCommand, false);
2759
2760 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
2761 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2762 let context = cx.new_model(|cx| {
2763 Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)
2764 });
2765
2766 let output_ranges = Rc::new(RefCell::new(HashSet::default()));
2767 context.update(cx, |_, cx| {
2768 cx.subscribe(&context, {
2769 let ranges = output_ranges.clone();
2770 move |_, _, event, _| match event {
2771 ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
2772 for range in removed {
2773 ranges.borrow_mut().remove(range);
2774 }
2775 for command in updated {
2776 ranges.borrow_mut().insert(command.source_range.clone());
2777 }
2778 }
2779 _ => {}
2780 }
2781 })
2782 .detach();
2783 });
2784
2785 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
2786
2787 // Insert a slash command
2788 buffer.update(cx, |buffer, cx| {
2789 buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
2790 });
2791 assert_text_and_output_ranges(
2792 &buffer,
2793 &output_ranges.borrow(),
2794 "
2795 «/file src/lib.rs»
2796 "
2797 .unindent()
2798 .trim_end(),
2799 cx,
2800 );
2801
2802 // Edit the argument of the slash command.
2803 buffer.update(cx, |buffer, cx| {
2804 let edit_offset = buffer.text().find("lib.rs").unwrap();
2805 buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
2806 });
2807 assert_text_and_output_ranges(
2808 &buffer,
2809 &output_ranges.borrow(),
2810 "
2811 «/file src/main.rs»
2812 "
2813 .unindent()
2814 .trim_end(),
2815 cx,
2816 );
2817
2818 // Edit the name of the slash command, using one that doesn't exist.
2819 buffer.update(cx, |buffer, cx| {
2820 let edit_offset = buffer.text().find("/file").unwrap();
2821 buffer.edit(
2822 [(edit_offset..edit_offset + "/file".len(), "/unknown")],
2823 None,
2824 cx,
2825 );
2826 });
2827 assert_text_and_output_ranges(
2828 &buffer,
2829 &output_ranges.borrow(),
2830 "
2831 /unknown src/main.rs
2832 "
2833 .unindent()
2834 .trim_end(),
2835 cx,
2836 );
2837
2838 #[track_caller]
2839 fn assert_text_and_output_ranges(
2840 buffer: &Model<Buffer>,
2841 ranges: &HashSet<Range<language::Anchor>>,
2842 expected_marked_text: &str,
2843 cx: &mut TestAppContext,
2844 ) {
2845 let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
2846 let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
2847 let mut ranges = ranges
2848 .iter()
2849 .map(|range| range.to_offset(buffer))
2850 .collect::<Vec<_>>();
2851 ranges.sort_by_key(|a| a.start);
2852 (buffer.text(), ranges)
2853 });
2854
2855 assert_eq!(actual_text, expected_text);
2856 assert_eq!(actual_ranges, expected_ranges);
2857 }
2858 }
2859
2860 #[gpui::test]
2861 async fn test_edit_step_parsing(cx: &mut TestAppContext) {
2862 cx.update(prompt_library::init);
2863 let settings_store = cx.update(SettingsStore::test);
2864 cx.set_global(settings_store);
2865 cx.update(Project::init_settings);
2866 let fs = FakeFs::new(cx.executor());
2867 fs.as_fake()
2868 .insert_tree(
2869 "/root",
2870 json!({
2871 "hello.rs": r#"
2872 fn hello() {
2873 println!("Hello, World!");
2874 }
2875 "#.unindent()
2876 }),
2877 )
2878 .await;
2879 let project = Project::test(fs, [Path::new("/root")], cx).await;
2880 cx.update(LanguageModelRegistry::test);
2881
2882 let model = cx.read(|cx| {
2883 LanguageModelRegistry::read_global(cx)
2884 .active_model()
2885 .unwrap()
2886 });
2887 cx.update(assistant_panel::init);
2888 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
2889
2890 // Create a new context
2891 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2892 let context = cx.new_model(|cx| {
2893 Context::local(
2894 registry.clone(),
2895 Some(project),
2896 None,
2897 prompt_builder.clone(),
2898 cx,
2899 )
2900 });
2901 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
2902
2903 // Simulate user input
2904 let user_message = indoc! {r#"
2905 Please add unnecessary complexity to this code:
2906
2907 ```hello.rs
2908 fn main() {
2909 println!("Hello, World!");
2910 }
2911 ```
2912 "#};
2913 buffer.update(cx, |buffer, cx| {
2914 buffer.edit([(0..0, user_message)], None, cx);
2915 });
2916
2917 // Simulate LLM response with edit steps
2918 let llm_response = indoc! {r#"
2919 Sure, I can help you with that. Here's a step-by-step process:
2920
2921 <step>
2922 First, let's extract the greeting into a separate function:
2923
2924 ```rust
2925 fn greet() {
2926 println!("Hello, World!");
2927 }
2928
2929 fn main() {
2930 greet();
2931 }
2932 ```
2933 </step>
2934
2935 <step>
2936 Now, let's make the greeting customizable:
2937
2938 ```rust
2939 fn greet(name: &str) {
2940 println!("Hello, {}!", name);
2941 }
2942
2943 fn main() {
2944 greet("World");
2945 }
2946 ```
2947 </step>
2948
2949 These changes make the code more modular and flexible.
2950 "#};
2951
2952 // Simulate the assist method to trigger the LLM response
2953 context.update(cx, |context, cx| context.assist(cx));
2954 cx.run_until_parked();
2955
2956 // Retrieve the assistant response message's start from the context
2957 let response_start_row = context.read_with(cx, |context, cx| {
2958 let buffer = context.buffer.read(cx);
2959 context.message_anchors[1].start.to_point(buffer).row
2960 });
2961
2962 // Simulate the LLM completion
2963 model
2964 .as_fake()
2965 .stream_last_completion_response(llm_response.to_string());
2966 model.as_fake().end_last_completion_stream();
2967
2968 // Wait for the completion to be processed
2969 cx.run_until_parked();
2970
2971 // Verify that the edit steps were parsed correctly
2972 context.read_with(cx, |context, cx| {
2973 assert_eq!(
2974 workflow_steps(context, cx),
2975 vec![
2976 (
2977 Point::new(response_start_row + 2, 0)
2978 ..Point::new(response_start_row + 12, 3),
2979 WorkflowStepTestStatus::Pending
2980 ),
2981 (
2982 Point::new(response_start_row + 14, 0)
2983 ..Point::new(response_start_row + 24, 3),
2984 WorkflowStepTestStatus::Pending
2985 ),
2986 ]
2987 );
2988 });
2989
2990 model
2991 .as_fake()
2992 .respond_to_last_tool_use(tool::WorkflowStepResolutionTool {
2993 step_title: "Title".into(),
2994 suggestions: vec![tool::WorkflowSuggestionTool {
2995 path: "/root/hello.rs".into(),
2996 // Simulate a symbol name that's slightly different than our outline query
2997 kind: tool::WorkflowSuggestionToolKind::Update {
2998 symbol: "fn main()".into(),
2999 description: "Extract a greeting function".into(),
3000 },
3001 }],
3002 });
3003
3004 // Wait for tool use to be processed.
3005 cx.run_until_parked();
3006
3007 // Verify that the first edit step is not pending anymore.
3008 context.read_with(cx, |context, cx| {
3009 assert_eq!(
3010 workflow_steps(context, cx),
3011 vec![
3012 (
3013 Point::new(response_start_row + 2, 0)
3014 ..Point::new(response_start_row + 12, 3),
3015 WorkflowStepTestStatus::Resolved
3016 ),
3017 (
3018 Point::new(response_start_row + 14, 0)
3019 ..Point::new(response_start_row + 24, 3),
3020 WorkflowStepTestStatus::Pending
3021 ),
3022 ]
3023 );
3024 });
3025
3026 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
3027 enum WorkflowStepTestStatus {
3028 Pending,
3029 Resolved,
3030 Error,
3031 }
3032
3033 fn workflow_steps(
3034 context: &Context,
3035 cx: &AppContext,
3036 ) -> Vec<(Range<Point>, WorkflowStepTestStatus)> {
3037 context
3038 .workflow_steps
3039 .iter()
3040 .map(|step| {
3041 let buffer = context.buffer.read(cx);
3042 let status = match &step.step.read(cx).resolution {
3043 None => WorkflowStepTestStatus::Pending,
3044 Some(Ok(_)) => WorkflowStepTestStatus::Resolved,
3045 Some(Err(_)) => WorkflowStepTestStatus::Error,
3046 };
3047 (step.range.to_point(buffer), status)
3048 })
3049 .collect()
3050 }
3051 }
3052
3053 #[gpui::test]
3054 async fn test_serialization(cx: &mut TestAppContext) {
3055 let settings_store = cx.update(SettingsStore::test);
3056 cx.set_global(settings_store);
3057 cx.update(LanguageModelRegistry::test);
3058 cx.update(assistant_panel::init);
3059 let registry = Arc::new(LanguageRegistry::test(cx.executor()));
3060 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3061 let context = cx.new_model(|cx| {
3062 Context::local(registry.clone(), None, None, prompt_builder.clone(), cx)
3063 });
3064 let buffer = context.read_with(cx, |context, _| context.buffer.clone());
3065 let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
3066 let message_1 = context.update(cx, |context, cx| {
3067 context
3068 .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
3069 .unwrap()
3070 });
3071 let message_2 = context.update(cx, |context, cx| {
3072 context
3073 .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
3074 .unwrap()
3075 });
3076 buffer.update(cx, |buffer, cx| {
3077 buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
3078 buffer.finalize_last_transaction();
3079 });
3080 let _message_3 = context.update(cx, |context, cx| {
3081 context
3082 .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
3083 .unwrap()
3084 });
3085 buffer.update(cx, |buffer, cx| buffer.undo(cx));
3086 assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
3087 assert_eq!(
3088 cx.read(|cx| messages(&context, cx)),
3089 [
3090 (message_0, Role::User, 0..2),
3091 (message_1.id, Role::Assistant, 2..6),
3092 (message_2.id, Role::System, 6..6),
3093 ]
3094 );
3095
3096 let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
3097 let deserialized_context = cx.new_model(|cx| {
3098 Context::deserialize(
3099 serialized_context,
3100 Default::default(),
3101 registry.clone(),
3102 prompt_builder.clone(),
3103 None,
3104 None,
3105 cx,
3106 )
3107 });
3108 let deserialized_buffer =
3109 deserialized_context.read_with(cx, |context, _| context.buffer.clone());
3110 assert_eq!(
3111 deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
3112 "a\nb\nc\n"
3113 );
3114 assert_eq!(
3115 cx.read(|cx| messages(&deserialized_context, cx)),
3116 [
3117 (message_0, Role::User, 0..2),
3118 (message_1.id, Role::Assistant, 2..6),
3119 (message_2.id, Role::System, 6..6),
3120 ]
3121 );
3122 }
3123
3124 #[gpui::test(iterations = 100)]
3125 async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
3126 let min_peers = env::var("MIN_PEERS")
3127 .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
3128 .unwrap_or(2);
3129 let max_peers = env::var("MAX_PEERS")
3130 .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
3131 .unwrap_or(5);
3132 let operations = env::var("OPERATIONS")
3133 .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
3134 .unwrap_or(50);
3135
3136 let settings_store = cx.update(SettingsStore::test);
3137 cx.set_global(settings_store);
3138 cx.update(LanguageModelRegistry::test);
3139
3140 cx.update(assistant_panel::init);
3141 let slash_commands = cx.update(SlashCommandRegistry::default_global);
3142 slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
3143 slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
3144 slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
3145
3146 let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
3147 let network = Arc::new(Mutex::new(Network::new(rng.clone())));
3148 let mut contexts = Vec::new();
3149
3150 let num_peers = rng.gen_range(min_peers..=max_peers);
3151 let context_id = ContextId::new();
3152 let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
3153 for i in 0..num_peers {
3154 let context = cx.new_model(|cx| {
3155 Context::new(
3156 context_id.clone(),
3157 i as ReplicaId,
3158 language::Capability::ReadWrite,
3159 registry.clone(),
3160 prompt_builder.clone(),
3161 None,
3162 None,
3163 cx,
3164 )
3165 });
3166
3167 cx.update(|cx| {
3168 cx.subscribe(&context, {
3169 let network = network.clone();
3170 move |_, event, _| {
3171 if let ContextEvent::Operation(op) = event {
3172 network
3173 .lock()
3174 .broadcast(i as ReplicaId, vec![op.to_proto()]);
3175 }
3176 }
3177 })
3178 .detach();
3179 });
3180
3181 contexts.push(context);
3182 network.lock().add_peer(i as ReplicaId);
3183 }
3184
3185 let mut mutation_count = operations;
3186
3187 while mutation_count > 0
3188 || !network.lock().is_idle()
3189 || network.lock().contains_disconnected_peers()
3190 {
3191 let context_index = rng.gen_range(0..contexts.len());
3192 let context = &contexts[context_index];
3193
3194 match rng.gen_range(0..100) {
3195 0..=29 if mutation_count > 0 => {
3196 log::info!("Context {}: edit buffer", context_index);
3197 context.update(cx, |context, cx| {
3198 context
3199 .buffer
3200 .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
3201 });
3202 mutation_count -= 1;
3203 }
3204 30..=44 if mutation_count > 0 => {
3205 context.update(cx, |context, cx| {
3206 let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
3207 log::info!("Context {}: split message at {:?}", context_index, range);
3208 context.split_message(range, cx);
3209 });
3210 mutation_count -= 1;
3211 }
3212 45..=59 if mutation_count > 0 => {
3213 context.update(cx, |context, cx| {
3214 if let Some(message) = context.messages(cx).choose(&mut rng) {
3215 let role = *[Role::User, Role::Assistant, Role::System]
3216 .choose(&mut rng)
3217 .unwrap();
3218 log::info!(
3219 "Context {}: insert message after {:?} with {:?}",
3220 context_index,
3221 message.id,
3222 role
3223 );
3224 context.insert_message_after(message.id, role, MessageStatus::Done, cx);
3225 }
3226 });
3227 mutation_count -= 1;
3228 }
3229 60..=74 if mutation_count > 0 => {
3230 context.update(cx, |context, cx| {
3231 let command_text = "/".to_string()
3232 + slash_commands
3233 .command_names()
3234 .choose(&mut rng)
3235 .unwrap()
3236 .clone()
3237 .as_ref();
3238
3239 let command_range = context.buffer.update(cx, |buffer, cx| {
3240 let offset = buffer.random_byte_range(0, &mut rng).start;
3241 buffer.edit(
3242 [(offset..offset, format!("\n{}\n", command_text))],
3243 None,
3244 cx,
3245 );
3246 offset + 1..offset + 1 + command_text.len()
3247 });
3248
3249 let output_len = rng.gen_range(1..=10);
3250 let output_text = RandomCharIter::new(&mut rng)
3251 .filter(|c| *c != '\r')
3252 .take(output_len)
3253 .collect::<String>();
3254
3255 let num_sections = rng.gen_range(0..=3);
3256 let mut sections = Vec::with_capacity(num_sections);
3257 for _ in 0..num_sections {
3258 let section_start = rng.gen_range(0..output_len);
3259 let section_end = rng.gen_range(section_start..=output_len);
3260 sections.push(SlashCommandOutputSection {
3261 range: section_start..section_end,
3262 icon: ui::IconName::Ai,
3263 label: "section".into(),
3264 });
3265 }
3266
3267 log::info!(
3268 "Context {}: insert slash command output at {:?} with {:?}",
3269 context_index,
3270 command_range,
3271 sections
3272 );
3273
3274 let command_range =
3275 context.buffer.read(cx).anchor_after(command_range.start)
3276 ..context.buffer.read(cx).anchor_after(command_range.end);
3277 context.insert_command_output(
3278 command_range,
3279 Task::ready(Ok(SlashCommandOutput {
3280 text: output_text,
3281 sections,
3282 run_commands_in_text: false,
3283 })),
3284 true,
3285 cx,
3286 );
3287 });
3288 cx.run_until_parked();
3289 mutation_count -= 1;
3290 }
3291 75..=84 if mutation_count > 0 => {
3292 context.update(cx, |context, cx| {
3293 if let Some(message) = context.messages(cx).choose(&mut rng) {
3294 let new_status = match rng.gen_range(0..3) {
3295 0 => MessageStatus::Done,
3296 1 => MessageStatus::Pending,
3297 _ => MessageStatus::Error(SharedString::from("Random error")),
3298 };
3299 log::info!(
3300 "Context {}: update message {:?} status to {:?}",
3301 context_index,
3302 message.id,
3303 new_status
3304 );
3305 context.update_metadata(message.id, cx, |metadata| {
3306 metadata.status = new_status;
3307 });
3308 }
3309 });
3310 mutation_count -= 1;
3311 }
3312 _ => {
3313 let replica_id = context_index as ReplicaId;
3314 if network.lock().is_disconnected(replica_id) {
3315 network.lock().reconnect_peer(replica_id, 0);
3316
3317 let (ops_to_send, ops_to_receive) = cx.read(|cx| {
3318 let host_context = &contexts[0].read(cx);
3319 let guest_context = context.read(cx);
3320 (
3321 guest_context.serialize_ops(&host_context.version(cx), cx),
3322 host_context.serialize_ops(&guest_context.version(cx), cx),
3323 )
3324 });
3325 let ops_to_send = ops_to_send.await;
3326 let ops_to_receive = ops_to_receive
3327 .await
3328 .into_iter()
3329 .map(ContextOperation::from_proto)
3330 .collect::<Result<Vec<_>>>()
3331 .unwrap();
3332 log::info!(
3333 "Context {}: reconnecting. Sent {} operations, received {} operations",
3334 context_index,
3335 ops_to_send.len(),
3336 ops_to_receive.len()
3337 );
3338
3339 network.lock().broadcast(replica_id, ops_to_send);
3340 context
3341 .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx))
3342 .unwrap();
3343 } else if rng.gen_bool(0.1) && replica_id != 0 {
3344 log::info!("Context {}: disconnecting", context_index);
3345 network.lock().disconnect_peer(replica_id);
3346 } else if network.lock().has_unreceived(replica_id) {
3347 log::info!("Context {}: applying operations", context_index);
3348 let ops = network.lock().receive(replica_id);
3349 let ops = ops
3350 .into_iter()
3351 .map(ContextOperation::from_proto)
3352 .collect::<Result<Vec<_>>>()
3353 .unwrap();
3354 context
3355 .update(cx, |context, cx| context.apply_ops(ops, cx))
3356 .unwrap();
3357 }
3358 }
3359 }
3360 }
3361
3362 cx.read(|cx| {
3363 let first_context = contexts[0].read(cx);
3364 for context in &contexts[1..] {
3365 let context = context.read(cx);
3366 assert!(context.pending_ops.is_empty());
3367 assert_eq!(
3368 context.buffer.read(cx).text(),
3369 first_context.buffer.read(cx).text(),
3370 "Context {} text != Context 0 text",
3371 context.buffer.read(cx).replica_id()
3372 );
3373 assert_eq!(
3374 context.message_anchors,
3375 first_context.message_anchors,
3376 "Context {} messages != Context 0 messages",
3377 context.buffer.read(cx).replica_id()
3378 );
3379 assert_eq!(
3380 context.messages_metadata,
3381 first_context.messages_metadata,
3382 "Context {} message metadata != Context 0 message metadata",
3383 context.buffer.read(cx).replica_id()
3384 );
3385 assert_eq!(
3386 context.slash_command_output_sections,
3387 first_context.slash_command_output_sections,
3388 "Context {} slash command output sections != Context 0 slash command output sections",
3389 context.buffer.read(cx).replica_id()
3390 );
3391 }
3392 });
3393 }
3394
3395 fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
3396 context
3397 .read(cx)
3398 .messages(cx)
3399 .map(|message| (message.id, message.role, message.offset_range))
3400 .collect()
3401 }
3402
3403 #[derive(Clone)]
3404 struct FakeSlashCommand(String);
3405
3406 impl SlashCommand for FakeSlashCommand {
3407 fn name(&self) -> String {
3408 self.0.clone()
3409 }
3410
3411 fn description(&self) -> String {
3412 format!("Fake slash command: {}", self.0)
3413 }
3414
3415 fn menu_text(&self) -> String {
3416 format!("Run fake command: {}", self.0)
3417 }
3418
3419 fn complete_argument(
3420 self: Arc<Self>,
3421 _arguments: &[String],
3422 _cancel: Arc<AtomicBool>,
3423 _workspace: Option<WeakView<Workspace>>,
3424 _cx: &mut WindowContext,
3425 ) -> Task<Result<Vec<ArgumentCompletion>>> {
3426 Task::ready(Ok(vec![]))
3427 }
3428
3429 fn requires_argument(&self) -> bool {
3430 false
3431 }
3432
3433 fn run(
3434 self: Arc<Self>,
3435 _arguments: &[String],
3436 _workspace: WeakView<Workspace>,
3437 _delegate: Option<Arc<dyn LspAdapterDelegate>>,
3438 _cx: &mut WindowContext,
3439 ) -> Task<Result<SlashCommandOutput>> {
3440 Task::ready(Ok(SlashCommandOutput {
3441 text: format!("Executed fake command: {}", self.0),
3442 sections: vec![],
3443 run_commands_in_text: false,
3444 }))
3445 }
3446 }
3447}