context_store.rs

  1use crate::context::{
  2    AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
  3    FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
  4    SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
  5};
  6use agent_client_protocol as acp;
  7use anyhow::{Context as _, Result, anyhow};
  8use assistant_context::AssistantContext;
  9use collections::{HashSet, IndexSet};
 10use futures::{self, FutureExt};
 11use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
 12use language::{Buffer, File as _};
 13use language_model::LanguageModelImage;
 14use project::{
 15    Project, ProjectItem, ProjectPath, Symbol, image_store::is_image_file,
 16    lsp_store::SymbolLocation,
 17};
 18use prompt_store::UserPromptId;
 19use ref_cast::RefCast as _;
 20use std::{
 21    ops::Range,
 22    path::{Path, PathBuf},
 23    sync::Arc,
 24};
 25use text::{Anchor, OffsetRangeExt};
 26
 27pub struct ContextStore {
 28    project: WeakEntity<Project>,
 29    next_context_id: ContextId,
 30    context_set: IndexSet<AgentContextKey>,
 31    context_thread_ids: HashSet<acp::SessionId>,
 32    context_text_thread_paths: HashSet<Arc<Path>>,
 33}
 34
 35pub enum ContextStoreEvent {
 36    ContextRemoved(AgentContextKey),
 37}
 38
 39impl EventEmitter<ContextStoreEvent> for ContextStore {}
 40
 41impl ContextStore {
 42    pub fn new(project: WeakEntity<Project>) -> Self {
 43        Self {
 44            project,
 45            next_context_id: ContextId::zero(),
 46            context_set: IndexSet::default(),
 47            context_thread_ids: HashSet::default(),
 48            context_text_thread_paths: HashSet::default(),
 49        }
 50    }
 51
 52    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 53        self.context_set.iter().map(|entry| entry.as_ref())
 54    }
 55
 56    pub fn clear(&mut self, cx: &mut Context<Self>) {
 57        self.context_set.clear();
 58        self.context_thread_ids.clear();
 59        cx.notify();
 60    }
 61
 62    pub fn add_file_from_path(
 63        &mut self,
 64        project_path: ProjectPath,
 65        remove_if_exists: bool,
 66        cx: &mut Context<Self>,
 67    ) -> Task<Result<Option<AgentContextHandle>>> {
 68        let Some(project) = self.project.upgrade() else {
 69            return Task::ready(Err(anyhow!("failed to read project")));
 70        };
 71
 72        if is_image_file(&project, &project_path, cx) {
 73            self.add_image_from_path(project_path, remove_if_exists, cx)
 74        } else {
 75            cx.spawn(async move |this, cx| {
 76                let open_buffer_task = project.update(cx, |project, cx| {
 77                    project.open_buffer(project_path.clone(), cx)
 78                })?;
 79                let buffer = open_buffer_task.await?;
 80                this.update(cx, |this, cx| {
 81                    this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
 82                })
 83            })
 84        }
 85    }
 86
 87    pub fn add_file_from_buffer(
 88        &mut self,
 89        project_path: &ProjectPath,
 90        buffer: Entity<Buffer>,
 91        remove_if_exists: bool,
 92        cx: &mut Context<Self>,
 93    ) -> Option<AgentContextHandle> {
 94        let context_id = self.next_context_id.post_inc();
 95        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
 96
 97        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
 98            if remove_if_exists {
 99                self.remove_context(&context, cx);
100                None
101            } else {
102                Some(key.as_ref().clone())
103            }
104        } else if self.path_included_in_directory(project_path, cx).is_some() {
105            None
106        } else {
107            self.insert_context(context.clone(), cx);
108            Some(context)
109        }
110    }
111
112    pub fn add_directory(
113        &mut self,
114        project_path: &ProjectPath,
115        remove_if_exists: bool,
116        cx: &mut Context<Self>,
117    ) -> Result<Option<AgentContextHandle>> {
118        let project = self.project.upgrade().context("failed to read project")?;
119        let entry_id = project
120            .read(cx)
121            .entry_for_path(project_path, cx)
122            .map(|entry| entry.id)
123            .context("no entry found for directory context")?;
124
125        let context_id = self.next_context_id.post_inc();
126        let context = AgentContextHandle::Directory(DirectoryContextHandle {
127            entry_id,
128            context_id,
129        });
130
131        let context =
132            if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
133                if remove_if_exists {
134                    self.remove_context(&context, cx);
135                    None
136                } else {
137                    Some(existing.as_ref().clone())
138                }
139            } else {
140                self.insert_context(context.clone(), cx);
141                Some(context)
142            };
143
144        anyhow::Ok(context)
145    }
146
147    pub fn add_symbol(
148        &mut self,
149        buffer: Entity<Buffer>,
150        symbol: SharedString,
151        range: Range<Anchor>,
152        enclosing_range: Range<Anchor>,
153        remove_if_exists: bool,
154        cx: &mut Context<Self>,
155    ) -> (Option<AgentContextHandle>, bool) {
156        let context_id = self.next_context_id.post_inc();
157        let context = AgentContextHandle::Symbol(SymbolContextHandle {
158            buffer,
159            symbol,
160            range,
161            enclosing_range,
162            context_id,
163        });
164
165        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
166            let handle = if remove_if_exists {
167                self.remove_context(&context, cx);
168                None
169            } else {
170                Some(key.as_ref().clone())
171            };
172            return (handle, false);
173        }
174
175        let included = self.insert_context(context.clone(), cx);
176        (Some(context), included)
177    }
178
179    pub fn add_thread(
180        &mut self,
181        thread: Entity<agent::Thread>,
182        remove_if_exists: bool,
183        cx: &mut Context<Self>,
184    ) -> Option<AgentContextHandle> {
185        let context_id = self.next_context_id.post_inc();
186        let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
187
188        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
189            if remove_if_exists {
190                self.remove_context(&context, cx);
191                None
192            } else {
193                Some(existing.as_ref().clone())
194            }
195        } else {
196            self.insert_context(context.clone(), cx);
197            Some(context)
198        }
199    }
200
201    pub fn add_text_thread(
202        &mut self,
203        context: Entity<AssistantContext>,
204        remove_if_exists: bool,
205        cx: &mut Context<Self>,
206    ) -> Option<AgentContextHandle> {
207        let context_id = self.next_context_id.post_inc();
208        let context = AgentContextHandle::TextThread(TextThreadContextHandle {
209            context,
210            context_id,
211        });
212
213        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
214            if remove_if_exists {
215                self.remove_context(&context, cx);
216                None
217            } else {
218                Some(existing.as_ref().clone())
219            }
220        } else {
221            self.insert_context(context.clone(), cx);
222            Some(context)
223        }
224    }
225
226    pub fn add_rules(
227        &mut self,
228        prompt_id: UserPromptId,
229        remove_if_exists: bool,
230        cx: &mut Context<ContextStore>,
231    ) -> Option<AgentContextHandle> {
232        let context_id = self.next_context_id.post_inc();
233        let context = AgentContextHandle::Rules(RulesContextHandle {
234            prompt_id,
235            context_id,
236        });
237
238        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
239            if remove_if_exists {
240                self.remove_context(&context, cx);
241                None
242            } else {
243                Some(existing.as_ref().clone())
244            }
245        } else {
246            self.insert_context(context.clone(), cx);
247            Some(context)
248        }
249    }
250
251    pub fn add_fetched_url(
252        &mut self,
253        url: String,
254        text: impl Into<SharedString>,
255        cx: &mut Context<ContextStore>,
256    ) -> AgentContextHandle {
257        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
258            url: url.into(),
259            text: text.into(),
260            context_id: self.next_context_id.post_inc(),
261        });
262
263        self.insert_context(context.clone(), cx);
264        context
265    }
266
267    pub fn add_image_from_path(
268        &mut self,
269        project_path: ProjectPath,
270        remove_if_exists: bool,
271        cx: &mut Context<ContextStore>,
272    ) -> Task<Result<Option<AgentContextHandle>>> {
273        let project = self.project.clone();
274        cx.spawn(async move |this, cx| {
275            let open_image_task = project.update(cx, |project, cx| {
276                project.open_image(project_path.clone(), cx)
277            })?;
278            let image_item = open_image_task.await?;
279
280            this.update(cx, |this, cx| {
281                let item = image_item.read(cx);
282                this.insert_image(
283                    Some(item.project_path(cx)),
284                    Some(item.file.full_path(cx).to_string_lossy().into_owned()),
285                    item.image.clone(),
286                    remove_if_exists,
287                    cx,
288                )
289            })
290        })
291    }
292
293    pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
294        self.insert_image(None, None, image, false, cx);
295    }
296
297    fn insert_image(
298        &mut self,
299        project_path: Option<ProjectPath>,
300        full_path: Option<String>,
301        image: Arc<Image>,
302        remove_if_exists: bool,
303        cx: &mut Context<ContextStore>,
304    ) -> Option<AgentContextHandle> {
305        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
306        let context = AgentContextHandle::Image(ImageContext {
307            project_path,
308            full_path,
309            original_image: image,
310            image_task,
311            context_id: self.next_context_id.post_inc(),
312        });
313        if self.has_context(&context) && remove_if_exists {
314            self.remove_context(&context, cx);
315            return None;
316        }
317
318        self.insert_context(context.clone(), cx);
319        Some(context)
320    }
321
322    pub fn add_selection(
323        &mut self,
324        buffer: Entity<Buffer>,
325        range: Range<Anchor>,
326        cx: &mut Context<ContextStore>,
327    ) {
328        let context_id = self.next_context_id.post_inc();
329        let context = AgentContextHandle::Selection(SelectionContextHandle {
330            buffer,
331            range,
332            context_id,
333        });
334        self.insert_context(context, cx);
335    }
336
337    pub fn add_suggested_context(
338        &mut self,
339        suggested: &SuggestedContext,
340        cx: &mut Context<ContextStore>,
341    ) {
342        match suggested {
343            SuggestedContext::File {
344                buffer,
345                icon_path: _,
346                name: _,
347            } => {
348                if let Some(buffer) = buffer.upgrade() {
349                    let context_id = self.next_context_id.post_inc();
350                    self.insert_context(
351                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
352                        cx,
353                    );
354                };
355            }
356            // SuggestedContext::Thread { thread, name: _ } => {
357            //     if let Some(thread) = thread.upgrade() {
358            //         let context_id = self.next_context_id.post_inc();
359            //         self.insert_context(
360            //             AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
361            //             cx,
362            //         );
363            //     }
364            // }
365            SuggestedContext::TextThread { context, name: _ } => {
366                if let Some(context) = context.upgrade() {
367                    let context_id = self.next_context_id.post_inc();
368                    self.insert_context(
369                        AgentContextHandle::TextThread(TextThreadContextHandle {
370                            context,
371                            context_id,
372                        }),
373                        cx,
374                    );
375                }
376            }
377        }
378    }
379
380    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
381        match &context {
382            // AgentContextHandle::Thread(thread_context) => {
383            //     if let Some(thread_store) = self.thread_store.clone() {
384            //         thread_context.thread.update(cx, |thread, cx| {
385            //             thread.start_generating_detailed_summary_if_needed(thread_store, cx);
386            //         });
387            //         self.context_thread_ids
388            //             .insert(thread_context.thread.read(cx).id().clone());
389            //     } else {
390            //         return false;
391            //     }
392            // }
393            AgentContextHandle::TextThread(text_thread_context) => {
394                self.context_text_thread_paths
395                    .extend(text_thread_context.context.read(cx).path().cloned());
396            }
397            _ => {}
398        }
399        let inserted = self.context_set.insert(AgentContextKey(context));
400        if inserted {
401            cx.notify();
402        }
403        inserted
404    }
405
406    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
407        if let Some((_, key)) = self
408            .context_set
409            .shift_remove_full(AgentContextKey::ref_cast(context))
410        {
411            match context {
412                AgentContextHandle::Thread(thread_context) => {
413                    self.context_thread_ids
414                        .remove(thread_context.thread.read(cx).id());
415                }
416                AgentContextHandle::TextThread(text_thread_context) => {
417                    if let Some(path) = text_thread_context.context.read(cx).path() {
418                        self.context_text_thread_paths.remove(path);
419                    }
420                }
421                _ => {}
422            }
423            cx.emit(ContextStoreEvent::ContextRemoved(key));
424            cx.notify();
425        }
426    }
427
428    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
429        self.context_set
430            .contains(AgentContextKey::ref_cast(context))
431    }
432
433    /// Returns whether this file path is already included directly in the context, or if it will be
434    /// included in the context via a directory.
435    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
436        let project = self.project.upgrade()?.read(cx);
437        self.context().find_map(|context| match context {
438            AgentContextHandle::File(file_context) => {
439                FileInclusion::check_file(file_context, path, cx)
440            }
441            AgentContextHandle::Image(image_context) => {
442                FileInclusion::check_image(image_context, path)
443            }
444            AgentContextHandle::Directory(directory_context) => {
445                FileInclusion::check_directory(directory_context, path, project, cx)
446            }
447            _ => None,
448        })
449    }
450
451    pub fn path_included_in_directory(
452        &self,
453        path: &ProjectPath,
454        cx: &App,
455    ) -> Option<FileInclusion> {
456        let project = self.project.upgrade()?.read(cx);
457        self.context().find_map(|context| match context {
458            AgentContextHandle::Directory(directory_context) => {
459                FileInclusion::check_directory(directory_context, path, project, cx)
460            }
461            _ => None,
462        })
463    }
464
465    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
466        self.context().any(|context| match context {
467            AgentContextHandle::Symbol(context) => {
468                if context.symbol != symbol.name {
469                    return false;
470                }
471                let buffer = context.buffer.read(cx);
472                let Some(context_path) = buffer.project_path(cx) else {
473                    return false;
474                };
475                if symbol.path != SymbolLocation::InProject(context_path) {
476                    return false;
477                }
478                let context_range = context.range.to_point_utf16(&buffer.snapshot());
479                context_range.start == symbol.range.start.0
480                    && context_range.end == symbol.range.end.0
481            }
482            _ => false,
483        })
484    }
485
486    pub fn includes_thread(&self, thread_id: &acp::SessionId) -> bool {
487        self.context_thread_ids.contains(thread_id)
488    }
489
490    pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
491        self.context_text_thread_paths.contains(path)
492    }
493
494    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
495        self.context_set
496            .contains(&RulesContextHandle::lookup_key(prompt_id))
497    }
498
499    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
500        self.context_set
501            .contains(&FetchedUrlContext::lookup_key(url.into()))
502    }
503
504    pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
505        self.context_set
506            .get(&FetchedUrlContext::lookup_key(url))
507            .map(|key| key.as_ref().clone())
508    }
509
510    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
511        self.context()
512            .filter_map(|context| match context {
513                AgentContextHandle::File(file) => {
514                    let buffer = file.buffer.read(cx);
515                    buffer.project_path(cx)
516                }
517                AgentContextHandle::Directory(_)
518                | AgentContextHandle::Symbol(_)
519                | AgentContextHandle::Thread(_)
520                | AgentContextHandle::Selection(_)
521                | AgentContextHandle::FetchedUrl(_)
522                | AgentContextHandle::TextThread(_)
523                | AgentContextHandle::Rules(_)
524                | AgentContextHandle::Image(_) => None,
525            })
526            .collect()
527    }
528
529    pub const fn thread_ids(&self) -> &HashSet<acp::SessionId> {
530        &self.context_thread_ids
531    }
532}
533
534#[derive(Clone)]
535pub enum SuggestedContext {
536    File {
537        name: SharedString,
538        icon_path: Option<SharedString>,
539        buffer: WeakEntity<Buffer>,
540    },
541    // Thread {
542    //     name: SharedString,
543    //     thread: WeakEntity<Thread>,
544    // },
545    TextThread {
546        name: SharedString,
547        context: WeakEntity<AssistantContext>,
548    },
549}
550
551impl SuggestedContext {
552    pub const fn name(&self) -> &SharedString {
553        match self {
554            Self::File { name, .. } => name,
555            // Self::Thread { name, .. } => name,
556            Self::TextThread { name, .. } => name,
557        }
558    }
559
560    pub fn icon_path(&self) -> Option<SharedString> {
561        match self {
562            Self::File { icon_path, .. } => icon_path.clone(),
563            // Self::Thread { .. } => None,
564            Self::TextThread { .. } => None,
565        }
566    }
567
568    pub const fn kind(&self) -> ContextKind {
569        match self {
570            Self::File { .. } => ContextKind::File,
571            // Self::Thread { .. } => ContextKind::Thread,
572            Self::TextThread { .. } => ContextKind::TextThread,
573        }
574    }
575}
576
577pub enum FileInclusion {
578    Direct,
579    InDirectory { full_path: PathBuf },
580}
581
582impl FileInclusion {
583    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
584        let file_path = file_context.buffer.read(cx).project_path(cx)?;
585        if path == &file_path {
586            Some(FileInclusion::Direct)
587        } else {
588            None
589        }
590    }
591
592    fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
593        let image_path = image_context.project_path.as_ref()?;
594        if path == image_path {
595            Some(FileInclusion::Direct)
596        } else {
597            None
598        }
599    }
600
601    fn check_directory(
602        directory_context: &DirectoryContextHandle,
603        path: &ProjectPath,
604        project: &Project,
605        cx: &App,
606    ) -> Option<Self> {
607        let worktree = project
608            .worktree_for_entry(directory_context.entry_id, cx)?
609            .read(cx);
610        let entry = worktree.entry_for_id(directory_context.entry_id)?;
611        let directory_path = ProjectPath {
612            worktree_id: worktree.id(),
613            path: entry.path.clone(),
614        };
615        if path.starts_with(&directory_path) {
616            if path == &directory_path {
617                Some(FileInclusion::Direct)
618            } else {
619                Some(FileInclusion::InDirectory {
620                    full_path: worktree.full_path(&entry.path),
621                })
622            }
623        } else {
624            None
625        }
626    }
627}