context_store.rs

  1use crate::{
  2    context::{
  3        AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
  4        FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
  5        SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
  6    },
  7    thread::{MessageId, Thread, ThreadId},
  8    thread_store::ThreadStore,
  9};
 10use anyhow::{Context as _, Result, anyhow};
 11use assistant_context::AssistantContext;
 12use collections::{HashSet, IndexSet};
 13use futures::{self, FutureExt};
 14use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
 15use language::{Buffer, File as _};
 16use language_model::LanguageModelImage;
 17use project::{Project, ProjectItem, ProjectPath, Symbol, image_store::is_image_file};
 18use prompt_store::UserPromptId;
 19use ref_cast::RefCast as _;
 20use std::{
 21    ops::Range,
 22    path::{Path, PathBuf},
 23    sync::Arc,
 24};
 25use text::{Anchor, OffsetRangeExt};
 26
 27pub struct ContextStore {
 28    project: WeakEntity<Project>,
 29    thread_store: Option<WeakEntity<ThreadStore>>,
 30    next_context_id: ContextId,
 31    context_set: IndexSet<AgentContextKey>,
 32    context_thread_ids: HashSet<ThreadId>,
 33    context_text_thread_paths: HashSet<Arc<Path>>,
 34}
 35
 36pub enum ContextStoreEvent {
 37    ContextRemoved(AgentContextKey),
 38}
 39
 40impl EventEmitter<ContextStoreEvent> for ContextStore {}
 41
 42impl ContextStore {
 43    pub fn new(
 44        project: WeakEntity<Project>,
 45        thread_store: Option<WeakEntity<ThreadStore>>,
 46    ) -> Self {
 47        Self {
 48            project,
 49            thread_store,
 50            next_context_id: ContextId::zero(),
 51            context_set: IndexSet::default(),
 52            context_thread_ids: HashSet::default(),
 53            context_text_thread_paths: HashSet::default(),
 54        }
 55    }
 56
 57    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 58        self.context_set.iter().map(|entry| entry.as_ref())
 59    }
 60
 61    pub fn clear(&mut self, cx: &mut Context<Self>) {
 62        self.context_set.clear();
 63        self.context_thread_ids.clear();
 64        cx.notify();
 65    }
 66
 67    pub fn new_context_for_thread(
 68        &self,
 69        thread: &Thread,
 70        exclude_messages_from_id: Option<MessageId>,
 71    ) -> Vec<AgentContextHandle> {
 72        let existing_context = thread
 73            .messages()
 74            .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
 75            .flat_map(|message| {
 76                message
 77                    .loaded_context
 78                    .contexts
 79                    .iter()
 80                    .map(|context| AgentContextKey(context.handle()))
 81            })
 82            .collect::<HashSet<_>>();
 83        self.context_set
 84            .iter()
 85            .filter(|context| !existing_context.contains(context))
 86            .map(|entry| entry.0.clone())
 87            .collect::<Vec<_>>()
 88    }
 89
 90    pub fn add_file_from_path(
 91        &mut self,
 92        project_path: ProjectPath,
 93        remove_if_exists: bool,
 94        cx: &mut Context<Self>,
 95    ) -> Task<Result<Option<AgentContextHandle>>> {
 96        let Some(project) = self.project.upgrade() else {
 97            return Task::ready(Err(anyhow!("failed to read project")));
 98        };
 99
100        if is_image_file(&project, &project_path, cx) {
101            self.add_image_from_path(project_path, remove_if_exists, cx)
102        } else {
103            cx.spawn(async move |this, cx| {
104                let open_buffer_task = project.update(cx, |project, cx| {
105                    project.open_buffer(project_path.clone(), cx)
106                })?;
107                let buffer = open_buffer_task.await?;
108                this.update(cx, |this, cx| {
109                    this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
110                })
111            })
112        }
113    }
114
115    pub fn add_file_from_buffer(
116        &mut self,
117        project_path: &ProjectPath,
118        buffer: Entity<Buffer>,
119        remove_if_exists: bool,
120        cx: &mut Context<Self>,
121    ) -> Option<AgentContextHandle> {
122        let context_id = self.next_context_id.post_inc();
123        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
124
125        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
126            if remove_if_exists {
127                self.remove_context(&context, cx);
128                None
129            } else {
130                Some(key.as_ref().clone())
131            }
132        } else if self.path_included_in_directory(project_path, cx).is_some() {
133            None
134        } else {
135            self.insert_context(context.clone(), cx);
136            Some(context)
137        }
138    }
139
140    pub fn add_directory(
141        &mut self,
142        project_path: &ProjectPath,
143        remove_if_exists: bool,
144        cx: &mut Context<Self>,
145    ) -> Result<Option<AgentContextHandle>> {
146        let project = self.project.upgrade().context("failed to read project")?;
147        let entry_id = project
148            .read(cx)
149            .entry_for_path(project_path, cx)
150            .map(|entry| entry.id)
151            .context("no entry found for directory context")?;
152
153        let context_id = self.next_context_id.post_inc();
154        let context = AgentContextHandle::Directory(DirectoryContextHandle {
155            entry_id,
156            context_id,
157        });
158
159        let context =
160            if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
161                if remove_if_exists {
162                    self.remove_context(&context, cx);
163                    None
164                } else {
165                    Some(existing.as_ref().clone())
166                }
167            } else {
168                self.insert_context(context.clone(), cx);
169                Some(context)
170            };
171
172        anyhow::Ok(context)
173    }
174
175    pub fn add_symbol(
176        &mut self,
177        buffer: Entity<Buffer>,
178        symbol: SharedString,
179        range: Range<Anchor>,
180        enclosing_range: Range<Anchor>,
181        remove_if_exists: bool,
182        cx: &mut Context<Self>,
183    ) -> (Option<AgentContextHandle>, bool) {
184        let context_id = self.next_context_id.post_inc();
185        let context = AgentContextHandle::Symbol(SymbolContextHandle {
186            buffer,
187            symbol,
188            range,
189            enclosing_range,
190            context_id,
191        });
192
193        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
194            let handle = if remove_if_exists {
195                self.remove_context(&context, cx);
196                None
197            } else {
198                Some(key.as_ref().clone())
199            };
200            return (handle, false);
201        }
202
203        let included = self.insert_context(context.clone(), cx);
204        (Some(context), included)
205    }
206
207    pub fn add_thread(
208        &mut self,
209        thread: Entity<Thread>,
210        remove_if_exists: bool,
211        cx: &mut Context<Self>,
212    ) -> Option<AgentContextHandle> {
213        let context_id = self.next_context_id.post_inc();
214        let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
215
216        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
217            if remove_if_exists {
218                self.remove_context(&context, cx);
219                None
220            } else {
221                Some(existing.as_ref().clone())
222            }
223        } else {
224            self.insert_context(context.clone(), cx);
225            Some(context)
226        }
227    }
228
229    pub fn add_text_thread(
230        &mut self,
231        context: Entity<AssistantContext>,
232        remove_if_exists: bool,
233        cx: &mut Context<Self>,
234    ) -> Option<AgentContextHandle> {
235        let context_id = self.next_context_id.post_inc();
236        let context = AgentContextHandle::TextThread(TextThreadContextHandle {
237            context,
238            context_id,
239        });
240
241        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
242            if remove_if_exists {
243                self.remove_context(&context, cx);
244                None
245            } else {
246                Some(existing.as_ref().clone())
247            }
248        } else {
249            self.insert_context(context.clone(), cx);
250            Some(context)
251        }
252    }
253
254    pub fn add_rules(
255        &mut self,
256        prompt_id: UserPromptId,
257        remove_if_exists: bool,
258        cx: &mut Context<ContextStore>,
259    ) -> Option<AgentContextHandle> {
260        let context_id = self.next_context_id.post_inc();
261        let context = AgentContextHandle::Rules(RulesContextHandle {
262            prompt_id,
263            context_id,
264        });
265
266        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
267            if remove_if_exists {
268                self.remove_context(&context, cx);
269                None
270            } else {
271                Some(existing.as_ref().clone())
272            }
273        } else {
274            self.insert_context(context.clone(), cx);
275            Some(context)
276        }
277    }
278
279    pub fn add_fetched_url(
280        &mut self,
281        url: String,
282        text: impl Into<SharedString>,
283        cx: &mut Context<ContextStore>,
284    ) -> AgentContextHandle {
285        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
286            url: url.into(),
287            text: text.into(),
288            context_id: self.next_context_id.post_inc(),
289        });
290
291        self.insert_context(context.clone(), cx);
292        context
293    }
294
295    pub fn add_image_from_path(
296        &mut self,
297        project_path: ProjectPath,
298        remove_if_exists: bool,
299        cx: &mut Context<ContextStore>,
300    ) -> Task<Result<Option<AgentContextHandle>>> {
301        let project = self.project.clone();
302        cx.spawn(async move |this, cx| {
303            let open_image_task = project.update(cx, |project, cx| {
304                project.open_image(project_path.clone(), cx)
305            })?;
306            let image_item = open_image_task.await?;
307
308            this.update(cx, |this, cx| {
309                let item = image_item.read(cx);
310                this.insert_image(
311                    Some(item.project_path(cx)),
312                    Some(item.file.full_path(cx).into()),
313                    item.image.clone(),
314                    remove_if_exists,
315                    cx,
316                )
317            })
318        })
319    }
320
321    pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
322        self.insert_image(None, None, image, false, cx);
323    }
324
325    fn insert_image(
326        &mut self,
327        project_path: Option<ProjectPath>,
328        full_path: Option<Arc<Path>>,
329        image: Arc<Image>,
330        remove_if_exists: bool,
331        cx: &mut Context<ContextStore>,
332    ) -> Option<AgentContextHandle> {
333        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
334        let context = AgentContextHandle::Image(ImageContext {
335            project_path,
336            full_path,
337            original_image: image,
338            image_task,
339            context_id: self.next_context_id.post_inc(),
340        });
341        if self.has_context(&context) {
342            if remove_if_exists {
343                self.remove_context(&context, cx);
344                return None;
345            }
346        }
347
348        self.insert_context(context.clone(), cx);
349        Some(context)
350    }
351
352    pub fn add_selection(
353        &mut self,
354        buffer: Entity<Buffer>,
355        range: Range<Anchor>,
356        cx: &mut Context<ContextStore>,
357    ) {
358        let context_id = self.next_context_id.post_inc();
359        let context = AgentContextHandle::Selection(SelectionContextHandle {
360            buffer,
361            range,
362            context_id,
363        });
364        self.insert_context(context, cx);
365    }
366
367    pub fn add_suggested_context(
368        &mut self,
369        suggested: &SuggestedContext,
370        cx: &mut Context<ContextStore>,
371    ) {
372        match suggested {
373            SuggestedContext::File {
374                buffer,
375                icon_path: _,
376                name: _,
377            } => {
378                if let Some(buffer) = buffer.upgrade() {
379                    let context_id = self.next_context_id.post_inc();
380                    self.insert_context(
381                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
382                        cx,
383                    );
384                };
385            }
386            SuggestedContext::Thread { thread, name: _ } => {
387                if let Some(thread) = thread.upgrade() {
388                    let context_id = self.next_context_id.post_inc();
389                    self.insert_context(
390                        AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
391                        cx,
392                    );
393                }
394            }
395            SuggestedContext::TextThread { context, name: _ } => {
396                if let Some(context) = context.upgrade() {
397                    let context_id = self.next_context_id.post_inc();
398                    self.insert_context(
399                        AgentContextHandle::TextThread(TextThreadContextHandle {
400                            context,
401                            context_id,
402                        }),
403                        cx,
404                    );
405                }
406            }
407        }
408    }
409
410    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
411        match &context {
412            AgentContextHandle::Thread(thread_context) => {
413                if let Some(thread_store) = self.thread_store.clone() {
414                    thread_context.thread.update(cx, |thread, cx| {
415                        thread.start_generating_detailed_summary_if_needed(thread_store, cx);
416                    });
417                    self.context_thread_ids
418                        .insert(thread_context.thread.read(cx).id().clone());
419                } else {
420                    return false;
421                }
422            }
423            AgentContextHandle::TextThread(text_thread_context) => {
424                self.context_text_thread_paths
425                    .extend(text_thread_context.context.read(cx).path().cloned());
426            }
427            _ => {}
428        }
429        let inserted = self.context_set.insert(AgentContextKey(context));
430        if inserted {
431            cx.notify();
432        }
433        inserted
434    }
435
436    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
437        if let Some((_, key)) = self
438            .context_set
439            .shift_remove_full(AgentContextKey::ref_cast(context))
440        {
441            match context {
442                AgentContextHandle::Thread(thread_context) => {
443                    self.context_thread_ids
444                        .remove(thread_context.thread.read(cx).id());
445                }
446                AgentContextHandle::TextThread(text_thread_context) => {
447                    if let Some(path) = text_thread_context.context.read(cx).path() {
448                        self.context_text_thread_paths.remove(path);
449                    }
450                }
451                _ => {}
452            }
453            cx.emit(ContextStoreEvent::ContextRemoved(key));
454            cx.notify();
455        }
456    }
457
458    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
459        self.context_set
460            .contains(AgentContextKey::ref_cast(context))
461    }
462
463    /// Returns whether this file path is already included directly in the context, or if it will be
464    /// included in the context via a directory.
465    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
466        let project = self.project.upgrade()?.read(cx);
467        self.context().find_map(|context| match context {
468            AgentContextHandle::File(file_context) => {
469                FileInclusion::check_file(file_context, path, cx)
470            }
471            AgentContextHandle::Image(image_context) => {
472                FileInclusion::check_image(image_context, path)
473            }
474            AgentContextHandle::Directory(directory_context) => {
475                FileInclusion::check_directory(directory_context, path, project, cx)
476            }
477            _ => None,
478        })
479    }
480
481    pub fn path_included_in_directory(
482        &self,
483        path: &ProjectPath,
484        cx: &App,
485    ) -> Option<FileInclusion> {
486        let project = self.project.upgrade()?.read(cx);
487        self.context().find_map(|context| match context {
488            AgentContextHandle::Directory(directory_context) => {
489                FileInclusion::check_directory(directory_context, path, project, cx)
490            }
491            _ => None,
492        })
493    }
494
495    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
496        self.context().any(|context| match context {
497            AgentContextHandle::Symbol(context) => {
498                if context.symbol != symbol.name {
499                    return false;
500                }
501                let buffer = context.buffer.read(cx);
502                let Some(context_path) = buffer.project_path(cx) else {
503                    return false;
504                };
505                if context_path != symbol.path {
506                    return false;
507                }
508                let context_range = context.range.to_point_utf16(&buffer.snapshot());
509                context_range.start == symbol.range.start.0
510                    && context_range.end == symbol.range.end.0
511            }
512            _ => false,
513        })
514    }
515
516    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
517        self.context_thread_ids.contains(thread_id)
518    }
519
520    pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
521        self.context_text_thread_paths.contains(path)
522    }
523
524    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
525        self.context_set
526            .contains(&RulesContextHandle::lookup_key(prompt_id))
527    }
528
529    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
530        self.context_set
531            .contains(&FetchedUrlContext::lookup_key(url.into()))
532    }
533
534    pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
535        self.context_set
536            .get(&FetchedUrlContext::lookup_key(url))
537            .map(|key| key.as_ref().clone())
538    }
539
540    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
541        self.context()
542            .filter_map(|context| match context {
543                AgentContextHandle::File(file) => {
544                    let buffer = file.buffer.read(cx);
545                    buffer.project_path(cx)
546                }
547                AgentContextHandle::Directory(_)
548                | AgentContextHandle::Symbol(_)
549                | AgentContextHandle::Selection(_)
550                | AgentContextHandle::FetchedUrl(_)
551                | AgentContextHandle::Thread(_)
552                | AgentContextHandle::TextThread(_)
553                | AgentContextHandle::Rules(_)
554                | AgentContextHandle::Image(_) => None,
555            })
556            .collect()
557    }
558
559    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
560        &self.context_thread_ids
561    }
562}
563
564#[derive(Clone)]
565pub enum SuggestedContext {
566    File {
567        name: SharedString,
568        icon_path: Option<SharedString>,
569        buffer: WeakEntity<Buffer>,
570    },
571    Thread {
572        name: SharedString,
573        thread: WeakEntity<Thread>,
574    },
575    TextThread {
576        name: SharedString,
577        context: WeakEntity<AssistantContext>,
578    },
579}
580
581impl SuggestedContext {
582    pub fn name(&self) -> &SharedString {
583        match self {
584            Self::File { name, .. } => name,
585            Self::Thread { name, .. } => name,
586            Self::TextThread { name, .. } => name,
587        }
588    }
589
590    pub fn icon_path(&self) -> Option<SharedString> {
591        match self {
592            Self::File { icon_path, .. } => icon_path.clone(),
593            Self::Thread { .. } => None,
594            Self::TextThread { .. } => None,
595        }
596    }
597
598    pub fn kind(&self) -> ContextKind {
599        match self {
600            Self::File { .. } => ContextKind::File,
601            Self::Thread { .. } => ContextKind::Thread,
602            Self::TextThread { .. } => ContextKind::TextThread,
603        }
604    }
605}
606
607pub enum FileInclusion {
608    Direct,
609    InDirectory { full_path: PathBuf },
610}
611
612impl FileInclusion {
613    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
614        let file_path = file_context.buffer.read(cx).project_path(cx)?;
615        if path == &file_path {
616            Some(FileInclusion::Direct)
617        } else {
618            None
619        }
620    }
621
622    fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
623        let image_path = image_context.project_path.as_ref()?;
624        if path == image_path {
625            Some(FileInclusion::Direct)
626        } else {
627            None
628        }
629    }
630
631    fn check_directory(
632        directory_context: &DirectoryContextHandle,
633        path: &ProjectPath,
634        project: &Project,
635        cx: &App,
636    ) -> Option<Self> {
637        let worktree = project
638            .worktree_for_entry(directory_context.entry_id, cx)?
639            .read(cx);
640        let entry = worktree.entry_for_id(directory_context.entry_id)?;
641        let directory_path = ProjectPath {
642            worktree_id: worktree.id(),
643            path: entry.path.clone(),
644        };
645        if path.starts_with(&directory_path) {
646            if path == &directory_path {
647                Some(FileInclusion::Direct)
648            } else {
649                Some(FileInclusion::InDirectory {
650                    full_path: worktree.full_path(&entry.path),
651                })
652            }
653        } else {
654            None
655        }
656    }
657}