context_store.rs

  1use crate::{
  2    MessageId, ThreadId,
  3    context::{
  4        AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
  5        FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
  6        SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
  7    },
  8    thread::Thread,
  9    thread_store::ThreadStore,
 10};
 11use anyhow::{Context as _, Result, anyhow};
 12use assistant_context::AssistantContext;
 13use collections::{HashSet, IndexSet};
 14use futures::{self, FutureExt};
 15use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
 16use language::{Buffer, File as _};
 17use language_model::LanguageModelImage;
 18use project::{Project, ProjectItem, ProjectPath, Symbol, image_store::is_image_file};
 19use prompt_store::UserPromptId;
 20use ref_cast::RefCast as _;
 21use std::{
 22    ops::Range,
 23    path::{Path, PathBuf},
 24    sync::Arc,
 25};
 26use text::{Anchor, OffsetRangeExt};
 27
 28pub struct ContextStore {
 29    project: WeakEntity<Project>,
 30    thread_store: Option<WeakEntity<ThreadStore>>,
 31    next_context_id: ContextId,
 32    context_set: IndexSet<AgentContextKey>,
 33    context_thread_ids: HashSet<ThreadId>,
 34    context_text_thread_paths: HashSet<Arc<Path>>,
 35}
 36
 37pub enum ContextStoreEvent {
 38    ContextRemoved(AgentContextKey),
 39}
 40
 41impl EventEmitter<ContextStoreEvent> for ContextStore {}
 42
 43impl ContextStore {
 44    pub fn new(
 45        project: WeakEntity<Project>,
 46        thread_store: Option<WeakEntity<ThreadStore>>,
 47    ) -> Self {
 48        Self {
 49            project,
 50            thread_store,
 51            next_context_id: ContextId::zero(),
 52            context_set: IndexSet::default(),
 53            context_thread_ids: HashSet::default(),
 54            context_text_thread_paths: HashSet::default(),
 55        }
 56    }
 57
 58    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 59        self.context_set.iter().map(|entry| entry.as_ref())
 60    }
 61
 62    pub fn clear(&mut self, cx: &mut Context<Self>) {
 63        self.context_set.clear();
 64        self.context_thread_ids.clear();
 65        cx.notify();
 66    }
 67
 68    pub fn new_context_for_thread(
 69        &self,
 70        thread: &Thread,
 71        exclude_messages_from_id: Option<MessageId>,
 72    ) -> Vec<AgentContextHandle> {
 73        let existing_context = thread
 74            .messages()
 75            .iter()
 76            .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id() != id))
 77            .flat_map(|message| {
 78                message
 79                    .loaded_context
 80                    .contexts
 81                    .iter()
 82                    .map(|context| AgentContextKey(context.handle()))
 83            })
 84            .collect::<HashSet<_>>();
 85        self.context_set
 86            .iter()
 87            .filter(|context| !existing_context.contains(context))
 88            .map(|entry| entry.0.clone())
 89            .collect::<Vec<_>>()
 90    }
 91
 92    pub fn add_file_from_path(
 93        &mut self,
 94        project_path: ProjectPath,
 95        remove_if_exists: bool,
 96        cx: &mut Context<Self>,
 97    ) -> Task<Result<Option<AgentContextHandle>>> {
 98        let Some(project) = self.project.upgrade() else {
 99            return Task::ready(Err(anyhow!("failed to read project")));
100        };
101
102        if is_image_file(&project, &project_path, cx) {
103            self.add_image_from_path(project_path, remove_if_exists, cx)
104        } else {
105            cx.spawn(async move |this, cx| {
106                let open_buffer_task = project.update(cx, |project, cx| {
107                    project.open_buffer(project_path.clone(), cx)
108                })?;
109                let buffer = open_buffer_task.await?;
110                this.update(cx, |this, cx| {
111                    this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
112                })
113            })
114        }
115    }
116
117    pub fn add_file_from_buffer(
118        &mut self,
119        project_path: &ProjectPath,
120        buffer: Entity<Buffer>,
121        remove_if_exists: bool,
122        cx: &mut Context<Self>,
123    ) -> Option<AgentContextHandle> {
124        let context_id = self.next_context_id.post_inc();
125        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
126
127        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
128            if remove_if_exists {
129                self.remove_context(&context, cx);
130                None
131            } else {
132                Some(key.as_ref().clone())
133            }
134        } else if self.path_included_in_directory(project_path, cx).is_some() {
135            None
136        } else {
137            self.insert_context(context.clone(), cx);
138            Some(context)
139        }
140    }
141
142    pub fn add_directory(
143        &mut self,
144        project_path: &ProjectPath,
145        remove_if_exists: bool,
146        cx: &mut Context<Self>,
147    ) -> Result<Option<AgentContextHandle>> {
148        let project = self.project.upgrade().context("failed to read project")?;
149        let entry_id = project
150            .read(cx)
151            .entry_for_path(project_path, cx)
152            .map(|entry| entry.id)
153            .context("no entry found for directory context")?;
154
155        let context_id = self.next_context_id.post_inc();
156        let context = AgentContextHandle::Directory(DirectoryContextHandle {
157            entry_id,
158            context_id,
159        });
160
161        let context =
162            if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
163                if remove_if_exists {
164                    self.remove_context(&context, cx);
165                    None
166                } else {
167                    Some(existing.as_ref().clone())
168                }
169            } else {
170                self.insert_context(context.clone(), cx);
171                Some(context)
172            };
173
174        anyhow::Ok(context)
175    }
176
177    pub fn add_symbol(
178        &mut self,
179        buffer: Entity<Buffer>,
180        symbol: SharedString,
181        range: Range<Anchor>,
182        enclosing_range: Range<Anchor>,
183        remove_if_exists: bool,
184        cx: &mut Context<Self>,
185    ) -> (Option<AgentContextHandle>, bool) {
186        let context_id = self.next_context_id.post_inc();
187        let context = AgentContextHandle::Symbol(SymbolContextHandle {
188            buffer,
189            symbol,
190            range,
191            enclosing_range,
192            context_id,
193        });
194
195        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
196            let handle = if remove_if_exists {
197                self.remove_context(&context, cx);
198                None
199            } else {
200                Some(key.as_ref().clone())
201            };
202            return (handle, false);
203        }
204
205        let included = self.insert_context(context.clone(), cx);
206        (Some(context), included)
207    }
208
209    pub fn add_thread(
210        &mut self,
211        thread: Entity<Thread>,
212        remove_if_exists: bool,
213        cx: &mut Context<Self>,
214    ) -> Option<AgentContextHandle> {
215        let context_id = self.next_context_id.post_inc();
216        let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
217
218        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
219            if remove_if_exists {
220                self.remove_context(&context, cx);
221                None
222            } else {
223                Some(existing.as_ref().clone())
224            }
225        } else {
226            self.insert_context(context.clone(), cx);
227            Some(context)
228        }
229    }
230
231    pub fn add_text_thread(
232        &mut self,
233        context: Entity<AssistantContext>,
234        remove_if_exists: bool,
235        cx: &mut Context<Self>,
236    ) -> Option<AgentContextHandle> {
237        let context_id = self.next_context_id.post_inc();
238        let context = AgentContextHandle::TextThread(TextThreadContextHandle {
239            context,
240            context_id,
241        });
242
243        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
244            if remove_if_exists {
245                self.remove_context(&context, cx);
246                None
247            } else {
248                Some(existing.as_ref().clone())
249            }
250        } else {
251            self.insert_context(context.clone(), cx);
252            Some(context)
253        }
254    }
255
256    pub fn add_rules(
257        &mut self,
258        prompt_id: UserPromptId,
259        remove_if_exists: bool,
260        cx: &mut Context<ContextStore>,
261    ) -> Option<AgentContextHandle> {
262        let context_id = self.next_context_id.post_inc();
263        let context = AgentContextHandle::Rules(RulesContextHandle {
264            prompt_id,
265            context_id,
266        });
267
268        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
269            if remove_if_exists {
270                self.remove_context(&context, cx);
271                None
272            } else {
273                Some(existing.as_ref().clone())
274            }
275        } else {
276            self.insert_context(context.clone(), cx);
277            Some(context)
278        }
279    }
280
281    pub fn add_fetched_url(
282        &mut self,
283        url: String,
284        text: impl Into<SharedString>,
285        cx: &mut Context<ContextStore>,
286    ) -> AgentContextHandle {
287        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
288            url: url.into(),
289            text: text.into(),
290            context_id: self.next_context_id.post_inc(),
291        });
292
293        self.insert_context(context.clone(), cx);
294        context
295    }
296
297    pub fn add_image_from_path(
298        &mut self,
299        project_path: ProjectPath,
300        remove_if_exists: bool,
301        cx: &mut Context<ContextStore>,
302    ) -> Task<Result<Option<AgentContextHandle>>> {
303        let project = self.project.clone();
304        cx.spawn(async move |this, cx| {
305            let open_image_task = project.update(cx, |project, cx| {
306                project.open_image(project_path.clone(), cx)
307            })?;
308            let image_item = open_image_task.await?;
309
310            this.update(cx, |this, cx| {
311                let item = image_item.read(cx);
312                this.insert_image(
313                    Some(item.project_path(cx)),
314                    Some(item.file.full_path(cx).into()),
315                    item.image.clone(),
316                    remove_if_exists,
317                    cx,
318                )
319            })
320        })
321    }
322
323    pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
324        self.insert_image(None, None, image, false, cx);
325    }
326
327    fn insert_image(
328        &mut self,
329        project_path: Option<ProjectPath>,
330        full_path: Option<Arc<Path>>,
331        image: Arc<Image>,
332        remove_if_exists: bool,
333        cx: &mut Context<ContextStore>,
334    ) -> Option<AgentContextHandle> {
335        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
336        let context = AgentContextHandle::Image(ImageContext {
337            project_path,
338            full_path,
339            original_image: image,
340            image_task,
341            context_id: self.next_context_id.post_inc(),
342        });
343        if self.has_context(&context) {
344            if remove_if_exists {
345                self.remove_context(&context, cx);
346                return None;
347            }
348        }
349
350        self.insert_context(context.clone(), cx);
351        Some(context)
352    }
353
354    pub fn add_selection(
355        &mut self,
356        buffer: Entity<Buffer>,
357        range: Range<Anchor>,
358        cx: &mut Context<ContextStore>,
359    ) {
360        let context_id = self.next_context_id.post_inc();
361        let context = AgentContextHandle::Selection(SelectionContextHandle {
362            buffer,
363            range,
364            context_id,
365        });
366        self.insert_context(context, cx);
367    }
368
369    pub fn add_suggested_context(
370        &mut self,
371        suggested: &SuggestedContext,
372        cx: &mut Context<ContextStore>,
373    ) {
374        match suggested {
375            SuggestedContext::File {
376                buffer,
377                icon_path: _,
378                name: _,
379            } => {
380                if let Some(buffer) = buffer.upgrade() {
381                    let context_id = self.next_context_id.post_inc();
382                    self.insert_context(
383                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
384                        cx,
385                    );
386                };
387            }
388            SuggestedContext::Thread { thread, name: _ } => {
389                if let Some(thread) = thread.upgrade() {
390                    let context_id = self.next_context_id.post_inc();
391                    self.insert_context(
392                        AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
393                        cx,
394                    );
395                }
396            }
397            SuggestedContext::TextThread { context, name: _ } => {
398                if let Some(context) = context.upgrade() {
399                    let context_id = self.next_context_id.post_inc();
400                    self.insert_context(
401                        AgentContextHandle::TextThread(TextThreadContextHandle {
402                            context,
403                            context_id,
404                        }),
405                        cx,
406                    );
407                }
408            }
409        }
410    }
411
412    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
413        match &context {
414            AgentContextHandle::Thread(thread_context) => {
415                if let Some(thread_store) = self.thread_store.clone() {
416                    thread_context.thread.update(cx, |thread, cx| {
417                        thread.start_generating_detailed_summary_if_needed(thread_store, cx);
418                    });
419                    self.context_thread_ids
420                        .insert(thread_context.thread.read(cx).id().clone());
421                } else {
422                    return false;
423                }
424            }
425            AgentContextHandle::TextThread(text_thread_context) => {
426                self.context_text_thread_paths
427                    .extend(text_thread_context.context.read(cx).path().cloned());
428            }
429            _ => {}
430        }
431        let inserted = self.context_set.insert(AgentContextKey(context));
432        if inserted {
433            cx.notify();
434        }
435        inserted
436    }
437
438    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
439        if let Some((_, key)) = self
440            .context_set
441            .shift_remove_full(AgentContextKey::ref_cast(context))
442        {
443            match context {
444                AgentContextHandle::Thread(thread_context) => {
445                    self.context_thread_ids
446                        .remove(&thread_context.thread.read(cx).id());
447                }
448                AgentContextHandle::TextThread(text_thread_context) => {
449                    if let Some(path) = text_thread_context.context.read(cx).path() {
450                        self.context_text_thread_paths.remove(path);
451                    }
452                }
453                _ => {}
454            }
455            cx.emit(ContextStoreEvent::ContextRemoved(key));
456            cx.notify();
457        }
458    }
459
460    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
461        self.context_set
462            .contains(AgentContextKey::ref_cast(context))
463    }
464
465    /// Returns whether this file path is already included directly in the context, or if it will be
466    /// included in the context via a directory.
467    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
468        let project = self.project.upgrade()?.read(cx);
469        self.context().find_map(|context| match context {
470            AgentContextHandle::File(file_context) => {
471                FileInclusion::check_file(file_context, path, cx)
472            }
473            AgentContextHandle::Image(image_context) => {
474                FileInclusion::check_image(image_context, path)
475            }
476            AgentContextHandle::Directory(directory_context) => {
477                FileInclusion::check_directory(directory_context, path, project, cx)
478            }
479            _ => None,
480        })
481    }
482
483    pub fn path_included_in_directory(
484        &self,
485        path: &ProjectPath,
486        cx: &App,
487    ) -> Option<FileInclusion> {
488        let project = self.project.upgrade()?.read(cx);
489        self.context().find_map(|context| match context {
490            AgentContextHandle::Directory(directory_context) => {
491                FileInclusion::check_directory(directory_context, path, project, cx)
492            }
493            _ => None,
494        })
495    }
496
497    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
498        self.context().any(|context| match context {
499            AgentContextHandle::Symbol(context) => {
500                if context.symbol != symbol.name {
501                    return false;
502                }
503                let buffer = context.buffer.read(cx);
504                let Some(context_path) = buffer.project_path(cx) else {
505                    return false;
506                };
507                if context_path != symbol.path {
508                    return false;
509                }
510                let context_range = context.range.to_point_utf16(&buffer.snapshot());
511                context_range.start == symbol.range.start.0
512                    && context_range.end == symbol.range.end.0
513            }
514            _ => false,
515        })
516    }
517
518    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
519        self.context_thread_ids.contains(thread_id)
520    }
521
522    pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
523        self.context_text_thread_paths.contains(path)
524    }
525
526    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
527        self.context_set
528            .contains(&RulesContextHandle::lookup_key(prompt_id))
529    }
530
531    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
532        self.context_set
533            .contains(&FetchedUrlContext::lookup_key(url.into()))
534    }
535
536    pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
537        self.context_set
538            .get(&FetchedUrlContext::lookup_key(url))
539            .map(|key| key.as_ref().clone())
540    }
541
542    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
543        self.context()
544            .filter_map(|context| match context {
545                AgentContextHandle::File(file) => {
546                    let buffer = file.buffer.read(cx);
547                    buffer.project_path(cx)
548                }
549                AgentContextHandle::Directory(_)
550                | AgentContextHandle::Symbol(_)
551                | AgentContextHandle::Selection(_)
552                | AgentContextHandle::FetchedUrl(_)
553                | AgentContextHandle::Thread(_)
554                | AgentContextHandle::TextThread(_)
555                | AgentContextHandle::Rules(_)
556                | AgentContextHandle::Image(_) => None,
557            })
558            .collect()
559    }
560
561    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
562        &self.context_thread_ids
563    }
564}
565
566#[derive(Clone)]
567pub enum SuggestedContext {
568    File {
569        name: SharedString,
570        icon_path: Option<SharedString>,
571        buffer: WeakEntity<Buffer>,
572    },
573    Thread {
574        name: SharedString,
575        thread: WeakEntity<Thread>,
576    },
577    TextThread {
578        name: SharedString,
579        context: WeakEntity<AssistantContext>,
580    },
581}
582
583impl SuggestedContext {
584    pub fn name(&self) -> &SharedString {
585        match self {
586            Self::File { name, .. } => name,
587            Self::Thread { name, .. } => name,
588            Self::TextThread { name, .. } => name,
589        }
590    }
591
592    pub fn icon_path(&self) -> Option<SharedString> {
593        match self {
594            Self::File { icon_path, .. } => icon_path.clone(),
595            Self::Thread { .. } => None,
596            Self::TextThread { .. } => None,
597        }
598    }
599
600    pub fn kind(&self) -> ContextKind {
601        match self {
602            Self::File { .. } => ContextKind::File,
603            Self::Thread { .. } => ContextKind::Thread,
604            Self::TextThread { .. } => ContextKind::TextThread,
605        }
606    }
607}
608
609pub enum FileInclusion {
610    Direct,
611    InDirectory { full_path: PathBuf },
612}
613
614impl FileInclusion {
615    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
616        let file_path = file_context.buffer.read(cx).project_path(cx)?;
617        if path == &file_path {
618            Some(FileInclusion::Direct)
619        } else {
620            None
621        }
622    }
623
624    fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
625        let image_path = image_context.project_path.as_ref()?;
626        if path == image_path {
627            Some(FileInclusion::Direct)
628        } else {
629            None
630        }
631    }
632
633    fn check_directory(
634        directory_context: &DirectoryContextHandle,
635        path: &ProjectPath,
636        project: &Project,
637        cx: &App,
638    ) -> Option<Self> {
639        let worktree = project
640            .worktree_for_entry(directory_context.entry_id, cx)?
641            .read(cx);
642        let entry = worktree.entry_for_id(directory_context.entry_id)?;
643        let directory_path = ProjectPath {
644            worktree_id: worktree.id(),
645            path: entry.path.clone(),
646        };
647        if path.starts_with(&directory_path) {
648            if path == &directory_path {
649                Some(FileInclusion::Direct)
650            } else {
651                Some(FileInclusion::InDirectory {
652                    full_path: worktree.full_path(&entry.path),
653                })
654            }
655        } else {
656            None
657        }
658    }
659}