1pub mod assistant_panel;
2pub mod assistant_settings;
3mod codegen;
4mod completion_provider;
5mod prompts;
6mod saved_conversation;
7mod streaming_diff;
8
9pub use assistant_panel::AssistantPanel;
10use assistant_settings::{AssistantSettings, OpenAiModel, ZedDotDevModel};
11use chrono::{DateTime, Local};
12use client::{proto, Client};
13use command_palette_hooks::CommandPaletteFilter;
14pub(crate) use completion_provider::*;
15use gpui::{actions, AppContext, BorrowAppContext, Global, SharedString};
16pub(crate) use saved_conversation::*;
17use serde::{Deserialize, Serialize};
18use settings::{Settings, SettingsStore};
19use std::{
20 fmt::{self, Display},
21 sync::Arc,
22};
23
24actions!(
25 assistant,
26 [
27 NewConversation,
28 Assist,
29 Split,
30 CycleMessageRole,
31 QuoteSelection,
32 ToggleFocus,
33 ResetKey,
34 InlineAssist,
35 ToggleIncludeConversation,
36 ]
37);
38
39#[derive(
40 Copy, Clone, Debug, Default, Eq, PartialEq, PartialOrd, Ord, Hash, Serialize, Deserialize,
41)]
42struct MessageId(usize);
43
44#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
45#[serde(rename_all = "lowercase")]
46pub enum Role {
47 User,
48 Assistant,
49 System,
50}
51
52impl Role {
53 pub fn cycle(&mut self) {
54 *self = match self {
55 Role::User => Role::Assistant,
56 Role::Assistant => Role::System,
57 Role::System => Role::User,
58 }
59 }
60}
61
62impl Display for Role {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
64 match self {
65 Role::User => write!(f, "user"),
66 Role::Assistant => write!(f, "assistant"),
67 Role::System => write!(f, "system"),
68 }
69 }
70}
71
72#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
73pub enum LanguageModel {
74 ZedDotDev(ZedDotDevModel),
75 OpenAi(OpenAiModel),
76}
77
78impl Default for LanguageModel {
79 fn default() -> Self {
80 LanguageModel::ZedDotDev(ZedDotDevModel::default())
81 }
82}
83
84impl LanguageModel {
85 pub fn telemetry_id(&self) -> String {
86 match self {
87 LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
88 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.id()),
89 }
90 }
91
92 pub fn display_name(&self) -> String {
93 match self {
94 LanguageModel::OpenAi(model) => format!("openai/{}", model.display_name()),
95 LanguageModel::ZedDotDev(model) => format!("zed.dev/{}", model.display_name()),
96 }
97 }
98
99 pub fn max_token_count(&self) -> usize {
100 match self {
101 LanguageModel::OpenAi(model) => model.max_token_count(),
102 LanguageModel::ZedDotDev(model) => model.max_token_count(),
103 }
104 }
105
106 pub fn id(&self) -> &str {
107 match self {
108 LanguageModel::OpenAi(model) => model.id(),
109 LanguageModel::ZedDotDev(model) => model.id(),
110 }
111 }
112}
113
114#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
115pub struct LanguageModelRequestMessage {
116 pub role: Role,
117 pub content: String,
118}
119
120impl LanguageModelRequestMessage {
121 pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
122 proto::LanguageModelRequestMessage {
123 role: match self.role {
124 Role::User => proto::LanguageModelRole::LanguageModelUser,
125 Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
126 Role::System => proto::LanguageModelRole::LanguageModelSystem,
127 } as i32,
128 content: self.content.clone(),
129 }
130 }
131}
132
133#[derive(Debug, Default, Serialize)]
134pub struct LanguageModelRequest {
135 pub model: LanguageModel,
136 pub messages: Vec<LanguageModelRequestMessage>,
137 pub stop: Vec<String>,
138 pub temperature: f32,
139}
140
141impl LanguageModelRequest {
142 pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
143 proto::CompleteWithLanguageModel {
144 model: self.model.id().to_string(),
145 messages: self.messages.iter().map(|m| m.to_proto()).collect(),
146 stop: self.stop.clone(),
147 temperature: self.temperature,
148 }
149 }
150}
151
152#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
153pub struct LanguageModelResponseMessage {
154 pub role: Option<Role>,
155 pub content: Option<String>,
156}
157
158#[derive(Deserialize, Debug)]
159pub struct LanguageModelUsage {
160 pub prompt_tokens: u32,
161 pub completion_tokens: u32,
162 pub total_tokens: u32,
163}
164
165#[derive(Deserialize, Debug)]
166pub struct LanguageModelChoiceDelta {
167 pub index: u32,
168 pub delta: LanguageModelResponseMessage,
169 pub finish_reason: Option<String>,
170}
171
172#[derive(Clone, Debug, Serialize, Deserialize)]
173struct MessageMetadata {
174 role: Role,
175 sent_at: DateTime<Local>,
176 status: MessageStatus,
177}
178
179#[derive(Clone, Debug, Serialize, Deserialize)]
180enum MessageStatus {
181 Pending,
182 Done,
183 Error(SharedString),
184}
185
186/// The state pertaining to the Assistant.
187#[derive(Default)]
188struct Assistant {
189 /// Whether the Assistant is enabled.
190 enabled: bool,
191}
192
193impl Global for Assistant {}
194
195impl Assistant {
196 const NAMESPACE: &'static str = "assistant";
197
198 fn set_enabled(&mut self, enabled: bool, cx: &mut AppContext) {
199 if self.enabled == enabled {
200 return;
201 }
202
203 self.enabled = enabled;
204
205 if !enabled {
206 CommandPaletteFilter::update_global(cx, |filter, _cx| {
207 filter.hide_namespace(Self::NAMESPACE);
208 });
209
210 return;
211 }
212
213 CommandPaletteFilter::update_global(cx, |filter, _cx| {
214 filter.show_namespace(Self::NAMESPACE);
215 });
216 }
217}
218
219pub fn init(client: Arc<Client>, cx: &mut AppContext) {
220 cx.set_global(Assistant::default());
221 AssistantSettings::register(cx);
222 completion_provider::init(client, cx);
223 assistant_panel::init(cx);
224
225 CommandPaletteFilter::update_global(cx, |filter, _cx| {
226 filter.hide_namespace(Assistant::NAMESPACE);
227 });
228 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
229 let settings = AssistantSettings::get_global(cx);
230
231 assistant.set_enabled(settings.enabled, cx);
232 });
233 cx.observe_global::<SettingsStore>(|cx| {
234 cx.update_global(|assistant: &mut Assistant, cx: &mut AppContext| {
235 let settings = AssistantSettings::get_global(cx);
236
237 assistant.set_enabled(settings.enabled, cx);
238 });
239 })
240 .detach();
241}
242
243#[cfg(test)]
244#[ctor::ctor]
245fn init_logger() {
246 if std::env::var("RUST_LOG").is_ok() {
247 env_logger::init();
248 }
249}