context_store.rs

  1use std::ops::Range;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use collections::{HashSet, IndexSet};
  7use futures::{self, FutureExt};
  8use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
  9use language::Buffer;
 10use language_model::LanguageModelImage;
 11use project::image_store::is_image_file;
 12use project::{Project, ProjectItem, ProjectPath, Symbol};
 13use prompt_store::UserPromptId;
 14use ref_cast::RefCast as _;
 15use text::{Anchor, OffsetRangeExt};
 16
 17use crate::ThreadStore;
 18use crate::context::{
 19    AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
 20    FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
 21    SymbolContextHandle, ThreadContextHandle,
 22};
 23use crate::context_strip::SuggestedContext;
 24use crate::thread::{MessageId, Thread, ThreadId};
 25
 26pub struct ContextStore {
 27    project: WeakEntity<Project>,
 28    thread_store: Option<WeakEntity<ThreadStore>>,
 29    next_context_id: ContextId,
 30    context_set: IndexSet<AgentContextKey>,
 31    context_thread_ids: HashSet<ThreadId>,
 32}
 33
 34impl ContextStore {
 35    pub fn new(
 36        project: WeakEntity<Project>,
 37        thread_store: Option<WeakEntity<ThreadStore>>,
 38    ) -> Self {
 39        Self {
 40            project,
 41            thread_store,
 42            next_context_id: ContextId::zero(),
 43            context_set: IndexSet::default(),
 44            context_thread_ids: HashSet::default(),
 45        }
 46    }
 47
 48    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 49        self.context_set.iter().map(|entry| entry.as_ref())
 50    }
 51
 52    pub fn clear(&mut self) {
 53        self.context_set.clear();
 54        self.context_thread_ids.clear();
 55    }
 56
 57    pub fn new_context_for_thread(
 58        &self,
 59        thread: &Thread,
 60        exclude_messages_from_id: Option<MessageId>,
 61    ) -> Vec<AgentContextHandle> {
 62        let existing_context = thread
 63            .messages()
 64            .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
 65            .flat_map(|message| {
 66                message
 67                    .loaded_context
 68                    .contexts
 69                    .iter()
 70                    .map(|context| AgentContextKey(context.handle()))
 71            })
 72            .collect::<HashSet<_>>();
 73        self.context_set
 74            .iter()
 75            .filter(|context| !existing_context.contains(context))
 76            .map(|entry| entry.0.clone())
 77            .collect::<Vec<_>>()
 78    }
 79
 80    pub fn add_file_from_path(
 81        &mut self,
 82        project_path: ProjectPath,
 83        remove_if_exists: bool,
 84        cx: &mut Context<Self>,
 85    ) -> Task<Result<()>> {
 86        let Some(project) = self.project.upgrade() else {
 87            return Task::ready(Err(anyhow!("failed to read project")));
 88        };
 89
 90        if is_image_file(&project, &project_path, cx) {
 91            self.add_image_from_path(project_path, remove_if_exists, cx)
 92        } else {
 93            cx.spawn(async move |this, cx| {
 94                let open_buffer_task = project.update(cx, |project, cx| {
 95                    project.open_buffer(project_path.clone(), cx)
 96                })?;
 97                let buffer = open_buffer_task.await?;
 98                this.update(cx, |this, cx| {
 99                    this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
100                })
101            })
102        }
103    }
104
105    pub fn add_file_from_buffer(
106        &mut self,
107        project_path: &ProjectPath,
108        buffer: Entity<Buffer>,
109        remove_if_exists: bool,
110        cx: &mut Context<Self>,
111    ) {
112        let context_id = self.next_context_id.post_inc();
113        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
114
115        let already_included = if self.has_context(&context) {
116            if remove_if_exists {
117                self.remove_context(&context, cx);
118            }
119            true
120        } else {
121            self.path_included_in_directory(project_path, cx).is_some()
122        };
123
124        if !already_included {
125            self.insert_context(context, cx);
126        }
127    }
128
129    pub fn add_directory(
130        &mut self,
131        project_path: &ProjectPath,
132        remove_if_exists: bool,
133        cx: &mut Context<Self>,
134    ) -> Result<()> {
135        let Some(project) = self.project.upgrade() else {
136            return Err(anyhow!("failed to read project"));
137        };
138
139        let Some(entry_id) = project
140            .read(cx)
141            .entry_for_path(project_path, cx)
142            .map(|entry| entry.id)
143        else {
144            return Err(anyhow!("no entry found for directory context"));
145        };
146
147        let context_id = self.next_context_id.post_inc();
148        let context = AgentContextHandle::Directory(DirectoryContextHandle {
149            entry_id,
150            context_id,
151        });
152
153        if self.has_context(&context) {
154            if remove_if_exists {
155                self.remove_context(&context, cx);
156            }
157        } else if self.path_included_in_directory(project_path, cx).is_none() {
158            self.insert_context(context, cx);
159        }
160
161        anyhow::Ok(())
162    }
163
164    pub fn add_symbol(
165        &mut self,
166        buffer: Entity<Buffer>,
167        symbol: SharedString,
168        range: Range<Anchor>,
169        enclosing_range: Range<Anchor>,
170        remove_if_exists: bool,
171        cx: &mut Context<Self>,
172    ) -> bool {
173        let context_id = self.next_context_id.post_inc();
174        let context = AgentContextHandle::Symbol(SymbolContextHandle {
175            buffer,
176            symbol,
177            range,
178            enclosing_range,
179            context_id,
180        });
181
182        if self.has_context(&context) {
183            if remove_if_exists {
184                self.remove_context(&context, cx);
185            }
186            return false;
187        }
188
189        self.insert_context(context, cx)
190    }
191
192    pub fn add_thread(
193        &mut self,
194        thread: Entity<Thread>,
195        remove_if_exists: bool,
196        cx: &mut Context<Self>,
197    ) {
198        let context_id = self.next_context_id.post_inc();
199        let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
200
201        if self.has_context(&context) {
202            if remove_if_exists {
203                self.remove_context(&context, cx);
204            }
205        } else {
206            self.insert_context(context, cx);
207        }
208    }
209
210    pub fn add_rules(
211        &mut self,
212        prompt_id: UserPromptId,
213        remove_if_exists: bool,
214        cx: &mut Context<ContextStore>,
215    ) {
216        let context_id = self.next_context_id.post_inc();
217        let context = AgentContextHandle::Rules(RulesContextHandle {
218            prompt_id,
219            context_id,
220        });
221
222        if self.has_context(&context) {
223            if remove_if_exists {
224                self.remove_context(&context, cx);
225            }
226        } else {
227            self.insert_context(context, cx);
228        }
229    }
230
231    pub fn add_fetched_url(
232        &mut self,
233        url: String,
234        text: impl Into<SharedString>,
235        cx: &mut Context<ContextStore>,
236    ) {
237        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
238            url: url.into(),
239            text: text.into(),
240            context_id: self.next_context_id.post_inc(),
241        });
242
243        self.insert_context(context, cx);
244    }
245
246    pub fn add_image_from_path(
247        &mut self,
248        project_path: ProjectPath,
249        remove_if_exists: bool,
250        cx: &mut Context<ContextStore>,
251    ) -> Task<Result<()>> {
252        let project = self.project.clone();
253        cx.spawn(async move |this, cx| {
254            let open_image_task = project.update(cx, |project, cx| {
255                project.open_image(project_path.clone(), cx)
256            })?;
257            let image_item = open_image_task.await?;
258            let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
259            this.update(cx, |this, cx| {
260                this.insert_image(
261                    Some(image_item.read(cx).project_path(cx)),
262                    image,
263                    remove_if_exists,
264                    cx,
265                );
266            })
267        })
268    }
269
270    pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
271        self.insert_image(None, image, false, cx);
272    }
273
274    fn insert_image(
275        &mut self,
276        project_path: Option<ProjectPath>,
277        image: Arc<Image>,
278        remove_if_exists: bool,
279        cx: &mut Context<ContextStore>,
280    ) {
281        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
282        let context = AgentContextHandle::Image(ImageContext {
283            project_path,
284            original_image: image,
285            image_task,
286            context_id: self.next_context_id.post_inc(),
287        });
288        if self.has_context(&context) {
289            if remove_if_exists {
290                self.remove_context(&context, cx);
291                return;
292            }
293        }
294
295        self.insert_context(context, cx);
296    }
297
298    pub fn add_selection(
299        &mut self,
300        buffer: Entity<Buffer>,
301        range: Range<Anchor>,
302        cx: &mut Context<ContextStore>,
303    ) {
304        let context_id = self.next_context_id.post_inc();
305        let context = AgentContextHandle::Selection(SelectionContextHandle {
306            buffer,
307            range,
308            context_id,
309        });
310        self.insert_context(context, cx);
311    }
312
313    pub fn add_suggested_context(
314        &mut self,
315        suggested: &SuggestedContext,
316        cx: &mut Context<ContextStore>,
317    ) {
318        match suggested {
319            SuggestedContext::File {
320                buffer,
321                icon_path: _,
322                name: _,
323            } => {
324                if let Some(buffer) = buffer.upgrade() {
325                    let context_id = self.next_context_id.post_inc();
326                    self.insert_context(
327                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
328                        cx,
329                    );
330                };
331            }
332            SuggestedContext::Thread { thread, name: _ } => {
333                if let Some(thread) = thread.upgrade() {
334                    let context_id = self.next_context_id.post_inc();
335                    self.insert_context(
336                        AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
337                        cx,
338                    );
339                }
340            }
341        }
342    }
343
344    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
345        match &context {
346            AgentContextHandle::Thread(thread_context) => {
347                if let Some(thread_store) = self.thread_store.clone() {
348                    thread_context.thread.update(cx, |thread, cx| {
349                        thread.start_generating_detailed_summary_if_needed(thread_store, cx);
350                    });
351                    self.context_thread_ids
352                        .insert(thread_context.thread.read(cx).id().clone());
353                } else {
354                    return false;
355                }
356            }
357            _ => {}
358        }
359        let inserted = self.context_set.insert(AgentContextKey(context));
360        if inserted {
361            cx.notify();
362        }
363        inserted
364    }
365
366    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
367        if self
368            .context_set
369            .shift_remove(AgentContextKey::ref_cast(context))
370        {
371            match context {
372                AgentContextHandle::Thread(thread_context) => {
373                    self.context_thread_ids
374                        .remove(thread_context.thread.read(cx).id());
375                }
376                _ => {}
377            }
378            cx.notify();
379        }
380    }
381
382    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
383        self.context_set
384            .contains(AgentContextKey::ref_cast(context))
385    }
386
387    /// Returns whether this file path is already included directly in the context, or if it will be
388    /// included in the context via a directory.
389    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
390        let project = self.project.upgrade()?.read(cx);
391        self.context().find_map(|context| match context {
392            AgentContextHandle::File(file_context) => {
393                FileInclusion::check_file(file_context, path, cx)
394            }
395            AgentContextHandle::Image(image_context) => {
396                FileInclusion::check_image(image_context, path)
397            }
398            AgentContextHandle::Directory(directory_context) => {
399                FileInclusion::check_directory(directory_context, path, project, cx)
400            }
401            _ => None,
402        })
403    }
404
405    pub fn path_included_in_directory(
406        &self,
407        path: &ProjectPath,
408        cx: &App,
409    ) -> Option<FileInclusion> {
410        let project = self.project.upgrade()?.read(cx);
411        self.context().find_map(|context| match context {
412            AgentContextHandle::Directory(directory_context) => {
413                FileInclusion::check_directory(directory_context, path, project, cx)
414            }
415            _ => None,
416        })
417    }
418
419    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
420        self.context().any(|context| match context {
421            AgentContextHandle::Symbol(context) => {
422                if context.symbol != symbol.name {
423                    return false;
424                }
425                let buffer = context.buffer.read(cx);
426                let Some(context_path) = buffer.project_path(cx) else {
427                    return false;
428                };
429                if context_path != symbol.path {
430                    return false;
431                }
432                let context_range = context.range.to_point_utf16(&buffer.snapshot());
433                context_range.start == symbol.range.start.0
434                    && context_range.end == symbol.range.end.0
435            }
436            _ => false,
437        })
438    }
439
440    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
441        self.context_thread_ids.contains(thread_id)
442    }
443
444    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
445        self.context_set
446            .contains(&RulesContextHandle::lookup_key(prompt_id))
447    }
448
449    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
450        self.context_set
451            .contains(&FetchedUrlContext::lookup_key(url.into()))
452    }
453
454    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
455        self.context()
456            .filter_map(|context| match context {
457                AgentContextHandle::File(file) => {
458                    let buffer = file.buffer.read(cx);
459                    buffer.project_path(cx)
460                }
461                AgentContextHandle::Directory(_)
462                | AgentContextHandle::Symbol(_)
463                | AgentContextHandle::Selection(_)
464                | AgentContextHandle::FetchedUrl(_)
465                | AgentContextHandle::Thread(_)
466                | AgentContextHandle::Rules(_)
467                | AgentContextHandle::Image(_) => None,
468            })
469            .collect()
470    }
471
472    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
473        &self.context_thread_ids
474    }
475}
476
477pub enum FileInclusion {
478    Direct,
479    InDirectory { full_path: PathBuf },
480}
481
482impl FileInclusion {
483    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
484        let file_path = file_context.buffer.read(cx).project_path(cx)?;
485        if path == &file_path {
486            Some(FileInclusion::Direct)
487        } else {
488            None
489        }
490    }
491
492    fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
493        let image_path = image_context.project_path.as_ref()?;
494        if path == image_path {
495            Some(FileInclusion::Direct)
496        } else {
497            None
498        }
499    }
500
501    fn check_directory(
502        directory_context: &DirectoryContextHandle,
503        path: &ProjectPath,
504        project: &Project,
505        cx: &App,
506    ) -> Option<Self> {
507        let worktree = project
508            .worktree_for_entry(directory_context.entry_id, cx)?
509            .read(cx);
510        let entry = worktree.entry_for_id(directory_context.entry_id)?;
511        let directory_path = ProjectPath {
512            worktree_id: worktree.id(),
513            path: entry.path.clone(),
514        };
515        if path.starts_with(&directory_path) {
516            if path == &directory_path {
517                Some(FileInclusion::Direct)
518            } else {
519                Some(FileInclusion::InDirectory {
520                    full_path: worktree.full_path(&entry.path),
521                })
522            }
523        } else {
524            None
525        }
526    }
527}