diff --git a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java
index 7e3e3307bc5b9395e0d0af09949097bb46018cab..e2b9ce2aa1fde11702ca36082935aac48e1428f9 100644
--- a/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java
+++ b/bramble-core/src/main/java/org/briarproject/bramble/db/JdbcDatabase.java
@@ -68,8 +68,8 @@ import static org.briarproject.bramble.db.ExponentialBackoff.calculateExpiry;
 @NotNullByDefault
 abstract class JdbcDatabase implements Database<Connection> {
 
-	private static final int SCHEMA_VERSION = 31;
-	private static final int MIN_SCHEMA_VERSION = 31;
+	private static final int SCHEMA_VERSION = 32;
+	private static final int MIN_SCHEMA_VERSION = 32;
 
 	private static final String CREATE_SETTINGS =
 			"CREATE TABLE settings"
@@ -148,11 +148,16 @@ abstract class JdbcDatabase implements Database<Connection> {
 	private static final String CREATE_MESSAGE_METADATA =
 			"CREATE TABLE messageMetadata"
 					+ " (messageId _HASH NOT NULL,"
+					+ " groupId _HASH NOT NULL," // Denormalised
+					+ " state INT NOT NULL," // Denormalised
 					+ " metaKey _STRING NOT NULL,"
 					+ " value _BINARY NOT NULL,"
 					+ " PRIMARY KEY (messageId, metaKey),"
 					+ " FOREIGN KEY (messageId)"
 					+ " REFERENCES messages (messageId)"
+					+ " ON DELETE CASCADE,"
+					+ " FOREIGN KEY (groupId)"
+					+ " REFERENCES groups (groupId)"
 					+ " ON DELETE CASCADE)";
 
 	private static final String CREATE_MESSAGE_DEPENDENCIES =
@@ -252,6 +257,10 @@ abstract class JdbcDatabase implements Database<Connection> {
 			"CREATE INDEX IF NOT EXISTS messageMetadataByMessageId"
 					+ " ON messageMetadata (messageId)";
 
+	private static final String INDEX_MESSAGE_METADATA_BY_GROUP_ID_STATE =
+			"CREATE INDEX IF NOT EXISTS messageMetadataByGroupIdState"
+					+ " ON messageMetadata (groupId, state)";
+
 	private static final String INDEX_GROUP_METADATA_BY_GROUP_ID =
 			"CREATE INDEX IF NOT EXISTS groupMetadataByGroupId"
 					+ " ON groupMetadata (groupId)";
@@ -376,6 +385,7 @@ abstract class JdbcDatabase implements Database<Connection> {
 			s.executeUpdate(INDEX_OFFERS_BY_CONTACT_ID);
 			s.executeUpdate(INDEX_GROUPS_BY_CLIENT_ID);
 			s.executeUpdate(INDEX_MESSAGE_METADATA_BY_MESSAGE_ID);
+			s.executeUpdate(INDEX_MESSAGE_METADATA_BY_GROUP_ID_STATE);
 			s.executeUpdate(INDEX_GROUP_METADATA_BY_GROUP_ID);
 			s.close();
 		} catch (SQLException e) {
@@ -1334,16 +1344,13 @@ abstract class JdbcDatabase implements Database<Connection> {
 		try {
 			// Retrieve the message IDs for each query term and intersect
 			Set<MessageId> intersection = null;
-			String sql = "SELECT m.messageId"
-					+ " FROM messages AS m"
-					+ " JOIN messageMetadata AS md"
-					+ " ON m.messageId = md.messageId"
-					+ " WHERE state = ? AND groupId = ?"
+			String sql = "SELECT messageId FROM messageMetadata"
+					+ " WHERE groupId = ? AND state = ?"
 					+ " AND metaKey = ? AND value = ?";
 			for (Entry<String, byte[]> e : query.entrySet()) {
 				ps = txn.prepareStatement(sql);
-				ps.setInt(1, DELIVERED.getValue());
-				ps.setBytes(2, g.getBytes());
+				ps.setBytes(1, g.getBytes());
+				ps.setInt(2, DELIVERED.getValue());
 				ps.setString(3, e.getKey());
 				ps.setBytes(4, e.getValue());
 				rs = ps.executeQuery();
@@ -1371,25 +1378,20 @@ abstract class JdbcDatabase implements Database<Connection> {
 		PreparedStatement ps = null;
 		ResultSet rs = null;
 		try {
-			String sql = "SELECT m.messageId, metaKey, value"
-					+ " FROM messages AS m"
-					+ " JOIN messageMetadata AS md"
-					+ " ON m.messageId = md.messageId"
-					+ " WHERE state = ? AND groupId = ?"
-					+ " ORDER BY m.messageId";
+			String sql = "SELECT messageId, metaKey, value"
+					+ " FROM messageMetadata"
+					+ " WHERE groupId = ? AND state = ?";
 			ps = txn.prepareStatement(sql);
-			ps.setInt(1, DELIVERED.getValue());
-			ps.setBytes(2, g.getBytes());
+			ps.setBytes(1, g.getBytes());
+			ps.setInt(2, DELIVERED.getValue());
 			rs = ps.executeQuery();
 			Map<MessageId, Metadata> all = new HashMap<>();
-			Metadata metadata = null;
-			MessageId lastMessageId = null;
 			while (rs.next()) {
 				MessageId messageId = new MessageId(rs.getBytes(1));
-				if (lastMessageId == null || !messageId.equals(lastMessageId)) {
+				Metadata metadata = all.get(messageId);
+				if (metadata == null) {
 					metadata = new Metadata();
 					all.put(messageId, metadata);
-					lastMessageId = messageId;
 				}
 				metadata.put(rs.getString(2), rs.getBytes(3));
 			}
@@ -1444,10 +1446,8 @@ abstract class JdbcDatabase implements Database<Connection> {
 		PreparedStatement ps = null;
 		ResultSet rs = null;
 		try {
-			String sql = "SELECT metaKey, value FROM messageMetadata AS md"
-					+ " JOIN messages AS m"
-					+ " ON m.messageId = md.messageId"
-					+ " WHERE m.state = ? AND md.messageId = ?";
+			String sql = "SELECT metaKey, value FROM messageMetadata"
+					+ " WHERE state = ? AND messageId = ?";
 			ps = txn.prepareStatement(sql);
 			ps.setInt(1, DELIVERED.getValue());
 			ps.setBytes(2, m.getBytes());
@@ -1470,11 +1470,9 @@ abstract class JdbcDatabase implements Database<Connection> {
 		PreparedStatement ps = null;
 		ResultSet rs = null;
 		try {
-			String sql = "SELECT metaKey, value FROM messageMetadata AS md"
-					+ " JOIN messages AS m"
-					+ " ON m.messageId = md.messageId"
-					+ " WHERE (m.state = ? OR m.state = ?)"
-					+ " AND md.messageId = ?";
+			String sql = "SELECT metaKey, value FROM messageMetadata"
+					+ " WHERE (state = ? OR state = ?)"
+					+ " AND messageId = ?";
 			ps = txn.prepareStatement(sql);
 			ps.setInt(1, DELIVERED.getValue());
 			ps.setInt(2, PENDING.getValue());
@@ -2051,7 +2049,7 @@ abstract class JdbcDatabase implements Database<Connection> {
 			int[] batchAffected = ps.executeBatch();
 			if (batchAffected.length != requested.size())
 				throw new DbStateException();
-			for (int rows: batchAffected) {
+			for (int rows : batchAffected) {
 				if (rows < 0) throw new DbStateException();
 				if (rows > 1) throw new DbStateException();
 			}
@@ -2065,25 +2063,92 @@ abstract class JdbcDatabase implements Database<Connection> {
 	@Override
 	public void mergeGroupMetadata(Connection txn, GroupId g, Metadata meta)
 			throws DbException {
-		mergeMetadata(txn, g.getBytes(), meta, "groupMetadata", "groupId");
+		PreparedStatement ps = null;
+		try {
+			Map<String, byte[]> added = removeOrUpdateMetadata(txn,
+					g.getBytes(), meta, "groupMetadata", "groupId");
+			if (added.isEmpty()) return;
+			// Insert any keys that don't already exist
+			String sql = "INSERT INTO groupMetadata (groupId, metaKey, value)"
+					+ " VALUES (?, ?, ?)";
+			ps = txn.prepareStatement(sql);
+			ps.setBytes(1, g.getBytes());
+			for (Entry<String, byte[]> e : added.entrySet()) {
+				ps.setString(2, e.getKey());
+				ps.setBytes(3, e.getValue());
+				ps.addBatch();
+			}
+			int[] batchAffected = ps.executeBatch();
+			if (batchAffected.length != added.size())
+				throw new DbStateException();
+			for (int rows : batchAffected)
+				if (rows != 1) throw new DbStateException();
+			ps.close();
+		} catch (SQLException e) {
+			tryToClose(ps);
+			throw new DbException(e);
+		}
 	}
 
 	@Override
-	public void mergeMessageMetadata(Connection txn, MessageId m, Metadata meta)
-			throws DbException {
-		mergeMetadata(txn, m.getBytes(), meta, "messageMetadata", "messageId");
+	public void mergeMessageMetadata(Connection txn, MessageId m,
+			Metadata meta) throws DbException {
+		PreparedStatement ps = null;
+		ResultSet rs = null;
+		try {
+			Map<String, byte[]> added = removeOrUpdateMetadata(txn,
+					m.getBytes(), meta, "messageMetadata", "messageId");
+			if (added.isEmpty()) return;
+			// Get the group ID and message state for the denormalised columns
+			String sql = "SELECT groupId, state FROM messages"
+					+ " WHERE messageId = ?";
+			ps = txn.prepareStatement(sql);
+			ps.setBytes(1, m.getBytes());
+			rs = ps.executeQuery();
+			if (!rs.next()) throw new DbStateException();
+			GroupId g = new GroupId(rs.getBytes(1));
+			State state = State.fromValue(rs.getInt(2));
+			rs.close();
+			ps.close();
+			// Insert any keys that don't already exist
+			sql = "INSERT INTO messageMetadata"
+					+ " (messageId, groupId, state, metaKey, value)"
+					+ " VALUES (?, ?, ?, ?, ?)";
+			ps = txn.prepareStatement(sql);
+			ps.setBytes(1, m.getBytes());
+			ps.setBytes(2, g.getBytes());
+			ps.setInt(3, state.getValue());
+			for (Entry<String, byte[]> e : added.entrySet()) {
+				ps.setString(4, e.getKey());
+				ps.setBytes(5, e.getValue());
+				ps.addBatch();
+			}
+			int[] batchAffected = ps.executeBatch();
+			if (batchAffected.length != added.size())
+				throw new DbStateException();
+			for (int rows : batchAffected)
+				if (rows != 1) throw new DbStateException();
+			ps.close();
+		} catch (SQLException e) {
+			tryToClose(rs);
+			tryToClose(ps);
+			throw new DbException(e);
+		}
 	}
 
-	private void mergeMetadata(Connection txn, byte[] id, Metadata meta,
-			String tableName, String columnName) throws DbException {
+	// Removes or updates any existing entries, returns any entries that
+	// need to be added
+	private Map<String, byte[]> removeOrUpdateMetadata(Connection txn,
+			byte[] id, Metadata meta, String tableName, String columnName)
+			throws DbException {
 		PreparedStatement ps = null;
 		try {
 			// Determine which keys are being removed
 			List<String> removed = new ArrayList<>();
-			Map<String, byte[]> retained = new HashMap<>();
+			Map<String, byte[]> notRemoved = new HashMap<>();
 			for (Entry<String, byte[]> e : meta.entrySet()) {
 				if (e.getValue() == REMOVE) removed.add(e.getKey());
-				else retained.put(e.getKey(), e.getValue());
+				else notRemoved.put(e.getKey(), e.getValue());
 			}
 			// Delete any keys that are being removed
 			if (!removed.isEmpty()) {
@@ -2104,45 +2169,33 @@ abstract class JdbcDatabase implements Database<Connection> {
 				}
 				ps.close();
 			}
-			if (retained.isEmpty()) return;
+			if (notRemoved.isEmpty()) return Collections.emptyMap();
 			// Update any keys that already exist
 			String sql = "UPDATE " + tableName + " SET value = ?"
 					+ " WHERE " + columnName + " = ? AND metaKey = ?";
 			ps = txn.prepareStatement(sql);
 			ps.setBytes(2, id);
-			for (Entry<String, byte[]> e : retained.entrySet()) {
+			for (Entry<String, byte[]> e : notRemoved.entrySet()) {
 				ps.setBytes(1, e.getValue());
 				ps.setString(3, e.getKey());
 				ps.addBatch();
 			}
 			int[] batchAffected = ps.executeBatch();
-			if (batchAffected.length != retained.size())
+			if (batchAffected.length != notRemoved.size())
 				throw new DbStateException();
 			for (int rows : batchAffected) {
 				if (rows < 0) throw new DbStateException();
 				if (rows > 1) throw new DbStateException();
 			}
-			// Insert any keys that don't already exist
-			sql = "INSERT INTO " + tableName
-					+ " (" + columnName + ", metaKey, value)"
-					+ " VALUES (?, ?, ?)";
-			ps = txn.prepareStatement(sql);
-			ps.setBytes(1, id);
-			int updateIndex = 0, inserted = 0;
-			for (Entry<String, byte[]> e : retained.entrySet()) {
-				if (batchAffected[updateIndex] == 0) {
-					ps.setString(2, e.getKey());
-					ps.setBytes(3, e.getValue());
-					ps.addBatch();
-					inserted++;
-				}
-				updateIndex++;
-			}
-			batchAffected = ps.executeBatch();
-			if (batchAffected.length != inserted) throw new DbStateException();
-			for (int rows : batchAffected)
-				if (rows != 1) throw new DbStateException();
 			ps.close();
+			// Are there any keys that don't already exist?
+			Map<String, byte[]> added = new HashMap<>();
+			int updateIndex = 0;
+			for (Entry<String, byte[]> e : notRemoved.entrySet()) {
+				if (batchAffected[updateIndex++] == 0)
+					added.put(e.getKey(), e.getValue());
+			}
+			return added;
 		} catch (SQLException e) {
 			tryToClose(ps);
 			throw new DbException(e);
@@ -2495,7 +2548,8 @@ abstract class JdbcDatabase implements Database<Connection> {
 	}
 
 	@Override
-	public void setMessageShared(Connection txn, MessageId m) throws DbException {
+	public void setMessageShared(Connection txn, MessageId m)
+			throws DbException {
 		PreparedStatement ps = null;
 		try {
 			String sql = "UPDATE messages SET shared = TRUE"
@@ -2523,6 +2577,14 @@ abstract class JdbcDatabase implements Database<Connection> {
 			int affected = ps.executeUpdate();
 			if (affected < 0 || affected > 1) throw new DbStateException();
 			ps.close();
+			// Update denormalised column in messageMetadata
+			sql = "UPDATE messageMetadata SET state = ? WHERE messageId = ?";
+			ps = txn.prepareStatement(sql);
+			ps.setInt(1, state.getValue());
+			ps.setBytes(2, m.getBytes());
+			affected = ps.executeUpdate();
+			if (affected < 0) throw new DbStateException();
+			ps.close();
 		} catch (SQLException e) {
 			tryToClose(ps);
 			throw new DbException(e);