text_thread_store.rs

   1use crate::{
   2    SavedTextThread, SavedTextThreadMetadata, TextThread, TextThreadEvent, TextThreadId,
   3    TextThreadOperation, TextThreadVersion,
   4};
   5use anyhow::{Context as _, Result};
   6use assistant_slash_command::{SlashCommandId, SlashCommandWorkingSet};
   7use client::{Client, TypedEnvelope, proto};
   8use clock::ReplicaId;
   9use collections::HashMap;
  10use context_server::ContextServerId;
  11use fs::{Fs, RemoveOptions};
  12use futures::StreamExt;
  13use fuzzy::StringMatchCandidate;
  14use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity};
  15use itertools::Itertools;
  16use language::LanguageRegistry;
  17use paths::text_threads_dir;
  18use project::{
  19    Project,
  20    context_server_store::{ContextServerStatus, ContextServerStore},
  21};
  22use prompt_store::PromptBuilder;
  23use regex::Regex;
  24use rpc::AnyProtoClient;
  25use std::sync::LazyLock;
  26use std::{cmp::Reverse, ffi::OsStr, mem, path::Path, sync::Arc, time::Duration};
  27use util::{ResultExt, TryFutureExt};
  28use zed_env_vars::ZED_STATELESS;
  29
  30pub(crate) fn init(client: &AnyProtoClient) {
  31    client.add_entity_message_handler(TextThreadStore::handle_advertise_contexts);
  32    client.add_entity_request_handler(TextThreadStore::handle_open_context);
  33    client.add_entity_request_handler(TextThreadStore::handle_create_context);
  34    client.add_entity_message_handler(TextThreadStore::handle_update_context);
  35    client.add_entity_request_handler(TextThreadStore::handle_synchronize_contexts);
  36}
  37
  38#[derive(Clone)]
  39pub struct RemoteTextThreadMetadata {
  40    pub id: TextThreadId,
  41    pub summary: Option<String>,
  42}
  43
  44pub struct TextThreadStore {
  45    text_threads: Vec<TextThreadHandle>,
  46    text_threads_metadata: Vec<SavedTextThreadMetadata>,
  47    context_server_slash_command_ids: HashMap<ContextServerId, Vec<SlashCommandId>>,
  48    host_text_threads: Vec<RemoteTextThreadMetadata>,
  49    fs: Arc<dyn Fs>,
  50    languages: Arc<LanguageRegistry>,
  51    slash_commands: Arc<SlashCommandWorkingSet>,
  52    _watch_updates: Task<Option<()>>,
  53    client: Arc<Client>,
  54    project: WeakEntity<Project>,
  55    project_is_shared: bool,
  56    client_subscription: Option<client::Subscription>,
  57    _project_subscriptions: Vec<gpui::Subscription>,
  58    prompt_builder: Arc<PromptBuilder>,
  59}
  60
  61enum TextThreadHandle {
  62    Weak(WeakEntity<TextThread>),
  63    Strong(Entity<TextThread>),
  64}
  65
  66impl TextThreadHandle {
  67    fn upgrade(&self) -> Option<Entity<TextThread>> {
  68        match self {
  69            TextThreadHandle::Weak(weak) => weak.upgrade(),
  70            TextThreadHandle::Strong(strong) => Some(strong.clone()),
  71        }
  72    }
  73
  74    fn downgrade(&self) -> WeakEntity<TextThread> {
  75        match self {
  76            TextThreadHandle::Weak(weak) => weak.clone(),
  77            TextThreadHandle::Strong(strong) => strong.downgrade(),
  78        }
  79    }
  80}
  81
  82impl TextThreadStore {
  83    pub fn new(
  84        project: Entity<Project>,
  85        prompt_builder: Arc<PromptBuilder>,
  86        slash_commands: Arc<SlashCommandWorkingSet>,
  87        cx: &mut App,
  88    ) -> Task<Result<Entity<Self>>> {
  89        let fs = project.read(cx).fs().clone();
  90        let languages = project.read(cx).languages().clone();
  91        cx.spawn(async move |cx| {
  92            const CONTEXT_WATCH_DURATION: Duration = Duration::from_millis(100);
  93            let (mut events, _) = fs.watch(text_threads_dir(), CONTEXT_WATCH_DURATION).await;
  94
  95            let this = cx.new(|cx: &mut Context<Self>| {
  96                let mut this = Self {
  97                    text_threads: Vec::new(),
  98                    text_threads_metadata: Vec::new(),
  99                    context_server_slash_command_ids: HashMap::default(),
 100                    host_text_threads: Vec::new(),
 101                    fs,
 102                    languages,
 103                    slash_commands,
 104                    _watch_updates: cx.spawn(async move |this, cx| {
 105                        async move {
 106                            while events.next().await.is_some() {
 107                                this.update(cx, |this, cx| this.reload(cx))?.await.log_err();
 108                            }
 109                            anyhow::Ok(())
 110                        }
 111                        .log_err()
 112                        .await
 113                    }),
 114                    client_subscription: None,
 115                    _project_subscriptions: vec![
 116                        cx.subscribe(&project, Self::handle_project_event),
 117                    ],
 118                    project_is_shared: false,
 119                    client: project.read(cx).client(),
 120                    project: project.downgrade(),
 121                    prompt_builder,
 122                };
 123                this.handle_project_shared(cx);
 124                this.synchronize_contexts(cx);
 125                this.register_context_server_handlers(cx);
 126                this.reload(cx).detach_and_log_err(cx);
 127                this
 128            });
 129
 130            Ok(this)
 131        })
 132    }
 133
 134    #[cfg(any(test, feature = "test-support"))]
 135    pub fn fake(project: Entity<Project>, cx: &mut Context<Self>) -> Self {
 136        Self {
 137            text_threads: Default::default(),
 138            text_threads_metadata: Default::default(),
 139            context_server_slash_command_ids: Default::default(),
 140            host_text_threads: Default::default(),
 141            fs: project.read(cx).fs().clone(),
 142            languages: project.read(cx).languages().clone(),
 143            slash_commands: Arc::default(),
 144            _watch_updates: Task::ready(None),
 145            client: project.read(cx).client(),
 146            project: project.downgrade(),
 147            project_is_shared: false,
 148            client_subscription: None,
 149            _project_subscriptions: Default::default(),
 150            prompt_builder: Arc::new(PromptBuilder::new(None).unwrap()),
 151        }
 152    }
 153
 154    async fn handle_advertise_contexts(
 155        this: Entity<Self>,
 156        envelope: TypedEnvelope<proto::AdvertiseContexts>,
 157        mut cx: AsyncApp,
 158    ) -> Result<()> {
 159        this.update(&mut cx, |this, cx| {
 160            this.host_text_threads = envelope
 161                .payload
 162                .contexts
 163                .into_iter()
 164                .map(|text_thread| RemoteTextThreadMetadata {
 165                    id: TextThreadId::from_proto(text_thread.context_id),
 166                    summary: text_thread.summary,
 167                })
 168                .collect();
 169            cx.notify();
 170        });
 171        Ok(())
 172    }
 173
 174    async fn handle_open_context(
 175        this: Entity<Self>,
 176        envelope: TypedEnvelope<proto::OpenContext>,
 177        mut cx: AsyncApp,
 178    ) -> Result<proto::OpenContextResponse> {
 179        let context_id = TextThreadId::from_proto(envelope.payload.context_id);
 180        let operations = this.update(&mut cx, |this, cx| {
 181            let project = this.project.upgrade().context("project not found")?;
 182
 183            anyhow::ensure!(
 184                !project.read(cx).is_via_collab(),
 185                "only the host contexts can be opened"
 186            );
 187
 188            let text_thread = this
 189                .loaded_text_thread_for_id(&context_id, cx)
 190                .context("context not found")?;
 191            anyhow::ensure!(
 192                text_thread.read(cx).replica_id() == ReplicaId::default(),
 193                "context must be opened via the host"
 194            );
 195
 196            anyhow::Ok(
 197                text_thread
 198                    .read(cx)
 199                    .serialize_ops(&TextThreadVersion::default(), cx),
 200            )
 201        })?;
 202        let operations = operations.await;
 203        Ok(proto::OpenContextResponse {
 204            context: Some(proto::Context { operations }),
 205        })
 206    }
 207
 208    async fn handle_create_context(
 209        this: Entity<Self>,
 210        _: TypedEnvelope<proto::CreateContext>,
 211        mut cx: AsyncApp,
 212    ) -> Result<proto::CreateContextResponse> {
 213        let (context_id, operations) = this.update(&mut cx, |this, cx| {
 214            let project = this.project.upgrade().context("project not found")?;
 215            anyhow::ensure!(
 216                !project.read(cx).is_via_collab(),
 217                "can only create contexts as the host"
 218            );
 219
 220            let text_thread = this.create(cx);
 221            let context_id = text_thread.read(cx).id().clone();
 222
 223            anyhow::Ok((
 224                context_id,
 225                text_thread
 226                    .read(cx)
 227                    .serialize_ops(&TextThreadVersion::default(), cx),
 228            ))
 229        })?;
 230        let operations = operations.await;
 231        Ok(proto::CreateContextResponse {
 232            context_id: context_id.to_proto(),
 233            context: Some(proto::Context { operations }),
 234        })
 235    }
 236
 237    async fn handle_update_context(
 238        this: Entity<Self>,
 239        envelope: TypedEnvelope<proto::UpdateContext>,
 240        mut cx: AsyncApp,
 241    ) -> Result<()> {
 242        this.update(&mut cx, |this, cx| {
 243            let context_id = TextThreadId::from_proto(envelope.payload.context_id);
 244            if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
 245                let operation_proto = envelope.payload.operation.context("invalid operation")?;
 246                let operation = TextThreadOperation::from_proto(operation_proto)?;
 247                text_thread.update(cx, |text_thread, cx| text_thread.apply_ops([operation], cx));
 248            }
 249            Ok(())
 250        })
 251    }
 252
 253    async fn handle_synchronize_contexts(
 254        this: Entity<Self>,
 255        envelope: TypedEnvelope<proto::SynchronizeContexts>,
 256        mut cx: AsyncApp,
 257    ) -> Result<proto::SynchronizeContextsResponse> {
 258        this.update(&mut cx, |this, cx| {
 259            let project = this.project.upgrade().context("project not found")?;
 260            anyhow::ensure!(
 261                !project.read(cx).is_via_collab(),
 262                "only the host can synchronize contexts"
 263            );
 264
 265            let mut local_versions = Vec::new();
 266            for remote_version_proto in envelope.payload.contexts {
 267                let remote_version = TextThreadVersion::from_proto(&remote_version_proto);
 268                let context_id = TextThreadId::from_proto(remote_version_proto.context_id);
 269                if let Some(text_thread) = this.loaded_text_thread_for_id(&context_id, cx) {
 270                    let text_thread = text_thread.read(cx);
 271                    let operations = text_thread.serialize_ops(&remote_version, cx);
 272                    local_versions.push(text_thread.version(cx).to_proto(context_id.clone()));
 273                    let client = this.client.clone();
 274                    let project_id = envelope.payload.project_id;
 275                    cx.background_spawn(async move {
 276                        let operations = operations.await;
 277                        for operation in operations {
 278                            client.send(proto::UpdateContext {
 279                                project_id,
 280                                context_id: context_id.to_proto(),
 281                                operation: Some(operation),
 282                            })?;
 283                        }
 284                        anyhow::Ok(())
 285                    })
 286                    .detach_and_log_err(cx);
 287                }
 288            }
 289
 290            this.advertise_contexts(cx);
 291
 292            anyhow::Ok(proto::SynchronizeContextsResponse {
 293                contexts: local_versions,
 294            })
 295        })
 296    }
 297
 298    fn handle_project_shared(&mut self, cx: &mut Context<Self>) {
 299        let Some(project) = self.project.upgrade() else {
 300            return;
 301        };
 302
 303        let is_shared = project.read(cx).is_shared();
 304        let was_shared = mem::replace(&mut self.project_is_shared, is_shared);
 305        if is_shared == was_shared {
 306            return;
 307        }
 308
 309        if is_shared {
 310            self.text_threads.retain_mut(|text_thread| {
 311                if let Some(strong_context) = text_thread.upgrade() {
 312                    *text_thread = TextThreadHandle::Strong(strong_context);
 313                    true
 314                } else {
 315                    false
 316                }
 317            });
 318            let remote_id = project.read(cx).remote_id().unwrap();
 319            self.client_subscription = self
 320                .client
 321                .subscribe_to_entity(remote_id)
 322                .log_err()
 323                .map(|subscription| subscription.set_entity(&cx.entity(), &cx.to_async()));
 324            self.advertise_contexts(cx);
 325        } else {
 326            self.client_subscription = None;
 327        }
 328    }
 329
 330    fn handle_project_event(
 331        &mut self,
 332        _project: Entity<Project>,
 333        event: &project::Event,
 334        cx: &mut Context<Self>,
 335    ) {
 336        match event {
 337            project::Event::RemoteIdChanged(_) => {
 338                self.handle_project_shared(cx);
 339            }
 340            project::Event::Reshared => {
 341                self.advertise_contexts(cx);
 342            }
 343            project::Event::HostReshared | project::Event::Rejoined => {
 344                self.synchronize_contexts(cx);
 345            }
 346            project::Event::DisconnectedFromHost => {
 347                self.text_threads.retain_mut(|text_thread| {
 348                    if let Some(strong_context) = text_thread.upgrade() {
 349                        *text_thread = TextThreadHandle::Weak(text_thread.downgrade());
 350                        strong_context.update(cx, |text_thread, cx| {
 351                            if text_thread.replica_id() != ReplicaId::default() {
 352                                text_thread.set_capability(language::Capability::ReadOnly, cx);
 353                            }
 354                        });
 355                        true
 356                    } else {
 357                        false
 358                    }
 359                });
 360                self.host_text_threads.clear();
 361                cx.notify();
 362            }
 363            _ => {}
 364        }
 365    }
 366
 367    /// Returns saved threads ordered by `mtime` descending (newest first).
 368    pub fn ordered_text_threads(&self) -> impl Iterator<Item = &SavedTextThreadMetadata> {
 369        self.text_threads_metadata
 370            .iter()
 371            .sorted_by(|a, b| b.mtime.cmp(&a.mtime))
 372    }
 373
 374    pub fn has_saved_text_threads(&self) -> bool {
 375        !self.text_threads_metadata.is_empty()
 376    }
 377
 378    pub fn host_text_threads(&self) -> impl Iterator<Item = &RemoteTextThreadMetadata> {
 379        self.host_text_threads.iter()
 380    }
 381
 382    pub fn create(&mut self, cx: &mut Context<Self>) -> Entity<TextThread> {
 383        let context = cx.new(|cx| {
 384            TextThread::local(
 385                self.languages.clone(),
 386                self.prompt_builder.clone(),
 387                self.slash_commands.clone(),
 388                cx,
 389            )
 390        });
 391        self.register_text_thread(&context, cx);
 392        context
 393    }
 394
 395    pub fn create_remote(&mut self, cx: &mut Context<Self>) -> Task<Result<Entity<TextThread>>> {
 396        let Some(project) = self.project.upgrade() else {
 397            return Task::ready(Err(anyhow::anyhow!("project was dropped")));
 398        };
 399        let project = project.read(cx);
 400        let Some(project_id) = project.remote_id() else {
 401            return Task::ready(Err(anyhow::anyhow!("project was not remote")));
 402        };
 403
 404        let replica_id = project.replica_id();
 405        let capability = project.capability();
 406        let language_registry = self.languages.clone();
 407
 408        let prompt_builder = self.prompt_builder.clone();
 409        let slash_commands = self.slash_commands.clone();
 410        let request = self.client.request(proto::CreateContext { project_id });
 411        cx.spawn(async move |this, cx| {
 412            let response = request.await?;
 413            let context_id = TextThreadId::from_proto(response.context_id);
 414            let context_proto = response.context.context("invalid context")?;
 415            let text_thread = cx.new(|cx| {
 416                TextThread::new(
 417                    context_id.clone(),
 418                    replica_id,
 419                    capability,
 420                    language_registry,
 421                    prompt_builder,
 422                    slash_commands,
 423                    cx,
 424                )
 425            });
 426            let operations = cx
 427                .background_spawn(async move {
 428                    context_proto
 429                        .operations
 430                        .into_iter()
 431                        .map(TextThreadOperation::from_proto)
 432                        .collect::<Result<Vec<_>>>()
 433                })
 434                .await?;
 435            text_thread.update(cx, |context, cx| context.apply_ops(operations, cx));
 436            this.update(cx, |this, cx| {
 437                if let Some(existing_context) = this.loaded_text_thread_for_id(&context_id, cx) {
 438                    existing_context
 439                } else {
 440                    this.register_text_thread(&text_thread, cx);
 441                    this.synchronize_contexts(cx);
 442                    text_thread
 443                }
 444            })
 445        })
 446    }
 447
 448    pub fn open_local(
 449        &mut self,
 450        path: Arc<Path>,
 451        cx: &Context<Self>,
 452    ) -> Task<Result<Entity<TextThread>>> {
 453        if let Some(existing_context) = self.loaded_text_thread_for_path(&path, cx) {
 454            return Task::ready(Ok(existing_context));
 455        }
 456
 457        let fs = self.fs.clone();
 458        let languages = self.languages.clone();
 459        let load = cx.background_spawn({
 460            let path = path.clone();
 461            async move {
 462                let saved_context = fs.load(&path).await?;
 463                SavedTextThread::from_json(&saved_context)
 464            }
 465        });
 466        let prompt_builder = self.prompt_builder.clone();
 467        let slash_commands = self.slash_commands.clone();
 468
 469        cx.spawn(async move |this, cx| {
 470            let saved_context = load.await?;
 471            let context = cx.new(|cx| {
 472                TextThread::deserialize(
 473                    saved_context,
 474                    path.clone(),
 475                    languages,
 476                    prompt_builder,
 477                    slash_commands,
 478                    cx,
 479                )
 480            });
 481            this.update(cx, |this, cx| {
 482                if let Some(existing_context) = this.loaded_text_thread_for_path(&path, cx) {
 483                    existing_context
 484                } else {
 485                    this.register_text_thread(&context, cx);
 486                    context
 487                }
 488            })
 489        })
 490    }
 491
 492    pub fn delete_local(&mut self, path: Arc<Path>, cx: &mut Context<Self>) -> Task<Result<()>> {
 493        let fs = self.fs.clone();
 494
 495        cx.spawn(async move |this, cx| {
 496            fs.remove_file(
 497                &path,
 498                RemoveOptions {
 499                    recursive: false,
 500                    ignore_if_not_exists: true,
 501                },
 502            )
 503            .await?;
 504
 505            this.update(cx, |this, cx| {
 506                this.text_threads.retain(|text_thread| {
 507                    text_thread
 508                        .upgrade()
 509                        .and_then(|text_thread| text_thread.read(cx).path())
 510                        != Some(&path)
 511                });
 512                this.text_threads_metadata
 513                    .retain(|text_thread| text_thread.path.as_ref() != path.as_ref());
 514            })?;
 515
 516            Ok(())
 517        })
 518    }
 519
 520    pub fn delete_all_local(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 521        let fs = self.fs.clone();
 522        let paths = self
 523            .text_threads_metadata
 524            .iter()
 525            .map(|metadata| metadata.path.clone())
 526            .collect::<Vec<_>>();
 527
 528        cx.spawn(async move |this, cx| {
 529            for path in paths {
 530                fs.remove_file(
 531                    &path,
 532                    RemoveOptions {
 533                        recursive: false,
 534                        ignore_if_not_exists: true,
 535                    },
 536                )
 537                .await?;
 538            }
 539
 540            this.update(cx, |this, cx| {
 541                this.text_threads.clear();
 542                this.text_threads_metadata.clear();
 543                cx.notify();
 544            })?;
 545
 546            Ok(())
 547        })
 548    }
 549
 550    fn loaded_text_thread_for_path(&self, path: &Path, cx: &App) -> Option<Entity<TextThread>> {
 551        self.text_threads.iter().find_map(|text_thread| {
 552            let text_thread = text_thread.upgrade()?;
 553            if text_thread.read(cx).path().map(Arc::as_ref) == Some(path) {
 554                Some(text_thread)
 555            } else {
 556                None
 557            }
 558        })
 559    }
 560
 561    pub fn loaded_text_thread_for_id(
 562        &self,
 563        id: &TextThreadId,
 564        cx: &App,
 565    ) -> Option<Entity<TextThread>> {
 566        self.text_threads.iter().find_map(|text_thread| {
 567            let text_thread = text_thread.upgrade()?;
 568            if text_thread.read(cx).id() == id {
 569                Some(text_thread)
 570            } else {
 571                None
 572            }
 573        })
 574    }
 575
 576    pub fn open_remote(
 577        &mut self,
 578        text_thread_id: TextThreadId,
 579        cx: &mut Context<Self>,
 580    ) -> Task<Result<Entity<TextThread>>> {
 581        let Some(project) = self.project.upgrade() else {
 582            return Task::ready(Err(anyhow::anyhow!("project was dropped")));
 583        };
 584        let project = project.read(cx);
 585        let Some(project_id) = project.remote_id() else {
 586            return Task::ready(Err(anyhow::anyhow!("project was not remote")));
 587        };
 588
 589        if let Some(context) = self.loaded_text_thread_for_id(&text_thread_id, cx) {
 590            return Task::ready(Ok(context));
 591        }
 592
 593        let replica_id = project.replica_id();
 594        let capability = project.capability();
 595        let language_registry = self.languages.clone();
 596        let request = self.client.request(proto::OpenContext {
 597            project_id,
 598            context_id: text_thread_id.to_proto(),
 599        });
 600        let prompt_builder = self.prompt_builder.clone();
 601        let slash_commands = self.slash_commands.clone();
 602        cx.spawn(async move |this, cx| {
 603            let response = request.await?;
 604            let context_proto = response.context.context("invalid context")?;
 605            let text_thread = cx.new(|cx| {
 606                TextThread::new(
 607                    text_thread_id.clone(),
 608                    replica_id,
 609                    capability,
 610                    language_registry,
 611                    prompt_builder,
 612                    slash_commands,
 613                    cx,
 614                )
 615            });
 616            let operations = cx
 617                .background_spawn(async move {
 618                    context_proto
 619                        .operations
 620                        .into_iter()
 621                        .map(TextThreadOperation::from_proto)
 622                        .collect::<Result<Vec<_>>>()
 623                })
 624                .await?;
 625            text_thread.update(cx, |context, cx| context.apply_ops(operations, cx));
 626            this.update(cx, |this, cx| {
 627                if let Some(existing_context) = this.loaded_text_thread_for_id(&text_thread_id, cx)
 628                {
 629                    existing_context
 630                } else {
 631                    this.register_text_thread(&text_thread, cx);
 632                    this.synchronize_contexts(cx);
 633                    text_thread
 634                }
 635            })
 636        })
 637    }
 638
 639    fn register_text_thread(&mut self, text_thread: &Entity<TextThread>, cx: &mut Context<Self>) {
 640        let handle = if self.project_is_shared {
 641            TextThreadHandle::Strong(text_thread.clone())
 642        } else {
 643            TextThreadHandle::Weak(text_thread.downgrade())
 644        };
 645        self.text_threads.push(handle);
 646        self.advertise_contexts(cx);
 647        cx.subscribe(text_thread, Self::handle_context_event)
 648            .detach();
 649    }
 650
 651    fn handle_context_event(
 652        &mut self,
 653        text_thread: Entity<TextThread>,
 654        event: &TextThreadEvent,
 655        cx: &mut Context<Self>,
 656    ) {
 657        let Some(project) = self.project.upgrade() else {
 658            return;
 659        };
 660        let Some(project_id) = project.read(cx).remote_id() else {
 661            return;
 662        };
 663
 664        match event {
 665            TextThreadEvent::SummaryChanged => {
 666                self.advertise_contexts(cx);
 667            }
 668            TextThreadEvent::PathChanged { old_path, new_path } => {
 669                if let Some(old_path) = old_path.as_ref() {
 670                    for metadata in &mut self.text_threads_metadata {
 671                        if &metadata.path == old_path {
 672                            metadata.path = new_path.clone();
 673                            break;
 674                        }
 675                    }
 676                }
 677            }
 678            TextThreadEvent::Operation(operation) => {
 679                let context_id = text_thread.read(cx).id().to_proto();
 680                let operation = operation.to_proto();
 681                self.client
 682                    .send(proto::UpdateContext {
 683                        project_id,
 684                        context_id,
 685                        operation: Some(operation),
 686                    })
 687                    .log_err();
 688            }
 689            _ => {}
 690        }
 691    }
 692
 693    fn advertise_contexts(&self, cx: &App) {
 694        let Some(project) = self.project.upgrade() else {
 695            return;
 696        };
 697        let Some(project_id) = project.read(cx).remote_id() else {
 698            return;
 699        };
 700        // For now, only the host can advertise their open contexts.
 701        if project.read(cx).is_via_collab() {
 702            return;
 703        }
 704
 705        let contexts = self
 706            .text_threads
 707            .iter()
 708            .rev()
 709            .filter_map(|text_thread| {
 710                let text_thread = text_thread.upgrade()?.read(cx);
 711                if text_thread.replica_id() == ReplicaId::default() {
 712                    Some(proto::ContextMetadata {
 713                        context_id: text_thread.id().to_proto(),
 714                        summary: text_thread
 715                            .summary()
 716                            .content()
 717                            .map(|summary| summary.text.clone()),
 718                    })
 719                } else {
 720                    None
 721                }
 722            })
 723            .collect();
 724        self.client
 725            .send(proto::AdvertiseContexts {
 726                project_id,
 727                contexts,
 728            })
 729            .ok();
 730    }
 731
 732    fn synchronize_contexts(&mut self, cx: &mut Context<Self>) {
 733        let Some(project) = self.project.upgrade() else {
 734            return;
 735        };
 736        let Some(project_id) = project.read(cx).remote_id() else {
 737            return;
 738        };
 739
 740        let text_threads = self
 741            .text_threads
 742            .iter()
 743            .filter_map(|text_thread| {
 744                let text_thread = text_thread.upgrade()?.read(cx);
 745                if text_thread.replica_id() != ReplicaId::default() {
 746                    Some(text_thread.version(cx).to_proto(text_thread.id().clone()))
 747                } else {
 748                    None
 749                }
 750            })
 751            .collect();
 752
 753        let client = self.client.clone();
 754        let request = self.client.request(proto::SynchronizeContexts {
 755            project_id,
 756            contexts: text_threads,
 757        });
 758        cx.spawn(async move |this, cx| {
 759            let response = request.await?;
 760
 761            let mut text_thread_ids = Vec::new();
 762            let mut operations = Vec::new();
 763            this.read_with(cx, |this, cx| {
 764                for context_version_proto in response.contexts {
 765                    let text_thread_version = TextThreadVersion::from_proto(&context_version_proto);
 766                    let text_thread_id = TextThreadId::from_proto(context_version_proto.context_id);
 767                    if let Some(text_thread) = this.loaded_text_thread_for_id(&text_thread_id, cx) {
 768                        text_thread_ids.push(text_thread_id);
 769                        operations
 770                            .push(text_thread.read(cx).serialize_ops(&text_thread_version, cx));
 771                    }
 772                }
 773            })?;
 774
 775            let operations = futures::future::join_all(operations).await;
 776            for (context_id, operations) in text_thread_ids.into_iter().zip(operations) {
 777                for operation in operations {
 778                    client.send(proto::UpdateContext {
 779                        project_id,
 780                        context_id: context_id.to_proto(),
 781                        operation: Some(operation),
 782                    })?;
 783                }
 784            }
 785
 786            anyhow::Ok(())
 787        })
 788        .detach_and_log_err(cx);
 789    }
 790
 791    pub fn search(&self, query: String, cx: &App) -> Task<Vec<SavedTextThreadMetadata>> {
 792        let metadata = self.text_threads_metadata.clone();
 793        let executor = cx.background_executor().clone();
 794        cx.background_spawn(async move {
 795            if query.is_empty() {
 796                metadata
 797            } else {
 798                let candidates = metadata
 799                    .iter()
 800                    .enumerate()
 801                    .map(|(id, metadata)| StringMatchCandidate::new(id, &metadata.title))
 802                    .collect::<Vec<_>>();
 803                let matches = fuzzy::match_strings(
 804                    &candidates,
 805                    &query,
 806                    false,
 807                    true,
 808                    100,
 809                    &Default::default(),
 810                    executor,
 811                )
 812                .await;
 813
 814                matches
 815                    .into_iter()
 816                    .map(|mat| metadata[mat.candidate_id].clone())
 817                    .collect()
 818            }
 819        })
 820    }
 821
 822    fn reload(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 823        let fs = self.fs.clone();
 824        cx.spawn(async move |this, cx| {
 825            if *ZED_STATELESS {
 826                return Ok(());
 827            }
 828            fs.create_dir(text_threads_dir()).await?;
 829
 830            let mut paths = fs.read_dir(text_threads_dir()).await?;
 831            let mut contexts = Vec::<SavedTextThreadMetadata>::new();
 832            while let Some(path) = paths.next().await {
 833                let path = path?;
 834                if path.extension() != Some(OsStr::new("json")) {
 835                    continue;
 836                }
 837
 838                static ASSISTANT_CONTEXT_REGEX: LazyLock<Regex> =
 839                    LazyLock::new(|| Regex::new(r" - \d+.zed.json$").unwrap());
 840
 841                let metadata = fs.metadata(&path).await?;
 842                if let Some((file_name, metadata)) = path
 843                    .file_name()
 844                    .and_then(|name| name.to_str())
 845                    .zip(metadata)
 846                {
 847                    // This is used to filter out contexts saved by the new assistant.
 848                    if !ASSISTANT_CONTEXT_REGEX.is_match(file_name) {
 849                        continue;
 850                    }
 851
 852                    if let Some(title) = ASSISTANT_CONTEXT_REGEX
 853                        .replace(file_name, "")
 854                        .lines()
 855                        .next()
 856                    {
 857                        contexts.push(SavedTextThreadMetadata {
 858                            title: title.to_string().into(),
 859                            path: path.into(),
 860                            mtime: metadata.mtime.timestamp_for_user().into(),
 861                        });
 862                    }
 863                }
 864            }
 865            contexts.sort_unstable_by_key(|text_thread| Reverse(text_thread.mtime));
 866
 867            this.update(cx, |this, cx| {
 868                this.text_threads_metadata = contexts;
 869                cx.notify();
 870            })
 871        })
 872    }
 873
 874    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
 875        let Some(project) = self.project.upgrade() else {
 876            return;
 877        };
 878        let context_server_store = project.read(cx).context_server_store();
 879        cx.subscribe(&context_server_store, Self::handle_context_server_event)
 880            .detach();
 881
 882        // Check for any servers that were already running before the handler was registered
 883        for server in context_server_store.read(cx).running_servers() {
 884            self.load_context_server_slash_commands(server.id(), context_server_store.clone(), cx);
 885        }
 886    }
 887
 888    fn handle_context_server_event(
 889        &mut self,
 890        context_server_store: Entity<ContextServerStore>,
 891        event: &project::context_server_store::Event,
 892        cx: &mut Context<Self>,
 893    ) {
 894        match event {
 895            project::context_server_store::Event::ServerStatusChanged { server_id, status } => {
 896                match status {
 897                    ContextServerStatus::Running => {
 898                        self.load_context_server_slash_commands(
 899                            server_id.clone(),
 900                            context_server_store,
 901                            cx,
 902                        );
 903                    }
 904                    ContextServerStatus::Stopped | ContextServerStatus::Error(_) => {
 905                        if let Some(slash_command_ids) =
 906                            self.context_server_slash_command_ids.remove(server_id)
 907                        {
 908                            self.slash_commands.remove(&slash_command_ids);
 909                        }
 910                    }
 911                    _ => {}
 912                }
 913            }
 914        }
 915    }
 916
 917    fn load_context_server_slash_commands(
 918        &self,
 919        server_id: ContextServerId,
 920        context_server_store: Entity<ContextServerStore>,
 921        cx: &mut Context<Self>,
 922    ) {
 923        let Some(server) = context_server_store.read(cx).get_running_server(&server_id) else {
 924            return;
 925        };
 926        let slash_command_working_set = self.slash_commands.clone();
 927        cx.spawn(async move |this, cx| {
 928            let Some(protocol) = server.client() else {
 929                return;
 930            };
 931
 932            if protocol.capable(context_server::protocol::ServerCapability::Prompts)
 933                && let Some(response) = protocol
 934                    .request::<context_server::types::requests::PromptsList>(())
 935                    .await
 936                    .log_err()
 937            {
 938                let slash_command_ids = response
 939                    .prompts
 940                    .into_iter()
 941                    .filter(assistant_slash_commands::acceptable_prompt)
 942                    .map(|prompt| {
 943                        log::info!("registering context server command: {:?}", prompt.name);
 944                        slash_command_working_set.insert(Arc::new(
 945                            assistant_slash_commands::ContextServerSlashCommand::new(
 946                                context_server_store.clone(),
 947                                server.id(),
 948                                prompt,
 949                            ),
 950                        ))
 951                    })
 952                    .collect::<Vec<_>>();
 953
 954                this.update(cx, |this, _cx| {
 955                    this.context_server_slash_command_ids
 956                        .insert(server_id.clone(), slash_command_ids);
 957                })
 958                .log_err();
 959            }
 960        })
 961        .detach();
 962    }
 963}
 964
 965#[cfg(test)]
 966mod tests {
 967    use super::*;
 968    use fs::FakeFs;
 969    use language_model::LanguageModelRegistry;
 970    use project::Project;
 971    use serde_json::json;
 972    use settings::SettingsStore;
 973    use std::path::{Path, PathBuf};
 974    use std::sync::Arc;
 975
 976    fn init_test(cx: &mut gpui::TestAppContext) {
 977        cx.update(|cx| {
 978            let settings_store = SettingsStore::test(cx);
 979            prompt_store::init(cx);
 980            LanguageModelRegistry::test(cx);
 981            cx.set_global(settings_store);
 982        });
 983    }
 984
 985    #[gpui::test]
 986    async fn ordered_text_threads_sort_by_mtime(cx: &mut gpui::TestAppContext) {
 987        init_test(cx);
 988
 989        let fs = FakeFs::new(cx.background_executor.clone());
 990        fs.insert_tree("/root", json!({})).await;
 991
 992        let project = Project::test(fs, [Path::new("/root")], cx).await;
 993        let store = cx.new(|cx| TextThreadStore::fake(project, cx));
 994
 995        let now = chrono::Local::now();
 996        let older = SavedTextThreadMetadata {
 997            title: "older".into(),
 998            path: Arc::from(PathBuf::from("/root/older.zed.json")),
 999            mtime: now - chrono::TimeDelta::days(1),
1000        };
1001        let middle = SavedTextThreadMetadata {
1002            title: "middle".into(),
1003            path: Arc::from(PathBuf::from("/root/middle.zed.json")),
1004            mtime: now - chrono::TimeDelta::hours(1),
1005        };
1006        let newer = SavedTextThreadMetadata {
1007            title: "newer".into(),
1008            path: Arc::from(PathBuf::from("/root/newer.zed.json")),
1009            mtime: now,
1010        };
1011
1012        store.update(cx, |store, _| {
1013            store.text_threads_metadata = vec![middle, older, newer];
1014        });
1015
1016        let ordered = store.read_with(cx, |store, _| {
1017            store
1018                .ordered_text_threads()
1019                .map(|entry| entry.title.to_string())
1020                .collect::<Vec<_>>()
1021        });
1022
1023        assert_eq!(ordered, vec!["newer", "middle", "older"]);
1024    }
1025
1026    #[gpui::test]
1027    async fn has_saved_text_threads_reflects_metadata(cx: &mut gpui::TestAppContext) {
1028        init_test(cx);
1029
1030        let fs = FakeFs::new(cx.background_executor.clone());
1031        fs.insert_tree("/root", json!({})).await;
1032
1033        let project = Project::test(fs, [Path::new("/root")], cx).await;
1034        let store = cx.new(|cx| TextThreadStore::fake(project, cx));
1035
1036        assert!(!store.read_with(cx, |store, _| store.has_saved_text_threads()));
1037
1038        store.update(cx, |store, _| {
1039            store.text_threads_metadata = vec![SavedTextThreadMetadata {
1040                title: "thread".into(),
1041                path: Arc::from(PathBuf::from("/root/thread.zed.json")),
1042                mtime: chrono::Local::now(),
1043            }];
1044        });
1045
1046        assert!(store.read_with(cx, |store, _| store.has_saved_text_threads()));
1047    }
1048
1049    #[gpui::test]
1050    async fn delete_all_local_clears_metadata_and_files(cx: &mut gpui::TestAppContext) {
1051        init_test(cx);
1052
1053        let fs = FakeFs::new(cx.background_executor.clone());
1054        fs.insert_tree("/root", json!({})).await;
1055
1056        let thread_a = PathBuf::from("/root/thread-a.zed.json");
1057        let thread_b = PathBuf::from("/root/thread-b.zed.json");
1058        fs.touch_path(&thread_a).await;
1059        fs.touch_path(&thread_b).await;
1060
1061        let project = Project::test(fs.clone(), [Path::new("/root")], cx).await;
1062        let store = cx.new(|cx| TextThreadStore::fake(project, cx));
1063
1064        let now = chrono::Local::now();
1065        store.update(cx, |store, cx| {
1066            store.create(cx);
1067            store.text_threads_metadata = vec![
1068                SavedTextThreadMetadata {
1069                    title: "thread-a".into(),
1070                    path: Arc::from(thread_a.clone()),
1071                    mtime: now,
1072                },
1073                SavedTextThreadMetadata {
1074                    title: "thread-b".into(),
1075                    path: Arc::from(thread_b.clone()),
1076                    mtime: now - chrono::TimeDelta::seconds(1),
1077                },
1078            ];
1079        });
1080
1081        let task = store.update(cx, |store, cx| store.delete_all_local(cx));
1082        task.await.unwrap();
1083
1084        assert!(!store.read_with(cx, |store, _| store.has_saved_text_threads()));
1085        assert_eq!(store.read_with(cx, |store, _| store.text_threads.len()), 0);
1086        assert!(fs.metadata(&thread_a).await.unwrap().is_none());
1087        assert!(fs.metadata(&thread_b).await.unwrap().is_none());
1088    }
1089}