feat: remember provider/model info per-session

Christian Rocha created

When switching sessions use the last set provider and model if they
exist. Thinking and reasoning settings are also preserved.

This applies to both large and small models and targets both
standalone and server-client implementations.

This functionality can extend to additional agents when the time comes.

Change summary

internal/agent/agent.go                                          |  10 
internal/agent/coordinator.go                                    |   4 
internal/agent/tools/todos.go                                    |   5 
internal/app/resolve_session_test.go                             |   9 
internal/backend/session.go                                      |  10 
internal/client/proto.go                                         |  17 
internal/config/config.go                                        |  21 
internal/config/selected_model_test.go                           | 143 +
internal/db/db.go                                                |  16 
internal/db/migrations/20260401000000_add_models_to_sessions.sql |   9 
internal/db/models.go                                            |   1 
internal/db/querier.go                                           |   1 
internal/db/sessions.sql.go                                      |  47 
internal/db/sql/sessions.sql                                     |   6 
internal/proto/proto.go                                          |   2 
internal/proto/session.go                                        |  44 
internal/proto/session_test.go                                   | 168 ++
internal/server/events.go                                        |  25 
internal/server/proto.go                                         |  34 
internal/server/server.go                                        |   1 
internal/session/session.go                                      |  54 
internal/session/session_test.go                                 | 151 +
internal/ui/model/ui.go                                          |  48 
internal/workspace/app_workspace.go                              |   4 
internal/workspace/client_workspace.go                           |  52 
internal/workspace/workspace.go                                  |   1 
26 files changed, 858 insertions(+), 25 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -423,7 +423,10 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 				return getSessionErr
 			}
 			a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
-			_, sessionErr := a.sessions.Save(ctx, updatedSession)
+			_, sessionErr := a.sessions.SaveWithModels(ctx, updatedSession, map[config.SelectedModelType]config.SelectedModel{
+				config.SelectedModelTypeLarge: a.largeModel.Get().ModelCfg,
+				config.SelectedModelTypeSmall: a.smallModel.Get().ModelCfg,
+			})
 			if sessionErr != nil {
 				return sessionErr
 			}
@@ -723,7 +726,10 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fan
 	currentSession.SummaryMessageID = summaryMessage.ID
 	currentSession.CompletionTokens = usage.OutputTokens
 	currentSession.PromptTokens = 0
-	_, err = a.sessions.Save(genCtx, currentSession)
+	_, err = a.sessions.SaveWithModels(genCtx, currentSession, map[config.SelectedModelType]config.SelectedModel{
+		config.SelectedModelTypeLarge: a.largeModel.Get().ModelCfg,
+		config.SelectedModelTypeSmall: a.smallModel.Get().ModelCfg,
+	})
 	return err
 }
 

internal/agent/coordinator.go 🔗

@@ -476,7 +476,7 @@ func (c *coordinator) buildTools(ctx context.Context, agent config.Agent) ([]fan
 		tools.NewGrepTool(c.cfg.WorkingDir(), c.cfg.Config().Tools.Grep),
 		tools.NewLsTool(c.permissions, c.cfg.WorkingDir(), c.cfg.Config().Tools.Ls),
 		tools.NewSourcegraphTool(nil),
-		tools.NewTodosTool(c.sessions),
+		tools.NewTodosTool(c.sessions, c.cfg),
 		tools.NewViewTool(c.lspManager, c.permissions, c.filetracker, c.skillTracker, c.cfg.WorkingDir(), c.cfg.Config().Options.SkillsPaths...),
 		tools.NewWriteTool(c.lspManager, c.permissions, c.history, c.filetracker, c.cfg.WorkingDir()),
 	)
