Commit 7739bcdd authored by akwizgran's avatar akwizgran

Second part of key rotation implementation. Work in progress.

parent 021b3c5a
package net.sf.briar.api.transport;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.TransportId;
......@@ -14,5 +15,12 @@ public interface ConnectionRecogniser {
* expected, or null if the connection was not expected.
*/
ConnectionContext acceptConnection(TransportId t, byte[] tag)
throws DbException;
throws DbException;
void addWindow(ContactId c, TransportId t, long period, boolean alice,
byte[] secret, long centre, byte[] bitmap) throws DbException;
void removeWindow(ContactId c, TransportId t, long period);
void removeWindows(ContactId c);
}
package net.sf.briar.api.transport;
import java.util.Set;
public interface ConnectionWindow {
boolean isSeen(long connection);
void setSeen(long connection);
Set<Long> getUnseen();
}
package net.sf.briar.transport;
import java.util.HashMap;
import java.util.Map;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.api.transport.ConnectionRecogniser;
import com.google.inject.Inject;
class ConnectionRecogniserImpl implements ConnectionRecogniser {
private final CryptoComponent crypto;
private final DatabaseComponent db;
// Locking: this
private final Map<TransportId, TransportConnectionRecogniser> recognisers;
@Inject
ConnectionRecogniserImpl(CryptoComponent crypto, DatabaseComponent db) {
this.crypto = crypto;
this.db = db;
recognisers = new HashMap<TransportId, TransportConnectionRecogniser>();
}
public ConnectionContext acceptConnection(TransportId t, byte[] tag)
throws DbException {
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
}
if(r == null) return null;
return r.acceptConnection(tag);
}
public void addWindow(ContactId c, TransportId t, long period,
boolean alice, byte[] secret, long centre, byte[] bitmap)
throws DbException {
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
if(r == null) {
r = new TransportConnectionRecogniser(crypto, db, t);
recognisers.put(t, r);
}
}
r.addWindow(c, period, alice, secret, centre, bitmap);
}
public void removeWindow(ContactId c, TransportId t, long period) {
TransportConnectionRecogniser r;
synchronized(this) {
r = recognisers.get(t);
}
if(r != null) r.removeWindow(c, period);
}
public synchronized void removeWindows(ContactId c) {
for(TransportConnectionRecogniser r : recognisers.values()) {
r.removeWindows(c);
}
}
}
......@@ -3,59 +3,75 @@ package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.CONNECTION_WINDOW_SIZE;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
import net.sf.briar.api.transport.ConnectionWindow;
// This class is not thread-safe
class ConnectionWindowImpl implements ConnectionWindow {
class ConnectionWindow {
private final Set<Long> unseen;
private long centre;
ConnectionWindowImpl() {
ConnectionWindow() {
unseen = new HashSet<Long>();
for(long l = 0; l < CONNECTION_WINDOW_SIZE / 2; l++) unseen.add(l);
centre = 0;
}
ConnectionWindowImpl(Set<Long> unseen) {
long min = Long.MAX_VALUE, max = Long.MIN_VALUE;
for(long l : unseen) {
if(l < 0 || l > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
if(l < min) min = l;
if(l > max) max = l;
}
if(max - min > CONNECTION_WINDOW_SIZE)
ConnectionWindow(long centre, byte[] bitmap) {
if(centre < 0 || centre > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
if(bitmap.length != CONNECTION_WINDOW_SIZE / 8)
throw new IllegalArgumentException();
centre = max - CONNECTION_WINDOW_SIZE / 2 + 1;
for(long l = centre; l <= max; l++) {
if(!unseen.contains(l)) throw new IllegalArgumentException();
unseen = new HashSet<Long>();
long bottom = getBottom(centre);
long top = getTop(centre);
for(long l = bottom; l < top; l++) {
int offset = (int) (l - bottom);
int bytes = offset / 8;
int bits = offset % 8;
if((bitmap[bytes] & (128 >> bits)) == 0) unseen.add(l);
}
this.unseen = unseen;
this.centre = centre;
}
public boolean isSeen(long connection) {
boolean isSeen(long connection) {
return !unseen.contains(connection);
}
public void setSeen(long connection) {
Collection<Long> setSeen(long connection) {
long bottom = getBottom(centre);
long top = getTop(centre);
if(connection < bottom || connection > top)
throw new IllegalArgumentException();
if(!unseen.remove(connection))
throw new IllegalArgumentException();
Collection<Long> changed = new ArrayList<Long>();
if(connection >= centre) {
centre = connection + 1;
long newBottom = getBottom(centre);
long newTop = getTop(centre);
for(long l = bottom; l < newBottom; l++) unseen.remove(l);
for(long l = top + 1; l <= newTop; l++) unseen.add(l);
for(long l = bottom; l < newBottom; l++) {
if(unseen.remove(l)) changed.add(l);
}
for(long l = top + 1; l <= newTop; l++) {
if(unseen.add(l)) changed.add(l);
}
}
return changed;
}
long getCentre() {
return centre;
}
byte[] getBitmap() {
byte[] bitmap = new byte[CONNECTION_WINDOW_SIZE / 8];
// FIXME
return bitmap;
}
// Returns the lowest value contained in a window with the given centre
......@@ -69,7 +85,7 @@ class ConnectionWindowImpl implements ConnectionWindow {
centre + CONNECTION_WINDOW_SIZE / 2 - 1);
}
public Set<Long> getUnseen() {
public Collection<Long> getUnseen() {
return unseen;
}
}
package net.sf.briar.transport;
import static javax.crypto.Cipher.DECRYPT_MODE;
import static javax.crypto.Cipher.ENCRYPT_MODE;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.security.GeneralSecurityException;
import javax.crypto.Cipher;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.util.ByteUtils;
class TagEncoder {
static void encodeTag(byte[] tag, Cipher tagCipher, ErasableKey tagKey) {
static void encodeTag(byte[] tag, Cipher tagCipher, ErasableKey tagKey,
long connection) {
if(tag.length < TAG_LENGTH) throw new IllegalArgumentException();
// Blank plaintext
if(connection < 0 || connection > MAX_32_BIT_UNSIGNED)
throw new IllegalArgumentException();
for(int i = 0; i < TAG_LENGTH; i++) tag[i] = 0;
ByteUtils.writeUint32(connection, tag, 0);
try {
tagCipher.init(ENCRYPT_MODE, tagKey);
int encrypted = tagCipher.doFinal(tag, 0, TAG_LENGTH, tag);
......@@ -25,19 +29,4 @@ class TagEncoder {
throw new IllegalArgumentException(e);
}
}
static boolean decodeTag(byte[] tag, Cipher tagCipher, ErasableKey tagKey) {
if(tag.length < TAG_LENGTH) throw new IllegalArgumentException();
try {
tagCipher.init(DECRYPT_MODE, tagKey);
int decrypted = tagCipher.doFinal(tag, 0, TAG_LENGTH, tag);
if(decrypted != TAG_LENGTH) throw new IllegalArgumentException();
//The plaintext should be blank
for(int i = 0; i < TAG_LENGTH; i++) if(tag[i] != 0) return false;
return true;
} catch(GeneralSecurityException e) {
// Unsuitable cipher or key
throw new IllegalArgumentException(e);
}
}
}
package net.sf.briar.transport;
import static net.sf.briar.api.transport.TransportConstants.TAG_LENGTH;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import javax.crypto.Cipher;
import net.sf.briar.api.Bytes;
import net.sf.briar.api.ContactId;
import net.sf.briar.api.crypto.CryptoComponent;
import net.sf.briar.api.crypto.ErasableKey;
import net.sf.briar.api.db.DatabaseComponent;
import net.sf.briar.api.db.DbException;
import net.sf.briar.api.protocol.TransportId;
import net.sf.briar.api.transport.ConnectionContext;
import net.sf.briar.util.ByteUtils;
/** A connection recogniser for a specific transport. */
class TransportConnectionRecogniser {
private final CryptoComponent crypto;
private final DatabaseComponent db;
private final TransportId transportId;
private final Map<Bytes, WindowContext> tagMap; // Locking: this
private final Map<RemovalKey, RemovalContext> windowMap; // Locking: this
TransportConnectionRecogniser(CryptoComponent crypto, DatabaseComponent db,
TransportId transportId) {
this.crypto = crypto;
this.db = db;
this.transportId = transportId;
tagMap = new HashMap<Bytes, WindowContext>();
windowMap = new HashMap<RemovalKey, RemovalContext>();
}
synchronized ConnectionContext acceptConnection(byte[] tag)
throws DbException {
WindowContext wctx = tagMap.remove(new Bytes(tag));
if(wctx == null) return null;
ConnectionWindow w = wctx.window;
ConnectionContext ctx = wctx.context;
long connection = ctx.getConnectionNumber();
Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(ctx.getSecret(), ctx.getAlice());
byte[] changedTag = new byte[TAG_LENGTH];
Bytes changedTagWrapper = new Bytes(changedTag);
for(long conn : w.setSeen(connection)) {
TagEncoder.encodeTag(changedTag, cipher, key, conn);
WindowContext old;
if(conn <= connection) old = tagMap.remove(changedTagWrapper);
else old = tagMap.put(changedTagWrapper, wctx);
assert old == null;
}
key.erase();
ContactId c = ctx.getContactId();
long centre = w.getCentre();
byte[] bitmap = w.getBitmap();
db.setConnectionWindow(c, transportId, wctx.period, centre, bitmap);
return wctx.context;
}
synchronized void addWindow(ContactId c, long period, boolean alice,
byte[] secret, long centre, byte[] bitmap) throws DbException {
Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(secret, alice);
ConnectionWindow w = new ConnectionWindow(centre, bitmap);
for(long conn : w.getUnseen()) {
byte[] tag = new byte[TAG_LENGTH];
TagEncoder.encodeTag(tag, cipher, key, conn);
ConnectionContext ctx = new ConnectionContext(c, transportId, tag,
secret, conn, alice);
WindowContext wctx = new WindowContext(w, ctx, period);
tagMap.put(new Bytes(tag), wctx);
}
db.setConnectionWindow(c, transportId, period, centre, bitmap);
RemovalContext ctx = new RemovalContext(w, secret, alice);
windowMap.put(new RemovalKey(c, period), ctx);
}
synchronized void removeWindow(ContactId c, long period) {
RemovalContext ctx = windowMap.remove(new RemovalKey(c, period));
if(ctx == null) throw new IllegalArgumentException();
Cipher cipher = crypto.getTagCipher();
ErasableKey key = crypto.deriveTagKey(ctx.secret, ctx.alice);
byte[] removedTag = new byte[TAG_LENGTH];
Bytes removedTagWrapper = new Bytes(removedTag);
for(long conn : ctx.window.getUnseen()) {
TagEncoder.encodeTag(removedTag, cipher, key, conn);
WindowContext old = tagMap.remove(removedTagWrapper);
assert old != null;
}
key.erase();
ByteUtils.erase(ctx.secret);
}
synchronized void removeWindows(ContactId c) {
Collection<RemovalKey> keysToRemove = new ArrayList<RemovalKey>();
for(RemovalKey k : windowMap.keySet()) {
if(k.contactId.equals(c)) keysToRemove.add(k);
}
for(RemovalKey k : keysToRemove) removeWindow(k.contactId, k.period);
}
private static class WindowContext {
private final ConnectionWindow window;
private final ConnectionContext context;
private final long period;
private WindowContext(ConnectionWindow window,
ConnectionContext context, long period) {
this.window = window;
this.context = context;
this.period = period;
}
}
private static class RemovalKey {
private final ContactId contactId;
private final long period;
private RemovalKey(ContactId contactId, long period) {
this.contactId = contactId;
this.period = period;
}
@Override
public int hashCode() {
return contactId.hashCode()+ (int) period;
}
@Override
public boolean equals(Object o) {
if(o instanceof RemovalKey) {
RemovalKey w = (RemovalKey) o;
return contactId.equals(w.contactId) && period == w.period;
}
return false;
}
}
private static class RemovalContext {
private final ConnectionWindow window;
private final byte[] secret;
private final boolean alice;
private RemovalContext(ConnectionWindow window, byte[] secret,
boolean alice) {
this.window = window;
this.secret = secret;
this.alice = alice;
}
}
}
......@@ -49,7 +49,7 @@
<test name='net.sf.briar.serial.WriterImplTest'/>
<test name='net.sf.briar.transport.ConnectionReaderImplTest'/>
<test name='net.sf.briar.transport.ConnectionRegistryImplTest'/>
<test name='net.sf.briar.transport.ConnectionWindowImplTest'/>
<test name='net.sf.briar.transport.ConnectionWindowTest'/>
<test name='net.sf.briar.transport.ConnectionWriterImplTest'/>
<test name='net.sf.briar.transport.IncomingEncryptionLayerTest'/>
<test name='net.sf.briar.transport.OutgoingEncryptionLayerTest'/>
......
package net.sf.briar.transport;
import static net.sf.briar.util.ByteUtils.MAX_32_BIT_UNSIGNED;
import java.util.HashSet;
import java.util.Set;
import java.util.Collection;
import net.sf.briar.BriarTestCase;
import net.sf.briar.api.transport.ConnectionWindow;
import org.junit.Test;
public class ConnectionWindowImplTest extends BriarTestCase {
public class ConnectionWindowTest extends BriarTestCase {
@Test
public void testWindowSliding() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
for(int i = 0; i < 100; i++) {
assertFalse(w.isSeen(i));
w.setSeen(i);
......@@ -24,7 +20,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test
public void testWindowJumping() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
for(int i = 0; i < 100; i += 13) {
assertFalse(w.isSeen(i));
w.setSeen(i);
......@@ -34,7 +30,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test
public void testWindowUpperLimit() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
// Centre is 0, highest value in window is 15
w.setSeen(15);
// Centre is 16, highest value in window is 31
......@@ -44,20 +40,11 @@ public class ConnectionWindowImplTest extends BriarTestCase {
w.setSeen(48);
fail();
} catch(IllegalArgumentException expected) {}
// Values greater than 2^32 - 1 should never be allowed
Set<Long> unseen = new HashSet<Long>();
for(int i = 0; i < 32; i++) unseen.add(MAX_32_BIT_UNSIGNED - i);
w = new ConnectionWindowImpl(unseen);
w.setSeen(MAX_32_BIT_UNSIGNED);
try {
w.setSeen(MAX_32_BIT_UNSIGNED + 1);
fail();
} catch(IllegalArgumentException expected) {}
}
@Test
public void testWindowLowerLimit() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
// Centre is 0, negative values should never be allowed
try {
w.setSeen(-1);
......@@ -87,7 +74,7 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test
public void testCannotSetSeenTwice() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
w.setSeen(15);
try {
w.setSeen(15);
......@@ -97,9 +84,9 @@ public class ConnectionWindowImplTest extends BriarTestCase {
@Test
public void testGetUnseenConnectionNumbers() {
ConnectionWindow w = new ConnectionWindowImpl();
ConnectionWindow w = new ConnectionWindow();
// Centre is 0; window should cover 0 to 15, inclusive, with none seen
Set<Long> unseen = w.getUnseen();
Collection<Long> unseen = w.getUnseen();
assertEquals(16, unseen.size());
for(int i = 0; i < 16; i++) {
assertTrue(unseen.contains(Long.valueOf(i)));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment