diff --git a/lib/registration.rb b/lib/registration.rb index 12466d3cbaf2409ada01ba4e13a397b80c899ec8..91f0c78bfd8be9fe4ece9b46fda79a3f6a9dd565 100644 --- a/lib/registration.rb +++ b/lib/registration.rb @@ -523,23 +523,33 @@ class Registration put_default_fwd ]) }.then do - FinishOnboarding.for(@customer, @tel).write + FinishOnboarding.for(@customer, @tel).then(&:write) end end end module FinishOnboarding - def self.for(customer, tel) + def self.for(customer, tel, db: LazyObject.new { DB }) jid = ProxiedJID.new(customer.jid).unproxied if jid.domain == CONFIG[:onboarding_domain] - Snikket.new(customer, tel) + Snikket.for(customer, tel, db: db) else NotOnboarding.new(customer, tel) end end class Snikket - def initialize(customer, tel, error: nil, db: LazyObject.new { DB }) + def self.for(customer, tel, db:) + ::Snikket::Repo.new(db: db).find_by_customer(customer).then do |is| + if is.empty? + new(customer, tel, db: db) + else + GetInvite.for(is[0]) + end + end + end + + def initialize(customer, tel, error: nil, db:) @customer = customer @tel = tel @error = error @@ -577,7 +587,7 @@ class Registration }.catch { |e| next EMPromise.reject(e) unless e.respond_to?(:text) - Snikket.new(@customer, @tel, error: e.text).write + Snikket.new(@customer, @tel, error: e.text, db: @db).write } end diff --git a/lib/snikket.rb b/lib/snikket.rb index 50e1d137bbf2bc081272eb47f52847ebafb76ac0..1b9c638a7341b6f1a5e6daf665d5da546d9fff84 100644 --- a/lib/snikket.rb +++ b/lib/snikket.rb @@ -143,7 +143,7 @@ module Snikket end def find_by_customer(customer) - promise = @db.query(<<~SQL, [customer.customer_id]) + promise = @db.query_defer(<<~SQL, [customer.customer_id]) SELECT instance_id, bootstrap_token, customer_id, domain FROM snikket_instances WHERE customer_id=$1 @@ -158,7 +158,7 @@ module Snikket instance.instance_id, instance.bootstrap_token, instance.customer_id, instance.domain ] - @db.exec(<<~SQL, params) + @db.exec_defer(<<~SQL, params) INSERT INTO snikket_instances (instance_id, boostrap_token, customer_id, domain) VALUES diff --git a/test/test_helper.rb b/test/test_helper.rb index 7aae1a144134f4a704fb145e839df03e8cc10041..02627dc687142e020d5c5232e765f1b7f0590682 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -310,7 +310,7 @@ class FakeDB EMPromise.resolve(row || default) end - def exec(_, _) + def exec_defer(_, _) EMPromise.resolve(nil) end end diff --git a/test/test_registration.rb b/test/test_registration.rb index 7fd2463ceb89920b64d0b3ffc737d83c5835ade5..719fa30e6b3631155681fa0b78dc95da811c59e6 100644 --- a/test/test_registration.rb +++ b/test/test_registration.rb @@ -710,6 +710,7 @@ class RegistrationTest < Minitest::Test Registration::Finish::TEL_SELECTIONS = FakeTelSelections.new Registration::Finish::REDIS = Minitest::Mock.new Bwmsgsv2Repo::REDIS = Minitest::Mock.new + Registration::FinishOnboarding::DB = FakeDB.new def setup @sgx = Minitest::Mock.new(TrivialBackendSgxRepo.new.get("test"))