1use anyhow::{anyhow, bail, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
4use http_client::HttpClient;
5use ollama::{
6 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
7 ChatResponseDelta, KeepAlive, OllamaToolCall,
8};
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use settings::{Settings, SettingsStore};
12use std::{collections::BTreeMap, sync::Arc, time::Duration};
13use ui::{prelude::*, ButtonLike, Indicator};
14use util::ResultExt;
15
16use crate::LanguageModelCompletionEvent;
17use crate::{
18 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
19 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
20 LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
21};
22
23const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
24const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
25const OLLAMA_SITE: &str = "https://ollama.com/";
26
27const PROVIDER_ID: &str = "ollama";
28const PROVIDER_NAME: &str = "Ollama";
29
30#[derive(Default, Debug, Clone, PartialEq)]
31pub struct OllamaSettings {
32 pub api_url: String,
33 pub low_speed_timeout: Option<Duration>,
34 pub available_models: Vec<AvailableModel>,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 /// The model name in the Ollama API (e.g. "llama3.1:latest")
40 pub name: String,
41 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
42 pub display_name: Option<String>,
43 /// The Context Length parameter to the model (aka num_ctx or n_ctx)
44 pub max_tokens: usize,
45 /// The number of seconds to keep the connection open after the last request
46 pub keep_alive: Option<KeepAlive>,
47}
48
49pub struct OllamaLanguageModelProvider {
50 http_client: Arc<dyn HttpClient>,
51 state: gpui::Model<State>,
52}
53
54pub struct State {
55 http_client: Arc<dyn HttpClient>,
56 available_models: Vec<ollama::Model>,
57 fetch_model_task: Option<Task<Result<()>>>,
58 _subscription: Subscription,
59}
60
61impl State {
62 fn is_authenticated(&self) -> bool {
63 !self.available_models.is_empty()
64 }
65
66 fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
67 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
68 let http_client = self.http_client.clone();
69 let api_url = settings.api_url.clone();
70
71 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
72 cx.spawn(|this, mut cx| async move {
73 let models = get_models(http_client.as_ref(), &api_url, None).await?;
74
75 let mut models: Vec<ollama::Model> = models
76 .into_iter()
77 // Since there is no metadata from the Ollama API
78 // indicating which models are embedding models,
79 // simply filter out models with "-embed" in their name
80 .filter(|model| !model.name.contains("-embed"))
81 .map(|model| ollama::Model::new(&model.name, None, None))
82 .collect();
83
84 models.sort_by(|a, b| a.name.cmp(&b.name));
85
86 this.update(&mut cx, |this, cx| {
87 this.available_models = models;
88 cx.notify();
89 })
90 })
91 }
92
93 fn restart_fetch_models_task(&mut self, cx: &mut ModelContext<Self>) {
94 let task = self.fetch_models(cx);
95 self.fetch_model_task.replace(task);
96 }
97
98 fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
99 if self.is_authenticated() {
100 Task::ready(Ok(()))
101 } else {
102 self.fetch_models(cx)
103 }
104 }
105}
106
107impl OllamaLanguageModelProvider {
108 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
109 let this = Self {
110 http_client: http_client.clone(),
111 state: cx.new_model(|cx| {
112 let subscription = cx.observe_global::<SettingsStore>({
113 let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
114 move |this: &mut State, cx| {
115 let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
116 if &settings != new_settings {
117 settings = new_settings.clone();
118 this.restart_fetch_models_task(cx);
119 cx.notify();
120 }
121 }
122 });
123
124 State {
125 http_client,
126 available_models: Default::default(),
127 fetch_model_task: None,
128 _subscription: subscription,
129 }
130 }),
131 };
132 this.state
133 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
134 this
135 }
136}
137
138impl LanguageModelProviderState for OllamaLanguageModelProvider {
139 type ObservableEntity = State;
140
141 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
142 Some(self.state.clone())
143 }
144}
145
146impl LanguageModelProvider for OllamaLanguageModelProvider {
147 fn id(&self) -> LanguageModelProviderId {
148 LanguageModelProviderId(PROVIDER_ID.into())
149 }
150
151 fn name(&self) -> LanguageModelProviderName {
152 LanguageModelProviderName(PROVIDER_NAME.into())
153 }
154
155 fn icon(&self) -> IconName {
156 IconName::AiOllama
157 }
158
159 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
160 let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
161
162 // Add models from the Ollama API
163 for model in self.state.read(cx).available_models.iter() {
164 models.insert(model.name.clone(), model.clone());
165 }
166
167 // Override with available models from settings
168 for model in AllLanguageModelSettings::get_global(cx)
169 .ollama
170 .available_models
171 .iter()
172 {
173 models.insert(
174 model.name.clone(),
175 ollama::Model {
176 name: model.name.clone(),
177 display_name: model.display_name.clone(),
178 max_tokens: model.max_tokens,
179 keep_alive: model.keep_alive.clone(),
180 },
181 );
182 }
183
184 models
185 .into_values()
186 .map(|model| {
187 Arc::new(OllamaLanguageModel {
188 id: LanguageModelId::from(model.name.clone()),
189 model: model.clone(),
190 http_client: self.http_client.clone(),
191 request_limiter: RateLimiter::new(4),
192 }) as Arc<dyn LanguageModel>
193 })
194 .collect()
195 }
196
197 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
198 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
199 let http_client = self.http_client.clone();
200 let api_url = settings.api_url.clone();
201 let id = model.id().0.to_string();
202 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
203 .detach_and_log_err(cx);
204 }
205
206 fn is_authenticated(&self, cx: &AppContext) -> bool {
207 self.state.read(cx).is_authenticated()
208 }
209
210 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
211 self.state.update(cx, |state, cx| state.authenticate(cx))
212 }
213
214 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
215 let state = self.state.clone();
216 cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
217 }
218
219 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
220 self.state.update(cx, |state, cx| state.fetch_models(cx))
221 }
222}
223
224pub struct OllamaLanguageModel {
225 id: LanguageModelId,
226 model: ollama::Model,
227 http_client: Arc<dyn HttpClient>,
228 request_limiter: RateLimiter,
229}
230
231impl OllamaLanguageModel {
232 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
233 ChatRequest {
234 model: self.model.name.clone(),
235 messages: request
236 .messages
237 .into_iter()
238 .map(|msg| match msg.role {
239 Role::User => ChatMessage::User {
240 content: msg.string_contents(),
241 },
242 Role::Assistant => ChatMessage::Assistant {
243 content: msg.string_contents(),
244 tool_calls: None,
245 },
246 Role::System => ChatMessage::System {
247 content: msg.string_contents(),
248 },
249 })
250 .collect(),
251 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
252 stream: true,
253 options: Some(ChatOptions {
254 num_ctx: Some(self.model.max_tokens),
255 stop: Some(request.stop),
256 temperature: request.temperature.or(Some(1.0)),
257 ..Default::default()
258 }),
259 tools: vec![],
260 }
261 }
262 fn request_completion(
263 &self,
264 request: ChatRequest,
265 cx: &AsyncAppContext,
266 ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
267 let http_client = self.http_client.clone();
268
269 let Ok(api_url) = cx.update(|cx| {
270 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
271 settings.api_url.clone()
272 }) else {
273 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
274 };
275
276 async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
277 }
278}
279
280impl LanguageModel for OllamaLanguageModel {
281 fn id(&self) -> LanguageModelId {
282 self.id.clone()
283 }
284
285 fn name(&self) -> LanguageModelName {
286 LanguageModelName::from(self.model.display_name().to_string())
287 }
288
289 fn provider_id(&self) -> LanguageModelProviderId {
290 LanguageModelProviderId(PROVIDER_ID.into())
291 }
292
293 fn provider_name(&self) -> LanguageModelProviderName {
294 LanguageModelProviderName(PROVIDER_NAME.into())
295 }
296
297 fn telemetry_id(&self) -> String {
298 format!("ollama/{}", self.model.id())
299 }
300
301 fn max_token_count(&self) -> usize {
302 self.model.max_token_count()
303 }
304
305 fn count_tokens(
306 &self,
307 request: LanguageModelRequest,
308 _cx: &AppContext,
309 ) -> BoxFuture<'static, Result<usize>> {
310 // There is no endpoint for this _yet_ in Ollama
311 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
312 let token_count = request
313 .messages
314 .iter()
315 .map(|msg| msg.string_contents().chars().count())
316 .sum::<usize>()
317 / 4;
318
319 async move { Ok(token_count) }.boxed()
320 }
321
322 fn stream_completion(
323 &self,
324 request: LanguageModelRequest,
325 cx: &AsyncAppContext,
326 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
327 let request = self.to_ollama_request(request);
328
329 let http_client = self.http_client.clone();
330 let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
331 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
332 (settings.api_url.clone(), settings.low_speed_timeout)
333 }) else {
334 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
335 };
336
337 let future = self.request_limiter.stream(async move {
338 let response =
339 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
340 .await?;
341 let stream = response
342 .filter_map(|response| async move {
343 match response {
344 Ok(delta) => {
345 let content = match delta.message {
346 ChatMessage::User { content } => content,
347 ChatMessage::Assistant { content, .. } => content,
348 ChatMessage::System { content } => content,
349 };
350 Some(Ok(content))
351 }
352 Err(error) => Some(Err(error)),
353 }
354 })
355 .boxed();
356 Ok(stream)
357 });
358
359 async move {
360 Ok(future
361 .await?
362 .map(|result| result.map(LanguageModelCompletionEvent::Text))
363 .boxed())
364 }
365 .boxed()
366 }
367
368 fn use_any_tool(
369 &self,
370 request: LanguageModelRequest,
371 tool_name: String,
372 tool_description: String,
373 schema: serde_json::Value,
374 cx: &AsyncAppContext,
375 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
376 use ollama::{OllamaFunctionTool, OllamaTool};
377 let function = OllamaFunctionTool {
378 name: tool_name.clone(),
379 description: Some(tool_description),
380 parameters: Some(schema),
381 };
382 let tools = vec![OllamaTool::Function { function }];
383 let request = self.to_ollama_request(request).with_tools(tools);
384 let response = self.request_completion(request, cx);
385 self.request_limiter
386 .run(async move {
387 let response = response.await?;
388 let ChatMessage::Assistant { tool_calls, .. } = response.message else {
389 bail!("message does not have an assistant role");
390 };
391 if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
392 for call in tool_calls {
393 let OllamaToolCall::Function(function) = call;
394 if function.name == tool_name {
395 return Ok(futures::stream::once(async move {
396 Ok(function.arguments.to_string())
397 })
398 .boxed());
399 }
400 }
401 } else {
402 bail!("assistant message does not have any tool calls");
403 };
404
405 bail!("tool not used")
406 })
407 .boxed()
408 }
409}
410
411struct ConfigurationView {
412 state: gpui::Model<State>,
413 loading_models_task: Option<Task<()>>,
414}
415
416impl ConfigurationView {
417 pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
418 let loading_models_task = Some(cx.spawn({
419 let state = state.clone();
420 |this, mut cx| async move {
421 if let Some(task) = state
422 .update(&mut cx, |state, cx| state.authenticate(cx))
423 .log_err()
424 {
425 task.await.log_err();
426 }
427 this.update(&mut cx, |this, cx| {
428 this.loading_models_task = None;
429 cx.notify();
430 })
431 .log_err();
432 }
433 }));
434
435 Self {
436 state,
437 loading_models_task,
438 }
439 }
440
441 fn retry_connection(&self, cx: &mut WindowContext) {
442 self.state
443 .update(cx, |state, cx| state.fetch_models(cx))
444 .detach_and_log_err(cx);
445 }
446}
447
448impl Render for ConfigurationView {
449 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
450 let is_authenticated = self.state.read(cx).is_authenticated();
451
452 let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
453 let ollama_reqs =
454 "Ollama must be running with at least one model installed to use it in the assistant.";
455
456 let mut inline_code_bg = cx.theme().colors().editor_background;
457 inline_code_bg.fade_out(0.5);
458
459 if self.loading_models_task.is_some() {
460 div().child(Label::new("Loading models...")).into_any()
461 } else {
462 v_flex()
463 .size_full()
464 .gap_3()
465 .child(
466 v_flex()
467 .size_full()
468 .gap_2()
469 .p_1()
470 .child(Label::new(ollama_intro))
471 .child(Label::new(ollama_reqs))
472 .child(
473 h_flex()
474 .gap_0p5()
475 .child(Label::new("Once installed, try "))
476 .child(
477 div()
478 .bg(inline_code_bg)
479 .px_1p5()
480 .rounded_md()
481 .child(Label::new("ollama run llama3.1")),
482 ),
483 ),
484 )
485 .child(
486 h_flex()
487 .w_full()
488 .pt_2()
489 .justify_between()
490 .gap_2()
491 .child(
492 h_flex()
493 .w_full()
494 .gap_2()
495 .map(|this| {
496 if is_authenticated {
497 this.child(
498 Button::new("ollama-site", "Ollama")
499 .style(ButtonStyle::Subtle)
500 .icon(IconName::ExternalLink)
501 .icon_size(IconSize::XSmall)
502 .icon_color(Color::Muted)
503 .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
504 .into_any_element(),
505 )
506 } else {
507 this.child(
508 Button::new(
509 "download_ollama_button",
510 "Download Ollama",
511 )
512 .style(ButtonStyle::Subtle)
513 .icon(IconName::ExternalLink)
514 .icon_size(IconSize::XSmall)
515 .icon_color(Color::Muted)
516 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
517 .into_any_element(),
518 )
519 }
520 })
521 .child(
522 Button::new("view-models", "All Models")
523 .style(ButtonStyle::Subtle)
524 .icon(IconName::ExternalLink)
525 .icon_size(IconSize::XSmall)
526 .icon_color(Color::Muted)
527 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
528 ),
529 )
530 .child(if is_authenticated {
531 // This is only a button to ensure the spacing is correct
532 // it should stay disabled
533 ButtonLike::new("connected")
534 .disabled(true)
535 // Since this won't ever be clickable, we can use the arrow cursor
536 .cursor_style(gpui::CursorStyle::Arrow)
537 .child(
538 h_flex()
539 .gap_2()
540 .child(Indicator::dot().color(Color::Success))
541 .child(Label::new("Connected"))
542 .into_any_element(),
543 )
544 .into_any_element()
545 } else {
546 Button::new("retry_ollama_models", "Connect")
547 .icon_position(IconPosition::Start)
548 .icon(IconName::ArrowCircle)
549 .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
550 .into_any_element()
551 }),
552 )
553 .into_any()
554 }
555 }
556}