diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java index 7e3e3307bc5b9395e0d0af09949097bb46018cab..e2b9ce2aa1fde11702ca36082935aac48e1428f9 100644 --- a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java +++ b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java @@ -68,8 +68,8 @@ import static org.briarproject.bramble.db.ExponentialBackoff.calculateExpiry; @NotNullByDefault abstract class JdbcDatabase implements Database<Connection> { - private static final int SCHEMA_VERSION = 31; - private static final int MIN_SCHEMA_VERSION = 31; + private static final int SCHEMA_VERSION = 32; + private static final int MIN_SCHEMA_VERSION = 32; private static final String CREATE_SETTINGS = "CREATE TABLE settings" @@ -148,11 +148,16 @@ abstract class JdbcDatabase implements Database<Connection> { private static final String CREATE_MESSAGE_METADATA = "CREATE TABLE messageMetadata" + " (messageId _HASH NOT NULL," + + " groupId _HASH NOT NULL," // Denormalised + + " state INT NOT NULL," // Denormalised + " metaKey _STRING NOT NULL," + " value _BINARY NOT NULL," + " PRIMARY KEY (messageId, metaKey)," + " FOREIGN KEY (messageId)" + " REFERENCES messages (messageId)" + + " ON DELETE CASCADE," + + " FOREIGN KEY (groupId)" + + " REFERENCES groups (groupId)" + " ON DELETE CASCADE)"; private static final String CREATE_MESSAGE_DEPENDENCIES = @@ -252,6 +257,10 @@ abstract class JdbcDatabase implements Database<Connection> { "CREATE INDEX IF NOT EXISTS messageMetadataByMessageId" + " ON messageMetadata (messageId)"; + private static final String INDEX_MESSAGE_METADATA_BY_GROUP_ID_STATE = + "CREATE INDEX IF NOT EXISTS messageMetadataByGroupIdState" + + " ON messageMetadata (groupId, state)"; + private static final String INDEX_GROUP_METADATA_BY_GROUP_ID = "CREATE INDEX IF NOT EXISTS groupMetadataByGroupId" + " ON groupMetadata (groupId)"; @@ -376,6 +385,7 @@ abstract class JdbcDatabase implements Database<Connection> { s.executeUpdate(INDEX_OFFERS_BY_CONTACT_ID); s.executeUpdate(INDEX_GROUPS_BY_CLIENT_ID); s.executeUpdate(INDEX_MESSAGE_METADATA_BY_MESSAGE_ID); + s.executeUpdate(INDEX_MESSAGE_METADATA_BY_GROUP_ID_STATE); s.executeUpdate(INDEX_GROUP_METADATA_BY_GROUP_ID); s.close(); } catch (SQLException e) { @@ -1334,16 +1344,13 @@ abstract class JdbcDatabase implements Database<Connection> { try { // Retrieve the message IDs for each query term and intersect Set<MessageId> intersection = null; - String sql = "SELECT m.messageId" - + " FROM messages AS m" - + " JOIN messageMetadata AS md" - + " ON m.messageId = md.messageId" - + " WHERE state = ? AND groupId = ?" + String sql = "SELECT messageId FROM messageMetadata" + + " WHERE groupId = ? AND state = ?" + " AND metaKey = ? AND value = ?"; for (Entry<String, byte[]> e : query.entrySet()) { ps = txn.prepareStatement(sql); - ps.setInt(1, DELIVERED.getValue()); - ps.setBytes(2, g.getBytes()); + ps.setBytes(1, g.getBytes()); + ps.setInt(2, DELIVERED.getValue()); ps.setString(3, e.getKey()); ps.setBytes(4, e.getValue()); rs = ps.executeQuery(); @@ -1371,25 +1378,20 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT m.messageId, metaKey, value" - + " FROM messages AS m" - + " JOIN messageMetadata AS md" - + " ON m.messageId = md.messageId" - + " WHERE state = ? AND groupId = ?" - + " ORDER BY m.messageId"; + String sql = "SELECT messageId, metaKey, value" + + " FROM messageMetadata" + + " WHERE groupId = ? AND state = ?"; ps = txn.prepareStatement(sql); - ps.setInt(1, DELIVERED.getValue()); - ps.setBytes(2, g.getBytes()); + ps.setBytes(1, g.getBytes()); + ps.setInt(2, DELIVERED.getValue()); rs = ps.executeQuery(); Map<MessageId, Metadata> all = new HashMap<>(); - Metadata metadata = null; - MessageId lastMessageId = null; while (rs.next()) { MessageId messageId = new MessageId(rs.getBytes(1)); - if (lastMessageId == null || !messageId.equals(lastMessageId)) { + Metadata metadata = all.get(messageId); + if (metadata == null) { metadata = new Metadata(); all.put(messageId, metadata); - lastMessageId = messageId; } metadata.put(rs.getString(2), rs.getBytes(3)); } @@ -1444,10 +1446,8 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT metaKey, value FROM messageMetadata AS md" - + " JOIN messages AS m" - + " ON m.messageId = md.messageId" - + " WHERE m.state = ? AND md.messageId = ?"; + String sql = "SELECT metaKey, value FROM messageMetadata" + + " WHERE state = ? AND messageId = ?"; ps = txn.prepareStatement(sql); ps.setInt(1, DELIVERED.getValue()); ps.setBytes(2, m.getBytes()); @@ -1470,11 +1470,9 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT metaKey, value FROM messageMetadata AS md" - + " JOIN messages AS m" - + " ON m.messageId = md.messageId" - + " WHERE (m.state = ? OR m.state = ?)" - + " AND md.messageId = ?"; + String sql = "SELECT metaKey, value FROM messageMetadata" + + " WHERE (state = ? OR state = ?)" + + " AND messageId = ?"; ps = txn.prepareStatement(sql); ps.setInt(1, DELIVERED.getValue()); ps.setInt(2, PENDING.getValue()); @@ -2051,7 +2049,7 @@ abstract class JdbcDatabase implements Database<Connection> { int[] batchAffected = ps.executeBatch(); if (batchAffected.length != requested.size()) throw new DbStateException(); - for (int rows: batchAffected) { + for (int rows : batchAffected) { if (rows < 0) throw new DbStateException(); if (rows > 1) throw new DbStateException(); } @@ -2065,25 +2063,92 @@ abstract class JdbcDatabase implements Database<Connection> { @Override public void mergeGroupMetadata(Connection txn, GroupId g, Metadata meta) throws DbException { - mergeMetadata(txn, g.getBytes(), meta, "groupMetadata", "groupId"); + PreparedStatement ps = null; + try { + Map<String, byte[]> added = removeOrUpdateMetadata(txn, + g.getBytes(), meta, "groupMetadata", "groupId"); + if (added.isEmpty()) return; + // Insert any keys that don't already exist + String sql = "INSERT INTO groupMetadata (groupId, metaKey, value)" + + " VALUES (?, ?, ?)"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, g.getBytes()); + for (Entry<String, byte[]> e : added.entrySet()) { + ps.setString(2, e.getKey()); + ps.setBytes(3, e.getValue()); + ps.addBatch(); + } + int[] batchAffected = ps.executeBatch(); + if (batchAffected.length != added.size()) + throw new DbStateException(); + for (int rows : batchAffected) + if (rows != 1) throw new DbStateException(); + ps.close(); + } catch (SQLException e) { + tryToClose(ps); + throw new DbException(e); + } } @Override - public void mergeMessageMetadata(Connection txn, MessageId m, Metadata meta) - throws DbException { - mergeMetadata(txn, m.getBytes(), meta, "messageMetadata", "messageId"); + public void mergeMessageMetadata(Connection txn, MessageId m, + Metadata meta) throws DbException { + PreparedStatement ps = null; + ResultSet rs = null; + try { + Map<String, byte[]> added = removeOrUpdateMetadata(txn, + m.getBytes(), meta, "messageMetadata", "messageId"); + if (added.isEmpty()) return; + // Get the group ID and message state for the denormalised columns + String sql = "SELECT groupId, state FROM messages" + + " WHERE messageId = ?"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, m.getBytes()); + rs = ps.executeQuery(); + if (!rs.next()) throw new DbStateException(); + GroupId g = new GroupId(rs.getBytes(1)); + State state = State.fromValue(rs.getInt(2)); + rs.close(); + ps.close(); + // Insert any keys that don't already exist + sql = "INSERT INTO messageMetadata" + + " (messageId, groupId, state, metaKey, value)" + + " VALUES (?, ?, ?, ?, ?)"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, m.getBytes()); + ps.setBytes(2, g.getBytes()); + ps.setInt(3, state.getValue()); + for (Entry<String, byte[]> e : added.entrySet()) { + ps.setString(4, e.getKey()); + ps.setBytes(5, e.getValue()); + ps.addBatch(); + } + int[] batchAffected = ps.executeBatch(); + if (batchAffected.length != added.size()) + throw new DbStateException(); + for (int rows : batchAffected) + if (rows != 1) throw new DbStateException(); + ps.close(); + } catch (SQLException e) { + tryToClose(rs); + tryToClose(ps); + throw new DbException(e); + } } - private void mergeMetadata(Connection txn, byte[] id, Metadata meta, - String tableName, String columnName) throws DbException { + // Removes or updates any existing entries, returns any entries that + // need to be added + private Map<String, byte[]> removeOrUpdateMetadata(Connection txn, + byte[] id, Metadata meta, String tableName, String columnName) + throws DbException { PreparedStatement ps = null; try { // Determine which keys are being removed List<String> removed = new ArrayList<>(); - Map<String, byte[]> retained = new HashMap<>(); + Map<String, byte[]> notRemoved = new HashMap<>(); for (Entry<String, byte[]> e : meta.entrySet()) { if (e.getValue() == REMOVE) removed.add(e.getKey()); - else retained.put(e.getKey(), e.getValue()); + else notRemoved.put(e.getKey(), e.getValue()); } // Delete any keys that are being removed if (!removed.isEmpty()) { @@ -2104,45 +2169,33 @@ abstract class JdbcDatabase implements Database<Connection> { } ps.close(); } - if (retained.isEmpty()) return; + if (notRemoved.isEmpty()) return Collections.emptyMap(); // Update any keys that already exist String sql = "UPDATE " + tableName + " SET value = ?" + " WHERE " + columnName + " = ? AND metaKey = ?"; ps = txn.prepareStatement(sql); ps.setBytes(2, id); - for (Entry<String, byte[]> e : retained.entrySet()) { + for (Entry<String, byte[]> e : notRemoved.entrySet()) { ps.setBytes(1, e.getValue()); ps.setString(3, e.getKey()); ps.addBatch(); } int[] batchAffected = ps.executeBatch(); - if (batchAffected.length != retained.size()) + if (batchAffected.length != notRemoved.size()) throw new DbStateException(); for (int rows : batchAffected) { if (rows < 0) throw new DbStateException(); if (rows > 1) throw new DbStateException(); } - // Insert any keys that don't already exist - sql = "INSERT INTO " + tableName - + " (" + columnName + ", metaKey, value)" - + " VALUES (?, ?, ?)"; - ps = txn.prepareStatement(sql); - ps.setBytes(1, id); - int updateIndex = 0, inserted = 0; - for (Entry<String, byte[]> e : retained.entrySet()) { - if (batchAffected[updateIndex] == 0) { - ps.setString(2, e.getKey()); - ps.setBytes(3, e.getValue()); - ps.addBatch(); - inserted++; - } - updateIndex++; - } - batchAffected = ps.executeBatch(); - if (batchAffected.length != inserted) throw new DbStateException(); - for (int rows : batchAffected) - if (rows != 1) throw new DbStateException(); ps.close(); + // Are there any keys that don't already exist? + Map<String, byte[]> added = new HashMap<>(); + int updateIndex = 0; + for (Entry<String, byte[]> e : notRemoved.entrySet()) { + if (batchAffected[updateIndex++] == 0) + added.put(e.getKey(), e.getValue()); + } + return added; } catch (SQLException e) { tryToClose(ps); throw new DbException(e); @@ -2495,7 +2548,8 @@ abstract class JdbcDatabase implements Database<Connection> { } @Override - public void setMessageShared(Connection txn, MessageId m) throws DbException { + public void setMessageShared(Connection txn, MessageId m) + throws DbException { PreparedStatement ps = null; try { String sql = "UPDATE messages SET shared = TRUE" @@ -2523,6 +2577,14 @@ abstract class JdbcDatabase implements Database<Connection> { int affected = ps.executeUpdate(); if (affected < 0 || affected > 1) throw new DbStateException(); ps.close(); + // Update denormalised column in messageMetadata + sql = "UPDATE messageMetadata SET state = ? WHERE messageId = ?"; + ps = txn.prepareStatement(sql); + ps.setInt(1, state.getValue()); + ps.setBytes(2, m.getBytes()); + affected = ps.executeUpdate(); + if (affected < 0) throw new DbStateException(); + ps.close(); } catch (SQLException e) { tryToClose(ps); throw new DbException(e);