1#[cfg(test)]
2mod context_tests;
3
4use crate::{
5 prompts::PromptBuilder, slash_command::SlashCommandLine, workflow::WorkflowStep, MessageId,
6 MessageStatus,
7};
8use anyhow::{anyhow, Context as _, Result};
9use assistant_slash_command::{
10 SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
11};
12use client::{self, proto, telemetry::Telemetry};
13use clock::ReplicaId;
14use collections::{HashMap, HashSet};
15use fs::{Fs, RemoveOptions};
16use futures::{future::Shared, stream::FuturesUnordered, FutureExt, StreamExt};
17use gpui::{
18 AppContext, Context as _, EventEmitter, Image, Model, ModelContext, RenderImage, SharedString,
19 Subscription, Task,
20};
21
22use language::{AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, Point, ToOffset};
23use language_model::{
24 LanguageModel, LanguageModelCacheConfiguration, LanguageModelImage, LanguageModelRegistry,
25 LanguageModelRequest, LanguageModelRequestMessage, MessageContent, Role,
26};
27use open_ai::Model as OpenAiModel;
28use paths::{context_images_dir, contexts_dir};
29use project::Project;
30use serde::{Deserialize, Serialize};
31use smallvec::SmallVec;
32use std::{
33 cmp::{max, Ordering},
34 collections::hash_map,
35 fmt::Debug,
36 iter, mem,
37 ops::Range,
38 path::{Path, PathBuf},
39 sync::Arc,
40 time::{Duration, Instant},
41};
42use telemetry_events::AssistantKind;
43use text::BufferSnapshot;
44use util::{post_inc, ResultExt, TryFutureExt};
45use uuid::Uuid;
46
47#[derive(Clone, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
48pub struct ContextId(String);
49
50impl ContextId {
51 pub fn new() -> Self {
52 Self(Uuid::new_v4().to_string())
53 }
54
55 pub fn from_proto(id: String) -> Self {
56 Self(id)
57 }
58
59 pub fn to_proto(&self) -> String {
60 self.0.clone()
61 }
62}
63
64#[derive(Clone, Debug)]
65pub enum ContextOperation {
66 InsertMessage {
67 anchor: MessageAnchor,
68 metadata: MessageMetadata,
69 version: clock::Global,
70 },
71 UpdateMessage {
72 message_id: MessageId,
73 metadata: MessageMetadata,
74 version: clock::Global,
75 },
76 UpdateSummary {
77 summary: ContextSummary,
78 version: clock::Global,
79 },
80 SlashCommandFinished {
81 id: SlashCommandId,
82 output_range: Range<language::Anchor>,
83 sections: Vec<SlashCommandOutputSection<language::Anchor>>,
84 version: clock::Global,
85 },
86 BufferOperation(language::Operation),
87}
88
89impl ContextOperation {
90 pub fn from_proto(op: proto::ContextOperation) -> Result<Self> {
91 match op.variant.context("invalid variant")? {
92 proto::context_operation::Variant::InsertMessage(insert) => {
93 let message = insert.message.context("invalid message")?;
94 let id = MessageId(language::proto::deserialize_timestamp(
95 message.id.context("invalid id")?,
96 ));
97 Ok(Self::InsertMessage {
98 anchor: MessageAnchor {
99 id,
100 start: language::proto::deserialize_anchor(
101 message.start.context("invalid anchor")?,
102 )
103 .context("invalid anchor")?,
104 },
105 metadata: MessageMetadata {
106 role: Role::from_proto(message.role),
107 status: MessageStatus::from_proto(
108 message.status.context("invalid status")?,
109 ),
110 timestamp: id.0,
111 cache: None,
112 },
113 version: language::proto::deserialize_version(&insert.version),
114 })
115 }
116 proto::context_operation::Variant::UpdateMessage(update) => Ok(Self::UpdateMessage {
117 message_id: MessageId(language::proto::deserialize_timestamp(
118 update.message_id.context("invalid message id")?,
119 )),
120 metadata: MessageMetadata {
121 role: Role::from_proto(update.role),
122 status: MessageStatus::from_proto(update.status.context("invalid status")?),
123 timestamp: language::proto::deserialize_timestamp(
124 update.timestamp.context("invalid timestamp")?,
125 ),
126 cache: None,
127 },
128 version: language::proto::deserialize_version(&update.version),
129 }),
130 proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary {
131 summary: ContextSummary {
132 text: update.summary,
133 done: update.done,
134 timestamp: language::proto::deserialize_timestamp(
135 update.timestamp.context("invalid timestamp")?,
136 ),
137 },
138 version: language::proto::deserialize_version(&update.version),
139 }),
140 proto::context_operation::Variant::SlashCommandFinished(finished) => {
141 Ok(Self::SlashCommandFinished {
142 id: SlashCommandId(language::proto::deserialize_timestamp(
143 finished.id.context("invalid id")?,
144 )),
145 output_range: language::proto::deserialize_anchor_range(
146 finished.output_range.context("invalid range")?,
147 )?,
148 sections: finished
149 .sections
150 .into_iter()
151 .map(|section| {
152 Ok(SlashCommandOutputSection {
153 range: language::proto::deserialize_anchor_range(
154 section.range.context("invalid range")?,
155 )?,
156 icon: section.icon_name.parse()?,
157 label: section.label.into(),
158 })
159 })
160 .collect::<Result<Vec<_>>>()?,
161 version: language::proto::deserialize_version(&finished.version),
162 })
163 }
164 proto::context_operation::Variant::BufferOperation(op) => Ok(Self::BufferOperation(
165 language::proto::deserialize_operation(
166 op.operation.context("invalid buffer operation")?,
167 )?,
168 )),
169 }
170 }
171
172 pub fn to_proto(&self) -> proto::ContextOperation {
173 match self {
174 Self::InsertMessage {
175 anchor,
176 metadata,
177 version,
178 } => proto::ContextOperation {
179 variant: Some(proto::context_operation::Variant::InsertMessage(
180 proto::context_operation::InsertMessage {
181 message: Some(proto::ContextMessage {
182 id: Some(language::proto::serialize_timestamp(anchor.id.0)),
183 start: Some(language::proto::serialize_anchor(&anchor.start)),
184 role: metadata.role.to_proto() as i32,
185 status: Some(metadata.status.to_proto()),
186 }),
187 version: language::proto::serialize_version(version),
188 },
189 )),
190 },
191 Self::UpdateMessage {
192 message_id,
193 metadata,
194 version,
195 } => proto::ContextOperation {
196 variant: Some(proto::context_operation::Variant::UpdateMessage(
197 proto::context_operation::UpdateMessage {
198 message_id: Some(language::proto::serialize_timestamp(message_id.0)),
199 role: metadata.role.to_proto() as i32,
200 status: Some(metadata.status.to_proto()),
201 timestamp: Some(language::proto::serialize_timestamp(metadata.timestamp)),
202 version: language::proto::serialize_version(version),
203 },
204 )),
205 },
206 Self::UpdateSummary { summary, version } => proto::ContextOperation {
207 variant: Some(proto::context_operation::Variant::UpdateSummary(
208 proto::context_operation::UpdateSummary {
209 summary: summary.text.clone(),
210 done: summary.done,
211 timestamp: Some(language::proto::serialize_timestamp(summary.timestamp)),
212 version: language::proto::serialize_version(version),
213 },
214 )),
215 },
216 Self::SlashCommandFinished {
217 id,
218 output_range,
219 sections,
220 version,
221 } => proto::ContextOperation {
222 variant: Some(proto::context_operation::Variant::SlashCommandFinished(
223 proto::context_operation::SlashCommandFinished {
224 id: Some(language::proto::serialize_timestamp(id.0)),
225 output_range: Some(language::proto::serialize_anchor_range(
226 output_range.clone(),
227 )),
228 sections: sections
229 .iter()
230 .map(|section| {
231 let icon_name: &'static str = section.icon.into();
232 proto::SlashCommandOutputSection {
233 range: Some(language::proto::serialize_anchor_range(
234 section.range.clone(),
235 )),
236 icon_name: icon_name.to_string(),
237 label: section.label.to_string(),
238 }
239 })
240 .collect(),
241 version: language::proto::serialize_version(version),
242 },
243 )),
244 },
245 Self::BufferOperation(operation) => proto::ContextOperation {
246 variant: Some(proto::context_operation::Variant::BufferOperation(
247 proto::context_operation::BufferOperation {
248 operation: Some(language::proto::serialize_operation(operation)),
249 },
250 )),
251 },
252 }
253 }
254
255 fn timestamp(&self) -> clock::Lamport {
256 match self {
257 Self::InsertMessage { anchor, .. } => anchor.id.0,
258 Self::UpdateMessage { metadata, .. } => metadata.timestamp,
259 Self::UpdateSummary { summary, .. } => summary.timestamp,
260 Self::SlashCommandFinished { id, .. } => id.0,
261 Self::BufferOperation(_) => {
262 panic!("reading the timestamp of a buffer operation is not supported")
263 }
264 }
265 }
266
267 /// Returns the current version of the context operation.
268 pub fn version(&self) -> &clock::Global {
269 match self {
270 Self::InsertMessage { version, .. }
271 | Self::UpdateMessage { version, .. }
272 | Self::UpdateSummary { version, .. }
273 | Self::SlashCommandFinished { version, .. } => version,
274 Self::BufferOperation(_) => {
275 panic!("reading the version of a buffer operation is not supported")
276 }
277 }
278 }
279}
280
281#[derive(Debug, Clone)]
282pub enum ContextEvent {
283 ShowAssistError(SharedString),
284 MessagesEdited,
285 SummaryChanged,
286 WorkflowStepsRemoved(Vec<Range<language::Anchor>>),
287 WorkflowStepUpdated(Range<language::Anchor>),
288 StreamedCompletion,
289 PendingSlashCommandsUpdated {
290 removed: Vec<Range<language::Anchor>>,
291 updated: Vec<PendingSlashCommand>,
292 },
293 SlashCommandFinished {
294 output_range: Range<language::Anchor>,
295 sections: Vec<SlashCommandOutputSection<language::Anchor>>,
296 run_commands_in_output: bool,
297 expand_result: bool,
298 },
299 Operation(ContextOperation),
300}
301
302#[derive(Clone, Default, Debug)]
303pub struct ContextSummary {
304 pub text: String,
305 done: bool,
306 timestamp: clock::Lamport,
307}
308
309#[derive(Clone, Debug, Eq, PartialEq)]
310pub struct MessageAnchor {
311 pub id: MessageId,
312 pub start: language::Anchor,
313}
314
315#[derive(Clone, Debug, Eq, PartialEq)]
316pub enum CacheStatus {
317 Pending,
318 Cached,
319}
320
321#[derive(Clone, Debug, Eq, PartialEq)]
322pub struct MessageCacheMetadata {
323 pub is_anchor: bool,
324 pub is_final_anchor: bool,
325 pub status: CacheStatus,
326 pub cached_at: clock::Global,
327}
328
329#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
330pub struct MessageMetadata {
331 pub role: Role,
332 pub status: MessageStatus,
333 pub(crate) timestamp: clock::Lamport,
334 #[serde(skip)]
335 pub cache: Option<MessageCacheMetadata>,
336}
337
338impl From<&Message> for MessageMetadata {
339 fn from(message: &Message) -> Self {
340 Self {
341 role: message.role,
342 status: message.status.clone(),
343 timestamp: message.id.0,
344 cache: message.cache.clone(),
345 }
346 }
347}
348
349impl MessageMetadata {
350 pub fn is_cache_valid(&self, buffer: &BufferSnapshot, range: &Range<usize>) -> bool {
351 let result = match &self.cache {
352 Some(MessageCacheMetadata { cached_at, .. }) => !buffer.has_edits_since_in_range(
353 &cached_at,
354 Range {
355 start: buffer.anchor_at(range.start, Bias::Right),
356 end: buffer.anchor_at(range.end, Bias::Left),
357 },
358 ),
359 _ => false,
360 };
361 result
362 }
363}
364
365#[derive(Clone, Debug)]
366pub struct MessageImage {
367 image_id: u64,
368 image: Shared<Task<Option<LanguageModelImage>>>,
369}
370
371impl PartialEq for MessageImage {
372 fn eq(&self, other: &Self) -> bool {
373 self.image_id == other.image_id
374 }
375}
376
377impl Eq for MessageImage {}
378
379#[derive(Clone, Debug)]
380pub struct Message {
381 pub image_offsets: SmallVec<[(usize, MessageImage); 1]>,
382 pub offset_range: Range<usize>,
383 pub index_range: Range<usize>,
384 pub id: MessageId,
385 pub anchor: language::Anchor,
386 pub role: Role,
387 pub status: MessageStatus,
388 pub cache: Option<MessageCacheMetadata>,
389}
390
391impl Message {
392 fn to_request_message(&self, buffer: &Buffer) -> Option<LanguageModelRequestMessage> {
393 let mut content = Vec::new();
394
395 let mut range_start = self.offset_range.start;
396 for (image_offset, message_image) in self.image_offsets.iter() {
397 if *image_offset != range_start {
398 if let Some(text) = Self::collect_text_content(buffer, range_start..*image_offset) {
399 content.push(text);
400 }
401 }
402
403 if let Some(image) = message_image.image.clone().now_or_never().flatten() {
404 content.push(language_model::MessageContent::Image(image));
405 }
406
407 range_start = *image_offset;
408 }
409 if range_start != self.offset_range.end {
410 if let Some(text) =
411 Self::collect_text_content(buffer, range_start..self.offset_range.end)
412 {
413 content.push(text);
414 }
415 }
416
417 if content.is_empty() {
418 return None;
419 }
420
421 Some(LanguageModelRequestMessage {
422 role: self.role,
423 content,
424 cache: self.cache.as_ref().map_or(false, |cache| cache.is_anchor),
425 })
426 }
427
428 fn collect_text_content(buffer: &Buffer, range: Range<usize>) -> Option<MessageContent> {
429 let text: String = buffer.text_for_range(range.clone()).collect();
430 if text.trim().is_empty() {
431 None
432 } else {
433 Some(MessageContent::Text(text))
434 }
435 }
436}
437
438#[derive(Clone, Debug)]
439pub struct ImageAnchor {
440 pub anchor: language::Anchor,
441 pub image_id: u64,
442 pub render_image: Arc<RenderImage>,
443 pub image: Shared<Task<Option<LanguageModelImage>>>,
444}
445
446struct PendingCompletion {
447 id: usize,
448 assistant_message_id: MessageId,
449 _task: Task<()>,
450}
451
452#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
453pub struct SlashCommandId(clock::Lamport);
454
455struct WorkflowStepEntry {
456 range: Range<language::Anchor>,
457 step: Model<WorkflowStep>,
458}
459
460pub struct Context {
461 id: ContextId,
462 timestamp: clock::Lamport,
463 version: clock::Global,
464 pending_ops: Vec<ContextOperation>,
465 operations: Vec<ContextOperation>,
466 buffer: Model<Buffer>,
467 pending_slash_commands: Vec<PendingSlashCommand>,
468 edits_since_last_slash_command_parse: language::Subscription,
469 finished_slash_commands: HashSet<SlashCommandId>,
470 slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
471 message_anchors: Vec<MessageAnchor>,
472 images: HashMap<u64, (Arc<RenderImage>, Shared<Task<Option<LanguageModelImage>>>)>,
473 image_anchors: Vec<ImageAnchor>,
474 messages_metadata: HashMap<MessageId, MessageMetadata>,
475 summary: Option<ContextSummary>,
476 pending_summary: Task<Option<()>>,
477 completion_count: usize,
478 pending_completions: Vec<PendingCompletion>,
479 token_count: Option<usize>,
480 pending_token_count: Task<Option<()>>,
481 pending_save: Task<Result<()>>,
482 pending_cache_warming_task: Task<Option<()>>,
483 path: Option<PathBuf>,
484 _subscriptions: Vec<Subscription>,
485 telemetry: Option<Arc<Telemetry>>,
486 language_registry: Arc<LanguageRegistry>,
487 workflow_steps: Vec<WorkflowStepEntry>,
488 edits_since_last_workflow_step_prune: language::Subscription,
489 project: Option<Model<Project>>,
490 prompt_builder: Arc<PromptBuilder>,
491}
492
493impl EventEmitter<ContextEvent> for Context {}
494
495impl Context {
496 pub fn local(
497 language_registry: Arc<LanguageRegistry>,
498 project: Option<Model<Project>>,
499 telemetry: Option<Arc<Telemetry>>,
500 prompt_builder: Arc<PromptBuilder>,
501 cx: &mut ModelContext<Self>,
502 ) -> Self {
503 Self::new(
504 ContextId::new(),
505 ReplicaId::default(),
506 language::Capability::ReadWrite,
507 language_registry,
508 prompt_builder,
509 project,
510 telemetry,
511 cx,
512 )
513 }
514
515 #[allow(clippy::too_many_arguments)]
516 pub fn new(
517 id: ContextId,
518 replica_id: ReplicaId,
519 capability: language::Capability,
520 language_registry: Arc<LanguageRegistry>,
521 prompt_builder: Arc<PromptBuilder>,
522 project: Option<Model<Project>>,
523 telemetry: Option<Arc<Telemetry>>,
524 cx: &mut ModelContext<Self>,
525 ) -> Self {
526 let buffer = cx.new_model(|_cx| {
527 let mut buffer = Buffer::remote(
528 language::BufferId::new(1).unwrap(),
529 replica_id,
530 capability,
531 "",
532 );
533 buffer.set_language_registry(language_registry.clone());
534 buffer
535 });
536 let edits_since_last_slash_command_parse =
537 buffer.update(cx, |buffer, _| buffer.subscribe());
538 let edits_since_last_workflow_step_prune =
539 buffer.update(cx, |buffer, _| buffer.subscribe());
540 let mut this = Self {
541 id,
542 timestamp: clock::Lamport::new(replica_id),
543 version: clock::Global::new(),
544 pending_ops: Vec::new(),
545 operations: Vec::new(),
546 message_anchors: Default::default(),
547 image_anchors: Default::default(),
548 images: Default::default(),
549 messages_metadata: Default::default(),
550 pending_slash_commands: Vec::new(),
551 finished_slash_commands: HashSet::default(),
552 slash_command_output_sections: Vec::new(),
553 edits_since_last_slash_command_parse,
554 summary: None,
555 pending_summary: Task::ready(None),
556 completion_count: Default::default(),
557 pending_completions: Default::default(),
558 token_count: None,
559 pending_token_count: Task::ready(None),
560 pending_cache_warming_task: Task::ready(None),
561 _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
562 pending_save: Task::ready(Ok(())),
563 path: None,
564 buffer,
565 telemetry,
566 project,
567 language_registry,
568 workflow_steps: Vec::new(),
569 edits_since_last_workflow_step_prune,
570 prompt_builder,
571 };
572
573 let first_message_id = MessageId(clock::Lamport {
574 replica_id: 0,
575 value: 0,
576 });
577 let message = MessageAnchor {
578 id: first_message_id,
579 start: language::Anchor::MIN,
580 };
581 this.messages_metadata.insert(
582 first_message_id,
583 MessageMetadata {
584 role: Role::User,
585 status: MessageStatus::Done,
586 timestamp: first_message_id.0,
587 cache: None,
588 },
589 );
590 this.message_anchors.push(message);
591
592 this.set_language(cx);
593 this.count_remaining_tokens(cx);
594 this
595 }
596
597 pub(crate) fn serialize(&self, cx: &AppContext) -> SavedContext {
598 let buffer = self.buffer.read(cx);
599 SavedContext {
600 id: Some(self.id.clone()),
601 zed: "context".into(),
602 version: SavedContext::VERSION.into(),
603 text: buffer.text(),
604 messages: self
605 .messages(cx)
606 .map(|message| SavedMessage {
607 id: message.id,
608 start: message.offset_range.start,
609 metadata: self.messages_metadata[&message.id].clone(),
610 image_offsets: message
611 .image_offsets
612 .iter()
613 .map(|image_offset| (image_offset.0, image_offset.1.image_id))
614 .collect(),
615 })
616 .collect(),
617 summary: self
618 .summary
619 .as_ref()
620 .map(|summary| summary.text.clone())
621 .unwrap_or_default(),
622 slash_command_output_sections: self
623 .slash_command_output_sections
624 .iter()
625 .filter_map(|section| {
626 let range = section.range.to_offset(buffer);
627 if section.range.start.is_valid(buffer) && !range.is_empty() {
628 Some(assistant_slash_command::SlashCommandOutputSection {
629 range,
630 icon: section.icon,
631 label: section.label.clone(),
632 })
633 } else {
634 None
635 }
636 })
637 .collect(),
638 }
639 }
640
641 #[allow(clippy::too_many_arguments)]
642 pub fn deserialize(
643 saved_context: SavedContext,
644 path: PathBuf,
645 language_registry: Arc<LanguageRegistry>,
646 prompt_builder: Arc<PromptBuilder>,
647 project: Option<Model<Project>>,
648 telemetry: Option<Arc<Telemetry>>,
649 cx: &mut ModelContext<Self>,
650 ) -> Self {
651 let id = saved_context.id.clone().unwrap_or_else(|| ContextId::new());
652 let mut this = Self::new(
653 id,
654 ReplicaId::default(),
655 language::Capability::ReadWrite,
656 language_registry,
657 prompt_builder,
658 project,
659 telemetry,
660 cx,
661 );
662 this.path = Some(path);
663 this.buffer.update(cx, |buffer, cx| {
664 buffer.set_text(saved_context.text.as_str(), cx)
665 });
666 let operations = saved_context.into_ops(&this.buffer, cx);
667 this.apply_ops(operations, cx).unwrap();
668 this
669 }
670
671 pub fn id(&self) -> &ContextId {
672 &self.id
673 }
674
675 pub fn replica_id(&self) -> ReplicaId {
676 self.timestamp.replica_id
677 }
678
679 pub fn version(&self, cx: &AppContext) -> ContextVersion {
680 ContextVersion {
681 context: self.version.clone(),
682 buffer: self.buffer.read(cx).version(),
683 }
684 }
685
686 pub fn set_capability(
687 &mut self,
688 capability: language::Capability,
689 cx: &mut ModelContext<Self>,
690 ) {
691 self.buffer
692 .update(cx, |buffer, cx| buffer.set_capability(capability, cx));
693 }
694
695 fn next_timestamp(&mut self) -> clock::Lamport {
696 let timestamp = self.timestamp.tick();
697 self.version.observe(timestamp);
698 timestamp
699 }
700
701 pub fn serialize_ops(
702 &self,
703 since: &ContextVersion,
704 cx: &AppContext,
705 ) -> Task<Vec<proto::ContextOperation>> {
706 let buffer_ops = self
707 .buffer
708 .read(cx)
709 .serialize_ops(Some(since.buffer.clone()), cx);
710
711 let mut context_ops = self
712 .operations
713 .iter()
714 .filter(|op| !since.context.observed(op.timestamp()))
715 .cloned()
716 .collect::<Vec<_>>();
717 context_ops.extend(self.pending_ops.iter().cloned());
718
719 cx.background_executor().spawn(async move {
720 let buffer_ops = buffer_ops.await;
721 context_ops.sort_unstable_by_key(|op| op.timestamp());
722 buffer_ops
723 .into_iter()
724 .map(|op| proto::ContextOperation {
725 variant: Some(proto::context_operation::Variant::BufferOperation(
726 proto::context_operation::BufferOperation {
727 operation: Some(op),
728 },
729 )),
730 })
731 .chain(context_ops.into_iter().map(|op| op.to_proto()))
732 .collect()
733 })
734 }
735
736 pub fn apply_ops(
737 &mut self,
738 ops: impl IntoIterator<Item = ContextOperation>,
739 cx: &mut ModelContext<Self>,
740 ) -> Result<()> {
741 let mut buffer_ops = Vec::new();
742 for op in ops {
743 match op {
744 ContextOperation::BufferOperation(buffer_op) => buffer_ops.push(buffer_op),
745 op @ _ => self.pending_ops.push(op),
746 }
747 }
748 self.buffer
749 .update(cx, |buffer, cx| buffer.apply_ops(buffer_ops, cx))?;
750 self.flush_ops(cx);
751
752 Ok(())
753 }
754
755 fn flush_ops(&mut self, cx: &mut ModelContext<Context>) {
756 let mut messages_changed = false;
757 let mut summary_changed = false;
758
759 self.pending_ops.sort_unstable_by_key(|op| op.timestamp());
760 for op in mem::take(&mut self.pending_ops) {
761 if !self.can_apply_op(&op, cx) {
762 self.pending_ops.push(op);
763 continue;
764 }
765
766 let timestamp = op.timestamp();
767 match op.clone() {
768 ContextOperation::InsertMessage {
769 anchor, metadata, ..
770 } => {
771 if self.messages_metadata.contains_key(&anchor.id) {
772 // We already applied this operation.
773 } else {
774 self.insert_message(anchor, metadata, cx);
775 messages_changed = true;
776 }
777 }
778 ContextOperation::UpdateMessage {
779 message_id,
780 metadata: new_metadata,
781 ..
782 } => {
783 let metadata = self.messages_metadata.get_mut(&message_id).unwrap();
784 if new_metadata.timestamp > metadata.timestamp {
785 *metadata = new_metadata;
786 messages_changed = true;
787 }
788 }
789 ContextOperation::UpdateSummary {
790 summary: new_summary,
791 ..
792 } => {
793 if self
794 .summary
795 .as_ref()
796 .map_or(true, |summary| new_summary.timestamp > summary.timestamp)
797 {
798 self.summary = Some(new_summary);
799 summary_changed = true;
800 }
801 }
802 ContextOperation::SlashCommandFinished {
803 id,
804 output_range,
805 sections,
806 ..
807 } => {
808 if self.finished_slash_commands.insert(id) {
809 let buffer = self.buffer.read(cx);
810 self.slash_command_output_sections
811 .extend(sections.iter().cloned());
812 self.slash_command_output_sections
813 .sort_by(|a, b| a.range.cmp(&b.range, buffer));
814 cx.emit(ContextEvent::SlashCommandFinished {
815 output_range,
816 sections,
817 expand_result: false,
818 run_commands_in_output: false,
819 });
820 }
821 }
822 ContextOperation::BufferOperation(_) => unreachable!(),
823 }
824
825 self.version.observe(timestamp);
826 self.timestamp.observe(timestamp);
827 self.operations.push(op);
828 }
829
830 if messages_changed {
831 cx.emit(ContextEvent::MessagesEdited);
832 cx.notify();
833 }
834
835 if summary_changed {
836 cx.emit(ContextEvent::SummaryChanged);
837 cx.notify();
838 }
839 }
840
841 fn can_apply_op(&self, op: &ContextOperation, cx: &AppContext) -> bool {
842 if !self.version.observed_all(op.version()) {
843 return false;
844 }
845
846 match op {
847 ContextOperation::InsertMessage { anchor, .. } => self
848 .buffer
849 .read(cx)
850 .version
851 .observed(anchor.start.timestamp),
852 ContextOperation::UpdateMessage { message_id, .. } => {
853 self.messages_metadata.contains_key(message_id)
854 }
855 ContextOperation::UpdateSummary { .. } => true,
856 ContextOperation::SlashCommandFinished {
857 output_range,
858 sections,
859 ..
860 } => {
861 let version = &self.buffer.read(cx).version;
862 sections
863 .iter()
864 .map(|section| §ion.range)
865 .chain([output_range])
866 .all(|range| {
867 let observed_start = range.start == language::Anchor::MIN
868 || range.start == language::Anchor::MAX
869 || version.observed(range.start.timestamp);
870 let observed_end = range.end == language::Anchor::MIN
871 || range.end == language::Anchor::MAX
872 || version.observed(range.end.timestamp);
873 observed_start && observed_end
874 })
875 }
876 ContextOperation::BufferOperation(_) => {
877 panic!("buffer operations should always be applied")
878 }
879 }
880 }
881
882 fn push_op(&mut self, op: ContextOperation, cx: &mut ModelContext<Self>) {
883 self.operations.push(op.clone());
884 cx.emit(ContextEvent::Operation(op));
885 }
886
887 pub fn buffer(&self) -> &Model<Buffer> {
888 &self.buffer
889 }
890
891 pub fn language_registry(&self) -> Arc<LanguageRegistry> {
892 self.language_registry.clone()
893 }
894
895 pub fn project(&self) -> Option<Model<Project>> {
896 self.project.clone()
897 }
898
899 pub fn prompt_builder(&self) -> Arc<PromptBuilder> {
900 self.prompt_builder.clone()
901 }
902
903 pub fn path(&self) -> Option<&Path> {
904 self.path.as_deref()
905 }
906
907 pub fn summary(&self) -> Option<&ContextSummary> {
908 self.summary.as_ref()
909 }
910
911 pub fn workflow_step_containing(
912 &self,
913 offset: usize,
914 cx: &AppContext,
915 ) -> Option<(Range<language::Anchor>, Model<WorkflowStep>)> {
916 let buffer = self.buffer.read(cx);
917 let index = self
918 .workflow_steps
919 .binary_search_by(|step| {
920 let step_range = step.range.to_offset(&buffer);
921 if offset < step_range.start {
922 Ordering::Greater
923 } else if offset > step_range.end {
924 Ordering::Less
925 } else {
926 Ordering::Equal
927 }
928 })
929 .ok()?;
930 let step = &self.workflow_steps[index];
931 Some((step.range.clone(), step.step.clone()))
932 }
933
934 pub fn workflow_step_for_range(
935 &self,
936 range: Range<language::Anchor>,
937 cx: &AppContext,
938 ) -> Option<Model<WorkflowStep>> {
939 let buffer = self.buffer.read(cx);
940 let index = self.workflow_step_index_for_range(&range, buffer).ok()?;
941 Some(self.workflow_steps[index].step.clone())
942 }
943
944 pub fn workflow_step_index_for_range(
945 &self,
946 tagged_range: &Range<text::Anchor>,
947 buffer: &text::BufferSnapshot,
948 ) -> Result<usize, usize> {
949 self.workflow_steps
950 .binary_search_by(|probe| probe.range.cmp(&tagged_range, buffer))
951 }
952
953 pub fn pending_slash_commands(&self) -> &[PendingSlashCommand] {
954 &self.pending_slash_commands
955 }
956
957 pub fn slash_command_output_sections(&self) -> &[SlashCommandOutputSection<language::Anchor>] {
958 &self.slash_command_output_sections
959 }
960
961 fn set_language(&mut self, cx: &mut ModelContext<Self>) {
962 let markdown = self.language_registry.language_for_name("Markdown");
963 cx.spawn(|this, mut cx| async move {
964 let markdown = markdown.await?;
965 this.update(&mut cx, |this, cx| {
966 this.buffer
967 .update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx));
968 })
969 })
970 .detach_and_log_err(cx);
971 }
972
973 fn handle_buffer_event(
974 &mut self,
975 _: Model<Buffer>,
976 event: &language::Event,
977 cx: &mut ModelContext<Self>,
978 ) {
979 match event {
980 language::Event::Operation(operation) => cx.emit(ContextEvent::Operation(
981 ContextOperation::BufferOperation(operation.clone()),
982 )),
983 language::Event::Edited => {
984 self.count_remaining_tokens(cx);
985 self.reparse_slash_commands(cx);
986 // Use `inclusive = true` to invalidate a step when an edit occurs
987 // at the start/end of a parsed step.
988 self.prune_invalid_workflow_steps(true, cx);
989 cx.emit(ContextEvent::MessagesEdited);
990 }
991 _ => {}
992 }
993 }
994
995 pub(crate) fn token_count(&self) -> Option<usize> {
996 self.token_count
997 }
998
999 pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1000 let request = self.to_completion_request(cx);
1001 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
1002 return;
1003 };
1004 self.pending_token_count = cx.spawn(|this, mut cx| {
1005 async move {
1006 cx.background_executor()
1007 .timer(Duration::from_millis(200))
1008 .await;
1009
1010 let token_count = cx.update(|cx| model.count_tokens(request, cx))?.await?;
1011 this.update(&mut cx, |this, cx| {
1012 this.token_count = Some(token_count);
1013 this.start_cache_warming(&model, cx);
1014 cx.notify()
1015 })
1016 }
1017 .log_err()
1018 });
1019 }
1020
1021 pub fn mark_cache_anchors(
1022 &mut self,
1023 cache_configuration: &Option<LanguageModelCacheConfiguration>,
1024 speculative: bool,
1025 cx: &mut ModelContext<Self>,
1026 ) -> bool {
1027 let cache_configuration =
1028 cache_configuration
1029 .as_ref()
1030 .unwrap_or(&LanguageModelCacheConfiguration {
1031 max_cache_anchors: 0,
1032 should_speculate: false,
1033 min_total_token: 0,
1034 });
1035
1036 let messages: Vec<Message> = self.messages(cx).collect();
1037
1038 let mut sorted_messages = messages.clone();
1039 if speculative {
1040 // Avoid caching the last message if this is a speculative cache fetch as
1041 // it's likely to change.
1042 sorted_messages.pop();
1043 }
1044 sorted_messages.retain(|m| m.role == Role::User);
1045 sorted_messages.sort_by(|a, b| b.offset_range.len().cmp(&a.offset_range.len()));
1046
1047 let cache_anchors = if self.token_count.unwrap_or(0) < cache_configuration.min_total_token {
1048 // If we have't hit the minimum threshold to enable caching, don't cache anything.
1049 0
1050 } else {
1051 // Save 1 anchor for the inline assistant to use.
1052 max(cache_configuration.max_cache_anchors, 1) - 1
1053 };
1054 sorted_messages.truncate(cache_anchors);
1055
1056 let anchors: HashSet<MessageId> = sorted_messages
1057 .into_iter()
1058 .map(|message| message.id)
1059 .collect();
1060
1061 let buffer = self.buffer.read(cx).snapshot();
1062 let invalidated_caches: HashSet<MessageId> = messages
1063 .iter()
1064 .scan(false, |encountered_invalid, message| {
1065 let message_id = message.id;
1066 let is_invalid = self
1067 .messages_metadata
1068 .get(&message_id)
1069 .map_or(true, |metadata| {
1070 !metadata.is_cache_valid(&buffer, &message.offset_range)
1071 || *encountered_invalid
1072 });
1073 *encountered_invalid |= is_invalid;
1074 Some(if is_invalid { Some(message_id) } else { None })
1075 })
1076 .flatten()
1077 .collect();
1078
1079 let last_anchor = messages.iter().rev().find_map(|message| {
1080 if anchors.contains(&message.id) {
1081 Some(message.id)
1082 } else {
1083 None
1084 }
1085 });
1086
1087 let mut new_anchor_needs_caching = false;
1088 let current_version = &buffer.version;
1089 // If we have no anchors, mark all messages as not being cached.
1090 let mut hit_last_anchor = last_anchor.is_none();
1091
1092 for message in messages.iter() {
1093 if hit_last_anchor {
1094 self.update_metadata(message.id, cx, |metadata| metadata.cache = None);
1095 continue;
1096 }
1097
1098 if let Some(last_anchor) = last_anchor {
1099 if message.id == last_anchor {
1100 hit_last_anchor = true;
1101 }
1102 }
1103
1104 new_anchor_needs_caching = new_anchor_needs_caching
1105 || (invalidated_caches.contains(&message.id) && anchors.contains(&message.id));
1106
1107 self.update_metadata(message.id, cx, |metadata| {
1108 let cache_status = if invalidated_caches.contains(&message.id) {
1109 CacheStatus::Pending
1110 } else {
1111 metadata
1112 .cache
1113 .as_ref()
1114 .map_or(CacheStatus::Pending, |cm| cm.status.clone())
1115 };
1116 metadata.cache = Some(MessageCacheMetadata {
1117 is_anchor: anchors.contains(&message.id),
1118 is_final_anchor: hit_last_anchor,
1119 status: cache_status,
1120 cached_at: current_version.clone(),
1121 });
1122 });
1123 }
1124 new_anchor_needs_caching
1125 }
1126
1127 fn start_cache_warming(&mut self, model: &Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
1128 let cache_configuration = model.cache_configuration();
1129
1130 if !self.mark_cache_anchors(&cache_configuration, true, cx) {
1131 return;
1132 }
1133 if !self.pending_completions.is_empty() {
1134 return;
1135 }
1136 if let Some(cache_configuration) = cache_configuration {
1137 if !cache_configuration.should_speculate {
1138 return;
1139 }
1140 }
1141
1142 let request = {
1143 let mut req = self.to_completion_request(cx);
1144 // Skip the last message because it's likely to change and
1145 // therefore would be a waste to cache.
1146 req.messages.pop();
1147 req.messages.push(LanguageModelRequestMessage {
1148 role: Role::User,
1149 content: vec!["Respond only with OK, nothing else.".into()],
1150 cache: false,
1151 });
1152 req
1153 };
1154
1155 let model = Arc::clone(model);
1156 self.pending_cache_warming_task = cx.spawn(|this, mut cx| {
1157 async move {
1158 match model.stream_completion(request, &cx).await {
1159 Ok(mut stream) => {
1160 stream.next().await;
1161 log::info!("Cache warming completed successfully");
1162 }
1163 Err(e) => {
1164 log::warn!("Cache warming failed: {}", e);
1165 }
1166 };
1167 this.update(&mut cx, |this, cx| {
1168 this.update_cache_status_for_completion(cx);
1169 })
1170 .ok();
1171 anyhow::Ok(())
1172 }
1173 .log_err()
1174 });
1175 }
1176
1177 pub fn update_cache_status_for_completion(&mut self, cx: &mut ModelContext<Self>) {
1178 let cached_message_ids: Vec<MessageId> = self
1179 .messages_metadata
1180 .iter()
1181 .filter_map(|(message_id, metadata)| {
1182 metadata.cache.as_ref().and_then(|cache| {
1183 if cache.status == CacheStatus::Pending {
1184 Some(*message_id)
1185 } else {
1186 None
1187 }
1188 })
1189 })
1190 .collect();
1191
1192 for message_id in cached_message_ids {
1193 self.update_metadata(message_id, cx, |metadata| {
1194 if let Some(cache) = &mut metadata.cache {
1195 cache.status = CacheStatus::Cached;
1196 }
1197 });
1198 }
1199 cx.notify();
1200 }
1201
1202 pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
1203 let buffer = self.buffer.read(cx);
1204 let mut row_ranges = self
1205 .edits_since_last_slash_command_parse
1206 .consume()
1207 .into_iter()
1208 .map(|edit| {
1209 let start_row = buffer.offset_to_point(edit.new.start).row;
1210 let end_row = buffer.offset_to_point(edit.new.end).row + 1;
1211 start_row..end_row
1212 })
1213 .peekable();
1214
1215 let mut removed = Vec::new();
1216 let mut updated = Vec::new();
1217 while let Some(mut row_range) = row_ranges.next() {
1218 while let Some(next_row_range) = row_ranges.peek() {
1219 if row_range.end >= next_row_range.start {
1220 row_range.end = next_row_range.end;
1221 row_ranges.next();
1222 } else {
1223 break;
1224 }
1225 }
1226
1227 let start = buffer.anchor_before(Point::new(row_range.start, 0));
1228 let end = buffer.anchor_after(Point::new(
1229 row_range.end - 1,
1230 buffer.line_len(row_range.end - 1),
1231 ));
1232
1233 let old_range = self.pending_command_indices_for_range(start..end, cx);
1234
1235 let mut new_commands = Vec::new();
1236 let mut lines = buffer.text_for_range(start..end).lines();
1237 let mut offset = lines.offset();
1238 while let Some(line) = lines.next() {
1239 if let Some(command_line) = SlashCommandLine::parse(line) {
1240 let name = &line[command_line.name.clone()];
1241 let arguments = command_line
1242 .arguments
1243 .iter()
1244 .filter_map(|argument_range| {
1245 if argument_range.is_empty() {
1246 None
1247 } else {
1248 line.get(argument_range.clone())
1249 }
1250 })
1251 .map(ToOwned::to_owned)
1252 .collect::<SmallVec<_>>();
1253 if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
1254 if !command.requires_argument() || !arguments.is_empty() {
1255 let start_ix = offset + command_line.name.start - 1;
1256 let end_ix = offset
1257 + command_line
1258 .arguments
1259 .last()
1260 .map_or(command_line.name.end, |argument| argument.end);
1261 let source_range =
1262 buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix);
1263 let pending_command = PendingSlashCommand {
1264 name: name.to_string(),
1265 arguments,
1266 source_range,
1267 status: PendingSlashCommandStatus::Idle,
1268 };
1269 updated.push(pending_command.clone());
1270 new_commands.push(pending_command);
1271 }
1272 }
1273 }
1274
1275 offset = lines.offset();
1276 }
1277
1278 let removed_commands = self.pending_slash_commands.splice(old_range, new_commands);
1279 removed.extend(removed_commands.map(|command| command.source_range));
1280 }
1281
1282 if !updated.is_empty() || !removed.is_empty() {
1283 cx.emit(ContextEvent::PendingSlashCommandsUpdated { removed, updated });
1284 }
1285 }
1286
1287 fn prune_invalid_workflow_steps(&mut self, inclusive: bool, cx: &mut ModelContext<Self>) {
1288 let mut removed = Vec::new();
1289
1290 for edit_range in self.edits_since_last_workflow_step_prune.consume() {
1291 let intersecting_range = self.find_intersecting_steps(edit_range.new, inclusive, cx);
1292 removed.extend(
1293 self.workflow_steps
1294 .drain(intersecting_range)
1295 .map(|step| step.range),
1296 );
1297 }
1298
1299 if !removed.is_empty() {
1300 cx.emit(ContextEvent::WorkflowStepsRemoved(removed));
1301 cx.notify();
1302 }
1303 }
1304
1305 fn find_intersecting_steps(
1306 &self,
1307 range: Range<usize>,
1308 inclusive: bool,
1309 cx: &AppContext,
1310 ) -> Range<usize> {
1311 let buffer = self.buffer.read(cx);
1312 let start_ix = match self.workflow_steps.binary_search_by(|probe| {
1313 probe
1314 .range
1315 .end
1316 .to_offset(buffer)
1317 .cmp(&range.start)
1318 .then(if inclusive {
1319 Ordering::Greater
1320 } else {
1321 Ordering::Less
1322 })
1323 }) {
1324 Ok(ix) | Err(ix) => ix,
1325 };
1326 let end_ix = match self.workflow_steps.binary_search_by(|probe| {
1327 probe
1328 .range
1329 .start
1330 .to_offset(buffer)
1331 .cmp(&range.end)
1332 .then(if inclusive {
1333 Ordering::Less
1334 } else {
1335 Ordering::Greater
1336 })
1337 }) {
1338 Ok(ix) | Err(ix) => ix,
1339 };
1340 start_ix..end_ix
1341 }
1342
1343 fn parse_workflow_steps_in_range(&mut self, range: Range<usize>, cx: &mut ModelContext<Self>) {
1344 let weak_self = cx.weak_model();
1345 let mut new_edit_steps = Vec::new();
1346 let mut edits = Vec::new();
1347
1348 let buffer = self.buffer.read(cx).snapshot();
1349 let mut message_lines = buffer.as_rope().chunks_in_range(range).lines();
1350 let mut in_step = false;
1351 let mut step_open_tag_start_ix = 0;
1352 let mut line_start_offset = message_lines.offset();
1353
1354 while let Some(line) = message_lines.next() {
1355 if let Some(step_start_index) = line.find("<step>") {
1356 if !in_step {
1357 in_step = true;
1358 step_open_tag_start_ix = line_start_offset + step_start_index;
1359 }
1360 }
1361
1362 if let Some(step_end_index) = line.find("</step>") {
1363 if in_step {
1364 let mut step_open_tag_end_ix = step_open_tag_start_ix + "<step>".len();
1365 if buffer.chars_at(step_open_tag_end_ix).next() == Some('\n') {
1366 step_open_tag_end_ix += 1;
1367 }
1368 let mut step_end_tag_start_ix = line_start_offset + step_end_index;
1369 let step_end_tag_end_ix = step_end_tag_start_ix + "</step>".len();
1370 if buffer.reversed_chars_at(step_end_tag_start_ix).next() == Some('\n') {
1371 step_end_tag_start_ix -= 1;
1372 }
1373 edits.push((step_open_tag_start_ix..step_open_tag_end_ix, ""));
1374 edits.push((step_end_tag_start_ix..step_end_tag_end_ix, ""));
1375 let tagged_range = buffer.anchor_after(step_open_tag_end_ix)
1376 ..buffer.anchor_before(step_end_tag_start_ix);
1377
1378 // Check if a step with the same range already exists
1379 let existing_step_index =
1380 self.workflow_step_index_for_range(&tagged_range, &buffer);
1381
1382 if let Err(ix) = existing_step_index {
1383 new_edit_steps.push((
1384 ix,
1385 WorkflowStepEntry {
1386 step: cx.new_model(|_| {
1387 WorkflowStep::new(tagged_range.clone(), weak_self.clone())
1388 }),
1389 range: tagged_range,
1390 },
1391 ));
1392 }
1393
1394 in_step = false;
1395 }
1396 }
1397
1398 line_start_offset = message_lines.offset();
1399 }
1400
1401 let mut updated = Vec::new();
1402 for (index, step) in new_edit_steps.into_iter().rev() {
1403 let step_range = step.range.clone();
1404 updated.push(step_range.clone());
1405 self.workflow_steps.insert(index, step);
1406 self.resolve_workflow_step(step_range, cx);
1407 }
1408
1409 // Delete <step> tags, making sure we don't accidentally invalidate
1410 // the step we just parsed.
1411 self.buffer
1412 .update(cx, |buffer, cx| buffer.edit(edits, None, cx));
1413 self.edits_since_last_workflow_step_prune.consume();
1414 }
1415
1416 pub fn resolve_workflow_step(
1417 &mut self,
1418 tagged_range: Range<language::Anchor>,
1419 cx: &mut ModelContext<Self>,
1420 ) {
1421 let Ok(step_index) = self
1422 .workflow_steps
1423 .binary_search_by(|step| step.range.cmp(&tagged_range, self.buffer.read(cx)))
1424 else {
1425 return;
1426 };
1427
1428 cx.emit(ContextEvent::WorkflowStepUpdated(tagged_range.clone()));
1429 cx.notify();
1430
1431 let resolution = self.workflow_steps[step_index].step.clone();
1432 cx.defer(move |cx| {
1433 resolution.update(cx, |resolution, cx| resolution.resolve(cx));
1434 });
1435 }
1436
1437 pub fn workflow_step_updated(
1438 &mut self,
1439 range: Range<language::Anchor>,
1440 cx: &mut ModelContext<Self>,
1441 ) {
1442 cx.emit(ContextEvent::WorkflowStepUpdated(range));
1443 cx.notify();
1444 }
1445
1446 pub fn pending_command_for_position(
1447 &mut self,
1448 position: language::Anchor,
1449 cx: &mut ModelContext<Self>,
1450 ) -> Option<&mut PendingSlashCommand> {
1451 let buffer = self.buffer.read(cx);
1452 match self
1453 .pending_slash_commands
1454 .binary_search_by(|probe| probe.source_range.end.cmp(&position, buffer))
1455 {
1456 Ok(ix) => Some(&mut self.pending_slash_commands[ix]),
1457 Err(ix) => {
1458 let cmd = self.pending_slash_commands.get_mut(ix)?;
1459 if position.cmp(&cmd.source_range.start, buffer).is_ge()
1460 && position.cmp(&cmd.source_range.end, buffer).is_le()
1461 {
1462 Some(cmd)
1463 } else {
1464 None
1465 }
1466 }
1467 }
1468 }
1469
1470 pub fn pending_commands_for_range(
1471 &self,
1472 range: Range<language::Anchor>,
1473 cx: &AppContext,
1474 ) -> &[PendingSlashCommand] {
1475 let range = self.pending_command_indices_for_range(range, cx);
1476 &self.pending_slash_commands[range]
1477 }
1478
1479 fn pending_command_indices_for_range(
1480 &self,
1481 range: Range<language::Anchor>,
1482 cx: &AppContext,
1483 ) -> Range<usize> {
1484 let buffer = self.buffer.read(cx);
1485 let start_ix = match self
1486 .pending_slash_commands
1487 .binary_search_by(|probe| probe.source_range.end.cmp(&range.start, &buffer))
1488 {
1489 Ok(ix) | Err(ix) => ix,
1490 };
1491 let end_ix = match self
1492 .pending_slash_commands
1493 .binary_search_by(|probe| probe.source_range.start.cmp(&range.end, &buffer))
1494 {
1495 Ok(ix) => ix + 1,
1496 Err(ix) => ix,
1497 };
1498 start_ix..end_ix
1499 }
1500
1501 pub fn insert_command_output(
1502 &mut self,
1503 command_range: Range<language::Anchor>,
1504 output: Task<Result<SlashCommandOutput>>,
1505 ensure_trailing_newline: bool,
1506 expand_result: bool,
1507 cx: &mut ModelContext<Self>,
1508 ) {
1509 self.reparse_slash_commands(cx);
1510
1511 let insert_output_task = cx.spawn(|this, mut cx| {
1512 let command_range = command_range.clone();
1513 async move {
1514 let output = output.await;
1515 this.update(&mut cx, |this, cx| match output {
1516 Ok(mut output) => {
1517 // Ensure section ranges are valid.
1518 for section in &mut output.sections {
1519 section.range.start = section.range.start.min(output.text.len());
1520 section.range.end = section.range.end.min(output.text.len());
1521 while !output.text.is_char_boundary(section.range.start) {
1522 section.range.start -= 1;
1523 }
1524 while !output.text.is_char_boundary(section.range.end) {
1525 section.range.end += 1;
1526 }
1527 }
1528
1529 // Ensure there is a newline after the last section.
1530 if ensure_trailing_newline {
1531 let has_newline_after_last_section =
1532 output.sections.last().map_or(false, |last_section| {
1533 output.text[last_section.range.end..].ends_with('\n')
1534 });
1535 if !has_newline_after_last_section {
1536 output.text.push('\n');
1537 }
1538 }
1539
1540 let version = this.version.clone();
1541 let command_id = SlashCommandId(this.next_timestamp());
1542 let (operation, event) = this.buffer.update(cx, |buffer, cx| {
1543 let start = command_range.start.to_offset(buffer);
1544 let old_end = command_range.end.to_offset(buffer);
1545 let new_end = start + output.text.len();
1546 buffer.edit([(start..old_end, output.text)], None, cx);
1547
1548 let mut sections = output
1549 .sections
1550 .into_iter()
1551 .map(|section| SlashCommandOutputSection {
1552 range: buffer.anchor_after(start + section.range.start)
1553 ..buffer.anchor_before(start + section.range.end),
1554 icon: section.icon,
1555 label: section.label,
1556 })
1557 .collect::<Vec<_>>();
1558 sections.sort_by(|a, b| a.range.cmp(&b.range, buffer));
1559
1560 this.slash_command_output_sections
1561 .extend(sections.iter().cloned());
1562 this.slash_command_output_sections
1563 .sort_by(|a, b| a.range.cmp(&b.range, buffer));
1564
1565 let output_range =
1566 buffer.anchor_after(start)..buffer.anchor_before(new_end);
1567 this.finished_slash_commands.insert(command_id);
1568
1569 (
1570 ContextOperation::SlashCommandFinished {
1571 id: command_id,
1572 output_range: output_range.clone(),
1573 sections: sections.clone(),
1574 version,
1575 },
1576 ContextEvent::SlashCommandFinished {
1577 output_range,
1578 sections,
1579 run_commands_in_output: output.run_commands_in_text,
1580 expand_result,
1581 },
1582 )
1583 });
1584
1585 this.push_op(operation, cx);
1586 cx.emit(event);
1587 }
1588 Err(error) => {
1589 if let Some(pending_command) =
1590 this.pending_command_for_position(command_range.start, cx)
1591 {
1592 pending_command.status =
1593 PendingSlashCommandStatus::Error(error.to_string());
1594 cx.emit(ContextEvent::PendingSlashCommandsUpdated {
1595 removed: vec![pending_command.source_range.clone()],
1596 updated: vec![pending_command.clone()],
1597 });
1598 }
1599 }
1600 })
1601 .ok();
1602 }
1603 });
1604
1605 if let Some(pending_command) = self.pending_command_for_position(command_range.start, cx) {
1606 pending_command.status = PendingSlashCommandStatus::Running {
1607 _task: insert_output_task.shared(),
1608 };
1609 cx.emit(ContextEvent::PendingSlashCommandsUpdated {
1610 removed: vec![pending_command.source_range.clone()],
1611 updated: vec![pending_command.clone()],
1612 });
1613 }
1614 }
1615
1616 pub fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
1617 self.count_remaining_tokens(cx);
1618 }
1619
1620 fn get_last_valid_message_id(&self, cx: &ModelContext<Self>) -> Option<MessageId> {
1621 self.message_anchors.iter().rev().find_map(|message| {
1622 message
1623 .start
1624 .is_valid(self.buffer.read(cx))
1625 .then_some(message.id)
1626 })
1627 }
1628
1629 pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
1630 let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
1631 let model = LanguageModelRegistry::read_global(cx).active_model()?;
1632 let last_message_id = self.get_last_valid_message_id(cx)?;
1633
1634 if !provider.is_authenticated(cx) {
1635 log::info!("completion provider has no credentials");
1636 return None;
1637 }
1638 // Compute which messages to cache, including the last one.
1639 self.mark_cache_anchors(&model.cache_configuration(), false, cx);
1640
1641 let request = self.to_completion_request(cx);
1642 let assistant_message = self
1643 .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
1644 .unwrap();
1645
1646 // Queue up the user's next reply.
1647 let user_message = self
1648 .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx)
1649 .unwrap();
1650
1651 let pending_completion_id = post_inc(&mut self.completion_count);
1652
1653 let task = cx.spawn({
1654 |this, mut cx| async move {
1655 let stream = model.stream_completion(request, &cx);
1656 let assistant_message_id = assistant_message.id;
1657 let mut response_latency = None;
1658 let stream_completion = async {
1659 let request_start = Instant::now();
1660 let mut chunks = stream.await?;
1661
1662 while let Some(chunk) = chunks.next().await {
1663 if response_latency.is_none() {
1664 response_latency = Some(request_start.elapsed());
1665 }
1666 let chunk = chunk?;
1667
1668 this.update(&mut cx, |this, cx| {
1669 let message_ix = this
1670 .message_anchors
1671 .iter()
1672 .position(|message| message.id == assistant_message_id)?;
1673 let message_range = this.buffer.update(cx, |buffer, cx| {
1674 let message_start_offset =
1675 this.message_anchors[message_ix].start.to_offset(buffer);
1676 let message_old_end_offset = this.message_anchors[message_ix + 1..]
1677 .iter()
1678 .find(|message| message.start.is_valid(buffer))
1679 .map_or(buffer.len(), |message| {
1680 message.start.to_offset(buffer).saturating_sub(1)
1681 });
1682 let message_new_end_offset = message_old_end_offset + chunk.len();
1683 buffer.edit(
1684 [(message_old_end_offset..message_old_end_offset, chunk)],
1685 None,
1686 cx,
1687 );
1688 message_start_offset..message_new_end_offset
1689 });
1690
1691 // Use `inclusive = false` as edits might occur at the end of a parsed step.
1692 this.prune_invalid_workflow_steps(false, cx);
1693 this.parse_workflow_steps_in_range(message_range, cx);
1694 cx.emit(ContextEvent::StreamedCompletion);
1695
1696 Some(())
1697 })?;
1698 smol::future::yield_now().await;
1699 }
1700 this.update(&mut cx, |this, cx| {
1701 this.pending_completions
1702 .retain(|completion| completion.id != pending_completion_id);
1703 this.summarize(false, cx);
1704 this.update_cache_status_for_completion(cx);
1705 })?;
1706
1707 anyhow::Ok(())
1708 };
1709
1710 let result = stream_completion.await;
1711
1712 this.update(&mut cx, |this, cx| {
1713 let error_message = result
1714 .err()
1715 .map(|error| error.to_string().trim().to_string());
1716
1717 if let Some(error_message) = error_message.as_ref() {
1718 cx.emit(ContextEvent::ShowAssistError(SharedString::from(
1719 error_message.clone(),
1720 )));
1721 }
1722
1723 this.update_metadata(assistant_message_id, cx, |metadata| {
1724 if let Some(error_message) = error_message.as_ref() {
1725 metadata.status =
1726 MessageStatus::Error(SharedString::from(error_message.clone()));
1727 } else {
1728 metadata.status = MessageStatus::Done;
1729 }
1730 });
1731
1732 if let Some(telemetry) = this.telemetry.as_ref() {
1733 telemetry.report_assistant_event(
1734 Some(this.id.0.clone()),
1735 AssistantKind::Panel,
1736 model.telemetry_id(),
1737 response_latency,
1738 error_message,
1739 );
1740 }
1741 })
1742 .ok();
1743 }
1744 });
1745
1746 self.pending_completions.push(PendingCompletion {
1747 id: pending_completion_id,
1748 assistant_message_id: assistant_message.id,
1749 _task: task,
1750 });
1751
1752 Some(user_message)
1753 }
1754
1755 pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest {
1756 let buffer = self.buffer.read(cx);
1757 let request_messages = self
1758 .messages(cx)
1759 .filter(|message| message.status == MessageStatus::Done)
1760 .filter_map(|message| message.to_request_message(&buffer))
1761 .collect();
1762
1763 LanguageModelRequest {
1764 messages: request_messages,
1765 stop: vec![],
1766 temperature: 1.0,
1767 }
1768 }
1769
1770 pub fn cancel_last_assist(&mut self, cx: &mut ModelContext<Self>) -> bool {
1771 if let Some(pending_completion) = self.pending_completions.pop() {
1772 self.update_metadata(pending_completion.assistant_message_id, cx, |metadata| {
1773 if metadata.status == MessageStatus::Pending {
1774 metadata.status = MessageStatus::Canceled;
1775 }
1776 });
1777 true
1778 } else {
1779 false
1780 }
1781 }
1782
1783 pub fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1784 for id in ids {
1785 if let Some(metadata) = self.messages_metadata.get(&id) {
1786 let role = metadata.role.cycle();
1787 self.update_metadata(id, cx, |metadata| metadata.role = role);
1788 }
1789 }
1790 }
1791
1792 pub fn update_metadata(
1793 &mut self,
1794 id: MessageId,
1795 cx: &mut ModelContext<Self>,
1796 f: impl FnOnce(&mut MessageMetadata),
1797 ) {
1798 let version = self.version.clone();
1799 let timestamp = self.next_timestamp();
1800 if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1801 f(metadata);
1802 metadata.timestamp = timestamp;
1803 let operation = ContextOperation::UpdateMessage {
1804 message_id: id,
1805 metadata: metadata.clone(),
1806 version,
1807 };
1808 self.push_op(operation, cx);
1809 cx.emit(ContextEvent::MessagesEdited);
1810 cx.notify();
1811 }
1812 }
1813
1814 pub fn insert_message_after(
1815 &mut self,
1816 message_id: MessageId,
1817 role: Role,
1818 status: MessageStatus,
1819 cx: &mut ModelContext<Self>,
1820 ) -> Option<MessageAnchor> {
1821 if let Some(prev_message_ix) = self
1822 .message_anchors
1823 .iter()
1824 .position(|message| message.id == message_id)
1825 {
1826 // Find the next valid message after the one we were given.
1827 let mut next_message_ix = prev_message_ix + 1;
1828 while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1829 if next_message.start.is_valid(self.buffer.read(cx)) {
1830 break;
1831 }
1832 next_message_ix += 1;
1833 }
1834
1835 let start = self.buffer.update(cx, |buffer, cx| {
1836 let offset = self
1837 .message_anchors
1838 .get(next_message_ix)
1839 .map_or(buffer.len(), |message| {
1840 buffer.clip_offset(message.start.to_offset(buffer) - 1, Bias::Left)
1841 });
1842 buffer.edit([(offset..offset, "\n")], None, cx);
1843 buffer.anchor_before(offset + 1)
1844 });
1845
1846 let version = self.version.clone();
1847 let anchor = MessageAnchor {
1848 id: MessageId(self.next_timestamp()),
1849 start,
1850 };
1851 let metadata = MessageMetadata {
1852 role,
1853 status,
1854 timestamp: anchor.id.0,
1855 cache: None,
1856 };
1857 self.insert_message(anchor.clone(), metadata.clone(), cx);
1858 self.push_op(
1859 ContextOperation::InsertMessage {
1860 anchor: anchor.clone(),
1861 metadata,
1862 version,
1863 },
1864 cx,
1865 );
1866 Some(anchor)
1867 } else {
1868 None
1869 }
1870 }
1871
1872 pub fn insert_image(&mut self, image: Image, cx: &mut ModelContext<Self>) -> Option<()> {
1873 if let hash_map::Entry::Vacant(entry) = self.images.entry(image.id()) {
1874 entry.insert((
1875 image.to_image_data(cx).log_err()?,
1876 LanguageModelImage::from_image(image, cx).shared(),
1877 ));
1878 }
1879
1880 Some(())
1881 }
1882
1883 pub fn insert_image_anchor(
1884 &mut self,
1885 image_id: u64,
1886 anchor: language::Anchor,
1887 cx: &mut ModelContext<Self>,
1888 ) -> bool {
1889 cx.emit(ContextEvent::MessagesEdited);
1890
1891 let buffer = self.buffer.read(cx);
1892 let insertion_ix = match self
1893 .image_anchors
1894 .binary_search_by(|existing_anchor| anchor.cmp(&existing_anchor.anchor, buffer))
1895 {
1896 Ok(ix) => ix,
1897 Err(ix) => ix,
1898 };
1899
1900 if let Some((render_image, image)) = self.images.get(&image_id) {
1901 self.image_anchors.insert(
1902 insertion_ix,
1903 ImageAnchor {
1904 anchor,
1905 image_id,
1906 image: image.clone(),
1907 render_image: render_image.clone(),
1908 },
1909 );
1910
1911 true
1912 } else {
1913 false
1914 }
1915 }
1916
1917 pub fn images<'a>(&'a self, _cx: &'a AppContext) -> impl 'a + Iterator<Item = ImageAnchor> {
1918 self.image_anchors.iter().cloned()
1919 }
1920
1921 pub fn split_message(
1922 &mut self,
1923 range: Range<usize>,
1924 cx: &mut ModelContext<Self>,
1925 ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1926 let start_message = self.message_for_offset(range.start, cx);
1927 let end_message = self.message_for_offset(range.end, cx);
1928 if let Some((start_message, end_message)) = start_message.zip(end_message) {
1929 // Prevent splitting when range spans multiple messages.
1930 if start_message.id != end_message.id {
1931 return (None, None);
1932 }
1933
1934 let message = start_message;
1935 let role = message.role;
1936 let mut edited_buffer = false;
1937
1938 let mut suffix_start = None;
1939
1940 // TODO: why did this start panicking?
1941 if range.start > message.offset_range.start
1942 && range.end < message.offset_range.end.saturating_sub(1)
1943 {
1944 if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1945 suffix_start = Some(range.end + 1);
1946 } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1947 suffix_start = Some(range.end);
1948 }
1949 }
1950
1951 let version = self.version.clone();
1952 let suffix = if let Some(suffix_start) = suffix_start {
1953 MessageAnchor {
1954 id: MessageId(self.next_timestamp()),
1955 start: self.buffer.read(cx).anchor_before(suffix_start),
1956 }
1957 } else {
1958 self.buffer.update(cx, |buffer, cx| {
1959 buffer.edit([(range.end..range.end, "\n")], None, cx);
1960 });
1961 edited_buffer = true;
1962 MessageAnchor {
1963 id: MessageId(self.next_timestamp()),
1964 start: self.buffer.read(cx).anchor_before(range.end + 1),
1965 }
1966 };
1967
1968 let suffix_metadata = MessageMetadata {
1969 role,
1970 status: MessageStatus::Done,
1971 timestamp: suffix.id.0,
1972 cache: None,
1973 };
1974 self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
1975 self.push_op(
1976 ContextOperation::InsertMessage {
1977 anchor: suffix.clone(),
1978 metadata: suffix_metadata,
1979 version,
1980 },
1981 cx,
1982 );
1983
1984 let new_messages =
1985 if range.start == range.end || range.start == message.offset_range.start {
1986 (None, Some(suffix))
1987 } else {
1988 let mut prefix_end = None;
1989 if range.start > message.offset_range.start
1990 && range.end < message.offset_range.end - 1
1991 {
1992 if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1993 prefix_end = Some(range.start + 1);
1994 } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1995 == Some('\n')
1996 {
1997 prefix_end = Some(range.start);
1998 }
1999 }
2000
2001 let version = self.version.clone();
2002 let selection = if let Some(prefix_end) = prefix_end {
2003 MessageAnchor {
2004 id: MessageId(self.next_timestamp()),
2005 start: self.buffer.read(cx).anchor_before(prefix_end),
2006 }
2007 } else {
2008 self.buffer.update(cx, |buffer, cx| {
2009 buffer.edit([(range.start..range.start, "\n")], None, cx)
2010 });
2011 edited_buffer = true;
2012 MessageAnchor {
2013 id: MessageId(self.next_timestamp()),
2014 start: self.buffer.read(cx).anchor_before(range.end + 1),
2015 }
2016 };
2017
2018 let selection_metadata = MessageMetadata {
2019 role,
2020 status: MessageStatus::Done,
2021 timestamp: selection.id.0,
2022 cache: None,
2023 };
2024 self.insert_message(selection.clone(), selection_metadata.clone(), cx);
2025 self.push_op(
2026 ContextOperation::InsertMessage {
2027 anchor: selection.clone(),
2028 metadata: selection_metadata,
2029 version,
2030 },
2031 cx,
2032 );
2033
2034 (Some(selection), Some(suffix))
2035 };
2036
2037 if !edited_buffer {
2038 cx.emit(ContextEvent::MessagesEdited);
2039 }
2040 new_messages
2041 } else {
2042 (None, None)
2043 }
2044 }
2045
2046 fn insert_message(
2047 &mut self,
2048 new_anchor: MessageAnchor,
2049 new_metadata: MessageMetadata,
2050 cx: &mut ModelContext<Self>,
2051 ) {
2052 cx.emit(ContextEvent::MessagesEdited);
2053
2054 self.messages_metadata.insert(new_anchor.id, new_metadata);
2055
2056 let buffer = self.buffer.read(cx);
2057 let insertion_ix = self
2058 .message_anchors
2059 .iter()
2060 .position(|anchor| {
2061 let comparison = new_anchor.start.cmp(&anchor.start, buffer);
2062 comparison.is_lt() || (comparison.is_eq() && new_anchor.id > anchor.id)
2063 })
2064 .unwrap_or(self.message_anchors.len());
2065 self.message_anchors.insert(insertion_ix, new_anchor);
2066 }
2067
2068 pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
2069 let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
2070 return;
2071 };
2072 let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
2073 return;
2074 };
2075
2076 if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
2077 if !provider.is_authenticated(cx) {
2078 return;
2079 }
2080
2081 let messages = self
2082 .messages(cx)
2083 .filter_map(|message| message.to_request_message(self.buffer.read(cx)))
2084 .chain(Some(LanguageModelRequestMessage {
2085 role: Role::User,
2086 content: vec![
2087 "Summarize the context into a short title without punctuation.".into(),
2088 ],
2089 cache: false,
2090 }));
2091 let request = LanguageModelRequest {
2092 messages: messages.collect(),
2093 stop: vec![],
2094 temperature: 1.0,
2095 };
2096
2097 self.pending_summary = cx.spawn(|this, mut cx| {
2098 async move {
2099 let stream = model.stream_completion(request, &cx);
2100 let mut messages = stream.await?;
2101
2102 let mut replaced = !replace_old;
2103 while let Some(message) = messages.next().await {
2104 let text = message?;
2105 let mut lines = text.lines();
2106 this.update(&mut cx, |this, cx| {
2107 let version = this.version.clone();
2108 let timestamp = this.next_timestamp();
2109 let summary = this.summary.get_or_insert(ContextSummary::default());
2110 if !replaced && replace_old {
2111 summary.text.clear();
2112 replaced = true;
2113 }
2114 summary.text.extend(lines.next());
2115 summary.timestamp = timestamp;
2116 let operation = ContextOperation::UpdateSummary {
2117 summary: summary.clone(),
2118 version,
2119 };
2120 this.push_op(operation, cx);
2121 cx.emit(ContextEvent::SummaryChanged);
2122 })?;
2123
2124 // Stop if the LLM generated multiple lines.
2125 if lines.next().is_some() {
2126 break;
2127 }
2128 }
2129
2130 this.update(&mut cx, |this, cx| {
2131 let version = this.version.clone();
2132 let timestamp = this.next_timestamp();
2133 if let Some(summary) = this.summary.as_mut() {
2134 summary.done = true;
2135 summary.timestamp = timestamp;
2136 let operation = ContextOperation::UpdateSummary {
2137 summary: summary.clone(),
2138 version,
2139 };
2140 this.push_op(operation, cx);
2141 cx.emit(ContextEvent::SummaryChanged);
2142 }
2143 })?;
2144
2145 anyhow::Ok(())
2146 }
2147 .log_err()
2148 });
2149 }
2150 }
2151
2152 fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
2153 self.messages_for_offsets([offset], cx).pop()
2154 }
2155
2156 pub fn messages_for_offsets(
2157 &self,
2158 offsets: impl IntoIterator<Item = usize>,
2159 cx: &AppContext,
2160 ) -> Vec<Message> {
2161 let mut result = Vec::new();
2162
2163 let mut messages = self.messages(cx).peekable();
2164 let mut offsets = offsets.into_iter().peekable();
2165 let mut current_message = messages.next();
2166 while let Some(offset) = offsets.next() {
2167 // Locate the message that contains the offset.
2168 while current_message.as_ref().map_or(false, |message| {
2169 !message.offset_range.contains(&offset) && messages.peek().is_some()
2170 }) {
2171 current_message = messages.next();
2172 }
2173 let Some(message) = current_message.as_ref() else {
2174 break;
2175 };
2176
2177 // Skip offsets that are in the same message.
2178 while offsets.peek().map_or(false, |offset| {
2179 message.offset_range.contains(offset) || messages.peek().is_none()
2180 }) {
2181 offsets.next();
2182 }
2183
2184 result.push(message.clone());
2185 }
2186 result
2187 }
2188
2189 fn messages_from_anchors<'a>(
2190 &'a self,
2191 message_anchors: impl Iterator<Item = &'a MessageAnchor> + 'a,
2192 cx: &'a AppContext,
2193 ) -> impl 'a + Iterator<Item = Message> {
2194 let buffer = self.buffer.read(cx);
2195 let messages = message_anchors.enumerate();
2196 let images = self.image_anchors.iter();
2197
2198 Self::messages_from_iters(buffer, &self.messages_metadata, messages, images)
2199 }
2200
2201 pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
2202 self.messages_from_anchors(self.message_anchors.iter(), cx)
2203 }
2204
2205 pub fn messages_from_iters<'a>(
2206 buffer: &'a Buffer,
2207 metadata: &'a HashMap<MessageId, MessageMetadata>,
2208 messages: impl Iterator<Item = (usize, &'a MessageAnchor)> + 'a,
2209 images: impl Iterator<Item = &'a ImageAnchor> + 'a,
2210 ) -> impl 'a + Iterator<Item = Message> {
2211 let mut messages = messages.peekable();
2212 let mut images = images.peekable();
2213
2214 iter::from_fn(move || {
2215 if let Some((start_ix, message_anchor)) = messages.next() {
2216 let metadata = metadata.get(&message_anchor.id)?;
2217
2218 let message_start = message_anchor.start.to_offset(buffer);
2219 let mut message_end = None;
2220 let mut end_ix = start_ix;
2221 while let Some((_, next_message)) = messages.peek() {
2222 if next_message.start.is_valid(buffer) {
2223 message_end = Some(next_message.start);
2224 break;
2225 } else {
2226 end_ix += 1;
2227 messages.next();
2228 }
2229 }
2230 let message_end_anchor = message_end.unwrap_or(language::Anchor::MAX);
2231 let message_end = message_end_anchor.to_offset(buffer);
2232
2233 let mut image_offsets = SmallVec::new();
2234 while let Some(image_anchor) = images.peek() {
2235 if image_anchor.anchor.cmp(&message_end_anchor, buffer).is_lt() {
2236 image_offsets.push((
2237 image_anchor.anchor.to_offset(buffer),
2238 MessageImage {
2239 image_id: image_anchor.image_id,
2240 image: image_anchor.image.clone(),
2241 },
2242 ));
2243 images.next();
2244 } else {
2245 break;
2246 }
2247 }
2248
2249 return Some(Message {
2250 index_range: start_ix..end_ix,
2251 offset_range: message_start..message_end,
2252 id: message_anchor.id,
2253 anchor: message_anchor.start,
2254 role: metadata.role,
2255 status: metadata.status.clone(),
2256 cache: metadata.cache.clone(),
2257 image_offsets,
2258 });
2259 }
2260 None
2261 })
2262 }
2263
2264 pub fn save(
2265 &mut self,
2266 debounce: Option<Duration>,
2267 fs: Arc<dyn Fs>,
2268 cx: &mut ModelContext<Context>,
2269 ) {
2270 if self.replica_id() != ReplicaId::default() {
2271 // Prevent saving a remote context for now.
2272 return;
2273 }
2274
2275 self.pending_save = cx.spawn(|this, mut cx| async move {
2276 if let Some(debounce) = debounce {
2277 cx.background_executor().timer(debounce).await;
2278 }
2279
2280 let (old_path, summary) = this.read_with(&cx, |this, _| {
2281 let path = this.path.clone();
2282 let summary = if let Some(summary) = this.summary.as_ref() {
2283 if summary.done {
2284 Some(summary.text.clone())
2285 } else {
2286 None
2287 }
2288 } else {
2289 None
2290 };
2291 (path, summary)
2292 })?;
2293
2294 if let Some(summary) = summary {
2295 this.read_with(&cx, |this, cx| this.serialize_images(fs.clone(), cx))?
2296 .await;
2297
2298 let context = this.read_with(&cx, |this, cx| this.serialize(cx))?;
2299 let mut discriminant = 1;
2300 let mut new_path;
2301 loop {
2302 new_path = contexts_dir().join(&format!(
2303 "{} - {}.zed.json",
2304 summary.trim(),
2305 discriminant
2306 ));
2307 if fs.is_file(&new_path).await {
2308 discriminant += 1;
2309 } else {
2310 break;
2311 }
2312 }
2313
2314 fs.create_dir(contexts_dir().as_ref()).await?;
2315 fs.atomic_write(new_path.clone(), serde_json::to_string(&context).unwrap())
2316 .await?;
2317 if let Some(old_path) = old_path {
2318 if new_path != old_path {
2319 fs.remove_file(
2320 &old_path,
2321 RemoveOptions {
2322 recursive: false,
2323 ignore_if_not_exists: true,
2324 },
2325 )
2326 .await?;
2327 }
2328 }
2329
2330 this.update(&mut cx, |this, _| this.path = Some(new_path))?;
2331 }
2332
2333 Ok(())
2334 });
2335 }
2336
2337 pub fn serialize_images(&self, fs: Arc<dyn Fs>, cx: &AppContext) -> Task<()> {
2338 let mut images_to_save = self
2339 .images
2340 .iter()
2341 .map(|(id, (_, llm_image))| {
2342 let fs = fs.clone();
2343 let llm_image = llm_image.clone();
2344 let id = *id;
2345 async move {
2346 if let Some(llm_image) = llm_image.await {
2347 let path: PathBuf =
2348 context_images_dir().join(&format!("{}.png.base64", id));
2349 if fs
2350 .metadata(path.as_path())
2351 .await
2352 .log_err()
2353 .flatten()
2354 .is_none()
2355 {
2356 fs.atomic_write(path, llm_image.source.to_string())
2357 .await
2358 .log_err();
2359 }
2360 }
2361 }
2362 })
2363 .collect::<FuturesUnordered<_>>();
2364 cx.background_executor().spawn(async move {
2365 if fs
2366 .create_dir(context_images_dir().as_ref())
2367 .await
2368 .log_err()
2369 .is_some()
2370 {
2371 while let Some(_) = images_to_save.next().await {}
2372 }
2373 })
2374 }
2375
2376 pub(crate) fn custom_summary(&mut self, custom_summary: String, cx: &mut ModelContext<Self>) {
2377 let timestamp = self.next_timestamp();
2378 let summary = self.summary.get_or_insert(ContextSummary::default());
2379 summary.timestamp = timestamp;
2380 summary.done = true;
2381 summary.text = custom_summary;
2382 cx.emit(ContextEvent::SummaryChanged);
2383 }
2384}
2385
2386#[derive(Debug, Default)]
2387pub struct ContextVersion {
2388 context: clock::Global,
2389 buffer: clock::Global,
2390}
2391
2392impl ContextVersion {
2393 pub fn from_proto(proto: &proto::ContextVersion) -> Self {
2394 Self {
2395 context: language::proto::deserialize_version(&proto.context_version),
2396 buffer: language::proto::deserialize_version(&proto.buffer_version),
2397 }
2398 }
2399
2400 pub fn to_proto(&self, context_id: ContextId) -> proto::ContextVersion {
2401 proto::ContextVersion {
2402 context_id: context_id.to_proto(),
2403 context_version: language::proto::serialize_version(&self.context),
2404 buffer_version: language::proto::serialize_version(&self.buffer),
2405 }
2406 }
2407}
2408
2409#[derive(Debug, Clone)]
2410pub struct PendingSlashCommand {
2411 pub name: String,
2412 pub arguments: SmallVec<[String; 3]>,
2413 pub status: PendingSlashCommandStatus,
2414 pub source_range: Range<language::Anchor>,
2415}
2416
2417#[derive(Debug, Clone)]
2418pub enum PendingSlashCommandStatus {
2419 Idle,
2420 Running { _task: Shared<Task<()>> },
2421 Error(String),
2422}
2423
2424#[derive(Serialize, Deserialize)]
2425pub struct SavedMessage {
2426 pub id: MessageId,
2427 pub start: usize,
2428 pub metadata: MessageMetadata,
2429 #[serde(default)]
2430 // This is defaulted for backwards compatibility with JSON files created before August 2024. We didn't always have this field.
2431 pub image_offsets: Vec<(usize, u64)>,
2432}
2433
2434#[derive(Serialize, Deserialize)]
2435pub struct SavedContext {
2436 pub id: Option<ContextId>,
2437 pub zed: String,
2438 pub version: String,
2439 pub text: String,
2440 pub messages: Vec<SavedMessage>,
2441 pub summary: String,
2442 pub slash_command_output_sections:
2443 Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
2444}
2445
2446impl SavedContext {
2447 pub const VERSION: &'static str = "0.4.0";
2448
2449 pub fn from_json(json: &str) -> Result<Self> {
2450 let saved_context_json = serde_json::from_str::<serde_json::Value>(json)?;
2451 match saved_context_json
2452 .get("version")
2453 .ok_or_else(|| anyhow!("version not found"))?
2454 {
2455 serde_json::Value::String(version) => match version.as_str() {
2456 SavedContext::VERSION => {
2457 Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
2458 }
2459 SavedContextV0_3_0::VERSION => {
2460 let saved_context =
2461 serde_json::from_value::<SavedContextV0_3_0>(saved_context_json)?;
2462 Ok(saved_context.upgrade())
2463 }
2464 SavedContextV0_2_0::VERSION => {
2465 let saved_context =
2466 serde_json::from_value::<SavedContextV0_2_0>(saved_context_json)?;
2467 Ok(saved_context.upgrade())
2468 }
2469 SavedContextV0_1_0::VERSION => {
2470 let saved_context =
2471 serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?;
2472 Ok(saved_context.upgrade())
2473 }
2474 _ => Err(anyhow!("unrecognized saved context version: {}", version)),
2475 },
2476 _ => Err(anyhow!("version not found on saved context")),
2477 }
2478 }
2479
2480 fn into_ops(
2481 self,
2482 buffer: &Model<Buffer>,
2483 cx: &mut ModelContext<Context>,
2484 ) -> Vec<ContextOperation> {
2485 let mut operations = Vec::new();
2486 let mut version = clock::Global::new();
2487 let mut next_timestamp = clock::Lamport::new(ReplicaId::default());
2488
2489 let mut first_message_metadata = None;
2490 for message in self.messages {
2491 if message.id == MessageId(clock::Lamport::default()) {
2492 first_message_metadata = Some(message.metadata);
2493 } else {
2494 operations.push(ContextOperation::InsertMessage {
2495 anchor: MessageAnchor {
2496 id: message.id,
2497 start: buffer.read(cx).anchor_before(message.start),
2498 },
2499 metadata: MessageMetadata {
2500 role: message.metadata.role,
2501 status: message.metadata.status,
2502 timestamp: message.metadata.timestamp,
2503 cache: None,
2504 },
2505 version: version.clone(),
2506 });
2507 version.observe(message.id.0);
2508 next_timestamp.observe(message.id.0);
2509 }
2510 }
2511
2512 if let Some(metadata) = first_message_metadata {
2513 let timestamp = next_timestamp.tick();
2514 operations.push(ContextOperation::UpdateMessage {
2515 message_id: MessageId(clock::Lamport::default()),
2516 metadata: MessageMetadata {
2517 role: metadata.role,
2518 status: metadata.status,
2519 timestamp,
2520 cache: None,
2521 },
2522 version: version.clone(),
2523 });
2524 version.observe(timestamp);
2525 }
2526
2527 let timestamp = next_timestamp.tick();
2528 operations.push(ContextOperation::SlashCommandFinished {
2529 id: SlashCommandId(timestamp),
2530 output_range: language::Anchor::MIN..language::Anchor::MAX,
2531 sections: self
2532 .slash_command_output_sections
2533 .into_iter()
2534 .map(|section| {
2535 let buffer = buffer.read(cx);
2536 SlashCommandOutputSection {
2537 range: buffer.anchor_after(section.range.start)
2538 ..buffer.anchor_before(section.range.end),
2539 icon: section.icon,
2540 label: section.label,
2541 }
2542 })
2543 .collect(),
2544 version: version.clone(),
2545 });
2546 version.observe(timestamp);
2547
2548 let timestamp = next_timestamp.tick();
2549 operations.push(ContextOperation::UpdateSummary {
2550 summary: ContextSummary {
2551 text: self.summary,
2552 done: true,
2553 timestamp,
2554 },
2555 version: version.clone(),
2556 });
2557 version.observe(timestamp);
2558
2559 operations
2560 }
2561}
2562
2563#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
2564struct SavedMessageIdPreV0_4_0(usize);
2565
2566#[derive(Serialize, Deserialize)]
2567struct SavedMessagePreV0_4_0 {
2568 id: SavedMessageIdPreV0_4_0,
2569 start: usize,
2570}
2571
2572#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
2573struct SavedMessageMetadataPreV0_4_0 {
2574 role: Role,
2575 status: MessageStatus,
2576}
2577
2578#[derive(Serialize, Deserialize)]
2579struct SavedContextV0_3_0 {
2580 id: Option<ContextId>,
2581 zed: String,
2582 version: String,
2583 text: String,
2584 messages: Vec<SavedMessagePreV0_4_0>,
2585 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2586 summary: String,
2587 slash_command_output_sections: Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
2588}
2589
2590impl SavedContextV0_3_0 {
2591 const VERSION: &'static str = "0.3.0";
2592
2593 fn upgrade(self) -> SavedContext {
2594 SavedContext {
2595 id: self.id,
2596 zed: self.zed,
2597 version: SavedContext::VERSION.into(),
2598 text: self.text,
2599 messages: self
2600 .messages
2601 .into_iter()
2602 .filter_map(|message| {
2603 let metadata = self.message_metadata.get(&message.id)?;
2604 let timestamp = clock::Lamport {
2605 replica_id: ReplicaId::default(),
2606 value: message.id.0 as u32,
2607 };
2608 Some(SavedMessage {
2609 id: MessageId(timestamp),
2610 start: message.start,
2611 metadata: MessageMetadata {
2612 role: metadata.role,
2613 status: metadata.status.clone(),
2614 timestamp,
2615 cache: None,
2616 },
2617 image_offsets: Vec::new(),
2618 })
2619 })
2620 .collect(),
2621 summary: self.summary,
2622 slash_command_output_sections: self.slash_command_output_sections,
2623 }
2624 }
2625}
2626
2627#[derive(Serialize, Deserialize)]
2628struct SavedContextV0_2_0 {
2629 id: Option<ContextId>,
2630 zed: String,
2631 version: String,
2632 text: String,
2633 messages: Vec<SavedMessagePreV0_4_0>,
2634 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2635 summary: String,
2636}
2637
2638impl SavedContextV0_2_0 {
2639 const VERSION: &'static str = "0.2.0";
2640
2641 fn upgrade(self) -> SavedContext {
2642 SavedContextV0_3_0 {
2643 id: self.id,
2644 zed: self.zed,
2645 version: SavedContextV0_3_0::VERSION.to_string(),
2646 text: self.text,
2647 messages: self.messages,
2648 message_metadata: self.message_metadata,
2649 summary: self.summary,
2650 slash_command_output_sections: Vec::new(),
2651 }
2652 .upgrade()
2653 }
2654}
2655
2656#[derive(Serialize, Deserialize)]
2657struct SavedContextV0_1_0 {
2658 id: Option<ContextId>,
2659 zed: String,
2660 version: String,
2661 text: String,
2662 messages: Vec<SavedMessagePreV0_4_0>,
2663 message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2664 summary: String,
2665 api_url: Option<String>,
2666 model: OpenAiModel,
2667}
2668
2669impl SavedContextV0_1_0 {
2670 const VERSION: &'static str = "0.1.0";
2671
2672 fn upgrade(self) -> SavedContext {
2673 SavedContextV0_2_0 {
2674 id: self.id,
2675 zed: self.zed,
2676 version: SavedContextV0_2_0::VERSION.to_string(),
2677 text: self.text,
2678 messages: self.messages,
2679 message_metadata: self.message_metadata,
2680 summary: self.summary,
2681 }
2682 .upgrade()
2683 }
2684}
2685
2686#[derive(Clone)]
2687pub struct SavedContextMetadata {
2688 pub title: String,
2689 pub path: PathBuf,
2690 pub mtime: chrono::DateTime<chrono::Local>,
2691}