context_store.rs

  1use std::ops::Range;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use collections::{HashSet, IndexSet};
  7use futures::future::join_all;
  8use futures::{self, FutureExt};
  9use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
 10use language::Buffer;
 11use language_model::LanguageModelImage;
 12use project::{Project, ProjectItem, ProjectPath, Symbol};
 13use prompt_store::UserPromptId;
 14use ref_cast::RefCast as _;
 15use text::{Anchor, OffsetRangeExt};
 16use util::ResultExt as _;
 17
 18use crate::ThreadStore;
 19use crate::context::{
 20    AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
 21    FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
 22    SymbolContextHandle, ThreadContextHandle,
 23};
 24use crate::context_strip::SuggestedContext;
 25use crate::thread::{Thread, ThreadId};
 26
 27pub struct ContextStore {
 28    project: WeakEntity<Project>,
 29    thread_store: Option<WeakEntity<ThreadStore>>,
 30    thread_summary_tasks: Vec<Task<()>>,
 31    next_context_id: ContextId,
 32    context_set: IndexSet<AgentContextKey>,
 33    context_thread_ids: HashSet<ThreadId>,
 34}
 35
 36impl ContextStore {
 37    pub fn new(
 38        project: WeakEntity<Project>,
 39        thread_store: Option<WeakEntity<ThreadStore>>,
 40    ) -> Self {
 41        Self {
 42            project,
 43            thread_store,
 44            thread_summary_tasks: Vec::new(),
 45            next_context_id: ContextId::zero(),
 46            context_set: IndexSet::default(),
 47            context_thread_ids: HashSet::default(),
 48        }
 49    }
 50
 51    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 52        self.context_set.iter().map(|entry| entry.as_ref())
 53    }
 54
 55    pub fn clear(&mut self) {
 56        self.context_set.clear();
 57        self.context_thread_ids.clear();
 58    }
 59
 60    pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
 61        let existing_context = thread
 62            .messages()
 63            .flat_map(|message| {
 64                message
 65                    .loaded_context
 66                    .contexts
 67                    .iter()
 68                    .map(|context| AgentContextKey(context.handle()))
 69            })
 70            .collect::<HashSet<_>>();
 71        self.context_set
 72            .iter()
 73            .filter(|context| !existing_context.contains(context))
 74            .map(|entry| entry.0.clone())
 75            .collect::<Vec<_>>()
 76    }
 77
 78    pub fn add_file_from_path(
 79        &mut self,
 80        project_path: ProjectPath,
 81        remove_if_exists: bool,
 82        cx: &mut Context<Self>,
 83    ) -> Task<Result<()>> {
 84        let Some(project) = self.project.upgrade() else {
 85            return Task::ready(Err(anyhow!("failed to read project")));
 86        };
 87
 88        cx.spawn(async move |this, cx| {
 89            let open_buffer_task = project.update(cx, |project, cx| {
 90                project.open_buffer(project_path.clone(), cx)
 91            })?;
 92            let buffer = open_buffer_task.await?;
 93            this.update(cx, |this, cx| {
 94                this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
 95            })
 96        })
 97    }
 98
 99    pub fn add_file_from_buffer(
100        &mut self,
101        project_path: &ProjectPath,
102        buffer: Entity<Buffer>,
103        remove_if_exists: bool,
104        cx: &mut Context<Self>,
105    ) {
106        let context_id = self.next_context_id.post_inc();
107        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
108
109        let already_included = if self.has_context(&context) {
110            if remove_if_exists {
111                self.remove_context(&context, cx);
112            }
113            true
114        } else {
115            self.path_included_in_directory(project_path, cx).is_some()
116        };
117
118        if !already_included {
119            self.insert_context(context, cx);
120        }
121    }
122
123    pub fn add_directory(
124        &mut self,
125        project_path: &ProjectPath,
126        remove_if_exists: bool,
127        cx: &mut Context<Self>,
128    ) -> Result<()> {
129        let Some(project) = self.project.upgrade() else {
130            return Err(anyhow!("failed to read project"));
131        };
132
133        let Some(entry_id) = project
134            .read(cx)
135            .entry_for_path(project_path, cx)
136            .map(|entry| entry.id)
137        else {
138            return Err(anyhow!("no entry found for directory context"));
139        };
140
141        let context_id = self.next_context_id.post_inc();
142        let context = AgentContextHandle::Directory(DirectoryContextHandle {
143            entry_id,
144            context_id,
145        });
146
147        if self.has_context(&context) {
148            if remove_if_exists {
149                self.remove_context(&context, cx);
150            }
151        } else if self.path_included_in_directory(project_path, cx).is_none() {
152            self.insert_context(context, cx);
153        }
154
155        anyhow::Ok(())
156    }
157
158    pub fn add_symbol(
159        &mut self,
160        buffer: Entity<Buffer>,
161        symbol: SharedString,
162        range: Range<Anchor>,
163        enclosing_range: Range<Anchor>,
164        remove_if_exists: bool,
165        cx: &mut Context<Self>,
166    ) -> bool {
167        let context_id = self.next_context_id.post_inc();
168        let context = AgentContextHandle::Symbol(SymbolContextHandle {
169            buffer,
170            symbol,
171            range,
172            enclosing_range,
173            context_id,
174        });
175
176        if self.has_context(&context) {
177            if remove_if_exists {
178                self.remove_context(&context, cx);
179            }
180            return false;
181        }
182
183        self.insert_context(context, cx)
184    }
185
186    pub fn add_thread(
187        &mut self,
188        thread: Entity<Thread>,
189        remove_if_exists: bool,
190        cx: &mut Context<Self>,
191    ) {
192        let context_id = self.next_context_id.post_inc();
193        let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
194
195        if self.has_context(&context) {
196            if remove_if_exists {
197                self.remove_context(&context, cx);
198            }
199        } else {
200            self.insert_context(context, cx);
201        }
202    }
203
204    fn start_summarizing_thread_if_needed(
205        &mut self,
206        thread: &Entity<Thread>,
207        cx: &mut Context<Self>,
208    ) {
209        if let Some(summary_task) =
210            thread.update(cx, |thread, cx| thread.generate_detailed_summary(cx))
211        {
212            let thread = thread.clone();
213            let thread_store = self.thread_store.clone();
214
215            self.thread_summary_tasks.push(cx.spawn(async move |_, cx| {
216                summary_task.await;
217
218                if let Some(thread_store) = thread_store {
219                    // Save thread so its summary can be reused later
220                    let save_task = thread_store
221                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx));
222
223                    if let Some(save_task) = save_task.ok() {
224                        save_task.await.log_err();
225                    }
226                }
227            }));
228        }
229    }
230
231    pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> {
232        let tasks = std::mem::take(&mut self.thread_summary_tasks);
233
234        cx.spawn(async move |_cx| {
235            join_all(tasks).await;
236        })
237    }
238
239    pub fn add_rules(
240        &mut self,
241        prompt_id: UserPromptId,
242        remove_if_exists: bool,
243        cx: &mut Context<ContextStore>,
244    ) {
245        let context_id = self.next_context_id.post_inc();
246        let context = AgentContextHandle::Rules(RulesContextHandle {
247            prompt_id,
248            context_id,
249        });
250
251        if self.has_context(&context) {
252            if remove_if_exists {
253                self.remove_context(&context, cx);
254            }
255        } else {
256            self.insert_context(context, cx);
257        }
258    }
259
260    pub fn add_fetched_url(
261        &mut self,
262        url: String,
263        text: impl Into<SharedString>,
264        cx: &mut Context<ContextStore>,
265    ) {
266        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
267            url: url.into(),
268            text: text.into(),
269            context_id: self.next_context_id.post_inc(),
270        });
271
272        self.insert_context(context, cx);
273    }
274
275    pub fn add_image(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
276        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
277        let context = AgentContextHandle::Image(ImageContext {
278            original_image: image,
279            image_task,
280            context_id: self.next_context_id.post_inc(),
281        });
282        self.insert_context(context, cx);
283    }
284
285    pub fn add_selection(
286        &mut self,
287        buffer: Entity<Buffer>,
288        range: Range<Anchor>,
289        cx: &mut Context<ContextStore>,
290    ) {
291        let context_id = self.next_context_id.post_inc();
292        let context = AgentContextHandle::Selection(SelectionContextHandle {
293            buffer,
294            range,
295            context_id,
296        });
297        self.insert_context(context, cx);
298    }
299
300    pub fn add_suggested_context(
301        &mut self,
302        suggested: &SuggestedContext,
303        cx: &mut Context<ContextStore>,
304    ) {
305        match suggested {
306            SuggestedContext::File {
307                buffer,
308                icon_path: _,
309                name: _,
310            } => {
311                if let Some(buffer) = buffer.upgrade() {
312                    let context_id = self.next_context_id.post_inc();
313                    self.insert_context(
314                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
315                        cx,
316                    );
317                };
318            }
319            SuggestedContext::Thread { thread, name: _ } => {
320                if let Some(thread) = thread.upgrade() {
321                    let context_id = self.next_context_id.post_inc();
322                    self.insert_context(
323                        AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
324                        cx,
325                    );
326                }
327            }
328        }
329    }
330
331    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
332        match &context {
333            AgentContextHandle::Thread(thread_context) => {
334                self.context_thread_ids
335                    .insert(thread_context.thread.read(cx).id().clone());
336                self.start_summarizing_thread_if_needed(&thread_context.thread, cx);
337            }
338            _ => {}
339        }
340        let inserted = self.context_set.insert(AgentContextKey(context));
341        if inserted {
342            cx.notify();
343        }
344        inserted
345    }
346
347    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
348        if self
349            .context_set
350            .shift_remove(AgentContextKey::ref_cast(context))
351        {
352            match context {
353                AgentContextHandle::Thread(thread_context) => {
354                    self.context_thread_ids
355                        .remove(thread_context.thread.read(cx).id());
356                }
357                _ => {}
358            }
359            cx.notify();
360        }
361    }
362
363    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
364        self.context_set
365            .contains(AgentContextKey::ref_cast(context))
366    }
367
368    /// Returns whether this file path is already included directly in the context, or if it will be
369    /// included in the context via a directory.
370    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
371        let project = self.project.upgrade()?.read(cx);
372        self.context().find_map(|context| match context {
373            AgentContextHandle::File(file_context) => {
374                FileInclusion::check_file(file_context, path, cx)
375            }
376            AgentContextHandle::Directory(directory_context) => {
377                FileInclusion::check_directory(directory_context, path, project, cx)
378            }
379            _ => None,
380        })
381    }
382
383    pub fn path_included_in_directory(
384        &self,
385        path: &ProjectPath,
386        cx: &App,
387    ) -> Option<FileInclusion> {
388        let project = self.project.upgrade()?.read(cx);
389        self.context().find_map(|context| match context {
390            AgentContextHandle::Directory(directory_context) => {
391                FileInclusion::check_directory(directory_context, path, project, cx)
392            }
393            _ => None,
394        })
395    }
396
397    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
398        self.context().any(|context| match context {
399            AgentContextHandle::Symbol(context) => {
400                if context.symbol != symbol.name {
401                    return false;
402                }
403                let buffer = context.buffer.read(cx);
404                let Some(context_path) = buffer.project_path(cx) else {
405                    return false;
406                };
407                if context_path != symbol.path {
408                    return false;
409                }
410                let context_range = context.range.to_point_utf16(&buffer.snapshot());
411                context_range.start == symbol.range.start.0
412                    && context_range.end == symbol.range.end.0
413            }
414            _ => false,
415        })
416    }
417
418    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
419        self.context_thread_ids.contains(thread_id)
420    }
421
422    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
423        self.context_set
424            .contains(&RulesContextHandle::lookup_key(prompt_id))
425    }
426
427    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
428        self.context_set
429            .contains(&FetchedUrlContext::lookup_key(url.into()))
430    }
431
432    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
433        self.context()
434            .filter_map(|context| match context {
435                AgentContextHandle::File(file) => {
436                    let buffer = file.buffer.read(cx);
437                    buffer.project_path(cx)
438                }
439                AgentContextHandle::Directory(_)
440                | AgentContextHandle::Symbol(_)
441                | AgentContextHandle::Selection(_)
442                | AgentContextHandle::FetchedUrl(_)
443                | AgentContextHandle::Thread(_)
444                | AgentContextHandle::Rules(_)
445                | AgentContextHandle::Image(_) => None,
446            })
447            .collect()
448    }
449
450    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
451        &self.context_thread_ids
452    }
453}
454
455pub enum FileInclusion {
456    Direct,
457    InDirectory { full_path: PathBuf },
458}
459
460impl FileInclusion {
461    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
462        let file_path = file_context.buffer.read(cx).project_path(cx)?;
463        if path == &file_path {
464            Some(FileInclusion::Direct)
465        } else {
466            None
467        }
468    }
469
470    fn check_directory(
471        directory_context: &DirectoryContextHandle,
472        path: &ProjectPath,
473        project: &Project,
474        cx: &App,
475    ) -> Option<Self> {
476        let worktree = project
477            .worktree_for_entry(directory_context.entry_id, cx)?
478            .read(cx);
479        let entry = worktree.entry_for_id(directory_context.entry_id)?;
480        let directory_path = ProjectPath {
481            worktree_id: worktree.id(),
482            path: entry.path.clone(),
483        };
484        if path.starts_with(&directory_path) {
485            if path == &directory_path {
486                Some(FileInclusion::Direct)
487            } else {
488                Some(FileInclusion::InDirectory {
489                    full_path: worktree.full_path(&entry.path),
490                })
491            }
492        } else {
493            None
494        }
495    }
496}