Detailed changes
@@ -91,6 +91,7 @@ dependencies = [
"futures 0.3.28",
"gpui",
"isahc",
+ "language",
"lazy_static",
"log",
"matrixmultiply",
@@ -1467,7 +1468,7 @@ dependencies = [
[[package]]
name = "collab"
-version = "0.24.0"
+version = "0.25.0"
dependencies = [
"anyhow",
"async-trait",
@@ -1503,6 +1504,7 @@ dependencies = [
"lsp",
"nanoid",
"node_runtime",
+ "notifications",
"parking_lot 0.11.2",
"pretty_assertions",
"project",
@@ -1558,13 +1560,17 @@ dependencies = [
"fuzzy",
"gpui",
"language",
+ "lazy_static",
"log",
"menu",
+ "notifications",
"picker",
"postage",
+ "pretty_assertions",
"project",
"recent_projects",
"rich_text",
+ "rpc",
"schemars",
"serde",
"serde_derive",
@@ -1573,6 +1579,7 @@ dependencies = [
"theme",
"theme_selector",
"time",
+ "tree-sitter-markdown",
"util",
"vcs_menu",
"workspace",
@@ -4730,6 +4737,26 @@ dependencies = [
"minimal-lexical",
]
+[[package]]
+name = "notifications"
+version = "0.1.0"
+dependencies = [
+ "anyhow",
+ "channel",
+ "client",
+ "clock",
+ "collections",
+ "db",
+ "feature_flags",
+ "gpui",
+ "rpc",
+ "settings",
+ "sum_tree",
+ "text",
+ "time",
+ "util",
+]
+
[[package]]
name = "ntapi"
version = "0.3.7"
@@ -6404,8 +6431,10 @@ dependencies = [
"rsa 0.4.0",
"serde",
"serde_derive",
+ "serde_json",
"smol",
"smol-timeout",
+ "strum",
"tempdir",
"tracing",
"util",
@@ -6626,6 +6655,12 @@ dependencies = [
"untrusted",
]
+[[package]]
+name = "rustversion"
+version = "1.0.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4"
+
[[package]]
name = "rustybuzz"
version = "0.3.0"
@@ -7700,6 +7735,22 @@ name = "strum"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125"
+dependencies = [
+ "strum_macros",
+]
+
+[[package]]
+name = "strum_macros"
+version = "0.25.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ad8d03b598d3d0fff69bf533ee3ef19b8eeb342729596df84bcc7e1f96ec4059"
+dependencies = [
+ "heck 0.4.1",
+ "proc-macro2",
+ "quote",
+ "rustversion",
+ "syn 2.0.37",
+]
[[package]]
name = "subtle"
@@ -9098,6 +9149,7 @@ name = "vcs_menu"
version = "0.1.0"
dependencies = [
"anyhow",
+ "fs",
"fuzzy",
"gpui",
"picker",
@@ -10042,7 +10094,7 @@ dependencies = [
[[package]]
name = "zed"
-version = "0.109.0"
+version = "0.110.0"
dependencies = [
"activity_indicator",
"ai",
@@ -10097,6 +10149,7 @@ dependencies = [
"log",
"lsp",
"node_runtime",
+ "notifications",
"num_cpus",
"outline",
"parking_lot 0.11.2",
@@ -47,6 +47,7 @@ members = [
"crates/media",
"crates/menu",
"crates/node_runtime",
+ "crates/notifications",
"crates/outline",
"crates/picker",
"crates/plugin",
@@ -112,6 +113,7 @@ serde_derive = { version = "1.0", features = ["deserialize_in_place"] }
serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] }
smallvec = { version = "1.6", features = ["union"] }
smol = { version = "1.2" }
+strum = { version = "0.25.0", features = ["derive"] }
sysinfo = "0.29.10"
tempdir = { version = "0.3.7" }
thiserror = { version = "1.0.29" }
@@ -0,0 +1,8 @@
+<svg width="15" height="15" viewBox="0 0 15 15" fill="none" xmlns="http://www.w3.org/2000/svg">
+ <path
+ fill-rule="evenodd"
+ clip-rule="evenodd"
@@ -142,6 +142,14 @@
// Default width of the channels panel.
"default_width": 240
},
+ "notification_panel": {
+ // Whether to show the collaboration panel button in the status bar.
+ "button": true,
+ // Where to dock channels panel. Can be 'left' or 'right'.
+ "dock": "right",
+ // Default width of the channels panel.
+ "default_width": 240
+ },
"assistant": {
// Whether to show the assistant panel button in the status bar.
"button": true,
@@ -11,6 +11,7 @@ doctest = false
[dependencies]
gpui = { path = "../gpui" }
util = { path = "../util" }
+language = { path = "../language" }
async-trait.workspace = true
anyhow.workspace = true
futures.workspace = true
@@ -1,2 +1,4 @@
pub mod completion;
pub mod embedding;
+pub mod models;
+pub mod templates;
@@ -53,6 +53,8 @@ pub struct OpenAIRequest {
pub model: String,
pub messages: Vec<RequestMessage>,
pub stream: bool,
+ pub stop: Vec<String>,
+ pub temperature: f32,
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::AsyncReadExt;
use gpui::executor::Background;
-use gpui::serde_json;
+use gpui::{serde_json, ViewContext};
use isahc::http::StatusCode;
use isahc::prelude::Configurable;
use isahc::{AsyncBody, Response};
@@ -20,9 +20,11 @@ use std::sync::Arc;
use std::time::{Duration, Instant};
use tiktoken_rs::{cl100k_base, CoreBPE};
use util::http::{HttpClient, Request};
+use util::ResultExt;
+
+use crate::completion::OPENAI_API_URL;
lazy_static! {
- static ref OPENAI_API_KEY: Option<String> = env::var("OPENAI_API_KEY").ok();
static ref OPENAI_BPE_TOKENIZER: CoreBPE = cl100k_base().unwrap();
}
@@ -87,6 +89,7 @@ impl Embedding {
#[derive(Clone)]
pub struct OpenAIEmbeddings {
+ pub api_key: Option<String>,
pub client: Arc<dyn HttpClient>,
pub executor: Arc<Background>,
rate_limit_count_rx: watch::Receiver<Option<Instant>>,
@@ -166,11 +169,36 @@ impl EmbeddingProvider for DummyEmbeddings {
const OPENAI_INPUT_LIMIT: usize = 8190;
impl OpenAIEmbeddings {
- pub fn new(client: Arc<dyn HttpClient>, executor: Arc<Background>) -> Self {
+ pub fn authenticate(&mut self, cx: &mut ViewContext<Self>) {
+ if self.api_key.is_none() {
+ let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ Some(api_key)
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ String::from_utf8(api_key).log_err()
+ } else {
+ None
+ };
+
+ if let Some(api_key) = api_key {
+ self.api_key = Some(api_key);
+ }
+ }
+ }
+ pub fn new(
+ api_key: Option<String>,
+ client: Arc<dyn HttpClient>,
+ executor: Arc<Background>,
+ ) -> Self {
let (rate_limit_count_tx, rate_limit_count_rx) = watch::channel_with(None);
let rate_limit_count_tx = Arc::new(Mutex::new(rate_limit_count_tx));
OpenAIEmbeddings {
+ api_key,
client,
executor,
rate_limit_count_rx,
@@ -237,8 +265,9 @@ impl OpenAIEmbeddings {
#[async_trait]
impl EmbeddingProvider for OpenAIEmbeddings {
fn is_authenticated(&self) -> bool {
- OPENAI_API_KEY.as_ref().is_some()
+ self.api_key.is_some()
}
+
fn max_tokens_per_batch(&self) -> usize {
50000
}
@@ -265,9 +294,9 @@ impl EmbeddingProvider for OpenAIEmbeddings {
const BACKOFF_SECONDS: [usize; 4] = [3, 5, 15, 45];
const MAX_RETRIES: usize = 4;
- let api_key = OPENAI_API_KEY
- .as_ref()
- .ok_or_else(|| anyhow!("no api key"))?;
+ let Some(api_key) = self.api_key.clone() else {
+ return Err(anyhow!("no open ai key provided"));
+ };
let mut request_number = 0;
let mut rate_limiting = false;
@@ -276,7 +305,7 @@ impl EmbeddingProvider for OpenAIEmbeddings {
while request_number < MAX_RETRIES {
response = self
.send_request(
- api_key,
+ &api_key,
spans.iter().map(|x| &**x).collect(),
request_timeout,
)
@@ -0,0 +1,66 @@
+use anyhow::anyhow;
+use tiktoken_rs::CoreBPE;
+use util::ResultExt;
+
+pub trait LanguageModel {
+ fn name(&self) -> String;
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize>;
+ fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String>;
+ fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String>;
+ fn capacity(&self) -> anyhow::Result<usize>;
+}
+
+pub struct OpenAILanguageModel {
+ name: String,
+ bpe: Option<CoreBPE>,
+}
+
+impl OpenAILanguageModel {
+ pub fn load(model_name: &str) -> Self {
+ let bpe = tiktoken_rs::get_bpe_from_model(model_name).log_err();
+ OpenAILanguageModel {
+ name: model_name.to_string(),
+ bpe,
+ }
+ }
+}
+
+impl LanguageModel for OpenAILanguageModel {
+ fn name(&self) -> String {
+ self.name.clone()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ if let Some(bpe) = &self.bpe {
+ anyhow::Ok(bpe.encode_with_special_tokens(content).len())
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
+ if let Some(bpe) = &self.bpe {
+ let tokens = bpe.encode_with_special_tokens(content);
+ if tokens.len() > length {
+ bpe.decode(tokens[..length].to_vec())
+ } else {
+ bpe.decode(tokens)
+ }
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+ if let Some(bpe) = &self.bpe {
+ let tokens = bpe.encode_with_special_tokens(content);
+ if tokens.len() > length {
+ bpe.decode(tokens[length..].to_vec())
+ } else {
+ bpe.decode(tokens)
+ }
+ } else {
+ Err(anyhow!("bpe for open ai model was not retrieved"))
+ }
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(tiktoken_rs::model::get_context_size(&self.name))
+ }
+}
@@ -0,0 +1,350 @@
+use std::cmp::Reverse;
+use std::ops::Range;
+use std::sync::Arc;
+
+use language::BufferSnapshot;
+use util::ResultExt;
+
+use crate::models::LanguageModel;
+use crate::templates::repository_context::PromptCodeSnippet;
+
+pub(crate) enum PromptFileType {
+ Text,
+ Code,
+}
+
+// TODO: Set this up to manage for defaults well
+pub struct PromptArguments {
+ pub model: Arc<dyn LanguageModel>,
+ pub user_prompt: Option<String>,
+ pub language_name: Option<String>,
+ pub project_name: Option<String>,
+ pub snippets: Vec<PromptCodeSnippet>,
+ pub reserved_tokens: usize,
+ pub buffer: Option<BufferSnapshot>,
+ pub selected_range: Option<Range<usize>>,
+}
+
+impl PromptArguments {
+ pub(crate) fn get_file_type(&self) -> PromptFileType {
+ if self
+ .language_name
+ .as_ref()
+ .and_then(|name| Some(!["Markdown", "Plain Text"].contains(&name.as_str())))
+ .unwrap_or(true)
+ {
+ PromptFileType::Code
+ } else {
+ PromptFileType::Text
+ }
+ }
+}
+
+pub trait PromptTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)>;
+}
+
+#[repr(i8)]
+#[derive(PartialEq, Eq, Ord)]
+pub enum PromptPriority {
+ Mandatory, // Ignores truncation
+ Ordered { order: usize }, // Truncates based on priority
+}
+
+impl PartialOrd for PromptPriority {
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+ match (self, other) {
+ (Self::Mandatory, Self::Mandatory) => Some(std::cmp::Ordering::Equal),
+ (Self::Mandatory, Self::Ordered { .. }) => Some(std::cmp::Ordering::Greater),
+ (Self::Ordered { .. }, Self::Mandatory) => Some(std::cmp::Ordering::Less),
+ (Self::Ordered { order: a }, Self::Ordered { order: b }) => b.partial_cmp(a),
+ }
+ }
+}
+
+pub struct PromptChain {
+ args: PromptArguments,
+ templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+}
+
+impl PromptChain {
+ pub fn new(
+ args: PromptArguments,
+ templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)>,
+ ) -> Self {
+ PromptChain { args, templates }
+ }
+
+ pub fn generate(&self, truncate: bool) -> anyhow::Result<(String, usize)> {
+ // Argsort based on Prompt Priority
+ let seperator = "\n";
+ let seperator_tokens = self.args.model.count_tokens(seperator)?;
+ let mut sorted_indices = (0..self.templates.len()).collect::<Vec<_>>();
+ sorted_indices.sort_by_key(|&i| Reverse(&self.templates[i].0));
+
+ // If Truncate
+ let mut tokens_outstanding = if truncate {
+ Some(self.args.model.capacity()? - self.args.reserved_tokens)
+ } else {
+ None
+ };
+
+ let mut prompts = vec!["".to_string(); sorted_indices.len()];
+ for idx in sorted_indices {
+ let (_, template) = &self.templates[idx];
+
+ if let Some((template_prompt, prompt_token_count)) =
+ template.generate(&self.args, tokens_outstanding).log_err()
+ {
+ if template_prompt != "" {
+ prompts[idx] = template_prompt;
+
+ if let Some(remaining_tokens) = tokens_outstanding {
+ let new_tokens = prompt_token_count + seperator_tokens;
+ tokens_outstanding = if remaining_tokens > new_tokens {
+ Some(remaining_tokens - new_tokens)
+ } else {
+ Some(0)
+ };
+ }
+ }
+ }
+ }
+
+ prompts.retain(|x| x != "");
+
+ let full_prompt = prompts.join(seperator);
+ let total_token_count = self.args.model.count_tokens(&full_prompt)?;
+ anyhow::Ok((prompts.join(seperator), total_token_count))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod tests {
+ use super::*;
+
+ #[test]
+ pub fn test_prompt_chain() {
+ struct TestPromptTemplate {}
+ impl PromptTemplate for TestPromptTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut content = "This is a test prompt template".to_string();
+
+ let mut token_count = args.model.count_tokens(&content)?;
+ if let Some(max_token_length) = max_token_length {
+ if token_count > max_token_length {
+ content = args.model.truncate(&content, max_token_length)?;
+ token_count = max_token_length;
+ }
+ }
+
+ anyhow::Ok((content, token_count))
+ }
+ }
+
+ struct TestLowPriorityTemplate {}
+ impl PromptTemplate for TestLowPriorityTemplate {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut content = "This is a low priority test prompt template".to_string();
+
+ let mut token_count = args.model.count_tokens(&content)?;
+ if let Some(max_token_length) = max_token_length {
+ if token_count > max_token_length {
+ content = args.model.truncate(&content, max_token_length)?;
+ token_count = max_token_length;
+ }
+ }
+
+ anyhow::Ok((content, token_count))
+ }
+ }
+
+ #[derive(Clone)]
+ struct DummyLanguageModel {
+ capacity: usize,
+ }
+
+ impl LanguageModel for DummyLanguageModel {
+ fn name(&self) -> String {
+ "dummy".to_string()
+ }
+ fn count_tokens(&self, content: &str) -> anyhow::Result<usize> {
+ anyhow::Ok(content.chars().collect::<Vec<char>>().len())
+ }
+ fn truncate(&self, content: &str, length: usize) -> anyhow::Result<String> {
+ anyhow::Ok(
+ content.chars().collect::<Vec<char>>()[..length]
+ .into_iter()
+ .collect::<String>(),
+ )
+ }
+ fn truncate_start(&self, content: &str, length: usize) -> anyhow::Result<String> {
+ anyhow::Ok(
+ content.chars().collect::<Vec<char>>()[length..]
+ .into_iter()
+ .collect::<String>(),
+ )
+ }
+ fn capacity(&self) -> anyhow::Result<usize> {
+ anyhow::Ok(self.capacity)
+ }
+ }
+
+ let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 100 });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(false).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a test prompt template\nThis is a low priority test prompt template"
+ .to_string()
+ );
+
+ assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+ // Testing with Truncation Off
+ // Should ignore capacity and return all prompts
+ let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity: 20 });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(false).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a test prompt template\nThis is a low priority test prompt template"
+ .to_string()
+ );
+
+ assert_eq!(model.count_tokens(&prompt).unwrap(), token_count);
+
+ // Testing with Truncation Off
+ // Should ignore capacity and return all prompts
+ let capacity = 20;
+ let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens: 0,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 2 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(true).unwrap();
+
+ assert_eq!(prompt, "This is a test promp".to_string());
+ assert_eq!(token_count, capacity);
+
+ // Change Ordering of Prompts Based on Priority
+ let capacity = 120;
+ let reserved_tokens = 10;
+ let model: Arc<dyn LanguageModel> = Arc::new(DummyLanguageModel { capacity });
+ let args = PromptArguments {
+ model: model.clone(),
+ language_name: None,
+ project_name: None,
+ snippets: Vec::new(),
+ reserved_tokens,
+ buffer: None,
+ selected_range: None,
+ user_prompt: None,
+ };
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (
+ PromptPriority::Mandatory,
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(TestPromptTemplate {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(TestLowPriorityTemplate {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+
+ let (prompt, token_count) = chain.generate(true).unwrap();
+
+ assert_eq!(
+ prompt,
+ "This is a low priority test prompt template\nThis is a test prompt template\nThis is a low priority test prompt "
+ .to_string()
+ );
+ assert_eq!(token_count, capacity - reserved_tokens);
+ }
+}
@@ -0,0 +1,160 @@
+use anyhow::anyhow;
+use language::BufferSnapshot;
+use language::ToOffset;
+
+use crate::models::LanguageModel;
+use crate::templates::base::PromptArguments;
+use crate::templates::base::PromptTemplate;
+use std::fmt::Write;
+use std::ops::Range;
+use std::sync::Arc;
+
+fn retrieve_context(
+ buffer: &BufferSnapshot,
+ selected_range: &Option<Range<usize>>,
+ model: Arc<dyn LanguageModel>,
+ max_token_count: Option<usize>,
+) -> anyhow::Result<(String, usize, bool)> {
+ let mut prompt = String::new();
+ let mut truncated = false;
+ if let Some(selected_range) = selected_range {
+ let start = selected_range.start.to_offset(buffer);
+ let end = selected_range.end.to_offset(buffer);
+
+ let start_window = buffer.text_for_range(0..start).collect::<String>();
+
+ let mut selected_window = String::new();
+ if start == end {
+ write!(selected_window, "<|START|>").unwrap();
+ } else {
+ write!(selected_window, "<|START|").unwrap();
+ }
+
+ write!(
+ selected_window,
+ "{}",
+ buffer.text_for_range(start..end).collect::<String>()
+ )
+ .unwrap();
+
+ if start != end {
+ write!(selected_window, "|END|>").unwrap();
+ }
+
+ let end_window = buffer.text_for_range(end..buffer.len()).collect::<String>();
+
+ if let Some(max_token_count) = max_token_count {
+ let selected_tokens = model.count_tokens(&selected_window)?;
+ if selected_tokens > max_token_count {
+ return Err(anyhow!(
+ "selected range is greater than model context window, truncation not possible"
+ ));
+ };
+
+ let mut remaining_tokens = max_token_count - selected_tokens;
+ let start_window_tokens = model.count_tokens(&start_window)?;
+ let end_window_tokens = model.count_tokens(&end_window)?;
+ let outside_tokens = start_window_tokens + end_window_tokens;
+ if outside_tokens > remaining_tokens {
+ let (start_goal_tokens, end_goal_tokens) =
+ if start_window_tokens < end_window_tokens {
+ let start_goal_tokens = (remaining_tokens / 2).min(start_window_tokens);
+ remaining_tokens -= start_goal_tokens;
+ let end_goal_tokens = remaining_tokens.min(end_window_tokens);
+ (start_goal_tokens, end_goal_tokens)
+ } else {
+ let end_goal_tokens = (remaining_tokens / 2).min(end_window_tokens);
+ remaining_tokens -= end_goal_tokens;
+ let start_goal_tokens = remaining_tokens.min(start_window_tokens);
+ (start_goal_tokens, end_goal_tokens)
+ };
+
+ let truncated_start_window =
+ model.truncate_start(&start_window, start_goal_tokens)?;
+ let truncated_end_window = model.truncate(&end_window, end_goal_tokens)?;
+ writeln!(
+ prompt,
+ "{truncated_start_window}{selected_window}{truncated_end_window}"
+ )
+ .unwrap();
+ truncated = true;
+ } else {
+ writeln!(prompt, "{start_window}{selected_window}{end_window}").unwrap();
+ }
+ } else {
+ // If we dont have a selected range, include entire file.
+ writeln!(prompt, "{}", &buffer.text()).unwrap();
+
+ // Dumb truncation strategy
+ if let Some(max_token_count) = max_token_count {
+ if model.count_tokens(&prompt)? > max_token_count {
+ truncated = true;
+ prompt = model.truncate(&prompt, max_token_count)?;
+ }
+ }
+ }
+ }
+
+ let token_count = model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count, truncated))
+}
+
+pub struct FileContext {}
+
+impl PromptTemplate for FileContext {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ if let Some(buffer) = &args.buffer {
+ let mut prompt = String::new();
+ // Add Initial Preamble
+ // TODO: Do we want to add the path in here?
+ writeln!(
+ prompt,
+ "The file you are currently working on has the following content:"
+ )
+ .unwrap();
+
+ let language_name = args
+ .language_name
+ .clone()
+ .unwrap_or("".to_string())
+ .to_lowercase();
+
+ let (context, _, truncated) = retrieve_context(
+ buffer,
+ &args.selected_range,
+ args.model.clone(),
+ max_token_length,
+ )?;
+ writeln!(prompt, "```{language_name}\n{context}\n```").unwrap();
+
+ if truncated {
+ writeln!(prompt, "Note the content has been truncated and only represents a portion of the file.").unwrap();
+ }
+
+ if let Some(selected_range) = &args.selected_range {
+ let start = selected_range.start.to_offset(buffer);
+ let end = selected_range.end.to_offset(buffer);
+
+ if start == end {
+ writeln!(prompt, "In particular, the user's cursor is currently on the '<|START|>' span in the above content, with no text selected.").unwrap();
+ } else {
+ writeln!(prompt, "In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.").unwrap();
+ }
+ }
+
+ // Really dumb truncation strategy
+ if let Some(max_tokens) = max_token_length {
+ prompt = args.model.truncate(&prompt, max_tokens)?;
+ }
+
+ let token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count))
+ } else {
+ Err(anyhow!("no buffer provided to retrieve file context from"))
+ }
+ }
+}
@@ -0,0 +1,95 @@
+use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use anyhow::anyhow;
+use std::fmt::Write;
+
+pub fn capitalize(s: &str) -> String {
+ let mut c = s.chars();
+ match c.next() {
+ None => String::new(),
+ Some(f) => f.to_uppercase().collect::<String>() + c.as_str(),
+ }
+}
+
+pub struct GenerateInlineContent {}
+
+impl PromptTemplate for GenerateInlineContent {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let Some(user_prompt) = &args.user_prompt else {
+ return Err(anyhow!("user prompt not provided"));
+ };
+
+ let file_type = args.get_file_type();
+ let content_type = match &file_type {
+ PromptFileType::Code => "code",
+ PromptFileType::Text => "text",
+ };
+
+ let mut prompt = String::new();
+
+ if let Some(selected_range) = &args.selected_range {
+ if selected_range.start == selected_range.end {
+ writeln!(
+ prompt,
+ "Assume the cursor is located where the `<|START|>` span is."
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "{} can't be replaced, so assume your answer will be inserted at the cursor.",
+ capitalize(content_type)
+ )
+ .unwrap();
+ writeln!(
+ prompt,
+ "Generate {content_type} based on the users prompt: {user_prompt}",
+ )
+ .unwrap();
+ } else {
+ writeln!(prompt, "Modify the user's selected {content_type} based upon the users prompt: '{user_prompt}'").unwrap();
+ writeln!(prompt, "You must reply with only the adjusted {content_type} (within the '<|START|' and '|END|>' spans) not the entire file.").unwrap();
+ writeln!(prompt, "Double check that you only return code and not the '<|START|' and '|END|'> spans").unwrap();
+ }
+ } else {
+ writeln!(
+ prompt,
+ "Generate {content_type} based on the users prompt: {user_prompt}"
+ )
+ .unwrap();
+ }
+
+ if let Some(language_name) = &args.language_name {
+ writeln!(
+ prompt,
+ "Your answer MUST always and only be valid {}.",
+ language_name
+ )
+ .unwrap();
+ }
+ writeln!(prompt, "Never make remarks about the output.").unwrap();
+ writeln!(
+ prompt,
+ "Do not return anything else, except the generated {content_type}."
+ )
+ .unwrap();
+
+ match file_type {
+ PromptFileType::Code => {
+ // writeln!(prompt, "Always wrap your code in a Markdown block.").unwrap();
+ }
+ _ => {}
+ }
+
+ // Really dumb truncation strategy
+ if let Some(max_tokens) = max_token_length {
+ prompt = args.model.truncate(&prompt, max_tokens)?;
+ }
+
+ let token_count = args.model.count_tokens(&prompt)?;
+
+ anyhow::Ok((prompt, token_count))
+ }
+}
@@ -0,0 +1,5 @@
+pub mod base;
+pub mod file_context;
+pub mod generate;
+pub mod preamble;
+pub mod repository_context;
@@ -0,0 +1,52 @@
+use crate::templates::base::{PromptArguments, PromptFileType, PromptTemplate};
+use std::fmt::Write;
+
+pub struct EngineerPreamble {}
+
+impl PromptTemplate for EngineerPreamble {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ let mut prompts = Vec::new();
+
+ match args.get_file_type() {
+ PromptFileType::Code => {
+ prompts.push(format!(
+ "You are an expert {}engineer.",
+ args.language_name.clone().unwrap_or("".to_string()) + " "
+ ));
+ }
+ PromptFileType::Text => {
+ prompts.push("You are an expert engineer.".to_string());
+ }
+ }
+
+ if let Some(project_name) = args.project_name.clone() {
+ prompts.push(format!(
+ "You are currently working inside the '{project_name}' project in code editor Zed."
+ ));
+ }
+
+ if let Some(mut remaining_tokens) = max_token_length {
+ let mut prompt = String::new();
+ let mut total_count = 0;
+ for prompt_piece in prompts {
+ let prompt_token_count =
+ args.model.count_tokens(&prompt_piece)? + args.model.count_tokens("\n")?;
+ if remaining_tokens > prompt_token_count {
+ writeln!(prompt, "{prompt_piece}").unwrap();
+ remaining_tokens -= prompt_token_count;
+ total_count += prompt_token_count;
+ }
+ }
+
+ anyhow::Ok((prompt, total_count))
+ } else {
+ let prompt = prompts.join("\n");
+ let token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, token_count))
+ }
+ }
+}
@@ -0,0 +1,94 @@
+use crate::templates::base::{PromptArguments, PromptTemplate};
+use std::fmt::Write;
+use std::{ops::Range, path::PathBuf};
+
+use gpui::{AsyncAppContext, ModelHandle};
+use language::{Anchor, Buffer};
+
+#[derive(Clone)]
+pub struct PromptCodeSnippet {
+ path: Option<PathBuf>,
+ language_name: Option<String>,
+ content: String,
+}
+
+impl PromptCodeSnippet {
+ pub fn new(buffer: ModelHandle<Buffer>, range: Range<Anchor>, cx: &AsyncAppContext) -> Self {
+ let (content, language_name, file_path) = buffer.read_with(cx, |buffer, _| {
+ let snapshot = buffer.snapshot();
+ let content = snapshot.text_for_range(range.clone()).collect::<String>();
+
+ let language_name = buffer
+ .language()
+ .and_then(|language| Some(language.name().to_string().to_lowercase()));
+
+ let file_path = buffer
+ .file()
+ .and_then(|file| Some(file.path().to_path_buf()));
+
+ (content, language_name, file_path)
+ });
+
+ PromptCodeSnippet {
+ path: file_path,
+ language_name,
+ content,
+ }
+ }
+}
+
+impl ToString for PromptCodeSnippet {
+ fn to_string(&self) -> String {
+ let path = self
+ .path
+ .as_ref()
+ .and_then(|path| Some(path.to_string_lossy().to_string()))
+ .unwrap_or("".to_string());
+ let language_name = self.language_name.clone().unwrap_or("".to_string());
+ let content = self.content.clone();
+
+ format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
+ }
+}
+
+pub struct RepositoryContext {}
+
+impl PromptTemplate for RepositoryContext {
+ fn generate(
+ &self,
+ args: &PromptArguments,
+ max_token_length: Option<usize>,
+ ) -> anyhow::Result<(String, usize)> {
+ const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
+ let template = "You are working inside a large repository, here are a few code snippets that may be useful.";
+ let mut prompt = String::new();
+
+ let mut remaining_tokens = max_token_length.clone();
+ let seperator_token_length = args.model.count_tokens("\n")?;
+ for snippet in &args.snippets {
+ let mut snippet_prompt = template.to_string();
+ let content = snippet.to_string();
+ writeln!(snippet_prompt, "{content}").unwrap();
+
+ let token_count = args.model.count_tokens(&snippet_prompt)?;
+ if token_count <= MAXIMUM_SNIPPET_TOKEN_COUNT {
+ if let Some(tokens_left) = remaining_tokens {
+ if tokens_left >= token_count {
+ writeln!(prompt, "{snippet_prompt}").unwrap();
+ remaining_tokens = if tokens_left >= (token_count + seperator_token_length)
+ {
+ Some(tokens_left - token_count - seperator_token_length)
+ } else {
+ Some(0)
+ };
+ }
+ } else {
+ writeln!(prompt, "{snippet_prompt}").unwrap();
+ }
+ }
+ }
+
+ let total_token_count = args.model.count_tokens(&prompt)?;
+ anyhow::Ok((prompt, total_token_count))
+ }
+}
@@ -1,12 +1,15 @@
use crate::{
assistant_settings::{AssistantDockPosition, AssistantSettings, OpenAIModel},
codegen::{self, Codegen, CodegenKind},
- prompts::{generate_content_prompt, PromptCodeSnippet},
+ prompts::generate_content_prompt,
MessageId, MessageMetadata, MessageStatus, Role, SavedConversation, SavedConversationMetadata,
SavedMessage,
};
-use ai::completion::{
- stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+use ai::{
+ completion::{
+ stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL,
+ },
+ templates::repository_context::PromptCodeSnippet,
};
use anyhow::{anyhow, Result};
use chrono::{DateTime, Local};
@@ -609,6 +612,18 @@ impl AssistantPanel {
let project = pending_assist.project.clone();
+ let project_name = if let Some(project) = project.upgrade(cx) {
+ Some(
+ project
+ .read(cx)
+ .worktree_root_names(cx)
+ .collect::<Vec<&str>>()
+ .join("/"),
+ )
+ } else {
+ None
+ };
+
self.inline_prompt_history
.retain(|prompt| prompt != user_prompt);
self.inline_prompt_history.push_back(user_prompt.into());
@@ -646,7 +661,19 @@ impl AssistantPanel {
None
};
- let codegen_kind = codegen.read(cx).kind().clone();
+ // Higher Temperature increases the randomness of model outputs.
+ // If Markdown or No Language is Known, increase the randomness for more creative output
+ // If Code, decrease temperature to get more deterministic outputs
+ let temperature = if let Some(language) = language_name.clone() {
+ if language.to_string() != "Markdown".to_string() {
+ 0.5
+ } else {
+ 1.0
+ }
+ } else {
+ 1.0
+ };
+
let user_prompt = user_prompt.to_string();
let snippets = if retrieve_context {
@@ -668,14 +695,7 @@ impl AssistantPanel {
let snippets = cx.spawn(|_, cx| async move {
let mut snippets = Vec::new();
for result in search_results.await {
- snippets.push(PromptCodeSnippet::new(result, &cx));
-
- // snippets.push(result.buffer.read_with(&cx, |buffer, _| {
- // buffer
- // .snapshot()
- // .text_for_range(result.range)
- // .collect::<String>()
- // }));
+ snippets.push(PromptCodeSnippet::new(result.buffer, result.range, &cx));
}
snippets
});
@@ -696,11 +716,11 @@ impl AssistantPanel {
generate_content_prompt(
user_prompt,
language_name,
- &buffer,
+ buffer,
range,
- codegen_kind,
snippets,
model_name,
+ project_name,
)
});
@@ -717,18 +737,23 @@ impl AssistantPanel {
}
cx.spawn(|_, mut cx| async move {
- let prompt = prompt.await;
+ // I Don't know if we want to return a ? here.
+ let prompt = prompt.await?;
messages.push(RequestMessage {
role: Role::User,
content: prompt,
});
+
let request = OpenAIRequest {
model: model.full_name().into(),
messages,
stream: true,
+ stop: vec!["|END|>".to_string()],
+ temperature,
};
codegen.update(&mut cx, |codegen, cx| codegen.start(request, cx));
+ anyhow::Ok(())
})
.detach();
}
@@ -1718,6 +1743,8 @@ impl Conversation {
.map(|message| message.to_open_ai_message(self.buffer.read(cx)))
.collect(),
stream: true,
+ stop: vec![],
+ temperature: 1.0,
};
let stream = stream_completion(api_key, cx.background().clone(), request);
@@ -2002,6 +2029,8 @@ impl Conversation {
model: self.model.full_name().to_string(),
messages: messages.collect(),
stream: true,
+ stop: vec![],
+ temperature: 1.0,
};
let stream = stream_completion(api_key, cx.background().clone(), request);
@@ -1,60 +1,13 @@
-use crate::codegen::CodegenKind;
-use gpui::AsyncAppContext;
+use ai::models::{LanguageModel, OpenAILanguageModel};
+use ai::templates::base::{PromptArguments, PromptChain, PromptPriority, PromptTemplate};
+use ai::templates::file_context::FileContext;
+use ai::templates::generate::GenerateInlineContent;
+use ai::templates::preamble::EngineerPreamble;
+use ai::templates::repository_context::{PromptCodeSnippet, RepositoryContext};
use language::{BufferSnapshot, OffsetRangeExt, ToOffset};
-use semantic_index::SearchResult;
use std::cmp::{self, Reverse};
-use std::fmt::Write;
use std::ops::Range;
-use std::path::PathBuf;
-use tiktoken_rs::ChatCompletionRequestMessage;
-
-pub struct PromptCodeSnippet {
- path: Option<PathBuf>,
- language_name: Option<String>,
- content: String,
-}
-
-impl PromptCodeSnippet {
- pub fn new(search_result: SearchResult, cx: &AsyncAppContext) -> Self {
- let (content, language_name, file_path) =
- search_result.buffer.read_with(cx, |buffer, _| {
- let snapshot = buffer.snapshot();
- let content = snapshot
- .text_for_range(search_result.range.clone())
- .collect::<String>();
-
- let language_name = buffer
- .language()
- .and_then(|language| Some(language.name().to_string()));
-
- let file_path = buffer
- .file()
- .and_then(|file| Some(file.path().to_path_buf()));
-
- (content, language_name, file_path)
- });
-
- PromptCodeSnippet {
- path: file_path,
- language_name,
- content,
- }
- }
-}
-
-impl ToString for PromptCodeSnippet {
- fn to_string(&self) -> String {
- let path = self
- .path
- .as_ref()
- .and_then(|path| Some(path.to_string_lossy().to_string()))
- .unwrap_or("".to_string());
- let language_name = self.language_name.clone().unwrap_or("".to_string());
- let content = self.content.clone();
-
- format!("The below code snippet may be relevant from file: {path}\n```{language_name}\n{content}\n```")
- }
-}
+use std::sync::Arc;
#[allow(dead_code)]
fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> String {
@@ -170,134 +123,50 @@ fn summarize(buffer: &BufferSnapshot, selected_range: Range<impl ToOffset>) -> S
pub fn generate_content_prompt(
user_prompt: String,
language_name: Option<&str>,
- buffer: &BufferSnapshot,
- range: Range<impl ToOffset>,
- kind: CodegenKind,
+ buffer: BufferSnapshot,
+ range: Range<usize>,
search_results: Vec<PromptCodeSnippet>,
model: &str,
-) -> String {
- const MAXIMUM_SNIPPET_TOKEN_COUNT: usize = 500;
- const RESERVED_TOKENS_FOR_GENERATION: usize = 1000;
-
- let mut prompts = Vec::new();
- let range = range.to_offset(buffer);
-
- // General Preamble
- if let Some(language_name) = language_name {
- prompts.push(format!("You're an expert {language_name} engineer.\n"));
- } else {
- prompts.push("You're an expert engineer.\n".to_string());
- }
-
- // Snippets
- let mut snippet_position = prompts.len() - 1;
-
- let mut content = String::new();
- content.extend(buffer.text_for_range(0..range.start));
- if range.start == range.end {
- content.push_str("<|START|>");
+ project_name: Option<String>,
+) -> anyhow::Result<String> {
+ // Using new Prompt Templates
+ let openai_model: Arc<dyn LanguageModel> = Arc::new(OpenAILanguageModel::load(model));
+ let lang_name = if let Some(language_name) = language_name {
+ Some(language_name.to_string())
} else {
- content.push_str("<|START|");
- }
- content.extend(buffer.text_for_range(range.clone()));
- if range.start != range.end {
- content.push_str("|END|>");
- }
- content.extend(buffer.text_for_range(range.end..buffer.len()));
-
- prompts.push("The file you are currently working on has the following content:\n".to_string());
-
- if let Some(language_name) = language_name {
- let language_name = language_name.to_lowercase();
- prompts.push(format!("```{language_name}\n{content}\n```"));
- } else {
- prompts.push(format!("```\n{content}\n```"));
- }
-
- match kind {
- CodegenKind::Generate { position: _ } => {
- prompts.push("In particular, the user's cursor is currently on the '<|START|>' span in the above outline, with no text selected.".to_string());
- prompts
- .push("Assume the cursor is located where the `<|START|` marker is.".to_string());
- prompts.push(
- "Text can't be replaced, so assume your answer will be inserted at the cursor."
- .to_string(),
- );
- prompts.push(format!(
- "Generate text based on the users prompt: {user_prompt}"
- ));
- }
- CodegenKind::Transform { range: _ } => {
- prompts.push("In particular, the user has selected a section of the text between the '<|START|' and '|END|>' spans.".to_string());
- prompts.push(format!(
- "Modify the users code selected text based upon the users prompt: '{user_prompt}'"
- ));
- prompts.push("You MUST reply with only the adjusted code (within the '<|START|' and '|END|>' spans), not the entire file.".to_string());
- }
- }
-
- if let Some(language_name) = language_name {
- prompts.push(format!(
- "Your answer MUST always and only be valid {language_name}"
- ));
- }
- prompts.push("Never make remarks about the output.".to_string());
- prompts.push("Do not return any text, except the generated code.".to_string());
- prompts.push("Do not wrap your text in a Markdown block".to_string());
-
- let current_messages = [ChatCompletionRequestMessage {
- role: "user".to_string(),
- content: Some(prompts.join("\n")),
- function_call: None,
- name: None,
- }];
-
- let mut remaining_token_count = if let Ok(current_token_count) =
- tiktoken_rs::num_tokens_from_messages(model, ¤t_messages)
- {
- let max_token_count = tiktoken_rs::model::get_context_size(model);
- let intermediate_token_count = max_token_count - current_token_count;
-
- if intermediate_token_count < RESERVED_TOKENS_FOR_GENERATION {
- 0
- } else {
- intermediate_token_count - RESERVED_TOKENS_FOR_GENERATION
- }
- } else {
- // If tiktoken fails to count token count, assume we have no space remaining.
- 0
+ None
};
- // TODO:
- // - add repository name to snippet
- // - add file path
- // - add language
- if let Ok(encoding) = tiktoken_rs::get_bpe_from_model(model) {
- let mut template = "You are working inside a large repository, here are a few code snippets that may be useful";
-
- for search_result in search_results {
- let mut snippet_prompt = template.to_string();
- let snippet = search_result.to_string();
- writeln!(snippet_prompt, "```\n{snippet}\n```").unwrap();
-
- let token_count = encoding
- .encode_with_special_tokens(snippet_prompt.as_str())
- .len();
- if token_count <= remaining_token_count {
- if token_count < MAXIMUM_SNIPPET_TOKEN_COUNT {
- prompts.insert(snippet_position, snippet_prompt);
- snippet_position += 1;
- remaining_token_count -= token_count;
- // If you have already added the template to the prompt, remove the template.
- template = "";
- }
- } else {
- break;
- }
- }
- }
+ let args = PromptArguments {
+ model: openai_model,
+ language_name: lang_name.clone(),
+ project_name,
+ snippets: search_results.clone(),
+ reserved_tokens: 1000,
+ buffer: Some(buffer),
+ selected_range: Some(range),
+ user_prompt: Some(user_prompt.clone()),
+ };
- prompts.join("\n")
+ let templates: Vec<(PromptPriority, Box<dyn PromptTemplate>)> = vec![
+ (PromptPriority::Mandatory, Box::new(EngineerPreamble {})),
+ (
+ PromptPriority::Ordered { order: 1 },
+ Box::new(RepositoryContext {}),
+ ),
+ (
+ PromptPriority::Ordered { order: 0 },
+ Box::new(FileContext {}),
+ ),
+ (
+ PromptPriority::Mandatory,
+ Box::new(GenerateInlineContent {}),
+ ),
+ ];
+ let chain = PromptChain::new(args, templates);
+ let (prompt, _) = chain.generate(true)?;
+
+ anyhow::Ok(prompt)
}
#[cfg(test)]
@@ -7,7 +7,10 @@ use gpui::{AppContext, ModelHandle};
use std::sync::Arc;
pub use channel_buffer::{ChannelBuffer, ChannelBufferEvent, ACKNOWLEDGE_DEBOUNCE_INTERVAL};
-pub use channel_chat::{ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId};
+pub use channel_chat::{
+ mentions_to_proto, ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId,
+ MessageParams,
+};
pub use channel_store::{
Channel, ChannelData, ChannelEvent, ChannelId, ChannelMembership, ChannelPath, ChannelStore,
};
@@ -3,12 +3,17 @@ use anyhow::{anyhow, Result};
use client::{
proto,
user::{User, UserStore},
- Client, Subscription, TypedEnvelope,
+ Client, Subscription, TypedEnvelope, UserId,
};
use futures::lock::Mutex;
use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
use rand::prelude::*;
-use std::{collections::HashSet, mem, ops::Range, sync::Arc};
+use std::{
+ collections::HashSet,
+ mem,
+ ops::{ControlFlow, Range},
+ sync::Arc,
+};
use sum_tree::{Bias, SumTree};
use time::OffsetDateTime;
use util::{post_inc, ResultExt as _, TryFutureExt};
@@ -16,6 +21,7 @@ use util::{post_inc, ResultExt as _, TryFutureExt};
pub struct ChannelChat {
pub channel_id: ChannelId,
messages: SumTree<ChannelMessage>,
+ acknowledged_message_ids: HashSet<u64>,
channel_store: ModelHandle<ChannelStore>,
loaded_all_messages: bool,
last_acknowledged_id: Option<u64>,
@@ -27,6 +33,12 @@ pub struct ChannelChat {
_subscription: Subscription,
}
+#[derive(Debug, PartialEq, Eq)]
+pub struct MessageParams {
+ pub text: String,
+ pub mentions: Vec<(Range<usize>, UserId)>,
+}
+
#[derive(Clone, Debug)]
pub struct ChannelMessage {
pub id: ChannelMessageId,
@@ -34,6 +46,7 @@ pub struct ChannelMessage {
pub timestamp: OffsetDateTime,
pub sender: Arc<User>,
pub nonce: u128,
+ pub mentions: Vec<(Range<usize>, UserId)>,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
@@ -105,6 +118,7 @@ impl ChannelChat {
rpc: client,
outgoing_messages_lock: Default::default(),
messages: Default::default(),
+ acknowledged_message_ids: Default::default(),
loaded_all_messages,
next_pending_message_id: 0,
last_acknowledged_id: None,
@@ -123,12 +137,16 @@ impl ChannelChat {
.cloned()
}
+ pub fn client(&self) -> &Arc<Client> {
+ &self.rpc
+ }
+
pub fn send_message(
&mut self,
- body: String,
+ message: MessageParams,
cx: &mut ModelContext<Self>,
- ) -> Result<Task<Result<()>>> {
- if body.is_empty() {
+ ) -> Result<Task<Result<u64>>> {
+ if message.text.is_empty() {
Err(anyhow!("message body can't be empty"))?;
}
@@ -145,9 +163,10 @@ impl ChannelChat {
SumTree::from_item(
ChannelMessage {
id: pending_id,
- body: body.clone(),
+ body: message.text.clone(),
sender: current_user,
timestamp: OffsetDateTime::now_utc(),
+ mentions: message.mentions.clone(),
nonce,
},
&(),
@@ -161,20 +180,18 @@ impl ChannelChat {
let outgoing_message_guard = outgoing_messages_lock.lock().await;
let request = rpc.request(proto::SendChannelMessage {
channel_id,
- body,
+ body: message.text,
nonce: Some(nonce.into()),
+ mentions: mentions_to_proto(&message.mentions),
});
let response = request.await?;
drop(outgoing_message_guard);
- let message = ChannelMessage::from_proto(
- response.message.ok_or_else(|| anyhow!("invalid message"))?,
- &user_store,
- &mut cx,
- )
- .await?;
+ let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
+ let id = response.id;
+ let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
this.update(&mut cx, |this, cx| {
this.insert_messages(SumTree::from_item(message, &()), cx);
- Ok(())
+ Ok(id)
})
}))
}
@@ -194,41 +211,76 @@ impl ChannelChat {
})
}
- pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
- if !self.loaded_all_messages {
- let rpc = self.rpc.clone();
- let user_store = self.user_store.clone();
- let channel_id = self.channel_id;
- if let Some(before_message_id) =
- self.messages.first().and_then(|message| match message.id {
- ChannelMessageId::Saved(id) => Some(id),
- ChannelMessageId::Pending(_) => None,
- })
- {
- cx.spawn(|this, mut cx| {
- async move {
- let response = rpc
- .request(proto::GetChannelMessages {
- channel_id,
- before_message_id,
- })
- .await?;
- let loaded_all_messages = response.done;
- let messages =
- messages_from_proto(response.messages, &user_store, &mut cx).await?;
- this.update(&mut cx, |this, cx| {
- this.loaded_all_messages = loaded_all_messages;
- this.insert_messages(messages, cx);
- });
- anyhow::Ok(())
+ pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Option<()>>> {
+ if self.loaded_all_messages {
+ return None;
+ }
+
+ let rpc = self.rpc.clone();
+ let user_store = self.user_store.clone();
+ let channel_id = self.channel_id;
+ let before_message_id = self.first_loaded_message_id()?;
+ Some(cx.spawn(|this, mut cx| {
+ async move {
+ let response = rpc
+ .request(proto::GetChannelMessages {
+ channel_id,
+ before_message_id,
+ })
+ .await?;
+ let loaded_all_messages = response.done;
+ let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?;
+ this.update(&mut cx, |this, cx| {
+ this.loaded_all_messages = loaded_all_messages;
+ this.insert_messages(messages, cx);
+ });
+ anyhow::Ok(())
+ }
+ .log_err()
+ }))
+ }
+
+ pub fn first_loaded_message_id(&mut self) -> Option<u64> {
+ self.messages.first().and_then(|message| match message.id {
+ ChannelMessageId::Saved(id) => Some(id),
+ ChannelMessageId::Pending(_) => None,
+ })
+ }
+
+ /// Load all of the chat messages since a certain message id.
+ ///
+ /// For now, we always maintain a suffix of the channel's messages.
+ pub async fn load_history_since_message(
+ chat: ModelHandle<Self>,
+ message_id: u64,
+ mut cx: AsyncAppContext,
+ ) -> Option<usize> {
+ loop {
+ let step = chat.update(&mut cx, |chat, cx| {
+ if let Some(first_id) = chat.first_loaded_message_id() {
+ if first_id <= message_id {
+ let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>();
+ let message_id = ChannelMessageId::Saved(message_id);
+ cursor.seek(&message_id, Bias::Left, &());
+ return ControlFlow::Break(
+ if cursor
+ .item()
+ .map_or(false, |message| message.id == message_id)
+ {
+ Some(cursor.start().1 .0)
+ } else {
+ None
+ },
+ );
}
- .log_err()
- })
- .detach();
- return true;
+ }
+ ControlFlow::Continue(chat.load_more_messages(cx))
+ });
+ match step {
+ ControlFlow::Break(ix) => return ix,
+ ControlFlow::Continue(task) => task?.await?,
}
}
- false
}
pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext<Self>) {
@@ -287,6 +339,7 @@ impl ChannelChat {
let request = rpc.request(proto::SendChannelMessage {
channel_id,
body: pending_message.body,
+ mentions: mentions_to_proto(&pending_message.mentions),
nonce: Some(pending_message.nonce.into()),
});
let response = request.await?;
@@ -322,6 +375,17 @@ impl ChannelChat {
cursor.item().unwrap()
}
+ pub fn acknowledge_message(&mut self, id: u64) {
+ if self.acknowledged_message_ids.insert(id) {
+ self.rpc
+ .send(proto::AckChannelMessage {
+ channel_id: self.channel_id,
+ message_id: id,
+ })
+ .ok();
+ }
+ }
+
pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
let mut cursor = self.messages.cursor::<Count>();
cursor.seek(&Count(range.start), Bias::Right, &());
@@ -454,22 +518,7 @@ async fn messages_from_proto(
user_store: &ModelHandle<UserStore>,
cx: &mut AsyncAppContext,
) -> Result<SumTree<ChannelMessage>> {
- let unique_user_ids = proto_messages
- .iter()
- .map(|m| m.sender_id)
- .collect::<HashSet<_>>()
- .into_iter()
- .collect();
- user_store
- .update(cx, |user_store, cx| {
- user_store.get_users(unique_user_ids, cx)
- })
- .await?;
-
- let mut messages = Vec::with_capacity(proto_messages.len());
- for message in proto_messages {
- messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
- }
+ let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?;
let mut result = SumTree::new();
result.extend(messages, &());
Ok(result)
@@ -489,6 +538,14 @@ impl ChannelMessage {
Ok(ChannelMessage {
id: ChannelMessageId::Saved(message.id),
body: message.body,
+ mentions: message
+ .mentions
+ .into_iter()
+ .filter_map(|mention| {
+ let range = mention.range?;
+ Some((range.start as usize..range.end as usize, mention.user_id))
+ })
+ .collect(),
timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?,
sender,
nonce: message
@@ -501,6 +558,43 @@ impl ChannelMessage {
pub fn is_pending(&self) -> bool {
matches!(self.id, ChannelMessageId::Pending(_))
}
+
+ pub async fn from_proto_vec(
+ proto_messages: Vec<proto::ChannelMessage>,
+ user_store: &ModelHandle<UserStore>,
+ cx: &mut AsyncAppContext,
+ ) -> Result<Vec<Self>> {
+ let unique_user_ids = proto_messages
+ .iter()
+ .map(|m| m.sender_id)
+ .collect::<HashSet<_>>()
+ .into_iter()
+ .collect();
+ user_store
+ .update(cx, |user_store, cx| {
+ user_store.get_users(unique_user_ids, cx)
+ })
+ .await?;
+
+ let mut messages = Vec::with_capacity(proto_messages.len());
+ for message in proto_messages {
+ messages.push(ChannelMessage::from_proto(message, user_store, cx).await?);
+ }
+ Ok(messages)
+ }
+}
+
+pub fn mentions_to_proto(mentions: &[(Range<usize>, UserId)]) -> Vec<proto::ChatMention> {
+ mentions
+ .iter()
+ .map(|(range, user_id)| proto::ChatMention {
+ range: Some(proto::Range {
+ start: range.start as u64,
+ end: range.end as u64,
+ }),
+ user_id: *user_id as u64,
+ })
+ .collect()
}
impl sum_tree::Item for ChannelMessage {
@@ -541,3 +635,12 @@ impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
self.0 += summary.count;
}
}
+
+impl<'a> From<&'a str> for MessageParams {
+ fn from(value: &'a str) -> Self {
+ Self {
+ text: value.into(),
+ mentions: Vec::new(),
+ }
+ }
+}
@@ -1,6 +1,6 @@
mod channel_index;
-use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat};
+use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat, ChannelMessage};
use anyhow::{anyhow, Result};
use channel_index::ChannelIndex;
use client::{Client, Subscription, User, UserId, UserStore};
@@ -157,9 +157,6 @@ impl ChannelStore {
this.update(&mut cx, |this, cx| this.handle_disconnect(true, cx));
}
}
- if status.is_connected() {
- } else {
- }
}
Some(())
});
@@ -245,6 +242,12 @@ impl ChannelStore {
self.channel_index.by_id().values().nth(ix)
}
+ pub fn has_channel_invitation(&self, channel_id: ChannelId) -> bool {
+ self.channel_invitations
+ .iter()
+ .any(|channel| channel.id == channel_id)
+ }
+
pub fn channel_invitations(&self) -> &[Arc<Channel>] {
&self.channel_invitations
}
@@ -278,6 +281,33 @@ impl ChannelStore {
)
}
+ pub fn fetch_channel_messages(
+ &self,
+ message_ids: Vec<u64>,
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<Vec<ChannelMessage>>> {
+ let request = if message_ids.is_empty() {
+ None
+ } else {
+ Some(
+ self.client
+ .request(proto::GetChannelMessagesById { message_ids }),
+ )
+ };
+ cx.spawn_weak(|this, mut cx| async move {
+ if let Some(request) = request {
+ let response = request.await?;
+ let this = this
+ .upgrade(&cx)
+ .ok_or_else(|| anyhow!("channel store dropped"))?;
+ let user_store = this.read_with(&cx, |this, _| this.user_store.clone());
+ ChannelMessage::from_proto_vec(response.messages, &user_store, &mut cx).await
+ } else {
+ Ok(Vec::new())
+ }
+ })
+ }
+
pub fn has_channel_buffer_changed(&self, channel_id: ChannelId) -> Option<bool> {
self.channel_index
.by_id()
@@ -689,14 +719,15 @@ impl ChannelStore {
&mut self,
channel_id: ChannelId,
accept: bool,
- ) -> impl Future<Output = Result<()>> {
+ cx: &mut ModelContext<Self>,
+ ) -> Task<Result<()>> {
let client = self.client.clone();
- async move {
+ cx.background().spawn(async move {
client
.request(proto::RespondToChannelInvite { channel_id, accept })
.await?;
Ok(())
- }
+ })
}
pub fn get_channel_member_details(
@@ -764,6 +795,11 @@ impl ChannelStore {
}
fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+ self.channel_index.clear();
+ self.channel_invitations.clear();
+ self.channel_participants.clear();
+ self.channel_index.clear();
+ self.outgoing_invites.clear();
self.disconnect_channel_buffers_task.take();
for chat in self.opened_chats.values() {
@@ -873,11 +909,6 @@ impl ChannelStore {
}
fn handle_disconnect(&mut self, wait_for_reconnect: bool, cx: &mut ModelContext<Self>) {
- self.channel_index.clear();
- self.channel_invitations.clear();
- self.channel_participants.clear();
- self.channel_index.clear();
- self.outgoing_invites.clear();
cx.notify();
self.disconnect_channel_buffers_task.get_or_insert_with(|| {
@@ -210,6 +210,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
body: "a".into(),
timestamp: 1000,
sender_id: 5,
+ mentions: vec![],
nonce: Some(1.into()),
},
proto::ChannelMessage {
@@ -217,6 +218,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
body: "b".into(),
timestamp: 1001,
sender_id: 6,
+ mentions: vec![],
nonce: Some(2.into()),
},
],
@@ -263,6 +265,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
body: "c".into(),
timestamp: 1002,
sender_id: 7,
+ mentions: vec![],
nonce: Some(3.into()),
}),
});
@@ -300,7 +303,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
// Scroll up to view older messages.
channel.update(cx, |channel, cx| {
- assert!(channel.load_more_messages(cx));
+ channel.load_more_messages(cx).unwrap().detach();
});
let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
assert_eq!(get_messages.payload.channel_id, 5);
@@ -316,6 +319,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
timestamp: 998,
sender_id: 5,
nonce: Some(4.into()),
+ mentions: vec![],
},
proto::ChannelMessage {
id: 9,
@@ -323,6 +327,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) {
timestamp: 999,
sender_id: 6,
nonce: Some(5.into()),
+ mentions: vec![],
},
],
},
@@ -293,21 +293,19 @@ impl UserStore {
// No need to paralellize here
let mut updated_contacts = Vec::new();
for contact in message.contacts {
- let should_notify = contact.should_notify;
- updated_contacts.push((
- Arc::new(Contact::from_proto(contact, &this, &mut cx).await?),
- should_notify,
+ updated_contacts.push(Arc::new(
+ Contact::from_proto(contact, &this, &mut cx).await?,
));
}
let mut incoming_requests = Vec::new();
for request in message.incoming_requests {
- incoming_requests.push({
- let user = this
- .update(&mut cx, |this, cx| this.get_user(request.requester_id, cx))
- .await?;
- (user, request.should_notify)
- });
+ incoming_requests.push(
+ this.update(&mut cx, |this, cx| {
+ this.get_user(request.requester_id, cx)
+ })
+ .await?,
+ );
}
let mut outgoing_requests = Vec::new();
@@ -330,13 +328,7 @@ impl UserStore {
this.contacts
.retain(|contact| !removed_contacts.contains(&contact.user.id));
// Update existing contacts and insert new ones
- for (updated_contact, should_notify) in updated_contacts {
- if should_notify {
- cx.emit(Event::Contact {
- user: updated_contact.user.clone(),
- kind: ContactEventKind::Accepted,
- });
- }
+ for updated_contact in updated_contacts {
match this.contacts.binary_search_by_key(
&&updated_contact.user.github_login,
|contact| &contact.user.github_login,
@@ -359,14 +351,7 @@ impl UserStore {
}
});
// Update existing incoming requests and insert new ones
- for (user, should_notify) in incoming_requests {
- if should_notify {
- cx.emit(Event::Contact {
- user: user.clone(),
- kind: ContactEventKind::Requested,
- });
- }
-
+ for user in incoming_requests {
match this
.incoming_contact_requests
.binary_search_by_key(&&user.github_login, |contact| {
@@ -415,6 +400,12 @@ impl UserStore {
&self.incoming_contact_requests
}
+ pub fn has_incoming_contact_request(&self, user_id: u64) -> bool {
+ self.incoming_contact_requests
+ .iter()
+ .any(|user| user.id == user_id)
+ }
+
pub fn outgoing_contact_requests(&self) -> &[Arc<User>] {
&self.outgoing_contact_requests
}
@@ -3,7 +3,7 @@ authors = ["Nathan Sobo <nathan@zed.dev>"]
default-run = "collab"
edition = "2021"
name = "collab"
-version = "0.24.0"
+version = "0.25.0"
publish = false
[[bin]]
@@ -73,6 +73,7 @@ git = { path = "../git", features = ["test-support"] }
live_kit_client = { path = "../live_kit_client", features = ["test-support"] }
lsp = { path = "../lsp", features = ["test-support"] }
node_runtime = { path = "../node_runtime" }
+notifications = { path = "../notifications", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
rpc = { path = "../rpc", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"] }
@@ -192,7 +192,7 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id");
CREATE TABLE "channels" (
"id" INTEGER PRIMARY KEY AUTOINCREMENT,
"name" VARCHAR NOT NULL,
- "created_at" TIMESTAMP NOT NULL DEFAULT now,
+ "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"visibility" VARCHAR NOT NULL
);
@@ -214,7 +214,15 @@ CREATE TABLE IF NOT EXISTS "channel_messages" (
"nonce" BLOB NOT NULL
);
CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id");
-CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce");
+CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce");
+
+CREATE TABLE "channel_message_mentions" (
+ "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE,
+ "start_offset" INTEGER NOT NULL,
+ "end_offset" INTEGER NOT NULL,
+ "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+ PRIMARY KEY(message_id, start_offset)
+);
CREATE TABLE "channel_paths" (
"id_path" TEXT NOT NULL PRIMARY KEY,
@@ -314,3 +322,26 @@ CREATE TABLE IF NOT EXISTS "observed_channel_messages" (
);
CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id");
+
+CREATE TABLE "notification_kinds" (
+ "id" INTEGER PRIMARY KEY AUTOINCREMENT,
+ "name" VARCHAR NOT NULL
+);
+
+CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name");
+
+CREATE TABLE "notifications" (
+ "id" INTEGER PRIMARY KEY AUTOINCREMENT,
+ "created_at" TIMESTAMP NOT NULL default CURRENT_TIMESTAMP,
+ "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+ "kind" INTEGER NOT NULL REFERENCES notification_kinds (id),
+ "entity_id" INTEGER,
+ "content" TEXT,
+ "is_read" BOOLEAN NOT NULL DEFAULT FALSE,
+ "response" BOOLEAN
+);
+
+CREATE INDEX
+ "index_notifications_on_recipient_id_is_read_kind_entity_id"
+ ON "notifications"
+ ("recipient_id", "is_read", "kind", "entity_id");
@@ -0,0 +1,22 @@
+CREATE TABLE "notification_kinds" (
+ "id" SERIAL PRIMARY KEY,
+ "name" VARCHAR NOT NULL
+);
+
+CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name");
+
+CREATE TABLE notifications (
+ "id" SERIAL PRIMARY KEY,
+ "created_at" TIMESTAMP NOT NULL DEFAULT now(),
+ "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+ "kind" INTEGER NOT NULL REFERENCES notification_kinds (id),
+ "entity_id" INTEGER,
+ "content" TEXT,
+ "is_read" BOOLEAN NOT NULL DEFAULT FALSE,
+ "response" BOOLEAN
+);
+
+CREATE INDEX
+ "index_notifications_on_recipient_id_is_read_kind_entity_id"
+ ON "notifications"
+ ("recipient_id", "is_read", "kind", "entity_id");
@@ -0,0 +1,11 @@
+CREATE TABLE "channel_message_mentions" (
+ "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE,
+ "start_offset" INTEGER NOT NULL,
+ "end_offset" INTEGER NOT NULL,
+ "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE,
+ PRIMARY KEY(message_id, start_offset)
+);
+
+-- We use 'on conflict update' with this index, so it should be per-user.
+CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce");
+DROP INDEX "index_channel_messages_on_nonce";
@@ -71,7 +71,6 @@ async fn main() {
db::NewUserParams {
github_login: github_user.login,
github_user_id: github_user.id,
- invite_count: 5,
},
)
.await
@@ -13,6 +13,7 @@ use anyhow::anyhow;
use collections::{BTreeMap, HashMap, HashSet};
use dashmap::DashMap;
use futures::StreamExt;
+use queries::channels::ChannelGraph;
use rand::{prelude::StdRng, Rng, SeedableRng};
use rpc::{
proto::{self},
@@ -20,7 +21,7 @@ use rpc::{
};
use sea_orm::{
entity::prelude::*,
- sea_query::{Alias, Expr, OnConflict, Query},
+ sea_query::{Alias, Expr, OnConflict},
ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr,
FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement,
TransactionTrait,
@@ -47,14 +48,14 @@ pub use ids::*;
pub use sea_orm::ConnectOptions;
pub use tables::user::Model as User;
-use self::queries::channels::ChannelGraph;
-
pub struct Database {
options: ConnectOptions,
pool: DatabaseConnection,
rooms: DashMap<RoomId, Arc<Mutex<()>>>,
rng: Mutex<StdRng>,
executor: Executor,
+ notification_kinds_by_id: HashMap<NotificationKindId, &'static str>,
+ notification_kinds_by_name: HashMap<String, NotificationKindId>,
#[cfg(test)]
runtime: Option<tokio::runtime::Runtime>,
}
@@ -69,6 +70,8 @@ impl Database {
pool: sea_orm::Database::connect(options).await?,
rooms: DashMap::with_capacity(16384),
rng: Mutex::new(StdRng::seed_from_u64(0)),
+ notification_kinds_by_id: HashMap::default(),
+ notification_kinds_by_name: HashMap::default(),
executor,
#[cfg(test)]
runtime: None,
@@ -121,6 +124,11 @@ impl Database {
Ok(new_migrations)
}
+ pub async fn initialize_static_data(&mut self) -> Result<()> {
+ self.initialize_notification_kinds().await?;
+ Ok(())
+ }
+
pub async fn transaction<F, Fut, T>(&self, f: F) -> Result<T>
where
F: Send + Fn(TransactionHandle) -> Fut,
@@ -361,18 +369,9 @@ impl<T> RoomGuard<T> {
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Contact {
- Accepted {
- user_id: UserId,
- should_notify: bool,
- busy: bool,
- },
- Outgoing {
- user_id: UserId,
- },
- Incoming {
- user_id: UserId,
- should_notify: bool,
- },
+ Accepted { user_id: UserId, busy: bool },
+ Outgoing { user_id: UserId },
+ Incoming { user_id: UserId },
}
impl Contact {
@@ -385,6 +384,15 @@ impl Contact {
}
}
+pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
+
+pub struct CreatedChannelMessage {
+ pub message_id: MessageId,
+ pub participant_connection_ids: Vec<ConnectionId>,
+ pub channel_members: Vec<UserId>,
+ pub notifications: NotificationBatch,
+}
+
#[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
pub struct Invite {
pub email_address: String,
@@ -417,7 +425,6 @@ pub struct WaitlistSummary {
pub struct NewUserParams {
pub github_login: String,
pub github_user_id: i32,
- pub invite_count: i32,
}
#[derive(Debug)]
@@ -466,6 +473,24 @@ pub enum SetMemberRoleResult {
MembershipUpdated(MembershipUpdated),
}
+#[derive(Debug)]
+pub struct InviteMemberResult {
+ pub channel: Channel,
+ pub notifications: NotificationBatch,
+}
+
+#[derive(Debug)]
+pub struct RespondToChannelInvite {
+ pub membership_update: Option<MembershipUpdated>,
+ pub notifications: NotificationBatch,
+}
+
+#[derive(Debug)]
+pub struct RemoveChannelMemberResult {
+ pub membership_update: MembershipUpdated,
+ pub notification_id: Option<NotificationId>,
+}
+
#[derive(FromQueryResult, Debug, PartialEq, Eq, Hash)]
pub struct Channel {
pub id: ChannelId,
@@ -81,6 +81,8 @@ id_type!(SignupId);
id_type!(UserId);
id_type!(ChannelBufferCollaboratorId);
id_type!(FlagId);
+id_type!(NotificationId);
+id_type!(NotificationKindId);
#[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default, Hash)]
#[sea_orm(rs_type = "String", db_type = "String(None)")]
@@ -5,6 +5,7 @@ pub mod buffers;
pub mod channels;
pub mod contacts;
pub mod messages;
+pub mod notifications;
pub mod projects;
pub mod rooms;
pub mod servers;
@@ -1,4 +1,5 @@
use super::*;
+use sea_orm::sea_query::Query;
impl Database {
pub async fn create_access_token(
@@ -349,11 +349,11 @@ impl Database {
&self,
channel_id: ChannelId,
invitee_id: UserId,
- admin_id: UserId,
+ inviter_id: UserId,
role: ChannelRole,
- ) -> Result<Channel> {
+ ) -> Result<InviteMemberResult> {
self.transaction(move |tx| async move {
- self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
+ self.check_user_is_channel_admin(channel_id, inviter_id, &*tx)
.await?;
channel_member::ActiveModel {
@@ -371,11 +371,31 @@ impl Database {
.await?
.unwrap();
- Ok(Channel {
+ let channel = Channel {
id: channel.id,
visibility: channel.visibility,
name: channel.name,
role,
+ };
+
+ let notifications = self
+ .create_notification(
+ invitee_id,
+ rpc::Notification::ChannelInvitation {
+ channel_id: channel_id.to_proto(),
+ channel_name: channel.name.clone(),
+ inviter_id: inviter_id.to_proto(),
+ },
+ true,
+ &*tx,
+ )
+ .await?
+ .into_iter()
+ .collect();
+
+ Ok(InviteMemberResult {
+ channel,
+ notifications,
})
})
.await
@@ -445,9 +465,9 @@ impl Database {
channel_id: ChannelId,
user_id: UserId,
accept: bool,
- ) -> Result<Option<MembershipUpdated>> {
+ ) -> Result<RespondToChannelInvite> {
self.transaction(move |tx| async move {
- if accept {
+ let membership_update = if accept {
let rows_affected = channel_member::Entity::update_many()
.set(channel_member::ActiveModel {
accepted: ActiveValue::Set(accept),
@@ -467,26 +487,45 @@ impl Database {
Err(anyhow!("no such invitation"))?;
}
- return Ok(Some(
+ Some(
self.calculate_membership_updated(channel_id, user_id, &*tx)
.await?,
- ));
- }
-
- let rows_affected = channel_member::ActiveModel {
- channel_id: ActiveValue::Unchanged(channel_id),
- user_id: ActiveValue::Unchanged(user_id),
- ..Default::default()
- }
- .delete(&*tx)
- .await?
- .rows_affected;
+ )
+ } else {
+ let rows_affected = channel_member::Entity::delete_many()
+ .filter(
+ channel_member::Column::ChannelId
+ .eq(channel_id)
+ .and(channel_member::Column::UserId.eq(user_id))
+ .and(channel_member::Column::Accepted.eq(false)),
+ )
+ .exec(&*tx)
+ .await?
+ .rows_affected;
+ if rows_affected == 0 {
+ Err(anyhow!("no such invitation"))?;
+ }
- if rows_affected == 0 {
- Err(anyhow!("no such invitation"))?;
- }
+ None
+ };
- Ok(None)
+ Ok(RespondToChannelInvite {
+ membership_update,
+ notifications: self
+ .mark_notification_as_read_with_response(
+ user_id,
+ &rpc::Notification::ChannelInvitation {
+ channel_id: channel_id.to_proto(),
+ channel_name: Default::default(),
+ inviter_id: Default::default(),
+ },
+ accept,
+ &*tx,
+ )
+ .await?
+ .into_iter()
+ .collect(),
+ })
})
.await
}
@@ -550,7 +589,7 @@ impl Database {
channel_id: ChannelId,
member_id: UserId,
admin_id: UserId,
- ) -> Result<MembershipUpdated> {
+ ) -> Result<RemoveChannelMemberResult> {
self.transaction(|tx| async move {
self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
.await?;
@@ -568,9 +607,22 @@ impl Database {
Err(anyhow!("no such member"))?;
}
- Ok(self
- .calculate_membership_updated(channel_id, member_id, &*tx)
- .await?)
+ Ok(RemoveChannelMemberResult {
+ membership_update: self
+ .calculate_membership_updated(channel_id, member_id, &*tx)
+ .await?,
+ notification_id: self
+ .remove_notification(
+ member_id,
+ rpc::Notification::ChannelInvitation {
+ channel_id: channel_id.to_proto(),
+ channel_name: Default::default(),
+ inviter_id: Default::default(),
+ },
+ &*tx,
+ )
+ .await?,
+ })
})
.await
}
@@ -911,6 +963,47 @@ impl Database {
.await
}
+ pub async fn get_channel_participant_details(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ ) -> Result<Vec<proto::ChannelMember>> {
+ let (role, members) = self
+ .transaction(move |tx| async move {
+ let role = self
+ .check_user_is_channel_participant(channel_id, user_id, &*tx)
+ .await?;
+ Ok((
+ role,
+ self.get_channel_participant_details_internal(channel_id, &*tx)
+ .await?,
+ ))
+ })
+ .await?;
+
+ if role == ChannelRole::Admin {
+ Ok(members
+ .into_iter()
+ .map(|channel_member| channel_member.to_proto())
+ .collect())
+ } else {
+ return Ok(members
+ .into_iter()
+ .filter_map(|member| {
+ if member.kind == proto::channel_member::Kind::Invitee {
+ return None;
+ }
+ Some(ChannelMember {
+ role: member.role,
+ user_id: member.user_id,
+ kind: proto::channel_member::Kind::Member,
+ })
+ })
+ .map(|channel_member| channel_member.to_proto())
+ .collect());
+ }
+ }
+
async fn get_channel_participant_details_internal(
&self,
channel_id: ChannelId,
@@ -1003,28 +1096,6 @@ impl Database {
.collect())
}
- pub async fn get_channel_participant_details(
- &self,
- channel_id: ChannelId,
- admin_id: UserId,
- ) -> Result<Vec<proto::ChannelMember>> {
- let members = self
- .transaction(move |tx| async move {
- self.check_user_is_channel_admin(channel_id, admin_id, &*tx)
- .await?;
-
- Ok(self
- .get_channel_participant_details_internal(channel_id, &*tx)
- .await?)
- })
- .await?;
-
- Ok(members
- .into_iter()
- .map(|channel_member| channel_member.to_proto())
- .collect())
- }
-
pub async fn get_channel_participants(
&self,
channel_id: ChannelId,
@@ -1062,9 +1133,10 @@ impl Database {
channel_id: ChannelId,
user_id: UserId,
tx: &DatabaseTransaction,
- ) -> Result<()> {
- match self.channel_role_for_user(channel_id, user_id, tx).await? {
- Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(()),
+ ) -> Result<ChannelRole> {
+ let channel_role = self.channel_role_for_user(channel_id, user_id, tx).await?;
+ match channel_role {
+ Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(channel_role.unwrap()),
Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!(
"user is not a channel member or channel does not exist"
))?,
@@ -8,7 +8,6 @@ impl Database {
user_id_b: UserId,
a_to_b: bool,
accepted: bool,
- should_notify: bool,
user_a_busy: bool,
user_b_busy: bool,
}
@@ -53,7 +52,6 @@ impl Database {
if db_contact.accepted {
contacts.push(Contact::Accepted {
user_id: db_contact.user_id_b,
- should_notify: db_contact.should_notify && db_contact.a_to_b,
busy: db_contact.user_b_busy,
});
} else if db_contact.a_to_b {
@@ -63,19 +61,16 @@ impl Database {
} else {
contacts.push(Contact::Incoming {
user_id: db_contact.user_id_b,
- should_notify: db_contact.should_notify,
});
}
} else if db_contact.accepted {
contacts.push(Contact::Accepted {
user_id: db_contact.user_id_a,
- should_notify: db_contact.should_notify && !db_contact.a_to_b,
busy: db_contact.user_a_busy,
});
} else if db_contact.a_to_b {
contacts.push(Contact::Incoming {
user_id: db_contact.user_id_a,
- should_notify: db_contact.should_notify,
});
} else {
contacts.push(Contact::Outgoing {
@@ -124,7 +119,11 @@ impl Database {
.await
}
- pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
+ pub async fn send_contact_request(
+ &self,
+ sender_id: UserId,
+ receiver_id: UserId,
+ ) -> Result<NotificationBatch> {
self.transaction(|tx| async move {
let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
(sender_id, receiver_id, true)
@@ -161,11 +160,22 @@ impl Database {
.exec_without_returning(&*tx)
.await?;
- if rows_affected == 1 {
- Ok(())
- } else {
- Err(anyhow!("contact already requested"))?
+ if rows_affected == 0 {
+ Err(anyhow!("contact already requested"))?;
}
+
+ Ok(self
+ .create_notification(
+ receiver_id,
+ rpc::Notification::ContactRequest {
+ sender_id: sender_id.to_proto(),
+ },
+ true,
+ &*tx,
+ )
+ .await?
+ .into_iter()
+ .collect())
})
.await
}
@@ -179,7 +189,11 @@ impl Database {
///
/// * `requester_id` - The user that initiates this request
/// * `responder_id` - The user that will be removed
- pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<bool> {
+ pub async fn remove_contact(
+ &self,
+ requester_id: UserId,
+ responder_id: UserId,
+ ) -> Result<(bool, Option<NotificationId>)> {
self.transaction(|tx| async move {
let (id_a, id_b) = if responder_id < requester_id {
(responder_id, requester_id)
@@ -198,7 +212,21 @@ impl Database {
.ok_or_else(|| anyhow!("no such contact"))?;
contact::Entity::delete_by_id(contact.id).exec(&*tx).await?;
- Ok(contact.accepted)
+
+ let mut deleted_notification_id = None;
+ if !contact.accepted {
+ deleted_notification_id = self
+ .remove_notification(
+ responder_id,
+ rpc::Notification::ContactRequest {
+ sender_id: requester_id.to_proto(),
+ },
+ &*tx,
+ )
+ .await?;
+ }
+
+ Ok((contact.accepted, deleted_notification_id))
})
.await
}
@@ -249,7 +277,7 @@ impl Database {
responder_id: UserId,
requester_id: UserId,
accept: bool,
- ) -> Result<()> {
+ ) -> Result<NotificationBatch> {
self.transaction(|tx| async move {
let (id_a, id_b, a_to_b) = if responder_id < requester_id {
(responder_id, requester_id, false)
@@ -287,11 +315,38 @@ impl Database {
result.rows_affected
};
- if rows_affected == 1 {
- Ok(())
- } else {
+ if rows_affected == 0 {
Err(anyhow!("no such contact request"))?
}
+
+ let mut notifications = Vec::new();
+ notifications.extend(
+ self.mark_notification_as_read_with_response(
+ responder_id,
+ &rpc::Notification::ContactRequest {
+ sender_id: requester_id.to_proto(),
+ },
+ accept,
+ &*tx,
+ )
+ .await?,
+ );
+
+ if accept {
+ notifications.extend(
+ self.create_notification(
+ requester_id,
+ rpc::Notification::ContactRequestAccepted {
+ responder_id: responder_id.to_proto(),
+ },
+ true,
+ &*tx,
+ )
+ .await?,
+ );
+ }
+
+ Ok(notifications)
})
.await
}
@@ -1,4 +1,7 @@
use super::*;
+use futures::Stream;
+use rpc::Notification;
+use sea_orm::TryInsertResult;
use time::OffsetDateTime;
impl Database {
@@ -87,43 +90,118 @@ impl Database {
condition = condition.add(channel_message::Column::Id.lt(before_message_id));
}
- let mut rows = channel_message::Entity::find()
+ let rows = channel_message::Entity::find()
.filter(condition)
.order_by_desc(channel_message::Column::Id)
.limit(count as u64)
.stream(&*tx)
.await?;
- let mut messages = Vec::new();
- while let Some(row) = rows.next().await {
- let row = row?;
- let nonce = row.nonce.as_u64_pair();
- messages.push(proto::ChannelMessage {
- id: row.id.to_proto(),
- sender_id: row.sender_id.to_proto(),
- body: row.body,
- timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
- nonce: Some(proto::Nonce {
- upper_half: nonce.0,
- lower_half: nonce.1,
+ self.load_channel_messages(rows, &*tx).await
+ })
+ .await
+ }
+
+ pub async fn get_channel_messages_by_id(
+ &self,
+ user_id: UserId,
+ message_ids: &[MessageId],
+ ) -> Result<Vec<proto::ChannelMessage>> {
+ self.transaction(|tx| async move {
+ let rows = channel_message::Entity::find()
+ .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
+ .order_by_desc(channel_message::Column::Id)
+ .stream(&*tx)
+ .await?;
+
+ let mut channel_ids = HashSet::<ChannelId>::default();
+ let messages = self
+ .load_channel_messages(
+ rows.map(|row| {
+ row.map(|row| {
+ channel_ids.insert(row.channel_id);
+ row
+ })
}),
- });
+ &*tx,
+ )
+ .await?;
+
+ for channel_id in channel_ids {
+ self.check_user_is_channel_member(channel_id, user_id, &*tx)
+ .await?;
}
- drop(rows);
- messages.reverse();
+
Ok(messages)
})
.await
}
+ async fn load_channel_messages(
+ &self,
+ mut rows: impl Send + Unpin + Stream<Item = Result<channel_message::Model, sea_orm::DbErr>>,
+ tx: &DatabaseTransaction,
+ ) -> Result<Vec<proto::ChannelMessage>> {
+ let mut messages = Vec::new();
+ while let Some(row) = rows.next().await {
+ let row = row?;
+ let nonce = row.nonce.as_u64_pair();
+ messages.push(proto::ChannelMessage {
+ id: row.id.to_proto(),
+ sender_id: row.sender_id.to_proto(),
+ body: row.body,
+ timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
+ mentions: vec![],
+ nonce: Some(proto::Nonce {
+ upper_half: nonce.0,
+ lower_half: nonce.1,
+ }),
+ });
+ }
+ drop(rows);
+ messages.reverse();
+
+ let mut mentions = channel_message_mention::Entity::find()
+ .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
+ .order_by_asc(channel_message_mention::Column::MessageId)
+ .order_by_asc(channel_message_mention::Column::StartOffset)
+ .stream(&*tx)
+ .await?;
+
+ let mut message_ix = 0;
+ while let Some(mention) = mentions.next().await {
+ let mention = mention?;
+ let message_id = mention.message_id.to_proto();
+ while let Some(message) = messages.get_mut(message_ix) {
+ if message.id < message_id {
+ message_ix += 1;
+ } else {
+ if message.id == message_id {
+ message.mentions.push(proto::ChatMention {
+ range: Some(proto::Range {
+ start: mention.start_offset as u64,
+ end: mention.end_offset as u64,
+ }),
+ user_id: mention.user_id.to_proto(),
+ });
+ }
+ break;
+ }
+ }
+ }
+
+ Ok(messages)
+ }
+
pub async fn create_channel_message(
&self,
channel_id: ChannelId,
user_id: UserId,
body: &str,
+ mentions: &[proto::ChatMention],
timestamp: OffsetDateTime,
nonce: u128,
- ) -> Result<(MessageId, Vec<ConnectionId>, Vec<UserId>)> {
+ ) -> Result<CreatedChannelMessage> {
self.transaction(|tx| async move {
self.check_user_is_channel_participant(channel_id, user_id, &*tx)
.await?;
@@ -153,7 +231,7 @@ impl Database {
let timestamp = timestamp.to_offset(time::UtcOffset::UTC);
let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time());
- let message = channel_message::Entity::insert(channel_message::ActiveModel {
+ let result = channel_message::Entity::insert(channel_message::ActiveModel {
channel_id: ActiveValue::Set(channel_id),
sender_id: ActiveValue::Set(user_id),
body: ActiveValue::Set(body.to_string()),
@@ -162,35 +240,85 @@ impl Database {
id: ActiveValue::NotSet,
})
.on_conflict(
- OnConflict::column(channel_message::Column::Nonce)
- .update_column(channel_message::Column::Nonce)
- .to_owned(),
+ OnConflict::columns([
+ channel_message::Column::SenderId,
+ channel_message::Column::Nonce,
+ ])
+ .do_nothing()
+ .to_owned(),
)
+ .do_nothing()
.exec(&*tx)
.await?;
- #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)]
- enum QueryConnectionId {
- ConnectionId,
- }
+ let message_id;
+ let mut notifications = Vec::new();
+ match result {
+ TryInsertResult::Inserted(result) => {
+ message_id = result.last_insert_id;
+ let mentioned_user_ids =
+ mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
+ let mentions = mentions
+ .iter()
+ .filter_map(|mention| {
+ let range = mention.range.as_ref()?;
+ if !body.is_char_boundary(range.start as usize)
+ || !body.is_char_boundary(range.end as usize)
+ {
+ return None;
+ }
+ Some(channel_message_mention::ActiveModel {
+ message_id: ActiveValue::Set(message_id),
+ start_offset: ActiveValue::Set(range.start as i32),
+ end_offset: ActiveValue::Set(range.end as i32),
+ user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)),
+ })
+ })
+ .collect::<Vec<_>>();
+ if !mentions.is_empty() {
+ channel_message_mention::Entity::insert_many(mentions)
+ .exec(&*tx)
+ .await?;
+ }
- // Observe this message for the sender
- self.observe_channel_message_internal(
- channel_id,
- user_id,
- message.last_insert_id,
- &*tx,
- )
- .await?;
+ for mentioned_user in mentioned_user_ids {
+ notifications.extend(
+ self.create_notification(
+ UserId::from_proto(mentioned_user),
+ rpc::Notification::ChannelMessageMention {
+ message_id: message_id.to_proto(),
+ sender_id: user_id.to_proto(),
+ channel_id: channel_id.to_proto(),
+ },
+ false,
+ &*tx,
+ )
+ .await?,
+ );
+ }
+
+ self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
+ .await?;
+ }
+ _ => {
+ message_id = channel_message::Entity::find()
+ .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce)))
+ .one(&*tx)
+ .await?
+ .ok_or_else(|| anyhow!("failed to insert message"))?
+ .id;
+ }
+ }
let mut channel_members = self.get_channel_participants(channel_id, &*tx).await?;
channel_members.retain(|member| !participant_user_ids.contains(member));
- Ok((
- message.last_insert_id,
+ Ok(CreatedChannelMessage {
+ message_id,
participant_connection_ids,
channel_members,
- ))
+ notifications,
+ })
})
.await
}
@@ -200,11 +328,24 @@ impl Database {
channel_id: ChannelId,
user_id: UserId,
message_id: MessageId,
- ) -> Result<()> {
+ ) -> Result<NotificationBatch> {
self.transaction(|tx| async move {
self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
.await?;
- Ok(())
+ let mut batch = NotificationBatch::default();
+ batch.extend(
+ self.mark_notification_as_read(
+ user_id,
+ &Notification::ChannelMessageMention {
+ message_id: message_id.to_proto(),
+ sender_id: Default::default(),
+ channel_id: Default::default(),
+ },
+ &*tx,
+ )
+ .await?,
+ );
+ Ok(batch)
})
.await
}
@@ -0,0 +1,262 @@
+use super::*;
+use rpc::Notification;
+
+impl Database {
+ pub async fn initialize_notification_kinds(&mut self) -> Result<()> {
+ notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map(
+ |kind| notification_kind::ActiveModel {
+ name: ActiveValue::Set(kind.to_string()),
+ ..Default::default()
+ },
+ ))
+ .on_conflict(OnConflict::new().do_nothing().to_owned())
+ .exec_without_returning(&self.pool)
+ .await?;
+
+ let mut rows = notification_kind::Entity::find().stream(&self.pool).await?;
+ while let Some(row) = rows.next().await {
+ let row = row?;
+ self.notification_kinds_by_name.insert(row.name, row.id);
+ }
+
+ for name in Notification::all_variant_names() {
+ if let Some(id) = self.notification_kinds_by_name.get(*name).copied() {
+ self.notification_kinds_by_id.insert(id, name);
+ }
+ }
+
+ Ok(())
+ }
+
+ pub async fn get_notifications(
+ &self,
+ recipient_id: UserId,
+ limit: usize,
+ before_id: Option<NotificationId>,
+ ) -> Result<Vec<proto::Notification>> {
+ self.transaction(|tx| async move {
+ let mut result = Vec::new();
+ let mut condition =
+ Condition::all().add(notification::Column::RecipientId.eq(recipient_id));
+
+ if let Some(before_id) = before_id {
+ condition = condition.add(notification::Column::Id.lt(before_id));
+ }
+
+ let mut rows = notification::Entity::find()
+ .filter(condition)
+ .order_by_desc(notification::Column::Id)
+ .limit(limit as u64)
+ .stream(&*tx)
+ .await?;
+ while let Some(row) = rows.next().await {
+ let row = row?;
+ let kind = row.kind;
+ if let Some(proto) = model_to_proto(self, row) {
+ result.push(proto);
+ } else {
+ log::warn!("unknown notification kind {:?}", kind);
+ }
+ }
+ result.reverse();
+ Ok(result)
+ })
+ .await
+ }
+
+ /// Create a notification. If `avoid_duplicates` is set to true, then avoid
+ /// creating a new notification if the given recipient already has an
+ /// unread notification with the given kind and entity id.
+ pub async fn create_notification(
+ &self,
+ recipient_id: UserId,
+ notification: Notification,
+ avoid_duplicates: bool,
+ tx: &DatabaseTransaction,
+ ) -> Result<Option<(UserId, proto::Notification)>> {
+ if avoid_duplicates {
+ if self
+ .find_notification(recipient_id, ¬ification, tx)
+ .await?
+ .is_some()
+ {
+ return Ok(None);
+ }
+ }
+
+ let proto = notification.to_proto();
+ let kind = notification_kind_from_proto(self, &proto)?;
+ let model = notification::ActiveModel {
+ recipient_id: ActiveValue::Set(recipient_id),
+ kind: ActiveValue::Set(kind),
+ entity_id: ActiveValue::Set(proto.entity_id.map(|id| id as i32)),
+ content: ActiveValue::Set(proto.content.clone()),
+ ..Default::default()
+ }
+ .save(&*tx)
+ .await?;
+
+ Ok(Some((
+ recipient_id,
+ proto::Notification {
+ id: model.id.as_ref().to_proto(),
+ kind: proto.kind,
+ timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64,
+ is_read: false,
+ response: None,
+ content: proto.content,
+ entity_id: proto.entity_id,
+ },
+ )))
+ }
+
+ /// Remove an unread notification with the given recipient, kind and
+ /// entity id.
+ pub async fn remove_notification(
+ &self,
+ recipient_id: UserId,
+ notification: Notification,
+ tx: &DatabaseTransaction,
+ ) -> Result<Option<NotificationId>> {
+ let id = self
+ .find_notification(recipient_id, ¬ification, tx)
+ .await?;
+ if let Some(id) = id {
+ notification::Entity::delete_by_id(id).exec(tx).await?;
+ }
+ Ok(id)
+ }
+
+ /// Populate the response for the notification with the given kind and
+ /// entity id.
+ pub async fn mark_notification_as_read_with_response(
+ &self,
+ recipient_id: UserId,
+ notification: &Notification,
+ response: bool,
+ tx: &DatabaseTransaction,
+ ) -> Result<Option<(UserId, proto::Notification)>> {
+ self.mark_notification_as_read_internal(recipient_id, notification, Some(response), tx)
+ .await
+ }
+
+ pub async fn mark_notification_as_read(
+ &self,
+ recipient_id: UserId,
+ notification: &Notification,
+ tx: &DatabaseTransaction,
+ ) -> Result<Option<(UserId, proto::Notification)>> {
+ self.mark_notification_as_read_internal(recipient_id, notification, None, tx)
+ .await
+ }
+
+ pub async fn mark_notification_as_read_by_id(
+ &self,
+ recipient_id: UserId,
+ notification_id: NotificationId,
+ ) -> Result<NotificationBatch> {
+ self.transaction(|tx| async move {
+ let row = notification::Entity::update(notification::ActiveModel {
+ id: ActiveValue::Unchanged(notification_id),
+ recipient_id: ActiveValue::Unchanged(recipient_id),
+ is_read: ActiveValue::Set(true),
+ ..Default::default()
+ })
+ .exec(&*tx)
+ .await?;
+ Ok(model_to_proto(self, row)
+ .map(|notification| (recipient_id, notification))
+ .into_iter()
+ .collect())
+ })
+ .await
+ }
+
+ async fn mark_notification_as_read_internal(
+ &self,
+ recipient_id: UserId,
+ notification: &Notification,
+ response: Option<bool>,
+ tx: &DatabaseTransaction,
+ ) -> Result<Option<(UserId, proto::Notification)>> {
+ if let Some(id) = self
+ .find_notification(recipient_id, notification, &*tx)
+ .await?
+ {
+ let row = notification::Entity::update(notification::ActiveModel {
+ id: ActiveValue::Unchanged(id),
+ recipient_id: ActiveValue::Unchanged(recipient_id),
+ is_read: ActiveValue::Set(true),
+ response: if let Some(response) = response {
+ ActiveValue::Set(Some(response))
+ } else {
+ ActiveValue::NotSet
+ },
+ ..Default::default()
+ })
+ .exec(tx)
+ .await?;
+ Ok(model_to_proto(self, row).map(|notification| (recipient_id, notification)))
+ } else {
+ Ok(None)
+ }
+ }
+
+ /// Find an unread notification by its recipient, kind and entity id.
+ async fn find_notification(
+ &self,
+ recipient_id: UserId,
+ notification: &Notification,
+ tx: &DatabaseTransaction,
+ ) -> Result<Option<NotificationId>> {
+ let proto = notification.to_proto();
+ let kind = notification_kind_from_proto(self, &proto)?;
+
+ #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+ enum QueryIds {
+ Id,
+ }
+
+ Ok(notification::Entity::find()
+ .select_only()
+ .column(notification::Column::Id)
+ .filter(
+ Condition::all()
+ .add(notification::Column::RecipientId.eq(recipient_id))
+ .add(notification::Column::IsRead.eq(false))
+ .add(notification::Column::Kind.eq(kind))
+ .add(if proto.entity_id.is_some() {
+ notification::Column::EntityId.eq(proto.entity_id)
+ } else {
+ notification::Column::EntityId.is_null()
+ }),
+ )
+ .into_values::<_, QueryIds>()
+ .one(&*tx)
+ .await?)
+ }
+}
+
+fn model_to_proto(this: &Database, row: notification::Model) -> Option<proto::Notification> {
+ let kind = this.notification_kinds_by_id.get(&row.kind)?;
+ Some(proto::Notification {
+ id: row.id.to_proto(),
+ kind: kind.to_string(),
+ timestamp: row.created_at.assume_utc().unix_timestamp() as u64,
+ is_read: row.is_read,
+ response: row.response,
+ content: row.content,
+ entity_id: row.entity_id.map(|id| id as u64),
+ })
+}
+
+fn notification_kind_from_proto(
+ this: &Database,
+ proto: &proto::Notification,
+) -> Result<NotificationKindId> {
+ Ok(this
+ .notification_kinds_by_name
+ .get(&proto.kind)
+ .copied()
+ .ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?)
+}
@@ -7,11 +7,14 @@ pub mod channel_buffer_collaborator;
pub mod channel_chat_participant;
pub mod channel_member;
pub mod channel_message;
+pub mod channel_message_mention;
pub mod channel_path;
pub mod contact;
pub mod feature_flag;
pub mod follower;
pub mod language_server;
+pub mod notification;
+pub mod notification_kind;
pub mod observed_buffer_edits;
pub mod observed_channel_messages;
pub mod project;
@@ -0,0 +1,43 @@
+use crate::db::{MessageId, UserId};
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "channel_message_mentions")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub message_id: MessageId,
+ #[sea_orm(primary_key)]
+ pub start_offset: i32,
+ pub end_offset: i32,
+ pub user_id: UserId,
+}
+
+impl ActiveModelBehavior for ActiveModel {}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::channel_message::Entity",
+ from = "Column::MessageId",
+ to = "super::channel_message::Column::Id"
+ )]
+ Message,
+ #[sea_orm(
+ belongs_to = "super::user::Entity",
+ from = "Column::UserId",
+ to = "super::user::Column::Id"
+ )]
+ MentionedUser,
+}
+
+impl Related<super::channel::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::Message.def()
+ }
+}
+
+impl Related<super::user::Entity> for Entity {
+ fn to() -> RelationDef {
+ Relation::MentionedUser.def()
+ }
+}
@@ -0,0 +1,29 @@
+use crate::db::{NotificationId, NotificationKindId, UserId};
+use sea_orm::entity::prelude::*;
+use time::PrimitiveDateTime;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "notifications")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: NotificationId,
+ pub created_at: PrimitiveDateTime,
+ pub recipient_id: UserId,
+ pub kind: NotificationKindId,
+ pub entity_id: Option<i32>,
+ pub content: String,
+ pub is_read: bool,
+ pub response: Option<bool>,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {
+ #[sea_orm(
+ belongs_to = "super::user::Entity",
+ from = "Column::RecipientId",
+ to = "super::user::Column::Id"
+ )]
+ Recipient,
+}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -0,0 +1,15 @@
+use crate::db::NotificationKindId;
+use sea_orm::entity::prelude::*;
+
+#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
+#[sea_orm(table_name = "notification_kinds")]
+pub struct Model {
+ #[sea_orm(primary_key)]
+ pub id: NotificationKindId,
+ pub name: String,
+}
+
+#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
+pub enum Relation {}
+
+impl ActiveModelBehavior for ActiveModel {}
@@ -10,7 +10,10 @@ use parking_lot::Mutex;
use rpc::proto::ChannelEdge;
use sea_orm::ConnectionTrait;
use sqlx::migrate::MigrateDatabase;
-use std::sync::Arc;
+use std::sync::{
+ atomic::{AtomicI32, AtomicU32, Ordering::SeqCst},
+ Arc,
+};
const TEST_RELEASE_CHANNEL: &'static str = "test";
@@ -31,7 +34,7 @@ impl TestDb {
let mut db = runtime.block_on(async {
let mut options = ConnectOptions::new(url);
options.max_connections(5);
- let db = Database::new(options, Executor::Deterministic(background))
+ let mut db = Database::new(options, Executor::Deterministic(background))
.await
.unwrap();
let sql = include_str!(concat!(
@@ -45,6 +48,7 @@ impl TestDb {
))
.await
.unwrap();
+ db.initialize_notification_kinds().await.unwrap();
db
});
@@ -79,11 +83,12 @@ impl TestDb {
options
.max_connections(5)
.idle_timeout(Duration::from_secs(0));
- let db = Database::new(options, Executor::Deterministic(background))
+ let mut db = Database::new(options, Executor::Deterministic(background))
.await
.unwrap();
let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations");
db.migrate(Path::new(migrations_path), false).await.unwrap();
+ db.initialize_notification_kinds().await.unwrap();
db
});
@@ -176,3 +181,27 @@ fn graph(
graph
}
+
+static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
+
+async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
+ db.create_user(
+ email,
+ false,
+ NewUserParams {
+ github_login: email[0..email.find("@").unwrap()].to_string(),
+ github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst),
+ },
+ )
+ .await
+ .unwrap()
+ .user_id
+}
+
+static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
+fn new_test_connection(server: ServerId) -> ConnectionId {
+ ConnectionId {
+ id: TEST_CONNECTION_ID.fetch_add(1, SeqCst),
+ owner_id: server.0 as u32,
+ }
+}
@@ -17,7 +17,6 @@ async fn test_channel_buffers(db: &Arc<Database>) {
NewUserParams {
github_login: "user_a".into(),
github_user_id: 101,
- invite_count: 0,
},
)
.await
@@ -30,7 +29,6 @@ async fn test_channel_buffers(db: &Arc<Database>) {
NewUserParams {
github_login: "user_b".into(),
github_user_id: 102,
- invite_count: 0,
},
)
.await
@@ -45,7 +43,6 @@ async fn test_channel_buffers(db: &Arc<Database>) {
NewUserParams {
github_login: "user_c".into(),
github_user_id: 102,
- invite_count: 0,
},
)
.await
@@ -178,7 +175,6 @@ async fn test_channel_buffers_last_operations(db: &Database) {
NewUserParams {
github_login: "user_a".into(),
github_user_id: 101,
- invite_count: 0,
},
)
.await
@@ -191,7 +187,6 @@ async fn test_channel_buffers_last_operations(db: &Database) {
NewUserParams {
github_login: "user_b".into(),
github_user_id: 102,
- invite_count: 0,
},
)
.await
@@ -1,20 +1,17 @@
-use collections::{HashMap, HashSet};
-use rpc::{
- proto::{self},
- ConnectionId,
-};
+use std::sync::Arc;
use crate::{
db::{
queries::channels::ChannelGraph,
- tests::{graph, TEST_RELEASE_CHANNEL},
- ChannelId, ChannelRole, Database, NewUserParams, RoomId, ServerId, UserId,
+ tests::{graph, new_test_connection, new_test_user, TEST_RELEASE_CHANNEL},
+ ChannelId, ChannelRole, Database, NewUserParams, RoomId,
},
test_both_dbs,
};
-use std::sync::{
- atomic::{AtomicI32, AtomicU32, Ordering},
- Arc,
+use collections::{HashMap, HashSet};
+use rpc::{
+ proto::{self},
+ ConnectionId,
};
test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite);
@@ -305,7 +302,6 @@ async fn test_channel_renames(db: &Arc<Database>) {
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
- invite_count: 0,
},
)
.await
@@ -319,7 +315,6 @@ async fn test_channel_renames(db: &Arc<Database>) {
NewUserParams {
github_login: "user2".into(),
github_user_id: 6,
- invite_count: 0,
},
)
.await
@@ -360,7 +355,6 @@ async fn test_db_channel_moving(db: &Arc<Database>) {
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
- invite_count: 0,
},
)
.await
@@ -727,7 +721,6 @@ async fn test_db_channel_moving_bugs(db: &Arc<Database>) {
NewUserParams {
github_login: "user1".into(),
github_user_id: 5,
- invite_count: 0,
},
)
.await
@@ -1122,28 +1115,3 @@ fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option<ChannelId>)])
pretty_assertions::assert_eq!(actual_map, expected_map)
}
-
-static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5);
-
-async fn new_test_user(db: &Arc<Database>, email: &str) -> UserId {
- db.create_user(
- email,
- false,
- NewUserParams {
- github_login: email[0..email.find("@").unwrap()].to_string(),
- github_user_id: GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst),
- invite_count: 0,
- },
- )
- .await
- .unwrap()
- .user_id
-}
-
-static TEST_CONNECTION_ID: AtomicU32 = AtomicU32::new(1);
-fn new_test_connection(server: ServerId) -> ConnectionId {
- ConnectionId {
- id: TEST_CONNECTION_ID.fetch_add(1, Ordering::SeqCst),
- owner_id: server.0 as u32,
- }
-}
@@ -22,7 +22,6 @@ async fn test_get_users(db: &Arc<Database>) {
NewUserParams {
github_login: format!("user{i}"),
github_user_id: i,
- invite_count: 0,
},
)
.await
@@ -88,7 +87,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc<Database>) {
NewUserParams {
github_login: "login1".into(),
github_user_id: 101,
- invite_count: 0,
},
)
.await
@@ -101,7 +99,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc<Database>) {
NewUserParams {
github_login: "login2".into(),
github_user_id: 102,
- invite_count: 0,
},
)
.await
@@ -156,7 +153,6 @@ async fn test_create_access_tokens(db: &Arc<Database>) {
NewUserParams {
github_login: "u1".into(),
github_user_id: 1,
- invite_count: 0,
},
)
.await
@@ -238,7 +234,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
NewUserParams {
github_login: format!("user{i}"),
github_user_id: i,
- invite_count: 0,
},
)
.await
@@ -264,10 +259,7 @@ async fn test_add_contacts(db: &Arc<Database>) {
);
assert_eq!(
db.get_contacts(user_2).await.unwrap(),
- &[Contact::Incoming {
- user_id: user_1,
- should_notify: true
- }]
+ &[Contact::Incoming { user_id: user_1 }]
);
// User 2 dismisses the contact request notification without accepting or rejecting.
@@ -280,10 +272,7 @@ async fn test_add_contacts(db: &Arc<Database>) {
.unwrap();
assert_eq!(
db.get_contacts(user_2).await.unwrap(),
- &[Contact::Incoming {
- user_id: user_1,
- should_notify: false
- }]
+ &[Contact::Incoming { user_id: user_1 }]
);
// User can't accept their own contact request
@@ -299,7 +288,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_1).await.unwrap(),
&[Contact::Accepted {
user_id: user_2,
- should_notify: true,
busy: false,
}],
);
@@ -309,7 +297,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_2).await.unwrap(),
&[Contact::Accepted {
user_id: user_1,
- should_notify: false,
busy: false,
}]
);
@@ -326,7 +313,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_1).await.unwrap(),
&[Contact::Accepted {
user_id: user_2,
- should_notify: true,
busy: false,
}]
);
@@ -339,7 +325,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_1).await.unwrap(),
&[Contact::Accepted {
user_id: user_2,
- should_notify: false,
busy: false,
}]
);
@@ -353,12 +338,10 @@ async fn test_add_contacts(db: &Arc<Database>) {
&[
Contact::Accepted {
user_id: user_2,
- should_notify: false,
busy: false,
},
Contact::Accepted {
user_id: user_3,
- should_notify: false,
busy: false,
}
]
@@ -367,7 +350,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_3).await.unwrap(),
&[Contact::Accepted {
user_id: user_1,
- should_notify: false,
busy: false,
}],
);
@@ -383,7 +365,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_2).await.unwrap(),
&[Contact::Accepted {
user_id: user_1,
- should_notify: false,
busy: false,
}]
);
@@ -391,7 +372,6 @@ async fn test_add_contacts(db: &Arc<Database>) {
db.get_contacts(user_3).await.unwrap(),
&[Contact::Accepted {
user_id: user_1,
- should_notify: false,
busy: false,
}],
);
@@ -415,7 +395,6 @@ async fn test_metrics_id(db: &Arc<Database>) {
NewUserParams {
github_login: "person1".into(),
github_user_id: 101,
- invite_count: 5,
},
)
.await
@@ -431,7 +410,6 @@ async fn test_metrics_id(db: &Arc<Database>) {
NewUserParams {
github_login: "person2".into(),
github_user_id: 102,
- invite_count: 5,
},
)
.await
@@ -460,7 +438,6 @@ async fn test_project_count(db: &Arc<Database>) {
NewUserParams {
github_login: "admin".into(),
github_user_id: 0,
- invite_count: 0,
},
)
.await
@@ -472,7 +449,6 @@ async fn test_project_count(db: &Arc<Database>) {
NewUserParams {
github_login: "user".into(),
github_user_id: 1,
- invite_count: 0,
},
)
.await
@@ -554,7 +530,6 @@ async fn test_fuzzy_search_users() {
NewUserParams {
github_login: github_login.into(),
github_user_id: i as i32,
- invite_count: 0,
},
)
.await
@@ -596,7 +571,6 @@ async fn test_non_matching_release_channels(db: &Arc<Database>) {
NewUserParams {
github_login: "admin".into(),
github_user_id: 0,
- invite_count: 0,
},
)
.await
@@ -608,7 +582,6 @@ async fn test_non_matching_release_channels(db: &Arc<Database>) {
NewUserParams {
github_login: "user".into(),
github_user_id: 1,
- invite_count: 0,
},
)
.await
@@ -18,7 +18,6 @@ async fn test_get_user_flags(db: &Arc<Database>) {
NewUserParams {
github_login: format!("user1"),
github_user_id: 1,
- invite_count: 0,
},
)
.await
@@ -32,7 +31,6 @@ async fn test_get_user_flags(db: &Arc<Database>) {
NewUserParams {
github_login: format!("user2"),
github_user_id: 2,
- invite_count: 0,
},
)
.await
@@ -1,7 +1,9 @@
+use super::new_test_user;
use crate::{
- db::{ChannelRole, Database, MessageId, NewUserParams},
+ db::{ChannelRole, Database, MessageId},
test_both_dbs,
};
+use channel::mentions_to_proto;
use std::sync::Arc;
use time::OffsetDateTime;
@@ -12,39 +14,38 @@ test_both_dbs!(
);
async fn test_channel_message_retrieval(db: &Arc<Database>) {
- let user = db
- .create_user(
- "user@example.com",
- false,
- NewUserParams {
- github_login: "user".into(),
- github_user_id: 1,
- invite_count: 0,
- },
- )
- .await
- .unwrap()
- .user_id;
- let channel = db.create_root_channel("channel", user).await.unwrap();
+ let user = new_test_user(db, "user@example.com").await;
+ let result = db.create_channel("channel", None, user).await.unwrap();
let owner_id = db.create_server("test").await.unwrap().0 as u32;
- db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user)
- .await
- .unwrap();
+ db.join_channel_chat(
+ result.channel.id,
+ rpc::ConnectionId { owner_id, id: 0 },
+ user,
+ )
+ .await
+ .unwrap();
let mut all_messages = Vec::new();
for i in 0..10 {
all_messages.push(
- db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
- .await
- .unwrap()
- .0
- .to_proto(),
+ db.create_channel_message(
+ result.channel.id,
+ user,
+ &i.to_string(),
+ &[],
+ OffsetDateTime::now_utc(),
+ i,
+ )
+ .await
+ .unwrap()
+ .message_id
+ .to_proto(),
);
}
let messages = db
- .get_channel_messages(channel, user, 3, None)
+ .get_channel_messages(result.channel.id, user, 3, None)
.await
.unwrap()
.into_iter()
@@ -54,7 +55,7 @@ async fn test_channel_message_retrieval(db: &Arc<Database>) {
let messages = db
.get_channel_messages(
- channel,
+ result.channel.id,
user,
4,
Some(MessageId::from_proto(all_messages[6])),
@@ -74,99 +75,154 @@ test_both_dbs!(
);
async fn test_channel_message_nonces(db: &Arc<Database>) {
- let user = db
- .create_user(
- "user@example.com",
- false,
- NewUserParams {
- github_login: "user".into(),
- github_user_id: 1,
- invite_count: 0,
- },
- )
+ let user_a = new_test_user(db, "user_a@example.com").await;
+ let user_b = new_test_user(db, "user_b@example.com").await;
+ let user_c = new_test_user(db, "user_c@example.com").await;
+ let channel = db.create_root_channel("channel", user_a).await.unwrap();
+ db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)
.await
- .unwrap()
- .user_id;
- let channel = db.create_root_channel("channel", user).await.unwrap();
-
- let owner_id = db.create_server("test").await.unwrap().0 as u32;
-
- db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user)
+ .unwrap();
+ db.invite_channel_member(channel, user_c, user_a, ChannelRole::Member)
.await
.unwrap();
-
- let msg1_id = db
- .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
+ db.respond_to_channel_invite(channel, user_b, true)
.await
.unwrap();
- let msg2_id = db
- .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
+ db.respond_to_channel_invite(channel, user_c, true)
.await
.unwrap();
- let msg3_id = db
- .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
+
+ let owner_id = db.create_server("test").await.unwrap().0 as u32;
+ db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user_a)
.await
.unwrap();
- let msg4_id = db
- .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
+ db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 1 }, user_b)
.await
.unwrap();
- assert_ne!(msg1_id, msg2_id);
- assert_eq!(msg1_id, msg3_id);
- assert_eq!(msg2_id, msg4_id);
-}
-
-test_both_dbs!(
- test_channel_message_new_notification,
- test_channel_message_new_notification_postgres,
- test_channel_message_new_notification_sqlite
-);
-
-async fn test_channel_message_new_notification(db: &Arc<Database>) {
- let user = db
- .create_user(
- "user_a@example.com",
- false,
- NewUserParams {
- github_login: "user_a".into(),
- github_user_id: 1,
- invite_count: 0,
- },
+ // As user A, create messages that re-use the same nonces. The requests
+ // succeed, but return the same ids.
+ let id1 = db
+ .create_channel_message(
+ channel,
+ user_a,
+ "hi @user_b",
+ &mentions_to_proto(&[(3..10, user_b.to_proto())]),
+ OffsetDateTime::now_utc(),
+ 100,
)
.await
.unwrap()
- .user_id;
- let observer = db
- .create_user(
- "user_b@example.com",
- false,
- NewUserParams {
- github_login: "user_b".into(),
- github_user_id: 1,
- invite_count: 0,
- },
+ .message_id;
+ let id2 = db
+ .create_channel_message(
+ channel,
+ user_a,
+ "hello, fellow users",
+ &mentions_to_proto(&[]),
+ OffsetDateTime::now_utc(),
+ 200,
+ )
+ .await
+ .unwrap()
+ .message_id;
+ let id3 = db
+ .create_channel_message(
+ channel,
+ user_a,
+ "bye @user_c (same nonce as first message)",
+ &mentions_to_proto(&[(4..11, user_c.to_proto())]),
+ OffsetDateTime::now_utc(),
+ 100,
+ )
+ .await
+ .unwrap()
+ .message_id;
+ let id4 = db
+ .create_channel_message(
+ channel,
+ user_a,
+ "omg (same nonce as second message)",
+ &mentions_to_proto(&[]),
+ OffsetDateTime::now_utc(),
+ 200,
)
.await
.unwrap()
- .user_id;
+ .message_id;
- let channel_1 = db.create_root_channel("channel", user).await.unwrap();
+ // As a different user, reuse one of the same nonces. This request succeeds
+ // and returns a different id.
+ let id5 = db
+ .create_channel_message(
+ channel,
+ user_b,
+ "omg @user_a (same nonce as user_a's first message)",
+ &mentions_to_proto(&[(4..11, user_a.to_proto())]),
+ OffsetDateTime::now_utc(),
+ 100,
+ )
+ .await
+ .unwrap()
+ .message_id;
+
+ assert_ne!(id1, id2);
+ assert_eq!(id1, id3);
+ assert_eq!(id2, id4);
+ assert_ne!(id5, id1);
+
+ let messages = db
+ .get_channel_messages(channel, user_a, 5, None)
+ .await
+ .unwrap()
+ .into_iter()
+ .map(|m| (m.id, m.body, m.mentions))
+ .collect::<Vec<_>>();
+ assert_eq!(
+ messages,
+ &[
+ (
+ id1.to_proto(),
+ "hi @user_b".into(),
+ mentions_to_proto(&[(3..10, user_b.to_proto())]),
+ ),
+ (
+ id2.to_proto(),
+ "hello, fellow users".into(),
+ mentions_to_proto(&[])
+ ),
+ (
+ id5.to_proto(),
+ "omg @user_a (same nonce as user_a's first message)".into(),
+ mentions_to_proto(&[(4..11, user_a.to_proto())]),
+ ),
+ ]
+ );
+}
+
+test_both_dbs!(
+ test_unseen_channel_messages,
+ test_unseen_channel_messages_postgres,
+ test_unseen_channel_messages_sqlite
+);
+
+async fn test_unseen_channel_messages(db: &Arc<Database>) {
+ let user = new_test_user(db, "user_a@example.com").await;
+ let observer = new_test_user(db, "user_b@example.com").await;
+ let channel_1 = db.create_root_channel("channel", user).await.unwrap();
let channel_2 = db.create_root_channel("channel-2", user).await.unwrap();
db.invite_channel_member(channel_1, observer, user, ChannelRole::Member)
.await
.unwrap();
-
- db.respond_to_channel_invite(channel_1, observer, true)
+ db.invite_channel_member(channel_2, observer, user, ChannelRole::Member)
.await
.unwrap();
- db.invite_channel_member(channel_2, observer, user, ChannelRole::Member)
+ db.respond_to_channel_invite(channel_1, observer, true)
.await
.unwrap();
-
db.respond_to_channel_invite(channel_2, observer, true)
.await
.unwrap();
@@ -179,28 +235,31 @@ async fn test_channel_message_new_notification(db: &Arc<Database>) {
.unwrap();
let _ = db
- .create_channel_message(channel_1, user, "1_1", OffsetDateTime::now_utc(), 1)
+ .create_channel_message(channel_1, user, "1_1", &[], OffsetDateTime::now_utc(), 1)
.await
.unwrap();
- let (second_message, _, _) = db
- .create_channel_message(channel_1, user, "1_2", OffsetDateTime::now_utc(), 2)
+ let second_message = db
+ .create_channel_message(channel_1, user, "1_2", &[], OffsetDateTime::now_utc(), 2)
.await
- .unwrap();
+ .unwrap()
+ .message_id;
- let (third_message, _, _) = db
- .create_channel_message(channel_1, user, "1_3", OffsetDateTime::now_utc(), 3)
+ let third_message = db
+ .create_channel_message(channel_1, user, "1_3", &[], OffsetDateTime::now_utc(), 3)
.await
- .unwrap();
+ .unwrap()
+ .message_id;
db.join_channel_chat(channel_2, user_connection_id, user)
.await
.unwrap();
- let (fourth_message, _, _) = db
- .create_channel_message(channel_2, user, "2_1", OffsetDateTime::now_utc(), 4)
+ let fourth_message = db
+ .create_channel_message(channel_2, user, "2_1", &[], OffsetDateTime::now_utc(), 4)
.await
- .unwrap();
+ .unwrap()
+ .message_id;
// Check that observer has new messages
let unseen_messages = db
@@ -295,3 +354,101 @@ async fn test_channel_message_new_notification(db: &Arc<Database>) {
}]
);
}
+
+test_both_dbs!(
+ test_channel_message_mentions,
+ test_channel_message_mentions_postgres,
+ test_channel_message_mentions_sqlite
+);
+
+async fn test_channel_message_mentions(db: &Arc<Database>) {
+ let user_a = new_test_user(db, "user_a@example.com").await;
+ let user_b = new_test_user(db, "user_b@example.com").await;
+ let user_c = new_test_user(db, "user_c@example.com").await;
+
+ let channel = db
+ .create_channel("channel", None, user_a)
+ .await
+ .unwrap()
+ .channel
+ .id;
+ db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member)
+ .await
+ .unwrap();
+ db.respond_to_channel_invite(channel, user_b, true)
+ .await
+ .unwrap();
+
+ let owner_id = db.create_server("test").await.unwrap().0 as u32;
+ let connection_id = rpc::ConnectionId { owner_id, id: 0 };
+ db.join_channel_chat(channel, connection_id, user_a)
+ .await
+ .unwrap();
+
+ db.create_channel_message(
+ channel,
+ user_a,
+ "hi @user_b and @user_c",
+ &mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]),
+ OffsetDateTime::now_utc(),
+ 1,
+ )
+ .await
+ .unwrap();
+ db.create_channel_message(
+ channel,
+ user_a,
+ "bye @user_c",
+ &mentions_to_proto(&[(4..11, user_c.to_proto())]),
+ OffsetDateTime::now_utc(),
+ 2,
+ )
+ .await
+ .unwrap();
+ db.create_channel_message(
+ channel,
+ user_a,
+ "umm",
+ &mentions_to_proto(&[]),
+ OffsetDateTime::now_utc(),
+ 3,
+ )
+ .await
+ .unwrap();
+ db.create_channel_message(
+ channel,
+ user_a,
+ "@user_b, stop.",
+ &mentions_to_proto(&[(0..7, user_b.to_proto())]),
+ OffsetDateTime::now_utc(),
+ 4,
+ )
+ .await
+ .unwrap();
+
+ let messages = db
+ .get_channel_messages(channel, user_b, 5, None)
+ .await
+ .unwrap()
+ .into_iter()
+ .map(|m| (m.body, m.mentions))
+ .collect::<Vec<_>>();
+ assert_eq!(
+ &messages,
+ &[
+ (
+ "hi @user_b and @user_c".into(),
+ mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]),
+ ),
+ (
+ "bye @user_c".into(),
+ mentions_to_proto(&[(4..11, user_c.to_proto())]),
+ ),
+ ("umm".into(), mentions_to_proto(&[]),),
+ (
+ "@user_b, stop.".into(),
+ mentions_to_proto(&[(0..7, user_b.to_proto())]),
+ ),
+ ]
+ );
+}
@@ -119,7 +119,9 @@ impl AppState {
pub async fn new(config: Config) -> Result<Arc<Self>> {
let mut db_options = db::ConnectOptions::new(config.database_url.clone());
db_options.max_connections(config.database_max_connections);
- let db = Database::new(db_options, Executor::Production).await?;
+ let mut db = Database::new(db_options, Executor::Production).await?;
+ db.initialize_notification_kinds().await?;
+
let live_kit_client = if let Some(((server, key), secret)) = config
.live_kit_server
.as_ref()
@@ -3,9 +3,11 @@ mod connection_pool;
use crate::{
auth,
db::{
- self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult, Database,
- MembershipUpdated, MessageId, MoveChannelResult, ProjectId, RenameChannelResult, RoomId,
- ServerId, SetChannelVisibilityResult, User, UserId,
+ self, BufferId, ChannelId, ChannelRole, ChannelsForUser, CreateChannelResult,
+ CreatedChannelMessage, Database, InviteMemberResult, MembershipUpdated, MessageId,
+ MoveChannelResult, NotificationId, ProjectId, RemoveChannelMemberResult,
+ RenameChannelResult, RespondToChannelInvite, RoomId, ServerId, SetChannelVisibilityResult,
+ User, UserId,
},
executor::Executor,
AppState, Result,
@@ -71,6 +73,7 @@ pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10);
const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024;
+const NOTIFICATION_COUNT_PER_PAGE: usize = 50;
lazy_static! {
static ref METRIC_CONNECTIONS: IntGauge =
@@ -271,6 +274,9 @@ impl Server {
.add_request_handler(send_channel_message)
.add_request_handler(remove_channel_message)
.add_request_handler(get_channel_messages)
+ .add_request_handler(get_channel_messages_by_id)
+ .add_request_handler(get_notifications)
+ .add_request_handler(mark_notification_as_read)
.add_request_handler(link_channel)
.add_request_handler(unlink_channel)
.add_request_handler(move_channel)
@@ -390,7 +396,7 @@ impl Server {
let contacts = app_state.db.get_contacts(user_id).await.trace_err();
if let Some((busy, contacts)) = busy.zip(contacts) {
let pool = pool.lock();
- let updated_contact = contact_for_user(user_id, false, busy, &pool);
+ let updated_contact = contact_for_user(user_id, busy, &pool);
for contact in contacts {
if let db::Contact::Accepted {
user_id: contact_user_id,
@@ -584,7 +590,7 @@ impl Server {
let (contacts, channels_for_user, channel_invites) = future::try_join3(
this.app_state.db.get_contacts(user_id),
this.app_state.db.get_channels_for_user(user_id),
- this.app_state.db.get_channel_invites_for_user(user_id)
+ this.app_state.db.get_channel_invites_for_user(user_id),
).await?;
{
@@ -690,7 +696,7 @@ impl Server {
if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? {
if let Some(code) = &user.invite_code {
let pool = self.connection_pool.lock();
- let invitee_contact = contact_for_user(invitee_id, true, false, &pool);
+ let invitee_contact = contact_for_user(invitee_id, false, &pool);
for connection_id in pool.user_connection_ids(inviter_id) {
self.peer.send(
connection_id,
@@ -2066,7 +2072,7 @@ async fn request_contact(
return Err(anyhow!("cannot add yourself as a contact"))?;
}
- session
+ let notifications = session
.db()
.await
.send_contact_request(requester_id, responder_id)
@@ -2089,16 +2095,14 @@ async fn request_contact(
.incoming_requests
.push(proto::IncomingContactRequest {
requester_id: requester_id.to_proto(),
- should_notify: true,
});
- for connection_id in session
- .connection_pool()
- .await
- .user_connection_ids(responder_id)
- {
+ let connection_pool = session.connection_pool().await;
+ for connection_id in connection_pool.user_connection_ids(responder_id) {
session.peer.send(connection_id, update.clone())?;
}
+ send_notifications(&*connection_pool, &session.peer, notifications);
+
response.send(proto::Ack {})?;
Ok(())
}
@@ -2117,7 +2121,8 @@ async fn respond_to_contact_request(
} else {
let accept = request.response == proto::ContactRequestResponse::Accept as i32;
- db.respond_to_contact_request(responder_id, requester_id, accept)
+ let notifications = db
+ .respond_to_contact_request(responder_id, requester_id, accept)
.await?;
let requester_busy = db.is_user_busy(requester_id).await?;
let responder_busy = db.is_user_busy(responder_id).await?;
@@ -2128,7 +2133,7 @@ async fn respond_to_contact_request(
if accept {
update
.contacts
- .push(contact_for_user(requester_id, false, requester_busy, &pool));
+ .push(contact_for_user(requester_id, requester_busy, &pool));
}
update
.remove_incoming_requests
@@ -2142,14 +2147,17 @@ async fn respond_to_contact_request(
if accept {
update
.contacts
- .push(contact_for_user(responder_id, true, responder_busy, &pool));
+ .push(contact_for_user(responder_id, responder_busy, &pool));
}
update
.remove_outgoing_requests
.push(responder_id.to_proto());
+
for connection_id in pool.user_connection_ids(requester_id) {
session.peer.send(connection_id, update.clone())?;
}
+
+ send_notifications(&*pool, &session.peer, notifications);
}
response.send(proto::Ack {})?;
@@ -2164,7 +2172,8 @@ async fn remove_contact(
let requester_id = session.user_id;
let responder_id = UserId::from_proto(request.user_id);
let db = session.db().await;
- let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
+ let (contact_accepted, deleted_notification_id) =
+ db.remove_contact(requester_id, responder_id).await?;
let pool = session.connection_pool().await;
// Update outgoing contact requests of requester
@@ -2191,6 +2200,14 @@ async fn remove_contact(
}
for connection_id in pool.user_connection_ids(responder_id) {
session.peer.send(connection_id, update.clone())?;
+ if let Some(notification_id) = deleted_notification_id {
+ session.peer.send(
+ connection_id,
+ proto::DeleteNotification {
+ notification_id: notification_id.to_proto(),
+ },
+ )?;
+ }
}
response.send(proto::Ack {})?;
@@ -2268,7 +2285,10 @@ async fn invite_channel_member(
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
let invitee_id = UserId::from_proto(request.user_id);
- let channel = db
+ let InviteMemberResult {
+ channel,
+ notifications,
+ } = db
.invite_channel_member(
channel_id,
invitee_id,
@@ -2282,14 +2302,13 @@ async fn invite_channel_member(
..Default::default()
};
- for connection_id in session
- .connection_pool()
- .await
- .user_connection_ids(invitee_id)
- {
+ let connection_pool = session.connection_pool().await;
+ for connection_id in connection_pool.user_connection_ids(invitee_id) {
session.peer.send(connection_id, update.clone())?;
}
+ send_notifications(&*connection_pool, &session.peer, notifications);
+
response.send(proto::Ack {})?;
Ok(())
}
@@ -2303,13 +2322,33 @@ async fn remove_channel_member(
let channel_id = ChannelId::from_proto(request.channel_id);
let member_id = UserId::from_proto(request.user_id);
- let membership_updated = db
+ let RemoveChannelMemberResult {
+ membership_update,
+ notification_id,
+ } = db
.remove_channel_member(channel_id, member_id, session.user_id)
.await?;
- dbg!(&membership_updated);
-
- notify_membership_updated(membership_updated, member_id, &session).await?;
+ let connection_pool = &session.connection_pool().await;
+ notify_membership_updated(
+ &connection_pool,
+ membership_update,
+ member_id,
+ &session.peer,
+ );
+ for connection_id in connection_pool.user_connection_ids(member_id) {
+ if let Some(notification_id) = notification_id {
+ session
+ .peer
+ .send(
+ connection_id,
+ proto::DeleteNotification {
+ notification_id: notification_id.to_proto(),
+ },
+ )
+ .trace_err();
+ }
+ }
response.send(proto::Ack {})?;
Ok(())
@@ -2374,7 +2413,13 @@ async fn set_channel_member_role(
match result {
db::SetMemberRoleResult::MembershipUpdated(membership_update) => {
- notify_membership_updated(membership_update, member_id, &session).await?;
+ let connection_pool = session.connection_pool().await;
+ notify_membership_updated(
+ &connection_pool,
+ membership_update,
+ member_id,
+ &session.peer,
+ )
}
db::SetMemberRoleResult::InviteUpdated(channel) => {
let update = proto::UpdateChannels {
@@ -2535,24 +2580,34 @@ async fn respond_to_channel_invite(
) -> Result<()> {
let db = session.db().await;
let channel_id = ChannelId::from_proto(request.channel_id);
- let result = db
+ let RespondToChannelInvite {
+ membership_update,
+ notifications,
+ } = db
.respond_to_channel_invite(channel_id, session.user_id, request.accept)
.await?;
- if let Some(accept_invite_result) = result {
- notify_membership_updated(accept_invite_result, session.user_id, &session).await?;
+ let connection_pool = session.connection_pool().await;
+ if let Some(membership_update) = membership_update {
+ notify_membership_updated(
+ &connection_pool,
+ membership_update,
+ session.user_id,
+ &session.peer,
+ );
} else {
let update = proto::UpdateChannels {
remove_channel_invitations: vec![channel_id.to_proto()],
..Default::default()
};
- let connection_pool = session.connection_pool().await;
for connection_id in connection_pool.user_connection_ids(session.user_id) {
session.peer.send(connection_id, update.clone())?;
}
};
+ send_notifications(&*connection_pool, &session.peer, notifications);
+
response.send(proto::Ack {})?;
Ok(())
@@ -2635,8 +2690,14 @@ async fn join_channel_internal(
live_kit_connection_info,
})?;
+ let connection_pool = session.connection_pool().await;
if let Some(accept_invite_result) = accept_invite_result {
- notify_membership_updated(accept_invite_result, session.user_id, &session).await?;
+ notify_membership_updated(
+ &connection_pool,
+ accept_invite_result,
+ session.user_id,
+ &session.peer,
+ );
}
room_updated(&joined_room.room, &session.peer);
@@ -2805,6 +2866,29 @@ fn channel_buffer_updated<T: EnvelopedMessage>(
});
}
+fn send_notifications(
+ connection_pool: &ConnectionPool,
+ peer: &Peer,
+ notifications: db::NotificationBatch,
+) {
+ for (user_id, notification) in notifications {
+ for connection_id in connection_pool.user_connection_ids(user_id) {
+ if let Err(error) = peer.send(
+ connection_id,
+ proto::AddNotification {
+ notification: Some(notification.clone()),
+ },
+ ) {
+ tracing::error!(
+ "failed to send notification to {:?} {}",
+ connection_id,
+ error
+ );
+ }
+ }
+ }
+}
+
async fn send_channel_message(
request: proto::SendChannelMessage,
response: Response<proto::SendChannelMessage>,
@@ -2819,19 +2903,27 @@ async fn send_channel_message(
return Err(anyhow!("message can't be blank"))?;
}
+ // TODO: adjust mentions if body is trimmed
+
let timestamp = OffsetDateTime::now_utc();
let nonce = request
.nonce
.ok_or_else(|| anyhow!("nonce can't be blank"))?;
let channel_id = ChannelId::from_proto(request.channel_id);
- let (message_id, connection_ids, non_participants) = session
+ let CreatedChannelMessage {
+ message_id,
+ participant_connection_ids,
+ channel_members,
+ notifications,
+ } = session
.db()
.await
.create_channel_message(
channel_id,
session.user_id,
&body,
+ &request.mentions,
timestamp,
nonce.clone().into(),
)
@@ -2840,18 +2932,23 @@ async fn send_channel_message(
sender_id: session.user_id.to_proto(),
id: message_id.to_proto(),
body,
+ mentions: request.mentions,
timestamp: timestamp.unix_timestamp() as u64,
nonce: Some(nonce),
};
- broadcast(Some(session.connection_id), connection_ids, |connection| {
- session.peer.send(
- connection,
- proto::ChannelMessageSent {
- channel_id: channel_id.to_proto(),
- message: Some(message.clone()),
- },
- )
- });
+ broadcast(
+ Some(session.connection_id),
+ participant_connection_ids,
+ |connection| {
+ session.peer.send(
+ connection,
+ proto::ChannelMessageSent {
+ channel_id: channel_id.to_proto(),
+ message: Some(message.clone()),
+ },
+ )
+ },
+ );
response.send(proto::SendChannelMessageResponse {
message: Some(message),
})?;
@@ -2859,7 +2956,7 @@ async fn send_channel_message(
let pool = &*session.connection_pool().await;
broadcast(
None,
- non_participants
+ channel_members
.iter()
.flat_map(|user_id| pool.user_connection_ids(*user_id)),
|peer_id| {
@@ -2875,6 +2972,7 @@ async fn send_channel_message(
)
},
);
+ send_notifications(pool, &session.peer, notifications);
Ok(())
}
@@ -2904,11 +3002,16 @@ async fn acknowledge_channel_message(
) -> Result<()> {
let channel_id = ChannelId::from_proto(request.channel_id);
let message_id = MessageId::from_proto(request.message_id);
- session
+ let notifications = session
.db()
.await
.observe_channel_message(channel_id, session.user_id, message_id)
.await?;
+ send_notifications(
+ &*session.connection_pool().await,
+ &session.peer,
+ notifications,
+ );
Ok(())
}
@@ -2983,6 +3086,72 @@ async fn get_channel_messages(
Ok(())
}
+async fn get_channel_messages_by_id(
+ request: proto::GetChannelMessagesById,
+ response: Response<proto::GetChannelMessagesById>,
+ session: Session,
+) -> Result<()> {
+ let message_ids = request
+ .message_ids
+ .iter()
+ .map(|id| MessageId::from_proto(*id))
+ .collect::<Vec<_>>();
+ let messages = session
+ .db()
+ .await
+ .get_channel_messages_by_id(session.user_id, &message_ids)
+ .await?;
+ response.send(proto::GetChannelMessagesResponse {
+ done: messages.len() < MESSAGE_COUNT_PER_PAGE,
+ messages,
+ })?;
+ Ok(())
+}
+
+async fn get_notifications(
+ request: proto::GetNotifications,
+ response: Response<proto::GetNotifications>,
+ session: Session,
+) -> Result<()> {
+ let notifications = session
+ .db()
+ .await
+ .get_notifications(
+ session.user_id,
+ NOTIFICATION_COUNT_PER_PAGE,
+ request
+ .before_id
+ .map(|id| db::NotificationId::from_proto(id)),
+ )
+ .await?;
+ response.send(proto::GetNotificationsResponse {
+ done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE,
+ notifications,
+ })?;
+ Ok(())
+}
+
+async fn mark_notification_as_read(
+ request: proto::MarkNotificationRead,
+ response: Response<proto::MarkNotificationRead>,
+ session: Session,
+) -> Result<()> {
+ let database = &session.db().await;
+ let notifications = database
+ .mark_notification_as_read_by_id(
+ session.user_id,
+ NotificationId::from_proto(request.notification_id),
+ )
+ .await?;
+ send_notifications(
+ &*session.connection_pool().await,
+ &session.peer,
+ notifications,
+ );
+ response.send(proto::Ack {})?;
+ Ok(())
+}
+
async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
let project_id = ProjectId::from_proto(request.project_id);
let project_connection_ids = session
@@ -3052,11 +3221,12 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage {
}
}
-async fn notify_membership_updated(
+fn notify_membership_updated(
+ connection_pool: &ConnectionPool,
result: MembershipUpdated,
user_id: UserId,
- session: &Session,
-) -> Result<()> {
+ peer: &Peer,
+) {
let mut update = build_channels_update(result.new_channels, vec![]);
update.delete_channels = result
.removed_channels
@@ -3065,11 +3235,9 @@ async fn notify_membership_updated(
.collect();
update.remove_channel_invitations = vec![result.channel_id.to_proto()];
- let connection_pool = session.connection_pool().await;
for connection_id in connection_pool.user_connection_ids(user_id) {
- session.peer.send(connection_id, update.clone())?;
+ peer.send(connection_id, update.clone()).trace_err();
}
- Ok(())
}
fn build_channels_update(
@@ -3120,42 +3288,28 @@ fn build_initial_contacts_update(
for contact in contacts {
match contact {
- db::Contact::Accepted {
- user_id,
- should_notify,
- busy,
- } => {
- update
- .contacts
- .push(contact_for_user(user_id, should_notify, busy, &pool));
+ db::Contact::Accepted { user_id, busy } => {
+ update.contacts.push(contact_for_user(user_id, busy, &pool));
}
db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()),
- db::Contact::Incoming {
- user_id,
- should_notify,
- } => update
- .incoming_requests
- .push(proto::IncomingContactRequest {
- requester_id: user_id.to_proto(),
- should_notify,
- }),
+ db::Contact::Incoming { user_id } => {
+ update
+ .incoming_requests
+ .push(proto::IncomingContactRequest {
+ requester_id: user_id.to_proto(),
+ })
+ }
}
}
update
}
-fn contact_for_user(
- user_id: UserId,
- should_notify: bool,
- busy: bool,
- pool: &ConnectionPool,
-) -> proto::Contact {
+fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact {
proto::Contact {
user_id: user_id.to_proto(),
online: pool.is_user_online(user_id),
busy,
- should_notify,
}
}
@@ -3216,7 +3370,7 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()>
let busy = db.is_user_busy(user_id).await?;
let pool = session.connection_pool().await;
- let updated_contact = contact_for_user(user_id, false, busy, &pool);
+ let updated_contact = contact_for_user(user_id, busy, &pool);
for contact in contacts {
if let db::Contact::Accepted {
user_id: contact_user_id,
@@ -6,6 +6,7 @@ mod channel_message_tests;
mod channel_tests;
mod following_tests;
mod integration_tests;
+mod notification_tests;
mod random_channel_buffer_tests;
mod random_project_collaboration_tests;
mod randomized_test_helpers;
@@ -1,27 +1,30 @@
use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
-use channel::{ChannelChat, ChannelMessageId};
+use channel::{ChannelChat, ChannelMessageId, MessageParams};
use collab_ui::chat_panel::ChatPanel;
use gpui::{executor::Deterministic, BorrowAppContext, ModelHandle, TestAppContext};
+use rpc::Notification;
use std::sync::Arc;
use workspace::dock::Panel;
#[gpui::test]
async fn test_basic_channel_messages(
deterministic: Arc<Deterministic>,
- cx_a: &mut TestAppContext,
- cx_b: &mut TestAppContext,
+ mut cx_a: &mut TestAppContext,
+ mut cx_b: &mut TestAppContext,
+ mut cx_c: &mut TestAppContext,
) {
deterministic.forbid_parking();
let mut server = TestServer::start(&deterministic).await;
let client_a = server.create_client(cx_a, "user_a").await;
let client_b = server.create_client(cx_b, "user_b").await;
+ let client_c = server.create_client(cx_c, "user_c").await;
let channel_id = server
.make_channel(
"the-channel",
None,
(&client_a, cx_a),
- &mut [(&client_b, cx_b)],
+ &mut [(&client_b, cx_b), (&client_c, cx_c)],
)
.await;
@@ -36,8 +39,17 @@ async fn test_basic_channel_messages(
.await
.unwrap();
- channel_chat_a
- .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap())
+ let message_id = channel_chat_a
+ .update(cx_a, |c, cx| {
+ c.send_message(
+ MessageParams {
+ text: "hi @user_c!".into(),
+ mentions: vec![(3..10, client_c.id())],
+ },
+ cx,
+ )
+ .unwrap()
+ })
.await
.unwrap();
channel_chat_a
@@ -52,15 +64,55 @@ async fn test_basic_channel_messages(
.unwrap();
deterministic.run_until_parked();
- channel_chat_a.update(cx_a, |c, _| {
+
+ let channel_chat_c = client_c
+ .channel_store()
+ .update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx))
+ .await
+ .unwrap();
+
+ for (chat, cx) in [
+ (&channel_chat_a, &mut cx_a),
+ (&channel_chat_b, &mut cx_b),
+ (&channel_chat_c, &mut cx_c),
+ ] {
+ chat.update(*cx, |c, _| {
+ assert_eq!(
+ c.messages()
+ .iter()
+ .map(|m| (m.body.as_str(), m.mentions.as_slice()))
+ .collect::<Vec<_>>(),
+ vec![
+ ("hi @user_c!", [(3..10, client_c.id())].as_slice()),
+ ("two", &[]),
+ ("three", &[])
+ ],
+ "results for user {}",
+ c.client().id(),
+ );
+ });
+ }
+
+ client_c.notification_store().update(cx_c, |store, _| {
+ assert_eq!(store.notification_count(), 2);
+ assert_eq!(store.unread_notification_count(), 1);
assert_eq!(
- c.messages()
- .iter()
- .map(|m| m.body.as_str())
- .collect::<Vec<_>>(),
- vec!["one", "two", "three"]
+ store.notification_at(0).unwrap().notification,
+ Notification::ChannelMessageMention {
+ message_id,
+ sender_id: client_a.id(),
+ channel_id,
+ }
);
- })
+ assert_eq!(
+ store.notification_at(1).unwrap().notification,
+ Notification::ChannelInvitation {
+ channel_id,
+ channel_name: "the-channel".to_string(),
+ inviter_id: client_a.id()
+ }
+ );
+ });
}
#[gpui::test]
@@ -280,7 +332,7 @@ async fn test_channel_message_changes(
chat_panel_b
.update(cx_b, |chat_panel, cx| {
chat_panel.set_active(true, cx);
- chat_panel.select_channel(channel_id, cx)
+ chat_panel.select_channel(channel_id, None, cx)
})
.await
.unwrap();
@@ -126,8 +126,8 @@ async fn test_core_channels(
// Client B accepts the invitation.
client_b
.channel_store()
- .update(cx_b, |channels, _| {
- channels.respond_to_channel_invite(channel_a_id, true)
+ .update(cx_b, |channels, cx| {
+ channels.respond_to_channel_invite(channel_a_id, true, cx)
})
.await
.unwrap();
@@ -153,7 +153,6 @@ async fn test_core_channels(
},
],
);
- dbg!("-------");
let channel_c_id = client_a
.channel_store()
@@ -289,11 +288,17 @@ async fn test_core_channels(
// Client B no longer has access to the channel
assert_channels(client_b.channel_store(), cx_b, &[]);
- // When disconnected, client A sees no channels.
server.forbid_connections();
server.disconnect_client(client_a.peer_id().unwrap());
deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
- assert_channels(client_a.channel_store(), cx_a, &[]);
+
+ client_b
+ .channel_store()
+ .update(cx_b, |channel_store, cx| {
+ channel_store.rename(channel_a_id, "channel-a-renamed", cx)
+ })
+ .await
+ .unwrap();
server.allow_connections();
deterministic.advance_clock(RECEIVE_TIMEOUT + RECONNECT_TIMEOUT);
@@ -302,7 +307,7 @@ async fn test_core_channels(
cx_a,
&[ExpectedChannel {
id: channel_a_id,
- name: "channel-a".to_string(),
+ name: "channel-a-renamed".to_string(),
depth: 0,
role: ChannelRole::Admin,
}],
@@ -886,8 +891,8 @@ async fn test_lost_channel_creation(
// Client B accepts the invite
client_b
.channel_store()
- .update(cx_b, |channel_store, _| {
- channel_store.respond_to_channel_invite(channel_id, true)
+ .update(cx_b, |channel_store, cx| {
+ channel_store.respond_to_channel_invite(channel_id, true, cx)
})
.await
.unwrap();
@@ -951,16 +956,16 @@ async fn test_channel_link_notifications(
client_b
.channel_store()
- .update(cx_b, |channel_store, _| {
- channel_store.respond_to_channel_invite(zed_channel, true)
+ .update(cx_b, |channel_store, cx| {
+ channel_store.respond_to_channel_invite(zed_channel, true, cx)
})
.await
.unwrap();
client_c
.channel_store()
- .update(cx_c, |channel_store, _| {
- channel_store.respond_to_channel_invite(zed_channel, true)
+ .update(cx_c, |channel_store, cx| {
+ channel_store.respond_to_channel_invite(zed_channel, true, cx)
})
.await
.unwrap();
@@ -1162,16 +1167,16 @@ async fn test_channel_membership_notifications(
client_b
.channel_store()
- .update(cx_b, |channel_store, _| {
- channel_store.respond_to_channel_invite(zed_channel, true)
+ .update(cx_b, |channel_store, cx| {
+ channel_store.respond_to_channel_invite(zed_channel, true, cx)
})
.await
.unwrap();
client_b
.channel_store()
- .update(cx_b, |channel_store, _| {
- channel_store.respond_to_channel_invite(vim_channel, true)
+ .update(cx_b, |channel_store, cx| {
+ channel_store.respond_to_channel_invite(vim_channel, true, cx)
})
.await
.unwrap();
@@ -1,6 +1,6 @@
use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
use call::ActiveCall;
-use collab_ui::project_shared_notification::ProjectSharedNotification;
+use collab_ui::notifications::project_shared_notification::ProjectSharedNotification;
use editor::{Editor, ExcerptRange, MultiBuffer};
use gpui::{executor::Deterministic, geometry::vector::vec2f, TestAppContext, ViewHandle};
use live_kit_client::MacOSDisplay;
@@ -15,8 +15,8 @@ use gpui::{executor::Deterministic, test::EmptyView, AppContext, ModelHandle, Te
use indoc::indoc;
use language::{
language_settings::{AllLanguageSettings, Formatter, InlayHintSettings},
- tree_sitter_rust, Anchor, BundledFormatter, Diagnostic, DiagnosticEntry, FakeLspAdapter,
- Language, LanguageConfig, LineEnding, OffsetRangeExt, Point, Rope,
+ tree_sitter_rust, Anchor, Diagnostic, DiagnosticEntry, FakeLspAdapter, Language,
+ LanguageConfig, LineEnding, OffsetRangeExt, Point, Rope,
};
use live_kit_client::MacOSDisplay;
use lsp::LanguageServerId;
@@ -4530,6 +4530,7 @@ async fn test_prettier_formatting_buffer(
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".to_string()],
+ prettier_parser_name: Some("test_parser".to_string()),
..Default::default()
},
Some(tree_sitter_rust::language()),
@@ -4537,10 +4538,7 @@ async fn test_prettier_formatting_buffer(
let test_plugin = "test_plugin";
let mut fake_language_servers = language
.set_fake_lsp_adapter(Arc::new(FakeLspAdapter {
- enabled_formatters: vec![BundledFormatter::Prettier {
- parser_name: Some("test_parser"),
- plugin_names: vec![test_plugin],
- }],
+ prettier_plugins: vec![test_plugin],
..Default::default()
}))
.await;
@@ -0,0 +1,159 @@
+use crate::tests::TestServer;
+use gpui::{executor::Deterministic, TestAppContext};
+use notifications::NotificationEvent;
+use parking_lot::Mutex;
+use rpc::{proto, Notification};
+use std::sync::Arc;
+
+#[gpui::test]
+async fn test_notifications(
+ deterministic: Arc<Deterministic>,
+ cx_a: &mut TestAppContext,
+ cx_b: &mut TestAppContext,
+) {
+ deterministic.forbid_parking();
+ let mut server = TestServer::start(&deterministic).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+
+ let notification_events_a = Arc::new(Mutex::new(Vec::new()));
+ let notification_events_b = Arc::new(Mutex::new(Vec::new()));
+ client_a.notification_store().update(cx_a, |_, cx| {
+ let events = notification_events_a.clone();
+ cx.subscribe(&cx.handle(), move |_, _, event, _| {
+ events.lock().push(event.clone());
+ })
+ .detach()
+ });
+ client_b.notification_store().update(cx_b, |_, cx| {
+ let events = notification_events_b.clone();
+ cx.subscribe(&cx.handle(), move |_, _, event, _| {
+ events.lock().push(event.clone());
+ })
+ .detach()
+ });
+
+ // Client A sends a contact request to client B.
+ client_a
+ .user_store()
+ .update(cx_a, |store, cx| store.request_contact(client_b.id(), cx))
+ .await
+ .unwrap();
+
+ // Client B receives a contact request notification and responds to the
+ // request, accepting it.
+ deterministic.run_until_parked();
+ client_b.notification_store().update(cx_b, |store, cx| {
+ assert_eq!(store.notification_count(), 1);
+ assert_eq!(store.unread_notification_count(), 1);
+
+ let entry = store.notification_at(0).unwrap();
+ assert_eq!(
+ entry.notification,
+ Notification::ContactRequest {
+ sender_id: client_a.id()
+ }
+ );
+ assert!(!entry.is_read);
+ assert_eq!(
+ ¬ification_events_b.lock()[0..],
+ &[
+ NotificationEvent::NewNotification {
+ entry: entry.clone(),
+ },
+ NotificationEvent::NotificationsUpdated {
+ old_range: 0..0,
+ new_count: 1
+ }
+ ]
+ );
+
+ store.respond_to_notification(entry.notification.clone(), true, cx);
+ });
+
+ // Client B sees the notification is now read, and that they responded.
+ deterministic.run_until_parked();
+ client_b.notification_store().read_with(cx_b, |store, _| {
+ assert_eq!(store.notification_count(), 1);
+ assert_eq!(store.unread_notification_count(), 0);
+
+ let entry = store.notification_at(0).unwrap();
+ assert!(entry.is_read);
+ assert_eq!(entry.response, Some(true));
+ assert_eq!(
+ ¬ification_events_b.lock()[2..],
+ &[
+ NotificationEvent::NotificationRead {
+ entry: entry.clone(),
+ },
+ NotificationEvent::NotificationsUpdated {
+ old_range: 0..1,
+ new_count: 1
+ }
+ ]
+ );
+ });
+
+ // Client A receives a notification that client B accepted their request.
+ client_a.notification_store().read_with(cx_a, |store, _| {
+ assert_eq!(store.notification_count(), 1);
+ assert_eq!(store.unread_notification_count(), 1);
+
+ let entry = store.notification_at(0).unwrap();
+ assert_eq!(
+ entry.notification,
+ Notification::ContactRequestAccepted {
+ responder_id: client_b.id()
+ }
+ );
+ assert!(!entry.is_read);
+ });
+
+ // Client A creates a channel and invites client B to be a member.
+ let channel_id = client_a
+ .channel_store()
+ .update(cx_a, |store, cx| {
+ store.create_channel("the-channel", None, cx)
+ })
+ .await
+ .unwrap();
+ client_a
+ .channel_store()
+ .update(cx_a, |store, cx| {
+ store.invite_member(channel_id, client_b.id(), proto::ChannelRole::Member, cx)
+ })
+ .await
+ .unwrap();
+
+ // Client B receives a channel invitation notification and responds to the
+ // invitation, accepting it.
+ deterministic.run_until_parked();
+ client_b.notification_store().update(cx_b, |store, cx| {
+ assert_eq!(store.notification_count(), 2);
+ assert_eq!(store.unread_notification_count(), 1);
+
+ let entry = store.notification_at(0).unwrap();
+ assert_eq!(
+ entry.notification,
+ Notification::ChannelInvitation {
+ channel_id,
+ channel_name: "the-channel".to_string(),
+ inviter_id: client_a.id()
+ }
+ );
+ assert!(!entry.is_read);
+
+ store.respond_to_notification(entry.notification.clone(), true, cx);
+ });
+
+ // Client B sees the notification is now read, and that they responded.
+ deterministic.run_until_parked();
+ client_b.notification_store().read_with(cx_b, |store, _| {
+ assert_eq!(store.notification_count(), 2);
+ assert_eq!(store.unread_notification_count(), 0);
+
+ let entry = store.notification_at(0).unwrap();
+ assert!(entry.is_read);
+ assert_eq!(entry.response, Some(true));
+ });
+}
@@ -208,8 +208,7 @@ impl<T: RandomizedTest> TestPlan<T> {
false,
NewUserParams {
github_login: username.clone(),
- github_user_id: (ix + 1) as i32,
- invite_count: 0,
+ github_user_id: ix as i32,
},
)
.await
@@ -16,6 +16,7 @@ use futures::{channel::oneshot, StreamExt as _};
use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle};
use language::LanguageRegistry;
use node_runtime::FakeNodeRuntime;
+use notifications::NotificationStore;
use parking_lot::Mutex;
use project::{Project, WorktreeId};
use rpc::{proto::ChannelRole, RECEIVE_TIMEOUT};
@@ -46,6 +47,7 @@ pub struct TestClient {
pub username: String,
pub app_state: Arc<workspace::AppState>,
channel_store: ModelHandle<ChannelStore>,
+ notification_store: ModelHandle<NotificationStore>,
state: RefCell<TestClientState>,
}
@@ -138,7 +140,6 @@ impl TestServer {
NewUserParams {
github_login: name.into(),
github_user_id: 0,
- invite_count: 0,
},
)
.await
@@ -231,7 +232,8 @@ impl TestServer {
workspace::init(app_state.clone(), cx);
audio::init((), cx);
call::init(client.clone(), user_store.clone(), cx);
- channel::init(&client, user_store, cx);
+ channel::init(&client, user_store.clone(), cx);
+ notifications::init(client.clone(), user_store, cx);
});
client
@@ -243,6 +245,7 @@ impl TestServer {
app_state,
username: name.to_string(),
channel_store: cx.read(ChannelStore::global).clone(),
+ notification_store: cx.read(NotificationStore::global).clone(),
state: Default::default(),
};
client.wait_for_current_user(cx).await;
@@ -338,8 +341,8 @@ impl TestServer {
member_cx
.read(ChannelStore::global)
- .update(*member_cx, |channels, _| {
- channels.respond_to_channel_invite(channel_id, true)
+ .update(*member_cx, |channels, cx| {
+ channels.respond_to_channel_invite(channel_id, true, cx)
})
.await
.unwrap();
@@ -448,6 +451,10 @@ impl TestClient {
&self.channel_store
}
+ pub fn notification_store(&self) -> &ModelHandle<NotificationStore> {
+ &self.notification_store
+ }
+
pub fn user_store(&self) -> &ModelHandle<UserStore> {
&self.app_state.user_store
}
@@ -37,10 +37,12 @@ fuzzy = { path = "../fuzzy" }
gpui = { path = "../gpui" }
language = { path = "../language" }
menu = { path = "../menu" }
+notifications = { path = "../notifications" }
rich_text = { path = "../rich_text" }
picker = { path = "../picker" }
project = { path = "../project" }
-recent_projects = {path = "../recent_projects"}
+recent_projects = { path = "../recent_projects" }
+rpc = { path = "../rpc" }
settings = { path = "../settings" }
feature_flags = {path = "../feature_flags"}
theme = { path = "../theme" }
@@ -52,6 +54,7 @@ zed-actions = {path = "../zed-actions"}
anyhow.workspace = true
futures.workspace = true
+lazy_static.workspace = true
log.workspace = true
schemars.workspace = true
postage.workspace = true
@@ -66,7 +69,12 @@ client = { path = "../client", features = ["test-support"] }
collections = { path = "../collections", features = ["test-support"] }
editor = { path = "../editor", features = ["test-support"] }
gpui = { path = "../gpui", features = ["test-support"] }
+notifications = { path = "../notifications", features = ["test-support"] }
project = { path = "../project", features = ["test-support"] }
+rpc = { path = "../rpc", features = ["test-support"] }
settings = { path = "../settings", features = ["test-support"] }
util = { path = "../util", features = ["test-support"] }
workspace = { path = "../workspace", features = ["test-support"] }
+
+pretty_assertions.workspace = true
+tree-sitter-markdown.workspace = true
@@ -1,4 +1,6 @@
-use crate::{channel_view::ChannelView, ChatPanelSettings};
+use crate::{
+ channel_view::ChannelView, is_channels_feature_enabled, render_avatar, ChatPanelSettings,
+};
use anyhow::Result;
use call::ActiveCall;
use channel::{ChannelChat, ChannelChatEvent, ChannelMessageId, ChannelStore};
@@ -6,18 +8,18 @@ use client::Client;
use collections::HashMap;
use db::kvp::KEY_VALUE_STORE;
use editor::Editor;
-use feature_flags::{ChannelsAlpha, FeatureFlagAppExt};
use gpui::{
actions,
elements::*,
platform::{CursorStyle, MouseButton},
serde_json,
views::{ItemType, Select, SelectStyle},
- AnyViewHandle, AppContext, AsyncAppContext, Entity, ImageData, ModelHandle, Subscription, Task,
- View, ViewContext, ViewHandle, WeakViewHandle,
+ AnyViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Subscription, Task, View,
+ ViewContext, ViewHandle, WeakViewHandle,
};
-use language::{language_settings::SoftWrap, LanguageRegistry};
+use language::LanguageRegistry;
use menu::Confirm;
+use message_editor::MessageEditor;
use project::Fs;
use rich_text::RichText;
use serde::{Deserialize, Serialize};
@@ -31,6 +33,8 @@ use workspace::{
Workspace,
};
+mod message_editor;
+
const MESSAGE_LOADING_THRESHOLD: usize = 50;
const CHAT_PANEL_KEY: &'static str = "ChatPanel";
@@ -40,7 +44,7 @@ pub struct ChatPanel {
languages: Arc<LanguageRegistry>,
active_chat: Option<(ModelHandle<ChannelChat>, Subscription)>,
message_list: ListState<ChatPanel>,
- input_editor: ViewHandle<Editor>,
+ input_editor: ViewHandle<MessageEditor>,
channel_select: ViewHandle<Select>,
local_timezone: UtcOffset,
fs: Arc<dyn Fs>,
@@ -49,6 +53,7 @@ pub struct ChatPanel {
pending_serialization: Task<Option<()>>,
subscriptions: Vec<gpui::Subscription>,
workspace: WeakViewHandle<Workspace>,
+ is_scrolled_to_bottom: bool,
has_focus: bool,
markdown_data: HashMap<ChannelMessageId, RichText>,
}
@@ -85,13 +90,18 @@ impl ChatPanel {
let languages = workspace.app_state().languages.clone();
let input_editor = cx.add_view(|cx| {
- let mut editor = Editor::auto_height(
- 4,
- Some(Arc::new(|theme| theme.chat_panel.input_editor.clone())),
+ MessageEditor::new(
+ languages.clone(),
+ channel_store.clone(),
+ cx.add_view(|cx| {
+ Editor::auto_height(
+ 4,
+ Some(Arc::new(|theme| theme.chat_panel.input_editor.clone())),
+ cx,
+ )
+ }),
cx,
- );
- editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
- editor
+ )
});
let workspace_handle = workspace.weak_handle();
@@ -121,13 +131,14 @@ impl ChatPanel {
});
let mut message_list =
- ListState::<Self>::new(0, Orientation::Bottom, 1000., move |this, ix, cx| {
+ ListState::<Self>::new(0, Orientation::Bottom, 10., move |this, ix, cx| {
this.render_message(ix, cx)
});
- message_list.set_scroll_handler(|visible_range, this, cx| {
+ message_list.set_scroll_handler(|visible_range, count, this, cx| {
if visible_range.start < MESSAGE_LOADING_THRESHOLD {
this.load_more_messages(&LoadMoreMessages, cx);
}
+ this.is_scrolled_to_bottom = visible_range.end == count;
});
cx.add_view(|cx| {
@@ -136,7 +147,6 @@ impl ChatPanel {
client,
channel_store,
languages,
-
active_chat: Default::default(),
pending_serialization: Task::ready(None),
message_list,
@@ -146,6 +156,7 @@ impl ChatPanel {
has_focus: false,
subscriptions: Vec::new(),
workspace: workspace_handle,
+ is_scrolled_to_bottom: true,
active: false,
width: None,
markdown_data: Default::default(),
@@ -179,35 +190,20 @@ impl ChatPanel {
.channel_at(selected_ix)
.map(|e| e.id);
if let Some(selected_channel_id) = selected_channel_id {
- this.select_channel(selected_channel_id, cx)
+ this.select_channel(selected_channel_id, None, cx)
.detach_and_log_err(cx);
}
})
.detach();
- let markdown = this.languages.language_for_name("Markdown");
- cx.spawn(|this, mut cx| async move {
- let markdown = markdown.await?;
-
- this.update(&mut cx, |this, cx| {
- this.input_editor.update(cx, |editor, cx| {
- editor.buffer().update(cx, |multi_buffer, cx| {
- multi_buffer
- .as_singleton()
- .unwrap()
- .update(cx, |buffer, cx| buffer.set_language(Some(markdown), cx))
- })
- })
- })?;
-
- anyhow::Ok(())
- })
- .detach_and_log_err(cx);
-
this
})
}
+ pub fn is_scrolled_to_bottom(&self) -> bool {
+ self.is_scrolled_to_bottom
+ }
+
pub fn active_chat(&self) -> Option<ModelHandle<ChannelChat>> {
self.active_chat.as_ref().map(|(chat, _)| chat.clone())
}
@@ -267,24 +263,22 @@ impl ChatPanel {
fn set_active_chat(&mut self, chat: ModelHandle<ChannelChat>, cx: &mut ViewContext<Self>) {
if self.active_chat.as_ref().map(|e| &e.0) != Some(&chat) {
- let id = chat.read(cx).channel_id;
+ let channel_id = chat.read(cx).channel_id;
{
+ self.markdown_data.clear();
let chat = chat.read(cx);
self.message_list.reset(chat.message_count());
- let placeholder = if let Some(channel) = chat.channel(cx) {
- format!("Message #{}", channel.name)
- } else {
- "Message Channel".to_string()
- };
- self.input_editor.update(cx, move |editor, cx| {
- editor.set_placeholder_text(placeholder, cx);
+
+ let channel_name = chat.channel(cx).map(|channel| channel.name.clone());
+ self.input_editor.update(cx, |editor, cx| {
+ editor.set_channel(channel_id, channel_name, cx);
});
- }
+ };
let subscription = cx.subscribe(&chat, Self::channel_did_change);
self.active_chat = Some((chat, subscription));
self.acknowledge_last_message(cx);
self.channel_select.update(cx, |select, cx| {
- if let Some(ix) = self.channel_store.read(cx).index_of_channel(id) {
+ if let Some(ix) = self.channel_store.read(cx).index_of_channel(channel_id) {
select.set_selected_index(ix, cx);
}
});
@@ -323,7 +317,7 @@ impl ChatPanel {
}
fn acknowledge_last_message(&mut self, cx: &mut ViewContext<'_, '_, ChatPanel>) {
- if self.active {
+ if self.active && self.is_scrolled_to_bottom {
if let Some((chat, _)) = &self.active_chat {
chat.update(cx, |chat, cx| {
chat.acknowledge_last_message(cx);
@@ -359,33 +353,48 @@ impl ChatPanel {
}
fn render_message(&mut self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
- let (message, is_continuation, is_last, is_admin) = {
- let active_chat = self.active_chat.as_ref().unwrap().0.read(cx);
- let is_admin = self
- .channel_store
- .read(cx)
- .is_channel_admin(active_chat.channel_id);
- let last_message = active_chat.message(ix.saturating_sub(1));
- let this_message = active_chat.message(ix);
- let is_continuation = last_message.id != this_message.id
- && this_message.sender.id == last_message.sender.id;
-
- (
- active_chat.message(ix).clone(),
- is_continuation,
- active_chat.message_count() == ix + 1,
- is_admin,
- )
- };
+ let (message, is_continuation, is_last, is_admin) = self
+ .active_chat
+ .as_ref()
+ .unwrap()
+ .0
+ .update(cx, |active_chat, cx| {
+ let is_admin = self
+ .channel_store
+ .read(cx)
+ .is_channel_admin(active_chat.channel_id);
+
+ let last_message = active_chat.message(ix.saturating_sub(1));
+ let this_message = active_chat.message(ix).clone();
+ let is_continuation = last_message.id != this_message.id
+ && this_message.sender.id == last_message.sender.id;
+
+ if let ChannelMessageId::Saved(id) = this_message.id {
+ if this_message
+ .mentions
+ .iter()
+ .any(|(_, user_id)| Some(*user_id) == self.client.user_id())
+ {
+ active_chat.acknowledge_message(id);
+ }
+ }
+
+ (
+ this_message,
+ is_continuation,
+ active_chat.message_count() == ix + 1,
+ is_admin,
+ )
+ });
let is_pending = message.is_pending();
- let text = self
- .markdown_data
- .entry(message.id)
- .or_insert_with(|| rich_text::render_markdown(message.body, &self.languages, None));
+ let theme = theme::current(cx);
+ let text = self.markdown_data.entry(message.id).or_insert_with(|| {
+ Self::render_markdown_with_mentions(&self.languages, self.client.id(), &message)
+ });
let now = OffsetDateTime::now_utc();
- let theme = theme::current(cx);
+
let style = if is_pending {
&theme.chat_panel.pending_message
} else if is_continuation {
@@ -405,14 +414,13 @@ impl ChatPanel {
enum MessageBackgroundHighlight {}
MouseEventHandler::new::<MessageBackgroundHighlight, _>(ix, cx, |state, cx| {
- let container = style.container.style_for(state);
+ let container = style.style_for(state);
if is_continuation {
Flex::row()
.with_child(
text.element(
theme.editor.syntax.clone(),
- style.body.clone(),
- theme.editor.document_highlight_read_background,
+ theme.chat_panel.rich_text.clone(),
cx,
)
.flex(1., true),
@@ -434,15 +442,16 @@ impl ChatPanel {
Flex::row()
.with_child(render_avatar(
message.sender.avatar.clone(),
- &theme,
+ &theme.chat_panel.avatar,
+ theme.chat_panel.avatar_container,
))
.with_child(
Label::new(
message.sender.github_login.clone(),
- style.sender.text.clone(),
+ theme.chat_panel.message_sender.text.clone(),
)
.contained()
- .with_style(style.sender.container),
+ .with_style(theme.chat_panel.message_sender.container),
)
.with_child(
Label::new(
@@ -451,10 +460,10 @@ impl ChatPanel {
now,
self.local_timezone,
),
- style.timestamp.text.clone(),
+ theme.chat_panel.message_timestamp.text.clone(),
)
.contained()
- .with_style(style.timestamp.container),
+ .with_style(theme.chat_panel.message_timestamp.container),
)
.align_children_center()
.flex(1., true),
@@ -467,8 +476,7 @@ impl ChatPanel {
.with_child(
text.element(
theme.editor.syntax.clone(),
- style.body.clone(),
- theme.editor.document_highlight_read_background,
+ theme.chat_panel.rich_text.clone(),
cx,
)
.flex(1., true),
@@ -489,6 +497,23 @@ impl ChatPanel {
.into_any()
}
+ fn render_markdown_with_mentions(
+ language_registry: &Arc<LanguageRegistry>,
+ current_user_id: u64,
+ message: &channel::ChannelMessage,
+ ) -> RichText {
+ let mentions = message
+ .mentions
+ .iter()
+ .map(|(range, user_id)| rich_text::Mention {
+ range: range.clone(),
+ is_self_mention: *user_id == current_user_id,
+ })
+ .collect::<Vec<_>>();
+
+ rich_text::render_markdown(message.body.clone(), &mentions, language_registry, None)
+ }
+
fn render_input_box(&self, theme: &Arc<Theme>, cx: &AppContext) -> AnyElement<Self> {
ChildView::new(&self.input_editor, cx)
.contained()
@@ -614,14 +639,12 @@ impl ChatPanel {
fn send(&mut self, _: &Confirm, cx: &mut ViewContext<Self>) {
if let Some((chat, _)) = self.active_chat.as_ref() {
- let body = self.input_editor.update(cx, |editor, cx| {
- let body = editor.text(cx);
- editor.clear(cx);
- body
- });
+ let message = self
+ .input_editor
+ .update(cx, |editor, cx| editor.take_message(cx));
if let Some(task) = chat
- .update(cx, |chat, cx| chat.send_message(body, cx))
+ .update(cx, |chat, cx| chat.send_message(message, cx))
.log_err()
{
task.detach();
@@ -638,7 +661,9 @@ impl ChatPanel {
fn load_more_messages(&mut self, _: &LoadMoreMessages, cx: &mut ViewContext<Self>) {
if let Some((chat, _)) = self.active_chat.as_ref() {
chat.update(cx, |channel, cx| {
- channel.load_more_messages(cx);
+ if let Some(task) = channel.load_more_messages(cx) {
+ task.detach();
+ }
})
}
}
@@ -646,23 +671,46 @@ impl ChatPanel {
pub fn select_channel(
&mut self,
selected_channel_id: u64,
+ scroll_to_message_id: Option<u64>,
cx: &mut ViewContext<ChatPanel>,
) -> Task<Result<()>> {
- if let Some((chat, _)) = &self.active_chat {
- if chat.read(cx).channel_id == selected_channel_id {
- return Task::ready(Ok(()));
- }
- }
+ let open_chat = self
+ .active_chat
+ .as_ref()
+ .and_then(|(chat, _)| {
+ (chat.read(cx).channel_id == selected_channel_id)
+ .then(|| Task::ready(anyhow::Ok(chat.clone())))
+ })
+ .unwrap_or_else(|| {
+ self.channel_store.update(cx, |store, cx| {
+ store.open_channel_chat(selected_channel_id, cx)
+ })
+ });
- let open_chat = self.channel_store.update(cx, |store, cx| {
- store.open_channel_chat(selected_channel_id, cx)
- });
cx.spawn(|this, mut cx| async move {
let chat = open_chat.await?;
this.update(&mut cx, |this, cx| {
- this.markdown_data = Default::default();
- this.set_active_chat(chat, cx);
- })
+ this.set_active_chat(chat.clone(), cx);
+ })?;
+
+ if let Some(message_id) = scroll_to_message_id {
+ if let Some(item_ix) =
+ ChannelChat::load_history_since_message(chat.clone(), message_id, cx.clone())
+ .await
+ {
+ this.update(&mut cx, |this, cx| {
+ if this.active_chat.as_ref().map_or(false, |(c, _)| *c == chat) {
+ this.message_list.scroll_to(ListOffset {
+ item_ix,
+ offset_in_item: 0.,
+ });
+ cx.notify();
+ }
+ })?;
+ }
+ }
+
+ Ok(())
})
}
@@ -685,32 +733,6 @@ impl ChatPanel {
}
}
-fn render_avatar(avatar: Option<Arc<ImageData>>, theme: &Arc<Theme>) -> AnyElement<ChatPanel> {
- let avatar_style = theme.chat_panel.avatar;
-
- avatar
- .map(|avatar| {
- Image::from_data(avatar)
- .with_style(avatar_style.image)
- .aligned()
- .contained()
- .with_corner_radius(avatar_style.outer_corner_radius)
- .constrained()
- .with_width(avatar_style.outer_width)
- .with_height(avatar_style.outer_width)
- .into_any()
- })
- .unwrap_or_else(|| {
- Empty::new()
- .constrained()
- .with_width(avatar_style.outer_width)
- .into_any()
- })
- .contained()
- .with_style(theme.chat_panel.avatar_container)
- .into_any()
-}
-
fn render_remove(
message_id_to_remove: Option<u64>,
cx: &mut ViewContext<'_, '_, ChatPanel>,
@@ -781,7 +803,8 @@ impl View for ChatPanel {
*self.client.status().borrow(),
client::Status::Connected { .. }
) {
- cx.focus(&self.input_editor);
+ let editor = self.input_editor.read(cx).editor.clone();
+ cx.focus(&editor);
}
}
@@ -820,14 +843,14 @@ impl Panel for ChatPanel {
self.active = active;
if active {
self.acknowledge_last_message(cx);
- if !is_chat_feature_enabled(cx) {
+ if !is_channels_feature_enabled(cx) {
cx.emit(Event::Dismissed);
}
}
}
fn icon_path(&self, cx: &gpui::WindowContext) -> Option<&'static str> {
- (settings::get::<ChatPanelSettings>(cx).button && is_chat_feature_enabled(cx))
+ (settings::get::<ChatPanelSettings>(cx).button && is_channels_feature_enabled(cx))
.then(|| "icons/conversations.svg")
}
@@ -852,10 +875,6 @@ impl Panel for ChatPanel {
}
}
-fn is_chat_feature_enabled(cx: &gpui::WindowContext<'_>) -> bool {
- cx.is_staff() || cx.has_flag::<ChannelsAlpha>()
-}
-
fn format_timestamp(
mut timestamp: OffsetDateTime,
mut now: OffsetDateTime,
@@ -893,3 +912,72 @@ fn render_icon_button<V: View>(style: &IconButton, svg_path: &'static str) -> im
.contained()
.with_style(style.container)
}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use gpui::fonts::HighlightStyle;
+ use pretty_assertions::assert_eq;
+ use rich_text::{BackgroundKind, Highlight, RenderedRegion};
+ use util::test::marked_text_ranges;
+
+ #[gpui::test]
+ fn test_render_markdown_with_mentions() {
+ let language_registry = Arc::new(LanguageRegistry::test());
+ let (body, ranges) = marked_text_ranges("*hi*, Β«@abcΒ», let's **call** Β«@fghΒ»", false);
+ let message = channel::ChannelMessage {
+ id: ChannelMessageId::Saved(0),
+ body,
+ timestamp: OffsetDateTime::now_utc(),
+ sender: Arc::new(client::User {
+ github_login: "fgh".into(),
+ avatar: None,
+ id: 103,
+ }),
+ nonce: 5,
+ mentions: vec![(ranges[0].clone(), 101), (ranges[1].clone(), 102)],
+ };
+
+ let message = ChatPanel::render_markdown_with_mentions(&language_registry, 102, &message);
+
+ // Note that the "'" was replaced with β due to smart punctuation.
+ let (body, ranges) = marked_text_ranges("Β«hiΒ», Β«@abcΒ», letβs Β«callΒ» Β«@fghΒ»", false);
+ assert_eq!(message.text, body);
+ assert_eq!(
+ message.highlights,
+ vec![
+ (
+ ranges[0].clone(),
+ HighlightStyle {
+ italic: Some(true),
+ ..Default::default()
+ }
+ .into()
+ ),
+ (ranges[1].clone(), Highlight::Mention),
+ (
+ ranges[2].clone(),
+ HighlightStyle {
+ weight: Some(gpui::fonts::Weight::BOLD),
+ ..Default::default()
+ }
+ .into()
+ ),
+ (ranges[3].clone(), Highlight::SelfMention)
+ ]
+ );
+ assert_eq!(
+ message.regions,
+ vec![
+ RenderedRegion {
+ background_kind: Some(BackgroundKind::Mention),
+ link_url: None
+ },
+ RenderedRegion {
+ background_kind: Some(BackgroundKind::SelfMention),
+ link_url: None
+ },
+ ]
+ );
+ }
+}
@@ -0,0 +1,313 @@
+use channel::{ChannelId, ChannelMembership, ChannelStore, MessageParams};
+use client::UserId;
+use collections::HashMap;
+use editor::{AnchorRangeExt, Editor};
+use gpui::{
+ elements::ChildView, AnyElement, AsyncAppContext, Element, Entity, ModelHandle, Task, View,
+ ViewContext, ViewHandle, WeakViewHandle,
+};
+use language::{language_settings::SoftWrap, Buffer, BufferSnapshot, LanguageRegistry};
+use lazy_static::lazy_static;
+use project::search::SearchQuery;
+use std::{sync::Arc, time::Duration};
+
+const MENTIONS_DEBOUNCE_INTERVAL: Duration = Duration::from_millis(50);
+
+lazy_static! {
+ static ref MENTIONS_SEARCH: SearchQuery = SearchQuery::regex(
+ "@[-_\\w]+",
+ false,
+ false,
+ Default::default(),
+ Default::default()
+ )
+ .unwrap();
+}
+
+pub struct MessageEditor {
+ pub editor: ViewHandle<Editor>,
+ channel_store: ModelHandle<ChannelStore>,
+ users: HashMap<String, UserId>,
+ mentions: Vec<UserId>,
+ mentions_task: Option<Task<()>>,
+ channel_id: Option<ChannelId>,
+}
+
+impl MessageEditor {
+ pub fn new(
+ language_registry: Arc<LanguageRegistry>,
+ channel_store: ModelHandle<ChannelStore>,
+ editor: ViewHandle<Editor>,
+ cx: &mut ViewContext<Self>,
+ ) -> Self {
+ editor.update(cx, |editor, cx| {
+ editor.set_soft_wrap_mode(SoftWrap::EditorWidth, cx);
+ });
+
+ let buffer = editor
+ .read(cx)
+ .buffer()
+ .read(cx)
+ .as_singleton()
+ .expect("message editor must be singleton");
+
+ cx.subscribe(&buffer, Self::on_buffer_event).detach();
+
+ let markdown = language_registry.language_for_name("Markdown");
+ cx.app_context()
+ .spawn(|mut cx| async move {
+ let markdown = markdown.await?;
+ buffer.update(&mut cx, |buffer, cx| {
+ buffer.set_language(Some(markdown), cx)
+ });
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+
+ Self {
+ editor,
+ channel_store,
+ users: HashMap::default(),
+ channel_id: None,
+ mentions: Vec::new(),
+ mentions_task: None,
+ }
+ }
+
+ pub fn set_channel(
+ &mut self,
+ channel_id: u64,
+ channel_name: Option<String>,
+ cx: &mut ViewContext<Self>,
+ ) {
+ self.editor.update(cx, |editor, cx| {
+ if let Some(channel_name) = channel_name {
+ editor.set_placeholder_text(format!("Message #{}", channel_name), cx);
+ } else {
+ editor.set_placeholder_text(format!("Message Channel"), cx);
+ }
+ });
+ self.channel_id = Some(channel_id);
+ self.refresh_users(cx);
+ }
+
+ pub fn refresh_users(&mut self, cx: &mut ViewContext<Self>) {
+ if let Some(channel_id) = self.channel_id {
+ let members = self.channel_store.update(cx, |store, cx| {
+ store.get_channel_member_details(channel_id, cx)
+ });
+ cx.spawn(|this, mut cx| async move {
+ let members = members.await?;
+ this.update(&mut cx, |this, cx| this.set_members(members, cx))?;
+ anyhow::Ok(())
+ })
+ .detach_and_log_err(cx);
+ }
+ }
+
+ pub fn set_members(&mut self, members: Vec<ChannelMembership>, _: &mut ViewContext<Self>) {
+ self.users.clear();
+ self.users.extend(
+ members
+ .into_iter()
+ .map(|member| (member.user.github_login.clone(), member.user.id)),
+ );
+ }
+
+ pub fn take_message(&mut self, cx: &mut ViewContext<Self>) -> MessageParams {
+ self.editor.update(cx, |editor, cx| {
+ let highlights = editor.text_highlights::<Self>(cx);
+ let text = editor.text(cx);
+ let snapshot = editor.buffer().read(cx).snapshot(cx);
+ let mentions = if let Some((_, ranges)) = highlights {
+ ranges
+ .iter()
+ .map(|range| range.to_offset(&snapshot))
+ .zip(self.mentions.iter().copied())
+ .collect()
+ } else {
+ Vec::new()
+ };
+
+ editor.clear(cx);
+ self.mentions.clear();
+
+ MessageParams { text, mentions }
+ })
+ }
+
+ fn on_buffer_event(
+ &mut self,
+ buffer: ModelHandle<Buffer>,
+ event: &language::Event,
+ cx: &mut ViewContext<Self>,
+ ) {
+ if let language::Event::Reparsed | language::Event::Edited = event {
+ let buffer = buffer.read(cx).snapshot();
+ self.mentions_task = Some(cx.spawn(|this, cx| async move {
+ cx.background().timer(MENTIONS_DEBOUNCE_INTERVAL).await;
+ Self::find_mentions(this, buffer, cx).await;
+ }));
+ }
+ }
+
+ async fn find_mentions(
+ this: WeakViewHandle<MessageEditor>,
+ buffer: BufferSnapshot,
+ mut cx: AsyncAppContext,
+ ) {
+ let (buffer, ranges) = cx
+ .background()
+ .spawn(async move {
+ let ranges = MENTIONS_SEARCH.search(&buffer, None).await;
+ (buffer, ranges)
+ })
+ .await;
+
+ this.update(&mut cx, |this, cx| {
+ let mut anchor_ranges = Vec::new();
+ let mut mentioned_user_ids = Vec::new();
+ let mut text = String::new();
+
+ this.editor.update(cx, |editor, cx| {
+ let multi_buffer = editor.buffer().read(cx).snapshot(cx);
+ for range in ranges {
+ text.clear();
+ text.extend(buffer.text_for_range(range.clone()));
+ if let Some(username) = text.strip_prefix("@") {
+ if let Some(user_id) = this.users.get(username) {
+ let start = multi_buffer.anchor_after(range.start);
+ let end = multi_buffer.anchor_after(range.end);
+
+ mentioned_user_ids.push(*user_id);
+ anchor_ranges.push(start..end);
+ }
+ }
+ }
+
+ editor.clear_highlights::<Self>(cx);
+ editor.highlight_text::<Self>(
+ anchor_ranges,
+ theme::current(cx).chat_panel.rich_text.mention_highlight,
+ cx,
+ )
+ });
+
+ this.mentions = mentioned_user_ids;
+ this.mentions_task.take();
+ })
+ .ok();
+ }
+}
+
+impl Entity for MessageEditor {
+ type Event = ();
+}
+
+impl View for MessageEditor {
+ fn render(&mut self, cx: &mut ViewContext<'_, '_, Self>) -> AnyElement<Self> {
+ ChildView::new(&self.editor, cx).into_any()
+ }
+
+ fn focus_in(&mut self, _: gpui::AnyViewHandle, cx: &mut ViewContext<Self>) {
+ if cx.is_self_focused() {
+ cx.focus(&self.editor);
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use client::{Client, User, UserStore};
+ use gpui::{TestAppContext, WindowHandle};
+ use language::{Language, LanguageConfig};
+ use rpc::proto;
+ use settings::SettingsStore;
+ use util::{http::FakeHttpClient, test::marked_text_ranges};
+
+ #[gpui::test]
+ async fn test_message_editor(cx: &mut TestAppContext) {
+ let editor = init_test(cx);
+ let editor = editor.root(cx);
+
+ editor.update(cx, |editor, cx| {
+ editor.set_members(
+ vec![
+ ChannelMembership {
+ user: Arc::new(User {
+ github_login: "a-b".into(),
+ id: 101,
+ avatar: None,
+ }),
+ kind: proto::channel_member::Kind::Member,
+ role: proto::ChannelRole::Member,
+ },
+ ChannelMembership {
+ user: Arc::new(User {
+ github_login: "C_D".into(),
+ id: 102,
+ avatar: None,
+ }),
+ kind: proto::channel_member::Kind::Member,
+ role: proto::ChannelRole::Member,
+ },
+ ],
+ cx,
+ );
+
+ editor.editor.update(cx, |editor, cx| {
+ editor.set_text("Hello, @a-b! Have you met @C_D?", cx)
+ });
+ });
+
+ cx.foreground().advance_clock(MENTIONS_DEBOUNCE_INTERVAL);
+
+ editor.update(cx, |editor, cx| {
+ let (text, ranges) = marked_text_ranges("Hello, Β«@a-bΒ»! Have you met Β«@C_DΒ»?", false);
+ assert_eq!(
+ editor.take_message(cx),
+ MessageParams {
+ text,
+ mentions: vec![(ranges[0].clone(), 101), (ranges[1].clone(), 102)],
+ }
+ );
+ });
+ }
+
+ fn init_test(cx: &mut TestAppContext) -> WindowHandle<MessageEditor> {
+ cx.foreground().forbid_parking();
+
+ cx.update(|cx| {
+ let http = FakeHttpClient::with_404_response();
+ let client = Client::new(http.clone(), cx);
+ let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx));
+ cx.set_global(SettingsStore::test(cx));
+ theme::init((), cx);
+ language::init(cx);
+ editor::init(cx);
+ client::init(&client, cx);
+ channel::init(&client, user_store, cx);
+ });
+
+ let language_registry = Arc::new(LanguageRegistry::test());
+ language_registry.add(Arc::new(Language::new(
+ LanguageConfig {
+ name: "Markdown".into(),
+ ..Default::default()
+ },
+ Some(tree_sitter_markdown::language()),
+ )));
+
+ let editor = cx.add_window(|cx| {
+ MessageEditor::new(
+ language_registry,
+ ChannelStore::global(cx),
+ cx.add_view(|cx| Editor::auto_height(4, None, cx)),
+ cx,
+ )
+ });
+ cx.foreground().run_until_parked();
+ editor
+ }
+}
@@ -3220,10 +3220,11 @@ impl CollabPanel {
accept: bool,
cx: &mut ViewContext<Self>,
) {
- let respond = self.channel_store.update(cx, |store, _| {
- store.respond_to_channel_invite(channel_id, accept)
- });
- cx.foreground().spawn(respond).detach();
+ self.channel_store
+ .update(cx, |store, cx| {
+ store.respond_to_channel_invite(channel_id, accept, cx)
+ })
+ .detach();
}
fn call(
@@ -3262,7 +3263,9 @@ impl CollabPanel {
workspace.update(cx, |workspace, cx| {
if let Some(panel) = workspace.focus_panel::<ChatPanel>(cx) {
panel.update(cx, |panel, cx| {
- panel.select_channel(channel_id, cx).detach_and_log_err(cx);
+ panel
+ .select_channel(channel_id, None, cx)
+ .detach_and_log_err(cx);
});
}
});
@@ -1,10 +1,10 @@
use crate::{
- contact_notification::ContactNotification, face_pile::FacePile, toggle_deafen, toggle_mute,
- toggle_screen_sharing, LeaveCall, ToggleDeafen, ToggleMute, ToggleScreenSharing,
+ face_pile::FacePile, toggle_deafen, toggle_mute, toggle_screen_sharing, LeaveCall,
+ ToggleDeafen, ToggleMute, ToggleScreenSharing,
};
use auto_update::AutoUpdateStatus;
use call::{ActiveCall, ParticipantLocation, Room};
-use client::{proto::PeerId, Client, ContactEventKind, SignIn, SignOut, User, UserStore};
+use client::{proto::PeerId, Client, SignIn, SignOut, User, UserStore};
use clock::ReplicaId;
use context_menu::{ContextMenu, ContextMenuItem};
use gpui::{
@@ -158,28 +158,6 @@ impl CollabTitlebarItem {
this.window_activation_changed(active, cx)
}));
subscriptions.push(cx.observe(&user_store, |_, _, cx| cx.notify()));
- subscriptions.push(
- cx.subscribe(&user_store, move |this, user_store, event, cx| {
- if let Some(workspace) = this.workspace.upgrade(cx) {
- workspace.update(cx, |workspace, cx| {
- if let client::Event::Contact { user, kind } = event {
- if let ContactEventKind::Requested | ContactEventKind::Accepted = kind {
- workspace.show_notification(user.id as usize, cx, |cx| {
- cx.add_view(|cx| {
- ContactNotification::new(
- user.clone(),
- *kind,
- user_store,
- cx,
- )
- })
- })
- }
- }
- });
- }
- }),
- );
Self {
workspace: workspace.weak_handle(),
@@ -495,7 +473,11 @@ impl CollabTitlebarItem {
pub fn toggle_vcs_menu(&mut self, _: &ToggleVcsMenu, cx: &mut ViewContext<Self>) {
if self.branch_popover.take().is_none() {
if let Some(workspace) = self.workspace.upgrade(cx) {
- let view = cx.add_view(|cx| build_branch_list(workspace, cx));
+ let Some(view) =
+ cx.add_option_view(|cx| build_branch_list(workspace, cx).log_err())
+ else {
+ return;
+ };
cx.subscribe(&view, |this, _, event, cx| {
match event {
PickerEvent::Dismiss => {
@@ -2,30 +2,32 @@ pub mod channel_view;
pub mod chat_panel;
pub mod collab_panel;
mod collab_titlebar_item;
-mod contact_notification;
mod face_pile;
-mod incoming_call_notification;
-mod notifications;
+pub mod notification_panel;
+pub mod notifications;
mod panel_settings;
-pub mod project_shared_notification;
-mod sharing_status_indicator;
use call::{report_call_event_for_room, ActiveCall, Room};
+use feature_flags::{ChannelsAlpha, FeatureFlagAppExt};
use gpui::{
actions,
+ elements::{ContainerStyle, Empty, Image},
geometry::{
rect::RectF,
vector::{vec2f, Vector2F},
},
platform::{Screen, WindowBounds, WindowKind, WindowOptions},
- AppContext, Task,
+ AnyElement, AppContext, Element, ImageData, Task,
};
use std::{rc::Rc, sync::Arc};
+use theme::AvatarStyle;
use util::ResultExt;
use workspace::AppState;
pub use collab_titlebar_item::CollabTitlebarItem;
-pub use panel_settings::{ChatPanelSettings, CollaborationPanelSettings};
+pub use panel_settings::{
+ ChatPanelSettings, CollaborationPanelSettings, NotificationPanelSettings,
+};
actions!(
collab,
@@ -35,14 +37,13 @@ actions!(
pub fn init(app_state: &Arc<AppState>, cx: &mut AppContext) {
settings::register::<CollaborationPanelSettings>(cx);
settings::register::<ChatPanelSettings>(cx);
+ settings::register::<NotificationPanelSettings>(cx);
vcs_menu::init(cx);
collab_titlebar_item::init(cx);
collab_panel::init(cx);
chat_panel::init(cx);
- incoming_call_notification::init(&app_state, cx);
- project_shared_notification::init(&app_state, cx);
- sharing_status_indicator::init(cx);
+ notifications::init(&app_state, cx);
cx.add_global_action(toggle_screen_sharing);
cx.add_global_action(toggle_mute);
@@ -130,3 +131,35 @@ fn notification_window_options(
screen: Some(screen),
}
}
+
+fn render_avatar<T: 'static>(
+ avatar: Option<Arc<ImageData>>,
+ avatar_style: &AvatarStyle,
+ container: ContainerStyle,
+) -> AnyElement<T> {
+ avatar
+ .map(|avatar| {
+ Image::from_data(avatar)
+ .with_style(avatar_style.image)
+ .aligned()
+ .contained()
+ .with_corner_radius(avatar_style.outer_corner_radius)
+ .constrained()
+ .with_width(avatar_style.outer_width)
+ .with_height(avatar_style.outer_width)
+ .into_any()
+ })
+ .unwrap_or_else(|| {
+ Empty::new()
+ .constrained()
+ .with_width(avatar_style.outer_width)
+ .into_any()
+ })
+ .contained()
+ .with_style(container)
+ .into_any()
+}
+
+fn is_channels_feature_enabled(cx: &gpui::WindowContext<'_>) -> bool {
+ cx.is_staff() || cx.has_flag::<ChannelsAlpha>()
+}
@@ -1,121 +0,0 @@
-use std::sync::Arc;
-
-use crate::notifications::render_user_notification;
-use client::{ContactEventKind, User, UserStore};
-use gpui::{elements::*, Entity, ModelHandle, View, ViewContext};
-use workspace::notifications::Notification;
-
-pub struct ContactNotification {
- user_store: ModelHandle<UserStore>,
- user: Arc<User>,
- kind: client::ContactEventKind,
-}
-
-#[derive(Clone, PartialEq)]
-struct Dismiss(u64);
-
-#[derive(Clone, PartialEq)]
-pub struct RespondToContactRequest {
- pub user_id: u64,
- pub accept: bool,
-}
-
-pub enum Event {
- Dismiss,
-}
-
-impl Entity for ContactNotification {
- type Event = Event;
-}
-
-impl View for ContactNotification {
- fn ui_name() -> &'static str {
- "ContactNotification"
- }
-
- fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
- match self.kind {
- ContactEventKind::Requested => render_user_notification(
- self.user.clone(),
- "wants to add you as a contact",
- Some("They won't be alerted if you decline."),
- |notification, cx| notification.dismiss(cx),
- vec![
- (
- "Decline",
- Box::new(|notification, cx| {
- notification.respond_to_contact_request(false, cx)
- }),
- ),
- (
- "Accept",
- Box::new(|notification, cx| {
- notification.respond_to_contact_request(true, cx)
- }),
- ),
- ],
- cx,
- ),
- ContactEventKind::Accepted => render_user_notification(
- self.user.clone(),
- "accepted your contact request",
- None,
- |notification, cx| notification.dismiss(cx),
- vec![],
- cx,
- ),
- _ => unreachable!(),
- }
- }
-}
-
-impl Notification for ContactNotification {
- fn should_dismiss_notification_on_event(&self, event: &<Self as Entity>::Event) -> bool {
- matches!(event, Event::Dismiss)
- }
-}
-
-impl ContactNotification {
- pub fn new(
- user: Arc<User>,
- kind: client::ContactEventKind,
- user_store: ModelHandle<UserStore>,
- cx: &mut ViewContext<Self>,
- ) -> Self {
- cx.subscribe(&user_store, move |this, _, event, cx| {
- if let client::Event::Contact {
- kind: ContactEventKind::Cancelled,
- user,
- } = event
- {
- if user.id == this.user.id {
- cx.emit(Event::Dismiss);
- }
- }
- })
- .detach();
-
- Self {
- user,
- kind,
- user_store,
- }
- }
-
- fn dismiss(&mut self, cx: &mut ViewContext<Self>) {
- self.user_store.update(cx, |store, cx| {
- store
- .dismiss_contact_request(self.user.id, cx)
- .detach_and_log_err(cx);
- });
- cx.emit(Event::Dismiss);
- }
-
- fn respond_to_contact_request(&mut self, accept: bool, cx: &mut ViewContext<Self>) {
- self.user_store
- .update(cx, |store, cx| {
- store.respond_to_contact_request(self.user.id, accept, cx)
- })
- .detach();
- }
-}
@@ -0,0 +1,884 @@
+use crate::{chat_panel::ChatPanel, render_avatar, NotificationPanelSettings};
+use anyhow::Result;
+use channel::ChannelStore;
+use client::{Client, Notification, User, UserStore};
+use collections::HashMap;
+use db::kvp::KEY_VALUE_STORE;
+use futures::StreamExt;
+use gpui::{
+ actions,
+ elements::*,
+ platform::{CursorStyle, MouseButton},
+ serde_json, AnyViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Task, View,
+ ViewContext, ViewHandle, WeakViewHandle, WindowContext,
+};
+use notifications::{NotificationEntry, NotificationEvent, NotificationStore};
+use project::Fs;
+use rpc::proto;
+use serde::{Deserialize, Serialize};
+use settings::SettingsStore;
+use std::{sync::Arc, time::Duration};
+use theme::{ui, Theme};
+use time::{OffsetDateTime, UtcOffset};
+use util::{ResultExt, TryFutureExt};
+use workspace::{
+ dock::{DockPosition, Panel},
+ Workspace,
+};
+
+const LOADING_THRESHOLD: usize = 30;
+const MARK_AS_READ_DELAY: Duration = Duration::from_secs(1);
+const TOAST_DURATION: Duration = Duration::from_secs(5);
+const NOTIFICATION_PANEL_KEY: &'static str = "NotificationPanel";
+
+pub struct NotificationPanel {
+ client: Arc<Client>,
+ user_store: ModelHandle<UserStore>,
+ channel_store: ModelHandle<ChannelStore>,
+ notification_store: ModelHandle<NotificationStore>,
+ fs: Arc<dyn Fs>,
+ width: Option<f32>,
+ active: bool,
+ notification_list: ListState<Self>,
+ pending_serialization: Task<Option<()>>,
+ subscriptions: Vec<gpui::Subscription>,
+ workspace: WeakViewHandle<Workspace>,
+ current_notification_toast: Option<(u64, Task<()>)>,
+ local_timezone: UtcOffset,
+ has_focus: bool,
+ mark_as_read_tasks: HashMap<u64, Task<Result<()>>>,
+}
+
+#[derive(Serialize, Deserialize)]
+struct SerializedNotificationPanel {
+ width: Option<f32>,
+}
+
+#[derive(Debug)]
+pub enum Event {
+ DockPositionChanged,
+ Focus,
+ Dismissed,
+}
+
+pub struct NotificationPresenter {
+ pub actor: Option<Arc<client::User>>,
+ pub text: String,
+ pub icon: &'static str,
+ pub needs_response: bool,
+ pub can_navigate: bool,
+}
+
+actions!(notification_panel, [ToggleFocus]);
+
+pub fn init(_cx: &mut AppContext) {}
+
+impl NotificationPanel {
+ pub fn new(workspace: &mut Workspace, cx: &mut ViewContext<Workspace>) -> ViewHandle<Self> {
+ let fs = workspace.app_state().fs.clone();
+ let client = workspace.app_state().client.clone();
+ let user_store = workspace.app_state().user_store.clone();
+ let workspace_handle = workspace.weak_handle();
+
+ cx.add_view(|cx| {
+ let mut status = client.status();
+ cx.spawn(|this, mut cx| async move {
+ while let Some(_) = status.next().await {
+ if this
+ .update(&mut cx, |_, cx| {
+ cx.notify();
+ })
+ .is_err()
+ {
+ break;
+ }
+ }
+ })
+ .detach();
+
+ let mut notification_list =
+ ListState::<Self>::new(0, Orientation::Top, 1000., move |this, ix, cx| {
+ this.render_notification(ix, cx)
+ .unwrap_or_else(|| Empty::new().into_any())
+ });
+ notification_list.set_scroll_handler(|visible_range, count, this, cx| {
+ if count.saturating_sub(visible_range.end) < LOADING_THRESHOLD {
+ if let Some(task) = this
+ .notification_store
+ .update(cx, |store, cx| store.load_more_notifications(false, cx))
+ {
+ task.detach();
+ }
+ }
+ });
+
+ let mut this = Self {
+ fs,
+ client,
+ user_store,
+ local_timezone: cx.platform().local_timezone(),
+ channel_store: ChannelStore::global(cx),
+ notification_store: NotificationStore::global(cx),
+ notification_list,
+ pending_serialization: Task::ready(None),
+ workspace: workspace_handle,
+ has_focus: false,
+ current_notification_toast: None,
+ subscriptions: Vec::new(),
+ active: false,
+ mark_as_read_tasks: HashMap::default(),
+ width: None,
+ };
+
+ let mut old_dock_position = this.position(cx);
+ this.subscriptions.extend([
+ cx.observe(&this.notification_store, |_, _, cx| cx.notify()),
+ cx.subscribe(&this.notification_store, Self::on_notification_event),
+ cx.observe_global::<SettingsStore, _>(move |this: &mut Self, cx| {
+ let new_dock_position = this.position(cx);
+ if new_dock_position != old_dock_position {
+ old_dock_position = new_dock_position;
+ cx.emit(Event::DockPositionChanged);
+ }
+ cx.notify();
+ }),
+ ]);
+ this
+ })
+ }
+
+ pub fn load(
+ workspace: WeakViewHandle<Workspace>,
+ cx: AsyncAppContext,
+ ) -> Task<Result<ViewHandle<Self>>> {
+ cx.spawn(|mut cx| async move {
+ let serialized_panel = if let Some(panel) = cx
+ .background()
+ .spawn(async move { KEY_VALUE_STORE.read_kvp(NOTIFICATION_PANEL_KEY) })
+ .await
+ .log_err()
+ .flatten()
+ {
+ Some(serde_json::from_str::<SerializedNotificationPanel>(&panel)?)
+ } else {
+ None
+ };
+
+ workspace.update(&mut cx, |workspace, cx| {
+ let panel = Self::new(workspace, cx);
+ if let Some(serialized_panel) = serialized_panel {
+ panel.update(cx, |panel, cx| {
+ panel.width = serialized_panel.width;
+ cx.notify();
+ });
+ }
+ panel
+ })
+ })
+ }
+
+ fn serialize(&mut self, cx: &mut ViewContext<Self>) {
+ let width = self.width;
+ self.pending_serialization = cx.background().spawn(
+ async move {
+ KEY_VALUE_STORE
+ .write_kvp(
+ NOTIFICATION_PANEL_KEY.into(),
+ serde_json::to_string(&SerializedNotificationPanel { width })?,
+ )
+ .await?;
+ anyhow::Ok(())
+ }
+ .log_err(),
+ );
+ }
+
+ fn render_notification(
+ &mut self,
+ ix: usize,
+ cx: &mut ViewContext<Self>,
+ ) -> Option<AnyElement<Self>> {
+ let entry = self.notification_store.read(cx).notification_at(ix)?;
+ let notification_id = entry.id;
+ let now = OffsetDateTime::now_utc();
+ let timestamp = entry.timestamp;
+ let NotificationPresenter {
+ actor,
+ text,
+ needs_response,
+ can_navigate,
+ ..
+ } = self.present_notification(entry, cx)?;
+
+ let theme = theme::current(cx);
+ let style = &theme.notification_panel;
+ let response = entry.response;
+ let notification = entry.notification.clone();
+
+ let message_style = if entry.is_read {
+ style.read_text.clone()
+ } else {
+ style.unread_text.clone()
+ };
+
+ if self.active && !entry.is_read {
+ self.did_render_notification(notification_id, ¬ification, cx);
+ }
+
+ enum Decline {}
+ enum Accept {}
+
+ Some(
+ MouseEventHandler::new::<NotificationEntry, _>(ix, cx, |_, cx| {
+ let container = message_style.container;
+
+ Flex::row()
+ .with_children(actor.map(|actor| {
+ render_avatar(actor.avatar.clone(), &style.avatar, style.avatar_container)
+ }))
+ .with_child(
+ Flex::column()
+ .with_child(Text::new(text, message_style.text.clone()))
+ .with_child(
+ Flex::row()
+ .with_child(
+ Label::new(
+ format_timestamp(timestamp, now, self.local_timezone),
+ style.timestamp.text.clone(),
+ )
+ .contained()
+ .with_style(style.timestamp.container),
+ )
+ .with_children(if let Some(is_accepted) = response {
+ Some(
+ Label::new(
+ if is_accepted {
+ "You accepted"
+ } else {
+ "You declined"
+ },
+ style.read_text.text.clone(),
+ )
+ .flex_float()
+ .into_any(),
+ )
+ } else if needs_response {
+ Some(
+ Flex::row()
+ .with_children([
+ MouseEventHandler::new::<Decline, _>(
+ ix,
+ cx,
+ |state, _| {
+ let button =
+ style.button.style_for(state);
+ Label::new(
+ "Decline",
+ button.text.clone(),
+ )
+ .contained()
+ .with_style(button.container)
+ },
+ )
+ .with_cursor_style(CursorStyle::PointingHand)
+ .on_click(MouseButton::Left, {
+ let notification = notification.clone();
+ move |_, view, cx| {
+ view.respond_to_notification(
+ notification.clone(),
+ false,
+ cx,
+ );
+ }
+ }),
+ MouseEventHandler::new::<Accept, _>(
+ ix,
+ cx,
+ |state, _| {
+ let button =
+ style.button.style_for(state);
+ Label::new(
+ "Accept",
+ button.text.clone(),
+ )
+ .contained()
+ .with_style(button.container)
+ },
+ )
+ .with_cursor_style(CursorStyle::PointingHand)
+ .on_click(MouseButton::Left, {
+ let notification = notification.clone();
+ move |_, view, cx| {
+ view.respond_to_notification(
+ notification.clone(),
+ true,
+ cx,
+ );
+ }
+ }),
+ ])
+ .flex_float()
+ .into_any(),
+ )
+ } else {
+ None
+ }),
+ )
+ .flex(1.0, true),
+ )
+ .contained()
+ .with_style(container)
+ .into_any()
+ })
+ .with_cursor_style(if can_navigate {
+ CursorStyle::PointingHand
+ } else {
+ CursorStyle::default()
+ })
+ .on_click(MouseButton::Left, {
+ let notification = notification.clone();
+ move |_, this, cx| this.did_click_notification(¬ification, cx)
+ })
+ .into_any(),
+ )
+ }
+
+ fn present_notification(
+ &self,
+ entry: &NotificationEntry,
+ cx: &AppContext,
+ ) -> Option<NotificationPresenter> {
+ let user_store = self.user_store.read(cx);
+ let channel_store = self.channel_store.read(cx);
+ match entry.notification {
+ Notification::ContactRequest { sender_id } => {
+ let requester = user_store.get_cached_user(sender_id)?;
+ Some(NotificationPresenter {
+ icon: "icons/plus.svg",
+ text: format!("{} wants to add you as a contact", requester.github_login),
+ needs_response: user_store.has_incoming_contact_request(requester.id),
+ actor: Some(requester),
+ can_navigate: false,
+ })
+ }
+ Notification::ContactRequestAccepted { responder_id } => {
+ let responder = user_store.get_cached_user(responder_id)?;
+ Some(NotificationPresenter {
+ icon: "icons/plus.svg",
+ text: format!("{} accepted your contact invite", responder.github_login),
+ needs_response: false,
+ actor: Some(responder),
+ can_navigate: false,
+ })
+ }
+ Notification::ChannelInvitation {
+ ref channel_name,
+ channel_id,
+ inviter_id,
+ } => {
+ let inviter = user_store.get_cached_user(inviter_id)?;
+ Some(NotificationPresenter {
+ icon: "icons/hash.svg",
+ text: format!(
+ "{} invited you to join the #{channel_name} channel",
+ inviter.github_login
+ ),
+ needs_response: channel_store.has_channel_invitation(channel_id),
+ actor: Some(inviter),
+ can_navigate: false,
+ })
+ }
+ Notification::ChannelMessageMention {
+ sender_id,
+ channel_id,
+ message_id,
+ } => {
+ let sender = user_store.get_cached_user(sender_id)?;
+ let channel = channel_store.channel_for_id(channel_id)?;
+ let message = self
+ .notification_store
+ .read(cx)
+ .channel_message_for_id(message_id)?;
+ Some(NotificationPresenter {
+ icon: "icons/conversations.svg",
+ text: format!(
+ "{} mentioned you in #{}:\n{}",
+ sender.github_login, channel.name, message.body,
+ ),
+ needs_response: false,
+ actor: Some(sender),
+ can_navigate: true,
+ })
+ }
+ }
+ }
+
+ fn did_render_notification(
+ &mut self,
+ notification_id: u64,
+ notification: &Notification,
+ cx: &mut ViewContext<Self>,
+ ) {
+ let should_mark_as_read = match notification {
+ Notification::ContactRequestAccepted { .. } => true,
+ Notification::ContactRequest { .. }
+ | Notification::ChannelInvitation { .. }
+ | Notification::ChannelMessageMention { .. } => false,
+ };
+
+ if should_mark_as_read {
+ self.mark_as_read_tasks
+ .entry(notification_id)
+ .or_insert_with(|| {
+ let client = self.client.clone();
+ cx.spawn(|this, mut cx| async move {
+ cx.background().timer(MARK_AS_READ_DELAY).await;
+ client
+ .request(proto::MarkNotificationRead { notification_id })
+ .await?;
+ this.update(&mut cx, |this, _| {
+ this.mark_as_read_tasks.remove(¬ification_id);
+ })?;
+ Ok(())
+ })
+ });
+ }
+ }
+
+ fn did_click_notification(&mut self, notification: &Notification, cx: &mut ViewContext<Self>) {
+ if let Notification::ChannelMessageMention {
+ message_id,
+ channel_id,
+ ..
+ } = notification.clone()
+ {
+ if let Some(workspace) = self.workspace.upgrade(cx) {
+ cx.app_context().defer(move |cx| {
+ workspace.update(cx, |workspace, cx| {
+ if let Some(panel) = workspace.focus_panel::<ChatPanel>(cx) {
+ panel.update(cx, |panel, cx| {
+ panel
+ .select_channel(channel_id, Some(message_id), cx)
+ .detach_and_log_err(cx);
+ });
+ }
+ });
+ });
+ }
+ }
+ }
+
+ fn is_showing_notification(&self, notification: &Notification, cx: &AppContext) -> bool {
+ if let Notification::ChannelMessageMention { channel_id, .. } = ¬ification {
+ if let Some(workspace) = self.workspace.upgrade(cx) {
+ return workspace
+ .read_with(cx, |workspace, cx| {
+ if let Some(panel) = workspace.panel::<ChatPanel>(cx) {
+ return panel.read_with(cx, |panel, cx| {
+ panel.is_scrolled_to_bottom()
+ && panel.active_chat().map_or(false, |chat| {
+ chat.read(cx).channel_id == *channel_id
+ })
+ });
+ }
+ false
+ })
+ .unwrap_or_default();
+ }
+ }
+
+ false
+ }
+
+ fn render_sign_in_prompt(
+ &self,
+ theme: &Arc<Theme>,
+ cx: &mut ViewContext<Self>,
+ ) -> AnyElement<Self> {
+ enum SignInPromptLabel {}
+
+ MouseEventHandler::new::<SignInPromptLabel, _>(0, cx, |mouse_state, _| {
+ Label::new(
+ "Sign in to view your notifications".to_string(),
+ theme
+ .chat_panel
+ .sign_in_prompt
+ .style_for(mouse_state)
+ .clone(),
+ )
+ })
+ .with_cursor_style(CursorStyle::PointingHand)
+ .on_click(MouseButton::Left, move |_, this, cx| {
+ let client = this.client.clone();
+ cx.spawn(|_, cx| async move {
+ client.authenticate_and_connect(true, &cx).log_err().await;
+ })
+ .detach();
+ })
+ .aligned()
+ .into_any()
+ }
+
+ fn render_empty_state(
+ &self,
+ theme: &Arc<Theme>,
+ _cx: &mut ViewContext<Self>,
+ ) -> AnyElement<Self> {
+ Label::new(
+ "You have no notifications".to_string(),
+ theme.chat_panel.sign_in_prompt.default.clone(),
+ )
+ .aligned()
+ .into_any()
+ }
+
+ fn on_notification_event(
+ &mut self,
+ _: ModelHandle<NotificationStore>,
+ event: &NotificationEvent,
+ cx: &mut ViewContext<Self>,
+ ) {
+ match event {
+ NotificationEvent::NewNotification { entry } => self.add_toast(entry, cx),
+ NotificationEvent::NotificationRemoved { entry }
+ | NotificationEvent::NotificationRead { entry } => self.remove_toast(entry.id, cx),
+ NotificationEvent::NotificationsUpdated {
+ old_range,
+ new_count,
+ } => {
+ self.notification_list.splice(old_range.clone(), *new_count);
+ cx.notify();
+ }
+ }
+ }
+
+ fn add_toast(&mut self, entry: &NotificationEntry, cx: &mut ViewContext<Self>) {
+ if self.is_showing_notification(&entry.notification, cx) {
+ return;
+ }
+
+ let Some(NotificationPresenter { actor, text, .. }) = self.present_notification(entry, cx)
+ else {
+ return;
+ };
+
+ let notification_id = entry.id;
+ self.current_notification_toast = Some((
+ notification_id,
+ cx.spawn(|this, mut cx| async move {
+ cx.background().timer(TOAST_DURATION).await;
+ this.update(&mut cx, |this, cx| this.remove_toast(notification_id, cx))
+ .ok();
+ }),
+ ));
+
+ self.workspace
+ .update(cx, |workspace, cx| {
+ workspace.dismiss_notification::<NotificationToast>(0, cx);
+ workspace.show_notification(0, cx, |cx| {
+ let workspace = cx.weak_handle();
+ cx.add_view(|_| NotificationToast {
+ notification_id,
+ actor,
+ text,
+ workspace,
+ })
+ })
+ })
+ .ok();
+ }
+
+ fn remove_toast(&mut self, notification_id: u64, cx: &mut ViewContext<Self>) {
+ if let Some((current_id, _)) = &self.current_notification_toast {
+ if *current_id == notification_id {
+ self.current_notification_toast.take();
+ self.workspace
+ .update(cx, |workspace, cx| {
+ workspace.dismiss_notification::<NotificationToast>(0, cx)
+ })
+ .ok();
+ }
+ }
+ }
+
+ fn respond_to_notification(
+ &mut self,
+ notification: Notification,
+ response: bool,
+ cx: &mut ViewContext<Self>,
+ ) {
+ self.notification_store.update(cx, |store, cx| {
+ store.respond_to_notification(notification, response, cx);
+ });
+ }
+}
+
+impl Entity for NotificationPanel {
+ type Event = Event;
+}
+
+impl View for NotificationPanel {
+ fn ui_name() -> &'static str {
+ "NotificationPanel"
+ }
+
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
+ let theme = theme::current(cx);
+ let style = &theme.notification_panel;
+ let element = if self.client.user_id().is_none() {
+ self.render_sign_in_prompt(&theme, cx)
+ } else if self.notification_list.item_count() == 0 {
+ self.render_empty_state(&theme, cx)
+ } else {
+ Flex::column()
+ .with_child(
+ Flex::row()
+ .with_child(Label::new("Notifications", style.title.text.clone()))
+ .with_child(ui::svg(&style.title_icon).flex_float())
+ .align_children_center()
+ .contained()
+ .with_style(style.title.container)
+ .constrained()
+ .with_height(style.title_height),
+ )
+ .with_child(
+ List::new(self.notification_list.clone())
+ .contained()
+ .with_style(style.list)
+ .flex(1., true),
+ )
+ .into_any()
+ };
+ element
+ .contained()
+ .with_style(style.container)
+ .constrained()
+ .with_min_width(150.)
+ .into_any()
+ }
+
+ fn focus_in(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
+ self.has_focus = true;
+ }
+
+ fn focus_out(&mut self, _: AnyViewHandle, _: &mut ViewContext<Self>) {
+ self.has_focus = false;
+ }
+}
+
+impl Panel for NotificationPanel {
+ fn position(&self, cx: &gpui::WindowContext) -> DockPosition {
+ settings::get::<NotificationPanelSettings>(cx).dock
+ }
+
+ fn position_is_valid(&self, position: DockPosition) -> bool {
+ matches!(position, DockPosition::Left | DockPosition::Right)
+ }
+
+ fn set_position(&mut self, position: DockPosition, cx: &mut ViewContext<Self>) {
+ settings::update_settings_file::<NotificationPanelSettings>(
+ self.fs.clone(),
+ cx,
+ move |settings| settings.dock = Some(position),
+ );
+ }
+
+ fn size(&self, cx: &gpui::WindowContext) -> f32 {
+ self.width
+ .unwrap_or_else(|| settings::get::<NotificationPanelSettings>(cx).default_width)
+ }
+
+ fn set_size(&mut self, size: Option<f32>, cx: &mut ViewContext<Self>) {
+ self.width = size;
+ self.serialize(cx);
+ cx.notify();
+ }
+
+ fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
+ self.active = active;
+ if self.notification_store.read(cx).notification_count() == 0 {
+ cx.emit(Event::Dismissed);
+ }
+ }
+
+ fn icon_path(&self, cx: &gpui::WindowContext) -> Option<&'static str> {
+ (settings::get::<NotificationPanelSettings>(cx).button
+ && self.notification_store.read(cx).notification_count() > 0)
+ .then(|| "icons/bell.svg")
+ }
+
+ fn icon_tooltip(&self) -> (String, Option<Box<dyn gpui::Action>>) {
+ (
+ "Notification Panel".to_string(),
+ Some(Box::new(ToggleFocus)),
+ )
+ }
+
+ fn icon_label(&self, cx: &WindowContext) -> Option<String> {
+ let count = self.notification_store.read(cx).unread_notification_count();
+ if count == 0 {
+ None
+ } else {
+ Some(count.to_string())
+ }
+ }
+
+ fn should_change_position_on_event(event: &Self::Event) -> bool {
+ matches!(event, Event::DockPositionChanged)
+ }
+
+ fn should_close_on_event(event: &Self::Event) -> bool {
+ matches!(event, Event::Dismissed)
+ }
+
+ fn has_focus(&self, _cx: &gpui::WindowContext) -> bool {
+ self.has_focus
+ }
+
+ fn is_focus_event(event: &Self::Event) -> bool {
+ matches!(event, Event::Focus)
+ }
+}
+
+pub struct NotificationToast {
+ notification_id: u64,
+ actor: Option<Arc<User>>,
+ text: String,
+ workspace: WeakViewHandle<Workspace>,
+}
+
+pub enum ToastEvent {
+ Dismiss,
+}
+
+impl NotificationToast {
+ fn focus_notification_panel(&self, cx: &mut AppContext) {
+ let workspace = self.workspace.clone();
+ let notification_id = self.notification_id;
+ cx.defer(move |cx| {
+ workspace
+ .update(cx, |workspace, cx| {
+ if let Some(panel) = workspace.focus_panel::<NotificationPanel>(cx) {
+ panel.update(cx, |panel, cx| {
+ let store = panel.notification_store.read(cx);
+ if let Some(entry) = store.notification_for_id(notification_id) {
+ panel.did_click_notification(&entry.clone().notification, cx);
+ }
+ });
+ }
+ })
+ .ok();
+ })
+ }
+}
+
+impl Entity for NotificationToast {
+ type Event = ToastEvent;
+}
+
+impl View for NotificationToast {
+ fn ui_name() -> &'static str {
+ "ContactNotification"
+ }
+
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
+ let user = self.actor.clone();
+ let theme = theme::current(cx).clone();
+ let theme = &theme.contact_notification;
+
+ MouseEventHandler::new::<Self, _>(0, cx, |_, cx| {
+ Flex::row()
+ .with_children(user.and_then(|user| {
+ Some(
+ Image::from_data(user.avatar.clone()?)
+ .with_style(theme.header_avatar)
+ .aligned()
+ .constrained()
+ .with_height(
+ cx.font_cache()
+ .line_height(theme.header_message.text.font_size),
+ )
+ .aligned()
+ .top(),
+ )
+ }))
+ .with_child(
+ Text::new(self.text.clone(), theme.header_message.text.clone())
+ .contained()
+ .with_style(theme.header_message.container)
+ .aligned()
+ .top()
+ .left()
+ .flex(1., true),
+ )
+ .with_child(
+ MouseEventHandler::new::<ToastEvent, _>(0, cx, |state, _| {
+ let style = theme.dismiss_button.style_for(state);
+ Svg::new("icons/x.svg")
+ .with_color(style.color)
+ .constrained()
+ .with_width(style.icon_width)
+ .aligned()
+ .contained()
+ .with_style(style.container)
+ .constrained()
+ .with_width(style.button_width)
+ .with_height(style.button_width)
+ })
+ .with_cursor_style(CursorStyle::PointingHand)
+ .with_padding(Padding::uniform(5.))
+ .on_click(MouseButton::Left, move |_, _, cx| {
+ cx.emit(ToastEvent::Dismiss)
+ })
+ .aligned()
+ .constrained()
+ .with_height(
+ cx.font_cache()
+ .line_height(theme.header_message.text.font_size),
+ )
+ .aligned()
+ .top()
+ .flex_float(),
+ )
+ .contained()
+ })
+ .with_cursor_style(CursorStyle::PointingHand)
+ .on_click(MouseButton::Left, move |_, this, cx| {
+ this.focus_notification_panel(cx);
+ cx.emit(ToastEvent::Dismiss);
+ })
+ .into_any()
+ }
+}
+
+impl workspace::notifications::Notification for NotificationToast {
+ fn should_dismiss_notification_on_event(&self, event: &<Self as Entity>::Event) -> bool {
+ matches!(event, ToastEvent::Dismiss)
+ }
+}
+
+fn format_timestamp(
+ mut timestamp: OffsetDateTime,
+ mut now: OffsetDateTime,
+ local_timezone: UtcOffset,
+) -> String {
+ timestamp = timestamp.to_offset(local_timezone);
+ now = now.to_offset(local_timezone);
+
+ let today = now.date();
+ let date = timestamp.date();
+ if date == today {
+ let difference = now - timestamp;
+ if difference >= Duration::from_secs(3600) {
+ format!("{}h", difference.whole_seconds() / 3600)
+ } else if difference >= Duration::from_secs(60) {
+ format!("{}m", difference.whole_seconds() / 60)
+ } else {
+ "just now".to_string()
+ }
+ } else if date.next_day() == Some(today) {
+ format!("yesterday")
+ } else {
+ format!("{:02}/{}/{}", date.month() as u32, date.day(), date.year())
+ }
+}
@@ -1,110 +1,11 @@
-use client::User;
-use gpui::{
- elements::*,
- platform::{CursorStyle, MouseButton},
- AnyElement, Element, ViewContext,
-};
+use gpui::AppContext;
use std::sync::Arc;
+use workspace::AppState;
-enum Dismiss {}
-enum Button {}
+pub mod incoming_call_notification;
+pub mod project_shared_notification;
-pub fn render_user_notification<F, V: 'static>(
- user: Arc<User>,
- title: &'static str,
- body: Option<&'static str>,
- on_dismiss: F,
- buttons: Vec<(&'static str, Box<dyn Fn(&mut V, &mut ViewContext<V>)>)>,
- cx: &mut ViewContext<V>,
-) -> AnyElement<V>
-where
- F: 'static + Fn(&mut V, &mut ViewContext<V>),
-{
- let theme = theme::current(cx).clone();
- let theme = &theme.contact_notification;
-
- Flex::column()
- .with_child(
- Flex::row()
- .with_children(user.avatar.clone().map(|avatar| {
- Image::from_data(avatar)
- .with_style(theme.header_avatar)
- .aligned()
- .constrained()
- .with_height(
- cx.font_cache()
- .line_height(theme.header_message.text.font_size),
- )
- .aligned()
- .top()
- }))
- .with_child(
- Text::new(
- format!("{} {}", user.github_login, title),
- theme.header_message.text.clone(),
- )
- .contained()
- .with_style(theme.header_message.container)
- .aligned()
- .top()
- .left()
- .flex(1., true),
- )
- .with_child(
- MouseEventHandler::new::<Dismiss, _>(user.id as usize, cx, |state, _| {
- let style = theme.dismiss_button.style_for(state);
- Svg::new("icons/x.svg")
- .with_color(style.color)
- .constrained()
- .with_width(style.icon_width)
- .aligned()
- .contained()
- .with_style(style.container)
- .constrained()
- .with_width(style.button_width)
- .with_height(style.button_width)
- })
- .with_cursor_style(CursorStyle::PointingHand)
- .with_padding(Padding::uniform(5.))
- .on_click(MouseButton::Left, move |_, view, cx| on_dismiss(view, cx))
- .aligned()
- .constrained()
- .with_height(
- cx.font_cache()
- .line_height(theme.header_message.text.font_size),
- )
- .aligned()
- .top()
- .flex_float(),
- )
- .into_any_named("contact notification header"),
- )
- .with_children(body.map(|body| {
- Label::new(body, theme.body_message.text.clone())
- .contained()
- .with_style(theme.body_message.container)
- }))
- .with_children(if buttons.is_empty() {
- None
- } else {
- Some(
- Flex::row()
- .with_children(buttons.into_iter().enumerate().map(
- |(ix, (message, handler))| {
- MouseEventHandler::new::<Button, _>(ix, cx, |state, _| {
- let button = theme.button.style_for(state);
- Label::new(message, button.text.clone())
- .contained()
- .with_style(button.container)
- })
- .with_cursor_style(CursorStyle::PointingHand)
- .on_click(MouseButton::Left, move |_, view, cx| handler(view, cx))
- },
- ))
- .aligned()
- .right(),
- )
- })
- .contained()
- .into_any()
+pub fn init(app_state: &Arc<AppState>, cx: &mut AppContext) {
+ incoming_call_notification::init(app_state, cx);
+ project_shared_notification::init(app_state, cx);
}
@@ -18,6 +18,13 @@ pub struct ChatPanelSettings {
pub default_width: f32,
}
+#[derive(Deserialize, Debug)]
+pub struct NotificationPanelSettings {
+ pub button: bool,
+ pub dock: DockPosition,
+ pub default_width: f32,
+}
+
#[derive(Clone, Default, Serialize, Deserialize, JsonSchema, Debug)]
pub struct PanelSettingsContent {
pub button: Option<bool>,
@@ -27,9 +34,7 @@ pub struct PanelSettingsContent {
impl Setting for CollaborationPanelSettings {
const KEY: Option<&'static str> = Some("collaboration_panel");
-
type FileContent = PanelSettingsContent;
-
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
@@ -41,9 +46,19 @@ impl Setting for CollaborationPanelSettings {
impl Setting for ChatPanelSettings {
const KEY: Option<&'static str> = Some("chat_panel");
-
type FileContent = PanelSettingsContent;
+ fn load(
+ default_value: &Self::FileContent,
+ user_values: &[&Self::FileContent],
+ _: &gpui::AppContext,
+ ) -> anyhow::Result<Self> {
+ Self::load_via_json_merge(default_value, user_values)
+ }
+}
+impl Setting for NotificationPanelSettings {
+ const KEY: Option<&'static str> = Some("notification_panel");
+ type FileContent = PanelSettingsContent;
fn load(
default_value: &Self::FileContent,
user_values: &[&Self::FileContent],
@@ -1,62 +0,0 @@
-use crate::toggle_screen_sharing;
-use call::ActiveCall;
-use gpui::{
- color::Color,
- elements::{MouseEventHandler, Svg},
- platform::{Appearance, MouseButton},
- AnyElement, AppContext, Element, Entity, View, ViewContext,
-};
-use workspace::WorkspaceSettings;
-
-pub fn init(cx: &mut AppContext) {
- let active_call = ActiveCall::global(cx);
-
- let mut status_indicator = None;
- cx.observe(&active_call, move |call, cx| {
- if let Some(room) = call.read(cx).room() {
- if room.read(cx).is_screen_sharing() {
- if status_indicator.is_none()
- && settings::get::<WorkspaceSettings>(cx).show_call_status_icon
- {
- status_indicator = Some(cx.add_status_bar_item(|_| SharingStatusIndicator));
- }
- } else if let Some(window) = status_indicator.take() {
- window.update(cx, |cx| cx.remove_window());
- }
- } else if let Some(window) = status_indicator.take() {
- window.update(cx, |cx| cx.remove_window());
- }
- })
- .detach();
-}
-
-pub struct SharingStatusIndicator;
-
-impl Entity for SharingStatusIndicator {
- type Event = ();
-}
-
-impl View for SharingStatusIndicator {
- fn ui_name() -> &'static str {
- "SharingStatusIndicator"
- }
-
- fn render(&mut self, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
- let color = match cx.window_appearance() {
- Appearance::Light | Appearance::VibrantLight => Color::black(),
- Appearance::Dark | Appearance::VibrantDark => Color::white(),
- };
-
- MouseEventHandler::new::<Self, _>(0, cx, |_, _| {
- Svg::new("icons/desktop.svg")
- .with_color(color)
- .constrained()
- .with_width(18.)
- .aligned()
- })
- .on_click(MouseButton::Left, |_, _, cx| {
- toggle_screen_sharing(&Default::default(), cx)
- })
- .into_any()
- }
-}
@@ -5,22 +5,24 @@ mod tab_map;
mod wrap_map;
use crate::{
- link_go_to_definition::InlayHighlight, Anchor, AnchorRangeExt, InlayId, MultiBuffer,
- MultiBufferSnapshot, ToOffset, ToPoint,
+ link_go_to_definition::InlayHighlight, movement::TextLayoutDetails, Anchor, AnchorRangeExt,
+ EditorStyle, InlayId, MultiBuffer, MultiBufferSnapshot, ToOffset, ToPoint,
};
pub use block_map::{BlockMap, BlockPoint};
use collections::{BTreeMap, HashMap, HashSet};
use fold_map::FoldMap;
use gpui::{
color::Color,
- fonts::{FontId, HighlightStyle},
+ fonts::{FontId, HighlightStyle, Underline},
+ text_layout::{Line, RunStyle},
Entity, ModelContext, ModelHandle,
};
use inlay_map::InlayMap;
use language::{
language_settings::language_settings, OffsetUtf16, Point, Subscription as BufferSubscription,
};
-use std::{any::TypeId, fmt::Debug, num::NonZeroU32, ops::Range, sync::Arc};
+use lsp::DiagnosticSeverity;
+use std::{any::TypeId, borrow::Cow, fmt::Debug, num::NonZeroU32, ops::Range, sync::Arc};
use sum_tree::{Bias, TreeMap};
use tab_map::TabMap;
use wrap_map::WrapMap;
@@ -316,6 +318,12 @@ pub struct Highlights<'a> {
pub suggestion_highlight_style: Option<HighlightStyle>,
}
+pub struct HighlightedChunk<'a> {
+ pub chunk: &'a str,
+ pub style: Option<HighlightStyle>,
+ pub is_tab: bool,
+}
+
pub struct DisplaySnapshot {
pub buffer_snapshot: MultiBufferSnapshot,
pub fold_snapshot: fold_map::FoldSnapshot,
@@ -485,7 +493,7 @@ impl DisplaySnapshot {
language_aware: bool,
inlay_highlight_style: Option<HighlightStyle>,
suggestion_highlight_style: Option<HighlightStyle>,
- ) -> DisplayChunks<'_> {
+ ) -> DisplayChunks<'a> {
self.block_snapshot.chunks(
display_rows,
language_aware,
@@ -498,6 +506,140 @@ impl DisplaySnapshot {
)
}
+ pub fn highlighted_chunks<'a>(
+ &'a self,
+ display_rows: Range<u32>,
+ language_aware: bool,
+ style: &'a EditorStyle,
+ ) -> impl Iterator<Item = HighlightedChunk<'a>> {
+ self.chunks(
+ display_rows,
+ language_aware,
+ Some(style.theme.hint),
+ Some(style.theme.suggestion),
+ )
+ .map(|chunk| {
+ let mut highlight_style = chunk
+ .syntax_highlight_id
+ .and_then(|id| id.style(&style.syntax));
+
+ if let Some(chunk_highlight) = chunk.highlight_style {
+ if let Some(highlight_style) = highlight_style.as_mut() {
+ highlight_style.highlight(chunk_highlight);
+ } else {
+ highlight_style = Some(chunk_highlight);
+ }
+ }
+
+ let mut diagnostic_highlight = HighlightStyle::default();
+
+ if chunk.is_unnecessary {
+ diagnostic_highlight.fade_out = Some(style.unnecessary_code_fade);
+ }
+
+ if let Some(severity) = chunk.diagnostic_severity {
+ // Omit underlines for HINT/INFO diagnostics on 'unnecessary' code.
+ if severity <= DiagnosticSeverity::WARNING || !chunk.is_unnecessary {
+ let diagnostic_style = super::diagnostic_style(severity, true, style);
+ diagnostic_highlight.underline = Some(Underline {
+ color: Some(diagnostic_style.message.text.color),
+ thickness: 1.0.into(),
+ squiggly: true,
+ });
+ }
+ }
+
+ if let Some(highlight_style) = highlight_style.as_mut() {
+ highlight_style.highlight(diagnostic_highlight);
+ } else {
+ highlight_style = Some(diagnostic_highlight);
+ }
+
+ HighlightedChunk {
+ chunk: chunk.text,
+ style: highlight_style,
+ is_tab: chunk.is_tab,
+ }
+ })
+ }
+
+ pub fn lay_out_line_for_row(
+ &self,
+ display_row: u32,
+ TextLayoutDetails {
+ font_cache,
+ text_layout_cache,
+ editor_style,
+ }: &TextLayoutDetails,
+ ) -> Line {
+ let mut styles = Vec::new();
+ let mut line = String::new();
+ let mut ended_in_newline = false;
+
+ let range = display_row..display_row + 1;
+ for chunk in self.highlighted_chunks(range, false, editor_style) {
+ line.push_str(chunk.chunk);
+
+ let text_style = if let Some(style) = chunk.style {
+ editor_style
+ .text
+ .clone()
+ .highlight(style, font_cache)
+ .map(Cow::Owned)
+ .unwrap_or_else(|_| Cow::Borrowed(&editor_style.text))
+ } else {
+ Cow::Borrowed(&editor_style.text)
+ };
+ ended_in_newline = chunk.chunk.ends_with("\n");
+
+ styles.push((
+ chunk.chunk.len(),
+ RunStyle {
+ font_id: text_style.font_id,
+ color: text_style.color,
+ underline: text_style.underline,
+ },
+ ));
+ }
+
+ // our pixel positioning logic assumes each line ends in \n,
+ // this is almost always true except for the last line which
+ // may have no trailing newline.
+ if !ended_in_newline && display_row == self.max_point().row() {
+ line.push_str("\n");
+
+ styles.push((
+ "\n".len(),
+ RunStyle {
+ font_id: editor_style.text.font_id,
+ color: editor_style.text_color,
+ underline: editor_style.text.underline,
+ },
+ ));
+ }
+
+ text_layout_cache.layout_str(&line, editor_style.text.font_size, &styles)
+ }
+
+ pub fn x_for_point(
+ &self,
+ display_point: DisplayPoint,
+ text_layout_details: &TextLayoutDetails,
+ ) -> f32 {
+ let layout_line = self.lay_out_line_for_row(display_point.row(), text_layout_details);
+ layout_line.x_for_index(display_point.column() as usize)
+ }
+
+ pub fn column_for_x(
+ &self,
+ display_row: u32,
+ x_coordinate: f32,
+ text_layout_details: &TextLayoutDetails,
+ ) -> u32 {
+ let layout_line = self.lay_out_line_for_row(display_row, text_layout_details);
+ layout_line.closest_index_for_x(x_coordinate) as u32
+ }
+
pub fn chars_at(
&self,
mut point: DisplayPoint,
@@ -869,12 +1011,16 @@ pub fn next_rows(display_row: u32, display_map: &DisplaySnapshot) -> impl Iterat
#[cfg(test)]
pub mod tests {
use super::*;
- use crate::{movement, test::marked_display_snapshot};
+ use crate::{
+ movement,
+ test::{editor_test_context::EditorTestContext, marked_display_snapshot},
+ };
use gpui::{color::Color, elements::*, test::observe, AppContext};
use language::{
language_settings::{AllLanguageSettings, AllLanguageSettingsContent},
Buffer, Language, LanguageConfig, SelectionGoal,
};
+ use project::Project;
use rand::{prelude::*, Rng};
use settings::SettingsStore;
use smol::stream::StreamExt;
@@ -1148,95 +1294,120 @@ pub mod tests {
}
#[gpui::test(retries = 5)]
- fn test_soft_wraps(cx: &mut AppContext) {
+ async fn test_soft_wraps(cx: &mut gpui::TestAppContext) {
cx.foreground().set_block_on_ticks(usize::MAX..=usize::MAX);
- init_test(cx, |_| {});
+ cx.update(|cx| {
+ init_test(cx, |_| {});
+ });
- let font_cache = cx.font_cache();
+ let mut cx = EditorTestContext::new(cx).await;
+ let editor = cx.editor.clone();
+ let window = cx.window.clone();
- let family_id = font_cache
- .load_family(&["Helvetica"], &Default::default())
- .unwrap();
- let font_id = font_cache
- .select_font(family_id, &Default::default())
- .unwrap();
- let font_size = 12.0;
- let wrap_width = Some(64.);
+ cx.update_window(window, |cx| {
+ let text_layout_details =
+ editor.read_with(cx, |editor, cx| editor.text_layout_details(cx));
- let text = "one two three four five\nsix seven eight";
- let buffer = MultiBuffer::build_simple(text, cx);
- let map = cx.add_model(|cx| {
- DisplayMap::new(buffer.clone(), font_id, font_size, wrap_width, 1, 1, cx)
- });
+ let font_cache = cx.font_cache().clone();
- let snapshot = map.update(cx, |map, cx| map.snapshot(cx));
- assert_eq!(
- snapshot.text_chunks(0).collect::<String>(),
- "one two \nthree four \nfive\nsix seven \neight"
- );
- assert_eq!(
- snapshot.clip_point(DisplayPoint::new(0, 8), Bias::Left),
- DisplayPoint::new(0, 7)
- );
- assert_eq!(
- snapshot.clip_point(DisplayPoint::new(0, 8), Bias::Right),
- DisplayPoint::new(1, 0)
- );
- assert_eq!(
- movement::right(&snapshot, DisplayPoint::new(0, 7)),
- DisplayPoint::new(1, 0)
- );
- assert_eq!(
- movement::left(&snapshot, DisplayPoint::new(1, 0)),
- DisplayPoint::new(0, 7)
- );
- assert_eq!(
- movement::up(
- &snapshot,
- DisplayPoint::new(1, 10),
- SelectionGoal::None,
- false
- ),
- (DisplayPoint::new(0, 7), SelectionGoal::Column(10))
- );
- assert_eq!(
- movement::down(
- &snapshot,
- DisplayPoint::new(0, 7),
- SelectionGoal::Column(10),
- false
- ),
- (DisplayPoint::new(1, 10), SelectionGoal::Column(10))
- );
- assert_eq!(
- movement::down(
- &snapshot,
- DisplayPoint::new(1, 10),
- SelectionGoal::Column(10),
- false
- ),
- (DisplayPoint::new(2, 4), SelectionGoal::Column(10))
- );
+ let family_id = font_cache
+ .load_family(&["Helvetica"], &Default::default())
+ .unwrap();
+ let font_id = font_cache
+ .select_font(family_id, &Default::default())
+ .unwrap();
+ let font_size = 12.0;
+ let wrap_width = Some(64.);
- let ix = snapshot.buffer_snapshot.text().find("seven").unwrap();
- buffer.update(cx, |buffer, cx| {
- buffer.edit([(ix..ix, "and ")], None, cx);
- });
+ let text = "one two three four five\nsix seven eight";
+ let buffer = MultiBuffer::build_simple(text, cx);
+ let map = cx.add_model(|cx| {
+ DisplayMap::new(buffer.clone(), font_id, font_size, wrap_width, 1, 1, cx)
+ });
- let snapshot = map.update(cx, |map, cx| map.snapshot(cx));
- assert_eq!(
- snapshot.text_chunks(1).collect::<String>(),
- "three four \nfive\nsix and \nseven eight"
- );
+ let snapshot = map.update(cx, |map, cx| map.snapshot(cx));
+ assert_eq!(
+ snapshot.text_chunks(0).collect::<String>(),
+ "one two \nthree four \nfive\nsix seven \neight"
+ );
+ assert_eq!(
+ snapshot.clip_point(DisplayPoint::new(0, 8), Bias::Left),
+ DisplayPoint::new(0, 7)
+ );
+ assert_eq!(
+ snapshot.clip_point(DisplayPoint::new(0, 8), Bias::Right),
+ DisplayPoint::new(1, 0)
+ );
+ assert_eq!(
+ movement::right(&snapshot, DisplayPoint::new(0, 7)),
+ DisplayPoint::new(1, 0)
+ );
+ assert_eq!(
+ movement::left(&snapshot, DisplayPoint::new(1, 0)),
+ DisplayPoint::new(0, 7)
+ );
- // Re-wrap on font size changes
- map.update(cx, |map, cx| map.set_font(font_id, font_size + 3., cx));
+ let x = snapshot.x_for_point(DisplayPoint::new(1, 10), &text_layout_details);
+ assert_eq!(
+ movement::up(
+ &snapshot,
+ DisplayPoint::new(1, 10),
+ SelectionGoal::None,
+ false,
+ &text_layout_details,
+ ),
+ (
+ DisplayPoint::new(0, 7),
+ SelectionGoal::HorizontalPosition(x)
+ )
+ );
+ assert_eq!(
+ movement::down(
+ &snapshot,
+ DisplayPoint::new(0, 7),
+ SelectionGoal::HorizontalPosition(x),
+ false,
+ &text_layout_details
+ ),
+ (
+ DisplayPoint::new(1, 10),
+ SelectionGoal::HorizontalPosition(x)
+ )
+ );
+ assert_eq!(
+ movement::down(
+ &snapshot,
+ DisplayPoint::new(1, 10),
+ SelectionGoal::HorizontalPosition(x),
+ false,
+ &text_layout_details
+ ),
+ (
+ DisplayPoint::new(2, 4),
+ SelectionGoal::HorizontalPosition(x)
+ )
+ );
- let snapshot = map.update(cx, |map, cx| map.snapshot(cx));
- assert_eq!(
- snapshot.text_chunks(1).collect::<String>(),
- "three \nfour five\nsix and \nseven \neight"
- )
+ let ix = snapshot.buffer_snapshot.text().find("seven").unwrap();
+ buffer.update(cx, |buffer, cx| {
+ buffer.edit([(ix..ix, "and ")], None, cx);
+ });
+
+ let snapshot = map.update(cx, |map, cx| map.snapshot(cx));
+ assert_eq!(
+ snapshot.text_chunks(1).collect::<String>(),
+ "three four \nfive\nsix and \nseven eight"
+ );
+
+ // Re-wrap on font size changes
+ map.update(cx, |map, cx| map.set_font(font_id, font_size + 3., cx));
+
+ let snapshot = map.update(cx, |map, cx| map.snapshot(cx));
+ assert_eq!(
+ snapshot.text_chunks(1).collect::<String>(),
+ "three \nfour five\nsix and \nseven \neight"
+ )
+ });
}
#[gpui::test]
@@ -1731,6 +1902,9 @@ pub mod tests {
cx.foreground().forbid_parking();
cx.set_global(SettingsStore::test(cx));
language::init(cx);
+ crate::init(cx);
+ Project::init_settings(cx);
+ theme::init((), cx);
cx.update_global::<SettingsStore, _, _>(|store, cx| {
store.update_user_settings::<AllLanguageSettings>(cx, f);
});
@@ -71,6 +71,7 @@ use link_go_to_definition::{
};
use log::error;
use lsp::LanguageServerId;
+use movement::TextLayoutDetails;
use multi_buffer::ToOffsetUtf16;
pub use multi_buffer::{
Anchor, AnchorRangeExt, ExcerptId, ExcerptRange, MultiBuffer, MultiBufferSnapshot, ToOffset,
@@ -3286,8 +3287,10 @@ impl Editor {
i = 0;
} else if pair_state.range.start.to_offset(buffer) > range.end {
break;
- } else if pair_state.selection_id == selection.id {
- enclosing = Some(pair_state);
+ } else {
+ if pair_state.selection_id == selection.id {
+ enclosing = Some(pair_state);
+ }
i += 1;
}
}
@@ -3474,6 +3477,14 @@ impl Editor {
.collect()
}
+ pub fn text_layout_details(&self, cx: &WindowContext) -> TextLayoutDetails {
+ TextLayoutDetails {
+ font_cache: cx.font_cache().clone(),
+ text_layout_cache: cx.text_layout_cache().clone(),
+ editor_style: self.style(cx),
+ }
+ }
+
fn splice_inlay_hints(
&self,
to_remove: Vec<InlayId>,
@@ -5408,6 +5419,7 @@ impl Editor {
}
pub fn transpose(&mut self, _: &Transpose, cx: &mut ViewContext<Self>) {
+ let text_layout_details = &self.text_layout_details(cx);
self.transact(cx, |this, cx| {
let edits = this.change_selections(Some(Autoscroll::fit()), cx, |s| {
let mut edits: Vec<(Range<usize>, String)> = Default::default();
@@ -5431,7 +5443,10 @@ impl Editor {
*head.column_mut() += 1;
head = display_map.clip_point(head, Bias::Right);
- selection.collapse_to(head, SelectionGoal::Column(head.column()));
+ let goal = SelectionGoal::HorizontalPosition(
+ display_map.x_for_point(head, &text_layout_details),
+ );
+ selection.collapse_to(head, goal);
let transpose_start = display_map
.buffer_snapshot
@@ -5695,13 +5710,21 @@ impl Editor {
return;
}
+ let text_layout_details = &self.text_layout_details(cx);
+
self.change_selections(Some(Autoscroll::fit()), cx, |s| {
let line_mode = s.line_mode;
s.move_with(|map, selection| {
if !selection.is_empty() && !line_mode {
selection.goal = SelectionGoal::None;
}
- let (cursor, goal) = movement::up(map, selection.start, selection.goal, false);
+ let (cursor, goal) = movement::up(
+ map,
+ selection.start,
+ selection.goal,
+ false,
+ &text_layout_details,
+ );
selection.collapse_to(cursor, goal);
});
})
@@ -5729,22 +5752,33 @@ impl Editor {
Autoscroll::fit()
};
+ let text_layout_details = &self.text_layout_details(cx);
+
self.change_selections(Some(autoscroll), cx, |s| {
let line_mode = s.line_mode;
s.move_with(|map, selection| {
if !selection.is_empty() && !line_mode {
selection.goal = SelectionGoal::None;
}
- let (cursor, goal) =
- movement::up_by_rows(map, selection.end, row_count, selection.goal, false);
+ let (cursor, goal) = movement::up_by_rows(
+ map,
+ selection.end,
+ row_count,
+ selection.goal,
+ false,
+ &text_layout_details,
+ );
selection.collapse_to(cursor, goal);
});
});
}
pub fn select_up(&mut self, _: &SelectUp, cx: &mut ViewContext<Self>) {
+ let text_layout_details = &self.text_layout_details(cx);
self.change_selections(Some(Autoscroll::fit()), cx, |s| {
- s.move_heads_with(|map, head, goal| movement::up(map, head, goal, false))
+ s.move_heads_with(|map, head, goal| {
+ movement::up(map, head, goal, false, &text_layout_details)
+ })
})
}
@@ -5756,13 +5790,20 @@ impl Editor {
return;
}
+ let text_layout_details = &self.text_layout_details(cx);
self.change_selections(Some(Autoscroll::fit()), cx, |s| {
let line_mode = s.line_mode;
s.move_with(|map, selection| {
if !selection.is_empty() && !line_mode {
selection.goal = SelectionGoal::None;
}
- let (cursor, goal) = movement::down(map, selection.end, selection.goal, false);
+ let (cursor, goal) = movement::down(
+ map,
+ selection.end,
+ selection.goal,
+ false,
+ &text_layout_details,
+ );
selection.collapse_to(cursor, goal);
});
});
@@ -5800,22 +5841,32 @@ impl Editor {
Autoscroll::fit()
};
+ let text_layout_details = &self.text_layout_details(cx);
self.change_selections(Some(autoscroll), cx, |s| {
let line_mode = s.line_mode;
s.move_with(|map, selection| {
if !selection.is_empty() && !line_mode {
selection.goal = SelectionGoal::None;
}
- let (cursor, goal) =
- movement::down_by_rows(map, selection.end, row_count, selection.goal, false);
+ let (cursor, goal) = movement::down_by_rows(
+ map,
+ selection.end,
+ row_count,
+ selection.goal,
+ false,
+ &text_layout_details,
+ );
selection.collapse_to(cursor, goal);
});
});
}
pub fn select_down(&mut self, _: &SelectDown, cx: &mut ViewContext<Self>) {
+ let text_layout_details = &self.text_layout_details(cx);
self.change_selections(Some(Autoscroll::fit()), cx, |s| {
- s.move_heads_with(|map, head, goal| movement::down(map, head, goal, false))
+ s.move_heads_with(|map, head, goal| {
+ movement::down(map, head, goal, false, &text_layout_details)
+ })
});
}
@@ -6334,11 +6385,14 @@ impl Editor {
fn add_selection(&mut self, above: bool, cx: &mut ViewContext<Self>) {
let display_map = self.display_map.update(cx, |map, cx| map.snapshot(cx));
let mut selections = self.selections.all::<Point>(cx);
+ let text_layout_details = self.text_layout_details(cx);
let mut state = self.add_selections_state.take().unwrap_or_else(|| {
let oldest_selection = selections.iter().min_by_key(|s| s.id).unwrap().clone();
let range = oldest_selection.display_range(&display_map).sorted();
- let columns = cmp::min(range.start.column(), range.end.column())
- ..cmp::max(range.start.column(), range.end.column());
+
+ let start_x = display_map.x_for_point(range.start, &text_layout_details);
+ let end_x = display_map.x_for_point(range.end, &text_layout_details);
+ let positions = start_x.min(end_x)..start_x.max(end_x);
selections.clear();
let mut stack = Vec::new();
@@ -6346,8 +6400,9 @@ impl Editor {
if let Some(selection) = self.selections.build_columnar_selection(
&display_map,
row,
- &columns,
+ &positions,
oldest_selection.reversed,
+ &text_layout_details,
) {
stack.push(selection.id);
selections.push(selection);
@@ -6375,12 +6430,15 @@ impl Editor {
let range = selection.display_range(&display_map).sorted();
debug_assert_eq!(range.start.row(), range.end.row());
let mut row = range.start.row();
- let columns = if let SelectionGoal::ColumnRange { start, end } = selection.goal
+ let positions = if let SelectionGoal::HorizontalRange { start, end } =
+ selection.goal
{
start..end
} else {
- cmp::min(range.start.column(), range.end.column())
- ..cmp::max(range.start.column(), range.end.column())
+ let start_x = display_map.x_for_point(range.start, &text_layout_details);
+ let end_x = display_map.x_for_point(range.end, &text_layout_details);
+
+ start_x.min(end_x)..start_x.max(end_x)
};
while row != end_row {
@@ -6393,8 +6451,9 @@ impl Editor {
if let Some(new_selection) = self.selections.build_columnar_selection(
&display_map,
row,
- &columns,
+ &positions,
selection.reversed,
+ &text_layout_details,
) {
state.stack.push(new_selection.id);
if above {
@@ -6688,6 +6747,7 @@ impl Editor {
}
pub fn toggle_comments(&mut self, action: &ToggleComments, cx: &mut ViewContext<Self>) {
+ let text_layout_details = &self.text_layout_details(cx);
self.transact(cx, |this, cx| {
let mut selections = this.selections.all::<Point>(cx);
let mut edits = Vec::new();
@@ -6930,7 +6990,10 @@ impl Editor {
point.row += 1;
point = snapshot.clip_point(point, Bias::Left);
let display_point = point.to_display_point(display_snapshot);
- (display_point, SelectionGoal::Column(display_point.column()))
+ let goal = SelectionGoal::HorizontalPosition(
+ display_snapshot.x_for_point(display_point, &text_layout_details),
+ );
+ (display_point, goal)
})
});
}
@@ -19,8 +19,8 @@ use gpui::{
use indoc::indoc;
use language::{
language_settings::{AllLanguageSettings, AllLanguageSettingsContent, LanguageSettingsContent},
- BracketPairConfig, BundledFormatter, FakeLspAdapter, LanguageConfig, LanguageConfigOverride,
- LanguageRegistry, Override, Point,
+ BracketPairConfig, FakeLspAdapter, LanguageConfig, LanguageConfigOverride, LanguageRegistry,
+ Override, Point,
};
use parking_lot::Mutex;
use project::project_settings::{LspSettings, ProjectSettings};
@@ -851,7 +851,7 @@ fn test_move_cursor_multibyte(cx: &mut TestAppContext) {
let view = cx
.add_window(|cx| {
- let buffer = MultiBuffer::build_simple("βββββ\nabcde\nαβγδΡ\n", cx);
+ let buffer = MultiBuffer::build_simple("βββββ\nabcde\nαβγδΡ", cx);
build_editor(buffer.clone(), cx)
})
.root(cx);
@@ -869,7 +869,7 @@ fn test_move_cursor_multibyte(cx: &mut TestAppContext) {
true,
cx,
);
- assert_eq!(view.display_text(cx), "βββ―β\nabβ―e\nΞ±Ξ²β―Ξ΅\n");
+ assert_eq!(view.display_text(cx), "βββ―β\nabβ―e\nΞ±Ξ²β―Ξ΅");
view.move_right(&MoveRight, cx);
assert_eq!(
@@ -888,6 +888,11 @@ fn test_move_cursor_multibyte(cx: &mut TestAppContext) {
);
view.move_down(&MoveDown, cx);
+ assert_eq!(
+ view.selections.display_ranges(cx),
+ &[empty_range(1, "abβ―e".len())]
+ );
+ view.move_left(&MoveLeft, cx);
assert_eq!(
view.selections.display_ranges(cx),
&[empty_range(1, "abβ―".len())]
@@ -929,17 +934,18 @@ fn test_move_cursor_multibyte(cx: &mut TestAppContext) {
view.selections.display_ranges(cx),
&[empty_range(1, "abβ―e".len())]
);
- view.move_up(&MoveUp, cx);
+ view.move_down(&MoveDown, cx);
assert_eq!(
view.selections.display_ranges(cx),
- &[empty_range(0, "βββ―β".len())]
+ &[empty_range(2, "Ξ±Ξ²β―Ξ΅".len())]
);
- view.move_left(&MoveLeft, cx);
+ view.move_up(&MoveUp, cx);
assert_eq!(
view.selections.display_ranges(cx),
- &[empty_range(0, "βββ―".len())]
+ &[empty_range(1, "abβ―e".len())]
);
- view.move_left(&MoveLeft, cx);
+
+ view.move_up(&MoveUp, cx);
assert_eq!(
view.selections.display_ranges(cx),
&[empty_range(0, "ββ".len())]
@@ -949,6 +955,11 @@ fn test_move_cursor_multibyte(cx: &mut TestAppContext) {
view.selections.display_ranges(cx),
&[empty_range(0, "β".len())]
);
+ view.move_left(&MoveLeft, cx);
+ assert_eq!(
+ view.selections.display_ranges(cx),
+ &[empty_range(0, "".len())]
+ );
});
}
@@ -5084,6 +5095,9 @@ async fn test_document_format_manual_trigger(cx: &mut gpui::TestAppContext) {
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".to_string()],
+ // Enable Prettier formatting for the same buffer, and ensure
+ // LSP is called instead of Prettier.
+ prettier_parser_name: Some("test_parser".to_string()),
..Default::default()
},
Some(tree_sitter_rust::language()),
@@ -5094,12 +5108,6 @@ async fn test_document_format_manual_trigger(cx: &mut gpui::TestAppContext) {
document_formatting_provider: Some(lsp::OneOf::Left(true)),
..Default::default()
},
- // Enable Prettier formatting for the same buffer, and ensure
- // LSP is called instead of Prettier.
- enabled_formatters: vec![BundledFormatter::Prettier {
- parser_name: Some("test_parser"),
- plugin_names: Vec::new(),
- }],
..Default::default()
}))
.await;
@@ -7838,6 +7846,7 @@ async fn test_document_format_with_prettier(cx: &mut gpui::TestAppContext) {
LanguageConfig {
name: "Rust".into(),
path_suffixes: vec!["rs".to_string()],
+ prettier_parser_name: Some("test_parser".to_string()),
..Default::default()
},
Some(tree_sitter_rust::language()),
@@ -7846,10 +7855,7 @@ async fn test_document_format_with_prettier(cx: &mut gpui::TestAppContext) {
let test_plugin = "test_plugin";
let _ = language
.set_fake_lsp_adapter(Arc::new(FakeLspAdapter {
- enabled_formatters: vec![BundledFormatter::Prettier {
- parser_name: Some("test_parser"),
- plugin_names: vec![test_plugin],
- }],
+ prettier_plugins: vec![test_plugin],
..Default::default()
}))
.await;
@@ -4,7 +4,7 @@ use super::{
MAX_LINE_LEN,
};
use crate::{
- display_map::{BlockStyle, DisplaySnapshot, FoldStatus, TransformBlock},
+ display_map::{BlockStyle, DisplaySnapshot, FoldStatus, HighlightedChunk, TransformBlock},
editor_settings::ShowScrollbar,
git::{diff_hunk_to_display, DisplayDiffHunk},
hover_popover::{
@@ -22,7 +22,7 @@ use git::diff::DiffHunkStatus;
use gpui::{
color::Color,
elements::*,
- fonts::{HighlightStyle, TextStyle, Underline},
+ fonts::TextStyle,
geometry::{
rect::RectF,
vector::{vec2f, Vector2F},
@@ -37,8 +37,7 @@ use gpui::{
use itertools::Itertools;
use json::json;
use language::{
- language_settings::ShowWhitespaceSetting, Bias, CursorShape, DiagnosticSeverity, OffsetUtf16,
- Selection,
+ language_settings::ShowWhitespaceSetting, Bias, CursorShape, OffsetUtf16, Selection,
};
use project::{
project_settings::{GitGutterSetting, ProjectSettings},
@@ -1584,56 +1583,7 @@ impl EditorElement {
.collect()
} else {
let style = &self.style;
- let chunks = snapshot
- .chunks(
- rows.clone(),
- true,
- Some(style.theme.hint),
- Some(style.theme.suggestion),
- )
- .map(|chunk| {
- let mut highlight_style = chunk
- .syntax_highlight_id
- .and_then(|id| id.style(&style.syntax));
-
- if let Some(chunk_highlight) = chunk.highlight_style {
- if let Some(highlight_style) = highlight_style.as_mut() {
- highlight_style.highlight(chunk_highlight);
- } else {
- highlight_style = Some(chunk_highlight);
- }
- }
-
- let mut diagnostic_highlight = HighlightStyle::default();
-
- if chunk.is_unnecessary {
- diagnostic_highlight.fade_out = Some(style.unnecessary_code_fade);
- }
-
- if let Some(severity) = chunk.diagnostic_severity {
- // Omit underlines for HINT/INFO diagnostics on 'unnecessary' code.
- if severity <= DiagnosticSeverity::WARNING || !chunk.is_unnecessary {
- let diagnostic_style = super::diagnostic_style(severity, true, style);
- diagnostic_highlight.underline = Some(Underline {
- color: Some(diagnostic_style.message.text.color),
- thickness: 1.0.into(),
- squiggly: true,
- });
- }
- }
-
- if let Some(highlight_style) = highlight_style.as_mut() {
- highlight_style.highlight(diagnostic_highlight);
- } else {
- highlight_style = Some(diagnostic_highlight);
- }
-
- HighlightedChunk {
- chunk: chunk.text,
- style: highlight_style,
- is_tab: chunk.is_tab,
- }
- });
+ let chunks = snapshot.highlighted_chunks(rows.clone(), true, style);
LineWithInvisibles::from_chunks(
chunks,
@@ -1870,12 +1820,6 @@ impl EditorElement {
}
}
-struct HighlightedChunk<'a> {
- chunk: &'a str,
- style: Option<HighlightStyle>,
- is_tab: bool,
-}
-
#[derive(Debug)]
pub struct LineWithInvisibles {
pub line: Line,
@@ -2138,7 +2138,7 @@ pub mod tests {
});
}
- #[gpui::test]
+ #[gpui::test(iterations = 10)]
async fn test_large_buffer_inlay_requests_split(cx: &mut gpui::TestAppContext) {
init_test(cx, |settings| {
settings.defaults.inlay_hints = Some(InlayHintSettings {
@@ -2400,11 +2400,13 @@ pub mod tests {
));
cx.foreground().run_until_parked();
editor.update(cx, |editor, cx| {
- let ranges = lsp_request_ranges.lock().drain(..).collect::<Vec<_>>();
+ let mut ranges = lsp_request_ranges.lock().drain(..).collect::<Vec<_>>();
+ ranges.sort_by_key(|r| r.start);
+
assert_eq!(ranges.len(), 3,
"On edit, should scroll to selection and query a range around it: visible + same range above and below. Instead, got query ranges {ranges:?}");
- let visible_query_range = &ranges[0];
- let above_query_range = &ranges[1];
+ let above_query_range = &ranges[0];
+ let visible_query_range = &ranges[1];
let below_query_range = &ranges[2];
assert!(above_query_range.end.character < visible_query_range.start.character || above_query_range.end.line + 1 == visible_query_range.start.line,
"Above range {above_query_range:?} should be before visible range {visible_query_range:?}");
@@ -1,7 +1,8 @@
use super::{Bias, DisplayPoint, DisplaySnapshot, SelectionGoal, ToDisplayPoint};
-use crate::{char_kind, CharKind, ToOffset, ToPoint};
+use crate::{char_kind, CharKind, EditorStyle, ToOffset, ToPoint};
+use gpui::{FontCache, TextLayoutCache};
use language::Point;
-use std::ops::Range;
+use std::{ops::Range, sync::Arc};
#[derive(Debug, PartialEq)]
pub enum FindRange {
@@ -9,6 +10,14 @@ pub enum FindRange {
MultiLine,
}
+/// TextLayoutDetails encompasses everything we need to move vertically
+/// taking into account variable width characters.
+pub struct TextLayoutDetails {
+ pub font_cache: Arc<FontCache>,
+ pub text_layout_cache: Arc<TextLayoutCache>,
+ pub editor_style: EditorStyle,
+}
+
pub fn left(map: &DisplaySnapshot, mut point: DisplayPoint) -> DisplayPoint {
if point.column() > 0 {
*point.column_mut() -= 1;
@@ -47,8 +56,16 @@ pub fn up(
start: DisplayPoint,
goal: SelectionGoal,
preserve_column_at_start: bool,
+ text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) {
- up_by_rows(map, start, 1, goal, preserve_column_at_start)
+ up_by_rows(
+ map,
+ start,
+ 1,
+ goal,
+ preserve_column_at_start,
+ text_layout_details,
+ )
}
pub fn down(
@@ -56,8 +73,16 @@ pub fn down(
start: DisplayPoint,
goal: SelectionGoal,
preserve_column_at_end: bool,
+ text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) {
- down_by_rows(map, start, 1, goal, preserve_column_at_end)
+ down_by_rows(
+ map,
+ start,
+ 1,
+ goal,
+ preserve_column_at_end,
+ text_layout_details,
+ )
}
pub fn up_by_rows(
@@ -66,11 +91,13 @@ pub fn up_by_rows(
row_count: u32,
goal: SelectionGoal,
preserve_column_at_start: bool,
+ text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) {
- let mut goal_column = match goal {
- SelectionGoal::Column(column) => column,
- SelectionGoal::ColumnRange { end, .. } => end,
- _ => map.column_to_chars(start.row(), start.column()),
+ let mut goal_x = match goal {
+ SelectionGoal::HorizontalPosition(x) => x,
+ SelectionGoal::WrappedHorizontalPosition((_, x)) => x,
+ SelectionGoal::HorizontalRange { end, .. } => end,
+ _ => map.x_for_point(start, text_layout_details),
};
let prev_row = start.row().saturating_sub(row_count);
@@ -79,19 +106,19 @@ pub fn up_by_rows(
Bias::Left,
);
if point.row() < start.row() {
- *point.column_mut() = map.column_from_chars(point.row(), goal_column);
+ *point.column_mut() = map.column_for_x(point.row(), goal_x, text_layout_details)
} else if preserve_column_at_start {
return (start, goal);
} else {
point = DisplayPoint::new(0, 0);
- goal_column = 0;
+ goal_x = 0.0;
}
let mut clipped_point = map.clip_point(point, Bias::Left);
if clipped_point.row() < point.row() {
clipped_point = map.clip_point(point, Bias::Right);
}
- (clipped_point, SelectionGoal::Column(goal_column))
+ (clipped_point, SelectionGoal::HorizontalPosition(goal_x))
}
pub fn down_by_rows(
@@ -100,29 +127,31 @@ pub fn down_by_rows(
row_count: u32,
goal: SelectionGoal,
preserve_column_at_end: bool,
+ text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) {
- let mut goal_column = match goal {
- SelectionGoal::Column(column) => column,
- SelectionGoal::ColumnRange { end, .. } => end,
- _ => map.column_to_chars(start.row(), start.column()),
+ let mut goal_x = match goal {
+ SelectionGoal::HorizontalPosition(x) => x,
+ SelectionGoal::WrappedHorizontalPosition((_, x)) => x,
+ SelectionGoal::HorizontalRange { end, .. } => end,
+ _ => map.x_for_point(start, text_layout_details),
};
let new_row = start.row() + row_count;
let mut point = map.clip_point(DisplayPoint::new(new_row, 0), Bias::Right);
if point.row() > start.row() {
- *point.column_mut() = map.column_from_chars(point.row(), goal_column);
+ *point.column_mut() = map.column_for_x(point.row(), goal_x, text_layout_details)
} else if preserve_column_at_end {
return (start, goal);
} else {
point = map.max_point();
- goal_column = map.column_to_chars(point.row(), point.column())
+ goal_x = map.x_for_point(point, text_layout_details)
}
let mut clipped_point = map.clip_point(point, Bias::Right);
if clipped_point.row() > point.row() {
clipped_point = map.clip_point(point, Bias::Left);
}
- (clipped_point, SelectionGoal::Column(goal_column))
+ (clipped_point, SelectionGoal::HorizontalPosition(goal_x))
}
pub fn line_beginning(
@@ -396,9 +425,11 @@ pub fn split_display_range_by_lines(
mod tests {
use super::*;
use crate::{
- display_map::Inlay, test::marked_display_snapshot, Buffer, DisplayMap, ExcerptRange,
- InlayId, MultiBuffer,
+ display_map::Inlay,
+ test::{editor_test_context::EditorTestContext, marked_display_snapshot},
+ Buffer, DisplayMap, ExcerptRange, InlayId, MultiBuffer,
};
+ use project::Project;
use settings::SettingsStore;
use util::post_inc;
@@ -691,123 +722,173 @@ mod tests {
}
#[gpui::test]
- fn test_move_up_and_down_with_excerpts(cx: &mut gpui::AppContext) {
- init_test(cx);
-
- let family_id = cx
- .font_cache()
- .load_family(&["Helvetica"], &Default::default())
- .unwrap();
- let font_id = cx
- .font_cache()
- .select_font(family_id, &Default::default())
- .unwrap();
+ async fn test_move_up_and_down_with_excerpts(cx: &mut gpui::TestAppContext) {
+ cx.update(|cx| {
+ init_test(cx);
+ });
- let buffer =
- cx.add_model(|cx| Buffer::new(0, cx.model_id() as u64, "abc\ndefg\nhijkl\nmn"));
- let multibuffer = cx.add_model(|cx| {
- let mut multibuffer = MultiBuffer::new(0);
- multibuffer.push_excerpts(
- buffer.clone(),
- [
- ExcerptRange {
- context: Point::new(0, 0)..Point::new(1, 4),
- primary: None,
- },
- ExcerptRange {
- context: Point::new(2, 0)..Point::new(3, 2),
- primary: None,
- },
- ],
- cx,
+ let mut cx = EditorTestContext::new(cx).await;
+ let editor = cx.editor.clone();
+ let window = cx.window.clone();
+ cx.update_window(window, |cx| {
+ let text_layout_details =
+ editor.read_with(cx, |editor, cx| editor.text_layout_details(cx));
+
+ let family_id = cx
+ .font_cache()
+ .load_family(&["Helvetica"], &Default::default())
+ .unwrap();
+ let font_id = cx
+ .font_cache()
+ .select_font(family_id, &Default::default())
+ .unwrap();
+
+ let buffer =
+ cx.add_model(|cx| Buffer::new(0, cx.model_id() as u64, "abc\ndefg\nhijkl\nmn"));
+ let multibuffer = cx.add_model(|cx| {
+ let mut multibuffer = MultiBuffer::new(0);
+ multibuffer.push_excerpts(
+ buffer.clone(),
+ [
+ ExcerptRange {
+ context: Point::new(0, 0)..Point::new(1, 4),
+ primary: None,
+ },
+ ExcerptRange {
+ context: Point::new(2, 0)..Point::new(3, 2),
+ primary: None,
+ },
+ ],
+ cx,
+ );
+ multibuffer
+ });
+ let display_map =
+ cx.add_model(|cx| DisplayMap::new(multibuffer, font_id, 14.0, None, 2, 2, cx));
+ let snapshot = display_map.update(cx, |map, cx| map.snapshot(cx));
+
+ assert_eq!(snapshot.text(), "\n\nabc\ndefg\n\n\nhijkl\nmn");
+
+ let col_2_x = snapshot.x_for_point(DisplayPoint::new(2, 2), &text_layout_details);
+
+ // Can't move up into the first excerpt's header
+ assert_eq!(
+ up(
+ &snapshot,
+ DisplayPoint::new(2, 2),
+ SelectionGoal::HorizontalPosition(col_2_x),
+ false,
+ &text_layout_details
+ ),
+ (
+ DisplayPoint::new(2, 0),
+ SelectionGoal::HorizontalPosition(0.0)
+ ),
+ );
+ assert_eq!(
+ up(
+ &snapshot,
+ DisplayPoint::new(2, 0),
+ SelectionGoal::None,
+ false,
+ &text_layout_details
+ ),
+ (
+ DisplayPoint::new(2, 0),
+ SelectionGoal::HorizontalPosition(0.0)
+ ),
);
- multibuffer
- });
- let display_map =
- cx.add_model(|cx| DisplayMap::new(multibuffer, font_id, 14.0, None, 2, 2, cx));
- let snapshot = display_map.update(cx, |map, cx| map.snapshot(cx));
- assert_eq!(snapshot.text(), "\n\nabc\ndefg\n\n\nhijkl\nmn");
+ let col_4_x = snapshot.x_for_point(DisplayPoint::new(3, 4), &text_layout_details);
- // Can't move up into the first excerpt's header
- assert_eq!(
- up(
- &snapshot,
- DisplayPoint::new(2, 2),
- SelectionGoal::Column(2),
- false
- ),
- (DisplayPoint::new(2, 0), SelectionGoal::Column(0)),
- );
- assert_eq!(
- up(
- &snapshot,
- DisplayPoint::new(2, 0),
- SelectionGoal::None,
- false
- ),
- (DisplayPoint::new(2, 0), SelectionGoal::Column(0)),
- );
+ // Move up and down within first excerpt
+ assert_eq!(
+ up(
+ &snapshot,
+ DisplayPoint::new(3, 4),
+ SelectionGoal::HorizontalPosition(col_4_x),
+ false,
+ &text_layout_details
+ ),
+ (
+ DisplayPoint::new(2, 3),
+ SelectionGoal::HorizontalPosition(col_4_x)
+ ),
+ );
+ assert_eq!(
+ down(
+ &snapshot,
+ DisplayPoint::new(2, 3),
+ SelectionGoal::HorizontalPosition(col_4_x),
+ false,
+ &text_layout_details
+ ),
+ (
+ DisplayPoint::new(3, 4),
+ SelectionGoal::HorizontalPosition(col_4_x)
+ ),
+ );
- // Move up and down within first excerpt
- assert_eq!(
- up(
- &snapshot,
- DisplayPoint::new(3, 4),
- SelectionGoal::Column(4),
- false
- ),
- (DisplayPoint::new(2, 3), SelectionGoal::Column(4)),
- );
- assert_eq!(
- down(
- &snapshot,
- DisplayPoint::new(2, 3),
- SelectionGoal::Column(4),
- false
- ),
- (DisplayPoint::new(3, 4), SelectionGoal::Column(4)),
- );
+ let col_5_x = snapshot.x_for_point(DisplayPoint::new(6, 5), &text_layout_details);
- // Move up and down across second excerpt's header
- assert_eq!(
- up(
- &snapshot,
- DisplayPoint::new(6, 5),
- SelectionGoal::Column(5),
- false
- ),
- (DisplayPoint::new(3, 4), SelectionGoal::Column(5)),
- );
- assert_eq!(
- down(
- &snapshot,
- DisplayPoint::new(3, 4),
- SelectionGoal::Column(5),
- false
- ),
- (DisplayPoint::new(6, 5), SelectionGoal::Column(5)),
- );
+ // Move up and down across second excerpt's header
+ assert_eq!(
+ up(
+ &snapshot,
+ DisplayPoint::new(6, 5),
+ SelectionGoal::HorizontalPosition(col_5_x),
+ false,
+ &text_layout_details
+ ),
+ (
+ DisplayPoint::new(3, 4),
+ SelectionGoal::HorizontalPosition(col_5_x)
+ ),
+ );
+ assert_eq!(
+ down(
+ &snapshot,
+ DisplayPoint::new(3, 4),
+ SelectionGoal::HorizontalPosition(col_5_x),
+ false,
+ &text_layout_details
+ ),
+ (
+ DisplayPoint::new(6, 5),
+ SelectionGoal::HorizontalPosition(col_5_x)
+ ),
+ );
- // Can't move down off the end
- assert_eq!(
- down(
- &snapshot,
- DisplayPoint::new(7, 0),
- SelectionGoal::Column(0),
- false
- ),
- (DisplayPoint::new(7, 2), SelectionGoal::Column(2)),
- );
- assert_eq!(
- down(
- &snapshot,
- DisplayPoint::new(7, 2),
- SelectionGoal::Column(2),
- false
- ),
- (DisplayPoint::new(7, 2), SelectionGoal::Column(2)),
- );
+ let max_point_x = snapshot.x_for_point(DisplayPoint::new(7, 2), &text_layout_details);
+
+ // Can't move down off the end
+ assert_eq!(
+ down(
+ &snapshot,
+ DisplayPoint::new(7, 0),
+ SelectionGoal::HorizontalPosition(0.0),
+ false,
+ &text_layout_details
+ ),
+ (
+ DisplayPoint::new(7, 2),
+ SelectionGoal::HorizontalPosition(max_point_x)
+ ),
+ );
+ assert_eq!(
+ down(
+ &snapshot,
+ DisplayPoint::new(7, 2),
+ SelectionGoal::HorizontalPosition(max_point_x),
+ false,
+ &text_layout_details
+ ),
+ (
+ DisplayPoint::new(7, 2),
+ SelectionGoal::HorizontalPosition(max_point_x)
+ ),
+ );
+ });
}
fn init_test(cx: &mut gpui::AppContext) {
@@ -815,5 +896,6 @@ mod tests {
theme::init((), cx);
language::init(cx);
crate::init(cx);
+ Project::init_settings(cx);
}
}
@@ -1,6 +1,6 @@
use std::{
cell::Ref,
- cmp, iter, mem,
+ iter, mem,
ops::{Deref, DerefMut, Range, Sub},
sync::Arc,
};
@@ -13,6 +13,7 @@ use util::post_inc;
use crate::{
display_map::{DisplayMap, DisplaySnapshot, ToDisplayPoint},
+ movement::TextLayoutDetails,
Anchor, DisplayPoint, ExcerptId, MultiBuffer, MultiBufferSnapshot, SelectMode, ToOffset,
};
@@ -305,23 +306,29 @@ impl SelectionsCollection {
&mut self,
display_map: &DisplaySnapshot,
row: u32,
- columns: &Range<u32>,
+ positions: &Range<f32>,
reversed: bool,
+ text_layout_details: &TextLayoutDetails,
) -> Option<Selection<Point>> {
- let is_empty = columns.start == columns.end;
+ let is_empty = positions.start == positions.end;
let line_len = display_map.line_len(row);
- if columns.start < line_len || (is_empty && columns.start == line_len) {
- let start = DisplayPoint::new(row, columns.start);
- let end = DisplayPoint::new(row, cmp::min(columns.end, line_len));
+
+ let layed_out_line = display_map.lay_out_line_for_row(row, &text_layout_details);
+
+ let start_col = layed_out_line.closest_index_for_x(positions.start) as u32;
+ if start_col < line_len || (is_empty && positions.start == layed_out_line.width()) {
+ let start = DisplayPoint::new(row, start_col);
+ let end_col = layed_out_line.closest_index_for_x(positions.end) as u32;
+ let end = DisplayPoint::new(row, end_col);
Some(Selection {
id: post_inc(&mut self.next_selection_id),
start: start.to_point(display_map),
end: end.to_point(display_map),
reversed,
- goal: SelectionGoal::ColumnRange {
- start: columns.start,
- end: columns.end,
+ goal: SelectionGoal::HorizontalRange {
+ start: positions.start,
+ end: positions.end,
},
})
} else {
@@ -30,7 +30,7 @@ struct StateInner<V> {
orientation: Orientation,
overdraw: f32,
#[allow(clippy::type_complexity)]
- scroll_handler: Option<Box<dyn FnMut(Range<usize>, &mut V, &mut ViewContext<V>)>>,
+ scroll_handler: Option<Box<dyn FnMut(Range<usize>, usize, &mut V, &mut ViewContext<V>)>>,
}
#[derive(Clone, Copy, Debug, Default, PartialEq)]
@@ -378,6 +378,10 @@ impl<V: 'static> ListState<V> {
.extend((0..element_count).map(|_| ListItem::Unrendered), &());
}
+ pub fn item_count(&self) -> usize {
+ self.0.borrow().items.summary().count
+ }
+
pub fn splice(&self, old_range: Range<usize>, count: usize) {
let state = &mut *self.0.borrow_mut();
@@ -416,7 +420,7 @@ impl<V: 'static> ListState<V> {
pub fn set_scroll_handler(
&mut self,
- handler: impl FnMut(Range<usize>, &mut V, &mut ViewContext<V>) + 'static,
+ handler: impl FnMut(Range<usize>, usize, &mut V, &mut ViewContext<V>) + 'static,
) {
self.0.borrow_mut().scroll_handler = Some(Box::new(handler))
}
@@ -529,7 +533,12 @@ impl<V: 'static> StateInner<V> {
if self.scroll_handler.is_some() {
let visible_range = self.visible_range(height, scroll_top);
- self.scroll_handler.as_mut().unwrap()(visible_range, view, cx);
+ self.scroll_handler.as_mut().unwrap()(
+ visible_range,
+ self.items.summary().count,
+ view,
+ cx,
+ );
}
cx.notify();
@@ -266,6 +266,8 @@ impl Line {
self.layout.len == 0
}
+ /// index_for_x returns the character containing the given x coordinate.
+ /// (e.g. to handle a mouse-click)
pub fn index_for_x(&self, x: f32) -> Option<usize> {
if x >= self.layout.width {
None
@@ -281,6 +283,28 @@ impl Line {
}
}
+ /// closest_index_for_x returns the character boundary closest to the given x coordinate
+ /// (e.g. to handle aligning up/down arrow keys)
+ pub fn closest_index_for_x(&self, x: f32) -> usize {
+ let mut prev_index = 0;
+ let mut prev_x = 0.0;
+
+ for run in self.layout.runs.iter() {
+ for glyph in run.glyphs.iter() {
+ if glyph.position.x() >= x {
+ if glyph.position.x() - x < x - prev_x {
+ return glyph.index;
+ } else {
+ return prev_index;
+ }
+ }
+ prev_index = glyph.index;
+ prev_x = glyph.position.x();
+ }
+ }
+ prev_index
+ }
+
pub fn paint(
&self,
origin: Vector2F,
@@ -201,7 +201,7 @@ pub struct CodeAction {
pub lsp_action: lsp::CodeAction,
}
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, PartialEq)]
pub enum Operation {
Buffer(text::Operation),
@@ -224,7 +224,7 @@ pub enum Operation {
},
}
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, PartialEq)]
pub enum Event {
Operation(Operation),
Edited,
@@ -226,8 +226,8 @@ impl CachedLspAdapter {
self.adapter.label_for_symbol(name, kind, language).await
}
- pub fn enabled_formatters(&self) -> Vec<BundledFormatter> {
- self.adapter.enabled_formatters()
+ pub fn prettier_plugins(&self) -> &[&'static str] {
+ self.adapter.prettier_plugins()
}
}
@@ -336,31 +336,8 @@ pub trait LspAdapter: 'static + Send + Sync {
Default::default()
}
- fn enabled_formatters(&self) -> Vec<BundledFormatter> {
- Vec::new()
- }
-}
-
-#[derive(Clone, Debug, PartialEq, Eq)]
-pub enum BundledFormatter {
- Prettier {
- // See https://prettier.io/docs/en/options.html#parser for a list of valid values.
- // Usually, every language has a single parser (standard or plugin-provided), hence `Some("parser_name")` can be used.
- // There can not be multiple parsers for a single language, in case of a conflict, we would attempt to select the one with most plugins.
- //
- // But exceptions like Tailwind CSS exist, which uses standard parsers for CSS/JS/HTML/etc. but require an extra plugin to be installed.
- // For those cases, `None` will install the plugin but apply other, regular parser defined for the language, and this would not be a conflict.
- parser_name: Option<&'static str>,
- plugin_names: Vec<&'static str>,
- },
-}
-
-impl BundledFormatter {
- pub fn prettier(parser_name: &'static str) -> Self {
- Self::Prettier {
- parser_name: Some(parser_name),
- plugin_names: Vec::new(),
- }
+ fn prettier_plugins(&self) -> &[&'static str] {
+ &[]
}
}
@@ -398,6 +375,8 @@ pub struct LanguageConfig {
pub overrides: HashMap<String, LanguageConfigOverride>,
#[serde(default)]
pub word_characters: HashSet<char>,
+ #[serde(default)]
+ pub prettier_parser_name: Option<String>,
}
#[derive(Debug, Default)]
@@ -471,6 +450,7 @@ impl Default for LanguageConfig {
overrides: Default::default(),
collapsed_placeholder: Default::default(),
word_characters: Default::default(),
+ prettier_parser_name: None,
}
}
}
@@ -496,7 +476,7 @@ pub struct FakeLspAdapter {
pub initializer: Option<Box<dyn 'static + Send + Sync + Fn(&mut lsp::FakeLanguageServer)>>,
pub disk_based_diagnostics_progress_token: Option<String>,
pub disk_based_diagnostics_sources: Vec<String>,
- pub enabled_formatters: Vec<BundledFormatter>,
+ pub prettier_plugins: Vec<&'static str>,
}
#[derive(Clone, Debug, Default)]
@@ -1597,6 +1577,10 @@ impl Language {
override_id: None,
}
}
+
+ pub fn prettier_parser_name(&self) -> Option<&str> {
+ self.config.prettier_parser_name.as_deref()
+ }
}
impl LanguageScope {
@@ -1759,7 +1743,7 @@ impl Default for FakeLspAdapter {
disk_based_diagnostics_progress_token: None,
initialization_options: None,
disk_based_diagnostics_sources: Vec::new(),
- enabled_formatters: Vec::new(),
+ prettier_plugins: Vec::new(),
}
}
}
@@ -1817,8 +1801,8 @@ impl LspAdapter for Arc<FakeLspAdapter> {
self.initialization_options.clone()
}
- fn enabled_formatters(&self) -> Vec<BundledFormatter> {
- self.enabled_formatters.clone()
+ fn prettier_plugins(&self) -> &[&'static str] {
+ &self.prettier_plugins
}
}
@@ -0,0 +1,42 @@
+[package]
+name = "notifications"
+version = "0.1.0"
+edition = "2021"
+publish = false
+
+[lib]
+path = "src/notification_store.rs"
+doctest = false
+
+[features]
+test-support = [
+ "channel/test-support",
+ "collections/test-support",
+ "gpui/test-support",
+ "rpc/test-support",
+]
+
+[dependencies]
+channel = { path = "../channel" }
+client = { path = "../client" }
+clock = { path = "../clock" }
+collections = { path = "../collections" }
+db = { path = "../db" }
+feature_flags = { path = "../feature_flags" }
+gpui = { path = "../gpui" }
+rpc = { path = "../rpc" }
+settings = { path = "../settings" }
+sum_tree = { path = "../sum_tree" }
+text = { path = "../text" }
+util = { path = "../util" }
+
+anyhow.workspace = true
+time.workspace = true
+
+[dev-dependencies]
+client = { path = "../client", features = ["test-support"] }
+collections = { path = "../collections", features = ["test-support"] }
+gpui = { path = "../gpui", features = ["test-support"] }
+rpc = { path = "../rpc", features = ["test-support"] }
+settings = { path = "../settings", features = ["test-support"] }
+util = { path = "../util", features = ["test-support"] }
@@ -0,0 +1,459 @@
+use anyhow::Result;
+use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
+use client::{Client, UserStore};
+use collections::HashMap;
+use db::smol::stream::StreamExt;
+use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
+use rpc::{proto, Notification, TypedEnvelope};
+use std::{ops::Range, sync::Arc};
+use sum_tree::{Bias, SumTree};
+use time::OffsetDateTime;
+use util::ResultExt;
+
+pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
+ let notification_store = cx.add_model(|cx| NotificationStore::new(client, user_store, cx));
+ cx.set_global(notification_store);
+}
+
+pub struct NotificationStore {
+ client: Arc<Client>,
+ user_store: ModelHandle<UserStore>,
+ channel_messages: HashMap<u64, ChannelMessage>,
+ channel_store: ModelHandle<ChannelStore>,
+ notifications: SumTree<NotificationEntry>,
+ loaded_all_notifications: bool,
+ _watch_connection_status: Task<Option<()>>,
+ _subscriptions: Vec<client::Subscription>,
+}
+
+#[derive(Clone, PartialEq, Eq, Debug)]
+pub enum NotificationEvent {
+ NotificationsUpdated {
+ old_range: Range<usize>,
+ new_count: usize,
+ },
+ NewNotification {
+ entry: NotificationEntry,
+ },
+ NotificationRemoved {
+ entry: NotificationEntry,
+ },
+ NotificationRead {
+ entry: NotificationEntry,
+ },
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub struct NotificationEntry {
+ pub id: u64,
+ pub notification: Notification,
+ pub timestamp: OffsetDateTime,
+ pub is_read: bool,
+ pub response: Option<bool>,
+}
+
+#[derive(Clone, Debug, Default)]
+pub struct NotificationSummary {
+ max_id: u64,
+ count: usize,
+ unread_count: usize,
+}
+
+#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
+struct Count(usize);
+
+#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
+struct UnreadCount(usize);
+
+#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
+struct NotificationId(u64);
+
+impl NotificationStore {
+ pub fn global(cx: &AppContext) -> ModelHandle<Self> {
+ cx.global::<ModelHandle<Self>>().clone()
+ }
+
+ pub fn new(
+ client: Arc<Client>,
+ user_store: ModelHandle<UserStore>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let mut connection_status = client.status();
+ let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
+ while let Some(status) = connection_status.next().await {
+ let this = this.upgrade(&cx)?;
+ match status {
+ client::Status::Connected { .. } => {
+ if let Some(task) = this.update(&mut cx, |this, cx| this.handle_connect(cx))
+ {
+ task.await.log_err()?;
+ }
+ }
+ _ => this.update(&mut cx, |this, cx| this.handle_disconnect(cx)),
+ }
+ }
+ Some(())
+ });
+
+ Self {
+ channel_store: ChannelStore::global(cx),
+ notifications: Default::default(),
+ loaded_all_notifications: false,
+ channel_messages: Default::default(),
+ _watch_connection_status: watch_connection_status,
+ _subscriptions: vec![
+ client.add_message_handler(cx.handle(), Self::handle_new_notification),
+ client.add_message_handler(cx.handle(), Self::handle_delete_notification),
+ ],
+ user_store,
+ client,
+ }
+ }
+
+ pub fn notification_count(&self) -> usize {
+ self.notifications.summary().count
+ }
+
+ pub fn unread_notification_count(&self) -> usize {
+ self.notifications.summary().unread_count
+ }
+
+ pub fn channel_message_for_id(&self, id: u64) -> Option<&ChannelMessage> {
+ self.channel_messages.get(&id)
+ }
+
+ // Get the nth newest notification.
+ pub fn notification_at(&self, ix: usize) -> Option<&NotificationEntry> {
+ let count = self.notifications.summary().count;
+ if ix >= count {
+ return None;
+ }
+ let ix = count - 1 - ix;
+ let mut cursor = self.notifications.cursor::<Count>();
+ cursor.seek(&Count(ix), Bias::Right, &());
+ cursor.item()
+ }
+
+ pub fn notification_for_id(&self, id: u64) -> Option<&NotificationEntry> {
+ let mut cursor = self.notifications.cursor::<NotificationId>();
+ cursor.seek(&NotificationId(id), Bias::Left, &());
+ if let Some(item) = cursor.item() {
+ if item.id == id {
+ return Some(item);
+ }
+ }
+ None
+ }
+
+ pub fn load_more_notifications(
+ &self,
+ clear_old: bool,
+ cx: &mut ModelContext<Self>,
+ ) -> Option<Task<Result<()>>> {
+ if self.loaded_all_notifications && !clear_old {
+ return None;
+ }
+
+ let before_id = if clear_old {
+ None
+ } else {
+ self.notifications.first().map(|entry| entry.id)
+ };
+ let request = self.client.request(proto::GetNotifications { before_id });
+ Some(cx.spawn(|this, mut cx| async move {
+ let response = request.await?;
+ this.update(&mut cx, |this, _| {
+ this.loaded_all_notifications = response.done
+ });
+ Self::add_notifications(
+ this,
+ response.notifications,
+ AddNotificationsOptions {
+ is_new: false,
+ clear_old,
+ includes_first: response.done,
+ },
+ cx,
+ )
+ .await?;
+ Ok(())
+ }))
+ }
+
+ fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Option<Task<Result<()>>> {
+ self.notifications = Default::default();
+ self.channel_messages = Default::default();
+ cx.notify();
+ self.load_more_notifications(true, cx)
+ }
+
+ fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
+ cx.notify()
+ }
+
+ async fn handle_new_notification(
+ this: ModelHandle<Self>,
+ envelope: TypedEnvelope<proto::AddNotification>,
+ _: Arc<Client>,
+ cx: AsyncAppContext,
+ ) -> Result<()> {
+ Self::add_notifications(
+ this,
+ envelope.payload.notification.into_iter().collect(),
+ AddNotificationsOptions {
+ is_new: true,
+ clear_old: false,
+ includes_first: false,
+ },
+ cx,
+ )
+ .await
+ }
+
+ async fn handle_delete_notification(
+ this: ModelHandle<Self>,
+ envelope: TypedEnvelope<proto::DeleteNotification>,
+ _: Arc<Client>,
+ mut cx: AsyncAppContext,
+ ) -> Result<()> {
+ this.update(&mut cx, |this, cx| {
+ this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
+ Ok(())
+ })
+ }
+
+ async fn add_notifications(
+ this: ModelHandle<Self>,
+ notifications: Vec<proto::Notification>,
+ options: AddNotificationsOptions,
+ mut cx: AsyncAppContext,
+ ) -> Result<()> {
+ let mut user_ids = Vec::new();
+ let mut message_ids = Vec::new();
+
+ let notifications = notifications
+ .into_iter()
+ .filter_map(|message| {
+ Some(NotificationEntry {
+ id: message.id,
+ is_read: message.is_read,
+ timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)
+ .ok()?,
+ notification: Notification::from_proto(&message)?,
+ response: message.response,
+ })
+ })
+ .collect::<Vec<_>>();
+ if notifications.is_empty() {
+ return Ok(());
+ }
+
+ for entry in ¬ifications {
+ match entry.notification {
+ Notification::ChannelInvitation { inviter_id, .. } => {
+ user_ids.push(inviter_id);
+ }
+ Notification::ContactRequest {
+ sender_id: requester_id,
+ } => {
+ user_ids.push(requester_id);
+ }
+ Notification::ContactRequestAccepted {
+ responder_id: contact_id,
+ } => {
+ user_ids.push(contact_id);
+ }
+ Notification::ChannelMessageMention {
+ sender_id,
+ message_id,
+ ..
+ } => {
+ user_ids.push(sender_id);
+ message_ids.push(message_id);
+ }
+ }
+ }
+
+ let (user_store, channel_store) = this.read_with(&cx, |this, _| {
+ (this.user_store.clone(), this.channel_store.clone())
+ });
+
+ user_store
+ .update(&mut cx, |store, cx| store.get_users(user_ids, cx))
+ .await?;
+ let messages = channel_store
+ .update(&mut cx, |store, cx| {
+ store.fetch_channel_messages(message_ids, cx)
+ })
+ .await?;
+ this.update(&mut cx, |this, cx| {
+ if options.clear_old {
+ cx.emit(NotificationEvent::NotificationsUpdated {
+ old_range: 0..this.notifications.summary().count,
+ new_count: 0,
+ });
+ this.notifications = SumTree::default();
+ this.channel_messages.clear();
+ this.loaded_all_notifications = false;
+ }
+
+ if options.includes_first {
+ this.loaded_all_notifications = true;
+ }
+
+ this.channel_messages
+ .extend(messages.into_iter().filter_map(|message| {
+ if let ChannelMessageId::Saved(id) = message.id {
+ Some((id, message))
+ } else {
+ None
+ }
+ }));
+
+ this.splice_notifications(
+ notifications
+ .into_iter()
+ .map(|notification| (notification.id, Some(notification))),
+ options.is_new,
+ cx,
+ );
+ });
+
+ Ok(())
+ }
+
+ fn splice_notifications(
+ &mut self,
+ notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
+ is_new: bool,
+ cx: &mut ModelContext<'_, NotificationStore>,
+ ) {
+ let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
+ let mut new_notifications = SumTree::new();
+ let mut old_range = 0..0;
+
+ for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
+ new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
+
+ if i == 0 {
+ old_range.start = cursor.start().1 .0;
+ }
+
+ let old_notification = cursor.item();
+ if let Some(old_notification) = old_notification {
+ if old_notification.id == id {
+ cursor.next(&());
+
+ if let Some(new_notification) = &new_notification {
+ if new_notification.is_read {
+ cx.emit(NotificationEvent::NotificationRead {
+ entry: new_notification.clone(),
+ });
+ }
+ } else {
+ cx.emit(NotificationEvent::NotificationRemoved {
+ entry: old_notification.clone(),
+ });
+ }
+ }
+ } else if let Some(new_notification) = &new_notification {
+ if is_new {
+ cx.emit(NotificationEvent::NewNotification {
+ entry: new_notification.clone(),
+ });
+ }
+ }
+
+ if let Some(notification) = new_notification {
+ new_notifications.push(notification, &());
+ }
+ }
+
+ old_range.end = cursor.start().1 .0;
+ let new_count = new_notifications.summary().count - old_range.start;
+ new_notifications.append(cursor.suffix(&()), &());
+ drop(cursor);
+
+ self.notifications = new_notifications;
+ cx.emit(NotificationEvent::NotificationsUpdated {
+ old_range,
+ new_count,
+ });
+ }
+
+ pub fn respond_to_notification(
+ &mut self,
+ notification: Notification,
+ response: bool,
+ cx: &mut ModelContext<Self>,
+ ) {
+ match notification {
+ Notification::ContactRequest { sender_id } => {
+ self.user_store
+ .update(cx, |store, cx| {
+ store.respond_to_contact_request(sender_id, response, cx)
+ })
+ .detach();
+ }
+ Notification::ChannelInvitation { channel_id, .. } => {
+ self.channel_store
+ .update(cx, |store, cx| {
+ store.respond_to_channel_invite(channel_id, response, cx)
+ })
+ .detach();
+ }
+ _ => {}
+ }
+ }
+}
+
+impl Entity for NotificationStore {
+ type Event = NotificationEvent;
+}
+
+impl sum_tree::Item for NotificationEntry {
+ type Summary = NotificationSummary;
+
+ fn summary(&self) -> Self::Summary {
+ NotificationSummary {
+ max_id: self.id,
+ count: 1,
+ unread_count: if self.is_read { 0 } else { 1 },
+ }
+ }
+}
+
+impl sum_tree::Summary for NotificationSummary {
+ type Context = ();
+
+ fn add_summary(&mut self, summary: &Self, _: &()) {
+ self.max_id = self.max_id.max(summary.max_id);
+ self.count += summary.count;
+ self.unread_count += summary.unread_count;
+ }
+}
+
+impl<'a> sum_tree::Dimension<'a, NotificationSummary> for NotificationId {
+ fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
+ debug_assert!(summary.max_id > self.0);
+ self.0 = summary.max_id;
+ }
+}
+
+impl<'a> sum_tree::Dimension<'a, NotificationSummary> for Count {
+ fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
+ self.0 += summary.count;
+ }
+}
+
+impl<'a> sum_tree::Dimension<'a, NotificationSummary> for UnreadCount {
+ fn add_summary(&mut self, summary: &NotificationSummary, _: &()) {
+ self.0 += summary.unread_count;
+ }
+}
+
+struct AddNotificationsOptions {
+ is_new: bool,
+ clear_old: bool,
+ includes_first: bool,
+}
@@ -3,11 +3,11 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use anyhow::Context;
-use collections::{HashMap, HashSet};
+use collections::HashMap;
use fs::Fs;
use gpui::{AsyncAppContext, ModelHandle};
use language::language_settings::language_settings;
-use language::{Buffer, BundledFormatter, Diff};
+use language::{Buffer, Diff};
use lsp::{LanguageServer, LanguageServerId};
use node_runtime::NodeRuntime;
use serde::{Deserialize, Serialize};
@@ -242,40 +242,16 @@ impl Prettier {
Self::Real(local) => {
let params = buffer.read_with(cx, |buffer, cx| {
let buffer_language = buffer.language();
- let parsers_with_plugins = buffer_language
- .into_iter()
- .flat_map(|language| {
- language
- .lsp_adapters()
- .iter()
- .flat_map(|adapter| adapter.enabled_formatters())
- .filter_map(|formatter| match formatter {
- BundledFormatter::Prettier {
- parser_name,
- plugin_names,
- } => Some((parser_name, plugin_names)),
- })
- })
- .fold(
- HashMap::default(),
- |mut parsers_with_plugins, (parser_name, plugins)| {
- match parser_name {
- Some(parser_name) => parsers_with_plugins
- .entry(parser_name)
- .or_insert_with(HashSet::default)
- .extend(plugins),
- None => parsers_with_plugins.values_mut().for_each(|existing_plugins| {
- existing_plugins.extend(plugins.iter());
- }),
- }
- parsers_with_plugins
- },
- );
-
- let selected_parser_with_plugins = parsers_with_plugins.iter().max_by_key(|(_, plugins)| plugins.len());
- if parsers_with_plugins.len() > 1 {
- log::warn!("Found multiple parsers with plugins {parsers_with_plugins:?}, will select only one: {selected_parser_with_plugins:?}");
- }
+ let parser_with_plugins = buffer_language.and_then(|l| {
+ let prettier_parser = l.prettier_parser_name()?;
+ let mut prettier_plugins = l
+ .lsp_adapters()
+ .iter()
+ .flat_map(|adapter| adapter.prettier_plugins())
+ .collect::<Vec<_>>();
+ prettier_plugins.dedup();
+ Some((prettier_parser, prettier_plugins))
+ });
let prettier_node_modules = self.prettier_dir().join("node_modules");
anyhow::ensure!(prettier_node_modules.is_dir(), "Prettier node_modules dir does not exist: {prettier_node_modules:?}");
@@ -296,7 +272,7 @@ impl Prettier {
}
None
};
- let (parser, located_plugins) = match selected_parser_with_plugins {
+ let (parser, located_plugins) = match parser_with_plugins {
Some((parser, plugins)) => {
// Tailwind plugin requires being added last
// https://github.com/tailwindlabs/prettier-plugin-tailwindcss#compatibility-with-other-prettier-plugins
@@ -39,11 +39,11 @@ use language::{
deserialize_anchor, deserialize_fingerprint, deserialize_line_ending, deserialize_version,
serialize_anchor, serialize_version, split_operations,
},
- range_from_lsp, range_to_lsp, Bias, Buffer, BufferSnapshot, BundledFormatter, CachedLspAdapter,
- CodeAction, CodeLabel, Completion, Diagnostic, DiagnosticEntry, DiagnosticSet, Diff,
- Event as BufferEvent, File as _, Language, LanguageRegistry, LanguageServerName, LocalFile,
- LspAdapterDelegate, OffsetRangeExt, Operation, Patch, PendingLanguageServer, PointUtf16,
- TextBufferSnapshot, ToOffset, ToPointUtf16, Transaction, Unclipped,
+ range_from_lsp, range_to_lsp, Bias, Buffer, BufferSnapshot, CachedLspAdapter, CodeAction,
+ CodeLabel, Completion, Diagnostic, DiagnosticEntry, DiagnosticSet, Diff, Event as BufferEvent,
+ File as _, Language, LanguageRegistry, LanguageServerName, LocalFile, LspAdapterDelegate,
+ OffsetRangeExt, Operation, Patch, PendingLanguageServer, PointUtf16, TextBufferSnapshot,
+ ToOffset, ToPointUtf16, Transaction, Unclipped,
};
use log::error;
use lsp::{
@@ -8352,12 +8352,7 @@ impl Project {
let Some(buffer_language) = buffer.language() else {
return Task::ready(None);
};
- if !buffer_language
- .lsp_adapters()
- .iter()
- .flat_map(|adapter| adapter.enabled_formatters())
- .any(|formatter| matches!(formatter, BundledFormatter::Prettier { .. }))
- {
+ if buffer_language.prettier_parser_name().is_none() {
return Task::ready(None);
}
@@ -8510,16 +8505,15 @@ impl Project {
};
let mut prettier_plugins = None;
- for formatter in new_language
- .lsp_adapters()
- .into_iter()
- .flat_map(|adapter| adapter.enabled_formatters())
- {
- match formatter {
- BundledFormatter::Prettier { plugin_names, .. } => prettier_plugins
- .get_or_insert_with(|| HashSet::default())
- .extend(plugin_names),
- }
+ if new_language.prettier_parser_name().is_some() {
+ prettier_plugins
+ .get_or_insert_with(|| HashSet::default())
+ .extend(
+ new_language
+ .lsp_adapters()
+ .iter()
+ .flat_map(|adapter| adapter.prettier_plugins()),
+ )
}
let Some(prettier_plugins) = prettier_plugins else {
return Task::ready(Ok(()));
@@ -1,20 +1,35 @@
use std::{ops::Range, sync::Arc};
+use anyhow::bail;
use futures::FutureExt;
use gpui::{
- color::Color,
elements::Text,
- fonts::{HighlightStyle, TextStyle, Underline, Weight},
+ fonts::{HighlightStyle, Underline, Weight},
platform::{CursorStyle, MouseButton},
AnyElement, CursorRegion, Element, MouseRegion, ViewContext,
};
use language::{HighlightId, Language, LanguageRegistry};
-use theme::SyntaxTheme;
+use theme::{RichTextStyle, SyntaxTheme};
+use util::RangeExt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Highlight {
Id(HighlightId),
Highlight(HighlightStyle),
+ Mention,
+ SelfMention,
+}
+
+impl From<HighlightStyle> for Highlight {
+ fn from(style: HighlightStyle) -> Self {
+ Self::Highlight(style)
+ }
+}
+
+impl From<HighlightId> for Highlight {
+ fn from(style: HighlightId) -> Self {
+ Self::Id(style)
+ }
}
#[derive(Debug, Clone)]
@@ -25,18 +40,32 @@ pub struct RichText {
pub regions: Vec<RenderedRegion>,
}
-#[derive(Debug, Clone)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+pub enum BackgroundKind {
+ Code,
+ /// A mention background for non-self user.
+ Mention,
+ SelfMention,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RenderedRegion {
- code: bool,
- link_url: Option<String>,
+ pub background_kind: Option<BackgroundKind>,
+ pub link_url: Option<String>,
+}
+
+/// Allows one to specify extra links to the rendered markdown, which can be used
+/// for e.g. mentions.
+pub struct Mention {
+ pub range: Range<usize>,
+ pub is_self_mention: bool,
}
impl RichText {
pub fn element<V: 'static>(
&self,
syntax: Arc<SyntaxTheme>,
- style: TextStyle,
- code_span_background_color: Color,
+ style: RichTextStyle,
cx: &mut ViewContext<V>,
) -> AnyElement<V> {
let mut region_id = 0;
@@ -45,7 +74,7 @@ impl RichText {
let regions = self.regions.clone();
enum Markdown {}
- Text::new(self.text.clone(), style.clone())
+ Text::new(self.text.clone(), style.text.clone())
.with_highlights(
self.highlights
.iter()
@@ -53,6 +82,8 @@ impl RichText {
let style = match highlight {
Highlight::Id(id) => id.style(&syntax)?,
Highlight::Highlight(style) => style.clone(),
+ Highlight::Mention => style.mention_highlight,
+ Highlight::SelfMention => style.self_mention_highlight,
};
Some((range.clone(), style))
})
@@ -73,22 +104,55 @@ impl RichText {
}),
);
}
- if region.code {
- cx.scene().push_quad(gpui::Quad {
- bounds,
- background: Some(code_span_background_color),
- border: Default::default(),
- corner_radii: (2.0).into(),
- });
+ if let Some(region_kind) = ®ion.background_kind {
+ let background = match region_kind {
+ BackgroundKind::Code => style.code_background,
+ BackgroundKind::Mention => style.mention_background,
+ BackgroundKind::SelfMention => style.self_mention_background,
+ };
+ if background.is_some() {
+ cx.scene().push_quad(gpui::Quad {
+ bounds,
+ background,
+ border: Default::default(),
+ corner_radii: (2.0).into(),
+ });
+ }
}
})
.with_soft_wrap(true)
.into_any()
}
+
+ pub fn add_mention(
+ &mut self,
+ range: Range<usize>,
+ is_current_user: bool,
+ mention_style: HighlightStyle,
+ ) -> anyhow::Result<()> {
+ if range.end > self.text.len() {
+ bail!(
+ "Mention in range {range:?} is outside of bounds for a message of length {}",
+ self.text.len()
+ );
+ }
+
+ if is_current_user {
+ self.region_ranges.push(range.clone());
+ self.regions.push(RenderedRegion {
+ background_kind: Some(BackgroundKind::Mention),
+ link_url: None,
+ });
+ }
+ self.highlights
+ .push((range, Highlight::Highlight(mention_style)));
+ Ok(())
+ }
}
pub fn render_markdown_mut(
block: &str,
+ mut mentions: &[Mention],
language_registry: &Arc<LanguageRegistry>,
language: Option<&Arc<Language>>,
data: &mut RichText,
@@ -101,15 +165,40 @@ pub fn render_markdown_mut(
let mut current_language = None;
let mut list_stack = Vec::new();
- for event in Parser::new_ext(&block, Options::all()) {
+ let options = Options::all();
+ for (event, source_range) in Parser::new_ext(&block, options).into_offset_iter() {
let prev_len = data.text.len();
match event {
Event::Text(t) => {
if let Some(language) = ¤t_language {
render_code(&mut data.text, &mut data.highlights, t.as_ref(), language);
} else {
- data.text.push_str(t.as_ref());
+ if let Some(mention) = mentions.first() {
+ if source_range.contains_inclusive(&mention.range) {
+ mentions = &mentions[1..];
+ let range = (prev_len + mention.range.start - source_range.start)
+ ..(prev_len + mention.range.end - source_range.start);
+ data.highlights.push((
+ range.clone(),
+ if mention.is_self_mention {
+ Highlight::SelfMention
+ } else {
+ Highlight::Mention
+ },
+ ));
+ data.region_ranges.push(range);
+ data.regions.push(RenderedRegion {
+ background_kind: Some(if mention.is_self_mention {
+ BackgroundKind::SelfMention
+ } else {
+ BackgroundKind::Mention
+ }),
+ link_url: None,
+ });
+ }
+ }
+ data.text.push_str(t.as_ref());
let mut style = HighlightStyle::default();
if bold_depth > 0 {
style.weight = Some(Weight::BOLD);
@@ -121,7 +210,7 @@ pub fn render_markdown_mut(
data.region_ranges.push(prev_len..data.text.len());
data.regions.push(RenderedRegion {
link_url: Some(link_url),
- code: false,
+ background_kind: None,
});
style.underline = Some(Underline {
thickness: 1.0.into(),
@@ -162,7 +251,7 @@ pub fn render_markdown_mut(
));
}
data.regions.push(RenderedRegion {
- code: true,
+ background_kind: Some(BackgroundKind::Code),
link_url: link_url.clone(),
});
}
@@ -228,6 +317,7 @@ pub fn render_markdown_mut(
pub fn render_markdown(
block: String,
+ mentions: &[Mention],
language_registry: &Arc<LanguageRegistry>,
language: Option<&Arc<Language>>,
) -> RichText {
@@ -238,7 +328,7 @@ pub fn render_markdown(
regions: Default::default(),
};
- render_markdown_mut(&block, language_registry, language, &mut data);
+ render_markdown_mut(&block, mentions, language_registry, language, &mut data);
data.text = data.text.trim().to_string();
@@ -17,6 +17,7 @@ clock = { path = "../clock" }
collections = { path = "../collections" }
gpui = { path = "../gpui", optional = true }
util = { path = "../util" }
+
anyhow.workspace = true
async-lock = "2.4"
async-tungstenite = "0.16"
@@ -27,8 +28,10 @@ prost.workspace = true
rand.workspace = true
rsa = "0.4"
serde.workspace = true
+serde_json.workspace = true
serde_derive.workspace = true
smol-timeout = "0.6"
+strum.workspace = true
tracing = { version = "0.1.34", features = ["log"] }
zstd = "0.11"
@@ -157,23 +157,30 @@ message Envelope {
UpdateChannelBufferCollaborators update_channel_buffer_collaborators = 130;
RejoinChannelBuffers rejoin_channel_buffers = 131;
RejoinChannelBuffersResponse rejoin_channel_buffers_response = 132;
- AckBufferOperation ack_buffer_operation = 145;
-
- JoinChannelChat join_channel_chat = 133;
- JoinChannelChatResponse join_channel_chat_response = 134;
- LeaveChannelChat leave_channel_chat = 135;
- SendChannelMessage send_channel_message = 136;
- SendChannelMessageResponse send_channel_message_response = 137;
- ChannelMessageSent channel_message_sent = 138;
- GetChannelMessages get_channel_messages = 139;
- GetChannelMessagesResponse get_channel_messages_response = 140;
- RemoveChannelMessage remove_channel_message = 141;
- AckChannelMessage ack_channel_message = 146;
-
- LinkChannel link_channel = 142;
- UnlinkChannel unlink_channel = 143;
- MoveChannel move_channel = 144;
- SetChannelVisibility set_channel_visibility = 147; // current max: 147
+ AckBufferOperation ack_buffer_operation = 133;
+
+ JoinChannelChat join_channel_chat = 134;
+ JoinChannelChatResponse join_channel_chat_response = 135;
+ LeaveChannelChat leave_channel_chat = 136;
+ SendChannelMessage send_channel_message = 137;
+ SendChannelMessageResponse send_channel_message_response = 138;
+ ChannelMessageSent channel_message_sent = 139;
+ GetChannelMessages get_channel_messages = 140;
+ GetChannelMessagesResponse get_channel_messages_response = 141;
+ RemoveChannelMessage remove_channel_message = 142;
+ AckChannelMessage ack_channel_message = 143;
+ GetChannelMessagesById get_channel_messages_by_id = 144;
+
+ LinkChannel link_channel = 145;
+ UnlinkChannel unlink_channel = 146;
+ MoveChannel move_channel = 147;
+ SetChannelVisibility set_channel_visibility = 148;
+
+ AddNotification add_notification = 149;
+ GetNotifications get_notifications = 150;
+ GetNotificationsResponse get_notifications_response = 151;
+ DeleteNotification delete_notification = 152;
+ MarkNotificationRead mark_notification_read = 153; // Current max
}
}
@@ -1094,6 +1101,7 @@ message SendChannelMessage {
uint64 channel_id = 1;
string body = 2;
Nonce nonce = 3;
+ repeated ChatMention mentions = 4;
}
message RemoveChannelMessage {
@@ -1125,6 +1133,10 @@ message GetChannelMessagesResponse {
bool done = 2;
}
+message GetChannelMessagesById {
+ repeated uint64 message_ids = 1;
+}
+
message LinkChannel {
uint64 channel_id = 1;
uint64 to = 2;
@@ -1151,6 +1163,12 @@ message ChannelMessage {
uint64 timestamp = 3;
uint64 sender_id = 4;
Nonce nonce = 5;
+ repeated ChatMention mentions = 6;
+}
+
+message ChatMention {
+ Range range = 1;
+ uint64 user_id = 2;
}
message RejoinChannelBuffers {
@@ -1242,7 +1260,6 @@ message ShowContacts {}
message IncomingContactRequest {
uint64 requester_id = 1;
- bool should_notify = 2;
}
message UpdateDiagnostics {
@@ -1575,7 +1592,6 @@ message Contact {
uint64 user_id = 1;
bool online = 2;
bool busy = 3;
- bool should_notify = 4;
}
message WorktreeMetadata {
@@ -1590,3 +1606,34 @@ message UpdateDiffBase {
uint64 buffer_id = 2;
optional string diff_base = 3;
}
+
+message GetNotifications {
+ optional uint64 before_id = 1;
+}
+
+message AddNotification {
+ Notification notification = 1;
+}
+
+message GetNotificationsResponse {
+ repeated Notification notifications = 1;
+ bool done = 2;
+}
+
+message DeleteNotification {
+ uint64 notification_id = 1;
+}
+
+message MarkNotificationRead {
+ uint64 notification_id = 1;
+}
+
+message Notification {
+ uint64 id = 1;
+ uint64 timestamp = 2;
+ string kind = 3;
+ optional uint64 entity_id = 4;
+ string content = 5;
+ bool is_read = 6;
+ optional bool response = 7;
+}
@@ -0,0 +1,105 @@
+use crate::proto;
+use serde::{Deserialize, Serialize};
+use serde_json::{map, Value};
+use strum::{EnumVariantNames, VariantNames as _};
+
+const KIND: &'static str = "kind";
+const ENTITY_ID: &'static str = "entity_id";
+
+/// A notification that can be stored, associated with a given recipient.
+///
+/// This struct is stored in the collab database as JSON, so it shouldn't be
+/// changed in a backward-incompatible way. For example, when renaming a
+/// variant, add a serde alias for the old name.
+///
+/// Most notification types have a special field which is aliased to
+/// `entity_id`. This field is stored in its own database column, and can
+/// be used to query the notification.
+#[derive(Debug, Clone, PartialEq, Eq, EnumVariantNames, Serialize, Deserialize)]
+#[serde(tag = "kind")]
+pub enum Notification {
+ ContactRequest {
+ #[serde(rename = "entity_id")]
+ sender_id: u64,
+ },
+ ContactRequestAccepted {
+ #[serde(rename = "entity_id")]
+ responder_id: u64,
+ },
+ ChannelInvitation {
+ #[serde(rename = "entity_id")]
+ channel_id: u64,
+ channel_name: String,
+ inviter_id: u64,
+ },
+ ChannelMessageMention {
+ #[serde(rename = "entity_id")]
+ message_id: u64,
+ sender_id: u64,
+ channel_id: u64,
+ },
+}
+
+impl Notification {
+ pub fn to_proto(&self) -> proto::Notification {
+ let mut value = serde_json::to_value(self).unwrap();
+ let mut entity_id = None;
+ let value = value.as_object_mut().unwrap();
+ let Some(Value::String(kind)) = value.remove(KIND) else {
+ unreachable!("kind is the enum tag")
+ };
+ if let map::Entry::Occupied(e) = value.entry(ENTITY_ID) {
+ if e.get().is_u64() {
+ entity_id = e.remove().as_u64();
+ }
+ }
+ proto::Notification {
+ kind,
+ entity_id,
+ content: serde_json::to_string(&value).unwrap(),
+ ..Default::default()
+ }
+ }
+
+ pub fn from_proto(notification: &proto::Notification) -> Option<Self> {
+ let mut value = serde_json::from_str::<Value>(¬ification.content).ok()?;
+ let object = value.as_object_mut()?;
+ object.insert(KIND.into(), notification.kind.to_string().into());
+ if let Some(entity_id) = notification.entity_id {
+ object.insert(ENTITY_ID.into(), entity_id.into());
+ }
+ serde_json::from_value(value).ok()
+ }
+
+ pub fn all_variant_names() -> &'static [&'static str] {
+ Self::VARIANTS
+ }
+}
+
+#[test]
+fn test_notification() {
+ // Notifications can be serialized and deserialized.
+ for notification in [
+ Notification::ContactRequest { sender_id: 1 },
+ Notification::ContactRequestAccepted { responder_id: 2 },
+ Notification::ChannelInvitation {
+ channel_id: 100,
+ channel_name: "the-channel".into(),
+ inviter_id: 50,
+ },
+ Notification::ChannelMessageMention {
+ sender_id: 200,
+ channel_id: 30,
+ message_id: 1,
+ },
+ ] {
+ let message = notification.to_proto();
+ let deserialized = Notification::from_proto(&message).unwrap();
+ assert_eq!(deserialized, notification);
+ }
+
+ // When notifications are serialized, the `kind` and `actor_id` fields are
+ // stored separately, and do not appear redundantly in the JSON.
+ let notification = Notification::ContactRequest { sender_id: 1 };
+ assert_eq!(notification.to_proto().content, "{}");
+}
@@ -133,6 +133,9 @@ impl fmt::Display for PeerId {
messages!(
(Ack, Foreground),
+ (AckBufferOperation, Background),
+ (AckChannelMessage, Background),
+ (AddNotification, Foreground),
(AddProjectCollaborator, Foreground),
(ApplyCodeAction, Background),
(ApplyCodeActionResponse, Background),
@@ -143,57 +146,75 @@ messages!(
(Call, Foreground),
(CallCanceled, Foreground),
(CancelCall, Foreground),
+ (ChannelMessageSent, Foreground),
(CopyProjectEntry, Foreground),
(CreateBufferForPeer, Foreground),
(CreateChannel, Foreground),
(CreateChannelResponse, Foreground),
- (ChannelMessageSent, Foreground),
(CreateProjectEntry, Foreground),
(CreateRoom, Foreground),
(CreateRoomResponse, Foreground),
(DeclineCall, Foreground),
+ (DeleteChannel, Foreground),
+ (DeleteNotification, Foreground),
(DeleteProjectEntry, Foreground),
(Error, Foreground),
(ExpandProjectEntry, Foreground),
+ (ExpandProjectEntryResponse, Foreground),
(Follow, Foreground),
(FollowResponse, Foreground),
(FormatBuffers, Foreground),
(FormatBuffersResponse, Foreground),
(FuzzySearchUsers, Foreground),
- (GetCodeActions, Background),
- (GetCodeActionsResponse, Background),
- (GetHover, Background),
- (GetHoverResponse, Background),
+ (GetChannelMembers, Foreground),
+ (GetChannelMembersResponse, Foreground),
(GetChannelMessages, Background),
+ (GetChannelMessagesById, Background),
(GetChannelMessagesResponse, Background),
- (SendChannelMessage, Background),
- (SendChannelMessageResponse, Background),
+ (GetCodeActions, Background),
+ (GetCodeActionsResponse, Background),
(GetCompletions, Background),
(GetCompletionsResponse, Background),
(GetDefinition, Background),
(GetDefinitionResponse, Background),
- (GetTypeDefinition, Background),
- (GetTypeDefinitionResponse, Background),
(GetDocumentHighlights, Background),
(GetDocumentHighlightsResponse, Background),
- (GetReferences, Background),
- (GetReferencesResponse, Background),
+ (GetHover, Background),
+ (GetHoverResponse, Background),
+ (GetNotifications, Foreground),
+ (GetNotificationsResponse, Foreground),
+ (GetPrivateUserInfo, Foreground),
+ (GetPrivateUserInfoResponse, Foreground),
(GetProjectSymbols, Background),
(GetProjectSymbolsResponse, Background),
+ (GetReferences, Background),
+ (GetReferencesResponse, Background),
+ (GetTypeDefinition, Background),
+ (GetTypeDefinitionResponse, Background),
(GetUsers, Foreground),
(Hello, Foreground),
(IncomingCall, Foreground),
+ (InlayHints, Background),
+ (InlayHintsResponse, Background),
(InviteChannelMember, Foreground),
- (UsersResponse, Foreground),
+ (JoinChannel, Foreground),
+ (JoinChannelBuffer, Foreground),
+ (JoinChannelBufferResponse, Foreground),
+ (JoinChannelChat, Foreground),
+ (JoinChannelChatResponse, Foreground),
(JoinProject, Foreground),
(JoinProjectResponse, Foreground),
(JoinRoom, Foreground),
(JoinRoomResponse, Foreground),
- (JoinChannelChat, Foreground),
- (JoinChannelChatResponse, Foreground),
+ (LeaveChannelBuffer, Background),
(LeaveChannelChat, Foreground),
(LeaveProject, Foreground),
(LeaveRoom, Foreground),
+ (LinkChannel, Foreground),
+ (MarkNotificationRead, Foreground),
+ (MoveChannel, Foreground),
+ (OnTypeFormatting, Background),
+ (OnTypeFormattingResponse, Background),
(OpenBufferById, Background),
(OpenBufferByPath, Background),
(OpenBufferForSymbol, Background),
@@ -201,61 +222,57 @@ messages!(
(OpenBufferResponse, Background),
(PerformRename, Background),
(PerformRenameResponse, Background),
- (OnTypeFormatting, Background),
- (OnTypeFormattingResponse, Background),
- (InlayHints, Background),
- (InlayHintsResponse, Background),
- (ResolveCompletionDocumentation, Background),
- (ResolveCompletionDocumentationResponse, Background),
- (ResolveInlayHint, Background),
- (ResolveInlayHintResponse, Background),
- (RefreshInlayHints, Foreground),
(Ping, Foreground),
(PrepareRename, Background),
(PrepareRenameResponse, Background),
- (ExpandProjectEntryResponse, Foreground),
(ProjectEntryResponse, Foreground),
+ (RefreshInlayHints, Foreground),
+ (RejoinChannelBuffers, Foreground),
+ (RejoinChannelBuffersResponse, Foreground),
(RejoinRoom, Foreground),
(RejoinRoomResponse, Foreground),
- (RemoveContact, Foreground),
- (RemoveChannelMember, Foreground),
- (RemoveChannelMessage, Foreground),
(ReloadBuffers, Foreground),
(ReloadBuffersResponse, Foreground),
+ (RemoveChannelMember, Foreground),
+ (RemoveChannelMessage, Foreground),
+ (RemoveContact, Foreground),
(RemoveProjectCollaborator, Foreground),
+ (RenameChannel, Foreground),
+ (RenameChannelResponse, Foreground),
(RenameProjectEntry, Foreground),
(RequestContact, Foreground),
- (RespondToContactRequest, Foreground),
+ (ResolveCompletionDocumentation, Background),
+ (ResolveCompletionDocumentationResponse, Background),
+ (ResolveInlayHint, Background),
+ (ResolveInlayHintResponse, Background),
(RespondToChannelInvite, Foreground),
- (JoinChannel, Foreground),
+ (RespondToContactRequest, Foreground),
(RoomUpdated, Foreground),
(SaveBuffer, Foreground),
- (RenameChannel, Foreground),
- (RenameChannelResponse, Foreground),
(SetChannelMemberRole, Foreground),
(SetChannelVisibility, Foreground),
(SearchProject, Background),
(SearchProjectResponse, Background),
+ (SendChannelMessage, Background),
+ (SendChannelMessageResponse, Background),
(ShareProject, Foreground),
(ShareProjectResponse, Foreground),
(ShowContacts, Foreground),
(StartLanguageServer, Foreground),
(SynchronizeBuffers, Foreground),
(SynchronizeBuffersResponse, Foreground),
- (RejoinChannelBuffers, Foreground),
- (RejoinChannelBuffersResponse, Foreground),
(Test, Foreground),
(Unfollow, Foreground),
+ (UnlinkChannel, Foreground),
(UnshareProject, Foreground),
(UpdateBuffer, Foreground),
(UpdateBufferFile, Foreground),
- (UpdateContacts, Foreground),
- (DeleteChannel, Foreground),
- (MoveChannel, Foreground),
- (LinkChannel, Foreground),
- (UnlinkChannel, Foreground),
+ (UpdateChannelBuffer, Foreground),
+ (UpdateChannelBufferCollaborators, Foreground),
(UpdateChannels, Foreground),
+ (UpdateContacts, Foreground),
(UpdateDiagnosticSummary, Foreground),
+ (UpdateDiffBase, Foreground),
(UpdateFollowers, Foreground),
(UpdateInviteInfo, Foreground),
(UpdateLanguageServer, Foreground),
@@ -264,18 +281,7 @@ messages!(
(UpdateProjectCollaborator, Foreground),
(UpdateWorktree, Foreground),
(UpdateWorktreeSettings, Foreground),
- (UpdateDiffBase, Foreground),
- (GetPrivateUserInfo, Foreground),
- (GetPrivateUserInfoResponse, Foreground),
- (GetChannelMembers, Foreground),
- (GetChannelMembersResponse, Foreground),
- (JoinChannelBuffer, Foreground),
- (JoinChannelBufferResponse, Foreground),
- (LeaveChannelBuffer, Background),
- (UpdateChannelBuffer, Foreground),
- (UpdateChannelBufferCollaborators, Foreground),
- (AckBufferOperation, Background),
- (AckChannelMessage, Background),
+ (UsersResponse, Foreground),
);
request_messages!(
@@ -287,77 +293,80 @@ request_messages!(
(Call, Ack),
(CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse),
+ (CreateChannel, CreateChannelResponse),
(CreateProjectEntry, ProjectEntryResponse),
(CreateRoom, CreateRoomResponse),
- (CreateChannel, CreateChannelResponse),
(DeclineCall, Ack),
+ (DeleteChannel, Ack),
(DeleteProjectEntry, ProjectEntryResponse),
(ExpandProjectEntry, ExpandProjectEntryResponse),
(Follow, FollowResponse),
(FormatBuffers, FormatBuffersResponse),
+ (FuzzySearchUsers, UsersResponse),
+ (GetChannelMembers, GetChannelMembersResponse),
+ (GetChannelMessages, GetChannelMessagesResponse),
+ (GetChannelMessagesById, GetChannelMessagesResponse),
(GetCodeActions, GetCodeActionsResponse),
- (GetHover, GetHoverResponse),
(GetCompletions, GetCompletionsResponse),
(GetDefinition, GetDefinitionResponse),
- (GetTypeDefinition, GetTypeDefinitionResponse),
(GetDocumentHighlights, GetDocumentHighlightsResponse),
- (GetReferences, GetReferencesResponse),
+ (GetHover, GetHoverResponse),
+ (GetNotifications, GetNotificationsResponse),
(GetPrivateUserInfo, GetPrivateUserInfoResponse),
(GetProjectSymbols, GetProjectSymbolsResponse),
- (FuzzySearchUsers, UsersResponse),
+ (GetReferences, GetReferencesResponse),
+ (GetTypeDefinition, GetTypeDefinitionResponse),
(GetUsers, UsersResponse),
+ (IncomingCall, Ack),
+ (InlayHints, InlayHintsResponse),
(InviteChannelMember, Ack),
+ (JoinChannel, JoinRoomResponse),
+ (JoinChannelBuffer, JoinChannelBufferResponse),
+ (JoinChannelChat, JoinChannelChatResponse),
(JoinProject, JoinProjectResponse),
(JoinRoom, JoinRoomResponse),
- (JoinChannelChat, JoinChannelChatResponse),
+ (LeaveChannelBuffer, Ack),
(LeaveRoom, Ack),
- (RejoinRoom, RejoinRoomResponse),
- (IncomingCall, Ack),
+ (LinkChannel, Ack),
+ (MarkNotificationRead, Ack),
+ (MoveChannel, Ack),
+ (OnTypeFormatting, OnTypeFormattingResponse),
(OpenBufferById, OpenBufferResponse),
(OpenBufferByPath, OpenBufferResponse),
(OpenBufferForSymbol, OpenBufferForSymbolResponse),
- (Ping, Ack),
(PerformRename, PerformRenameResponse),
+ (Ping, Ack),
(PrepareRename, PrepareRenameResponse),
- (OnTypeFormatting, OnTypeFormattingResponse),
- (InlayHints, InlayHintsResponse),
+ (RefreshInlayHints, Ack),
+ (RejoinChannelBuffers, RejoinChannelBuffersResponse),
+ (RejoinRoom, RejoinRoomResponse),
+ (ReloadBuffers, ReloadBuffersResponse),
+ (RemoveChannelMember, Ack),
+ (RemoveChannelMessage, Ack),
+ (RemoveContact, Ack),
+ (RenameChannel, RenameChannelResponse),
+ (RenameProjectEntry, ProjectEntryResponse),
+ (RequestContact, Ack),
(
ResolveCompletionDocumentation,
ResolveCompletionDocumentationResponse
),
(ResolveInlayHint, ResolveInlayHintResponse),
- (RefreshInlayHints, Ack),
- (ReloadBuffers, ReloadBuffersResponse),
- (RequestContact, Ack),
- (RemoveChannelMember, Ack),
- (RemoveContact, Ack),
- (RespondToContactRequest, Ack),
(RespondToChannelInvite, Ack),
- (SetChannelMemberRole, Ack),
- (SetChannelVisibility, Ack),
- (SendChannelMessage, SendChannelMessageResponse),
- (GetChannelMessages, GetChannelMessagesResponse),
- (GetChannelMembers, GetChannelMembersResponse),
- (JoinChannel, JoinRoomResponse),
- (RemoveChannelMessage, Ack),
- (DeleteChannel, Ack),
- (RenameProjectEntry, ProjectEntryResponse),
- (RenameChannel, RenameChannelResponse),
- (LinkChannel, Ack),
- (UnlinkChannel, Ack),
- (MoveChannel, Ack),
+ (RespondToContactRequest, Ack),
(SaveBuffer, BufferSaved),
(SearchProject, SearchProjectResponse),
+ (SendChannelMessage, SendChannelMessageResponse),
+ (SetChannelMemberRole, Ack),
+ (SetChannelVisibility, Ack),
(ShareProject, ShareProjectResponse),
(SynchronizeBuffers, SynchronizeBuffersResponse),
- (RejoinChannelBuffers, RejoinChannelBuffersResponse),
(Test, Test),
+ (UnlinkChannel, Ack),
(UpdateBuffer, Ack),
(UpdateParticipantLocation, Ack),
(UpdateProject, Ack),
(UpdateWorktree, Ack),
- (JoinChannelBuffer, JoinChannelBufferResponse),
- (LeaveChannelBuffer, Ack)
);
entity_messages!(
@@ -376,26 +385,26 @@ entity_messages!(
GetCodeActions,
GetCompletions,
GetDefinition,
- GetTypeDefinition,
GetDocumentHighlights,
GetHover,
- GetReferences,
GetProjectSymbols,
+ GetReferences,
+ GetTypeDefinition,
+ InlayHints,
JoinProject,
LeaveProject,
+ OnTypeFormatting,
OpenBufferById,
OpenBufferByPath,
OpenBufferForSymbol,
PerformRename,
- OnTypeFormatting,
- InlayHints,
- ResolveCompletionDocumentation,
- ResolveInlayHint,
- RefreshInlayHints,
PrepareRename,
+ RefreshInlayHints,
ReloadBuffers,
RemoveProjectCollaborator,
RenameProjectEntry,
+ ResolveCompletionDocumentation,
+ ResolveInlayHint,
SaveBuffer,
SearchProject,
StartLanguageServer,
@@ -404,19 +413,19 @@ entity_messages!(
UpdateBuffer,
UpdateBufferFile,
UpdateDiagnosticSummary,
+ UpdateDiffBase,
UpdateLanguageServer,
UpdateProject,
UpdateProjectCollaborator,
UpdateWorktree,
UpdateWorktreeSettings,
- UpdateDiffBase
);
entity_messages!(
channel_id,
ChannelMessageSent,
- UpdateChannelBuffer,
RemoveChannelMessage,
+ UpdateChannelBuffer,
UpdateChannelBufferCollaborators,
);
@@ -1,8 +1,11 @@
pub mod auth;
mod conn;
+mod notification;
mod peer;
pub mod proto;
+
pub use conn::Connection;
+pub use notification::*;
pub use peer::*;
mod macros;
@@ -537,6 +537,7 @@ impl BufferSearchBar {
self.active_searchable_item
.as_ref()
.map(|searchable_item| searchable_item.query_suggestion(cx))
+ .filter(|suggestion| !suggestion.is_empty())
}
pub fn set_replacement(&mut self, replacement: Option<&str>, cx: &mut ViewContext<Self>) {
@@ -351,33 +351,32 @@ impl View for ProjectSearchView {
SemanticIndexStatus::NotAuthenticated => {
major_text = Cow::Borrowed("Not Authenticated");
show_minor_text = false;
- Some(
- "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables"
- .to_string(),
- )
+ Some(vec![
+ "API Key Missing: Please set 'OPENAI_API_KEY' in Environment Variables."
+ .to_string(), "If you authenticated using the Assistant Panel, please restart Zed to Authenticate.".to_string()])
}
- SemanticIndexStatus::Indexed => Some("Indexing complete".to_string()),
+ SemanticIndexStatus::Indexed => Some(vec!["Indexing complete".to_string()]),
SemanticIndexStatus::Indexing {
remaining_files,
rate_limit_expiry,
} => {
if remaining_files == 0 {
- Some(format!("Indexing..."))
+ Some(vec![format!("Indexing...")])
} else {
if let Some(rate_limit_expiry) = rate_limit_expiry {
let remaining_seconds =
rate_limit_expiry.duration_since(Instant::now());
if remaining_seconds > Duration::from_secs(0) {
- Some(format!(
+ Some(vec![format!(
"Remaining files to index (rate limit resets in {}s): {}",
remaining_seconds.as_secs(),
remaining_files
- ))
+ )])
} else {
- Some(format!("Remaining files to index: {}", remaining_files))
+ Some(vec![format!("Remaining files to index: {}", remaining_files)])
}
} else {
- Some(format!("Remaining files to index: {}", remaining_files))
+ Some(vec![format!("Remaining files to index: {}", remaining_files)])
}
}
}
@@ -394,9 +393,11 @@ impl View for ProjectSearchView {
} else {
match current_mode {
SearchMode::Semantic => {
- let mut minor_text = Vec::new();
+ let mut minor_text: Vec<String> = Vec::new();
minor_text.push("".into());
- minor_text.extend(semantic_status);
+ if let Some(semantic_status) = semantic_status {
+ minor_text.extend(semantic_status);
+ }
if show_minor_text {
minor_text
.push("Simply explain the code you are looking to find.".into());
@@ -7,7 +7,10 @@ pub mod semantic_index_settings;
mod semantic_index_tests;
use crate::semantic_index_settings::SemanticIndexSettings;
-use ai::embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings};
+use ai::{
+ completion::OPENAI_API_URL,
+ embedding::{Embedding, EmbeddingProvider, OpenAIEmbeddings},
+};
use anyhow::{anyhow, Result};
use collections::{BTreeMap, HashMap, HashSet};
use db::VectorDatabase;
@@ -55,6 +58,19 @@ pub fn init(
.join(Path::new(RELEASE_CHANNEL_NAME.as_str()))
.join("embeddings_db");
+ let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ Some(api_key)
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ String::from_utf8(api_key).log_err()
+ } else {
+ None
+ };
+
cx.subscribe_global::<WorkspaceCreated, _>({
move |event, cx| {
let Some(semantic_index) = SemanticIndex::global(cx) else {
@@ -88,7 +104,7 @@ pub fn init(
let semantic_index = SemanticIndex::new(
fs,
db_file_path,
- Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+ Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
language_registry,
cx.clone(),
)
@@ -2,14 +2,15 @@ use crate::{Anchor, BufferSnapshot, TextDimension};
use std::cmp::Ordering;
use std::ops::Range;
-#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+#[derive(Copy, Clone, Debug, PartialEq)]
pub enum SelectionGoal {
None,
- Column(u32),
- ColumnRange { start: u32, end: u32 },
+ HorizontalPosition(f32),
+ HorizontalRange { start: f32, end: f32 },
+ WrappedHorizontalPosition((u32, f32)),
}
-#[derive(Clone, Debug, Eq, PartialEq)]
+#[derive(Clone, Debug, PartialEq)]
pub struct Selection<T> {
pub id: usize,
pub start: T,
@@ -53,6 +53,7 @@ pub struct Theme {
pub collab_panel: CollabPanel,
pub project_panel: ProjectPanel,
pub chat_panel: ChatPanel,
+ pub notification_panel: NotificationPanel,
pub command_palette: CommandPalette,
pub picker: Picker,
pub editor: Editor,
@@ -638,21 +639,43 @@ pub struct ChatPanel {
pub input_editor: FieldEditor,
pub avatar: AvatarStyle,
pub avatar_container: ContainerStyle,
- pub message: ChatMessage,
- pub continuation_message: ChatMessage,
+ pub rich_text: RichTextStyle,
+ pub message_sender: ContainedText,
+ pub message_timestamp: ContainedText,
+ pub message: Interactive<ContainerStyle>,
+ pub continuation_message: Interactive<ContainerStyle>,
+ pub pending_message: Interactive<ContainerStyle>,
pub last_message_bottom_spacing: f32,
- pub pending_message: ChatMessage,
pub sign_in_prompt: Interactive<TextStyle>,
pub icon_button: Interactive<IconButton>,
}
+#[derive(Clone, Deserialize, Default, JsonSchema)]
+pub struct RichTextStyle {
+ pub text: TextStyle,
+ pub mention_highlight: HighlightStyle,
+ pub mention_background: Option<Color>,
+ pub self_mention_highlight: HighlightStyle,
+ pub self_mention_background: Option<Color>,
+ pub code_background: Option<Color>,
+}
+
#[derive(Deserialize, Default, JsonSchema)]
-pub struct ChatMessage {
+pub struct NotificationPanel {
#[serde(flatten)]
- pub container: Interactive<ContainerStyle>,
- pub body: TextStyle,
- pub sender: ContainedText,
+ pub container: ContainerStyle,
+ pub title: ContainedText,
+ pub title_icon: SvgStyle,
+ pub title_height: f32,
+ pub list: ContainerStyle,
+ pub avatar: AvatarStyle,
+ pub avatar_container: ContainerStyle,
+ pub sign_in_prompt: Interactive<TextStyle>,
+ pub icon_button: Interactive<IconButton>,
+ pub unread_text: ContainedText,
+ pub read_text: ContainedText,
pub timestamp: ContainedText,
+ pub button: Interactive<ContainedText>,
}
#[derive(Deserialize, Default, JsonSchema)]
@@ -7,6 +7,7 @@ publish = false
[dependencies]
fuzzy = {path = "../fuzzy"}
+fs = {path = "../fs"}
gpui = {path = "../gpui"}
picker = {path = "../picker"}
util = {path = "../util"}
@@ -1,4 +1,5 @@
use anyhow::{anyhow, bail, Result};
+use fs::repository::Branch;
use fuzzy::{StringMatch, StringMatchCandidate};
use gpui::{
actions,
@@ -22,18 +23,9 @@ pub type BranchList = Picker<BranchListDelegate>;
pub fn build_branch_list(
workspace: ViewHandle<Workspace>,
cx: &mut ViewContext<BranchList>,
-) -> BranchList {
- Picker::new(
- BranchListDelegate {
- matches: vec![],
- workspace,
- selected_index: 0,
- last_query: String::default(),
- branch_name_trailoff_after: 29,
- },
- cx,
- )
- .with_theme(|theme| theme.picker.clone())
+) -> Result<BranchList> {
+ Ok(Picker::new(BranchListDelegate::new(workspace, 29, cx)?, cx)
+ .with_theme(|theme| theme.picker.clone()))
}
fn toggle(
@@ -43,31 +35,24 @@ fn toggle(
) -> Option<Task<Result<()>>> {
Some(cx.spawn(|workspace, mut cx| async move {
workspace.update(&mut cx, |workspace, cx| {
+ // Modal branch picker has a longer trailoff than a popover one.
+ let delegate = BranchListDelegate::new(cx.handle(), 70, cx)?;
workspace.toggle_modal(cx, |_, cx| {
- let workspace = cx.handle();
cx.add_view(|cx| {
- Picker::new(
- BranchListDelegate {
- matches: vec![],
- workspace,
- selected_index: 0,
- last_query: String::default(),
- /// Modal branch picker has a longer trailoff than a popover one.
- branch_name_trailoff_after: 70,
- },
- cx,
- )
- .with_theme(|theme| theme.picker.clone())
- .with_max_size(800., 1200.)
+ Picker::new(delegate, cx)
+ .with_theme(|theme| theme.picker.clone())
+ .with_max_size(800., 1200.)
})
});
- })?;
+ Ok::<_, anyhow::Error>(())
+ })??;
Ok(())
}))
}
pub struct BranchListDelegate {
matches: Vec<StringMatch>,
+ all_branches: Vec<Branch>,
workspace: ViewHandle<Workspace>,
selected_index: usize,
last_query: String,
@@ -76,6 +61,31 @@ pub struct BranchListDelegate {
}
impl BranchListDelegate {
+ fn new(
+ workspace: ViewHandle<Workspace>,
+ branch_name_trailoff_after: usize,
+ cx: &AppContext,
+ ) -> Result<Self> {
+ let project = workspace.read(cx).project().read(&cx);
+
+ let Some(worktree) = project.visible_worktrees(cx).next() else {
+ bail!("Cannot update branch list as there are no visible worktrees")
+ };
+ let mut cwd = worktree.read(cx).abs_path().to_path_buf();
+ cwd.push(".git");
+ let Some(repo) = project.fs().open_repo(&cwd) else {
+ bail!("Project does not have associated git repository.")
+ };
+ let all_branches = repo.lock().branches()?;
+ Ok(Self {
+ matches: vec![],
+ workspace,
+ all_branches,
+ selected_index: 0,
+ last_query: Default::default(),
+ branch_name_trailoff_after,
+ })
+ }
fn display_error_toast(&self, message: String, cx: &mut ViewContext<BranchList>) {
const GIT_CHECKOUT_FAILURE_ID: usize = 2048;
self.workspace.update(cx, |model, ctx| {
@@ -83,6 +93,7 @@ impl BranchListDelegate {
});
}
}
+
impl PickerDelegate for BranchListDelegate {
fn placeholder_text(&self) -> Arc<str> {
"Select branch...".into()
@@ -102,45 +113,28 @@ impl PickerDelegate for BranchListDelegate {
fn update_matches(&mut self, query: String, cx: &mut ViewContext<Picker<Self>>) -> Task<()> {
cx.spawn(move |picker, mut cx| async move {
- let Some(candidates) = picker
- .read_with(&mut cx, |view, cx| {
- let delegate = view.delegate();
- let project = delegate.workspace.read(cx).project().read(&cx);
-
- let Some(worktree) = project.visible_worktrees(cx).next() else {
- bail!("Cannot update branch list as there are no visible worktrees")
- };
- let mut cwd = worktree.read(cx).abs_path().to_path_buf();
- cwd.push(".git");
- let Some(repo) = project.fs().open_repo(&cwd) else {
- bail!("Project does not have associated git repository.")
- };
- let mut branches = repo.lock().branches()?;
- const RECENT_BRANCHES_COUNT: usize = 10;
- if query.is_empty() && branches.len() > RECENT_BRANCHES_COUNT {
- // Truncate list of recent branches
- // Do a partial sort to show recent-ish branches first.
- branches.select_nth_unstable_by(RECENT_BRANCHES_COUNT - 1, |lhs, rhs| {
- rhs.unix_timestamp.cmp(&lhs.unix_timestamp)
- });
- branches.truncate(RECENT_BRANCHES_COUNT);
- branches.sort_unstable_by(|lhs, rhs| lhs.name.cmp(&rhs.name));
- }
- Ok(branches
- .iter()
- .cloned()
- .enumerate()
- .map(|(ix, command)| StringMatchCandidate {
- id: ix,
- char_bag: command.name.chars().collect(),
- string: command.name.into(),
- })
- .collect::<Vec<_>>())
- })
- .log_err()
- else {
- return;
- };
+ let candidates = picker.read_with(&mut cx, |view, _| {
+ const RECENT_BRANCHES_COUNT: usize = 10;
+ let mut branches = view.delegate().all_branches.clone();
+ if query.is_empty() && branches.len() > RECENT_BRANCHES_COUNT {
+ // Truncate list of recent branches
+ // Do a partial sort to show recent-ish branches first.
+ branches.select_nth_unstable_by(RECENT_BRANCHES_COUNT - 1, |lhs, rhs| {
+ rhs.unix_timestamp.cmp(&lhs.unix_timestamp)
+ });
+ branches.truncate(RECENT_BRANCHES_COUNT);
+ branches.sort_unstable_by(|lhs, rhs| lhs.name.cmp(&rhs.name));
+ }
+ branches
+ .into_iter()
+ .enumerate()
+ .map(|(ix, command)| StringMatchCandidate {
+ id: ix,
+ char_bag: command.name.chars().collect(),
+ string: command.name.into(),
+ })
+ .collect::<Vec<StringMatchCandidate>>()
+ });
let Some(candidates) = candidates.log_err() else {
return;
};
@@ -1,9 +1,7 @@
-use std::cmp;
-
use editor::{
char_kind,
display_map::{DisplaySnapshot, FoldPoint, ToDisplayPoint},
- movement::{self, find_boundary, find_preceding_boundary, FindRange},
+ movement::{self, find_boundary, find_preceding_boundary, FindRange, TextLayoutDetails},
Bias, CharKind, DisplayPoint, ToOffset,
};
use gpui::{actions, impl_actions, AppContext, WindowContext};
@@ -361,6 +359,7 @@ impl Motion {
point: DisplayPoint,
goal: SelectionGoal,
maybe_times: Option<usize>,
+ text_layout_details: &TextLayoutDetails,
) -> Option<(DisplayPoint, SelectionGoal)> {
let times = maybe_times.unwrap_or(1);
use Motion::*;
@@ -370,16 +369,16 @@ impl Motion {
Backspace => (backspace(map, point, times), SelectionGoal::None),
Down {
display_lines: false,
- } => down(map, point, goal, times),
+ } => up_down_buffer_rows(map, point, goal, times as isize, &text_layout_details),
Down {
display_lines: true,
- } => down_display(map, point, goal, times),
+ } => down_display(map, point, goal, times, &text_layout_details),
Up {
display_lines: false,
- } => up(map, point, goal, times),
+ } => up_down_buffer_rows(map, point, goal, 0 - times as isize, &text_layout_details),
Up {
display_lines: true,
- } => up_display(map, point, goal, times),
+ } => up_display(map, point, goal, times, &text_layout_details),
Right => (right(map, point, times), SelectionGoal::None),
NextWordStart { ignore_punctuation } => (
next_word_start(map, point, *ignore_punctuation, times),
@@ -442,10 +441,15 @@ impl Motion {
selection: &mut Selection<DisplayPoint>,
times: Option<usize>,
expand_to_surrounding_newline: bool,
+ text_layout_details: &TextLayoutDetails,
) -> bool {
- if let Some((new_head, goal)) =
- self.move_point(map, selection.head(), selection.goal, times)
- {
+ if let Some((new_head, goal)) = self.move_point(
+ map,
+ selection.head(),
+ selection.goal,
+ times,
+ &text_layout_details,
+ ) {
selection.set_head(new_head, goal);
if self.linewise() {
@@ -530,35 +534,85 @@ fn backspace(map: &DisplaySnapshot, mut point: DisplayPoint, times: usize) -> Di
point
}
-fn down(
+pub(crate) fn start_of_relative_buffer_row(
+ map: &DisplaySnapshot,
+ point: DisplayPoint,
+ times: isize,
+) -> DisplayPoint {
+ let start = map.display_point_to_fold_point(point, Bias::Left);
+ let target = start.row() as isize + times;
+ let new_row = (target.max(0) as u32).min(map.fold_snapshot.max_point().row());
+
+ map.clip_point(
+ map.fold_point_to_display_point(
+ map.fold_snapshot
+ .clip_point(FoldPoint::new(new_row, 0), Bias::Right),
+ ),
+ Bias::Right,
+ )
+}
+
+fn up_down_buffer_rows(
map: &DisplaySnapshot,
point: DisplayPoint,
mut goal: SelectionGoal,
- times: usize,
+ times: isize,
+ text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) {
let start = map.display_point_to_fold_point(point, Bias::Left);
+ let begin_folded_line = map.fold_point_to_display_point(
+ map.fold_snapshot
+ .clip_point(FoldPoint::new(start.row(), 0), Bias::Left),
+ );
+ let select_nth_wrapped_row = point.row() - begin_folded_line.row();
- let goal_column = match goal {
- SelectionGoal::Column(column) => column,
- SelectionGoal::ColumnRange { end, .. } => end,
+ let (goal_wrap, goal_x) = match goal {
+ SelectionGoal::WrappedHorizontalPosition((row, x)) => (row, x),
+ SelectionGoal::HorizontalRange { end, .. } => (select_nth_wrapped_row, end),
+ SelectionGoal::HorizontalPosition(x) => (select_nth_wrapped_row, x),
_ => {
- goal = SelectionGoal::Column(start.column());
- start.column()
+ let x = map.x_for_point(point, text_layout_details);
+ goal = SelectionGoal::WrappedHorizontalPosition((select_nth_wrapped_row, x));
+ (select_nth_wrapped_row, x)
}
};
- let new_row = cmp::min(
- start.row() + times as u32,
- map.fold_snapshot.max_point().row(),
- );
- let new_col = cmp::min(goal_column, map.fold_snapshot.line_len(new_row));
- let point = map.fold_point_to_display_point(
+ let target = start.row() as isize + times;
+ let new_row = (target.max(0) as u32).min(map.fold_snapshot.max_point().row());
+
+ let mut begin_folded_line = map.fold_point_to_display_point(
map.fold_snapshot
- .clip_point(FoldPoint::new(new_row, new_col), Bias::Left),
+ .clip_point(FoldPoint::new(new_row, 0), Bias::Left),
);
- // clip twice to "clip at end of line"
- (map.clip_point(point, Bias::Left), goal)
+ let mut i = 0;
+ while i < goal_wrap && begin_folded_line.row() < map.max_point().row() {
+ let next_folded_line = DisplayPoint::new(begin_folded_line.row() + 1, 0);
+ if map
+ .display_point_to_fold_point(next_folded_line, Bias::Right)
+ .row()
+ == new_row
+ {
+ i += 1;
+ begin_folded_line = next_folded_line;
+ } else {
+ break;
+ }
+ }
+
+ let new_col = if i == goal_wrap {
+ map.column_for_x(begin_folded_line.row(), goal_x, text_layout_details)
+ } else {
+ map.line_len(begin_folded_line.row())
+ };
+
+ (
+ map.clip_point(
+ DisplayPoint::new(begin_folded_line.row(), new_col),
+ Bias::Left,
+ ),
+ goal,
+ )
}
fn down_display(
@@ -566,49 +620,24 @@ fn down_display(
mut point: DisplayPoint,
mut goal: SelectionGoal,
times: usize,
+ text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) {
for _ in 0..times {
- (point, goal) = movement::down(map, point, goal, true);
+ (point, goal) = movement::down(map, point, goal, true, text_layout_details);
}
(point, goal)
}
-pub(crate) fn up(
- map: &DisplaySnapshot,
- point: DisplayPoint,
- mut goal: SelectionGoal,
- times: usize,
-) -> (DisplayPoint, SelectionGoal) {
- let start = map.display_point_to_fold_point(point, Bias::Left);
-
- let goal_column = match goal {
- SelectionGoal::Column(column) => column,
- SelectionGoal::ColumnRange { end, .. } => end,
- _ => {
- goal = SelectionGoal::Column(start.column());
- start.column()
- }
- };
-
- let new_row = start.row().saturating_sub(times as u32);
- let new_col = cmp::min(goal_column, map.fold_snapshot.line_len(new_row));
- let point = map.fold_point_to_display_point(
- map.fold_snapshot
- .clip_point(FoldPoint::new(new_row, new_col), Bias::Left),
- );
-
- (map.clip_point(point, Bias::Left), goal)
-}
-
fn up_display(
map: &DisplaySnapshot,
mut point: DisplayPoint,
mut goal: SelectionGoal,
times: usize,
+ text_layout_details: &TextLayoutDetails,
) -> (DisplayPoint, SelectionGoal) {
for _ in 0..times {
- (point, goal) = movement::up(map, point, goal, true);
+ (point, goal) = movement::up(map, point, goal, true, &text_layout_details);
}
(point, goal)
@@ -707,7 +736,7 @@ fn previous_word_start(
point
}
-fn first_non_whitespace(
+pub(crate) fn first_non_whitespace(
map: &DisplaySnapshot,
display_lines: bool,
from: DisplayPoint,
@@ -886,13 +915,17 @@ fn find_backward(
}
fn next_line_start(map: &DisplaySnapshot, point: DisplayPoint, times: usize) -> DisplayPoint {
- let correct_line = down(map, point, SelectionGoal::None, times).0;
+ let correct_line = start_of_relative_buffer_row(map, point, times as isize);
first_non_whitespace(map, false, correct_line)
}
-fn next_line_end(map: &DisplaySnapshot, mut point: DisplayPoint, times: usize) -> DisplayPoint {
+pub(crate) fn next_line_end(
+ map: &DisplaySnapshot,
+ mut point: DisplayPoint,
+ times: usize,
+) -> DisplayPoint {
if times > 1 {
- point = down(map, point, SelectionGoal::None, times - 1).0;
+ point = start_of_relative_buffer_row(map, point, times as isize - 1);
}
end_of_line(map, false, point)
}
@@ -12,7 +12,7 @@ mod yank;
use std::sync::Arc;
use crate::{
- motion::{self, Motion},
+ motion::{self, first_non_whitespace, next_line_end, right, Motion},
object::Object,
state::{Mode, Operator},
Vim,
@@ -179,10 +179,11 @@ pub(crate) fn move_cursor(
cx: &mut WindowContext,
) {
vim.update_active_editor(cx, |editor, cx| {
+ let text_layout_details = editor.text_layout_details(cx);
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.move_cursors_with(|map, cursor, goal| {
motion
- .move_point(map, cursor, goal, times)
+ .move_point(map, cursor, goal, times, &text_layout_details)
.unwrap_or((cursor, goal))
})
})
@@ -195,9 +196,7 @@ fn insert_after(_: &mut Workspace, _: &InsertAfter, cx: &mut ViewContext<Workspa
vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
- s.maybe_move_cursors_with(|map, cursor, goal| {
- Motion::Right.move_point(map, cursor, goal, None)
- });
+ s.move_cursors_with(|map, cursor, _| (right(map, cursor, 1), SelectionGoal::None));
});
});
});
@@ -220,11 +219,11 @@ fn insert_first_non_whitespace(
vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
- s.maybe_move_cursors_with(|map, cursor, goal| {
- Motion::FirstNonWhitespace {
- display_lines: false,
- }
- .move_point(map, cursor, goal, None)
+ s.move_cursors_with(|map, cursor, _| {
+ (
+ first_non_whitespace(map, false, cursor),
+ SelectionGoal::None,
+ )
});
});
});
@@ -237,8 +236,8 @@ fn insert_end_of_line(_: &mut Workspace, _: &InsertEndOfLine, cx: &mut ViewConte
vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
- s.maybe_move_cursors_with(|map, cursor, goal| {
- Motion::CurrentLine.move_point(map, cursor, goal, None)
+ s.move_cursors_with(|map, cursor, _| {
+ (next_line_end(map, cursor, 1), SelectionGoal::None)
});
});
});
@@ -268,7 +267,7 @@ fn insert_line_above(_: &mut Workspace, _: &InsertLineAbove, cx: &mut ViewContex
editor.edit_with_autoindent(edits, cx);
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.move_cursors_with(|map, cursor, _| {
- let previous_line = motion::up(map, cursor, SelectionGoal::None, 1).0;
+ let previous_line = motion::start_of_relative_buffer_row(map, cursor, -1);
let insert_point = motion::end_of_line(map, false, previous_line);
(insert_point, SelectionGoal::None)
});
@@ -283,6 +282,7 @@ fn insert_line_below(_: &mut Workspace, _: &InsertLineBelow, cx: &mut ViewContex
vim.start_recording(cx);
vim.switch_mode(Mode::Insert, false, cx);
vim.update_active_editor(cx, |editor, cx| {
+ let text_layout_details = editor.text_layout_details(cx);
editor.transact(cx, |editor, cx| {
let (map, old_selections) = editor.selections.all_display(cx);
@@ -301,7 +301,13 @@ fn insert_line_below(_: &mut Workspace, _: &InsertLineBelow, cx: &mut ViewContex
});
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
s.maybe_move_cursors_with(|map, cursor, goal| {
- Motion::CurrentLine.move_point(map, cursor, goal, None)
+ Motion::CurrentLine.move_point(
+ map,
+ cursor,
+ goal,
+ None,
+ &text_layout_details,
+ )
});
});
editor.edit_with_autoindent(edits, cx);
@@ -399,12 +405,26 @@ mod test {
#[gpui::test]
async fn test_j(cx: &mut gpui::TestAppContext) {
- let mut cx = NeovimBackedTestContext::new(cx).await.binding(["j"]);
- cx.assert_all(indoc! {"
- ΛThe qΛuick broΛwn
- Λfox jumps"
+ let mut cx = NeovimBackedTestContext::new(cx).await;
+
+ cx.set_shared_state(indoc! {"
+ aaΛaa
+ ππ"
+ })
+ .await;
+ cx.simulate_shared_keystrokes(["j"]).await;
+ cx.assert_shared_state(indoc! {"
+ aaaa
+ πΛπ"
})
.await;
+
+ for marked_position in cx.each_marked_position(indoc! {"
+ ΛThe qΛuick broΛwn
+ Λfox jumps"
+ }) {
+ cx.assert_neovim_compatible(&marked_position, ["j"]).await;
+ }
}
#[gpui::test]
@@ -2,7 +2,7 @@ use crate::{motion::Motion, object::Object, state::Mode, utils::copy_selections_
use editor::{
char_kind,
display_map::DisplaySnapshot,
- movement::{self, FindRange},
+ movement::{self, FindRange, TextLayoutDetails},
scroll::autoscroll::Autoscroll,
CharKind, DisplayPoint,
};
@@ -20,6 +20,7 @@ pub fn change_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &m
| Motion::StartOfLine { .. }
);
vim.update_active_editor(cx, |editor, cx| {
+ let text_layout_details = editor.text_layout_details(cx);
editor.transact(cx, |editor, cx| {
// We are swapping to insert mode anyway. Just set the line end clipping behavior now
editor.set_clip_at_line_ends(false, cx);
@@ -27,9 +28,15 @@ pub fn change_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &m
s.move_with(|map, selection| {
motion_succeeded |= if let Motion::NextWordStart { ignore_punctuation } = motion
{
- expand_changed_word_selection(map, selection, times, ignore_punctuation)
+ expand_changed_word_selection(
+ map,
+ selection,
+ times,
+ ignore_punctuation,
+ &text_layout_details,
+ )
} else {
- motion.expand_selection(map, selection, times, false)
+ motion.expand_selection(map, selection, times, false, &text_layout_details)
};
});
});
@@ -81,6 +88,7 @@ fn expand_changed_word_selection(
selection: &mut Selection<DisplayPoint>,
times: Option<usize>,
ignore_punctuation: bool,
+ text_layout_details: &TextLayoutDetails,
) -> bool {
if times.is_none() || times.unwrap() == 1 {
let scope = map
@@ -103,11 +111,22 @@ fn expand_changed_word_selection(
});
true
} else {
- Motion::NextWordStart { ignore_punctuation }
- .expand_selection(map, selection, None, false)
+ Motion::NextWordStart { ignore_punctuation }.expand_selection(
+ map,
+ selection,
+ None,
+ false,
+ &text_layout_details,
+ )
}
} else {
- Motion::NextWordStart { ignore_punctuation }.expand_selection(map, selection, times, false)
+ Motion::NextWordStart { ignore_punctuation }.expand_selection(
+ map,
+ selection,
+ times,
+ false,
+ &text_layout_details,
+ )
}
}
@@ -7,6 +7,7 @@ use language::Point;
pub fn delete_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &mut WindowContext) {
vim.stop_recording();
vim.update_active_editor(cx, |editor, cx| {
+ let text_layout_details = editor.text_layout_details(cx);
editor.transact(cx, |editor, cx| {
editor.set_clip_at_line_ends(false, cx);
let mut original_columns: HashMap<_, _> = Default::default();
@@ -14,7 +15,7 @@ pub fn delete_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &m
s.move_with(|map, selection| {
let original_head = selection.head();
original_columns.insert(selection.id, original_head.column());
- motion.expand_selection(map, selection, times, true);
+ motion.expand_selection(map, selection, times, true, &text_layout_details);
// Motion::NextWordStart on an empty line should delete it.
if let Motion::NextWordStart {
@@ -255,8 +255,18 @@ mod test {
4
5"})
.await;
- cx.simulate_shared_keystrokes(["shift-g", "ctrl-v", "g", "g", "g", "ctrl-x"])
+
+ cx.simulate_shared_keystrokes(["shift-g", "ctrl-v", "g", "g"])
+ .await;
+ cx.assert_shared_state(indoc! {"
+ Β«1ΛΒ»
+ Β«2ΛΒ»
+ Β«3ΛΒ» 2
+ Β«4ΛΒ»
+ Β«5ΛΒ»"})
.await;
+
+ cx.simulate_shared_keystrokes(["g", "ctrl-x"]).await;
cx.assert_shared_state(indoc! {"
Λ0
0
@@ -30,6 +30,7 @@ fn paste(_: &mut Workspace, action: &Paste, cx: &mut ViewContext<Workspace>) {
Vim::update(cx, |vim, cx| {
vim.record_current_action(cx);
vim.update_active_editor(cx, |editor, cx| {
+ let text_layout_details = editor.text_layout_details(cx);
editor.transact(cx, |editor, cx| {
editor.set_clip_at_line_ends(false, cx);
@@ -168,8 +169,14 @@ fn paste(_: &mut Workspace, action: &Paste, cx: &mut ViewContext<Workspace>) {
let mut cursor = anchor.to_display_point(map);
if *line_mode {
if !before {
- cursor =
- movement::down(map, cursor, SelectionGoal::None, false).0;
+ cursor = movement::down(
+ map,
+ cursor,
+ SelectionGoal::None,
+ false,
+ &text_layout_details,
+ )
+ .0;
}
cursor = movement::indented_line_beginning(map, cursor, true);
} else if !is_multiline {
@@ -32,10 +32,17 @@ pub fn substitute(vim: &mut Vim, count: Option<usize>, line_mode: bool, cx: &mut
vim.update_active_editor(cx, |editor, cx| {
editor.set_clip_at_line_ends(false, cx);
editor.transact(cx, |editor, cx| {
+ let text_layout_details = editor.text_layout_details(cx);
editor.change_selections(None, cx, |s| {
s.move_with(|map, selection| {
if selection.start == selection.end {
- Motion::Right.expand_selection(map, selection, count, true);
+ Motion::Right.expand_selection(
+ map,
+ selection,
+ count,
+ true,
+ &text_layout_details,
+ );
}
if line_mode {
// in Visual mode when the selection contains the newline at the end
@@ -43,7 +50,13 @@ pub fn substitute(vim: &mut Vim, count: Option<usize>, line_mode: bool, cx: &mut
if !selection.is_empty() && selection.end.column() == 0 {
selection.end = movement::left(map, selection.end);
}
- Motion::CurrentLine.expand_selection(map, selection, None, false);
+ Motion::CurrentLine.expand_selection(
+ map,
+ selection,
+ None,
+ false,
+ &text_layout_details,
+ );
if let Some((point, _)) = (Motion::FirstNonWhitespace {
display_lines: false,
})
@@ -52,6 +65,7 @@ pub fn substitute(vim: &mut Vim, count: Option<usize>, line_mode: bool, cx: &mut
selection.start,
selection.goal,
None,
+ &text_layout_details,
) {
selection.start = point;
}
@@ -4,6 +4,7 @@ use gpui::WindowContext;
pub fn yank_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &mut WindowContext) {
vim.update_active_editor(cx, |editor, cx| {
+ let text_layout_details = editor.text_layout_details(cx);
editor.transact(cx, |editor, cx| {
editor.set_clip_at_line_ends(false, cx);
let mut original_positions: HashMap<_, _> = Default::default();
@@ -11,7 +12,7 @@ pub fn yank_motion(vim: &mut Vim, motion: Motion, times: Option<usize>, cx: &mut
s.move_with(|map, selection| {
let original_position = (selection.head(), selection.goal);
original_positions.insert(selection.id, original_position);
- motion.expand_selection(map, selection, times, true);
+ motion.expand_selection(map, selection, times, true, &text_layout_details);
});
});
copy_selections_content(editor, motion.linewise(), cx);
@@ -653,6 +653,63 @@ async fn test_selection_goal(cx: &mut gpui::TestAppContext) {
.await;
}
+#[gpui::test]
+async fn test_wrapped_motions(cx: &mut gpui::TestAppContext) {
+ let mut cx = NeovimBackedTestContext::new(cx).await;
+
+ cx.set_shared_wrap(12).await;
+
+ cx.set_shared_state(indoc! {"
+ aaΛaa
+ ππ"
+ })
+ .await;
+ cx.simulate_shared_keystrokes(["j"]).await;
+ cx.assert_shared_state(indoc! {"
+ aaaa
+ πΛπ"
+ })
+ .await;
+
+ cx.set_shared_state(indoc! {"
+ 123456789012aaΛaa
+ 123456789012ππ"
+ })
+ .await;
+ cx.simulate_shared_keystrokes(["j"]).await;
+ cx.assert_shared_state(indoc! {"
+ 123456789012aaaa
+ 123456789012πΛπ"
+ })
+ .await;
+
+ cx.set_shared_state(indoc! {"
+ 123456789012aaΛaa
+ 123456789012ππ"
+ })
+ .await;
+ cx.simulate_shared_keystrokes(["j"]).await;
+ cx.assert_shared_state(indoc! {"
+ 123456789012aaaa
+ 123456789012πΛπ"
+ })
+ .await;
+
+ cx.set_shared_state(indoc! {"
+ 123456789012aaaaΛaaaaaaaa123456789012
+ wow
+ 123456789012ππππππ123456789012"
+ })
+ .await;
+ cx.simulate_shared_keystrokes(["j", "j"]).await;
+ cx.assert_shared_state(indoc! {"
+ 123456789012aaaaaaaaaaaa123456789012
+ wow
+ 123456789012ππΛππππ123456789012"
+ })
+ .await;
+}
+
#[gpui::test]
async fn test_paragraphs_dont_wrap(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
@@ -25,7 +25,7 @@ pub use mode_indicator::ModeIndicator;
use motion::Motion;
use normal::normal_replace;
use serde::Deserialize;
-use settings::{Setting, SettingsStore};
+use settings::{update_settings_file, Setting, SettingsStore};
use state::{EditorState, Mode, Operator, RecordedSelection, WorkspaceState};
use std::{ops::Range, sync::Arc};
use visual::{visual_block_motion, visual_replace};
@@ -48,6 +48,7 @@ actions!(
vim,
[Tab, Enter, Object, InnerObject, FindForward, FindBackward]
);
+actions!(workspace, [ToggleVimMode]);
impl_actions!(vim, [Number, SwitchMode, PushOperator]);
#[derive(Copy, Clone, Debug)]
@@ -88,6 +89,14 @@ pub fn init(cx: &mut AppContext) {
Vim::active_editor_input_ignored("\n".into(), cx)
});
+ cx.add_action(|workspace: &mut Workspace, _: &ToggleVimMode, cx| {
+ let fs = workspace.app_state().fs.clone();
+ let currently_enabled = settings::get::<VimModeSetting>(cx).0;
+ update_settings_file::<VimModeSetting>(fs, cx, move |setting| {
+ *setting = Some(!currently_enabled)
+ })
+ });
+
// Any time settings change, update vim mode to match. The Vim struct
// will be initialized as disabled by default, so we filter its commands
// out when starting up.
@@ -581,7 +590,7 @@ impl Setting for VimModeSetting {
fn local_selections_changed(newest: Selection<usize>, cx: &mut WindowContext) {
Vim::update(cx, |vim, cx| {
if vim.enabled && vim.state().mode == Mode::Normal && !newest.is_empty() {
- if matches!(newest.goal, SelectionGoal::ColumnRange { .. }) {
+ if matches!(newest.goal, SelectionGoal::HorizontalRange { .. }) {
vim.switch_mode(Mode::VisualBlock, false, cx);
} else {
vim.switch_mode(Mode::Visual, false, cx)
@@ -57,6 +57,7 @@ pub fn init(cx: &mut AppContext) {
pub fn visual_motion(motion: Motion, times: Option<usize>, cx: &mut WindowContext) {
Vim::update(cx, |vim, cx| {
vim.update_active_editor(cx, |editor, cx| {
+ let text_layout_details = editor.text_layout_details(cx);
if vim.state().mode == Mode::VisualBlock
&& !matches!(
motion,
@@ -67,7 +68,7 @@ pub fn visual_motion(motion: Motion, times: Option<usize>, cx: &mut WindowContex
{
let is_up_or_down = matches!(motion, Motion::Up { .. } | Motion::Down { .. });
visual_block_motion(is_up_or_down, editor, cx, |map, point, goal| {
- motion.move_point(map, point, goal, times)
+ motion.move_point(map, point, goal, times, &text_layout_details)
})
} else {
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
@@ -89,9 +90,13 @@ pub fn visual_motion(motion: Motion, times: Option<usize>, cx: &mut WindowContex
current_head = movement::left(map, selection.end)
}
- let Some((new_head, goal)) =
- motion.move_point(map, current_head, selection.goal, times)
- else {
+ let Some((new_head, goal)) = motion.move_point(
+ map,
+ current_head,
+ selection.goal,
+ times,
+ &text_layout_details,
+ ) else {
return;
};
@@ -135,19 +140,23 @@ pub fn visual_block_motion(
SelectionGoal,
) -> Option<(DisplayPoint, SelectionGoal)>,
) {
+ let text_layout_details = editor.text_layout_details(cx);
editor.change_selections(Some(Autoscroll::fit()), cx, |s| {
let map = &s.display_map();
let mut head = s.newest_anchor().head().to_display_point(map);
let mut tail = s.oldest_anchor().tail().to_display_point(map);
+ let mut head_x = map.x_for_point(head, &text_layout_details);
+ let mut tail_x = map.x_for_point(tail, &text_layout_details);
+
let (start, end) = match s.newest_anchor().goal {
- SelectionGoal::ColumnRange { start, end } if preserve_goal => (start, end),
- SelectionGoal::Column(start) if preserve_goal => (start, start + 1),
- _ => (tail.column(), head.column()),
+ SelectionGoal::HorizontalRange { start, end } if preserve_goal => (start, end),
+ SelectionGoal::HorizontalPosition(start) if preserve_goal => (start, start),
+ _ => (tail_x, head_x),
};
- let goal = SelectionGoal::ColumnRange { start, end };
+ let mut goal = SelectionGoal::HorizontalRange { start, end };
- let was_reversed = tail.column() > head.column();
+ let was_reversed = tail_x > head_x;
if !was_reversed && !preserve_goal {
head = movement::saturating_left(map, head);
}
@@ -156,32 +165,56 @@ pub fn visual_block_motion(
return;
};
head = new_head;
+ head_x = map.x_for_point(head, &text_layout_details);
- let is_reversed = tail.column() > head.column();
+ let is_reversed = tail_x > head_x;
if was_reversed && !is_reversed {
- tail = movement::left(map, tail)
+ tail = movement::saturating_left(map, tail);
+ tail_x = map.x_for_point(tail, &text_layout_details);
} else if !was_reversed && is_reversed {
- tail = movement::right(map, tail)
+ tail = movement::saturating_right(map, tail);
+ tail_x = map.x_for_point(tail, &text_layout_details);
}
if !is_reversed && !preserve_goal {
- head = movement::saturating_right(map, head)
+ head = movement::saturating_right(map, head);
+ head_x = map.x_for_point(head, &text_layout_details);
}
- let columns = if is_reversed {
- head.column()..tail.column()
- } else if head.column() == tail.column() {
- head.column()..(head.column() + 1)
+ let positions = if is_reversed {
+ head_x..tail_x
} else {
- tail.column()..head.column()
+ tail_x..head_x
};
+ if !preserve_goal {
+ goal = SelectionGoal::HorizontalRange {
+ start: positions.start,
+ end: positions.end,
+ };
+ }
+
let mut selections = Vec::new();
let mut row = tail.row();
loop {
- let start = map.clip_point(DisplayPoint::new(row, columns.start), Bias::Left);
- let end = map.clip_point(DisplayPoint::new(row, columns.end), Bias::Left);
- if columns.start <= map.line_len(row) {
+ let layed_out_line = map.lay_out_line_for_row(row, &text_layout_details);
+ let start = DisplayPoint::new(
+ row,
+ layed_out_line.closest_index_for_x(positions.start) as u32,
+ );
+ let mut end = DisplayPoint::new(
+ row,
+ layed_out_line.closest_index_for_x(positions.end) as u32,
+ );
+ if end <= start {
+ if start.column() == map.line_len(start.row()) {
+ end = start;
+ } else {
+ end = movement::saturating_right(map, start);
+ }
+ }
+
+ if positions.start <= layed_out_line.width() {
let selection = Selection {
id: s.new_selection_id(),
start: start.to_point(map),
@@ -888,6 +921,28 @@ mod test {
.await;
}
+ #[gpui::test]
+ async fn test_visual_block_issue_2123(cx: &mut gpui::TestAppContext) {
+ let mut cx = NeovimBackedTestContext::new(cx).await;
+
+ cx.set_shared_state(indoc! {
+ "The Λquick brown
+ fox jumps over
+ the lazy dog
+ "
+ })
+ .await;
+ cx.simulate_shared_keystrokes(["ctrl-v", "right", "down"])
+ .await;
+ cx.assert_shared_state(indoc! {
+ "The Β«quΛΒ»ick brown
+ fox Β«juΛΒ»mps over
+ the lazy dog
+ "
+ })
+ .await;
+ }
+
#[gpui::test]
async fn test_visual_block_insert(cx: &mut gpui::TestAppContext) {
let mut cx = NeovimBackedTestContext::new(cx).await;
@@ -9,6 +9,7 @@
{"Key":"ctrl-v"}
{"Key":"g"}
{"Key":"g"}
+{"Get":{"state":"Β«1ΛΒ»\nΒ«2ΛΒ»\nΒ«3ΛΒ» 2\nΒ«4ΛΒ»\nΒ«5ΛΒ»","mode":"VisualBlock"}}
{"Key":"g"}
{"Key":"ctrl-x"}
{"Get":{"state":"Λ0\n0\n0 2\n0\n0","mode":"Normal"}}
@@ -1,3 +1,6 @@
+{"Put":{"state":"aaΛaa\nππ"}}
+{"Key":"j"}
+{"Get":{"state":"aaaa\nπΛπ","mode":"Normal"}}
{"Put":{"state":"ΛThe quick brown\nfox jumps"}}
{"Key":"j"}
{"Get":{"state":"The quick brown\nΛfox jumps","mode":"Normal"}}
@@ -0,0 +1,5 @@
+{"Put":{"state":"The Λquick brown\nfox jumps over\nthe lazy dog\n"}}
+{"Key":"ctrl-v"}
+{"Key":"right"}
+{"Key":"down"}
+{"Get":{"state":"The Β«quΛΒ»ick brown\nfox Β«juΛΒ»mps over\nthe lazy dog\n","mode":"VisualBlock"}}
@@ -0,0 +1,15 @@
+{"SetOption":{"value":"wrap"}}
+{"SetOption":{"value":"columns=12"}}
+{"Put":{"state":"aaΛaa\nππ"}}
+{"Key":"j"}
+{"Get":{"state":"aaaa\nπΛπ","mode":"Normal"}}
+{"Put":{"state":"123456789012aaΛaa\n123456789012ππ"}}
+{"Key":"j"}
+{"Get":{"state":"123456789012aaaa\n123456789012πΛπ","mode":"Normal"}}
+{"Put":{"state":"123456789012aaΛaa\n123456789012ππ"}}
+{"Key":"j"}
+{"Get":{"state":"123456789012aaaa\n123456789012πΛπ","mode":"Normal"}}
+{"Put":{"state":"123456789012aaaaΛaaaaaaaa123456789012\nwow\n123456789012ππππππ123456789012"}}
+{"Key":"j"}
+{"Key":"j"}
+{"Get":{"state":"123456789012aaaaaaaaaaaa123456789012\nwow\n123456789012ππΛππππ123456789012","mode":"Normal"}}
@@ -3,7 +3,7 @@ authors = ["Nathan Sobo <nathansobo@gmail.com>"]
description = "The fast, collaborative code editor."
edition = "2021"
name = "zed"
-version = "0.109.0"
+version = "0.110.0"
publish = false
[lib]
@@ -53,6 +53,7 @@ language_selector = { path = "../language_selector" }
lsp = { path = "../lsp" }
language_tools = { path = "../language_tools" }
node_runtime = { path = "../node_runtime" }
+notifications = { path = "../notifications" }
assistant = { path = "../assistant" }
outline = { path = "../outline" }
plugin_runtime = { path = "../plugin_runtime",optional = true }
@@ -1,3 +1,4 @@
+use ai::completion::OPENAI_API_URL;
use ai::embedding::OpenAIEmbeddings;
use anyhow::{anyhow, Result};
use client::{self, UserStore};
@@ -17,6 +18,7 @@ use std::{cmp, env, fs};
use util::channel::{RELEASE_CHANNEL, RELEASE_CHANNEL_NAME};
use util::http::{self};
use util::paths::EMBEDDINGS_DIR;
+use util::ResultExt;
use zed::languages;
#[derive(Deserialize, Clone, Serialize)]
@@ -469,12 +471,26 @@ fn main() {
.join("embeddings_db");
let languages = languages.clone();
+
+ let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") {
+ Some(api_key)
+ } else if let Some((_, api_key)) = cx
+ .platform()
+ .read_credentials(OPENAI_API_URL)
+ .log_err()
+ .flatten()
+ {
+ String::from_utf8(api_key).log_err()
+ } else {
+ None
+ };
+
let fs = fs.clone();
cx.spawn(|mut cx| async move {
let semantic_index = SemanticIndex::new(
fs.clone(),
db_file_path,
- Arc::new(OpenAIEmbeddings::new(http_client, cx.background())),
+ Arc::new(OpenAIEmbeddings::new(api_key, http_client, cx.background())),
languages.clone(),
cx.clone(),
)
@@ -76,7 +76,10 @@ pub fn init(
elixir::ElixirLspSetting::ElixirLs => language(
"elixir",
tree_sitter_elixir::language(),
- vec![Arc::new(elixir::ElixirLspAdapter)],
+ vec![
+ Arc::new(elixir::ElixirLspAdapter),
+ Arc::new(tailwind::TailwindLspAdapter::new(node_runtime.clone())),
+ ],
),
elixir::ElixirLspSetting::NextLs => language(
"elixir",
@@ -101,7 +104,10 @@ pub fn init(
language(
"heex",
tree_sitter_heex::language(),
- vec![Arc::new(elixir::ElixirLspAdapter)],
+ vec![
+ Arc::new(elixir::ElixirLspAdapter),
+ Arc::new(tailwind::TailwindLspAdapter::new(node_runtime.clone())),
+ ],
);
language(
"json",
@@ -167,7 +173,10 @@ pub fn init(
language(
"erb",
tree_sitter_embedded_template::language(),
- vec![Arc::new(ruby::RubyLanguageServer)],
+ vec![
+ Arc::new(ruby::RubyLanguageServer),
+ Arc::new(tailwind::TailwindLspAdapter::new(node_runtime.clone())),
+ ],
);
language("scheme", tree_sitter_scheme::language(), vec![]);
language("racket", tree_sitter_racket::language(), vec![]);
@@ -184,16 +193,18 @@ pub fn init(
language(
"svelte",
tree_sitter_svelte::language(),
- vec![Arc::new(svelte::SvelteLspAdapter::new(
- node_runtime.clone(),
- ))],
+ vec![
+ Arc::new(svelte::SvelteLspAdapter::new(node_runtime.clone())),
+ Arc::new(tailwind::TailwindLspAdapter::new(node_runtime.clone())),
+ ],
);
language(
"php",
tree_sitter_php::language(),
- vec![Arc::new(php::IntelephenseLspAdapter::new(
- node_runtime.clone(),
- ))],
+ vec![
+ Arc::new(php::IntelephenseLspAdapter::new(node_runtime.clone())),
+ Arc::new(tailwind::TailwindLspAdapter::new(node_runtime.clone())),
+ ],
);
language("elm", tree_sitter_elm::language(), vec![]);
@@ -1,7 +1,7 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::StreamExt;
-use language::{BundledFormatter, LanguageServerName, LspAdapter, LspAdapterDelegate};
+use language::{LanguageServerName, LspAdapter, LspAdapterDelegate};
use lsp::LanguageServerBinary;
use node_runtime::NodeRuntime;
use serde_json::json;
@@ -96,10 +96,6 @@ impl LspAdapter for CssLspAdapter {
"provideFormatter": true
}))
}
-
- fn enabled_formatters(&self) -> Vec<BundledFormatter> {
- vec![BundledFormatter::prettier("css")]
- }
}
async fn get_cached_server_binary(
@@ -10,3 +10,4 @@ brackets = [
]
word_characters = ["-"]
block_comment = ["/* ", " */"]
+prettier_parser_name = "css"
@@ -9,3 +9,8 @@ brackets = [
{ start = "\"", end = "\"", close = true, newline = false, not_in = ["string", "comment"] },
{ start = "'", end = "'", close = true, newline = false, not_in = ["string", "comment"] },
]
+scope_opt_in_language_servers = ["tailwindcss-language-server"]
+
+[overrides.string]
+word_characters = ["-"]
+opt_into_language_servers = ["tailwindcss-language-server"]
@@ -5,3 +5,4 @@ brackets = [
{ start = "<", end = ">", close = true, newline = true },
]
block_comment = ["<%#", "%>"]
+scope_opt_in_language_servers = ["tailwindcss-language-server"]
@@ -5,3 +5,8 @@ brackets = [
{ start = "<", end = ">", close = true, newline = true },
]
block_comment = ["<%!-- ", " --%>"]
+scope_opt_in_language_servers = ["tailwindcss-language-server"]
+
+[overrides.string]
+word_characters = ["-"]
+opt_into_language_servers = ["tailwindcss-language-server"]
@@ -0,0 +1,4 @@
+[
+ (attribute_value)
+ (quoted_attribute_value)
+] @string
@@ -1,7 +1,7 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::StreamExt;
-use language::{BundledFormatter, LanguageServerName, LspAdapter, LspAdapterDelegate};
+use language::{LanguageServerName, LspAdapter, LspAdapterDelegate};
use lsp::LanguageServerBinary;
use node_runtime::NodeRuntime;
use serde_json::json;
@@ -96,10 +96,6 @@ impl LspAdapter for HtmlLspAdapter {
"provideFormatter": true
}))
}
-
- fn enabled_formatters(&self) -> Vec<BundledFormatter> {
- vec![BundledFormatter::prettier("html")]
- }
}
async fn get_cached_server_binary(
@@ -11,3 +11,4 @@ brackets = [
{ start = "!--", end = " --", close = true, newline = false, not_in = ["comment", "string"] },
]
word_characters = ["-"]
+prettier_parser_name = "html"
@@ -15,6 +15,7 @@ brackets = [
]
word_characters = ["$", "#"]
scope_opt_in_language_servers = ["tailwindcss-language-server"]
+prettier_parser_name = "babel"
[overrides.element]
line_comment = { remove = true }
@@ -4,9 +4,7 @@ use collections::HashMap;
use feature_flags::FeatureFlagAppExt;
use futures::{future::BoxFuture, FutureExt, StreamExt};
use gpui::AppContext;
-use language::{
- BundledFormatter, LanguageRegistry, LanguageServerName, LspAdapter, LspAdapterDelegate,
-};
+use language::{LanguageRegistry, LanguageServerName, LspAdapter, LspAdapterDelegate};
use lsp::LanguageServerBinary;
use node_runtime::NodeRuntime;
use serde_json::json;
@@ -146,10 +144,6 @@ impl LspAdapter for JsonLspAdapter {
async fn language_ids(&self) -> HashMap<String, String> {
[("JSON".into(), "jsonc".into())].into_iter().collect()
}
-
- fn enabled_formatters(&self) -> Vec<BundledFormatter> {
- vec![BundledFormatter::prettier("json")]
- }
}
async fn get_cached_server_binary(
@@ -7,3 +7,4 @@ brackets = [
{ start = "[", end = "]", close = true, newline = true },
{ start = "\"", end = "\"", close = true, newline = false, not_in = ["string"] },
]
+prettier_parser_name = "json"
@@ -11,3 +11,4 @@ brackets = [
]
collapsed_placeholder = "/* ... */"
word_characters = ["$"]
+scope_opt_in_language_servers = ["tailwindcss-language-server"]
@@ -1,7 +1,7 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use futures::StreamExt;
-use language::{BundledFormatter, LanguageServerName, LspAdapter, LspAdapterDelegate};
+use language::{LanguageServerName, LspAdapter, LspAdapterDelegate};
use lsp::LanguageServerBinary;
use node_runtime::NodeRuntime;
use serde_json::json;
@@ -96,11 +96,8 @@ impl LspAdapter for SvelteLspAdapter {
}))
}
- fn enabled_formatters(&self) -> Vec<BundledFormatter> {
- vec![BundledFormatter::Prettier {
- parser_name: Some("svelte"),
- plugin_names: vec!["prettier-plugin-svelte"],
- }]
+ fn prettier_plugins(&self) -> &[&'static str] {
+ &["prettier-plugin-svelte"]
}
}
@@ -12,7 +12,9 @@ brackets = [
{ start = "`", end = "`", close = true, newline = false, not_in = ["string"] },
{ start = "/*", end = " */", close = true, newline = false, not_in = ["string", "comment"] },
]
+scope_opt_in_language_servers = ["tailwindcss-language-server"]
+prettier_parser_name = "svelte"
-[overrides.element]
-line_comment = { remove = true }
-block_comment = ["{/* ", " */}"]
+[overrides.string]
+word_characters = ["-"]
+opt_into_language_servers = ["tailwindcss-language-server"]
@@ -0,0 +1,7 @@
+(comment) @comment
+
+[
+ (raw_text)
+ (attribute_value)
+ (quoted_attribute_value)
+] @string
@@ -6,7 +6,7 @@ use futures::{
FutureExt, StreamExt,
};
use gpui::AppContext;
-use language::{BundledFormatter, LanguageServerName, LspAdapter, LspAdapterDelegate};
+use language::{LanguageServerName, LspAdapter, LspAdapterDelegate};
use lsp::LanguageServerBinary;
use node_runtime::NodeRuntime;
use serde_json::{json, Value};
@@ -117,22 +117,21 @@ impl LspAdapter for TailwindLspAdapter {
}
async fn language_ids(&self) -> HashMap<String, String> {
- HashMap::from_iter(
- [
- ("HTML".to_string(), "html".to_string()),
- ("CSS".to_string(), "css".to_string()),
- ("JavaScript".to_string(), "javascript".to_string()),
- ("TSX".to_string(), "typescriptreact".to_string()),
- ]
- .into_iter(),
- )
+ HashMap::from_iter([
+ ("HTML".to_string(), "html".to_string()),
+ ("CSS".to_string(), "css".to_string()),
+ ("JavaScript".to_string(), "javascript".to_string()),
+ ("TSX".to_string(), "typescriptreact".to_string()),
+ ("Svelte".to_string(), "svelte".to_string()),
+ ("Elixir".to_string(), "phoenix-heex".to_string()),
+ ("HEEX".to_string(), "phoenix-heex".to_string()),
+ ("ERB".to_string(), "erb".to_string()),
+ ("PHP".to_string(), "php".to_string()),
+ ])
}
- fn enabled_formatters(&self) -> Vec<BundledFormatter> {
- vec![BundledFormatter::Prettier {
- parser_name: None,
- plugin_names: vec!["prettier-plugin-tailwindcss"],
- }]
+ fn prettier_plugins(&self) -> &[&'static str] {
+ &["prettier-plugin-tailwindcss"]
}
}
@@ -14,6 +14,7 @@ brackets = [
]
word_characters = ["#", "$"]
scope_opt_in_language_servers = ["tailwindcss-language-server"]
+prettier_parser_name = "typescript"
[overrides.element]
line_comment = { remove = true }
@@ -4,7 +4,7 @@ use async_tar::Archive;
use async_trait::async_trait;
use futures::{future::BoxFuture, FutureExt};
use gpui::AppContext;
-use language::{BundledFormatter, LanguageServerName, LspAdapter, LspAdapterDelegate};
+use language::{LanguageServerName, LspAdapter, LspAdapterDelegate};
use lsp::{CodeActionKind, LanguageServerBinary};
use node_runtime::NodeRuntime;
use serde_json::{json, Value};
@@ -161,10 +161,6 @@ impl LspAdapter for TypeScriptLspAdapter {
"provideFormatter": true
}))
}
-
- fn enabled_formatters(&self) -> Vec<BundledFormatter> {
- vec![BundledFormatter::prettier("typescript")]
- }
}
async fn get_cached_ts_server_binary(
@@ -313,10 +309,6 @@ impl LspAdapter for EsLintLspAdapter {
async fn initialization_options(&self) -> Option<serde_json::Value> {
None
}
-
- fn enabled_formatters(&self) -> Vec<BundledFormatter> {
- vec![BundledFormatter::prettier("babel")]
- }
}
async fn get_cached_eslint_server_binary(
@@ -13,3 +13,4 @@ brackets = [
{ start = "/*", end = " */", close = true, newline = false, not_in = ["string", "comment"] },
]
word_characters = ["#", "$"]
+prettier_parser_name = "typescript"
@@ -3,8 +3,7 @@ use async_trait::async_trait;
use futures::{future::BoxFuture, FutureExt, StreamExt};
use gpui::AppContext;
use language::{
- language_settings::all_language_settings, BundledFormatter, LanguageServerName, LspAdapter,
- LspAdapterDelegate,
+ language_settings::all_language_settings, LanguageServerName, LspAdapter, LspAdapterDelegate,
};
use lsp::LanguageServerBinary;
use node_runtime::NodeRuntime;
@@ -109,10 +108,6 @@ impl LspAdapter for YamlLspAdapter {
}))
.boxed()
}
-
- fn enabled_formatters(&self) -> Vec<BundledFormatter> {
- vec![BundledFormatter::prettier("yaml")]
- }
}
async fn get_cached_server_binary(
@@ -9,3 +9,4 @@ brackets = [
]
increase_indent_pattern = ":\\s*[|>]?\\s*$"
+prettier_parser_name = "yaml"
@@ -191,6 +191,7 @@ fn main() {
activity_indicator::init(cx);
language_tools::init(cx);
call::init(app_state.client.clone(), app_state.user_store.clone(), cx);
+ notifications::init(app_state.client.clone(), app_state.user_store.clone(), cx);
collab_ui::init(&app_state, cx);
feedback::init(cx);
welcome::init(cx);
@@ -227,6 +227,13 @@ pub fn init(app_state: &Arc<AppState>, cx: &mut gpui::AppContext) {
workspace.toggle_panel_focus::<collab_ui::chat_panel::ChatPanel>(cx);
},
);
+ cx.add_action(
+ |workspace: &mut Workspace,
+ _: &collab_ui::notification_panel::ToggleFocus,
+ cx: &mut ViewContext<Workspace>| {
+ workspace.toggle_panel_focus::<collab_ui::notification_panel::NotificationPanel>(cx);
+ },
+ );
cx.add_action(
|workspace: &mut Workspace,
_: &terminal_panel::ToggleFocus,
@@ -281,9 +288,8 @@ pub fn initialize_workspace(
QuickActionBar::new(buffer_search_bar, workspace)
});
toolbar.add_item(quick_action_bar, cx);
- let diagnostic_editor_controls = cx.add_view(|_| {
- diagnostics::ToolbarControls::new()
- });
+ let diagnostic_editor_controls =
+ cx.add_view(|_| diagnostics::ToolbarControls::new());
toolbar.add_item(diagnostic_editor_controls, cx);
let project_search_bar = cx.add_view(|_| ProjectSearchBar::new());
toolbar.add_item(project_search_bar, cx);
@@ -357,12 +363,24 @@ pub fn initialize_workspace(
collab_ui::collab_panel::CollabPanel::load(workspace_handle.clone(), cx.clone());
let chat_panel =
collab_ui::chat_panel::ChatPanel::load(workspace_handle.clone(), cx.clone());
- let (project_panel, terminal_panel, assistant_panel, channels_panel, chat_panel) = futures::try_join!(
+ let notification_panel = collab_ui::notification_panel::NotificationPanel::load(
+ workspace_handle.clone(),
+ cx.clone(),
+ );
+ let (
+ project_panel,
+ terminal_panel,
+ assistant_panel,
+ channels_panel,
+ chat_panel,
+ notification_panel,
+ ) = futures::try_join!(
project_panel,
terminal_panel,
assistant_panel,
channels_panel,
chat_panel,
+ notification_panel,
)?;
workspace_handle.update(&mut cx, |workspace, cx| {
let project_panel_position = project_panel.position(cx);
@@ -383,6 +401,7 @@ pub fn initialize_workspace(
workspace.add_panel(assistant_panel, cx);
workspace.add_panel(channels_panel, cx);
workspace.add_panel(chat_panel, cx);
+ workspace.add_panel(notification_panel, cx);
if !was_deserialized
&& workspace
@@ -2432,6 +2451,7 @@ mod tests {
audio::init((), cx);
channel::init(&app_state.client, app_state.user_store.clone(), cx);
call::init(app_state.client.clone(), app_state.user_store.clone(), cx);
+ notifications::init(app_state.client.clone(), app_state.user_store.clone(), cx);
workspace::init(app_state.clone(), cx);
Project::init_settings(cx);
language::init(cx);
@@ -55,6 +55,8 @@ let users = [
'iamnbutler'
]
+const RUST_LOG = process.env.RUST_LOG || 'info'
+
// If a user is specified, make sure it's first in the list
const user = process.env.ZED_IMPERSONATE
if (user) {
@@ -81,7 +83,9 @@ setTimeout(() => {
ZED_ALWAYS_ACTIVE: '1',
ZED_SERVER_URL: 'http://localhost:8080',
ZED_ADMIN_API_TOKEN: 'secret',
- ZED_WINDOW_SIZE: `${instanceWidth},${instanceHeight}`
+ ZED_WINDOW_SIZE: `${instanceWidth},${instanceHeight}`,
+ PATH: process.env.PATH,
+ RUST_LOG,
}
})
}
@@ -13,6 +13,7 @@ import project_shared_notification from "./project_shared_notification"
import tooltip from "./tooltip"
import terminal from "./terminal"
import chat_panel from "./chat_panel"
+import notification_panel from "./notification_panel"
import collab_panel from "./collab_panel"
import toolbar_dropdown_menu from "./toolbar_dropdown_menu"
import incoming_call_notification from "./incoming_call_notification"
@@ -57,6 +58,7 @@ export default function app(): any {
assistant: assistant(),
feedback: feedback(),
chat_panel: chat_panel(),
+ notification_panel: notification_panel(),
component_test: component_test(),
}
}
@@ -1,10 +1,6 @@
-import {
- background,
- border,
- text,
-} from "./components"
+import { background, border, foreground, text } from "./components"
import { icon_button } from "../component/icon_button"
-import { useTheme } from "../theme"
+import { useTheme, with_opacity } from "../theme"
import { interactive } from "../element"
export default function chat_panel(): any {
@@ -41,15 +37,13 @@ export default function chat_panel(): any {
left: 2,
top: 2,
bottom: 2,
- }
- },
- list: {
-
+ },
},
+ list: {},
channel_select: {
header: {
...channel_name,
- border: border(layer, { bottom: true })
+ border: border(layer, { bottom: true }),
},
item: channel_name,
active_item: {
@@ -62,8 +56,8 @@ export default function chat_panel(): any {
},
menu: {
background: background(layer, "on"),
- border: border(layer, { bottom: true })
- }
+ border: border(layer, { bottom: true }),
+ },
},
icon_button: icon_button({
variant: "ghost",
@@ -91,6 +85,21 @@ export default function chat_panel(): any {
top: 4,
},
},
+
+ rich_text: {
+ text: text(layer, "sans", "base"),
+ code_background: with_opacity(foreground(layer, "accent"), 0.1),
+ mention_highlight: { weight: "bold" },
+ self_mention_highlight: { weight: "bold" },
+ self_mention_background: background(layer, "active"),
+ },
+ message_sender: {
+ margin: {
+ right: 8,
+ },
+ ...text(layer, "sans", "base", { weight: "bold" }),
+ },
+ message_timestamp: text(layer, "sans", "base", "disabled"),
message: {
...interactive({
base: {
@@ -100,7 +109,7 @@ export default function chat_panel(): any {
bottom: 4,
left: SPACING / 2,
right: SPACING / 3,
- }
+ },
},
state: {
hovered: {
@@ -108,25 +117,9 @@ export default function chat_panel(): any {
},
},
}),
- body: text(layer, "sans", "base"),
- sender: {
- margin: {
- right: 8,
- },
- ...text(layer, "sans", "base", { weight: "bold" }),
- },
- timestamp: text(layer, "sans", "base", "disabled"),
},
last_message_bottom_spacing: SPACING,
continuation_message: {
- body: text(layer, "sans", "base"),
- sender: {
- margin: {
- right: 8,
- },
- ...text(layer, "sans", "base", { weight: "bold" }),
- },
- timestamp: text(layer, "sans", "base", "disabled"),
...interactive({
base: {
padding: {
@@ -134,7 +127,7 @@ export default function chat_panel(): any {
bottom: 4,
left: SPACING / 2,
right: SPACING / 3,
- }
+ },
},
state: {
hovered: {
@@ -144,14 +137,6 @@ export default function chat_panel(): any {
}),
},
pending_message: {
- body: text(layer, "sans", "base"),
- sender: {
- margin: {
- right: 8,
- },
- ...text(layer, "sans", "base", "disabled"),
- },
- timestamp: text(layer, "sans", "base"),
...interactive({
base: {
padding: {
@@ -159,7 +144,7 @@ export default function chat_panel(): any {
bottom: 4,
left: SPACING / 2,
right: SPACING / 3,
- }
+ },
},
state: {
hovered: {
@@ -170,6 +155,6 @@ export default function chat_panel(): any {
},
sign_in_prompt: {
default: text(layer, "sans", "base"),
- }
+ },
}
}
@@ -0,0 +1,80 @@
+import { background, border, text } from "./components"
+import { icon_button } from "../component/icon_button"
+import { useTheme } from "../theme"
+import { interactive } from "../element"
+
+export default function (): any {
+ const theme = useTheme()
+ const layer = theme.middle
+
+ return {
+ background: background(layer),
+ avatar: {
+ icon_width: 24,
+ icon_height: 24,
+ corner_radius: 12,
+ outer_width: 24,
+ outer_corner_radius: 24,
+ },
+ title: {
+ ...text(layer, "sans", "default"),
+ padding: { left: 8, right: 8 },
+ border: border(layer, { bottom: true }),
+ },
+ title_height: 32,
+ title_icon: {
+ asset: "icons/feedback.svg",
+ color: text(theme.lowest, "sans", "default").color,
+ dimensions: {
+ width: 16,
+ height: 16,
+ },
+ },
+ read_text: {
+ padding: { top: 4, bottom: 4 },
+ ...text(layer, "sans", "disabled"),
+ },
+ unread_text: {
+ padding: { top: 4, bottom: 4 },
+ ...text(layer, "sans", "base"),
+ },
+ button: interactive({
+ base: {
+ ...text(theme.lowest, "sans", "on", { size: "xs" }),
+ background: background(theme.lowest, "on"),
+ padding: 4,
+ corner_radius: 6,
+ margin: { left: 6 },
+ },
+
+ state: {
+ hovered: {
+ background: background(theme.lowest, "on", "hovered"),
+ },
+ },
+ }),
+ timestamp: text(layer, "sans", "base", "disabled"),
+ avatar_container: {
+ padding: {
+ right: 6,
+ left: 2,
+ top: 2,
+ bottom: 2,
+ },
+ },
+ list: {
+ padding: {
+ left: 8,
+ right: 8,
+ },
+ },
+ icon_button: icon_button({
+ variant: "ghost",
+ color: "variant",
+ size: "sm",
+ }),
+ sign_in_prompt: {
+ default: text(layer, "sans", "base"),
+ },
+ }
+}