context_store.rs

  1use std::ops::Range;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use collections::{HashSet, IndexSet};
  7use futures::future::join_all;
  8use futures::{self, FutureExt};
  9use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
 10use language::Buffer;
 11use language_model::LanguageModelImage;
 12use project::image_store::is_image_file;
 13use project::{Project, ProjectItem, ProjectPath, Symbol};
 14use prompt_store::UserPromptId;
 15use ref_cast::RefCast as _;
 16use text::{Anchor, OffsetRangeExt};
 17use util::ResultExt as _;
 18
 19use crate::ThreadStore;
 20use crate::context::{
 21    AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
 22    FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
 23    SymbolContextHandle, ThreadContextHandle,
 24};
 25use crate::context_strip::SuggestedContext;
 26use crate::thread::{Thread, ThreadId};
 27
 28pub struct ContextStore {
 29    project: WeakEntity<Project>,
 30    thread_store: Option<WeakEntity<ThreadStore>>,
 31    thread_summary_tasks: Vec<Task<()>>,
 32    next_context_id: ContextId,
 33    context_set: IndexSet<AgentContextKey>,
 34    context_thread_ids: HashSet<ThreadId>,
 35}
 36
 37impl ContextStore {
 38    pub fn new(
 39        project: WeakEntity<Project>,
 40        thread_store: Option<WeakEntity<ThreadStore>>,
 41    ) -> Self {
 42        Self {
 43            project,
 44            thread_store,
 45            thread_summary_tasks: Vec::new(),
 46            next_context_id: ContextId::zero(),
 47            context_set: IndexSet::default(),
 48            context_thread_ids: HashSet::default(),
 49        }
 50    }
 51
 52    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 53        self.context_set.iter().map(|entry| entry.as_ref())
 54    }
 55
 56    pub fn clear(&mut self) {
 57        self.context_set.clear();
 58        self.context_thread_ids.clear();
 59    }
 60
 61    pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
 62        let existing_context = thread
 63            .messages()
 64            .flat_map(|message| {
 65                message
 66                    .loaded_context
 67                    .contexts
 68                    .iter()
 69                    .map(|context| AgentContextKey(context.handle()))
 70            })
 71            .collect::<HashSet<_>>();
 72        self.context_set
 73            .iter()
 74            .filter(|context| !existing_context.contains(context))
 75            .map(|entry| entry.0.clone())
 76            .collect::<Vec<_>>()
 77    }
 78
 79    pub fn add_file_from_path(
 80        &mut self,
 81        project_path: ProjectPath,
 82        remove_if_exists: bool,
 83        cx: &mut Context<Self>,
 84    ) -> Task<Result<()>> {
 85        let Some(project) = self.project.upgrade() else {
 86            return Task::ready(Err(anyhow!("failed to read project")));
 87        };
 88
 89        if is_image_file(&project, &project_path, cx) {
 90            self.add_image_from_path(project_path, remove_if_exists, cx)
 91        } else {
 92            cx.spawn(async move |this, cx| {
 93                let open_buffer_task = project.update(cx, |project, cx| {
 94                    project.open_buffer(project_path.clone(), cx)
 95                })?;
 96                let buffer = open_buffer_task.await?;
 97                this.update(cx, |this, cx| {
 98                    this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
 99                })
100            })
101        }
102    }
103
104    pub fn add_file_from_buffer(
105        &mut self,
106        project_path: &ProjectPath,
107        buffer: Entity<Buffer>,
108        remove_if_exists: bool,
109        cx: &mut Context<Self>,
110    ) {
111        let context_id = self.next_context_id.post_inc();
112        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
113
114        let already_included = if self.has_context(&context) {
115            if remove_if_exists {
116                self.remove_context(&context, cx);
117            }
118            true
119        } else {
120            self.path_included_in_directory(project_path, cx).is_some()
121        };
122
123        if !already_included {
124            self.insert_context(context, cx);
125        }
126    }
127
128    pub fn add_directory(
129        &mut self,
130        project_path: &ProjectPath,
131        remove_if_exists: bool,
132        cx: &mut Context<Self>,
133    ) -> Result<()> {
134        let Some(project) = self.project.upgrade() else {
135            return Err(anyhow!("failed to read project"));
136        };
137
138        let Some(entry_id) = project
139            .read(cx)
140            .entry_for_path(project_path, cx)
141            .map(|entry| entry.id)
142        else {
143            return Err(anyhow!("no entry found for directory context"));
144        };
145
146        let context_id = self.next_context_id.post_inc();
147        let context = AgentContextHandle::Directory(DirectoryContextHandle {
148            entry_id,
149            context_id,
150        });
151
152        if self.has_context(&context) {
153            if remove_if_exists {
154                self.remove_context(&context, cx);
155            }
156        } else if self.path_included_in_directory(project_path, cx).is_none() {
157            self.insert_context(context, cx);
158        }
159
160        anyhow::Ok(())
161    }
162
163    pub fn add_symbol(
164        &mut self,
165        buffer: Entity<Buffer>,
166        symbol: SharedString,
167        range: Range<Anchor>,
168        enclosing_range: Range<Anchor>,
169        remove_if_exists: bool,
170        cx: &mut Context<Self>,
171    ) -> bool {
172        let context_id = self.next_context_id.post_inc();
173        let context = AgentContextHandle::Symbol(SymbolContextHandle {
174            buffer,
175            symbol,
176            range,
177            enclosing_range,
178            context_id,
179        });
180
181        if self.has_context(&context) {
182            if remove_if_exists {
183                self.remove_context(&context, cx);
184            }
185            return false;
186        }
187
188        self.insert_context(context, cx)
189    }
190
191    pub fn add_thread(
192        &mut self,
193        thread: Entity<Thread>,
194        remove_if_exists: bool,
195        cx: &mut Context<Self>,
196    ) {
197        let context_id = self.next_context_id.post_inc();
198        let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
199
200        if self.has_context(&context) {
201            if remove_if_exists {
202                self.remove_context(&context, cx);
203            }
204        } else {
205            self.insert_context(context, cx);
206        }
207    }
208
209    fn start_summarizing_thread_if_needed(
210        &mut self,
211        thread: &Entity<Thread>,
212        cx: &mut Context<Self>,
213    ) {
214        if let Some(summary_task) =
215            thread.update(cx, |thread, cx| thread.generate_detailed_summary(cx))
216        {
217            let thread = thread.clone();
218            let thread_store = self.thread_store.clone();
219
220            self.thread_summary_tasks.push(cx.spawn(async move |_, cx| {
221                summary_task.await;
222
223                if let Some(thread_store) = thread_store {
224                    // Save thread so its summary can be reused later
225                    let save_task = thread_store
226                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx));
227
228                    if let Some(save_task) = save_task.ok() {
229                        save_task.await.log_err();
230                    }
231                }
232            }));
233        }
234    }
235
236    pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> {
237        let tasks = std::mem::take(&mut self.thread_summary_tasks);
238
239        cx.spawn(async move |_cx| {
240            join_all(tasks).await;
241        })
242    }
243
244    pub fn add_rules(
245        &mut self,
246        prompt_id: UserPromptId,
247        remove_if_exists: bool,
248        cx: &mut Context<ContextStore>,
249    ) {
250        let context_id = self.next_context_id.post_inc();
251        let context = AgentContextHandle::Rules(RulesContextHandle {
252            prompt_id,
253            context_id,
254        });
255
256        if self.has_context(&context) {
257            if remove_if_exists {
258                self.remove_context(&context, cx);
259            }
260        } else {
261            self.insert_context(context, cx);
262        }
263    }
264
265    pub fn add_fetched_url(
266        &mut self,
267        url: String,
268        text: impl Into<SharedString>,
269        cx: &mut Context<ContextStore>,
270    ) {
271        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
272            url: url.into(),
273            text: text.into(),
274            context_id: self.next_context_id.post_inc(),
275        });
276
277        self.insert_context(context, cx);
278    }
279
280    pub fn add_image_from_path(
281        &mut self,
282        project_path: ProjectPath,
283        remove_if_exists: bool,
284        cx: &mut Context<ContextStore>,
285    ) -> Task<Result<()>> {
286        let project = self.project.clone();
287        cx.spawn(async move |this, cx| {
288            let open_image_task = project.update(cx, |project, cx| {
289                project.open_image(project_path.clone(), cx)
290            })?;
291            let image_item = open_image_task.await?;
292            let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
293            this.update(cx, |this, cx| {
294                this.insert_image(
295                    Some(image_item.read(cx).project_path(cx)),
296                    image,
297                    remove_if_exists,
298                    cx,
299                );
300            })
301        })
302    }
303
304    pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
305        self.insert_image(None, image, false, cx);
306    }
307
308    fn insert_image(
309        &mut self,
310        project_path: Option<ProjectPath>,
311        image: Arc<Image>,
312        remove_if_exists: bool,
313        cx: &mut Context<ContextStore>,
314    ) {
315        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
316        let context = AgentContextHandle::Image(ImageContext {
317            project_path,
318            original_image: image,
319            image_task,
320            context_id: self.next_context_id.post_inc(),
321        });
322        if self.has_context(&context) {
323            if remove_if_exists {
324                self.remove_context(&context, cx);
325                return;
326            }
327        }
328
329        self.insert_context(context, cx);
330    }
331
332    pub fn add_selection(
333        &mut self,
334        buffer: Entity<Buffer>,
335        range: Range<Anchor>,
336        cx: &mut Context<ContextStore>,
337    ) {
338        let context_id = self.next_context_id.post_inc();
339        let context = AgentContextHandle::Selection(SelectionContextHandle {
340            buffer,
341            range,
342            context_id,
343        });
344        self.insert_context(context, cx);
345    }
346
347    pub fn add_suggested_context(
348        &mut self,
349        suggested: &SuggestedContext,
350        cx: &mut Context<ContextStore>,
351    ) {
352        match suggested {
353            SuggestedContext::File {
354                buffer,
355                icon_path: _,
356                name: _,
357            } => {
358                if let Some(buffer) = buffer.upgrade() {
359                    let context_id = self.next_context_id.post_inc();
360                    self.insert_context(
361                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
362                        cx,
363                    );
364                };
365            }
366            SuggestedContext::Thread { thread, name: _ } => {
367                if let Some(thread) = thread.upgrade() {
368                    let context_id = self.next_context_id.post_inc();
369                    self.insert_context(
370                        AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
371                        cx,
372                    );
373                }
374            }
375        }
376    }
377
378    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
379        match &context {
380            AgentContextHandle::Thread(thread_context) => {
381                self.context_thread_ids
382                    .insert(thread_context.thread.read(cx).id().clone());
383                self.start_summarizing_thread_if_needed(&thread_context.thread, cx);
384            }
385            _ => {}
386        }
387        let inserted = self.context_set.insert(AgentContextKey(context));
388        if inserted {
389            cx.notify();
390        }
391        inserted
392    }
393
394    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
395        if self
396            .context_set
397            .shift_remove(AgentContextKey::ref_cast(context))
398        {
399            match context {
400                AgentContextHandle::Thread(thread_context) => {
401                    self.context_thread_ids
402                        .remove(thread_context.thread.read(cx).id());
403                }
404                _ => {}
405            }
406            cx.notify();
407        }
408    }
409
410    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
411        self.context_set
412            .contains(AgentContextKey::ref_cast(context))
413    }
414
415    /// Returns whether this file path is already included directly in the context, or if it will be
416    /// included in the context via a directory.
417    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
418        let project = self.project.upgrade()?.read(cx);
419        self.context().find_map(|context| match context {
420            AgentContextHandle::File(file_context) => {
421                FileInclusion::check_file(file_context, path, cx)
422            }
423            AgentContextHandle::Image(image_context) => {
424                FileInclusion::check_image(image_context, path)
425            }
426            AgentContextHandle::Directory(directory_context) => {
427                FileInclusion::check_directory(directory_context, path, project, cx)
428            }
429            _ => None,
430        })
431    }
432
433    pub fn path_included_in_directory(
434        &self,
435        path: &ProjectPath,
436        cx: &App,
437    ) -> Option<FileInclusion> {
438        let project = self.project.upgrade()?.read(cx);
439        self.context().find_map(|context| match context {
440            AgentContextHandle::Directory(directory_context) => {
441                FileInclusion::check_directory(directory_context, path, project, cx)
442            }
443            _ => None,
444        })
445    }
446
447    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
448        self.context().any(|context| match context {
449            AgentContextHandle::Symbol(context) => {
450                if context.symbol != symbol.name {
451                    return false;
452                }
453                let buffer = context.buffer.read(cx);
454                let Some(context_path) = buffer.project_path(cx) else {
455                    return false;
456                };
457                if context_path != symbol.path {
458                    return false;
459                }
460                let context_range = context.range.to_point_utf16(&buffer.snapshot());
461                context_range.start == symbol.range.start.0
462                    && context_range.end == symbol.range.end.0
463            }
464            _ => false,
465        })
466    }
467
468    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
469        self.context_thread_ids.contains(thread_id)
470    }
471
472    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
473        self.context_set
474            .contains(&RulesContextHandle::lookup_key(prompt_id))
475    }
476
477    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
478        self.context_set
479            .contains(&FetchedUrlContext::lookup_key(url.into()))
480    }
481
482    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
483        self.context()
484            .filter_map(|context| match context {
485                AgentContextHandle::File(file) => {
486                    let buffer = file.buffer.read(cx);
487                    buffer.project_path(cx)
488                }
489                AgentContextHandle::Directory(_)
490                | AgentContextHandle::Symbol(_)
491                | AgentContextHandle::Selection(_)
492                | AgentContextHandle::FetchedUrl(_)
493                | AgentContextHandle::Thread(_)
494                | AgentContextHandle::Rules(_)
495                | AgentContextHandle::Image(_) => None,
496            })
497            .collect()
498    }
499
500    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
501        &self.context_thread_ids
502    }
503}
504
505pub enum FileInclusion {
506    Direct,
507    InDirectory { full_path: PathBuf },
508}
509
510impl FileInclusion {
511    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
512        let file_path = file_context.buffer.read(cx).project_path(cx)?;
513        if path == &file_path {
514            Some(FileInclusion::Direct)
515        } else {
516            None
517        }
518    }
519
520    fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
521        let image_path = image_context.project_path.as_ref()?;
522        if path == image_path {
523            Some(FileInclusion::Direct)
524        } else {
525            None
526        }
527    }
528
529    fn check_directory(
530        directory_context: &DirectoryContextHandle,
531        path: &ProjectPath,
532        project: &Project,
533        cx: &App,
534    ) -> Option<Self> {
535        let worktree = project
536            .worktree_for_entry(directory_context.entry_id, cx)?
537            .read(cx);
538        let entry = worktree.entry_for_id(directory_context.entry_id)?;
539        let directory_path = ProjectPath {
540            worktree_id: worktree.id(),
541            path: entry.path.clone(),
542        };
543        if path.starts_with(&directory_path) {
544            if path == &directory_path {
545                Some(FileInclusion::Direct)
546            } else {
547                Some(FileInclusion::InDirectory {
548                    full_path: worktree.full_path(&entry.path),
549                })
550            }
551        } else {
552            None
553        }
554    }
555}