@@ -1051,7 +1051,7 @@ func (c *coordinator) updateParentSessionCost(ctx context.Context, childSessionI
 
 	parentSession.Cost += childSession.Cost
 
-	if _, err := c.sessions.Save(ctx, parentSession); err != nil {
+	if _, err := c.sessions.SaveWithModels(ctx, parentSession, c.cfg.Config().Models); err != nil {
 		return fmt.Errorf("save parent session: %w", err)
 	}
 

internal/agent/tools/todos.go 🔗

@@ -6,6 +6,7 @@ import (
 	"fmt"
 
 	"charm.land/fantasy"
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/session"
 )
 
@@ -33,7 +34,7 @@ type TodosResponseMetadata struct {
 	Total         int            `json:"total"`
 }
 
-func NewTodosTool(sessions session.Service) fantasy.AgentTool {
+func NewTodosTool(sessions session.Service, cfg *config.ConfigStore) fantasy.AgentTool {
 	return fantasy.NewAgentTool(
 		TodosToolName,
 		FirstLineDescription(todosDescription),
@@ -96,7 +97,7 @@ func NewTodosTool(sessions session.Service) fantasy.AgentTool {
 			}
 
 			currentSession.Todos = todos
-			_, err = sessions.Save(ctx, currentSession)
+			_, err = sessions.SaveWithModels(ctx, currentSession, cfg.Config().Models)
 			if err != nil {
 				return fantasy.ToolResponse{}, fmt.Errorf("failed to save todos: %w", err)
 			}

internal/app/resolve_session_test.go 🔗

@@ -7,6 +7,7 @@ import (
 	"strings"
 	"testing"
 
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/pubsub"
 	"github.com/charmbracelet/crush/internal/session"
 	"github.com/stretchr/testify/require"
@@ -60,6 +61,14 @@ func (m *mockSessionService) Save(_ context.Context, s session.Session) (session
 	return s, nil
 }
 
+func (m *mockSessionService) SaveWithModels(_ context.Context, s session.Session, _ map[config.SelectedModelType]config.SelectedModel) (session.Session, error) {
+	return s, nil
+}
+
+func (m *mockSessionService) UpdateSessionModels(context.Context, string, map[config.SelectedModelType]config.SelectedModel) error {
+	return nil
+}
+
 func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error {
 	return nil
 }

internal/backend/session.go 🔗

@@ -3,6 +3,7 @@ package backend
 import (
 	"context"
 
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/message"
 	"github.com/charmbracelet/crush/internal/proto"
 	"github.com/charmbracelet/crush/internal/session"
@@ -105,6 +106,15 @@ func (b *Backend) DeleteSession(ctx context.Context, workspaceID, sessionID stri
 	return ws.Sessions.Delete(ctx, sessionID)
 }
 
+// UpdateSessionModels updates the models for a session in the given workspace.
+func (b *Backend) UpdateSessionModels(ctx context.Context, workspaceID, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error {
+	ws, err := b.GetWorkspace(workspaceID)
+	if err != nil {
+		return err
+	}
+	return ws.Sessions.UpdateSessionModels(ctx, sessionID, models)
+}
+
 // ListUserMessages returns user-role messages for a session.
 func (b *Backend) ListUserMessages(ctx context.Context, workspaceID, sessionID string) ([]message.Message, error) {
 	ws, err := b.GetWorkspace(workspaceID)

internal/client/proto.go 🔗

@@ -585,6 +585,23 @@ func (c *Client) DeleteSession(ctx context.Context, id string, sessionID string)
 	return nil
 }
 
+// UpdateSessionModels updates the models for a session.
+func (c *Client) UpdateSessionModels(ctx context.Context, workspaceID, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error {
+	rsp, err := c.post(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/models", workspaceID, sessionID), nil,
+		jsonBody(struct {
+			Models map[config.SelectedModelType]config.SelectedModel `json:"models"`
+		}{Models: models}),
+		http.Header{"Content-Type": []string{"application/json"}})
+	if err != nil {
+		return fmt.Errorf("failed to update session models: %w", err)
+	}
+	defer rsp.Body.Close()
+	if rsp.StatusCode != http.StatusNoContent {
+		return fmt.Errorf("failed to update session models: status code %d", rsp.StatusCode)
+	}
+	return nil
+}
+
 // ListUserMessages retrieves user-role messages for a session as proto types.
 func (c *Client) ListUserMessages(ctx context.Context, id string, sessionID string) ([]proto.Message, error) {
 	rsp, err := c.get(ctx, fmt.Sprintf("/workspaces/%s/sessions/%s/messages/user", id, sessionID), nil, nil)

internal/config/config.go 🔗

@@ -89,6 +89,27 @@ type SelectedModel struct {
 	ProviderOptions map[string]any `json:"provider_options,omitempty" jsonschema:"description=Additional provider-specific options for the model"`
 }
 
+func (a SelectedModel) Equal(b SelectedModel) bool {
+	return a.Model == b.Model &&
+		a.Provider == b.Provider &&
+		a.ReasoningEffort == b.ReasoningEffort &&
+		a.Think == b.Think &&
+		a.MaxTokens == b.MaxTokens &&
+		ptrEqual(a.Temperature, b.Temperature) &&
+		ptrEqual(a.TopP, b.TopP) &&
+		ptrEqual(a.TopK, b.TopK) &&
+		ptrEqual(a.FrequencyPenalty, b.FrequencyPenalty) &&
+		ptrEqual(a.PresencePenalty, b.PresencePenalty) &&
+		maps.Equal(a.ProviderOptions, b.ProviderOptions)
+}
+
+func ptrEqual[T comparable](a, b *T) bool {
+	if a == nil || b == nil {
+		return a == b
+	}
+	return *a == *b
+}
+
 type ProviderConfig struct {
 	// The provider's id.
 	ID string `json:"id,omitempty" jsonschema:"description=Unique identifier for the provider,example=openai"`

internal/config/selected_model_test.go 🔗

@@ -0,0 +1,143 @@
+package config
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/require"
+)
+
+func TestSelectedModelEqual(t *testing.T) {
+	t.Parallel()
+
+	t.Run("equal models", func(t *testing.T) {
+		t.Parallel()
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+		b := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+		require.True(t, a.Equal(b))
+	})
+
+	t.Run("different model", func(t *testing.T) {
+		t.Parallel()
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+		b := SelectedModel{Model: "gpt-4o-mini", Provider: "openai"}
+		require.False(t, a.Equal(b))
+	})
+
+	t.Run("different provider", func(t *testing.T) {
+		t.Parallel()
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+		b := SelectedModel{Model: "gpt-4o", Provider: "anthropic"}
+		require.False(t, a.Equal(b))
+	})
+
+	t.Run("different reasoning effort", func(t *testing.T) {
+		t.Parallel()
+		a := SelectedModel{Model: "o3", Provider: "openai", ReasoningEffort: "high"}
+		b := SelectedModel{Model: "o3", Provider: "openai", ReasoningEffort: "low"}
+		require.False(t, a.Equal(b))
+	})
+
+	t.Run("different think", func(t *testing.T) {
+		t.Parallel()
+		a := SelectedModel{Model: "claude-sonnet", Provider: "anthropic", Think: true}
+		b := SelectedModel{Model: "claude-sonnet", Provider: "anthropic", Think: false}
+		require.False(t, a.Equal(b))
+	})
+
+	t.Run("different max tokens", func(t *testing.T) {
+		t.Parallel()
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai", MaxTokens: 4096}
+		b := SelectedModel{Model: "gpt-4o", Provider: "openai", MaxTokens: 8192}
+		require.False(t, a.Equal(b))
+	})
+
+	t.Run("both nil pointers", func(t *testing.T) {
+		t.Parallel()
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+		b := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+		require.True(t, a.Equal(b))
+	})
+
+	t.Run("one nil one non-nil pointer", func(t *testing.T) {
+		t.Parallel()
+		temp := 0.7
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp}
+		b := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+		require.False(t, a.Equal(b))
+		require.False(t, b.Equal(a))
+	})
+
+	t.Run("both non-nil equal pointers", func(t *testing.T) {
+		t.Parallel()
+		temp := 0.7
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp}
+		b := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &temp}
+		require.True(t, a.Equal(b))
+	})
+
+	t.Run("both non-nil different pointers", func(t *testing.T) {
+		t.Parallel()
+		t1 := 0.7
+		t2 := 0.9
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &t1}
+		b := SelectedModel{Model: "gpt-4o", Provider: "openai", Temperature: &t2}
+		require.False(t, a.Equal(b))
+	})
+
+	t.Run("nil ProviderOptions", func(t *testing.T) {
+		t.Parallel()
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+		b := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+		require.True(t, a.Equal(b))
+	})
+
+	t.Run("empty ProviderOptions", func(t *testing.T) {
+		t.Parallel()
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{}}
+		b := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{}}
+		require.True(t, a.Equal(b))
+	})
+
+	t.Run("different ProviderOptions", func(t *testing.T) {
+		t.Parallel()
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "a"}}
+		b := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "b"}}
+		require.False(t, a.Equal(b))
+	})
+
+	t.Run("one nil one non-nil ProviderOptions", func(t *testing.T) {
+		t.Parallel()
+		a := SelectedModel{Model: "gpt-4o", Provider: "openai", ProviderOptions: map[string]any{"key": "a"}}
+		b := SelectedModel{Model: "gpt-4o", Provider: "openai"}
+		require.False(t, a.Equal(b))
+	})
+}
+
+func TestPtrEqual(t *testing.T) {
+	t.Parallel()
+
+	t.Run("both nil", func(t *testing.T) {
+		t.Parallel()
+		require.True(t, ptrEqual[int](nil, nil))
+	})
+
+	t.Run("one nil", func(t *testing.T) {
+		t.Parallel()
+		v := 42
+		require.False(t, ptrEqual(&v, nil))
+		require.False(t, ptrEqual(nil, &v))
+	})
+
+	t.Run("both equal", func(t *testing.T) {
+		t.Parallel()
+		v := 42
+		require.True(t, ptrEqual(&v, &v))
+	})
+
+	t.Run("both different", func(t *testing.T) {
+		t.Parallel()
+		a := 42
+		b := 43
+		require.False(t, ptrEqual(&a, &b))
+	})
+}

