context.rs

   1use crate::{
   2    prompt_library::PromptStore, slash_command::SlashCommandLine, InitialInsertion,
   3    LanguageModelCompletionProvider, MessageId, MessageStatus,
   4};
   5use anyhow::{anyhow, Context as _, Result};
   6use assistant_slash_command::{
   7    SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
   8};
   9use client::{self, proto, telemetry::Telemetry};
  10use clock::ReplicaId;
  11use collections::{HashMap, HashSet};
  12use fs::{Fs, RemoveOptions};
  13use futures::{
  14    future::{self, Shared},
  15    FutureExt, StreamExt,
  16};
  17use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscription, Task};
  18use language::{
  19    AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
  20};
  21use language_model::{LanguageModelRequest, LanguageModelRequestMessage, LanguageModelTool, Role};
  22use open_ai::Model as OpenAiModel;
  23use paths::contexts_dir;
  24use project::Project;
  25use schemars::JsonSchema;
  26use serde::{Deserialize, Serialize};
  27use std::{
  28    cmp,
  29    fmt::Debug,
  30    iter, mem,
  31    ops::Range,
  32    path::{Path, PathBuf},
  33    sync::Arc,
  34    time::{Duration, Instant},
  35};
  36use telemetry_events::AssistantKind;
  37use ui::SharedString;
  38use util::{post_inc, ResultExt, TryFutureExt};
  39use uuid::Uuid;
  40
  41#[derive(Clone, Eq, PartialEq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
  42pub struct ContextId(String);
  43
  44impl ContextId {
  45    pub fn new() -> Self {
  46        Self(Uuid::new_v4().to_string())
  47    }
  48
  49    pub fn from_proto(id: String) -> Self {
  50        Self(id)
  51    }
  52
  53    pub fn to_proto(&self) -> String {
  54        self.0.clone()
  55    }
  56}
  57
  58#[derive(Clone, Debug)]
  59pub enum ContextOperation {
  60    InsertMessage {
  61        anchor: MessageAnchor,
  62        metadata: MessageMetadata,
  63        version: clock::Global,
  64    },
  65    UpdateMessage {
  66        message_id: MessageId,
  67        metadata: MessageMetadata,
  68        version: clock::Global,
  69    },
  70    UpdateSummary {
  71        summary: ContextSummary,
  72        version: clock::Global,
  73    },
  74    SlashCommandFinished {
  75        id: SlashCommandId,
  76        output_range: Range<language::Anchor>,
  77        sections: Vec<SlashCommandOutputSection<language::Anchor>>,
  78        version: clock::Global,
  79    },
  80    BufferOperation(language::Operation),
  81}
  82
  83impl ContextOperation {
  84    pub fn from_proto(op: proto::ContextOperation) -> Result<Self> {
  85        match op.variant.context("invalid variant")? {
  86            proto::context_operation::Variant::InsertMessage(insert) => {
  87                let message = insert.message.context("invalid message")?;
  88                let id = MessageId(language::proto::deserialize_timestamp(
  89                    message.id.context("invalid id")?,
  90                ));
  91                Ok(Self::InsertMessage {
  92                    anchor: MessageAnchor {
  93                        id,
  94                        start: language::proto::deserialize_anchor(
  95                            message.start.context("invalid anchor")?,
  96                        )
  97                        .context("invalid anchor")?,
  98                    },
  99                    metadata: MessageMetadata {
 100                        role: Role::from_proto(message.role),
 101                        status: MessageStatus::from_proto(
 102                            message.status.context("invalid status")?,
 103                        ),
 104                        timestamp: id.0,
 105                    },
 106                    version: language::proto::deserialize_version(&insert.version),
 107                })
 108            }
 109            proto::context_operation::Variant::UpdateMessage(update) => Ok(Self::UpdateMessage {
 110                message_id: MessageId(language::proto::deserialize_timestamp(
 111                    update.message_id.context("invalid message id")?,
 112                )),
 113                metadata: MessageMetadata {
 114                    role: Role::from_proto(update.role),
 115                    status: MessageStatus::from_proto(update.status.context("invalid status")?),
 116                    timestamp: language::proto::deserialize_timestamp(
 117                        update.timestamp.context("invalid timestamp")?,
 118                    ),
 119                },
 120                version: language::proto::deserialize_version(&update.version),
 121            }),
 122            proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary {
 123                summary: ContextSummary {
 124                    text: update.summary,
 125                    done: update.done,
 126                    timestamp: language::proto::deserialize_timestamp(
 127                        update.timestamp.context("invalid timestamp")?,
 128                    ),
 129                },
 130                version: language::proto::deserialize_version(&update.version),
 131            }),
 132            proto::context_operation::Variant::SlashCommandFinished(finished) => {
 133                Ok(Self::SlashCommandFinished {
 134                    id: SlashCommandId(language::proto::deserialize_timestamp(
 135                        finished.id.context("invalid id")?,
 136                    )),
 137                    output_range: language::proto::deserialize_anchor_range(
 138                        finished.output_range.context("invalid range")?,
 139                    )?,
 140                    sections: finished
 141                        .sections
 142                        .into_iter()
 143                        .map(|section| {
 144                            Ok(SlashCommandOutputSection {
 145                                range: language::proto::deserialize_anchor_range(
 146                                    section.range.context("invalid range")?,
 147                                )?,
 148                                icon: section.icon_name.parse()?,
 149                                label: section.label.into(),
 150                            })
 151                        })
 152                        .collect::<Result<Vec<_>>>()?,
 153                    version: language::proto::deserialize_version(&finished.version),
 154                })
 155            }
 156            proto::context_operation::Variant::BufferOperation(op) => Ok(Self::BufferOperation(
 157                language::proto::deserialize_operation(
 158                    op.operation.context("invalid buffer operation")?,
 159                )?,
 160            )),
 161        }
 162    }
 163
 164    pub fn to_proto(&self) -> proto::ContextOperation {
 165        match self {
 166            Self::InsertMessage {
 167                anchor,
 168                metadata,
 169                version,
 170            } => proto::ContextOperation {
 171                variant: Some(proto::context_operation::Variant::InsertMessage(
 172                    proto::context_operation::InsertMessage {
 173                        message: Some(proto::ContextMessage {
 174                            id: Some(language::proto::serialize_timestamp(anchor.id.0)),
 175                            start: Some(language::proto::serialize_anchor(&anchor.start)),
 176                            role: metadata.role.to_proto() as i32,
 177                            status: Some(metadata.status.to_proto()),
 178                        }),
 179                        version: language::proto::serialize_version(version),
 180                    },
 181                )),
 182            },
 183            Self::UpdateMessage {
 184                message_id,
 185                metadata,
 186                version,
 187            } => proto::ContextOperation {
 188                variant: Some(proto::context_operation::Variant::UpdateMessage(
 189                    proto::context_operation::UpdateMessage {
 190                        message_id: Some(language::proto::serialize_timestamp(message_id.0)),
 191                        role: metadata.role.to_proto() as i32,
 192                        status: Some(metadata.status.to_proto()),
 193                        timestamp: Some(language::proto::serialize_timestamp(metadata.timestamp)),
 194                        version: language::proto::serialize_version(version),
 195                    },
 196                )),
 197            },
 198            Self::UpdateSummary { summary, version } => proto::ContextOperation {
 199                variant: Some(proto::context_operation::Variant::UpdateSummary(
 200                    proto::context_operation::UpdateSummary {
 201                        summary: summary.text.clone(),
 202                        done: summary.done,
 203                        timestamp: Some(language::proto::serialize_timestamp(summary.timestamp)),
 204                        version: language::proto::serialize_version(version),
 205                    },
 206                )),
 207            },
 208            Self::SlashCommandFinished {
 209                id,
 210                output_range,
 211                sections,
 212                version,
 213            } => proto::ContextOperation {
 214                variant: Some(proto::context_operation::Variant::SlashCommandFinished(
 215                    proto::context_operation::SlashCommandFinished {
 216                        id: Some(language::proto::serialize_timestamp(id.0)),
 217                        output_range: Some(language::proto::serialize_anchor_range(
 218                            output_range.clone(),
 219                        )),
 220                        sections: sections
 221                            .iter()
 222                            .map(|section| {
 223                                let icon_name: &'static str = section.icon.into();
 224                                proto::SlashCommandOutputSection {
 225                                    range: Some(language::proto::serialize_anchor_range(
 226                                        section.range.clone(),
 227                                    )),
 228                                    icon_name: icon_name.to_string(),
 229                                    label: section.label.to_string(),
 230                                }
 231                            })
 232                            .collect(),
 233                        version: language::proto::serialize_version(version),
 234                    },
 235                )),
 236            },
 237            Self::BufferOperation(operation) => proto::ContextOperation {
 238                variant: Some(proto::context_operation::Variant::BufferOperation(
 239                    proto::context_operation::BufferOperation {
 240                        operation: Some(language::proto::serialize_operation(operation)),
 241                    },
 242                )),
 243            },
 244        }
 245    }
 246
 247    fn timestamp(&self) -> clock::Lamport {
 248        match self {
 249            Self::InsertMessage { anchor, .. } => anchor.id.0,
 250            Self::UpdateMessage { metadata, .. } => metadata.timestamp,
 251            Self::UpdateSummary { summary, .. } => summary.timestamp,
 252            Self::SlashCommandFinished { id, .. } => id.0,
 253            Self::BufferOperation(_) => {
 254                panic!("reading the timestamp of a buffer operation is not supported")
 255            }
 256        }
 257    }
 258
 259    /// Returns the current version of the context operation.
 260    pub fn version(&self) -> &clock::Global {
 261        match self {
 262            Self::InsertMessage { version, .. }
 263            | Self::UpdateMessage { version, .. }
 264            | Self::UpdateSummary { version, .. }
 265            | Self::SlashCommandFinished { version, .. } => version,
 266            Self::BufferOperation(_) => {
 267                panic!("reading the version of a buffer operation is not supported")
 268            }
 269        }
 270    }
 271}
 272
 273#[derive(Clone)]
 274pub enum ContextEvent {
 275    MessagesEdited,
 276    SummaryChanged,
 277    EditStepsChanged,
 278    StreamedCompletion,
 279    PendingSlashCommandsUpdated {
 280        removed: Vec<Range<language::Anchor>>,
 281        updated: Vec<PendingSlashCommand>,
 282    },
 283    SlashCommandFinished {
 284        output_range: Range<language::Anchor>,
 285        sections: Vec<SlashCommandOutputSection<language::Anchor>>,
 286        run_commands_in_output: bool,
 287    },
 288    Operation(ContextOperation),
 289}
 290
 291#[derive(Clone, Default, Debug)]
 292pub struct ContextSummary {
 293    pub text: String,
 294    done: bool,
 295    timestamp: clock::Lamport,
 296}
 297
 298#[derive(Clone, Debug, Eq, PartialEq)]
 299pub struct MessageAnchor {
 300    pub id: MessageId,
 301    pub start: language::Anchor,
 302}
 303
 304#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
 305pub struct MessageMetadata {
 306    pub role: Role,
 307    status: MessageStatus,
 308    timestamp: clock::Lamport,
 309}
 310
 311#[derive(Clone, Debug, PartialEq, Eq)]
 312pub struct Message {
 313    pub offset_range: Range<usize>,
 314    pub index_range: Range<usize>,
 315    pub id: MessageId,
 316    pub anchor: language::Anchor,
 317    pub role: Role,
 318    pub status: MessageStatus,
 319}
 320
 321impl Message {
 322    fn to_request_message(&self, buffer: &Buffer) -> LanguageModelRequestMessage {
 323        LanguageModelRequestMessage {
 324            role: self.role,
 325            content: buffer.text_for_range(self.offset_range.clone()).collect(),
 326        }
 327    }
 328}
 329
 330struct PendingCompletion {
 331    id: usize,
 332    _task: Task<()>,
 333}
 334
 335#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
 336pub struct SlashCommandId(clock::Lamport);
 337
 338#[derive(Debug)]
 339pub struct EditStep {
 340    pub source_range: Range<language::Anchor>,
 341    pub operations: Option<EditStepOperations>,
 342}
 343
 344#[derive(Debug)]
 345pub struct EditSuggestionGroup {
 346    pub context_range: Range<language::Anchor>,
 347    pub suggestions: Vec<EditSuggestion>,
 348}
 349
 350#[derive(Debug)]
 351pub struct EditSuggestion {
 352    pub range: Range<language::Anchor>,
 353    /// If None, assume this is a suggestion to delete the range rather than transform it.
 354    pub description: Option<String>,
 355    pub initial_insertion: Option<InitialInsertion>,
 356}
 357
 358impl EditStep {
 359    pub fn edit_suggestions(
 360        &self,
 361        project: &Model<Project>,
 362        cx: &AppContext,
 363    ) -> Task<HashMap<Model<Buffer>, Vec<EditSuggestionGroup>>> {
 364        let Some(EditStepOperations::Ready(operations)) = &self.operations else {
 365            return Task::ready(HashMap::default());
 366        };
 367
 368        let suggestion_tasks: Vec<_> = operations
 369            .iter()
 370            .map(|operation| operation.edit_suggestion(project.clone(), cx))
 371            .collect();
 372
 373        cx.spawn(|mut cx| async move {
 374            let suggestions = future::join_all(suggestion_tasks)
 375                .await
 376                .into_iter()
 377                .filter_map(|task| task.log_err())
 378                .collect::<Vec<_>>();
 379
 380            let mut suggestions_by_buffer = HashMap::default();
 381            for (buffer, suggestion) in suggestions {
 382                suggestions_by_buffer
 383                    .entry(buffer)
 384                    .or_insert_with(Vec::new)
 385                    .push(suggestion);
 386            }
 387
 388            let mut suggestion_groups_by_buffer = HashMap::default();
 389            for (buffer, mut suggestions) in suggestions_by_buffer {
 390                let mut suggestion_groups = Vec::<EditSuggestionGroup>::new();
 391                buffer
 392                    .update(&mut cx, |buffer, _cx| {
 393                        // Sort suggestions by their range
 394                        suggestions.sort_by(|a, b| a.range.cmp(&b.range, buffer));
 395
 396                        // Dedup overlapping suggestions
 397                        suggestions.dedup_by(|a, b| {
 398                            let a_range = a.range.to_offset(buffer);
 399                            let b_range = b.range.to_offset(buffer);
 400                            if a_range.start <= b_range.end && b_range.start <= a_range.end {
 401                                if b_range.start < a_range.start {
 402                                    a.range.start = b.range.start;
 403                                }
 404                                if b_range.end > a_range.end {
 405                                    a.range.end = b.range.end;
 406                                }
 407
 408                                if let (Some(a_desc), Some(b_desc)) =
 409                                    (a.description.as_mut(), b.description.as_mut())
 410                                {
 411                                    b_desc.push('\n');
 412                                    b_desc.push_str(a_desc);
 413                                } else if a.description.is_some() {
 414                                    b.description = a.description.take();
 415                                }
 416
 417                                true
 418                            } else {
 419                                false
 420                            }
 421                        });
 422
 423                        // Create context ranges for each suggestion
 424                        for suggestion in suggestions {
 425                            let context_range = {
 426                                let suggestion_point_range = suggestion.range.to_point(buffer);
 427                                let start_row = suggestion_point_range.start.row.saturating_sub(5);
 428                                let end_row = cmp::min(
 429                                    suggestion_point_range.end.row + 5,
 430                                    buffer.max_point().row,
 431                                );
 432                                let start = buffer.anchor_before(Point::new(start_row, 0));
 433                                let end = buffer
 434                                    .anchor_after(Point::new(end_row, buffer.line_len(end_row)));
 435                                start..end
 436                            };
 437
 438                            if let Some(last_group) = suggestion_groups.last_mut() {
 439                                if last_group
 440                                    .context_range
 441                                    .end
 442                                    .cmp(&context_range.start, buffer)
 443                                    .is_ge()
 444                                {
 445                                    // Merge with the previous group if context ranges overlap
 446                                    last_group.context_range.end = context_range.end;
 447                                    last_group.suggestions.push(suggestion);
 448                                } else {
 449                                    // Create a new group
 450                                    suggestion_groups.push(EditSuggestionGroup {
 451                                        context_range,
 452                                        suggestions: vec![suggestion],
 453                                    });
 454                                }
 455                            } else {
 456                                // Create the first group
 457                                suggestion_groups.push(EditSuggestionGroup {
 458                                    context_range,
 459                                    suggestions: vec![suggestion],
 460                                });
 461                            }
 462                        }
 463                    })
 464                    .ok();
 465                suggestion_groups_by_buffer.insert(buffer, suggestion_groups);
 466            }
 467
 468            suggestion_groups_by_buffer
 469        })
 470    }
 471}
 472
 473pub enum EditStepOperations {
 474    Pending(Task<Option<()>>),
 475    Ready(Vec<EditOperation>),
 476}
 477
 478impl Debug for EditStepOperations {
 479    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 480        match self {
 481            EditStepOperations::Pending(_) => write!(f, "EditStepOperations::Pending"),
 482            EditStepOperations::Ready(operations) => f
 483                .debug_struct("EditStepOperations::Parsed")
 484                .field("operations", operations)
 485                .finish(),
 486        }
 487    }
 488}
 489
 490/// A description of an operation to apply to one location in the codebase.
 491#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
 492pub struct EditOperation {
 493    /// The path to the file containing the relevant operation
 494    pub path: String,
 495    #[serde(flatten)]
 496    pub kind: EditOperationKind,
 497}
 498
 499impl EditOperation {
 500    fn edit_suggestion(
 501        &self,
 502        project: Model<Project>,
 503        cx: &AppContext,
 504    ) -> Task<Result<(Model<language::Buffer>, EditSuggestion)>> {
 505        let path = self.path.clone();
 506        let kind = self.kind.clone();
 507        cx.spawn(move |mut cx| async move {
 508            let buffer = project
 509                .update(&mut cx, |project, cx| {
 510                    let project_path = project
 511                        .project_path_for_full_path(Path::new(&path), cx)
 512                        .with_context(|| format!("worktree not found for {:?}", path))?;
 513                    anyhow::Ok(project.open_buffer(project_path, cx))
 514                })??
 515                .await?;
 516
 517            let mut parse_status = buffer.read_with(&cx, |buffer, _cx| buffer.parse_status())?;
 518            while *parse_status.borrow() != ParseStatus::Idle {
 519                parse_status.changed().await?;
 520            }
 521
 522            let initial_insertion = kind.initial_insertion();
 523            let suggestion_range = if let Some(symbol) = kind.symbol() {
 524                let outline = buffer
 525                    .update(&mut cx, |buffer, _| buffer.snapshot().outline(None))?
 526                    .context("no outline for buffer")?;
 527                let candidate = outline
 528                    .path_candidates
 529                    .iter()
 530                    .find(|item| item.string == symbol)
 531                    .with_context(|| {
 532                        format!(
 533                            "symbol {:?} not found in path {:?}.\ncandidates: {:?}.\nparse status: {:?}. text:\n{}",
 534                            symbol,
 535                            path,
 536                            outline
 537                                .path_candidates
 538                                .iter()
 539                                .map(|candidate| &candidate.string)
 540                                .collect::<Vec<_>>(),
 541                            *parse_status.borrow(),
 542                            buffer.read_with(&cx, |buffer, _| buffer.text()).unwrap_or_else(|_| "error".to_string())
 543                        )
 544                    })?;
 545
 546                buffer.update(&mut cx, |buffer, _| {
 547                    let outline_item = &outline.items[candidate.id];
 548                    let symbol_range = outline_item.range.to_point(buffer);
 549                    let body_range = outline_item
 550                        .body_range
 551                        .as_ref()
 552                        .map(|range| range.to_point(buffer))
 553                        .unwrap_or(symbol_range.clone());
 554
 555                    match kind {
 556                        EditOperationKind::PrependChild { .. } => {
 557                            let position = buffer.anchor_after(body_range.start);
 558                            position..position
 559                        }
 560                        EditOperationKind::AppendChild { .. } => {
 561                            let position = buffer.anchor_before(body_range.end);
 562                            position..position
 563                        }
 564                        EditOperationKind::InsertSiblingBefore { .. } => {
 565                            let position = buffer.anchor_before(symbol_range.start);
 566                            position..position
 567                        }
 568                        EditOperationKind::InsertSiblingAfter { .. } => {
 569                            let position = buffer.anchor_after(symbol_range.end);
 570                            position..position
 571                        }
 572                        EditOperationKind::Update { .. } | EditOperationKind::Delete { .. } => {
 573                            let start = Point::new(symbol_range.start.row, 0);
 574                            let end = Point::new(
 575                                symbol_range.end.row,
 576                                buffer.line_len(symbol_range.end.row),
 577                            );
 578                            buffer.anchor_before(start)..buffer.anchor_after(end)
 579                        }
 580                        EditOperationKind::Create { .. } => unreachable!(),
 581                    }
 582                })?
 583            } else {
 584                match kind {
 585                    EditOperationKind::PrependChild { .. } => {
 586                        language::Anchor::MIN..language::Anchor::MIN
 587                    }
 588                    EditOperationKind::AppendChild { .. } | EditOperationKind::Create { .. } => {
 589                        language::Anchor::MAX..language::Anchor::MAX
 590                    }
 591                    _ => unreachable!("All other operations should have a symbol"),
 592                }
 593            };
 594
 595            Ok((
 596                buffer,
 597                EditSuggestion {
 598                    range: suggestion_range,
 599                    description: kind.description().map(ToString::to_string),
 600                    initial_insertion,
 601                },
 602            ))
 603        })
 604    }
 605}
 606
 607#[derive(Clone, Debug, PartialEq, Eq, Deserialize, JsonSchema)]
 608#[serde(tag = "kind")]
 609pub enum EditOperationKind {
 610    /// Rewrite the specified symbol in its entirely based on the given description.
 611    Update {
 612        /// A full path to the symbol to be rewritten from the provided list.
 613        symbol: String,
 614        /// A brief one-line description of the change that should be applied.
 615        description: String,
 616    },
 617    /// Create a new file with the given path based on the given description.
 618    Create {
 619        /// A brief one-line description of the change that should be applied.
 620        description: String,
 621    },
 622    /// Insert a new symbol based on the given description before the specified symbol.
 623    InsertSiblingBefore {
 624        /// A full path to the symbol to be rewritten from the provided list.
 625        symbol: String,
 626        /// A brief one-line description of the change that should be applied.
 627        description: String,
 628    },
 629    /// Insert a new symbol based on the given description after the specified symbol.
 630    InsertSiblingAfter {
 631        /// A full path to the symbol to be rewritten from the provided list.
 632        symbol: String,
 633        /// A brief one-line description of the change that should be applied.
 634        description: String,
 635    },
 636    /// Insert a new symbol as a child of the specified symbol at the start.
 637    PrependChild {
 638        /// An optional full path to the symbol to be rewritten from the provided list.
 639        /// If not provided, the edit should be applied at the top of the file.
 640        symbol: Option<String>,
 641        /// A brief one-line description of the change that should be applied.
 642        description: String,
 643    },
 644    /// Insert a new symbol as a child of the specified symbol at the end.
 645    AppendChild {
 646        /// An optional full path to the symbol to be rewritten from the provided list.
 647        /// If not provided, the edit should be applied at the top of the file.
 648        symbol: Option<String>,
 649        /// A brief one-line description of the change that should be applied.
 650        description: String,
 651    },
 652    /// Delete the specified symbol.
 653    Delete {
 654        /// A full path to the symbol to be rewritten from the provided list.
 655        symbol: String,
 656    },
 657}
 658
 659impl EditOperationKind {
 660    pub fn symbol(&self) -> Option<&str> {
 661        match self {
 662            Self::Update { symbol, .. } => Some(symbol),
 663            Self::InsertSiblingBefore { symbol, .. } => Some(symbol),
 664            Self::InsertSiblingAfter { symbol, .. } => Some(symbol),
 665            Self::PrependChild { symbol, .. } => symbol.as_deref(),
 666            Self::AppendChild { symbol, .. } => symbol.as_deref(),
 667            Self::Delete { symbol } => Some(symbol),
 668            Self::Create { .. } => None,
 669        }
 670    }
 671
 672    pub fn description(&self) -> Option<&str> {
 673        match self {
 674            Self::Update { description, .. } => Some(description),
 675            Self::Create { description } => Some(description),
 676            Self::InsertSiblingBefore { description, .. } => Some(description),
 677            Self::InsertSiblingAfter { description, .. } => Some(description),
 678            Self::PrependChild { description, .. } => Some(description),
 679            Self::AppendChild { description, .. } => Some(description),
 680            Self::Delete { .. } => None,
 681        }
 682    }
 683
 684    pub fn initial_insertion(&self) -> Option<InitialInsertion> {
 685        match self {
 686            EditOperationKind::InsertSiblingBefore { .. } => Some(InitialInsertion::NewlineAfter),
 687            EditOperationKind::InsertSiblingAfter { .. } => Some(InitialInsertion::NewlineBefore),
 688            EditOperationKind::PrependChild { .. } => Some(InitialInsertion::NewlineAfter),
 689            EditOperationKind::AppendChild { .. } => Some(InitialInsertion::NewlineBefore),
 690            _ => None,
 691        }
 692    }
 693}
 694
 695pub struct Context {
 696    id: ContextId,
 697    timestamp: clock::Lamport,
 698    version: clock::Global,
 699    pending_ops: Vec<ContextOperation>,
 700    operations: Vec<ContextOperation>,
 701    buffer: Model<Buffer>,
 702    pending_slash_commands: Vec<PendingSlashCommand>,
 703    edits_since_last_slash_command_parse: language::Subscription,
 704    finished_slash_commands: HashSet<SlashCommandId>,
 705    slash_command_output_sections: Vec<SlashCommandOutputSection<language::Anchor>>,
 706    message_anchors: Vec<MessageAnchor>,
 707    messages_metadata: HashMap<MessageId, MessageMetadata>,
 708    summary: Option<ContextSummary>,
 709    pending_summary: Task<Option<()>>,
 710    completion_count: usize,
 711    pending_completions: Vec<PendingCompletion>,
 712    token_count: Option<usize>,
 713    pending_token_count: Task<Option<()>>,
 714    pending_save: Task<Result<()>>,
 715    path: Option<PathBuf>,
 716    _subscriptions: Vec<Subscription>,
 717    telemetry: Option<Arc<Telemetry>>,
 718    language_registry: Arc<LanguageRegistry>,
 719    edit_steps: Vec<EditStep>,
 720}
 721
 722impl EventEmitter<ContextEvent> for Context {}
 723
 724impl Context {
 725    pub fn local(
 726        language_registry: Arc<LanguageRegistry>,
 727        telemetry: Option<Arc<Telemetry>>,
 728        cx: &mut ModelContext<Self>,
 729    ) -> Self {
 730        Self::new(
 731            ContextId::new(),
 732            ReplicaId::default(),
 733            language::Capability::ReadWrite,
 734            language_registry,
 735            telemetry,
 736            cx,
 737        )
 738    }
 739
 740    pub fn new(
 741        id: ContextId,
 742        replica_id: ReplicaId,
 743        capability: language::Capability,
 744        language_registry: Arc<LanguageRegistry>,
 745        telemetry: Option<Arc<Telemetry>>,
 746        cx: &mut ModelContext<Self>,
 747    ) -> Self {
 748        let buffer = cx.new_model(|_cx| {
 749            let mut buffer = Buffer::remote(
 750                language::BufferId::new(1).unwrap(),
 751                replica_id,
 752                capability,
 753                "",
 754            );
 755            buffer.set_language_registry(language_registry.clone());
 756            buffer
 757        });
 758        let edits_since_last_slash_command_parse =
 759            buffer.update(cx, |buffer, _| buffer.subscribe());
 760        let mut this = Self {
 761            id,
 762            timestamp: clock::Lamport::new(replica_id),
 763            version: clock::Global::new(),
 764            pending_ops: Vec::new(),
 765            operations: Vec::new(),
 766            message_anchors: Default::default(),
 767            messages_metadata: Default::default(),
 768            pending_slash_commands: Vec::new(),
 769            finished_slash_commands: HashSet::default(),
 770            slash_command_output_sections: Vec::new(),
 771            edits_since_last_slash_command_parse,
 772            summary: None,
 773            pending_summary: Task::ready(None),
 774            completion_count: Default::default(),
 775            pending_completions: Default::default(),
 776            token_count: None,
 777            pending_token_count: Task::ready(None),
 778            _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)],
 779            pending_save: Task::ready(Ok(())),
 780            path: None,
 781            buffer,
 782            telemetry,
 783            language_registry,
 784            edit_steps: Vec::new(),
 785        };
 786
 787        let first_message_id = MessageId(clock::Lamport {
 788            replica_id: 0,
 789            value: 0,
 790        });
 791        let message = MessageAnchor {
 792            id: first_message_id,
 793            start: language::Anchor::MIN,
 794        };
 795        this.messages_metadata.insert(
 796            first_message_id,
 797            MessageMetadata {
 798                role: Role::User,
 799                status: MessageStatus::Done,
 800                timestamp: first_message_id.0,
 801            },
 802        );
 803        this.message_anchors.push(message);
 804
 805        this.set_language(cx);
 806        this.count_remaining_tokens(cx);
 807        this
 808    }
 809
 810    fn serialize(&self, cx: &AppContext) -> SavedContext {
 811        let buffer = self.buffer.read(cx);
 812        SavedContext {
 813            id: Some(self.id.clone()),
 814            zed: "context".into(),
 815            version: SavedContext::VERSION.into(),
 816            text: buffer.text(),
 817            messages: self
 818                .messages(cx)
 819                .map(|message| SavedMessage {
 820                    id: message.id,
 821                    start: message.offset_range.start,
 822                    metadata: self.messages_metadata[&message.id].clone(),
 823                })
 824                .collect(),
 825            summary: self
 826                .summary
 827                .as_ref()
 828                .map(|summary| summary.text.clone())
 829                .unwrap_or_default(),
 830            slash_command_output_sections: self
 831                .slash_command_output_sections
 832                .iter()
 833                .filter_map(|section| {
 834                    let range = section.range.to_offset(buffer);
 835                    if section.range.start.is_valid(buffer) && !range.is_empty() {
 836                        Some(assistant_slash_command::SlashCommandOutputSection {
 837                            range,
 838                            icon: section.icon,
 839                            label: section.label.clone(),
 840                        })
 841                    } else {
 842                        None
 843                    }
 844                })
 845                .collect(),
 846        }
 847    }
 848
 849    #[allow(clippy::too_many_arguments)]
 850    pub fn deserialize(
 851        saved_context: SavedContext,
 852        path: PathBuf,
 853        language_registry: Arc<LanguageRegistry>,
 854        telemetry: Option<Arc<Telemetry>>,
 855        cx: &mut ModelContext<Self>,
 856    ) -> Self {
 857        let id = saved_context.id.clone().unwrap_or_else(|| ContextId::new());
 858        let mut this = Self::new(
 859            id,
 860            ReplicaId::default(),
 861            language::Capability::ReadWrite,
 862            language_registry,
 863            telemetry,
 864            cx,
 865        );
 866        this.path = Some(path);
 867        this.buffer.update(cx, |buffer, cx| {
 868            buffer.set_text(saved_context.text.as_str(), cx)
 869        });
 870        let operations = saved_context.into_ops(&this.buffer, cx);
 871        this.apply_ops(operations, cx).unwrap();
 872        this
 873    }
 874
 875    pub fn id(&self) -> &ContextId {
 876        &self.id
 877    }
 878
 879    pub fn replica_id(&self) -> ReplicaId {
 880        self.timestamp.replica_id
 881    }
 882
 883    pub fn version(&self, cx: &AppContext) -> ContextVersion {
 884        ContextVersion {
 885            context: self.version.clone(),
 886            buffer: self.buffer.read(cx).version(),
 887        }
 888    }
 889
 890    pub fn set_capability(
 891        &mut self,
 892        capability: language::Capability,
 893        cx: &mut ModelContext<Self>,
 894    ) {
 895        self.buffer
 896            .update(cx, |buffer, cx| buffer.set_capability(capability, cx));
 897    }
 898
 899    fn next_timestamp(&mut self) -> clock::Lamport {
 900        let timestamp = self.timestamp.tick();
 901        self.version.observe(timestamp);
 902        timestamp
 903    }
 904
 905    pub fn serialize_ops(
 906        &self,
 907        since: &ContextVersion,
 908        cx: &AppContext,
 909    ) -> Task<Vec<proto::ContextOperation>> {
 910        let buffer_ops = self
 911            .buffer
 912            .read(cx)
 913            .serialize_ops(Some(since.buffer.clone()), cx);
 914
 915        let mut context_ops = self
 916            .operations
 917            .iter()
 918            .filter(|op| !since.context.observed(op.timestamp()))
 919            .cloned()
 920            .collect::<Vec<_>>();
 921        context_ops.extend(self.pending_ops.iter().cloned());
 922
 923        cx.background_executor().spawn(async move {
 924            let buffer_ops = buffer_ops.await;
 925            context_ops.sort_unstable_by_key(|op| op.timestamp());
 926            buffer_ops
 927                .into_iter()
 928                .map(|op| proto::ContextOperation {
 929                    variant: Some(proto::context_operation::Variant::BufferOperation(
 930                        proto::context_operation::BufferOperation {
 931                            operation: Some(op),
 932                        },
 933                    )),
 934                })
 935                .chain(context_ops.into_iter().map(|op| op.to_proto()))
 936                .collect()
 937        })
 938    }
 939
 940    pub fn apply_ops(
 941        &mut self,
 942        ops: impl IntoIterator<Item = ContextOperation>,
 943        cx: &mut ModelContext<Self>,
 944    ) -> Result<()> {
 945        let mut buffer_ops = Vec::new();
 946        for op in ops {
 947            match op {
 948                ContextOperation::BufferOperation(buffer_op) => buffer_ops.push(buffer_op),
 949                op @ _ => self.pending_ops.push(op),
 950            }
 951        }
 952        self.buffer
 953            .update(cx, |buffer, cx| buffer.apply_ops(buffer_ops, cx))?;
 954        self.flush_ops(cx);
 955
 956        Ok(())
 957    }
 958
 959    fn flush_ops(&mut self, cx: &mut ModelContext<Context>) {
 960        let mut messages_changed = false;
 961        let mut summary_changed = false;
 962
 963        self.pending_ops.sort_unstable_by_key(|op| op.timestamp());
 964        for op in mem::take(&mut self.pending_ops) {
 965            if !self.can_apply_op(&op, cx) {
 966                self.pending_ops.push(op);
 967                continue;
 968            }
 969
 970            let timestamp = op.timestamp();
 971            match op.clone() {
 972                ContextOperation::InsertMessage {
 973                    anchor, metadata, ..
 974                } => {
 975                    if self.messages_metadata.contains_key(&anchor.id) {
 976                        // We already applied this operation.
 977                    } else {
 978                        self.insert_message(anchor, metadata, cx);
 979                        messages_changed = true;
 980                    }
 981                }
 982                ContextOperation::UpdateMessage {
 983                    message_id,
 984                    metadata: new_metadata,
 985                    ..
 986                } => {
 987                    let metadata = self.messages_metadata.get_mut(&message_id).unwrap();
 988                    if new_metadata.timestamp > metadata.timestamp {
 989                        *metadata = new_metadata;
 990                        messages_changed = true;
 991                    }
 992                }
 993                ContextOperation::UpdateSummary {
 994                    summary: new_summary,
 995                    ..
 996                } => {
 997                    if self
 998                        .summary
 999                        .as_ref()
1000                        .map_or(true, |summary| new_summary.timestamp > summary.timestamp)
1001                    {
1002                        self.summary = Some(new_summary);
1003                        summary_changed = true;
1004                    }
1005                }
1006                ContextOperation::SlashCommandFinished {
1007                    id,
1008                    output_range,
1009                    sections,
1010                    ..
1011                } => {
1012                    if self.finished_slash_commands.insert(id) {
1013                        let buffer = self.buffer.read(cx);
1014                        self.slash_command_output_sections
1015                            .extend(sections.iter().cloned());
1016                        self.slash_command_output_sections
1017                            .sort_by(|a, b| a.range.cmp(&b.range, buffer));
1018                        cx.emit(ContextEvent::SlashCommandFinished {
1019                            output_range,
1020                            sections,
1021                            run_commands_in_output: false,
1022                        });
1023                    }
1024                }
1025                ContextOperation::BufferOperation(_) => unreachable!(),
1026            }
1027
1028            self.version.observe(timestamp);
1029            self.timestamp.observe(timestamp);
1030            self.operations.push(op);
1031        }
1032
1033        if messages_changed {
1034            cx.emit(ContextEvent::MessagesEdited);
1035            cx.notify();
1036        }
1037
1038        if summary_changed {
1039            cx.emit(ContextEvent::SummaryChanged);
1040            cx.notify();
1041        }
1042    }
1043
1044    fn can_apply_op(&self, op: &ContextOperation, cx: &AppContext) -> bool {
1045        if !self.version.observed_all(op.version()) {
1046            return false;
1047        }
1048
1049        match op {
1050            ContextOperation::InsertMessage { anchor, .. } => self
1051                .buffer
1052                .read(cx)
1053                .version
1054                .observed(anchor.start.timestamp),
1055            ContextOperation::UpdateMessage { message_id, .. } => {
1056                self.messages_metadata.contains_key(message_id)
1057            }
1058            ContextOperation::UpdateSummary { .. } => true,
1059            ContextOperation::SlashCommandFinished {
1060                output_range,
1061                sections,
1062                ..
1063            } => {
1064                let version = &self.buffer.read(cx).version;
1065                sections
1066                    .iter()
1067                    .map(|section| &section.range)
1068                    .chain([output_range])
1069                    .all(|range| {
1070                        let observed_start = range.start == language::Anchor::MIN
1071                            || range.start == language::Anchor::MAX
1072                            || version.observed(range.start.timestamp);
1073                        let observed_end = range.end == language::Anchor::MIN
1074                            || range.end == language::Anchor::MAX
1075                            || version.observed(range.end.timestamp);
1076                        observed_start && observed_end
1077                    })
1078            }
1079            ContextOperation::BufferOperation(_) => {
1080                panic!("buffer operations should always be applied")
1081            }
1082        }
1083    }
1084
1085    fn push_op(&mut self, op: ContextOperation, cx: &mut ModelContext<Self>) {
1086        self.operations.push(op.clone());
1087        cx.emit(ContextEvent::Operation(op));
1088    }
1089
1090    pub fn buffer(&self) -> &Model<Buffer> {
1091        &self.buffer
1092    }
1093
1094    pub fn path(&self) -> Option<&Path> {
1095        self.path.as_deref()
1096    }
1097
1098    pub fn summary(&self) -> Option<&ContextSummary> {
1099        self.summary.as_ref()
1100    }
1101
1102    pub fn edit_steps(&self) -> &[EditStep] {
1103        &self.edit_steps
1104    }
1105
1106    pub fn pending_slash_commands(&self) -> &[PendingSlashCommand] {
1107        &self.pending_slash_commands
1108    }
1109
1110    pub fn slash_command_output_sections(&self) -> &[SlashCommandOutputSection<language::Anchor>] {
1111        &self.slash_command_output_sections
1112    }
1113
1114    fn set_language(&mut self, cx: &mut ModelContext<Self>) {
1115        let markdown = self.language_registry.language_for_name("Markdown");
1116        cx.spawn(|this, mut cx| async move {
1117            let markdown = markdown.await?;
1118            this.update(&mut cx, |this, cx| {
1119                this.buffer
1120                    .update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx));
1121            })
1122        })
1123        .detach_and_log_err(cx);
1124    }
1125
1126    fn handle_buffer_event(
1127        &mut self,
1128        _: Model<Buffer>,
1129        event: &language::Event,
1130        cx: &mut ModelContext<Self>,
1131    ) {
1132        match event {
1133            language::Event::Operation(operation) => cx.emit(ContextEvent::Operation(
1134                ContextOperation::BufferOperation(operation.clone()),
1135            )),
1136            language::Event::Edited => {
1137                self.count_remaining_tokens(cx);
1138                self.reparse_slash_commands(cx);
1139                self.prune_invalid_edit_steps(cx);
1140                cx.emit(ContextEvent::MessagesEdited);
1141            }
1142            _ => {}
1143        }
1144    }
1145
1146    pub(crate) fn token_count(&self) -> Option<usize> {
1147        self.token_count
1148    }
1149
1150    pub(crate) fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
1151        let request = self.to_completion_request(cx);
1152        self.pending_token_count = cx.spawn(|this, mut cx| {
1153            async move {
1154                cx.background_executor()
1155                    .timer(Duration::from_millis(200))
1156                    .await;
1157
1158                let token_count = cx
1159                    .update(|cx| {
1160                        LanguageModelCompletionProvider::read_global(cx).count_tokens(request, cx)
1161                    })?
1162                    .await?;
1163                this.update(&mut cx, |this, cx| {
1164                    this.token_count = Some(token_count);
1165                    cx.notify()
1166                })
1167            }
1168            .log_err()
1169        });
1170    }
1171
1172    pub fn reparse_slash_commands(&mut self, cx: &mut ModelContext<Self>) {
1173        let buffer = self.buffer.read(cx);
1174        let mut row_ranges = self
1175            .edits_since_last_slash_command_parse
1176            .consume()
1177            .into_iter()
1178            .map(|edit| {
1179                let start_row = buffer.offset_to_point(edit.new.start).row;
1180                let end_row = buffer.offset_to_point(edit.new.end).row + 1;
1181                start_row..end_row
1182            })
1183            .peekable();
1184
1185        let mut removed = Vec::new();
1186        let mut updated = Vec::new();
1187        while let Some(mut row_range) = row_ranges.next() {
1188            while let Some(next_row_range) = row_ranges.peek() {
1189                if row_range.end >= next_row_range.start {
1190                    row_range.end = next_row_range.end;
1191                    row_ranges.next();
1192                } else {
1193                    break;
1194                }
1195            }
1196
1197            let start = buffer.anchor_before(Point::new(row_range.start, 0));
1198            let end = buffer.anchor_after(Point::new(
1199                row_range.end - 1,
1200                buffer.line_len(row_range.end - 1),
1201            ));
1202
1203            let old_range = self.pending_command_indices_for_range(start..end, cx);
1204
1205            let mut new_commands = Vec::new();
1206            let mut lines = buffer.text_for_range(start..end).lines();
1207            let mut offset = lines.offset();
1208            while let Some(line) = lines.next() {
1209                if let Some(command_line) = SlashCommandLine::parse(line) {
1210                    let name = &line[command_line.name.clone()];
1211                    let argument = command_line.argument.as_ref().and_then(|argument| {
1212                        (!argument.is_empty()).then_some(&line[argument.clone()])
1213                    });
1214                    if let Some(command) = SlashCommandRegistry::global(cx).command(name) {
1215                        if !command.requires_argument() || argument.is_some() {
1216                            let start_ix = offset + command_line.name.start - 1;
1217                            let end_ix = offset
1218                                + command_line
1219                                    .argument
1220                                    .map_or(command_line.name.end, |argument| argument.end);
1221                            let source_range =
1222                                buffer.anchor_after(start_ix)..buffer.anchor_after(end_ix);
1223                            let pending_command = PendingSlashCommand {
1224                                name: name.to_string(),
1225                                argument: argument.map(ToString::to_string),
1226                                source_range,
1227                                status: PendingSlashCommandStatus::Idle,
1228                            };
1229                            updated.push(pending_command.clone());
1230                            new_commands.push(pending_command);
1231                        }
1232                    }
1233                }
1234
1235                offset = lines.offset();
1236            }
1237
1238            let removed_commands = self.pending_slash_commands.splice(old_range, new_commands);
1239            removed.extend(removed_commands.map(|command| command.source_range));
1240        }
1241
1242        if !updated.is_empty() || !removed.is_empty() {
1243            cx.emit(ContextEvent::PendingSlashCommandsUpdated { removed, updated });
1244        }
1245    }
1246
1247    fn prune_invalid_edit_steps(&mut self, cx: &mut ModelContext<Self>) {
1248        let buffer = self.buffer.read(cx);
1249        let prev_len = self.edit_steps.len();
1250        self.edit_steps.retain(|step| {
1251            step.source_range.start.is_valid(buffer) && step.source_range.end.is_valid(buffer)
1252        });
1253        if self.edit_steps.len() != prev_len {
1254            cx.emit(ContextEvent::EditStepsChanged);
1255            cx.notify();
1256        }
1257    }
1258
1259    fn parse_edit_steps_in_range(&mut self, range: Range<usize>, cx: &mut ModelContext<Self>) {
1260        let mut new_edit_steps = Vec::new();
1261
1262        self.buffer.update(cx, |buffer, _cx| {
1263            let mut message_lines = buffer.as_rope().chunks_in_range(range).lines();
1264            let mut in_step = false;
1265            let mut step_start = 0;
1266            let mut line_start_offset = message_lines.offset();
1267
1268            while let Some(line) = message_lines.next() {
1269                if let Some(step_start_index) = line.find("<step>") {
1270                    if !in_step {
1271                        in_step = true;
1272                        step_start = line_start_offset + step_start_index;
1273                    }
1274                }
1275
1276                if let Some(step_end_index) = line.find("</step>") {
1277                    if in_step {
1278                        let start_anchor = buffer.anchor_after(step_start);
1279                        let end_anchor = buffer
1280                            .anchor_before(line_start_offset + step_end_index + "</step>".len());
1281                        let source_range = start_anchor..end_anchor;
1282
1283                        // Check if a step with the same range already exists
1284                        let existing_step_index = self.edit_steps.binary_search_by(|probe| {
1285                            probe.source_range.cmp(&source_range, buffer)
1286                        });
1287
1288                        if let Err(ix) = existing_step_index {
1289                            // Step doesn't exist, so add it
1290                            new_edit_steps.push((
1291                                ix,
1292                                EditStep {
1293                                    source_range,
1294                                    operations: None,
1295                                },
1296                            ));
1297                        }
1298
1299                        in_step = false;
1300                    }
1301                }
1302
1303                line_start_offset = message_lines.offset();
1304            }
1305        });
1306
1307        // Insert new steps and generate their corresponding tasks
1308        for (index, mut step) in new_edit_steps.into_iter().rev() {
1309            let task = self.generate_edit_step_operations(&step, cx);
1310            step.operations = Some(EditStepOperations::Pending(task));
1311            self.edit_steps.insert(index, step);
1312        }
1313
1314        cx.emit(ContextEvent::EditStepsChanged);
1315        cx.notify();
1316    }
1317
1318    fn generate_edit_step_operations(
1319        &self,
1320        edit_step: &EditStep,
1321        cx: &mut ModelContext<Self>,
1322    ) -> Task<Option<()>> {
1323        #[derive(Debug, Deserialize, JsonSchema)]
1324        struct EditTool {
1325            /// A sequence of operations to apply to the codebase.
1326            /// When multiple operations are required for a step, be sure to include multiple operations in this list.
1327            operations: Vec<EditOperation>,
1328        }
1329
1330        impl LanguageModelTool for EditTool {
1331            fn name() -> String {
1332                "edit".into()
1333            }
1334
1335            fn description() -> String {
1336                "suggest edits to one or more locations in the codebase".into()
1337            }
1338        }
1339
1340        let mut request = self.to_completion_request(cx);
1341        let edit_step_range = edit_step.source_range.clone();
1342        let step_text = self
1343            .buffer
1344            .read(cx)
1345            .text_for_range(edit_step_range.clone())
1346            .collect::<String>();
1347
1348        cx.spawn(|this, mut cx| {
1349            async move {
1350                let prompt_store = cx.update(|cx| PromptStore::global(cx))?.await?;
1351
1352                let mut prompt = prompt_store.operations_prompt();
1353                prompt.push_str(&step_text);
1354
1355                request.messages.push(LanguageModelRequestMessage {
1356                    role: Role::User,
1357                    content: prompt,
1358                });
1359
1360                let tool_use = cx
1361                    .update(|cx| {
1362                        LanguageModelCompletionProvider::read_global(cx)
1363                            .use_tool::<EditTool>(request, cx)
1364                    })?
1365                    .await?;
1366
1367                this.update(&mut cx, |this, cx| {
1368                    let step_index = this
1369                        .edit_steps
1370                        .binary_search_by(|step| {
1371                            step.source_range
1372                                .cmp(&edit_step_range, this.buffer.read(cx))
1373                        })
1374                        .map_err(|_| anyhow!("edit step not found"))?;
1375                    if let Some(edit_step) = this.edit_steps.get_mut(step_index) {
1376                        edit_step.operations = Some(EditStepOperations::Ready(tool_use.operations));
1377                        cx.emit(ContextEvent::EditStepsChanged);
1378                    }
1379                    anyhow::Ok(())
1380                })?
1381            }
1382            .log_err()
1383        })
1384    }
1385
1386    pub fn pending_command_for_position(
1387        &mut self,
1388        position: language::Anchor,
1389        cx: &mut ModelContext<Self>,
1390    ) -> Option<&mut PendingSlashCommand> {
1391        let buffer = self.buffer.read(cx);
1392        match self
1393            .pending_slash_commands
1394            .binary_search_by(|probe| probe.source_range.end.cmp(&position, buffer))
1395        {
1396            Ok(ix) => Some(&mut self.pending_slash_commands[ix]),
1397            Err(ix) => {
1398                let cmd = self.pending_slash_commands.get_mut(ix)?;
1399                if position.cmp(&cmd.source_range.start, buffer).is_ge()
1400                    && position.cmp(&cmd.source_range.end, buffer).is_le()
1401                {
1402                    Some(cmd)
1403                } else {
1404                    None
1405                }
1406            }
1407        }
1408    }
1409
1410    pub fn pending_commands_for_range(
1411        &self,
1412        range: Range<language::Anchor>,
1413        cx: &AppContext,
1414    ) -> &[PendingSlashCommand] {
1415        let range = self.pending_command_indices_for_range(range, cx);
1416        &self.pending_slash_commands[range]
1417    }
1418
1419    fn pending_command_indices_for_range(
1420        &self,
1421        range: Range<language::Anchor>,
1422        cx: &AppContext,
1423    ) -> Range<usize> {
1424        let buffer = self.buffer.read(cx);
1425        let start_ix = match self
1426            .pending_slash_commands
1427            .binary_search_by(|probe| probe.source_range.end.cmp(&range.start, &buffer))
1428        {
1429            Ok(ix) | Err(ix) => ix,
1430        };
1431        let end_ix = match self
1432            .pending_slash_commands
1433            .binary_search_by(|probe| probe.source_range.start.cmp(&range.end, &buffer))
1434        {
1435            Ok(ix) => ix + 1,
1436            Err(ix) => ix,
1437        };
1438        start_ix..end_ix
1439    }
1440
1441    pub fn insert_command_output(
1442        &mut self,
1443        command_range: Range<language::Anchor>,
1444        output: Task<Result<SlashCommandOutput>>,
1445        insert_trailing_newline: bool,
1446        cx: &mut ModelContext<Self>,
1447    ) {
1448        self.reparse_slash_commands(cx);
1449
1450        let insert_output_task = cx.spawn(|this, mut cx| {
1451            let command_range = command_range.clone();
1452            async move {
1453                let output = output.await;
1454                this.update(&mut cx, |this, cx| match output {
1455                    Ok(mut output) => {
1456                        if insert_trailing_newline {
1457                            output.text.push('\n');
1458                        }
1459
1460                        let version = this.version.clone();
1461                        let command_id = SlashCommandId(this.next_timestamp());
1462                        let (operation, event) = this.buffer.update(cx, |buffer, cx| {
1463                            let start = command_range.start.to_offset(buffer);
1464                            let old_end = command_range.end.to_offset(buffer);
1465                            let new_end = start + output.text.len();
1466                            buffer.edit([(start..old_end, output.text)], None, cx);
1467
1468                            let mut sections = output
1469                                .sections
1470                                .into_iter()
1471                                .map(|section| SlashCommandOutputSection {
1472                                    range: buffer.anchor_after(start + section.range.start)
1473                                        ..buffer.anchor_before(start + section.range.end),
1474                                    icon: section.icon,
1475                                    label: section.label,
1476                                })
1477                                .collect::<Vec<_>>();
1478                            sections.sort_by(|a, b| a.range.cmp(&b.range, buffer));
1479
1480                            this.slash_command_output_sections
1481                                .extend(sections.iter().cloned());
1482                            this.slash_command_output_sections
1483                                .sort_by(|a, b| a.range.cmp(&b.range, buffer));
1484
1485                            let output_range =
1486                                buffer.anchor_after(start)..buffer.anchor_before(new_end);
1487                            this.finished_slash_commands.insert(command_id);
1488
1489                            (
1490                                ContextOperation::SlashCommandFinished {
1491                                    id: command_id,
1492                                    output_range: output_range.clone(),
1493                                    sections: sections.clone(),
1494                                    version,
1495                                },
1496                                ContextEvent::SlashCommandFinished {
1497                                    output_range,
1498                                    sections,
1499                                    run_commands_in_output: output.run_commands_in_text,
1500                                },
1501                            )
1502                        });
1503
1504                        this.push_op(operation, cx);
1505                        cx.emit(event);
1506                    }
1507                    Err(error) => {
1508                        if let Some(pending_command) =
1509                            this.pending_command_for_position(command_range.start, cx)
1510                        {
1511                            pending_command.status =
1512                                PendingSlashCommandStatus::Error(error.to_string());
1513                            cx.emit(ContextEvent::PendingSlashCommandsUpdated {
1514                                removed: vec![pending_command.source_range.clone()],
1515                                updated: vec![pending_command.clone()],
1516                            });
1517                        }
1518                    }
1519                })
1520                .ok();
1521            }
1522        });
1523
1524        if let Some(pending_command) = self.pending_command_for_position(command_range.start, cx) {
1525            pending_command.status = PendingSlashCommandStatus::Running {
1526                _task: insert_output_task.shared(),
1527            };
1528            cx.emit(ContextEvent::PendingSlashCommandsUpdated {
1529                removed: vec![pending_command.source_range.clone()],
1530                updated: vec![pending_command.clone()],
1531            });
1532        }
1533    }
1534
1535    pub fn completion_provider_changed(&mut self, cx: &mut ModelContext<Self>) {
1536        self.count_remaining_tokens(cx);
1537    }
1538
1539    pub fn assist(&mut self, cx: &mut ModelContext<Self>) -> Option<MessageAnchor> {
1540        let last_message_id = self.message_anchors.iter().rev().find_map(|message| {
1541            message
1542                .start
1543                .is_valid(self.buffer.read(cx))
1544                .then_some(message.id)
1545        })?;
1546
1547        if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
1548            log::info!("completion provider has no credentials");
1549            return None;
1550        }
1551
1552        let request = self.to_completion_request(cx);
1553        let stream =
1554            LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
1555        let assistant_message = self
1556            .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
1557            .unwrap();
1558
1559        // Queue up the user's next reply.
1560        let user_message = self
1561            .insert_message_after(assistant_message.id, Role::User, MessageStatus::Done, cx)
1562            .unwrap();
1563
1564        let task = cx.spawn({
1565            |this, mut cx| async move {
1566                let assistant_message_id = assistant_message.id;
1567                let mut response_latency = None;
1568                let stream_completion = async {
1569                    let request_start = Instant::now();
1570                    let mut chunks = stream.await?;
1571
1572                    while let Some(chunk) = chunks.next().await {
1573                        if response_latency.is_none() {
1574                            response_latency = Some(request_start.elapsed());
1575                        }
1576                        let chunk = chunk?;
1577
1578                        this.update(&mut cx, |this, cx| {
1579                            let message_ix = this
1580                                .message_anchors
1581                                .iter()
1582                                .position(|message| message.id == assistant_message_id)?;
1583                            let message_range = this.buffer.update(cx, |buffer, cx| {
1584                                let message_start_offset =
1585                                    this.message_anchors[message_ix].start.to_offset(buffer);
1586                                let message_old_end_offset = this.message_anchors[message_ix + 1..]
1587                                    .iter()
1588                                    .find(|message| message.start.is_valid(buffer))
1589                                    .map_or(buffer.len(), |message| {
1590                                        message.start.to_offset(buffer).saturating_sub(1)
1591                                    });
1592                                let message_new_end_offset = message_old_end_offset + chunk.len();
1593                                buffer.edit(
1594                                    [(message_old_end_offset..message_old_end_offset, chunk)],
1595                                    None,
1596                                    cx,
1597                                );
1598                                message_start_offset..message_new_end_offset
1599                            });
1600                            this.parse_edit_steps_in_range(message_range, cx);
1601                            cx.emit(ContextEvent::StreamedCompletion);
1602
1603                            Some(())
1604                        })?;
1605                        smol::future::yield_now().await;
1606                    }
1607
1608                    this.update(&mut cx, |this, cx| {
1609                        this.pending_completions
1610                            .retain(|completion| completion.id != this.completion_count);
1611                        this.summarize(false, cx);
1612                    })?;
1613
1614                    anyhow::Ok(())
1615                };
1616
1617                let result = stream_completion.await;
1618
1619                this.update(&mut cx, |this, cx| {
1620                    let error_message = result
1621                        .err()
1622                        .map(|error| error.to_string().trim().to_string());
1623
1624                    this.update_metadata(assistant_message_id, cx, |metadata| {
1625                        if let Some(error_message) = error_message.as_ref() {
1626                            metadata.status =
1627                                MessageStatus::Error(SharedString::from(error_message.clone()));
1628                        } else {
1629                            metadata.status = MessageStatus::Done;
1630                        }
1631                    });
1632
1633                    if let Some(telemetry) = this.telemetry.as_ref() {
1634                        let model_telemetry_id = LanguageModelCompletionProvider::read_global(cx)
1635                            .active_model()
1636                            .map(|m| m.telemetry_id())
1637                            .unwrap_or_default();
1638                        telemetry.report_assistant_event(
1639                            Some(this.id.0.clone()),
1640                            AssistantKind::Panel,
1641                            model_telemetry_id,
1642                            response_latency,
1643                            error_message,
1644                        );
1645                    }
1646                })
1647                .ok();
1648            }
1649        });
1650
1651        self.pending_completions.push(PendingCompletion {
1652            id: post_inc(&mut self.completion_count),
1653            _task: task,
1654        });
1655
1656        Some(user_message)
1657    }
1658
1659    pub fn to_completion_request(&self, cx: &AppContext) -> LanguageModelRequest {
1660        let messages = self
1661            .messages(cx)
1662            .filter(|message| matches!(message.status, MessageStatus::Done))
1663            .map(|message| message.to_request_message(self.buffer.read(cx)));
1664
1665        LanguageModelRequest {
1666            messages: messages.collect(),
1667            stop: vec![],
1668            temperature: 1.0,
1669        }
1670    }
1671
1672    pub fn cancel_last_assist(&mut self) -> bool {
1673        self.pending_completions.pop().is_some()
1674    }
1675
1676    pub fn cycle_message_roles(&mut self, ids: HashSet<MessageId>, cx: &mut ModelContext<Self>) {
1677        for id in ids {
1678            if let Some(metadata) = self.messages_metadata.get(&id) {
1679                let role = metadata.role.cycle();
1680                self.update_metadata(id, cx, |metadata| metadata.role = role);
1681            }
1682        }
1683    }
1684
1685    pub fn update_metadata(
1686        &mut self,
1687        id: MessageId,
1688        cx: &mut ModelContext<Self>,
1689        f: impl FnOnce(&mut MessageMetadata),
1690    ) {
1691        let version = self.version.clone();
1692        let timestamp = self.next_timestamp();
1693        if let Some(metadata) = self.messages_metadata.get_mut(&id) {
1694            f(metadata);
1695            metadata.timestamp = timestamp;
1696            let operation = ContextOperation::UpdateMessage {
1697                message_id: id,
1698                metadata: metadata.clone(),
1699                version,
1700            };
1701            self.push_op(operation, cx);
1702            cx.emit(ContextEvent::MessagesEdited);
1703            cx.notify();
1704        }
1705    }
1706
1707    fn insert_message_after(
1708        &mut self,
1709        message_id: MessageId,
1710        role: Role,
1711        status: MessageStatus,
1712        cx: &mut ModelContext<Self>,
1713    ) -> Option<MessageAnchor> {
1714        if let Some(prev_message_ix) = self
1715            .message_anchors
1716            .iter()
1717            .position(|message| message.id == message_id)
1718        {
1719            // Find the next valid message after the one we were given.
1720            let mut next_message_ix = prev_message_ix + 1;
1721            while let Some(next_message) = self.message_anchors.get(next_message_ix) {
1722                if next_message.start.is_valid(self.buffer.read(cx)) {
1723                    break;
1724                }
1725                next_message_ix += 1;
1726            }
1727
1728            let start = self.buffer.update(cx, |buffer, cx| {
1729                let offset = self
1730                    .message_anchors
1731                    .get(next_message_ix)
1732                    .map_or(buffer.len(), |message| {
1733                        buffer.clip_offset(message.start.to_offset(buffer) - 1, Bias::Left)
1734                    });
1735                buffer.edit([(offset..offset, "\n")], None, cx);
1736                buffer.anchor_before(offset + 1)
1737            });
1738
1739            let version = self.version.clone();
1740            let anchor = MessageAnchor {
1741                id: MessageId(self.next_timestamp()),
1742                start,
1743            };
1744            let metadata = MessageMetadata {
1745                role,
1746                status,
1747                timestamp: anchor.id.0,
1748            };
1749            self.insert_message(anchor.clone(), metadata.clone(), cx);
1750            self.push_op(
1751                ContextOperation::InsertMessage {
1752                    anchor: anchor.clone(),
1753                    metadata,
1754                    version,
1755                },
1756                cx,
1757            );
1758            Some(anchor)
1759        } else {
1760            None
1761        }
1762    }
1763
1764    pub fn split_message(
1765        &mut self,
1766        range: Range<usize>,
1767        cx: &mut ModelContext<Self>,
1768    ) -> (Option<MessageAnchor>, Option<MessageAnchor>) {
1769        let start_message = self.message_for_offset(range.start, cx);
1770        let end_message = self.message_for_offset(range.end, cx);
1771        if let Some((start_message, end_message)) = start_message.zip(end_message) {
1772            // Prevent splitting when range spans multiple messages.
1773            if start_message.id != end_message.id {
1774                return (None, None);
1775            }
1776
1777            let message = start_message;
1778            let role = message.role;
1779            let mut edited_buffer = false;
1780
1781            let mut suffix_start = None;
1782            if range.start > message.offset_range.start && range.end < message.offset_range.end - 1
1783            {
1784                if self.buffer.read(cx).chars_at(range.end).next() == Some('\n') {
1785                    suffix_start = Some(range.end + 1);
1786                } else if self.buffer.read(cx).reversed_chars_at(range.end).next() == Some('\n') {
1787                    suffix_start = Some(range.end);
1788                }
1789            }
1790
1791            let version = self.version.clone();
1792            let suffix = if let Some(suffix_start) = suffix_start {
1793                MessageAnchor {
1794                    id: MessageId(self.next_timestamp()),
1795                    start: self.buffer.read(cx).anchor_before(suffix_start),
1796                }
1797            } else {
1798                self.buffer.update(cx, |buffer, cx| {
1799                    buffer.edit([(range.end..range.end, "\n")], None, cx);
1800                });
1801                edited_buffer = true;
1802                MessageAnchor {
1803                    id: MessageId(self.next_timestamp()),
1804                    start: self.buffer.read(cx).anchor_before(range.end + 1),
1805                }
1806            };
1807
1808            let suffix_metadata = MessageMetadata {
1809                role,
1810                status: MessageStatus::Done,
1811                timestamp: suffix.id.0,
1812            };
1813            self.insert_message(suffix.clone(), suffix_metadata.clone(), cx);
1814            self.push_op(
1815                ContextOperation::InsertMessage {
1816                    anchor: suffix.clone(),
1817                    metadata: suffix_metadata,
1818                    version,
1819                },
1820                cx,
1821            );
1822
1823            let new_messages =
1824                if range.start == range.end || range.start == message.offset_range.start {
1825                    (None, Some(suffix))
1826                } else {
1827                    let mut prefix_end = None;
1828                    if range.start > message.offset_range.start
1829                        && range.end < message.offset_range.end - 1
1830                    {
1831                        if self.buffer.read(cx).chars_at(range.start).next() == Some('\n') {
1832                            prefix_end = Some(range.start + 1);
1833                        } else if self.buffer.read(cx).reversed_chars_at(range.start).next()
1834                            == Some('\n')
1835                        {
1836                            prefix_end = Some(range.start);
1837                        }
1838                    }
1839
1840                    let version = self.version.clone();
1841                    let selection = if let Some(prefix_end) = prefix_end {
1842                        MessageAnchor {
1843                            id: MessageId(self.next_timestamp()),
1844                            start: self.buffer.read(cx).anchor_before(prefix_end),
1845                        }
1846                    } else {
1847                        self.buffer.update(cx, |buffer, cx| {
1848                            buffer.edit([(range.start..range.start, "\n")], None, cx)
1849                        });
1850                        edited_buffer = true;
1851                        MessageAnchor {
1852                            id: MessageId(self.next_timestamp()),
1853                            start: self.buffer.read(cx).anchor_before(range.end + 1),
1854                        }
1855                    };
1856
1857                    let selection_metadata = MessageMetadata {
1858                        role,
1859                        status: MessageStatus::Done,
1860                        timestamp: selection.id.0,
1861                    };
1862                    self.insert_message(selection.clone(), selection_metadata.clone(), cx);
1863                    self.push_op(
1864                        ContextOperation::InsertMessage {
1865                            anchor: selection.clone(),
1866                            metadata: selection_metadata,
1867                            version,
1868                        },
1869                        cx,
1870                    );
1871
1872                    (Some(selection), Some(suffix))
1873                };
1874
1875            if !edited_buffer {
1876                cx.emit(ContextEvent::MessagesEdited);
1877            }
1878            new_messages
1879        } else {
1880            (None, None)
1881        }
1882    }
1883
1884    fn insert_message(
1885        &mut self,
1886        new_anchor: MessageAnchor,
1887        new_metadata: MessageMetadata,
1888        cx: &mut ModelContext<Self>,
1889    ) {
1890        cx.emit(ContextEvent::MessagesEdited);
1891
1892        self.messages_metadata.insert(new_anchor.id, new_metadata);
1893
1894        let buffer = self.buffer.read(cx);
1895        let insertion_ix = self
1896            .message_anchors
1897            .iter()
1898            .position(|anchor| {
1899                let comparison = new_anchor.start.cmp(&anchor.start, buffer);
1900                comparison.is_lt() || (comparison.is_eq() && new_anchor.id > anchor.id)
1901            })
1902            .unwrap_or(self.message_anchors.len());
1903        self.message_anchors.insert(insertion_ix, new_anchor);
1904    }
1905
1906    pub(super) fn summarize(&mut self, replace_old: bool, cx: &mut ModelContext<Self>) {
1907        if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) {
1908            if !LanguageModelCompletionProvider::read_global(cx).is_authenticated(cx) {
1909                return;
1910            }
1911
1912            let messages = self
1913                .messages(cx)
1914                .map(|message| message.to_request_message(self.buffer.read(cx)))
1915                .chain(Some(LanguageModelRequestMessage {
1916                    role: Role::User,
1917                    content: "Summarize the context into a short title without punctuation.".into(),
1918                }));
1919            let request = LanguageModelRequest {
1920                messages: messages.collect(),
1921                stop: vec![],
1922                temperature: 1.0,
1923            };
1924
1925            let stream =
1926                LanguageModelCompletionProvider::read_global(cx).stream_completion(request, cx);
1927            self.pending_summary = cx.spawn(|this, mut cx| {
1928                async move {
1929                    let mut messages = stream.await?;
1930
1931                    let mut replaced = !replace_old;
1932                    while let Some(message) = messages.next().await {
1933                        let text = message?;
1934                        let mut lines = text.lines();
1935                        this.update(&mut cx, |this, cx| {
1936                            let version = this.version.clone();
1937                            let timestamp = this.next_timestamp();
1938                            let summary = this.summary.get_or_insert(ContextSummary::default());
1939                            if !replaced && replace_old {
1940                                summary.text.clear();
1941                                replaced = true;
1942                            }
1943                            summary.text.extend(lines.next());
1944                            summary.timestamp = timestamp;
1945                            let operation = ContextOperation::UpdateSummary {
1946                                summary: summary.clone(),
1947                                version,
1948                            };
1949                            this.push_op(operation, cx);
1950                            cx.emit(ContextEvent::SummaryChanged);
1951                        })?;
1952
1953                        // Stop if the LLM generated multiple lines.
1954                        if lines.next().is_some() {
1955                            break;
1956                        }
1957                    }
1958
1959                    this.update(&mut cx, |this, cx| {
1960                        let version = this.version.clone();
1961                        let timestamp = this.next_timestamp();
1962                        if let Some(summary) = this.summary.as_mut() {
1963                            summary.done = true;
1964                            summary.timestamp = timestamp;
1965                            let operation = ContextOperation::UpdateSummary {
1966                                summary: summary.clone(),
1967                                version,
1968                            };
1969                            this.push_op(operation, cx);
1970                            cx.emit(ContextEvent::SummaryChanged);
1971                        }
1972                    })?;
1973
1974                    anyhow::Ok(())
1975                }
1976                .log_err()
1977            });
1978        }
1979    }
1980
1981    fn message_for_offset(&self, offset: usize, cx: &AppContext) -> Option<Message> {
1982        self.messages_for_offsets([offset], cx).pop()
1983    }
1984
1985    pub fn messages_for_offsets(
1986        &self,
1987        offsets: impl IntoIterator<Item = usize>,
1988        cx: &AppContext,
1989    ) -> Vec<Message> {
1990        let mut result = Vec::new();
1991
1992        let mut messages = self.messages(cx).peekable();
1993        let mut offsets = offsets.into_iter().peekable();
1994        let mut current_message = messages.next();
1995        while let Some(offset) = offsets.next() {
1996            // Locate the message that contains the offset.
1997            while current_message.as_ref().map_or(false, |message| {
1998                !message.offset_range.contains(&offset) && messages.peek().is_some()
1999            }) {
2000                current_message = messages.next();
2001            }
2002            let Some(message) = current_message.as_ref() else {
2003                break;
2004            };
2005
2006            // Skip offsets that are in the same message.
2007            while offsets.peek().map_or(false, |offset| {
2008                message.offset_range.contains(offset) || messages.peek().is_none()
2009            }) {
2010                offsets.next();
2011            }
2012
2013            result.push(message.clone());
2014        }
2015        result
2016    }
2017
2018    pub fn messages<'a>(&'a self, cx: &'a AppContext) -> impl 'a + Iterator<Item = Message> {
2019        let buffer = self.buffer.read(cx);
2020        let mut message_anchors = self.message_anchors.iter().enumerate().peekable();
2021        iter::from_fn(move || {
2022            if let Some((start_ix, message_anchor)) = message_anchors.next() {
2023                let metadata = self.messages_metadata.get(&message_anchor.id)?;
2024                let message_start = message_anchor.start.to_offset(buffer);
2025                let mut message_end = None;
2026                let mut end_ix = start_ix;
2027                while let Some((_, next_message)) = message_anchors.peek() {
2028                    if next_message.start.is_valid(buffer) {
2029                        message_end = Some(next_message.start);
2030                        break;
2031                    } else {
2032                        end_ix += 1;
2033                        message_anchors.next();
2034                    }
2035                }
2036                let message_end = message_end
2037                    .unwrap_or(language::Anchor::MAX)
2038                    .to_offset(buffer);
2039
2040                return Some(Message {
2041                    index_range: start_ix..end_ix,
2042                    offset_range: message_start..message_end,
2043                    id: message_anchor.id,
2044                    anchor: message_anchor.start,
2045                    role: metadata.role,
2046                    status: metadata.status.clone(),
2047                });
2048            }
2049            None
2050        })
2051    }
2052
2053    pub fn save(
2054        &mut self,
2055        debounce: Option<Duration>,
2056        fs: Arc<dyn Fs>,
2057        cx: &mut ModelContext<Context>,
2058    ) {
2059        if self.replica_id() != ReplicaId::default() {
2060            // Prevent saving a remote context for now.
2061            return;
2062        }
2063
2064        self.pending_save = cx.spawn(|this, mut cx| async move {
2065            if let Some(debounce) = debounce {
2066                cx.background_executor().timer(debounce).await;
2067            }
2068
2069            let (old_path, summary) = this.read_with(&cx, |this, _| {
2070                let path = this.path.clone();
2071                let summary = if let Some(summary) = this.summary.as_ref() {
2072                    if summary.done {
2073                        Some(summary.text.clone())
2074                    } else {
2075                        None
2076                    }
2077                } else {
2078                    None
2079                };
2080                (path, summary)
2081            })?;
2082
2083            if let Some(summary) = summary {
2084                let context = this.read_with(&cx, |this, cx| this.serialize(cx))?;
2085                let mut discriminant = 1;
2086                let mut new_path;
2087                loop {
2088                    new_path = contexts_dir().join(&format!(
2089                        "{} - {}.zed.json",
2090                        summary.trim(),
2091                        discriminant
2092                    ));
2093                    if fs.is_file(&new_path).await {
2094                        discriminant += 1;
2095                    } else {
2096                        break;
2097                    }
2098                }
2099
2100                fs.create_dir(contexts_dir().as_ref()).await?;
2101                fs.atomic_write(new_path.clone(), serde_json::to_string(&context).unwrap())
2102                    .await?;
2103                if let Some(old_path) = old_path {
2104                    if new_path != old_path {
2105                        fs.remove_file(
2106                            &old_path,
2107                            RemoveOptions {
2108                                recursive: false,
2109                                ignore_if_not_exists: true,
2110                            },
2111                        )
2112                        .await?;
2113                    }
2114                }
2115
2116                this.update(&mut cx, |this, _| this.path = Some(new_path))?;
2117            }
2118
2119            Ok(())
2120        });
2121    }
2122
2123    pub(crate) fn custom_summary(&mut self, custom_summary: String, cx: &mut ModelContext<Self>) {
2124        let timestamp = self.next_timestamp();
2125        let summary = self.summary.get_or_insert(ContextSummary::default());
2126        summary.timestamp = timestamp;
2127        summary.done = true;
2128        summary.text = custom_summary;
2129        cx.emit(ContextEvent::SummaryChanged);
2130    }
2131}
2132
2133#[derive(Debug, Default)]
2134pub struct ContextVersion {
2135    context: clock::Global,
2136    buffer: clock::Global,
2137}
2138
2139impl ContextVersion {
2140    pub fn from_proto(proto: &proto::ContextVersion) -> Self {
2141        Self {
2142            context: language::proto::deserialize_version(&proto.context_version),
2143            buffer: language::proto::deserialize_version(&proto.buffer_version),
2144        }
2145    }
2146
2147    pub fn to_proto(&self, context_id: ContextId) -> proto::ContextVersion {
2148        proto::ContextVersion {
2149            context_id: context_id.to_proto(),
2150            context_version: language::proto::serialize_version(&self.context),
2151            buffer_version: language::proto::serialize_version(&self.buffer),
2152        }
2153    }
2154}
2155
2156#[derive(Clone)]
2157pub struct PendingSlashCommand {
2158    pub name: String,
2159    pub argument: Option<String>,
2160    pub status: PendingSlashCommandStatus,
2161    pub source_range: Range<language::Anchor>,
2162}
2163
2164#[derive(Clone)]
2165pub enum PendingSlashCommandStatus {
2166    Idle,
2167    Running { _task: Shared<Task<()>> },
2168    Error(String),
2169}
2170
2171#[derive(Serialize, Deserialize)]
2172pub struct SavedMessage {
2173    pub id: MessageId,
2174    pub start: usize,
2175    pub metadata: MessageMetadata,
2176}
2177
2178#[derive(Serialize, Deserialize)]
2179pub struct SavedContext {
2180    pub id: Option<ContextId>,
2181    pub zed: String,
2182    pub version: String,
2183    pub text: String,
2184    pub messages: Vec<SavedMessage>,
2185    pub summary: String,
2186    pub slash_command_output_sections:
2187        Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
2188}
2189
2190impl SavedContext {
2191    pub const VERSION: &'static str = "0.4.0";
2192
2193    pub fn from_json(json: &str) -> Result<Self> {
2194        let saved_context_json = serde_json::from_str::<serde_json::Value>(json)?;
2195        match saved_context_json
2196            .get("version")
2197            .ok_or_else(|| anyhow!("version not found"))?
2198        {
2199            serde_json::Value::String(version) => match version.as_str() {
2200                SavedContext::VERSION => {
2201                    Ok(serde_json::from_value::<SavedContext>(saved_context_json)?)
2202                }
2203                SavedContextV0_3_0::VERSION => {
2204                    let saved_context =
2205                        serde_json::from_value::<SavedContextV0_3_0>(saved_context_json)?;
2206                    Ok(saved_context.upgrade())
2207                }
2208                SavedContextV0_2_0::VERSION => {
2209                    let saved_context =
2210                        serde_json::from_value::<SavedContextV0_2_0>(saved_context_json)?;
2211                    Ok(saved_context.upgrade())
2212                }
2213                SavedContextV0_1_0::VERSION => {
2214                    let saved_context =
2215                        serde_json::from_value::<SavedContextV0_1_0>(saved_context_json)?;
2216                    Ok(saved_context.upgrade())
2217                }
2218                _ => Err(anyhow!("unrecognized saved context version: {}", version)),
2219            },
2220            _ => Err(anyhow!("version not found on saved context")),
2221        }
2222    }
2223
2224    fn into_ops(
2225        self,
2226        buffer: &Model<Buffer>,
2227        cx: &mut ModelContext<Context>,
2228    ) -> Vec<ContextOperation> {
2229        let mut operations = Vec::new();
2230        let mut version = clock::Global::new();
2231        let mut next_timestamp = clock::Lamport::new(ReplicaId::default());
2232
2233        let mut first_message_metadata = None;
2234        for message in self.messages {
2235            if message.id == MessageId(clock::Lamport::default()) {
2236                first_message_metadata = Some(message.metadata);
2237            } else {
2238                operations.push(ContextOperation::InsertMessage {
2239                    anchor: MessageAnchor {
2240                        id: message.id,
2241                        start: buffer.read(cx).anchor_before(message.start),
2242                    },
2243                    metadata: MessageMetadata {
2244                        role: message.metadata.role,
2245                        status: message.metadata.status,
2246                        timestamp: message.metadata.timestamp,
2247                    },
2248                    version: version.clone(),
2249                });
2250                version.observe(message.id.0);
2251                next_timestamp.observe(message.id.0);
2252            }
2253        }
2254
2255        if let Some(metadata) = first_message_metadata {
2256            let timestamp = next_timestamp.tick();
2257            operations.push(ContextOperation::UpdateMessage {
2258                message_id: MessageId(clock::Lamport::default()),
2259                metadata: MessageMetadata {
2260                    role: metadata.role,
2261                    status: metadata.status,
2262                    timestamp,
2263                },
2264                version: version.clone(),
2265            });
2266            version.observe(timestamp);
2267        }
2268
2269        let timestamp = next_timestamp.tick();
2270        operations.push(ContextOperation::SlashCommandFinished {
2271            id: SlashCommandId(timestamp),
2272            output_range: language::Anchor::MIN..language::Anchor::MAX,
2273            sections: self
2274                .slash_command_output_sections
2275                .into_iter()
2276                .map(|section| {
2277                    let buffer = buffer.read(cx);
2278                    SlashCommandOutputSection {
2279                        range: buffer.anchor_after(section.range.start)
2280                            ..buffer.anchor_before(section.range.end),
2281                        icon: section.icon,
2282                        label: section.label,
2283                    }
2284                })
2285                .collect(),
2286            version: version.clone(),
2287        });
2288        version.observe(timestamp);
2289
2290        let timestamp = next_timestamp.tick();
2291        operations.push(ContextOperation::UpdateSummary {
2292            summary: ContextSummary {
2293                text: self.summary,
2294                done: true,
2295                timestamp,
2296            },
2297            version: version.clone(),
2298        });
2299        version.observe(timestamp);
2300
2301        operations
2302    }
2303}
2304
2305#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
2306struct SavedMessageIdPreV0_4_0(usize);
2307
2308#[derive(Serialize, Deserialize)]
2309struct SavedMessagePreV0_4_0 {
2310    id: SavedMessageIdPreV0_4_0,
2311    start: usize,
2312}
2313
2314#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
2315struct SavedMessageMetadataPreV0_4_0 {
2316    role: Role,
2317    status: MessageStatus,
2318}
2319
2320#[derive(Serialize, Deserialize)]
2321struct SavedContextV0_3_0 {
2322    id: Option<ContextId>,
2323    zed: String,
2324    version: String,
2325    text: String,
2326    messages: Vec<SavedMessagePreV0_4_0>,
2327    message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2328    summary: String,
2329    slash_command_output_sections: Vec<assistant_slash_command::SlashCommandOutputSection<usize>>,
2330}
2331
2332impl SavedContextV0_3_0 {
2333    const VERSION: &'static str = "0.3.0";
2334
2335    fn upgrade(self) -> SavedContext {
2336        SavedContext {
2337            id: self.id,
2338            zed: self.zed,
2339            version: SavedContext::VERSION.into(),
2340            text: self.text,
2341            messages: self
2342                .messages
2343                .into_iter()
2344                .filter_map(|message| {
2345                    let metadata = self.message_metadata.get(&message.id)?;
2346                    let timestamp = clock::Lamport {
2347                        replica_id: ReplicaId::default(),
2348                        value: message.id.0 as u32,
2349                    };
2350                    Some(SavedMessage {
2351                        id: MessageId(timestamp),
2352                        start: message.start,
2353                        metadata: MessageMetadata {
2354                            role: metadata.role,
2355                            status: metadata.status.clone(),
2356                            timestamp,
2357                        },
2358                    })
2359                })
2360                .collect(),
2361            summary: self.summary,
2362            slash_command_output_sections: self.slash_command_output_sections,
2363        }
2364    }
2365}
2366
2367#[derive(Serialize, Deserialize)]
2368struct SavedContextV0_2_0 {
2369    id: Option<ContextId>,
2370    zed: String,
2371    version: String,
2372    text: String,
2373    messages: Vec<SavedMessagePreV0_4_0>,
2374    message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2375    summary: String,
2376}
2377
2378impl SavedContextV0_2_0 {
2379    const VERSION: &'static str = "0.2.0";
2380
2381    fn upgrade(self) -> SavedContext {
2382        SavedContextV0_3_0 {
2383            id: self.id,
2384            zed: self.zed,
2385            version: SavedContextV0_3_0::VERSION.to_string(),
2386            text: self.text,
2387            messages: self.messages,
2388            message_metadata: self.message_metadata,
2389            summary: self.summary,
2390            slash_command_output_sections: Vec::new(),
2391        }
2392        .upgrade()
2393    }
2394}
2395
2396#[derive(Serialize, Deserialize)]
2397struct SavedContextV0_1_0 {
2398    id: Option<ContextId>,
2399    zed: String,
2400    version: String,
2401    text: String,
2402    messages: Vec<SavedMessagePreV0_4_0>,
2403    message_metadata: HashMap<SavedMessageIdPreV0_4_0, SavedMessageMetadataPreV0_4_0>,
2404    summary: String,
2405    api_url: Option<String>,
2406    model: OpenAiModel,
2407}
2408
2409impl SavedContextV0_1_0 {
2410    const VERSION: &'static str = "0.1.0";
2411
2412    fn upgrade(self) -> SavedContext {
2413        SavedContextV0_2_0 {
2414            id: self.id,
2415            zed: self.zed,
2416            version: SavedContextV0_2_0::VERSION.to_string(),
2417            text: self.text,
2418            messages: self.messages,
2419            message_metadata: self.message_metadata,
2420            summary: self.summary,
2421        }
2422        .upgrade()
2423    }
2424}
2425
2426#[derive(Clone)]
2427pub struct SavedContextMetadata {
2428    pub title: String,
2429    pub path: PathBuf,
2430    pub mtime: chrono::DateTime<chrono::Local>,
2431}
2432
2433#[cfg(test)]
2434mod tests {
2435    use super::*;
2436    use crate::{
2437        assistant_panel, prompt_library,
2438        slash_command::{active_command, file_command},
2439        MessageId,
2440    };
2441    use assistant_slash_command::{ArgumentCompletion, SlashCommand};
2442    use fs::FakeFs;
2443    use gpui::{AppContext, TestAppContext, WeakView};
2444    use indoc::indoc;
2445    use language::LspAdapterDelegate;
2446    use parking_lot::Mutex;
2447    use project::Project;
2448    use rand::prelude::*;
2449    use serde_json::json;
2450    use settings::SettingsStore;
2451    use std::{cell::RefCell, env, rc::Rc, sync::atomic::AtomicBool};
2452    use text::{network::Network, ToPoint};
2453    use ui::WindowContext;
2454    use unindent::Unindent;
2455    use util::{test::marked_text_ranges, RandomCharIter};
2456    use workspace::Workspace;
2457
2458    #[gpui::test]
2459    fn test_inserting_and_removing_messages(cx: &mut AppContext) {
2460        let settings_store = SettingsStore::test(cx);
2461        language_model::LanguageModelRegistry::test(cx);
2462        completion::LanguageModelCompletionProvider::test(cx);
2463        cx.set_global(settings_store);
2464        assistant_panel::init(cx);
2465        let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2466
2467        let context = cx.new_model(|cx| Context::local(registry, None, cx));
2468        let buffer = context.read(cx).buffer.clone();
2469
2470        let message_1 = context.read(cx).message_anchors[0].clone();
2471        assert_eq!(
2472            messages(&context, cx),
2473            vec![(message_1.id, Role::User, 0..0)]
2474        );
2475
2476        let message_2 = context.update(cx, |context, cx| {
2477            context
2478                .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx)
2479                .unwrap()
2480        });
2481        assert_eq!(
2482            messages(&context, cx),
2483            vec![
2484                (message_1.id, Role::User, 0..1),
2485                (message_2.id, Role::Assistant, 1..1)
2486            ]
2487        );
2488
2489        buffer.update(cx, |buffer, cx| {
2490            buffer.edit([(0..0, "1"), (1..1, "2")], None, cx)
2491        });
2492        assert_eq!(
2493            messages(&context, cx),
2494            vec![
2495                (message_1.id, Role::User, 0..2),
2496                (message_2.id, Role::Assistant, 2..3)
2497            ]
2498        );
2499
2500        let message_3 = context.update(cx, |context, cx| {
2501            context
2502                .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2503                .unwrap()
2504        });
2505        assert_eq!(
2506            messages(&context, cx),
2507            vec![
2508                (message_1.id, Role::User, 0..2),
2509                (message_2.id, Role::Assistant, 2..4),
2510                (message_3.id, Role::User, 4..4)
2511            ]
2512        );
2513
2514        let message_4 = context.update(cx, |context, cx| {
2515            context
2516                .insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2517                .unwrap()
2518        });
2519        assert_eq!(
2520            messages(&context, cx),
2521            vec![
2522                (message_1.id, Role::User, 0..2),
2523                (message_2.id, Role::Assistant, 2..4),
2524                (message_4.id, Role::User, 4..5),
2525                (message_3.id, Role::User, 5..5),
2526            ]
2527        );
2528
2529        buffer.update(cx, |buffer, cx| {
2530            buffer.edit([(4..4, "C"), (5..5, "D")], None, cx)
2531        });
2532        assert_eq!(
2533            messages(&context, cx),
2534            vec![
2535                (message_1.id, Role::User, 0..2),
2536                (message_2.id, Role::Assistant, 2..4),
2537                (message_4.id, Role::User, 4..6),
2538                (message_3.id, Role::User, 6..7),
2539            ]
2540        );
2541
2542        // Deleting across message boundaries merges the messages.
2543        buffer.update(cx, |buffer, cx| buffer.edit([(1..4, "")], None, cx));
2544        assert_eq!(
2545            messages(&context, cx),
2546            vec![
2547                (message_1.id, Role::User, 0..3),
2548                (message_3.id, Role::User, 3..4),
2549            ]
2550        );
2551
2552        // Undoing the deletion should also undo the merge.
2553        buffer.update(cx, |buffer, cx| buffer.undo(cx));
2554        assert_eq!(
2555            messages(&context, cx),
2556            vec![
2557                (message_1.id, Role::User, 0..2),
2558                (message_2.id, Role::Assistant, 2..4),
2559                (message_4.id, Role::User, 4..6),
2560                (message_3.id, Role::User, 6..7),
2561            ]
2562        );
2563
2564        // Redoing the deletion should also redo the merge.
2565        buffer.update(cx, |buffer, cx| buffer.redo(cx));
2566        assert_eq!(
2567            messages(&context, cx),
2568            vec![
2569                (message_1.id, Role::User, 0..3),
2570                (message_3.id, Role::User, 3..4),
2571            ]
2572        );
2573
2574        // Ensure we can still insert after a merged message.
2575        let message_5 = context.update(cx, |context, cx| {
2576            context
2577                .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
2578                .unwrap()
2579        });
2580        assert_eq!(
2581            messages(&context, cx),
2582            vec![
2583                (message_1.id, Role::User, 0..3),
2584                (message_5.id, Role::System, 3..4),
2585                (message_3.id, Role::User, 4..5)
2586            ]
2587        );
2588    }
2589
2590    #[gpui::test]
2591    fn test_message_splitting(cx: &mut AppContext) {
2592        let settings_store = SettingsStore::test(cx);
2593        cx.set_global(settings_store);
2594        language_model::LanguageModelRegistry::test(cx);
2595        completion::LanguageModelCompletionProvider::test(cx);
2596        assistant_panel::init(cx);
2597        let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2598
2599        let context = cx.new_model(|cx| Context::local(registry, None, cx));
2600        let buffer = context.read(cx).buffer.clone();
2601
2602        let message_1 = context.read(cx).message_anchors[0].clone();
2603        assert_eq!(
2604            messages(&context, cx),
2605            vec![(message_1.id, Role::User, 0..0)]
2606        );
2607
2608        buffer.update(cx, |buffer, cx| {
2609            buffer.edit([(0..0, "aaa\nbbb\nccc\nddd\n")], None, cx)
2610        });
2611
2612        let (_, message_2) = context.update(cx, |context, cx| context.split_message(3..3, cx));
2613        let message_2 = message_2.unwrap();
2614
2615        // We recycle newlines in the middle of a split message
2616        assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\nddd\n");
2617        assert_eq!(
2618            messages(&context, cx),
2619            vec![
2620                (message_1.id, Role::User, 0..4),
2621                (message_2.id, Role::User, 4..16),
2622            ]
2623        );
2624
2625        let (_, message_3) = context.update(cx, |context, cx| context.split_message(3..3, cx));
2626        let message_3 = message_3.unwrap();
2627
2628        // We don't recycle newlines at the end of a split message
2629        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2630        assert_eq!(
2631            messages(&context, cx),
2632            vec![
2633                (message_1.id, Role::User, 0..4),
2634                (message_3.id, Role::User, 4..5),
2635                (message_2.id, Role::User, 5..17),
2636            ]
2637        );
2638
2639        let (_, message_4) = context.update(cx, |context, cx| context.split_message(9..9, cx));
2640        let message_4 = message_4.unwrap();
2641        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\nccc\nddd\n");
2642        assert_eq!(
2643            messages(&context, cx),
2644            vec![
2645                (message_1.id, Role::User, 0..4),
2646                (message_3.id, Role::User, 4..5),
2647                (message_2.id, Role::User, 5..9),
2648                (message_4.id, Role::User, 9..17),
2649            ]
2650        );
2651
2652        let (_, message_5) = context.update(cx, |context, cx| context.split_message(9..9, cx));
2653        let message_5 = message_5.unwrap();
2654        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\nddd\n");
2655        assert_eq!(
2656            messages(&context, cx),
2657            vec![
2658                (message_1.id, Role::User, 0..4),
2659                (message_3.id, Role::User, 4..5),
2660                (message_2.id, Role::User, 5..9),
2661                (message_4.id, Role::User, 9..10),
2662                (message_5.id, Role::User, 10..18),
2663            ]
2664        );
2665
2666        let (message_6, message_7) =
2667            context.update(cx, |context, cx| context.split_message(14..16, cx));
2668        let message_6 = message_6.unwrap();
2669        let message_7 = message_7.unwrap();
2670        assert_eq!(buffer.read(cx).text(), "aaa\n\nbbb\n\nccc\ndd\nd\n");
2671        assert_eq!(
2672            messages(&context, cx),
2673            vec![
2674                (message_1.id, Role::User, 0..4),
2675                (message_3.id, Role::User, 4..5),
2676                (message_2.id, Role::User, 5..9),
2677                (message_4.id, Role::User, 9..10),
2678                (message_5.id, Role::User, 10..14),
2679                (message_6.id, Role::User, 14..17),
2680                (message_7.id, Role::User, 17..19),
2681            ]
2682        );
2683    }
2684
2685    #[gpui::test]
2686    fn test_messages_for_offsets(cx: &mut AppContext) {
2687        let settings_store = SettingsStore::test(cx);
2688        language_model::LanguageModelRegistry::test(cx);
2689        completion::LanguageModelCompletionProvider::test(cx);
2690        cx.set_global(settings_store);
2691        assistant_panel::init(cx);
2692        let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
2693        let context = cx.new_model(|cx| Context::local(registry, None, cx));
2694        let buffer = context.read(cx).buffer.clone();
2695
2696        let message_1 = context.read(cx).message_anchors[0].clone();
2697        assert_eq!(
2698            messages(&context, cx),
2699            vec![(message_1.id, Role::User, 0..0)]
2700        );
2701
2702        buffer.update(cx, |buffer, cx| buffer.edit([(0..0, "aaa")], None, cx));
2703        let message_2 = context
2704            .update(cx, |context, cx| {
2705                context.insert_message_after(message_1.id, Role::User, MessageStatus::Done, cx)
2706            })
2707            .unwrap();
2708        buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "bbb")], None, cx));
2709
2710        let message_3 = context
2711            .update(cx, |context, cx| {
2712                context.insert_message_after(message_2.id, Role::User, MessageStatus::Done, cx)
2713            })
2714            .unwrap();
2715        buffer.update(cx, |buffer, cx| buffer.edit([(8..8, "ccc")], None, cx));
2716
2717        assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc");
2718        assert_eq!(
2719            messages(&context, cx),
2720            vec![
2721                (message_1.id, Role::User, 0..4),
2722                (message_2.id, Role::User, 4..8),
2723                (message_3.id, Role::User, 8..11)
2724            ]
2725        );
2726
2727        assert_eq!(
2728            message_ids_for_offsets(&context, &[0, 4, 9], cx),
2729            [message_1.id, message_2.id, message_3.id]
2730        );
2731        assert_eq!(
2732            message_ids_for_offsets(&context, &[0, 1, 11], cx),
2733            [message_1.id, message_3.id]
2734        );
2735
2736        let message_4 = context
2737            .update(cx, |context, cx| {
2738                context.insert_message_after(message_3.id, Role::User, MessageStatus::Done, cx)
2739            })
2740            .unwrap();
2741        assert_eq!(buffer.read(cx).text(), "aaa\nbbb\nccc\n");
2742        assert_eq!(
2743            messages(&context, cx),
2744            vec![
2745                (message_1.id, Role::User, 0..4),
2746                (message_2.id, Role::User, 4..8),
2747                (message_3.id, Role::User, 8..12),
2748                (message_4.id, Role::User, 12..12)
2749            ]
2750        );
2751        assert_eq!(
2752            message_ids_for_offsets(&context, &[0, 4, 8, 12], cx),
2753            [message_1.id, message_2.id, message_3.id, message_4.id]
2754        );
2755
2756        fn message_ids_for_offsets(
2757            context: &Model<Context>,
2758            offsets: &[usize],
2759            cx: &AppContext,
2760        ) -> Vec<MessageId> {
2761            context
2762                .read(cx)
2763                .messages_for_offsets(offsets.iter().copied(), cx)
2764                .into_iter()
2765                .map(|message| message.id)
2766                .collect()
2767        }
2768    }
2769
2770    #[gpui::test]
2771    async fn test_slash_commands(cx: &mut TestAppContext) {
2772        let settings_store = cx.update(SettingsStore::test);
2773        cx.set_global(settings_store);
2774        cx.update(language_model::LanguageModelRegistry::test);
2775        cx.update(completion::LanguageModelCompletionProvider::test);
2776        cx.update(Project::init_settings);
2777        cx.update(assistant_panel::init);
2778        let fs = FakeFs::new(cx.background_executor.clone());
2779
2780        fs.insert_tree(
2781            "/test",
2782            json!({
2783                "src": {
2784                    "lib.rs": "fn one() -> usize { 1 }",
2785                    "main.rs": "
2786                        use crate::one;
2787                        fn main() { one(); }
2788                    ".unindent(),
2789                }
2790            }),
2791        )
2792        .await;
2793
2794        let slash_command_registry = cx.update(SlashCommandRegistry::default_global);
2795        slash_command_registry.register_command(file_command::FileSlashCommand, false);
2796        slash_command_registry.register_command(active_command::ActiveSlashCommand, false);
2797
2798        let registry = Arc::new(LanguageRegistry::test(cx.executor()));
2799        let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
2800
2801        let output_ranges = Rc::new(RefCell::new(HashSet::default()));
2802        context.update(cx, |_, cx| {
2803            cx.subscribe(&context, {
2804                let ranges = output_ranges.clone();
2805                move |_, _, event, _| match event {
2806                    ContextEvent::PendingSlashCommandsUpdated { removed, updated } => {
2807                        for range in removed {
2808                            ranges.borrow_mut().remove(range);
2809                        }
2810                        for command in updated {
2811                            ranges.borrow_mut().insert(command.source_range.clone());
2812                        }
2813                    }
2814                    _ => {}
2815                }
2816            })
2817            .detach();
2818        });
2819
2820        let buffer = context.read_with(cx, |context, _| context.buffer.clone());
2821
2822        // Insert a slash command
2823        buffer.update(cx, |buffer, cx| {
2824            buffer.edit([(0..0, "/file src/lib.rs")], None, cx);
2825        });
2826        assert_text_and_output_ranges(
2827            &buffer,
2828            &output_ranges.borrow(),
2829            "
2830            «/file src/lib.rs»
2831            "
2832            .unindent()
2833            .trim_end(),
2834            cx,
2835        );
2836
2837        // Edit the argument of the slash command.
2838        buffer.update(cx, |buffer, cx| {
2839            let edit_offset = buffer.text().find("lib.rs").unwrap();
2840            buffer.edit([(edit_offset..edit_offset + "lib".len(), "main")], None, cx);
2841        });
2842        assert_text_and_output_ranges(
2843            &buffer,
2844            &output_ranges.borrow(),
2845            "
2846            «/file src/main.rs»
2847            "
2848            .unindent()
2849            .trim_end(),
2850            cx,
2851        );
2852
2853        // Edit the name of the slash command, using one that doesn't exist.
2854        buffer.update(cx, |buffer, cx| {
2855            let edit_offset = buffer.text().find("/file").unwrap();
2856            buffer.edit(
2857                [(edit_offset..edit_offset + "/file".len(), "/unknown")],
2858                None,
2859                cx,
2860            );
2861        });
2862        assert_text_and_output_ranges(
2863            &buffer,
2864            &output_ranges.borrow(),
2865            "
2866            /unknown src/main.rs
2867            "
2868            .unindent()
2869            .trim_end(),
2870            cx,
2871        );
2872
2873        #[track_caller]
2874        fn assert_text_and_output_ranges(
2875            buffer: &Model<Buffer>,
2876            ranges: &HashSet<Range<language::Anchor>>,
2877            expected_marked_text: &str,
2878            cx: &mut TestAppContext,
2879        ) {
2880            let (expected_text, expected_ranges) = marked_text_ranges(expected_marked_text, false);
2881            let (actual_text, actual_ranges) = buffer.update(cx, |buffer, _| {
2882                let mut ranges = ranges
2883                    .iter()
2884                    .map(|range| range.to_offset(buffer))
2885                    .collect::<Vec<_>>();
2886                ranges.sort_by_key(|a| a.start);
2887                (buffer.text(), ranges)
2888            });
2889
2890            assert_eq!(actual_text, expected_text);
2891            assert_eq!(actual_ranges, expected_ranges);
2892        }
2893    }
2894
2895    #[gpui::test]
2896    async fn test_edit_step_parsing(cx: &mut TestAppContext) {
2897        cx.update(prompt_library::init);
2898        let settings_store = cx.update(SettingsStore::test);
2899        cx.set_global(settings_store);
2900
2901        let fake_provider = cx.update(language_model::LanguageModelRegistry::test);
2902        cx.update(completion::LanguageModelCompletionProvider::test);
2903
2904        let fake_model = fake_provider.test_model();
2905        cx.update(assistant_panel::init);
2906        let registry = Arc::new(LanguageRegistry::test(cx.executor()));
2907
2908        // Create a new context
2909        let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
2910        let buffer = context.read_with(cx, |context, _| context.buffer.clone());
2911
2912        // Simulate user input
2913        let user_message = indoc! {r#"
2914            Please refactor this code:
2915
2916            fn main() {
2917                println!("Hello, World!");
2918            }
2919        "#};
2920        buffer.update(cx, |buffer, cx| {
2921            buffer.edit([(0..0, user_message)], None, cx);
2922        });
2923
2924        // Simulate LLM response with edit steps
2925        let llm_response = indoc! {r#"
2926            Sure, I can help you refactor that code. Here's a step-by-step process:
2927
2928            <step>
2929            First, let's extract the greeting into a separate function:
2930
2931            ```rust
2932            fn greet() {
2933                println!("Hello, World!");
2934            }
2935
2936            fn main() {
2937                greet();
2938            }
2939            ```
2940            </step>
2941
2942            <step>
2943            Now, let's make the greeting customizable:
2944
2945            ```rust
2946            fn greet(name: &str) {
2947                println!("Hello, {}!", name);
2948            }
2949
2950            fn main() {
2951                greet("World");
2952            }
2953            ```
2954            </step>
2955
2956            These changes make the code more modular and flexible.
2957        "#};
2958
2959        // Simulate the assist method to trigger the LLM response
2960        context.update(cx, |context, cx| context.assist(cx));
2961        cx.run_until_parked();
2962
2963        // Retrieve the assistant response message's start from the context
2964        let response_start_row = context.read_with(cx, |context, cx| {
2965            let buffer = context.buffer.read(cx);
2966            context.message_anchors[1].start.to_point(buffer).row
2967        });
2968
2969        // Simulate the LLM completion
2970        fake_model.send_last_completion_chunk(llm_response.to_string());
2971        fake_model.finish_last_completion();
2972
2973        // Wait for the completion to be processed
2974        cx.run_until_parked();
2975
2976        // Verify that the edit steps were parsed correctly
2977        context.read_with(cx, |context, cx| {
2978            assert_eq!(
2979                edit_steps(context, cx),
2980                vec![
2981                    Point::new(response_start_row + 2, 0)..Point::new(response_start_row + 14, 7),
2982                    Point::new(response_start_row + 16, 0)..Point::new(response_start_row + 28, 7),
2983                ]
2984            );
2985        });
2986
2987        fn edit_steps(context: &Context, cx: &AppContext) -> Vec<Range<Point>> {
2988            context
2989                .edit_steps
2990                .iter()
2991                .map(|step| {
2992                    let buffer = context.buffer.read(cx);
2993                    step.source_range.to_point(buffer)
2994                })
2995                .collect()
2996        }
2997    }
2998
2999    #[gpui::test]
3000    async fn test_serialization(cx: &mut TestAppContext) {
3001        let settings_store = cx.update(SettingsStore::test);
3002        cx.set_global(settings_store);
3003        cx.update(language_model::LanguageModelRegistry::test);
3004        cx.update(completion::LanguageModelCompletionProvider::test);
3005        cx.update(assistant_panel::init);
3006        let registry = Arc::new(LanguageRegistry::test(cx.executor()));
3007        let context = cx.new_model(|cx| Context::local(registry.clone(), None, cx));
3008        let buffer = context.read_with(cx, |context, _| context.buffer.clone());
3009        let message_0 = context.read_with(cx, |context, _| context.message_anchors[0].id);
3010        let message_1 = context.update(cx, |context, cx| {
3011            context
3012                .insert_message_after(message_0, Role::Assistant, MessageStatus::Done, cx)
3013                .unwrap()
3014        });
3015        let message_2 = context.update(cx, |context, cx| {
3016            context
3017                .insert_message_after(message_1.id, Role::System, MessageStatus::Done, cx)
3018                .unwrap()
3019        });
3020        buffer.update(cx, |buffer, cx| {
3021            buffer.edit([(0..0, "a"), (1..1, "b\nc")], None, cx);
3022            buffer.finalize_last_transaction();
3023        });
3024        let _message_3 = context.update(cx, |context, cx| {
3025            context
3026                .insert_message_after(message_2.id, Role::System, MessageStatus::Done, cx)
3027                .unwrap()
3028        });
3029        buffer.update(cx, |buffer, cx| buffer.undo(cx));
3030        assert_eq!(buffer.read_with(cx, |buffer, _| buffer.text()), "a\nb\nc\n");
3031        assert_eq!(
3032            cx.read(|cx| messages(&context, cx)),
3033            [
3034                (message_0, Role::User, 0..2),
3035                (message_1.id, Role::Assistant, 2..6),
3036                (message_2.id, Role::System, 6..6),
3037            ]
3038        );
3039
3040        let serialized_context = context.read_with(cx, |context, cx| context.serialize(cx));
3041        let deserialized_context = cx.new_model(|cx| {
3042            Context::deserialize(
3043                serialized_context,
3044                Default::default(),
3045                registry.clone(),
3046                None,
3047                cx,
3048            )
3049        });
3050        let deserialized_buffer =
3051            deserialized_context.read_with(cx, |context, _| context.buffer.clone());
3052        assert_eq!(
3053            deserialized_buffer.read_with(cx, |buffer, _| buffer.text()),
3054            "a\nb\nc\n"
3055        );
3056        assert_eq!(
3057            cx.read(|cx| messages(&deserialized_context, cx)),
3058            [
3059                (message_0, Role::User, 0..2),
3060                (message_1.id, Role::Assistant, 2..6),
3061                (message_2.id, Role::System, 6..6),
3062            ]
3063        );
3064    }
3065
3066    #[gpui::test(iterations = 100)]
3067    async fn test_random_context_collaboration(cx: &mut TestAppContext, mut rng: StdRng) {
3068        let min_peers = env::var("MIN_PEERS")
3069            .map(|i| i.parse().expect("invalid `MIN_PEERS` variable"))
3070            .unwrap_or(2);
3071        let max_peers = env::var("MAX_PEERS")
3072            .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
3073            .unwrap_or(5);
3074        let operations = env::var("OPERATIONS")
3075            .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
3076            .unwrap_or(50);
3077
3078        let settings_store = cx.update(SettingsStore::test);
3079        cx.set_global(settings_store);
3080        cx.update(language_model::LanguageModelRegistry::test);
3081        cx.update(completion::LanguageModelCompletionProvider::test);
3082
3083        cx.update(assistant_panel::init);
3084        let slash_commands = cx.update(SlashCommandRegistry::default_global);
3085        slash_commands.register_command(FakeSlashCommand("cmd-1".into()), false);
3086        slash_commands.register_command(FakeSlashCommand("cmd-2".into()), false);
3087        slash_commands.register_command(FakeSlashCommand("cmd-3".into()), false);
3088
3089        let registry = Arc::new(LanguageRegistry::test(cx.background_executor.clone()));
3090        let network = Arc::new(Mutex::new(Network::new(rng.clone())));
3091        let mut contexts = Vec::new();
3092
3093        let num_peers = rng.gen_range(min_peers..=max_peers);
3094        let context_id = ContextId::new();
3095        for i in 0..num_peers {
3096            let context = cx.new_model(|cx| {
3097                Context::new(
3098                    context_id.clone(),
3099                    i as ReplicaId,
3100                    language::Capability::ReadWrite,
3101                    registry.clone(),
3102                    None,
3103                    cx,
3104                )
3105            });
3106
3107            cx.update(|cx| {
3108                cx.subscribe(&context, {
3109                    let network = network.clone();
3110                    move |_, event, _| {
3111                        if let ContextEvent::Operation(op) = event {
3112                            network
3113                                .lock()
3114                                .broadcast(i as ReplicaId, vec![op.to_proto()]);
3115                        }
3116                    }
3117                })
3118                .detach();
3119            });
3120
3121            contexts.push(context);
3122            network.lock().add_peer(i as ReplicaId);
3123        }
3124
3125        let mut mutation_count = operations;
3126
3127        while mutation_count > 0
3128            || !network.lock().is_idle()
3129            || network.lock().contains_disconnected_peers()
3130        {
3131            let context_index = rng.gen_range(0..contexts.len());
3132            let context = &contexts[context_index];
3133
3134            match rng.gen_range(0..100) {
3135                0..=29 if mutation_count > 0 => {
3136                    log::info!("Context {}: edit buffer", context_index);
3137                    context.update(cx, |context, cx| {
3138                        context
3139                            .buffer
3140                            .update(cx, |buffer, cx| buffer.randomly_edit(&mut rng, 1, cx));
3141                    });
3142                    mutation_count -= 1;
3143                }
3144                30..=44 if mutation_count > 0 => {
3145                    context.update(cx, |context, cx| {
3146                        let range = context.buffer.read(cx).random_byte_range(0, &mut rng);
3147                        log::info!("Context {}: split message at {:?}", context_index, range);
3148                        context.split_message(range, cx);
3149                    });
3150                    mutation_count -= 1;
3151                }
3152                45..=59 if mutation_count > 0 => {
3153                    context.update(cx, |context, cx| {
3154                        if let Some(message) = context.messages(cx).choose(&mut rng) {
3155                            let role = *[Role::User, Role::Assistant, Role::System]
3156                                .choose(&mut rng)
3157                                .unwrap();
3158                            log::info!(
3159                                "Context {}: insert message after {:?} with {:?}",
3160                                context_index,
3161                                message.id,
3162                                role
3163                            );
3164                            context.insert_message_after(message.id, role, MessageStatus::Done, cx);
3165                        }
3166                    });
3167                    mutation_count -= 1;
3168                }
3169                60..=74 if mutation_count > 0 => {
3170                    context.update(cx, |context, cx| {
3171                        let command_text = "/".to_string()
3172                            + slash_commands
3173                                .command_names()
3174                                .choose(&mut rng)
3175                                .unwrap()
3176                                .clone()
3177                                .as_ref();
3178
3179                        let command_range = context.buffer.update(cx, |buffer, cx| {
3180                            let offset = buffer.random_byte_range(0, &mut rng).start;
3181                            buffer.edit(
3182                                [(offset..offset, format!("\n{}\n", command_text))],
3183                                None,
3184                                cx,
3185                            );
3186                            offset + 1..offset + 1 + command_text.len()
3187                        });
3188
3189                        let output_len = rng.gen_range(1..=10);
3190                        let output_text = RandomCharIter::new(&mut rng)
3191                            .filter(|c| *c != '\r')
3192                            .take(output_len)
3193                            .collect::<String>();
3194
3195                        let num_sections = rng.gen_range(0..=3);
3196                        let mut sections = Vec::with_capacity(num_sections);
3197                        for _ in 0..num_sections {
3198                            let section_start = rng.gen_range(0..output_len);
3199                            let section_end = rng.gen_range(section_start..=output_len);
3200                            sections.push(SlashCommandOutputSection {
3201                                range: section_start..section_end,
3202                                icon: ui::IconName::Ai,
3203                                label: "section".into(),
3204                            });
3205                        }
3206
3207                        log::info!(
3208                            "Context {}: insert slash command output at {:?} with {:?}",
3209                            context_index,
3210                            command_range,
3211                            sections
3212                        );
3213
3214                        let command_range =
3215                            context.buffer.read(cx).anchor_after(command_range.start)
3216                                ..context.buffer.read(cx).anchor_after(command_range.end);
3217                        context.insert_command_output(
3218                            command_range,
3219                            Task::ready(Ok(SlashCommandOutput {
3220                                text: output_text,
3221                                sections,
3222                                run_commands_in_text: false,
3223                            })),
3224                            true,
3225                            cx,
3226                        );
3227                    });
3228                    cx.run_until_parked();
3229                    mutation_count -= 1;
3230                }
3231                75..=84 if mutation_count > 0 => {
3232                    context.update(cx, |context, cx| {
3233                        if let Some(message) = context.messages(cx).choose(&mut rng) {
3234                            let new_status = match rng.gen_range(0..3) {
3235                                0 => MessageStatus::Done,
3236                                1 => MessageStatus::Pending,
3237                                _ => MessageStatus::Error(SharedString::from("Random error")),
3238                            };
3239                            log::info!(
3240                                "Context {}: update message {:?} status to {:?}",
3241                                context_index,
3242                                message.id,
3243                                new_status
3244                            );
3245                            context.update_metadata(message.id, cx, |metadata| {
3246                                metadata.status = new_status;
3247                            });
3248                        }
3249                    });
3250                    mutation_count -= 1;
3251                }
3252                _ => {
3253                    let replica_id = context_index as ReplicaId;
3254                    if network.lock().is_disconnected(replica_id) {
3255                        network.lock().reconnect_peer(replica_id, 0);
3256
3257                        let (ops_to_send, ops_to_receive) = cx.read(|cx| {
3258                            let host_context = &contexts[0].read(cx);
3259                            let guest_context = context.read(cx);
3260                            (
3261                                guest_context.serialize_ops(&host_context.version(cx), cx),
3262                                host_context.serialize_ops(&guest_context.version(cx), cx),
3263                            )
3264                        });
3265                        let ops_to_send = ops_to_send.await;
3266                        let ops_to_receive = ops_to_receive
3267                            .await
3268                            .into_iter()
3269                            .map(ContextOperation::from_proto)
3270                            .collect::<Result<Vec<_>>>()
3271                            .unwrap();
3272                        log::info!(
3273                            "Context {}: reconnecting. Sent {} operations, received {} operations",
3274                            context_index,
3275                            ops_to_send.len(),
3276                            ops_to_receive.len()
3277                        );
3278
3279                        network.lock().broadcast(replica_id, ops_to_send);
3280                        context
3281                            .update(cx, |context, cx| context.apply_ops(ops_to_receive, cx))
3282                            .unwrap();
3283                    } else if rng.gen_bool(0.1) && replica_id != 0 {
3284                        log::info!("Context {}: disconnecting", context_index);
3285                        network.lock().disconnect_peer(replica_id);
3286                    } else if network.lock().has_unreceived(replica_id) {
3287                        log::info!("Context {}: applying operations", context_index);
3288                        let ops = network.lock().receive(replica_id);
3289                        let ops = ops
3290                            .into_iter()
3291                            .map(ContextOperation::from_proto)
3292                            .collect::<Result<Vec<_>>>()
3293                            .unwrap();
3294                        context
3295                            .update(cx, |context, cx| context.apply_ops(ops, cx))
3296                            .unwrap();
3297                    }
3298                }
3299            }
3300        }
3301
3302        cx.read(|cx| {
3303            let first_context = contexts[0].read(cx);
3304            for context in &contexts[1..] {
3305                let context = context.read(cx);
3306                assert!(context.pending_ops.is_empty());
3307                assert_eq!(
3308                    context.buffer.read(cx).text(),
3309                    first_context.buffer.read(cx).text(),
3310                    "Context {} text != Context 0 text",
3311                    context.buffer.read(cx).replica_id()
3312                );
3313                assert_eq!(
3314                    context.message_anchors,
3315                    first_context.message_anchors,
3316                    "Context {} messages != Context 0 messages",
3317                    context.buffer.read(cx).replica_id()
3318                );
3319                assert_eq!(
3320                    context.messages_metadata,
3321                    first_context.messages_metadata,
3322                    "Context {} message metadata != Context 0 message metadata",
3323                    context.buffer.read(cx).replica_id()
3324                );
3325                assert_eq!(
3326                    context.slash_command_output_sections,
3327                    first_context.slash_command_output_sections,
3328                    "Context {} slash command output sections != Context 0 slash command output sections",
3329                    context.buffer.read(cx).replica_id()
3330                );
3331            }
3332        });
3333    }
3334
3335    fn messages(context: &Model<Context>, cx: &AppContext) -> Vec<(MessageId, Role, Range<usize>)> {
3336        context
3337            .read(cx)
3338            .messages(cx)
3339            .map(|message| (message.id, message.role, message.offset_range))
3340            .collect()
3341    }
3342
3343    #[derive(Clone)]
3344    struct FakeSlashCommand(String);
3345
3346    impl SlashCommand for FakeSlashCommand {
3347        fn name(&self) -> String {
3348            self.0.clone()
3349        }
3350
3351        fn description(&self) -> String {
3352            format!("Fake slash command: {}", self.0)
3353        }
3354
3355        fn menu_text(&self) -> String {
3356            format!("Run fake command: {}", self.0)
3357        }
3358
3359        fn complete_argument(
3360            self: Arc<Self>,
3361            _query: String,
3362            _cancel: Arc<AtomicBool>,
3363            _workspace: Option<WeakView<Workspace>>,
3364            _cx: &mut AppContext,
3365        ) -> Task<Result<Vec<ArgumentCompletion>>> {
3366            Task::ready(Ok(vec![]))
3367        }
3368
3369        fn requires_argument(&self) -> bool {
3370            false
3371        }
3372
3373        fn run(
3374            self: Arc<Self>,
3375            _argument: Option<&str>,
3376            _workspace: WeakView<Workspace>,
3377            _delegate: Arc<dyn LspAdapterDelegate>,
3378            _cx: &mut WindowContext,
3379        ) -> Task<Result<SlashCommandOutput>> {
3380            Task::ready(Ok(SlashCommandOutput {
3381                text: format!("Executed fake command: {}", self.0),
3382                sections: vec![],
3383                run_commands_in_text: false,
3384            }))
3385        }
3386    }
3387}