context_store.rs

  1use std::ops::Range;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use collections::{HashSet, IndexSet};
  7use futures::future::join_all;
  8use futures::{self, FutureExt};
  9use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
 10use language::Buffer;
 11use language_model::LanguageModelImage;
 12use project::{Project, ProjectItem, ProjectPath, Symbol};
 13use prompt_store::UserPromptId;
 14use ref_cast::RefCast as _;
 15use text::{Anchor, OffsetRangeExt};
 16use util::ResultExt as _;
 17
 18use crate::ThreadStore;
 19use crate::context::{
 20    AgentContext, AgentContextKey, ContextId, DirectoryContext, FetchedUrlContext, FileContext,
 21    ImageContext, RulesContext, SelectionContext, SymbolContext, ThreadContext,
 22};
 23use crate::context_strip::SuggestedContext;
 24use crate::thread::{Thread, ThreadId};
 25
 26pub struct ContextStore {
 27    project: WeakEntity<Project>,
 28    thread_store: Option<WeakEntity<ThreadStore>>,
 29    thread_summary_tasks: Vec<Task<()>>,
 30    next_context_id: ContextId,
 31    context_set: IndexSet<AgentContextKey>,
 32    context_thread_ids: HashSet<ThreadId>,
 33}
 34
 35impl ContextStore {
 36    pub fn new(
 37        project: WeakEntity<Project>,
 38        thread_store: Option<WeakEntity<ThreadStore>>,
 39    ) -> Self {
 40        Self {
 41            project,
 42            thread_store,
 43            thread_summary_tasks: Vec::new(),
 44            next_context_id: ContextId::zero(),
 45            context_set: IndexSet::default(),
 46            context_thread_ids: HashSet::default(),
 47        }
 48    }
 49
 50    pub fn context(&self) -> impl Iterator<Item = &AgentContext> {
 51        self.context_set.iter().map(|entry| entry.as_ref())
 52    }
 53
 54    pub fn clear(&mut self) {
 55        self.context_set.clear();
 56        self.context_thread_ids.clear();
 57    }
 58
 59    pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContext> {
 60        let existing_context = thread
 61            .messages()
 62            .flat_map(|message| &message.loaded_context.contexts)
 63            .map(AgentContextKey::ref_cast)
 64            .collect::<HashSet<_>>();
 65        self.context_set
 66            .iter()
 67            .filter(|context| !existing_context.contains(context))
 68            .map(|entry| entry.0.clone())
 69            .collect::<Vec<_>>()
 70    }
 71
 72    pub fn add_file_from_path(
 73        &mut self,
 74        project_path: ProjectPath,
 75        remove_if_exists: bool,
 76        cx: &mut Context<Self>,
 77    ) -> Task<Result<()>> {
 78        let Some(project) = self.project.upgrade() else {
 79            return Task::ready(Err(anyhow!("failed to read project")));
 80        };
 81
 82        cx.spawn(async move |this, cx| {
 83            let open_buffer_task = project.update(cx, |project, cx| {
 84                project.open_buffer(project_path.clone(), cx)
 85            })?;
 86            let buffer = open_buffer_task.await?;
 87            this.update(cx, |this, cx| {
 88                this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
 89            })
 90        })
 91    }
 92
 93    pub fn add_file_from_buffer(
 94        &mut self,
 95        project_path: &ProjectPath,
 96        buffer: Entity<Buffer>,
 97        remove_if_exists: bool,
 98        cx: &mut Context<Self>,
 99    ) {
100        let context_id = self.next_context_id.post_inc();
101        let context = AgentContext::File(FileContext { buffer, context_id });
102
103        let already_included = if self.has_context(&context) {
104            if remove_if_exists {
105                self.remove_context(&context, cx);
106            }
107            true
108        } else {
109            self.path_included_in_directory(project_path, cx).is_some()
110        };
111
112        if !already_included {
113            self.insert_context(context, cx);
114        }
115    }
116
117    pub fn add_directory(
118        &mut self,
119        project_path: &ProjectPath,
120        remove_if_exists: bool,
121        cx: &mut Context<Self>,
122    ) -> Result<()> {
123        let Some(project) = self.project.upgrade() else {
124            return Err(anyhow!("failed to read project"));
125        };
126
127        let Some(entry_id) = project
128            .read(cx)
129            .entry_for_path(project_path, cx)
130            .map(|entry| entry.id)
131        else {
132            return Err(anyhow!("no entry found for directory context"));
133        };
134
135        let context_id = self.next_context_id.post_inc();
136        let context = AgentContext::Directory(DirectoryContext {
137            entry_id,
138            context_id,
139        });
140
141        if self.has_context(&context) {
142            if remove_if_exists {
143                self.remove_context(&context, cx);
144            }
145        } else if self.path_included_in_directory(project_path, cx).is_none() {
146            self.insert_context(context, cx);
147        }
148
149        anyhow::Ok(())
150    }
151
152    pub fn add_symbol(
153        &mut self,
154        buffer: Entity<Buffer>,
155        symbol: SharedString,
156        range: Range<Anchor>,
157        enclosing_range: Range<Anchor>,
158        remove_if_exists: bool,
159        cx: &mut Context<Self>,
160    ) -> bool {
161        let context_id = self.next_context_id.post_inc();
162        let context = AgentContext::Symbol(SymbolContext {
163            buffer,
164            symbol,
165            range,
166            enclosing_range,
167            context_id,
168        });
169
170        if self.has_context(&context) {
171            if remove_if_exists {
172                self.remove_context(&context, cx);
173            }
174            return false;
175        }
176
177        self.insert_context(context, cx)
178    }
179
180    pub fn add_thread(
181        &mut self,
182        thread: Entity<Thread>,
183        remove_if_exists: bool,
184        cx: &mut Context<Self>,
185    ) {
186        let context_id = self.next_context_id.post_inc();
187        let context = AgentContext::Thread(ThreadContext { thread, context_id });
188
189        if self.has_context(&context) {
190            if remove_if_exists {
191                self.remove_context(&context, cx);
192            }
193        } else {
194            self.insert_context(context, cx);
195        }
196    }
197
198    fn start_summarizing_thread_if_needed(
199        &mut self,
200        thread: &Entity<Thread>,
201        cx: &mut Context<Self>,
202    ) {
203        if let Some(summary_task) =
204            thread.update(cx, |thread, cx| thread.generate_detailed_summary(cx))
205        {
206            let thread = thread.clone();
207            let thread_store = self.thread_store.clone();
208
209            self.thread_summary_tasks.push(cx.spawn(async move |_, cx| {
210                summary_task.await;
211
212                if let Some(thread_store) = thread_store {
213                    // Save thread so its summary can be reused later
214                    let save_task = thread_store
215                        .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx));
216
217                    if let Some(save_task) = save_task.ok() {
218                        save_task.await.log_err();
219                    }
220                }
221            }));
222        }
223    }
224
225    pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> {
226        let tasks = std::mem::take(&mut self.thread_summary_tasks);
227
228        cx.spawn(async move |_cx| {
229            join_all(tasks).await;
230        })
231    }
232
233    pub fn add_rules(
234        &mut self,
235        prompt_id: UserPromptId,
236        remove_if_exists: bool,
237        cx: &mut Context<ContextStore>,
238    ) {
239        let context_id = self.next_context_id.post_inc();
240        let context = AgentContext::Rules(RulesContext {
241            prompt_id,
242            context_id,
243        });
244
245        if self.has_context(&context) {
246            if remove_if_exists {
247                self.remove_context(&context, cx);
248            }
249        } else {
250            self.insert_context(context, cx);
251        }
252    }
253
254    pub fn add_fetched_url(
255        &mut self,
256        url: String,
257        text: impl Into<SharedString>,
258        cx: &mut Context<ContextStore>,
259    ) {
260        let context = AgentContext::FetchedUrl(FetchedUrlContext {
261            url: url.into(),
262            text: text.into(),
263            context_id: self.next_context_id.post_inc(),
264        });
265
266        self.insert_context(context, cx);
267    }
268
269    pub fn add_image(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
270        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
271        let context = AgentContext::Image(ImageContext {
272            original_image: image,
273            image_task,
274            context_id: self.next_context_id.post_inc(),
275        });
276        self.insert_context(context, cx);
277    }
278
279    pub fn add_selection(
280        &mut self,
281        buffer: Entity<Buffer>,
282        range: Range<Anchor>,
283        cx: &mut Context<ContextStore>,
284    ) {
285        let context_id = self.next_context_id.post_inc();
286        let context = AgentContext::Selection(SelectionContext {
287            buffer,
288            range,
289            context_id,
290        });
291        self.insert_context(context, cx);
292    }
293
294    pub fn add_suggested_context(
295        &mut self,
296        suggested: &SuggestedContext,
297        cx: &mut Context<ContextStore>,
298    ) {
299        match suggested {
300            SuggestedContext::File {
301                buffer,
302                icon_path: _,
303                name: _,
304            } => {
305                if let Some(buffer) = buffer.upgrade() {
306                    let context_id = self.next_context_id.post_inc();
307                    self.insert_context(AgentContext::File(FileContext { buffer, context_id }), cx);
308                };
309            }
310            SuggestedContext::Thread { thread, name: _ } => {
311                if let Some(thread) = thread.upgrade() {
312                    let context_id = self.next_context_id.post_inc();
313                    self.insert_context(
314                        AgentContext::Thread(ThreadContext { thread, context_id }),
315                        cx,
316                    );
317                }
318            }
319        }
320    }
321
322    fn insert_context(&mut self, context: AgentContext, cx: &mut Context<Self>) -> bool {
323        match &context {
324            AgentContext::Thread(thread_context) => {
325                self.context_thread_ids
326                    .insert(thread_context.thread.read(cx).id().clone());
327                self.start_summarizing_thread_if_needed(&thread_context.thread, cx);
328            }
329            _ => {}
330        }
331        let inserted = self.context_set.insert(AgentContextKey(context));
332        if inserted {
333            cx.notify();
334        }
335        inserted
336    }
337
338    pub fn remove_context(&mut self, context: &AgentContext, cx: &mut Context<Self>) {
339        if self
340            .context_set
341            .shift_remove(AgentContextKey::ref_cast(context))
342        {
343            match context {
344                AgentContext::Thread(thread_context) => {
345                    self.context_thread_ids
346                        .remove(thread_context.thread.read(cx).id());
347                }
348                _ => {}
349            }
350            cx.notify();
351        }
352    }
353
354    pub fn has_context(&mut self, context: &AgentContext) -> bool {
355        self.context_set
356            .contains(AgentContextKey::ref_cast(context))
357    }
358
359    /// Returns whether this file path is already included directly in the context, or if it will be
360    /// included in the context via a directory.
361    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
362        let project = self.project.upgrade()?.read(cx);
363        self.context().find_map(|context| match context {
364            AgentContext::File(file_context) => FileInclusion::check_file(file_context, path, cx),
365            AgentContext::Directory(directory_context) => {
366                FileInclusion::check_directory(directory_context, path, project, cx)
367            }
368            _ => None,
369        })
370    }
371
372    pub fn path_included_in_directory(
373        &self,
374        path: &ProjectPath,
375        cx: &App,
376    ) -> Option<FileInclusion> {
377        let project = self.project.upgrade()?.read(cx);
378        self.context().find_map(|context| match context {
379            AgentContext::Directory(directory_context) => {
380                FileInclusion::check_directory(directory_context, path, project, cx)
381            }
382            _ => None,
383        })
384    }
385
386    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
387        self.context().any(|context| match context {
388            AgentContext::Symbol(context) => {
389                if context.symbol != symbol.name {
390                    return false;
391                }
392                let buffer = context.buffer.read(cx);
393                let Some(context_path) = buffer.project_path(cx) else {
394                    return false;
395                };
396                if context_path != symbol.path {
397                    return false;
398                }
399                let context_range = context.range.to_point_utf16(&buffer.snapshot());
400                context_range.start == symbol.range.start.0
401                    && context_range.end == symbol.range.end.0
402            }
403            _ => false,
404        })
405    }
406
407    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
408        self.context_thread_ids.contains(thread_id)
409    }
410
411    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
412        self.context_set
413            .contains(&RulesContext::lookup_key(prompt_id))
414    }
415
416    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
417        self.context_set
418            .contains(&FetchedUrlContext::lookup_key(url.into()))
419    }
420
421    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
422        self.context()
423            .filter_map(|context| match context {
424                AgentContext::File(file) => {
425                    let buffer = file.buffer.read(cx);
426                    buffer.project_path(cx)
427                }
428                AgentContext::Directory(_)
429                | AgentContext::Symbol(_)
430                | AgentContext::Selection(_)
431                | AgentContext::FetchedUrl(_)
432                | AgentContext::Thread(_)
433                | AgentContext::Rules(_)
434                | AgentContext::Image(_) => None,
435            })
436            .collect()
437    }
438
439    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
440        &self.context_thread_ids
441    }
442}
443
444pub enum FileInclusion {
445    Direct,
446    InDirectory { full_path: PathBuf },
447}
448
449impl FileInclusion {
450    fn check_file(file_context: &FileContext, path: &ProjectPath, cx: &App) -> Option<Self> {
451        let file_path = file_context.buffer.read(cx).project_path(cx)?;
452        if path == &file_path {
453            Some(FileInclusion::Direct)
454        } else {
455            None
456        }
457    }
458
459    fn check_directory(
460        directory_context: &DirectoryContext,
461        path: &ProjectPath,
462        project: &Project,
463        cx: &App,
464    ) -> Option<Self> {
465        let worktree = project
466            .worktree_for_entry(directory_context.entry_id, cx)?
467            .read(cx);
468        let entry = worktree.entry_for_id(directory_context.entry_id)?;
469        let directory_path = ProjectPath {
470            worktree_id: worktree.id(),
471            path: entry.path.clone(),
472        };
473        if path.starts_with(&directory_path) {
474            if path == &directory_path {
475                Some(FileInclusion::Direct)
476            } else {
477                Some(FileInclusion::InDirectory {
478                    full_path: worktree.full_path(&entry.path),
479                })
480            }
481        } else {
482            None
483        }
484    }
485}