internal/db/db.go 🔗

@@ -11,10 +11,10 @@ import (
 )
 
 type DBTX interface {
-	ExecContext(context.Context, string, ...any) (sql.Result, error)
+	ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
 	PrepareContext(context.Context, string) (*sql.Stmt, error)
-	QueryContext(context.Context, string, ...any) (*sql.Rows, error)
-	QueryRowContext(context.Context, string, ...any) *sql.Row
+	QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
+	QueryRowContext(context.Context, string, ...interface{}) *sql.Row
 }
 
 func New(db DBTX) *Queries {
@@ -132,6 +132,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
 	if q.updateSessionStmt, err = db.PrepareContext(ctx, updateSession); err != nil {
 		return nil, fmt.Errorf("error preparing query UpdateSession: %w", err)
 	}
+	if q.updateSessionModelsStmt, err = db.PrepareContext(ctx, updateSessionModels); err != nil {
+		return nil, fmt.Errorf("error preparing query UpdateSessionModels: %w", err)
+	}
 	if q.updateSessionTitleAndUsageStmt, err = db.PrepareContext(ctx, updateSessionTitleAndUsage); err != nil {
 		return nil, fmt.Errorf("error preparing query UpdateSessionTitleAndUsage: %w", err)
 	}
@@ -320,6 +323,11 @@ func (q *Queries) Close() error {
 			err = fmt.Errorf("error closing updateSessionStmt: %w", cerr)
 		}
 	}
+	if q.updateSessionModelsStmt != nil {
+		if cerr := q.updateSessionModelsStmt.Close(); cerr != nil {
+			err = fmt.Errorf("error closing updateSessionModelsStmt: %w", cerr)
+		}
+	}
 	if q.updateSessionTitleAndUsageStmt != nil {
 		if cerr := q.updateSessionTitleAndUsageStmt.Close(); cerr != nil {
 			err = fmt.Errorf("error closing updateSessionTitleAndUsageStmt: %w", cerr)
@@ -400,6 +408,7 @@ type Queries struct {
 	renameSessionStmt              *sql.Stmt
 	updateMessageStmt              *sql.Stmt
 	updateSessionStmt              *sql.Stmt
+	updateSessionModelsStmt        *sql.Stmt
 	updateSessionTitleAndUsageStmt *sql.Stmt
 }
 
@@ -443,6 +452,7 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
 		renameSessionStmt:              q.renameSessionStmt,
 		updateMessageStmt:              q.updateMessageStmt,
 		updateSessionStmt:              q.updateSessionStmt,
+		updateSessionModelsStmt:        q.updateSessionModelsStmt,
 		updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt,
 	}
 }

internal/db/models.go 🔗

@@ -49,4 +49,5 @@ type Session struct {
 	CreatedAt        int64          `json:"created_at"`
 	SummaryMessageID sql.NullString `json:"summary_message_id"`
 	Todos            sql.NullString `json:"todos"`
+	Models           sql.NullString `json:"models"`
 }

internal/db/querier.go 🔗

@@ -45,6 +45,7 @@ type Querier interface {
 	RenameSession(ctx context.Context, arg RenameSessionParams) error
 	UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
 	UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
+	UpdateSessionModels(ctx context.Context, arg UpdateSessionModelsParams) (Session, error)
 	UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error
 }
 

internal/db/sessions.sql.go 🔗

@@ -33,7 +33,7 @@ INSERT INTO sessions (
     null,
     strftime('%s', 'now'),
     strftime('%s', 'now')
-) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
 `
 
 type CreateSessionParams struct {
@@ -69,6 +69,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
 		&i.CreatedAt,
 		&i.SummaryMessageID,
 		&i.Todos,
+		&i.Models,
 	)
 	return i, err
 }
@@ -84,7 +85,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error {
 }
 
 const getLastSession = `-- name: GetLastSession :one
-SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
 FROM sessions
 ORDER BY updated_at DESC
 LIMIT 1
@@ -105,12 +106,13 @@ func (q *Queries) GetLastSession(ctx context.Context) (Session, error) {
 		&i.CreatedAt,
 		&i.SummaryMessageID,
 		&i.Todos,
+		&i.Models,
 	)
 	return i, err
 }
 
 const getSessionByID = `-- name: GetSessionByID :one
-SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
 FROM sessions
 WHERE id = ? LIMIT 1
 `
@@ -130,12 +132,13 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
 		&i.CreatedAt,
 		&i.SummaryMessageID,
 		&i.Todos,
