context_store.rs

  1use std::ops::Range;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use collections::{HashSet, IndexSet};
  7use futures::{self, FutureExt};
  8use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
  9use language::Buffer;
 10use language_model::LanguageModelImage;
 11use project::image_store::is_image_file;
 12use project::{Project, ProjectItem, ProjectPath, Symbol};
 13use prompt_store::UserPromptId;
 14use ref_cast::RefCast as _;
 15use text::{Anchor, OffsetRangeExt};
 16
 17use crate::ThreadStore;
 18use crate::context::{
 19    AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
 20    FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
 21    SymbolContextHandle, ThreadContextHandle,
 22};
 23use crate::context_strip::SuggestedContext;
 24use crate::thread::{Thread, ThreadId};
 25
 26pub struct ContextStore {
 27    project: WeakEntity<Project>,
 28    thread_store: Option<WeakEntity<ThreadStore>>,
 29    next_context_id: ContextId,
 30    context_set: IndexSet<AgentContextKey>,
 31    context_thread_ids: HashSet<ThreadId>,
 32}
 33
 34impl ContextStore {
 35    pub fn new(
 36        project: WeakEntity<Project>,
 37        thread_store: Option<WeakEntity<ThreadStore>>,
 38    ) -> Self {
 39        Self {
 40            project,
 41            thread_store,
 42            next_context_id: ContextId::zero(),
 43            context_set: IndexSet::default(),
 44            context_thread_ids: HashSet::default(),
 45        }
 46    }
 47
 48    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 49        self.context_set.iter().map(|entry| entry.as_ref())
 50    }
 51
 52    pub fn clear(&mut self) {
 53        self.context_set.clear();
 54        self.context_thread_ids.clear();
 55    }
 56
 57    pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
 58        let existing_context = thread
 59            .messages()
 60            .flat_map(|message| {
 61                message
 62                    .loaded_context
 63                    .contexts
 64                    .iter()
 65                    .map(|context| AgentContextKey(context.handle()))
 66            })
 67            .collect::<HashSet<_>>();
 68        self.context_set
 69            .iter()
 70            .filter(|context| !existing_context.contains(context))
 71            .map(|entry| entry.0.clone())
 72            .collect::<Vec<_>>()
 73    }
 74
 75    pub fn add_file_from_path(
 76        &mut self,
 77        project_path: ProjectPath,
 78        remove_if_exists: bool,
 79        cx: &mut Context<Self>,
 80    ) -> Task<Result<()>> {
 81        let Some(project) = self.project.upgrade() else {
 82            return Task::ready(Err(anyhow!("failed to read project")));
 83        };
 84
 85        if is_image_file(&project, &project_path, cx) {
 86            self.add_image_from_path(project_path, remove_if_exists, cx)
 87        } else {
 88            cx.spawn(async move |this, cx| {
 89                let open_buffer_task = project.update(cx, |project, cx| {
 90                    project.open_buffer(project_path.clone(), cx)
 91                })?;
 92                let buffer = open_buffer_task.await?;
 93                this.update(cx, |this, cx| {
 94                    this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
 95                })
 96            })
 97        }
 98    }
 99
100    pub fn add_file_from_buffer(
101        &mut self,
102        project_path: &ProjectPath,
103        buffer: Entity<Buffer>,
104        remove_if_exists: bool,
105        cx: &mut Context<Self>,
106    ) {
107        let context_id = self.next_context_id.post_inc();
108        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
109
110        let already_included = if self.has_context(&context) {
111            if remove_if_exists {
112                self.remove_context(&context, cx);
113            }
114            true
115        } else {
116            self.path_included_in_directory(project_path, cx).is_some()
117        };
118
119        if !already_included {
120            self.insert_context(context, cx);
121        }
122    }
123
124    pub fn add_directory(
125        &mut self,
126        project_path: &ProjectPath,
127        remove_if_exists: bool,
128        cx: &mut Context<Self>,
129    ) -> Result<()> {
130        let Some(project) = self.project.upgrade() else {
131            return Err(anyhow!("failed to read project"));
132        };
133
134        let Some(entry_id) = project
135            .read(cx)
136            .entry_for_path(project_path, cx)
137            .map(|entry| entry.id)
138        else {
139            return Err(anyhow!("no entry found for directory context"));
140        };
141
142        let context_id = self.next_context_id.post_inc();
143        let context = AgentContextHandle::Directory(DirectoryContextHandle {
144            entry_id,
145            context_id,
146        });
147
148        if self.has_context(&context) {
149            if remove_if_exists {
150                self.remove_context(&context, cx);
151            }
152        } else if self.path_included_in_directory(project_path, cx).is_none() {
153            self.insert_context(context, cx);
154        }
155
156        anyhow::Ok(())
157    }
158
159    pub fn add_symbol(
160        &mut self,
161        buffer: Entity<Buffer>,
162        symbol: SharedString,
163        range: Range<Anchor>,
164        enclosing_range: Range<Anchor>,
165        remove_if_exists: bool,
166        cx: &mut Context<Self>,
167    ) -> bool {
168        let context_id = self.next_context_id.post_inc();
169        let context = AgentContextHandle::Symbol(SymbolContextHandle {
170            buffer,
171            symbol,
172            range,
173            enclosing_range,
174            context_id,
175        });
176
177        if self.has_context(&context) {
178            if remove_if_exists {
179                self.remove_context(&context, cx);
180            }
181            return false;
182        }
183
184        self.insert_context(context, cx)
185    }
186
187    pub fn add_thread(
188        &mut self,
189        thread: Entity<Thread>,
190        remove_if_exists: bool,
191        cx: &mut Context<Self>,
192    ) {
193        let context_id = self.next_context_id.post_inc();
194        let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
195
196        if self.has_context(&context) {
197            if remove_if_exists {
198                self.remove_context(&context, cx);
199            }
200        } else {
201            self.insert_context(context, cx);
202        }
203    }
204
205    pub fn add_rules(
206        &mut self,
207        prompt_id: UserPromptId,
208        remove_if_exists: bool,
209        cx: &mut Context<ContextStore>,
210    ) {
211        let context_id = self.next_context_id.post_inc();
212        let context = AgentContextHandle::Rules(RulesContextHandle {
213            prompt_id,
214            context_id,
215        });
216
217        if self.has_context(&context) {
218            if remove_if_exists {
219                self.remove_context(&context, cx);
220            }
221        } else {
222            self.insert_context(context, cx);
223        }
224    }
225
226    pub fn add_fetched_url(
227        &mut self,
228        url: String,
229        text: impl Into<SharedString>,
230        cx: &mut Context<ContextStore>,
231    ) {
232        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
233            url: url.into(),
234            text: text.into(),
235            context_id: self.next_context_id.post_inc(),
236        });
237
238        self.insert_context(context, cx);
239    }
240
241    pub fn add_image_from_path(
242        &mut self,
243        project_path: ProjectPath,
244        remove_if_exists: bool,
245        cx: &mut Context<ContextStore>,
246    ) -> Task<Result<()>> {
247        let project = self.project.clone();
248        cx.spawn(async move |this, cx| {
249            let open_image_task = project.update(cx, |project, cx| {
250                project.open_image(project_path.clone(), cx)
251            })?;
252            let image_item = open_image_task.await?;
253            let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
254            this.update(cx, |this, cx| {
255                this.insert_image(
256                    Some(image_item.read(cx).project_path(cx)),
257                    image,
258                    remove_if_exists,
259                    cx,
260                );
261            })
262        })
263    }
264
265    pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
266        self.insert_image(None, image, false, cx);
267    }
268
269    fn insert_image(
270        &mut self,
271        project_path: Option<ProjectPath>,
272        image: Arc<Image>,
273        remove_if_exists: bool,
274        cx: &mut Context<ContextStore>,
275    ) {
276        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
277        let context = AgentContextHandle::Image(ImageContext {
278            project_path,
279            original_image: image,
280            image_task,
281            context_id: self.next_context_id.post_inc(),
282        });
283        if self.has_context(&context) {
284            if remove_if_exists {
285                self.remove_context(&context, cx);
286                return;
287            }
288        }
289
290        self.insert_context(context, cx);
291    }
292
293    pub fn add_selection(
294        &mut self,
295        buffer: Entity<Buffer>,
296        range: Range<Anchor>,
297        cx: &mut Context<ContextStore>,
298    ) {
299        let context_id = self.next_context_id.post_inc();
300        let context = AgentContextHandle::Selection(SelectionContextHandle {
301            buffer,
302            range,
303            context_id,
304        });
305        self.insert_context(context, cx);
306    }
307
308    pub fn add_suggested_context(
309        &mut self,
310        suggested: &SuggestedContext,
311        cx: &mut Context<ContextStore>,
312    ) {
313        match suggested {
314            SuggestedContext::File {
315                buffer,
316                icon_path: _,
317                name: _,
318            } => {
319                if let Some(buffer) = buffer.upgrade() {
320                    let context_id = self.next_context_id.post_inc();
321                    self.insert_context(
322                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
323                        cx,
324                    );
325                };
326            }
327            SuggestedContext::Thread { thread, name: _ } => {
328                if let Some(thread) = thread.upgrade() {
329                    let context_id = self.next_context_id.post_inc();
330                    self.insert_context(
331                        AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
332                        cx,
333                    );
334                }
335            }
336        }
337    }
338
339    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
340        match &context {
341            AgentContextHandle::Thread(thread_context) => {
342                if let Some(thread_store) = self.thread_store.clone() {
343                    thread_context.thread.update(cx, |thread, cx| {
344                        thread.start_generating_detailed_summary_if_needed(thread_store, cx);
345                    });
346                    self.context_thread_ids
347                        .insert(thread_context.thread.read(cx).id().clone());
348                } else {
349                    return false;
350                }
351            }
352            _ => {}
353        }
354        let inserted = self.context_set.insert(AgentContextKey(context));
355        if inserted {
356            cx.notify();
357        }
358        inserted
359    }
360
361    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
362        if self
363            .context_set
364            .shift_remove(AgentContextKey::ref_cast(context))
365        {
366            match context {
367                AgentContextHandle::Thread(thread_context) => {
368                    self.context_thread_ids
369                        .remove(thread_context.thread.read(cx).id());
370                }
371                _ => {}
372            }
373            cx.notify();
374        }
375    }
376
377    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
378        self.context_set
379            .contains(AgentContextKey::ref_cast(context))
380    }
381
382    /// Returns whether this file path is already included directly in the context, or if it will be
383    /// included in the context via a directory.
384    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
385        let project = self.project.upgrade()?.read(cx);
386        self.context().find_map(|context| match context {
387            AgentContextHandle::File(file_context) => {
388                FileInclusion::check_file(file_context, path, cx)
389            }
390            AgentContextHandle::Image(image_context) => {
391                FileInclusion::check_image(image_context, path)
392            }
393            AgentContextHandle::Directory(directory_context) => {
394                FileInclusion::check_directory(directory_context, path, project, cx)
395            }
396            _ => None,
397        })
398    }
399
400    pub fn path_included_in_directory(
401        &self,
402        path: &ProjectPath,
403        cx: &App,
404    ) -> Option<FileInclusion> {
405        let project = self.project.upgrade()?.read(cx);
406        self.context().find_map(|context| match context {
407            AgentContextHandle::Directory(directory_context) => {
408                FileInclusion::check_directory(directory_context, path, project, cx)
409            }
410            _ => None,
411        })
412    }
413
414    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
415        self.context().any(|context| match context {
416            AgentContextHandle::Symbol(context) => {
417                if context.symbol != symbol.name {
418                    return false;
419                }
420                let buffer = context.buffer.read(cx);
421                let Some(context_path) = buffer.project_path(cx) else {
422                    return false;
423                };
424                if context_path != symbol.path {
425                    return false;
426                }
427                let context_range = context.range.to_point_utf16(&buffer.snapshot());
428                context_range.start == symbol.range.start.0
429                    && context_range.end == symbol.range.end.0
430            }
431            _ => false,
432        })
433    }
434
435    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
436        self.context_thread_ids.contains(thread_id)
437    }
438
439    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
440        self.context_set
441            .contains(&RulesContextHandle::lookup_key(prompt_id))
442    }
443
444    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
445        self.context_set
446            .contains(&FetchedUrlContext::lookup_key(url.into()))
447    }
448
449    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
450        self.context()
451            .filter_map(|context| match context {
452                AgentContextHandle::File(file) => {
453                    let buffer = file.buffer.read(cx);
454                    buffer.project_path(cx)
455                }
456                AgentContextHandle::Directory(_)
457                | AgentContextHandle::Symbol(_)
458                | AgentContextHandle::Selection(_)
459                | AgentContextHandle::FetchedUrl(_)
460                | AgentContextHandle::Thread(_)
461                | AgentContextHandle::Rules(_)
462                | AgentContextHandle::Image(_) => None,
463            })
464            .collect()
465    }
466
467    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
468        &self.context_thread_ids
469    }
470}
471
472pub enum FileInclusion {
473    Direct,
474    InDirectory { full_path: PathBuf },
475}
476
477impl FileInclusion {
478    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
479        let file_path = file_context.buffer.read(cx).project_path(cx)?;
480        if path == &file_path {
481            Some(FileInclusion::Direct)
482        } else {
483            None
484        }
485    }
486
487    fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
488        let image_path = image_context.project_path.as_ref()?;
489        if path == image_path {
490            Some(FileInclusion::Direct)
491        } else {
492            None
493        }
494    }
495
496    fn check_directory(
497        directory_context: &DirectoryContextHandle,
498        path: &ProjectPath,
499        project: &Project,
500        cx: &App,
501    ) -> Option<Self> {
502        let worktree = project
503            .worktree_for_entry(directory_context.entry_id, cx)?
504            .read(cx);
505        let entry = worktree.entry_for_id(directory_context.entry_id)?;
506        let directory_path = ProjectPath {
507            worktree_id: worktree.id(),
508            path: entry.path.clone(),
509        };
510        if path.starts_with(&directory_path) {
511            if path == &directory_path {
512                Some(FileInclusion::Direct)
513            } else {
514                Some(FileInclusion::InDirectory {
515                    full_path: worktree.full_path(&entry.path),
516                })
517            }
518        } else {
519            None
520        }
521    }
522}