Commit 27e50b84 authored by akwizgran's avatar akwizgran

Implemented KeyManager (untested).

A test is failing due to key derivation errors - must be fixed!
parent cc6e9d53
......@@ -6,10 +6,20 @@ import net.sf.briar.api.transport.ConnectionContext;
public interface KeyManager {
/**
* Starts the key manager and returns true if the manager started
* successfully. This method must be called after the database has been
* opened.
*/
boolean start();
/** Stops the key manager. */
void stop();
/**
* Returns a connection context for connecting to the given contact over
* the given transport, or null if the contact does not support the
* transport.
* the given transport, or null if an error occurs or the contact does not
* support the transport.
*/
ConnectionContext getConnectionContext(ContactId c, TransportId t);
}
......@@ -114,9 +114,6 @@ public interface DatabaseComponent {
/** Returns the IDs of all contacts. */
Collection<ContactId> getContacts() throws DbException;
/** Returns all contact transports. */
Collection<ContactTransport> getContactTransports() throws DbException;
/** Returns the local transport properties for the given transport. */
TransportProperties getLocalProperties(TransportId t) throws DbException;
......@@ -150,9 +147,9 @@ public interface DatabaseComponent {
/**
* Increments the outgoing connection counter for the given contact
* transport in the given rotation period.
* transport in the given rotation period and returns the old value.
*/
void incrementConnectionCounter(ContactId c, TransportId t, long period)
long incrementConnectionCounter(ContactId c, TransportId t, long period)
throws DbException;
/** Processes an acknowledgement from the given contact. */
......
package net.sf.briar.api.db;
import static net.sf.briar.api.transport.TransportConstants.CONNECTION_WINDOW_SIZE;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.protocol.TransportId;
public class TemporarySecret {
public class TemporarySecret extends ContactTransport {
private final ContactId contactId;
private final TransportId transportId;
private final long period, outgoing, centre;
private final byte[] secret, bitmap;
public TemporarySecret(ContactId contactId, TransportId transportId,
long epoch, long clockDiff, long latency, boolean alice,
long period, byte[] secret, long outgoing, long centre,
byte[] bitmap) {
this.contactId = contactId;
this.transportId = transportId;
super(contactId, transportId, epoch, clockDiff, latency, alice);
this.period = period;
this.secret = secret;
this.outgoing = outgoing;
......@@ -22,12 +21,14 @@ public class TemporarySecret {
this.bitmap = bitmap;
}
public ContactId getContactId() {
return contactId;
}
public TransportId getTransportId() {
return transportId;
public TemporarySecret(TemporarySecret old, long period, byte[] secret) {
super(old.getContactId(), old.getTransportId(), old.getEpoch(),
old.getClockDifference(), old.getLatency(), old.getAlice());
this.period = period;
this.secret = secret;
outgoing = 0L;
centre = 0L;
bitmap = new byte[CONNECTION_WINDOW_SIZE / 8];
}
public long getPeriod() {
......
......@@ -2,6 +2,7 @@ package net.sf.briar.api.transport;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.TemporarySecret;
import net.sf.briar.api.protocol.TransportId;
/**
......@@ -17,10 +18,11 @@ public interface ConnectionRecogniser {
ConnectionContext acceptConnection(TransportId t, byte[] tag)
throws DbException;
void addWindow(ContactId c, TransportId t, long period, boolean alice,
byte[] secret, long centre, byte[] bitmap) throws DbException;
void addSecret(TemporarySecret s) throws DbException;
void removeWindow(ContactId c, TransportId t, long period);
void removeSecret(ContactId c, TransportId t, long period);
void removeWindows(ContactId c);
void removeSecrets(ContactId c);
void removeSecrets();
}
package net.sf.briar.crypto;
import net.sf.briar.api.db.DbException;
interface KeyRotator {
/**
* Starts a new thread to rotate keys periodically. The rotator will pause
* for the given number of milliseconds between rotations.
*/
void startRotating(Callback callback, long msBetweenRotations);
/** Tells the rotator thread to exit. */
void stopRotating();
interface Callback {
/**
* Rotates keys, replacing and destroying any keys that have passed the
* ends of their respective retention periods.
*/
void rotateKeys() throws DbException;
}
}
package net.sf.briar.crypto;
import java.util.Timer;
import java.util.TimerTask;
import java.util.logging.Level;
import java.util.logging.Logger;
import net.sf.briar.api.db.DbException;
class KeyRotatorImpl extends TimerTask implements KeyRotator {
private static final Logger LOG =
Logger.getLogger(KeyRotatorImpl.class.getName());
private volatile Callback callback = null;
private volatile Timer timer = null;
public void startRotating(Callback callback, long msBetweenRotations) {
this.callback = callback;
timer = new Timer();
timer.scheduleAtFixedRate(this, 0L, msBetweenRotations);
}
public void stopRotating() {
if(timer == null) throw new IllegalStateException();
timer.cancel();
}
public void run() {
if(callback == null) throw new IllegalStateException();
try {
callback.rotateKeys();
} catch(DbException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
throw new Error(e); // Kill the application
} catch(RuntimeException e) {
if(LOG.isLoggable(Level.WARNING)) LOG.warning(e.toString());
throw new Error(e); // Kill the application
}
}
}
......@@ -468,11 +468,11 @@ interface Database<T> {
/**
* Increments the outgoing connection counter for the given contact
* transport in the given rotation period.
* transport in the given rotation period and returns the old value;
* <p>
* Locking: contact read, window write.
*/
void incrementConnectionCounter(T txn, ContactId c, TransportId t,
long incrementConnectionCounter(T txn, ContactId c, TransportId t,
long period) throws DbException;
/**
......
......@@ -758,30 +758,6 @@ DatabaseCleaner.Callback {
}
}
public Collection<ContactTransport> getContactTransports()
throws DbException {
contactLock.readLock().lock();
try {
windowLock.readLock().lock();
try {
T txn = db.startTransaction();
try {
Collection<ContactTransport> contactTransports =
db.getContactTransports(txn);
db.commitTransaction(txn);
return contactTransports;
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
}
} finally {
windowLock.readLock().unlock();
}
} finally {
contactLock.readLock().unlock();
}
}
public TransportProperties getLocalProperties(TransportId t)
throws DbException {
transportLock.readLock().lock();
......@@ -1005,7 +981,7 @@ DatabaseCleaner.Callback {
}
}
public void incrementConnectionCounter(ContactId c, TransportId t,
public long incrementConnectionCounter(ContactId c, TransportId t,
long period) throws DbException {
contactLock.readLock().lock();
try {
......@@ -1015,8 +991,9 @@ DatabaseCleaner.Callback {
try {
if(!db.containsContactTransport(txn, c, t))
throw new NoSuchContactTransportException();
db.incrementConnectionCounter(txn, c, t, period);
long l = db.incrementConnectionCounter(txn, c, t, period);
db.commitTransaction(txn);
return l;
} catch(DbException e) {
db.abortTransaction(txn);
throw e;
......
......@@ -1557,22 +1557,30 @@ abstract class JdbcDatabase implements Database<Connection> {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "SELECT contactId, transportId, period, secret,"
+ " outgoing, centre, bitmap"
+ " FROM secrets";
String sql = "SELECT ct.contactId, ct.transportId, epoch,"
+ " clockDiff, latency, alice, period, secret, outgoing,"
+ " centre, bitmap"
+ " FROM contactTransports AS ct"
+ " JOIN secrets AS s"
+ " ON ct.contactId = s.contactId"
+ " AND ct.transportId = s.transportId";
ps = txn.prepareStatement(sql);
rs = ps.executeQuery();
List<TemporarySecret> secrets = new ArrayList<TemporarySecret>();
while(rs.next()) {
ContactId c = new ContactId(rs.getInt(1));
TransportId t = new TransportId(rs.getBytes(2));
long period = rs.getLong(3);
byte[] secret = rs.getBytes(4);
long outgoing = rs.getLong(5);
long centre = rs.getLong(6);
byte[] bitmap = rs.getBytes(7);
secrets.add(new TemporarySecret(c, t, period, secret, outgoing,
centre, bitmap));
long epoch = rs.getLong(3);
long clockDiff = rs.getLong(4);
long latency = rs.getLong(5);
boolean alice = rs.getBoolean(6);
long period = rs.getLong(7);
byte[] secret = rs.getBytes(8);
long outgoing = rs.getLong(9);
long centre = rs.getLong(10);
byte[] bitmap = rs.getBytes(11);
secrets.add(new TemporarySecret(c, t, epoch, clockDiff, latency,
alice, period, secret, outgoing, centre, bitmap));
}
rs.close();
ps.close();
......@@ -2021,11 +2029,26 @@ abstract class JdbcDatabase implements Database<Connection> {
}
}
public void incrementConnectionCounter(Connection txn, ContactId c,
public long incrementConnectionCounter(Connection txn, ContactId c,
TransportId t, long period) throws DbException {
PreparedStatement ps = null;
ResultSet rs = null;
try {
String sql = "UPDATE secrets SET outgoing = outgoing + 1"
// Get the current connection counter
String sql = "SELECT outgoing FROM secrets"
+ " WHERE contactId = ? AND transportId = ? AND period + ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
ps.setBytes(2, t.getBytes());
ps.setLong(3, period);
rs = ps.executeQuery();
if(!rs.next()) throw new DbStateException();
long connection = rs.getLong(1);
if(rs.next()) throw new DbStateException();
rs.close();
ps.close();
// Increment the connection counter
sql = "UPDATE secrets SET outgoing = outgoing + 1"
+ " WHERE contactId = ? AND transportId = ? AND period = ?";
ps = txn.prepareStatement(sql);
ps.setInt(1, c.getInt());
......@@ -2034,8 +2057,10 @@ abstract class JdbcDatabase implements Database<Connection> {
int affected = ps.executeUpdate();
if(affected > 1) throw new DbStateException();
ps.close();
return connection;
} catch(SQLException e) {
tryToClose(ps);
tryToClose(rs);
throw new DbException(e);
}
}
......
......@@ -7,6 +7,7 @@ import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.TemporarySecret;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser;
......@@ -37,9 +38,8 @@ class ConnectionRecogniserImpl implements ConnectionRecogniser {
return r.acceptConnection(tag);
}
public void addWindow(ContactId c, TransportId t, long period,
boolean alice, byte[] secret, long centre, byte[] bitmap)
throws DbException {
public void addSecret(TemporarySecret s) throws DbException {
TransportId t = s.getTransportId();
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
......@@ -48,20 +48,24 @@ class ConnectionRecogniserImpl implements ConnectionRecogniser {
recognisers.put(t, r);
}
}
r.addWindow(c, period, alice, secret, centre, bitmap);
r.addSecret(s);
}
public void removeWindow(ContactId c, TransportId t, long period) {
public void removeSecret(ContactId c, TransportId t, long period) {
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
}
if(r != null) r.removeWindow(c, period);
if(r != null) r.removeSecret(c, period);
}
public synchronized void removeWindows(ContactId c) {
for(TransportConnectionRecogniser r : recognisers.values()) {
r.removeWindows(c);
}
public synchronized void removeSecrets(ContactId c) {
for(TransportConnectionRecogniser r : recognisers.values())
r.removeSecrets(c);
}
public synchronized void removeSecrets() {
for(TransportConnectionRecogniser r : recognisers.values())
r.removeSecrets();
}
}
......@@ -13,6 +13,7 @@ import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionWriter;
import net.sf.briar.api.transport.ConnectionWriterFactory;
import com.google.inject.Inject;
class ConnectionWriterFactoryImpl implements ConnectionWriterFactory {
......
......@@ -6,6 +6,7 @@ import static net.sf.briar.api.transport.TransportConstants.MAC_LENGTH;
import java.io.IOException;
import java.io.OutputStream;
import net.sf.briar.api.transport.ConnectionWriter;
/**
......
This diff is collapsed.
......@@ -11,6 +11,7 @@ import java.io.IOException;
import java.io.OutputStream;
import java.security.GeneralSecurityException;
import net.sf.briar.api.crypto.AuthenticatedCipher;
import net.sf.briar.api.crypto.ErasableKey;
......
......@@ -15,6 +15,7 @@ import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.db.TemporarySecret;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.util.ByteUtils;
......@@ -74,8 +75,13 @@ class TransportConnectionRecogniser {
return ctx;
}
synchronized void addWindow(ContactId contactId, long period, boolean alice,
byte[] secret, long centre, byte[] bitmap) throws DbException {
synchronized void addSecret(TemporarySecret s) throws DbException {
ContactId contactId = s.getContactId();
long period = s.getPeriod();
byte[] secret = s.getSecret();
boolean alice = s.getAlice();
long centre = s.getWindowCentre();
byte[] bitmap = s.getWindowBitmap();
// Create the connection window and the expected tags
Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(secret, alice);
......@@ -96,10 +102,15 @@ class TransportConnectionRecogniser {
removalMap.put(new RemovalKey(contactId, period), rctx);
}
synchronized void removeWindow(ContactId contactId, long period) {
synchronized void removeSecret(ContactId contactId, long period) {
RemovalKey rk = new RemovalKey(contactId, period);
RemovalContext rctx = removalMap.remove(rk);
if(rctx == null) throw new IllegalArgumentException();
removeSecret(rctx);
}
// Locking: this
private void removeSecret(RemovalContext rctx) {
// Remove the expected tags
Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(rctx.secret, rctx.alice);
......@@ -114,12 +125,18 @@ class TransportConnectionRecogniser {
ByteUtils.erase(rctx.secret);
}
synchronized void removeWindows(ContactId c) {
synchronized void removeSecrets(ContactId c) {
Collection<RemovalKey> keysToRemove = new ArrayList<RemovalKey>();
for(RemovalKey k : removalMap.keySet()) {
if(k.contactId.equals(c)) keysToRemove.add(k);
}
for(RemovalKey k : keysToRemove) removeWindow(k.contactId, k.period);
for(RemovalKey k : keysToRemove) removeSecret(k.contactId, k.period);
}
synchronized void removeSecrets() {
for(RemovalContext rctx : removalMap.values()) removeSecret(rctx);
assert tagMap.isEmpty();
removalMap.clear();
}
private static class WindowContext {
......@@ -148,7 +165,7 @@ class TransportConnectionRecogniser {
@Override
public int hashCode() {
return contactId.hashCode()+ (int) period;
return contactId.hashCode() + (int) period;
}
@Override
......
......@@ -3,13 +3,16 @@ package net.sf.briar.transport;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import net.sf.briar.api.crypto.KeyManager;
import net.sf.briar.api.transport.ConnectionDispatcher;
import net.sf.briar.api.transport.ConnectionReaderFactory;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRegistry;
import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.api.transport.IncomingConnectionExecutor;
import com.google.inject.AbstractModule;
import com.google.inject.Singleton;
public class TransportModule extends AbstractModule {
......@@ -18,6 +21,8 @@ public class TransportModule extends AbstractModule {
bind(ConnectionDispatcher.class).to(ConnectionDispatcherImpl.class);
bind(ConnectionReaderFactory.class).to(
ConnectionReaderFactoryImpl.class);
bind(ConnectionRecogniser.class).to(ConnectionRecogniserImpl.class).in(
Singleton.class);
bind(ConnectionRegistry.class).toInstance(new ConnectionRegistryImpl());
bind(ConnectionWriterFactory.class).to(
ConnectionWriterFactoryImpl.class);
......@@ -25,5 +30,6 @@ public class TransportModule extends AbstractModule {
bind(Executor.class).annotatedWith(
IncomingConnectionExecutor.class).toInstance(
Executors.newCachedThreadPool());
bind(KeyManager.class).to(KeyManagerImpl.class).in(Singleton.class);
}
}
......@@ -20,7 +20,6 @@
<test name='net.sf.briar.crypto.ErasableKeyTest'/>
<test name='net.sf.briar.crypto.KeyAgreementTest'/>
<test name='net.sf.briar.crypto.KeyDerivationTest'/>
<test name='net.sf.briar.crypto.KeyRotatorImplTest'/>
<test name='net.sf.briar.db.BasicH2Test'/>
<test name='net.sf.briar.db.DatabaseCleanerImplTest'/>
<test name='net.sf.briar.db.DatabaseComponentImplTest'/>
......
......@@ -189,7 +189,7 @@ public class ProtocolIntegrationTest extends BriarTestCase {
InputStream in = new ByteArrayInputStream(connectionData);
byte[] tag = new byte[TAG_LENGTH];
assertEquals(TAG_LENGTH, in.read(tag, 0, TAG_LENGTH));
assertArrayEquals(new byte[TAG_LENGTH], tag);
// FIXME: Check that the expected tag was received
ConnectionContext ctx = new ConnectionContext(contactId, transportId,
secret.clone(), 0L, true);
ConnectionReader conn = connectionReaderFactory.createConnectionReader(
......
......@@ -63,7 +63,7 @@ public class KeyDerivationTest extends BriarTestCase {
public void testConnectionNumberAffectsDerivation() {
List<byte[]> secrets = new ArrayList<byte[]>();
for(int i = 0; i < 20; i++) {
secrets.add(crypto.deriveNextSecret(secret, i));
secrets.add(crypto.deriveNextSecret(secret.clone(), i));
}
for(int i = 0; i < 20; i++) {
byte[] secretI = secrets.get(i);
......
package net.sf.briar.crypto;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import net.sf.briar.BriarTestCase;
import net.sf.briar.api.db.DbException;
import net.sf.briar.crypto.KeyRotatorImpl;
import net.sf.briar.crypto.KeyRotator.Callback;
import org.junit.Test;
public class KeyRotatorImplTest extends BriarTestCase {
@Test
public void testCleanerRunsPeriodically() throws Exception {
final CountDownLatch latch = new CountDownLatch(5);
Callback callback = new Callback() {
public void rotateKeys() throws DbException {
latch.countDown();
}
};
KeyRotatorImpl cleaner = new KeyRotatorImpl();
// Start the rotator
cleaner.startRotating(callback, 10L);
// The keys should be rotated five times (allow 5 secs for system load)
assertTrue(latch.await(5, TimeUnit.SECONDS));
// Stop the rotator
cleaner.stopRotating();
}
@Test
public void testStoppingCleanerWakesItUp() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
Callback callback = new Callback() {
public void rotateKeys() throws DbException {
latch.countDown();
}
};
KeyRotatorImpl cleaner = new KeyRotatorImpl();
long start = System.currentTimeMillis();
// Start the rotator
cleaner.startRotating(callback, 10L * 1000L);
// The keys should be rotated once at startup
assertTrue(latch.await(5, TimeUnit.SECONDS));
// Stop the rotator (it should be waiting between rotations)
cleaner.stopRotating();
long end = System.currentTimeMillis();
// Check that much less than 10 seconds expired
assertTrue(end - start < 10L * 1000L);
}
}
......@@ -88,8 +88,8 @@ public abstract class DatabaseComponentTest extends BriarTestCase {
transports = Collections.singletonList(transport);
contactTransport = new ContactTransport(contactId, transportId, 123L,
234L, 345L, true);
temporarySecret = new TemporarySecret(contactId, transportId, 0L,
new byte[32], 0L, 0L, new byte[4]);
temporarySecret = new TemporarySecret(contactId, transportId, 1L, 2L,
3L, false, 4L, new byte[32], 5L, 6L, new byte[4]);
}
protected abstract <T> DatabaseComponent createDatabaseComponent(
......
This diff is collapsed.
......@@ -14,7 +14,6 @@ import java.util.concurrent.Executors;
import net.sf.briar.BriarTestCase;
import net.sf.briar.TestUtils;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.KeyManager;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DatabaseExecutor;
import net.sf.briar.api.protocol.Ack;
......@@ -24,7 +23,6 @@ import net.sf.briar.api.protocol.RawBatch;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.protocol.UniqueId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser;
import net.sf.briar.api.transport.ConnectionRegistry;
import net.sf.briar.api.transport.ConnectionWriterFactory;
import net.sf.briar.crypto.CryptoModule;
......@@ -48,8 +46,6 @@ public class OutgoingSimplexConnectionTest extends BriarTestCase {
private final Mockery context;
private final DatabaseComponent db;
private final KeyManager keyManager;
private final ConnectionRecogniser connRecogniser;
private final ConnectionRegistry connRegistry;
private final ConnectionWriterFactory connFactory;
private final ProtocolWriterFactory protoFactory;
......@@ -61,14 +57,10 @@ public class OutgoingSimplexConnectionTest extends BriarTestCase {
super();
context = new Mockery();
db = context.mock(DatabaseComponent.class);
keyManager = context.mock(KeyManager.class);
connRecogniser = context.mock(ConnectionRecogniser.class);
Module testModule = new AbstractModule() {
@Override
public void configure() {
bind(DatabaseComponent.class).toInstance(db);
bind(KeyManager.class).toInstance(keyManager);
bind(ConnectionRecogniser.class).toInstance(connRecogniser);
bind(Executor.class).annotatedWith(
DatabaseExecutor.class).toInstance(
Executors.newCachedThreadPool());
......
......@@ -8,6 +8,7 @@ import org.jmock.Expectations;
import org.jmock.Mockery;
import org.junit.Test;
public class ConnectionWriterImplTest extends BriarTestCase {