entity: support different author in staging operations

Michael Muré created

Change summary

entity/dag/clock.go               |  4 
entity/dag/common_test.go         |  8 +-
entity/dag/entity.go              | 99 ++++++++++++++++++---------------
entity/dag/entity_actions.go      | 20 +++---
entity/dag/entity_actions_test.go | 14 ++--
entity/dag/entity_test.go         | 47 +++++++++------
entity/dag/operation.go           |  8 --
entity/dag/operation_pack.go      | 14 ++--
8 files changed, 113 insertions(+), 101 deletions(-)

Detailed changes

entity/dag/clock.go 🔗

@@ -11,8 +11,8 @@ import (
 func ClockLoader(defs ...Definition) repository.ClockLoader {
 	clocks := make([]string, len(defs)*2)
 	for _, def := range defs {
-		clocks = append(clocks, fmt.Sprintf(creationClockPattern, def.namespace))
-		clocks = append(clocks, fmt.Sprintf(editClockPattern, def.namespace))
+		clocks = append(clocks, fmt.Sprintf(creationClockPattern, def.Namespace))
+		clocks = append(clocks, fmt.Sprintf(editClockPattern, def.Namespace))
 	}
 
 	return repository.ClockLoader{

entity/dag/common_test.go 🔗

@@ -152,10 +152,10 @@ func makeTestContextInternal(repo repository.ClockedRepo) (identity.Interface, i
 	})
 
 	def := Definition{
-		typename:             "foo",
-		namespace:            "foos",
-		operationUnmarshaler: unmarshaler,
-		formatVersion:        1,
+		Typename:             "foo",
+		Namespace:            "foos",
+		OperationUnmarshaler: unmarshaler,
+		FormatVersion:        1,
 	}
 
 	return id1, id2, resolver, def

entity/dag/entity.go 🔗

@@ -22,13 +22,13 @@ const editClockPattern = "%s-edit"
 // Definition hold the details defining one specialization of an Entity.
 type Definition struct {
 	// the name of the entity (bug, pull-request, ...)
-	typename string
-	// the namespace in git (bugs, prs, ...)
-	namespace string
+	Typename string
+	// the Namespace in git (bugs, prs, ...)
+	Namespace string
 	// a function decoding a JSON message into an Operation
-	operationUnmarshaler func(author identity.Interface, raw json.RawMessage) (Operation, error)
+	OperationUnmarshaler func(author identity.Interface, raw json.RawMessage) (Operation, error)
 	// the expected format version number, that can be used for data migration/upgrade
-	formatVersion uint
+	FormatVersion uint
 }
 
 // Entity is a data structure stored in a chain of git objects, supporting actions like Push, Pull and Merge.
@@ -62,7 +62,7 @@ func Read(def Definition, repo repository.ClockedRepo, resolver identity.Resolve
 		return nil, errors.Wrap(err, "invalid id")
 	}
 
-	ref := fmt.Sprintf("refs/%s/%s", def.namespace, id.String())
+	ref := fmt.Sprintf("refs/%s/%s", def.Namespace, id.String())
 
 	return read(def, repo, resolver, ref)
 }
@@ -73,7 +73,7 @@ func readRemote(def Definition, repo repository.ClockedRepo, resolver identity.R
 		return nil, errors.Wrap(err, "invalid id")
 	}
 
-	ref := fmt.Sprintf("refs/remotes/%s/%s/%s", def.namespace, remote, id.String())
+	ref := fmt.Sprintf("refs/remotes/%s/%s/%s", def.Namespace, remote, id.String())
 
 	return read(def, repo, resolver, ref)
 }
