1use std::{path::PathBuf, sync::Arc};
2
3use anyhow::Context as _;
4use collections::HashMap;
5use fs::Fs;
6use futures::StreamExt as _;
7use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
8use language::{
9 proto::{deserialize_anchor, serialize_anchor},
10 ContextProvider as _, LanguageToolchainStore, Location,
11};
12use rpc::{proto, AnyProtoClient, TypedEnvelope};
13use settings::{watch_config_file, SettingsLocation, TaskKind};
14use task::{TaskContext, TaskVariables, VariableName};
15use text::{BufferId, OffsetRangeExt};
16use util::ResultExt;
17
18use crate::{
19 buffer_store::BufferStore, worktree_store::WorktreeStore, BasicContextProvider, Inventory,
20 ProjectEnvironment,
21};
22
23#[allow(clippy::large_enum_variant)] // platform-dependent warning
24pub enum TaskStore {
25 Functional(StoreState),
26 Noop,
27}
28
29pub struct StoreState {
30 mode: StoreMode,
31 task_inventory: Entity<Inventory>,
32 buffer_store: WeakEntity<BufferStore>,
33 worktree_store: Entity<WorktreeStore>,
34 toolchain_store: Arc<dyn LanguageToolchainStore>,
35 _global_task_config_watchers: (Task<()>, Task<()>),
36}
37
38enum StoreMode {
39 Local {
40 downstream_client: Option<(AnyProtoClient, u64)>,
41 environment: Entity<ProjectEnvironment>,
42 },
43 Remote {
44 upstream_client: AnyProtoClient,
45 project_id: u64,
46 },
47}
48
49impl EventEmitter<crate::Event> for TaskStore {}
50
51impl TaskStore {
52 pub fn init(client: Option<&AnyProtoClient>) {
53 if let Some(client) = client {
54 client.add_entity_request_handler(Self::handle_task_context_for_location);
55 }
56 }
57
58 async fn handle_task_context_for_location(
59 store: Entity<Self>,
60 envelope: TypedEnvelope<proto::TaskContextForLocation>,
61 mut cx: AsyncApp,
62 ) -> anyhow::Result<proto::TaskContext> {
63 let location = envelope
64 .payload
65 .location
66 .context("no location given for task context handling")?;
67 let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
68 Ok(match store {
69 TaskStore::Functional(state) => (
70 state.buffer_store.clone(),
71 match &state.mode {
72 StoreMode::Local { .. } => false,
73 StoreMode::Remote { .. } => true,
74 },
75 ),
76 TaskStore::Noop => {
77 anyhow::bail!("empty task store cannot handle task context requests")
78 }
79 })
80 })??;
81 let buffer_store = buffer_store
82 .upgrade()
83 .context("no buffer store when handling task context request")?;
84
85 let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
86 format!(
87 "cannot handle task context request for invalid buffer id: {}",
88 location.buffer_id
89 )
90 })?;
91
92 let start = location
93 .start
94 .and_then(deserialize_anchor)
95 .context("missing task context location start")?;
96 let end = location
97 .end
98 .and_then(deserialize_anchor)
99 .context("missing task context location end")?;
100 let buffer = buffer_store
101 .update(&mut cx, |buffer_store, cx| {
102 if is_remote {
103 buffer_store.wait_for_remote_buffer(buffer_id, cx)
104 } else {
105 Task::ready(
106 buffer_store
107 .get(buffer_id)
108 .with_context(|| format!("no local buffer with id {buffer_id}")),
109 )
110 }
111 })?
112 .await?;
113
114 let location = Location {
115 buffer,
116 range: start..end,
117 };
118 let context_task = store.update(&mut cx, |store, cx| {
119 let captured_variables = {
120 let mut variables = TaskVariables::from_iter(
121 envelope
122 .payload
123 .task_variables
124 .into_iter()
125 .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
126 );
127
128 let snapshot = location.buffer.read(cx).snapshot();
129 let range = location.range.to_offset(&snapshot);
130
131 for range in snapshot.runnable_ranges(range) {
132 for (capture_name, value) in range.extra_captures {
133 variables.insert(VariableName::Custom(capture_name.into()), value);
134 }
135 }
136 variables
137 };
138 store.task_context_for_location(captured_variables, location, cx)
139 })?;
140 let task_context = context_task.await.unwrap_or_default();
141 Ok(proto::TaskContext {
142 project_env: task_context.project_env.into_iter().collect(),
143 cwd: task_context
144 .cwd
145 .map(|cwd| cwd.to_string_lossy().to_string()),
146 task_variables: task_context
147 .task_variables
148 .into_iter()
149 .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
150 .collect(),
151 })
152 }
153
154 pub fn local(
155 fs: Arc<dyn Fs>,
156 buffer_store: WeakEntity<BufferStore>,
157 worktree_store: Entity<WorktreeStore>,
158 toolchain_store: Arc<dyn LanguageToolchainStore>,
159 environment: Entity<ProjectEnvironment>,
160 cx: &mut Context<'_, Self>,
161 ) -> Self {
162 Self::Functional(StoreState {
163 mode: StoreMode::Local {
164 downstream_client: None,
165 environment,
166 },
167 task_inventory: Inventory::new(cx),
168 buffer_store,
169 toolchain_store,
170 worktree_store,
171 _global_task_config_watchers: (
172 Self::subscribe_to_global_task_file_changes(
173 fs.clone(),
174 TaskKind::Script,
175 paths::tasks_file().clone(),
176 cx,
177 ),
178 Self::subscribe_to_global_task_file_changes(
179 fs.clone(),
180 TaskKind::Debug,
181 paths::debug_tasks_file().clone(),
182 cx,
183 ),
184 ),
185 })
186 }
187
188 pub fn remote(
189 fs: Arc<dyn Fs>,
190 buffer_store: WeakEntity<BufferStore>,
191 worktree_store: Entity<WorktreeStore>,
192 toolchain_store: Arc<dyn LanguageToolchainStore>,
193 upstream_client: AnyProtoClient,
194 project_id: u64,
195 cx: &mut Context<'_, Self>,
196 ) -> Self {
197 Self::Functional(StoreState {
198 mode: StoreMode::Remote {
199 upstream_client,
200 project_id,
201 },
202 task_inventory: Inventory::new(cx),
203 buffer_store,
204 toolchain_store,
205 worktree_store,
206 _global_task_config_watchers: (
207 Self::subscribe_to_global_task_file_changes(
208 fs.clone(),
209 TaskKind::Script,
210 paths::tasks_file().clone(),
211 cx,
212 ),
213 Self::subscribe_to_global_task_file_changes(
214 fs.clone(),
215 TaskKind::Debug,
216 paths::debug_tasks_file().clone(),
217 cx,
218 ),
219 ),
220 })
221 }
222
223 pub fn task_context_for_location(
224 &self,
225 captured_variables: TaskVariables,
226 location: Location,
227 cx: &mut App,
228 ) -> Task<Option<TaskContext>> {
229 match self {
230 TaskStore::Functional(state) => match &state.mode {
231 StoreMode::Local { environment, .. } => local_task_context_for_location(
232 state.worktree_store.clone(),
233 state.toolchain_store.clone(),
234 environment.clone(),
235 captured_variables,
236 location,
237 cx,
238 ),
239 StoreMode::Remote {
240 upstream_client,
241 project_id,
242 } => remote_task_context_for_location(
243 *project_id,
244 upstream_client.clone(),
245 state.worktree_store.clone(),
246 captured_variables,
247 location,
248 state.toolchain_store.clone(),
249 cx,
250 ),
251 },
252 TaskStore::Noop => Task::ready(None),
253 }
254 }
255
256 pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
257 match self {
258 TaskStore::Functional(state) => Some(&state.task_inventory),
259 TaskStore::Noop => None,
260 }
261 }
262
263 pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
264 if let Self::Functional(StoreState {
265 mode: StoreMode::Local {
266 downstream_client, ..
267 },
268 ..
269 }) = self
270 {
271 *downstream_client = Some((new_downstream_client, remote_id));
272 }
273 }
274
275 pub fn unshared(&mut self, _: &mut Context<Self>) {
276 if let Self::Functional(StoreState {
277 mode: StoreMode::Local {
278 downstream_client, ..
279 },
280 ..
281 }) = self
282 {
283 *downstream_client = None;
284 }
285 }
286
287 pub(super) fn update_user_tasks(
288 &self,
289 location: Option<SettingsLocation<'_>>,
290 raw_tasks_json: Option<&str>,
291 task_type: TaskKind,
292 cx: &mut Context<'_, Self>,
293 ) -> anyhow::Result<()> {
294 let task_inventory = match self {
295 TaskStore::Functional(state) => &state.task_inventory,
296 TaskStore::Noop => return Ok(()),
297 };
298 let raw_tasks_json = raw_tasks_json
299 .map(|json| json.trim())
300 .filter(|json| !json.is_empty());
301
302 task_inventory.update(cx, |inventory, _| {
303 inventory.update_file_based_tasks(location, raw_tasks_json, task_type)
304 })
305 }
306
307 fn subscribe_to_global_task_file_changes(
308 fs: Arc<dyn Fs>,
309 task_kind: TaskKind,
310 file_path: PathBuf,
311 cx: &mut Context<'_, Self>,
312 ) -> Task<()> {
313 let mut user_tasks_file_rx = watch_config_file(&cx.background_executor(), fs, file_path);
314 let user_tasks_content = cx.background_executor().block(user_tasks_file_rx.next());
315 cx.spawn(async move |task_store, cx| {
316 if let Some(user_tasks_content) = user_tasks_content {
317 let Ok(_) = task_store.update(cx, |task_store, cx| {
318 task_store
319 .update_user_tasks(None, Some(&user_tasks_content), task_kind, cx)
320 .log_err();
321 }) else {
322 return;
323 };
324 }
325 while let Some(user_tasks_content) = user_tasks_file_rx.next().await {
326 let Ok(()) = task_store.update(cx, |task_store, cx| {
327 let result = task_store.update_user_tasks(
328 None,
329 Some(&user_tasks_content),
330 task_kind,
331 cx,
332 );
333 if let Err(err) = &result {
334 log::error!("Failed to load user {:?} tasks: {err}", task_kind);
335 cx.emit(crate::Event::Toast {
336 notification_id: format!("load-user-{:?}-tasks", task_kind).into(),
337 message: format!("Invalid global {:?} tasks file\n{err}", task_kind),
338 });
339 }
340 cx.refresh_windows();
341 }) else {
342 break; // App dropped
343 };
344 }
345 })
346 }
347}
348
349fn local_task_context_for_location(
350 worktree_store: Entity<WorktreeStore>,
351 toolchain_store: Arc<dyn LanguageToolchainStore>,
352 environment: Entity<ProjectEnvironment>,
353 captured_variables: TaskVariables,
354 location: Location,
355 cx: &App,
356) -> Task<Option<TaskContext>> {
357 let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
358 let worktree_abs_path = worktree_id
359 .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
360 .and_then(|worktree| worktree.read(cx).root_dir());
361
362 cx.spawn(async move |cx| {
363 let worktree_abs_path = worktree_abs_path.clone();
364 let project_env = environment
365 .update(cx, |environment, cx| {
366 environment.get_environment(worktree_id, worktree_abs_path.clone(), cx)
367 })
368 .ok()?
369 .await;
370
371 let mut task_variables = cx
372 .update(|cx| {
373 combine_task_variables(
374 captured_variables,
375 location,
376 project_env.clone(),
377 BasicContextProvider::new(worktree_store),
378 toolchain_store,
379 cx,
380 )
381 })
382 .ok()?
383 .await
384 .log_err()?;
385 // Remove all custom entries starting with _, as they're not intended for use by the end user.
386 task_variables.sweep();
387
388 Some(TaskContext {
389 project_env: project_env.unwrap_or_default(),
390 cwd: worktree_abs_path.map(|p| p.to_path_buf()),
391 task_variables,
392 })
393 })
394}
395
396fn remote_task_context_for_location(
397 project_id: u64,
398 upstream_client: AnyProtoClient,
399 worktree_store: Entity<WorktreeStore>,
400 captured_variables: TaskVariables,
401 location: Location,
402 toolchain_store: Arc<dyn LanguageToolchainStore>,
403 cx: &mut App,
404) -> Task<Option<TaskContext>> {
405 cx.spawn(async move |cx| {
406 // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
407 let mut remote_context = cx
408 .update(|cx| {
409 BasicContextProvider::new(worktree_store).build_context(
410 &TaskVariables::default(),
411 &location,
412 None,
413 toolchain_store,
414 cx,
415 )
416 })
417 .ok()?
418 .await
419 .log_err()
420 .unwrap_or_default();
421 remote_context.extend(captured_variables);
422
423 let buffer_id = cx
424 .update(|cx| location.buffer.read(cx).remote_id().to_proto())
425 .ok()?;
426 let context_task = upstream_client.request(proto::TaskContextForLocation {
427 project_id,
428 location: Some(proto::Location {
429 buffer_id,
430 start: Some(serialize_anchor(&location.range.start)),
431 end: Some(serialize_anchor(&location.range.end)),
432 }),
433 task_variables: remote_context
434 .into_iter()
435 .map(|(k, v)| (k.to_string(), v))
436 .collect(),
437 });
438 let task_context = context_task.await.log_err()?;
439 Some(TaskContext {
440 cwd: task_context.cwd.map(PathBuf::from),
441 task_variables: task_context
442 .task_variables
443 .into_iter()
444 .filter_map(
445 |(variable_name, variable_value)| match variable_name.parse() {
446 Ok(variable_name) => Some((variable_name, variable_value)),
447 Err(()) => {
448 log::error!("Unknown variable name: {variable_name}");
449 None
450 }
451 },
452 )
453 .collect(),
454 project_env: task_context.project_env.into_iter().collect(),
455 })
456 })
457}
458
459fn combine_task_variables(
460 mut captured_variables: TaskVariables,
461 location: Location,
462 project_env: Option<HashMap<String, String>>,
463 baseline: BasicContextProvider,
464 toolchain_store: Arc<dyn LanguageToolchainStore>,
465 cx: &mut App,
466) -> Task<anyhow::Result<TaskVariables>> {
467 let language_context_provider = location
468 .buffer
469 .read(cx)
470 .language()
471 .and_then(|language| language.context_provider());
472 cx.spawn(async move |cx| {
473 let baseline = cx
474 .update(|cx| {
475 baseline.build_context(
476 &captured_variables,
477 &location,
478 project_env.clone(),
479 toolchain_store.clone(),
480 cx,
481 )
482 })?
483 .await
484 .context("building basic default context")?;
485 captured_variables.extend(baseline);
486 if let Some(provider) = language_context_provider {
487 captured_variables.extend(
488 cx.update(|cx| {
489 provider.build_context(
490 &captured_variables,
491 &location,
492 project_env,
493 toolchain_store,
494 cx,
495 )
496 })?
497 .await
498 .context("building provider context")?,
499 );
500 }
501 Ok(captured_variables)
502 })
503}