1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod prompts;
6mod saved_conversation;
7mod search;
8mod slash_command;
9mod streaming_diff;
10
11pub use assistant_panel::AssistantPanel;
12
13use assistant_settings::{AnthropicModel, AssistantSettings, OpenAiModel, ZedDotDevModel};
14use client::{proto, Client};
15use command_palette_hooks::CommandPaletteFilter;
16pub(crate) use completion_provider::*;
17use gpui::{actions, AppContext, Global, SharedString, UpdateGlobal};
18pub(crate) use saved_conversation::*;
19use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
20use serde::{Deserialize, Serialize};
21use settings::{Settings, SettingsStore};
22use std::{
23 fmt::{self, Display},
24 sync::Arc,
25};
26use util::paths::EMBEDDINGS_DIR;
27
28actions!(
29 assistant,
30 [
31 Assist,
32 Split,
33 CycleMessageRole,
34 QuoteSelection,
35 ToggleFocus,
36 ResetKey,
37 InlineAssist,
38 InsertActivePrompt,
39 ToggleIncludeConversation,
40 ToggleHistory,
41 ApplyEdit,
42 ConfirmCommand
43 ]
44);
45
46#[derive(
47 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
48)]
49struct MessageId(usize);
50
51#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
52#[serde(rename_all = "lowercase")]
53pub enum Role {
54 User,
55 Assistant,
56 System,
57}
58
59impl Role {
60 pub fn cycle(&mut self) {
61 *self = match self {
62 Role::User => Role::Assistant,
63 Role::Assistant => Role::System,
64 Role::System => Role::User,
65 }
66 }
67}
68
69impl Display for Role {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
71 match self {
72 Role::User => write!(f, "user"),
73 Role::Assistant => write!(f, "assistant"),
74 Role::System => write!(f, "system"),
75 }
76 }
77}
78
79#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
80pub enum LanguageModel {
81 ZedDotDev(ZedDotDevModel),
82 OpenAi(OpenAiModel),
83 Anthropic(AnthropicModel),
84}
85
86impl Default for LanguageModel {
87 fn default() -> Self {
88 LanguageModel::ZedDotDev(ZedDotDevModel::default())
89 }
90}
91
92impl LanguageModel {
93 pub fn telemetry_id(&self) -> String {
94 match self {
95 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
96 LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
97 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
98 }
99 }
100
101 pub fn display_name(&self) -> String {
102 match self {
103 LanguageModel::OpenAi(model) => model.display_name().into(),
104 LanguageModel::Anthropic(model) => model.display_name().into(),
105 LanguageModel::ZedDotDev(model) => model.display_name().into(),
106 }
107 }
108
109 pub fn max_token_count(&self) -> usize {
110 match self {
111 LanguageModel::OpenAi(model) => model.max_token_count(),
112 LanguageModel::Anthropic(model) => model.max_token_count(),
113 LanguageModel::ZedDotDev(model) => model.max_token_count(),
114 }
115 }
116
117 pub fn id(&self) -> &str {
118 match self {
119 LanguageModel::OpenAi(model) => model.id(),
120 LanguageModel::Anthropic(model) => model.id(),
121 LanguageModel::ZedDotDev(model) => model.id(),
122 }
123 }
124}
125
126#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
127pub struct LanguageModelRequestMessage {
128 pub role: Role,
129 pub content: String,
130}
131
132impl LanguageModelRequestMessage {
133 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
134 proto::LanguageModelRequestMessage {
135 role: match self.role {
136 Role::User => proto::LanguageModelRole::LanguageModelUser,
137 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
138 Role::System => proto::LanguageModelRole::LanguageModelSystem,
139 } as i32,
140 content: self.content.clone(),
141 tool_calls: Vec::new(),
142 tool_call_id: None,
143 }
144 }
145}
146
147#[derive(Debug, Default, Serialize)]
148pub struct LanguageModelRequest {
149 pub model: LanguageModel,
150 pub messages: Vec<LanguageModelRequestMessage>,
151 pub stop: Vec<String>,
152 pub temperature: f32,
153}
154
155impl LanguageModelRequest {
156 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
157 proto::CompleteWithLanguageModel {
158 model: self.model.id().to_string(),
159 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
160 stop: self.stop.clone(),
161 temperature: self.temperature,
162 tool_choice: None,
163 tools: Vec::new(),
164 }
165 }
166}
167
168#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
169pub struct LanguageModelResponseMessage {
170 pub role: Option<Role>,
171 pub content: Option<String>,
172}
173
174#[derive(Deserialize, Debug)]
175pub struct LanguageModelUsage {
176 pub prompt_tokens: u32,
177 pub completion_tokens: u32,
178 pub total_tokens: u32,
179}
180
181#[derive(Deserialize, Debug)]
182pub struct LanguageModelChoiceDelta {
183 pub index: u32,
184 pub delta: LanguageModelResponseMessage,
185 pub finish_reason: Option<String>,
186}
187
188#[derive(Clone, Debug, Serialize, Deserialize)]
189struct MessageMetadata {
190 role: Role,
191 status: MessageStatus,
192}
193
194#[derive(Clone, Debug, Serialize, Deserialize)]
195enum MessageStatus {
196 Pending,
197 Done,
198 Error(SharedString),
199}
200
201/// The state pertaining to the Assistant.
202#[derive(Default)]
203struct Assistant {
204 /// Whether the Assistant is enabled.
205 enabled: bool,
206}
207
208impl Global for Assistant {}
209
210impl Assistant {
211 const NAMESPACE: &'static str = "assistant";
212
213 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
214 if self.enabled == enabled {
215 return;
216 }
217
218 self.enabled = enabled;
219
220 if !enabled {
221 CommandPaletteFilter::update_global(cx, |filter, _cx| {
222 filter.hide_namespace(Self::NAMESPACE);
223 });
224
225 return;
226 }
227
228 CommandPaletteFilter::update_global(cx, |filter, _cx| {
229 filter.show_namespace(Self::NAMESPACE);
230 });
231 }
232}
233
234pub fn init(client: Arc<Client>, cx: &mut AppContext) {
235 cx.set_global(Assistant::default());
236 AssistantSettings::register(cx);
237
238 cx.spawn(|mut cx| {
239 let client = client.clone();
240 async move {
241 let embedding_provider = CloudEmbeddingProvider::new(client.clone());
242 let semantic_index = SemanticIndex::new(
243 EMBEDDINGS_DIR.join("semantic-index-db.0.mdb"),
244 Arc::new(embedding_provider),
245 &mut cx,
246 )
247 .await?;
248 cx.update(|cx| cx.set_global(semantic_index))
249 }
250 })
251 .detach();
252 completion_provider::init(client, cx);
253 assistant_slash_command::init(cx);
254 assistant_panel::init(cx);
255
256 CommandPaletteFilter::update_global(cx, |filter, _cx| {
257 filter.hide_namespace(Assistant::NAMESPACE);
258 });
259 Assistant::update_global(cx, |assistant, cx| {
260 let settings = AssistantSettings::get_global(cx);
261
262 assistant.set_enabled(settings.enabled, cx);
263 });
264 cx.observe_global::<SettingsStore>(|cx| {
265 Assistant::update_global(cx, |assistant, cx| {
266 let settings = AssistantSettings::get_global(cx);
267
268 assistant.set_enabled(settings.enabled, cx);
269 });
270 })
271 .detach();
272}
273
274#[cfg(test)]
275#[ctor::ctor]
276fn init_logger() {
277 if std::env::var("RUST_LOG").is_ok() {
278 env_logger::init();
279 }
280}