From 6b011d2a7d29553c3e0d3c1295c67a668731062e Mon Sep 17 00:00:00 2001 From: akwizgran <michael@briarproject.org> Date: Sat, 28 Apr 2018 22:15:59 +0100 Subject: [PATCH] Update transport keys in-place to retain key set IDs. --- .../org/briarproject/bramble/db/Database.java | 4 +- .../bramble/db/DatabaseComponentImpl.java | 4 +- .../briarproject/bramble/db/JdbcDatabase.java | 85 +++++++++++-- .../bramble/db/DatabaseComponentImplTest.java | 6 +- .../bramble/db/JdbcDatabaseTest.java | 115 +++++++++++++----- 5 files changed, 166 insertions(+), 48 deletions(-) diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java b/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java index dcf96ea4e1..0c3b2b1bcc 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/db/Database.java @@ -654,7 +654,7 @@ interface Database<T> { throws DbException; /** - * Stores the given transport keys, deleting any keys they have replaced. + * Updates the given transport keys following key rotation. */ - void updateTransportKeys(T txn, Collection<KeySet> keys) throws DbException; + void updateTransportKeys(T txn, KeySet ks) throws DbException; } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java b/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java index 00f0bb6229..aaab4b9dd8 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/db/DatabaseComponentImpl.java @@ -903,11 +903,9 @@ class DatabaseComponentImpl<T> implements DatabaseComponent { Collection<KeySet> keys) throws DbException { if (transaction.isReadOnly()) throw new IllegalArgumentException(); T txn = unbox(transaction); - Collection<KeySet> filtered = new ArrayList<>(); for (KeySet ks : keys) { TransportId t = ks.getTransportKeys().getTransportId(); - if (db.containsTransport(txn, t)) filtered.add(ks); + if (db.containsTransport(txn, t)) db.updateTransportKeys(txn, ks); } - db.updateTransportKeys(txn, filtered); } } diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java index f63bcb0a6b..dd0e928a85 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java @@ -74,7 +74,12 @@ import static org.briarproject.bramble.db.ExponentialBackoff.calculateExpiry; abstract class JdbcDatabase implements Database<Connection> { // Package access for testing - static final int CODE_SCHEMA_VERSION = 36; + static final int CODE_SCHEMA_VERSION = 37; + + // Rotation period offsets for incoming transport keys + private static final int OFFSET_PREV = -1; + private static final int OFFSET_CURR = 0; + private static final int OFFSET_NEXT = 1; private static final String CREATE_SETTINGS = "CREATE TABLE settings" @@ -254,7 +259,8 @@ abstract class JdbcDatabase implements Database<Connection> { + " headerKey _SECRET NOT NULL," + " base BIGINT NOT NULL," + " bitmap _BINARY NOT NULL," - + " PRIMARY KEY (transportId, keySetId, rotationPeriod)," + + " periodOffset INT NOT NULL," + + " PRIMARY KEY (transportId, keySetId, periodOffset)," + " FOREIGN KEY (transportId)" + " REFERENCES transports (transportId)" + " ON DELETE CASCADE," @@ -908,8 +914,9 @@ abstract class JdbcDatabase implements Database<Connection> { ps.close(); // Store the incoming keys sql = "INSERT INTO incomingKeys (keySetId, contactId, transportId," - + " rotationPeriod, tagKey, headerKey, base, bitmap)" - + " VALUES (?, ?, ?, ?, ?, ?, ?, ?)"; + + " rotationPeriod, tagKey, headerKey, base, bitmap," + + " periodOffset)" + + " VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; ps = txn.prepareStatement(sql); ps.setInt(1, keySetId.getInt()); if (c == null) ps.setNull(2, INTEGER); @@ -922,6 +929,7 @@ abstract class JdbcDatabase implements Database<Connection> { ps.setBytes(6, inPrev.getHeaderKey().getBytes()); ps.setLong(7, inPrev.getWindowBase()); ps.setBytes(8, inPrev.getWindowBitmap()); + ps.setInt(9, OFFSET_PREV); ps.addBatch(); // Current rotation period IncomingKeys inCurr = k.getCurrentIncomingKeys(); @@ -930,6 +938,7 @@ abstract class JdbcDatabase implements Database<Connection> { ps.setBytes(6, inCurr.getHeaderKey().getBytes()); ps.setLong(7, inCurr.getWindowBase()); ps.setBytes(8, inCurr.getWindowBitmap()); + ps.setInt(9, OFFSET_CURR); ps.addBatch(); // Next rotation period IncomingKeys inNext = k.getNextIncomingKeys(); @@ -938,6 +947,7 @@ abstract class JdbcDatabase implements Database<Connection> { ps.setBytes(6, inNext.getHeaderKey().getBytes()); ps.setLong(7, inNext.getWindowBase()); ps.setBytes(8, inNext.getWindowBitmap()); + ps.setInt(9, OFFSET_NEXT); ps.addBatch(); int[] batchAffected = ps.executeBatch(); if (batchAffected.length != 3) throw new DbStateException(); @@ -2141,7 +2151,7 @@ abstract class JdbcDatabase implements Database<Connection> { + " base, bitmap" + " FROM incomingKeys" + " WHERE transportId = ?" - + " ORDER BY keySetId, rotationPeriod"; + + " ORDER BY keySetId, periodOffset"; ps = txn.prepareStatement(sql); ps.setString(1, t.getString()); rs = ps.executeQuery(); @@ -2944,12 +2954,69 @@ abstract class JdbcDatabase implements Database<Connection> { } @Override - public void updateTransportKeys(Connection txn, Collection<KeySet> keys) + public void updateTransportKeys(Connection txn, KeySet ks) throws DbException { - for (KeySet ks : keys) { + PreparedStatement ps = null; + try { + // Update the outgoing keys + String sql = "UPDATE outgoingKeys SET rotationPeriod = ?," + + " tagKey = ?, headerKey = ?, stream = ?" + + " WHERE transportId = ? AND keySetId = ?"; + ps = txn.prepareStatement(sql); TransportKeys k = ks.getTransportKeys(); - removeTransportKeys(txn, k.getTransportId(), ks.getKeySetId()); - addTransportKeys(txn, ks.getContactId(), k); + OutgoingKeys outCurr = k.getCurrentOutgoingKeys(); + ps.setLong(1, outCurr.getRotationPeriod()); + ps.setBytes(2, outCurr.getTagKey().getBytes()); + ps.setBytes(3, outCurr.getHeaderKey().getBytes()); + ps.setLong(4, outCurr.getStreamCounter()); + ps.setString(5, k.getTransportId().getString()); + ps.setInt(6, ks.getKeySetId().getInt()); + int affected = ps.executeUpdate(); + if (affected < 0 || affected > 1) throw new DbStateException(); + ps.close(); + // Update the incoming keys + sql = "UPDATE incomingKeys SET rotationPeriod = ?," + + " tagKey = ?, headerKey = ?, base = ?, bitmap = ?" + + " WHERE transportId = ? AND keySetId = ?" + + " AND periodOffset = ?"; + ps = txn.prepareStatement(sql); + ps.setString(6, k.getTransportId().getString()); + ps.setInt(7, ks.getKeySetId().getInt()); + // Previous rotation period + IncomingKeys inPrev = k.getPreviousIncomingKeys(); + ps.setLong(1, inPrev.getRotationPeriod()); + ps.setBytes(2, inPrev.getTagKey().getBytes()); + ps.setBytes(3, inPrev.getHeaderKey().getBytes()); + ps.setLong(4, inPrev.getWindowBase()); + ps.setBytes(5, inPrev.getWindowBitmap()); + ps.setInt(8, OFFSET_PREV); + ps.addBatch(); + // Current rotation period + IncomingKeys inCurr = k.getCurrentIncomingKeys(); + ps.setLong(1, inCurr.getRotationPeriod()); + ps.setBytes(2, inCurr.getTagKey().getBytes()); + ps.setBytes(3, inCurr.getHeaderKey().getBytes()); + ps.setLong(4, inCurr.getWindowBase()); + ps.setBytes(5, inCurr.getWindowBitmap()); + ps.setInt(8, OFFSET_CURR); + ps.addBatch(); + // Next rotation period + IncomingKeys inNext = k.getNextIncomingKeys(); + ps.setLong(1, inNext.getRotationPeriod()); + ps.setBytes(2, inNext.getTagKey().getBytes()); + ps.setBytes(3, inNext.getHeaderKey().getBytes()); + ps.setLong(4, inNext.getWindowBase()); + ps.setBytes(5, inNext.getWindowBitmap()); + ps.setInt(8, OFFSET_NEXT); + ps.addBatch(); + int[] batchAffected = ps.executeBatch(); + if (batchAffected.length != 3) throw new DbStateException(); + for (int rows : batchAffected) + if (rows < 0 || rows > 1) throw new DbStateException(); + ps.close(); + } catch (SQLException e) { + tryToClose(ps); + throw new DbException(e); } } } diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java index 35bda79dd4..ae3e203405 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/DatabaseComponentImplTest.java @@ -1315,8 +1315,8 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { @Test public void testTransportKeys() throws Exception { TransportKeys transportKeys = createTransportKeys(); - Collection<KeySet> keys = - singletonList(new KeySet(keySetId, contactId, transportKeys)); + KeySet ks = new KeySet(keySetId, contactId, transportKeys); + Collection<KeySet> keys = singletonList(ks); context.checking(new Expectations() {{ // startTransaction() oneOf(database).startTransaction(); @@ -1324,7 +1324,7 @@ public class DatabaseComponentImplTest extends BrambleMockTestCase { // updateTransportKeys() oneOf(database).containsTransport(txn, transportId); will(returnValue(true)); - oneOf(database).updateTransportKeys(txn, keys); + oneOf(database).updateTransportKeys(txn, ks); // getTransportKeys() oneOf(database).containsTransport(txn, transportId); will(returnValue(true)); diff --git a/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java b/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java index 7b4618fbd9..97268d16b3 100644 --- a/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java +++ b/bramble-core/src/test/java/org/briarproject/bramble/db/JdbcDatabaseTest.java @@ -667,8 +667,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { @Test public void testTransportKeys() throws Exception { - TransportKeys keys = createTransportKeys(); - TransportKeys keys1 = createTransportKeys(); + long rotationPeriod = 123, rotationPeriod1 = 234; + TransportKeys keys = createTransportKeys(rotationPeriod); + TransportKeys keys1 = createTransportKeys(rotationPeriod1); Database<Connection> db = open(false); Connection txn = db.startTransaction(); @@ -697,6 +698,25 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { } } + // Rotate the transport keys + TransportKeys rotated = createTransportKeys(rotationPeriod + 1); + TransportKeys rotated1 = createTransportKeys(rotationPeriod1 + 1); + db.updateTransportKeys(txn, new KeySet(keySetId, contactId, rotated)); + db.updateTransportKeys(txn, new KeySet(keySetId1, contactId, rotated1)); + + // Retrieve the transport keys again + allKeys = db.getTransportKeys(txn, transportId); + assertEquals(2, allKeys.size()); + for (KeySet ks : allKeys) { + assertEquals(contactId, ks.getContactId()); + if (ks.getKeySetId().equals(keySetId)) { + assertKeysEquals(rotated, ks.getTransportKeys()); + } else { + assertEquals(keySetId1, ks.getKeySetId()); + assertKeysEquals(rotated1, ks.getTransportKeys()); + } + } + // Removing the contact should remove the transport keys db.removeContact(txn, contactId); assertEquals(emptyList(), db.getTransportKeys(txn, transportId)); @@ -707,8 +727,9 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { @Test public void testUnboundTransportKeys() throws Exception { - TransportKeys keys = createTransportKeys(); - TransportKeys keys1 = createTransportKeys(); + long rotationPeriod = 123, rotationPeriod1 = 234; + TransportKeys keys = createTransportKeys(rotationPeriod); + TransportKeys keys1 = createTransportKeys(rotationPeriod1); Database<Connection> db = open(false); Connection txn = db.startTransaction(); @@ -754,6 +775,26 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { } } + // Rotate the transport keys + TransportKeys rotated = createTransportKeys(rotationPeriod + 1); + TransportKeys rotated1 = createTransportKeys(rotationPeriod1 + 1); + db.updateTransportKeys(txn, new KeySet(keySetId, contactId, rotated)); + db.updateTransportKeys(txn, new KeySet(keySetId1, null, rotated1)); + + // Retrieve the transport keys again + allKeys = db.getTransportKeys(txn, transportId); + assertEquals(2, allKeys.size()); + for (KeySet ks : allKeys) { + if (ks.getKeySetId().equals(keySetId)) { + assertEquals(contactId, ks.getContactId()); + assertKeysEquals(rotated, ks.getTransportKeys()); + } else { + assertEquals(keySetId1, ks.getKeySetId()); + assertNull(ks.getContactId()); + assertKeysEquals(rotated1, ks.getTransportKeys()); + } + } + // Remove the unbound transport keys db.removeTransportKeys(txn, transportId, keySetId1); @@ -763,7 +804,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { KeySet ks = allKeys.iterator().next(); assertEquals(keySetId, ks.getKeySetId()); assertEquals(contactId, ks.getContactId()); - assertKeysEquals(keys, ks.getTransportKeys()); + assertKeysEquals(rotated, ks.getTransportKeys()); // Removing the transport should remove the remaining transport keys db.removeTransport(txn, transportId); @@ -809,8 +850,8 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { @Test public void testIncrementStreamCounter() throws Exception { - TransportKeys keys = createTransportKeys(); - long rotationPeriod = keys.getCurrentOutgoingKeys().getRotationPeriod(); + long rotationPeriod = 123; + TransportKeys keys = createTransportKeys(rotationPeriod); long streamCounter = keys.getCurrentOutgoingKeys().getStreamCounter(); Database<Connection> db = open(false); @@ -821,8 +862,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), true, true)); db.addTransport(txn, transportId, 123); - db.updateTransportKeys(txn, - singletonList(new KeySet(keySetId, contactId, keys))); + assertEquals(keySetId, db.addTransportKeys(txn, contactId, keys)); // Increment the stream counter twice and retrieve the transport keys db.incrementStreamCounter(txn, transportId, keySetId); @@ -838,14 +878,21 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { assertEquals(rotationPeriod, outCurr.getRotationPeriod()); assertEquals(streamCounter + 2, outCurr.getStreamCounter()); + // The rest of the keys should be unaffected + assertKeysEquals(keys.getPreviousIncomingKeys(), + k.getPreviousIncomingKeys()); + assertKeysEquals(keys.getCurrentIncomingKeys(), + k.getCurrentIncomingKeys()); + assertKeysEquals(keys.getNextIncomingKeys(), k.getNextIncomingKeys()); + db.commitTransaction(txn); db.close(); } @Test public void testSetReorderingWindow() throws Exception { - TransportKeys keys = createTransportKeys(); - long rotationPeriod = keys.getCurrentIncomingKeys().getRotationPeriod(); + long rotationPeriod = 123; + TransportKeys keys = createTransportKeys(rotationPeriod); long base = keys.getCurrentIncomingKeys().getWindowBase(); byte[] bitmap = keys.getCurrentIncomingKeys().getWindowBitmap(); @@ -857,8 +904,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { assertEquals(contactId, db.addContact(txn, author, localAuthor.getId(), true, true)); db.addTransport(txn, transportId, 123); - db.updateTransportKeys(txn, - singletonList(new KeySet(keySetId, contactId, keys))); + assertEquals(keySetId, db.addTransportKeys(txn, contactId, keys)); // Update the reordering window and retrieve the transport keys new Random().nextBytes(bitmap); @@ -876,6 +922,13 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { assertEquals(base + 1, inCurr.getWindowBase()); assertArrayEquals(bitmap, inCurr.getWindowBitmap()); + // The rest of the keys should be unaffected + assertKeysEquals(keys.getPreviousIncomingKeys(), + k.getPreviousIncomingKeys()); + assertKeysEquals(keys.getNextIncomingKeys(), k.getNextIncomingKeys()); + assertKeysEquals(keys.getCurrentOutgoingKeys(), + k.getCurrentOutgoingKeys()); + db.commitTransaction(txn); db.close(); } @@ -973,8 +1026,8 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Attach some metadata to the group Metadata metadata = new Metadata(); - metadata.put("foo", new byte[]{'b', 'a', 'r'}); - metadata.put("baz", new byte[]{'b', 'a', 'm'}); + metadata.put("foo", new byte[] {'b', 'a', 'r'}); + metadata.put("baz", new byte[] {'b', 'a', 'm'}); db.mergeGroupMetadata(txn, groupId, metadata); // Retrieve the metadata for the group @@ -1012,8 +1065,8 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Attach some metadata to the message Metadata metadata = new Metadata(); - metadata.put("foo", new byte[]{'b', 'a', 'r'}); - metadata.put("baz", new byte[]{'b', 'a', 'm'}); + metadata.put("foo", new byte[] {'b', 'a', 'r'}); + metadata.put("baz", new byte[] {'b', 'a', 'm'}); db.mergeMessageMetadata(txn, messageId, metadata); // Retrieve the metadata for the message @@ -1083,8 +1136,8 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Attach some metadata to the message Metadata metadata = new Metadata(); - metadata.put("foo", new byte[]{'b', 'a', 'r'}); - metadata.put("baz", new byte[]{'b', 'a', 'm'}); + metadata.put("foo", new byte[] {'b', 'a', 'r'}); + metadata.put("baz", new byte[] {'b', 'a', 'm'}); db.mergeMessageMetadata(txn, messageId, metadata); // Retrieve the metadata for the message @@ -1145,11 +1198,11 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Attach some metadata to the messages Metadata metadata = new Metadata(); - metadata.put("foo", new byte[]{'b', 'a', 'r'}); - metadata.put("baz", new byte[]{'b', 'a', 'm'}); + metadata.put("foo", new byte[] {'b', 'a', 'r'}); + metadata.put("baz", new byte[] {'b', 'a', 'm'}); db.mergeMessageMetadata(txn, messageId, metadata); Metadata metadata1 = new Metadata(); - metadata1.put("foo", new byte[]{'q', 'u', 'x'}); + metadata1.put("foo", new byte[] {'q', 'u', 'x'}); db.mergeMessageMetadata(txn, messageId1, metadata1); // Retrieve all the metadata for the group @@ -1249,11 +1302,11 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { // Attach some metadata to the messages Metadata metadata = new Metadata(); - metadata.put("foo", new byte[]{'b', 'a', 'r'}); - metadata.put("baz", new byte[]{'b', 'a', 'm'}); + metadata.put("foo", new byte[] {'b', 'a', 'r'}); + metadata.put("baz", new byte[] {'b', 'a', 'm'}); db.mergeMessageMetadata(txn, messageId, metadata); Metadata metadata1 = new Metadata(); - metadata1.put("foo", new byte[]{'b', 'a', 'r'}); + metadata1.put("foo", new byte[] {'b', 'a', 'r'}); db.mergeMessageMetadata(txn, messageId1, metadata1); for (int i = 0; i < 2; i++) { @@ -1264,7 +1317,7 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { } else { // Query for foo query = new Metadata(); - query.put("foo", new byte[]{'b', 'a', 'r'}); + query.put("foo", new byte[] {'b', 'a', 'r'}); } db.setMessageState(txn, messageId, DELIVERED); @@ -1804,23 +1857,23 @@ public abstract class JdbcDatabaseTest extends BrambleTestCase { return db; } - private TransportKeys createTransportKeys() { + private TransportKeys createTransportKeys(long rotationPeriod) { SecretKey inPrevTagKey = getSecretKey(); SecretKey inPrevHeaderKey = getSecretKey(); IncomingKeys inPrev = new IncomingKeys(inPrevTagKey, inPrevHeaderKey, - 1, 123, new byte[4]); + rotationPeriod - 1, 123, new byte[4]); SecretKey inCurrTagKey = getSecretKey(); SecretKey inCurrHeaderKey = getSecretKey(); IncomingKeys inCurr = new IncomingKeys(inCurrTagKey, inCurrHeaderKey, - 2, 234, new byte[4]); + rotationPeriod, 234, new byte[4]); SecretKey inNextTagKey = getSecretKey(); SecretKey inNextHeaderKey = getSecretKey(); IncomingKeys inNext = new IncomingKeys(inNextTagKey, inNextHeaderKey, - 3, 345, new byte[4]); + rotationPeriod + 1, 345, new byte[4]); SecretKey outCurrTagKey = getSecretKey(); SecretKey outCurrHeaderKey = getSecretKey(); OutgoingKeys outCurr = new OutgoingKeys(outCurrTagKey, outCurrHeaderKey, - 2, 456, true); + rotationPeriod, 456, true); return new TransportKeys(transportId, inPrev, inCurr, inNext, outCurr); } -- GitLab