1use std::ops::Range;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{Result, anyhow};
6use collections::{HashSet, IndexSet};
7use futures::future::join_all;
8use futures::{self, FutureExt};
9use gpui::{App, Context, Entity, Image, SharedString, Task, WeakEntity};
10use language::Buffer;
11use language_model::LanguageModelImage;
12use project::image_store::is_image_file;
13use project::{Project, ProjectItem, ProjectPath, Symbol};
14use prompt_store::UserPromptId;
15use ref_cast::RefCast as _;
16use text::{Anchor, OffsetRangeExt};
17use util::ResultExt as _;
18
19use crate::ThreadStore;
20use crate::context::{
21 AgentContextHandle, AgentContextKey, ContextId, DirectoryContextHandle, FetchedUrlContext,
22 FileContextHandle, ImageContext, RulesContextHandle, SelectionContextHandle,
23 SymbolContextHandle, ThreadContextHandle,
24};
25use crate::context_strip::SuggestedContext;
26use crate::thread::{Thread, ThreadId};
27
28pub struct ContextStore {
29 project: WeakEntity<Project>,
30 thread_store: Option<WeakEntity<ThreadStore>>,
31 thread_summary_tasks: Vec<Task<()>>,
32 next_context_id: ContextId,
33 context_set: IndexSet<AgentContextKey>,
34 context_thread_ids: HashSet<ThreadId>,
35}
36
37impl ContextStore {
38 pub fn new(
39 project: WeakEntity<Project>,
40 thread_store: Option<WeakEntity<ThreadStore>>,
41 ) -> Self {
42 Self {
43 project,
44 thread_store,
45 thread_summary_tasks: Vec::new(),
46 next_context_id: ContextId::zero(),
47 context_set: IndexSet::default(),
48 context_thread_ids: HashSet::default(),
49 }
50 }
51
52 pub fn context(&self) -> impl Iterator<Item = &AgentContextHandle> {
53 self.context_set.iter().map(|entry| entry.as_ref())
54 }
55
56 pub fn clear(&mut self) {
57 self.context_set.clear();
58 self.context_thread_ids.clear();
59 }
60
61 pub fn new_context_for_thread(&self, thread: &Thread) -> Vec<AgentContextHandle> {
62 let existing_context = thread
63 .messages()
64 .flat_map(|message| {
65 message
66 .loaded_context
67 .contexts
68 .iter()
69 .map(|context| AgentContextKey(context.handle()))
70 })
71 .collect::<HashSet<_>>();
72 self.context_set
73 .iter()
74 .filter(|context| !existing_context.contains(context))
75 .map(|entry| entry.0.clone())
76 .collect::<Vec<_>>()
77 }
78
79 pub fn add_file_from_path(
80 &mut self,
81 project_path: ProjectPath,
82 remove_if_exists: bool,
83 cx: &mut Context<Self>,
84 ) -> Task<Result<()>> {
85 let Some(project) = self.project.upgrade() else {
86 return Task::ready(Err(anyhow!("failed to read project")));
87 };
88
89 if is_image_file(&project, &project_path, cx) {
90 self.add_image_from_path(project_path, remove_if_exists, cx)
91 } else {
92 cx.spawn(async move |this, cx| {
93 let open_buffer_task = project.update(cx, |project, cx| {
94 project.open_buffer(project_path.clone(), cx)
95 })?;
96 let buffer = open_buffer_task.await?;
97 this.update(cx, |this, cx| {
98 this.add_file_from_buffer(&project_path, buffer, remove_if_exists, cx)
99 })
100 })
101 }
102 }
103
104 pub fn add_file_from_buffer(
105 &mut self,
106 project_path: &ProjectPath,
107 buffer: Entity<Buffer>,
108 remove_if_exists: bool,
109 cx: &mut Context<Self>,
110 ) {
111 let context_id = self.next_context_id.post_inc();
112 let context = AgentContextHandle::File(FileContextHandle { buffer, context_id });
113
114 let already_included = if self.has_context(&context) {
115 if remove_if_exists {
116 self.remove_context(&context, cx);
117 }
118 true
119 } else {
120 self.path_included_in_directory(project_path, cx).is_some()
121 };
122
123 if !already_included {
124 self.insert_context(context, cx);
125 }
126 }
127
128 pub fn add_directory(
129 &mut self,
130 project_path: &ProjectPath,
131 remove_if_exists: bool,
132 cx: &mut Context<Self>,
133 ) -> Result<()> {
134 let Some(project) = self.project.upgrade() else {
135 return Err(anyhow!("failed to read project"));
136 };
137
138 let Some(entry_id) = project
139 .read(cx)
140 .entry_for_path(project_path, cx)
141 .map(|entry| entry.id)
142 else {
143 return Err(anyhow!("no entry found for directory context"));
144 };
145
146 let context_id = self.next_context_id.post_inc();
147 let context = AgentContextHandle::Directory(DirectoryContextHandle {
148 entry_id,
149 context_id,
150 });
151
152 if self.has_context(&context) {
153 if remove_if_exists {
154 self.remove_context(&context, cx);
155 }
156 } else if self.path_included_in_directory(project_path, cx).is_none() {
157 self.insert_context(context, cx);
158 }
159
160 anyhow::Ok(())
161 }
162
163 pub fn add_symbol(
164 &mut self,
165 buffer: Entity<Buffer>,
166 symbol: SharedString,
167 range: Range<Anchor>,
168 enclosing_range: Range<Anchor>,
169 remove_if_exists: bool,
170 cx: &mut Context<Self>,
171 ) -> bool {
172 let context_id = self.next_context_id.post_inc();
173 let context = AgentContextHandle::Symbol(SymbolContextHandle {
174 buffer,
175 symbol,
176 range,
177 enclosing_range,
178 context_id,
179 });
180
181 if self.has_context(&context) {
182 if remove_if_exists {
183 self.remove_context(&context, cx);
184 }
185 return false;
186 }
187
188 self.insert_context(context, cx)
189 }
190
191 pub fn add_thread(
192 &mut self,
193 thread: Entity<Thread>,
194 remove_if_exists: bool,
195 cx: &mut Context<Self>,
196 ) {
197 let context_id = self.next_context_id.post_inc();
198 let context = AgentContextHandle::Thread(ThreadContextHandle { thread, context_id });
199
200 if self.has_context(&context) {
201 if remove_if_exists {
202 self.remove_context(&context, cx);
203 }
204 } else {
205 self.insert_context(context, cx);
206 }
207 }
208
209 fn start_summarizing_thread_if_needed(
210 &mut self,
211 thread: &Entity<Thread>,
212 cx: &mut Context<Self>,
213 ) {
214 if let Some(summary_task) =
215 thread.update(cx, |thread, cx| thread.generate_detailed_summary(cx))
216 {
217 let thread = thread.clone();
218 let thread_store = self.thread_store.clone();
219
220 self.thread_summary_tasks.push(cx.spawn(async move |_, cx| {
221 summary_task.await;
222
223 if let Some(thread_store) = thread_store {
224 // Save thread so its summary can be reused later
225 let save_task = thread_store
226 .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx));
227
228 if let Some(save_task) = save_task.ok() {
229 save_task.await.log_err();
230 }
231 }
232 }));
233 }
234 }
235
236 pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> {
237 let tasks = std::mem::take(&mut self.thread_summary_tasks);
238
239 cx.spawn(async move |_cx| {
240 join_all(tasks).await;
241 })
242 }
243
244 pub fn add_rules(
245 &mut self,
246 prompt_id: UserPromptId,
247 remove_if_exists: bool,
248 cx: &mut Context<ContextStore>,
249 ) {
250 let context_id = self.next_context_id.post_inc();
251 let context = AgentContextHandle::Rules(RulesContextHandle {
252 prompt_id,
253 context_id,
254 });
255
256 if self.has_context(&context) {
257 if remove_if_exists {
258 self.remove_context(&context, cx);
259 }
260 } else {
261 self.insert_context(context, cx);
262 }
263 }
264
265 pub fn add_fetched_url(
266 &mut self,
267 url: String,
268 text: impl Into<SharedString>,
269 cx: &mut Context<ContextStore>,
270 ) {
271 let context = AgentContextHandle::FetchedUrl(FetchedUrlContext {
272 url: url.into(),
273 text: text.into(),
274 context_id: self.next_context_id.post_inc(),
275 });
276
277 self.insert_context(context, cx);
278 }
279
280 pub fn add_image_from_path(
281 &mut self,
282 project_path: ProjectPath,
283 remove_if_exists: bool,
284 cx: &mut Context<ContextStore>,
285 ) -> Task<Result<()>> {
286 let project = self.project.clone();
287 cx.spawn(async move |this, cx| {
288 let open_image_task = project.update(cx, |project, cx| {
289 project.open_image(project_path.clone(), cx)
290 })?;
291 let image_item = open_image_task.await?;
292 let image = image_item.read_with(cx, |image_item, _| image_item.image.clone())?;
293 this.update(cx, |this, cx| {
294 this.insert_image(
295 Some(image_item.read(cx).project_path(cx)),
296 image,
297 remove_if_exists,
298 cx,
299 );
300 })
301 })
302 }
303
304 pub fn add_image_instance(&mut self, image: Arc<Image>, cx: &mut Context<ContextStore>) {
305 self.insert_image(None, image, false, cx);
306 }
307
308 fn insert_image(
309 &mut self,
310 project_path: Option<ProjectPath>,
311 image: Arc<Image>,
312 remove_if_exists: bool,
313 cx: &mut Context<ContextStore>,
314 ) {
315 let image_task = LanguageModelImage::from_image(image.clone(), cx).shared();
316 let context = AgentContextHandle::Image(ImageContext {
317 project_path,
318 original_image: image,
319 image_task,
320 context_id: self.next_context_id.post_inc(),
321 });
322 if self.has_context(&context) {
323 if remove_if_exists {
324 self.remove_context(&context, cx);
325 return;
326 }
327 }
328
329 self.insert_context(context, cx);
330 }
331
332 pub fn add_selection(
333 &mut self,
334 buffer: Entity<Buffer>,
335 range: Range<Anchor>,
336 cx: &mut Context<ContextStore>,
337 ) {
338 let context_id = self.next_context_id.post_inc();
339 let context = AgentContextHandle::Selection(SelectionContextHandle {
340 buffer,
341 range,
342 context_id,
343 });
344 self.insert_context(context, cx);
345 }
346
347 pub fn add_suggested_context(
348 &mut self,
349 suggested: &SuggestedContext,
350 cx: &mut Context<ContextStore>,
351 ) {
352 match suggested {
353 SuggestedContext::File {
354 buffer,
355 icon_path: _,
356 name: _,
357 } => {
358 if let Some(buffer) = buffer.upgrade() {
359 let context_id = self.next_context_id.post_inc();
360 self.insert_context(
361 AgentContextHandle::File(FileContextHandle { buffer, context_id }),
362 cx,
363 );
364 };
365 }
366 SuggestedContext::Thread { thread, name: _ } => {
367 if let Some(thread) = thread.upgrade() {
368 let context_id = self.next_context_id.post_inc();
369 self.insert_context(
370 AgentContextHandle::Thread(ThreadContextHandle { thread, context_id }),
371 cx,
372 );
373 }
374 }
375 }
376 }
377
378 fn insert_context(&mut self, context: AgentContextHandle, cx: &mut Context<Self>) -> bool {
379 match &context {
380 AgentContextHandle::Thread(thread_context) => {
381 self.context_thread_ids
382 .insert(thread_context.thread.read(cx).id().clone());
383 self.start_summarizing_thread_if_needed(&thread_context.thread, cx);
384 }
385 _ => {}
386 }
387 let inserted = self.context_set.insert(AgentContextKey(context));
388 if inserted {
389 cx.notify();
390 }
391 inserted
392 }
393
394 pub fn remove_context(&mut self, context: &AgentContextHandle, cx: &mut Context<Self>) {
395 if self
396 .context_set
397 .shift_remove(AgentContextKey::ref_cast(context))
398 {
399 match context {
400 AgentContextHandle::Thread(thread_context) => {
401 self.context_thread_ids
402 .remove(thread_context.thread.read(cx).id());
403 }
404 _ => {}
405 }
406 cx.notify();
407 }
408 }
409
410 pub fn has_context(&mut self, context: &AgentContextHandle) -> bool {
411 self.context_set
412 .contains(AgentContextKey::ref_cast(context))
413 }
414
415 /// Returns whether this file path is already included directly in the context, or if it will be
416 /// included in the context via a directory.
417 pub fn file_path_included(&self, path: &ProjectPath, cx: &App) -> Option<FileInclusion> {
418 let project = self.project.upgrade()?.read(cx);
419 self.context().find_map(|context| match context {
420 AgentContextHandle::File(file_context) => {
421 FileInclusion::check_file(file_context, path, cx)
422 }
423 AgentContextHandle::Image(image_context) => {
424 FileInclusion::check_image(image_context, path)
425 }
426 AgentContextHandle::Directory(directory_context) => {
427 FileInclusion::check_directory(directory_context, path, project, cx)
428 }
429 _ => None,
430 })
431 }
432
433 pub fn path_included_in_directory(
434 &self,
435 path: &ProjectPath,
436 cx: &App,
437 ) -> Option<FileInclusion> {
438 let project = self.project.upgrade()?.read(cx);
439 self.context().find_map(|context| match context {
440 AgentContextHandle::Directory(directory_context) => {
441 FileInclusion::check_directory(directory_context, path, project, cx)
442 }
443 _ => None,
444 })
445 }
446
447 pub fn includes_symbol(&self, symbol: &Symbol, cx: &App) -> bool {
448 self.context().any(|context| match context {
449 AgentContextHandle::Symbol(context) => {
450 if context.symbol != symbol.name {
451 return false;
452 }
453 let buffer = context.buffer.read(cx);
454 let Some(context_path) = buffer.project_path(cx) else {
455 return false;
456 };
457 if context_path != symbol.path {
458 return false;
459 }
460 let context_range = context.range.to_point_utf16(&buffer.snapshot());
461 context_range.start == symbol.range.start.0
462 && context_range.end == symbol.range.end.0
463 }
464 _ => false,
465 })
466 }
467
468 pub fn includes_thread(&self, thread_id: &ThreadId) -> bool {
469 self.context_thread_ids.contains(thread_id)
470 }
471
472 pub fn includes_user_rules(&self, prompt_id: UserPromptId) -> bool {
473 self.context_set
474 .contains(&RulesContextHandle::lookup_key(prompt_id))
475 }
476
477 pub fn includes_url(&self, url: impl Into<SharedString>) -> bool {
478 self.context_set
479 .contains(&FetchedUrlContext::lookup_key(url.into()))
480 }
481
482 pub fn file_paths(&self, cx: &App) -> HashSet<ProjectPath> {
483 self.context()
484 .filter_map(|context| match context {
485 AgentContextHandle::File(file) => {
486 let buffer = file.buffer.read(cx);
487 buffer.project_path(cx)
488 }
489 AgentContextHandle::Directory(_)
490 | AgentContextHandle::Symbol(_)
491 | AgentContextHandle::Selection(_)
492 | AgentContextHandle::FetchedUrl(_)
493 | AgentContextHandle::Thread(_)
494 | AgentContextHandle::Rules(_)
495 | AgentContextHandle::Image(_) => None,
496 })
497 .collect()
498 }
499
500 pub fn thread_ids(&self) -> &HashSet<ThreadId> {
501 &self.context_thread_ids
502 }
503}
504
505pub enum FileInclusion {
506 Direct,
507 InDirectory { full_path: PathBuf },
508}
509
510impl FileInclusion {
511 fn check_file(file_context: &FileContextHandle, path: &ProjectPath, cx: &App) -> Option<Self> {
512 let file_path = file_context.buffer.read(cx).project_path(cx)?;
513 if path == &file_path {
514 Some(FileInclusion::Direct)
515 } else {
516 None
517 }
518 }
519
520 fn check_image(image_context: &ImageContext, path: &ProjectPath) -> Option<Self> {
521 let image_path = image_context.project_path.as_ref()?;
522 if path == image_path {
523 Some(FileInclusion::Direct)
524 } else {
525 None
526 }
527 }
528
529 fn check_directory(
530 directory_context: &DirectoryContextHandle,
531 path: &ProjectPath,
532 project: &Project,
533 cx: &App,
534 ) -> Option<Self> {
535 let worktree = project
536 .worktree_for_entry(directory_context.entry_id, cx)?
537 .read(cx);
538 let entry = worktree.entry_for_id(directory_context.entry_id)?;
539 let directory_path = ProjectPath {
540 worktree_id: worktree.id(),
541 path: entry.path.clone(),
542 };
543 if path.starts_with(&directory_path) {
544 if path == &directory_path {
545 Some(FileInclusion::Direct)
546 } else {
547 Some(FileInclusion::InDirectory {
548 full_path: worktree.full_path(&entry.path),
549 })
550 }
551 } else {
552 None
553 }
554 }
555}