diff --git a/briar-core/src/net/sf/briar/db/Database.java b/briar-core/src/net/sf/briar/db/Database.java index 0333dcfa60faf670862d55b13d065d1518c9d6e1..9caffadad56cd42875d0f3fbac036c22fd184a69 100644 --- a/briar-core/src/net/sf/briar/db/Database.java +++ b/briar-core/src/net/sf/briar/db/Database.java @@ -418,6 +418,15 @@ interface Database<T> { SubscriptionUpdate getSubscriptionUpdate(T txn, ContactId c, long maxLatency) throws DbException; + /** + * Returns the transmission count of the given message with respect to the + * given contact. + * <p> + * Locking: contact read, message read. + */ + int getTransmissionCount(T txn, ContactId c, MessageId m) + throws DbException; + /** * Returns a collection of transport acks for the given contact, or null if * no acks are due. @@ -567,15 +576,6 @@ interface Database<T> { void setConnectionWindow(T txn, ContactId c, TransportId t, long period, long centre, byte[] bitmap) throws DbException; - /** - * Updates the expiry times of the given messages with respect to the given - * contact, using the latency of the transport over which they were sent. - * <p> - * Locking: contact read, message write. - */ - void setMessageExpiry(T txn, ContactId c, Collection<MessageId> sent, - long maxLatency) throws DbException; - /** * Sets the user's rating for the given author. * <p> @@ -674,4 +674,14 @@ interface Database<T> { */ void setTransportUpdateAcked(T txn, ContactId c, TransportId t, long version) throws DbException; + + /** + * Updates the expiry times of the given messages with respect to the given + * contact, using the given transmission counts and the latency of the + * transport over which they were sent. + * <p> + * Locking: contact read, message write. + */ + void updateExpiryTimes(T txn, ContactId c, Map<MessageId, Integer> sent, + long maxLatency) throws DbException; } diff --git a/briar-core/src/net/sf/briar/db/DatabaseComponentImpl.java b/briar-core/src/net/sf/briar/db/DatabaseComponentImpl.java index 1338c8c33a79891152ee08e36f23524aa969cb72..a1d766df65aa79431a0bd13c14593632ce8af710 100644 --- a/briar-core/src/net/sf/briar/db/DatabaseComponentImpl.java +++ b/briar-core/src/net/sf/briar/db/DatabaseComponentImpl.java @@ -13,6 +13,7 @@ import java.util.ArrayList; import java.util.BitSet; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -491,6 +492,7 @@ DatabaseCleaner.Callback { public Collection<byte[]> generateBatch(ContactId c, int maxLength, long maxLatency) throws DbException { Collection<MessageId> ids; + Map<MessageId, Integer> sent = new HashMap<MessageId, Integer>(); List<byte[]> messages = new ArrayList<byte[]>(); // Get some sendable messages from the database contactLock.readLock().lock(); @@ -504,8 +506,10 @@ DatabaseCleaner.Callback { if(!db.containsContact(txn, c)) throw new NoSuchContactException(); ids = db.getSendableMessages(txn, c, maxLength); - for(MessageId m : ids) + for(MessageId m : ids) { messages.add(db.getRawMessage(txn, m)); + sent.put(m, db.getTransmissionCount(txn, c, m)); + } db.commitTransaction(txn); } catch(DbException e) { db.abortTransaction(txn); @@ -523,7 +527,7 @@ DatabaseCleaner.Callback { try { T txn = db.startTransaction(); try { - db.setMessageExpiry(txn, c, ids, maxLatency); + db.updateExpiryTimes(txn, c, sent, maxLatency); db.commitTransaction(txn); } catch(DbException e) { db.abortTransaction(txn); @@ -541,7 +545,7 @@ DatabaseCleaner.Callback { public Collection<byte[]> generateBatch(ContactId c, int maxLength, long maxLatency, Collection<MessageId> requested) throws DbException { - Collection<MessageId> ids = new ArrayList<MessageId>(); + Map<MessageId, Integer> sent = new HashMap<MessageId, Integer>(); List<byte[]> messages = new ArrayList<byte[]>(); // Get some sendable messages from the database contactLock.readLock().lock(); @@ -561,7 +565,7 @@ DatabaseCleaner.Callback { if(raw != null) { if(raw.length > maxLength) break; messages.add(raw); - ids.add(m); + sent.put(m, db.getTransmissionCount(txn, c, m)); maxLength -= raw.length; } it.remove(); @@ -583,7 +587,7 @@ DatabaseCleaner.Callback { try { T txn = db.startTransaction(); try { - db.setMessageExpiry(txn, c, ids, maxLatency); + db.updateExpiryTimes(txn, c, sent, maxLatency); db.commitTransaction(txn); } catch(DbException e) { db.abortTransaction(txn); diff --git a/briar-core/src/net/sf/briar/db/JdbcDatabase.java b/briar-core/src/net/sf/briar/db/JdbcDatabase.java index 2a3d215cf97db2b824412166099fd2735d495913..2725a22486b246737de0c53c2d01be63f7ca3cd0 100644 --- a/briar-core/src/net/sf/briar/db/JdbcDatabase.java +++ b/briar-core/src/net/sf/briar/db/JdbcDatabase.java @@ -99,6 +99,7 @@ abstract class JdbcDatabase implements Database<Connection> { + " remoteVersion BIGINT UNSIGNED NOT NULL," + " remoteAcked BOOLEAN NOT NULL," + " expiry BIGINT UNSIGNED NOT NULL," + + " txCount INT UNSIGNED NOT NULL," + " PRIMARY KEY (contactId)," + " FOREIGN KEY (contactid)" + " REFERENCES contacts (contactId)" @@ -158,6 +159,7 @@ abstract class JdbcDatabase implements Database<Connection> { + " contactId INT UNSIGNED NOT NULL," + " seen BOOLEAN NOT NULL," + " expiry BIGINT UNSIGNED NOT NULL," + + " txCount INT UNSIGNED NOT NULL," + " PRIMARY KEY (messageId, contactId)," + " FOREIGN KEY (messageId)" + " REFERENCES messages (messageId)" @@ -189,6 +191,7 @@ abstract class JdbcDatabase implements Database<Connection> { + " remoteVersion BIGINT UNSIGNED NOT NULL," + " remoteAcked BOOLEAN NOT NULL," + " expiry BIGINT UNSIGNED NOT NULL," + + " txCount INT UNSIGNED NOT NULL," + " PRIMARY KEY (contactId)," + " FOREIGN KEY (contactId)" + " REFERENCES contacts (contactId)" @@ -228,6 +231,7 @@ abstract class JdbcDatabase implements Database<Connection> { + " localVersion BIGINT UNSIGNED NOT NULL," + " localAcked BIGINT UNSIGNED NOT NULL," + " expiry BIGINT UNSIGNED NOT NULL," + + " txCount INT UNSIGNED NOT NULL," + " PRIMARY KEY (contactId, transportId)," + " FOREIGN KEY (contactId)" + " REFERENCES contacts (contactId)" @@ -508,10 +512,11 @@ abstract class JdbcDatabase implements Database<Connection> { rs.close(); ps.close(); // Create a retention version row - sql = "INSERT INTO retentionVersions" - + " (contactId, retention, localVersion, localAcked," - + " remoteVersion, remoteAcked, expiry)" - + " VALUES (?, ZERO(), ?, ZERO(), ZERO(), TRUE, ZERO())"; + sql = "INSERT INTO retentionVersions (contactId, retention," + + " localVersion, localAcked, remoteVersion, remoteAcked," + + " expiry, txCount)" + + " VALUES (?, ZERO(), ?, ZERO(), ZERO(), TRUE, ZERO()," + + " ZERO())"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); ps.setInt(2, 1); @@ -520,8 +525,9 @@ abstract class JdbcDatabase implements Database<Connection> { ps.close(); // Create a group version row sql = "INSERT INTO groupVersions (contactId, localVersion," - + " localAcked, remoteVersion, remoteAcked, expiry)" - + " VALUES (?, ?, ZERO(), ZERO(), TRUE, ZERO())"; + + " localAcked, remoteVersion, remoteAcked, expiry," + + " txCount)" + + " VALUES (?, ?, ZERO(), ZERO(), TRUE, ZERO(), ZERO())"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); ps.setInt(2, 1); @@ -538,8 +544,8 @@ abstract class JdbcDatabase implements Database<Connection> { ps.close(); if(transports.isEmpty()) return c; sql = "INSERT INTO transportVersions (contactId, transportId," - + " localVersion, localAcked, expiry)" - + " VALUES (?, ?, ?, ZERO(), ZERO())"; + + " localVersion, localAcked, expiry, txCount)" + + " VALUES (?, ?, ?, ZERO(), ZERO(), ZERO())"; ps = txn.prepareStatement(sql); ps.setInt(1, c.getInt()); ps.setInt(3, 1); @@ -687,8 +693,8 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; try { String sql = "INSERT INTO statuses" - + " (messageId, contactId, seen, expiry)" - + " VALUES (?, ?, ?, ZERO())"; + + " (messageId, contactId, seen, expiry, txCount)" + + " VALUES (?, ?, ?, ZERO(), ZERO())"; ps = txn.prepareStatement(sql); ps.setBytes(1, m.getBytes()); ps.setInt(2, c.getInt()); @@ -790,8 +796,8 @@ abstract class JdbcDatabase implements Database<Connection> { ps.close(); if(contacts.isEmpty()) return; sql = "INSERT INTO transportVersions (contactId, transportId," - + " localVersion, localAcked, expiry)" - + " VALUES (?, ?, ?, ZERO(), ZERO())"; + + " localVersion, localAcked, expiry, txCount)" + + " VALUES (?, ?, ?, ZERO(), ZERO(), ZERO())"; ps = txn.prepareStatement(sql); ps.setBytes(2, t.getBytes()); ps.setInt(3, 1); @@ -826,7 +832,8 @@ abstract class JdbcDatabase implements Database<Connection> { ps.close(); // Bump the subscription version sql = "UPDATE groupVersions" - + " SET localVersion = localVersion + ?, expiry = ZERO()" + + " SET localVersion = localVersion + ?," + + " expiry = ZERO(), txCount = ZERO()" + " WHERE contactId = ?"; ps = txn.prepareStatement(sql); ps.setInt(1, 1); @@ -1541,7 +1548,7 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT timestamp, localVersion" + String sql = "SELECT timestamp, localVersion, txCount" + " FROM messages AS m" + " JOIN retentionVersions AS rv" + " WHERE rv.contactId = ?" @@ -1561,13 +1568,17 @@ abstract class JdbcDatabase implements Database<Connection> { long retention = rs.getLong(1); retention -= retention % RETENTION_MODULUS; long version = rs.getLong(2); + int txCount = rs.getInt(3); if(rs.next()) throw new DbStateException(); rs.close(); ps.close(); - sql = "UPDATE retentionVersions SET expiry = ? WHERE contactId = ?"; + sql = "UPDATE retentionVersions" + + " SET expiry = ?, txCount = txCount + ?" + + " WHERE contactId = ?"; ps = txn.prepareStatement(sql); - ps.setLong(1, calculateExpiry(now, maxLatency)); - ps.setInt(2, c.getInt()); + ps.setLong(1, calculateExpiry(now, maxLatency, txCount)); + ps.setInt(2, 1); + ps.setInt(3, c.getInt()); int affected = ps.executeUpdate(); if(affected != 1) throw new DbStateException(); ps.close(); @@ -1817,7 +1828,7 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT g.groupId, name, key, localVersion" + String sql = "SELECT g.groupId, name, key, localVersion, txCount" + " FROM groups AS g" + " JOIN groupVisibilities AS vis" + " ON g.groupId = vis.groupId" @@ -1832,20 +1843,25 @@ abstract class JdbcDatabase implements Database<Connection> { rs = ps.executeQuery(); List<Group> subs = new ArrayList<Group>(); long version = 0; + int txCount = 0; while(rs.next()) { byte[] id = rs.getBytes(1); String name = rs.getString(2); byte[] key = rs.getBytes(3); - version = rs.getLong(4); subs.add(new Group(new GroupId(id), name, key)); + version = rs.getLong(4); + txCount = rs.getInt(5); } rs.close(); ps.close(); if(subs.isEmpty()) return null; - sql = "UPDATE groupVersions SET expiry = ? WHERE contactId = ?"; + sql = "UPDATE groupVersions" + + " SET expiry = ?, txCount = txCount + ?" + + " WHERE contactId = ?"; ps = txn.prepareStatement(sql); - ps.setLong(1, calculateExpiry(now, maxLatency)); - ps.setInt(2, c.getInt()); + ps.setLong(1, calculateExpiry(now, maxLatency, txCount)); + ps.setInt(2, 1); + ps.setInt(3, c.getInt()); int affected = ps.executeUpdate(); if(affected != 1) throw new DbStateException(); ps.close(); @@ -1858,6 +1874,30 @@ abstract class JdbcDatabase implements Database<Connection> { } } + public int getTransmissionCount(Connection txn, ContactId c, MessageId m) + throws DbException { + PreparedStatement ps = null; + ResultSet rs = null; + try { + String sql = "SELECT txCount FROM statuses" + + " WHERE messageId = ? AND contactId = ?"; + ps = txn.prepareStatement(sql); + ps.setBytes(1, m.getBytes()); + ps.setInt(2, c.getInt()); + rs = ps.executeQuery(); + if(!rs.next()) throw new DbStateException(); + int txCount = rs.getInt(1); + if(rs.next()) throw new DbStateException(); + rs.close(); + ps.close(); + return txCount; + } catch(SQLException e) { + tryToClose(ps); + tryToClose(rs); + throw new DbException(e); + } + } + public Collection<TransportAck> getTransportAcks(Connection txn, ContactId c) throws DbException { PreparedStatement ps = null; @@ -1906,7 +1946,8 @@ abstract class JdbcDatabase implements Database<Connection> { PreparedStatement ps = null; ResultSet rs = null; try { - String sql = "SELECT tp.transportId, key, value, localVersion" + String sql = "SELECT tp.transportId, key, value, localVersion," + + " txCount" + " FROM transportProperties AS tp" + " JOIN transportVersions AS tv" + " ON tp.transportId = tv.transportId" @@ -1920,32 +1961,39 @@ abstract class JdbcDatabase implements Database<Connection> { List<TransportUpdate> updates = new ArrayList<TransportUpdate>(); TransportId lastId = null; TransportProperties p = null; + List<Integer> txCounts = new ArrayList<Integer>(); while(rs.next()) { TransportId id = new TransportId(rs.getBytes(1)); String key = rs.getString(2), value = rs.getString(3); long version = rs.getLong(4); + int txCount = rs.getInt(5); if(!id.equals(lastId)) { p = new TransportProperties(); updates.add(new TransportUpdate(id, p, version)); + txCounts.add(txCount); } p.put(key, value); } rs.close(); ps.close(); if(updates.isEmpty()) return null; - sql = "UPDATE transportVersions SET expiry = ?" + sql = "UPDATE transportVersions" + + " SET expiry = ?, txCount = txCount + ?" + " WHERE contactId = ? AND transportId = ?"; ps = txn.prepareStatement(sql); - ps.setLong(1, calculateExpiry(now, maxLatency)); - ps.setInt(2, c.getInt()); + ps.setInt(2, 1); + ps.setInt(3, c.getInt()); + int i = 0; for(TransportUpdate u : updates) { + int txCount = txCounts.get(i++); + ps.setLong(1, calculateExpiry(now, maxLatency, txCount)); ps.setBytes(3, u.getId().getBytes()); ps.addBatch(); } int [] batchAffected = ps.executeBatch(); if(batchAffected.length != updates.size()) throw new DbStateException(); - for(int i = 0; i < batchAffected.length; i++) { + for(i = 0; i < batchAffected.length; i++) { if(batchAffected[i] != 1) throw new DbStateException(); } ps.close(); @@ -2405,41 +2453,6 @@ abstract class JdbcDatabase implements Database<Connection> { } } - public void setMessageExpiry(Connection txn, ContactId c, - Collection<MessageId> sent, long maxLatency) throws DbException { - long now = clock.currentTimeMillis(); - PreparedStatement ps = null; - try { - String sql = "UPDATE statuses SET expiry = ?" - + " WHERE messageId = ? AND contactId = ?"; - ps = txn.prepareStatement(sql); - ps.setLong(1, calculateExpiry(now, maxLatency)); - ps.setInt(3, c.getInt()); - for(MessageId m : sent) { - ps.setBytes(2, m.getBytes()); - ps.addBatch(); - } - int[] batchAffected = ps.executeBatch(); - if(batchAffected.length != sent.size()) - throw new DbStateException(); - for(int i = 0; i < batchAffected.length; i++) { - if(batchAffected[i] > 1) throw new DbStateException(); - } - ps.close(); - } catch(SQLException e) { - tryToClose(ps); - throw new DbException(e); - } - } - - private long calculateExpiry(long now, long maxLatency) { - long roundTrip = maxLatency * 2; - if(roundTrip < 0) return Long.MAX_VALUE; // Overflow; - long expiry = now + roundTrip; - if(expiry < 0) return Long.MAX_VALUE; // Overflow - return expiry; - } - public Rating setRating(Connection txn, AuthorId a, Rating r) throws DbException { PreparedStatement ps = null; @@ -2835,4 +2848,46 @@ abstract class JdbcDatabase implements Database<Connection> { throw new DbException(e); } } + + public void updateExpiryTimes(Connection txn, ContactId c, + Map<MessageId, Integer> sent, long maxLatency) throws DbException { + long now = clock.currentTimeMillis(); + PreparedStatement ps = null; + try { + String sql = "UPDATE statuses" + + " SET expiry = ?, txCount = txCount + ?" + + " WHERE messageId = ? AND contactId = ?"; + ps = txn.prepareStatement(sql); + ps.setInt(2, 1); + ps.setInt(4, c.getInt()); + for(Entry<MessageId, Integer> e : sent.entrySet()) { + ps.setLong(1, calculateExpiry(now, maxLatency, e.getValue())); + ps.setBytes(3, e.getKey().getBytes()); + ps.addBatch(); + } + int[] batchAffected = ps.executeBatch(); + if(batchAffected.length != sent.size()) + throw new DbStateException(); + for(int i = 0; i < batchAffected.length; i++) { + if(batchAffected[i] > 1) throw new DbStateException(); + } + ps.close(); + } catch(SQLException e) { + tryToClose(ps); + throw new DbException(e); + } + } + + // FIXME: Refactor the exponential backoff logic into a separate class + private long calculateExpiry(long now, long maxLatency, int txCount) { + long roundTrip = maxLatency * 2; + if(roundTrip < 0) return Long.MAX_VALUE; + for(int i = 0; i < txCount; i++) { + roundTrip <<= 1; + if(roundTrip < 0) return Long.MAX_VALUE; + } + long expiry = now + roundTrip; + if(expiry < 0) return Long.MAX_VALUE; + return expiry; + } } diff --git a/briar-tests/src/net/sf/briar/db/DatabaseComponentTest.java b/briar-tests/src/net/sf/briar/db/DatabaseComponentTest.java index 83b0cbf621a54af0abc5af6503b96f9c0bcc95f9..ae2c9a9f10477f77cb51075287f2ed85787c7547 100644 --- a/briar-tests/src/net/sf/briar/db/DatabaseComponentTest.java +++ b/briar-tests/src/net/sf/briar/db/DatabaseComponentTest.java @@ -8,6 +8,8 @@ import java.util.Arrays; import java.util.BitSet; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import net.sf.briar.BriarTestCase; import net.sf.briar.TestMessage; @@ -677,6 +679,9 @@ public abstract class DatabaseComponentTest extends BriarTestCase { final Collection<MessageId> sendable = Arrays.asList(messageId, messageId1); final Collection<byte[]> messages = Arrays.asList(raw, raw1); + final Map<MessageId, Integer> sent = new HashMap<MessageId, Integer>(); + sent.put(messageId, 1); + sent.put(messageId1, 2); Mockery context = new Mockery(); @SuppressWarnings("unchecked") final Database<Object> database = context.mock(Database.class); @@ -688,15 +693,19 @@ public abstract class DatabaseComponentTest extends BriarTestCase { allowing(database).commitTransaction(txn); allowing(database).containsContact(txn, contactId); will(returnValue(true)); - // Get the sendable messages + // Get the sendable messages and their transmission counts oneOf(database).getSendableMessages(txn, contactId, size * 2); will(returnValue(sendable)); oneOf(database).getRawMessage(txn, messageId); will(returnValue(raw)); + oneOf(database).getTransmissionCount(txn, contactId, messageId); + will(returnValue(1)); oneOf(database).getRawMessage(txn, messageId1); will(returnValue(raw1)); + oneOf(database).getTransmissionCount(txn, contactId, messageId1); + will(returnValue(2)); // Record the outstanding messages - oneOf(database).setMessageExpiry(txn, contactId, sendable, + oneOf(database).updateExpiryTimes(txn, contactId, sent, Long.MAX_VALUE); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, @@ -731,11 +740,13 @@ public abstract class DatabaseComponentTest extends BriarTestCase { will(returnValue(null)); // Message is not sendable oneOf(database).getRawMessageIfSendable(txn, contactId, messageId1); will(returnValue(raw1)); // Message is sendable + oneOf(database).getTransmissionCount(txn, contactId, messageId1); + will(returnValue(2)); oneOf(database).getRawMessageIfSendable(txn, contactId, messageId2); will(returnValue(null)); // Message is not sendable // Mark the message as sent - oneOf(database).setMessageExpiry(txn, contactId, - Arrays.asList(messageId1), Long.MAX_VALUE); + oneOf(database).updateExpiryTimes(txn, contactId, + Collections.singletonMap(messageId1, 2), Long.MAX_VALUE); }}); DatabaseComponent db = createDatabaseComponent(database, cleaner, shutdown); diff --git a/briar-tests/src/net/sf/briar/db/H2DatabaseTest.java b/briar-tests/src/net/sf/briar/db/H2DatabaseTest.java index 74bc6b4a9b6ea1a33ef9d37a916aa85b4895cac1..927871d755b462c2286d76280e21196119519a09 100644 --- a/briar-tests/src/net/sf/briar/db/H2DatabaseTest.java +++ b/briar-tests/src/net/sf/briar/db/H2DatabaseTest.java @@ -523,8 +523,8 @@ public class H2DatabaseTest extends BriarTestCase { assertTrue(it.hasNext()); assertEquals(messageId, it.next()); assertFalse(it.hasNext()); - db.setMessageExpiry(txn, contactId, Arrays.asList(messageId), - Long.MAX_VALUE); + db.updateExpiryTimes(txn, contactId, + Collections.singletonMap(messageId, 0), Long.MAX_VALUE); // The message should no longer be sendable it = db.getSendableMessages(txn, contactId, ONE_MEGABYTE).iterator();