context_store.rs

  1use crate::{
  2    context::{
  3        AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
  4        FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
  5        SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
  6    },
  7    thread::{MessageId, ThreadId, ZedAgentThread},
  8    thread_store::ThreadStore,
  9};
 10use anyhow::{Context as _, Result, anyhow};
 11use assistant_context::AssistantContext;
 12use collections::{HashSet, IndexSet};
 13use futures::{self, FutureExt};
 14use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
 15use language::{Buffer, File as _};
 16use language_model::LanguageModelImage;
 17use project::{Project, ProjectItem, ProjectPath, Symbol, image_store::is_image_file};
 18use prompt_store::UserPromptId;
 19use ref_cast::RefCast as _;
 20use std::{
 21    ops::Range,
 22    path::{Path, PathBuf},
 23    sync::Arc,
 24};
 25use text::{Anchor, OffsetRangeExt};
 26
 27pub struct ContextStore {
 28    project: WeakEntity<Project>,
 29    thread_store: Option<WeakEntity<ThreadStore>>,
 30    next_context_id: ContextId,
 31    context_set: IndexSet<AgentContextKey>,
 32    context_thread_ids: HashSet<ThreadId>,
 33    context_text_thread_paths: HashSet<Arc<Path>>,
 34}
 35
 36pub enum ContextStoreEvent {
 37    ContextRemoved(AgentContextKey),
 38}
 39
 40impl EventEmitter<ContextStoreEvent> for ContextStore {}
 41
 42impl ContextStore {
 43    pub fn new(
 44        project: WeakEntity<Project>,
 45        thread_store: Option<WeakEntity<ThreadStore>>,
 46    ) -> Self {
 47        Self {
 48            project,
 49            thread_store,
 50            next_context_id: ContextId::zero(),
 51            context_set: IndexSet::default(),
 52            context_thread_ids: HashSet::default(),
 53            context_text_thread_paths: HashSet::default(),
 54        }
 55    }
 56
 57    pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
 58        self.context_set.iter().map(|entry| entry.as_ref())
 59    }
 60
 61    pub fn clear(&mut self, cx: &mut Context<Self>) {
 62        self.context_set.clear();
 63        self.context_thread_ids.clear();
 64        cx.notify();
 65    }
 66
 67    pub fn new_context_for_thread(
 68        &self,
 69        thread: &ZedAgentThread,
 70        exclude_messages_from_id: Option<MessageId>,
 71        _cx: &App,
 72    ) -> Vec<AgentContextHandle> {
 73        let existing_context = thread
 74            .messages()
 75            .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
 76            .flat_map(|message| {
 77                message
 78                    .loaded_context
 79                    .contexts
 80                    .iter()
 81                    .map(|context| AgentContextKey(context.handle()))
 82            })
 83            .collect::<HashSet<_>>();
 84        self.context_set
 85            .iter()
 86            .filter(|context| !existing_context.contains(context))
 87            .map(|entry| entry.0.clone())
 88            .collect::<Vec<_>>()
 89    }
 90
 91    pub fn add_file_from_path(
 92        &mut self,
 93        project_path: ProjectPath,
 94        remove_if_exists: bool,
 95        cx: &mut Context<Self>,
 96    ) -> Task<Result<Option<AgentContextHandle>>> {
 97        let Some(project) = self.project.upgrade() else {
 98            return Task::ready(Err(anyhow!("failed to read project")));
 99        };
100
101        if is_image_file(&project, &project_path, cx) {
102            self.add_image_from_path(project_path, remove_if_exists, cx)
103        } else {
104            cx.spawn(async move |this, cx| {
105                let open_buffer_task = project.update(cx, |project, cx| {
106                    project.open_buffer(project_path.clone(), cx)
107                })?;
108                let buffer = open_buffer_task.await?;
109                this.update(cx, |this, cx| {
110                    this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
111                })
112            })
113        }
114    }
115
116    pub fn add_file_from_buffer(
117        &mut self,
118        project_path: &ProjectPath,
119        buffer: Entity<Buffer>,
120        remove_if_exists: bool,
121        cx: &mut Context<Self>,
122    ) -> Option<AgentContextHandle> {
123        let context_id = self.next_context_id.post_inc();
124        let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
125
126        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
127            if remove_if_exists {
128                self.remove_context(&context, cx);
129                None
130            } else {
131                Some(key.as_ref().clone())
132            }
133        } else if self.path_included_in_directory(project_path, cx).is_some() {
134            None
135        } else {
136            self.insert_context(context.clone(), cx);
137            Some(context)
138        }
139    }
140
141    pub fn add_directory(
142        &mut self,
143        project_path: &ProjectPath,
144        remove_if_exists: bool,
145        cx: &mut Context<Self>,
146    ) -> Result<Option<AgentContextHandle>> {
147        let project = self.project.upgrade().context("failed to read project")?;
148        let entry_id = project
149            .read(cx)
150            .entry_for_path(project_path, cx)
151            .map(|entry| entry.id)
152            .context("no entry found for directory context")?;
153
154        let context_id = self.next_context_id.post_inc();
155        let context = AgentContextHandle::Directory(DirectoryContextHandle {
156            entry_id,
157            context_id,
158        });
159
160        let context =
161            if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
162                if remove_if_exists {
163                    self.remove_context(&context, cx);
164                    None
165                } else {
166                    Some(existing.as_ref().clone())
167                }
168            } else {
169                self.insert_context(context.clone(), cx);
170                Some(context)
171            };
172
173        anyhow::Ok(context)
174    }
175
176    pub fn add_symbol(
177        &mut self,
178        buffer: Entity<Buffer>,
179        symbol: SharedString,
180        range: Range<Anchor>,
181        enclosing_range: Range<Anchor>,
182        remove_if_exists: bool,
183        cx: &mut Context<Self>,
184    ) -> (Option<AgentContextHandle>, bool) {
185        let context_id = self.next_context_id.post_inc();
186        let context = AgentContextHandle::Symbol(SymbolContextHandle {
187            buffer,
188            symbol,
189            range,
190            enclosing_range,
191            context_id,
192        });
193
194        if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
195            let handle = if remove_if_exists {
196                self.remove_context(&context, cx);
197                None
198            } else {
199                Some(key.as_ref().clone())
200            };
201            return (handle, false);
202        }
203
204        let included = self.insert_context(context.clone(), cx);
205        (Some(context), included)
206    }
207
208    pub fn add_thread(
209        &mut self,
210        thread: Entity<ZedAgentThread>,
211        remove_if_exists: bool,
212        cx: &mut Context<Self>,
213    ) -> Option<AgentContextHandle> {
214        let context_id = self.next_context_id.post_inc();
215        let context = AgentContextHandle::Thread(ThreadContextHandle {
216            agent: thread,
217            context_id,
218        });
219
220        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
221            if remove_if_exists {
222                self.remove_context(&context, cx);
223                None
224            } else {
225                Some(existing.as_ref().clone())
226            }
227        } else {
228            self.insert_context(context.clone(), cx);
229            Some(context)
230        }
231    }
232
233    pub fn add_text_thread(
234        &mut self,
235        context: Entity<AssistantContext>,
236        remove_if_exists: bool,
237        cx: &mut Context<Self>,
238    ) -> Option<AgentContextHandle> {
239        let context_id = self.next_context_id.post_inc();
240        let context = AgentContextHandle::TextThread(TextThreadContextHandle {
241            context,
242            context_id,
243        });
244
245        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
246            if remove_if_exists {
247                self.remove_context(&context, cx);
248                None
249            } else {
250                Some(existing.as_ref().clone())
251            }
252        } else {
253            self.insert_context(context.clone(), cx);
254            Some(context)
255        }
256    }
257
258    pub fn add_rules(
259        &mut self,
260        prompt_id: UserPromptId,
261        remove_if_exists: bool,
262        cx: &mut Context<ContextStore>,
263    ) -> Option<AgentContextHandle> {
264        let context_id = self.next_context_id.post_inc();
265        let context = AgentContextHandle::Rules(RulesContextHandle {
266            prompt_id,
267            context_id,
268        });
269
270        if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
271            if remove_if_exists {
272                self.remove_context(&context, cx);
273                None
274            } else {
275                Some(existing.as_ref().clone())
276            }
277        } else {
278            self.insert_context(context.clone(), cx);
279            Some(context)
280        }
281    }
282
283    pub fn add_fetched_url(
284        &mut self,
285        url: String,
286        text: impl Into<SharedString>,
287        cx: &mut Context<ContextStore>,
288    ) -> AgentContextHandle {
289        let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
290            url: url.into(),
291            text: text.into(),
292            context_id: self.next_context_id.post_inc(),
293        });
294
295        self.insert_context(context.clone(), cx);
296        context
297    }
298
299    pub fn add_image_from_path(
300        &mut self,
301        project_path: ProjectPath,
302        remove_if_exists: bool,
303        cx: &mut Context<ContextStore>,
304    ) -> Task<Result<Option<AgentContextHandle>>> {
305        let project = self.project.clone();
306        cx.spawn(async move |this, cx| {
307            let open_image_task = project.update(cx, |project, cx| {
308                project.open_image(project_path.clone(), cx)
309            })?;
310            let image_item = open_image_task.await?;
311
312            this.update(cx, |this, cx| {
313                let item = image_item.read(cx);
314                this.insert_image(
315                    Some(item.project_path(cx)),
316                    Some(item.file.full_path(cx).into()),
317                    item.image.clone(),
318                    remove_if_exists,
319                    cx,
320                )
321            })
322        })
323    }
324
325    pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
326        self.insert_image(None, None, image, false, cx);
327    }
328
329    fn insert_image(
330        &mut self,
331        project_path: Option<ProjectPath>,
332        full_path: Option<Arc<Path>>,
333        image: Arc<Image>,
334        remove_if_exists: bool,
335        cx: &mut Context<ContextStore>,
336    ) -> Option<AgentContextHandle> {
337        let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
338        let context = AgentContextHandle::Image(ImageContext {
339            project_path,
340            full_path,
341            original_image: image,
342            image_task,
343            context_id: self.next_context_id.post_inc(),
344        });
345        if self.has_context(&context) {
346            if remove_if_exists {
347                self.remove_context(&context, cx);
348                return None;
349            }
350        }
351
352        self.insert_context(context.clone(), cx);
353        Some(context)
354    }
355
356    pub fn add_selection(
357        &mut self,
358        buffer: Entity<Buffer>,
359        range: Range<Anchor>,
360        cx: &mut Context<ContextStore>,
361    ) {
362        let context_id = self.next_context_id.post_inc();
363        let context = AgentContextHandle::Selection(SelectionContextHandle {
364            buffer,
365            range,
366            context_id,
367        });
368        self.insert_context(context, cx);
369    }
370
371    pub fn add_suggested_context(
372        &mut self,
373        suggested: &SuggestedContext,
374        cx: &mut Context<ContextStore>,
375    ) {
376        match suggested {
377            SuggestedContext::File {
378                buffer,
379                icon_path: _,
380                name: _,
381            } => {
382                if let Some(buffer) = buffer.upgrade() {
383                    let context_id = self.next_context_id.post_inc();
384                    self.insert_context(
385                        AgentContextHandle::File(FileContextHandle { buffer, context_id }),
386                        cx,
387                    );
388                };
389            }
390            SuggestedContext::Thread { thread, name: _ } => {
391                if let Some(thread) = thread.upgrade() {
392                    let context_id = self.next_context_id.post_inc();
393                    self.insert_context(
394                        AgentContextHandle::Thread(ThreadContextHandle {
395                            agent: thread,
396                            context_id,
397                        }),
398                        cx,
399                    );
400                }
401            }
402            SuggestedContext::TextThread { context, name: _ } => {
403                if let Some(context) = context.upgrade() {
404                    let context_id = self.next_context_id.post_inc();
405                    self.insert_context(
406                        AgentContextHandle::TextThread(TextThreadContextHandle {
407                            context,
408                            context_id,
409                        }),
410                        cx,
411                    );
412                }
413            }
414        }
415    }
416
417    fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
418        match &context {
419            AgentContextHandle::Thread(thread_context) => {
420                if let Some(thread_store) = self.thread_store.clone() {
421                    thread_context.agent.update(cx, |thread, cx| {
422                        thread.start_generating_detailed_summary_if_needed(thread_store, cx);
423                    });
424                    self.context_thread_ids
425                        .insert(thread_context.agent.read(cx).id().clone());
426                } else {
427                    return false;
428                }
429            }
430            AgentContextHandle::TextThread(text_thread_context) => {
431                self.context_text_thread_paths
432                    .extend(text_thread_context.context.read(cx).path().cloned());
433            }
434            _ => {}
435        }
436        let inserted = self.context_set.insert(AgentContextKey(context));
437        if inserted {
438            cx.notify();
439        }
440        inserted
441    }
442
443    pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
444        if let Some((_, key)) = self
445            .context_set
446            .shift_remove_full(AgentContextKey::ref_cast(context))
447        {
448            match context {
449                AgentContextHandle::Thread(thread_context) => {
450                    self.context_thread_ids
451                        .remove(thread_context.agent.read(cx).id());
452                }
453                AgentContextHandle::TextThread(text_thread_context) => {
454                    if let Some(path) = text_thread_context.context.read(cx).path() {
455                        self.context_text_thread_paths.remove(path);
456                    }
457                }
458                _ => {}
459            }
460            cx.emit(ContextStoreEvent::ContextRemoved(key));
461            cx.notify();
462        }
463    }
464
465    pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
466        self.context_set
467            .contains(AgentContextKey::ref_cast(context))
468    }
469
470    /// Returns whether this file path is already included directly in the context, or if it will be
471    /// included in the context via a directory.
472    pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
473        let project = self.project.upgrade()?.read(cx);
474        self.context().find_map(|context| match context {
475            AgentContextHandle::File(file_context) => {
476                FileInclusion::check_file(file_context, path, cx)
477            }
478            AgentContextHandle::Image(image_context) => {
479                FileInclusion::check_image(image_context, path)
480            }
481            AgentContextHandle::Directory(directory_context) => {
482                FileInclusion::check_directory(directory_context, path, project, cx)
483            }
484            _ => None,
485        })
486    }
487
488    pub fn path_included_in_directory(
489        &self,
490        path: &ProjectPath,
491        cx: &App,
492    ) -> Option<FileInclusion> {
493        let project = self.project.upgrade()?.read(cx);
494        self.context().find_map(|context| match context {
495            AgentContextHandle::Directory(directory_context) => {
496                FileInclusion::check_directory(directory_context, path, project, cx)
497            }
498            _ => None,
499        })
500    }
501
502    pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
503        self.context().any(|context| match context {
504            AgentContextHandle::Symbol(context) => {
505                if context.symbol != symbol.name {
506                    return false;
507                }
508                let buffer = context.buffer.read(cx);
509                let Some(context_path) = buffer.project_path(cx) else {
510                    return false;
511                };
512                if context_path != symbol.path {
513                    return false;
514                }
515                let context_range = context.range.to_point_utf16(&buffer.snapshot());
516                context_range.start == symbol.range.start.0
517                    && context_range.end == symbol.range.end.0
518            }
519            _ => false,
520        })
521    }
522
523    pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
524        self.context_thread_ids.contains(thread_id)
525    }
526
527    pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
528        self.context_text_thread_paths.contains(path)
529    }
530
531    pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
532        self.context_set
533            .contains(&RulesContextHandle::lookup_key(prompt_id))
534    }
535
536    pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
537        self.context_set
538            .contains(&FetchedUrlContext::lookup_key(url.into()))
539    }
540
541    pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
542        self.context_set
543            .get(&FetchedUrlContext::lookup_key(url))
544            .map(|key| key.as_ref().clone())
545    }
546
547    pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
548        self.context()
549            .filter_map(|context| match context {
550                AgentContextHandle::File(file) => {
551                    let buffer = file.buffer.read(cx);
552                    buffer.project_path(cx)
553                }
554                AgentContextHandle::Directory(_)
555                | AgentContextHandle::Symbol(_)
556                | AgentContextHandle::Selection(_)
557                | AgentContextHandle::FetchedUrl(_)
558                | AgentContextHandle::Thread(_)
559                | AgentContextHandle::TextThread(_)
560                | AgentContextHandle::Rules(_)
561                | AgentContextHandle::Image(_) => None,
562            })
563            .collect()
564    }
565
566    pub fn thread_ids(&self) -> &HashSet<ThreadId> {
567        &self.context_thread_ids
568    }
569}
570
571#[derive(Clone)]
572pub enum SuggestedContext {
573    File {
574        name: SharedString,
575        icon_path: Option<SharedString>,
576        buffer: WeakEntity<Buffer>,
577    },
578    Thread {
579        name: SharedString,
580        thread: WeakEntity<ZedAgentThread>,
581    },
582    TextThread {
583        name: SharedString,
584        context: WeakEntity<AssistantContext>,
585    },
586}
587
588impl SuggestedContext {
589    pub fn name(&self) -> &SharedString {
590        match self {
591            Self::File { name, .. } => name,
592            Self::Thread { name, .. } => name,
593            Self::TextThread { name, .. } => name,
594        }
595    }
596
597    pub fn icon_path(&self) -> Option<SharedString> {
598        match self {
599            Self::File { icon_path, .. } => icon_path.clone(),
600            Self::Thread { .. } => None,
601            Self::TextThread { .. } => None,
602        }
603    }
604
605    pub fn kind(&self) -> ContextKind {
606        match self {
607            Self::File { .. } => ContextKind::File,
608            Self::Thread { .. } => ContextKind::Thread,
609            Self::TextThread { .. } => ContextKind::TextThread,
610        }
611    }
612}
613
614pub enum FileInclusion {
615    Direct,
616    InDirectory { full_path: PathBuf },
617}
618
619impl FileInclusion {
620    fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
621        let file_path = file_context.buffer.read(cx).project_path(cx)?;
622        if path == &file_path {
623            Some(FileInclusion::Direct)
624        } else {
625            None
626        }
627    }
628
629    fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
630        let image_path = image_context.project_path.as_ref()?;
631        if path == image_path {
632            Some(FileInclusion::Direct)
633        } else {
634            None
635        }
636    }
637
638    fn check_directory(
639        directory_context: &DirectoryContextHandle,
640        path: &ProjectPath,
641        project: &Project,
642        cx: &App,
643    ) -> Option<Self> {
644        let worktree = project
645            .worktree_for_entry(directory_context.entry_id, cx)?
646            .read(cx);
647        let entry = worktree.entry_for_id(directory_context.entry_id)?;
648        let directory_path = ProjectPath {
649            worktree_id: worktree.id(),
650            path: entry.path.clone(),
651        };
652        if path.starts_with(&directory_path) {
653            if path == &directory_path {
654                Some(FileInclusion::Direct)
655            } else {
656                Some(FileInclusion::InDirectory {
657                    full_path: worktree.full_path(&entry.path),
658                })
659            }
660        } else {
661            None
662        }
663    }
664}