1use anyhow::{Result, anyhow};
2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
3use futures::{Stream, TryFutureExt, stream};
4use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
5use http_client::HttpClient;
6use language_model::{
7 AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
8 LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
9 LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
10 LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
11 LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
12};
13use ollama::{
14 ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, OllamaFunctionCall,
15 OllamaFunctionTool, OllamaToolCall, get_models, show_model, stream_chat_completion,
16};
17pub use settings::OllamaAvailableModel as AvailableModel;
18use settings::{Settings, SettingsStore};
19use std::pin::Pin;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::{collections::HashMap, sync::Arc};
22use ui::{ButtonLike, Indicator, List, prelude::*};
23use util::ResultExt;
24
25use crate::AllLanguageModelSettings;
26use crate::ui::InstructionListItem;
27
28const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
29const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
30const OLLAMA_SITE: &str = "https://ollama.com/";
31
32const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
33const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
34
35#[derive(Default, Debug, Clone, PartialEq)]
36pub struct OllamaSettings {
37 pub api_url: String,
38 pub available_models: Vec<AvailableModel>,
39}
40
41pub struct OllamaLanguageModelProvider {
42 http_client: Arc<dyn HttpClient>,
43 state: gpui::Entity<State>,
44}
45
46pub struct State {
47 http_client: Arc<dyn HttpClient>,
48 available_models: Vec<ollama::Model>,
49 fetch_model_task: Option<Task<Result<()>>>,
50 _subscription: Subscription,
51}
52
53impl State {
54 fn is_authenticated(&self) -> bool {
55 !self.available_models.is_empty()
56 }
57
58 fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
59 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
60 let http_client = Arc::clone(&self.http_client);
61 let api_url = settings.api_url.clone();
62
63 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
64 cx.spawn(async move |this, cx| {
65 let models = get_models(http_client.as_ref(), &api_url, None).await?;
66
67 let tasks = models
68 .into_iter()
69 // Since there is no metadata from the Ollama API
70 // indicating which models are embedding models,
71 // simply filter out models with "-embed" in their name
72 .filter(|model| !model.name.contains("-embed"))
73 .map(|model| {
74 let http_client = Arc::clone(&http_client);
75 let api_url = api_url.clone();
76 async move {
77 let name = model.name.as_str();
78 let capabilities = show_model(http_client.as_ref(), &api_url, name).await?;
79 let ollama_model = ollama::Model::new(
80 name,
81 None,
82 None,
83 Some(capabilities.supports_tools()),
84 Some(capabilities.supports_vision()),
85 Some(capabilities.supports_thinking()),
86 );
87 Ok(ollama_model)
88 }
89 });
90
91 // Rate-limit capability fetches
92 // since there is an arbitrary number of models available
93 let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
94 .buffer_unordered(5)
95 .collect::<Vec<Result<_>>>()
96 .await
97 .into_iter()
98 .collect::<Result<Vec<_>>>()?;
99
100 ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
101
102 this.update(cx, |this, cx| {
103 this.available_models = ollama_models;
104 cx.notify();
105 })
106 })
107 }
108
109 fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
110 let task = self.fetch_models(cx);
111 self.fetch_model_task.replace(task);
112 }
113
114 fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
115 if self.is_authenticated() {
116 return Task::ready(Ok(()));
117 }
118
119 let fetch_models_task = self.fetch_models(cx);
120 cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
121 }
122}
123
124impl OllamaLanguageModelProvider {
125 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
126 let this = Self {
127 http_client: http_client.clone(),
128 state: cx.new(|cx| {
129 let subscription = cx.observe_global::<SettingsStore>({
130 let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
131 move |this: &mut State, cx| {
132 let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
133 if &settings != new_settings {
134 settings = new_settings.clone();
135 this.restart_fetch_models_task(cx);
136 cx.notify();
137 }
138 }
139 });
140
141 State {
142 http_client,
143 available_models: Default::default(),
144 fetch_model_task: None,
145 _subscription: subscription,
146 }
147 }),
148 };
149 this.state
150 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
151 this
152 }
153}
154
155impl LanguageModelProviderState for OllamaLanguageModelProvider {
156 type ObservableEntity = State;
157
158 fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
159 Some(self.state.clone())
160 }
161}
162
163impl LanguageModelProvider for OllamaLanguageModelProvider {
164 fn id(&self) -> LanguageModelProviderId {
165 PROVIDER_ID
166 }
167
168 fn name(&self) -> LanguageModelProviderName {
169 PROVIDER_NAME
170 }
171
172 fn icon(&self) -> IconName {
173 IconName::AiOllama
174 }
175
176 fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
177 // We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
178 // In a constrained environment where user might not have enough resources it'll be a bad UX to select something
179 // to load by default.
180 None
181 }
182
183 fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
184 // See explanation for default_model.
185 None
186 }
187
188 fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
189 let mut models: HashMap<String, ollama::Model> = HashMap::new();
190
191 // Add models from the Ollama API
192 for model in self.state.read(cx).available_models.iter() {
193 models.insert(model.name.clone(), model.clone());
194 }
195
196 // Override with available models from settings
197 for model in AllLanguageModelSettings::get_global(cx)
198 .ollama
199 .available_models
200 .iter()
201 {
202 models.insert(
203 model.name.clone(),
204 ollama::Model {
205 name: model.name.clone(),
206 display_name: model.display_name.clone(),
207 max_tokens: model.max_tokens,
208 keep_alive: model.keep_alive.clone(),
209 supports_tools: model.supports_tools,
210 supports_vision: model.supports_images,
211 supports_thinking: model.supports_thinking,
212 },
213 );
214 }
215
216 let mut models = models
217 .into_values()
218 .map(|model| {
219 Arc::new(OllamaLanguageModel {
220 id: LanguageModelId::from(model.name.clone()),
221 model,
222 http_client: self.http_client.clone(),
223 request_limiter: RateLimiter::new(4),
224 }) as Arc<dyn LanguageModel>
225 })
226 .collect::<Vec<_>>();
227 models.sort_by_key(|model| model.name());
228 models
229 }
230
231 fn is_authenticated(&self, cx: &App) -> bool {
232 self.state.read(cx).is_authenticated()
233 }
234
235 fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
236 self.state.update(cx, |state, cx| state.authenticate(cx))
237 }
238
239 fn configuration_view(
240 &self,
241 _target_agent: language_model::ConfigurationViewTargetAgent,
242 window: &mut Window,
243 cx: &mut App,
244 ) -> AnyView {
245 let state = self.state.clone();
246 cx.new(|cx| ConfigurationView::new(state, window, cx))
247 .into()
248 }
249
250 fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
251 self.state.update(cx, |state, cx| state.fetch_models(cx))
252 }
253}
254
255pub struct OllamaLanguageModel {
256 id: LanguageModelId,
257 model: ollama::Model,
258 http_client: Arc<dyn HttpClient>,
259 request_limiter: RateLimiter,
260}
261
262impl OllamaLanguageModel {
263 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
264 let supports_vision = self.model.supports_vision.unwrap_or(false);
265
266 let mut messages = Vec::with_capacity(request.messages.len());
267
268 for mut msg in request.messages.into_iter() {
269 let images = if supports_vision {
270 msg.content
271 .iter()
272 .filter_map(|content| match content {
273 MessageContent::Image(image) => Some(image.source.to_string()),
274 _ => None,
275 })
276 .collect::<Vec<String>>()
277 } else {
278 vec![]
279 };
280
281 match msg.role {
282 Role::User => {
283 for tool_result in msg
284 .content
285 .extract_if(.., |x| matches!(x, MessageContent::ToolResult(..)))
286 {
287 match tool_result {
288 MessageContent::ToolResult(tool_result) => {
289 messages.push(ChatMessage::Tool {
290 tool_name: tool_result.tool_name.to_string(),
291 content: tool_result.content.to_str().unwrap_or("").to_string(),
292 })
293 }
294 _ => unreachable!("Only tool result should be extracted"),
295 }
296 }
297 if !msg.content.is_empty() {
298 messages.push(ChatMessage::User {
299 content: msg.string_contents(),
300 images: if images.is_empty() {
301 None
302 } else {
303 Some(images)
304 },
305 })
306 }
307 }
308 Role::Assistant => {
309 let content = msg.string_contents();
310 let mut thinking = None;
311 let mut tool_calls = Vec::new();
312 for content in msg.content.into_iter() {
313 match content {
314 MessageContent::Thinking { text, .. } if !text.is_empty() => {
315 thinking = Some(text)
316 }
317 MessageContent::ToolUse(tool_use) => {
318 tool_calls.push(OllamaToolCall::Function(OllamaFunctionCall {
319 name: tool_use.name.to_string(),
320 arguments: tool_use.input,
321 }));
322 }
323 _ => (),
324 }
325 }
326 messages.push(ChatMessage::Assistant {
327 content,
328 tool_calls: Some(tool_calls),
329 images: if images.is_empty() {
330 None
331 } else {
332 Some(images)
333 },
334 thinking,
335 })
336 }
337 Role::System => messages.push(ChatMessage::System {
338 content: msg.string_contents(),
339 }),
340 }
341 }
342 ChatRequest {
343 model: self.model.name.clone(),
344 messages,
345 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
346 stream: true,
347 options: Some(ChatOptions {
348 num_ctx: Some(self.model.max_tokens),
349 stop: Some(request.stop),
350 temperature: request.temperature.or(Some(1.0)),
351 ..Default::default()
352 }),
353 think: self
354 .model
355 .supports_thinking
356 .map(|supports_thinking| supports_thinking && request.thinking_allowed),
357 tools: if self.model.supports_tools.unwrap_or(false) {
358 request.tools.into_iter().map(tool_into_ollama).collect()
359 } else {
360 vec![]
361 },
362 }
363 }
364}
365
366impl LanguageModel for OllamaLanguageModel {
367 fn id(&self) -> LanguageModelId {
368 self.id.clone()
369 }
370
371 fn name(&self) -> LanguageModelName {
372 LanguageModelName::from(self.model.display_name().to_string())
373 }
374
375 fn provider_id(&self) -> LanguageModelProviderId {
376 PROVIDER_ID
377 }
378
379 fn provider_name(&self) -> LanguageModelProviderName {
380 PROVIDER_NAME
381 }
382
383 fn supports_tools(&self) -> bool {
384 self.model.supports_tools.unwrap_or(false)
385 }
386
387 fn supports_images(&self) -> bool {
388 self.model.supports_vision.unwrap_or(false)
389 }
390
391 fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
392 match choice {
393 LanguageModelToolChoice::Auto => false,
394 LanguageModelToolChoice::Any => false,
395 LanguageModelToolChoice::None => false,
396 }
397 }
398
399 fn telemetry_id(&self) -> String {
400 format!("ollama/{}", self.model.id())
401 }
402
403 fn max_token_count(&self) -> u64 {
404 self.model.max_token_count()
405 }
406
407 fn count_tokens(
408 &self,
409 request: LanguageModelRequest,
410 _cx: &App,
411 ) -> BoxFuture<'static, Result<u64>> {
412 // There is no endpoint for this _yet_ in Ollama
413 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
414 let token_count = request
415 .messages
416 .iter()
417 .map(|msg| msg.string_contents().chars().count())
418 .sum::<usize>()
419 / 4;
420
421 async move { Ok(token_count as u64) }.boxed()
422 }
423
424 fn stream_completion(
425 &self,
426 request: LanguageModelRequest,
427 cx: &AsyncApp,
428 ) -> BoxFuture<
429 'static,
430 Result<
431 BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
432 LanguageModelCompletionError,
433 >,
434 > {
435 let request = self.to_ollama_request(request);
436
437 let http_client = self.http_client.clone();
438 let Ok(api_url) = cx.update(|cx| {
439 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
440 settings.api_url.clone()
441 }) else {
442 return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
443 };
444
445 let future = self.request_limiter.stream(async move {
446 let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
447 let stream = map_to_language_model_completion_events(stream);
448 Ok(stream)
449 });
450
451 future.map_ok(|f| f.boxed()).boxed()
452 }
453}
454
455fn map_to_language_model_completion_events(
456 stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
457) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
458 // Used for creating unique tool use ids
459 static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
460
461 struct State {
462 stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
463 used_tools: bool,
464 }
465
466 // We need to create a ToolUse and Stop event from a single
467 // response from the original stream
468 let stream = stream::unfold(
469 State {
470 stream,
471 used_tools: false,
472 },
473 async move |mut state| {
474 let response = state.stream.next().await?;
475
476 let delta = match response {
477 Ok(delta) => delta,
478 Err(e) => {
479 let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
480 return Some((vec![event], state));
481 }
482 };
483
484 let mut events = Vec::new();
485
486 match delta.message {
487 ChatMessage::User { content, images: _ } => {
488 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
489 }
490 ChatMessage::System { content } => {
491 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
492 }
493 ChatMessage::Tool { content, .. } => {
494 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
495 }
496 ChatMessage::Assistant {
497 content,
498 tool_calls,
499 images: _,
500 thinking,
501 } => {
502 if let Some(text) = thinking {
503 events.push(Ok(LanguageModelCompletionEvent::Thinking {
504 text,
505 signature: None,
506 }));
507 }
508
509 if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) {
510 match tool_call {
511 OllamaToolCall::Function(function) => {
512 let tool_id = format!(
513 "{}-{}",
514 &function.name,
515 TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed)
516 );
517 let event =
518 LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
519 id: LanguageModelToolUseId::from(tool_id),
520 name: Arc::from(function.name),
521 raw_input: function.arguments.to_string(),
522 input: function.arguments,
523 is_input_complete: true,
524 });
525 events.push(Ok(event));
526 state.used_tools = true;
527 }
528 }
529 } else if !content.is_empty() {
530 events.push(Ok(LanguageModelCompletionEvent::Text(content)));
531 }
532 }
533 };
534
535 if delta.done {
536 events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
537 input_tokens: delta.prompt_eval_count.unwrap_or(0),
538 output_tokens: delta.eval_count.unwrap_or(0),
539 cache_creation_input_tokens: 0,
540 cache_read_input_tokens: 0,
541 })));
542 if state.used_tools {
543 state.used_tools = false;
544 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
545 } else {
546 events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
547 }
548 }
549
550 Some((events, state))
551 },
552 );
553
554 stream.flat_map(futures::stream::iter)
555}
556
557struct ConfigurationView {
558 state: gpui::Entity<State>,
559 loading_models_task: Option<Task<()>>,
560}
561
562impl ConfigurationView {
563 pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
564 let loading_models_task = Some(cx.spawn_in(window, {
565 let state = state.clone();
566 async move |this, cx| {
567 if let Some(task) = state
568 .update(cx, |state, cx| state.authenticate(cx))
569 .log_err()
570 {
571 task.await.log_err();
572 }
573 this.update(cx, |this, cx| {
574 this.loading_models_task = None;
575 cx.notify();
576 })
577 .log_err();
578 }
579 }));
580
581 Self {
582 state,
583 loading_models_task,
584 }
585 }
586
587 fn retry_connection(&self, cx: &mut App) {
588 self.state
589 .update(cx, |state, cx| state.fetch_models(cx))
590 .detach_and_log_err(cx);
591 }
592}
593
594impl Render for ConfigurationView {
595 fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
596 let is_authenticated = self.state.read(cx).is_authenticated();
597
598 let ollama_intro =
599 "Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama.";
600
601 if self.loading_models_task.is_some() {
602 div().child(Label::new("Loading models...")).into_any()
603 } else {
604 v_flex()
605 .gap_2()
606 .child(
607 v_flex().gap_1().child(Label::new(ollama_intro)).child(
608 List::new()
609 .child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant."))
610 .child(InstructionListItem::text_only(
611 "Once installed, try `ollama run llama3.2`",
612 )),
613 ),
614 )
615 .child(
616 h_flex()
617 .w_full()
618 .justify_between()
619 .gap_2()
620 .child(
621 h_flex()
622 .w_full()
623 .gap_2()
624 .map(|this| {
625 if is_authenticated {
626 this.child(
627 Button::new("ollama-site", "Ollama")
628 .style(ButtonStyle::Subtle)
629 .icon(IconName::ArrowUpRight)
630 .icon_size(IconSize::Small)
631 .icon_color(Color::Muted)
632 .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
633 .into_any_element(),
634 )
635 } else {
636 this.child(
637 Button::new(
638 "download_ollama_button",
639 "Download Ollama",
640 )
641 .style(ButtonStyle::Subtle)
642 .icon(IconName::ArrowUpRight)
643 .icon_size(IconSize::Small)
644 .icon_color(Color::Muted)
645 .on_click(move |_, _, cx| {
646 cx.open_url(OLLAMA_DOWNLOAD_URL)
647 })
648 .into_any_element(),
649 )
650 }
651 })
652 .child(
653 Button::new("view-models", "View All Models")
654 .style(ButtonStyle::Subtle)
655 .icon(IconName::ArrowUpRight)
656 .icon_size(IconSize::Small)
657 .icon_color(Color::Muted)
658 .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
659 ),
660 )
661 .map(|this| {
662 if is_authenticated {
663 this.child(
664 ButtonLike::new("connected")
665 .disabled(true)
666 .cursor_style(gpui::CursorStyle::Arrow)
667 .child(
668 h_flex()
669 .gap_2()
670 .child(Indicator::dot().color(Color::Success))
671 .child(Label::new("Connected"))
672 .into_any_element(),
673 ),
674 )
675 } else {
676 this.child(
677 Button::new("retry_ollama_models", "Connect")
678 .icon_position(IconPosition::Start)
679 .icon_size(IconSize::XSmall)
680 .icon(IconName::PlayFilled)
681 .on_click(cx.listener(move |this, _, _, cx| {
682 this.retry_connection(cx)
683 })),
684 )
685 }
686 })
687 )
688 .into_any()
689 }
690 }
691}
692
693fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
694 ollama::OllamaTool::Function {
695 function: OllamaFunctionTool {
696 name: tool.name,
697 description: Some(tool.description),
698 parameters: Some(tool.input_schema),
699 },
700 }
701}