1use crate::{
2 context::{
3 AgentContextHandle, AgentContextKey, ContextId, ContextKind, DirectoryContextHandle,
4 FetchedUrlContext, FileContextHandle, ImageContext, RulesContextHandle,
5 SelectionContextHandle, SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
6 },
7 thread::{MessageId, ThreadId, ZedAgentThread},
8 thread_store::ThreadStore,
9};
10use anyhow::{Context as _, Result, anyhow};
11use assistant_context::AssistantContext;
12use collections::{HashSet, IndexSet};
13use futures::{self, FutureExt};
14use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
15use language::{Buffer, File as _};
16use language_model::LanguageModelImage;
17use project::{Project, ProjectItem, ProjectPath, Symbol, image_store::is_image_file};
18use prompt_store::UserPromptId;
19use ref_cast::RefCast as _;
20use std::{
21 ops::Range,
22 path::{Path, PathBuf},
23 sync::Arc,
24};
25use text::{Anchor, OffsetRangeExt};
26
27pub struct ContextStore {
28 project: WeakEntity<Project>,
29 thread_store: Option<WeakEntity<ThreadStore>>,
30 next_context_id: ContextId,
31 context_set: IndexSet<AgentContextKey>,
32 context_thread_ids: HashSet<ThreadId>,
33 context_text_thread_paths: HashSet<Arc<Path>>,
34}
35
36pub enum ContextStoreEvent {
37 ContextRemoved(AgentContextKey),
38}
39
40impl EventEmitter<ContextStoreEvent> for ContextStore {}
41
42impl ContextStore {
43 pub fn new(
44 project: WeakEntity<Project>,
45 thread_store: Option<WeakEntity<ThreadStore>>,
46 ) -> Self {
47 Self {
48 project,
49 thread_store,
50 next_context_id: ContextId::zero(),
51 context_set: IndexSet::default(),
52 context_thread_ids: HashSet::default(),
53 context_text_thread_paths: HashSet::default(),
54 }
55 }
56
57 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
58 self.context_set.iter().map(|entry| entry.as_ref())
59 }
60
61 pub fn clear(&mut self, cx: &mut Context<Self>) {
62 self.context_set.clear();
63 self.context_thread_ids.clear();
64 cx.notify();
65 }
66
67 pub fn new_context_for_thread(
68 &self,
69 thread: &ZedAgentThread,
70 exclude_messages_from_id: Option<MessageId>,
71 _cx: &App,
72 ) -> Vec<AgentContextHandle> {
73 let existing_context = thread
74 .messages()
75 .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
76 .flat_map(|message| {
77 message
78 .loaded_context
79 .contexts
80 .iter()
81 .map(|context| AgentContextKey(context.handle()))
82 })
83 .collect::<HashSet<_>>();
84 self.context_set
85 .iter()
86 .filter(|context| !existing_context.contains(context))
87 .map(|entry| entry.0.clone())
88 .collect::<Vec<_>>()
89 }
90
91 pub fn add_file_from_path(
92 &mut self,
93 project_path: ProjectPath,
94 remove_if_exists: bool,
95 cx: &mut Context<Self>,
96 ) -> Task<Result<Option<AgentContextHandle>>> {
97 let Some(project) = self.project.upgrade() else {
98 return Task::ready(Err(anyhow!("failed to read project")));
99 };
100
101 if is_image_file(&project, &project_path, cx) {
102 self.add_image_from_path(project_path, remove_if_exists, cx)
103 } else {
104 cx.spawn(async move |this, cx| {
105 let open_buffer_task = project.update(cx, |project, cx| {
106 project.open_buffer(project_path.clone(), cx)
107 })?;
108 let buffer = open_buffer_task.await?;
109 this.update(cx, |this, cx| {
110 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
111 })
112 })
113 }
114 }
115
116 pub fn add_file_from_buffer(
117 &mut self,
118 project_path: &ProjectPath,
119 buffer: Entity<Buffer>,
120 remove_if_exists: bool,
121 cx: &mut Context<Self>,
122 ) -> Option<AgentContextHandle> {
123 let context_id = self.next_context_id.post_inc();
124 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
125
126 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
127 if remove_if_exists {
128 self.remove_context(&context, cx);
129 None
130 } else {
131 Some(key.as_ref().clone())
132 }
133 } else if self.path_included_in_directory(project_path, cx).is_some() {
134 None
135 } else {
136 self.insert_context(context.clone(), cx);
137 Some(context)
138 }
139 }
140
141 pub fn add_directory(
142 &mut self,
143 project_path: &ProjectPath,
144 remove_if_exists: bool,
145 cx: &mut Context<Self>,
146 ) -> Result<Option<AgentContextHandle>> {
147 let project = self.project.upgrade().context("failed to read project")?;
148 let entry_id = project
149 .read(cx)
150 .entry_for_path(project_path, cx)
151 .map(|entry| entry.id)
152 .context("no entry found for directory context")?;
153
154 let context_id = self.next_context_id.post_inc();
155 let context = AgentContextHandle::Directory(DirectoryContextHandle {
156 entry_id,
157 context_id,
158 });
159
160 let context =
161 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
162 if remove_if_exists {
163 self.remove_context(&context, cx);
164 None
165 } else {
166 Some(existing.as_ref().clone())
167 }
168 } else {
169 self.insert_context(context.clone(), cx);
170 Some(context)
171 };
172
173 anyhow::Ok(context)
174 }
175
176 pub fn add_symbol(
177 &mut self,
178 buffer: Entity<Buffer>,
179 symbol: SharedString,
180 range: Range<Anchor>,
181 enclosing_range: Range<Anchor>,
182 remove_if_exists: bool,
183 cx: &mut Context<Self>,
184 ) -> (Option<AgentContextHandle>, bool) {
185 let context_id = self.next_context_id.post_inc();
186 let context = AgentContextHandle::Symbol(SymbolContextHandle {
187 buffer,
188 symbol,
189 range,
190 enclosing_range,
191 context_id,
192 });
193
194 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
195 let handle = if remove_if_exists {
196 self.remove_context(&context, cx);
197 None
198 } else {
199 Some(key.as_ref().clone())
200 };
201 return (handle, false);
202 }
203
204 let included = self.insert_context(context.clone(), cx);
205 (Some(context), included)
206 }
207
208 pub fn add_thread(
209 &mut self,
210 thread: Entity<ZedAgentThread>,
211 remove_if_exists: bool,
212 cx: &mut Context<Self>,
213 ) -> Option<AgentContextHandle> {
214 let context_id = self.next_context_id.post_inc();
215 let context = AgentContextHandle::Thread(ThreadContextHandle {
216 agent: thread,
217 context_id,
218 });
219
220 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
221 if remove_if_exists {
222 self.remove_context(&context, cx);
223 None
224 } else {
225 Some(existing.as_ref().clone())
226 }
227 } else {
228 self.insert_context(context.clone(), cx);
229 Some(context)
230 }
231 }
232
233 pub fn add_text_thread(
234 &mut self,
235 context: Entity<AssistantContext>,
236 remove_if_exists: bool,
237 cx: &mut Context<Self>,
238 ) -> Option<AgentContextHandle> {
239 let context_id = self.next_context_id.post_inc();
240 let context = AgentContextHandle::TextThread(TextThreadContextHandle {
241 context,
242 context_id,
243 });
244
245 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
246 if remove_if_exists {
247 self.remove_context(&context, cx);
248 None
249 } else {
250 Some(existing.as_ref().clone())
251 }
252 } else {
253 self.insert_context(context.clone(), cx);
254 Some(context)
255 }
256 }
257
258 pub fn add_rules(
259 &mut self,
260 prompt_id: UserPromptId,
261 remove_if_exists: bool,
262 cx: &mut Context<ContextStore>,
263 ) -> Option<AgentContextHandle> {
264 let context_id = self.next_context_id.post_inc();
265 let context = AgentContextHandle::Rules(RulesContextHandle {
266 prompt_id,
267 context_id,
268 });
269
270 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
271 if remove_if_exists {
272 self.remove_context(&context, cx);
273 None
274 } else {
275 Some(existing.as_ref().clone())
276 }
277 } else {
278 self.insert_context(context.clone(), cx);
279 Some(context)
280 }
281 }
282
283 pub fn add_fetched_url(
284 &mut self,
285 url: String,
286 text: impl Into<SharedString>,
287 cx: &mut Context<ContextStore>,
288 ) -> AgentContextHandle {
289 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
290 url: url.into(),
291 text: text.into(),
292 context_id: self.next_context_id.post_inc(),
293 });
294
295 self.insert_context(context.clone(), cx);
296 context
297 }
298
299 pub fn add_image_from_path(
300 &mut self,
301 project_path: ProjectPath,
302 remove_if_exists: bool,
303 cx: &mut Context<ContextStore>,
304 ) -> Task<Result<Option<AgentContextHandle>>> {
305 let project = self.project.clone();
306 cx.spawn(async move |this, cx| {
307 let open_image_task = project.update(cx, |project, cx| {
308 project.open_image(project_path.clone(), cx)
309 })?;
310 let image_item = open_image_task.await?;
311
312 this.update(cx, |this, cx| {
313 let item = image_item.read(cx);
314 this.insert_image(
315 Some(item.project_path(cx)),
316 Some(item.file.full_path(cx).into()),
317 item.image.clone(),
318 remove_if_exists,
319 cx,
320 )
321 })
322 })
323 }
324
325 pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
326 self.insert_image(None, None, image, false, cx);
327 }
328
329 fn insert_image(
330 &mut self,
331 project_path: Option<ProjectPath>,
332 full_path: Option<Arc<Path>>,
333 image: Arc<Image>,
334 remove_if_exists: bool,
335 cx: &mut Context<ContextStore>,
336 ) -> Option<AgentContextHandle> {
337 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
338 let context = AgentContextHandle::Image(ImageContext {
339 project_path,
340 full_path,
341 original_image: image,
342 image_task,
343 context_id: self.next_context_id.post_inc(),
344 });
345 if self.has_context(&context) {
346 if remove_if_exists {
347 self.remove_context(&context, cx);
348 return None;
349 }
350 }
351
352 self.insert_context(context.clone(), cx);
353 Some(context)
354 }
355
356 pub fn add_selection(
357 &mut self,
358 buffer: Entity<Buffer>,
359 range: Range<Anchor>,
360 cx: &mut Context<ContextStore>,
361 ) {
362 let context_id = self.next_context_id.post_inc();
363 let context = AgentContextHandle::Selection(SelectionContextHandle {
364 buffer,
365 range,
366 context_id,
367 });
368 self.insert_context(context, cx);
369 }
370
371 pub fn add_suggested_context(
372 &mut self,
373 suggested: &SuggestedContext,
374 cx: &mut Context<ContextStore>,
375 ) {
376 match suggested {
377 SuggestedContext::File {
378 buffer,
379 icon_path: _,
380 name: _,
381 } => {
382 if let Some(buffer) = buffer.upgrade() {
383 let context_id = self.next_context_id.post_inc();
384 self.insert_context(
385 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
386 cx,
387 );
388 };
389 }
390 SuggestedContext::Thread { thread, name: _ } => {
391 if let Some(thread) = thread.upgrade() {
392 let context_id = self.next_context_id.post_inc();
393 self.insert_context(
394 AgentContextHandle::Thread(ThreadContextHandle {
395 agent: thread,
396 context_id,
397 }),
398 cx,
399 );
400 }
401 }
402 SuggestedContext::TextThread { context, name: _ } => {
403 if let Some(context) = context.upgrade() {
404 let context_id = self.next_context_id.post_inc();
405 self.insert_context(
406 AgentContextHandle::TextThread(TextThreadContextHandle {
407 context,
408 context_id,
409 }),
410 cx,
411 );
412 }
413 }
414 }
415 }
416
417 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
418 match &context {
419 AgentContextHandle::Thread(thread_context) => {
420 if let Some(thread_store) = self.thread_store.clone() {
421 thread_context.agent.update(cx, |thread, cx| {
422 thread.start_generating_detailed_summary_if_needed(thread_store, cx);
423 });
424 self.context_thread_ids
425 .insert(thread_context.agent.read(cx).id().clone());
426 } else {
427 return false;
428 }
429 }
430 AgentContextHandle::TextThread(text_thread_context) => {
431 self.context_text_thread_paths
432 .extend(text_thread_context.context.read(cx).path().cloned());
433 }
434 _ => {}
435 }
436 let inserted = self.context_set.insert(AgentContextKey(context));
437 if inserted {
438 cx.notify();
439 }
440 inserted
441 }
442
443 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
444 if let Some((_, key)) = self
445 .context_set
446 .shift_remove_full(AgentContextKey::ref_cast(context))
447 {
448 match context {
449 AgentContextHandle::Thread(thread_context) => {
450 self.context_thread_ids
451 .remove(thread_context.agent.read(cx).id());
452 }
453 AgentContextHandle::TextThread(text_thread_context) => {
454 if let Some(path) = text_thread_context.context.read(cx).path() {
455 self.context_text_thread_paths.remove(path);
456 }
457 }
458 _ => {}
459 }
460 cx.emit(ContextStoreEvent::ContextRemoved(key));
461 cx.notify();
462 }
463 }
464
465 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
466 self.context_set
467 .contains(AgentContextKey::ref_cast(context))
468 }
469
470 /// Returns whether this file path is already included directly in the context, or if it will be
471 /// included in the context via a directory.
472 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
473 let project = self.project.upgrade()?.read(cx);
474 self.context().find_map(|context| match context {
475 AgentContextHandle::File(file_context) => {
476 FileInclusion::check_file(file_context, path, cx)
477 }
478 AgentContextHandle::Image(image_context) => {
479 FileInclusion::check_image(image_context, path)
480 }
481 AgentContextHandle::Directory(directory_context) => {
482 FileInclusion::check_directory(directory_context, path, project, cx)
483 }
484 _ => None,
485 })
486 }
487
488 pub fn path_included_in_directory(
489 &self,
490 path: &ProjectPath,
491 cx: &App,
492 ) -> Option<FileInclusion> {
493 let project = self.project.upgrade()?.read(cx);
494 self.context().find_map(|context| match context {
495 AgentContextHandle::Directory(directory_context) => {
496 FileInclusion::check_directory(directory_context, path, project, cx)
497 }
498 _ => None,
499 })
500 }
501
502 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
503 self.context().any(|context| match context {
504 AgentContextHandle::Symbol(context) => {
505 if context.symbol != symbol.name {
506 return false;
507 }
508 let buffer = context.buffer.read(cx);
509 let Some(context_path) = buffer.project_path(cx) else {
510 return false;
511 };
512 if context_path != symbol.path {
513 return false;
514 }
515 let context_range = context.range.to_point_utf16(&buffer.snapshot());
516 context_range.start == symbol.range.start.0
517 && context_range.end == symbol.range.end.0
518 }
519 _ => false,
520 })
521 }
522
523 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
524 self.context_thread_ids.contains(thread_id)
525 }
526
527 pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
528 self.context_text_thread_paths.contains(path)
529 }
530
531 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
532 self.context_set
533 .contains(&RulesContextHandle::lookup_key(prompt_id))
534 }
535
536 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
537 self.context_set
538 .contains(&FetchedUrlContext::lookup_key(url.into()))
539 }
540
541 pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
542 self.context_set
543 .get(&FetchedUrlContext::lookup_key(url))
544 .map(|key| key.as_ref().clone())
545 }
546
547 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
548 self.context()
549 .filter_map(|context| match context {
550 AgentContextHandle::File(file) => {
551 let buffer = file.buffer.read(cx);
552 buffer.project_path(cx)
553 }
554 AgentContextHandle::Directory(_)
555 | AgentContextHandle::Symbol(_)
556 | AgentContextHandle::Selection(_)
557 | AgentContextHandle::FetchedUrl(_)
558 | AgentContextHandle::Thread(_)
559 | AgentContextHandle::TextThread(_)
560 | AgentContextHandle::Rules(_)
561 | AgentContextHandle::Image(_) => None,
562 })
563 .collect()
564 }
565
566 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
567 &self.context_thread_ids
568 }
569}
570
571#[derive(Clone)]
572pub enum SuggestedContext {
573 File {
574 name: SharedString,
575 icon_path: Option<SharedString>,
576 buffer: WeakEntity<Buffer>,
577 },
578 Thread {
579 name: SharedString,
580 thread: WeakEntity<ZedAgentThread>,
581 },
582 TextThread {
583 name: SharedString,
584 context: WeakEntity<AssistantContext>,
585 },
586}
587
588impl SuggestedContext {
589 pub fn name(&self) -> &SharedString {
590 match self {
591 Self::File { name, .. } => name,
592 Self::Thread { name, .. } => name,
593 Self::TextThread { name, .. } => name,
594 }
595 }
596
597 pub fn icon_path(&self) -> Option<SharedString> {
598 match self {
599 Self::File { icon_path, .. } => icon_path.clone(),
600 Self::Thread { .. } => None,
601 Self::TextThread { .. } => None,
602 }
603 }
604
605 pub fn kind(&self) -> ContextKind {
606 match self {
607 Self::File { .. } => ContextKind::File,
608 Self::Thread { .. } => ContextKind::Thread,
609 Self::TextThread { .. } => ContextKind::TextThread,
610 }
611 }
612}
613
614pub enum FileInclusion {
615 Direct,
616 InDirectory { full_path: PathBuf },
617}
618
619impl FileInclusion {
620 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
621 let file_path = file_context.buffer.read(cx).project_path(cx)?;
622 if path == &file_path {
623 Some(FileInclusion::Direct)
624 } else {
625 None
626 }
627 }
628
629 fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
630 let image_path = image_context.project_path.as_ref()?;
631 if path == image_path {
632 Some(FileInclusion::Direct)
633 } else {
634 None
635 }
636 }
637
638 fn check_directory(
639 directory_context: &DirectoryContextHandle,
640 path: &ProjectPath,
641 project: &Project,
642 cx: &App,
643 ) -> Option<Self> {
644 let worktree = project
645 .worktree_for_entry(directory_context.entry_id, cx)?
646 .read(cx);
647 let entry = worktree.entry_for_id(directory_context.entry_id)?;
648 let directory_path = ProjectPath {
649 worktree_id: worktree.id(),
650 path: entry.path.clone(),
651 };
652 if path.starts_with(&directory_path) {
653 if path == &directory_path {
654 Some(FileInclusion::Direct)
655 } else {
656 Some(FileInclusion::InDirectory {
657 full_path: worktree.full_path(&entry.path),
658 })
659 }
660 } else {
661 None
662 }
663 }
664}