text_thread_store.rs

   1use crate::{
   2    SavedTextThread, SavedTextThreadMetadata, TextThread, TextThreadEvent, TextThreadId,
   3    TextThreadOperation, TextThreadVersion,
   4};
   5use anyhow::{Context as _, Result};
   6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
   7use client::{Client, TypedEnvelope, proto};
   8use clock::ReplicaId;
   9use collections::HashMap;
  10use context_server::ContextServerId;
  11use fs::{Fs, RemoveOptions};
  12use futures::StreamExt;
  13use fuzzy::StringMatchCandidate;
  14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
  15use itertools::Itertools;
  16use language::LanguageRegistry;
  17use paths::text_threads_dir;
  18use project::{
  19    Project,
  20    context_server_store::{ContextServerStatus, ContextServerStore},
  21};
  22use prompt_store::PromptBuilder;
  23use regex::Regex;
  24use rpc::AnyProtoClient;
  25use std::sync::LazyLock;
  26use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
  27use util::{ResultExt, TryFutureExt};
  28use zed_env_vars::ZED_STATELESS;
  29
  30pub(crate) fn init(client: &AnyProtoClient) {
  31    client.add_entity_message_handler(TextThreadStore::handle_advertise_contexts);
  32    client.add_entity_request_handler(TextThreadStore::handle_open_context);
  33    client.add_entity_request_handler(TextThreadStore::handle_create_context);
  34    client.add_entity_message_handler(TextThreadStore::handle_update_context);
  35    client.add_entity_request_handler(TextThreadStore::handle_synchronize_contexts);
  36}
  37
  38#[derive(Clone)]
  39pub struct RemoteTextThreadMetadata {
  40    pub id: TextThreadId,
  41    pub summary: Option<String>,
  42}
  43
  44pub struct TextThreadStore {
  45    text_threads: Vec<TextThreadHandle>,
  46    text_threads_metadata: Vec<SavedTextThreadMetadata>,
  47    context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
  48    host_text_threads: Vec<RemoteTextThreadMetadata>,
  49    fs: Arc<dyn Fs>,
  50    languages: Arc<LanguageRegistry>,
  51    slash_commands: Arc<SlashCommandWorkingSet>,
  52    _watch_updates: Task<Option<()>>,
  53    client: Arc<Client>,
  54    project: WeakEntity<Project>,
  55    project_is_shared: bool,
  56    client_subscription: Option<client::Subscription>,
  57    _project_subscriptions: Vec<gpui::Subscription>,
  58    prompt_builder: Arc<PromptBuilder>,
  59}
  60
  61enum TextThreadHandle {
  62    Weak(WeakEntity<TextThread>),
  63    Strong(Entity<TextThread>),
  64}
  65
  66impl TextThreadHandle {
  67    fn upgrade(&self) -> Option<Entity<TextThread>> {
  68        match self {
  69            TextThreadHandle::Weak(weak) => weak.upgrade(),
  70            TextThreadHandle::Strong(strong) => Some(strong.clone()),
  71        }
  72    }
  73
  74    fn downgrade(&self) -> WeakEntity<TextThread> {
  75        match self {
  76            TextThreadHandle::Weak(weak) => weak.clone(),
  77            TextThreadHandle::Strong(strong) => strong.downgrade(),
  78        }
  79    }
  80}
  81
  82impl TextThreadStore {
  83    pub fn new(
  84        project: Entity<Project>,
  85        prompt_builder: Arc<PromptBuilder>,
  86        slash_commands: Arc<SlashCommandWorkingSet>,
  87        cx: &mut App,
  88    ) -> Task<Result<Entity<Self>>> {
  89        let fs = project.read(cx).fs().clone();
  90        let languages = project.read(cx).languages().clone();
  91        cx.spawn(async move |cx| {
  92            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
  93            let (mut events, _) = fs.watch(text_threads_dir(), CONTEXT_WATCH_DURATION).await;
  94
  95            let this = cx.new(|cx: &mut Context<Self>| {
  96                let mut this = Self {
  97                    text_threads: Vec::new(),
  98                    text_threads_metadata: Vec::new(),
  99                    context_server_slash_command_ids: HashMap::default(),
 100                    host_text_threads: Vec::new(),
 101                    fs,
 102                    languages,
 103                    slash_commands,
 104                    _watch_updates: cx.spawn(async move |this, cx| {
 105                        async move {
 106                            while events.next().await.is_some() {
 107                                this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
 108                            }
 109                            anyhow::Ok(())
 110                        }
 111                        .log_err()
 112                        .await
 113                    }),
 114                    client_subscription: None,
 115                    _project_subscriptions: vec![
 116                        cx.subscribe(&project, Self::handle_project_event),
 117                    ],
 118                    project_is_shared: false,
 119                    client: project.read(cx).client(),
 120                    project: project.downgrade(),
 121                    prompt_builder,
 122                };
 123                this.handle_project_shared(cx);
 124                this.synchronize_contexts(cx);
 125                this.register_context_server_handlers(cx);
 126                this.reload(cx).detach_and_log_err(cx);
 127                this
 128            });
 129
 130            Ok(this)
 131        })
 132    }
 133
 134    #[cfg(any(test, feature = "test-support"))]
 135    pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 136        Self {
 137            text_threads: Default::default(),
 138            text_threads_metadata: Default::default(),
 139            context_server_slash_command_ids: Default::default(),
 140            host_text_threads: Default::default(),
 141            fs: project.read(cx).fs().clone(),
 142            languages: project.read(cx).languages().clone(),
 143            slash_commands: Arc::default(),
 144            _watch_updates: Task::ready(None),
 145            client: project.read(cx).client(),
 146            project: project.downgrade(),
 147            project_is_shared: false,
 148            client_subscription: None,
 149            _project_subscriptions: Default::default(),
 150            prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
 151        }
 152    }
 153
 154    async fn handle_advertise_contexts(
 155        this: Entity<Self>,
 156        envelope: TypedEnvelope<proto::AdvertiseContexts>,
 157        mut cx: AsyncApp,
 158    ) -> Result<()> {
 159        this.update(&mut cx, |this, cx| {
 160            this.host_text_threads = envelope
 161                .payload
 162                .contexts
 163                .into_iter()
 164                .map(|text_thread| RemoteTextThreadMetadata {
 165                    id: TextThreadId::from_proto(text_thread.context_id),
 166                    summary: text_thread.summary,
 167                })
 168                .collect();
 169            cx.notify();
 170        });
 171        Ok(())
 172    }
 173
 174    async fn handle_open_context(
 175        this: Entity<Self>,
 176        envelope: TypedEnvelope<proto::OpenContext>,
 177        mut cx: AsyncApp,
 178    ) -> Result<proto::OpenContextResponse> {
 179        let context_id = TextThreadId::from_proto(envelope.payload.context_id);
 180        let operations = this.update(&mut cx, |this, cx| {
 181            let project = this.project.upgrade().context("project not found")?;
 182
 183            anyhow::ensure!(
 184                !project.read(cx).is_via_collab(),
 185                "only the host contexts can be opened"
 186            );
 187
 188            let text_thread = this
 189                .loaded_text_thread_for_id(&context_id, cx)
 190                .context("context not found")?;
 191            anyhow::ensure!(
 192                text_thread.read(cx).replica_id() == ReplicaId::default(),
 193                "context must be opened via the host"
 194            );
 195
 196            anyhow::Ok(
 197                text_thread
 198                    .read(cx)
 199                    .serialize_ops(&TextThreadVersion::default(), cx),
 200            )
 201        })?;
 202        let operations = operations.await;
 203        Ok(proto::OpenContextResponse {
 204            context: Some(proto::Context { operations }),
 205        })
 206    }
 207
 208    async fn handle_create_context(
 209        this: Entity<Self>,
 210        _: TypedEnvelope<proto::CreateContext>,
 211        mut cx: AsyncApp,
 212    ) -> Result<proto::CreateContextResponse> {
 213        let (context_id, operations) = this.update(&mut cx, |this, cx| {
 214            let project = this.project.upgrade().context("project not found")?;
 215            anyhow::ensure!(
 216                !project.read(cx).is_via_collab(),
 217                "can only create contexts as the host"
 218            );
 219
 220            let text_thread = this.create(cx);
 221            let context_id = text_thread.read(cx).id().clone();
 222
 223            anyhow::Ok((
 224                context_id,
 225                text_thread
 226                    .read(cx)
 227                    .serialize_ops(&TextThreadVersion::default(), cx),
 228            ))
 229        })?;
 230        let operations = operations.await;
 231        Ok(proto::CreateContextResponse {
 232            context_id: context_id.to_proto(),
 233            context: Some(proto::Context { operations }),
 234        })
 235    }
 236
 237    async fn handle_update_context(
 238        this: Entity<Self>,
 239        envelope: TypedEnvelope<proto::UpdateContext>,
 240        mut cx: AsyncApp,
 241    ) -> Result<()> {
 242        this.update(&mut cx, |this, cx| {
 243            let context_id = TextThreadId::from_proto(envelope.payload.context_id);
 244            if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
 245                let operation_proto = envelope.payload.operation.context("invalid operation")?;
 246                let operation = TextThreadOperation::from_proto(operation_proto)?;
 247                text_thread.update(cx, |text_thread, cx| text_thread.apply_ops([operation], cx));
 248            }
 249            Ok(())
 250        })
 251    }
 252
 253    async fn handle_synchronize_contexts(
 254        this: Entity<Self>,
 255        envelope: TypedEnvelope<proto::SynchronizeContexts>,
 256        mut cx: AsyncApp,
 257    ) -> Result<proto::SynchronizeContextsResponse> {
 258        this.update(&mut cx, |this, cx| {
 259            let project = this.project.upgrade().context("project not found")?;
 260            anyhow::ensure!(
 261                !project.read(cx).is_via_collab(),
 262                "only the host can synchronize contexts"
 263            );
 264
 265            let mut local_versions = Vec::new();
 266            for remote_version_proto in envelope.payload.contexts {
 267                let remote_version = TextThreadVersion::from_proto(&remote_version_proto);
 268                let context_id = TextThreadId::from_proto(remote_version_proto.context_id);
 269                if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
 270                    let text_thread = text_thread.read(cx);
 271                    let operations = text_thread.serialize_ops(&remote_version, cx);
 272                    local_versions.push(text_thread.version(cx).to_proto(context_id.clone()));
 273                    let client = this.client.clone();
 274                    let project_id = envelope.payload.project_id;
 275                    cx.background_spawn(async move {
 276                        let operations = operations.await;
 277                        for operation in operations {
 278                            client.send(proto::UpdateContext {
 279                                project_id,
 280                                context_id: context_id.to_proto(),
 281                                operation: Some(operation),
 282                            })?;
 283                        }
 284                        anyhow::Ok(())
 285                    })
 286                    .detach_and_log_err(cx);
 287                }
 288            }
 289
 290            this.advertise_contexts(cx);
 291
 292            anyhow::Ok(proto::SynchronizeContextsResponse {
 293                contexts: local_versions,
 294            })
 295        })
 296    }
 297
 298    fn handle_project_shared(&mut self, cx: &mut Context<Self>) {
 299        let Some(project) = self.project.upgrade() else {
 300            return;
 301        };
 302
 303        let is_shared = project.read(cx).is_shared();
 304        let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
 305        if is_shared == was_shared {
 306            return;
 307        }
 308
 309        if is_shared {
 310            self.text_threads.retain_mut(|text_thread| {
 311                if let Some(strong_context) = text_thread.upgrade() {
 312                    *text_thread = TextThreadHandle::Strong(strong_context);
 313                    true
 314                } else {
 315                    false
 316                }
 317            });
 318            let remote_id = project.read(cx).remote_id().unwrap();
 319            self.client_subscription = self
 320                .client
 321                .subscribe_to_entity(remote_id)
 322                .log_err()
 323                .map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async()));
 324            self.advertise_contexts(cx);
 325        } else {
 326            self.client_subscription = None;
 327        }
 328    }
 329
 330    fn handle_project_event(
 331        &mut self,
 332        _project: Entity<Project>,
 333        event: &project::Event,
 334        cx: &mut Context<Self>,
 335    ) {
 336        match event {
 337            project::Event::RemoteIdChanged(_) => {
 338                self.handle_project_shared(cx);
 339            }
 340            project::Event::Reshared => {
 341                self.advertise_contexts(cx);
 342            }
 343            project::Event::HostReshared | project::Event::Rejoined => {
 344                self.synchronize_contexts(cx);
 345            }
 346            project::Event::DisconnectedFromHost => {
 347                self.text_threads.retain_mut(|text_thread| {
 348                    if let Some(strong_context) = text_thread.upgrade() {
 349                        *text_thread = TextThreadHandle::Weak(text_thread.downgrade());
 350                        strong_context.update(cx, |text_thread, cx| {
 351                            if text_thread.replica_id() != ReplicaId::default() {
 352                                text_thread.set_capability(language::Capability::ReadOnly, cx);
 353                            }
 354                        });
 355                        true
 356                    } else {
 357                        false
 358                    }
 359                });
 360                self.host_text_threads.clear();
 361                cx.notify();
 362            }
 363            _ => {}
 364        }
 365    }
 366
 367    /// Returns saved threads ordered by `mtime` descending (newest first).
 368    pub fn ordered_text_threads(&self) -> impl Iterator<Item = &SavedTextThreadMetadata> {
 369        self.text_threads_metadata
 370            .iter()
 371            .sorted_by(|a, b| b.mtime.cmp(&a.mtime))
 372    }
 373
 374    pub fn has_saved_text_threads(&self) -> bool {
 375        !self.text_threads_metadata.is_empty()
 376    }
 377
 378    pub fn host_text_threads(&self) -> impl Iterator<Item = &RemoteTextThreadMetadata> {
 379        self.host_text_threads.iter()
 380    }
 381
 382    pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<TextThread> {
 383        let context = cx.new(|cx| {
 384            TextThread::local(
 385                self.languages.clone(),
 386                Some(self.project.clone()),
 387                self.prompt_builder.clone(),
 388                self.slash_commands.clone(),
 389                cx,
 390            )
 391        });
 392        self.register_text_thread(&context, cx);
 393        context
 394    }
 395
 396    pub fn create_remote(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<TextThread>>> {
 397        let Some(project) = self.project.upgrade() else {
 398            return Task::ready(Err(anyhow::anyhow!("project was dropped")));
 399        };
 400        let project = project.read(cx);
 401        let Some(project_id) = project.remote_id() else {
 402            return Task::ready(Err(anyhow::anyhow!("project was not remote")));
 403        };
 404
 405        let replica_id = project.replica_id();
 406        let capability = project.capability();
 407        let language_registry = self.languages.clone();
 408        let project = self.project.clone();
 409
 410        let prompt_builder = self.prompt_builder.clone();
 411        let slash_commands = self.slash_commands.clone();
 412        let request = self.client.request(proto::CreateContext { project_id });
 413        cx.spawn(async move |this, cx| {
 414            let response = request.await?;
 415            let context_id = TextThreadId::from_proto(response.context_id);
 416            let context_proto = response.context.context("invalid context")?;
 417            let text_thread = cx.new(|cx| {
 418                TextThread::new(
 419                    context_id.clone(),
 420                    replica_id,
 421                    capability,
 422                    language_registry,
 423                    prompt_builder,
 424                    slash_commands,
 425                    Some(project),
 426                    cx,
 427                )
 428            });
 429            let operations = cx
 430                .background_spawn(async move {
 431                    context_proto
 432                        .operations
 433                        .into_iter()
 434                        .map(TextThreadOperation::from_proto)
 435                        .collect::<Result<Vec<_>>>()
 436                })
 437                .await?;
 438            text_thread.update(cx, |context, cx| context.apply_ops(operations, cx));
 439            this.update(cx, |this, cx| {
 440                if let Some(existing_context) = this.loaded_text_thread_for_id(&context_id, cx) {
 441                    existing_context
 442                } else {
 443                    this.register_text_thread(&text_thread, cx);
 444                    this.synchronize_contexts(cx);
 445                    text_thread
 446                }
 447            })
 448        })
 449    }
 450
 451    pub fn open_local(
 452        &mut self,
 453        path: Arc<Path>,
 454        cx: &Context<Self>,
 455    ) -> Task<Result<Entity<TextThread>>> {
 456        if let Some(existing_context) = self.loaded_text_thread_for_path(&path, cx) {
 457            return Task::ready(Ok(existing_context));
 458        }
 459
 460        let fs = self.fs.clone();
 461        let languages = self.languages.clone();
 462        let project = self.project.clone();
 463        let load = cx.background_spawn({
 464            let path = path.clone();
 465            async move {
 466                let saved_context = fs.load(&path).await?;
 467                SavedTextThread::from_json(&saved_context)
 468            }
 469        });
 470        let prompt_builder = self.prompt_builder.clone();
 471        let slash_commands = self.slash_commands.clone();
 472
 473        cx.spawn(async move |this, cx| {
 474            let saved_context = load.await?;
 475            let context = cx.new(|cx| {
 476                TextThread::deserialize(
 477                    saved_context,
 478                    path.clone(),
 479                    languages,
 480                    prompt_builder,
 481                    slash_commands,
 482                    Some(project),
 483                    cx,
 484                )
 485            });
 486            this.update(cx, |this, cx| {
 487                if let Some(existing_context) = this.loaded_text_thread_for_path(&path, cx) {
 488                    existing_context
 489                } else {
 490                    this.register_text_thread(&context, cx);
 491                    context
 492                }
 493            })
 494        })
 495    }
 496
 497    pub fn delete_local(&mut self, path: Arc<Path>, cx: &mut Context<Self>) -> Task<Result<()>> {
 498        let fs = self.fs.clone();
 499
 500        cx.spawn(async move |this, cx| {
 501            fs.remove_file(
 502                &path,
 503                RemoveOptions {
 504                    recursive: false,
 505                    ignore_if_not_exists: true,
 506                },
 507            )
 508            .await?;
 509
 510            this.update(cx, |this, cx| {
 511                this.text_threads.retain(|text_thread| {
 512                    text_thread
 513                        .upgrade()
 514                        .and_then(|text_thread| text_thread.read(cx).path())
 515                        != Some(&path)
 516                });
 517                this.text_threads_metadata
 518                    .retain(|text_thread| text_thread.path.as_ref() != path.as_ref());
 519            })?;
 520
 521            Ok(())
 522        })
 523    }
 524
 525    pub fn delete_all_local(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 526        let fs = self.fs.clone();
 527        let paths = self
 528            .text_threads_metadata
 529            .iter()
 530            .map(|metadata| metadata.path.clone())
 531            .collect::<Vec<_>>();
 532
 533        cx.spawn(async move |this, cx| {
 534            for path in paths {
 535                fs.remove_file(
 536                    &path,
 537                    RemoveOptions {
 538                        recursive: false,
 539                        ignore_if_not_exists: true,
 540                    },
 541                )
 542                .await?;
 543            }
 544
 545            this.update(cx, |this, cx| {
 546                this.text_threads.clear();
 547                this.text_threads_metadata.clear();
 548                cx.notify();
 549            })?;
 550
 551            Ok(())
 552        })
 553    }
 554
 555    fn loaded_text_thread_for_path(&self, path: &Path, cx: &App) -> Option<Entity<TextThread>> {
 556        self.text_threads.iter().find_map(|text_thread| {
 557            let text_thread = text_thread.upgrade()?;
 558            if text_thread.read(cx).path().map(Arc::as_ref) == Some(path) {
 559                Some(text_thread)
 560            } else {
 561                None
 562            }
 563        })
 564    }
 565
 566    pub fn loaded_text_thread_for_id(
 567        &self,
 568        id: &TextThreadId,
 569        cx: &App,
 570    ) -> Option<Entity<TextThread>> {
 571        self.text_threads.iter().find_map(|text_thread| {
 572            let text_thread = text_thread.upgrade()?;
 573            if text_thread.read(cx).id() == id {
 574                Some(text_thread)
 575            } else {
 576                None
 577            }
 578        })
 579    }
 580
 581    pub fn open_remote(
 582        &mut self,
 583        text_thread_id: TextThreadId,
 584        cx: &mut Context<Self>,
 585    ) -> Task<Result<Entity<TextThread>>> {
 586        let Some(project) = self.project.upgrade() else {
 587            return Task::ready(Err(anyhow::anyhow!("project was dropped")));
 588        };
 589        let project = project.read(cx);
 590        let Some(project_id) = project.remote_id() else {
 591            return Task::ready(Err(anyhow::anyhow!("project was not remote")));
 592        };
 593
 594        if let Some(context) = self.loaded_text_thread_for_id(&text_thread_id, cx) {
 595            return Task::ready(Ok(context));
 596        }
 597
 598        let replica_id = project.replica_id();
 599        let capability = project.capability();
 600        let language_registry = self.languages.clone();
 601        let project = self.project.clone();
 602        let request = self.client.request(proto::OpenContext {
 603            project_id,
 604            context_id: text_thread_id.to_proto(),
 605        });
 606        let prompt_builder = self.prompt_builder.clone();
 607        let slash_commands = self.slash_commands.clone();
 608        cx.spawn(async move |this, cx| {
 609            let response = request.await?;
 610            let context_proto = response.context.context("invalid context")?;
 611            let text_thread = cx.new(|cx| {
 612                TextThread::new(
 613                    text_thread_id.clone(),
 614                    replica_id,
 615                    capability,
 616                    language_registry,
 617                    prompt_builder,
 618                    slash_commands,
 619                    Some(project),
 620                    cx,
 621                )
 622            });
 623            let operations = cx
 624                .background_spawn(async move {
 625                    context_proto
 626                        .operations
 627                        .into_iter()
 628                        .map(TextThreadOperation::from_proto)
 629                        .collect::<Result<Vec<_>>>()
 630                })
 631                .await?;
 632            text_thread.update(cx, |context, cx| context.apply_ops(operations, cx));
 633            this.update(cx, |this, cx| {
 634                if let Some(existing_context) = this.loaded_text_thread_for_id(&text_thread_id, cx)
 635                {
 636                    existing_context
 637                } else {
 638                    this.register_text_thread(&text_thread, cx);
 639                    this.synchronize_contexts(cx);
 640                    text_thread
 641                }
 642            })
 643        })
 644    }
 645
 646    fn register_text_thread(&mut self, text_thread: &Entity<TextThread>, cx: &mut Context<Self>) {
 647        let handle = if self.project_is_shared {
 648            TextThreadHandle::Strong(text_thread.clone())
 649        } else {
 650            TextThreadHandle::Weak(text_thread.downgrade())
 651        };
 652        self.text_threads.push(handle);
 653        self.advertise_contexts(cx);
 654        cx.subscribe(text_thread, Self::handle_context_event)
 655            .detach();
 656    }
 657
 658    fn handle_context_event(
 659        &mut self,
 660        text_thread: Entity<TextThread>,
 661        event: &TextThreadEvent,
 662        cx: &mut Context<Self>,
 663    ) {
 664        let Some(project) = self.project.upgrade() else {
 665            return;
 666        };
 667        let Some(project_id) = project.read(cx).remote_id() else {
 668            return;
 669        };
 670
 671        match event {
 672            TextThreadEvent::SummaryChanged => {
 673                self.advertise_contexts(cx);
 674            }
 675            TextThreadEvent::PathChanged { old_path, new_path } => {
 676                if let Some(old_path) = old_path.as_ref() {
 677                    for metadata in &mut self.text_threads_metadata {
 678                        if &metadata.path == old_path {
 679                            metadata.path = new_path.clone();
 680                            break;
 681                        }
 682                    }
 683                }
 684            }
 685            TextThreadEvent::Operation(operation) => {
 686                let context_id = text_thread.read(cx).id().to_proto();
 687                let operation = operation.to_proto();
 688                self.client
 689                    .send(proto::UpdateContext {
 690                        project_id,
 691                        context_id,
 692                        operation: Some(operation),
 693                    })
 694                    .log_err();
 695            }
 696            _ => {}
 697        }
 698    }
 699
 700    fn advertise_contexts(&self, cx: &App) {
 701        let Some(project) = self.project.upgrade() else {
 702            return;
 703        };
 704        let Some(project_id) = project.read(cx).remote_id() else {
 705            return;
 706        };
 707        // For now, only the host can advertise their open contexts.
 708        if project.read(cx).is_via_collab() {
 709            return;
 710        }
 711
 712        let contexts = self
 713            .text_threads
 714            .iter()
 715            .rev()
 716            .filter_map(|text_thread| {
 717                let text_thread = text_thread.upgrade()?.read(cx);
 718                if text_thread.replica_id() == ReplicaId::default() {
 719                    Some(proto::ContextMetadata {
 720                        context_id: text_thread.id().to_proto(),
 721                        summary: text_thread
 722                            .summary()
 723                            .content()
 724                            .map(|summary| summary.text.clone()),
 725                    })
 726                } else {
 727                    None
 728                }
 729            })
 730            .collect();
 731        self.client
 732            .send(proto::AdvertiseContexts {
 733                project_id,
 734                contexts,
 735            })
 736            .ok();
 737    }
 738
 739    fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
 740        let Some(project) = self.project.upgrade() else {
 741            return;
 742        };
 743        let Some(project_id) = project.read(cx).remote_id() else {
 744            return;
 745        };
 746
 747        let text_threads = self
 748            .text_threads
 749            .iter()
 750            .filter_map(|text_thread| {
 751                let text_thread = text_thread.upgrade()?.read(cx);
 752                if text_thread.replica_id() != ReplicaId::default() {
 753                    Some(text_thread.version(cx).to_proto(text_thread.id().clone()))
 754                } else {
 755                    None
 756                }
 757            })
 758            .collect();
 759
 760        let client = self.client.clone();
 761        let request = self.client.request(proto::SynchronizeContexts {
 762            project_id,
 763            contexts: text_threads,
 764        });
 765        cx.spawn(async move |this, cx| {
 766            let response = request.await?;
 767
 768            let mut text_thread_ids = Vec::new();
 769            let mut operations = Vec::new();
 770            this.read_with(cx, |this, cx| {
 771                for context_version_proto in response.contexts {
 772                    let text_thread_version = TextThreadVersion::from_proto(&context_version_proto);
 773                    let text_thread_id = TextThreadId::from_proto(context_version_proto.context_id);
 774                    if let Some(text_thread) = this.loaded_text_thread_for_id(&text_thread_id, cx) {
 775                        text_thread_ids.push(text_thread_id);
 776                        operations
 777                            .push(text_thread.read(cx).serialize_ops(&text_thread_version, cx));
 778                    }
 779                }
 780            })?;
 781
 782            let operations = futures::future::join_all(operations).await;
 783            for (context_id, operations) in text_thread_ids.into_iter().zip(operations) {
 784                for operation in operations {
 785                    client.send(proto::UpdateContext {
 786                        project_id,
 787                        context_id: context_id.to_proto(),
 788                        operation: Some(operation),
 789                    })?;
 790                }
 791            }
 792
 793            anyhow::Ok(())
 794        })
 795        .detach_and_log_err(cx);
 796    }
 797
 798    pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedTextThreadMetadata>> {
 799        let metadata = self.text_threads_metadata.clone();
 800        let executor = cx.background_executor().clone();
 801        cx.background_spawn(async move {
 802            if query.is_empty() {
 803                metadata
 804            } else {
 805                let candidates = metadata
 806                    .iter()
 807                    .enumerate()
 808                    .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
 809                    .collect::<Vec<_>>();
 810                let matches = fuzzy::match_strings(
 811                    &candidates,
 812                    &query,
 813                    false,
 814                    true,
 815                    100,
 816                    &Default::default(),
 817                    executor,
 818                )
 819                .await;
 820
 821                matches
 822                    .into_iter()
 823                    .map(|mat| metadata[mat.candidate_id].clone())
 824                    .collect()
 825            }
 826        })
 827    }
 828
 829    fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 830        let fs = self.fs.clone();
 831        cx.spawn(async move |this, cx| {
 832            if *ZED_STATELESS {
 833                return Ok(());
 834            }
 835            fs.create_dir(text_threads_dir()).await?;
 836
 837            let mut paths = fs.read_dir(text_threads_dir()).await?;
 838            let mut contexts = Vec::<SavedTextThreadMetadata>::new();
 839            while let Some(path) = paths.next().await {
 840                let path = path?;
 841                if path.extension() != Some(OsStr::new("json")) {
 842                    continue;
 843                }
 844
 845                static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
 846                    LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
 847
 848                let metadata = fs.metadata(&path).await?;
 849                if let Some((file_name, metadata)) = path
 850                    .file_name()
 851                    .and_then(|name| name.to_str())
 852                    .zip(metadata)
 853                {
 854                    // This is used to filter out contexts saved by the new assistant.
 855                    if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
 856                        continue;
 857                    }
 858
 859                    if let Some(title) = ASSISTANT_CONTEXT_REGEX
 860                        .replace(file_name, "")
 861                        .lines()
 862                        .next()
 863                    {
 864                        contexts.push(SavedTextThreadMetadata {
 865                            title: title.to_string().into(),
 866                            path: path.into(),
 867                            mtime: metadata.mtime.timestamp_for_user().into(),
 868                        });
 869                    }
 870                }
 871            }
 872            contexts.sort_unstable_by_key(|text_thread| Reverse(text_thread.mtime));
 873
 874            this.update(cx, |this, cx| {
 875                this.text_threads_metadata = contexts;
 876                cx.notify();
 877            })
 878        })
 879    }
 880
 881    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 882        let Some(project) = self.project.upgrade() else {
 883            return;
 884        };
 885        let context_server_store = project.read(cx).context_server_store();
 886        cx.subscribe(&context_server_store, Self::handle_context_server_event)
 887            .detach();
 888
 889        // Check for any servers that were already running before the handler was registered
 890        for server in context_server_store.read(cx).running_servers() {
 891            self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
 892        }
 893    }
 894
 895    fn handle_context_server_event(
 896        &mut self,
 897        context_server_store: Entity<ContextServerStore>,
 898        event: &project::context_server_store::Event,
 899        cx: &mut Context<Self>,
 900    ) {
 901        match event {
 902            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 903                match status {
 904                    ContextServerStatus::Running => {
 905                        self.load_context_server_slash_commands(
 906                            server_id.clone(),
 907                            context_server_store,
 908                            cx,
 909                        );
 910                    }
 911                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 912                        if let Some(slash_command_ids) =
 913                            self.context_server_slash_command_ids.remove(server_id)
 914                        {
 915                            self.slash_commands.remove(&slash_command_ids);
 916                        }
 917                    }
 918                    _ => {}
 919                }
 920            }
 921        }
 922    }
 923
 924    fn load_context_server_slash_commands(
 925        &self,
 926        server_id: ContextServerId,
 927        context_server_store: Entity<ContextServerStore>,
 928        cx: &mut Context<Self>,
 929    ) {
 930        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
 931            return;
 932        };
 933        let slash_command_working_set = self.slash_commands.clone();
 934        cx.spawn(async move |this, cx| {
 935            let Some(protocol) = server.client() else {
 936                return;
 937            };
 938
 939            if protocol.capable(context_server::protocol::ServerCapability::Prompts)
 940                && let Some(response) = protocol
 941                    .request::<context_server::types::requests::PromptsList>(())
 942                    .await
 943                    .log_err()
 944            {
 945                let slash_command_ids = response
 946                    .prompts
 947                    .into_iter()
 948                    .filter(assistant_slash_commands::acceptable_prompt)
 949                    .map(|prompt| {
 950                        log::info!("registering context server command: {:?}", prompt.name);
 951                        slash_command_working_set.insert(Arc::new(
 952                            assistant_slash_commands::ContextServerSlashCommand::new(
 953                                context_server_store.clone(),
 954                                server.id(),
 955                                prompt,
 956                            ),
 957                        ))
 958                    })
 959                    .collect::<Vec<_>>();
 960
 961                this.update(cx, |this, _cx| {
 962                    this.context_server_slash_command_ids
 963                        .insert(server_id.clone(), slash_command_ids);
 964                })
 965                .log_err();
 966            }
 967        })
 968        .detach();
 969    }
 970}
 971
 972#[cfg(test)]
 973mod tests {
 974    use super::*;
 975    use fs::FakeFs;
 976    use language_model::LanguageModelRegistry;
 977    use project::Project;
 978    use serde_json::json;
 979    use settings::SettingsStore;
 980    use std::path::{Path, PathBuf};
 981    use std::sync::Arc;
 982
 983    fn init_test(cx: &mut gpui::TestAppContext) {
 984        cx.update(|cx| {
 985            let settings_store = SettingsStore::test(cx);
 986            prompt_store::init(cx);
 987            LanguageModelRegistry::test(cx);
 988            cx.set_global(settings_store);
 989        });
 990    }
 991
 992    #[gpui::test]
 993    async fn ordered_text_threads_sort_by_mtime(cx: &mut gpui::TestAppContext) {
 994        init_test(cx);
 995
 996        let fs = FakeFs::new(cx.background_executor.clone());
 997        fs.insert_tree("/root", json!({})).await;
 998
 999        let project = Project::test(fs, [Path::new("/root")], cx).await;
1000        let store = cx.new(|cx| TextThreadStore::fake(project, cx));
1001
1002        let now = chrono::Local::now();
1003        let older = SavedTextThreadMetadata {
1004            title: "older".into(),
1005            path: Arc::from(PathBuf::from("/root/older.zed.json")),
1006            mtime: now - chrono::TimeDelta::days(1),
1007        };
1008        let middle = SavedTextThreadMetadata {
1009            title: "middle".into(),
1010            path: Arc::from(PathBuf::from("/root/middle.zed.json")),
1011            mtime: now - chrono::TimeDelta::hours(1),
1012        };
1013        let newer = SavedTextThreadMetadata {
1014            title: "newer".into(),
1015            path: Arc::from(PathBuf::from("/root/newer.zed.json")),
1016            mtime: now,
1017        };
1018
1019        store.update(cx, |store, _| {
1020            store.text_threads_metadata = vec![middle, older, newer];
1021        });
1022
1023        let ordered = store.read_with(cx, |store, _| {
1024            store
1025                .ordered_text_threads()
1026                .map(|entry| entry.title.to_string())
1027                .collect::<Vec<_>>()
1028        });
1029
1030        assert_eq!(ordered, vec!["newer", "middle", "older"]);
1031    }
1032
1033    #[gpui::test]
1034    async fn has_saved_text_threads_reflects_metadata(cx: &mut gpui::TestAppContext) {
1035        init_test(cx);
1036
1037        let fs = FakeFs::new(cx.background_executor.clone());
1038        fs.insert_tree("/root", json!({})).await;
1039
1040        let project = Project::test(fs, [Path::new("/root")], cx).await;
1041        let store = cx.new(|cx| TextThreadStore::fake(project, cx));
1042
1043        assert!(!store.read_with(cx, |store, _| store.has_saved_text_threads()));
1044
1045        store.update(cx, |store, _| {
1046            store.text_threads_metadata = vec![SavedTextThreadMetadata {
1047                title: "thread".into(),
1048                path: Arc::from(PathBuf::from("/root/thread.zed.json")),
1049                mtime: chrono::Local::now(),
1050            }];
1051        });
1052
1053        assert!(store.read_with(cx, |store, _| store.has_saved_text_threads()));
1054    }
1055
1056    #[gpui::test]
1057    async fn delete_all_local_clears_metadata_and_files(cx: &mut gpui::TestAppContext) {
1058        init_test(cx);
1059
1060        let fs = FakeFs::new(cx.background_executor.clone());
1061        fs.insert_tree("/root", json!({})).await;
1062
1063        let thread_a = PathBuf::from("/root/thread-a.zed.json");
1064        let thread_b = PathBuf::from("/root/thread-b.zed.json");
1065        fs.touch_path(&thread_a).await;
1066        fs.touch_path(&thread_b).await;
1067
1068        let project = Project::test(fs.clone(), [Path::new("/root")], cx).await;
1069        let store = cx.new(|cx| TextThreadStore::fake(project, cx));
1070
1071        let now = chrono::Local::now();
1072        store.update(cx, |store, cx| {
1073            store.create(cx);
1074            store.text_threads_metadata = vec![
1075                SavedTextThreadMetadata {
1076                    title: "thread-a".into(),
1077                    path: Arc::from(thread_a.clone()),
1078                    mtime: now,
1079                },
1080                SavedTextThreadMetadata {
1081                    title: "thread-b".into(),
1082                    path: Arc::from(thread_b.clone()),
1083                    mtime: now - chrono::TimeDelta::seconds(1),
1084                },
1085            ];
1086        });
1087
1088        let task = store.update(cx, |store, cx| store.delete_all_local(cx));
1089        task.await.unwrap();
1090
1091        assert!(!store.read_with(cx, |store, _| store.has_saved_text_threads()));
1092        assert_eq!(store.read_with(cx, |store, _| store.text_threads.len()), 0);
1093        assert!(fs.metadata(&thread_a).await.unwrap().is_none());
1094        assert!(fs.metadata(&thread_b).await.unwrap().is_none());
1095    }
1096}