context_store.rs

  1use std::ops::Range;
  2use std::path::{Path, PathBuf};
  3use std::sync::Arc;
  4
  5use anyhow::{Context as _, Result, anyhow};
  6use assistant_context_editor::AssistantContext;
  7use collections::{HashSet, IndexSet};
  8use futures::{self, FutureExt};
  9use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
 10use language::Buffer;
 11use language_model::LanguageModelImage;
 12use project::image_store::is_image_file;
 13use project::{Project, ProjectItem, ProjectPath, Symbol};
 14use prompt_store::UserPromptId;
 15use ref_cast::RefCast as _;
 16use text::{Anchor, OffsetRangeExt};
 17
 18use crate::ThreadStore;
 19use crate::context::{
 20    AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
 21    FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
 22    SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
 23};
 24use crate::context_strip::SuggestedContext;
 25use crate::thread::{MessageId, Thread, ThreadId};
 26
 27pub struct ContextStore {
 28    project: WeakEntity<Project>,
 29    thread_store: Option<WeakEntity<ThreadStore>>,
 30    next_context_id: ContextId,
 31    context_set: IndexSet<AgentContextKey>,
 32    context_thread_ids: HashSet<ThreadId>,
 33    context_text_thread_paths: HashSet<Arc<Path>>,
 34}
 35
 36pub enum ContextStoreEvent {
 37    ContextRemoved(AgentContextKey),
 38}
 39
 40impl EventEmitter<ContextStoreEvent> for ContextStore {}
 41
 42impl ContextStore {
 43    pub fn new(
 44        project: WeakEntity<Project>,
 45        thread_store: Option<WeakEntity<ThreadStore>>,
 46    ) -> Self {
 47        Self {
 48            project,
 49            thread_store,
 50            next_context_id: ContextId::zero(),
 51            context_set: IndexSet::default(),
 52            context_thread_ids: HashSet::default(),
 53            context_text_thread_paths: HashSet::default(),
 54        }
 55    }
 56
 57    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 58        self.context_set.iter().map(|entry| entry.as_ref())
 59    }
 60
 61    pub fn clear(&mut self) {
 62        self.context_set.clear();
 63        self.context_thread_ids.clear();
 64    }
 65
 66    pub fn new_context_for_thread(
 67        &self,
 68        thread: &Thread,
 69        exclude_messages_from_id: Option<MessageId>,
 70    ) -> Vec<AgentContextHandle> {
 71        let existing_context = thread
 72            .messages()
 73            .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
 74            .flat_map(|message| {
 75                message
 76                    .loaded_context
 77                    .contexts
 78                    .iter()
 79                    .map(|context| AgentContextKey(context.handle()))
 80            })
 81            .collect::<HashSet<_>>();
 82        self.context_set
 83            .iter()
 84            .filter(|context| !existing_context.contains(context))
 85            .map(|entry| entry.0.clone())
 86            .collect::<Vec<_>>()
 87    }
 88
 89    pub fn add_file_from_path(
 90        &mut self,
 91        project_path: ProjectPath,
 92        remove_if_exists: bool,
 93        cx: &mut Context<Self>,
 94    ) -> Task<Result<Option<AgentContextHandle>>> {
 95        let Some(project) = self.project.upgrade() else {
 96            return Task::ready(Err(anyhow!("failed to read project")));
 97        };
 98
 99        if is_image_file(&project, &project_path, cx) {
100            self.add_image_from_path(project_path, remove_if_exists, cx)
101        } else {
102            cx.spawn(async move |this, cx| {
103                let open_buffer_task = project.update(cx, |project, cx| {
104                    project.open_buffer(project_path.clone(), cx)
105                })?;
106                let buffer = open_buffer_task.await?;
107                this.update(cx, |this, cx| {
108                    this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
109                })
110            })
111        }
112    }
113
114    pub fn add_file_from_buffer(
115        &mut self,
116        project_path: &ProjectPath,
117        buffer: Entity<Buffer>,
118        remove_if_exists: bool,
119        cx: &mut Context<Self>,
120    ) -> Option<AgentContextHandle> {
121        let context_id = self.next_context_id.post_inc();
122        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
123
124        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
125            if remove_if_exists {
126                self.remove_context(&context, cx);
127                None
128            } else {
129                Some(key.as_ref().clone())
130            }
131        } else if self.path_included_in_directory(project_path, cx).is_some() {
132            None
133        } else {
134            self.insert_context(context.clone(), cx);
135            Some(context)
136        }
137    }
138
139    pub fn add_directory(
140        &mut self,
141        project_path: &ProjectPath,
142        remove_if_exists: bool,
143        cx: &mut Context<Self>,
144    ) -> Result<Option<AgentContextHandle>> {
145        let project = self.project.upgrade().context("failed to read project")?;
146        let entry_id = project
147            .read(cx)
148            .entry_for_path(project_path, cx)
149            .map(|entry| entry.id)
150            .context("no entry found for directory context")?;
151
152        let context_id = self.next_context_id.post_inc();
153        let context = AgentContextHandle::Directory(DirectoryContextHandle {
154            entry_id,
155            context_id,
156        });
157
158        let context =
159            if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
160                if remove_if_exists {
161                    self.remove_context(&context, cx);
162                    None
163                } else {
164                    Some(existing.as_ref().clone())
165                }
166            } else {
167                self.insert_context(context.clone(), cx);
168                Some(context)
169            };
170
171        anyhow::Ok(context)
172    }
173
174    pub fn add_symbol(
175        &mut self,
176        buffer: Entity<Buffer>,
177        symbol: SharedString,
178        range: Range<Anchor>,
179        enclosing_range: Range<Anchor>,
180        remove_if_exists: bool,
181        cx: &mut Context<Self>,
182    ) -> (Option<AgentContextHandle>, bool) {
183        let context_id = self.next_context_id.post_inc();
184        let context = AgentContextHandle::Symbol(SymbolContextHandle {
185            buffer,
186            symbol,
187            range,
188            enclosing_range,
189            context_id,
190        });
191
192        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
193            let handle = if remove_if_exists {
194                self.remove_context(&context, cx);
195                None
196            } else {
197                Some(key.as_ref().clone())
198            };
199            return (handle, false);
200        }
201
202        let included = self.insert_context(context.clone(), cx);
203        (Some(context), included)
204    }
205
206    pub fn add_thread(
207        &mut self,
208        thread: Entity<Thread>,
209        remove_if_exists: bool,
210        cx: &mut Context<Self>,
211    ) -> Option<AgentContextHandle> {
212        let context_id = self.next_context_id.post_inc();
213        let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
214
215        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
216            if remove_if_exists {
217                self.remove_context(&context, cx);
218                None
219            } else {
220                Some(existing.as_ref().clone())
221            }
222        } else {
223            self.insert_context(context.clone(), cx);
224            Some(context)
225        }
226    }
227
228    pub fn add_text_thread(
229        &mut self,
230        context: Entity<AssistantContext>,
231        remove_if_exists: bool,
232        cx: &mut Context<Self>,
233    ) -> Option<AgentContextHandle> {
234        let context_id = self.next_context_id.post_inc();
235        let context = AgentContextHandle::TextThread(TextThreadContextHandle {
236            context,
237            context_id,
238        });
239
240        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
241            if remove_if_exists {
242                self.remove_context(&context, cx);
243                None
244            } else {
245                Some(existing.as_ref().clone())
246            }
247        } else {
248            self.insert_context(context.clone(), cx);
249            Some(context)
250        }
251    }
252
253    pub fn add_rules(
254        &mut self,
255        prompt_id: UserPromptId,
256        remove_if_exists: bool,
257        cx: &mut Context<ContextStore>,
258    ) -> Option<AgentContextHandle> {
259        let context_id = self.next_context_id.post_inc();
260        let context = AgentContextHandle::Rules(RulesContextHandle {
261            prompt_id,
262            context_id,
263        });
264
265        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
266            if remove_if_exists {
267                self.remove_context(&context, cx);
268                None
269            } else {
270                Some(existing.as_ref().clone())
271            }
272        } else {
273            self.insert_context(context.clone(), cx);
274            Some(context)
275        }
276    }
277
278    pub fn add_fetched_url(
279        &mut self,
280        url: String,
281        text: impl Into<SharedString>,
282        cx: &mut Context<ContextStore>,
283    ) -> AgentContextHandle {
284        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
285            url: url.into(),
286            text: text.into(),
287            context_id: self.next_context_id.post_inc(),
288        });
289
290        self.insert_context(context.clone(), cx);
291        context
292    }
293
294    pub fn add_image_from_path(
295        &mut self,
296        project_path: ProjectPath,
297        remove_if_exists: bool,
298        cx: &mut Context<ContextStore>,
299    ) -> Task<Result<Option<AgentContextHandle>>> {
300        let project = self.project.clone();
301        cx.spawn(async move |this, cx| {
302            let open_image_task = project.update(cx, |project, cx| {
303                project.open_image(project_path.clone(), cx)
304            })?;
305            let image_item = open_image_task.await?;
306            let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
307            this.update(cx, |this, cx| {
308                this.insert_image(
309                    Some(image_item.read(cx).project_path(cx)),
310                    image,
311                    remove_if_exists,
312                    cx,
313                )
314            })
315        })
316    }
317
318    pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
319        self.insert_image(None, image, false, cx);
320    }
321
322    fn insert_image(
323        &mut self,
324        project_path: Option<ProjectPath>,
325        image: Arc<Image>,
326        remove_if_exists: bool,
327        cx: &mut Context<ContextStore>,
328    ) -> Option<AgentContextHandle> {
329        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
330        let context = AgentContextHandle::Image(ImageContext {
331            project_path,
332            original_image: image,
333            image_task,
334            context_id: self.next_context_id.post_inc(),
335        });
336        if self.has_context(&context) {
337            if remove_if_exists {
338                self.remove_context(&context, cx);
339                return None;
340            }
341        }
342
343        self.insert_context(context.clone(), cx);
344        Some(context)
345    }
346
347    pub fn add_selection(
348        &mut self,
349        buffer: Entity<Buffer>,
350        range: Range<Anchor>,
351        cx: &mut Context<ContextStore>,
352    ) {
353        let context_id = self.next_context_id.post_inc();
354        let context = AgentContextHandle::Selection(SelectionContextHandle {
355            buffer,
356            range,
357            context_id,
358        });
359        self.insert_context(context, cx);
360    }
361
362    pub fn add_suggested_context(
363        &mut self,
364        suggested: &SuggestedContext,
365        cx: &mut Context<ContextStore>,
366    ) {
367        match suggested {
368            SuggestedContext::File {
369                buffer,
370                icon_path: _,
371                name: _,
372            } => {
373                if let Some(buffer) = buffer.upgrade() {
374                    let context_id = self.next_context_id.post_inc();
375                    self.insert_context(
376                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
377                        cx,
378                    );
379                };
380            }
381            SuggestedContext::Thread { thread, name: _ } => {
382                if let Some(thread) = thread.upgrade() {
383                    let context_id = self.next_context_id.post_inc();
384                    self.insert_context(
385                        AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
386                        cx,
387                    );
388                }
389            }
390            SuggestedContext::TextThread { context, name: _ } => {
391                if let Some(context) = context.upgrade() {
392                    let context_id = self.next_context_id.post_inc();
393                    self.insert_context(
394                        AgentContextHandle::TextThread(TextThreadContextHandle {
395                            context,
396                            context_id,
397                        }),
398                        cx,
399                    );
400                }
401            }
402        }
403    }
404
405    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
406        match &context {
407            AgentContextHandle::Thread(thread_context) => {
408                if let Some(thread_store) = self.thread_store.clone() {
409                    thread_context.thread.update(cx, |thread, cx| {
410                        thread.start_generating_detailed_summary_if_needed(thread_store, cx);
411                    });
412                    self.context_thread_ids
413                        .insert(thread_context.thread.read(cx).id().clone());
414                } else {
415                    return false;
416                }
417            }
418            AgentContextHandle::TextThread(text_thread_context) => {
419                self.context_text_thread_paths
420                    .extend(text_thread_context.context.read(cx).path().cloned());
421            }
422            _ => {}
423        }
424        let inserted = self.context_set.insert(AgentContextKey(context));
425        if inserted {
426            cx.notify();
427        }
428        inserted
429    }
430
431    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
432        if let Some((_, key)) = self
433            .context_set
434            .shift_remove_full(AgentContextKey::ref_cast(context))
435        {
436            match context {
437                AgentContextHandle::Thread(thread_context) => {
438                    self.context_thread_ids
439                        .remove(thread_context.thread.read(cx).id());
440                }
441                AgentContextHandle::TextThread(text_thread_context) => {
442                    if let Some(path) = text_thread_context.context.read(cx).path() {
443                        self.context_text_thread_paths.remove(path);
444                    }
445                }
446                _ => {}
447            }
448            cx.emit(ContextStoreEvent::ContextRemoved(key));
449            cx.notify();
450        }
451    }
452
453    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
454        self.context_set
455            .contains(AgentContextKey::ref_cast(context))
456    }
457
458    /// Returns whether this file path is already included directly in the context, or if it will be
459    /// included in the context via a directory.
460    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
461        let project = self.project.upgrade()?.read(cx);
462        self.context().find_map(|context| match context {
463            AgentContextHandle::File(file_context) => {
464                FileInclusion::check_file(file_context, path, cx)
465            }
466            AgentContextHandle::Image(image_context) => {
467                FileInclusion::check_image(image_context, path)
468            }
469            AgentContextHandle::Directory(directory_context) => {
470                FileInclusion::check_directory(directory_context, path, project, cx)
471            }
472            _ => None,
473        })
474    }
475
476    pub fn path_included_in_directory(
477        &self,
478        path: &ProjectPath,
479        cx: &App,
480    ) -> Option<FileInclusion> {
481        let project = self.project.upgrade()?.read(cx);
482        self.context().find_map(|context| match context {
483            AgentContextHandle::Directory(directory_context) => {
484                FileInclusion::check_directory(directory_context, path, project, cx)
485            }
486            _ => None,
487        })
488    }
489
490    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
491        self.context().any(|context| match context {
492            AgentContextHandle::Symbol(context) => {
493                if context.symbol != symbol.name {
494                    return false;
495                }
496                let buffer = context.buffer.read(cx);
497                let Some(context_path) = buffer.project_path(cx) else {
498                    return false;
499                };
500                if context_path != symbol.path {
501                    return false;
502                }
503                let context_range = context.range.to_point_utf16(&buffer.snapshot());
504                context_range.start == symbol.range.start.0
505                    && context_range.end == symbol.range.end.0
506            }
507            _ => false,
508        })
509    }
510
511    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
512        self.context_thread_ids.contains(thread_id)
513    }
514
515    pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
516        self.context_text_thread_paths.contains(path)
517    }
518
519    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
520        self.context_set
521            .contains(&RulesContextHandle::lookup_key(prompt_id))
522    }
523
524    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
525        self.context_set
526            .contains(&FetchedUrlContext::lookup_key(url.into()))
527    }
528
529    pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
530        self.context_set
531            .get(&FetchedUrlContext::lookup_key(url))
532            .map(|key| key.as_ref().clone())
533    }
534
535    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
536        self.context()
537            .filter_map(|context| match context {
538                AgentContextHandle::File(file) => {
539                    let buffer = file.buffer.read(cx);
540                    buffer.project_path(cx)
541                }
542                AgentContextHandle::Directory(_)
543                | AgentContextHandle::Symbol(_)
544                | AgentContextHandle::Selection(_)
545                | AgentContextHandle::FetchedUrl(_)
546                | AgentContextHandle::Thread(_)
547                | AgentContextHandle::TextThread(_)
548                | AgentContextHandle::Rules(_)
549                | AgentContextHandle::Image(_) => None,
550            })
551            .collect()
552    }
553
554    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
555        &self.context_thread_ids
556    }
557}
558
559pub enum FileInclusion {
560    Direct,
561    InDirectory { full_path: PathBuf },
562}
563
564impl FileInclusion {
565    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
566        let file_path = file_context.buffer.read(cx).project_path(cx)?;
567        if path == &file_path {
568            Some(FileInclusion::Direct)
569        } else {
570            None
571        }
572    }
573
574    fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
575        let image_path = image_context.project_path.as_ref()?;
576        if path == image_path {
577            Some(FileInclusion::Direct)
578        } else {
579            None
580        }
581    }
582
583    fn check_directory(
584        directory_context: &DirectoryContextHandle,
585        path: &ProjectPath,
586        project: &Project,
587        cx: &App,
588    ) -> Option<Self> {
589        let worktree = project
590            .worktree_for_entry(directory_context.entry_id, cx)?
591            .read(cx);
592        let entry = worktree.entry_for_id(directory_context.entry_id)?;
593        let directory_path = ProjectPath {
594            worktree_id: worktree.id(),
595            path: entry.path.clone(),
596        };
597        if path.starts_with(&directory_path) {
598            if path == &directory_path {
599                Some(FileInclusion::Direct)
600            } else {
601                Some(FileInclusion::InDirectory {
602                    full_path: worktree.full_path(&entry.path),
603                })
604            }
605        } else {
606            None
607        }
608    }
609}