context_store.rs

  1use std::ops::Range;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use collections::{HashSet, IndexSet};
  7use futures::{self, FutureExt};
  8use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
  9use language::Buffer;
 10use language_model::LanguageModelImage;
 11use project::image_store::is_image_file;
 12use project::{Project, ProjectItem, ProjectPath, Symbol};
 13use prompt_store::UserPromptId;
 14use ref_cast::RefCast as _;
 15use text::{Anchor, OffsetRangeExt};
 16
 17use crate::ThreadStore;
 18use crate::context::{
 19    AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
 20    FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
 21    SymbolContextHandle, ThreadContextHandle,
 22};
 23use crate::context_strip::SuggestedContext;
 24use crate::thread::{MessageId, Thread, ThreadId};
 25
 26pub struct ContextStore {
 27    project: WeakEntity<Project>,
 28    thread_store: Option<WeakEntity<ThreadStore>>,
 29    next_context_id: ContextId,
 30    context_set: IndexSet<AgentContextKey>,
 31    context_thread_ids: HashSet<ThreadId>,
 32}
 33
 34pub enum ContextStoreEvent {
 35    ContextRemoved(AgentContextKey),
 36}
 37
 38impl EventEmitter<ContextStoreEvent> for ContextStore {}
 39
 40impl ContextStore {
 41    pub fn new(
 42        project: WeakEntity<Project>,
 43        thread_store: Option<WeakEntity<ThreadStore>>,
 44    ) -> Self {
 45        Self {
 46            project,
 47            thread_store,
 48            next_context_id: ContextId::zero(),
 49            context_set: IndexSet::default(),
 50            context_thread_ids: HashSet::default(),
 51        }
 52    }
 53
 54    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 55        self.context_set.iter().map(|entry| entry.as_ref())
 56    }
 57
 58    pub fn clear(&mut self) {
 59        self.context_set.clear();
 60        self.context_thread_ids.clear();
 61    }
 62
 63    pub fn new_context_for_thread(
 64        &self,
 65        thread: &Thread,
 66        exclude_messages_from_id: Option<MessageId>,
 67    ) -> Vec<AgentContextHandle> {
 68        let existing_context = thread
 69            .messages()
 70            .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
 71            .flat_map(|message| {
 72                message
 73                    .loaded_context
 74                    .contexts
 75                    .iter()
 76                    .map(|context| AgentContextKey(context.handle()))
 77            })
 78            .collect::<HashSet<_>>();
 79        self.context_set
 80            .iter()
 81            .filter(|context| !existing_context.contains(context))
 82            .map(|entry| entry.0.clone())
 83            .collect::<Vec<_>>()
 84    }
 85
 86    pub fn add_file_from_path(
 87        &mut self,
 88        project_path: ProjectPath,
 89        remove_if_exists: bool,
 90        cx: &mut Context<Self>,
 91    ) -> Task<Result<Option<AgentContextHandle>>> {
 92        let Some(project) = self.project.upgrade() else {
 93            return Task::ready(Err(anyhow!("failed to read project")));
 94        };
 95
 96        if is_image_file(&project, &project_path, cx) {
 97            self.add_image_from_path(project_path, remove_if_exists, cx)
 98        } else {
 99            cx.spawn(async move |this, cx| {
100                let open_buffer_task = project.update(cx, |project, cx| {
101                    project.open_buffer(project_path.clone(), cx)
102                })?;
103                let buffer = open_buffer_task.await?;
104                this.update(cx, |this, cx| {
105                    this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
106                })
107            })
108        }
109    }
110
111    pub fn add_file_from_buffer(
112        &mut self,
113        project_path: &ProjectPath,
114        buffer: Entity<Buffer>,
115        remove_if_exists: bool,
116        cx: &mut Context<Self>,
117    ) -> Option<AgentContextHandle> {
118        let context_id = self.next_context_id.post_inc();
119        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
120
121        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
122            if remove_if_exists {
123                self.remove_context(&context, cx);
124                None
125            } else {
126                Some(key.as_ref().clone())
127            }
128        } else if self.path_included_in_directory(project_path, cx).is_some() {
129            None
130        } else {
131            self.insert_context(context.clone(), cx);
132            Some(context)
133        }
134    }
135
136    pub fn add_directory(
137        &mut self,
138        project_path: &ProjectPath,
139        remove_if_exists: bool,
140        cx: &mut Context<Self>,
141    ) -> Result<Option<AgentContextHandle>> {
142        let Some(project) = self.project.upgrade() else {
143            return Err(anyhow!("failed to read project"));
144        };
145
146        let Some(entry_id) = project
147            .read(cx)
148            .entry_for_path(project_path, cx)
149            .map(|entry| entry.id)
150        else {
151            return Err(anyhow!("no entry found for directory context"));
152        };
153
154        let context_id = self.next_context_id.post_inc();
155        let context = AgentContextHandle::Directory(DirectoryContextHandle {
156            entry_id,
157            context_id,
158        });
159
160        let context =
161            if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
162                if remove_if_exists {
163                    self.remove_context(&context, cx);
164                    None
165                } else {
166                    Some(existing.as_ref().clone())
167                }
168            } else {
169                self.insert_context(context.clone(), cx);
170                Some(context)
171            };
172
173        anyhow::Ok(context)
174    }
175
176    pub fn add_symbol(
177        &mut self,
178        buffer: Entity<Buffer>,
179        symbol: SharedString,
180        range: Range<Anchor>,
181        enclosing_range: Range<Anchor>,
182        remove_if_exists: bool,
183        cx: &mut Context<Self>,
184    ) -> (Option<AgentContextHandle>, bool) {
185        let context_id = self.next_context_id.post_inc();
186        let context = AgentContextHandle::Symbol(SymbolContextHandle {
187            buffer,
188            symbol,
189            range,
190            enclosing_range,
191            context_id,
192        });
193
194        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
195            let handle = if remove_if_exists {
196                self.remove_context(&context, cx);
197                None
198            } else {
199                Some(key.as_ref().clone())
200            };
201            return (handle, false);
202        }
203
204        let included = self.insert_context(context.clone(), cx);
205        (Some(context), included)
206    }
207
208    pub fn add_thread(
209        &mut self,
210        thread: Entity<Thread>,
211        remove_if_exists: bool,
212        cx: &mut Context<Self>,
213    ) -> Option<AgentContextHandle> {
214        let context_id = self.next_context_id.post_inc();
215        let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
216
217        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
218            if remove_if_exists {
219                self.remove_context(&context, cx);
220                None
221            } else {
222                Some(existing.as_ref().clone())
223            }
224        } else {
225            self.insert_context(context.clone(), cx);
226            Some(context)
227        }
228    }
229
230    pub fn add_rules(
231        &mut self,
232        prompt_id: UserPromptId,
233        remove_if_exists: bool,
234        cx: &mut Context<ContextStore>,
235    ) -> Option<AgentContextHandle> {
236        let context_id = self.next_context_id.post_inc();
237        let context = AgentContextHandle::Rules(RulesContextHandle {
238            prompt_id,
239            context_id,
240        });
241
242        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
243            if remove_if_exists {
244                self.remove_context(&context, cx);
245                None
246            } else {
247                Some(existing.as_ref().clone())
248            }
249        } else {
250            self.insert_context(context.clone(), cx);
251            Some(context)
252        }
253    }
254
255    pub fn add_fetched_url(
256        &mut self,
257        url: String,
258        text: impl Into<SharedString>,
259        cx: &mut Context<ContextStore>,
260    ) -> AgentContextHandle {
261        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
262            url: url.into(),
263            text: text.into(),
264            context_id: self.next_context_id.post_inc(),
265        });
266
267        self.insert_context(context.clone(), cx);
268        context
269    }
270
271    pub fn add_image_from_path(
272        &mut self,
273        project_path: ProjectPath,
274        remove_if_exists: bool,
275        cx: &mut Context<ContextStore>,
276    ) -> Task<Result<Option<AgentContextHandle>>> {
277        let project = self.project.clone();
278        cx.spawn(async move |this, cx| {
279            let open_image_task = project.update(cx, |project, cx| {
280                project.open_image(project_path.clone(), cx)
281            })?;
282            let image_item = open_image_task.await?;
283            let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
284            this.update(cx, |this, cx| {
285                this.insert_image(
286                    Some(image_item.read(cx).project_path(cx)),
287                    image,
288                    remove_if_exists,
289                    cx,
290                )
291            })
292        })
293    }
294
295    pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
296        self.insert_image(None, image, false, cx);
297    }
298
299    fn insert_image(
300        &mut self,
301        project_path: Option<ProjectPath>,
302        image: Arc<Image>,
303        remove_if_exists: bool,
304        cx: &mut Context<ContextStore>,
305    ) -> Option<AgentContextHandle> {
306        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
307        let context = AgentContextHandle::Image(ImageContext {
308            project_path,
309            original_image: image,
310            image_task,
311            context_id: self.next_context_id.post_inc(),
312        });
313        if self.has_context(&context) {
314            if remove_if_exists {
315                self.remove_context(&context, cx);
316                return None;
317            }
318        }
319
320        self.insert_context(context.clone(), cx);
321        Some(context)
322    }
323
324    pub fn add_selection(
325        &mut self,
326        buffer: Entity<Buffer>,
327        range: Range<Anchor>,
328        cx: &mut Context<ContextStore>,
329    ) {
330        let context_id = self.next_context_id.post_inc();
331        let context = AgentContextHandle::Selection(SelectionContextHandle {
332            buffer,
333            range,
334            context_id,
335        });
336        self.insert_context(context, cx);
337    }
338
339    pub fn add_suggested_context(
340        &mut self,
341        suggested: &SuggestedContext,
342        cx: &mut Context<ContextStore>,
343    ) {
344        match suggested {
345            SuggestedContext::File {
346                buffer,
347                icon_path: _,
348                name: _,
349            } => {
350                if let Some(buffer) = buffer.upgrade() {
351                    let context_id = self.next_context_id.post_inc();
352                    self.insert_context(
353                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
354                        cx,
355                    );
356                };
357            }
358            SuggestedContext::Thread { thread, name: _ } => {
359                if let Some(thread) = thread.upgrade() {
360                    let context_id = self.next_context_id.post_inc();
361                    self.insert_context(
362                        AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
363                        cx,
364                    );
365                }
366            }
367        }
368    }
369
370    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
371        match &context {
372            AgentContextHandle::Thread(thread_context) => {
373                if let Some(thread_store) = self.thread_store.clone() {
374                    thread_context.thread.update(cx, |thread, cx| {
375                        thread.start_generating_detailed_summary_if_needed(thread_store, cx);
376                    });
377                    self.context_thread_ids
378                        .insert(thread_context.thread.read(cx).id().clone());
379                } else {
380                    return false;
381                }
382            }
383            _ => {}
384        }
385        let inserted = self.context_set.insert(AgentContextKey(context));
386        if inserted {
387            cx.notify();
388        }
389        inserted
390    }
391
392    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
393        if let Some((_, key)) = self
394            .context_set
395            .shift_remove_full(AgentContextKey::ref_cast(context))
396        {
397            match context {
398                AgentContextHandle::Thread(thread_context) => {
399                    self.context_thread_ids
400                        .remove(thread_context.thread.read(cx).id());
401                }
402                _ => {}
403            }
404            cx.emit(ContextStoreEvent::ContextRemoved(key));
405            cx.notify();
406        }
407    }
408
409    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
410        self.context_set
411            .contains(AgentContextKey::ref_cast(context))
412    }
413
414    /// Returns whether this file path is already included directly in the context, or if it will be
415    /// included in the context via a directory.
416    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
417        let project = self.project.upgrade()?.read(cx);
418        self.context().find_map(|context| match context {
419            AgentContextHandle::File(file_context) => {
420                FileInclusion::check_file(file_context, path, cx)
421            }
422            AgentContextHandle::Image(image_context) => {
423                FileInclusion::check_image(image_context, path)
424            }
425            AgentContextHandle::Directory(directory_context) => {
426                FileInclusion::check_directory(directory_context, path, project, cx)
427            }
428            _ => None,
429        })
430    }
431
432    pub fn path_included_in_directory(
433        &self,
434        path: &ProjectPath,
435        cx: &App,
436    ) -> Option<FileInclusion> {
437        let project = self.project.upgrade()?.read(cx);
438        self.context().find_map(|context| match context {
439            AgentContextHandle::Directory(directory_context) => {
440                FileInclusion::check_directory(directory_context, path, project, cx)
441            }
442            _ => None,
443        })
444    }
445
446    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
447        self.context().any(|context| match context {
448            AgentContextHandle::Symbol(context) => {
449                if context.symbol != symbol.name {
450                    return false;
451                }
452                let buffer = context.buffer.read(cx);
453                let Some(context_path) = buffer.project_path(cx) else {
454                    return false;
455                };
456                if context_path != symbol.path {
457                    return false;
458                }
459                let context_range = context.range.to_point_utf16(&buffer.snapshot());
460                context_range.start == symbol.range.start.0
461                    && context_range.end == symbol.range.end.0
462            }
463            _ => false,
464        })
465    }
466
467    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
468        self.context_thread_ids.contains(thread_id)
469    }
470
471    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
472        self.context_set
473            .contains(&RulesContextHandle::lookup_key(prompt_id))
474    }
475
476    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
477        self.context_set
478            .contains(&FetchedUrlContext::lookup_key(url.into()))
479    }
480
481    pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
482        self.context_set
483            .get(&FetchedUrlContext::lookup_key(url))
484            .map(|key| key.as_ref().clone())
485    }
486
487    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
488        self.context()
489            .filter_map(|context| match context {
490                AgentContextHandle::File(file) => {
491                    let buffer = file.buffer.read(cx);
492                    buffer.project_path(cx)
493                }
494                AgentContextHandle::Directory(_)
495                | AgentContextHandle::Symbol(_)
496                | AgentContextHandle::Selection(_)
497                | AgentContextHandle::FetchedUrl(_)
498                | AgentContextHandle::Thread(_)
499                | AgentContextHandle::Rules(_)
500                | AgentContextHandle::Image(_) => None,
501            })
502            .collect()
503    }
504
505    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
506        &self.context_thread_ids
507    }
508}
509
510pub enum FileInclusion {
511    Direct,
512    InDirectory { full_path: PathBuf },
513}
514
515impl FileInclusion {
516    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
517        let file_path = file_context.buffer.read(cx).project_path(cx)?;
518        if path == &file_path {
519            Some(FileInclusion::Direct)
520        } else {
521            None
522        }
523    }
524
525    fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
526        let image_path = image_context.project_path.as_ref()?;
527        if path == image_path {
528            Some(FileInclusion::Direct)
529        } else {
530            None
531        }
532    }
533
534    fn check_directory(
535        directory_context: &DirectoryContextHandle,
536        path: &ProjectPath,
537        project: &Project,
538        cx: &App,
539    ) -> Option<Self> {
540        let worktree = project
541            .worktree_for_entry(directory_context.entry_id, cx)?
542            .read(cx);
543        let entry = worktree.entry_for_id(directory_context.entry_id)?;
544        let directory_path = ProjectPath {
545            worktree_id: worktree.id(),
546            path: entry.path.clone(),
547        };
548        if path.starts_with(&directory_path) {
549            if path == &directory_path {
550                Some(FileInclusion::Direct)
551            } else {
552                Some(FileInclusion::InDirectory {
553                    full_path: worktree.full_path(&entry.path),
554                })
555            }
556        } else {
557            None
558        }
559    }
560}