+		&i.Models,
 	)
 	return i, err
 }
 
 const listSessions = `-- name: ListSessions :many
-SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
 FROM sessions
 WHERE parent_session_id is NULL
 ORDER BY updated_at DESC
@@ -162,6 +165,7 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
 			&i.CreatedAt,
 			&i.SummaryMessageID,
 			&i.Todos,
+			&i.Models,
 		); err != nil {
 			return nil, err
 		}
@@ -203,7 +207,7 @@ SET
     cost = ?,
     todos = ?
 WHERE id = ?
-RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos
+RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
 `
 
 type UpdateSessionParams struct {
@@ -239,6 +243,39 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
 		&i.CreatedAt,
 		&i.SummaryMessageID,
 		&i.Todos,
+		&i.Models,
+	)
+	return i, err
+}
+
+const updateSessionModels = `-- name: UpdateSessionModels :one
+UPDATE sessions
+SET models = ?
+WHERE id = ?
+RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id, todos, models
+`
+
+type UpdateSessionModelsParams struct {
+	Models sql.NullString `json:"models"`
+	ID     string         `json:"id"`
+}
+
+func (q *Queries) UpdateSessionModels(ctx context.Context, arg UpdateSessionModelsParams) (Session, error) {
+	row := q.queryRow(ctx, q.updateSessionModelsStmt, updateSessionModels, arg.Models, arg.ID)
+	var i Session
+	err := row.Scan(
+		&i.ID,
+		&i.ParentSessionID,
+		&i.Title,
+		&i.MessageCount,
+		&i.PromptTokens,
+		&i.CompletionTokens,
+		&i.Cost,
+		&i.UpdatedAt,
+		&i.CreatedAt,
+		&i.SummaryMessageID,
+		&i.Todos,
+		&i.Models,
 	)
 	return i, err
 }

internal/db/sql/sessions.sql 🔗

@@ -69,6 +69,12 @@ SET
     title = ?
 WHERE id = ?;
 
+-- name: UpdateSessionModels :one
+UPDATE sessions
+SET models = ?
+WHERE id = ?
+RETURNING *;
+
 -- name: DeleteSession :exec
 DELETE FROM sessions
 WHERE id = ?;

internal/proto/proto.go 🔗

