1use std::{
2 path::{Path, PathBuf},
3 sync::Arc,
4};
5
6use anyhow::Context as _;
7use collections::HashMap;
8use gpui::{App, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
9use language::{
10 ContextProvider as _, LanguageToolchainStore, Location,
11 proto::{deserialize_anchor, serialize_anchor},
12};
13use rpc::{AnyProtoClient, TypedEnvelope, proto};
14use settings::{InvalidSettingsError, SettingsLocation};
15use task::{TaskContext, TaskVariables, VariableName};
16use text::{BufferId, OffsetRangeExt};
17use util::ResultExt;
18
19use crate::{
20 BasicContextProvider, Inventory, ProjectEnvironment, buffer_store::BufferStore,
21 worktree_store::WorktreeStore,
22};
23
24#[allow(clippy::large_enum_variant)] // platform-dependent warning
25pub enum TaskStore {
26 Functional(StoreState),
27 Noop,
28}
29
30pub struct StoreState {
31 mode: StoreMode,
32 task_inventory: Entity<Inventory>,
33 buffer_store: WeakEntity<BufferStore>,
34 worktree_store: Entity<WorktreeStore>,
35 toolchain_store: Arc<dyn LanguageToolchainStore>,
36}
37
38enum StoreMode {
39 Local {
40 downstream_client: Option<(AnyProtoClient, u64)>,
41 environment: Entity<ProjectEnvironment>,
42 },
43 Remote {
44 upstream_client: AnyProtoClient,
45 project_id: u64,
46 },
47}
48
49impl EventEmitter<crate::Event> for TaskStore {}
50
51#[derive(Debug)]
52pub enum TaskSettingsLocation<'a> {
53 Global(&'a Path),
54 Worktree(SettingsLocation<'a>),
55}
56
57impl TaskStore {
58 pub fn init(client: Option<&AnyProtoClient>) {
59 if let Some(client) = client {
60 client.add_entity_request_handler(Self::handle_task_context_for_location);
61 }
62 }
63
64 async fn handle_task_context_for_location(
65 store: Entity<Self>,
66 envelope: TypedEnvelope<proto::TaskContextForLocation>,
67 mut cx: AsyncApp,
68 ) -> anyhow::Result<proto::TaskContext> {
69 let location = envelope
70 .payload
71 .location
72 .context("no location given for task context handling")?;
73 let (buffer_store, is_remote) = store.update(&mut cx, |store, _| {
74 Ok(match store {
75 TaskStore::Functional(state) => (
76 state.buffer_store.clone(),
77 match &state.mode {
78 StoreMode::Local { .. } => false,
79 StoreMode::Remote { .. } => true,
80 },
81 ),
82 TaskStore::Noop => {
83 anyhow::bail!("empty task store cannot handle task context requests")
84 }
85 })
86 })??;
87 let buffer_store = buffer_store
88 .upgrade()
89 .context("no buffer store when handling task context request")?;
90
91 let buffer_id = BufferId::new(location.buffer_id).with_context(|| {
92 format!(
93 "cannot handle task context request for invalid buffer id: {}",
94 location.buffer_id
95 )
96 })?;
97
98 let start = location
99 .start
100 .and_then(deserialize_anchor)
101 .context("missing task context location start")?;
102 let end = location
103 .end
104 .and_then(deserialize_anchor)
105 .context("missing task context location end")?;
106 let buffer = buffer_store
107 .update(&mut cx, |buffer_store, cx| {
108 if is_remote {
109 buffer_store.wait_for_remote_buffer(buffer_id, cx)
110 } else {
111 Task::ready(
112 buffer_store
113 .get(buffer_id)
114 .with_context(|| format!("no local buffer with id {buffer_id}")),
115 )
116 }
117 })?
118 .await?;
119
120 let location = Location {
121 buffer,
122 range: start..end,
123 };
124 let context_task = store.update(&mut cx, |store, cx| {
125 let captured_variables = {
126 let mut variables = TaskVariables::from_iter(
127 envelope
128 .payload
129 .task_variables
130 .into_iter()
131 .filter_map(|(k, v)| Some((k.parse().log_err()?, v))),
132 );
133
134 let snapshot = location.buffer.read(cx).snapshot();
135 let range = location.range.to_offset(&snapshot);
136
137 for range in snapshot.runnable_ranges(range) {
138 for (capture_name, value) in range.extra_captures {
139 variables.insert(VariableName::Custom(capture_name.into()), value);
140 }
141 }
142 variables
143 };
144 store.task_context_for_location(captured_variables, location, cx)
145 })?;
146 let task_context = context_task.await.unwrap_or_default();
147 Ok(proto::TaskContext {
148 project_env: task_context.project_env.into_iter().collect(),
149 cwd: task_context
150 .cwd
151 .map(|cwd| cwd.to_string_lossy().to_string()),
152 task_variables: task_context
153 .task_variables
154 .into_iter()
155 .map(|(variable_name, variable_value)| (variable_name.to_string(), variable_value))
156 .collect(),
157 })
158 }
159
160 pub fn local(
161 buffer_store: WeakEntity<BufferStore>,
162 worktree_store: Entity<WorktreeStore>,
163 toolchain_store: Arc<dyn LanguageToolchainStore>,
164 environment: Entity<ProjectEnvironment>,
165 cx: &mut Context<Self>,
166 ) -> Self {
167 Self::Functional(StoreState {
168 mode: StoreMode::Local {
169 downstream_client: None,
170 environment,
171 },
172 task_inventory: Inventory::new(cx),
173 buffer_store,
174 toolchain_store,
175 worktree_store,
176 })
177 }
178
179 pub fn remote(
180 buffer_store: WeakEntity<BufferStore>,
181 worktree_store: Entity<WorktreeStore>,
182 toolchain_store: Arc<dyn LanguageToolchainStore>,
183 upstream_client: AnyProtoClient,
184 project_id: u64,
185 cx: &mut Context<Self>,
186 ) -> Self {
187 Self::Functional(StoreState {
188 mode: StoreMode::Remote {
189 upstream_client,
190 project_id,
191 },
192 task_inventory: Inventory::new(cx),
193 buffer_store,
194 toolchain_store,
195 worktree_store,
196 })
197 }
198
199 pub fn task_context_for_location(
200 &self,
201 captured_variables: TaskVariables,
202 location: Location,
203 cx: &mut App,
204 ) -> Task<Option<TaskContext>> {
205 match self {
206 TaskStore::Functional(state) => match &state.mode {
207 StoreMode::Local { environment, .. } => local_task_context_for_location(
208 state.worktree_store.clone(),
209 state.toolchain_store.clone(),
210 environment.clone(),
211 captured_variables,
212 location,
213 cx,
214 ),
215 StoreMode::Remote {
216 upstream_client,
217 project_id,
218 } => remote_task_context_for_location(
219 *project_id,
220 upstream_client.clone(),
221 state.worktree_store.clone(),
222 captured_variables,
223 location,
224 state.toolchain_store.clone(),
225 cx,
226 ),
227 },
228 TaskStore::Noop => Task::ready(None),
229 }
230 }
231
232 pub fn task_inventory(&self) -> Option<&Entity<Inventory>> {
233 match self {
234 TaskStore::Functional(state) => Some(&state.task_inventory),
235 TaskStore::Noop => None,
236 }
237 }
238
239 pub fn shared(&mut self, remote_id: u64, new_downstream_client: AnyProtoClient, _cx: &mut App) {
240 if let Self::Functional(StoreState {
241 mode: StoreMode::Local {
242 downstream_client, ..
243 },
244 ..
245 }) = self
246 {
247 *downstream_client = Some((new_downstream_client, remote_id));
248 }
249 }
250
251 pub fn unshared(&mut self, _: &mut Context<Self>) {
252 if let Self::Functional(StoreState {
253 mode: StoreMode::Local {
254 downstream_client, ..
255 },
256 ..
257 }) = self
258 {
259 *downstream_client = None;
260 }
261 }
262
263 pub(super) fn update_user_tasks(
264 &self,
265 location: TaskSettingsLocation<'_>,
266 raw_tasks_json: Option<&str>,
267 cx: &mut Context<Self>,
268 ) -> Result<(), InvalidSettingsError> {
269 let task_inventory = match self {
270 TaskStore::Functional(state) => &state.task_inventory,
271 TaskStore::Noop => return Ok(()),
272 };
273 let raw_tasks_json = raw_tasks_json
274 .map(|json| json.trim())
275 .filter(|json| !json.is_empty());
276
277 task_inventory.update(cx, |inventory, _| {
278 inventory.update_file_based_tasks(location, raw_tasks_json)
279 })
280 }
281
282 pub(super) fn update_user_debug_scenarios(
283 &self,
284 location: TaskSettingsLocation<'_>,
285 raw_tasks_json: Option<&str>,
286 cx: &mut Context<Self>,
287 ) -> Result<(), InvalidSettingsError> {
288 let task_inventory = match self {
289 TaskStore::Functional(state) => &state.task_inventory,
290 TaskStore::Noop => return Ok(()),
291 };
292 let raw_tasks_json = raw_tasks_json
293 .map(|json| json.trim())
294 .filter(|json| !json.is_empty());
295
296 task_inventory.update(cx, |inventory, _| {
297 inventory.update_file_based_scenarios(location, raw_tasks_json)
298 })
299 }
300}
301
302fn local_task_context_for_location(
303 worktree_store: Entity<WorktreeStore>,
304 toolchain_store: Arc<dyn LanguageToolchainStore>,
305 environment: Entity<ProjectEnvironment>,
306 captured_variables: TaskVariables,
307 location: Location,
308 cx: &App,
309) -> Task<Option<TaskContext>> {
310 let worktree_id = location.buffer.read(cx).file().map(|f| f.worktree_id(cx));
311 let worktree_abs_path = worktree_id
312 .and_then(|worktree_id| worktree_store.read(cx).worktree_for_id(worktree_id, cx))
313 .and_then(|worktree| worktree.read(cx).root_dir());
314
315 cx.spawn(async move |cx| {
316 let project_env = environment
317 .update(cx, |environment, cx| {
318 environment.get_buffer_environment(
319 location.buffer.clone(),
320 worktree_store.clone(),
321 cx,
322 )
323 })
324 .ok()?
325 .await;
326
327 let mut task_variables = cx
328 .update(|cx| {
329 combine_task_variables(
330 captured_variables,
331 location,
332 project_env.clone(),
333 BasicContextProvider::new(worktree_store),
334 toolchain_store,
335 cx,
336 )
337 })
338 .ok()?
339 .await
340 .log_err()?;
341 // Remove all custom entries starting with _, as they're not intended for use by the end user.
342 task_variables.sweep();
343
344 Some(TaskContext {
345 project_env: project_env.unwrap_or_default(),
346 cwd: worktree_abs_path.map(|p| p.to_path_buf()),
347 task_variables,
348 })
349 })
350}
351
352fn remote_task_context_for_location(
353 project_id: u64,
354 upstream_client: AnyProtoClient,
355 worktree_store: Entity<WorktreeStore>,
356 captured_variables: TaskVariables,
357 location: Location,
358 toolchain_store: Arc<dyn LanguageToolchainStore>,
359 cx: &mut App,
360) -> Task<Option<TaskContext>> {
361 cx.spawn(async move |cx| {
362 // We need to gather a client context, as the headless one may lack certain information (e.g. tree-sitter parsing is disabled there, so symbols are not available).
363 let mut remote_context = cx
364 .update(|cx| {
365 BasicContextProvider::new(worktree_store).build_context(
366 &TaskVariables::default(),
367 &location,
368 None,
369 toolchain_store,
370 cx,
371 )
372 })
373 .ok()?
374 .await
375 .log_err()
376 .unwrap_or_default();
377 remote_context.extend(captured_variables);
378
379 let buffer_id = cx
380 .update(|cx| location.buffer.read(cx).remote_id().to_proto())
381 .ok()?;
382 let context_task = upstream_client.request(proto::TaskContextForLocation {
383 project_id,
384 location: Some(proto::Location {
385 buffer_id,
386 start: Some(serialize_anchor(&location.range.start)),
387 end: Some(serialize_anchor(&location.range.end)),
388 }),
389 task_variables: remote_context
390 .into_iter()
391 .map(|(k, v)| (k.to_string(), v))
392 .collect(),
393 });
394 let task_context = context_task.await.log_err()?;
395 Some(TaskContext {
396 cwd: task_context.cwd.map(PathBuf::from),
397 task_variables: task_context
398 .task_variables
399 .into_iter()
400 .filter_map(
401 |(variable_name, variable_value)| match variable_name.parse() {
402 Ok(variable_name) => Some((variable_name, variable_value)),
403 Err(()) => {
404 log::error!("Unknown variable name: {variable_name}");
405 None
406 }
407 },
408 )
409 .collect(),
410 project_env: task_context.project_env.into_iter().collect(),
411 })
412 })
413}
414
415fn combine_task_variables(
416 mut captured_variables: TaskVariables,
417 location: Location,
418 project_env: Option<HashMap<String, String>>,
419 baseline: BasicContextProvider,
420 toolchain_store: Arc<dyn LanguageToolchainStore>,
421 cx: &mut App,
422) -> Task<anyhow::Result<TaskVariables>> {
423 let language_context_provider = location
424 .buffer
425 .read(cx)
426 .language()
427 .and_then(|language| language.context_provider());
428 cx.spawn(async move |cx| {
429 let baseline = cx
430 .update(|cx| {
431 baseline.build_context(
432 &captured_variables,
433 &location,
434 project_env.clone(),
435 toolchain_store.clone(),
436 cx,
437 )
438 })?
439 .await
440 .context("building basic default context")?;
441 captured_variables.extend(baseline);
442 if let Some(provider) = language_context_provider {
443 captured_variables.extend(
444 cx.update(|cx| {
445 provider.build_context(
446 &captured_variables,
447 &location,
448 project_env,
449 toolchain_store,
450 cx,
451 )
452 })?
453 .await
454 .context("building provider context")?,
455 );
456 }
457 Ok(captured_variables)
458 })
459}