diff --git a/bramble-core/src/main/java/org/briarproject/bramble/rendezvous/RendezvousPollerImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/rendezvous/RendezvousPollerImpl.java index b7d355548a545fdeaa77d1b41f93969586b18330..11340b2338175b71e9f37db7eb6870f658d29277 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/rendezvous/RendezvousPollerImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/rendezvous/RendezvousPollerImpl.java @@ -47,6 +47,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.Logger; import javax.annotation.Nullable; @@ -79,6 +80,7 @@ class RendezvousPollerImpl implements RendezvousPoller, Service, EventListener { private final EventBus eventBus; private final Clock clock; + private final AtomicBoolean used = new AtomicBoolean(false); // Executor that runs one task at a time private final Executor worker; // The following fields are only accessed on the worker @@ -113,6 +115,7 @@ class RendezvousPollerImpl implements RendezvousPoller, Service, EventListener { @Override public void startService() throws ServiceException { + if (used.getAndSet(true)) throw new IllegalStateException(); try { db.transaction(true, txn -> { Collection pending = db.getPendingContacts(txn); diff --git a/bramble-core/src/test/java/org/briarproject/bramble/rendezvous/RendezvousPollerImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/rendezvous/RendezvousPollerImplTest.java index 92af0f780eda0db9d74b35890f61d2fda9d0f488..4832ff8704e5fd1f868559e4db874d964e10382c 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/rendezvous/RendezvousPollerImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/rendezvous/RendezvousPollerImplTest.java @@ -115,20 +115,10 @@ public class RendezvousPollerImplTest extends BrambleMockTestCase { with(MILLISECONDS)); will(new CaptureArgumentAction<>(captureExpiryTask, Runnable.class, 0)); - // Load our handshake key pair - oneOf(db).transactionWithResult(with(true), withDbCallable(txn)); - will(returnValue(handshakeKeyPair)); - // Derive the rendezvous key - oneOf(transportCrypto).deriveStaticMasterKey( - pendingContact.getPublicKey(), handshakeKeyPair); - will(returnValue(staticMasterKey)); - oneOf(rendezvousCrypto).deriveRendezvousKey(staticMasterKey); - will(returnValue(rendezvousKey)); - oneOf(transportCrypto).isAlice(pendingContact.getPublicKey(), - handshakeKeyPair); - will(returnValue(alice)); }}); + expectDeriveRendezvousKey(); + rendezvousPoller.startService(); context.assertIsSatisfied(); @@ -157,45 +147,27 @@ public class RendezvousPollerImplTest extends BrambleMockTestCase { @Test public void testCreatesAndClosesEndpointsWhenPendingContactIsAddedAndRemoved() throws Exception { - Transaction txn = new Transaction(null, true); long now = pendingContact.getTimestamp(); // Enable the transport - no endpoints should be created yet - context.checking(new Expectations() {{ - oneOf(pluginManager).getPlugin(transportId); - will(returnValue(plugin)); - oneOf(plugin).supportsRendezvous(); - will(returnValue(true)); - allowing(plugin).getId(); - will(returnValue(transportId)); - }}); + expectGetPlugin(); rendezvousPoller.eventOccurred(new TransportEnabledEvent(transportId)); context.assertIsSatisfied(); // Add the pending contact - endpoint should be created and polled - context.checking(new DbExpectations() {{ + context.checking(new Expectations() {{ // Add pending contact oneOf(clock).currentTimeMillis(); will(returnValue(now)); oneOf(scheduler).schedule(with(any(Runnable.class)), with(RENDEZVOUS_TIMEOUT_MS), with(MILLISECONDS)); - oneOf(db).transactionWithResult(with(true), withDbCallable(txn)); - will(returnValue(handshakeKeyPair)); - oneOf(transportCrypto).deriveStaticMasterKey( - pendingContact.getPublicKey(), handshakeKeyPair); - will(returnValue(staticMasterKey)); - oneOf(rendezvousCrypto).deriveRendezvousKey(staticMasterKey); - will(returnValue(rendezvousKey)); - oneOf(transportCrypto).isAlice(pendingContact.getPublicKey(), - handshakeKeyPair); - will(returnValue(alice)); - oneOf(rendezvousCrypto).createKeyMaterialSource(rendezvousKey, - transportId); - will(returnValue(keyMaterialSource)); - oneOf(plugin).createRendezvousEndpoint(with(keyMaterialSource), - with(alice), with(any(ConnectionHandler.class))); - will(returnValue(rendezvousEndpoint)); + }}); + + expectDeriveRendezvousKey(); + expectCreateEndpoint(); + + context.checking(new Expectations() {{ // Poll newly added pending contact oneOf(rendezvousEndpoint).getRemoteTransportProperties(); will(returnValue(transportProperties)); @@ -224,25 +196,17 @@ public class RendezvousPollerImplTest extends BrambleMockTestCase { @Test public void testCreatesAndClosesEndpointsWhenPendingContactIsAddedAndExpired() throws Exception { - Transaction txn = new Transaction(null, true); long now = pendingContact.getTimestamp(); AtomicReference captureExpiryTask = new AtomicReference<>(); // Enable the transport - no endpoints should be created yet - context.checking(new Expectations() {{ - oneOf(pluginManager).getPlugin(transportId); - will(returnValue(plugin)); - oneOf(plugin).supportsRendezvous(); - will(returnValue(true)); - allowing(plugin).getId(); - will(returnValue(transportId)); - }}); + expectGetPlugin(); rendezvousPoller.eventOccurred(new TransportEnabledEvent(transportId)); context.assertIsSatisfied(); // Add the pending contact - endpoint should be created and polled - context.checking(new DbExpectations() {{ + context.checking(new Expectations() {{ // Add pending contact oneOf(clock).currentTimeMillis(); will(returnValue(now)); @@ -251,22 +215,12 @@ public class RendezvousPollerImplTest extends BrambleMockTestCase { with(RENDEZVOUS_TIMEOUT_MS), with(MILLISECONDS)); will(new CaptureArgumentAction<>(captureExpiryTask, Runnable.class, 0)); - oneOf(db).transactionWithResult(with(true), withDbCallable(txn)); - will(returnValue(handshakeKeyPair)); - oneOf(transportCrypto).deriveStaticMasterKey( - pendingContact.getPublicKey(), handshakeKeyPair); - will(returnValue(staticMasterKey)); - oneOf(rendezvousCrypto).deriveRendezvousKey(staticMasterKey); - will(returnValue(rendezvousKey)); - oneOf(transportCrypto).isAlice(pendingContact.getPublicKey(), - handshakeKeyPair); - will(returnValue(alice)); - oneOf(rendezvousCrypto).createKeyMaterialSource(rendezvousKey, - transportId); - will(returnValue(keyMaterialSource)); - oneOf(plugin).createRendezvousEndpoint(with(keyMaterialSource), - with(alice), with(any(ConnectionHandler.class))); - will(returnValue(rendezvousEndpoint)); + }}); + + expectDeriveRendezvousKey(); + expectCreateEndpoint(); + + context.checking(new Expectations() {{ // Poll newly added pending contact oneOf(rendezvousEndpoint).getRemoteTransportProperties(); will(returnValue(transportProperties)); @@ -300,7 +254,6 @@ public class RendezvousPollerImplTest extends BrambleMockTestCase { @Test public void testCreatesAndClosesEndpointsWhenTransportIsEnabledAndDisabled() throws Exception { - Transaction txn = new Transaction(null, true); long now = pendingContact.getTimestamp(); // Add the pending contact - no endpoints should be created yet @@ -309,6 +262,38 @@ public class RendezvousPollerImplTest extends BrambleMockTestCase { will(returnValue(now)); oneOf(scheduler).schedule(with(any(Runnable.class)), with(RENDEZVOUS_TIMEOUT_MS), with(MILLISECONDS)); + }}); + + expectDeriveRendezvousKey(); + + rendezvousPoller.eventOccurred( + new PendingContactAddedEvent(pendingContact)); + context.assertIsSatisfied(); + + // Enable the transport - endpoint should be created + expectGetPlugin(); + expectCreateEndpoint(); + + rendezvousPoller.eventOccurred(new TransportEnabledEvent(transportId)); + context.assertIsSatisfied(); + + // Disable the transport - endpoint should be closed + context.checking(new Expectations() {{ + oneOf(rendezvousEndpoint).close(); + }}); + + rendezvousPoller.eventOccurred(new TransportDisabledEvent(transportId)); + context.assertIsSatisfied(); + + // Remove the pending contact - endpoint is already closed + rendezvousPoller.eventOccurred( + new PendingContactRemovedEvent(pendingContact.getId())); + } + + private void expectDeriveRendezvousKey() throws Exception { + Transaction txn = new Transaction(null, true); + + context.checking(new DbExpectations() {{ oneOf(db).transactionWithResult(with(true), withDbCallable(txn)); will(returnValue(handshakeKeyPair)); oneOf(transportCrypto).deriveStaticMasterKey( @@ -320,19 +305,10 @@ public class RendezvousPollerImplTest extends BrambleMockTestCase { handshakeKeyPair); will(returnValue(alice)); }}); + } - rendezvousPoller.eventOccurred( - new PendingContactAddedEvent(pendingContact)); - context.assertIsSatisfied(); - - // Enable the transport - endpoint should be created + private void expectCreateEndpoint() { context.checking(new Expectations() {{ - oneOf(pluginManager).getPlugin(transportId); - will(returnValue(plugin)); - oneOf(plugin).supportsRendezvous(); - will(returnValue(true)); - allowing(plugin).getId(); - will(returnValue(transportId)); oneOf(rendezvousCrypto).createKeyMaterialSource(rendezvousKey, transportId); will(returnValue(keyMaterialSource)); @@ -340,20 +316,16 @@ public class RendezvousPollerImplTest extends BrambleMockTestCase { with(alice), with(any(ConnectionHandler.class))); will(returnValue(rendezvousEndpoint)); }}); + } - rendezvousPoller.eventOccurred(new TransportEnabledEvent(transportId)); - context.assertIsSatisfied(); - - // Disable the transport - endpoint should be closed + private void expectGetPlugin() { context.checking(new Expectations() {{ - oneOf(rendezvousEndpoint).close(); + oneOf(pluginManager).getPlugin(transportId); + will(returnValue(plugin)); + oneOf(plugin).supportsRendezvous(); + will(returnValue(true)); + allowing(plugin).getId(); + will(returnValue(transportId)); }}); - - rendezvousPoller.eventOccurred(new TransportDisabledEvent(transportId)); - context.assertIsSatisfied(); - - // Remove the pending contact - endpoint is already closed - rendezvousPoller.eventOccurred( - new PendingContactRemovedEvent(pendingContact.getId())); } }