@@ -56,7 +56,7 @@ type AgentSession struct {
 
 // IsZero checks if the AgentSession is zero-valued.
 func (a AgentSession) IsZero() bool {
-	return a == AgentSession{}
+	return a.ID == "" && !a.IsBusy
 }
 
 // PermissionAction represents an action taken on a permission request.

internal/proto/session.go 🔗

@@ -1,15 +1,39 @@
 package proto
 
+// SelectedModelType represents the type of model selection (large or small).
+type SelectedModelType string
+
+const (
+	SelectedModelTypeLarge SelectedModelType = "large"
+	SelectedModelTypeSmall SelectedModelType = "small"
+)
+
+// SelectedModel represents a model selection with provider and configuration.
+type SelectedModel struct {
+	Model            string         `json:"model"`
+	Provider         string         `json:"provider"`
+	ReasoningEffort  string         `json:"reasoning_effort,omitempty"`
+	Think            bool           `json:"think,omitempty"`
+	MaxTokens        int64          `json:"max_tokens,omitempty"`
+	Temperature      *float64       `json:"temperature,omitempty"`
+	TopP             *float64       `json:"top_p,omitempty"`
+	TopK             *int64         `json:"top_k,omitempty"`
+	FrequencyPenalty *float64       `json:"frequency_penalty,omitempty"`
+	PresencePenalty  *float64       `json:"presence_penalty,omitempty"`
+	ProviderOptions  map[string]any `json:"provider_options,omitempty"`
+}
+
 // Session represents a session in the proto layer.
 type Session struct {
-	ID               string  `json:"id"`
-	ParentSessionID  string  `json:"parent_session_id"`
-	Title            string  `json:"title"`
-	MessageCount     int64   `json:"message_count"`
-	PromptTokens     int64   `json:"prompt_tokens"`
-	CompletionTokens int64   `json:"completion_tokens"`
-	SummaryMessageID string  `json:"summary_message_id"`
-	Cost             float64 `json:"cost"`
-	CreatedAt        int64   `json:"created_at"`
-	UpdatedAt        int64   `json:"updated_at"`
+	ID               string                              `json:"id"`
+	ParentSessionID  string                              `json:"parent_session_id"`
+	Title            string                              `json:"title"`
+	MessageCount     int64                               `json:"message_count"`
+	PromptTokens     int64                               `json:"prompt_tokens"`
+	CompletionTokens int64                               `json:"completion_tokens"`
+	SummaryMessageID string                              `json:"summary_message_id"`
+	Cost             float64                             `json:"cost"`
+	CreatedAt        int64                               `json:"created_at"`
+	UpdatedAt        int64                               `json:"updated_at"`
+	Models           map[SelectedModelType]SelectedModel `json:"models,omitempty"`
 }

internal/proto/session_test.go 🔗

@@ -0,0 +1,168 @@
+package proto
+
+import (
+	"encoding/json"
+	"testing"
+
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/stretchr/testify/require"
+)
+
+func TestModelsRoundTrip(t *testing.T) {
+	t.Parallel()
+
+	t.Run("nil models", func(t *testing.T) {
+		t.Parallel()
+		s := Session{
+			ID:    "test-id",
+			Title: "Test",
+		}
+		data, err := json.Marshal(s)
+		require.NoError(t, err)
+		var decoded Session
+		require.NoError(t, json.Unmarshal(data, &decoded))
+		require.Nil(t, decoded.Models)
+	})
+
+	t.Run("populated models", func(t *testing.T) {
+		t.Parallel()
+		temp := 0.7
+		s := Session{
+			ID:    "test-id",
+			Title: "Test",
+			Models: map[SelectedModelType]SelectedModel{
+				SelectedModelTypeLarge: {
+					Model:           "gpt-4o",
+					Provider:        "openai",
+					ReasoningEffort: "high",
+					Think:           true,
+					MaxTokens:       4096,
+					Temperature:     &temp,
+					ProviderOptions: map[string]any{"key": "value"},
+				},
+				SelectedModelTypeSmall: {
+					Model:    "gpt-4o-mini",
+					Provider: "openai",
+				},
+			},
+		}
+		data, err := json.Marshal(s)
+		require.NoError(t, err)
+		var decoded Session
+		require.NoError(t, json.Unmarshal(data, &decoded))
+		require.Equal(t, "gpt-4o", decoded.Models[SelectedModelTypeLarge].Model)
+		require.Equal(t, "openai", decoded.Models[SelectedModelTypeLarge].Provider)
+		require.Equal(t, "high", decoded.Models[SelectedModelTypeLarge].ReasoningEffort)
+		require.True(t, decoded.Models[SelectedModelTypeLarge].Think)
+		require.Equal(t, int64(4096), decoded.Models[SelectedModelTypeLarge].MaxTokens)
+		require.NotNil(t, decoded.Models[SelectedModelTypeLarge].Temperature)
+		require.Equal(t, 0.7, *decoded.Models[SelectedModelTypeLarge].Temperature)
+		require.Equal(t, "gpt-4o-mini", decoded.Models[SelectedModelTypeSmall].Model)
+	})
+
+	t.Run("empty map models", func(t *testing.T) {
+		t.Parallel()
+		s := Session{
+			ID:     "test-id",
+			Title:  "Test",
+			Models: map[SelectedModelType]SelectedModel{},
+		}
+		data, err := json.Marshal(s)
+		require.NoError(t, err)
+		var decoded Session
+		require.NoError(t, json.Unmarshal(data, &decoded))
+		// Empty map with omitempty is dropped during marshaling.
+		require.Nil(t, decoded.Models)
+	})
+}
+
+func TestProtoToDomainRoundTrip(t *testing.T) {
+	t.Parallel()
+
+	t.Run("models through proto", func(t *testing.T) {
+		t.Parallel()
+		temp := 0.7
+		domainModels := map[config.SelectedModelType]config.SelectedModel{
+			config.SelectedModelTypeLarge: {
+				Model:           "gpt-4o",
+				Provider:        "openai",
+				ReasoningEffort: "high",
+				Think:           true,
+				MaxTokens:       4096,
+				Temperature:     &temp,
+				ProviderOptions: map[string]any{"key": "value"},
+			},
+		}
+
+		// Domain → Proto
+		protoModels := convertModelsToProtoLocal(domainModels)
+		require.Equal(t, SelectedModelTypeLarge, SelectedModelType(config.SelectedModelTypeLarge))
+		require.Equal(t, "gpt-4o", protoModels[SelectedModelTypeLarge].Model)
+		require.Equal(t, "openai", protoModels[SelectedModelTypeLarge].Provider)
+
+		// Proto → Domain
+		result := convertModelsFromProtoLocal(protoModels)
+		require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model)
+		require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider)
+		require.Equal(t, "high", result[config.SelectedModelTypeLarge].ReasoningEffort)
+		require.True(t, result[config.SelectedModelTypeLarge].Think)
+		require.Equal(t, int64(4096), result[config.SelectedModelTypeLarge].MaxTokens)
+		require.NotNil(t, result[config.SelectedModelTypeLarge].Temperature)
+		require.Equal(t, 0.7, *result[config.SelectedModelTypeLarge].Temperature)
+	})
+
+	t.Run("nil models round-trip", func(t *testing.T) {
+		t.Parallel()
+		protoModels := convertModelsToProtoLocal(nil)
+		require.Nil(t, protoModels)
+
+		domainModels := convertModelsFromProtoLocal(nil)
+		require.Nil(t, domainModels)
+	})
+}
+
+func convertModelsToProtoLocal(models map[config.SelectedModelType]config.SelectedModel) map[SelectedModelType]SelectedModel {
+	if models == nil {
+		return nil
+	}
+	result := make(map[SelectedModelType]SelectedModel, len(models))
+	for k, v := range models {
+		result[SelectedModelType(k)] = SelectedModel{
+			Model:            v.Model,
+			Provider:         v.Provider,
+			ReasoningEffort:  v.ReasoningEffort,
+			Think:            v.Think,
+			MaxTokens:        v.MaxTokens,
+			Temperature:      v.Temperature,
+			TopP:             v.TopP,
+			TopK:             v.TopK,
+			FrequencyPenalty: v.FrequencyPenalty,
+			PresencePenalty:  v.PresencePenalty,
+			ProviderOptions:  v.ProviderOptions,
+		}
+	}
+	return result
+}
+
+func convertModelsFromProtoLocal(models map[SelectedModelType]SelectedModel) map[config.SelectedModelType]config.SelectedModel {
+	if models == nil {
+		return nil
+	}
+	result := make(map[config.SelectedModelType]config.SelectedModel, len(models))
+	for k, v := range models {
+		result[config.SelectedModelType(k)] = config.SelectedModel{
+			Model:            v.Model,
+			Provider:         v.Provider,
+			ReasoningEffort:  v.ReasoningEffort,
+			Think:            v.Think,
+			MaxTokens:        v.MaxTokens,
+			Temperature:      v.Temperature,
+			TopP:             v.TopP,
+			TopK:             v.TopK,
+			FrequencyPenalty: v.FrequencyPenalty,
+			PresencePenalty:  v.PresencePenalty,
+			ProviderOptions:  v.ProviderOptions,
+		}
+	}
+	return result
+}

internal/server/events.go 🔗

@@ -8,6 +8,7 @@ import (
 	"github.com/charmbracelet/crush/internal/agent/notify"
 	"github.com/charmbracelet/crush/internal/agent/tools/mcp"
 	"github.com/charmbracelet/crush/internal/app"
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/history"
 	"github.com/charmbracelet/crush/internal/message"
 	"github.com/charmbracelet/crush/internal/permission"
@@ -137,9 +138,33 @@ func sessionToProto(s session.Session) proto.Session {
 		Cost:             s.Cost,
 		CreatedAt:        s.CreatedAt,
 		UpdatedAt:        s.UpdatedAt,
+		Models:           convertModelsToProto(s.Models),
 	}
 }
 
