context_store.rs

  1use std::ops::Range;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use collections::{HashSet, IndexSet};
  7use futures::{self, FutureExt};
  8use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
  9use language::Buffer;
 10use language_model::LanguageModelImage;
 11use project::{Project, ProjectItem, ProjectPath, Symbol};
 12use prompt_store::UserPromptId;
 13use ref_cast::RefCast as _;
 14use text::{Anchor, OffsetRangeExt};
 15
 16use crate::ThreadStore;
 17use crate::context::{
 18    AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
 19    FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
 20    SymbolContextHandle, ThreadContextHandle,
 21};
 22use crate::context_strip::SuggestedContext;
 23use crate::thread::{Thread, ThreadId};
 24
 25pub struct ContextStore {
 26    project: WeakEntity<Project>,
 27    thread_store: Option<WeakEntity<ThreadStore>>,
 28    next_context_id: ContextId,
 29    context_set: IndexSet<AgentContextKey>,
 30    context_thread_ids: HashSet<ThreadId>,
 31}
 32
 33impl ContextStore {
 34    pub fn new(
 35        project: WeakEntity<Project>,
 36        thread_store: Option<WeakEntity<ThreadStore>>,
 37    ) -> Self {
 38        Self {
 39            project,
 40            thread_store,
 41            next_context_id: ContextId::zero(),
 42            context_set: IndexSet::default(),
 43            context_thread_ids: HashSet::default(),
 44        }
 45    }
 46
 47    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 48        self.context_set.iter().map(|entry| entry.as_ref())
 49    }
 50
 51    pub fn clear(&mut self) {
 52        self.context_set.clear();
 53        self.context_thread_ids.clear();
 54    }
 55
 56    pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
 57        let existing_context = thread
 58            .messages()
 59            .flat_map(|message| {
 60                message
 61                    .loaded_context
 62                    .contexts
 63                    .iter()
 64                    .map(|context| AgentContextKey(context.handle()))
 65            })
 66            .collect::<HashSet<_>>();
 67        self.context_set
 68            .iter()
 69            .filter(|context| !existing_context.contains(context))
 70            .map(|entry| entry.0.clone())
 71            .collect::<Vec<_>>()
 72    }
 73
 74    pub fn add_file_from_path(
 75        &mut self,
 76        project_path: ProjectPath,
 77        remove_if_exists: bool,
 78        cx: &mut Context<Self>,
 79    ) -> Task<Result<()>> {
 80        let Some(project) = self.project.upgrade() else {
 81            return Task::ready(Err(anyhow!("failed to read project")));
 82        };
 83
 84        cx.spawn(async move |this, cx| {
 85            let open_buffer_task = project.update(cx, |project, cx| {
 86                project.open_buffer(project_path.clone(), cx)
 87            })?;
 88            let buffer = open_buffer_task.await?;
 89            this.update(cx, |this, cx| {
 90                this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
 91            })
 92        })
 93    }
 94
 95    pub fn add_file_from_buffer(
 96        &mut self,
 97        project_path: &ProjectPath,
 98        buffer: Entity<Buffer>,
 99        remove_if_exists: bool,
100        cx: &mut Context<Self>,
101    ) {
102        let context_id = self.next_context_id.post_inc();
103        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
104
105        let already_included = if self.has_context(&context) {
106            if remove_if_exists {
107                self.remove_context(&context, cx);
108            }
109            true
110        } else {
111            self.path_included_in_directory(project_path, cx).is_some()
112        };
113
114        if !already_included {
115            self.insert_context(context, cx);
116        }
117    }
118
119    pub fn add_directory(
120        &mut self,
121        project_path: &ProjectPath,
122        remove_if_exists: bool,
123        cx: &mut Context<Self>,
124    ) -> Result<()> {
125        let Some(project) = self.project.upgrade() else {
126            return Err(anyhow!("failed to read project"));
127        };
128
129        let Some(entry_id) = project
130            .read(cx)
131            .entry_for_path(project_path, cx)
132            .map(|entry| entry.id)
133        else {
134            return Err(anyhow!("no entry found for directory context"));
135        };
136
137        let context_id = self.next_context_id.post_inc();
138        let context = AgentContextHandle::Directory(DirectoryContextHandle {
139            entry_id,
140            context_id,
141        });
142
143        if self.has_context(&context) {
144            if remove_if_exists {
145                self.remove_context(&context, cx);
146            }
147        } else if self.path_included_in_directory(project_path, cx).is_none() {
148            self.insert_context(context, cx);
149        }
150
151        anyhow::Ok(())
152    }
153
154    pub fn add_symbol(
155        &mut self,
156        buffer: Entity<Buffer>,
157        symbol: SharedString,
158        range: Range<Anchor>,
159        enclosing_range: Range<Anchor>,
160        remove_if_exists: bool,
161        cx: &mut Context<Self>,
162    ) -> bool {
163        let context_id = self.next_context_id.post_inc();
164        let context = AgentContextHandle::Symbol(SymbolContextHandle {
165            buffer,
166            symbol,
167            range,
168            enclosing_range,
169            context_id,
170        });
171
172        if self.has_context(&context) {
173            if remove_if_exists {
174                self.remove_context(&context, cx);
175            }
176            return false;
177        }
178
179        self.insert_context(context, cx)
180    }
181
182    pub fn add_thread(
183        &mut self,
184        thread: Entity<Thread>,
185        remove_if_exists: bool,
186        cx: &mut Context<Self>,
187    ) {
188        let context_id = self.next_context_id.post_inc();
189        let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
190
191        if self.has_context(&context) {
192            if remove_if_exists {
193                self.remove_context(&context, cx);
194            }
195        } else {
196            self.insert_context(context, cx);
197        }
198    }
199
200    pub fn add_rules(
201        &mut self,
202        prompt_id: UserPromptId,
203        remove_if_exists: bool,
204        cx: &mut Context<ContextStore>,
205    ) {
206        let context_id = self.next_context_id.post_inc();
207        let context = AgentContextHandle::Rules(RulesContextHandle {
208            prompt_id,
209            context_id,
210        });
211
212        if self.has_context(&context) {
213            if remove_if_exists {
214                self.remove_context(&context, cx);
215            }
216        } else {
217            self.insert_context(context, cx);
218        }
219    }
220
221    pub fn add_fetched_url(
222        &mut self,
223        url: String,
224        text: impl Into<SharedString>,
225        cx: &mut Context<ContextStore>,
226    ) {
227        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
228            url: url.into(),
229            text: text.into(),
230            context_id: self.next_context_id.post_inc(),
231        });
232
233        self.insert_context(context, cx);
234    }
235
236    pub fn add_image(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
237        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
238        let context = AgentContextHandle::Image(ImageContext {
239            original_image: image,
240            image_task,
241            context_id: self.next_context_id.post_inc(),
242        });
243        self.insert_context(context, cx);
244    }
245
246    pub fn add_selection(
247        &mut self,
248        buffer: Entity<Buffer>,
249        range: Range<Anchor>,
250        cx: &mut Context<ContextStore>,
251    ) {
252        let context_id = self.next_context_id.post_inc();
253        let context = AgentContextHandle::Selection(SelectionContextHandle {
254            buffer,
255            range,
256            context_id,
257        });
258        self.insert_context(context, cx);
259    }
260
261    pub fn add_suggested_context(
262        &mut self,
263        suggested: &SuggestedContext,
264        cx: &mut Context<ContextStore>,
265    ) {
266        match suggested {
267            SuggestedContext::File {
268                buffer,
269                icon_path: _,
270                name: _,
271            } => {
272                if let Some(buffer) = buffer.upgrade() {
273                    let context_id = self.next_context_id.post_inc();
274                    self.insert_context(
275                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
276                        cx,
277                    );
278                };
279            }
280            SuggestedContext::Thread { thread, name: _ } => {
281                if let Some(thread) = thread.upgrade() {
282                    let context_id = self.next_context_id.post_inc();
283                    self.insert_context(
284                        AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
285                        cx,
286                    );
287                }
288            }
289        }
290    }
291
292    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
293        match &context {
294            AgentContextHandle::Thread(thread_context) => {
295                if let Some(thread_store) = self.thread_store.clone() {
296                    thread_context.thread.update(cx, |thread, cx| {
297                        thread.start_generating_detailed_summary_if_needed(thread_store, cx);
298                    });
299                    self.context_thread_ids
300                        .insert(thread_context.thread.read(cx).id().clone());
301                } else {
302                    return false;
303                }
304            }
305            _ => {}
306        }
307        let inserted = self.context_set.insert(AgentContextKey(context));
308        if inserted {
309            cx.notify();
310        }
311        inserted
312    }
313
314    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
315        if self
316            .context_set
317            .shift_remove(AgentContextKey::ref_cast(context))
318        {
319            match context {
320                AgentContextHandle::Thread(thread_context) => {
321                    self.context_thread_ids
322                        .remove(thread_context.thread.read(cx).id());
323                }
324                _ => {}
325            }
326            cx.notify();
327        }
328    }
329
330    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
331        self.context_set
332            .contains(AgentContextKey::ref_cast(context))
333    }
334
335    /// Returns whether this file path is already included directly in the context, or if it will be
336    /// included in the context via a directory.
337    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
338        let project = self.project.upgrade()?.read(cx);
339        self.context().find_map(|context| match context {
340            AgentContextHandle::File(file_context) => {
341                FileInclusion::check_file(file_context, path, cx)
342            }
343            AgentContextHandle::Directory(directory_context) => {
344                FileInclusion::check_directory(directory_context, path, project, cx)
345            }
346            _ => None,
347        })
348    }
349
350    pub fn path_included_in_directory(
351        &self,
352        path: &ProjectPath,
353        cx: &App,
354    ) -> Option<FileInclusion> {
355        let project = self.project.upgrade()?.read(cx);
356        self.context().find_map(|context| match context {
357            AgentContextHandle::Directory(directory_context) => {
358                FileInclusion::check_directory(directory_context, path, project, cx)
359            }
360            _ => None,
361        })
362    }
363
364    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
365        self.context().any(|context| match context {
366            AgentContextHandle::Symbol(context) => {
367                if context.symbol != symbol.name {
368                    return false;
369                }
370                let buffer = context.buffer.read(cx);
371                let Some(context_path) = buffer.project_path(cx) else {
372                    return false;
373                };
374                if context_path != symbol.path {
375                    return false;
376                }
377                let context_range = context.range.to_point_utf16(&buffer.snapshot());
378                context_range.start == symbol.range.start.0
379                    && context_range.end == symbol.range.end.0
380            }
381            _ => false,
382        })
383    }
384
385    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
386        self.context_thread_ids.contains(thread_id)
387    }
388
389    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
390        self.context_set
391            .contains(&RulesContextHandle::lookup_key(prompt_id))
392    }
393
394    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
395        self.context_set
396            .contains(&FetchedUrlContext::lookup_key(url.into()))
397    }
398
399    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
400        self.context()
401            .filter_map(|context| match context {
402                AgentContextHandle::File(file) => {
403                    let buffer = file.buffer.read(cx);
404                    buffer.project_path(cx)
405                }
406                AgentContextHandle::Directory(_)
407                | AgentContextHandle::Symbol(_)
408                | AgentContextHandle::Selection(_)
409                | AgentContextHandle::FetchedUrl(_)
410                | AgentContextHandle::Thread(_)
411                | AgentContextHandle::Rules(_)
412                | AgentContextHandle::Image(_) => None,
413            })
414            .collect()
415    }
416
417    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
418        &self.context_thread_ids
419    }
420}
421
422pub enum FileInclusion {
423    Direct,
424    InDirectory { full_path: PathBuf },
425}
426
427impl FileInclusion {
428    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
429        let file_path = file_context.buffer.read(cx).project_path(cx)?;
430        if path == &file_path {
431            Some(FileInclusion::Direct)
432        } else {
433            None
434        }
435    }
436
437    fn check_directory(
438        directory_context: &DirectoryContextHandle,
439        path: &ProjectPath,
440        project: &Project,
441        cx: &App,
442    ) -> Option<Self> {
443        let worktree = project
444            .worktree_for_entry(directory_context.entry_id, cx)?
445            .read(cx);
446        let entry = worktree.entry_for_id(directory_context.entry_id)?;
447        let directory_path = ProjectPath {
448            worktree_id: worktree.id(),
449            path: entry.path.clone(),
450        };
451        if path.starts_with(&directory_path) {
452            if path == &directory_path {
453                Some(FileInclusion::Direct)
454            } else {
455                Some(FileInclusion::InDirectory {
456                    full_path: worktree.full_path(&entry.path),
457                })
458            }
459        } else {
460            None
461        }
462    }
463}