1use std::ops::Range;
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4
5use anyhow::{Context as _, Result, anyhow};
6use assistant_context_editor::AssistantContext;
7use collections::{HashSet, IndexSet};
8use futures::{self, FutureExt};
9use gpui::{App, Context, Entity, EventEmitter, Image, SharedString, Task, WeakEntity};
10use language::Buffer;
11use language_model::LanguageModelImage;
12use project::image_store::is_image_file;
13use project::{Project, ProjectItem, ProjectPath, Symbol};
14use prompt_store::UserPromptId;
15use ref_cast::RefCast as _;
16use text::{Anchor, OffsetRangeExt};
17
18use crate::ThreadStore;
19use crate::context::{
20 AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
21 FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
22 SymbolContextHandle, TextThreadContextHandle, ThreadContextHandle,
23};
24use crate::context_strip::SuggestedContext;
25use crate::thread::{MessageId, Thread, ThreadId};
26
27pub struct ContextStore {
28 project: WeakEntity<Project>,
29 thread_store: Option<WeakEntity<ThreadStore>>,
30 next_context_id: ContextId,
31 context_set: IndexSet<AgentContextKey>,
32 context_thread_ids: HashSet<ThreadId>,
33 context_text_thread_paths: HashSet<Arc<Path>>,
34}
35
36pub enum ContextStoreEvent {
37 ContextRemoved(AgentContextKey),
38}
39
40impl EventEmitter<ContextStoreEvent> for ContextStore {}
41
42impl ContextStore {
43 pub fn new(
44 project: WeakEntity<Project>,
45 thread_store: Option<WeakEntity<ThreadStore>>,
46 ) -> Self {
47 Self {
48 project,
49 thread_store,
50 next_context_id: ContextId::zero(),
51 context_set: IndexSet::default(),
52 context_thread_ids: HashSet::default(),
53 context_text_thread_paths: HashSet::default(),
54 }
55 }
56
57 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
58 self.context_set.iter().map(|entry| entry.as_ref())
59 }
60
61 pub fn clear(&mut self) {
62 self.context_set.clear();
63 self.context_thread_ids.clear();
64 }
65
66 pub fn new_context_for_thread(
67 &self,
68 thread: &Thread,
69 exclude_messages_from_id: Option<MessageId>,
70 ) -> Vec<AgentContextHandle> {
71 let existing_context = thread
72 .messages()
73 .take_while(|message| exclude_messages_from_id.is_none_or(|id| message.id != id))
74 .flat_map(|message| {
75 message
76 .loaded_context
77 .contexts
78 .iter()
79 .map(|context| AgentContextKey(context.handle()))
80 })
81 .collect::<HashSet<_>>();
82 self.context_set
83 .iter()
84 .filter(|context| !existing_context.contains(context))
85 .map(|entry| entry.0.clone())
86 .collect::<Vec<_>>()
87 }
88
89 pub fn add_file_from_path(
90 &mut self,
91 project_path: ProjectPath,
92 remove_if_exists: bool,
93 cx: &mut Context<Self>,
94 ) -> Task<Result<Option<AgentContextHandle>>> {
95 let Some(project) = self.project.upgrade() else {
96 return Task::ready(Err(anyhow!("failed to read project")));
97 };
98
99 if is_image_file(&project, &project_path, cx) {
100 self.add_image_from_path(project_path, remove_if_exists, cx)
101 } else {
102 cx.spawn(async move |this, cx| {
103 let open_buffer_task = project.update(cx, |project, cx| {
104 project.open_buffer(project_path.clone(), cx)
105 })?;
106 let buffer = open_buffer_task.await?;
107 this.update(cx, |this, cx| {
108 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
109 })
110 })
111 }
112 }
113
114 pub fn add_file_from_buffer(
115 &mut self,
116 project_path: &ProjectPath,
117 buffer: Entity<Buffer>,
118 remove_if_exists: bool,
119 cx: &mut Context<Self>,
120 ) -> Option<AgentContextHandle> {
121 let context_id = self.next_context_id.post_inc();
122 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
123
124 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
125 if remove_if_exists {
126 self.remove_context(&context, cx);
127 None
128 } else {
129 Some(key.as_ref().clone())
130 }
131 } else if self.path_included_in_directory(project_path, cx).is_some() {
132 None
133 } else {
134 self.insert_context(context.clone(), cx);
135 Some(context)
136 }
137 }
138
139 pub fn add_directory(
140 &mut self,
141 project_path: &ProjectPath,
142 remove_if_exists: bool,
143 cx: &mut Context<Self>,
144 ) -> Result<Option<AgentContextHandle>> {
145 let project = self.project.upgrade().context("failed to read project")?;
146 let entry_id = project
147 .read(cx)
148 .entry_for_path(project_path, cx)
149 .map(|entry| entry.id)
150 .context("no entry found for directory context")?;
151
152 let context_id = self.next_context_id.post_inc();
153 let context = AgentContextHandle::Directory(DirectoryContextHandle {
154 entry_id,
155 context_id,
156 });
157
158 let context =
159 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
160 if remove_if_exists {
161 self.remove_context(&context, cx);
162 None
163 } else {
164 Some(existing.as_ref().clone())
165 }
166 } else {
167 self.insert_context(context.clone(), cx);
168 Some(context)
169 };
170
171 anyhow::Ok(context)
172 }
173
174 pub fn add_symbol(
175 &mut self,
176 buffer: Entity<Buffer>,
177 symbol: SharedString,
178 range: Range<Anchor>,
179 enclosing_range: Range<Anchor>,
180 remove_if_exists: bool,
181 cx: &mut Context<Self>,
182 ) -> (Option<AgentContextHandle>, bool) {
183 let context_id = self.next_context_id.post_inc();
184 let context = AgentContextHandle::Symbol(SymbolContextHandle {
185 buffer,
186 symbol,
187 range,
188 enclosing_range,
189 context_id,
190 });
191
192 if let Some(key) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
193 let handle = if remove_if_exists {
194 self.remove_context(&context, cx);
195 None
196 } else {
197 Some(key.as_ref().clone())
198 };
199 return (handle, false);
200 }
201
202 let included = self.insert_context(context.clone(), cx);
203 (Some(context), included)
204 }
205
206 pub fn add_thread(
207 &mut self,
208 thread: Entity<Thread>,
209 remove_if_exists: bool,
210 cx: &mut Context<Self>,
211 ) -> Option<AgentContextHandle> {
212 let context_id = self.next_context_id.post_inc();
213 let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
214
215 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
216 if remove_if_exists {
217 self.remove_context(&context, cx);
218 None
219 } else {
220 Some(existing.as_ref().clone())
221 }
222 } else {
223 self.insert_context(context.clone(), cx);
224 Some(context)
225 }
226 }
227
228 pub fn add_text_thread(
229 &mut self,
230 context: Entity<AssistantContext>,
231 remove_if_exists: bool,
232 cx: &mut Context<Self>,
233 ) -> Option<AgentContextHandle> {
234 let context_id = self.next_context_id.post_inc();
235 let context = AgentContextHandle::TextThread(TextThreadContextHandle {
236 context,
237 context_id,
238 });
239
240 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
241 if remove_if_exists {
242 self.remove_context(&context, cx);
243 None
244 } else {
245 Some(existing.as_ref().clone())
246 }
247 } else {
248 self.insert_context(context.clone(), cx);
249 Some(context)
250 }
251 }
252
253 pub fn add_rules(
254 &mut self,
255 prompt_id: UserPromptId,
256 remove_if_exists: bool,
257 cx: &mut Context<ContextStore>,
258 ) -> Option<AgentContextHandle> {
259 let context_id = self.next_context_id.post_inc();
260 let context = AgentContextHandle::Rules(RulesContextHandle {
261 prompt_id,
262 context_id,
263 });
264
265 if let Some(existing) = self.context_set.get(AgentContextKey::ref_cast(&context)) {
266 if remove_if_exists {
267 self.remove_context(&context, cx);
268 None
269 } else {
270 Some(existing.as_ref().clone())
271 }
272 } else {
273 self.insert_context(context.clone(), cx);
274 Some(context)
275 }
276 }
277
278 pub fn add_fetched_url(
279 &mut self,
280 url: String,
281 text: impl Into<SharedString>,
282 cx: &mut Context<ContextStore>,
283 ) -> AgentContextHandle {
284 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
285 url: url.into(),
286 text: text.into(),
287 context_id: self.next_context_id.post_inc(),
288 });
289
290 self.insert_context(context.clone(), cx);
291 context
292 }
293
294 pub fn add_image_from_path(
295 &mut self,
296 project_path: ProjectPath,
297 remove_if_exists: bool,
298 cx: &mut Context<ContextStore>,
299 ) -> Task<Result<Option<AgentContextHandle>>> {
300 let project = self.project.clone();
301 cx.spawn(async move |this, cx| {
302 let open_image_task = project.update(cx, |project, cx| {
303 project.open_image(project_path.clone(), cx)
304 })?;
305 let image_item = open_image_task.await?;
306 let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
307 this.update(cx, |this, cx| {
308 this.insert_image(
309 Some(image_item.read(cx).project_path(cx)),
310 image,
311 remove_if_exists,
312 cx,
313 )
314 })
315 })
316 }
317
318 pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
319 self.insert_image(None, image, false, cx);
320 }
321
322 fn insert_image(
323 &mut self,
324 project_path: Option<ProjectPath>,
325 image: Arc<Image>,
326 remove_if_exists: bool,
327 cx: &mut Context<ContextStore>,
328 ) -> Option<AgentContextHandle> {
329 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
330 let context = AgentContextHandle::Image(ImageContext {
331 project_path,
332 original_image: image,
333 image_task,
334 context_id: self.next_context_id.post_inc(),
335 });
336 if self.has_context(&context) {
337 if remove_if_exists {
338 self.remove_context(&context, cx);
339 return None;
340 }
341 }
342
343 self.insert_context(context.clone(), cx);
344 Some(context)
345 }
346
347 pub fn add_selection(
348 &mut self,
349 buffer: Entity<Buffer>,
350 range: Range<Anchor>,
351 cx: &mut Context<ContextStore>,
352 ) {
353 let context_id = self.next_context_id.post_inc();
354 let context = AgentContextHandle::Selection(SelectionContextHandle {
355 buffer,
356 range,
357 context_id,
358 });
359 self.insert_context(context, cx);
360 }
361
362 pub fn add_suggested_context(
363 &mut self,
364 suggested: &SuggestedContext,
365 cx: &mut Context<ContextStore>,
366 ) {
367 match suggested {
368 SuggestedContext::File {
369 buffer,
370 icon_path: _,
371 name: _,
372 } => {
373 if let Some(buffer) = buffer.upgrade() {
374 let context_id = self.next_context_id.post_inc();
375 self.insert_context(
376 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
377 cx,
378 );
379 };
380 }
381 SuggestedContext::Thread { thread, name: _ } => {
382 if let Some(thread) = thread.upgrade() {
383 let context_id = self.next_context_id.post_inc();
384 self.insert_context(
385 AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
386 cx,
387 );
388 }
389 }
390 SuggestedContext::TextThread { context, name: _ } => {
391 if let Some(context) = context.upgrade() {
392 let context_id = self.next_context_id.post_inc();
393 self.insert_context(
394 AgentContextHandle::TextThread(TextThreadContextHandle {
395 context,
396 context_id,
397 }),
398 cx,
399 );
400 }
401 }
402 }
403 }
404
405 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
406 match &context {
407 AgentContextHandle::Thread(thread_context) => {
408 if let Some(thread_store) = self.thread_store.clone() {
409 thread_context.thread.update(cx, |thread, cx| {
410 thread.start_generating_detailed_summary_if_needed(thread_store, cx);
411 });
412 self.context_thread_ids
413 .insert(thread_context.thread.read(cx).id().clone());
414 } else {
415 return false;
416 }
417 }
418 AgentContextHandle::TextThread(text_thread_context) => {
419 self.context_text_thread_paths
420 .extend(text_thread_context.context.read(cx).path().cloned());
421 }
422 _ => {}
423 }
424 let inserted = self.context_set.insert(AgentContextKey(context));
425 if inserted {
426 cx.notify();
427 }
428 inserted
429 }
430
431 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
432 if let Some((_, key)) = self
433 .context_set
434 .shift_remove_full(AgentContextKey::ref_cast(context))
435 {
436 match context {
437 AgentContextHandle::Thread(thread_context) => {
438 self.context_thread_ids
439 .remove(thread_context.thread.read(cx).id());
440 }
441 AgentContextHandle::TextThread(text_thread_context) => {
442 if let Some(path) = text_thread_context.context.read(cx).path() {
443 self.context_text_thread_paths.remove(path);
444 }
445 }
446 _ => {}
447 }
448 cx.emit(ContextStoreEvent::ContextRemoved(key));
449 cx.notify();
450 }
451 }
452
453 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
454 self.context_set
455 .contains(AgentContextKey::ref_cast(context))
456 }
457
458 /// Returns whether this file path is already included directly in the context, or if it will be
459 /// included in the context via a directory.
460 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
461 let project = self.project.upgrade()?.read(cx);
462 self.context().find_map(|context| match context {
463 AgentContextHandle::File(file_context) => {
464 FileInclusion::check_file(file_context, path, cx)
465 }
466 AgentContextHandle::Image(image_context) => {
467 FileInclusion::check_image(image_context, path)
468 }
469 AgentContextHandle::Directory(directory_context) => {
470 FileInclusion::check_directory(directory_context, path, project, cx)
471 }
472 _ => None,
473 })
474 }
475
476 pub fn path_included_in_directory(
477 &self,
478 path: &ProjectPath,
479 cx: &App,
480 ) -> Option<FileInclusion> {
481 let project = self.project.upgrade()?.read(cx);
482 self.context().find_map(|context| match context {
483 AgentContextHandle::Directory(directory_context) => {
484 FileInclusion::check_directory(directory_context, path, project, cx)
485 }
486 _ => None,
487 })
488 }
489
490 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
491 self.context().any(|context| match context {
492 AgentContextHandle::Symbol(context) => {
493 if context.symbol != symbol.name {
494 return false;
495 }
496 let buffer = context.buffer.read(cx);
497 let Some(context_path) = buffer.project_path(cx) else {
498 return false;
499 };
500 if context_path != symbol.path {
501 return false;
502 }
503 let context_range = context.range.to_point_utf16(&buffer.snapshot());
504 context_range.start == symbol.range.start.0
505 && context_range.end == symbol.range.end.0
506 }
507 _ => false,
508 })
509 }
510
511 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
512 self.context_thread_ids.contains(thread_id)
513 }
514
515 pub fn includes_text_thread(&self, path: &Arc<Path>) -> bool {
516 self.context_text_thread_paths.contains(path)
517 }
518
519 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
520 self.context_set
521 .contains(&RulesContextHandle::lookup_key(prompt_id))
522 }
523
524 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
525 self.context_set
526 .contains(&FetchedUrlContext::lookup_key(url.into()))
527 }
528
529 pub fn get_url_context(&self, url: SharedString) -> Option<AgentContextHandle> {
530 self.context_set
531 .get(&FetchedUrlContext::lookup_key(url))
532 .map(|key| key.as_ref().clone())
533 }
534
535 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
536 self.context()
537 .filter_map(|context| match context {
538 AgentContextHandle::File(file) => {
539 let buffer = file.buffer.read(cx);
540 buffer.project_path(cx)
541 }
542 AgentContextHandle::Directory(_)
543 | AgentContextHandle::Symbol(_)
544 | AgentContextHandle::Selection(_)
545 | AgentContextHandle::FetchedUrl(_)
546 | AgentContextHandle::Thread(_)
547 | AgentContextHandle::TextThread(_)
548 | AgentContextHandle::Rules(_)
549 | AgentContextHandle::Image(_) => None,
550 })
551 .collect()
552 }
553
554 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
555 &self.context_thread_ids
556 }
557}
558
559pub enum FileInclusion {
560 Direct,
561 InDirectory { full_path: PathBuf },
562}
563
564impl FileInclusion {
565 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
566 let file_path = file_context.buffer.read(cx).project_path(cx)?;
567 if path == &file_path {
568 Some(FileInclusion::Direct)
569 } else {
570 None
571 }
572 }
573
574 fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
575 let image_path = image_context.project_path.as_ref()?;
576 if path == image_path {
577 Some(FileInclusion::Direct)
578 } else {
579 None
580 }
581 }
582
583 fn check_directory(
584 directory_context: &DirectoryContextHandle,
585 path: &ProjectPath,
586 project: &Project,
587 cx: &App,
588 ) -> Option<Self> {
589 let worktree = project
590 .worktree_for_entry(directory_context.entry_id, cx)?
591 .read(cx);
592 let entry = worktree.entry_for_id(directory_context.entry_id)?;
593 let directory_path = ProjectPath {
594 worktree_id: worktree.id(),
595 path: entry.path.clone(),
596 };
597 if path.starts_with(&directory_path) {
598 if path == &directory_path {
599 Some(FileInclusion::Direct)
600 } else {
601 Some(FileInclusion::InDirectory {
602 full_path: worktree.full_path(&entry.path),
603 })
604 }
605 } else {
606 None
607 }
608 }
609}