@@ -179,11 +179,11 @@ func read(def Definition, repo repository.ClockedRepo, resolver identity.Resolve
 
 	// The clocks are fine, we witness them
 	for _, opp := range oppMap {
-		err = repo.Witness(fmt.Sprintf(creationClockPattern, def.namespace), opp.CreateTime)
+		err = repo.Witness(fmt.Sprintf(creationClockPattern, def.Namespace), opp.CreateTime)
 		if err != nil {
 			return nil, err
 		}
-		err = repo.Witness(fmt.Sprintf(editClockPattern, def.namespace), opp.EditTime)
+		err = repo.Witness(fmt.Sprintf(editClockPattern, def.Namespace), opp.EditTime)
 		if err != nil {
 			return nil, err
 		}
@@ -247,7 +247,7 @@ func ReadAll(def Definition, repo repository.ClockedRepo, resolver identity.Reso
 	go func() {
 		defer close(out)
 
-		refPrefix := fmt.Sprintf("refs/%s/", def.namespace)
+		refPrefix := fmt.Sprintf("refs/%s/", def.Namespace)
 
 		refs, err := repo.ListRefs(refPrefix)
 		if err != nil {
@@ -346,9 +346,9 @@ func (e *Entity) NeedCommit() bool {
 	return len(e.staging) > 0
 }
 
-// CommitAdNeeded execute a Commit only if necessary. This function is useful to avoid getting an error if the Entity
+// CommitAsNeeded execute a Commit only if necessary. This function is useful to avoid getting an error if the Entity
 // is already in sync with the repository.
-func (e *Entity) CommitAdNeeded(repo repository.ClockedRepo) error {
+func (e *Entity) CommitAsNeeded(repo repository.ClockedRepo) error {
 	if e.NeedCommit() {
 		return e.Commit(repo)
 	}
@@ -363,56 +363,65 @@ func (e *Entity) Commit(repo repository.ClockedRepo) error {
 
 	err := e.Validate()
 	if err != nil {
-		return errors.Wrapf(err, "can't commit a %s with invalid data", e.Definition.typename)
+		return errors.Wrapf(err, "can't commit a %s with invalid data", e.Definition.Typename)
 	}
 
-	var author identity.Interface
-	for _, op := range e.staging {
-		if author != nil && op.Author() != author {
-			return fmt.Errorf("operations with different author")
+	for len(e.staging) > 0 {
+		var author identity.Interface
+		var toCommit []Operation
+
+		// Split into chunks with the same author
+		for len(e.staging) > 0 {
+			op := e.staging[0]
+			if author != nil && op.Author().Id() != author.Id() {
+				break
+			}
+			author = e.staging[0].Author()
+			toCommit = append(toCommit, op)
+			e.staging = e.staging[1:]
 		}
-		author = op.Author()
-	}
 
-	e.editTime, err = repo.Increment(fmt.Sprintf(editClockPattern, e.namespace))
-	if err != nil {
-		return err
-	}
+		e.editTime, err = repo.Increment(fmt.Sprintf(editClockPattern, e.Namespace))
+		if err != nil {
+			return err
+		}
 
-	opp := &operationPack{
-		Author:     author,
-		Operations: e.staging,
-		EditTime:   e.editTime,
-	}
+		opp := &operationPack{
+			Author:     author,
+			Operations: toCommit,
+			EditTime:   e.editTime,
+		}
 
-	if e.lastCommit == "" {
-		e.createTime, err = repo.Increment(fmt.Sprintf(creationClockPattern, e.namespace))
+		if e.lastCommit == "" {
+			e.createTime, err = repo.Increment(fmt.Sprintf(creationClockPattern, e.Namespace))
+			if err != nil {
+				return err
+			}
+			opp.CreateTime = e.createTime
+		}
+
+		var parentCommit []repository.Hash
+		if e.lastCommit != "" {
+			parentCommit = []repository.Hash{e.lastCommit}
+		}
+
+		commitHash, err := opp.Write(e.Definition, repo, parentCommit...)
 		if err != nil {
 			return err
 		}
-		opp.CreateTime = e.createTime
-	}
 
-	var commitHash repository.Hash
-	if e.lastCommit == "" {
-		commitHash, err = opp.Write(e.Definition, repo)
-	} else {
-		commitHash, err = opp.Write(e.Definition, repo, e.lastCommit)
-	}
-
-	if err != nil {
-		return err
+		e.lastCommit = commitHash
+		e.ops = append(e.ops, toCommit...)
 	}
 
-	e.lastCommit = commitHash
-	e.ops = append(e.ops, e.staging...)
+	// not strictly necessary but make equality testing easier in tests
 	e.staging = nil
 
 	// Create or update the Git reference for this entity
 	// When pushing later, the remote will ensure that this ref update
 	// is fast-forward, that is no data has been overwritten.
-	ref := fmt.Sprintf(refsPattern, e.namespace, e.Id().String())
-	return repo.UpdateRef(ref, commitHash)
+	ref := fmt.Sprintf(refsPattern, e.Namespace, e.Id().String())
+	return repo.UpdateRef(ref, e.lastCommit)
 }
 
 // CreateLamportTime return the Lamport time of creation

entity/dag/entity_actions.go 🔗

@@ -12,7 +12,7 @@ import (
 
 // ListLocalIds list all the available local Entity's Id
 func ListLocalIds(def Definition, repo repository.RepoData) ([]entity.Id, error) {
-	refs, err := repo.ListRefs(fmt.Sprintf("refs/%s/", def.namespace))
+	refs, err := repo.ListRefs(fmt.Sprintf("refs/%s/", def.Namespace))
 	if err != nil {
 		return nil, err
 	}
@@ -22,12 +22,12 @@ func ListLocalIds(def Definition, repo repository.RepoData) ([]entity.Id, error)
 // Fetch retrieve updates from a remote
 // This does not change the local entity state
 func Fetch(def Definition, repo repository.Repo, remote string) (string, error) {
-	return repo.FetchRefs(remote, def.namespace)
+	return repo.FetchRefs(remote, def.Namespace)
 }
 
 // Push update a remote with the local changes
 func Push(def Definition, repo repository.Repo, remote string) (string, error) {
-	return repo.PushRefs(remote, def.namespace)
+	return repo.PushRefs(remote, def.Namespace)
 }
 
 // Pull will do a Fetch + MergeAll
@@ -74,7 +74,7 @@ func MergeAll(def Definition, repo repository.ClockedRepo, resolver identity.Res
 	go func() {
 		defer close(out)
 
-		remoteRefSpec := fmt.Sprintf("refs/remotes/%s/%s/", remote, def.namespace)
+		remoteRefSpec := fmt.Sprintf("refs/remotes/%s/%s/", remote, def.Namespace)
 		remoteRefs, err := repo.ListRefs(remoteRefSpec)
 		if err != nil {
 			out <- entity.MergeResult{Err: err}
@@ -101,16 +101,16 @@ func merge(def Definition, repo repository.ClockedRepo, resolver identity.Resolv
 	remoteEntity, err := read(def, repo, resolver, remoteRef)
 	if err != nil {
 		return entity.NewMergeInvalidStatus(id,
-			errors.Wrapf(err, "remote %s is not readable", def.typename).Error())
+			errors.Wrapf(err, "remote %s is not readable", def.Typename).Error())
 	}
 
 	// Check for error in remote data
 	if err := remoteEntity.Validate(); err != nil {
 		return entity.NewMergeInvalidStatus(id,
-			errors.Wrapf(err, "remote %s data is invalid", def.typename).Error())
+			errors.Wrapf(err, "remote %s data is invalid", def.Typename).Error())
 	}
 
-	localRef := fmt.Sprintf("refs/%s/%s", def.namespace, id.String())
+	localRef := fmt.Sprintf("refs/%s/%s", def.Namespace, id.String())
 
 	// SCENARIO 1
 	// if the remote Entity doesn't exist locally, it's created
@@ -202,7 +202,7 @@ func merge(def Definition, repo repository.ClockedRepo, resolver identity.Resolv
 		return entity.NewMergeError(err, id)
 	}
 
-	editTime, err := repo.Increment(fmt.Sprintf(editClockPattern, def.namespace))
+	editTime, err := repo.Increment(fmt.Sprintf(editClockPattern, def.Namespace))
 	if err != nil {
 		return entity.NewMergeError(err, id)
 	}
@@ -236,7 +236,7 @@ func merge(def Definition, repo repository.ClockedRepo, resolver identity.Resolv
 func Remove(def Definition, repo repository.ClockedRepo, id entity.Id) error {
 	var matches []string
 
-	ref := fmt.Sprintf("refs/%s/%s", def.namespace, id.String())
+	ref := fmt.Sprintf("refs/%s/%s", def.Namespace, id.String())
 	matches = append(matches, ref)
 
 	remotes, err := repo.GetRemotes()
@@ -245,7 +245,7 @@ func Remove(def Definition, repo repository.ClockedRepo, id entity.Id) error {
 	}
 
 	for remote := range remotes {
-		ref = fmt.Sprintf("refs/remotes/%s/%s/%s", remote, def.namespace, id.String())
+		ref = fmt.Sprintf("refs/remotes/%s/%s/%s", remote, def.Namespace, id.String())
 		matches = append(matches, ref)
 	}
 

entity/dag/entity_actions_test.go 🔗

@@ -244,7 +244,7 @@ func TestMerge(t *testing.T) {
 		},
 	}, results)
 
-	assertEqualRefs(t, repoA, repoB, "refs/"+def.namespace)
+	assertEqualRefs(t, repoA, repoB, "refs/"+def.Namespace)
 
 	// SCENARIO 2
 	// if the remote and local Entity have the same state, nothing is changed
@@ -262,7 +262,7 @@ func TestMerge(t *testing.T) {
 		},
 	}, results)
 
-	assertEqualRefs(t, repoA, repoB, "refs/"+def.namespace)
+	assertEqualRefs(t, repoA, repoB, "refs/"+def.Namespace)
 
 	// SCENARIO 3
 	// if the local Entity has new commits but the remote don't, nothing is changed
@@ -288,7 +288,7 @@ func TestMerge(t *testing.T) {
 		},
 	}, results)
 
-	assertNotEqualRefs(t, repoA, repoB, "refs/"+def.namespace)
+	assertNotEqualRefs(t, repoA, repoB, "refs/"+def.Namespace)
 
 	// SCENARIO 4
 	// if the remote has new commit, the local bug is updated to match the same history
@@ -313,7 +313,7 @@ func TestMerge(t *testing.T) {
 		},
 	}, results)
 
-	assertEqualRefs(t, repoA, repoB, "refs/"+def.namespace)
+	assertEqualRefs(t, repoA, repoB, "refs/"+def.Namespace)
 
 	// SCENARIO 5
 	// if both local and remote Entity have new commits (that is, we have a concurrent edition),
@@ -360,7 +360,7 @@ func TestMerge(t *testing.T) {
 		},
 	}, results)
 
-	assertNotEqualRefs(t, repoA, repoB, "refs/"+def.namespace)
+	assertNotEqualRefs(t, repoA, repoB, "refs/"+def.Namespace)
 
 	_, err = Push(def, repoB, "remote")
 	require.NoError(t, err)
@@ -368,7 +368,7 @@ func TestMerge(t *testing.T) {
 	_, err = Fetch(def, repoA, "remote")
 	require.NoError(t, err)
 
-	results = MergeAll(def, repoA, "remote", id1)
+	results = MergeAll(def, repoA, resolver, "remote", id1)
 
 	assertMergeResults(t, []entity.MergeResult{
 		{
@@ -383,7 +383,7 @@ func TestMerge(t *testing.T) {
 
 	// make sure that the graphs become stable over multiple repo, due to the
 	// fast-forward
-	assertEqualRefs(t, repoA, repoB, "refs/"+def.namespace)
+	assertEqualRefs(t, repoA, repoB, "refs/"+def.Namespace)
 }
 
 func TestRemove(t *testing.T) {

entity/dag/entity_test.go 🔗

@@ -7,7 +7,7 @@ import (
 )
 
 func TestWriteRead(t *testing.T) {
-	repo, id1, id2, def := makeTestContext()
+	repo, id1, id2, resolver, def := makeTestContext()
 
 	entity := New(def)
 	require.False(t, entity.NeedCommit())
@@ -16,15 +16,34 @@ func TestWriteRead(t *testing.T) {
 	entity.Append(newOp2(id1, "bar"))
 
 	require.True(t, entity.NeedCommit())
-	require.NoError(t, entity.CommitAdNeeded(repo))
+	require.NoError(t, entity.CommitAsNeeded(repo))
 	require.False(t, entity.NeedCommit())
 
 	entity.Append(newOp2(id2, "foobar"))
 	require.True(t, entity.NeedCommit())
-	require.NoError(t, entity.CommitAdNeeded(repo))
+	require.NoError(t, entity.CommitAsNeeded(repo))
 	require.False(t, entity.NeedCommit())
 
-	read, err := Read(def, repo, entity.Id())
+	read, err := Read(def, repo, resolver, entity.Id())
+	require.NoError(t, err)
+
+	assertEqualEntities(t, entity, read)
+}
+
+func TestWriteReadMultipleAuthor(t *testing.T) {
+	repo, id1, id2, resolver, def := makeTestContext()
+
+	entity := New(def)
+
+	entity.Append(newOp1(id1, "foo"))
+	entity.Append(newOp2(id2, "bar"))
+
+	require.NoError(t, entity.CommitAsNeeded(repo))
+
+	entity.Append(newOp2(id1, "foobar"))
+	require.NoError(t, entity.CommitAsNeeded(repo))
+
+	read, err := Read(def, repo, resolver, entity.Id())
 	require.NoError(t, err)
 
 	assertEqualEntities(t, entity, read)
@@ -34,23 +53,15 @@ func assertEqualEntities(t *testing.T, a, b *Entity) {
 	// testify doesn't support comparing functions and systematically fail if they are not nil
 	// so we have to set them to nil temporarily
 
-	backOpUnA := a.Definition.operationUnmarshaler
-	backOpUnB := b.Definition.operationUnmarshaler
-
-	a.Definition.operationUnmarshaler = nil
-	b.Definition.operationUnmarshaler = nil
-
-	backIdResA := a.Definition.identityResolver
-	backIdResB := b.Definition.identityResolver
+	backOpUnA := a.Definition.OperationUnmarshaler
+	backOpUnB := b.Definition.OperationUnmarshaler
 
-	a.Definition.identityResolver = nil
-	b.Definition.identityResolver = nil
+	a.Definition.OperationUnmarshaler = nil
+	b.Definition.OperationUnmarshaler = nil
 
 	defer func() {
-		a.Definition.operationUnmarshaler = backOpUnA
-		b.Definition.operationUnmarshaler = backOpUnB
-		a.Definition.identityResolver = backIdResA
-		b.Definition.identityResolver = backIdResB
+		a.Definition.OperationUnmarshaler = backOpUnA
+		b.Definition.OperationUnmarshaler = backOpUnB
 	}()
 
 	require.Equal(t, a, b)

entity/dag/operation.go 🔗

@@ -23,11 +23,3 @@ type Operation interface {
 	// Author returns the author of this operation
 	Author() identity.Interface
 }
-
-// TODO: remove?
-type operationBase struct {
-	author identity.Interface
-
-	// Not serialized. Store the op's id in memory.
-	id entity.Id
-}

entity/dag/operation_pack.go 🔗

@@ -72,7 +72,7 @@ func (opp *operationPack) Validate() error {
 		return fmt.Errorf("missing author")
 	}
 	for _, op := range opp.Operations {
-		if op.Author() != opp.Author {
+		if op.Author().Id() != opp.Author.Id() {
 			return fmt.Errorf("operation has different author than the operationPack's")
 		}
 	}
@@ -120,7 +120,7 @@ func (opp *operationPack) Write(def Definition, repo repository.Repo, parentComm
 	// - clocks
 	tree := []repository.TreeEntry{
 		{ObjectType: repository.Blob, Hash: emptyBlobHash,
-			Name: fmt.Sprintf(versionEntryPrefix+"%d", def.formatVersion)},
+			Name: fmt.Sprintf(versionEntryPrefix+"%d", def.FormatVersion)},
 		{ObjectType: repository.Blob, Hash: hash,
 			Name: opsEntryName},
 		{ObjectType: repository.Blob, Hash: emptyBlobHash,
@@ -188,10 +188,10 @@ func readOperationPack(def Definition, repo repository.RepoData, resolver identi
 		}
 	}
 	if version == 0 {
-		return nil, entity.NewErrUnknowFormat(def.formatVersion)
+		return nil, entity.NewErrUnknowFormat(def.FormatVersion)
 	}
-	if version != def.formatVersion {
-		return nil, entity.NewErrInvalidFormat(version, def.formatVersion)
+	if version != def.FormatVersion {
+		return nil, entity.NewErrInvalidFormat(version, def.FormatVersion)
 	}
 
 	var id entity.Id
@@ -230,7 +230,7 @@ func readOperationPack(def Definition, repo repository.RepoData, resolver identi
 	}
 
 	// Verify signature if we expect one
-	keys := author.ValidKeysAtTime(fmt.Sprintf(editClockPattern, def.namespace), editTime)
+	keys := author.ValidKeysAtTime(fmt.Sprintf(editClockPattern, def.Namespace), editTime)
 	if len(keys) > 0 {
 		keyring := PGPKeyring(keys)
 		_, err = openpgp.CheckDetachedSignature(keyring, commit.SignedData, commit.Signature)
@@ -274,7 +274,7 @@ func unmarshallPack(def Definition, resolver identity.Resolver, data []byte) ([]
 
 	for _, raw := range aux.Operations {
 		// delegate to specialized unmarshal function
-		op, err := def.operationUnmarshaler(author, raw)
+		op, err := def.OperationUnmarshaler(author, raw)
 		if err != nil {
 			return nil, nil, err
 		}