diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 0015e498f986c67dd4477a6fb35e8846c8442b9e..13b65cccc79ded8f1f7267063898216defb38908 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -395,7 +395,7 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string defer log.RecoverPanic("agent.Run", func() { slog.Error("panic while generating title") }) - titleErr := a.generateTitle(context.Background(), sessionID, content) + titleErr := a.generateTitle(ctx, sessionID, content) if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) { slog.Error("failed to generate title", "error", titleErr) } @@ -996,11 +996,17 @@ func (a *agent) UpdateModel() error { return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider) } + var maxTitleTokens int64 = 40 + + // if the max output is too low for the gemini provider it won't return anything + if smallModelCfg.Provider == "gemini" { + maxTitleTokens = 1000 + } // Recreate title provider titleOpts := []provider.ProviderClientOption{ provider.WithModel(config.SelectedModelTypeSmall), provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), - provider.WithMaxTokens(40), + provider.WithMaxTokens(maxTitleTokens), } newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...) if err != nil { diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 2e02bd088b57c9434d1d534204b664f3ef7443ed..9d5164973a5ad86b4c0dee001e54b46b838b89e6 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -322,6 +322,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t for _, part := range lastMsg.Parts { lastMsgParts = append(lastMsgParts, *part) } + for resp, err := range chat.SendMessageStream(ctx, lastMsgParts...) { if err != nil { retry, after, retryErr := g.shouldRetry(attempts, err) @@ -385,6 +386,9 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t } } } + } else { + // no content received + break } } @@ -408,6 +412,11 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t }, } return + } else { + eventChan <- ProviderEvent{ + Type: EventError, + Error: errors.New("no content received"), + } } } }()