context_store.rs

  1use crate::{
  2    context::{
  3        AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
  4        FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
  5        SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
  6    },
  7    thread::{MessageId, Thread, ThreadId},
  8    thread_store::ThreadStore,
  9};
 10use anyhow::{Context as _, Result, anyhow};
 11use assistant_context::AssistantContext;
 12use collections::{HashSet, IndexSet};
 13use futures::{self, FutureExt};
 14use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
 15use language::{Buffer, File as _};
 16use language_model::LanguageModelImage;
 17use project::{
 18    Project, ProjectItem, ProjectPath, Symbol, image_store::is_image_file,
 19    lsp_store::SymbolLocation,
 20};
 21use prompt_store::UserPromptId;
 22use ref_cast::RefCast as _;
 23use std::{
 24    ops::Range,
 25    path::{Path, PathBuf},
 26    sync::Arc,
 27};
 28use text::{Anchor, OffsetRangeExt};
 29
 30pub struct ContextStore {
 31    project: WeakEntity<Project>,
 32    thread_store: Option<WeakEntity<ThreadStore>>,
 33    next_context_id: ContextId,
 34    context_set: IndexSet<AgentContextKey>,
 35    context_thread_ids: HashSet<ThreadId>,
 36    context_text_thread_paths: HashSet<Arc<Path>>,
 37}
 38
 39pub enum ContextStoreEvent {
 40    ContextRemoved(AgentContextKey),
 41}
 42
 43impl EventEmitter<ContextStoreEvent> for ContextStore {}
 44
 45impl ContextStore {
 46    pub fn new(
 47        project: WeakEntity<Project>,
 48        thread_store: Option<WeakEntity<ThreadStore>>,
 49    ) -> Self {
 50        Self {
 51            project,
 52            thread_store,
 53            next_context_id: ContextId::zero(),
 54            context_set: IndexSet::default(),
 55            context_thread_ids: HashSet::default(),
 56            context_text_thread_paths: HashSet::default(),
 57        }
 58    }
 59
 60    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 61        self.context_set.iter().map(|entry| entry.as_ref())
 62    }
 63
 64    pub fn clear(&mut self, cx: &mut Context<Self>) {
 65        self.context_set.clear();
 66        self.context_thread_ids.clear();
 67        cx.notify();
 68    }
 69
 70    pub fn new_context_for_thread(
 71        &self,
 72        thread: &Thread,
 73        exclude_messages_from_id: Option<MessageId>,
 74    ) -> Vec<AgentContextHandle> {
 75        let existing_context = thread
 76            .messages()
 77            .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
 78            .flat_map(|message| {
 79                message
 80                    .loaded_context
 81                    .contexts
 82                    .iter()
 83                    .map(|context| AgentContextKey(context.handle()))
 84            })
 85            .collect::<HashSet<_>>();
 86        self.context_set
 87            .iter()
 88            .filter(|context| !existing_context.contains(context))
 89            .map(|entry| entry.0.clone())
 90            .collect::<Vec<_>>()
 91    }
 92
 93    pub fn add_file_from_path(
 94        &mut self,
 95        project_path: ProjectPath,
 96        remove_if_exists: bool,
 97        cx: &mut Context<Self>,
 98    ) -> Task<Result<Option<AgentContextHandle>>> {
 99        let Some(project) = self.project.upgrade() else {
100            return Task::ready(Err(anyhow!("failed to read project")));
101        };
102
103        if is_image_file(&project, &project_path, cx) {
104            self.add_image_from_path(project_path, remove_if_exists, cx)
105        } else {
106            cx.spawn(async move |this, cx| {
107                let open_buffer_task = project.update(cx, |project, cx| {
108                    project.open_buffer(project_path.clone(), cx)
109                })?;
110                let buffer = open_buffer_task.await?;
111                this.update(cx, |this, cx| {
112                    this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
113                })
114            })
115        }
116    }
117
118    pub fn add_file_from_buffer(
119        &mut self,
120        project_path: &ProjectPath,
121        buffer: Entity<Buffer>,
122        remove_if_exists: bool,
123        cx: &mut Context<Self>,
124    ) -> Option<AgentContextHandle> {
125        let context_id = self.next_context_id.post_inc();
126        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
127
128        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
129            if remove_if_exists {
130                self.remove_context(&context, cx);
131                None
132            } else {
133                Some(key.as_ref().clone())
134            }
135        } else if self.path_included_in_directory(project_path, cx).is_some() {
136            None
137        } else {
138            self.insert_context(context.clone(), cx);
139            Some(context)
140        }
141    }
142
143    pub fn add_directory(
144        &mut self,
145        project_path: &ProjectPath,
146        remove_if_exists: bool,
147        cx: &mut Context<Self>,
148    ) -> Result<Option<AgentContextHandle>> {
149        let project = self.project.upgrade().context("failed to read project")?;
150        let entry_id = project
151            .read(cx)
152            .entry_for_path(project_path, cx)
153            .map(|entry| entry.id)
154            .context("no entry found for directory context")?;
155
156        let context_id = self.next_context_id.post_inc();
157        let context = AgentContextHandle::Directory(DirectoryContextHandle {
158            entry_id,
159            context_id,
160        });
161
162        let context =
163            if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
164                if remove_if_exists {
165                    self.remove_context(&context, cx);
166                    None
167                } else {
168                    Some(existing.as_ref().clone())
169                }
170            } else {
171                self.insert_context(context.clone(), cx);
172                Some(context)
173            };
174
175        anyhow::Ok(context)
176    }
177
178    pub fn add_symbol(
179        &mut self,
180        buffer: Entity<Buffer>,
181        symbol: SharedString,
182        range: Range<Anchor>,
183        enclosing_range: Range<Anchor>,
184        remove_if_exists: bool,
185        cx: &mut Context<Self>,
186    ) -> (Option<AgentContextHandle>, bool) {
187        let context_id = self.next_context_id.post_inc();
188        let context = AgentContextHandle::Symbol(SymbolContextHandle {
189            buffer,
190            symbol,
191            range,
192            enclosing_range,
193            context_id,
194        });
195
196        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
197            let handle = if remove_if_exists {
198                self.remove_context(&context, cx);
199                None
200            } else {
201                Some(key.as_ref().clone())
202            };
203            return (handle, false);
204        }
205
206        let included = self.insert_context(context.clone(), cx);
207        (Some(context), included)
208    }
209
210    pub fn add_thread(
211        &mut self,
212        thread: Entity<Thread>,
213        remove_if_exists: bool,
214        cx: &mut Context<Self>,
215    ) -> Option<AgentContextHandle> {
216        let context_id = self.next_context_id.post_inc();
217        let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
218
219        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
220            if remove_if_exists {
221                self.remove_context(&context, cx);
222                None
223            } else {
224                Some(existing.as_ref().clone())
225            }
226        } else {
227            self.insert_context(context.clone(), cx);
228            Some(context)
229        }
230    }
231
232    pub fn add_text_thread(
233        &mut self,
234        context: Entity<AssistantContext>,
235        remove_if_exists: bool,
236        cx: &mut Context<Self>,
237    ) -> Option<AgentContextHandle> {
238        let context_id = self.next_context_id.post_inc();
239        let context = AgentContextHandle::TextThread(TextThreadContextHandle {
240            context,
241            context_id,
242        });
243
244        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
245            if remove_if_exists {
246                self.remove_context(&context, cx);
247                None
248            } else {
249                Some(existing.as_ref().clone())
250            }
251        } else {
252            self.insert_context(context.clone(), cx);
253            Some(context)
254        }
255    }
256
257    pub fn add_rules(
258        &mut self,
259        prompt_id: UserPromptId,
260        remove_if_exists: bool,
261        cx: &mut Context<ContextStore>,
262    ) -> Option<AgentContextHandle> {
263        let context_id = self.next_context_id.post_inc();
264        let context = AgentContextHandle::Rules(RulesContextHandle {
265            prompt_id,
266            context_id,
267        });
268
269        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
270            if remove_if_exists {
271                self.remove_context(&context, cx);
272                None
273            } else {
274                Some(existing.as_ref().clone())
275            }
276        } else {
277            self.insert_context(context.clone(), cx);
278            Some(context)
279        }
280    }
281
282    pub fn add_fetched_url(
283        &mut self,
284        url: String,
285        text: impl Into<SharedString>,
286        cx: &mut Context<ContextStore>,
287    ) -> AgentContextHandle {
288        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
289            url: url.into(),
290            text: text.into(),
291            context_id: self.next_context_id.post_inc(),
292        });
293
294        self.insert_context(context.clone(), cx);
295        context
296    }
297
298    pub fn add_image_from_path(
299        &mut self,
300        project_path: ProjectPath,
301        remove_if_exists: bool,
302        cx: &mut Context<ContextStore>,
303    ) -> Task<Result<Option<AgentContextHandle>>> {
304        let project = self.project.clone();
305        cx.spawn(async move |this, cx| {
306            let open_image_task = project.update(cx, |project, cx| {
307                project.open_image(project_path.clone(), cx)
308            })?;
309            let image_item = open_image_task.await?;
310
311            this.update(cx, |this, cx| {
312                let item = image_item.read(cx);
313                this.insert_image(
314                    Some(item.project_path(cx)),
315                    Some(item.file.full_path(cx).to_string_lossy().into_owned()),
316                    item.image.clone(),
317                    remove_if_exists,
318                    cx,
319                )
320            })
321        })
322    }
323
324    pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
325        self.insert_image(None, None, image, false, cx);
326    }
327
328    fn insert_image(
329        &mut self,
330        project_path: Option<ProjectPath>,
331        full_path: Option<String>,
332        image: Arc<Image>,
333        remove_if_exists: bool,
334        cx: &mut Context<ContextStore>,
335    ) -> Option<AgentContextHandle> {
336        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
337        let context = AgentContextHandle::Image(ImageContext {
338            project_path,
339            full_path,
340            original_image: image,
341            image_task,
342            context_id: self.next_context_id.post_inc(),
343        });
344        if self.has_context(&context) && remove_if_exists {
345            self.remove_context(&context, cx);
346            return None;
347        }
348
349        self.insert_context(context.clone(), cx);
350        Some(context)
351    }
352
353    pub fn add_selection(
354        &mut self,
355        buffer: Entity<Buffer>,
356        range: Range<Anchor>,
357        cx: &mut Context<ContextStore>,
358    ) {
359        let context_id = self.next_context_id.post_inc();
360        let context = AgentContextHandle::Selection(SelectionContextHandle {
361            buffer,
362            range,
363            context_id,
364        });
365        self.insert_context(context, cx);
366    }
367
368    pub fn add_suggested_context(
369        &mut self,
370        suggested: &SuggestedContext,
371        cx: &mut Context<ContextStore>,
372    ) {
373        match suggested {
374            SuggestedContext::File {
375                buffer,
376                icon_path: _,
377                name: _,
378            } => {
379                if let Some(buffer) = buffer.upgrade() {
380                    let context_id = self.next_context_id.post_inc();
381                    self.insert_context(
382                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
383                        cx,
384                    );
385                };
386            }
387            SuggestedContext::Thread { thread, name: _ } => {
388                if let Some(thread) = thread.upgrade() {
389                    let context_id = self.next_context_id.post_inc();
390                    self.insert_context(
391                        AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
392                        cx,
393                    );
394                }
395            }
396            SuggestedContext::TextThread { context, name: _ } => {
397                if let Some(context) = context.upgrade() {
398                    let context_id = self.next_context_id.post_inc();
399                    self.insert_context(
400                        AgentContextHandle::TextThread(TextThreadContextHandle {
401                            context,
402                            context_id,
403                        }),
404                        cx,
405                    );
406                }
407            }
408        }
409    }
410
411    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
412        match &context {
413            AgentContextHandle::Thread(thread_context) => {
414                if let Some(thread_store) = self.thread_store.clone() {
415                    thread_context.thread.update(cx, |thread, cx| {
416                        thread.start_generating_detailed_summary_if_needed(thread_store, cx);
417                    });
418                    self.context_thread_ids
419                        .insert(thread_context.thread.read(cx).id().clone());
420                } else {
421                    return false;
422                }
423            }
424            AgentContextHandle::TextThread(text_thread_context) => {
425                self.context_text_thread_paths
426                    .extend(text_thread_context.context.read(cx).path().cloned());
427            }
428            _ => {}
429        }
430        let inserted = self.context_set.insert(AgentContextKey(context));
431        if inserted {
432            cx.notify();
433        }
434        inserted
435    }
436
437    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
438        if let Some((_, key)) = self
439            .context_set
440            .shift_remove_full(AgentContextKey::ref_cast(context))
441        {
442            match context {
443                AgentContextHandle::Thread(thread_context) => {
444                    self.context_thread_ids
445                        .remove(thread_context.thread.read(cx).id());
446                }
447                AgentContextHandle::TextThread(text_thread_context) => {
448                    if let Some(path) = text_thread_context.context.read(cx).path() {
449                        self.context_text_thread_paths.remove(path);
450                    }
451                }
452                _ => {}
453            }
454            cx.emit(ContextStoreEvent::ContextRemoved(key));
455            cx.notify();
456        }
457    }
458
459    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
460        self.context_set
461            .contains(AgentContextKey::ref_cast(context))
462    }
463
464    /// Returns whether this file path is already included directly in the context, or if it will be
465    /// included in the context via a directory.
466    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
467        let project = self.project.upgrade()?.read(cx);
468        self.context().find_map(|context| match context {
469            AgentContextHandle::File(file_context) => {
470                FileInclusion::check_file(file_context, path, cx)
471            }
472            AgentContextHandle::Image(image_context) => {
473                FileInclusion::check_image(image_context, path)
474            }
475            AgentContextHandle::Directory(directory_context) => {
476                FileInclusion::check_directory(directory_context, path, project, cx)
477            }
478            _ => None,
479        })
480    }
481
482    pub fn path_included_in_directory(
483        &self,
484        path: &ProjectPath,
485        cx: &App,
486    ) -> Option<FileInclusion> {
487        let project = self.project.upgrade()?.read(cx);
488        self.context().find_map(|context| match context {
489            AgentContextHandle::Directory(directory_context) => {
490                FileInclusion::check_directory(directory_context, path, project, cx)
491            }
492            _ => None,
493        })
494    }
495
496    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
497        self.context().any(|context| match context {
498            AgentContextHandle::Symbol(context) => {
499                if context.symbol != symbol.name {
500                    return false;
501                }
502                let buffer = context.buffer.read(cx);
503                let Some(context_path) = buffer.project_path(cx) else {
504                    return false;
505                };
506                if symbol.path != SymbolLocation::InProject(context_path) {
507                    return false;
508                }
509                let context_range = context.range.to_point_utf16(&buffer.snapshot());
510                context_range.start == symbol.range.start.0
511                    && context_range.end == symbol.range.end.0
512            }
513            _ => false,
514        })
515    }
516
517    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
518        self.context_thread_ids.contains(thread_id)
519    }
520
521    pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
522        self.context_text_thread_paths.contains(path)
523    }
524
525    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
526        self.context_set
527            .contains(&RulesContextHandle::lookup_key(prompt_id))
528    }
529
530    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
531        self.context_set
532            .contains(&FetchedUrlContext::lookup_key(url.into()))
533    }
534
535    pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
536        self.context_set
537            .get(&FetchedUrlContext::lookup_key(url))
538            .map(|key| key.as_ref().clone())
539    }
540
541    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
542        self.context()
543            .filter_map(|context| match context {
544                AgentContextHandle::File(file) => {
545                    let buffer = file.buffer.read(cx);
546                    buffer.project_path(cx)
547                }
548                AgentContextHandle::Directory(_)
549                | AgentContextHandle::Symbol(_)
550                | AgentContextHandle::Selection(_)
551                | AgentContextHandle::FetchedUrl(_)
552                | AgentContextHandle::Thread(_)
553                | AgentContextHandle::TextThread(_)
554                | AgentContextHandle::Rules(_)
555                | AgentContextHandle::Image(_) => None,
556            })
557            .collect()
558    }
559
560    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
561        &self.context_thread_ids
562    }
563}
564
565#[derive(Clone)]
566pub enum SuggestedContext {
567    File {
568        name: SharedString,
569        icon_path: Option<SharedString>,
570        buffer: WeakEntity<Buffer>,
571    },
572    Thread {
573        name: SharedString,
574        thread: WeakEntity<Thread>,
575    },
576    TextThread {
577        name: SharedString,
578        context: WeakEntity<AssistantContext>,
579    },
580}
581
582impl SuggestedContext {
583    pub fn name(&self) -> &SharedString {
584        match self {
585            Self::File { name, .. } => name,
586            Self::Thread { name, .. } => name,
587            Self::TextThread { name, .. } => name,
588        }
589    }
590
591    pub fn icon_path(&self) -> Option<SharedString> {
592        match self {
593            Self::File { icon_path, .. } => icon_path.clone(),
594            Self::Thread { .. } => None,
595            Self::TextThread { .. } => None,
596        }
597    }
598
599    pub fn kind(&self) -> ContextKind {
600        match self {
601            Self::File { .. } => ContextKind::File,
602            Self::Thread { .. } => ContextKind::Thread,
603            Self::TextThread { .. } => ContextKind::TextThread,
604        }
605    }
606}
607
608pub enum FileInclusion {
609    Direct,
610    InDirectory { full_path: PathBuf },
611}
612
613impl FileInclusion {
614    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
615        let file_path = file_context.buffer.read(cx).project_path(cx)?;
616        if path == &file_path {
617            Some(FileInclusion::Direct)
618        } else {
619            None
620        }
621    }
622
623    fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
624        let image_path = image_context.project_path.as_ref()?;
625        if path == image_path {
626            Some(FileInclusion::Direct)
627        } else {
628            None
629        }
630    }
631
632    fn check_directory(
633        directory_context: &DirectoryContextHandle,
634        path: &ProjectPath,
635        project: &Project,
636        cx: &App,
637    ) -> Option<Self> {
638        let worktree = project
639            .worktree_for_entry(directory_context.entry_id, cx)?
640            .read(cx);
641        let entry = worktree.entry_for_id(directory_context.entry_id)?;
642        let directory_path = ProjectPath {
643            worktree_id: worktree.id(),
644            path: entry.path.clone(),
645        };
646        if path.starts_with(&directory_path) {
647            if path == &directory_path {
648                Some(FileInclusion::Direct)
649            } else {
650                Some(FileInclusion::InDirectory {
651                    full_path: worktree.full_path(&entry.path),
652                })
653            }
654        } else {
655            None
656        }
657    }
658}