+func convertModelsToProto(models map[config.SelectedModelType]config.SelectedModel) map[proto.SelectedModelType]proto.SelectedModel {
+	if models == nil {
+		return nil
+	}
+	result := make(map[proto.SelectedModelType]proto.SelectedModel, len(models))
+	for k, v := range models {
+		result[proto.SelectedModelType(k)] = proto.SelectedModel{
+			Model:            v.Model,
+			Provider:         v.Provider,
+			ReasoningEffort:  v.ReasoningEffort,
+			Think:            v.Think,
+			MaxTokens:        v.MaxTokens,
+			Temperature:      v.Temperature,
+			TopP:             v.TopP,
+			TopK:             v.TopK,
+			FrequencyPenalty: v.FrequencyPenalty,
+			PresencePenalty:  v.PresencePenalty,
+			ProviderOptions:  v.ProviderOptions,
+		}
+	}
+	return result
+}
+
 func fileToProto(f history.File) proto.File {
 	return proto.File{
 		ID:        f.ID,

internal/server/proto.go 🔗

@@ -7,6 +7,7 @@ import (
 	"net/http"
 
 	"github.com/charmbracelet/crush/internal/backend"
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/proto"
 	"github.com/charmbracelet/crush/internal/session"
 )
@@ -460,6 +461,39 @@ func (c *controllerV1) handleDeleteWorkspaceSession(w http.ResponseWriter, r *ht
 	w.WriteHeader(http.StatusOK)
 }
 
+// handlePostWorkspaceSessionModels updates the models for a session.
+//
+//	@Summary		Update session models
+//	@Tags			sessions
+//	@Accept			json
+//	@Param			id		path		string										true	"Workspace ID"
+//	@Param			sid		path		string										true	"Session ID"
+//	@Param			body	body		map[config.SelectedModelType]config.SelectedModel	true	"Models"
+//	@Success		204
+//	@Failure		400	{object}	proto.Error
+//	@Failure		500	{object}	proto.Error
+//	@Router			/workspaces/{id}/sessions/{sid}/models [post]
+func (c *controllerV1) handlePostWorkspaceSessionModels(w http.ResponseWriter, r *http.Request) {
+	workspaceID := r.PathValue("id")
+	sessionID := r.PathValue("sid")
+
+	var req struct {
+		Models map[config.SelectedModelType]config.SelectedModel `json:"models"`
+	}
+	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+		c.server.logError(r, "Failed to decode request", "error", err)
+		jsonError(w, http.StatusBadRequest, "failed to decode request")
+		return
+	}
+
+	if err := c.backend.UpdateSessionModels(r.Context(), workspaceID, sessionID, req.Models); err != nil {
+		c.handleError(w, r, err)
+		return
+	}
+
+	w.WriteHeader(http.StatusNoContent)
+}
+
 // handleGetWorkspaceSessionUserMessages returns user messages for a session.
 //
 //	@Summary		Get user messages for session

internal/server/server.go 🔗

@@ -122,6 +122,7 @@ func NewServer(cfg *config.ConfigStore, network, address string) *Server {
 	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}", c.handleGetWorkspaceSession)
 	mux.HandleFunc("PUT /v1/workspaces/{id}/sessions/{sid}", c.handlePutWorkspaceSession)
 	mux.HandleFunc("DELETE /v1/workspaces/{id}/sessions/{sid}", c.handleDeleteWorkspaceSession)
+	mux.HandleFunc("POST /v1/workspaces/{id}/sessions/{sid}/models", c.handlePostWorkspaceSessionModels)
 	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/history", c.handleGetWorkspaceSessionHistory)
 	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages", c.handleGetWorkspaceSessionMessages)
 	mux.HandleFunc("GET /v1/workspaces/{id}/sessions/{sid}/messages/user", c.handleGetWorkspaceSessionUserMessages)

internal/session/session.go 🔗

@@ -8,6 +8,7 @@ import (
 	"log/slog"
 	"strings"
 
+	"github.com/charmbracelet/crush/internal/config"
 	"github.com/charmbracelet/crush/internal/db"
 	"github.com/charmbracelet/crush/internal/event"
 	"github.com/charmbracelet/crush/internal/pubsub"
@@ -56,6 +57,7 @@ type Session struct {
 	SummaryMessageID string
 	Cost             float64
 	Todos            []Todo
+	Models           map[config.SelectedModelType]config.SelectedModel
 	CreatedAt        int64
 	UpdatedAt        int64
 }
@@ -69,6 +71,8 @@ type Service interface {
 	GetLast(ctx context.Context) (Session, error)
 	List(ctx context.Context) ([]Session, error)
 	Save(ctx context.Context, session Session) (Session, error)
+	SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error)
+	UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error
 	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 	Rename(ctx context.Context, id string, title string) error
 	Delete(ctx context.Context, id string) error
@@ -242,6 +246,10 @@ func (s service) fromDBItem(item db.Session) Session {
 	if err != nil {
 		slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
 	}
+	models, err := unmarshalModels(item.Models.String)
+	if err != nil {
+		slog.Error("Failed to unmarshal models", "session_id", item.ID, "error", err)
+	}
 	return Session{
 		ID:               item.ID,
 		ParentSessionID:  item.ParentSessionID.String,
@@ -252,6 +260,7 @@ func (s service) fromDBItem(item db.Session) Session {
 		SummaryMessageID: item.SummaryMessageID.String,
 		Cost:             item.Cost,
 		Todos:            todos,
+		Models:           models,
 		CreatedAt:        item.CreatedAt,
 		UpdatedAt:        item.UpdatedAt,
 	}
@@ -279,6 +288,51 @@ func unmarshalTodos(data string) ([]Todo, error) {
 	return todos, nil
 }
 
+func marshalModels(models map[config.SelectedModelType]config.SelectedModel) (string, error) {
+	if len(models) == 0 {
+		return "", nil
+	}
+	data, err := json.Marshal(models)
+	if err != nil {
+		return "", err
+	}
+	return string(data), nil
+}
+
+func unmarshalModels(data string) (map[config.SelectedModelType]config.SelectedModel, error) {
+	if data == "" {
+		return nil, nil
+	}
+	var models map[config.SelectedModelType]config.SelectedModel
+	if err := json.Unmarshal([]byte(data), &models); err != nil {
+		return nil, err
+	}
+	return models, nil
+}
+
+func (s *service) UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error {
+	modelsJSON, err := marshalModels(models)
+	if err != nil {
+		return fmt.Errorf("failed to marshal models: %w", err)
+	}
+	_, err = s.q.UpdateSessionModels(ctx, db.UpdateSessionModelsParams{
+		Models: sql.NullString{String: modelsJSON, Valid: modelsJSON != ""},
+		ID:     id,
+	})
+	return err
+}
+
+func (s *service) SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error) {
+	saved, err := s.Save(ctx, session)
+	if err != nil {
+		return Session{}, err
+	}
+	if err := s.UpdateSessionModels(ctx, session.ID, models); err != nil {
+		return Session{}, fmt.Errorf("failed to persist models: %w", err)
+	}
+	return saved, nil
+}
+
 func NewService(q *db.Queries, conn *sql.DB) Service {
 	broker := pubsub.NewBroker[Session]()
 	return &service{

internal/session/session_test.go 🔗

@@ -0,0 +1,151 @@
+package session
+
+import (
+	"database/sql"
+	"testing"
+
+	"github.com/charmbracelet/crush/internal/config"
+	"github.com/charmbracelet/crush/internal/db"
+	"github.com/stretchr/testify/require"
+)
+
+func TestMarshalModels(t *testing.T) {
+	t.Parallel()
+
+	t.Run("empty", func(t *testing.T) {
+		t.Parallel()
+		result, err := marshalModels(map[config.SelectedModelType]config.SelectedModel{})
+		require.NoError(t, err)
+		require.Equal(t, "", result)
+	})
+
+	t.Run("nil", func(t *testing.T) {
+		t.Parallel()
+		result, err := marshalModels(nil)
+		require.NoError(t, err)
+		require.Equal(t, "", result)
+	})
+
+	t.Run("single entry", func(t *testing.T) {
+		t.Parallel()
+		models := map[config.SelectedModelType]config.SelectedModel{
+			config.SelectedModelTypeLarge: {
+				Model:    "claude-sonnet-4-20250514",
+				Provider: "anthropic",
+			},
+		}
+		result, err := marshalModels(models)
+		require.NoError(t, err)
+		require.Contains(t, result, "claude-sonnet-4-20250514")
+		require.Contains(t, result, "anthropic")
+	})
+
+	t.Run("round-trip", func(t *testing.T) {
+		t.Parallel()
+		temp := 0.7
+		topP := 0.9
+		topK := int64(50)
+		freqPen := 0.1
+		presPen := 0.2
+		models := map[config.SelectedModelType]config.SelectedModel{
+			config.SelectedModelTypeLarge: {
+				Model:            "gpt-4o",
+				Provider:         "openai",
+				ReasoningEffort:  "high",
+				Think:            true,
+				MaxTokens:        4096,
+				Temperature:      &temp,
+				TopP:             &topP,
+				TopK:             &topK,
+				FrequencyPenalty: &freqPen,
+				PresencePenalty:  &presPen,
+				ProviderOptions:  map[string]any{"key": "value"},
+			},
+			config.SelectedModelTypeSmall: {
+				Model:    "gpt-4o-mini",
+				Provider: "openai",
+			},
+		}
+		data, err := marshalModels(models)
+		require.NoError(t, err)
+		result, err := unmarshalModels(data)
+		require.NoError(t, err)
+		require.Equal(t, models, result)
+	})
+}
+
+func TestUnmarshalModels(t *testing.T) {
+	t.Parallel()
+
+	t.Run("empty string", func(t *testing.T) {
+		t.Parallel()
+		result, err := unmarshalModels("")
+		require.NoError(t, err)
+		require.Nil(t, result)
+	})
+
+	t.Run("valid JSON", func(t *testing.T) {
+		t.Parallel()
+		data := `{"large":{"model":"gpt-4o","provider":"openai"}}`
+		result, err := unmarshalModels(data)
+		require.NoError(t, err)
+		require.Equal(t, "gpt-4o", result[config.SelectedModelTypeLarge].Model)
+		require.Equal(t, "openai", result[config.SelectedModelTypeLarge].Provider)
+	})
+
+	t.Run("invalid JSON", func(t *testing.T) {
+		t.Parallel()
+		_, err := unmarshalModels("{invalid}")
+		require.Error(t, err)
+	})
+}
+
+func TestFromDBItemWithModels(t *testing.T) {
+	t.Parallel()
+
+	t.Run("null models", func(t *testing.T) {
+		t.Parallel()
+		item := testDBSession()
+		item.Models = sql.NullString{Valid: false}
+		result := service{}.fromDBItem(item)
+		require.Nil(t, result.Models)
+	})
+
+	t.Run("empty models", func(t *testing.T) {
+		t.Parallel()
+		item := testDBSession()
+		item.Models = sql.NullString{String: "", Valid: false}
+		result := service{}.fromDBItem(item)
+		require.Nil(t, result.Models)
+	})
+
+	t.Run("valid models", func(t *testing.T) {
+		t.Parallel()
+		item := testDBSession()
+		item.Models = sql.NullString{
+			String: `{"large":{"model":"gpt-4o","provider":"openai"}}`,
+			Valid:  true,
+		}
+		result := service{}.fromDBItem(item)
+		require.NotNil(t, result.Models)
+		require.Equal(t, "gpt-4o", result.Models[config.SelectedModelTypeLarge].Model)
+	})
+
+	t.Run("invalid JSON models", func(t *testing.T) {
+		t.Parallel()
+		item := testDBSession()
+		item.Models = sql.NullString{
+			String: "{invalid}",
+			Valid:  true,
+		}
+		result := service{}.fromDBItem(item)
+		require.Nil(t, result.Models)
+	})
+}
+
+func testDBSession() db.Session {
+	return db.Session{
+		ID:    "test-id",
+		Title: "Test",
+	}
+}

internal/ui/model/ui.go 🔗

@@ -498,6 +498,29 @@ func (m *UI) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 		m.setState(uiChat, m.focus)
 		m.session = msg.session
 		m.sessionFiles = msg.files
+		if msg.session.Models != nil {
+			cfg := m.com.Config()
+			anyChanged := false
+			for modelType, selectedModel := range msg.session.Models {
+				if current, ok := cfg.Models[modelType]; ok && current.Equal(selectedModel) {
+					continue
+				}
+				if cfg.GetModel(selectedModel.Provider, selectedModel.Model) != nil {
+					if err := m.com.Workspace.UpdatePreferredModel(config.ScopeGlobal, modelType, selectedModel); err != nil {
+						slog.Error("Failed to restore model", "type", modelType, "error", err)
+					}
+					anyChanged = true
+				}
+			}
+			if anyChanged {
+				cmds = append(cmds, func() tea.Msg {
+					if err := m.com.Workspace.UpdateAgentModel(context.TODO()); err != nil {
+						return util.ReportError(err)
+					}
+					return nil
+				})
+			}
+		}
 		cmds = append(cmds, m.startLSPs(msg.lspFilePaths()))
 		msgs, err := m.com.Workspace.ListMessages(context.Background(), m.session.ID)
 		if err != nil {
@@ -1376,6 +1399,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
 			}
 			return util.NewInfoMsg("Thinking mode " + status)
 		})
+		if m.session != nil {
+			cmds = append(cmds, m.saveSessionModels())
+		}
 		m.dialog.CloseDialog(dialog.CommandsID)
 	case dialog.ActionToggleTransparentBackground:
 		cmds = append(cmds, func() tea.Msg {
@@ -1465,6 +1491,10 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
 			return util.NewInfoMsg(modelMsg)
 		})
 
+		if m.session != nil {
+			cmds = append(cmds, m.saveSessionModels())
+		}
+
 		m.dialog.CloseDialog(dialog.APIKeyInputID)
 		m.dialog.CloseDialog(dialog.OAuthID)
 		m.dialog.CloseDialog(dialog.ModelsID)
@@ -1505,6 +1535,9 @@ func (m *UI) handleDialogMsg(msg tea.Msg) tea.Cmd {
 			m.com.Workspace.UpdateAgentModel(context.TODO())
 			return util.NewInfoMsg("Reasoning effort set to " + msg.Effort)
 		})
+		if m.session != nil {
+			cmds = append(cmds, m.saveSessionModels())
+		}
 		m.dialog.CloseDialog(dialog.ReasoningID)
 	case dialog.ActionPermissionResponse:
 		m.dialog.CloseDialog(dialog.PermissionsID)
@@ -3652,3 +3685,18 @@ func renderLogo(t *styles.Styles, compact bool, width int) string {
 		Width:        width,
 	})
 }
