1pub mod assistant_panel;
2pub mod assistant_settings;
3mod completion_provider;
4mod context_store;
5mod inline_assistant;
6mod model_selector;
7mod prompt_library;
8mod prompts;
9mod search;
10mod slash_command;
11mod streaming_diff;
12
13pub use assistant_panel::AssistantPanel;
14
15use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OpenAiModel};
16use assistant_slash_command::SlashCommandRegistry;
17use client::{proto, Client};
18use command_palette_hooks::CommandPaletteFilter;
19pub(crate) use completion_provider::*;
20pub(crate) use context_store::*;
21use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
22pub(crate) use inline_assistant::*;
23pub(crate) use model_selector::*;
24use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
25use serde::{Deserialize, Serialize};
26use settings::{Settings, SettingsStore};
27use slash_command::{
28 active_command, default_command, fetch_command, file_command, now_command, project_command,
29 prompt_command, rustdoc_command, search_command, tabs_command,
30};
31use std::{
32 fmt::{self, Display},
33 sync::Arc,
34};
35pub(crate) use streaming_diff::*;
36use util::paths::EMBEDDINGS_DIR;
37
38actions!(
39 assistant,
40 [
41 Assist,
42 Split,
43 CycleMessageRole,
44 QuoteSelection,
45 ToggleFocus,
46 ResetKey,
47 InlineAssist,
48 InsertActivePrompt,
49 ToggleHistory,
50 ApplyEdit,
51 ConfirmCommand,
52 ToggleModelSelector
53 ]
54);
55
56#[derive(
57 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
58)]
59struct MessageId(usize);
60
61#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
62#[serde(rename_all = "lowercase")]
63pub enum Role {
64 User,
65 Assistant,
66 System,
67}
68
69impl Role {
70 pub fn cycle(&mut self) {
71 *self = match self {
72 Role::User => Role::Assistant,
73 Role::Assistant => Role::System,
74 Role::System => Role::User,
75 }
76 }
77}
78
79impl Display for Role {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
81 match self {
82 Role::User => write!(f, "user"),
83 Role::Assistant => write!(f, "assistant"),
84 Role::System => write!(f, "system"),
85 }
86 }
87}
88
89#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
90pub enum LanguageModel {
91 Cloud(CloudModel),
92 OpenAi(OpenAiModel),
93 Anthropic(AnthropicModel),
94}
95
96impl Default for LanguageModel {
97 fn default() -> Self {
98 LanguageModel::Cloud(CloudModel::default())
99 }
100}
101
102impl LanguageModel {
103 pub fn telemetry_id(&self) -> String {
104 match self {
105 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
106 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
107 LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
108 }
109 }
110
111 pub fn display_name(&self) -> String {
112 match self {
113 LanguageModel::OpenAi(model) => model.display_name().into(),
114 LanguageModel::Anthropic(model) => model.display_name().into(),
115 LanguageModel::Cloud(model) => model.display_name().into(),
116 }
117 }
118
119 pub fn max_token_count(&self) -> usize {
120 match self {
121 LanguageModel::OpenAi(model) => model.max_token_count(),
122 LanguageModel::Anthropic(model) => model.max_token_count(),
123 LanguageModel::Cloud(model) => model.max_token_count(),
124 }
125 }
126
127 pub fn id(&self) -> &str {
128 match self {
129 LanguageModel::OpenAi(model) => model.id(),
130 LanguageModel::Anthropic(model) => model.id(),
131 LanguageModel::Cloud(model) => model.id(),
132 }
133 }
134}
135
136#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
137pub struct LanguageModelRequestMessage {
138 pub role: Role,
139 pub content: String,
140}
141
142impl LanguageModelRequestMessage {
143 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
144 proto::LanguageModelRequestMessage {
145 role: match self.role {
146 Role::User => proto::LanguageModelRole::LanguageModelUser,
147 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
148 Role::System => proto::LanguageModelRole::LanguageModelSystem,
149 } as i32,
150 content: self.content.clone(),
151 tool_calls: Vec::new(),
152 tool_call_id: None,
153 }
154 }
155}
156
157#[derive(Debug, Default, Serialize)]
158pub struct LanguageModelRequest {
159 pub model: LanguageModel,
160 pub messages: Vec<LanguageModelRequestMessage>,
161 pub stop: Vec<String>,
162 pub temperature: f32,
163}
164
165impl LanguageModelRequest {
166 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
167 proto::CompleteWithLanguageModel {
168 model: self.model.id().to_string(),
169 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
170 stop: self.stop.clone(),
171 temperature: self.temperature,
172 tool_choice: None,
173 tools: Vec::new(),
174 }
175 }
176
177 /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
178 pub fn preprocess(&mut self) {
179 match &self.model {
180 LanguageModel::OpenAi(_) => {}
181 LanguageModel::Anthropic(_) => {}
182 LanguageModel::Cloud(model) => match model {
183 CloudModel::Claude3Opus | CloudModel::Claude3Sonnet | CloudModel::Claude3Haiku => {
184 preprocess_anthropic_request(self);
185 }
186 _ => {}
187 },
188 }
189 }
190}
191
192#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
193pub struct LanguageModelResponseMessage {
194 pub role: Option<Role>,
195 pub content: Option<String>,
196}
197
198#[derive(Deserialize, Debug)]
199pub struct LanguageModelUsage {
200 pub prompt_tokens: u32,
201 pub completion_tokens: u32,
202 pub total_tokens: u32,
203}
204
205#[derive(Deserialize, Debug)]
206pub struct LanguageModelChoiceDelta {
207 pub index: u32,
208 pub delta: LanguageModelResponseMessage,
209 pub finish_reason: Option<String>,
210}
211
212#[derive(Clone, Debug, Serialize, Deserialize)]
213struct MessageMetadata {
214 role: Role,
215 status: MessageStatus,
216}
217
218#[derive(Clone, Debug, Serialize, Deserialize)]
219enum MessageStatus {
220 Pending,
221 Done,
222 Error(SharedString),
223}
224
225/// The state pertaining to the Assistant.
226#[derive(Default)]
227struct Assistant {
228 /// Whether the Assistant is enabled.
229 enabled: bool,
230}
231
232impl Global for Assistant {}
233
234impl Assistant {
235 const NAMESPACE: &'static str = "assistant";
236
237 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
238 if self.enabled == enabled {
239 return;
240 }
241
242 self.enabled = enabled;
243
244 if !enabled {
245 CommandPaletteFilter::update_global(cx, |filter, _cx| {
246 filter.hide_namespace(Self::NAMESPACE);
247 });
248
249 return;
250 }
251
252 CommandPaletteFilter::update_global(cx, |filter, _cx| {
253 filter.show_namespace(Self::NAMESPACE);
254 });
255 }
256}
257
258pub fn init(client: Arc<Client>, cx: &mut AppContext) {
259 cx.set_global(Assistant::default());
260 AssistantSettings::register(cx);
261
262 cx.spawn(|mut cx| {
263 let client = client.clone();
264 async move {
265 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
266 let semantic_index = SemanticIndex::new(
267 EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
268 Arc::new(embedding_provider),
269 &mut cx,
270 )
271 .await?;
272 cx.update(|cx| cx.set_global(semantic_index))
273 }
274 })
275 .detach();
276
277 prompt_library::init(cx);
278 completion_provider::init(client.clone(), cx);
279 assistant_slash_command::init(cx);
280 register_slash_commands(cx);
281 assistant_panel::init(cx);
282 inline_assistant::init(client.telemetry().clone(), cx);
283
284 CommandPaletteFilter::update_global(cx, |filter, _cx| {
285 filter.hide_namespace(Assistant::NAMESPACE);
286 });
287 Assistant::update_global(cx, |assistant, cx| {
288 let settings = AssistantSettings::get_global(cx);
289
290 assistant.set_enabled(settings.enabled, cx);
291 });
292 cx.observe_global::<SettingsStore>(|cx| {
293 Assistant::update_global(cx, |assistant, cx| {
294 let settings = AssistantSettings::get_global(cx);
295 assistant.set_enabled(settings.enabled, cx);
296 });
297 })
298 .detach();
299}
300
301fn register_slash_commands(cx: &mut AppContext) {
302 let slash_command_registry = SlashCommandRegistry::global(cx);
303 slash_command_registry.register_command(file_command::FileSlashCommand, true);
304 slash_command_registry.register_command(active_command::ActiveSlashCommand, true);
305 slash_command_registry.register_command(tabs_command::TabsSlashCommand, true);
306 slash_command_registry.register_command(project_command::ProjectSlashCommand, true);
307 slash_command_registry.register_command(search_command::SearchSlashCommand, true);
308 slash_command_registry.register_command(prompt_command::PromptSlashCommand, true);
309 slash_command_registry.register_command(default_command::DefaultSlashCommand, true);
310 slash_command_registry.register_command(now_command::NowSlashCommand, true);
311 slash_command_registry.register_command(rustdoc_command::RustdocSlashCommand, false);
312 slash_command_registry.register_command(fetch_command::FetchSlashCommand, false);
313}
314
315#[cfg(test)]
316#[ctor::ctor]
317fn init_logger() {
318 if std::env::var("RUST_LOG").is_ok() {
319 env_logger::init();
320 }
321}