diff --git a/components/net/sf/briar/db/JdbcDatabase.java b/components/net/sf/briar/db/JdbcDatabase.java index 562d90785ce068e983b02258bb2d252f97fd11da..c2ec4eeb0106e520bb07187911b01b976035de2c 100644 --- a/components/net/sf/briar/db/JdbcDatabase.java +++ b/components/net/sf/briar/db/JdbcDatabase.java @@ -42,6 +42,7 @@ import net.sf.briar.api.transport.ConnectionContext; import net.sf.briar.api.transport.ConnectionContextFactory; import net.sf.briar.api.transport.ConnectionWindow; import net.sf.briar.api.transport.ConnectionWindowFactory; +import net.sf.briar.util.ByteUtils; import net.sf.briar.util.FileUtils; /** @@ -58,6 +59,12 @@ abstract class JdbcDatabase implements Database<Connection> { + " start BIGINT NOT NULL," + " PRIMARY KEY (groupId))"; + private static final String CREATE_SUBSCRIPTION_IDS = + "CREATE TABLE subscriptionIds" + + " (groupId HASH," // Null for the head of the list + + " nextId HASH," // Null for the tail of the list + + " deleted BIGINT NOT NULL)"; + private static final String CREATE_CONTACTS = "CREATE TABLE contacts" + " (contactId COUNTER," @@ -326,6 +333,7 @@ abstract class JdbcDatabase implements Database<Connection> { try { s = txn.createStatement(); s.executeUpdate(insertTypeNames(CREATE_SUBSCRIPTIONS)); + s.executeUpdate(insertTypeNames(CREATE_SUBSCRIPTION_IDS)); s.executeUpdate(insertTypeNames(CREATE_CONTACTS)); s.executeUpdate(insertTypeNames(CREATE_MESSAGES)); s.executeUpdate(INDEX_MESSAGES_BY_PARENT); @@ -721,7 +729,9 @@ abstract class JdbcDatabase implements Database<Connection> { public void addSubscription(Connection txn, Group g) throws DbException { PreparedStatement ps = null; + ResultSet rs = null; try { + // Add the group to the subscriptions table String sql = "INSERT INTO subscriptions" + " (groupId, groupName, groupKey, start)" + " VALUES (?, ?, ?, ?)"; @@ -733,7 +743,78 @@ abstract class JdbcDatabase implements Database<Connection> { int affected = ps.executeUpdate(); if(affected != 1) throw new DbStateException(); ps.close(); + // Insert the group ID into the linked list + byte[] id = g.getId().getBytes(); + sql = "SELECT groupId, nextId, deleted FROM subscriptionIds" + + " ORDER BY groupId"; + ps = txn.prepareStatement(sql); + rs = ps.executeQuery(); + if(rs.next()) { + // The head pointer of the list exists + byte[] groupId = rs.getBytes(1); + if(groupId != null) throw new DbStateException(); + byte[] nextId = rs.getBytes(2); + long deleted = rs.getLong(3); + // Scan through the list to find the insertion point + while(nextId != null && ByteUtils.compare(id, nextId) > 0) { + if(!rs.next()) throw new DbStateException(); + groupId = rs.getBytes(1); + if(groupId == null) throw new DbStateException(); + nextId = rs.getBytes(2); + deleted = rs.getLong(3); + } + rs.close(); + ps.close(); + // Update the previous element + if(groupId == null) { + // Inserting at the head of the list + sql = "UPDATE subscriptionIds SET nextId = ?" + + " WHERE groupId IS NULL"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, id); + } else { + // Inserting in the middle or at the tail of the list + sql = "UPDATE subscriptionIds SET nextId = ?" + + " WHERE groupId = ?"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, id); + ps.setBytes(2, groupId); + } + affected = ps.executeUpdate(); + if(affected != 1) throw new DbStateException(); + ps.close(); + // Insert the new element + sql = "INSERT INTO subscriptionIds (groupId, nextId, deleted)" + + " VALUES (?, ?, ?)"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, id); + if(nextId == null) ps.setNull(2, Types.BINARY); // At the tail + else ps.setBytes(2, nextId); // In the middle + ps.setLong(3, deleted); + affected = ps.executeUpdate(); + if(affected != 1) throw new DbStateException(); + ps.close(); + } else { + // The head pointer of the list does not exist + rs.close(); + ps.close(); + sql = "INSERT INTO subscriptionIds (nextId, deleted)" + + " VALUES (?, ZERO())"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, id); + affected = ps.executeUpdate(); + if(affected != 1) throw new DbStateException(); + ps.close(); + sql = "INSERT INTO subscriptionIds (groupId, deleted)" + + " VALUES (?, ZERO())"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, id); + affected = ps.executeUpdate(); + if(affected != 1) throw new DbStateException(); + ps.close(); + } } catch(SQLException e) { + tryToClose(rs); tryToClose(ps); throw new DbException(e); } @@ -2096,14 +2177,43 @@ abstract class JdbcDatabase implements Database<Connection> { public void removeSubscription(Connection txn, GroupId g) throws DbException { PreparedStatement ps = null; + ResultSet rs = null; try { + // Remove the group from the subscriptions table String sql = "DELETE FROM subscriptions WHERE groupId = ?"; ps = txn.prepareStatement(sql); ps.setBytes(1, g.getBytes()); int affected = ps.executeUpdate(); if(affected != 1) throw new DbStateException(); ps.close(); + // Remove the group ID from the linked list + sql = "SELECT nextId FROM subscriptionIds WHERE groupId = ?"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, g.getBytes()); + rs = ps.executeQuery(); + if(!rs.next()) throw new DbStateException(); + byte[] nextId = rs.getBytes(1); + if(rs.next()) throw new DbStateException(); + rs.close(); + ps.close(); + sql = "DELETE FROM subscriptionIds WHERE groupId = ?"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, g.getBytes()); + affected = ps.executeUpdate(); + if(affected != 1) throw new DbStateException(); + ps.close(); + sql = "UPDATE subscriptionIds SET nextId = ?, deleted = ?" + + " WHERE nextId = ?"; + ps = txn.prepareStatement(sql); + if(nextId == null) ps.setNull(1, Types.BINARY); // At the tail + else ps.setBytes(1, nextId); // At the head or in the middle + ps.setLong(2, System.currentTimeMillis()); + ps.setBytes(3, g.getBytes()); + affected = ps.executeUpdate(); + if(affected != 1) throw new DbStateException(); + ps.close(); } catch(SQLException e) { + tryToClose(rs); tryToClose(ps); throw new DbException(e); } diff --git a/test/net/sf/briar/db/H2DatabaseTest.java b/test/net/sf/briar/db/H2DatabaseTest.java index cbc223c06f1192fe437461640e9a156bd710408e..6cdad47d48ccf04dd0c7dcae33e26680a9f0ef36 100644 --- a/test/net/sf/briar/db/H2DatabaseTest.java +++ b/test/net/sf/briar/db/H2DatabaseTest.java @@ -10,6 +10,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import java.util.Map; import java.util.Random; import java.util.concurrent.CountDownLatch; @@ -1882,6 +1883,29 @@ public class H2DatabaseTest extends BriarTestCase { db.close(); } + @Test + public void testMultipleSubscriptionsAndUnsubscriptions() throws Exception { + // Create some groups + List<Group> groups = new ArrayList<Group>(); + for(int i = 0; i < 100; i++) { + GroupId id = new GroupId(TestUtils.getRandomId()); + groups.add(groupFactory.createGroup(id, "Group name", null)); + } + + Database<Connection> db = open(false); + Connection txn = db.startTransaction(); + + // Add the groups to the database + for(Group g : groups) db.addSubscription(txn, g); + + // Remove the groups in a different order + Collections.shuffle(groups); + for(Group g : groups) db.removeSubscription(txn, g.getId()); + + db.commitTransaction(txn); + db.close(); + } + @Test public void testExceptionHandling() throws Exception { Database<Connection> db = open(false); diff --git a/util/net/sf/briar/util/ByteUtils.java b/util/net/sf/briar/util/ByteUtils.java index e19c7e3ca39ce68393b7e6a6bb24e9597ab3397a..d858b8f9b07f28d3eb6e2aefe96459bbc2cf8bdf 100644 --- a/util/net/sf/briar/util/ByteUtils.java +++ b/util/net/sf/briar/util/ByteUtils.java @@ -62,4 +62,22 @@ public class ByteUtils { assert result < 1 << bits; return result; } + + /** + * Compares two byte arrays and returns -1, 0, or +1 if the first array is + * less than, equal to, or greater than the second array, respectively. + * <p> + * If one of the arrays is a prefix of the other, the longer array is + * considered to be greater. Bytes are treated as unsigned. + */ + public static int compare(byte[] b1, byte[] b2) { + for(int i = 0; i < b1.length || i < b2.length; i++) { + if(i == b1.length) return -1; + if(i == b2.length) return 1; + int b1i = b1[i] & 0xff, b2i = b2[i] & 0xff; + if(b1i < b2i) return -1; + if(b1i > b2i) return 1; + } + return 0; + } }