Detailed changes
@@ -423,7 +423,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
return getSessionErr
}
a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
- _, sessionErr := a.sessions.Save(ctx, updatedSession)
+ _, sessionErr := a.sessions.SaveWithModels(ctx, updatedSession, map[config.SelectedModelType]config.SelectedModel{
+ config.SelectedModelTypeLarge: a.largeModel.Get().ModelCfg,
+ config.SelectedModelTypeSmall: a.smallModel.Get().ModelCfg,
+ })
if sessionErr != nil {
return sessionErr
}
@@ -723,7 +726,10 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
currentSession.SummaryMessageID = summaryMessage.ID
currentSession.CompletionTokens = usage.OutputTokens
currentSession.PromptTokens = 0
- _, err = a.sessions.Save(genCtx, currentSession)
+ _, err = a.sessions.SaveWithModels(genCtx, currentSession, map[config.SelectedModelType]config.SelectedModel{
+ config.SelectedModelTypeLarge: a.largeModel.Get().ModelCfg,
+ config.SelectedModelTypeSmall: a.smallModel.Get().ModelCfg,
+ })
return err
}
@@ -476,7 +476,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
tools.NewSourcegraphTool(nil),
- tools.NewTodosTool(c.sessions),
+ tools.NewTodosTool(c.sessions, c.cfg),
tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.skillTracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
)
@@ -1051,7 +1051,7 @@ func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionI
parentSession.Cost += childSession.Cost
- if _, err := c.sessions.Save(ctx, parentSession); err != nil {
+ if _, err := c.sessions.SaveWithModels(ctx, parentSession, c.cfg.Config().Models); err != nil {
return fmt.Errorf("save parent session: %w", err)
}
@@ -6,6 +6,7 @@ import (
"fmt"
"charm.land/fantasy"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/session"
)
@@ -33,7 +34,7 @@ type TodosResponseMetadata struct {
Total int `json:"total"`
}
-func NewTodosTool(sessions session.Service) fantasy.AgentTool {
+func NewTodosTool(sessions session.Service, cfg *config.ConfigStore) fantasy.AgentTool {
return fantasy.NewAgentTool(
TodosToolName,
FirstLineDescription(todosDescription),
@@ -96,7 +97,7 @@ func NewTodosTool(sessions session.Service) fantasy.AgentTool {
}
currentSession.Todos = todos
- _, err = sessions.Save(ctx, currentSession)
+ _, err = sessions.SaveWithModels(ctx, currentSession, cfg.Config().Models)
if err != nil {
return fantasy.ToolResponse{}, fmt.Errorf("failed to save todos: %w", err)
}
@@ -7,6 +7,7 @@ import (
"strings"
"testing"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/session"
"github.com/stretchr/testify/require"
@@ -60,6 +61,14 @@ func (m *mockSessionService) Save(_ context.Context, s session.Session) (session
return s, nil
}
+func (m *mockSessionService) SaveWithModels(_ context.Context, s session.Session, _ map[config.SelectedModelType]config.SelectedModel) (session.Session, error) {
+ return s, nil
+}
+
+func (m *mockSessionService) UpdateSessionModels(context.Context, string, map[config.SelectedModelType]config.SelectedModel) error {
+ return nil
+}
+
func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error {
return nil
}
@@ -3,6 +3,7 @@ package backend
import (
"context"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/session"
@@ -105,6 +106,15 @@ func (b *Backend) DeleteSession(ctx context.Context, workspaceID, sessionID stri
return ws.Sessions.Delete(ctx, sessionID)
}
+// UpdateSessionModels updates the models for a session in the given workspace.
+func (b *Backend) UpdateSessionModels(ctx context.Context, workspaceID, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error {
+ ws, err := b.GetWorkspace(workspaceID)
+ if err != nil {
+ return err
+ }
+ return ws.Sessions.UpdateSessionModels(ctx, sessionID, models)
+}
+
// ListUserMessages returns user-role messages for a session.
func (b *Backend) ListUserMessages(ctx context.Context, workspaceID, sessionID string) ([]message.Message, error) {
ws, err := b.GetWorkspace(workspaceID)
@@ -585,6 +585,23 @@ func (c *Client) DeleteSession(ctx context.Context, id string, sessionID string)
return nil
}
+// UpdateSessionModels updates the models for a session.
+func (c *Client) UpdateSessionModels(ctx context.Context, workspaceID, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error {
+ rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/models", workspaceID, sessionID), nil,
+ jsonBody(struct {
+ Models map[config.SelectedModelType]config.SelectedModel `json:"models"`
+ }{Models: models}),
+ http.Header{"Content-Type": []string{"application/json"}})
+ if err != nil {
+ return fmt.Errorf("failed to update session models: %w", err)
+ }
+ defer rsp.Body.Close()
+ if rsp.StatusCode != http.StatusNoContent {
+ return fmt.Errorf("failed to update session models: status code %d", rsp.StatusCode)
+ }
+ return nil
+}
+
// ListUserMessages retrieves user-role messages for a session as proto types.
func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID string) ([]proto.Message, error) {
rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages/user", id, sessionID), nil, nil)
@@ -89,6 +89,27 @@ type SelectedModel struct {
ProviderOptions map[string]any `json:"provider_options,omitempty" jsonschema:"description=Additional provider-specific options for the model"`
}
+func (a SelectedModel) Equal(b SelectedModel) bool {
+ return a.Model == b.Model &&
+ a.Provider == b.Provider &&
+ a.ReasoningEffort == b.ReasoningEffort &&
+ a.Think == b.Think &&
+ a.MaxTokens == b.MaxTokens &&
+ ptrEqual(a.Temperature, b.Temperature) &&
+ ptrEqual(a.TopP, b.TopP) &&
+ ptrEqual(a.TopK, b.TopK) &&
+ ptrEqual(a.FrequencyPenalty, b.FrequencyPenalty) &&
+ ptrEqual(a.PresencePenalty, b.PresencePenalty) &&
+ maps.Equal(a.ProviderOptions, b.ProviderOptions)
+}
+
+func ptrEqual[T comparable](a, b *T) bool {
+ if a == nil || b == nil {
+ return a == b
+ }
+ return *a == *b
+}
+
type ProviderConfig struct {
// The provider's id.
ID string `json:"id,omitempty" jsonschema:"description=Unique identifier for the provider,example=openai"`
@@ -0,0 +1,143 @@
+package config
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestSelectedModelEqual(t *testing.T) {
+ t.Parallel()
+
+ t.Run("equal models", func(t *testing.T) {
+ t.Parallel()
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+ b := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+ require.True(t, a.Equal(b))
+ })
+
+ t.Run("different model", func(t *testing.T) {
+ t.Parallel()
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+ b := SelectedModel{Model: "gpt-4o-mini", Provider: "openai"}
+ require.False(t, a.Equal(b))
+ })
+
+ t.Run("different provider", func(t *testing.T) {
+ t.Parallel()
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+ b := SelectedModel{Model: "gpt-4o", Provider: "anthropic"}
+ require.False(t, a.Equal(b))
+ })
+
+ t.Run("different reasoning effort", func(t *testing.T) {
+ t.Parallel()
+ a := SelectedModel{Model: "o3", Provider: "openai", ReasoningEffort: "high"}
+ b := SelectedModel{Model: "o3", Provider: "openai", ReasoningEffort: "low"}
+ require.False(t, a.Equal(b))
+ })
+
+ t.Run("different think", func(t *testing.T) {
+ t.Parallel()
+ a := SelectedModel{Model: "claude-sonnet", Provider: "anthropic", Think: true}
+ b := SelectedModel{Model: "claude-sonnet", Provider: "anthropic", Think: false}
+ require.False(t, a.Equal(b))
+ })
+
+ t.Run("different max tokens", func(t *testing.T) {
+ t.Parallel()
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai", MaxTokens: 4096}
+ b := SelectedModel{Model: "gpt-4o", Provider: "openai", MaxTokens: 8192}
+ require.False(t, a.Equal(b))
+ })
+
+ t.Run("both nil pointers", func(t *testing.T) {
+ t.Parallel()
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+ b := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+ require.True(t, a.Equal(b))
+ })
+
+ t.Run("one nil one non-nil pointer", func(t *testing.T) {
+ t.Parallel()
+ temp := 0.7
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp}
+ b := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+ require.False(t, a.Equal(b))
+ require.False(t, b.Equal(a))
+ })
+
+ t.Run("both non-nil equal pointers", func(t *testing.T) {
+ t.Parallel()
+ temp := 0.7
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp}
+ b := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp}
+ require.True(t, a.Equal(b))
+ })
+
+ t.Run("both non-nil different pointers", func(t *testing.T) {
+ t.Parallel()
+ t1 := 0.7
+ t2 := 0.9
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &t1}
+ b := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &t2}
+ require.False(t, a.Equal(b))
+ })
+
+ t.Run("nil ProviderOptions", func(t *testing.T) {
+ t.Parallel()
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+ b := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+ require.True(t, a.Equal(b))
+ })
+
+ t.Run("empty ProviderOptions", func(t *testing.T) {
+ t.Parallel()
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{}}
+ b := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{}}
+ require.True(t, a.Equal(b))
+ })
+
+ t.Run("different ProviderOptions", func(t *testing.T) {
+ t.Parallel()
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "a"}}
+ b := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "b"}}
+ require.False(t, a.Equal(b))
+ })
+
+ t.Run("one nil one non-nil ProviderOptions", func(t *testing.T) {
+ t.Parallel()
+ a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "a"}}
+ b := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+ require.False(t, a.Equal(b))
+ })
+}
+
+func TestPtrEqual(t *testing.T) {
+ t.Parallel()
+
+ t.Run("both nil", func(t *testing.T) {
+ t.Parallel()
+ require.True(t, ptrEqual[int](nil, nil))
+ })
+
+ t.Run("one nil", func(t *testing.T) {
+ t.Parallel()
+ v := 42
+ require.False(t, ptrEqual(&v, nil))
+ require.False(t, ptrEqual(nil, &v))
+ })
+
+ t.Run("both equal", func(t *testing.T) {
+ t.Parallel()
+ v := 42
+ require.True(t, ptrEqual(&v, &v))
+ })
+
+ t.Run("both different", func(t *testing.T) {
+ t.Parallel()
+ a := 42
+ b := 43
+ require.False(t, ptrEqual(&a, &b))
+ })
+}
@@ -11,10 +11,10 @@ import (
)
type DBTX interface {
- ExecContext(context.Context, string, ...any) (sql.Result, error)
+ ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
PrepareContext(context.Context, string) (*sql.Stmt, error)
- QueryContext(context.Context, string, ...any) (*sql.Rows, error)
- QueryRowContext(context.Context, string, ...any) *sql.Row
+ QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
+ QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
func New(db DBTX) *Queries {
@@ -132,6 +132,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.updateSessionStmt, err = db.PrepareContext(ctx, updateSession); err != nil {
return nil, fmt.Errorf("error preparing query UpdateSession: %w", err)
}
+ if q.updateSessionModelsStmt, err = db.PrepareContext(ctx, updateSessionModels); err != nil {
+ return nil, fmt.Errorf("error preparing query UpdateSessionModels: %w", err)
+ }
if q.updateSessionTitleAndUsageStmt, err = db.PrepareContext(ctx, updateSessionTitleAndUsage); err != nil {
return nil, fmt.Errorf("error preparing query UpdateSessionTitleAndUsage: %w", err)
}
@@ -320,6 +323,11 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing updateSessionStmt: %w", cerr)
}
}
+ if q.updateSessionModelsStmt != nil {
+ if cerr := q.updateSessionModelsStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing updateSessionModelsStmt: %w", cerr)
+ }
+ }
if q.updateSessionTitleAndUsageStmt != nil {
if cerr := q.updateSessionTitleAndUsageStmt.Close(); cerr != nil {
err = fmt.Errorf("error closing updateSessionTitleAndUsageStmt: %w", cerr)
@@ -400,6 +408,7 @@ type Queries struct {
renameSessionStmt *sql.Stmt
updateMessageStmt *sql.Stmt
updateSessionStmt *sql.Stmt
+ updateSessionModelsStmt *sql.Stmt
updateSessionTitleAndUsageStmt *sql.Stmt
}
@@ -443,6 +452,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
renameSessionStmt: q.renameSessionStmt,
updateMessageStmt: q.updateMessageStmt,
updateSessionStmt: q.updateSessionStmt,
+ updateSessionModelsStmt: q.updateSessionModelsStmt,
updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt,
}
}
@@ -0,0 +1,9 @@
+-- +goose Up
+-- +goose StatementBegin
+ALTER TABLE sessions ADD COLUMN models TEXT;
+-- +goose StatementEnd
+
+-- +goose Down
+-- +goose StatementBegin
+ALTER TABLE sessions DROP COLUMN models;
+-- +goose StatementEnd
@@ -49,4 +49,5 @@ type Session struct {
CreatedAt int64 `json:"created_at"`
SummaryMessageID sql.NullString `json:"summary_message_id"`
Todos sql.NullString `json:"todos"`
+ Models sql.NullString `json:"models"`
}
@@ -45,6 +45,7 @@ type Querier interface {
RenameSession(ctx context.Context, arg RenameSessionParams) error
UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
+ UpdateSessionModels(ctx context.Context, arg UpdateSessionModelsParams) (Session, error)
UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error
}
@@ -33,7 +33,7 @@ INSERT INTO sessions (
null,
strftime('%s', 'now'),
strftime('%s', 'now')
-) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
`
type CreateSessionParams struct {
@@ -69,6 +69,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
&i.CreatedAt,
&i.SummaryMessageID,
&i.Todos,
+ &i.Models,
)
return i, err
}
@@ -84,7 +85,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error {
}
const getLastSession = `-- name: GetLastSession :one
-SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
FROM sessions
ORDER BY updated_at DESC
LIMIT 1
@@ -105,12 +106,13 @@ func (q *Queries) GetLastSession(ctx context.Context) (Session, error) {
&i.CreatedAt,
&i.SummaryMessageID,
&i.Todos,
+ &i.Models,
)
return i, err
}
const getSessionByID = `-- name: GetSessionByID :one
-SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
FROM sessions
WHERE id = ? LIMIT 1
`
@@ -130,12 +132,13 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
&i.CreatedAt,
&i.SummaryMessageID,
&i.Todos,
+ &i.Models,
)
return i, err
}
const listSessions = `-- name: ListSessions :many
-SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
FROM sessions
WHERE parent_session_id is NULL
ORDER BY updated_at DESC
@@ -162,6 +165,7 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
&i.CreatedAt,
&i.SummaryMessageID,
&i.Todos,
+ &i.Models,
); err != nil {
return nil, err
}
@@ -203,7 +207,7 @@ SET
cost = ?,
todos = ?
WHERE id = ?
-RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
`
type UpdateSessionParams struct {
@@ -239,6 +243,39 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
&i.CreatedAt,
&i.SummaryMessageID,
&i.Todos,
+ &i.Models,
+ )
+ return i, err
+}
+
+const updateSessionModels = `-- name: UpdateSessionModels :one
+UPDATE sessions
+SET models = ?
+WHERE id = ?
+RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
+`
+
+type UpdateSessionModelsParams struct {
+ Models sql.NullString `json:"models"`
+ ID string `json:"id"`
+}
+
+func (q *Queries) UpdateSessionModels(ctx context.Context, arg UpdateSessionModelsParams) (Session, error) {
+ row := q.queryRow(ctx, q.updateSessionModelsStmt, updateSessionModels, arg.Models, arg.ID)
+ var i Session
+ err := row.Scan(
+ &i.ID,
+ &i.ParentSessionID,
+ &i.Title,
+ &i.MessageCount,
+ &i.PromptTokens,
+ &i.CompletionTokens,
+ &i.Cost,
+ &i.UpdatedAt,
+ &i.CreatedAt,
+ &i.SummaryMessageID,
+ &i.Todos,
+ &i.Models,
)
return i, err
}
@@ -69,6 +69,12 @@ SET
title = ?
WHERE id = ?;
+-- name: UpdateSessionModels :one
+UPDATE sessions
+SET models = ?
+WHERE id = ?
+RETURNING *;
+
-- name: DeleteSession :exec
DELETE FROM sessions
WHERE id = ?;
@@ -56,7 +56,7 @@ type AgentSession struct {
// IsZero checks if the AgentSession is zero-valued.
func (a AgentSession) IsZero() bool {
- return a == AgentSession{}
+ return a.ID == "" && !a.IsBusy
}
// PermissionAction represents an action taken on a permission request.
@@ -1,15 +1,39 @@
package proto
+// SelectedModelType represents the type of model selection (large or small).
+type SelectedModelType string
+
+const (
+ SelectedModelTypeLarge SelectedModelType = "large"
+ SelectedModelTypeSmall SelectedModelType = "small"
+)
+
+// SelectedModel represents a model selection with provider and configuration.
+type SelectedModel struct {
+ Model string `json:"model"`
+ Provider string `json:"provider"`
+ ReasoningEffort string `json:"reasoning_effort,omitempty"`
+ Think bool `json:"think,omitempty"`
+ MaxTokens int64 `json:"max_tokens,omitempty"`
+ Temperature *float64 `json:"temperature,omitempty"`
+ TopP *float64 `json:"top_p,omitempty"`
+ TopK *int64 `json:"top_k,omitempty"`
+ FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
+ PresencePenalty *float64 `json:"presence_penalty,omitempty"`
+ ProviderOptions map[string]any `json:"provider_options,omitempty"`
+}
+
// Session represents a session in the proto layer.
type Session struct {
- ID string `json:"id"`
- ParentSessionID string `json:"parent_session_id"`
- Title string `json:"title"`
- MessageCount int64 `json:"message_count"`
- PromptTokens int64 `json:"prompt_tokens"`
- CompletionTokens int64 `json:"completion_tokens"`
- SummaryMessageID string `json:"summary_message_id"`
- Cost float64 `json:"cost"`
- CreatedAt int64 `json:"created_at"`
- UpdatedAt int64 `json:"updated_at"`
+ ID string `json:"id"`
+ ParentSessionID string `json:"parent_session_id"`
+ Title string `json:"title"`
+ MessageCount int64 `json:"message_count"`
+ PromptTokens int64 `json:"prompt_tokens"`
+ CompletionTokens int64 `json:"completion_tokens"`
+ SummaryMessageID string `json:"summary_message_id"`
+ Cost float64 `json:"cost"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+ Models map[SelectedModelType]SelectedModel `json:"models,omitempty"`
}
@@ -0,0 +1,168 @@
+package proto
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestModelsRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ t.Run("nil models", func(t *testing.T) {
+ t.Parallel()
+ s := Session{
+ ID: "test-id",
+ Title: "Test",
+ }
+ data, err := json.Marshal(s)
+ require.NoError(t, err)
+ var decoded Session
+ require.NoError(t, json.Unmarshal(data, &decoded))
+ require.Nil(t, decoded.Models)
+ })
+
+ t.Run("populated models", func(t *testing.T) {
+ t.Parallel()
+ temp := 0.7
+ s := Session{
+ ID: "test-id",
+ Title: "Test",
+ Models: map[SelectedModelType]SelectedModel{
+ SelectedModelTypeLarge: {
+ Model: "gpt-4o",
+ Provider: "openai",
+ ReasoningEffort: "high",
+ Think: true,
+ MaxTokens: 4096,
+ Temperature: &temp,
+ ProviderOptions: map[string]any{"key": "value"},
+ },
+ SelectedModelTypeSmall: {
+ Model: "gpt-4o-mini",
+ Provider: "openai",
+ },
+ },
+ }
+ data, err := json.Marshal(s)
+ require.NoError(t, err)
+ var decoded Session
+ require.NoError(t, json.Unmarshal(data, &decoded))
+ require.Equal(t, "gpt-4o", decoded.Models[SelectedModelTypeLarge].Model)
+ require.Equal(t, "openai", decoded.Models[SelectedModelTypeLarge].Provider)
+ require.Equal(t, "high", decoded.Models[SelectedModelTypeLarge].ReasoningEffort)
+ require.True(t, decoded.Models[SelectedModelTypeLarge].Think)
+ require.Equal(t, int64(4096), decoded.Models[SelectedModelTypeLarge].MaxTokens)
+ require.NotNil(t, decoded.Models[SelectedModelTypeLarge].Temperature)
+ require.Equal(t, 0.7, *decoded.Models[SelectedModelTypeLarge].Temperature)
+ require.Equal(t, "gpt-4o-mini", decoded.Models[SelectedModelTypeSmall].Model)
+ })
+
+ t.Run("empty map models", func(t *testing.T) {
+ t.Parallel()
+ s := Session{
+ ID: "test-id",
+ Title: "Test",
+ Models: map[SelectedModelType]SelectedModel{},
+ }
+ data, err := json.Marshal(s)
+ require.NoError(t, err)
+ var decoded Session
+ require.NoError(t, json.Unmarshal(data, &decoded))
+ // Empty map with omitempty is dropped during marshaling.
+ require.Nil(t, decoded.Models)
+ })
+}
+
+func TestProtoToDomainRoundTrip(t *testing.T) {
+ t.Parallel()
+
+ t.Run("models through proto", func(t *testing.T) {
+ t.Parallel()
+ temp := 0.7
+ domainModels := map[config.SelectedModelType]config.SelectedModel{
+ config.SelectedModelTypeLarge: {
+ Model: "gpt-4o",
+ Provider: "openai",
+ ReasoningEffort: "high",
+ Think: true,
+ MaxTokens: 4096,
+ Temperature: &temp,
+ ProviderOptions: map[string]any{"key": "value"},
+ },
+ }
+
+ // Domain → Proto
+ protoModels := convertModelsToProtoLocal(domainModels)
+ require.Equal(t, SelectedModelTypeLarge, SelectedModelType(config.SelectedModelTypeLarge))
+ require.Equal(t, "gpt-4o", protoModels[SelectedModelTypeLarge].Model)
+ require.Equal(t, "openai", protoModels[SelectedModelTypeLarge].Provider)
+
+ // Proto → Domain
+ result := convertModelsFromProtoLocal(protoModels)
+ require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model)
+ require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider)
+ require.Equal(t, "high", result[config.SelectedModelTypeLarge].ReasoningEffort)
+ require.True(t, result[config.SelectedModelTypeLarge].Think)
+ require.Equal(t, int64(4096), result[config.SelectedModelTypeLarge].MaxTokens)
+ require.NotNil(t, result[config.SelectedModelTypeLarge].Temperature)
+ require.Equal(t, 0.7, *result[config.SelectedModelTypeLarge].Temperature)
+ })
+
+ t.Run("nil models round-trip", func(t *testing.T) {
+ t.Parallel()
+ protoModels := convertModelsToProtoLocal(nil)
+ require.Nil(t, protoModels)
+
+ domainModels := convertModelsFromProtoLocal(nil)
+ require.Nil(t, domainModels)
+ })
+}
+
+func convertModelsToProtoLocal(models map[config.SelectedModelType]config.SelectedModel) map[SelectedModelType]SelectedModel {
+ if models == nil {
+ return nil
+ }
+ result := make(map[SelectedModelType]SelectedModel, len(models))
+ for k, v := range models {
+ result[SelectedModelType(k)] = SelectedModel{
+ Model: v.Model,
+ Provider: v.Provider,
+ ReasoningEffort: v.ReasoningEffort,
+ Think: v.Think,
+ MaxTokens: v.MaxTokens,
+ Temperature: v.Temperature,
+ TopP: v.TopP,
+ TopK: v.TopK,
+ FrequencyPenalty: v.FrequencyPenalty,
+ PresencePenalty: v.PresencePenalty,
+ ProviderOptions: v.ProviderOptions,
+ }
+ }
+ return result
+}
+
+func convertModelsFromProtoLocal(models map[SelectedModelType]SelectedModel) map[config.SelectedModelType]config.SelectedModel {
+ if models == nil {
+ return nil
+ }
+ result := make(map[config.SelectedModelType]config.SelectedModel, len(models))
+ for k, v := range models {
+ result[config.SelectedModelType(k)] = config.SelectedModel{
+ Model: v.Model,
+ Provider: v.Provider,
+ ReasoningEffort: v.ReasoningEffort,
+ Think: v.Think,
+ MaxTokens: v.MaxTokens,
+ Temperature: v.Temperature,
+ TopP: v.TopP,
+ TopK: v.TopK,
+ FrequencyPenalty: v.FrequencyPenalty,
+ PresencePenalty: v.PresencePenalty,
+ ProviderOptions: v.ProviderOptions,
+ }
+ }
+ return result
+}
@@ -8,6 +8,7 @@ import (
"github.com/charmbracelet/crush/internal/agent/notify"
"github.com/charmbracelet/crush/internal/agent/tools/mcp"
"github.com/charmbracelet/crush/internal/app"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/history"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
@@ -137,9 +138,33 @@ func sessionToProto(s session.Session) proto.Session {
Cost: s.Cost,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
+ Models: convertModelsToProto(s.Models),
}
}
+func convertModelsToProto(models map[config.SelectedModelType]config.SelectedModel) map[proto.SelectedModelType]proto.SelectedModel {
+ if models == nil {
+ return nil
+ }
+ result := make(map[proto.SelectedModelType]proto.SelectedModel, len(models))
+ for k, v := range models {
+ result[proto.SelectedModelType(k)] = proto.SelectedModel{
+ Model: v.Model,
+ Provider: v.Provider,
+ ReasoningEffort: v.ReasoningEffort,
+ Think: v.Think,
+ MaxTokens: v.MaxTokens,
+ Temperature: v.Temperature,
+ TopP: v.TopP,
+ TopK: v.TopK,
+ FrequencyPenalty: v.FrequencyPenalty,
+ PresencePenalty: v.PresencePenalty,
+ ProviderOptions: v.ProviderOptions,
+ }
+ }
+ return result
+}
+
func fileToProto(f history.File) proto.File {
return proto.File{
ID: f.ID,
@@ -7,6 +7,7 @@ import (
"net/http"
"github.com/charmbracelet/crush/internal/backend"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/session"
)
@@ -460,6 +461,39 @@ func (c *controllerV1) handleDeleteWorkspaceSession(w http.ResponseWriter, r *ht
w.WriteHeader(http.StatusOK)
}
+// handlePostWorkspaceSessionModels updates the models for a session.
+//
+// @Summary Update session models
+// @Tags sessions
+// @Accept json
+// @Param id path string true "Workspace ID"
+// @Param sid path string true "Session ID"
+// @Param body body map[config.SelectedModelType]config.SelectedModel true "Models"
+// @Success 204
+// @Failure 400 {object} proto.Error
+// @Failure 500 {object} proto.Error
+// @Router /workspaces/{id}/sessions/{sid}/models [post]
+func (c *controllerV1) handlePostWorkspaceSessionModels(w http.ResponseWriter, r *http.Request) {
+ workspaceID := r.PathValue("id")
+ sessionID := r.PathValue("sid")
+
+ var req struct {
+ Models map[config.SelectedModelType]config.SelectedModel `json:"models"`
+ }
+ if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+ c.server.logError(r, "Failed to decode request", "error", err)
+ jsonError(w, http.StatusBadRequest, "failed to decode request")
+ return
+ }
+
+ if err := c.backend.UpdateSessionModels(r.Context(), workspaceID, sessionID, req.Models); err != nil {
+ c.handleError(w, r, err)
+ return
+ }
+
+ w.WriteHeader(http.StatusNoContent)
+}
+
// handleGetWorkspaceSessionUserMessages returns user messages for a session.
//
// @Summary Get user messages for session
@@ -122,6 +122,7 @@ func NewServer(cfg *config.ConfigStore, network, address string) *Server {
mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}", c.handleGetWorkspaceSession)
mux.HandleFunc("PUT /v1/workspaces/{id}/sessions/{sid}", c.handlePutWorkspaceSession)
mux.HandleFunc("DELETE /v1/workspaces/{id}/sessions/{sid}", c.handleDeleteWorkspaceSession)
+ mux.HandleFunc("POST /v1/workspaces/{id}/sessions/{sid}/models", c.handlePostWorkspaceSessionModels)
mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/history", c.handleGetWorkspaceSessionHistory)
mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages", c.handleGetWorkspaceSessionMessages)
mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages/user", c.handleGetWorkspaceSessionUserMessages)
@@ -8,6 +8,7 @@ import (
"log/slog"
"strings"
+ "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/event"
"github.com/charmbracelet/crush/internal/pubsub"
@@ -56,6 +57,7 @@ type Session struct {
SummaryMessageID string
Cost float64
Todos []Todo
+ Models map[config.SelectedModelType]config.SelectedModel
CreatedAt int64
UpdatedAt int64
}
@@ -69,6 +71,8 @@ type Service interface {
GetLast(ctx context.Context) (Session, error)
List(ctx context.Context) ([]Session, error)
Save(ctx context.Context, session Session) (Session, error)
+ SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error)
+ UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error
UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
Rename(ctx context.Context, id string, title string) error
Delete(ctx context.Context, id string) error
@@ -242,6 +246,10 @@ func (s service) fromDBItem(item db.Session) Session {
if err != nil {
slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
}
+ models, err := unmarshalModels(item.Models.String)
+ if err != nil {
+ slog.Error("Failed to unmarshal models", "session_id", item.ID, "error", err)
+ }
return Session{
ID: item.ID,
ParentSessionID: item.ParentSessionID.String,
@@ -252,6 +260,7 @@ func (s service) fromDBItem(item db.Session) Session {
SummaryMessageID: item.SummaryMessageID.String,
Cost: item.Cost,
Todos: todos,
+ Models: models,
CreatedAt: item.CreatedAt,
UpdatedAt: item.UpdatedAt,
}
@@ -279,6 +288,51 @@ func unmarshalTodos(data string) ([]Todo, error) {
return todos, nil
}
+func marshalModels(models map[config.SelectedModelType]config.SelectedModel) (string, error) {
+ if len(models) == 0 {
+ return "", nil
+ }
+ data, err := json.Marshal(models)
+ if err != nil {
+ return "", err
+ }
+ return string(data), nil
+}
+
+func unmarshalModels(data string) (map[config.SelectedModelType]config.SelectedModel, error) {
+ if data == "" {
+ return nil, nil
+ }
+ var models map[config.SelectedModelType]config.SelectedModel
+ if err := json.Unmarshal([]byte(data), &models); err != nil {
+ return nil, err
+ }
+ return models, nil
+}
+
+func (s *service) UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error {
+ modelsJSON, err := marshalModels(models)
+ if err != nil {
+ return fmt.Errorf("failed to marshal models: %w", err)
+ }
+ _, err = s.q.UpdateSessionModels(ctx, db.UpdateSessionModelsParams{
+ Models: sql.NullString{String: modelsJSON, Valid: modelsJSON != ""},
+ ID: id,
+ })
+ return err
+}
+
+func (s *service) SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error) {
+ saved, err := s.Save(ctx, session)
+ if err != nil {
+ return Session{}, err
+ }
+ if err := s.UpdateSessionModels(ctx, session.ID, models); err != nil {
+ return Session{}, fmt.Errorf("failed to persist models: %w", err)
+ }
+ return saved, nil
+}
+
func NewService(q *db.Queries, conn *sql.DB) Service {
broker := pubsub.NewBroker[Session]()
return &service{
@@ -0,0 +1,151 @@
+package session
+
+import (
+ "database/sql"
+ "testing"
+
+ "github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/db"
+ "github.com/stretchr/testify/require"
+)
+
+func TestMarshalModels(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty", func(t *testing.T) {
+ t.Parallel()
+ result, err := marshalModels(map[config.SelectedModelType]config.SelectedModel{})
+ require.NoError(t, err)
+ require.Equal(t, "", result)
+ })
+
+ t.Run("nil", func(t *testing.T) {
+ t.Parallel()
+ result, err := marshalModels(nil)
+ require.NoError(t, err)
+ require.Equal(t, "", result)
+ })
+
+ t.Run("single entry", func(t *testing.T) {
+ t.Parallel()
+ models := map[config.SelectedModelType]config.SelectedModel{
+ config.SelectedModelTypeLarge: {
+ Model: "claude-sonnet-4-20250514",
+ Provider: "anthropic",
+ },
+ }
+ result, err := marshalModels(models)
+ require.NoError(t, err)
+ require.Contains(t, result, "claude-sonnet-4-20250514")
+ require.Contains(t, result, "anthropic")
+ })
+
+ t.Run("round-trip", func(t *testing.T) {
+ t.Parallel()
+ temp := 0.7
+ topP := 0.9
+ topK := int64(50)
+ freqPen := 0.1
+ presPen := 0.2
+ models := map[config.SelectedModelType]config.SelectedModel{
+ config.SelectedModelTypeLarge: {
+ Model: "gpt-4o",
+ Provider: "openai",
+ ReasoningEffort: "high",
+ Think: true,
+ MaxTokens: 4096,
+ Temperature: &temp,
+ TopP: &topP,
+ TopK: &topK,
+ FrequencyPenalty: &freqPen,
+ PresencePenalty: &presPen,
+ ProviderOptions: map[string]any{"key": "value"},
+ },
+ config.SelectedModelTypeSmall: {
+ Model: "gpt-4o-mini",
+ Provider: "openai",
+ },
+ }
+ data, err := marshalModels(models)
+ require.NoError(t, err)
+ result, err := unmarshalModels(data)
+ require.NoError(t, err)
+ require.Equal(t, models, result)
+ })
+}
+
+func TestUnmarshalModels(t *testing.T) {
+ t.Parallel()
+
+ t.Run("empty string", func(t *testing.T) {
+ t.Parallel()
+ result, err := unmarshalModels("")
+ require.NoError(t, err)
+ require.Nil(t, result)
+ })
+
+ t.Run("valid JSON", func(t *testing.T) {
+ t.Parallel()
+ data := `{"large":{"model":"gpt-4o","provider":"openai"}}`
+ result, err := unmarshalModels(data)
+ require.NoError(t, err)
+ require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model)
+ require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider)
+ })
+
+ t.Run("invalid JSON", func(t *testing.T) {
+ t.Parallel()
+ _, err := unmarshalModels("{invalid}")
+ require.Error(t, err)
+ })
+}
+
+func TestFromDBItemWithModels(t *testing.T) {
+ t.Parallel()
+
+ t.Run("null models", func(t *testing.T) {
+ t.Parallel()
+ item := testDBSession()
+ item.Models = sql.NullString{Valid: false}
+ result := service{}.fromDBItem(item)
+ require.Nil(t, result.Models)
+ })
+
+ t.Run("empty models", func(t *testing.T) {
+ t.Parallel()
+ item := testDBSession()
+ item.Models = sql.NullString{String: "", Valid: false}
+ result := service{}.fromDBItem(item)
+ require.Nil(t, result.Models)
+ })
+
+ t.Run("valid models", func(t *testing.T) {
+ t.Parallel()
+ item := testDBSession()
+ item.Models = sql.NullString{
+ String: `{"large":{"model":"gpt-4o","provider":"openai"}}`,
+ Valid: true,
+ }
+ result := service{}.fromDBItem(item)
+ require.NotNil(t, result.Models)
+ require.Equal(t, "gpt-4o", result.Models[config.SelectedModelTypeLarge].Model)
+ })
+
+ t.Run("invalid JSON models", func(t *testing.T) {
+ t.Parallel()
+ item := testDBSession()
+ item.Models = sql.NullString{
+ String: "{invalid}",
+ Valid: true,
+ }
+ result := service{}.fromDBItem(item)
+ require.Nil(t, result.Models)
+ })
+}
+
+func testDBSession() db.Session {
+ return db.Session{
+ ID: "test-id",
+ Title: "Test",
+ }
+}
@@ -498,6 +498,29 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
m.setState(uiChat, m.focus)
m.session = msg.session
m.sessionFiles = msg.files
+ if msg.session.Models != nil {
+ cfg := m.com.Config()
+ anyChanged := false
+ for modelType, selectedModel := range msg.session.Models {
+ if current, ok := cfg.Models[modelType]; ok && current.Equal(selectedModel) {
+ continue
+ }
+ if cfg.GetModel(selectedModel.Provider, selectedModel.Model) != nil {
+ if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, modelType, selectedModel); err != nil {
+ slog.Error("Failed to restore model", "type", modelType, "error", err)
+ }
+ anyChanged = true
+ }
+ }
+ if anyChanged {
+ cmds = append(cmds, func() tea.Msg {
+ if err := m.com.Workspace.UpdateAgentModel(context.TODO()); err != nil {
+ return util.ReportError(err)
+ }
+ return nil
+ })
+ }
+ }
cmds = append(cmds, m.startLSPs(msg.lspFilePaths()))
msgs, err := m.com.Workspace.ListMessages(context.Background(), m.session.ID)
if err != nil {
@@ -1376,6 +1399,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
}
return util.NewInfoMsg("Thinking mode " + status)
})
+ if m.session != nil {
+ cmds = append(cmds, m.saveSessionModels())
+ }
m.dialog.CloseDialog(dialog.CommandsID)
case dialog.ActionToggleTransparentBackground:
cmds = append(cmds, func() tea.Msg {
@@ -1465,6 +1491,10 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
return util.NewInfoMsg(modelMsg)
})
+ if m.session != nil {
+ cmds = append(cmds, m.saveSessionModels())
+ }
+
m.dialog.CloseDialog(dialog.APIKeyInputID)
m.dialog.CloseDialog(dialog.OAuthID)
m.dialog.CloseDialog(dialog.ModelsID)
@@ -1505,6 +1535,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
m.com.Workspace.UpdateAgentModel(context.TODO())
return util.NewInfoMsg("Reasoning effort set to " + msg.Effort)
})
+ if m.session != nil {
+ cmds = append(cmds, m.saveSessionModels())
+ }
m.dialog.CloseDialog(dialog.ReasoningID)
case dialog.ActionPermissionResponse:
m.dialog.CloseDialog(dialog.PermissionsID)
@@ -3652,3 +3685,18 @@ func renderLogo(t *styles.Styles, compact bool, width int) string {
Width: width,
})
}
+
+func (m *UI) saveSessionModels() tea.Cmd {
+ session := m.session
+ if session == nil {
+ return nil
+ }
+ models := m.com.Config().Models
+ return func() tea.Msg {
+ err := m.com.Workspace.UpdateSessionModels(context.Background(), session.ID, models)
+ if err != nil {
+ return util.ReportError(err)
+ }
+ return nil
+ }
+}
@@ -59,6 +59,10 @@ func (w *AppWorkspace) DeleteSession(ctx context.Context, sessionID string) erro
return w.app.Sessions.Delete(ctx, sessionID)
}
+func (w *AppWorkspace) UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error {
+ return w.app.Sessions.UpdateSessionModels(ctx, sessionID, models)
+}
+
func (w *AppWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string {
return w.app.Sessions.CreateAgentToolSessionID(messageID, toolCallID)
}
@@ -119,6 +119,10 @@ func (w *ClientWorkspace) DeleteSession(ctx context.Context, sessionID string) e
return w.client.DeleteSession(ctx, w.workspaceID(), sessionID)
}
+func (w *ClientWorkspace) UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error {
+ return w.client.UpdateSessionModels(ctx, w.workspaceID(), sessionID, models)
+}
+
func (w *ClientWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string {
return fmt.Sprintf("%s$$%s", messageID, toolCallID)
}
@@ -673,6 +677,7 @@ func protoToSession(s proto.Session) session.Session {
Cost: s.Cost,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
+ Models: convertModelsFromProto(s.Models),
}
}
@@ -769,5 +774,52 @@ func sessionToProto(s session.Session) proto.Session {
Cost: s.Cost,
CreatedAt: s.CreatedAt,
UpdatedAt: s.UpdatedAt,
+ Models: convertModelsToProtoClient(s.Models),
+ }
+}
+
+func convertModelsFromProto(models map[proto.SelectedModelType]proto.SelectedModel) map[config.SelectedModelType]config.SelectedModel {
+ if models == nil {
+ return nil
+ }
+ result := make(map[config.SelectedModelType]config.SelectedModel, len(models))
+ for k, v := range models {
+ result[config.SelectedModelType(k)] = config.SelectedModel{
+ Model: v.Model,
+ Provider: v.Provider,
+ ReasoningEffort: v.ReasoningEffort,
+ Think: v.Think,
+ MaxTokens: v.MaxTokens,
+ Temperature: v.Temperature,
+ TopP: v.TopP,
+ TopK: v.TopK,
+ FrequencyPenalty: v.FrequencyPenalty,
+ PresencePenalty: v.PresencePenalty,
+ ProviderOptions: v.ProviderOptions,
+ }
}
+ return result
+}
+
+func convertModelsToProtoClient(models map[config.SelectedModelType]config.SelectedModel) map[proto.SelectedModelType]proto.SelectedModel {
+ if models == nil {
+ return nil
+ }
+ result := make(map[proto.SelectedModelType]proto.SelectedModel, len(models))
+ for k, v := range models {
+ result[proto.SelectedModelType(k)] = proto.SelectedModel{
+ Model: v.Model,
+ Provider: v.Provider,
+ ReasoningEffort: v.ReasoningEffort,
+ Think: v.Think,
+ MaxTokens: v.MaxTokens,
+ Temperature: v.Temperature,
+ TopP: v.TopP,
+ TopK: v.TopK,
+ FrequencyPenalty: v.FrequencyPenalty,
+ PresencePenalty: v.PresencePenalty,
+ ProviderOptions: v.ProviderOptions,
+ }
+ }
+ return result
}
@@ -65,6 +65,7 @@ type Workspace interface {
ListSessions(ctx context.Context) ([]session.Session, error)
SaveSession(ctx context.Context, sess session.Session) (session.Session, error)
DeleteSession(ctx context.Context, sessionID string) error
+ UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error
CreateAgentToolSessionID(messageID, toolCallID string) string
ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)