+
+func (m *UI) saveSessionModels() tea.Cmd {
+	session := m.session
+	if session == nil {
+		return nil
+	}
+	models := m.com.Config().Models
+	return func() tea.Msg {
+		err := m.com.Workspace.UpdateSessionModels(context.Background(), session.ID, models)
+		if err != nil {
+			return util.ReportError(err)
+		}
+		return nil
+	}
+}

internal/workspace/app_workspace.go 🔗

@@ -59,6 +59,10 @@ func (w *AppWorkspace) DeleteSession(ctx context.Context, sessionID string) erro
 	return w.app.Sessions.Delete(ctx, sessionID)
 }
 
+func (w *AppWorkspace) UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error {
+	return w.app.Sessions.UpdateSessionModels(ctx, sessionID, models)
+}
+
 func (w *AppWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string {
 	return w.app.Sessions.CreateAgentToolSessionID(messageID, toolCallID)
 }

internal/workspace/client_workspace.go 🔗

@@ -119,6 +119,10 @@ func (w *ClientWorkspace) DeleteSession(ctx context.Context, sessionID string) e
 	return w.client.DeleteSession(ctx, w.workspaceID(), sessionID)
 }
 
+func (w *ClientWorkspace) UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error {
+	return w.client.UpdateSessionModels(ctx, w.workspaceID(), sessionID, models)
+}
+
 func (w *ClientWorkspace) CreateAgentToolSessionID(messageID, toolCallID string) string {
 	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
 }
@@ -673,6 +677,7 @@ func protoToSession(s proto.Session) session.Session {
 		Cost:             s.Cost,
 		CreatedAt:        s.CreatedAt,
 		UpdatedAt:        s.UpdatedAt,
+		Models:           convertModelsFromProto(s.Models),
 	}
 }
 
@@ -769,5 +774,52 @@ func sessionToProto(s session.Session) proto.Session {
 		Cost:             s.Cost,
 		CreatedAt:        s.CreatedAt,
 		UpdatedAt:        s.UpdatedAt,
+		Models:           convertModelsToProtoClient(s.Models),
+	}
+}
+
+func convertModelsFromProto(models map[proto.SelectedModelType]proto.SelectedModel) map[config.SelectedModelType]config.SelectedModel {
+	if models == nil {
+		return nil
+	}
+	result := make(map[config.SelectedModelType]config.SelectedModel, len(models))
+	for k, v := range models {
+		result[config.SelectedModelType(k)] = config.SelectedModel{
+			Model:            v.Model,
+			Provider:         v.Provider,
+			ReasoningEffort:  v.ReasoningEffort,
+			Think:            v.Think,
+			MaxTokens:        v.MaxTokens,
+			Temperature:      v.Temperature,
+			TopP:             v.TopP,
+			TopK:             v.TopK,
+			FrequencyPenalty: v.FrequencyPenalty,
+			PresencePenalty:  v.PresencePenalty,
+			ProviderOptions:  v.ProviderOptions,
+		}
 	}
+	return result
+}
+
+func convertModelsToProtoClient(models map[config.SelectedModelType]config.SelectedModel) map[proto.SelectedModelType]proto.SelectedModel {
+	if models == nil {
+		return nil
+	}
+	result := make(map[proto.SelectedModelType]proto.SelectedModel, len(models))
+	for k, v := range models {
+		result[proto.SelectedModelType(k)] = proto.SelectedModel{
+			Model:            v.Model,
+			Provider:         v.Provider,
+			ReasoningEffort:  v.ReasoningEffort,
+			Think:            v.Think,
+			MaxTokens:        v.MaxTokens,
+			Temperature:      v.Temperature,
+			TopP:             v.TopP,
+			TopK:             v.TopK,
+			FrequencyPenalty: v.FrequencyPenalty,
+			PresencePenalty:  v.PresencePenalty,
+			ProviderOptions:  v.ProviderOptions,
+		}
+	}
+	return result
 }

internal/workspace/workspace.go 🔗

@@ -65,6 +65,7 @@ type Workspace interface {
 	ListSessions(ctx context.Context) ([]session.Session, error)
 	SaveSession(ctx context.Context, sess session.Session) (session.Session, error)
 	DeleteSession(ctx context.Context, sessionID string) error
+	UpdateSessionModels(ctx context.Context, sessionID string, models map[config.SelectedModelType]config.SelectedModel) error
 	CreateAgentToolSessionID(messageID, toolCallID string) string
 	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)